forked from platypush/platypush
62 lines
1.9 KiB
Python
62 lines
1.9 KiB
Python
from typing import Dict, List, Union, Optional
|
|
|
|
import numpy as np
|
|
from tensorflow.keras.models import Model
|
|
|
|
from platypush.message.response import Response
|
|
|
|
|
|
class TensorflowResponse(Response):
|
|
"""
|
|
Generic Tensorflow response.
|
|
"""
|
|
def __init__(self, *args, model: Model, **kwargs):
|
|
"""
|
|
:param model: Name of the model.
|
|
"""
|
|
super().__init__(*args, output={
|
|
'model': model.name,
|
|
}, **kwargs)
|
|
|
|
self.model = model
|
|
|
|
|
|
class TensorflowTrainResponse(TensorflowResponse):
|
|
"""
|
|
Tensorflow model fit/train response.
|
|
"""
|
|
def __init__(self, *args, epochs: List[int], history: Dict[str, List[Union[int, float]]], **kwargs):
|
|
"""
|
|
:param epochs: List of epoch indexes the model has been trained on.
|
|
:param history: Train history, as a ``metric -> [values]`` dictionary where each value in ``values`` is
|
|
the value for of that metric on a specific epoch.
|
|
"""
|
|
super().__init__(*args, **kwargs)
|
|
self.output['epochs'] = epochs
|
|
self.output['history'] = history
|
|
|
|
|
|
class TensorflowPredictResponse(TensorflowResponse):
|
|
"""
|
|
Tensorflow model prediction response.
|
|
"""
|
|
def __init__(self, *args, prediction: np.ndarray, output_labels: Optional[List[str]] = None, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
if output_labels and len(output_labels) == self.model.outputs[-1].shape[-1]:
|
|
self.output['values'] = [
|
|
{output_labels[i]: value for i, value in enumerate(p)}
|
|
for p in prediction
|
|
]
|
|
else:
|
|
self.output['values'] = prediction
|
|
|
|
if self.model.__class__.__name__ != 'LinearModel':
|
|
prediction = [int(np.argmax(p)) for p in prediction]
|
|
if output_labels:
|
|
self.output['labels'] = [output_labels[p] for p in prediction]
|
|
else:
|
|
self.output['labels'] = prediction
|
|
|
|
|
|
# vim:sw=4:ts=4:et:
|