Squeeze the extra dimensions in a grayscale image and pass the model name to the response objects
This commit is contained in:
parent
287b6303ae
commit
daaa0050d1
2 changed files with 10 additions and 5 deletions
|
@ -10,12 +10,12 @@ class TensorflowResponse(Response):
|
||||||
"""
|
"""
|
||||||
Generic Tensorflow response.
|
Generic Tensorflow response.
|
||||||
"""
|
"""
|
||||||
def __init__(self, *args, model: Model, **kwargs):
|
def __init__(self, *args, model: Model, model_name: Optional[str] = None, **kwargs):
|
||||||
"""
|
"""
|
||||||
:param model: Name of the model.
|
:param model: Name of the model.
|
||||||
"""
|
"""
|
||||||
super().__init__(*args, output={
|
super().__init__(*args, output={
|
||||||
'model': model.name,
|
'model': model_name or model.name,
|
||||||
}, **kwargs)
|
}, **kwargs)
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
|
@ -610,7 +610,11 @@ class TensorflowPlugin(Plugin):
|
||||||
'Found: {}'.format(colors[0]))
|
'Found: {}'.format(colors[0]))
|
||||||
|
|
||||||
img = image.load_img(image_file, target_size=size, color_mode=color_mode)
|
img = image.load_img(image_file, target_size=size, color_mode=color_mode)
|
||||||
return image.img_to_array(img)
|
data = image.img_to_array(img)
|
||||||
|
if data.shape[-1] == 1:
|
||||||
|
# Squeeze extra color channels
|
||||||
|
data = np.squeeze(data)
|
||||||
|
return data
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_dir(cls, directory: str, model: Model) -> Dict[str, Iterable]:
|
def _get_dir(cls, directory: str, model: Model) -> Dict[str, Iterable]:
|
||||||
|
@ -847,7 +851,7 @@ class TensorflowPlugin(Plugin):
|
||||||
use_multiprocessing=use_multiprocessing,
|
use_multiprocessing=use_multiprocessing,
|
||||||
)
|
)
|
||||||
|
|
||||||
return TensorflowTrainResponse(model=model, epochs=ret.epoch, history=ret.history)
|
return TensorflowTrainResponse(model=model, model_name=name, epochs=ret.epoch, history=ret.history)
|
||||||
|
|
||||||
@action
|
@action
|
||||||
def evaluate(self,
|
def evaluate(self,
|
||||||
|
@ -1071,7 +1075,8 @@ class TensorflowPlugin(Plugin):
|
||||||
use_multiprocessing=use_multiprocessing
|
use_multiprocessing=use_multiprocessing
|
||||||
)
|
)
|
||||||
|
|
||||||
return TensorflowPredictResponse(model=model, prediction=ret, output_labels=model.output_labels)
|
return TensorflowPredictResponse(model=model, model_name=name, prediction=ret,
|
||||||
|
output_labels=model.output_labels)
|
||||||
|
|
||||||
@action
|
@action
|
||||||
def save(self, model: str, overwrite: bool = True, **opts) -> None:
|
def save(self, model: str, overwrite: bool = True, **opts) -> None:
|
||||||
|
|
Loading…
Reference in a new issue