forked from platypush/platypush
Performance improvements when loading the Tensorflow plugin.
The Tensorflow module may take a few seconds to load the first time and slow down the first scan of the plugins. All the Tensorflow imports should therefore be placed close to where they are used instead of being defined at the top of the module.
This commit is contained in:
parent
f49ad4c349
commit
cfedcd701e
2 changed files with 293 additions and 156 deletions
|
@ -1,7 +1,6 @@
|
|||
from typing import Dict, List, Union, Optional
|
||||
|
||||
import numpy as np
|
||||
from tensorflow.keras.models import Model
|
||||
|
||||
from platypush.message.response import Response
|
||||
|
||||
|
@ -10,13 +9,18 @@ class TensorflowResponse(Response):
|
|||
"""
|
||||
Generic Tensorflow response.
|
||||
"""
|
||||
def __init__(self, *args, model: Model, model_name: Optional[str] = None, **kwargs):
|
||||
|
||||
def __init__(self, *args, model, model_name: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
:param model: Name of the model.
|
||||
"""
|
||||
super().__init__(*args, output={
|
||||
super().__init__(
|
||||
*args,
|
||||
output={
|
||||
'model': model_name or model.name,
|
||||
}, **kwargs)
|
||||
},
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.model = model
|
||||
|
||||
|
@ -25,7 +29,14 @@ class TensorflowTrainResponse(TensorflowResponse):
|
|||
"""
|
||||
Tensorflow model fit/train response.
|
||||
"""
|
||||
def __init__(self, *args, epochs: List[int], history: Dict[str, List[Union[int, float]]], **kwargs):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
epochs: List[int],
|
||||
history: Dict[str, List[Union[int, float]]],
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
:param epochs: List of epoch indexes the model has been trained on.
|
||||
:param history: Train history, as a ``metric -> [values]`` dictionary where each value in ``values`` is
|
||||
|
@ -40,7 +51,14 @@ class TensorflowPredictResponse(TensorflowResponse):
|
|||
"""
|
||||
Tensorflow model prediction response.
|
||||
"""
|
||||
def __init__(self, *args, prediction: np.ndarray, output_labels: Optional[List[str]] = None, **kwargs):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
prediction: np.ndarray,
|
||||
output_labels: Optional[List[str]] = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if output_labels and len(output_labels) == self.model.outputs[-1].shape[-1]:
|
||||
|
|
|
@ -9,17 +9,21 @@ from datetime import datetime
|
|||
from typing import List, Dict, Any, Union, Optional, Tuple, Iterable
|
||||
|
||||
import numpy as np
|
||||
from tensorflow.keras import Model
|
||||
from tensorflow.keras.layers import Layer
|
||||
from tensorflow.keras.models import load_model
|
||||
from tensorflow.keras.preprocessing import image
|
||||
from tensorflow.keras import utils
|
||||
|
||||
from platypush.config import Config
|
||||
from platypush.context import get_bus
|
||||
from platypush.message.event.tensorflow import TensorflowEpochStartedEvent, TensorflowEpochEndedEvent, \
|
||||
TensorflowBatchStartedEvent, TensorflowBatchEndedEvent, TensorflowTrainStartedEvent, TensorflowTrainEndedEvent
|
||||
from platypush.message.response.tensorflow import TensorflowTrainResponse, TensorflowPredictResponse
|
||||
from platypush.message.event.tensorflow import (
|
||||
TensorflowEpochStartedEvent,
|
||||
TensorflowEpochEndedEvent,
|
||||
TensorflowBatchStartedEvent,
|
||||
TensorflowBatchEndedEvent,
|
||||
TensorflowTrainStartedEvent,
|
||||
TensorflowTrainEndedEvent,
|
||||
)
|
||||
from platypush.message.response.tensorflow import (
|
||||
TensorflowTrainResponse,
|
||||
TensorflowPredictResponse,
|
||||
)
|
||||
from platypush.plugins import Plugin, action
|
||||
|
||||
|
||||
|
@ -55,7 +59,11 @@ class TensorflowPlugin(Plugin):
|
|||
_image_extensions = ['jpg', 'jpeg', 'bmp', 'tiff', 'tif', 'png', 'gif']
|
||||
_numpy_extensions = ['npy', 'npz']
|
||||
_csv_extensions = ['csv', 'tsv']
|
||||
_supported_data_file_extensions = [*_csv_extensions, *_numpy_extensions, *_image_extensions]
|
||||
_supported_data_file_extensions = [
|
||||
*_csv_extensions,
|
||||
*_numpy_extensions,
|
||||
*_image_extensions,
|
||||
]
|
||||
|
||||
def __init__(self, workdir: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
|
@ -63,11 +71,14 @@ class TensorflowPlugin(Plugin):
|
|||
(default: PLATYPUSH_WORKDIR/tensorflow).
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.models: Dict[str, Model] = {}
|
||||
self.models = {} # str -> Model
|
||||
self._models_lock = threading.RLock()
|
||||
self._model_locks: Dict[str, threading.RLock()] = {}
|
||||
self._work_dir = os.path.abspath(os.path.expanduser(workdir)) if workdir else \
|
||||
os.path.join(Config.get('workdir'), 'tensorflow')
|
||||
self._work_dir = (
|
||||
os.path.abspath(os.path.expanduser(workdir))
|
||||
if workdir
|
||||
else os.path.join(Config.get('workdir'), 'tensorflow')
|
||||
)
|
||||
|
||||
self._models_dir = os.path.join(self._work_dir, 'models')
|
||||
pathlib.Path(self._models_dir).mkdir(mode=0o755, exist_ok=True, parents=True)
|
||||
|
@ -79,7 +90,7 @@ class TensorflowPlugin(Plugin):
|
|||
self._model_locks[model_name] = threading.RLock()
|
||||
|
||||
try:
|
||||
success = self._model_locks[model_name].acquire(blocking=True, timeout=30.)
|
||||
success = self._model_locks[model_name].acquire(blocking=True, timeout=30.0)
|
||||
assert success, 'Unable to acquire the model lock'
|
||||
yield
|
||||
finally:
|
||||
|
@ -88,7 +99,9 @@ class TensorflowPlugin(Plugin):
|
|||
except Exception as e:
|
||||
self.logger.info(f'Model {model_name} lock release error: {e}')
|
||||
|
||||
def _load_model(self, model_name: str, reload: bool = False) -> Model:
|
||||
def _load_model(self, model_name: str, reload: bool = False):
|
||||
from tensorflow.keras.models import load_model
|
||||
|
||||
if model_name in self.models and not reload:
|
||||
return self.models[model_name]
|
||||
|
||||
|
@ -133,49 +146,66 @@ class TensorflowPlugin(Plugin):
|
|||
|
||||
def _generate_callbacks(self, model: str):
|
||||
from tensorflow.keras.callbacks import LambdaCallback
|
||||
return [LambdaCallback(
|
||||
|
||||
return [
|
||||
LambdaCallback(
|
||||
on_epoch_begin=self.on_epoch_begin(model),
|
||||
on_epoch_end=self.on_epoch_end(model),
|
||||
on_batch_begin=self.on_batch_begin(model),
|
||||
on_batch_end=self.on_batch_end(model),
|
||||
on_train_begin=self.on_train_begin(model),
|
||||
on_train_end=self.on_train_end(model),
|
||||
)]
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def on_epoch_begin(model: str):
|
||||
def callback(epoch: int, logs: Optional[dict] = None):
|
||||
get_bus().post(TensorflowEpochStartedEvent(model=model, epoch=epoch, logs=logs))
|
||||
get_bus().post(
|
||||
TensorflowEpochStartedEvent(model=model, epoch=epoch, logs=logs)
|
||||
)
|
||||
|
||||
return callback
|
||||
|
||||
@staticmethod
|
||||
def on_epoch_end(model: str):
|
||||
def callback(epoch: int, logs: Optional[dict] = None):
|
||||
get_bus().post(TensorflowEpochEndedEvent(model=model, epoch=epoch, logs=logs))
|
||||
get_bus().post(
|
||||
TensorflowEpochEndedEvent(model=model, epoch=epoch, logs=logs)
|
||||
)
|
||||
|
||||
return callback
|
||||
|
||||
@staticmethod
|
||||
def on_batch_begin(model: str):
|
||||
def callback(batch: int, logs: Optional[dict] = None):
|
||||
get_bus().post(TensorflowBatchStartedEvent(model=model, batch=batch, logs=logs))
|
||||
get_bus().post(
|
||||
TensorflowBatchStartedEvent(model=model, batch=batch, logs=logs)
|
||||
)
|
||||
|
||||
return callback
|
||||
|
||||
@staticmethod
|
||||
def on_batch_end(model: str):
|
||||
def callback(batch, logs: Optional[dict] = None):
|
||||
get_bus().post(TensorflowBatchEndedEvent(model=model, batch=batch, logs=logs))
|
||||
get_bus().post(
|
||||
TensorflowBatchEndedEvent(model=model, batch=batch, logs=logs)
|
||||
)
|
||||
|
||||
return callback
|
||||
|
||||
@staticmethod
|
||||
def on_train_begin(model: str):
|
||||
def callback(logs: Optional[dict] = None):
|
||||
get_bus().post(TensorflowTrainStartedEvent(model=model, logs=logs))
|
||||
|
||||
return callback
|
||||
|
||||
@staticmethod
|
||||
def on_train_end(model: str):
|
||||
def callback(logs: Optional[dict] = None):
|
||||
get_bus().post(TensorflowTrainEndedEvent(model=model, logs=logs))
|
||||
|
||||
return callback
|
||||
|
||||
@action
|
||||
|
@ -220,20 +250,23 @@ class TensorflowPlugin(Plugin):
|
|||
shutil.rmtree(model_dir)
|
||||
|
||||
@action
|
||||
def create_network(self,
|
||||
def create_network(
|
||||
self,
|
||||
name: str,
|
||||
layers: List[Union[Layer, Dict[str, Any]]],
|
||||
layers: list, # Layer or dict representation
|
||||
input_names: Optional[List[str]] = None,
|
||||
output_names: Optional[List[str]] = None,
|
||||
optimizer: Optional[str] = 'rmsprop',
|
||||
loss: Optional[Union[str, List[str], Dict[str, str]]] = None,
|
||||
metrics: Optional[
|
||||
Union[str, List[Union[str, List[str]]], Dict[str, Union[str, List[str]]]]] = None,
|
||||
Union[str, List[Union[str, List[str]]], Dict[str, Union[str, List[str]]]]
|
||||
] = None,
|
||||
loss_weights: Optional[Union[List[float], Dict[str, float]]] = None,
|
||||
sample_weight_mode: Optional[Union[str, List[str], Dict[str, str]]] = None,
|
||||
weighted_metrics: Optional[List[str]] = None,
|
||||
target_tensors=None,
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a neural network TensorFlow Keras model.
|
||||
|
||||
|
@ -410,6 +443,8 @@ class TensorflowPlugin(Plugin):
|
|||
|
||||
"""
|
||||
from tensorflow.keras import Sequential
|
||||
from tensorflow.keras.layers import Layer
|
||||
|
||||
model = Sequential(name=name)
|
||||
for layer in layers:
|
||||
if not isinstance(layer, Layer):
|
||||
|
@ -427,7 +462,7 @@ class TensorflowPlugin(Plugin):
|
|||
sample_weight_mode=sample_weight_mode,
|
||||
weighted_metrics=weighted_metrics,
|
||||
target_tensors=target_tensors,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
model.input_labels = input_names or []
|
||||
|
@ -438,7 +473,8 @@ class TensorflowPlugin(Plugin):
|
|||
return model.get_config()
|
||||
|
||||
@action
|
||||
def create_regression(self,
|
||||
def create_regression(
|
||||
self,
|
||||
name: str,
|
||||
units: int = 1,
|
||||
input_names: Optional[List[str]] = None,
|
||||
|
@ -452,12 +488,14 @@ class TensorflowPlugin(Plugin):
|
|||
optimizer: Optional[str] = 'rmsprop',
|
||||
loss: Optional[Union[str, List[str], Dict[str, str]]] = 'mse',
|
||||
metrics: Optional[
|
||||
Union[str, List[Union[str, List[str]]], Dict[str, Union[str, List[str]]]]] = None,
|
||||
Union[str, List[Union[str, List[str]]], Dict[str, Union[str, List[str]]]]
|
||||
] = None,
|
||||
loss_weights: Optional[Union[List[float], Dict[str, float]]] = None,
|
||||
sample_weight_mode: Optional[Union[str, List[str], Dict[str, str]]] = None,
|
||||
weighted_metrics: Optional[List[str]] = None,
|
||||
target_tensors=None,
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a linear/logistic regression model.
|
||||
|
||||
|
@ -534,6 +572,7 @@ class TensorflowPlugin(Plugin):
|
|||
|
||||
"""
|
||||
from tensorflow.keras.experimental import LinearModel
|
||||
|
||||
model = LinearModel(
|
||||
units=units,
|
||||
activation=activation,
|
||||
|
@ -542,7 +581,8 @@ class TensorflowPlugin(Plugin):
|
|||
bias_initializer=bias_initializer,
|
||||
kernel_regularizer=kernel_regularizer,
|
||||
bias_regularizer=bias_regularizer,
|
||||
name=name)
|
||||
name=name,
|
||||
)
|
||||
|
||||
model.input_names = input_names or []
|
||||
|
||||
|
@ -563,7 +603,7 @@ class TensorflowPlugin(Plugin):
|
|||
sample_weight_mode=sample_weight_mode,
|
||||
weighted_metrics=weighted_metrics,
|
||||
target_tensors=target_tensors,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
with self._lock_model(name):
|
||||
|
@ -571,15 +611,17 @@ class TensorflowPlugin(Plugin):
|
|||
return model.get_config()
|
||||
|
||||
@staticmethod
|
||||
def _layer_from_dict(layer_type: str, *args, **kwargs) -> Layer:
|
||||
def _layer_from_dict(layer_type: str, *args, **kwargs):
|
||||
from tensorflow.keras import layers
|
||||
|
||||
cls = getattr(layers, layer_type)
|
||||
assert issubclass(cls, Layer)
|
||||
assert issubclass(cls, layers.Layer)
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _get_csv_data(data_file: str) -> np.ndarray:
|
||||
import pandas as pd
|
||||
|
||||
return pd.read_csv(data_file).to_numpy()
|
||||
|
||||
@staticmethod
|
||||
|
@ -591,11 +633,16 @@ class TensorflowPlugin(Plugin):
|
|||
return list(np.load(data_file).values()).pop()
|
||||
|
||||
@classmethod
|
||||
def _get_image(cls, image_file: str, model: Model) -> np.ndarray:
|
||||
def _get_image(cls, image_file: str, model) -> np.ndarray:
|
||||
from tensorflow.keras.preprocessing import image
|
||||
|
||||
input_shape = model.inputs[0].shape
|
||||
size = input_shape[1:3].as_list()
|
||||
assert len(size) == 2, 'The model {} does not have enough dimensions to process an image (shape: {})'.format(
|
||||
model.name, size)
|
||||
assert (
|
||||
len(size) == 2
|
||||
), 'The model {} does not have enough dimensions to process an image (shape: {})'.format(
|
||||
model.name, size
|
||||
)
|
||||
|
||||
colors = input_shape[3:]
|
||||
if len(colors) == 0 or colors[0] == 1:
|
||||
|
@ -605,8 +652,10 @@ class TensorflowPlugin(Plugin):
|
|||
elif colors[0] == 4:
|
||||
color_mode = 'rgba'
|
||||
else:
|
||||
raise AssertionError('The input tensor should have either 1 (grayscale), 3 (rgb) or 4 (rgba) units. ' +
|
||||
'Found: {}'.format(colors[0]))
|
||||
raise AssertionError(
|
||||
'The input tensor should have either 1 (grayscale), 3 (rgb) or 4 (rgba) units. '
|
||||
+ 'Found: {}'.format(colors[0])
|
||||
)
|
||||
|
||||
img = image.load_img(image_file, target_size=size, color_mode=color_mode)
|
||||
data = image.img_to_array(img)
|
||||
|
@ -616,11 +665,17 @@ class TensorflowPlugin(Plugin):
|
|||
return data
|
||||
|
||||
@classmethod
|
||||
def _get_dir(cls, directory: str, model: Model) -> Dict[str, Iterable]:
|
||||
labels = [f for f in os.listdir(directory) if os.path.isdir(os.path.join(directory, f))]
|
||||
assert set(model.output_labels) == set(labels),\
|
||||
'The directory {dir} should contain exactly {n} subfolders named {names}'.format(
|
||||
dir=directory, n=len(model.output_labels), names=model.output.labels)
|
||||
def _get_dir(cls, directory: str, model) -> Dict[str, Iterable]:
|
||||
labels = [
|
||||
f
|
||||
for f in os.listdir(directory)
|
||||
if os.path.isdir(os.path.join(directory, f))
|
||||
]
|
||||
assert set(model.output_labels) == set(
|
||||
labels
|
||||
), 'The directory {dir} should contain exactly {n} subfolders named {names}'.format(
|
||||
dir=directory, n=len(model.output_labels), names=model.output.labels
|
||||
)
|
||||
|
||||
ret = {}
|
||||
for label in labels:
|
||||
|
@ -634,12 +689,17 @@ class TensorflowPlugin(Plugin):
|
|||
return ret
|
||||
|
||||
@classmethod
|
||||
def _get_outputs(cls, data: Union[str, np.ndarray, Iterable], model: Model) -> np.ndarray:
|
||||
def _get_outputs(cls, data: Union[str, np.ndarray, Iterable], model) -> np.ndarray:
|
||||
if isinstance(data, str):
|
||||
if model.output_labels:
|
||||
label_index = model.output_labels.index(data)
|
||||
if label_index >= 0:
|
||||
return np.array([1 if i == label_index else 0 for i in range(len(model.output_labels))])
|
||||
return np.array(
|
||||
[
|
||||
1 if i == label_index else 0
|
||||
for i in range(len(model.output_labels))
|
||||
]
|
||||
)
|
||||
|
||||
return np.array([data])
|
||||
|
||||
|
@ -649,10 +709,14 @@ class TensorflowPlugin(Plugin):
|
|||
return data
|
||||
|
||||
@classmethod
|
||||
def _get_data(cls, data: Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]], model: Model) \
|
||||
-> Union[np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]]:
|
||||
if isinstance(data, List) or isinstance(data, Tuple):
|
||||
if len(data) and isinstance(data[0], str):
|
||||
def _get_data(
|
||||
cls,
|
||||
data: Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]],
|
||||
model,
|
||||
) -> Union[np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]]:
|
||||
from tensorflow.keras import utils
|
||||
|
||||
if isinstance(data, (list, tuple)) and len(data) and isinstance(data[0], str):
|
||||
return np.array([cls._get_data(item, model) for item in data])
|
||||
|
||||
if not isinstance(data, str):
|
||||
|
@ -660,15 +724,22 @@ class TensorflowPlugin(Plugin):
|
|||
|
||||
if data.startswith('http://') or data.startswith('https://'):
|
||||
filename = '{timestamp}_{filename}'.format(
|
||||
timestamp=datetime.now().timestamp(), filename=data.split('/')[-1])
|
||||
timestamp=datetime.now().timestamp(), filename=data.split('/')[-1]
|
||||
)
|
||||
data_file = utils.get_file(filename, data)
|
||||
else:
|
||||
data_file = os.path.abspath(os.path.expanduser(data))
|
||||
|
||||
extensions = [ext for ext in cls._supported_data_file_extensions if data_file.endswith('.' + ext)]
|
||||
extensions = [
|
||||
ext
|
||||
for ext in cls._supported_data_file_extensions
|
||||
if data_file.endswith('.' + ext)
|
||||
]
|
||||
|
||||
if os.path.isfile(data_file):
|
||||
assert extensions, 'Unsupported type for file {}. Supported extensions: {}'.format(
|
||||
assert (
|
||||
extensions
|
||||
), 'Unsupported type for file {}. Supported extensions: {}'.format(
|
||||
data_file, cls._supported_data_file_extensions
|
||||
)
|
||||
|
||||
|
@ -689,12 +760,19 @@ class TensorflowPlugin(Plugin):
|
|||
return data
|
||||
|
||||
@classmethod
|
||||
def _get_dataset(cls,
|
||||
inputs: Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]],
|
||||
outputs: Optional[Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]]],
|
||||
model: Model) \
|
||||
-> Tuple[Union[np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]],
|
||||
Optional[Union[np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]]]]:
|
||||
def _get_dataset(
|
||||
cls,
|
||||
inputs: Union[
|
||||
str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]
|
||||
],
|
||||
outputs: Optional[
|
||||
Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]]
|
||||
],
|
||||
model,
|
||||
) -> Tuple[
|
||||
Union[np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]],
|
||||
Optional[Union[np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]]],
|
||||
]:
|
||||
inputs = cls._get_data(inputs, model)
|
||||
if outputs:
|
||||
outputs = cls._get_outputs(outputs, model)
|
||||
|
@ -702,8 +780,18 @@ class TensorflowPlugin(Plugin):
|
|||
pairs = []
|
||||
for i, label in enumerate(model.output_labels):
|
||||
data = inputs.get(label, [])
|
||||
pairs.extend([(d, tuple(1 if i == j else 0 for j, _ in enumerate(model.output_labels)))
|
||||
for d in data])
|
||||
pairs.extend(
|
||||
[
|
||||
(
|
||||
d,
|
||||
tuple(
|
||||
1 if i == j else 0
|
||||
for j, _ in enumerate(model.output_labels)
|
||||
),
|
||||
)
|
||||
for d in data
|
||||
]
|
||||
)
|
||||
|
||||
random.shuffle(pairs)
|
||||
inputs = np.asarray([p[0] for p in pairs])
|
||||
|
@ -712,14 +800,17 @@ class TensorflowPlugin(Plugin):
|
|||
return inputs, outputs
|
||||
|
||||
@action
|
||||
def train(self,
|
||||
def train(
|
||||
self,
|
||||
model: str,
|
||||
inputs: Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]],
|
||||
inputs: Union[
|
||||
str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]
|
||||
],
|
||||
outputs: Optional[Union[str, np.ndarray, Iterable]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
epochs: int = 1,
|
||||
verbose: int = 1,
|
||||
validation_split: float = 0.,
|
||||
validation_split: float = 0.0,
|
||||
validation_data: Optional[Tuple[Union[np.ndarray, Iterable]]] = None,
|
||||
shuffle: Union[bool, str] = True,
|
||||
class_weight: Optional[Dict[int, float]] = None,
|
||||
|
@ -730,7 +821,8 @@ class TensorflowPlugin(Plugin):
|
|||
validation_freq: int = 1,
|
||||
max_queue_size: int = 10,
|
||||
workers: int = 1,
|
||||
use_multiprocessing: bool = False) -> TensorflowTrainResponse:
|
||||
use_multiprocessing: bool = False,
|
||||
) -> TensorflowTrainResponse:
|
||||
"""
|
||||
Trains a model on a dataset for a fixed number of epochs.
|
||||
|
||||
|
@ -783,7 +875,8 @@ class TensorflowPlugin(Plugin):
|
|||
Fraction of the training data to be used as validation data. The model will set apart this fraction
|
||||
of the training data, will not train on it, and will evaluate the loss and any model metrics on this data
|
||||
at the end of each epoch. The validation data is selected from the last samples in the ``x`` and ``y``
|
||||
data provided, before shuffling. Not supported when ``x`` is a dataset, generator or ``keras.utils.Sequence`` instance.
|
||||
data provided, before shuffling. Not supported when ``x`` is a dataset, generator or
|
||||
``keras.utils.Sequence`` instance.
|
||||
|
||||
:param validation_data: Data on which to evaluate the loss and any model metrics at the end of each epoch.
|
||||
The model will not be trained on this data. ``validation_data`` will override ``validation_split``.
|
||||
|
@ -872,12 +965,17 @@ class TensorflowPlugin(Plugin):
|
|||
use_multiprocessing=use_multiprocessing,
|
||||
)
|
||||
|
||||
return TensorflowTrainResponse(model=model, model_name=name, epochs=ret.epoch, history=ret.history)
|
||||
return TensorflowTrainResponse(
|
||||
model=model, model_name=name, epochs=ret.epoch, history=ret.history
|
||||
)
|
||||
|
||||
@action
|
||||
def evaluate(self,
|
||||
def evaluate(
|
||||
self,
|
||||
model: str,
|
||||
inputs: Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]],
|
||||
inputs: Union[
|
||||
str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]
|
||||
],
|
||||
outputs: Optional[Union[str, np.ndarray, Iterable]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
verbose: int = 1,
|
||||
|
@ -885,7 +983,8 @@ class TensorflowPlugin(Plugin):
|
|||
steps: Optional[int] = None,
|
||||
max_queue_size: int = 10,
|
||||
workers: int = 1,
|
||||
use_multiprocessing: bool = False) -> Union[Dict[str, float], List[float]]:
|
||||
use_multiprocessing: bool = False,
|
||||
) -> Union[Dict[str, float], List[float]]:
|
||||
"""
|
||||
Returns the loss value and metrics values for the model in test model.
|
||||
|
||||
|
@ -916,10 +1015,10 @@ class TensorflowPlugin(Plugin):
|
|||
- ``outputs`` doesn't have to be specified.
|
||||
|
||||
|
||||
:param outputs: Target data. Like the input data `x`, it can be a numpy array (or array-like) or TensorFlow tensor(s).
|
||||
It should be consistent with `x` (you cannot have Numpy inputs and tensor targets, or inversely).
|
||||
If `x` is a dataset, generator, or `keras.utils.Sequence` instance, `y` should not be specified
|
||||
(since targets will be obtained from `x`).
|
||||
:param outputs: Target data. Like the input data `x`, it can be a numpy array (or array-like) or
|
||||
TensorFlow tensor(s). It should be consistent with `x` (you cannot have Numpy inputs and tensor
|
||||
targets, or inversely). If `x` is a dataset, generator, or `keras.utils.Sequence` instance, `y`
|
||||
should not be specified (since targets will be obtained from `x`).
|
||||
|
||||
:param batch_size: Number of samples per gradient update. If unspecified, ``batch_size`` will default to 32.
|
||||
Do not specify the ``batch_size`` if your data is in the form of symbolic tensors, datasets,
|
||||
|
@ -972,7 +1071,7 @@ class TensorflowPlugin(Plugin):
|
|||
callbacks=self._generate_callbacks(name),
|
||||
max_queue_size=max_queue_size,
|
||||
workers=workers,
|
||||
use_multiprocessing=use_multiprocessing
|
||||
use_multiprocessing=use_multiprocessing,
|
||||
)
|
||||
|
||||
ret = ret if isinstance(ret, list) else [ret]
|
||||
|
@ -982,15 +1081,19 @@ class TensorflowPlugin(Plugin):
|
|||
return {model.metrics_names[i]: value for i, value in enumerate(ret)}
|
||||
|
||||
@action
|
||||
def predict(self,
|
||||
def predict(
|
||||
self,
|
||||
model: str,
|
||||
inputs: Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]],
|
||||
inputs: Union[
|
||||
str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]
|
||||
],
|
||||
batch_size: Optional[int] = None,
|
||||
verbose: int = 0,
|
||||
steps: Optional[int] = None,
|
||||
max_queue_size: int = 10,
|
||||
workers: int = 1,
|
||||
use_multiprocessing: bool = False) -> TensorflowPredictResponse:
|
||||
use_multiprocessing: bool = False,
|
||||
) -> TensorflowPredictResponse:
|
||||
"""
|
||||
Generates output predictions for the input samples.
|
||||
|
||||
|
@ -1059,8 +1162,8 @@ class TensorflowPlugin(Plugin):
|
|||
}
|
||||
|
||||
- For neural networks: ``outputs`` will contain the list of the output vector like in the case of
|
||||
regression, and ``predictions`` will store the list of ``argmax`` (i.e. the index of the output unit with the
|
||||
highest value) or their labels, if the model has output labels:
|
||||
regression, and ``predictions`` will store the list of ``argmax`` (i.e. the index of the output unit
|
||||
with the highest value) or their labels, if the model has output labels:
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
|
@ -1080,9 +1183,14 @@ class TensorflowPlugin(Plugin):
|
|||
name = model
|
||||
model = self._load_model(model)
|
||||
inputs = self._get_data(inputs, model)
|
||||
if isinstance(inputs, np.ndarray) and \
|
||||
len(model.inputs[0].shape) == len(inputs.shape) + 1 and \
|
||||
(model.inputs[0].shape[0] is None or model.inputs[0].shape[0].value is None):
|
||||
if (
|
||||
isinstance(inputs, np.ndarray)
|
||||
and len(model.inputs[0].shape) == len(inputs.shape) + 1
|
||||
and (
|
||||
model.inputs[0].shape[0] is None
|
||||
or model.inputs[0].shape[0].value is None
|
||||
)
|
||||
):
|
||||
inputs = np.asarray([inputs])
|
||||
|
||||
ret = model.predict(
|
||||
|
@ -1093,11 +1201,15 @@ class TensorflowPlugin(Plugin):
|
|||
callbacks=self._generate_callbacks(name),
|
||||
max_queue_size=max_queue_size,
|
||||
workers=workers,
|
||||
use_multiprocessing=use_multiprocessing
|
||||
use_multiprocessing=use_multiprocessing,
|
||||
)
|
||||
|
||||
return TensorflowPredictResponse(model=model, model_name=name, prediction=ret,
|
||||
output_labels=model.output_labels)
|
||||
return TensorflowPredictResponse(
|
||||
model=model,
|
||||
model_name=name,
|
||||
prediction=ret,
|
||||
output_labels=model.output_labels,
|
||||
)
|
||||
|
||||
@action
|
||||
def save(self, model: str, overwrite: bool = True, **opts) -> None:
|
||||
|
@ -1112,7 +1224,10 @@ class TensorflowPlugin(Plugin):
|
|||
model_name = model
|
||||
model_dir = None
|
||||
|
||||
if os.path.isdir(os.path.join(self._models_dir, model_name)) or model_name in self.models:
|
||||
if (
|
||||
os.path.isdir(os.path.join(self._models_dir, model_name))
|
||||
or model_name in self.models
|
||||
):
|
||||
model_dir = os.path.join(self._models_dir, model_name)
|
||||
else:
|
||||
model_file = os.path.abspath(os.path.expanduser(model_name))
|
||||
|
@ -1141,7 +1256,11 @@ class TensorflowPlugin(Plugin):
|
|||
with open(labels_file, 'w') as f:
|
||||
json.dump(labels, f)
|
||||
|
||||
model.save(model_name if os.path.isfile(model_name) else model_dir, overwrite=overwrite, options=opts)
|
||||
model.save(
|
||||
model_name if os.path.isfile(model_name) else model_dir,
|
||||
overwrite=overwrite,
|
||||
options=opts,
|
||||
)
|
||||
|
||||
|
||||
# vim:sw=4:ts=4:et:
|
||||
|
|
Loading…
Reference in a new issue