[#398] Removed custom Response objects from Tensorflow and response docs generation logic.

Closes: #398
This commit is contained in:
Fabio Manganiello 2024-05-15 09:55:17 +02:00
parent 77c91aa5e3
commit 98a98ea1dc
Signed by untrusted user: blacklight
GPG key ID: D90FBA7F76362774
7 changed files with 100 additions and 152 deletions

View file

@ -50,7 +50,6 @@ Reference
backends
plugins
events
responses
Indices and tables
==================

View file

@ -1,5 +0,0 @@
``tensorflow``
=========================================
.. automodule:: platypush.message.response.tensorflow
:members:

View file

@ -1,9 +0,0 @@
Responses
=========
.. toctree::
:maxdepth: 1
:caption: Responses:
platypush/responses/tensorflow.rst

View file

@ -8,7 +8,6 @@ import pkgutil
from platypush.backend import Backend
from platypush.message.event import Event
from platypush.message.response import Response
from platypush.plugins import Plugin
from platypush.utils.manifest import Manifests
from platypush.utils.mock import auto_mocks
@ -26,10 +25,6 @@ def get_all_events():
return _get_modules(Event)
def get_all_responses():
return _get_modules(Response)
def _get_modules(base_type: type):
ret = set()
base_dir = os.path.dirname(inspect.getfile(base_type))
@ -151,20 +146,11 @@ def generate_events_doc():
)
def generate_responses_doc():
_generate_components_doc(
index_name='responses',
package_name='message.response',
components=sorted(response for response in get_all_responses() if response),
)
def main():
with auto_mocks():
generate_plugins_doc()
generate_backends_doc()
generate_events_doc()
generate_responses_doc()
if __name__ == '__main__':

View file

@ -1,80 +0,0 @@
from typing import Dict, List, Union, Optional
import numpy as np
from platypush.message.response import Response
class TensorflowResponse(Response):
"""
Generic Tensorflow response.
"""
def __init__(self, *args, model, model_name: Optional[str] = None, **kwargs):
"""
:param model: Name of the model.
"""
super().__init__(
*args,
output={
'model': model_name or model.name,
},
**kwargs
)
self.model = model
class TensorflowTrainResponse(TensorflowResponse):
"""
Tensorflow model fit/train response.
"""
def __init__(
self,
*args,
epochs: List[int],
history: Dict[str, List[Union[int, float]]],
**kwargs
):
"""
:param epochs: List of epoch indexes the model has been trained on.
:param history: Train history, as a ``metric -> [values]`` dictionary where each value in ``values`` is
the value for of that metric on a specific epoch.
"""
super().__init__(*args, **kwargs)
self.output['epochs'] = epochs
self.output['history'] = history
class TensorflowPredictResponse(TensorflowResponse):
"""
Tensorflow model prediction response.
"""
def __init__(
self,
*args,
prediction: np.ndarray,
output_labels: Optional[List[str]] = None,
**kwargs
):
super().__init__(*args, **kwargs)
if output_labels and len(output_labels) == self.model.outputs[-1].shape[-1]:
self.output['outputs'] = [
{output_labels[i]: value for i, value in enumerate(p)}
for p in prediction
]
else:
self.output['outputs'] = prediction
if self.model.__class__.__name__ != 'LinearModel':
prediction = [int(np.argmax(p)) for p in prediction]
if output_labels:
self.output['predictions'] = [output_labels[p] for p in prediction]
else:
self.output['predictions'] = prediction
# vim:sw=4:ts=4:et:

View file

@ -20,11 +20,8 @@ from platypush.message.event.tensorflow import (
TensorflowTrainStartedEvent,
TensorflowTrainEndedEvent,
)
from platypush.message.response.tensorflow import (
TensorflowTrainResponse,
TensorflowPredictResponse,
)
from platypush.plugins import Plugin, action
from platypush.schemas.tensorflow import TensorflowTrainSchema
class TensorflowPlugin(Plugin):
@ -50,11 +47,11 @@ class TensorflowPlugin(Plugin):
super().__init__(**kwargs)
self.models = {} # str -> Model
self._models_lock = threading.RLock()
self._model_locks: Dict[str, threading.RLock()] = {}
self._model_locks: Dict[str, threading.RLock] = {}
self._work_dir = (
os.path.abspath(os.path.expanduser(workdir))
if workdir
else os.path.join(Config.get('workdir'), 'tensorflow')
else os.path.join(Config.get_workdir(), 'tensorflow')
)
self._models_dir = os.path.join(self._work_dir, 'models')
@ -99,6 +96,8 @@ class TensorflowPlugin(Plugin):
elif os.path.isdir(model_name):
model_dir = model_name
model = load_model(model_dir)
else:
raise FileNotFoundError(f'Model not found: {model_name}')
assert model, 'Could not find model: {}'.format(model_name)
model.input_labels = []
@ -196,8 +195,7 @@ class TensorflowPlugin(Plugin):
loaded, otherwise the model currently in memory will be kept (default: ``False``).
:return: The model configuration.
"""
model = self._load_model(model, reload=reload)
return model.get_config()
return self._load_model(model, reload=reload).get_config()
@action
def unload(self, model: str) -> None:
@ -794,12 +792,12 @@ class TensorflowPlugin(Plugin):
sample_weight: Optional[Union[np.ndarray, Iterable]] = None,
initial_epoch: int = 0,
steps_per_epoch: Optional[int] = None,
validation_steps: int = None,
validation_steps: Optional[int] = None,
validation_freq: int = 1,
max_queue_size: int = 10,
workers: int = 1,
use_multiprocessing: bool = False,
) -> TensorflowTrainResponse:
) -> dict:
"""
Trains a model on a dataset for a fixed number of epochs.
@ -915,13 +913,12 @@ class TensorflowPlugin(Plugin):
Note that because this implementation relies on multiprocessing, you should not pass non-picklable
arguments to the generator as they can't be passed easily to children processes.
:return: :class:`platypush.message.response.tensorflow.TensorflowTrainResponse`
:return: .. schema:: tensorflow.TensorflowTrainSchema
"""
name = model
model = self._load_model(model)
model_obj = self._load_model(model)
inputs, outputs = self._get_dataset(inputs, outputs, model)
ret = model.fit(
ret = model_obj.fit(
x=inputs,
y=outputs,
batch_size=batch_size,
@ -942,8 +939,14 @@ class TensorflowPlugin(Plugin):
use_multiprocessing=use_multiprocessing,
)
return TensorflowTrainResponse(
model=model, model_name=name, epochs=ret.epoch, history=ret.history
return dict(
TensorflowTrainSchema().dump(
{
'model': name,
'epochs': ret.epoch,
'history': ret.history,
}
)
)
@action
@ -1035,10 +1038,10 @@ class TensorflowPlugin(Plugin):
"""
name = model
model = self._load_model(model)
inputs, outputs = self._get_dataset(inputs, outputs, model)
model_obj = self._load_model(model)
inputs, outputs = self._get_dataset(inputs, outputs, model_obj)
ret = model.evaluate(
ret = model_obj.evaluate(
x=inputs,
y=outputs,
batch_size=batch_size,
@ -1052,10 +1055,10 @@ class TensorflowPlugin(Plugin):
)
ret = ret if isinstance(ret, list) else [ret]
if not model.metrics_names:
if not model_obj.metrics_names:
return ret
return {model.metrics_names[i]: value for i, value in enumerate(ret)}
return {model_obj.metrics_names[i]: value for i, value in enumerate(ret)}
@action
def predict(
@ -1070,7 +1073,7 @@ class TensorflowPlugin(Plugin):
max_queue_size: int = 10,
workers: int = 1,
use_multiprocessing: bool = False,
) -> TensorflowPredictResponse:
) -> dict:
"""
Generates output predictions for the input samples.
@ -1114,7 +1117,7 @@ class TensorflowPlugin(Plugin):
Note that because this implementation relies on multiprocessing, you should not pass non-picklable
arguments to the generator as they can't be passed easily to children processes.
:return: :class:`platypush.message.response.tensorflow.TensorflowPredictResponse`. Format:
:return: Format:
- For regression models with no output labels specified: ``outputs`` will contain the output vector:
@ -1158,19 +1161,19 @@ class TensorflowPlugin(Plugin):
"""
name = model
model = self._load_model(model)
inputs = self._get_data(inputs, model)
model_obj = self._load_model(model)
inputs = self._get_data(inputs, model_obj)
if (
isinstance(inputs, np.ndarray)
and len(model.inputs[0].shape) == len(inputs.shape) + 1
and len(model_obj.inputs[0].shape) == len(inputs.shape) + 1
and (
model.inputs[0].shape[0] is None
or model.inputs[0].shape[0].value is None
model_obj.inputs[0].shape[0] is None
or model_obj.inputs[0].shape[0].value is None
)
):
inputs = np.asarray([inputs])
ret = model.predict(
ret = model_obj.predict(
inputs,
batch_size=batch_size,
verbose=verbose,
@ -1181,12 +1184,28 @@ class TensorflowPlugin(Plugin):
use_multiprocessing=use_multiprocessing,
)
return TensorflowPredictResponse(
model=model,
model_name=name,
prediction=ret,
output_labels=model.output_labels,
)
if (
model_obj.output_labels
and len(model_obj.output_labels) == model_obj.outputs[-1].shape[-1]
):
outputs = [
{model_obj.output_labels[i]: value for i, value in enumerate(p)}
for p in ret
]
else:
outputs = ret
if model_obj.__class__.__name__ != 'LinearModel':
ret = [int(np.argmax(p)) for p in ret]
if model_obj.output_labels:
ret = [model_obj.output_labels[p] for p in ret]
return {
'model': name,
'outputs': outputs,
'prediction': ret,
'output_labels': model_obj.output_labels,
}
@action
def save(self, model: str, overwrite: bool = True, **opts) -> None:
@ -1212,28 +1231,30 @@ class TensorflowPlugin(Plugin):
model_dir = str(pathlib.Path(model_file).parent)
elif os.path.isdir(model_file):
model_dir = model_file
else:
raise FileNotFoundError(f'No such model loaded: {model_name}')
model = self.models.get(model_name, self.models.get(model_dir))
assert model, 'No such model loaded: {}'.format(model_name)
model_obj = self.models.get(model_name, self.models.get(model_dir))
assert model_obj, f'No such model loaded: {model_name}'
pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)
with self._lock_model(model_name):
labels = {}
labels_file = os.path.join(model_dir, 'labels.json')
if hasattr(model, 'input_labels') and model.input_labels:
labels['input'] = model.input_labels
if hasattr(model, 'output_labels') and model.output_labels:
if hasattr(model_obj, 'input_labels') and model_obj.input_labels:
labels['input'] = model_obj.input_labels
if hasattr(model_obj, 'output_labels') and model_obj.output_labels:
if hasattr(labels, 'input'):
labels['output'] = model.output_labels
labels['output'] = model_obj.output_labels
else:
labels = model.output_labels
labels = model_obj.output_labels
if labels:
with open(labels_file, 'w') as f:
json.dump(labels, f)
model.save(
model_obj.save(
model_name if os.path.isfile(model_name) else model_dir,
overwrite=overwrite,
options=opts,

View file

@ -0,0 +1,36 @@
from marshmallow import fields
from marshmallow.schema import Schema
from platypush.schemas import StrippedString
class TensorflowTrainSchema(Schema):
"""
Schema for TensorFlow model training results.
"""
model = StrippedString(
required=True,
metadata={
"description": "Model name.",
"example": "MyModel",
},
)
epochs = fields.Int(
required=True,
metadata={
"description": "Number of epochs.",
"example": 10,
},
)
history = fields.Dict(
metadata={
"description": "Training history.",
"example": {
"loss": [0.1, 0.2, 0.3],
"accuracy": [0.9, 0.8, 0.7],
},
},
)