diff --git a/platypush/plugins/assistant/picovoice/__init__.py b/platypush/plugins/assistant/picovoice/__init__.py index 9426ba60..d275494d 100644 --- a/platypush/plugins/assistant/picovoice/__init__.py +++ b/platypush/plugins/assistant/picovoice/__init__.py @@ -1,7 +1,10 @@ +import os from typing import Optional, Sequence +from platypush.context import get_plugin from platypush.plugins import RunnablePlugin, action from platypush.plugins.assistant import AssistantPlugin +from platypush.plugins.tts.picovoice import TtsPicovoicePlugin from ._assistant import Assistant from ._state import AssistantState @@ -96,7 +99,12 @@ class AssistantPicovoicePlugin(AssistantPlugin, RunnablePlugin): using a language other than English, you can provide the path to the model file for that language. Model files are available for all the supported languages through the `Picovoice repository - `_. + `_. + You can also use the `Picovoice console + `_ + to train your custom models. You can use a base model and fine-tune + it by boosting the detection of your own words and phrases and edit + the phonetic representation of the words you want to detect. :param endpoint_duration: If set, the assistant will stop listening when no speech is detected for the specified duration (in seconds) after the end of an utterance. @@ -146,15 +154,47 @@ class AssistantPicovoicePlugin(AssistantPlugin, RunnablePlugin): 'on_hotword_detected': self._on_hotword_detected, } + @property + def tts(self) -> TtsPicovoicePlugin: + p = get_plugin('tts.picovoice') + assert p, 'Picovoice TTS plugin not configured/found' + return p + + def _get_tts_plugin(self) -> TtsPicovoicePlugin: + return self.tts + + def _on_response_render_start(self, text: Optional[str]): + if self._assistant: + self._assistant.state = AssistantState.RESPONDING + return super()._on_response_render_start(text) + + def _on_response_render_end(self): + if self._assistant: + self._assistant.state = ( + AssistantState.DETECTING_HOTWORD + if self._assistant.hotword_enabled + else AssistantState.IDLE + ) + + return super()._on_response_render_end() + @action - def start_conversation(self, *_, **__): + def start_conversation(self, *_, model_file: Optional[str] = None, **__): """ - Programmatically start a conversation with the assistant + Programmatically start a conversation with the assistant. + + :param model_file: Override the model file to be used to detect speech + in this conversation. If not set, the configured + ``speech_model_path`` will be used. """ if not self._assistant: self.logger.warning('Assistant not initialized') return + if model_file: + model_file = os.path.expanduser(model_file) + + self._assistant.override_speech_model(model_file) self._assistant.state = AssistantState.DETECTING_SPEECH @action @@ -166,6 +206,8 @@ class AssistantPicovoicePlugin(AssistantPlugin, RunnablePlugin): self.logger.warning('Assistant not initialized') return + self._assistant.override_speech_model(None) + if self._assistant.hotword_enabled: self._assistant.state = AssistantState.DETECTING_HOTWORD else: diff --git a/platypush/plugins/assistant/picovoice/_assistant.py b/platypush/plugins/assistant/picovoice/_assistant.py index 761fe981..11da4c88 100644 --- a/platypush/plugins/assistant/picovoice/_assistant.py +++ b/platypush/plugins/assistant/picovoice/_assistant.py @@ -9,11 +9,13 @@ import pvleopard import pvporcupine import pvrhino +from platypush.context import get_plugin from platypush.message.event.assistant import ( ConversationTimeoutEvent, HotwordDetectedEvent, SpeechRecognizedEvent, ) +from platypush.plugins.tts.picovoice import TtsPicovoicePlugin from ._context import ConversationContext from ._recorder import AudioRecorder @@ -25,6 +27,7 @@ class Assistant: A facade class that wraps the Picovoice engines under an assistant API. """ + @staticmethod def _default_callback(*_, **__): pass @@ -61,11 +64,12 @@ class Assistant: self.keyword_paths = None self.keyword_model_path = None self.frame_expiration = frame_expiration - self.speech_model_path = speech_model_path self.endpoint_duration = endpoint_duration self.enable_automatic_punctuation = enable_automatic_punctuation self.start_conversation_on_hotword = start_conversation_on_hotword self.audio_queue_size = audio_queue_size + self._speech_model_path = speech_model_path + self._speech_model_path_override = None self._on_conversation_start = on_conversation_start self._on_conversation_end = on_conversation_end @@ -103,11 +107,22 @@ class Assistant: self.keyword_model_path = keyword_model_path - self._cheetah: Optional[pvcheetah.Cheetah] = None + # Model path -> model instance cache + self._cheetah = {} self._leopard: Optional[pvleopard.Leopard] = None self._porcupine: Optional[pvporcupine.Porcupine] = None self._rhino: Optional[pvrhino.Rhino] = None + @property + def speech_model_path(self): + return self._speech_model_path_override or self._speech_model_path + + @property + def tts(self) -> TtsPicovoicePlugin: + p = get_plugin('tts.picovoice') + assert p, 'Picovoice TTS plugin not configured/found' + return p + def should_stop(self): return self._stop_event.is_set() @@ -130,12 +145,18 @@ class Assistant: return if prev_state == AssistantState.DETECTING_SPEECH: + self.tts.stop() self._ctx.stop() + self._speech_model_path_override = None self._on_conversation_end() elif new_state == AssistantState.DETECTING_SPEECH: self._ctx.start() self._on_conversation_start() + if new_state == AssistantState.DETECTING_HOTWORD: + self.tts.stop() + self._ctx.reset() + @property def porcupine(self) -> Optional[pvporcupine.Porcupine]: if not self.hotword_enabled: @@ -159,7 +180,7 @@ class Assistant: if not self.stt_enabled: return None - if not self._cheetah: + if not self._cheetah.get(self.speech_model_path): args: Dict[str, Any] = {'access_key': self._access_key} if self.speech_model_path: args['model_path'] = self.speech_model_path @@ -168,9 +189,9 @@ class Assistant: if self.enable_automatic_punctuation: args['enable_automatic_punctuation'] = self.enable_automatic_punctuation - self._cheetah = pvcheetah.create(**args) + self._cheetah[self.speech_model_path] = pvcheetah.create(**args) - return self._cheetah + return self._cheetah[self.speech_model_path] def __enter__(self): if self.should_stop(): @@ -178,10 +199,9 @@ class Assistant: if self._recorder: self.logger.info('A recording stream already exists') - elif self.porcupine or self.cheetah: + elif self.hotword_enabled or self.stt_enabled: sample_rate = (self.porcupine or self.cheetah).sample_rate # type: ignore frame_length = (self.porcupine or self.cheetah).frame_length # type: ignore - self._recorder = AudioRecorder( stop_event=self._stop_event, sample_rate=sample_rate, @@ -190,6 +210,9 @@ class Assistant: channels=1, ) + if self.stt_enabled: + self._cheetah[self.speech_model_path] = self.cheetah + self._recorder.__enter__() if self.porcupine: @@ -205,10 +228,10 @@ class Assistant: self._recorder = None self.state = AssistantState.IDLE - - if self._cheetah: - self._cheetah.delete() - self._cheetah = None + for model in [*self._cheetah.keys()]: + cheetah = self._cheetah.pop(model, None) + if cheetah: + cheetah.delete() if self._leopard: self._leopard.delete() @@ -242,10 +265,10 @@ class Assistant: ) continue # The audio frame is too old - if self.porcupine and self.state == AssistantState.DETECTING_HOTWORD: + if self.hotword_enabled and self.state == AssistantState.DETECTING_HOTWORD: return self._process_hotword(frame) - if self.cheetah and self.state == AssistantState.DETECTING_SPEECH: + if self.stt_enabled and self.state == AssistantState.DETECTING_SPEECH: return self._process_speech(frame) raise StopIteration @@ -283,15 +306,12 @@ class Assistant: ) if self._ctx.is_final or self._ctx.timed_out: - phrase = '' - if self.cheetah: - phrase = self.cheetah.flush() - + phrase = self.cheetah.flush() or '' self._ctx.transcript += phrase phrase = self._ctx.transcript phrase = phrase[:1].lower() + phrase[1:] - if self._ctx.is_final or phrase: + if self._ctx.is_final and phrase: event = SpeechRecognizedEvent(phrase=phrase) self._on_speech_recognized(phrase=phrase) else: @@ -304,5 +324,8 @@ class Assistant: return event + def override_speech_model(self, model_path: Optional[str]): + self._speech_model_path_override = model_path + # vim:sw=4:ts=4:et: diff --git a/platypush/plugins/assistant/picovoice/_context.py b/platypush/plugins/assistant/picovoice/_context.py index 1a534073..e3696601 100644 --- a/platypush/plugins/assistant/picovoice/_context.py +++ b/platypush/plugins/assistant/picovoice/_context.py @@ -9,7 +9,7 @@ class ConversationContext: Context of the conversation process. """ - partial_transcript: str = '' + transcript: str = '' is_final: bool = False timeout: Optional[float] = None t_start: Optional[float] = None @@ -24,7 +24,7 @@ class ConversationContext: self.t_end = time() def reset(self): - self.partial_transcript = '' + self.transcript = '' self.is_final = False self.t_start = None self.t_end = None @@ -32,11 +32,17 @@ class ConversationContext: @property def timed_out(self): return ( - not self.partial_transcript + not self.transcript and not self.is_final and self.timeout and self.t_start and time() - self.t_start > self.timeout + ) or ( + self.transcript + and not self.is_final + and self.timeout + and self.t_start + and time() - self.t_start > self.timeout * 2 ) diff --git a/platypush/plugins/assistant/picovoice/_state.py b/platypush/plugins/assistant/picovoice/_state.py index e0eb7e71..22e1ee74 100644 --- a/platypush/plugins/assistant/picovoice/_state.py +++ b/platypush/plugins/assistant/picovoice/_state.py @@ -9,6 +9,7 @@ class AssistantState(Enum): IDLE = 'idle' DETECTING_HOTWORD = 'detecting_hotword' DETECTING_SPEECH = 'detecting_speech' + RESPONDING = 'responding' # vim:sw=4:ts=4:et: