mirror of
https://github.com/BlackLight/micmon.git
synced 2024-11-27 22:25:13 +01:00
Support for both models saved in .h5/.pb format and as Keras directories
This commit is contained in:
parent
604c315b5e
commit
fa786f56b9
1 changed files with 19 additions and 6 deletions
|
@ -1,5 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import pathlib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
@ -45,10 +46,16 @@ class Model:
|
||||||
prediction = int(np.argmax(output))
|
prediction = int(np.argmax(output))
|
||||||
return self.label_names[prediction] if self.label_names else prediction
|
return self.label_names[prediction] if self.label_names else prediction
|
||||||
|
|
||||||
def save(self, model_dir: str, *args, **kwargs):
|
def save(self, path: str, *args, **kwargs):
|
||||||
model_dir = os.path.abspath(os.path.expanduser(model_dir))
|
path = os.path.abspath(os.path.expanduser(path))
|
||||||
self._model.save(model_dir, *args, **kwargs)
|
is_file = path.endswith('.h5') or path.endswith('.pb')
|
||||||
|
if is_file:
|
||||||
|
model_dir = str(pathlib.Path(path).parent)
|
||||||
|
else:
|
||||||
|
model_dir = path
|
||||||
|
|
||||||
|
pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
self._model.save(path, *args, **kwargs)
|
||||||
if self.label_names:
|
if self.label_names:
|
||||||
labels_file = os.path.join(model_dir, self.labels_file_name)
|
labels_file = os.path.join(model_dir, self.labels_file_name)
|
||||||
with open(labels_file, 'w') as f:
|
with open(labels_file, 'w') as f:
|
||||||
|
@ -60,9 +67,15 @@ class Model:
|
||||||
json.dump(self.cutoff_frequencies, f)
|
json.dump(self.cutoff_frequencies, f)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, model_dir: str, *args, **kwargs):
|
def load(cls, path: str, *args, **kwargs):
|
||||||
model_dir = os.path.abspath(os.path.expanduser(model_dir))
|
path = os.path.abspath(os.path.expanduser(path))
|
||||||
model = load_model(model_dir, *args, **kwargs)
|
is_file = path.endswith('.h5') or path.endswith('.pb')
|
||||||
|
if is_file:
|
||||||
|
model_dir = str(pathlib.Path(path).parent)
|
||||||
|
else:
|
||||||
|
model_dir = path
|
||||||
|
|
||||||
|
model = load_model(path, *args, **kwargs)
|
||||||
labels_file = os.path.join(model_dir, cls.labels_file_name)
|
labels_file = os.path.join(model_dir, cls.labels_file_name)
|
||||||
freq_file = os.path.join(model_dir, cls.freq_file_name)
|
freq_file = os.path.join(model_dir, cls.freq_file_name)
|
||||||
label_names = []
|
label_names = []
|
||||||
|
|
Loading…
Reference in a new issue