micmon/micmon/model/__init__.py

93 lines
3.4 KiB
Python
Raw Normal View History

2020-10-27 15:21:32 +01:00
import json
import os
import pathlib
2020-10-27 15:21:32 +01:00
import numpy as np
2020-10-28 19:46:20 +01:00
from typing import List, Optional
2020-10-28 22:58:59 +01:00
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Layer
from tensorflow.keras.models import load_model, Model as _Model
2020-10-27 15:21:32 +01:00
from micmon.audio import AudioSegment
from micmon.dataset import Dataset
class Model:
labels_file_name = 'labels.json'
freq_file_name = 'freq.json'
# noinspection PyShadowingNames
def __init__(self, layers: Optional[List[Layer]] = None, labels: Optional[List[str]] = None,
2020-10-28 19:46:20 +01:00
model: Optional[_Model] = None,
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=('accuracy',),
2020-10-28 18:12:19 +01:00
low_freq: int = AudioSegment.default_low_freq,
high_freq: int = AudioSegment.default_high_freq):
2020-10-27 15:21:32 +01:00
assert layers or model
self.label_names = labels
2020-10-28 18:12:19 +01:00
self.cutoff_frequencies = (int(low_freq), int(high_freq))
2020-10-27 15:21:32 +01:00
if layers:
self._model = Sequential(layers)
self._model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
else:
self._model = model
def fit(self, dataset: Dataset, *args, **kwargs):
return self._model.fit(dataset.train_samples, dataset.train_classes, *args, **kwargs)
def evaluate(self, dataset: Dataset, *args, **kwargs):
return self._model.evaluate(dataset.validation_samples, dataset.validation_classes, *args, **kwargs)
def predict(self, audio: AudioSegment):
spectrum = audio.spectrum(low_freq=self.cutoff_frequencies[0], high_freq=self.cutoff_frequencies[1])
output = self._model.predict(np.array([spectrum]))
prediction = int(np.argmax(output))
return self.label_names[prediction] if self.label_names else prediction
def save(self, path: str, *args, **kwargs):
path = os.path.abspath(os.path.expanduser(path))
is_file = path.endswith('.h5') or path.endswith('.pb')
if is_file:
model_dir = str(pathlib.Path(path).parent)
else:
model_dir = path
2020-10-27 15:21:32 +01:00
pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)
self._model.save(path, *args, **kwargs)
2020-10-27 15:21:32 +01:00
if self.label_names:
labels_file = os.path.join(model_dir, self.labels_file_name)
with open(labels_file, 'w') as f:
json.dump(self.label_names, f)
if self.cutoff_frequencies:
freq_file = os.path.join(model_dir, self.freq_file_name)
with open(freq_file, 'w') as f:
json.dump(self.cutoff_frequencies, f)
@classmethod
def load(cls, path: str, *args, **kwargs):
path = os.path.abspath(os.path.expanduser(path))
is_file = path.endswith('.h5') or path.endswith('.pb')
if is_file:
model_dir = str(pathlib.Path(path).parent)
else:
model_dir = path
model = load_model(path, *args, **kwargs)
2020-10-27 15:21:32 +01:00
labels_file = os.path.join(model_dir, cls.labels_file_name)
freq_file = os.path.join(model_dir, cls.freq_file_name)
label_names = []
frequencies = []
if os.path.isfile(labels_file):
with open(labels_file, 'r') as f:
label_names = json.load(f)
if os.path.isfile(freq_file):
with open(freq_file, 'r') as f:
frequencies = json.load(f)
2020-10-28 18:12:19 +01:00
return cls(model=model, labels=label_names, low_freq=frequencies[0], high_freq=frequencies[1])