diff --git a/platypush/message/response/tensorflow.py b/platypush/message/response/tensorflow.py index 31bf318a1..75656e5e1 100644 --- a/platypush/message/response/tensorflow.py +++ b/platypush/message/response/tensorflow.py @@ -10,12 +10,12 @@ class TensorflowResponse(Response): """ Generic Tensorflow response. """ - def __init__(self, *args, model: Model, **kwargs): + def __init__(self, *args, model: Model, model_name: Optional[str] = None, **kwargs): """ :param model: Name of the model. """ super().__init__(*args, output={ - 'model': model.name, + 'model': model_name or model.name, }, **kwargs) self.model = model diff --git a/platypush/plugins/tensorflow/__init__.py b/platypush/plugins/tensorflow/__init__.py index 8e6fbdbf4..c903fa4d4 100644 --- a/platypush/plugins/tensorflow/__init__.py +++ b/platypush/plugins/tensorflow/__init__.py @@ -610,7 +610,11 @@ class TensorflowPlugin(Plugin): 'Found: {}'.format(colors[0])) img = image.load_img(image_file, target_size=size, color_mode=color_mode) - return image.img_to_array(img) + data = image.img_to_array(img) + if data.shape[-1] == 1: + # Squeeze extra color channels + data = np.squeeze(data) + return data @classmethod def _get_dir(cls, directory: str, model: Model) -> Dict[str, Iterable]: @@ -847,7 +851,7 @@ class TensorflowPlugin(Plugin): use_multiprocessing=use_multiprocessing, ) - return TensorflowTrainResponse(model=model, epochs=ret.epoch, history=ret.history) + return TensorflowTrainResponse(model=model, model_name=name, epochs=ret.epoch, history=ret.history) @action def evaluate(self, @@ -1071,7 +1075,8 @@ class TensorflowPlugin(Plugin): use_multiprocessing=use_multiprocessing ) - return TensorflowPredictResponse(model=model, prediction=ret, output_labels=model.output_labels) + return TensorflowPredictResponse(model=model, model_name=name, prediction=ret, + output_labels=model.output_labels) @action def save(self, model: str, overwrite: bool = True, **opts) -> None: