mirror of https://github.com/BlackLight/micmon.git
54 lines
2.2 KiB
Python
54 lines
2.2 KiB
Python
import os
|
|
import numpy as np
|
|
|
|
from .writer import DatasetWriter
|
|
from ..audio import AudioSegment
|
|
|
|
|
|
class Dataset:
|
|
def __init__(self, samples: np.ndarray, classes: np.ndarray, validation_split: float = 0.,
|
|
low_freq: float = AudioSegment.default_low_freq, high_freq: float = AudioSegment.default_high_freq):
|
|
self.samples = samples
|
|
self.classes = classes
|
|
self.labels = np.sort(np.unique(classes))
|
|
self.validation_split = validation_split
|
|
self.low_freq = low_freq
|
|
self.high_freq = high_freq
|
|
self.train_samples, self.train_classes, self.validation_samples, self.validation_classes = [np.array([])] * 4
|
|
self.shuffle()
|
|
|
|
@classmethod
|
|
def load(cls, npz_path: str, validation_split: float = 0.):
|
|
dataset = np.load(os.path.abspath(os.path.expanduser(npz_path)))
|
|
return cls(samples=dataset['samples'],
|
|
classes=dataset['classes'],
|
|
validation_split=validation_split,
|
|
low_freq=dataset['cutoff_frequencies'][0],
|
|
high_freq=dataset['cutoff_frequencies'][1])
|
|
|
|
@classmethod
|
|
def scan(cls, datasets_path, validation_split: float = 0.):
|
|
datasets_path = os.path.abspath(os.path.expanduser(datasets_path))
|
|
return [
|
|
cls.load(os.path.join(datasets_path, file), validation_split=validation_split)
|
|
for file in os.listdir(datasets_path)
|
|
if os.path.isfile(os.path.join(datasets_path, file))
|
|
and file.endswith('.npz')
|
|
]
|
|
|
|
def shuffle(self):
|
|
data = np.array([
|
|
(self.samples[i], self.classes[i])
|
|
for i in range(len(self.samples))
|
|
], dtype=object)
|
|
|
|
np.random.shuffle(data)
|
|
self.samples = np.array([p[0] for p in data])
|
|
self.classes = np.array([p[1] for p in data])
|
|
|
|
pivot = int(len(data) - (self.validation_split * len(data)))
|
|
self.train_samples = np.array([p[0] for p in data[:pivot]])
|
|
self.train_classes = np.array([p[1] for p in data[:pivot]])
|
|
self.validation_samples = np.array([p[0] for p in data[pivot:]])
|
|
self.validation_classes = np.array([p[1] for p in data[pivot:]])
|