More flexible module loading and better lock management for models in Tensorflow plugin

This commit is contained in:
Fabio Manganiello 2020-10-01 17:41:12 +02:00
parent 9b23ab7015
commit 287b6303ae
1 changed files with 105 additions and 41 deletions

View File

@ -1,8 +1,10 @@
import json
import os
import pathlib
import random
import shutil
import threading
from contextlib import contextmanager
from datetime import datetime
from typing import List, Dict, Any, Union, Optional, Tuple, Iterable
@ -57,36 +59,76 @@ class TensorflowPlugin(Plugin):
def __init__(self, workdir: Optional[str] = None, **kwargs):
"""
:param workdir: Working directory for TensorFlow, where models will be stored
:param workdir: Working directory for TensorFlow, where models will be stored and looked up by default
(default: PLATYPUSH_WORKDIR/tensorflow).
"""
super().__init__(**kwargs)
self.models: Dict[str, Model] = {}
self._models_lock = threading.RLock()
self._model_locks: Dict[str, threading.RLock()] = {}
self._work_dir = os.path.abspath(os.path.expanduser(workdir)) if workdir else \
os.path.join(Config.get('workdir'), 'tensorflow')
self._models_dir = os.path.join(self._work_dir, 'models')
os.makedirs(self._models_dir, mode=0o755, exist_ok=True)
pathlib.Path(self._models_dir).mkdir(mode=0o755, exist_ok=True, parents=True)
def _load_model(self, name: str, reload: bool = False) -> Model:
if name in self.models and not reload:
return self.models[name]
@contextmanager
def _lock_model(self, model_name: str):
with self._models_lock:
if model_name not in self._model_locks:
self._model_locks[model_name] = threading.RLock()
model_dir = os.path.join(self._models_dir, name)
assert os.path.isdir(model_dir), 'The model {} does not exist'.format(name)
model = load_model(model_dir)
try:
success = self._model_locks[model_name].acquire(blocking=True, timeout=30.)
assert success, 'Unable to acquire the model lock'
yield
finally:
# noinspection PyBroadException
try:
self._model_locks[model_name].release()
except:
pass
def _load_model(self, model_name: str, reload: bool = False) -> Model:
if model_name in self.models and not reload:
return self.models[model_name]
model = None
model_dir = None
if os.path.isdir(os.path.join(self._models_dir, model_name)):
model_dir = os.path.join(self._models_dir, model_name)
model = load_model(model_dir)
else:
model_name = os.path.abspath(os.path.expanduser(model_name))
if model_name in self.models and not reload:
return self.models[model_name]
if os.path.isfile(model_name):
model_dir = str(pathlib.Path(model_name).parent)
model = load_model(model_name)
elif os.path.isdir(model_name):
model_dir = model_name
model = load_model(model_dir)
assert model, 'Could not find model: {}'.format(model_name)
model.input_labels = []
model.output_labels = []
labels_file = os.path.join(model_dir, 'labels.json')
if os.path.isfile(labels_file):
with open(labels_file, 'r') as f:
labels = json.load(f)
if 'input' in labels:
model.input_labels = labels['input']
if 'output' in labels:
model.output_labels = labels['output']
if isinstance(labels, dict):
if 'input' in labels:
model.input_labels = labels['input']
if 'output' in labels:
model.output_labels = labels['output']
elif hasattr(labels, '__iter__'):
model.output_labels = labels
with self._lock_model(model_name):
self.models[model_name] = model
return model
@ -142,7 +184,8 @@ class TensorflowPlugin(Plugin):
"""
(Re)-load a model from the file system.
:param model: Name of the model. Must be a folder name stored under ``<workdir>/models``.
:param model: Name of the model. It can be a folder name stored under ``<workdir>/models``, or an absolute path
to a model directory or file (Tensorflow directories, Protobuf models and HDF5 files are supported).
:param reload: If ``True``, the model will be reloaded from the filesystem even if it's been already
loaded, otherwise the model currently in memory will be kept (default: ``False``).
:return: The model configuration.
@ -157,8 +200,9 @@ class TensorflowPlugin(Plugin):
:param model: Name of the model.
"""
assert model in self.models, 'The model {} is not loaded'.format(model)
del self.models[model]
with self._lock_model(model):
assert model in self.models, 'The model {} is not loaded'.format(model)
del self.models[model]
@action
def remove(self, model: str) -> None:
@ -168,8 +212,9 @@ class TensorflowPlugin(Plugin):
:param model: Name of the model.
"""
if model in self.models:
del self.models[model]
with self._lock_model(model):
if model in self.models:
del self.models[model]
model_dir = os.path.join(self._models_dir, model)
if os.path.isdir(model_dir):
@ -388,7 +433,9 @@ class TensorflowPlugin(Plugin):
model.input_labels = input_names or []
model.output_labels = output_names or []
self.models[name] = model
with self._lock_model(name):
self.models[name] = model
return model.get_config()
@action
@ -520,7 +567,8 @@ class TensorflowPlugin(Plugin):
**kwargs
)
self.models[name] = model
with self._lock_model(name):
self.models[name] = model
return model.get_config()
@staticmethod
@ -551,10 +599,7 @@ class TensorflowPlugin(Plugin):
model.name, size)
colors = input_shape[3:]
assert colors, ('The model {} requires a tensor with at least 3 inputs in order to process images: ' +
'[WIDTH, HEIGHT, COLORS]').format(model.name)
if colors[0] == 1:
if len(colors) == 0 or colors[0] == 1:
color_mode = 'grayscale'
elif colors[0] == 3:
color_mode = 'rgb'
@ -664,7 +709,8 @@ class TensorflowPlugin(Plugin):
"""
Trains a model on a dataset for a fixed number of epochs.
:param model: Name of the existing model to be trained.
:param model: Name of the model. It can be a folder name stored under ``<workdir>/models``, or an absolute path
to a model directory or file (Tensorflow directories, Protobuf models and HDF5 files are supported).
:param inputs: Input data. It can be:
- A numpy array (or array-like), or a list of arrays in case the model has multiple inputs.
@ -818,7 +864,8 @@ class TensorflowPlugin(Plugin):
"""
Returns the loss value and metrics values for the model in test model.
:param model: Name of the existing model to be trained.
:param model: Name of the model. It can be a folder name stored under ``<workdir>/models``, or an absolute path
to a model directory or file (Tensorflow directories, Protobuf models and HDF5 files are supported).
:param inputs: Input data. It can be:
- A numpy array (or array-like), or a list of arrays in case the model has multiple inputs.
@ -922,7 +969,8 @@ class TensorflowPlugin(Plugin):
"""
Generates output predictions for the input samples.
:param model: Name of the existing model to be trained.
:param model: Name of the model. It can be a folder name stored under ``<workdir>/models``, or an absolute path
to a model directory or file (Tensorflow directories, Protobuf models and HDF5 files are supported).
:param inputs: Input data. It can be:
- A numpy array (or array-like), or a list of arrays in case the model has multiple inputs.
@ -1035,22 +1083,38 @@ class TensorflowPlugin(Plugin):
:param overwrite: Overwrite the model files if they already exist.
:param opts: Extra options to be passed to ``Model.save()``.
"""
assert model in self.models, 'No such model in memory: {}'.format(model)
model_dir = os.path.join(self._models_dir, model)
model = self.models[model]
os.makedirs(model_dir, exist_ok=True)
labels = {}
labels_file = os.path.join(model_dir, 'labels.json')
model_name = model
model_dir = None
if model.input_labels:
labels['input'] = model.input_labels
if model.output_labels:
labels['output'] = model.output_labels
if labels:
with open(labels_file, 'w') as f:
json.dump(labels, f)
if os.path.isdir(os.path.join(self._work_dir, model_name)):
model_dir = os.path.join(self._work_dir, model_name)
else:
model_name = os.path.abspath(os.path.expanduser(model_name))
if os.path.isfile(model_name):
model_dir = str(pathlib.Path(model_name).parent)
elif os.path.isdir(model_name):
model_dir = model_name
model.save(model_dir, overwrite=overwrite, options=opts)
assert model_dir and model_name in self.models, 'No such model loaded: {}'.format(model)
with self._lock_model(model_name):
model = self.models[model_name]
labels = {}
labels_file = os.path.join(model_dir, 'labels.json')
if hasattr(model, 'input_labels') and model.input_labels:
labels['input'] = model.input_labels
if hasattr(model, 'output_labels') and model.output_labels:
if hasattr(labels, 'input'):
labels['output'] = model.output_labels
else:
labels = model.output_labels
if labels:
with open(labels_file, 'w') as f:
json.dump(labels, f)
model.save(model_name if os.path.isfile(model_name) else model_dir, overwrite=overwrite, options=opts)
# vim:sw=4:ts=4:et: