[assistant.picovoice] More features.

- Added wiring between `assistant.picovoice` and `tts.picovoice`.

- Added `RESPONDING` status to the assistant.

- Added ability to override the default speech model upon
  `start_conversation`.

- Better handling of conversation timeouts.

- Cache Cheetah objects in a `model -> object` map - at least the
  default model should be pre-loaded, since model loading at runtime
  seems to take a while, and that could impact the ability to detect the
  speech in the first seconds after a hotword is detected.
This commit is contained in:
Fabio Manganiello 2024-04-10 22:26:45 +02:00
parent af875c996e
commit aa333db05c
4 changed files with 96 additions and 24 deletions

View file

@ -1,7 +1,10 @@
import os
from typing import Optional, Sequence from typing import Optional, Sequence
from platypush.context import get_plugin
from platypush.plugins import RunnablePlugin, action from platypush.plugins import RunnablePlugin, action
from platypush.plugins.assistant import AssistantPlugin from platypush.plugins.assistant import AssistantPlugin
from platypush.plugins.tts.picovoice import TtsPicovoicePlugin
from ._assistant import Assistant from ._assistant import Assistant
from ._state import AssistantState from ._state import AssistantState
@ -96,7 +99,12 @@ class AssistantPicovoicePlugin(AssistantPlugin, RunnablePlugin):
using a language other than English, you can provide the path to the using a language other than English, you can provide the path to the
model file for that language. Model files are available for all the model file for that language. Model files are available for all the
supported languages through the `Picovoice repository supported languages through the `Picovoice repository
<https://github.com/Picovoice/porcupine/tree/master/lib/common>`_. <https://github.com/Picovoice/cheetah/tree/master/lib/common>`_.
You can also use the `Picovoice console
<https://console.picovoice.ai/cat>`_
to train your custom models. You can use a base model and fine-tune
it by boosting the detection of your own words and phrases and edit
the phonetic representation of the words you want to detect.
:param endpoint_duration: If set, the assistant will stop listening when :param endpoint_duration: If set, the assistant will stop listening when
no speech is detected for the specified duration (in seconds) after no speech is detected for the specified duration (in seconds) after
the end of an utterance. the end of an utterance.
@ -146,15 +154,47 @@ class AssistantPicovoicePlugin(AssistantPlugin, RunnablePlugin):
'on_hotword_detected': self._on_hotword_detected, 'on_hotword_detected': self._on_hotword_detected,
} }
@property
def tts(self) -> TtsPicovoicePlugin:
p = get_plugin('tts.picovoice')
assert p, 'Picovoice TTS plugin not configured/found'
return p
def _get_tts_plugin(self) -> TtsPicovoicePlugin:
return self.tts
def _on_response_render_start(self, text: Optional[str]):
if self._assistant:
self._assistant.state = AssistantState.RESPONDING
return super()._on_response_render_start(text)
def _on_response_render_end(self):
if self._assistant:
self._assistant.state = (
AssistantState.DETECTING_HOTWORD
if self._assistant.hotword_enabled
else AssistantState.IDLE
)
return super()._on_response_render_end()
@action @action
def start_conversation(self, *_, **__): def start_conversation(self, *_, model_file: Optional[str] = None, **__):
""" """
Programmatically start a conversation with the assistant Programmatically start a conversation with the assistant.
:param model_file: Override the model file to be used to detect speech
in this conversation. If not set, the configured
``speech_model_path`` will be used.
""" """
if not self._assistant: if not self._assistant:
self.logger.warning('Assistant not initialized') self.logger.warning('Assistant not initialized')
return return
if model_file:
model_file = os.path.expanduser(model_file)
self._assistant.override_speech_model(model_file)
self._assistant.state = AssistantState.DETECTING_SPEECH self._assistant.state = AssistantState.DETECTING_SPEECH
@action @action
@ -166,6 +206,8 @@ class AssistantPicovoicePlugin(AssistantPlugin, RunnablePlugin):
self.logger.warning('Assistant not initialized') self.logger.warning('Assistant not initialized')
return return
self._assistant.override_speech_model(None)
if self._assistant.hotword_enabled: if self._assistant.hotword_enabled:
self._assistant.state = AssistantState.DETECTING_HOTWORD self._assistant.state = AssistantState.DETECTING_HOTWORD
else: else:

View file

@ -9,11 +9,13 @@ import pvleopard
import pvporcupine import pvporcupine
import pvrhino import pvrhino
from platypush.context import get_plugin
from platypush.message.event.assistant import ( from platypush.message.event.assistant import (
ConversationTimeoutEvent, ConversationTimeoutEvent,
HotwordDetectedEvent, HotwordDetectedEvent,
SpeechRecognizedEvent, SpeechRecognizedEvent,
) )
from platypush.plugins.tts.picovoice import TtsPicovoicePlugin
from ._context import ConversationContext from ._context import ConversationContext
from ._recorder import AudioRecorder from ._recorder import AudioRecorder
@ -25,6 +27,7 @@ class Assistant:
A facade class that wraps the Picovoice engines under an assistant API. A facade class that wraps the Picovoice engines under an assistant API.
""" """
@staticmethod
def _default_callback(*_, **__): def _default_callback(*_, **__):
pass pass
@ -61,11 +64,12 @@ class Assistant:
self.keyword_paths = None self.keyword_paths = None
self.keyword_model_path = None self.keyword_model_path = None
self.frame_expiration = frame_expiration self.frame_expiration = frame_expiration
self.speech_model_path = speech_model_path
self.endpoint_duration = endpoint_duration self.endpoint_duration = endpoint_duration
self.enable_automatic_punctuation = enable_automatic_punctuation self.enable_automatic_punctuation = enable_automatic_punctuation
self.start_conversation_on_hotword = start_conversation_on_hotword self.start_conversation_on_hotword = start_conversation_on_hotword
self.audio_queue_size = audio_queue_size self.audio_queue_size = audio_queue_size
self._speech_model_path = speech_model_path
self._speech_model_path_override = None
self._on_conversation_start = on_conversation_start self._on_conversation_start = on_conversation_start
self._on_conversation_end = on_conversation_end self._on_conversation_end = on_conversation_end
@ -103,11 +107,22 @@ class Assistant:
self.keyword_model_path = keyword_model_path self.keyword_model_path = keyword_model_path
self._cheetah: Optional[pvcheetah.Cheetah] = None # Model path -> model instance cache
self._cheetah = {}
self._leopard: Optional[pvleopard.Leopard] = None self._leopard: Optional[pvleopard.Leopard] = None
self._porcupine: Optional[pvporcupine.Porcupine] = None self._porcupine: Optional[pvporcupine.Porcupine] = None
self._rhino: Optional[pvrhino.Rhino] = None self._rhino: Optional[pvrhino.Rhino] = None
@property
def speech_model_path(self):
return self._speech_model_path_override or self._speech_model_path
@property
def tts(self) -> TtsPicovoicePlugin:
p = get_plugin('tts.picovoice')
assert p, 'Picovoice TTS plugin not configured/found'
return p
def should_stop(self): def should_stop(self):
return self._stop_event.is_set() return self._stop_event.is_set()
@ -130,12 +145,18 @@ class Assistant:
return return
if prev_state == AssistantState.DETECTING_SPEECH: if prev_state == AssistantState.DETECTING_SPEECH:
self.tts.stop()
self._ctx.stop() self._ctx.stop()
self._speech_model_path_override = None
self._on_conversation_end() self._on_conversation_end()
elif new_state == AssistantState.DETECTING_SPEECH: elif new_state == AssistantState.DETECTING_SPEECH:
self._ctx.start() self._ctx.start()
self._on_conversation_start() self._on_conversation_start()
if new_state == AssistantState.DETECTING_HOTWORD:
self.tts.stop()
self._ctx.reset()
@property @property
def porcupine(self) -> Optional[pvporcupine.Porcupine]: def porcupine(self) -> Optional[pvporcupine.Porcupine]:
if not self.hotword_enabled: if not self.hotword_enabled:
@ -159,7 +180,7 @@ class Assistant:
if not self.stt_enabled: if not self.stt_enabled:
return None return None
if not self._cheetah: if not self._cheetah.get(self.speech_model_path):
args: Dict[str, Any] = {'access_key': self._access_key} args: Dict[str, Any] = {'access_key': self._access_key}
if self.speech_model_path: if self.speech_model_path:
args['model_path'] = self.speech_model_path args['model_path'] = self.speech_model_path
@ -168,9 +189,9 @@ class Assistant:
if self.enable_automatic_punctuation: if self.enable_automatic_punctuation:
args['enable_automatic_punctuation'] = self.enable_automatic_punctuation args['enable_automatic_punctuation'] = self.enable_automatic_punctuation
self._cheetah = pvcheetah.create(**args) self._cheetah[self.speech_model_path] = pvcheetah.create(**args)
return self._cheetah return self._cheetah[self.speech_model_path]
def __enter__(self): def __enter__(self):
if self.should_stop(): if self.should_stop():
@ -178,10 +199,9 @@ class Assistant:
if self._recorder: if self._recorder:
self.logger.info('A recording stream already exists') self.logger.info('A recording stream already exists')
elif self.porcupine or self.cheetah: elif self.hotword_enabled or self.stt_enabled:
sample_rate = (self.porcupine or self.cheetah).sample_rate # type: ignore sample_rate = (self.porcupine or self.cheetah).sample_rate # type: ignore
frame_length = (self.porcupine or self.cheetah).frame_length # type: ignore frame_length = (self.porcupine or self.cheetah).frame_length # type: ignore
self._recorder = AudioRecorder( self._recorder = AudioRecorder(
stop_event=self._stop_event, stop_event=self._stop_event,
sample_rate=sample_rate, sample_rate=sample_rate,
@ -190,6 +210,9 @@ class Assistant:
channels=1, channels=1,
) )
if self.stt_enabled:
self._cheetah[self.speech_model_path] = self.cheetah
self._recorder.__enter__() self._recorder.__enter__()
if self.porcupine: if self.porcupine:
@ -205,10 +228,10 @@ class Assistant:
self._recorder = None self._recorder = None
self.state = AssistantState.IDLE self.state = AssistantState.IDLE
for model in [*self._cheetah.keys()]:
if self._cheetah: cheetah = self._cheetah.pop(model, None)
self._cheetah.delete() if cheetah:
self._cheetah = None cheetah.delete()
if self._leopard: if self._leopard:
self._leopard.delete() self._leopard.delete()
@ -242,10 +265,10 @@ class Assistant:
) )
continue # The audio frame is too old continue # The audio frame is too old
if self.porcupine and self.state == AssistantState.DETECTING_HOTWORD: if self.hotword_enabled and self.state == AssistantState.DETECTING_HOTWORD:
return self._process_hotword(frame) return self._process_hotword(frame)
if self.cheetah and self.state == AssistantState.DETECTING_SPEECH: if self.stt_enabled and self.state == AssistantState.DETECTING_SPEECH:
return self._process_speech(frame) return self._process_speech(frame)
raise StopIteration raise StopIteration
@ -283,15 +306,12 @@ class Assistant:
) )
if self._ctx.is_final or self._ctx.timed_out: if self._ctx.is_final or self._ctx.timed_out:
phrase = '' phrase = self.cheetah.flush() or ''
if self.cheetah:
phrase = self.cheetah.flush()
self._ctx.transcript += phrase self._ctx.transcript += phrase
phrase = self._ctx.transcript phrase = self._ctx.transcript
phrase = phrase[:1].lower() + phrase[1:] phrase = phrase[:1].lower() + phrase[1:]
if self._ctx.is_final or phrase: if self._ctx.is_final and phrase:
event = SpeechRecognizedEvent(phrase=phrase) event = SpeechRecognizedEvent(phrase=phrase)
self._on_speech_recognized(phrase=phrase) self._on_speech_recognized(phrase=phrase)
else: else:
@ -304,5 +324,8 @@ class Assistant:
return event return event
def override_speech_model(self, model_path: Optional[str]):
self._speech_model_path_override = model_path
# vim:sw=4:ts=4:et: # vim:sw=4:ts=4:et:

View file

@ -9,7 +9,7 @@ class ConversationContext:
Context of the conversation process. Context of the conversation process.
""" """
partial_transcript: str = '' transcript: str = ''
is_final: bool = False is_final: bool = False
timeout: Optional[float] = None timeout: Optional[float] = None
t_start: Optional[float] = None t_start: Optional[float] = None
@ -24,7 +24,7 @@ class ConversationContext:
self.t_end = time() self.t_end = time()
def reset(self): def reset(self):
self.partial_transcript = '' self.transcript = ''
self.is_final = False self.is_final = False
self.t_start = None self.t_start = None
self.t_end = None self.t_end = None
@ -32,11 +32,17 @@ class ConversationContext:
@property @property
def timed_out(self): def timed_out(self):
return ( return (
not self.partial_transcript not self.transcript
and not self.is_final and not self.is_final
and self.timeout and self.timeout
and self.t_start and self.t_start
and time() - self.t_start > self.timeout and time() - self.t_start > self.timeout
) or (
self.transcript
and not self.is_final
and self.timeout
and self.t_start
and time() - self.t_start > self.timeout * 2
) )

View file

@ -9,6 +9,7 @@ class AssistantState(Enum):
IDLE = 'idle' IDLE = 'idle'
DETECTING_HOTWORD = 'detecting_hotword' DETECTING_HOTWORD = 'detecting_hotword'
DETECTING_SPEECH = 'detecting_speech' DETECTING_SPEECH = 'detecting_speech'
RESPONDING = 'responding'
# vim:sw=4:ts=4:et: # vim:sw=4:ts=4:et: