micmon/micmon/model/__init__.py

79 lines
3.1 KiB
Python

import json
import os
import numpy as np
from typing import List, Optional, Union
from keras import Sequential, losses, optimizers, metrics
from keras.layers import Layer
from keras.models import load_model, Model as _Model
from micmon.audio import AudioSegment
from micmon.dataset import Dataset
class Model:
labels_file_name = 'labels.json'
freq_file_name = 'freq.json'
# noinspection PyShadowingNames
def __init__(self, layers: Optional[List[Layer]] = None, labels: Optional[List[str]] = None,
model: Optional[_Model] = None, optimizer: Union[str, optimizers.Optimizer] = 'adam',
loss: Union[str, losses.Loss] = losses.SparseCategoricalCrossentropy(from_logits=True),
metrics: List[Union[str, metrics.Metric]] = ('accuracy',),
low_freq: int = AudioSegment.default_low_freq,
high_freq: int = AudioSegment.default_high_freq):
assert layers or model
self.label_names = labels
self.cutoff_frequencies = (int(low_freq), int(high_freq))
if layers:
self._model = Sequential(layers)
self._model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
else:
self._model = model
def fit(self, dataset: Dataset, *args, **kwargs):
return self._model.fit(dataset.train_samples, dataset.train_classes, *args, **kwargs)
def evaluate(self, dataset: Dataset, *args, **kwargs):
return self._model.evaluate(dataset.validation_samples, dataset.validation_classes, *args, **kwargs)
def predict(self, audio: AudioSegment):
spectrum = audio.spectrum(low_freq=self.cutoff_frequencies[0], high_freq=self.cutoff_frequencies[1])
output = self._model.predict(np.array([spectrum]))
prediction = int(np.argmax(output))
return self.label_names[prediction] if self.label_names else prediction
def save(self, model_dir: str, *args, **kwargs):
model_dir = os.path.abspath(os.path.expanduser(model_dir))
self._model.save(model_dir, *args, **kwargs)
if self.label_names:
labels_file = os.path.join(model_dir, self.labels_file_name)
with open(labels_file, 'w') as f:
json.dump(self.label_names, f)
if self.cutoff_frequencies:
freq_file = os.path.join(model_dir, self.freq_file_name)
with open(freq_file, 'w') as f:
json.dump(self.cutoff_frequencies, f)
@classmethod
def load(cls, model_dir: str, *args, **kwargs):
model_dir = os.path.abspath(os.path.expanduser(model_dir))
model = load_model(model_dir, *args, **kwargs)
labels_file = os.path.join(model_dir, cls.labels_file_name)
freq_file = os.path.join(model_dir, cls.freq_file_name)
label_names = []
frequencies = []
if os.path.isfile(labels_file):
with open(labels_file, 'r') as f:
label_names = json.load(f)
if os.path.isfile(freq_file):
with open(freq_file, 'r') as f:
frequencies = json.load(f)
return cls(model=model, labels=label_names, low_freq=frequencies[0], high_freq=frequencies[1])