forked from platypush/platypush
[assistant.picovoice] More features.
- Added wiring between `assistant.picovoice` and `tts.picovoice`. - Added `RESPONDING` status to the assistant. - Added ability to override the default speech model upon `start_conversation`. - Better handling of conversation timeouts. - Cache Cheetah objects in a `model -> object` map - at least the default model should be pre-loaded, since model loading at runtime seems to take a while, and that could impact the ability to detect the speech in the first seconds after a hotword is detected.
This commit is contained in:
parent
af875c996e
commit
aa333db05c
4 changed files with 96 additions and 24 deletions
|
@ -1,7 +1,10 @@
|
|||
import os
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from platypush.context import get_plugin
|
||||
from platypush.plugins import RunnablePlugin, action
|
||||
from platypush.plugins.assistant import AssistantPlugin
|
||||
from platypush.plugins.tts.picovoice import TtsPicovoicePlugin
|
||||
|
||||
from ._assistant import Assistant
|
||||
from ._state import AssistantState
|
||||
|
@ -96,7 +99,12 @@ class AssistantPicovoicePlugin(AssistantPlugin, RunnablePlugin):
|
|||
using a language other than English, you can provide the path to the
|
||||
model file for that language. Model files are available for all the
|
||||
supported languages through the `Picovoice repository
|
||||
<https://github.com/Picovoice/porcupine/tree/master/lib/common>`_.
|
||||
<https://github.com/Picovoice/cheetah/tree/master/lib/common>`_.
|
||||
You can also use the `Picovoice console
|
||||
<https://console.picovoice.ai/cat>`_
|
||||
to train your custom models. You can use a base model and fine-tune
|
||||
it by boosting the detection of your own words and phrases and edit
|
||||
the phonetic representation of the words you want to detect.
|
||||
:param endpoint_duration: If set, the assistant will stop listening when
|
||||
no speech is detected for the specified duration (in seconds) after
|
||||
the end of an utterance.
|
||||
|
@ -146,15 +154,47 @@ class AssistantPicovoicePlugin(AssistantPlugin, RunnablePlugin):
|
|||
'on_hotword_detected': self._on_hotword_detected,
|
||||
}
|
||||
|
||||
@property
|
||||
def tts(self) -> TtsPicovoicePlugin:
|
||||
p = get_plugin('tts.picovoice')
|
||||
assert p, 'Picovoice TTS plugin not configured/found'
|
||||
return p
|
||||
|
||||
def _get_tts_plugin(self) -> TtsPicovoicePlugin:
|
||||
return self.tts
|
||||
|
||||
def _on_response_render_start(self, text: Optional[str]):
|
||||
if self._assistant:
|
||||
self._assistant.state = AssistantState.RESPONDING
|
||||
return super()._on_response_render_start(text)
|
||||
|
||||
def _on_response_render_end(self):
|
||||
if self._assistant:
|
||||
self._assistant.state = (
|
||||
AssistantState.DETECTING_HOTWORD
|
||||
if self._assistant.hotword_enabled
|
||||
else AssistantState.IDLE
|
||||
)
|
||||
|
||||
return super()._on_response_render_end()
|
||||
|
||||
@action
|
||||
def start_conversation(self, *_, **__):
|
||||
def start_conversation(self, *_, model_file: Optional[str] = None, **__):
|
||||
"""
|
||||
Programmatically start a conversation with the assistant
|
||||
Programmatically start a conversation with the assistant.
|
||||
|
||||
:param model_file: Override the model file to be used to detect speech
|
||||
in this conversation. If not set, the configured
|
||||
``speech_model_path`` will be used.
|
||||
"""
|
||||
if not self._assistant:
|
||||
self.logger.warning('Assistant not initialized')
|
||||
return
|
||||
|
||||
if model_file:
|
||||
model_file = os.path.expanduser(model_file)
|
||||
|
||||
self._assistant.override_speech_model(model_file)
|
||||
self._assistant.state = AssistantState.DETECTING_SPEECH
|
||||
|
||||
@action
|
||||
|
@ -166,6 +206,8 @@ class AssistantPicovoicePlugin(AssistantPlugin, RunnablePlugin):
|
|||
self.logger.warning('Assistant not initialized')
|
||||
return
|
||||
|
||||
self._assistant.override_speech_model(None)
|
||||
|
||||
if self._assistant.hotword_enabled:
|
||||
self._assistant.state = AssistantState.DETECTING_HOTWORD
|
||||
else:
|
||||
|
|
|
@ -9,11 +9,13 @@ import pvleopard
|
|||
import pvporcupine
|
||||
import pvrhino
|
||||
|
||||
from platypush.context import get_plugin
|
||||
from platypush.message.event.assistant import (
|
||||
ConversationTimeoutEvent,
|
||||
HotwordDetectedEvent,
|
||||
SpeechRecognizedEvent,
|
||||
)
|
||||
from platypush.plugins.tts.picovoice import TtsPicovoicePlugin
|
||||
|
||||
from ._context import ConversationContext
|
||||
from ._recorder import AudioRecorder
|
||||
|
@ -25,6 +27,7 @@ class Assistant:
|
|||
A facade class that wraps the Picovoice engines under an assistant API.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _default_callback(*_, **__):
|
||||
pass
|
||||
|
||||
|
@ -61,11 +64,12 @@ class Assistant:
|
|||
self.keyword_paths = None
|
||||
self.keyword_model_path = None
|
||||
self.frame_expiration = frame_expiration
|
||||
self.speech_model_path = speech_model_path
|
||||
self.endpoint_duration = endpoint_duration
|
||||
self.enable_automatic_punctuation = enable_automatic_punctuation
|
||||
self.start_conversation_on_hotword = start_conversation_on_hotword
|
||||
self.audio_queue_size = audio_queue_size
|
||||
self._speech_model_path = speech_model_path
|
||||
self._speech_model_path_override = None
|
||||
|
||||
self._on_conversation_start = on_conversation_start
|
||||
self._on_conversation_end = on_conversation_end
|
||||
|
@ -103,11 +107,22 @@ class Assistant:
|
|||
|
||||
self.keyword_model_path = keyword_model_path
|
||||
|
||||
self._cheetah: Optional[pvcheetah.Cheetah] = None
|
||||
# Model path -> model instance cache
|
||||
self._cheetah = {}
|
||||
self._leopard: Optional[pvleopard.Leopard] = None
|
||||
self._porcupine: Optional[pvporcupine.Porcupine] = None
|
||||
self._rhino: Optional[pvrhino.Rhino] = None
|
||||
|
||||
@property
|
||||
def speech_model_path(self):
|
||||
return self._speech_model_path_override or self._speech_model_path
|
||||
|
||||
@property
|
||||
def tts(self) -> TtsPicovoicePlugin:
|
||||
p = get_plugin('tts.picovoice')
|
||||
assert p, 'Picovoice TTS plugin not configured/found'
|
||||
return p
|
||||
|
||||
def should_stop(self):
|
||||
return self._stop_event.is_set()
|
||||
|
||||
|
@ -130,12 +145,18 @@ class Assistant:
|
|||
return
|
||||
|
||||
if prev_state == AssistantState.DETECTING_SPEECH:
|
||||
self.tts.stop()
|
||||
self._ctx.stop()
|
||||
self._speech_model_path_override = None
|
||||
self._on_conversation_end()
|
||||
elif new_state == AssistantState.DETECTING_SPEECH:
|
||||
self._ctx.start()
|
||||
self._on_conversation_start()
|
||||
|
||||
if new_state == AssistantState.DETECTING_HOTWORD:
|
||||
self.tts.stop()
|
||||
self._ctx.reset()
|
||||
|
||||
@property
|
||||
def porcupine(self) -> Optional[pvporcupine.Porcupine]:
|
||||
if not self.hotword_enabled:
|
||||
|
@ -159,7 +180,7 @@ class Assistant:
|
|||
if not self.stt_enabled:
|
||||
return None
|
||||
|
||||
if not self._cheetah:
|
||||
if not self._cheetah.get(self.speech_model_path):
|
||||
args: Dict[str, Any] = {'access_key': self._access_key}
|
||||
if self.speech_model_path:
|
||||
args['model_path'] = self.speech_model_path
|
||||
|
@ -168,9 +189,9 @@ class Assistant:
|
|||
if self.enable_automatic_punctuation:
|
||||
args['enable_automatic_punctuation'] = self.enable_automatic_punctuation
|
||||
|
||||
self._cheetah = pvcheetah.create(**args)
|
||||
self._cheetah[self.speech_model_path] = pvcheetah.create(**args)
|
||||
|
||||
return self._cheetah
|
||||
return self._cheetah[self.speech_model_path]
|
||||
|
||||
def __enter__(self):
|
||||
if self.should_stop():
|
||||
|
@ -178,10 +199,9 @@ class Assistant:
|
|||
|
||||
if self._recorder:
|
||||
self.logger.info('A recording stream already exists')
|
||||
elif self.porcupine or self.cheetah:
|
||||
elif self.hotword_enabled or self.stt_enabled:
|
||||
sample_rate = (self.porcupine or self.cheetah).sample_rate # type: ignore
|
||||
frame_length = (self.porcupine or self.cheetah).frame_length # type: ignore
|
||||
|
||||
self._recorder = AudioRecorder(
|
||||
stop_event=self._stop_event,
|
||||
sample_rate=sample_rate,
|
||||
|
@ -190,6 +210,9 @@ class Assistant:
|
|||
channels=1,
|
||||
)
|
||||
|
||||
if self.stt_enabled:
|
||||
self._cheetah[self.speech_model_path] = self.cheetah
|
||||
|
||||
self._recorder.__enter__()
|
||||
|
||||
if self.porcupine:
|
||||
|
@ -205,10 +228,10 @@ class Assistant:
|
|||
self._recorder = None
|
||||
|
||||
self.state = AssistantState.IDLE
|
||||
|
||||
if self._cheetah:
|
||||
self._cheetah.delete()
|
||||
self._cheetah = None
|
||||
for model in [*self._cheetah.keys()]:
|
||||
cheetah = self._cheetah.pop(model, None)
|
||||
if cheetah:
|
||||
cheetah.delete()
|
||||
|
||||
if self._leopard:
|
||||
self._leopard.delete()
|
||||
|
@ -242,10 +265,10 @@ class Assistant:
|
|||
)
|
||||
continue # The audio frame is too old
|
||||
|
||||
if self.porcupine and self.state == AssistantState.DETECTING_HOTWORD:
|
||||
if self.hotword_enabled and self.state == AssistantState.DETECTING_HOTWORD:
|
||||
return self._process_hotword(frame)
|
||||
|
||||
if self.cheetah and self.state == AssistantState.DETECTING_SPEECH:
|
||||
if self.stt_enabled and self.state == AssistantState.DETECTING_SPEECH:
|
||||
return self._process_speech(frame)
|
||||
|
||||
raise StopIteration
|
||||
|
@ -283,15 +306,12 @@ class Assistant:
|
|||
)
|
||||
|
||||
if self._ctx.is_final or self._ctx.timed_out:
|
||||
phrase = ''
|
||||
if self.cheetah:
|
||||
phrase = self.cheetah.flush()
|
||||
|
||||
phrase = self.cheetah.flush() or ''
|
||||
self._ctx.transcript += phrase
|
||||
phrase = self._ctx.transcript
|
||||
phrase = phrase[:1].lower() + phrase[1:]
|
||||
|
||||
if self._ctx.is_final or phrase:
|
||||
if self._ctx.is_final and phrase:
|
||||
event = SpeechRecognizedEvent(phrase=phrase)
|
||||
self._on_speech_recognized(phrase=phrase)
|
||||
else:
|
||||
|
@ -304,5 +324,8 @@ class Assistant:
|
|||
|
||||
return event
|
||||
|
||||
def override_speech_model(self, model_path: Optional[str]):
|
||||
self._speech_model_path_override = model_path
|
||||
|
||||
|
||||
# vim:sw=4:ts=4:et:
|
||||
|
|
|
@ -9,7 +9,7 @@ class ConversationContext:
|
|||
Context of the conversation process.
|
||||
"""
|
||||
|
||||
partial_transcript: str = ''
|
||||
transcript: str = ''
|
||||
is_final: bool = False
|
||||
timeout: Optional[float] = None
|
||||
t_start: Optional[float] = None
|
||||
|
@ -24,7 +24,7 @@ class ConversationContext:
|
|||
self.t_end = time()
|
||||
|
||||
def reset(self):
|
||||
self.partial_transcript = ''
|
||||
self.transcript = ''
|
||||
self.is_final = False
|
||||
self.t_start = None
|
||||
self.t_end = None
|
||||
|
@ -32,11 +32,17 @@ class ConversationContext:
|
|||
@property
|
||||
def timed_out(self):
|
||||
return (
|
||||
not self.partial_transcript
|
||||
not self.transcript
|
||||
and not self.is_final
|
||||
and self.timeout
|
||||
and self.t_start
|
||||
and time() - self.t_start > self.timeout
|
||||
) or (
|
||||
self.transcript
|
||||
and not self.is_final
|
||||
and self.timeout
|
||||
and self.t_start
|
||||
and time() - self.t_start > self.timeout * 2
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ class AssistantState(Enum):
|
|||
IDLE = 'idle'
|
||||
DETECTING_HOTWORD = 'detecting_hotword'
|
||||
DETECTING_SPEECH = 'detecting_speech'
|
||||
RESPONDING = 'responding'
|
||||
|
||||
|
||||
# vim:sw=4:ts=4:et:
|
||||
|
|
Loading…
Reference in a new issue