mirror of
https://github.com/BlackLight/micmon.git
synced 2024-11-12 20:17:15 +01:00
Support for both models saved in .h5/.pb format and as Keras directories
This commit is contained in:
parent
604c315b5e
commit
fa786f56b9
1 changed files with 19 additions and 6 deletions
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import numpy as np
|
||||
|
||||
from typing import List, Optional
|
||||
|
@ -45,10 +46,16 @@ class Model:
|
|||
prediction = int(np.argmax(output))
|
||||
return self.label_names[prediction] if self.label_names else prediction
|
||||
|
||||
def save(self, model_dir: str, *args, **kwargs):
|
||||
model_dir = os.path.abspath(os.path.expanduser(model_dir))
|
||||
self._model.save(model_dir, *args, **kwargs)
|
||||
def save(self, path: str, *args, **kwargs):
|
||||
path = os.path.abspath(os.path.expanduser(path))
|
||||
is_file = path.endswith('.h5') or path.endswith('.pb')
|
||||
if is_file:
|
||||
model_dir = str(pathlib.Path(path).parent)
|
||||
else:
|
||||
model_dir = path
|
||||
|
||||
pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)
|
||||
self._model.save(path, *args, **kwargs)
|
||||
if self.label_names:
|
||||
labels_file = os.path.join(model_dir, self.labels_file_name)
|
||||
with open(labels_file, 'w') as f:
|
||||
|
@ -60,9 +67,15 @@ class Model:
|
|||
json.dump(self.cutoff_frequencies, f)
|
||||
|
||||
@classmethod
|
||||
def load(cls, model_dir: str, *args, **kwargs):
|
||||
model_dir = os.path.abspath(os.path.expanduser(model_dir))
|
||||
model = load_model(model_dir, *args, **kwargs)
|
||||
def load(cls, path: str, *args, **kwargs):
|
||||
path = os.path.abspath(os.path.expanduser(path))
|
||||
is_file = path.endswith('.h5') or path.endswith('.pb')
|
||||
if is_file:
|
||||
model_dir = str(pathlib.Path(path).parent)
|
||||
else:
|
||||
model_dir = path
|
||||
|
||||
model = load_model(path, *args, **kwargs)
|
||||
labels_file = os.path.join(model_dir, cls.labels_file_name)
|
||||
freq_file = os.path.join(model_dir, cls.freq_file_name)
|
||||
label_names = []
|
||||
|
|
Loading…
Reference in a new issue