Squeeze the extra dimensions in a grayscale image and pass the model name to the response objects

This commit is contained in:
Fabio Manganiello 2020-10-01 18:50:36 +02:00
parent 287b6303ae
commit daaa0050d1
2 changed files with 10 additions and 5 deletions

View file

@ -10,12 +10,12 @@ class TensorflowResponse(Response):
"""
Generic Tensorflow response.
"""
def __init__(self, *args, model: Model, **kwargs):
def __init__(self, *args, model: Model, model_name: Optional[str] = None, **kwargs):
"""
:param model: Name of the model.
"""
super().__init__(*args, output={
'model': model.name,
'model': model_name or model.name,
}, **kwargs)
self.model = model

View file

@ -610,7 +610,11 @@ class TensorflowPlugin(Plugin):
'Found: {}'.format(colors[0]))
img = image.load_img(image_file, target_size=size, color_mode=color_mode)
return image.img_to_array(img)
data = image.img_to_array(img)
if data.shape[-1] == 1:
# Squeeze extra color channels
data = np.squeeze(data)
return data
@classmethod
def _get_dir(cls, directory: str, model: Model) -> Dict[str, Iterable]:
@ -847,7 +851,7 @@ class TensorflowPlugin(Plugin):
use_multiprocessing=use_multiprocessing,
)
return TensorflowTrainResponse(model=model, epochs=ret.epoch, history=ret.history)
return TensorflowTrainResponse(model=model, model_name=name, epochs=ret.epoch, history=ret.history)
@action
def evaluate(self,
@ -1071,7 +1075,8 @@ class TensorflowPlugin(Plugin):
use_multiprocessing=use_multiprocessing
)
return TensorflowPredictResponse(model=model, prediction=ret, output_labels=model.output_labels)
return TensorflowPredictResponse(model=model, model_name=name, prediction=ret,
output_labels=model.output_labels)
@action
def save(self, model: str, overwrite: bool = True, **opts) -> None: