Performance improvements when loading the Tensorflow plugin.

The Tensorflow module may take a few seconds to load the first time and
slow down the first scan of the plugins.

All the Tensorflow imports should therefore be placed close to where
they are used instead of being defined at the top of the module.
This commit is contained in:
Fabio Manganiello 2023-05-11 19:41:58 +02:00
parent f49ad4c349
commit cfedcd701e
Signed by: blacklight
GPG key ID: D90FBA7F76362774
2 changed files with 293 additions and 156 deletions

View file

@ -1,7 +1,6 @@
from typing import Dict, List, Union, Optional
import numpy as np
from tensorflow.keras.models import Model
from platypush.message.response import Response
@ -10,13 +9,18 @@ class TensorflowResponse(Response):
"""
Generic Tensorflow response.
"""
def __init__(self, *args, model: Model, model_name: Optional[str] = None, **kwargs):
def __init__(self, *args, model, model_name: Optional[str] = None, **kwargs):
"""
:param model: Name of the model.
"""
super().__init__(*args, output={
super().__init__(
*args,
output={
'model': model_name or model.name,
}, **kwargs)
},
**kwargs
)
self.model = model
@ -25,7 +29,14 @@ class TensorflowTrainResponse(TensorflowResponse):
"""
Tensorflow model fit/train response.
"""
def __init__(self, *args, epochs: List[int], history: Dict[str, List[Union[int, float]]], **kwargs):
def __init__(
self,
*args,
epochs: List[int],
history: Dict[str, List[Union[int, float]]],
**kwargs
):
"""
:param epochs: List of epoch indexes the model has been trained on.
:param history: Train history, as a ``metric -> [values]`` dictionary where each value in ``values`` is
@ -40,7 +51,14 @@ class TensorflowPredictResponse(TensorflowResponse):
"""
Tensorflow model prediction response.
"""
def __init__(self, *args, prediction: np.ndarray, output_labels: Optional[List[str]] = None, **kwargs):
def __init__(
self,
*args,
prediction: np.ndarray,
output_labels: Optional[List[str]] = None,
**kwargs
):
super().__init__(*args, **kwargs)
if output_labels and len(output_labels) == self.model.outputs[-1].shape[-1]:

View file

@ -9,17 +9,21 @@ from datetime import datetime
from typing import List, Dict, Any, Union, Optional, Tuple, Iterable
import numpy as np
from tensorflow.keras import Model
from tensorflow.keras.layers import Layer
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
from tensorflow.keras import utils
from platypush.config import Config
from platypush.context import get_bus
from platypush.message.event.tensorflow import TensorflowEpochStartedEvent, TensorflowEpochEndedEvent, \
TensorflowBatchStartedEvent, TensorflowBatchEndedEvent, TensorflowTrainStartedEvent, TensorflowTrainEndedEvent
from platypush.message.response.tensorflow import TensorflowTrainResponse, TensorflowPredictResponse
from platypush.message.event.tensorflow import (
TensorflowEpochStartedEvent,
TensorflowEpochEndedEvent,
TensorflowBatchStartedEvent,
TensorflowBatchEndedEvent,
TensorflowTrainStartedEvent,
TensorflowTrainEndedEvent,
)
from platypush.message.response.tensorflow import (
TensorflowTrainResponse,
TensorflowPredictResponse,
)
from platypush.plugins import Plugin, action
@ -55,7 +59,11 @@ class TensorflowPlugin(Plugin):
_image_extensions = ['jpg', 'jpeg', 'bmp', 'tiff', 'tif', 'png', 'gif']
_numpy_extensions = ['npy', 'npz']
_csv_extensions = ['csv', 'tsv']
_supported_data_file_extensions = [*_csv_extensions, *_numpy_extensions, *_image_extensions]
_supported_data_file_extensions = [
*_csv_extensions,
*_numpy_extensions,
*_image_extensions,
]
def __init__(self, workdir: Optional[str] = None, **kwargs):
"""
@ -63,11 +71,14 @@ class TensorflowPlugin(Plugin):
(default: PLATYPUSH_WORKDIR/tensorflow).
"""
super().__init__(**kwargs)
self.models: Dict[str, Model] = {}
self.models = {} # str -> Model
self._models_lock = threading.RLock()
self._model_locks: Dict[str, threading.RLock()] = {}
self._work_dir = os.path.abspath(os.path.expanduser(workdir)) if workdir else \
os.path.join(Config.get('workdir'), 'tensorflow')
self._work_dir = (
os.path.abspath(os.path.expanduser(workdir))
if workdir
else os.path.join(Config.get('workdir'), 'tensorflow')
)
self._models_dir = os.path.join(self._work_dir, 'models')
pathlib.Path(self._models_dir).mkdir(mode=0o755, exist_ok=True, parents=True)
@ -79,7 +90,7 @@ class TensorflowPlugin(Plugin):
self._model_locks[model_name] = threading.RLock()
try:
success = self._model_locks[model_name].acquire(blocking=True, timeout=30.)
success = self._model_locks[model_name].acquire(blocking=True, timeout=30.0)
assert success, 'Unable to acquire the model lock'
yield
finally:
@ -88,7 +99,9 @@ class TensorflowPlugin(Plugin):
except Exception as e:
self.logger.info(f'Model {model_name} lock release error: {e}')
def _load_model(self, model_name: str, reload: bool = False) -> Model:
def _load_model(self, model_name: str, reload: bool = False):
from tensorflow.keras.models import load_model
if model_name in self.models and not reload:
return self.models[model_name]
@ -133,49 +146,66 @@ class TensorflowPlugin(Plugin):
def _generate_callbacks(self, model: str):
from tensorflow.keras.callbacks import LambdaCallback
return [LambdaCallback(
return [
LambdaCallback(
on_epoch_begin=self.on_epoch_begin(model),
on_epoch_end=self.on_epoch_end(model),
on_batch_begin=self.on_batch_begin(model),
on_batch_end=self.on_batch_end(model),
on_train_begin=self.on_train_begin(model),
on_train_end=self.on_train_end(model),
)]
)
]
@staticmethod
def on_epoch_begin(model: str):
def callback(epoch: int, logs: Optional[dict] = None):
get_bus().post(TensorflowEpochStartedEvent(model=model, epoch=epoch, logs=logs))
get_bus().post(
TensorflowEpochStartedEvent(model=model, epoch=epoch, logs=logs)
)
return callback
@staticmethod
def on_epoch_end(model: str):
def callback(epoch: int, logs: Optional[dict] = None):
get_bus().post(TensorflowEpochEndedEvent(model=model, epoch=epoch, logs=logs))
get_bus().post(
TensorflowEpochEndedEvent(model=model, epoch=epoch, logs=logs)
)
return callback
@staticmethod
def on_batch_begin(model: str):
def callback(batch: int, logs: Optional[dict] = None):
get_bus().post(TensorflowBatchStartedEvent(model=model, batch=batch, logs=logs))
get_bus().post(
TensorflowBatchStartedEvent(model=model, batch=batch, logs=logs)
)
return callback
@staticmethod
def on_batch_end(model: str):
def callback(batch, logs: Optional[dict] = None):
get_bus().post(TensorflowBatchEndedEvent(model=model, batch=batch, logs=logs))
get_bus().post(
TensorflowBatchEndedEvent(model=model, batch=batch, logs=logs)
)
return callback
@staticmethod
def on_train_begin(model: str):
def callback(logs: Optional[dict] = None):
get_bus().post(TensorflowTrainStartedEvent(model=model, logs=logs))
return callback
@staticmethod
def on_train_end(model: str):
def callback(logs: Optional[dict] = None):
get_bus().post(TensorflowTrainEndedEvent(model=model, logs=logs))
return callback
@action
@ -220,20 +250,23 @@ class TensorflowPlugin(Plugin):
shutil.rmtree(model_dir)
@action
def create_network(self,
def create_network(
self,
name: str,
layers: List[Union[Layer, Dict[str, Any]]],
layers: list, # Layer or dict representation
input_names: Optional[List[str]] = None,
output_names: Optional[List[str]] = None,
optimizer: Optional[str] = 'rmsprop',
loss: Optional[Union[str, List[str], Dict[str, str]]] = None,
metrics: Optional[
Union[str, List[Union[str, List[str]]], Dict[str, Union[str, List[str]]]]] = None,
Union[str, List[Union[str, List[str]]], Dict[str, Union[str, List[str]]]]
] = None,
loss_weights: Optional[Union[List[float], Dict[str, float]]] = None,
sample_weight_mode: Optional[Union[str, List[str], Dict[str, str]]] = None,
weighted_metrics: Optional[List[str]] = None,
target_tensors=None,
**kwargs) -> Dict[str, Any]:
**kwargs,
) -> Dict[str, Any]:
"""
Create a neural network TensorFlow Keras model.
@ -410,6 +443,8 @@ class TensorflowPlugin(Plugin):
"""
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Layer
model = Sequential(name=name)
for layer in layers:
if not isinstance(layer, Layer):
@ -427,7 +462,7 @@ class TensorflowPlugin(Plugin):
sample_weight_mode=sample_weight_mode,
weighted_metrics=weighted_metrics,
target_tensors=target_tensors,
**kwargs
**kwargs,
)
model.input_labels = input_names or []
@ -438,7 +473,8 @@ class TensorflowPlugin(Plugin):
return model.get_config()
@action
def create_regression(self,
def create_regression(
self,
name: str,
units: int = 1,
input_names: Optional[List[str]] = None,
@ -452,12 +488,14 @@ class TensorflowPlugin(Plugin):
optimizer: Optional[str] = 'rmsprop',
loss: Optional[Union[str, List[str], Dict[str, str]]] = 'mse',
metrics: Optional[
Union[str, List[Union[str, List[str]]], Dict[str, Union[str, List[str]]]]] = None,
Union[str, List[Union[str, List[str]]], Dict[str, Union[str, List[str]]]]
] = None,
loss_weights: Optional[Union[List[float], Dict[str, float]]] = None,
sample_weight_mode: Optional[Union[str, List[str], Dict[str, str]]] = None,
weighted_metrics: Optional[List[str]] = None,
target_tensors=None,
**kwargs) -> Dict[str, Any]:
**kwargs,
) -> Dict[str, Any]:
"""
Create a linear/logistic regression model.
@ -534,6 +572,7 @@ class TensorflowPlugin(Plugin):
"""
from tensorflow.keras.experimental import LinearModel
model = LinearModel(
units=units,
activation=activation,
@ -542,7 +581,8 @@ class TensorflowPlugin(Plugin):
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
name=name)
name=name,
)
model.input_names = input_names or []
@ -563,7 +603,7 @@ class TensorflowPlugin(Plugin):
sample_weight_mode=sample_weight_mode,
weighted_metrics=weighted_metrics,
target_tensors=target_tensors,
**kwargs
**kwargs,
)
with self._lock_model(name):
@ -571,15 +611,17 @@ class TensorflowPlugin(Plugin):
return model.get_config()
@staticmethod
def _layer_from_dict(layer_type: str, *args, **kwargs) -> Layer:
def _layer_from_dict(layer_type: str, *args, **kwargs):
from tensorflow.keras import layers
cls = getattr(layers, layer_type)
assert issubclass(cls, Layer)
assert issubclass(cls, layers.Layer)
return cls(*args, **kwargs)
@staticmethod
def _get_csv_data(data_file: str) -> np.ndarray:
import pandas as pd
return pd.read_csv(data_file).to_numpy()
@staticmethod
@ -591,11 +633,16 @@ class TensorflowPlugin(Plugin):
return list(np.load(data_file).values()).pop()
@classmethod
def _get_image(cls, image_file: str, model: Model) -> np.ndarray:
def _get_image(cls, image_file: str, model) -> np.ndarray:
from tensorflow.keras.preprocessing import image
input_shape = model.inputs[0].shape
size = input_shape[1:3].as_list()
assert len(size) == 2, 'The model {} does not have enough dimensions to process an image (shape: {})'.format(
model.name, size)
assert (
len(size) == 2
), 'The model {} does not have enough dimensions to process an image (shape: {})'.format(
model.name, size
)
colors = input_shape[3:]
if len(colors) == 0 or colors[0] == 1:
@ -605,8 +652,10 @@ class TensorflowPlugin(Plugin):
elif colors[0] == 4:
color_mode = 'rgba'
else:
raise AssertionError('The input tensor should have either 1 (grayscale), 3 (rgb) or 4 (rgba) units. ' +
'Found: {}'.format(colors[0]))
raise AssertionError(
'The input tensor should have either 1 (grayscale), 3 (rgb) or 4 (rgba) units. '
+ 'Found: {}'.format(colors[0])
)
img = image.load_img(image_file, target_size=size, color_mode=color_mode)
data = image.img_to_array(img)
@ -616,11 +665,17 @@ class TensorflowPlugin(Plugin):
return data
@classmethod
def _get_dir(cls, directory: str, model: Model) -> Dict[str, Iterable]:
labels = [f for f in os.listdir(directory) if os.path.isdir(os.path.join(directory, f))]
assert set(model.output_labels) == set(labels),\
'The directory {dir} should contain exactly {n} subfolders named {names}'.format(
dir=directory, n=len(model.output_labels), names=model.output.labels)
def _get_dir(cls, directory: str, model) -> Dict[str, Iterable]:
labels = [
f
for f in os.listdir(directory)
if os.path.isdir(os.path.join(directory, f))
]
assert set(model.output_labels) == set(
labels
), 'The directory {dir} should contain exactly {n} subfolders named {names}'.format(
dir=directory, n=len(model.output_labels), names=model.output.labels
)
ret = {}
for label in labels:
@ -634,12 +689,17 @@ class TensorflowPlugin(Plugin):
return ret
@classmethod
def _get_outputs(cls, data: Union[str, np.ndarray, Iterable], model: Model) -> np.ndarray:
def _get_outputs(cls, data: Union[str, np.ndarray, Iterable], model) -> np.ndarray:
if isinstance(data, str):
if model.output_labels:
label_index = model.output_labels.index(data)
if label_index >= 0:
return np.array([1 if i == label_index else 0 for i in range(len(model.output_labels))])
return np.array(
[
1 if i == label_index else 0
for i in range(len(model.output_labels))
]
)
return np.array([data])
@ -649,10 +709,14 @@ class TensorflowPlugin(Plugin):
return data
@classmethod
def _get_data(cls, data: Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]], model: Model) \
-> Union[np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]]:
if isinstance(data, List) or isinstance(data, Tuple):
if len(data) and isinstance(data[0], str):
def _get_data(
cls,
data: Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]],
model,
) -> Union[np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]]:
from tensorflow.keras import utils
if isinstance(data, (list, tuple)) and len(data) and isinstance(data[0], str):
return np.array([cls._get_data(item, model) for item in data])
if not isinstance(data, str):
@ -660,15 +724,22 @@ class TensorflowPlugin(Plugin):
if data.startswith('http://') or data.startswith('https://'):
filename = '{timestamp}_{filename}'.format(
timestamp=datetime.now().timestamp(), filename=data.split('/')[-1])
timestamp=datetime.now().timestamp(), filename=data.split('/')[-1]
)
data_file = utils.get_file(filename, data)
else:
data_file = os.path.abspath(os.path.expanduser(data))
extensions = [ext for ext in cls._supported_data_file_extensions if data_file.endswith('.' + ext)]
extensions = [
ext
for ext in cls._supported_data_file_extensions
if data_file.endswith('.' + ext)
]
if os.path.isfile(data_file):
assert extensions, 'Unsupported type for file {}. Supported extensions: {}'.format(
assert (
extensions
), 'Unsupported type for file {}. Supported extensions: {}'.format(
data_file, cls._supported_data_file_extensions
)
@ -689,12 +760,19 @@ class TensorflowPlugin(Plugin):
return data
@classmethod
def _get_dataset(cls,
inputs: Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]],
outputs: Optional[Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]]],
model: Model) \
-> Tuple[Union[np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]],
Optional[Union[np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]]]]:
def _get_dataset(
cls,
inputs: Union[
str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]
],
outputs: Optional[
Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]]
],
model,
) -> Tuple[
Union[np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]],
Optional[Union[np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]]],
]:
inputs = cls._get_data(inputs, model)
if outputs:
outputs = cls._get_outputs(outputs, model)
@ -702,8 +780,18 @@ class TensorflowPlugin(Plugin):
pairs = []
for i, label in enumerate(model.output_labels):
data = inputs.get(label, [])
pairs.extend([(d, tuple(1 if i == j else 0 for j, _ in enumerate(model.output_labels)))
for d in data])
pairs.extend(
[
(
d,
tuple(
1 if i == j else 0
for j, _ in enumerate(model.output_labels)
),
)
for d in data
]
)
random.shuffle(pairs)
inputs = np.asarray([p[0] for p in pairs])
@ -712,14 +800,17 @@ class TensorflowPlugin(Plugin):
return inputs, outputs
@action
def train(self,
def train(
self,
model: str,
inputs: Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]],
inputs: Union[
str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]
],
outputs: Optional[Union[str, np.ndarray, Iterable]] = None,
batch_size: Optional[int] = None,
epochs: int = 1,
verbose: int = 1,
validation_split: float = 0.,
validation_split: float = 0.0,
validation_data: Optional[Tuple[Union[np.ndarray, Iterable]]] = None,
shuffle: Union[bool, str] = True,
class_weight: Optional[Dict[int, float]] = None,
@ -730,7 +821,8 @@ class TensorflowPlugin(Plugin):
validation_freq: int = 1,
max_queue_size: int = 10,
workers: int = 1,
use_multiprocessing: bool = False) -> TensorflowTrainResponse:
use_multiprocessing: bool = False,
) -> TensorflowTrainResponse:
"""
Trains a model on a dataset for a fixed number of epochs.
@ -783,7 +875,8 @@ class TensorflowPlugin(Plugin):
Fraction of the training data to be used as validation data. The model will set apart this fraction
of the training data, will not train on it, and will evaluate the loss and any model metrics on this data
at the end of each epoch. The validation data is selected from the last samples in the ``x`` and ``y``
data provided, before shuffling. Not supported when ``x`` is a dataset, generator or ``keras.utils.Sequence`` instance.
data provided, before shuffling. Not supported when ``x`` is a dataset, generator or
``keras.utils.Sequence`` instance.
:param validation_data: Data on which to evaluate the loss and any model metrics at the end of each epoch.
The model will not be trained on this data. ``validation_data`` will override ``validation_split``.
@ -872,12 +965,17 @@ class TensorflowPlugin(Plugin):
use_multiprocessing=use_multiprocessing,
)
return TensorflowTrainResponse(model=model, model_name=name, epochs=ret.epoch, history=ret.history)
return TensorflowTrainResponse(
model=model, model_name=name, epochs=ret.epoch, history=ret.history
)
@action
def evaluate(self,
def evaluate(
self,
model: str,
inputs: Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]],
inputs: Union[
str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]
],
outputs: Optional[Union[str, np.ndarray, Iterable]] = None,
batch_size: Optional[int] = None,
verbose: int = 1,
@ -885,7 +983,8 @@ class TensorflowPlugin(Plugin):
steps: Optional[int] = None,
max_queue_size: int = 10,
workers: int = 1,
use_multiprocessing: bool = False) -> Union[Dict[str, float], List[float]]:
use_multiprocessing: bool = False,
) -> Union[Dict[str, float], List[float]]:
"""
Returns the loss value and metrics values for the model in test model.
@ -916,10 +1015,10 @@ class TensorflowPlugin(Plugin):
- ``outputs`` doesn't have to be specified.
:param outputs: Target data. Like the input data `x`, it can be a numpy array (or array-like) or TensorFlow tensor(s).
It should be consistent with `x` (you cannot have Numpy inputs and tensor targets, or inversely).
If `x` is a dataset, generator, or `keras.utils.Sequence` instance, `y` should not be specified
(since targets will be obtained from `x`).
:param outputs: Target data. Like the input data `x`, it can be a numpy array (or array-like) or
TensorFlow tensor(s). It should be consistent with `x` (you cannot have Numpy inputs and tensor
targets, or inversely). If `x` is a dataset, generator, or `keras.utils.Sequence` instance, `y`
should not be specified (since targets will be obtained from `x`).
:param batch_size: Number of samples per gradient update. If unspecified, ``batch_size`` will default to 32.
Do not specify the ``batch_size`` if your data is in the form of symbolic tensors, datasets,
@ -972,7 +1071,7 @@ class TensorflowPlugin(Plugin):
callbacks=self._generate_callbacks(name),
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing
use_multiprocessing=use_multiprocessing,
)
ret = ret if isinstance(ret, list) else [ret]
@ -982,15 +1081,19 @@ class TensorflowPlugin(Plugin):
return {model.metrics_names[i]: value for i, value in enumerate(ret)}
@action
def predict(self,
def predict(
self,
model: str,
inputs: Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]],
inputs: Union[
str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]
],
batch_size: Optional[int] = None,
verbose: int = 0,
steps: Optional[int] = None,
max_queue_size: int = 10,
workers: int = 1,
use_multiprocessing: bool = False) -> TensorflowPredictResponse:
use_multiprocessing: bool = False,
) -> TensorflowPredictResponse:
"""
Generates output predictions for the input samples.
@ -1059,8 +1162,8 @@ class TensorflowPlugin(Plugin):
}
- For neural networks: ``outputs`` will contain the list of the output vector like in the case of
regression, and ``predictions`` will store the list of ``argmax`` (i.e. the index of the output unit with the
highest value) or their labels, if the model has output labels:
regression, and ``predictions`` will store the list of ``argmax`` (i.e. the index of the output unit
with the highest value) or their labels, if the model has output labels:
.. code-block:: json
@ -1080,9 +1183,14 @@ class TensorflowPlugin(Plugin):
name = model
model = self._load_model(model)
inputs = self._get_data(inputs, model)
if isinstance(inputs, np.ndarray) and \
len(model.inputs[0].shape) == len(inputs.shape) + 1 and \
(model.inputs[0].shape[0] is None or model.inputs[0].shape[0].value is None):
if (
isinstance(inputs, np.ndarray)
and len(model.inputs[0].shape) == len(inputs.shape) + 1
and (
model.inputs[0].shape[0] is None
or model.inputs[0].shape[0].value is None
)
):
inputs = np.asarray([inputs])
ret = model.predict(
@ -1093,11 +1201,15 @@ class TensorflowPlugin(Plugin):
callbacks=self._generate_callbacks(name),
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing
use_multiprocessing=use_multiprocessing,
)
return TensorflowPredictResponse(model=model, model_name=name, prediction=ret,
output_labels=model.output_labels)
return TensorflowPredictResponse(
model=model,
model_name=name,
prediction=ret,
output_labels=model.output_labels,
)
@action
def save(self, model: str, overwrite: bool = True, **opts) -> None:
@ -1112,7 +1224,10 @@ class TensorflowPlugin(Plugin):
model_name = model
model_dir = None
if os.path.isdir(os.path.join(self._models_dir, model_name)) or model_name in self.models:
if (
os.path.isdir(os.path.join(self._models_dir, model_name))
or model_name in self.models
):
model_dir = os.path.join(self._models_dir, model_name)
else:
model_file = os.path.abspath(os.path.expanduser(model_name))
@ -1141,7 +1256,11 @@ class TensorflowPlugin(Plugin):
with open(labels_file, 'w') as f:
json.dump(labels, f)
model.save(model_name if os.path.isfile(model_name) else model_dir, overwrite=overwrite, options=opts)
model.save(
model_name if os.path.isfile(model_name) else model_dir,
overwrite=overwrite,
options=opts,
)
# vim:sw=4:ts=4:et: