Refined Tensorflow train methods

This commit is contained in:
Fabio Manganiello 2020-10-12 01:06:32 +02:00
parent 8d7e790eda
commit 37e006d86e

View file

@ -634,9 +634,28 @@ class TensorflowPlugin(Plugin):
return ret
@classmethod
def _get_outputs(cls, data: Union[str, np.ndarray, Iterable], model: Model) -> np.ndarray:
if isinstance(data, str):
if model.output_labels:
label_index = model.output_labels.index(data)
if label_index >= 0:
return np.array([1 if i == label_index else 0 for i in range(len(model.output_labels))])
return np.array([data])
if len(data) > 0 and isinstance(data[0], str):
return np.array([cls._get_outputs(item, model) for item in data])
return data
@classmethod
def _get_data(cls, data: Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]], model: Model) \
-> Union[np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]]:
if isinstance(data, List) or isinstance(data, Tuple):
if len(data) and isinstance(data[0], str):
return np.array([cls._get_data(item, model) for item in data])
if not isinstance(data, str):
return data
@ -668,6 +687,8 @@ class TensorflowPlugin(Plugin):
elif os.path.isdir(data_file):
return cls._get_dir(data_file, model)
return data
@classmethod
def _get_dataset(cls,
inputs: Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]],
@ -677,12 +698,13 @@ class TensorflowPlugin(Plugin):
Optional[Union[np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]]]]:
inputs = cls._get_data(inputs, model)
if outputs:
outputs = cls._get_data(inputs, model)
outputs = cls._get_outputs(outputs, model)
elif isinstance(inputs, dict) and model.output_labels:
pairs = []
for i, label in enumerate(model.output_labels):
data = inputs.get(label, [])
pairs.extend([(d, i) for d in data])
pairs.extend([(d, tuple(1 if i == j else 0 for j, _ in enumerate(model.output_labels)))
for d in data])
random.shuffle(pairs)
inputs = np.asarray([p[0] for p in pairs])
@ -1091,19 +1113,20 @@ class TensorflowPlugin(Plugin):
model_name = model
model_dir = None
if os.path.isdir(os.path.join(self._work_dir, model_name)):
model_dir = os.path.join(self._work_dir, model_name)
if os.path.isdir(os.path.join(self._models_dir, model_name)) or model_name in self.models:
model_dir = os.path.join(self._models_dir, model_name)
else:
model_name = os.path.abspath(os.path.expanduser(model_name))
if os.path.isfile(model_name):
model_dir = str(pathlib.Path(model_name).parent)
elif os.path.isdir(model_name):
model_dir = model_name
model_file = os.path.abspath(os.path.expanduser(model_name))
if os.path.isfile(model_file):
model_dir = str(pathlib.Path(model_file).parent)
elif os.path.isdir(model_file):
model_dir = model_file
assert model_dir and model_name in self.models, 'No such model loaded: {}'.format(model)
model = self.models.get(model_name, self.models.get(model_dir))
assert model, 'No such model loaded: {}'.format(model_name)
pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)
with self._lock_model(model_name):
model = self.models[model_name]
labels = {}
labels_file = os.path.join(model_dir, 'labels.json')