Refined Tensorflow train methods
This commit is contained in:
parent
8d7e790eda
commit
37e006d86e
1 changed files with 34 additions and 11 deletions
|
@ -634,9 +634,28 @@ class TensorflowPlugin(Plugin):
|
|||
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def _get_outputs(cls, data: Union[str, np.ndarray, Iterable], model: Model) -> np.ndarray:
|
||||
if isinstance(data, str):
|
||||
if model.output_labels:
|
||||
label_index = model.output_labels.index(data)
|
||||
if label_index >= 0:
|
||||
return np.array([1 if i == label_index else 0 for i in range(len(model.output_labels))])
|
||||
|
||||
return np.array([data])
|
||||
|
||||
if len(data) > 0 and isinstance(data[0], str):
|
||||
return np.array([cls._get_outputs(item, model) for item in data])
|
||||
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def _get_data(cls, data: Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]], model: Model) \
|
||||
-> Union[np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]]:
|
||||
if isinstance(data, List) or isinstance(data, Tuple):
|
||||
if len(data) and isinstance(data[0], str):
|
||||
return np.array([cls._get_data(item, model) for item in data])
|
||||
|
||||
if not isinstance(data, str):
|
||||
return data
|
||||
|
||||
|
@ -668,6 +687,8 @@ class TensorflowPlugin(Plugin):
|
|||
elif os.path.isdir(data_file):
|
||||
return cls._get_dir(data_file, model)
|
||||
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def _get_dataset(cls,
|
||||
inputs: Union[str, np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]],
|
||||
|
@ -677,12 +698,13 @@ class TensorflowPlugin(Plugin):
|
|||
Optional[Union[np.ndarray, Iterable, Dict[str, Union[Iterable, np.ndarray]]]]]:
|
||||
inputs = cls._get_data(inputs, model)
|
||||
if outputs:
|
||||
outputs = cls._get_data(inputs, model)
|
||||
outputs = cls._get_outputs(outputs, model)
|
||||
elif isinstance(inputs, dict) and model.output_labels:
|
||||
pairs = []
|
||||
for i, label in enumerate(model.output_labels):
|
||||
data = inputs.get(label, [])
|
||||
pairs.extend([(d, i) for d in data])
|
||||
pairs.extend([(d, tuple(1 if i == j else 0 for j, _ in enumerate(model.output_labels)))
|
||||
for d in data])
|
||||
|
||||
random.shuffle(pairs)
|
||||
inputs = np.asarray([p[0] for p in pairs])
|
||||
|
@ -1091,19 +1113,20 @@ class TensorflowPlugin(Plugin):
|
|||
model_name = model
|
||||
model_dir = None
|
||||
|
||||
if os.path.isdir(os.path.join(self._work_dir, model_name)):
|
||||
model_dir = os.path.join(self._work_dir, model_name)
|
||||
if os.path.isdir(os.path.join(self._models_dir, model_name)) or model_name in self.models:
|
||||
model_dir = os.path.join(self._models_dir, model_name)
|
||||
else:
|
||||
model_name = os.path.abspath(os.path.expanduser(model_name))
|
||||
if os.path.isfile(model_name):
|
||||
model_dir = str(pathlib.Path(model_name).parent)
|
||||
elif os.path.isdir(model_name):
|
||||
model_dir = model_name
|
||||
model_file = os.path.abspath(os.path.expanduser(model_name))
|
||||
if os.path.isfile(model_file):
|
||||
model_dir = str(pathlib.Path(model_file).parent)
|
||||
elif os.path.isdir(model_file):
|
||||
model_dir = model_file
|
||||
|
||||
assert model_dir and model_name in self.models, 'No such model loaded: {}'.format(model)
|
||||
model = self.models.get(model_name, self.models.get(model_dir))
|
||||
assert model, 'No such model loaded: {}'.format(model_name)
|
||||
pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with self._lock_model(model_name):
|
||||
model = self.models[model_name]
|
||||
labels = {}
|
||||
labels_file = os.path.join(model_dir, 'labels.json')
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue