micmon/micmon/model/__init__.py

79 lines
3.1 KiB
Python
Raw Normal View History

2020-10-27 15:21:32 +01:00
import json
import os
import numpy as np
2020-10-28 18:12:19 +01:00
from typing import List, Optional, Union
2020-10-27 15:21:32 +01:00
from keras import Sequential, losses, optimizers, metrics
from keras.layers import Layer
from keras.models import load_model, Model as _Model
from micmon.audio import AudioSegment
from micmon.dataset import Dataset
class Model:
labels_file_name = 'labels.json'
freq_file_name = 'freq.json'
# noinspection PyShadowingNames
def __init__(self, layers: Optional[List[Layer]] = None, labels: Optional[List[str]] = None,
model: Optional[_Model] = None, optimizer: Union[str, optimizers.Optimizer] = 'adam',
loss: Union[str, losses.Loss] = losses.SparseCategoricalCrossentropy(from_logits=True),
metrics: List[Union[str, metrics.Metric]] = ('accuracy',),
2020-10-28 18:12:19 +01:00
low_freq: int = AudioSegment.default_low_freq,
high_freq: int = AudioSegment.default_high_freq):
2020-10-27 15:21:32 +01:00
assert layers or model
self.label_names = labels
2020-10-28 18:12:19 +01:00
self.cutoff_frequencies = (int(low_freq), int(high_freq))
2020-10-27 15:21:32 +01:00
if layers:
self._model = Sequential(layers)
self._model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
else:
self._model = model
def fit(self, dataset: Dataset, *args, **kwargs):
return self._model.fit(dataset.train_samples, dataset.train_classes, *args, **kwargs)
def evaluate(self, dataset: Dataset, *args, **kwargs):
return self._model.evaluate(dataset.validation_samples, dataset.validation_classes, *args, **kwargs)
def predict(self, audio: AudioSegment):
spectrum = audio.spectrum(low_freq=self.cutoff_frequencies[0], high_freq=self.cutoff_frequencies[1])
output = self._model.predict(np.array([spectrum]))
prediction = int(np.argmax(output))
return self.label_names[prediction] if self.label_names else prediction
def save(self, model_dir: str, *args, **kwargs):
model_dir = os.path.abspath(os.path.expanduser(model_dir))
self._model.save(model_dir, *args, **kwargs)
if self.label_names:
labels_file = os.path.join(model_dir, self.labels_file_name)
with open(labels_file, 'w') as f:
json.dump(self.label_names, f)
if self.cutoff_frequencies:
freq_file = os.path.join(model_dir, self.freq_file_name)
with open(freq_file, 'w') as f:
json.dump(self.cutoff_frequencies, f)
@classmethod
def load(cls, model_dir: str, *args, **kwargs):
model_dir = os.path.abspath(os.path.expanduser(model_dir))
model = load_model(model_dir, *args, **kwargs)
labels_file = os.path.join(model_dir, cls.labels_file_name)
freq_file = os.path.join(model_dir, cls.freq_file_name)
label_names = []
frequencies = []
if os.path.isfile(labels_file):
with open(labels_file, 'r') as f:
label_names = json.load(f)
if os.path.isfile(freq_file):
with open(freq_file, 'r') as f:
frequencies = json.load(f)
2020-10-28 18:12:19 +01:00
return cls(model=model, labels=label_names, low_freq=frequencies[0], high_freq=frequencies[1])