diff --git a/platypush/plugins/tensorflow/__init__.py b/platypush/plugins/tensorflow/__init__.py index 55d39df7..83a68f3f 100644 --- a/platypush/plugins/tensorflow/__init__.py +++ b/platypush/plugins/tensorflow/__init__.py @@ -1010,7 +1010,7 @@ class TensorflowPlugin(Plugin): inputs = self._get_data(inputs, model) if isinstance(inputs, np.ndarray) and \ len(model.inputs[0].shape) == len(inputs.shape) + 1 and \ - model.inputs[0].shape[0] is None: + (model.inputs[0].shape[0] is None or model.inputs[0].shape[0].value is None): inputs = np.asarray([inputs]) ret = model.predict(