forked from platypush/platypush
[#398] Removed custom Response
objects from Tensorflow and response docs generation logic.
Closes: #398
This commit is contained in:
parent
77c91aa5e3
commit
98a98ea1dc
7 changed files with 100 additions and 152 deletions
|
@ -50,7 +50,6 @@ Reference
|
||||||
backends
|
backends
|
||||||
plugins
|
plugins
|
||||||
events
|
events
|
||||||
responses
|
|
||||||
|
|
||||||
Indices and tables
|
Indices and tables
|
||||||
==================
|
==================
|
||||||
|
|
|
@ -1,5 +0,0 @@
|
||||||
``tensorflow``
|
|
||||||
=========================================
|
|
||||||
|
|
||||||
.. automodule:: platypush.message.response.tensorflow
|
|
||||||
:members:
|
|
|
@ -1,9 +0,0 @@
|
||||||
|
|
||||||
Responses
|
|
||||||
=========
|
|
||||||
|
|
||||||
.. toctree::
|
|
||||||
:maxdepth: 1
|
|
||||||
:caption: Responses:
|
|
||||||
|
|
||||||
platypush/responses/tensorflow.rst
|
|
|
@ -8,7 +8,6 @@ import pkgutil
|
||||||
|
|
||||||
from platypush.backend import Backend
|
from platypush.backend import Backend
|
||||||
from platypush.message.event import Event
|
from platypush.message.event import Event
|
||||||
from platypush.message.response import Response
|
|
||||||
from platypush.plugins import Plugin
|
from platypush.plugins import Plugin
|
||||||
from platypush.utils.manifest import Manifests
|
from platypush.utils.manifest import Manifests
|
||||||
from platypush.utils.mock import auto_mocks
|
from platypush.utils.mock import auto_mocks
|
||||||
|
@ -26,10 +25,6 @@ def get_all_events():
|
||||||
return _get_modules(Event)
|
return _get_modules(Event)
|
||||||
|
|
||||||
|
|
||||||
def get_all_responses():
|
|
||||||
return _get_modules(Response)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_modules(base_type: type):
|
def _get_modules(base_type: type):
|
||||||
ret = set()
|
ret = set()
|
||||||
base_dir = os.path.dirname(inspect.getfile(base_type))
|
base_dir = os.path.dirname(inspect.getfile(base_type))
|
||||||
|
@ -151,20 +146,11 @@ def generate_events_doc():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_responses_doc():
|
|
||||||
_generate_components_doc(
|
|
||||||
index_name='responses',
|
|
||||||
package_name='message.response',
|
|
||||||
components=sorted(response for response in get_all_responses() if response),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
with auto_mocks():
|
with auto_mocks():
|
||||||
generate_plugins_doc()
|
generate_plugins_doc()
|
||||||
generate_backends_doc()
|
generate_backends_doc()
|
||||||
generate_events_doc()
|
generate_events_doc()
|
||||||
generate_responses_doc()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -1,80 +0,0 @@
|
||||||
from typing import Dict, List, Union, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from platypush.message.response import Response
|
|
||||||
|
|
||||||
|
|
||||||
class TensorflowResponse(Response):
|
|
||||||
"""
|
|
||||||
Generic Tensorflow response.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *args, model, model_name: Optional[str] = None, **kwargs):
|
|
||||||
"""
|
|
||||||
:param model: Name of the model.
|
|
||||||
"""
|
|
||||||
super().__init__(
|
|
||||||
*args,
|
|
||||||
output={
|
|
||||||
'model': model_name or model.name,
|
|
||||||
},
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
|
|
||||||
class TensorflowTrainResponse(TensorflowResponse):
|
|
||||||
"""
|
|
||||||
Tensorflow model fit/train response.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
epochs: List[int],
|
|
||||||
history: Dict[str, List[Union[int, float]]],
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
:param epochs: List of epoch indexes the model has been trained on.
|
|
||||||
:param history: Train history, as a ``metric -> [values]`` dictionary where each value in ``values`` is
|
|
||||||
the value for of that metric on a specific epoch.
|
|
||||||
"""
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.output['epochs'] = epochs
|
|
||||||
self.output['history'] = history
|
|
||||||
|
|
||||||
|
|
||||||
class TensorflowPredictResponse(TensorflowResponse):
|
|
||||||
"""
|
|
||||||
Tensorflow model prediction response.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
prediction: np.ndarray,
|
|
||||||
output_labels: Optional[List[str]] = None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
if output_labels and len(output_labels) == self.model.outputs[-1].shape[-1]:
|
|
||||||
self.output['outputs'] = [
|
|
||||||
{output_labels[i]: value for i, value in enumerate(p)}
|
|
||||||
for p in prediction
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
self.output['outputs'] = prediction
|
|
||||||
|
|
||||||
if self.model.__class__.__name__ != 'LinearModel':
|
|
||||||
prediction = [int(np.argmax(p)) for p in prediction]
|
|
||||||
if output_labels:
|
|
||||||
self.output['predictions'] = [output_labels[p] for p in prediction]
|
|
||||||
else:
|
|
||||||
self.output['predictions'] = prediction
|
|
||||||
|
|
||||||
|
|
||||||
# vim:sw=4:ts=4:et:
|
|
|
@ -20,11 +20,8 @@ from platypush.message.event.tensorflow import (
|
||||||
TensorflowTrainStartedEvent,
|
TensorflowTrainStartedEvent,
|
||||||
TensorflowTrainEndedEvent,
|
TensorflowTrainEndedEvent,
|
||||||
)
|
)
|
||||||
from platypush.message.response.tensorflow import (
|
|
||||||
TensorflowTrainResponse,
|
|
||||||
TensorflowPredictResponse,
|
|
||||||
)
|
|
||||||
from platypush.plugins import Plugin, action
|
from platypush.plugins import Plugin, action
|
||||||
|
from platypush.schemas.tensorflow import TensorflowTrainSchema
|
||||||
|
|
||||||
|
|
||||||
class TensorflowPlugin(Plugin):
|
class TensorflowPlugin(Plugin):
|
||||||
|
@ -50,11 +47,11 @@ class TensorflowPlugin(Plugin):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.models = {} # str -> Model
|
self.models = {} # str -> Model
|
||||||
self._models_lock = threading.RLock()
|
self._models_lock = threading.RLock()
|
||||||
self._model_locks: Dict[str, threading.RLock()] = {}
|
self._model_locks: Dict[str, threading.RLock] = {}
|
||||||
self._work_dir = (
|
self._work_dir = (
|
||||||
os.path.abspath(os.path.expanduser(workdir))
|
os.path.abspath(os.path.expanduser(workdir))
|
||||||
if workdir
|
if workdir
|
||||||
else os.path.join(Config.get('workdir'), 'tensorflow')
|
else os.path.join(Config.get_workdir(), 'tensorflow')
|
||||||
)
|
)
|
||||||
|
|
||||||
self._models_dir = os.path.join(self._work_dir, 'models')
|
self._models_dir = os.path.join(self._work_dir, 'models')
|
||||||
|
@ -99,6 +96,8 @@ class TensorflowPlugin(Plugin):
|
||||||
elif os.path.isdir(model_name):
|
elif os.path.isdir(model_name):
|
||||||
model_dir = model_name
|
model_dir = model_name
|
||||||
model = load_model(model_dir)
|
model = load_model(model_dir)
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f'Model not found: {model_name}')
|
||||||
|
|
||||||
assert model, 'Could not find model: {}'.format(model_name)
|
assert model, 'Could not find model: {}'.format(model_name)
|
||||||
model.input_labels = []
|
model.input_labels = []
|
||||||
|
@ -196,8 +195,7 @@ class TensorflowPlugin(Plugin):
|
||||||
loaded, otherwise the model currently in memory will be kept (default: ``False``).
|
loaded, otherwise the model currently in memory will be kept (default: ``False``).
|
||||||
:return: The model configuration.
|
:return: The model configuration.
|
||||||
"""
|
"""
|
||||||
model = self._load_model(model, reload=reload)
|
return self._load_model(model, reload=reload).get_config()
|
||||||
return model.get_config()
|
|
||||||
|
|
||||||
@action
|
@action
|
||||||
def unload(self, model: str) -> None:
|
def unload(self, model: str) -> None:
|
||||||
|
@ -794,12 +792,12 @@ class TensorflowPlugin(Plugin):
|
||||||
sample_weight: Optional[Union[np.ndarray, Iterable]] = None,
|
sample_weight: Optional[Union[np.ndarray, Iterable]] = None,
|
||||||
initial_epoch: int = 0,
|
initial_epoch: int = 0,
|
||||||
steps_per_epoch: Optional[int] = None,
|
steps_per_epoch: Optional[int] = None,
|
||||||
validation_steps: int = None,
|
validation_steps: Optional[int] = None,
|
||||||
validation_freq: int = 1,
|
validation_freq: int = 1,
|
||||||
max_queue_size: int = 10,
|
max_queue_size: int = 10,
|
||||||
workers: int = 1,
|
workers: int = 1,
|
||||||
use_multiprocessing: bool = False,
|
use_multiprocessing: bool = False,
|
||||||
) -> TensorflowTrainResponse:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Trains a model on a dataset for a fixed number of epochs.
|
Trains a model on a dataset for a fixed number of epochs.
|
||||||
|
|
||||||
|
@ -915,13 +913,12 @@ class TensorflowPlugin(Plugin):
|
||||||
Note that because this implementation relies on multiprocessing, you should not pass non-picklable
|
Note that because this implementation relies on multiprocessing, you should not pass non-picklable
|
||||||
arguments to the generator as they can't be passed easily to children processes.
|
arguments to the generator as they can't be passed easily to children processes.
|
||||||
|
|
||||||
:return: :class:`platypush.message.response.tensorflow.TensorflowTrainResponse`
|
:return: .. schema:: tensorflow.TensorflowTrainSchema
|
||||||
"""
|
"""
|
||||||
name = model
|
name = model
|
||||||
model = self._load_model(model)
|
model_obj = self._load_model(model)
|
||||||
inputs, outputs = self._get_dataset(inputs, outputs, model)
|
inputs, outputs = self._get_dataset(inputs, outputs, model)
|
||||||
|
ret = model_obj.fit(
|
||||||
ret = model.fit(
|
|
||||||
x=inputs,
|
x=inputs,
|
||||||
y=outputs,
|
y=outputs,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
@ -942,8 +939,14 @@ class TensorflowPlugin(Plugin):
|
||||||
use_multiprocessing=use_multiprocessing,
|
use_multiprocessing=use_multiprocessing,
|
||||||
)
|
)
|
||||||
|
|
||||||
return TensorflowTrainResponse(
|
return dict(
|
||||||
model=model, model_name=name, epochs=ret.epoch, history=ret.history
|
TensorflowTrainSchema().dump(
|
||||||
|
{
|
||||||
|
'model': name,
|
||||||
|
'epochs': ret.epoch,
|
||||||
|
'history': ret.history,
|
||||||
|
}
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@action
|
@action
|
||||||
|
@ -1035,10 +1038,10 @@ class TensorflowPlugin(Plugin):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = model
|
name = model
|
||||||
model = self._load_model(model)
|
model_obj = self._load_model(model)
|
||||||
inputs, outputs = self._get_dataset(inputs, outputs, model)
|
inputs, outputs = self._get_dataset(inputs, outputs, model_obj)
|
||||||
|
|
||||||
ret = model.evaluate(
|
ret = model_obj.evaluate(
|
||||||
x=inputs,
|
x=inputs,
|
||||||
y=outputs,
|
y=outputs,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
@ -1052,10 +1055,10 @@ class TensorflowPlugin(Plugin):
|
||||||
)
|
)
|
||||||
|
|
||||||
ret = ret if isinstance(ret, list) else [ret]
|
ret = ret if isinstance(ret, list) else [ret]
|
||||||
if not model.metrics_names:
|
if not model_obj.metrics_names:
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
return {model.metrics_names[i]: value for i, value in enumerate(ret)}
|
return {model_obj.metrics_names[i]: value for i, value in enumerate(ret)}
|
||||||
|
|
||||||
@action
|
@action
|
||||||
def predict(
|
def predict(
|
||||||
|
@ -1070,7 +1073,7 @@ class TensorflowPlugin(Plugin):
|
||||||
max_queue_size: int = 10,
|
max_queue_size: int = 10,
|
||||||
workers: int = 1,
|
workers: int = 1,
|
||||||
use_multiprocessing: bool = False,
|
use_multiprocessing: bool = False,
|
||||||
) -> TensorflowPredictResponse:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Generates output predictions for the input samples.
|
Generates output predictions for the input samples.
|
||||||
|
|
||||||
|
@ -1114,7 +1117,7 @@ class TensorflowPlugin(Plugin):
|
||||||
Note that because this implementation relies on multiprocessing, you should not pass non-picklable
|
Note that because this implementation relies on multiprocessing, you should not pass non-picklable
|
||||||
arguments to the generator as they can't be passed easily to children processes.
|
arguments to the generator as they can't be passed easily to children processes.
|
||||||
|
|
||||||
:return: :class:`platypush.message.response.tensorflow.TensorflowPredictResponse`. Format:
|
:return: Format:
|
||||||
|
|
||||||
- For regression models with no output labels specified: ``outputs`` will contain the output vector:
|
- For regression models with no output labels specified: ``outputs`` will contain the output vector:
|
||||||
|
|
||||||
|
@ -1158,19 +1161,19 @@ class TensorflowPlugin(Plugin):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
name = model
|
name = model
|
||||||
model = self._load_model(model)
|
model_obj = self._load_model(model)
|
||||||
inputs = self._get_data(inputs, model)
|
inputs = self._get_data(inputs, model_obj)
|
||||||
if (
|
if (
|
||||||
isinstance(inputs, np.ndarray)
|
isinstance(inputs, np.ndarray)
|
||||||
and len(model.inputs[0].shape) == len(inputs.shape) + 1
|
and len(model_obj.inputs[0].shape) == len(inputs.shape) + 1
|
||||||
and (
|
and (
|
||||||
model.inputs[0].shape[0] is None
|
model_obj.inputs[0].shape[0] is None
|
||||||
or model.inputs[0].shape[0].value is None
|
or model_obj.inputs[0].shape[0].value is None
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
inputs = np.asarray([inputs])
|
inputs = np.asarray([inputs])
|
||||||
|
|
||||||
ret = model.predict(
|
ret = model_obj.predict(
|
||||||
inputs,
|
inputs,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
|
@ -1181,12 +1184,28 @@ class TensorflowPlugin(Plugin):
|
||||||
use_multiprocessing=use_multiprocessing,
|
use_multiprocessing=use_multiprocessing,
|
||||||
)
|
)
|
||||||
|
|
||||||
return TensorflowPredictResponse(
|
if (
|
||||||
model=model,
|
model_obj.output_labels
|
||||||
model_name=name,
|
and len(model_obj.output_labels) == model_obj.outputs[-1].shape[-1]
|
||||||
prediction=ret,
|
):
|
||||||
output_labels=model.output_labels,
|
outputs = [
|
||||||
)
|
{model_obj.output_labels[i]: value for i, value in enumerate(p)}
|
||||||
|
for p in ret
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
outputs = ret
|
||||||
|
|
||||||
|
if model_obj.__class__.__name__ != 'LinearModel':
|
||||||
|
ret = [int(np.argmax(p)) for p in ret]
|
||||||
|
if model_obj.output_labels:
|
||||||
|
ret = [model_obj.output_labels[p] for p in ret]
|
||||||
|
|
||||||
|
return {
|
||||||
|
'model': name,
|
||||||
|
'outputs': outputs,
|
||||||
|
'prediction': ret,
|
||||||
|
'output_labels': model_obj.output_labels,
|
||||||
|
}
|
||||||
|
|
||||||
@action
|
@action
|
||||||
def save(self, model: str, overwrite: bool = True, **opts) -> None:
|
def save(self, model: str, overwrite: bool = True, **opts) -> None:
|
||||||
|
@ -1212,28 +1231,30 @@ class TensorflowPlugin(Plugin):
|
||||||
model_dir = str(pathlib.Path(model_file).parent)
|
model_dir = str(pathlib.Path(model_file).parent)
|
||||||
elif os.path.isdir(model_file):
|
elif os.path.isdir(model_file):
|
||||||
model_dir = model_file
|
model_dir = model_file
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f'No such model loaded: {model_name}')
|
||||||
|
|
||||||
model = self.models.get(model_name, self.models.get(model_dir))
|
model_obj = self.models.get(model_name, self.models.get(model_dir))
|
||||||
assert model, 'No such model loaded: {}'.format(model_name)
|
assert model_obj, f'No such model loaded: {model_name}'
|
||||||
pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)
|
pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
with self._lock_model(model_name):
|
with self._lock_model(model_name):
|
||||||
labels = {}
|
labels = {}
|
||||||
labels_file = os.path.join(model_dir, 'labels.json')
|
labels_file = os.path.join(model_dir, 'labels.json')
|
||||||
|
|
||||||
if hasattr(model, 'input_labels') and model.input_labels:
|
if hasattr(model_obj, 'input_labels') and model_obj.input_labels:
|
||||||
labels['input'] = model.input_labels
|
labels['input'] = model_obj.input_labels
|
||||||
if hasattr(model, 'output_labels') and model.output_labels:
|
if hasattr(model_obj, 'output_labels') and model_obj.output_labels:
|
||||||
if hasattr(labels, 'input'):
|
if hasattr(labels, 'input'):
|
||||||
labels['output'] = model.output_labels
|
labels['output'] = model_obj.output_labels
|
||||||
else:
|
else:
|
||||||
labels = model.output_labels
|
labels = model_obj.output_labels
|
||||||
|
|
||||||
if labels:
|
if labels:
|
||||||
with open(labels_file, 'w') as f:
|
with open(labels_file, 'w') as f:
|
||||||
json.dump(labels, f)
|
json.dump(labels, f)
|
||||||
|
|
||||||
model.save(
|
model_obj.save(
|
||||||
model_name if os.path.isfile(model_name) else model_dir,
|
model_name if os.path.isfile(model_name) else model_dir,
|
||||||
overwrite=overwrite,
|
overwrite=overwrite,
|
||||||
options=opts,
|
options=opts,
|
||||||
|
|
36
platypush/schemas/tensorflow.py
Normal file
36
platypush/schemas/tensorflow.py
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
from marshmallow import fields
|
||||||
|
from marshmallow.schema import Schema
|
||||||
|
|
||||||
|
from platypush.schemas import StrippedString
|
||||||
|
|
||||||
|
|
||||||
|
class TensorflowTrainSchema(Schema):
|
||||||
|
"""
|
||||||
|
Schema for TensorFlow model training results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model = StrippedString(
|
||||||
|
required=True,
|
||||||
|
metadata={
|
||||||
|
"description": "Model name.",
|
||||||
|
"example": "MyModel",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
epochs = fields.Int(
|
||||||
|
required=True,
|
||||||
|
metadata={
|
||||||
|
"description": "Number of epochs.",
|
||||||
|
"example": 10,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
history = fields.Dict(
|
||||||
|
metadata={
|
||||||
|
"description": "Training history.",
|
||||||
|
"example": {
|
||||||
|
"loss": [0.1, 0.2, 0.3],
|
||||||
|
"accuracy": [0.9, 0.8, 0.7],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
Loading…
Reference in a new issue