Implemented extensive support for neural networks, images and directories [closes #121]
This commit is contained in:
parent
1f1fefca9d
commit
50e372be36
8 changed files with 289 additions and 75 deletions
|
@ -48,6 +48,7 @@ Events
|
|||
platypush/events/serial.rst
|
||||
platypush/events/sound.rst
|
||||
platypush/events/stt.rst
|
||||
platypush/events/tensorflow.rst
|
||||
platypush/events/todoist.rst
|
||||
platypush/events/torrent.rst
|
||||
platypush/events/travisci.rst
|
||||
|
|
5
docs/source/platypush/events/tensorflow.rst
Normal file
5
docs/source/platypush/events/tensorflow.rst
Normal file
|
@ -0,0 +1,5 @@
|
|||
``platypush.message.event.tensorflow``
|
||||
======================================
|
||||
|
||||
.. automodule:: platypush.message.event.tensorflow
|
||||
:members:
|
5
docs/source/platypush/plugins/tensorflow.rst
Normal file
5
docs/source/platypush/plugins/tensorflow.rst
Normal file
|
@ -0,0 +1,5 @@
|
|||
``platypush.plugins.tensorflow``
|
||||
================================
|
||||
|
||||
.. automodule:: platypush.plugins.tensorflow
|
||||
:members:
|
5
docs/source/platypush/responses/tensorflow.rst
Normal file
5
docs/source/platypush/responses/tensorflow.rst
Normal file
|
@ -0,0 +1,5 @@
|
|||
``platypush.message.response.tensorflow``
|
||||
=========================================
|
||||
|
||||
.. automodule:: platypush.message.response.tensorflow
|
||||
:members:
|
|
@ -104,6 +104,7 @@ Plugins
|
|||
platypush/plugins/switch.wemo.rst
|
||||
platypush/plugins/system.rst
|
||||
platypush/plugins/tcp.rst
|
||||
platypush/plugins/tensorflow.rst
|
||||
platypush/plugins/todoist.rst
|
||||
platypush/plugins/torrent.rst
|
||||
platypush/plugins/travisci.rst
|
||||
|
|
|
@ -18,6 +18,7 @@ Responses
|
|||
platypush/responses/qrcode.rst
|
||||
platypush/responses/stt.rst
|
||||
platypush/responses/system.rst
|
||||
platypush/responses/tensorflow.rst
|
||||
platypush/responses/todoist.rst
|
||||
platypush/responses/trello.rst
|
||||
platypush/responses/weather.buienradar.rst
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
from typing import Dict, List, Union
|
||||
from typing import Dict, List, Union, Optional
|
||||
|
||||
import numpy as np
|
||||
from tensorflow.keras.models import Model
|
||||
|
||||
from platypush.message.response import Response
|
||||
|
||||
|
@ -7,14 +10,16 @@ class TensorflowResponse(Response):
|
|||
"""
|
||||
Generic Tensorflow response.
|
||||
"""
|
||||
def __init__(self, *args, model: str, **kwargs):
|
||||
def __init__(self, *args, model: Model, **kwargs):
|
||||
"""
|
||||
:param model: Name of the model.
|
||||
"""
|
||||
super().__init__(*args, output={
|
||||
'model': model,
|
||||
'model': model.name,
|
||||
}, **kwargs)
|
||||
|
||||
self.model = model
|
||||
|
||||
|
||||
class TensorflowTrainResponse(TensorflowResponse):
|
||||
"""
|
||||
|
@ -31,4 +36,27 @@ class TensorflowTrainResponse(TensorflowResponse):
|
|||
self.output['history'] = history
|
||||
|
||||
|
||||
class TensorflowPredictResponse(TensorflowResponse):
|
||||
"""
|
||||
Tensorflow model prediction response.
|
||||
"""
|
||||
def __init__(self, *args, prediction: np.ndarray, output_labels: Optional[List[str]] = None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if output_labels and len(output_labels) == self.model.outputs[-1].shape[-1]:
|
||||
self.output['values'] = [
|
||||
{output_labels[i]: value for i, value in enumerate(p)}
|
||||
for p in prediction
|
||||
]
|
||||
else:
|
||||
self.output['values'] = prediction
|
||||
|
||||
if self.model.__class__.__name__ != 'LinearModel':
|
||||
prediction = [int(np.argmax(p)) for p in prediction]
|
||||
if output_labels:
|
||||
self.output['labels'] = [output_labels[p] for p in prediction]
|
||||
else:
|
||||
self.output['labels'] = prediction
|
||||
|
||||
|
||||
# vim:sw=4:ts=4:et:
|
||||
|
|
|
@ -1,18 +1,23 @@
|
|||
import json
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Union, Optional, Tuple, Iterable
|
||||
|
||||
import numpy as np
|
||||
from tensorflow.keras import Model
|
||||
from tensorflow.keras.layers import Layer
|
||||
from tensorflow.keras.models import load_model
|
||||
from tensorflow.keras.preprocessing import image
|
||||
from tensorflow.keras import utils
|
||||
|
||||
from platypush.config import Config
|
||||
from platypush.context import get_bus
|
||||
from platypush.message.event.tensorflow import TensorflowEpochStartedEvent, TensorflowEpochEndedEvent, \
|
||||
TensorflowBatchStartedEvent, TensorflowBatchEndedEvent, TensorflowTrainStartedEvent, TensorflowTrainEndedEvent
|
||||
from platypush.message.response.tensorflow import TensorflowTrainResponse
|
||||
from platypush.message.response.tensorflow import TensorflowTrainResponse, TensorflowPredictResponse
|
||||
from platypush.plugins import Plugin, action
|
||||
|
||||
|
||||
|
@ -45,7 +50,10 @@ class TensorflowPlugin(Plugin):
|
|||
|
||||
"""
|
||||
|
||||
_supported_data_file_extensions = ['npy', 'npz', 'csv']
|
||||
_image_extensions = ['jpg', 'jpeg', 'bmp', 'tiff', 'tif', 'png', 'gif']
|
||||
_numpy_extensions = ['npy', 'npz']
|
||||
_csv_extensions = ['csv', 'tsv']
|
||||
_supported_data_file_extensions = [*_csv_extensions, *_numpy_extensions, *_image_extensions]
|
||||
|
||||
def __init__(self, workdir: str = os.path.join(Config.get('workdir'), 'tensorflow'), **kwargs):
|
||||
"""
|
||||
|
@ -65,7 +73,20 @@ class TensorflowPlugin(Plugin):
|
|||
|
||||
model_dir = os.path.join(self._models_dir, name)
|
||||
assert os.path.isdir(model_dir), 'The model {} does not exist'.format(name)
|
||||
return load_model(model_dir)
|
||||
model = load_model(model_dir)
|
||||
model.input_labels = []
|
||||
model.output_labels = []
|
||||
|
||||
labels_file = os.path.join(model_dir, 'labels.json')
|
||||
if os.path.isfile(labels_file):
|
||||
with open(labels_file, 'r') as f:
|
||||
labels = json.load(f)
|
||||
if 'input' in labels:
|
||||
model.input_labels = labels['input']
|
||||
if 'output' in labels:
|
||||
model.output_labels = labels['output']
|
||||
|
||||
return model
|
||||
|
||||
def _generate_callbacks(self, model: str):
|
||||
from tensorflow.keras.callbacks import LambdaCallback
|
||||
|
@ -115,40 +136,40 @@ class TensorflowPlugin(Plugin):
|
|||
return callback
|
||||
|
||||
@action
|
||||
def load(self, name: str, reload: bool = False) -> Dict[str, Any]:
|
||||
def load(self, model: str, reload: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
(Re)-load a model from the file system.
|
||||
|
||||
:param name: Name of the model. Must be a folder name stored under ``<workdir>/models``.
|
||||
:param model: Name of the model. Must be a folder name stored under ``<workdir>/models``.
|
||||
:param reload: If ``True``, the model will be reloaded from the filesystem even if it's been already
|
||||
loaded, otherwise the model currently in memory will be kept (default: ``False``).
|
||||
:return: The model configuration.
|
||||
"""
|
||||
model = self._load_model(name, reload=reload)
|
||||
model = self._load_model(model, reload=reload)
|
||||
return model.get_config()
|
||||
|
||||
@action
|
||||
def unload(self, name: str) -> None:
|
||||
def unload(self, model: str) -> None:
|
||||
"""
|
||||
Remove a loaded model from memory.
|
||||
|
||||
:param name: Name of the model.
|
||||
:param model: Name of the model.
|
||||
"""
|
||||
assert name in self.models, 'The model {} is not loaded'.format(name)
|
||||
del self.models[name]
|
||||
assert model in self.models, 'The model {} is not loaded'.format(model)
|
||||
del self.models[model]
|
||||
|
||||
@action
|
||||
def remove(self, name: str) -> None:
|
||||
def remove(self, model: str) -> None:
|
||||
"""
|
||||
Unload a module and, if stored on the filesystem, remove its resource files as well.
|
||||
WARNING: This operation is not reversible.
|
||||
|
||||
:param name: Name of the model.
|
||||
:param model: Name of the model.
|
||||
"""
|
||||
if name in self.models:
|
||||
del self.models[name]
|
||||
if model in self.models:
|
||||
del self.models[model]
|
||||
|
||||
model_dir = os.path.join(self._models_dir, name)
|
||||
model_dir = os.path.join(self._models_dir, model)
|
||||
if os.path.isdir(model_dir):
|
||||
shutil.rmtree(model_dir)
|
||||
|
||||
|
@ -205,7 +226,7 @@ class TensorflowPlugin(Plugin):
|
|||
]
|
||||
|
||||
:param input_names: List of names for the input units (default: TensorFlow name auto-assign logic).
|
||||
:param output_names: List of names for the output units (default: TensorFlow name auto-assign logic).
|
||||
:param output_names: List of labels for the output units (default: TensorFlow name auto-assign logic).
|
||||
:param optimizer: Optimizer, see <https://keras.io/optimizers/> (default: ``rmsprop``).
|
||||
:param loss: Loss function, see <https://keras.io/losses/>. An objective function is any callable with
|
||||
the signature ``scalar_loss = fn(y_true, y_pred)``. If the model has multiple outputs, you can use a
|
||||
|
@ -360,11 +381,8 @@ class TensorflowPlugin(Plugin):
|
|||
**kwargs
|
||||
)
|
||||
|
||||
if input_names:
|
||||
model.input_names = input_names
|
||||
if output_names:
|
||||
model.output_names = output_names
|
||||
|
||||
model.input_labels = input_names or []
|
||||
model.output_labels = output_names or []
|
||||
self.models[name] = model
|
||||
return model.get_config()
|
||||
|
||||
|
@ -395,7 +413,7 @@ class TensorflowPlugin(Plugin):
|
|||
:param name: Name of the model.
|
||||
:param units: Output dimension (default: 1).
|
||||
:param input_names: List of names for the input units (default: TensorFlow name auto-assign logic).
|
||||
:param output_names: List of names for the output units (default: TensorFlow name auto-assign logic).
|
||||
:param output_names: List of labels for the output units (default: TensorFlow name auto-assign logic).
|
||||
:param activation: Activation function to be used (default: None).
|
||||
:param use_bias: Whether to calculate the bias/intercept for this model. If set
|
||||
to False, no bias/intercept will be used in calculations, e.g., the data
|
||||
|
@ -475,11 +493,13 @@ class TensorflowPlugin(Plugin):
|
|||
bias_regularizer=bias_regularizer,
|
||||
name=name)
|
||||
|
||||
if input_names:
|
||||
model.input_names = input_names
|
||||
model.input_names = input_names or []
|
||||
|
||||
if output_names:
|
||||
assert units == len(output_names)
|
||||
model.output_names = output_names
|
||||
model.output_labels = output_names
|
||||
else:
|
||||
model.output_labels = []
|
||||
|
||||
model.compile(
|
||||
optimizer=optimizer,
|
||||
|
@ -516,31 +536,106 @@ class TensorflowPlugin(Plugin):
|
|||
return list(np.load(data_file).values()).pop()
|
||||
|
||||
@classmethod
|
||||
def _get_data(cls, data: Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]]) \
|
||||
def _get_image(cls, image_file: str, model: Model) -> np.ndarray:
|
||||
input_shape = model.inputs[0].shape
|
||||
size = input_shape[1:3].as_list()
|
||||
assert len(size) == 2, 'The model {} does not have enough dimensions to process an image (shape: {})'.format(
|
||||
model.name, size)
|
||||
|
||||
colors = input_shape[3:]
|
||||
assert colors, ('The model {} requires a tensor with at least 3 inputs in order to process images: ' +
|
||||
'[WIDTH, HEIGHT, COLORS]').format(model.name)
|
||||
|
||||
if colors[0] == 1:
|
||||
color_mode = 'grayscale'
|
||||
elif colors[0] == 3:
|
||||
color_mode = 'rgb'
|
||||
elif colors[0] == 4:
|
||||
color_mode = 'rgba'
|
||||
else:
|
||||
raise AssertionError('The input tensor should have either 1 (grayscale), 3 (rgb) or 4 (rgba) units. ' +
|
||||
'Found: {}'.format(colors[0]))
|
||||
|
||||
img = image.load_img(image_file, target_size=size, color_mode=color_mode)
|
||||
return image.img_to_array(img)
|
||||
|
||||
@classmethod
|
||||
def _get_dir(cls, directory: str, model: Model) -> Dict[str, Iterable]:
|
||||
labels = [f for f in os.listdir(directory) if os.path.isdir(os.path.join(directory, f))]
|
||||
assert set(model.output_labels) == set(labels),\
|
||||
'The directory {dir} should contain exactly {n} subfolders named {names}'.format(
|
||||
dir=directory, n=len(model.output_labels), names=model.output.labels)
|
||||
|
||||
ret = {}
|
||||
for label in labels:
|
||||
subdir = os.path.join(directory, label)
|
||||
ret[label] = [
|
||||
cls._get_data(os.path.join(subdir, f), model)
|
||||
for f in os.listdir(subdir)
|
||||
if f.split('.')[-1] in cls._supported_data_file_extensions
|
||||
]
|
||||
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def _get_data(cls, data: Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]], model: Model) \
|
||||
-> Union[np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]]:
|
||||
if not isinstance(data, str):
|
||||
return data
|
||||
|
||||
if data.startswith('http://') or data.startswith('https://'):
|
||||
filename = '{timestamp}_{filename}'.format(
|
||||
timestamp=datetime.now().timestamp(), filename=data.split('/')[-1])
|
||||
data_file = utils.get_file(filename, data)
|
||||
else:
|
||||
data_file = os.path.abspath(os.path.expanduser(data))
|
||||
|
||||
extensions = [ext for ext in cls._supported_data_file_extensions if data_file.endswith('.' + ext)]
|
||||
assert os.path.isfile(data_file)
|
||||
|
||||
if os.path.isfile(data_file):
|
||||
assert extensions, 'Unsupported type for file {}. Supported extensions: {}'.format(
|
||||
data_file, cls._supported_data_file_extensions
|
||||
)
|
||||
|
||||
extension = extensions.pop()
|
||||
if extension == 'csv':
|
||||
if extension in cls._csv_extensions:
|
||||
return cls._get_csv_data(data_file)
|
||||
if extension == 'npy':
|
||||
return cls._get_numpy_data(data_file)
|
||||
if extension == 'npz':
|
||||
return cls._get_numpy_compressed_data(data_file)
|
||||
if extension in cls._image_extensions:
|
||||
return cls._get_image(data_file, model)
|
||||
|
||||
raise AssertionError('Something went wrong while loading the data file {}'.format(data_file))
|
||||
raise AssertionError('Unsupported file type: {}'.format(data_file))
|
||||
elif os.path.isdir(data_file):
|
||||
return cls._get_dir(data_file, model)
|
||||
|
||||
@classmethod
|
||||
def _get_dataset(cls,
|
||||
inputs: Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]],
|
||||
outputs: Optional[Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]]],
|
||||
model: Model) \
|
||||
-> Tuple[Union[np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]],
|
||||
Optional[Union[np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]]]]:
|
||||
inputs = cls._get_data(inputs, model)
|
||||
if outputs:
|
||||
outputs = cls._get_data(inputs, model)
|
||||
elif isinstance(inputs, dict) and model.output_labels:
|
||||
pairs = []
|
||||
for i, label in enumerate(model.output_labels):
|
||||
data = inputs.get(label, [])
|
||||
pairs.extend([(d, i) for d in data])
|
||||
|
||||
random.shuffle(pairs)
|
||||
inputs = np.asarray([p[0] for p in pairs])
|
||||
outputs = np.asarray([p[1] for p in pairs])
|
||||
|
||||
return inputs, outputs
|
||||
|
||||
@action
|
||||
def train(self,
|
||||
name: str,
|
||||
model: str,
|
||||
inputs: Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]],
|
||||
outputs: Optional[Union[str, np.ndarray, Iterable]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
|
@ -561,7 +656,7 @@ class TensorflowPlugin(Plugin):
|
|||
"""
|
||||
Trains a model on a dataset for a fixed number of epochs.
|
||||
|
||||
:param name: Name of the existing model to be trained.
|
||||
:param model: Name of the existing model to be trained.
|
||||
:param inputs: Input data. It can be:
|
||||
|
||||
- A numpy array (or array-like), or a list of arrays in case the model has multiple inputs.
|
||||
|
@ -575,9 +670,19 @@ class TensorflowPlugin(Plugin):
|
|||
|
||||
- CSV with header (``.csv`` extension``)
|
||||
- Numpy raw or compressed files (``.npy`` or ``.npz`` extension)
|
||||
- Image files
|
||||
- An HTTP URL pointing to one of the file types listed above
|
||||
- Directories with images. If ``inputs`` points to a directory of images then the following
|
||||
conventions are followed:
|
||||
|
||||
:param outputs: Target data. Like the input data `x`, it can be a numpy array (or array-like) or TensorFlow tensor(s).
|
||||
It should be consistent with `x` (you cannot have Numpy inputs and tensor targets, or inversely).
|
||||
- The folder must contain exactly as many subfolders as the output units of your model. If
|
||||
the model has ``output_labels`` then those subfolders should be named as the output labels.
|
||||
Each subfolder will contain training examples that match the associated label (e.g.
|
||||
``positive`` will contain all the positive images and ``negative`` all the negative images).
|
||||
- ``outputs`` doesn't have to be specified.
|
||||
|
||||
:param outputs: Target data. Like the input data `x`, it can be a numpy array (or array-like) or TensorFlow
|
||||
tensor(s). It should be consistent with `x` (you cannot have Numpy inputs and tensor targets, or inversely).
|
||||
If `x` is a dataset, generator, or `keras.utils.Sequence` instance, `y` should not be specified
|
||||
(since targets will be obtained from `x`).
|
||||
|
||||
|
@ -664,10 +769,9 @@ class TensorflowPlugin(Plugin):
|
|||
|
||||
:return: :class:`platypush.message.response.tensorflow.TensorflowTrainResponse`
|
||||
"""
|
||||
model = self._load_model(name)
|
||||
inputs = self._get_data(inputs)
|
||||
if outputs:
|
||||
outputs = self._get_data(outputs)
|
||||
name = model
|
||||
model = self._load_model(model)
|
||||
inputs, outputs = self._get_dataset(inputs, outputs, model)
|
||||
|
||||
ret = model.fit(
|
||||
x=inputs,
|
||||
|
@ -690,11 +794,11 @@ class TensorflowPlugin(Plugin):
|
|||
use_multiprocessing=use_multiprocessing,
|
||||
)
|
||||
|
||||
return TensorflowTrainResponse(model=name, epochs=ret.epoch, history=ret.history)
|
||||
return TensorflowTrainResponse(model=model, epochs=ret.epoch, history=ret.history)
|
||||
|
||||
@action
|
||||
def evaluate(self,
|
||||
name: str,
|
||||
model: str,
|
||||
inputs: Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]],
|
||||
outputs: Optional[Union[str, np.ndarray, Iterable]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
|
@ -707,7 +811,7 @@ class TensorflowPlugin(Plugin):
|
|||
"""
|
||||
Returns the loss value and metrics values for the model in test model.
|
||||
|
||||
:param name: Name of the existing model to be trained.
|
||||
:param model: Name of the existing model to be trained.
|
||||
:param inputs: Input data. It can be:
|
||||
|
||||
- A numpy array (or array-like), or a list of arrays in case the model has multiple inputs.
|
||||
|
@ -721,6 +825,16 @@ class TensorflowPlugin(Plugin):
|
|||
|
||||
- CSV with header (``.csv`` extension``)
|
||||
- Numpy raw or compressed files (``.npy`` or ``.npz`` extension)
|
||||
- Image files
|
||||
- An HTTP URL pointing to one of the file types listed above
|
||||
- Directories with images. If ``inputs`` points to a directory of images then the following
|
||||
conventions are followed:
|
||||
|
||||
- The folder must contain exactly as many subfolders as the output units of your model. If
|
||||
the model has ``output_labels`` then those subfolders should be named as the output labels.
|
||||
Each subfolder will contain training examples that match the associated label (e.g.
|
||||
``positive`` will contain all the positive images and ``negative`` all the negative images).
|
||||
- ``outputs`` doesn't have to be specified.
|
||||
|
||||
|
||||
:param outputs: Target data. Like the input data `x`, it can be a numpy array (or array-like) or TensorFlow tensor(s).
|
||||
|
@ -765,10 +879,9 @@ class TensorflowPlugin(Plugin):
|
|||
otherwise a list with the result test metrics (loss is usually the first value).
|
||||
"""
|
||||
|
||||
model = self._load_model(name)
|
||||
inputs = self._get_data(inputs)
|
||||
if outputs:
|
||||
outputs = self._get_data(outputs)
|
||||
name = model
|
||||
model = self._load_model(model)
|
||||
inputs, outputs = self._get_dataset(inputs, outputs, model)
|
||||
|
||||
ret = model.evaluate(
|
||||
x=inputs,
|
||||
|
@ -791,18 +904,18 @@ class TensorflowPlugin(Plugin):
|
|||
|
||||
@action
|
||||
def predict(self,
|
||||
name: str,
|
||||
model: str,
|
||||
inputs: Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]],
|
||||
batch_size: Optional[int] = None,
|
||||
verbose: int = 0,
|
||||
steps: Optional[int] = None,
|
||||
max_queue_size: int = 10,
|
||||
workers: int = 1,
|
||||
use_multiprocessing: bool = False) -> Union[Dict[str, float], List[float]]:
|
||||
use_multiprocessing: bool = False) -> TensorflowPredictResponse:
|
||||
"""
|
||||
Generates output predictions for the input samples.
|
||||
|
||||
:param name: Name of the existing model to be trained.
|
||||
:param model: Name of the existing model to be trained.
|
||||
:param inputs: Input data. It can be:
|
||||
|
||||
- A numpy array (or array-like), or a list of arrays in case the model has multiple inputs.
|
||||
|
@ -816,7 +929,8 @@ class TensorflowPlugin(Plugin):
|
|||
|
||||
- CSV with header (``.csv`` extension``)
|
||||
- Numpy raw or compressed files (``.npy`` or ``.npz`` extension)
|
||||
|
||||
- Image files
|
||||
- An HTTP URL pointing to one of the file types listed above
|
||||
|
||||
:param batch_size: Number of samples per gradient update. If unspecified, ``batch_size`` will default to 32.
|
||||
Do not specify the ``batch_size`` if your data is in the form of symbolic tensors, datasets,
|
||||
|
@ -840,11 +954,57 @@ class TensorflowPlugin(Plugin):
|
|||
Note that because this implementation relies on multiprocessing, you should not pass non-picklable
|
||||
arguments to the generator as they can't be passed easily to children processes.
|
||||
|
||||
:return: ``{output_metric: metric_value}`` dictionary if the ``output_names`` of the model are specified,
|
||||
otherwise a list with the result values.
|
||||
:return: :class:`platypush.message.response.tensorflow.TensorflowPredictResponse`. Format:
|
||||
|
||||
- For regression models with no output labels specified: ``values`` will contain the output vector:
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"values": [[3.1415]]
|
||||
}
|
||||
|
||||
- For regression models with output labels specified: ``values`` will be a list of ``{label -> value}``
|
||||
maps:
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"values": [
|
||||
{
|
||||
"x": 42.0,
|
||||
"y": 43.0
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
- For neural networks: ``values`` will contain the list of the output vector like in the case of
|
||||
regression, and ``labels`` will store the list of ``argmax`` (i.e. the index of the output unit with the
|
||||
highest value) or their labels, if the model has output labels:
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"labels": [
|
||||
"positive"
|
||||
],
|
||||
"values": [
|
||||
{
|
||||
"positive": 0.998,
|
||||
"negative": 0.002
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
"""
|
||||
model = self._load_model(name)
|
||||
inputs = self._get_data(inputs)
|
||||
name = model
|
||||
model = self._load_model(model)
|
||||
inputs = self._get_data(inputs, model)
|
||||
if isinstance(inputs, np.ndarray) and \
|
||||
len(model.inputs[0].shape) == len(inputs.shape) + 1 and \
|
||||
model.inputs[0].shape[0] is None:
|
||||
inputs = np.asarray([inputs])
|
||||
|
||||
ret = model.predict(
|
||||
inputs,
|
||||
batch_size=batch_size,
|
||||
|
@ -856,26 +1016,34 @@ class TensorflowPlugin(Plugin):
|
|||
use_multiprocessing=use_multiprocessing
|
||||
)
|
||||
|
||||
if not model.output_names:
|
||||
return ret
|
||||
|
||||
return {model.output_names[i]: value for i, value in enumerate(ret)}
|
||||
return TensorflowPredictResponse(model=model, prediction=ret, output_labels=model.output_labels)
|
||||
|
||||
@action
|
||||
def save(self, name: str, overwrite: bool = True, **opts) -> None:
|
||||
def save(self, model: str, overwrite: bool = True, **opts) -> None:
|
||||
"""
|
||||
Save a model in memory to the filesystem. The model files will be stored under
|
||||
``<WORKDIR>/models/<model_name>``.
|
||||
|
||||
:param name: Model name.
|
||||
:param model: Model name.
|
||||
:param overwrite: Overwrite the model files if they already exist.
|
||||
:param opts: Extra options to be passed to ``Model.save()``.
|
||||
"""
|
||||
assert name in self.models, 'No such model in memory: {}'.format(name)
|
||||
model_dir = os.path.join(self._models_dir, name)
|
||||
name = self.models[name]
|
||||
assert model in self.models, 'No such model in memory: {}'.format(model)
|
||||
model_dir = os.path.join(self._models_dir, model)
|
||||
model = self.models[model]
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
name.save(model_dir, overwrite=overwrite, options=opts)
|
||||
labels = {}
|
||||
labels_file = os.path.join(model_dir, 'labels.json')
|
||||
|
||||
if model.input_labels:
|
||||
labels['input'] = model.input_labels
|
||||
if model.output_labels:
|
||||
labels['output'] = model.output_labels
|
||||
if labels:
|
||||
with open(labels_file, 'w') as f:
|
||||
json.dump(labels, f)
|
||||
|
||||
model.save(model_dir, overwrite=overwrite, options=opts)
|
||||
|
||||
|
||||
# vim:sw=4:ts=4:et:
|
||||
|
|
Loading…
Reference in a new issue