Extracted common logic for speech-to-text integrations into abstract STT plugin

This commit is contained in:
Fabio Manganiello 2020-03-06 23:11:19 +01:00
parent 7ed847b646
commit e04c6fb921
3 changed files with 314 additions and 209 deletions

View file

@ -34,7 +34,6 @@ class SttDeepspeechBackend(Backend):
try:
plugin: SttDeepspeechPlugin = get_plugin('stt.deepspeech')
with plugin:
plugin.start_detection()
# noinspection PyProtectedMember
plugin._detection_thread.join()
except Exception as e:

View file

@ -0,0 +1,277 @@
import queue
import threading
from abc import ABC, abstractmethod
from typing import Optional, Union, List
import sounddevice as sd
from platypush.context import get_bus
from platypush.message.event.stt import SpeechDetectionStartedEvent, SpeechDetectionStoppedEvent, SpeechStartedEvent, \
SpeechDetectedEvent, HotwordDetectedEvent, ConversationDetectedEvent
from platypush.message.response.stt import SpeechDetectedResponse
from platypush.plugins import Plugin, action
class SttPlugin(ABC, Plugin):
"""
Abstract class for speech-to-text plugins.
Triggers:
* :class:`platypush.message.event.stt.SpeechStartedEvent` when speech starts being detected.
* :class:`platypush.message.event.stt.SpeechDetectedEvent` when speech is detected.
* :class:`platypush.message.event.stt.SpeechDetectionStartedEvent` when speech detection starts.
* :class:`platypush.message.event.stt.SpeechDetectionStoppedEvent` when speech detection stops.
* :class:`platypush.message.event.stt.HotwordDetectedEvent` when a user-defined hotword is detected.
* :class:`platypush.message.event.stt.ConversationDetectedEvent` when speech is detected after a hotword.
"""
_thread_stop_timeout = 10.0
rate = 16000
channels = 1
def __init__(self,
input_device: Optional[Union[int, str]] = None,
hotword: Optional[str] = None,
hotwords: Optional[List[str]] = None,
conversation_timeout: Optional[float] = None,
block_duration: float = 1.0):
"""
:param input_device: PortAudio device index or name that will be used for recording speech (default: default
system audio input device).
:param hotword: When this word is detected, the plugin will trigger a
:class:`platypush.message.event.stt.HotwordDetectedEvent` instead of a
:class:`platypush.message.event.stt.SpeechDetectedEvent` event. You can use these events for hooking other
assistants.
:param hotwords: Use a list of hotwords instead of a single one.
:param conversation_timeout: If ``hotword`` or ``hotwords`` are set and ``conversation_timeout`` is set,
the next speech detected event will trigger a :class:`platypush.message.event.stt.ConversationDetectedEvent`
instead of a :class:`platypush.message.event.stt.SpeechDetectedEvent` event. You can hook custom hooks
here to run any logic depending on the detected speech - it can emulate a kind of
"OK, Google. Turn on the lights" interaction without using an external assistant.
:param block_duration: Duration of the acquired audio blocks (default: 1 second).
"""
super().__init__()
self.input_device = input_device
self.conversation_timeout = conversation_timeout
self.block_duration = block_duration
self.hotwords = set(hotwords or [])
if hotword:
self.hotwords = {hotword}
self._conversation_event = threading.Event()
self._input_stream: Optional[sd.InputStream] = None
self._recording_thread: Optional[threading.Thread] = None
self._detection_thread: Optional[threading.Thread] = None
self._audio_queue: Optional[queue.Queue] = None
def _get_input_device(self, device: Optional[Union[int, str]] = None) -> int:
"""
Get the index of the input device by index or name.
:param device: Device index or name. If None is set then the function will return the index of the
default audio input device.
:return: Index of the audio input device.
"""
if not device:
device = self.input_device
if not device:
return sd.query_hostapis()[0].get('default_input_device')
if isinstance(device, int):
assert device <= len(sd.query_devices())
return device
for i, dev in enumerate(sd.query_devices()):
if dev['name'] == device:
return i
raise AssertionError('Device {} not found'.format(device))
def on_speech_detected(self, speech: str) -> None:
"""
Hook called when speech is detected. Triggers the right event depending on the current context.
:param speech: Detected speech.
"""
speech = speech.strip()
if self._conversation_event.is_set():
event = ConversationDetectedEvent(speech=speech)
elif speech in self.hotwords:
event = HotwordDetectedEvent(hotword=speech)
if self.conversation_timeout:
self._conversation_event.set()
threading.Timer(self.conversation_timeout, lambda: self._conversation_event.clear()).start()
else:
event = SpeechDetectedEvent(speech=speech)
get_bus().post(event)
@staticmethod
def convert_frames(frames: bytes) -> bytes:
"""
Conversion method for raw audio frames. It just returns the input frames as bytes. Override it if required
by your logic.
:param frames: Input audio frames, as bytes.
:return: The audio frames as passed on the input. Override if required.
"""
return frames
def on_detection_started(self):
"""
Method called when the ``detection_thread`` starts. Initialize your context variables and models here if
required.
"""
pass
def on_detection_ended(self):
"""
Method called when the ``detection_thread`` stops. Clean up your context variables and models here.
"""
pass
@abstractmethod
def detect_audio(self, frames) -> str:
"""
Method called within the ``detection_thread`` when new audio frames have been captured. Must be implemented
by the derived classes.
:param frames: Audio frames, as returned by ``convert_frames``.
:return: Detected text, as a string.
"""
raise NotImplementedError
def detection_thread(self) -> None:
"""
This thread reads frames from ``_audio_queue``, performs the speech-to-text detection and calls
"""
current_text = ''
self.logger.debug('Detection thread started')
self.on_detection_started()
while self._audio_queue:
try:
frames = self._audio_queue.get()
frames = self.convert_frames(frames)
except Exception as e:
self.logger.warning('Error while feeding audio to the model: {}'.format(str(e)))
continue
text = self.detect_audio(frames)
if text == current_text:
if current_text:
self.on_speech_detected(current_text)
current_text = ''
else:
if not current_text:
get_bus().post(SpeechStartedEvent())
self.logger.info('Intermediate speech results: [{}]'.format(text))
current_text = text
self.on_detection_ended()
self.logger.debug('Detection thread terminated')
def recording_thread(self, block_duration: float, input_device: Optional[str] = None) -> None:
"""
Recording thread. It reads raw frames from the audio device and dispatches them to ``detection_thread``.
:param block_duration: Audio blocks duration.
:param input_device: Input device
"""
self.logger.debug('Recording thread started')
device = self._get_input_device(input_device)
blocksize = int(self.rate * self.channels * block_duration)
self._input_stream = sd.InputStream(samplerate=self.rate, device=device,
channels=self.channels, dtype='int16', latency=0,
blocksize=blocksize)
self._input_stream.start()
get_bus().post(SpeechDetectionStartedEvent())
while self._input_stream:
try:
frames = self._input_stream.read(self.rate)[0]
except Exception as e:
self.logger.warning('Error while reading from the audio input: {}'.format(str(e)))
continue
self._audio_queue.put(frames)
get_bus().post(SpeechDetectionStoppedEvent())
self.logger.debug('Recording thread terminated')
@abstractmethod
@action
def detect(self, audio_file: str) -> SpeechDetectedResponse:
"""
Perform speech-to-text analysis on an audio file. Must be implemented by the derived classes.
:param audio_file: Path to the audio file.
"""
raise NotImplementedError
def __enter__(self):
"""
Context manager enter. Starts detection and returns self.
"""
self.start_detection()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""
Context manager exit. Stops detection.
"""
self.stop_detection()
@action
def start_detection(self, input_device: Optional[str] = None, seconds: Optional[float] = None,
block_duration: Optional[float] = None) -> None:
"""
Start the speech detection engine.
:param input_device: Audio input device name/index override
:param seconds: If set, then the detection engine will stop after this many seconds, otherwise it'll
start running until ``stop_detection`` is called or application stop.
:param block_duration: ``block_duration`` override.
"""
assert not self._input_stream and not self._recording_thread, 'Speech detection is already running'
block_duration = block_duration or self.block_duration
input_device = input_device if input_device is not None else self.input_device
self._audio_queue = queue.Queue()
self._recording_thread = threading.Thread(
target=lambda: self.recording_thread(block_duration=block_duration, input_device=input_device))
self._recording_thread.start()
self._detection_thread = threading.Thread(target=lambda: self.detection_thread())
self._detection_thread.start()
if seconds:
threading.Timer(seconds, lambda: self.stop_detection()).start()
@action
def stop_detection(self) -> None:
"""
Stop the speech detection engine.
"""
assert self._input_stream, 'Speech detection is not running'
self._input_stream.stop(ignore_errors=True)
self._input_stream.close(ignore_errors=True)
self._input_stream = None
if self._recording_thread:
self._recording_thread.join(timeout=self._thread_stop_timeout)
self._recording_thread = None
self._audio_queue = None
if self._detection_thread:
self._detection_thread.join(timeout=self._thread_stop_timeout)
self._detection_thread = None
# vim:sw=4:ts=4:et:

View file

@ -1,34 +1,20 @@
import queue
import os
import threading
from typing import Optional, Union, List
from typing import Optional, Union
import deepspeech
import numpy as np
import sounddevice as sd
import wave
from platypush.context import get_bus
from platypush.message.event.stt import SpeechDetectionStartedEvent, SpeechDetectionStoppedEvent, SpeechStartedEvent, \
SpeechDetectedEvent, HotwordDetectedEvent, ConversationDetectedEvent
from platypush.message.response.stt import SpeechDetectedResponse
from platypush.plugins import Plugin, action
from platypush.plugins import action
from platypush.plugins.stt import SttPlugin
class SttDeepspeechPlugin(Plugin):
class SttDeepspeechPlugin(SttPlugin):
"""
This plugin performs speech-to-text and speech detection using the
`Mozilla DeepSpeech <https://github.com/mozilla/DeepSpeech>`_ engine.
Triggers:
* :class:`platypush.message.event.stt.SpeechStartedEvent` when speech starts being detected.
* :class:`platypush.message.event.stt.SpeechDetectedEvent` when speech is detected.
* :class:`platypush.message.event.stt.SpeechDetectionStartedEvent` when speech detection starts.
* :class:`platypush.message.event.stt.SpeechDetectionStoppedEvent` when speech detection stops.
* :class:`platypush.message.event.stt.HotwordDetectedEvent` when a user-defined hotword is detected.
* :class:`platypush.message.event.stt.ConversationDetectedEvent` when speech is detected after a hotword.
Requires:
* **deepspeech** (``pip install 'deepspeech>=0.6.0'``)
@ -37,10 +23,6 @@ class SttDeepspeechPlugin(Plugin):
"""
_thread_stop_timeout = 10.0
rate = 16000
channels = 1
def __init__(self,
model_file: str,
lm_file: str,
@ -48,11 +30,7 @@ class SttDeepspeechPlugin(Plugin):
lm_alpha: float = 0.75,
lm_beta: float = 1.85,
beam_width: int = 500,
input_device: Optional[Union[int, str]] = None,
hotword: Optional[str] = None,
hotwords: Optional[List[str]] = None,
conversation_timeout: Optional[float] = None,
block_duration: float = 1.0):
*args, **kwargs):
"""
In order to run the speech-to-text engine you'll need to download the right model files for the
Deepspeech engine that you have installed:
@ -101,27 +79,15 @@ class SttDeepspeechPlugin(Plugin):
:param block_duration: Duration of the acquired audio blocks (default: 1 second).
"""
super().__init__()
super().__init__(*args, **kwargs)
self.model_file = os.path.abspath(os.path.expanduser(model_file))
self.lm_file = os.path.abspath(os.path.expanduser(lm_file))
self.trie_file = os.path.abspath(os.path.expanduser(trie_file))
self.lm_alpha = lm_alpha
self.lm_beta = lm_beta
self.beam_width = beam_width
self.input_device = input_device
self.conversation_timeout = conversation_timeout
self.block_duration = block_duration
self.hotwords = set(hotwords or [])
if hotword:
self.hotwords = {hotword}
self._conversation_event = threading.Event()
self._model: Optional[deepspeech.Model] = None
self._input_stream: Optional[sd.InputStream] = None
self._recording_thread: Optional[threading.Thread] = None
self._detection_thread: Optional[threading.Thread] = None
self._audio_queue: Optional[queue.Queue] = None
self._context = None
def _get_model(self) -> deepspeech.Model:
if not self._model:
@ -130,125 +96,41 @@ class SttDeepspeechPlugin(Plugin):
return self._model
def _detect(self, data: Union[bytes, np.ndarray]) -> str:
data = self._convert_data(data)
model = self._get_model()
return model.stt(data)
def _get_context(self):
if not self._model:
self._model = self._get_model()
if not self._context:
self._context = self._model.createStream()
return self._context
@staticmethod
def _convert_data(data: Union[np.ndarray, bytes]) -> np.ndarray:
return np.frombuffer(data, dtype=np.int16)
def convert_frames(frames: Union[np.ndarray, bytes]) -> np.ndarray:
return np.frombuffer(frames, dtype=np.int16)
def _get_input_device(self, device: Optional[Union[int, str]] = None) -> int:
"""
Get the index of the input device by index or name.
def on_detection_started(self):
self._context = self._get_context()
:param device: Device index or name. If None is set then the function will return the index of the
default audio input device.
:return: Index of the audio input device.
"""
if not device:
device = self.input_device
if not device:
return sd.query_hostapis()[0].get('default_input_device')
def on_detection_ended(self):
if self._model and self._context:
self._model.finishStream()
self._context = None
if isinstance(device, int):
assert device <= len(sd.query_devices())
return device
for i, dev in enumerate(sd.query_devices()):
if dev['name'] == device:
return i
raise AssertionError('Device {} not found'.format(device))
def _on_speech_detected(self, speech: str) -> None:
"""
Hook called when speech is detected. Triggers the right event depending on the current context.
:param speech: Detected speech.
"""
speech = speech.strip()
if self._conversation_event.is_set():
event = ConversationDetectedEvent(speech=speech)
elif speech in self.hotwords:
event = HotwordDetectedEvent(hotword=speech)
if self.conversation_timeout:
self._conversation_event.set()
threading.Timer(self.conversation_timeout, lambda: self._conversation_event.clear()).start()
else:
event = SpeechDetectedEvent(speech=speech)
get_bus().post(event)
def detection_thread(self) -> None:
"""
Speech detection thread. Reads from the ``audio_queue`` and uses the Deepspeech model to detect
speech real-time.
"""
self.logger.debug('Detection thread started')
def detect_audio(self, frames) -> str:
model = self._get_model()
current_text = ''
context = None
context = self._get_context()
model.feedAudioContent(context, frames)
return model.intermediateDecode(context)
while self._audio_queue:
if not context:
context = model.createStream()
def on_speech_detected(self, speech: str) -> None:
super().on_speech_detected(speech)
if not speech:
return
try:
frames = self._audio_queue.get()
frames = self._convert_data(frames)
except Exception as e:
self.logger.warning('Error while feeding audio to the model: {}'.format(str(e)))
continue
model.feedAudioContent(context, frames)
text = model.intermediateDecode(context)
if text == current_text:
if current_text:
self._on_speech_detected(current_text)
model.finishStream(context)
context = None
current_text = ''
else:
if not current_text:
get_bus().post(SpeechStartedEvent())
self.logger.info('Intermediate speech results: [{}]'.format(text))
current_text = text
self.logger.debug('Detection thread terminated')
def recording_thread(self, block_duration: float, input_device: Optional[str] = None) -> None:
"""
Recording thread. It reads raw frames from the audio device and dispatches them to ``detection_thread``.
:param block_duration: Audio blocks duration.
:param input_device: Input device
"""
self.logger.debug('Recording thread started')
device = self._get_input_device(input_device)
blocksize = int(self.rate * self.channels * block_duration)
self._input_stream = sd.InputStream(samplerate=self.rate, device=device,
channels=self.channels, dtype='int16', latency=0,
blocksize=blocksize)
self._input_stream.start()
get_bus().post(SpeechDetectionStartedEvent())
while self._input_stream:
try:
frames = self._input_stream.read(self.rate)[0]
except Exception as e:
self.logger.warning('Error while reading from the audio input: {}'.format(str(e)))
continue
self._audio_queue.put(frames)
get_bus().post(SpeechDetectionStoppedEvent())
self.logger.debug('Recording thread terminated')
model = self._get_model()
context = self._get_context()
model.finishStream(context)
self._context = None
@action
def detect(self, audio_file: str) -> SpeechDetectedResponse:
@ -260,63 +142,10 @@ class SttDeepspeechPlugin(Plugin):
audio_file = os.path.abspath(os.path.expanduser(audio_file))
wav = wave.open(audio_file, 'r')
buffer = wav.readframes(wav.getnframes())
speech = self._detect(buffer)
data = self.convert_frames(buffer)
model = self._get_model()
speech = model.stt(data)
return SpeechDetectedResponse(speech=speech)
def __enter__(self):
"""
Context manager enter. Starts detection and returns self.
"""
self.start_detection()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""
Context manager exit. Stops detection.
"""
self.stop_detection()
@action
def start_detection(self, input_device: Optional[str] = None, seconds: Optional[float] = None,
block_duration: Optional[float] = None) -> None:
"""
Start the speech detection engine.
:param input_device: Audio input device name/index override
:param seconds: If set, then the detection engine will stop after this many seconds, otherwise it'll
start running until ``stop_detection`` is called or application stop.
:param block_duration: ``block_duration`` override.
"""
assert not self._input_stream, 'Speech detection is already running'
block_duration = block_duration or self.block_duration
input_device = input_device if input_device is not None else self.input_device
self._audio_queue = queue.Queue()
self._recording_thread = threading.Thread(
target=lambda: self.recording_thread(block_duration=block_duration, input_device=input_device))
self._recording_thread.start()
self._detection_thread = threading.Thread(target=lambda: self.detection_thread())
self._detection_thread.start()
if seconds:
threading.Timer(seconds, lambda: self.stop_detection()).start()
@action
def stop_detection(self) -> None:
"""
Stop the speech detection engine.
"""
assert self._input_stream, 'Speech detection is not running'
self._input_stream.stop(ignore_errors=True)
self._input_stream.close(ignore_errors=True)
self._input_stream = None
if self._recording_thread:
self._recording_thread.join(timeout=self._thread_stop_timeout)
self._audio_queue = None
if self._detection_thread:
self._detection_thread.join(timeout=self._thread_stop_timeout)
# vim:sw=4:ts=4:et: