forked from platypush/platypush
More flexible module loading and better lock management for models in Tensorflow plugin
This commit is contained in:
parent
9b23ab7015
commit
287b6303ae
1 changed files with 105 additions and 41 deletions
|
@ -1,8 +1,10 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import pathlib
|
||||||
import random
|
import random
|
||||||
import shutil
|
import shutil
|
||||||
import threading
|
import threading
|
||||||
|
from contextlib import contextmanager
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Dict, Any, Union, Optional, Tuple, Iterable
|
from typing import List, Dict, Any, Union, Optional, Tuple, Iterable
|
||||||
|
|
||||||
|
@ -57,36 +59,76 @@ class TensorflowPlugin(Plugin):
|
||||||
|
|
||||||
def __init__(self, workdir: Optional[str] = None, **kwargs):
|
def __init__(self, workdir: Optional[str] = None, **kwargs):
|
||||||
"""
|
"""
|
||||||
:param workdir: Working directory for TensorFlow, where models will be stored
|
:param workdir: Working directory for TensorFlow, where models will be stored and looked up by default
|
||||||
(default: PLATYPUSH_WORKDIR/tensorflow).
|
(default: PLATYPUSH_WORKDIR/tensorflow).
|
||||||
"""
|
"""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.models: Dict[str, Model] = {}
|
self.models: Dict[str, Model] = {}
|
||||||
|
self._models_lock = threading.RLock()
|
||||||
self._model_locks: Dict[str, threading.RLock()] = {}
|
self._model_locks: Dict[str, threading.RLock()] = {}
|
||||||
self._work_dir = os.path.abspath(os.path.expanduser(workdir)) if workdir else \
|
self._work_dir = os.path.abspath(os.path.expanduser(workdir)) if workdir else \
|
||||||
os.path.join(Config.get('workdir'), 'tensorflow')
|
os.path.join(Config.get('workdir'), 'tensorflow')
|
||||||
|
|
||||||
self._models_dir = os.path.join(self._work_dir, 'models')
|
self._models_dir = os.path.join(self._work_dir, 'models')
|
||||||
os.makedirs(self._models_dir, mode=0o755, exist_ok=True)
|
pathlib.Path(self._models_dir).mkdir(mode=0o755, exist_ok=True, parents=True)
|
||||||
|
|
||||||
def _load_model(self, name: str, reload: bool = False) -> Model:
|
@contextmanager
|
||||||
if name in self.models and not reload:
|
def _lock_model(self, model_name: str):
|
||||||
return self.models[name]
|
with self._models_lock:
|
||||||
|
if model_name not in self._model_locks:
|
||||||
|
self._model_locks[model_name] = threading.RLock()
|
||||||
|
|
||||||
model_dir = os.path.join(self._models_dir, name)
|
try:
|
||||||
assert os.path.isdir(model_dir), 'The model {} does not exist'.format(name)
|
success = self._model_locks[model_name].acquire(blocking=True, timeout=30.)
|
||||||
|
assert success, 'Unable to acquire the model lock'
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
self._model_locks[model_name].release()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _load_model(self, model_name: str, reload: bool = False) -> Model:
|
||||||
|
if model_name in self.models and not reload:
|
||||||
|
return self.models[model_name]
|
||||||
|
|
||||||
|
model = None
|
||||||
|
model_dir = None
|
||||||
|
|
||||||
|
if os.path.isdir(os.path.join(self._models_dir, model_name)):
|
||||||
|
model_dir = os.path.join(self._models_dir, model_name)
|
||||||
model = load_model(model_dir)
|
model = load_model(model_dir)
|
||||||
|
else:
|
||||||
|
model_name = os.path.abspath(os.path.expanduser(model_name))
|
||||||
|
if model_name in self.models and not reload:
|
||||||
|
return self.models[model_name]
|
||||||
|
|
||||||
|
if os.path.isfile(model_name):
|
||||||
|
model_dir = str(pathlib.Path(model_name).parent)
|
||||||
|
model = load_model(model_name)
|
||||||
|
elif os.path.isdir(model_name):
|
||||||
|
model_dir = model_name
|
||||||
|
model = load_model(model_dir)
|
||||||
|
|
||||||
|
assert model, 'Could not find model: {}'.format(model_name)
|
||||||
model.input_labels = []
|
model.input_labels = []
|
||||||
model.output_labels = []
|
model.output_labels = []
|
||||||
|
|
||||||
labels_file = os.path.join(model_dir, 'labels.json')
|
labels_file = os.path.join(model_dir, 'labels.json')
|
||||||
|
|
||||||
if os.path.isfile(labels_file):
|
if os.path.isfile(labels_file):
|
||||||
with open(labels_file, 'r') as f:
|
with open(labels_file, 'r') as f:
|
||||||
labels = json.load(f)
|
labels = json.load(f)
|
||||||
|
if isinstance(labels, dict):
|
||||||
if 'input' in labels:
|
if 'input' in labels:
|
||||||
model.input_labels = labels['input']
|
model.input_labels = labels['input']
|
||||||
if 'output' in labels:
|
if 'output' in labels:
|
||||||
model.output_labels = labels['output']
|
model.output_labels = labels['output']
|
||||||
|
elif hasattr(labels, '__iter__'):
|
||||||
|
model.output_labels = labels
|
||||||
|
|
||||||
|
with self._lock_model(model_name):
|
||||||
|
self.models[model_name] = model
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@ -142,7 +184,8 @@ class TensorflowPlugin(Plugin):
|
||||||
"""
|
"""
|
||||||
(Re)-load a model from the file system.
|
(Re)-load a model from the file system.
|
||||||
|
|
||||||
:param model: Name of the model. Must be a folder name stored under ``<workdir>/models``.
|
:param model: Name of the model. It can be a folder name stored under ``<workdir>/models``, or an absolute path
|
||||||
|
to a model directory or file (Tensorflow directories, Protobuf models and HDF5 files are supported).
|
||||||
:param reload: If ``True``, the model will be reloaded from the filesystem even if it's been already
|
:param reload: If ``True``, the model will be reloaded from the filesystem even if it's been already
|
||||||
loaded, otherwise the model currently in memory will be kept (default: ``False``).
|
loaded, otherwise the model currently in memory will be kept (default: ``False``).
|
||||||
:return: The model configuration.
|
:return: The model configuration.
|
||||||
|
@ -157,6 +200,7 @@ class TensorflowPlugin(Plugin):
|
||||||
|
|
||||||
:param model: Name of the model.
|
:param model: Name of the model.
|
||||||
"""
|
"""
|
||||||
|
with self._lock_model(model):
|
||||||
assert model in self.models, 'The model {} is not loaded'.format(model)
|
assert model in self.models, 'The model {} is not loaded'.format(model)
|
||||||
del self.models[model]
|
del self.models[model]
|
||||||
|
|
||||||
|
@ -168,6 +212,7 @@ class TensorflowPlugin(Plugin):
|
||||||
|
|
||||||
:param model: Name of the model.
|
:param model: Name of the model.
|
||||||
"""
|
"""
|
||||||
|
with self._lock_model(model):
|
||||||
if model in self.models:
|
if model in self.models:
|
||||||
del self.models[model]
|
del self.models[model]
|
||||||
|
|
||||||
|
@ -388,6 +433,8 @@ class TensorflowPlugin(Plugin):
|
||||||
|
|
||||||
model.input_labels = input_names or []
|
model.input_labels = input_names or []
|
||||||
model.output_labels = output_names or []
|
model.output_labels = output_names or []
|
||||||
|
|
||||||
|
with self._lock_model(name):
|
||||||
self.models[name] = model
|
self.models[name] = model
|
||||||
return model.get_config()
|
return model.get_config()
|
||||||
|
|
||||||
|
@ -520,6 +567,7 @@ class TensorflowPlugin(Plugin):
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with self._lock_model(name):
|
||||||
self.models[name] = model
|
self.models[name] = model
|
||||||
return model.get_config()
|
return model.get_config()
|
||||||
|
|
||||||
|
@ -551,10 +599,7 @@ class TensorflowPlugin(Plugin):
|
||||||
model.name, size)
|
model.name, size)
|
||||||
|
|
||||||
colors = input_shape[3:]
|
colors = input_shape[3:]
|
||||||
assert colors, ('The model {} requires a tensor with at least 3 inputs in order to process images: ' +
|
if len(colors) == 0 or colors[0] == 1:
|
||||||
'[WIDTH, HEIGHT, COLORS]').format(model.name)
|
|
||||||
|
|
||||||
if colors[0] == 1:
|
|
||||||
color_mode = 'grayscale'
|
color_mode = 'grayscale'
|
||||||
elif colors[0] == 3:
|
elif colors[0] == 3:
|
||||||
color_mode = 'rgb'
|
color_mode = 'rgb'
|
||||||
|
@ -664,7 +709,8 @@ class TensorflowPlugin(Plugin):
|
||||||
"""
|
"""
|
||||||
Trains a model on a dataset for a fixed number of epochs.
|
Trains a model on a dataset for a fixed number of epochs.
|
||||||
|
|
||||||
:param model: Name of the existing model to be trained.
|
:param model: Name of the model. It can be a folder name stored under ``<workdir>/models``, or an absolute path
|
||||||
|
to a model directory or file (Tensorflow directories, Protobuf models and HDF5 files are supported).
|
||||||
:param inputs: Input data. It can be:
|
:param inputs: Input data. It can be:
|
||||||
|
|
||||||
- A numpy array (or array-like), or a list of arrays in case the model has multiple inputs.
|
- A numpy array (or array-like), or a list of arrays in case the model has multiple inputs.
|
||||||
|
@ -818,7 +864,8 @@ class TensorflowPlugin(Plugin):
|
||||||
"""
|
"""
|
||||||
Returns the loss value and metrics values for the model in test model.
|
Returns the loss value and metrics values for the model in test model.
|
||||||
|
|
||||||
:param model: Name of the existing model to be trained.
|
:param model: Name of the model. It can be a folder name stored under ``<workdir>/models``, or an absolute path
|
||||||
|
to a model directory or file (Tensorflow directories, Protobuf models and HDF5 files are supported).
|
||||||
:param inputs: Input data. It can be:
|
:param inputs: Input data. It can be:
|
||||||
|
|
||||||
- A numpy array (or array-like), or a list of arrays in case the model has multiple inputs.
|
- A numpy array (or array-like), or a list of arrays in case the model has multiple inputs.
|
||||||
|
@ -922,7 +969,8 @@ class TensorflowPlugin(Plugin):
|
||||||
"""
|
"""
|
||||||
Generates output predictions for the input samples.
|
Generates output predictions for the input samples.
|
||||||
|
|
||||||
:param model: Name of the existing model to be trained.
|
:param model: Name of the model. It can be a folder name stored under ``<workdir>/models``, or an absolute path
|
||||||
|
to a model directory or file (Tensorflow directories, Protobuf models and HDF5 files are supported).
|
||||||
:param inputs: Input data. It can be:
|
:param inputs: Input data. It can be:
|
||||||
|
|
||||||
- A numpy array (or array-like), or a list of arrays in case the model has multiple inputs.
|
- A numpy array (or array-like), or a list of arrays in case the model has multiple inputs.
|
||||||
|
@ -1035,22 +1083,38 @@ class TensorflowPlugin(Plugin):
|
||||||
:param overwrite: Overwrite the model files if they already exist.
|
:param overwrite: Overwrite the model files if they already exist.
|
||||||
:param opts: Extra options to be passed to ``Model.save()``.
|
:param opts: Extra options to be passed to ``Model.save()``.
|
||||||
"""
|
"""
|
||||||
assert model in self.models, 'No such model in memory: {}'.format(model)
|
model_name = model
|
||||||
model_dir = os.path.join(self._models_dir, model)
|
model_dir = None
|
||||||
model = self.models[model]
|
|
||||||
os.makedirs(model_dir, exist_ok=True)
|
if os.path.isdir(os.path.join(self._work_dir, model_name)):
|
||||||
|
model_dir = os.path.join(self._work_dir, model_name)
|
||||||
|
else:
|
||||||
|
model_name = os.path.abspath(os.path.expanduser(model_name))
|
||||||
|
if os.path.isfile(model_name):
|
||||||
|
model_dir = str(pathlib.Path(model_name).parent)
|
||||||
|
elif os.path.isdir(model_name):
|
||||||
|
model_dir = model_name
|
||||||
|
|
||||||
|
assert model_dir and model_name in self.models, 'No such model loaded: {}'.format(model)
|
||||||
|
|
||||||
|
with self._lock_model(model_name):
|
||||||
|
model = self.models[model_name]
|
||||||
labels = {}
|
labels = {}
|
||||||
labels_file = os.path.join(model_dir, 'labels.json')
|
labels_file = os.path.join(model_dir, 'labels.json')
|
||||||
|
|
||||||
if model.input_labels:
|
if hasattr(model, 'input_labels') and model.input_labels:
|
||||||
labels['input'] = model.input_labels
|
labels['input'] = model.input_labels
|
||||||
if model.output_labels:
|
if hasattr(model, 'output_labels') and model.output_labels:
|
||||||
|
if hasattr(labels, 'input'):
|
||||||
labels['output'] = model.output_labels
|
labels['output'] = model.output_labels
|
||||||
|
else:
|
||||||
|
labels = model.output_labels
|
||||||
|
|
||||||
if labels:
|
if labels:
|
||||||
with open(labels_file, 'w') as f:
|
with open(labels_file, 'w') as f:
|
||||||
json.dump(labels, f)
|
json.dump(labels, f)
|
||||||
|
|
||||||
model.save(model_dir, overwrite=overwrite, options=opts)
|
model.save(model_name if os.path.isfile(model_name) else model_dir, overwrite=overwrite, options=opts)
|
||||||
|
|
||||||
|
|
||||||
# vim:sw=4:ts=4:et:
|
# vim:sw=4:ts=4:et:
|
||||||
|
|
Loading…
Reference in a new issue