More flexible module loading and better lock management for models in Tensorflow plugin

This commit is contained in:
Fabio Manganiello 2020-10-01 17:41:12 +02:00
parent 9b23ab7015
commit 287b6303ae

View file

@ -1,8 +1,10 @@
import json import json
import os import os
import pathlib
import random import random
import shutil import shutil
import threading import threading
from contextlib import contextmanager
from datetime import datetime from datetime import datetime
from typing import List, Dict, Any, Union, Optional, Tuple, Iterable from typing import List, Dict, Any, Union, Optional, Tuple, Iterable
@ -57,36 +59,76 @@ class TensorflowPlugin(Plugin):
def __init__(self, workdir: Optional[str] = None, **kwargs): def __init__(self, workdir: Optional[str] = None, **kwargs):
""" """
:param workdir: Working directory for TensorFlow, where models will be stored :param workdir: Working directory for TensorFlow, where models will be stored and looked up by default
(default: PLATYPUSH_WORKDIR/tensorflow). (default: PLATYPUSH_WORKDIR/tensorflow).
""" """
super().__init__(**kwargs) super().__init__(**kwargs)
self.models: Dict[str, Model] = {} self.models: Dict[str, Model] = {}
self._models_lock = threading.RLock()
self._model_locks: Dict[str, threading.RLock()] = {} self._model_locks: Dict[str, threading.RLock()] = {}
self._work_dir = os.path.abspath(os.path.expanduser(workdir)) if workdir else \ self._work_dir = os.path.abspath(os.path.expanduser(workdir)) if workdir else \
os.path.join(Config.get('workdir'), 'tensorflow') os.path.join(Config.get('workdir'), 'tensorflow')
self._models_dir = os.path.join(self._work_dir, 'models') self._models_dir = os.path.join(self._work_dir, 'models')
os.makedirs(self._models_dir, mode=0o755, exist_ok=True) pathlib.Path(self._models_dir).mkdir(mode=0o755, exist_ok=True, parents=True)
def _load_model(self, name: str, reload: bool = False) -> Model: @contextmanager
if name in self.models and not reload: def _lock_model(self, model_name: str):
return self.models[name] with self._models_lock:
if model_name not in self._model_locks:
self._model_locks[model_name] = threading.RLock()
model_dir = os.path.join(self._models_dir, name) try:
assert os.path.isdir(model_dir), 'The model {} does not exist'.format(name) success = self._model_locks[model_name].acquire(blocking=True, timeout=30.)
assert success, 'Unable to acquire the model lock'
yield
finally:
# noinspection PyBroadException
try:
self._model_locks[model_name].release()
except:
pass
def _load_model(self, model_name: str, reload: bool = False) -> Model:
if model_name in self.models and not reload:
return self.models[model_name]
model = None
model_dir = None
if os.path.isdir(os.path.join(self._models_dir, model_name)):
model_dir = os.path.join(self._models_dir, model_name)
model = load_model(model_dir) model = load_model(model_dir)
else:
model_name = os.path.abspath(os.path.expanduser(model_name))
if model_name in self.models and not reload:
return self.models[model_name]
if os.path.isfile(model_name):
model_dir = str(pathlib.Path(model_name).parent)
model = load_model(model_name)
elif os.path.isdir(model_name):
model_dir = model_name
model = load_model(model_dir)
assert model, 'Could not find model: {}'.format(model_name)
model.input_labels = [] model.input_labels = []
model.output_labels = [] model.output_labels = []
labels_file = os.path.join(model_dir, 'labels.json') labels_file = os.path.join(model_dir, 'labels.json')
if os.path.isfile(labels_file): if os.path.isfile(labels_file):
with open(labels_file, 'r') as f: with open(labels_file, 'r') as f:
labels = json.load(f) labels = json.load(f)
if isinstance(labels, dict):
if 'input' in labels: if 'input' in labels:
model.input_labels = labels['input'] model.input_labels = labels['input']
if 'output' in labels: if 'output' in labels:
model.output_labels = labels['output'] model.output_labels = labels['output']
elif hasattr(labels, '__iter__'):
model.output_labels = labels
with self._lock_model(model_name):
self.models[model_name] = model
return model return model
@ -142,7 +184,8 @@ class TensorflowPlugin(Plugin):
""" """
(Re)-load a model from the file system. (Re)-load a model from the file system.
:param model: Name of the model. Must be a folder name stored under ``<workdir>/models``. :param model: Name of the model. It can be a folder name stored under ``<workdir>/models``, or an absolute path
to a model directory or file (Tensorflow directories, Protobuf models and HDF5 files are supported).
:param reload: If ``True``, the model will be reloaded from the filesystem even if it's been already :param reload: If ``True``, the model will be reloaded from the filesystem even if it's been already
loaded, otherwise the model currently in memory will be kept (default: ``False``). loaded, otherwise the model currently in memory will be kept (default: ``False``).
:return: The model configuration. :return: The model configuration.
@ -157,6 +200,7 @@ class TensorflowPlugin(Plugin):
:param model: Name of the model. :param model: Name of the model.
""" """
with self._lock_model(model):
assert model in self.models, 'The model {} is not loaded'.format(model) assert model in self.models, 'The model {} is not loaded'.format(model)
del self.models[model] del self.models[model]
@ -168,6 +212,7 @@ class TensorflowPlugin(Plugin):
:param model: Name of the model. :param model: Name of the model.
""" """
with self._lock_model(model):
if model in self.models: if model in self.models:
del self.models[model] del self.models[model]
@ -388,6 +433,8 @@ class TensorflowPlugin(Plugin):
model.input_labels = input_names or [] model.input_labels = input_names or []
model.output_labels = output_names or [] model.output_labels = output_names or []
with self._lock_model(name):
self.models[name] = model self.models[name] = model
return model.get_config() return model.get_config()
@ -520,6 +567,7 @@ class TensorflowPlugin(Plugin):
**kwargs **kwargs
) )
with self._lock_model(name):
self.models[name] = model self.models[name] = model
return model.get_config() return model.get_config()
@ -551,10 +599,7 @@ class TensorflowPlugin(Plugin):
model.name, size) model.name, size)
colors = input_shape[3:] colors = input_shape[3:]
assert colors, ('The model {} requires a tensor with at least 3 inputs in order to process images: ' + if len(colors) == 0 or colors[0] == 1:
'[WIDTH, HEIGHT, COLORS]').format(model.name)
if colors[0] == 1:
color_mode = 'grayscale' color_mode = 'grayscale'
elif colors[0] == 3: elif colors[0] == 3:
color_mode = 'rgb' color_mode = 'rgb'
@ -664,7 +709,8 @@ class TensorflowPlugin(Plugin):
""" """
Trains a model on a dataset for a fixed number of epochs. Trains a model on a dataset for a fixed number of epochs.
:param model: Name of the existing model to be trained. :param model: Name of the model. It can be a folder name stored under ``<workdir>/models``, or an absolute path
to a model directory or file (Tensorflow directories, Protobuf models and HDF5 files are supported).
:param inputs: Input data. It can be: :param inputs: Input data. It can be:
- A numpy array (or array-like), or a list of arrays in case the model has multiple inputs. - A numpy array (or array-like), or a list of arrays in case the model has multiple inputs.
@ -818,7 +864,8 @@ class TensorflowPlugin(Plugin):
""" """
Returns the loss value and metrics values for the model in test model. Returns the loss value and metrics values for the model in test model.
:param model: Name of the existing model to be trained. :param model: Name of the model. It can be a folder name stored under ``<workdir>/models``, or an absolute path
to a model directory or file (Tensorflow directories, Protobuf models and HDF5 files are supported).
:param inputs: Input data. It can be: :param inputs: Input data. It can be:
- A numpy array (or array-like), or a list of arrays in case the model has multiple inputs. - A numpy array (or array-like), or a list of arrays in case the model has multiple inputs.
@ -922,7 +969,8 @@ class TensorflowPlugin(Plugin):
""" """
Generates output predictions for the input samples. Generates output predictions for the input samples.
:param model: Name of the existing model to be trained. :param model: Name of the model. It can be a folder name stored under ``<workdir>/models``, or an absolute path
to a model directory or file (Tensorflow directories, Protobuf models and HDF5 files are supported).
:param inputs: Input data. It can be: :param inputs: Input data. It can be:
- A numpy array (or array-like), or a list of arrays in case the model has multiple inputs. - A numpy array (or array-like), or a list of arrays in case the model has multiple inputs.
@ -1035,22 +1083,38 @@ class TensorflowPlugin(Plugin):
:param overwrite: Overwrite the model files if they already exist. :param overwrite: Overwrite the model files if they already exist.
:param opts: Extra options to be passed to ``Model.save()``. :param opts: Extra options to be passed to ``Model.save()``.
""" """
assert model in self.models, 'No such model in memory: {}'.format(model) model_name = model
model_dir = os.path.join(self._models_dir, model) model_dir = None
model = self.models[model]
os.makedirs(model_dir, exist_ok=True) if os.path.isdir(os.path.join(self._work_dir, model_name)):
model_dir = os.path.join(self._work_dir, model_name)
else:
model_name = os.path.abspath(os.path.expanduser(model_name))
if os.path.isfile(model_name):
model_dir = str(pathlib.Path(model_name).parent)
elif os.path.isdir(model_name):
model_dir = model_name
assert model_dir and model_name in self.models, 'No such model loaded: {}'.format(model)
with self._lock_model(model_name):
model = self.models[model_name]
labels = {} labels = {}
labels_file = os.path.join(model_dir, 'labels.json') labels_file = os.path.join(model_dir, 'labels.json')
if model.input_labels: if hasattr(model, 'input_labels') and model.input_labels:
labels['input'] = model.input_labels labels['input'] = model.input_labels
if model.output_labels: if hasattr(model, 'output_labels') and model.output_labels:
if hasattr(labels, 'input'):
labels['output'] = model.output_labels labels['output'] = model.output_labels
else:
labels = model.output_labels
if labels: if labels:
with open(labels_file, 'w') as f: with open(labels_file, 'w') as f:
json.dump(labels, f) json.dump(labels, f)
model.save(model_dir, overwrite=overwrite, options=opts) model.save(model_name if os.path.isfile(model_name) else model_dir, overwrite=overwrite, options=opts)
# vim:sw=4:ts=4:et: # vim:sw=4:ts=4:et: