platypush/platypush/plugins/assistant/picovoice/_speech/_stt.py

111 lines
3.4 KiB
Python

from typing import Callable, Optional, Sequence, Union
import pvcheetah
from platypush.message.event.assistant import (
ConversationTimeoutEvent,
SpeechRecognizedEvent,
)
from ._base import BaseProcessor
class SttProcessor(BaseProcessor):
"""
Implementation of the speech-to-text processor using the Picovoice Cheetah
engine.
"""
def __init__(
self, *args, get_cheetah_args: Callable[[], dict] = lambda: {}, **kwargs
):
super().__init__(*args, **kwargs)
self._get_cheetah_args = get_cheetah_args
# model_path -> Cheetah instance cache
self._cheetah = {self._model_path: pvcheetah.create(**self._get_cheetah_args())}
@property
def _model_path(self) -> Optional[str]:
return self._get_cheetah_args().get('model_path')
@property
def sample_rate(self) -> int:
return self._get_cheetah().sample_rate
@property
def frame_length(self) -> int:
return self._get_cheetah().frame_length
def _get_cheetah(self) -> pvcheetah.Cheetah:
with self._state_lock:
if not self._cheetah.get(self._model_path):
self.logger.debug(
'Creating Cheetah instance for model %s', self._model_path
)
self._cheetah[self._model_path] = pvcheetah.create(
**self._get_cheetah_args()
)
self.logger.debug(
'Cheetah instance created for model %s', self._model_path
)
return self._cheetah[self._model_path]
def process(
self, audio: Sequence[int]
) -> Optional[Union[SpeechRecognizedEvent, ConversationTimeoutEvent]]:
event = None
cheetah = self._get_cheetah()
partial_transcript, self._ctx.is_final = cheetah.process(audio)
last_transcript = self._ctx.transcript
# Concatenate the partial transcript to the context
if partial_transcript:
self._ctx.transcript += partial_transcript
self.logger.info(
'Partial transcript: %s, is_final: %s',
self._ctx.transcript,
self._ctx.is_final,
)
# If the transcript is final or the conversation timed out, then
# process and return whatever is available in the context
if self._ctx.is_final or self._ctx.timed_out:
phrase = cheetah.flush() or ''
self._ctx.transcript += phrase
if self._ctx.transcript and self._ctx.transcript != last_transcript:
self.logger.debug('Processed STT transcript: %s', self._ctx.transcript)
last_transcript = self._ctx.transcript
phrase = self._ctx.transcript
phrase = (phrase[:1].lower() + phrase[1:]).strip()
event = (
SpeechRecognizedEvent(phrase=phrase)
if phrase
else ConversationTimeoutEvent()
)
self.reset()
return event
def reset(self):
if not self._enabled:
return
with self._state_lock:
super().reset()
self._get_cheetah().flush()
def stop(self):
if not self._enabled:
return
super().stop()
with self._state_lock:
objs = self._cheetah.copy()
for key, obj in objs.items():
obj.delete()
self._cheetah.pop(key)