forked from platypush/platypush
[#398] Removed custom Response
objects from Tensorflow and response docs generation logic.
Closes: #398
This commit is contained in:
parent
77c91aa5e3
commit
98a98ea1dc
7 changed files with 100 additions and 152 deletions
|
@ -50,7 +50,6 @@ Reference
|
|||
backends
|
||||
plugins
|
||||
events
|
||||
responses
|
||||
|
||||
Indices and tables
|
||||
==================
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
``tensorflow``
|
||||
=========================================
|
||||
|
||||
.. automodule:: platypush.message.response.tensorflow
|
||||
:members:
|
|
@ -1,9 +0,0 @@
|
|||
|
||||
Responses
|
||||
=========
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Responses:
|
||||
|
||||
platypush/responses/tensorflow.rst
|
|
@ -8,7 +8,6 @@ import pkgutil
|
|||
|
||||
from platypush.backend import Backend
|
||||
from platypush.message.event import Event
|
||||
from platypush.message.response import Response
|
||||
from platypush.plugins import Plugin
|
||||
from platypush.utils.manifest import Manifests
|
||||
from platypush.utils.mock import auto_mocks
|
||||
|
@ -26,10 +25,6 @@ def get_all_events():
|
|||
return _get_modules(Event)
|
||||
|
||||
|
||||
def get_all_responses():
|
||||
return _get_modules(Response)
|
||||
|
||||
|
||||
def _get_modules(base_type: type):
|
||||
ret = set()
|
||||
base_dir = os.path.dirname(inspect.getfile(base_type))
|
||||
|
@ -151,20 +146,11 @@ def generate_events_doc():
|
|||
)
|
||||
|
||||
|
||||
def generate_responses_doc():
|
||||
_generate_components_doc(
|
||||
index_name='responses',
|
||||
package_name='message.response',
|
||||
components=sorted(response for response in get_all_responses() if response),
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
with auto_mocks():
|
||||
generate_plugins_doc()
|
||||
generate_backends_doc()
|
||||
generate_events_doc()
|
||||
generate_responses_doc()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,80 +0,0 @@
|
|||
from typing import Dict, List, Union, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from platypush.message.response import Response
|
||||
|
||||
|
||||
class TensorflowResponse(Response):
|
||||
"""
|
||||
Generic Tensorflow response.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, model, model_name: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
:param model: Name of the model.
|
||||
"""
|
||||
super().__init__(
|
||||
*args,
|
||||
output={
|
||||
'model': model_name or model.name,
|
||||
},
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.model = model
|
||||
|
||||
|
||||
class TensorflowTrainResponse(TensorflowResponse):
|
||||
"""
|
||||
Tensorflow model fit/train response.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
epochs: List[int],
|
||||
history: Dict[str, List[Union[int, float]]],
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
:param epochs: List of epoch indexes the model has been trained on.
|
||||
:param history: Train history, as a ``metric -> [values]`` dictionary where each value in ``values`` is
|
||||
the value for of that metric on a specific epoch.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.output['epochs'] = epochs
|
||||
self.output['history'] = history
|
||||
|
||||
|
||||
class TensorflowPredictResponse(TensorflowResponse):
|
||||
"""
|
||||
Tensorflow model prediction response.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
prediction: np.ndarray,
|
||||
output_labels: Optional[List[str]] = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if output_labels and len(output_labels) == self.model.outputs[-1].shape[-1]:
|
||||
self.output['outputs'] = [
|
||||
{output_labels[i]: value for i, value in enumerate(p)}
|
||||
for p in prediction
|
||||
]
|
||||
else:
|
||||
self.output['outputs'] = prediction
|
||||
|
||||
if self.model.__class__.__name__ != 'LinearModel':
|
||||
prediction = [int(np.argmax(p)) for p in prediction]
|
||||
if output_labels:
|
||||
self.output['predictions'] = [output_labels[p] for p in prediction]
|
||||
else:
|
||||
self.output['predictions'] = prediction
|
||||
|
||||
|
||||
# vim:sw=4:ts=4:et:
|
|
@ -20,11 +20,8 @@ from platypush.message.event.tensorflow import (
|
|||
TensorflowTrainStartedEvent,
|
||||
TensorflowTrainEndedEvent,
|
||||
)
|
||||
from platypush.message.response.tensorflow import (
|
||||
TensorflowTrainResponse,
|
||||
TensorflowPredictResponse,
|
||||
)
|
||||
from platypush.plugins import Plugin, action
|
||||
from platypush.schemas.tensorflow import TensorflowTrainSchema
|
||||
|
||||
|
||||
class TensorflowPlugin(Plugin):
|
||||
|
@ -50,11 +47,11 @@ class TensorflowPlugin(Plugin):
|
|||
super().__init__(**kwargs)
|
||||
self.models = {} # str -> Model
|
||||
self._models_lock = threading.RLock()
|
||||
self._model_locks: Dict[str, threading.RLock()] = {}
|
||||
self._model_locks: Dict[str, threading.RLock] = {}
|
||||
self._work_dir = (
|
||||
os.path.abspath(os.path.expanduser(workdir))
|
||||
if workdir
|
||||
else os.path.join(Config.get('workdir'), 'tensorflow')
|
||||
else os.path.join(Config.get_workdir(), 'tensorflow')
|
||||
)
|
||||
|
||||
self._models_dir = os.path.join(self._work_dir, 'models')
|
||||
|
@ -99,6 +96,8 @@ class TensorflowPlugin(Plugin):
|
|||
elif os.path.isdir(model_name):
|
||||
model_dir = model_name
|
||||
model = load_model(model_dir)
|
||||
else:
|
||||
raise FileNotFoundError(f'Model not found: {model_name}')
|
||||
|
||||
assert model, 'Could not find model: {}'.format(model_name)
|
||||
model.input_labels = []
|
||||
|
@ -196,8 +195,7 @@ class TensorflowPlugin(Plugin):
|
|||
loaded, otherwise the model currently in memory will be kept (default: ``False``).
|
||||
:return: The model configuration.
|
||||
"""
|
||||
model = self._load_model(model, reload=reload)
|
||||
return model.get_config()
|
||||
return self._load_model(model, reload=reload).get_config()
|
||||
|
||||
@action
|
||||
def unload(self, model: str) -> None:
|
||||
|
@ -794,12 +792,12 @@ class TensorflowPlugin(Plugin):
|
|||
sample_weight: Optional[Union[np.ndarray, Iterable]] = None,
|
||||
initial_epoch: int = 0,
|
||||
steps_per_epoch: Optional[int] = None,
|
||||
validation_steps: int = None,
|
||||
validation_steps: Optional[int] = None,
|
||||
validation_freq: int = 1,
|
||||
max_queue_size: int = 10,
|
||||
workers: int = 1,
|
||||
use_multiprocessing: bool = False,
|
||||
) -> TensorflowTrainResponse:
|
||||
) -> dict:
|
||||
"""
|
||||
Trains a model on a dataset for a fixed number of epochs.
|
||||
|
||||
|
@ -915,13 +913,12 @@ class TensorflowPlugin(Plugin):
|
|||
Note that because this implementation relies on multiprocessing, you should not pass non-picklable
|
||||
arguments to the generator as they can't be passed easily to children processes.
|
||||
|
||||
:return: :class:`platypush.message.response.tensorflow.TensorflowTrainResponse`
|
||||
:return: .. schema:: tensorflow.TensorflowTrainSchema
|
||||
"""
|
||||
name = model
|
||||
model = self._load_model(model)
|
||||
model_obj = self._load_model(model)
|
||||
inputs, outputs = self._get_dataset(inputs, outputs, model)
|
||||
|
||||
ret = model.fit(
|
||||
ret = model_obj.fit(
|
||||
x=inputs,
|
||||
y=outputs,
|
||||
batch_size=batch_size,
|
||||
|
@ -942,8 +939,14 @@ class TensorflowPlugin(Plugin):
|
|||
use_multiprocessing=use_multiprocessing,
|
||||
)
|
||||
|
||||
return TensorflowTrainResponse(
|
||||
model=model, model_name=name, epochs=ret.epoch, history=ret.history
|
||||
return dict(
|
||||
TensorflowTrainSchema().dump(
|
||||
{
|
||||
'model': name,
|
||||
'epochs': ret.epoch,
|
||||
'history': ret.history,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
@action
|
||||
|
@ -1035,10 +1038,10 @@ class TensorflowPlugin(Plugin):
|
|||
"""
|
||||
|
||||
name = model
|
||||
model = self._load_model(model)
|
||||
inputs, outputs = self._get_dataset(inputs, outputs, model)
|
||||
model_obj = self._load_model(model)
|
||||
inputs, outputs = self._get_dataset(inputs, outputs, model_obj)
|
||||
|
||||
ret = model.evaluate(
|
||||
ret = model_obj.evaluate(
|
||||
x=inputs,
|
||||
y=outputs,
|
||||
batch_size=batch_size,
|
||||
|
@ -1052,10 +1055,10 @@ class TensorflowPlugin(Plugin):
|
|||
)
|
||||
|
||||
ret = ret if isinstance(ret, list) else [ret]
|
||||
if not model.metrics_names:
|
||||
if not model_obj.metrics_names:
|
||||
return ret
|
||||
|
||||
return {model.metrics_names[i]: value for i, value in enumerate(ret)}
|
||||
return {model_obj.metrics_names[i]: value for i, value in enumerate(ret)}
|
||||
|
||||
@action
|
||||
def predict(
|
||||
|
@ -1070,7 +1073,7 @@ class TensorflowPlugin(Plugin):
|
|||
max_queue_size: int = 10,
|
||||
workers: int = 1,
|
||||
use_multiprocessing: bool = False,
|
||||
) -> TensorflowPredictResponse:
|
||||
) -> dict:
|
||||
"""
|
||||
Generates output predictions for the input samples.
|
||||
|
||||
|
@ -1114,7 +1117,7 @@ class TensorflowPlugin(Plugin):
|
|||
Note that because this implementation relies on multiprocessing, you should not pass non-picklable
|
||||
arguments to the generator as they can't be passed easily to children processes.
|
||||
|
||||
:return: :class:`platypush.message.response.tensorflow.TensorflowPredictResponse`. Format:
|
||||
:return: Format:
|
||||
|
||||
- For regression models with no output labels specified: ``outputs`` will contain the output vector:
|
||||
|
||||
|
@ -1158,19 +1161,19 @@ class TensorflowPlugin(Plugin):
|
|||
|
||||
"""
|
||||
name = model
|
||||
model = self._load_model(model)
|
||||
inputs = self._get_data(inputs, model)
|
||||
model_obj = self._load_model(model)
|
||||
inputs = self._get_data(inputs, model_obj)
|
||||
if (
|
||||
isinstance(inputs, np.ndarray)
|
||||
and len(model.inputs[0].shape) == len(inputs.shape) + 1
|
||||
and len(model_obj.inputs[0].shape) == len(inputs.shape) + 1
|
||||
and (
|
||||
model.inputs[0].shape[0] is None
|
||||
or model.inputs[0].shape[0].value is None
|
||||
model_obj.inputs[0].shape[0] is None
|
||||
or model_obj.inputs[0].shape[0].value is None
|
||||
)
|
||||
):
|
||||
inputs = np.asarray([inputs])
|
||||
|
||||
ret = model.predict(
|
||||
ret = model_obj.predict(
|
||||
inputs,
|
||||
batch_size=batch_size,
|
||||
verbose=verbose,
|
||||
|
@ -1181,12 +1184,28 @@ class TensorflowPlugin(Plugin):
|
|||
use_multiprocessing=use_multiprocessing,
|
||||
)
|
||||
|
||||
return TensorflowPredictResponse(
|
||||
model=model,
|
||||
model_name=name,
|
||||
prediction=ret,
|
||||
output_labels=model.output_labels,
|
||||
)
|
||||
if (
|
||||
model_obj.output_labels
|
||||
and len(model_obj.output_labels) == model_obj.outputs[-1].shape[-1]
|
||||
):
|
||||
outputs = [
|
||||
{model_obj.output_labels[i]: value for i, value in enumerate(p)}
|
||||
for p in ret
|
||||
]
|
||||
else:
|
||||
outputs = ret
|
||||
|
||||
if model_obj.__class__.__name__ != 'LinearModel':
|
||||
ret = [int(np.argmax(p)) for p in ret]
|
||||
if model_obj.output_labels:
|
||||
ret = [model_obj.output_labels[p] for p in ret]
|
||||
|
||||
return {
|
||||
'model': name,
|
||||
'outputs': outputs,
|
||||
'prediction': ret,
|
||||
'output_labels': model_obj.output_labels,
|
||||
}
|
||||
|
||||
@action
|
||||
def save(self, model: str, overwrite: bool = True, **opts) -> None:
|
||||
|
@ -1212,28 +1231,30 @@ class TensorflowPlugin(Plugin):
|
|||
model_dir = str(pathlib.Path(model_file).parent)
|
||||
elif os.path.isdir(model_file):
|
||||
model_dir = model_file
|
||||
else:
|
||||
raise FileNotFoundError(f'No such model loaded: {model_name}')
|
||||
|
||||
model = self.models.get(model_name, self.models.get(model_dir))
|
||||
assert model, 'No such model loaded: {}'.format(model_name)
|
||||
model_obj = self.models.get(model_name, self.models.get(model_dir))
|
||||
assert model_obj, f'No such model loaded: {model_name}'
|
||||
pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with self._lock_model(model_name):
|
||||
labels = {}
|
||||
labels_file = os.path.join(model_dir, 'labels.json')
|
||||
|
||||
if hasattr(model, 'input_labels') and model.input_labels:
|
||||
labels['input'] = model.input_labels
|
||||
if hasattr(model, 'output_labels') and model.output_labels:
|
||||
if hasattr(model_obj, 'input_labels') and model_obj.input_labels:
|
||||
labels['input'] = model_obj.input_labels
|
||||
if hasattr(model_obj, 'output_labels') and model_obj.output_labels:
|
||||
if hasattr(labels, 'input'):
|
||||
labels['output'] = model.output_labels
|
||||
labels['output'] = model_obj.output_labels
|
||||
else:
|
||||
labels = model.output_labels
|
||||
labels = model_obj.output_labels
|
||||
|
||||
if labels:
|
||||
with open(labels_file, 'w') as f:
|
||||
json.dump(labels, f)
|
||||
|
||||
model.save(
|
||||
model_obj.save(
|
||||
model_name if os.path.isfile(model_name) else model_dir,
|
||||
overwrite=overwrite,
|
||||
options=opts,
|
||||
|
|
36
platypush/schemas/tensorflow.py
Normal file
36
platypush/schemas/tensorflow.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
from marshmallow import fields
|
||||
from marshmallow.schema import Schema
|
||||
|
||||
from platypush.schemas import StrippedString
|
||||
|
||||
|
||||
class TensorflowTrainSchema(Schema):
|
||||
"""
|
||||
Schema for TensorFlow model training results.
|
||||
"""
|
||||
|
||||
model = StrippedString(
|
||||
required=True,
|
||||
metadata={
|
||||
"description": "Model name.",
|
||||
"example": "MyModel",
|
||||
},
|
||||
)
|
||||
|
||||
epochs = fields.Int(
|
||||
required=True,
|
||||
metadata={
|
||||
"description": "Number of epochs.",
|
||||
"example": 10,
|
||||
},
|
||||
)
|
||||
|
||||
history = fields.Dict(
|
||||
metadata={
|
||||
"description": "Training history.",
|
||||
"example": {
|
||||
"loss": [0.1, 0.2, 0.3],
|
||||
"accuracy": [0.9, 0.8, 0.7],
|
||||
},
|
||||
},
|
||||
)
|
Loading…
Reference in a new issue