Refined Tensorflow train methods
This commit is contained in:
parent
8d7e790eda
commit
37e006d86e
1 changed files with 34 additions and 11 deletions
|
@ -634,9 +634,28 @@ class TensorflowPlugin(Plugin):
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_outputs(cls, data: Union[str, np.ndarray, Iterable], model: Model) -> np.ndarray:
|
||||||
|
if isinstance(data, str):
|
||||||
|
if model.output_labels:
|
||||||
|
label_index = model.output_labels.index(data)
|
||||||
|
if label_index >= 0:
|
||||||
|
return np.array([1 if i == label_index else 0 for i in range(len(model.output_labels))])
|
||||||
|
|
||||||
|
return np.array([data])
|
||||||
|
|
||||||
|
if len(data) > 0 and isinstance(data[0], str):
|
||||||
|
return np.array([cls._get_outputs(item, model) for item in data])
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_data(cls, data: Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]], model: Model) \
|
def _get_data(cls, data: Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]], model: Model) \
|
||||||
-> Union[np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]]:
|
-> Union[np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]]:
|
||||||
|
if isinstance(data, List) or isinstance(data, Tuple):
|
||||||
|
if len(data) and isinstance(data[0], str):
|
||||||
|
return np.array([cls._get_data(item, model) for item in data])
|
||||||
|
|
||||||
if not isinstance(data, str):
|
if not isinstance(data, str):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -668,6 +687,8 @@ class TensorflowPlugin(Plugin):
|
||||||
elif os.path.isdir(data_file):
|
elif os.path.isdir(data_file):
|
||||||
return cls._get_dir(data_file, model)
|
return cls._get_dir(data_file, model)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_dataset(cls,
|
def _get_dataset(cls,
|
||||||
inputs: Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]],
|
inputs: Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]],
|
||||||
|
@ -677,12 +698,13 @@ class TensorflowPlugin(Plugin):
|
||||||
Optional[Union[np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]]]]:
|
Optional[Union[np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]]]]:
|
||||||
inputs = cls._get_data(inputs, model)
|
inputs = cls._get_data(inputs, model)
|
||||||
if outputs:
|
if outputs:
|
||||||
outputs = cls._get_data(inputs, model)
|
outputs = cls._get_outputs(outputs, model)
|
||||||
elif isinstance(inputs, dict) and model.output_labels:
|
elif isinstance(inputs, dict) and model.output_labels:
|
||||||
pairs = []
|
pairs = []
|
||||||
for i, label in enumerate(model.output_labels):
|
for i, label in enumerate(model.output_labels):
|
||||||
data = inputs.get(label, [])
|
data = inputs.get(label, [])
|
||||||
pairs.extend([(d, i) for d in data])
|
pairs.extend([(d, tuple(1 if i == j else 0 for j, _ in enumerate(model.output_labels)))
|
||||||
|
for d in data])
|
||||||
|
|
||||||
random.shuffle(pairs)
|
random.shuffle(pairs)
|
||||||
inputs = np.asarray([p[0] for p in pairs])
|
inputs = np.asarray([p[0] for p in pairs])
|
||||||
|
@ -1091,19 +1113,20 @@ class TensorflowPlugin(Plugin):
|
||||||
model_name = model
|
model_name = model
|
||||||
model_dir = None
|
model_dir = None
|
||||||
|
|
||||||
if os.path.isdir(os.path.join(self._work_dir, model_name)):
|
if os.path.isdir(os.path.join(self._models_dir, model_name)) or model_name in self.models:
|
||||||
model_dir = os.path.join(self._work_dir, model_name)
|
model_dir = os.path.join(self._models_dir, model_name)
|
||||||
else:
|
else:
|
||||||
model_name = os.path.abspath(os.path.expanduser(model_name))
|
model_file = os.path.abspath(os.path.expanduser(model_name))
|
||||||
if os.path.isfile(model_name):
|
if os.path.isfile(model_file):
|
||||||
model_dir = str(pathlib.Path(model_name).parent)
|
model_dir = str(pathlib.Path(model_file).parent)
|
||||||
elif os.path.isdir(model_name):
|
elif os.path.isdir(model_file):
|
||||||
model_dir = model_name
|
model_dir = model_file
|
||||||
|
|
||||||
assert model_dir and model_name in self.models, 'No such model loaded: {}'.format(model)
|
model = self.models.get(model_name, self.models.get(model_dir))
|
||||||
|
assert model, 'No such model loaded: {}'.format(model_name)
|
||||||
|
pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
with self._lock_model(model_name):
|
with self._lock_model(model_name):
|
||||||
model = self.models[model_name]
|
|
||||||
labels = {}
|
labels = {}
|
||||||
labels_file = os.path.join(model_dir, 'labels.json')
|
labels_file = os.path.join(model_dir, 'labels.json')
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue