diff --git a/platypush/backend/stt/deepspeech.py b/platypush/backend/stt/deepspeech.py index 22916874..1a149be4 100644 --- a/platypush/backend/stt/deepspeech.py +++ b/platypush/backend/stt/deepspeech.py @@ -34,7 +34,6 @@ class SttDeepspeechBackend(Backend): try: plugin: SttDeepspeechPlugin = get_plugin('stt.deepspeech') with plugin: - plugin.start_detection() # noinspection PyProtectedMember plugin._detection_thread.join() except Exception as e: diff --git a/platypush/plugins/stt/__init__.py b/platypush/plugins/stt/__init__.py index e69de29b..f4978248 100644 --- a/platypush/plugins/stt/__init__.py +++ b/platypush/plugins/stt/__init__.py @@ -0,0 +1,277 @@ +import queue +import threading +from abc import ABC, abstractmethod +from typing import Optional, Union, List + +import sounddevice as sd + +from platypush.context import get_bus +from platypush.message.event.stt import SpeechDetectionStartedEvent, SpeechDetectionStoppedEvent, SpeechStartedEvent, \ + SpeechDetectedEvent, HotwordDetectedEvent, ConversationDetectedEvent +from platypush.message.response.stt import SpeechDetectedResponse +from platypush.plugins import Plugin, action + + +class SttPlugin(ABC, Plugin): + """ + Abstract class for speech-to-text plugins. + + Triggers: + + * :class:`platypush.message.event.stt.SpeechStartedEvent` when speech starts being detected. + * :class:`platypush.message.event.stt.SpeechDetectedEvent` when speech is detected. + * :class:`platypush.message.event.stt.SpeechDetectionStartedEvent` when speech detection starts. + * :class:`platypush.message.event.stt.SpeechDetectionStoppedEvent` when speech detection stops. + * :class:`platypush.message.event.stt.HotwordDetectedEvent` when a user-defined hotword is detected. + * :class:`platypush.message.event.stt.ConversationDetectedEvent` when speech is detected after a hotword. + + """ + + _thread_stop_timeout = 10.0 + rate = 16000 + channels = 1 + + def __init__(self, + input_device: Optional[Union[int, str]] = None, + hotword: Optional[str] = None, + hotwords: Optional[List[str]] = None, + conversation_timeout: Optional[float] = None, + block_duration: float = 1.0): + """ + :param input_device: PortAudio device index or name that will be used for recording speech (default: default + system audio input device). + :param hotword: When this word is detected, the plugin will trigger a + :class:`platypush.message.event.stt.HotwordDetectedEvent` instead of a + :class:`platypush.message.event.stt.SpeechDetectedEvent` event. You can use these events for hooking other + assistants. + :param hotwords: Use a list of hotwords instead of a single one. + :param conversation_timeout: If ``hotword`` or ``hotwords`` are set and ``conversation_timeout`` is set, + the next speech detected event will trigger a :class:`platypush.message.event.stt.ConversationDetectedEvent` + instead of a :class:`platypush.message.event.stt.SpeechDetectedEvent` event. You can hook custom hooks + here to run any logic depending on the detected speech - it can emulate a kind of + "OK, Google. Turn on the lights" interaction without using an external assistant. + :param block_duration: Duration of the acquired audio blocks (default: 1 second). + """ + + super().__init__() + self.input_device = input_device + self.conversation_timeout = conversation_timeout + self.block_duration = block_duration + + self.hotwords = set(hotwords or []) + if hotword: + self.hotwords = {hotword} + + self._conversation_event = threading.Event() + self._input_stream: Optional[sd.InputStream] = None + self._recording_thread: Optional[threading.Thread] = None + self._detection_thread: Optional[threading.Thread] = None + self._audio_queue: Optional[queue.Queue] = None + + def _get_input_device(self, device: Optional[Union[int, str]] = None) -> int: + """ + Get the index of the input device by index or name. + + :param device: Device index or name. If None is set then the function will return the index of the + default audio input device. + :return: Index of the audio input device. + """ + if not device: + device = self.input_device + if not device: + return sd.query_hostapis()[0].get('default_input_device') + + if isinstance(device, int): + assert device <= len(sd.query_devices()) + return device + + for i, dev in enumerate(sd.query_devices()): + if dev['name'] == device: + return i + + raise AssertionError('Device {} not found'.format(device)) + + def on_speech_detected(self, speech: str) -> None: + """ + Hook called when speech is detected. Triggers the right event depending on the current context. + + :param speech: Detected speech. + """ + speech = speech.strip() + + if self._conversation_event.is_set(): + event = ConversationDetectedEvent(speech=speech) + elif speech in self.hotwords: + event = HotwordDetectedEvent(hotword=speech) + if self.conversation_timeout: + self._conversation_event.set() + threading.Timer(self.conversation_timeout, lambda: self._conversation_event.clear()).start() + else: + event = SpeechDetectedEvent(speech=speech) + + get_bus().post(event) + + @staticmethod + def convert_frames(frames: bytes) -> bytes: + """ + Conversion method for raw audio frames. It just returns the input frames as bytes. Override it if required + by your logic. + + :param frames: Input audio frames, as bytes. + :return: The audio frames as passed on the input. Override if required. + """ + return frames + + def on_detection_started(self): + """ + Method called when the ``detection_thread`` starts. Initialize your context variables and models here if + required. + """ + pass + + def on_detection_ended(self): + """ + Method called when the ``detection_thread`` stops. Clean up your context variables and models here. + """ + pass + + @abstractmethod + def detect_audio(self, frames) -> str: + """ + Method called within the ``detection_thread`` when new audio frames have been captured. Must be implemented + by the derived classes. + + :param frames: Audio frames, as returned by ``convert_frames``. + :return: Detected text, as a string. + """ + raise NotImplementedError + + def detection_thread(self) -> None: + """ + This thread reads frames from ``_audio_queue``, performs the speech-to-text detection and calls + """ + current_text = '' + self.logger.debug('Detection thread started') + self.on_detection_started() + + while self._audio_queue: + try: + frames = self._audio_queue.get() + frames = self.convert_frames(frames) + except Exception as e: + self.logger.warning('Error while feeding audio to the model: {}'.format(str(e))) + continue + + text = self.detect_audio(frames) + if text == current_text: + if current_text: + self.on_speech_detected(current_text) + + current_text = '' + else: + if not current_text: + get_bus().post(SpeechStartedEvent()) + + self.logger.info('Intermediate speech results: [{}]'.format(text)) + current_text = text + + self.on_detection_ended() + self.logger.debug('Detection thread terminated') + + def recording_thread(self, block_duration: float, input_device: Optional[str] = None) -> None: + """ + Recording thread. It reads raw frames from the audio device and dispatches them to ``detection_thread``. + + :param block_duration: Audio blocks duration. + :param input_device: Input device + """ + self.logger.debug('Recording thread started') + device = self._get_input_device(input_device) + blocksize = int(self.rate * self.channels * block_duration) + self._input_stream = sd.InputStream(samplerate=self.rate, device=device, + channels=self.channels, dtype='int16', latency=0, + blocksize=blocksize) + self._input_stream.start() + get_bus().post(SpeechDetectionStartedEvent()) + + while self._input_stream: + try: + frames = self._input_stream.read(self.rate)[0] + except Exception as e: + self.logger.warning('Error while reading from the audio input: {}'.format(str(e))) + continue + + self._audio_queue.put(frames) + + get_bus().post(SpeechDetectionStoppedEvent()) + self.logger.debug('Recording thread terminated') + + @abstractmethod + @action + def detect(self, audio_file: str) -> SpeechDetectedResponse: + """ + Perform speech-to-text analysis on an audio file. Must be implemented by the derived classes. + + :param audio_file: Path to the audio file. + """ + raise NotImplementedError + + def __enter__(self): + """ + Context manager enter. Starts detection and returns self. + """ + self.start_detection() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Context manager exit. Stops detection. + """ + self.stop_detection() + + @action + def start_detection(self, input_device: Optional[str] = None, seconds: Optional[float] = None, + block_duration: Optional[float] = None) -> None: + """ + Start the speech detection engine. + + :param input_device: Audio input device name/index override + :param seconds: If set, then the detection engine will stop after this many seconds, otherwise it'll + start running until ``stop_detection`` is called or application stop. + :param block_duration: ``block_duration`` override. + """ + assert not self._input_stream and not self._recording_thread, 'Speech detection is already running' + block_duration = block_duration or self.block_duration + input_device = input_device if input_device is not None else self.input_device + self._audio_queue = queue.Queue() + self._recording_thread = threading.Thread( + target=lambda: self.recording_thread(block_duration=block_duration, input_device=input_device)) + + self._recording_thread.start() + self._detection_thread = threading.Thread(target=lambda: self.detection_thread()) + self._detection_thread.start() + + if seconds: + threading.Timer(seconds, lambda: self.stop_detection()).start() + + @action + def stop_detection(self) -> None: + """ + Stop the speech detection engine. + """ + assert self._input_stream, 'Speech detection is not running' + self._input_stream.stop(ignore_errors=True) + self._input_stream.close(ignore_errors=True) + self._input_stream = None + + if self._recording_thread: + self._recording_thread.join(timeout=self._thread_stop_timeout) + self._recording_thread = None + + self._audio_queue = None + if self._detection_thread: + self._detection_thread.join(timeout=self._thread_stop_timeout) + self._detection_thread = None + + +# vim:sw=4:ts=4:et: diff --git a/platypush/plugins/stt/deepspeech.py b/platypush/plugins/stt/deepspeech.py index 3aaa7497..f50bd1db 100644 --- a/platypush/plugins/stt/deepspeech.py +++ b/platypush/plugins/stt/deepspeech.py @@ -1,34 +1,20 @@ -import queue import os -import threading -from typing import Optional, Union, List +from typing import Optional, Union import deepspeech import numpy as np -import sounddevice as sd import wave -from platypush.context import get_bus -from platypush.message.event.stt import SpeechDetectionStartedEvent, SpeechDetectionStoppedEvent, SpeechStartedEvent, \ - SpeechDetectedEvent, HotwordDetectedEvent, ConversationDetectedEvent from platypush.message.response.stt import SpeechDetectedResponse -from platypush.plugins import Plugin, action +from platypush.plugins import action +from platypush.plugins.stt import SttPlugin -class SttDeepspeechPlugin(Plugin): +class SttDeepspeechPlugin(SttPlugin): """ This plugin performs speech-to-text and speech detection using the `Mozilla DeepSpeech `_ engine. - Triggers: - - * :class:`platypush.message.event.stt.SpeechStartedEvent` when speech starts being detected. - * :class:`platypush.message.event.stt.SpeechDetectedEvent` when speech is detected. - * :class:`platypush.message.event.stt.SpeechDetectionStartedEvent` when speech detection starts. - * :class:`platypush.message.event.stt.SpeechDetectionStoppedEvent` when speech detection stops. - * :class:`platypush.message.event.stt.HotwordDetectedEvent` when a user-defined hotword is detected. - * :class:`platypush.message.event.stt.ConversationDetectedEvent` when speech is detected after a hotword. - Requires: * **deepspeech** (``pip install 'deepspeech>=0.6.0'``) @@ -37,10 +23,6 @@ class SttDeepspeechPlugin(Plugin): """ - _thread_stop_timeout = 10.0 - rate = 16000 - channels = 1 - def __init__(self, model_file: str, lm_file: str, @@ -48,11 +30,7 @@ class SttDeepspeechPlugin(Plugin): lm_alpha: float = 0.75, lm_beta: float = 1.85, beam_width: int = 500, - input_device: Optional[Union[int, str]] = None, - hotword: Optional[str] = None, - hotwords: Optional[List[str]] = None, - conversation_timeout: Optional[float] = None, - block_duration: float = 1.0): + *args, **kwargs): """ In order to run the speech-to-text engine you'll need to download the right model files for the Deepspeech engine that you have installed: @@ -101,27 +79,15 @@ class SttDeepspeechPlugin(Plugin): :param block_duration: Duration of the acquired audio blocks (default: 1 second). """ - super().__init__() + super().__init__(*args, **kwargs) self.model_file = os.path.abspath(os.path.expanduser(model_file)) self.lm_file = os.path.abspath(os.path.expanduser(lm_file)) self.trie_file = os.path.abspath(os.path.expanduser(trie_file)) self.lm_alpha = lm_alpha self.lm_beta = lm_beta self.beam_width = beam_width - self.input_device = input_device - self.conversation_timeout = conversation_timeout - self.block_duration = block_duration - - self.hotwords = set(hotwords or []) - if hotword: - self.hotwords = {hotword} - - self._conversation_event = threading.Event() self._model: Optional[deepspeech.Model] = None - self._input_stream: Optional[sd.InputStream] = None - self._recording_thread: Optional[threading.Thread] = None - self._detection_thread: Optional[threading.Thread] = None - self._audio_queue: Optional[queue.Queue] = None + self._context = None def _get_model(self) -> deepspeech.Model: if not self._model: @@ -130,125 +96,41 @@ class SttDeepspeechPlugin(Plugin): return self._model - def _detect(self, data: Union[bytes, np.ndarray]) -> str: - data = self._convert_data(data) - model = self._get_model() - return model.stt(data) + def _get_context(self): + if not self._model: + self._model = self._get_model() + if not self._context: + self._context = self._model.createStream() + + return self._context @staticmethod - def _convert_data(data: Union[np.ndarray, bytes]) -> np.ndarray: - return np.frombuffer(data, dtype=np.int16) + def convert_frames(frames: Union[np.ndarray, bytes]) -> np.ndarray: + return np.frombuffer(frames, dtype=np.int16) - def _get_input_device(self, device: Optional[Union[int, str]] = None) -> int: - """ - Get the index of the input device by index or name. + def on_detection_started(self): + self._context = self._get_context() - :param device: Device index or name. If None is set then the function will return the index of the - default audio input device. - :return: Index of the audio input device. - """ - if not device: - device = self.input_device - if not device: - return sd.query_hostapis()[0].get('default_input_device') + def on_detection_ended(self): + if self._model and self._context: + self._model.finishStream() + self._context = None - if isinstance(device, int): - assert device <= len(sd.query_devices()) - return device - - for i, dev in enumerate(sd.query_devices()): - if dev['name'] == device: - return i - - raise AssertionError('Device {} not found'.format(device)) - - def _on_speech_detected(self, speech: str) -> None: - """ - Hook called when speech is detected. Triggers the right event depending on the current context. - - :param speech: Detected speech. - """ - speech = speech.strip() - - if self._conversation_event.is_set(): - event = ConversationDetectedEvent(speech=speech) - elif speech in self.hotwords: - event = HotwordDetectedEvent(hotword=speech) - if self.conversation_timeout: - self._conversation_event.set() - threading.Timer(self.conversation_timeout, lambda: self._conversation_event.clear()).start() - else: - event = SpeechDetectedEvent(speech=speech) - - get_bus().post(event) - - def detection_thread(self) -> None: - """ - Speech detection thread. Reads from the ``audio_queue`` and uses the Deepspeech model to detect - speech real-time. - """ - self.logger.debug('Detection thread started') + def detect_audio(self, frames) -> str: model = self._get_model() - current_text = '' - context = None + context = self._get_context() + model.feedAudioContent(context, frames) + return model.intermediateDecode(context) - while self._audio_queue: - if not context: - context = model.createStream() + def on_speech_detected(self, speech: str) -> None: + super().on_speech_detected(speech) + if not speech: + return - try: - frames = self._audio_queue.get() - frames = self._convert_data(frames) - except Exception as e: - self.logger.warning('Error while feeding audio to the model: {}'.format(str(e))) - continue - - model.feedAudioContent(context, frames) - text = model.intermediateDecode(context) - - if text == current_text: - if current_text: - self._on_speech_detected(current_text) - model.finishStream(context) - context = None - - current_text = '' - else: - if not current_text: - get_bus().post(SpeechStartedEvent()) - - self.logger.info('Intermediate speech results: [{}]'.format(text)) - current_text = text - - self.logger.debug('Detection thread terminated') - - def recording_thread(self, block_duration: float, input_device: Optional[str] = None) -> None: - """ - Recording thread. It reads raw frames from the audio device and dispatches them to ``detection_thread``. - - :param block_duration: Audio blocks duration. - :param input_device: Input device - """ - self.logger.debug('Recording thread started') - device = self._get_input_device(input_device) - blocksize = int(self.rate * self.channels * block_duration) - self._input_stream = sd.InputStream(samplerate=self.rate, device=device, - channels=self.channels, dtype='int16', latency=0, - blocksize=blocksize) - self._input_stream.start() - get_bus().post(SpeechDetectionStartedEvent()) - - while self._input_stream: - try: - frames = self._input_stream.read(self.rate)[0] - except Exception as e: - self.logger.warning('Error while reading from the audio input: {}'.format(str(e))) - continue - - self._audio_queue.put(frames) - - get_bus().post(SpeechDetectionStoppedEvent()) - self.logger.debug('Recording thread terminated') + model = self._get_model() + context = self._get_context() + model.finishStream(context) + self._context = None @action def detect(self, audio_file: str) -> SpeechDetectedResponse: @@ -260,63 +142,10 @@ class SttDeepspeechPlugin(Plugin): audio_file = os.path.abspath(os.path.expanduser(audio_file)) wav = wave.open(audio_file, 'r') buffer = wav.readframes(wav.getnframes()) - speech = self._detect(buffer) + data = self.convert_frames(buffer) + model = self._get_model() + speech = model.stt(data) return SpeechDetectedResponse(speech=speech) - def __enter__(self): - """ - Context manager enter. Starts detection and returns self. - """ - self.start_detection() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """ - Context manager exit. Stops detection. - """ - self.stop_detection() - - @action - def start_detection(self, input_device: Optional[str] = None, seconds: Optional[float] = None, - block_duration: Optional[float] = None) -> None: - """ - Start the speech detection engine. - - :param input_device: Audio input device name/index override - :param seconds: If set, then the detection engine will stop after this many seconds, otherwise it'll - start running until ``stop_detection`` is called or application stop. - :param block_duration: ``block_duration`` override. - """ - assert not self._input_stream, 'Speech detection is already running' - block_duration = block_duration or self.block_duration - input_device = input_device if input_device is not None else self.input_device - self._audio_queue = queue.Queue() - self._recording_thread = threading.Thread( - target=lambda: self.recording_thread(block_duration=block_duration, input_device=input_device)) - - self._recording_thread.start() - self._detection_thread = threading.Thread(target=lambda: self.detection_thread()) - self._detection_thread.start() - - if seconds: - threading.Timer(seconds, lambda: self.stop_detection()).start() - - @action - def stop_detection(self) -> None: - """ - Stop the speech detection engine. - """ - assert self._input_stream, 'Speech detection is not running' - self._input_stream.stop(ignore_errors=True) - self._input_stream.close(ignore_errors=True) - self._input_stream = None - - if self._recording_thread: - self._recording_thread.join(timeout=self._thread_stop_timeout) - - self._audio_queue = None - if self._detection_thread: - self._detection_thread.join(timeout=self._thread_stop_timeout) - # vim:sw=4:ts=4:et: