Support for both models saved in .h5/.pb format and as Keras directories

This commit is contained in:
Fabio Manganiello 2020-10-28 20:06:45 +01:00
parent 604c315b5e
commit fa786f56b9

View file

@ -1,5 +1,6 @@
import json import json
import os import os
import pathlib
import numpy as np import numpy as np
from typing import List, Optional from typing import List, Optional
@ -45,10 +46,16 @@ class Model:
prediction = int(np.argmax(output)) prediction = int(np.argmax(output))
return self.label_names[prediction] if self.label_names else prediction return self.label_names[prediction] if self.label_names else prediction
def save(self, model_dir: str, *args, **kwargs): def save(self, path: str, *args, **kwargs):
model_dir = os.path.abspath(os.path.expanduser(model_dir)) path = os.path.abspath(os.path.expanduser(path))
self._model.save(model_dir, *args, **kwargs) is_file = path.endswith('.h5') or path.endswith('.pb')
if is_file:
model_dir = str(pathlib.Path(path).parent)
else:
model_dir = path
pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)
self._model.save(path, *args, **kwargs)
if self.label_names: if self.label_names:
labels_file = os.path.join(model_dir, self.labels_file_name) labels_file = os.path.join(model_dir, self.labels_file_name)
with open(labels_file, 'w') as f: with open(labels_file, 'w') as f:
@ -60,9 +67,15 @@ class Model:
json.dump(self.cutoff_frequencies, f) json.dump(self.cutoff_frequencies, f)
@classmethod @classmethod
def load(cls, model_dir: str, *args, **kwargs): def load(cls, path: str, *args, **kwargs):
model_dir = os.path.abspath(os.path.expanduser(model_dir)) path = os.path.abspath(os.path.expanduser(path))
model = load_model(model_dir, *args, **kwargs) is_file = path.endswith('.h5') or path.endswith('.pb')
if is_file:
model_dir = str(pathlib.Path(path).parent)
else:
model_dir = path
model = load_model(path, *args, **kwargs)
labels_file = os.path.join(model_dir, cls.labels_file_name) labels_file = os.path.join(model_dir, cls.labels_file_name)
freq_file = os.path.join(model_dir, cls.freq_file_name) freq_file = os.path.join(model_dir, cls.freq_file_name)
label_names = [] label_names = []