From 37e006d86e4a91f8a18d54be4083421c841fb876 Mon Sep 17 00:00:00 2001 From: Fabio Manganiello Date: Mon, 12 Oct 2020 01:06:32 +0200 Subject: [PATCH] Refined Tensorflow train methods --- platypush/plugins/tensorflow/__init__.py | 45 ++++++++++++++++++------ 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/platypush/plugins/tensorflow/__init__.py b/platypush/plugins/tensorflow/__init__.py index e6d423be53..bd2f9a29be 100644 --- a/platypush/plugins/tensorflow/__init__.py +++ b/platypush/plugins/tensorflow/__init__.py @@ -634,9 +634,28 @@ class TensorflowPlugin(Plugin): return ret + @classmethod + def _get_outputs(cls, data: Union[str, np.ndarray, Iterable], model: Model) -> np.ndarray: + if isinstance(data, str): + if model.output_labels: + label_index = model.output_labels.index(data) + if label_index >= 0: + return np.array([1 if i == label_index else 0 for i in range(len(model.output_labels))]) + + return np.array([data]) + + if len(data) > 0 and isinstance(data[0], str): + return np.array([cls._get_outputs(item, model) for item in data]) + + return data + @classmethod def _get_data(cls, data: Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]], model: Model) \ -> Union[np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]]: + if isinstance(data, List) or isinstance(data, Tuple): + if len(data) and isinstance(data[0], str): + return np.array([cls._get_data(item, model) for item in data]) + if not isinstance(data, str): return data @@ -668,6 +687,8 @@ class TensorflowPlugin(Plugin): elif os.path.isdir(data_file): return cls._get_dir(data_file, model) + return data + @classmethod def _get_dataset(cls, inputs: Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]], @@ -677,12 +698,13 @@ class TensorflowPlugin(Plugin): Optional[Union[np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]]]]: inputs = cls._get_data(inputs, model) if outputs: - outputs = cls._get_data(inputs, model) + outputs = cls._get_outputs(outputs, model) elif isinstance(inputs, dict) and model.output_labels: pairs = [] for i, label in enumerate(model.output_labels): data = inputs.get(label, []) - pairs.extend([(d, i) for d in data]) + pairs.extend([(d, tuple(1 if i == j else 0 for j, _ in enumerate(model.output_labels))) + for d in data]) random.shuffle(pairs) inputs = np.asarray([p[0] for p in pairs]) @@ -1091,19 +1113,20 @@ class TensorflowPlugin(Plugin): model_name = model model_dir = None - if os.path.isdir(os.path.join(self._work_dir, model_name)): - model_dir = os.path.join(self._work_dir, model_name) + if os.path.isdir(os.path.join(self._models_dir, model_name)) or model_name in self.models: + model_dir = os.path.join(self._models_dir, model_name) else: - model_name = os.path.abspath(os.path.expanduser(model_name)) - if os.path.isfile(model_name): - model_dir = str(pathlib.Path(model_name).parent) - elif os.path.isdir(model_name): - model_dir = model_name + model_file = os.path.abspath(os.path.expanduser(model_name)) + if os.path.isfile(model_file): + model_dir = str(pathlib.Path(model_file).parent) + elif os.path.isdir(model_file): + model_dir = model_file - assert model_dir and model_name in self.models, 'No such model loaded: {}'.format(model) + model = self.models.get(model_name, self.models.get(model_dir)) + assert model, 'No such model loaded: {}'.format(model_name) + pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True) with self._lock_model(model_name): - model = self.models[model_name] labels = {} labels_file = os.path.join(model_dir, 'labels.json')