From fa786f56b9e3427c7e74c9fbc70edf8b2767bb63 Mon Sep 17 00:00:00 2001 From: Fabio Manganiello Date: Wed, 28 Oct 2020 20:06:45 +0100 Subject: [PATCH] Support for both models saved in .h5/.pb format and as Keras directories --- micmon/model/__init__.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/micmon/model/__init__.py b/micmon/model/__init__.py index 57777bc..6c3f6c8 100644 --- a/micmon/model/__init__.py +++ b/micmon/model/__init__.py @@ -1,5 +1,6 @@ import json import os +import pathlib import numpy as np from typing import List, Optional @@ -45,10 +46,16 @@ class Model: prediction = int(np.argmax(output)) return self.label_names[prediction] if self.label_names else prediction - def save(self, model_dir: str, *args, **kwargs): - model_dir = os.path.abspath(os.path.expanduser(model_dir)) - self._model.save(model_dir, *args, **kwargs) + def save(self, path: str, *args, **kwargs): + path = os.path.abspath(os.path.expanduser(path)) + is_file = path.endswith('.h5') or path.endswith('.pb') + if is_file: + model_dir = str(pathlib.Path(path).parent) + else: + model_dir = path + pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True) + self._model.save(path, *args, **kwargs) if self.label_names: labels_file = os.path.join(model_dir, self.labels_file_name) with open(labels_file, 'w') as f: @@ -60,9 +67,15 @@ class Model: json.dump(self.cutoff_frequencies, f) @classmethod - def load(cls, model_dir: str, *args, **kwargs): - model_dir = os.path.abspath(os.path.expanduser(model_dir)) - model = load_model(model_dir, *args, **kwargs) + def load(cls, path: str, *args, **kwargs): + path = os.path.abspath(os.path.expanduser(path)) + is_file = path.endswith('.h5') or path.endswith('.pb') + if is_file: + model_dir = str(pathlib.Path(path).parent) + else: + model_dir = path + + model = load_model(path, *args, **kwargs) labels_file = os.path.join(model_dir, cls.labels_file_name) freq_file = os.path.join(model_dir, cls.freq_file_name) label_names = []