forked from platypush/platypush
More flexible module loading and better lock management for models in Tensorflow plugin
This commit is contained in:
parent
9b23ab7015
commit
287b6303ae
1 changed files with 105 additions and 41 deletions
|
@ -1,8 +1,10 @@
|
|||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import random
|
||||
import shutil
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Union, Optional, Tuple, Iterable
|
||||
|
||||
|
@ -57,36 +59,76 @@ class TensorflowPlugin(Plugin):
|
|||
|
||||
def __init__(self, workdir: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
:param workdir: Working directory for TensorFlow, where models will be stored
|
||||
:param workdir: Working directory for TensorFlow, where models will be stored and looked up by default
|
||||
(default: PLATYPUSH_WORKDIR/tensorflow).
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.models: Dict[str, Model] = {}
|
||||
self._models_lock = threading.RLock()
|
||||
self._model_locks: Dict[str, threading.RLock()] = {}
|
||||
self._work_dir = os.path.abspath(os.path.expanduser(workdir)) if workdir else \
|
||||
os.path.join(Config.get('workdir'), 'tensorflow')
|
||||
|
||||
self._models_dir = os.path.join(self._work_dir, 'models')
|
||||
os.makedirs(self._models_dir, mode=0o755, exist_ok=True)
|
||||
pathlib.Path(self._models_dir).mkdir(mode=0o755, exist_ok=True, parents=True)
|
||||
|
||||
def _load_model(self, name: str, reload: bool = False) -> Model:
|
||||
if name in self.models and not reload:
|
||||
return self.models[name]
|
||||
@contextmanager
|
||||
def _lock_model(self, model_name: str):
|
||||
with self._models_lock:
|
||||
if model_name not in self._model_locks:
|
||||
self._model_locks[model_name] = threading.RLock()
|
||||
|
||||
model_dir = os.path.join(self._models_dir, name)
|
||||
assert os.path.isdir(model_dir), 'The model {} does not exist'.format(name)
|
||||
model = load_model(model_dir)
|
||||
try:
|
||||
success = self._model_locks[model_name].acquire(blocking=True, timeout=30.)
|
||||
assert success, 'Unable to acquire the model lock'
|
||||
yield
|
||||
finally:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self._model_locks[model_name].release()
|
||||
except:
|
||||
pass
|
||||
|
||||
def _load_model(self, model_name: str, reload: bool = False) -> Model:
|
||||
if model_name in self.models and not reload:
|
||||
return self.models[model_name]
|
||||
|
||||
model = None
|
||||
model_dir = None
|
||||
|
||||
if os.path.isdir(os.path.join(self._models_dir, model_name)):
|
||||
model_dir = os.path.join(self._models_dir, model_name)
|
||||
model = load_model(model_dir)
|
||||
else:
|
||||
model_name = os.path.abspath(os.path.expanduser(model_name))
|
||||
if model_name in self.models and not reload:
|
||||
return self.models[model_name]
|
||||
|
||||
if os.path.isfile(model_name):
|
||||
model_dir = str(pathlib.Path(model_name).parent)
|
||||
model = load_model(model_name)
|
||||
elif os.path.isdir(model_name):
|
||||
model_dir = model_name
|
||||
model = load_model(model_dir)
|
||||
|
||||
assert model, 'Could not find model: {}'.format(model_name)
|
||||
model.input_labels = []
|
||||
model.output_labels = []
|
||||
|
||||
labels_file = os.path.join(model_dir, 'labels.json')
|
||||
|
||||
if os.path.isfile(labels_file):
|
||||
with open(labels_file, 'r') as f:
|
||||
labels = json.load(f)
|
||||
if 'input' in labels:
|
||||
model.input_labels = labels['input']
|
||||
if 'output' in labels:
|
||||
model.output_labels = labels['output']
|
||||
if isinstance(labels, dict):
|
||||
if 'input' in labels:
|
||||
model.input_labels = labels['input']
|
||||
if 'output' in labels:
|
||||
model.output_labels = labels['output']
|
||||
elif hasattr(labels, '__iter__'):
|
||||
model.output_labels = labels
|
||||
|
||||
with self._lock_model(model_name):
|
||||
self.models[model_name] = model
|
||||
|
||||
return model
|
||||
|
||||
|
@ -142,7 +184,8 @@ class TensorflowPlugin(Plugin):
|
|||
"""
|
||||
(Re)-load a model from the file system.
|
||||
|
||||
:param model: Name of the model. Must be a folder name stored under ``<workdir>/models``.
|
||||
:param model: Name of the model. It can be a folder name stored under ``<workdir>/models``, or an absolute path
|
||||
to a model directory or file (Tensorflow directories, Protobuf models and HDF5 files are supported).
|
||||
:param reload: If ``True``, the model will be reloaded from the filesystem even if it's been already
|
||||
loaded, otherwise the model currently in memory will be kept (default: ``False``).
|
||||
:return: The model configuration.
|
||||
|
@ -157,8 +200,9 @@ class TensorflowPlugin(Plugin):
|
|||
|
||||
:param model: Name of the model.
|
||||
"""
|
||||
assert model in self.models, 'The model {} is not loaded'.format(model)
|
||||
del self.models[model]
|
||||
with self._lock_model(model):
|
||||
assert model in self.models, 'The model {} is not loaded'.format(model)
|
||||
del self.models[model]
|
||||
|
||||
@action
|
||||
def remove(self, model: str) -> None:
|
||||
|
@ -168,8 +212,9 @@ class TensorflowPlugin(Plugin):
|
|||
|
||||
:param model: Name of the model.
|
||||
"""
|
||||
if model in self.models:
|
||||
del self.models[model]
|
||||
with self._lock_model(model):
|
||||
if model in self.models:
|
||||
del self.models[model]
|
||||
|
||||
model_dir = os.path.join(self._models_dir, model)
|
||||
if os.path.isdir(model_dir):
|
||||
|
@ -388,7 +433,9 @@ class TensorflowPlugin(Plugin):
|
|||
|
||||
model.input_labels = input_names or []
|
||||
model.output_labels = output_names or []
|
||||
self.models[name] = model
|
||||
|
||||
with self._lock_model(name):
|
||||
self.models[name] = model
|
||||
return model.get_config()
|
||||
|
||||
@action
|
||||
|
@ -520,7 +567,8 @@ class TensorflowPlugin(Plugin):
|
|||
**kwargs
|
||||
)
|
||||
|
||||
self.models[name] = model
|
||||
with self._lock_model(name):
|
||||
self.models[name] = model
|
||||
return model.get_config()
|
||||
|
||||
@staticmethod
|
||||
|
@ -551,10 +599,7 @@ class TensorflowPlugin(Plugin):
|
|||
model.name, size)
|
||||
|
||||
colors = input_shape[3:]
|
||||
assert colors, ('The model {} requires a tensor with at least 3 inputs in order to process images: ' +
|
||||
'[WIDTH, HEIGHT, COLORS]').format(model.name)
|
||||
|
||||
if colors[0] == 1:
|
||||
if len(colors) == 0 or colors[0] == 1:
|
||||
color_mode = 'grayscale'
|
||||
elif colors[0] == 3:
|
||||
color_mode = 'rgb'
|
||||
|
@ -664,7 +709,8 @@ class TensorflowPlugin(Plugin):
|
|||
"""
|
||||
Trains a model on a dataset for a fixed number of epochs.
|
||||
|
||||
:param model: Name of the existing model to be trained.
|
||||
:param model: Name of the model. It can be a folder name stored under ``<workdir>/models``, or an absolute path
|
||||
to a model directory or file (Tensorflow directories, Protobuf models and HDF5 files are supported).
|
||||
:param inputs: Input data. It can be:
|
||||
|
||||
- A numpy array (or array-like), or a list of arrays in case the model has multiple inputs.
|
||||
|
@ -818,7 +864,8 @@ class TensorflowPlugin(Plugin):
|
|||
"""
|
||||
Returns the loss value and metrics values for the model in test model.
|
||||
|
||||
:param model: Name of the existing model to be trained.
|
||||
:param model: Name of the model. It can be a folder name stored under ``<workdir>/models``, or an absolute path
|
||||
to a model directory or file (Tensorflow directories, Protobuf models and HDF5 files are supported).
|
||||
:param inputs: Input data. It can be:
|
||||
|
||||
- A numpy array (or array-like), or a list of arrays in case the model has multiple inputs.
|
||||
|
@ -922,7 +969,8 @@ class TensorflowPlugin(Plugin):
|
|||
"""
|
||||
Generates output predictions for the input samples.
|
||||
|
||||
:param model: Name of the existing model to be trained.
|
||||
:param model: Name of the model. It can be a folder name stored under ``<workdir>/models``, or an absolute path
|
||||
to a model directory or file (Tensorflow directories, Protobuf models and HDF5 files are supported).
|
||||
:param inputs: Input data. It can be:
|
||||
|
||||
- A numpy array (or array-like), or a list of arrays in case the model has multiple inputs.
|
||||
|
@ -1035,22 +1083,38 @@ class TensorflowPlugin(Plugin):
|
|||
:param overwrite: Overwrite the model files if they already exist.
|
||||
:param opts: Extra options to be passed to ``Model.save()``.
|
||||
"""
|
||||
assert model in self.models, 'No such model in memory: {}'.format(model)
|
||||
model_dir = os.path.join(self._models_dir, model)
|
||||
model = self.models[model]
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
labels = {}
|
||||
labels_file = os.path.join(model_dir, 'labels.json')
|
||||
model_name = model
|
||||
model_dir = None
|
||||
|
||||
if model.input_labels:
|
||||
labels['input'] = model.input_labels
|
||||
if model.output_labels:
|
||||
labels['output'] = model.output_labels
|
||||
if labels:
|
||||
with open(labels_file, 'w') as f:
|
||||
json.dump(labels, f)
|
||||
if os.path.isdir(os.path.join(self._work_dir, model_name)):
|
||||
model_dir = os.path.join(self._work_dir, model_name)
|
||||
else:
|
||||
model_name = os.path.abspath(os.path.expanduser(model_name))
|
||||
if os.path.isfile(model_name):
|
||||
model_dir = str(pathlib.Path(model_name).parent)
|
||||
elif os.path.isdir(model_name):
|
||||
model_dir = model_name
|
||||
|
||||
model.save(model_dir, overwrite=overwrite, options=opts)
|
||||
assert model_dir and model_name in self.models, 'No such model loaded: {}'.format(model)
|
||||
|
||||
with self._lock_model(model_name):
|
||||
model = self.models[model_name]
|
||||
labels = {}
|
||||
labels_file = os.path.join(model_dir, 'labels.json')
|
||||
|
||||
if hasattr(model, 'input_labels') and model.input_labels:
|
||||
labels['input'] = model.input_labels
|
||||
if hasattr(model, 'output_labels') and model.output_labels:
|
||||
if hasattr(labels, 'input'):
|
||||
labels['output'] = model.output_labels
|
||||
else:
|
||||
labels = model.output_labels
|
||||
|
||||
if labels:
|
||||
with open(labels_file, 'w') as f:
|
||||
json.dump(labels, f)
|
||||
|
||||
model.save(model_name if os.path.isfile(model_name) else model_dir, overwrite=overwrite, options=opts)
|
||||
|
||||
|
||||
# vim:sw=4:ts=4:et:
|
||||
|
|
Loading…
Add table
Reference in a new issue