2024-04-07 22:42:01 +02:00
|
|
|
import logging
|
|
|
|
import os
|
2024-04-08 01:54:26 +02:00
|
|
|
from threading import Event, RLock
|
2024-04-07 22:42:01 +02:00
|
|
|
from time import time
|
|
|
|
from typing import Any, Dict, Optional, Sequence
|
|
|
|
|
|
|
|
import pvcheetah
|
|
|
|
import pvleopard
|
|
|
|
import pvporcupine
|
|
|
|
import pvrhino
|
|
|
|
|
2024-04-10 22:26:45 +02:00
|
|
|
from platypush.context import get_plugin
|
2024-04-08 01:54:26 +02:00
|
|
|
from platypush.message.event.assistant import (
|
|
|
|
ConversationTimeoutEvent,
|
|
|
|
HotwordDetectedEvent,
|
|
|
|
SpeechRecognizedEvent,
|
|
|
|
)
|
2024-04-10 22:26:45 +02:00
|
|
|
from platypush.plugins.tts.picovoice import TtsPicovoicePlugin
|
2024-04-07 22:42:01 +02:00
|
|
|
|
2024-04-08 03:02:03 +02:00
|
|
|
from ._context import ConversationContext
|
2024-04-07 22:42:01 +02:00
|
|
|
from ._recorder import AudioRecorder
|
2024-04-08 01:54:26 +02:00
|
|
|
from ._state import AssistantState
|
2024-04-07 22:42:01 +02:00
|
|
|
|
|
|
|
|
|
|
|
class Assistant:
|
|
|
|
"""
|
|
|
|
A facade class that wraps the Picovoice engines under an assistant API.
|
|
|
|
"""
|
|
|
|
|
2024-04-10 22:26:45 +02:00
|
|
|
@staticmethod
|
2024-04-08 03:02:03 +02:00
|
|
|
def _default_callback(*_, **__):
|
|
|
|
pass
|
|
|
|
|
2024-04-07 22:42:01 +02:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
access_key: str,
|
|
|
|
stop_event: Event,
|
|
|
|
hotword_enabled: bool = True,
|
|
|
|
stt_enabled: bool = True,
|
|
|
|
intent_enabled: bool = False,
|
|
|
|
keywords: Optional[Sequence[str]] = None,
|
|
|
|
keyword_paths: Optional[Sequence[str]] = None,
|
|
|
|
keyword_model_path: Optional[str] = None,
|
|
|
|
frame_expiration: float = 3.0, # Don't process audio frames older than this
|
2024-04-08 01:54:26 +02:00
|
|
|
speech_model_path: Optional[str] = None,
|
|
|
|
endpoint_duration: Optional[float] = None,
|
|
|
|
enable_automatic_punctuation: bool = False,
|
|
|
|
start_conversation_on_hotword: bool = False,
|
|
|
|
audio_queue_size: int = 100,
|
|
|
|
conversation_timeout: Optional[float] = None,
|
2024-04-08 03:02:03 +02:00
|
|
|
on_conversation_start=_default_callback,
|
|
|
|
on_conversation_end=_default_callback,
|
|
|
|
on_conversation_timeout=_default_callback,
|
|
|
|
on_speech_recognized=_default_callback,
|
|
|
|
on_hotword_detected=_default_callback,
|
2024-04-07 22:42:01 +02:00
|
|
|
):
|
|
|
|
self._access_key = access_key
|
|
|
|
self._stop_event = stop_event
|
2024-04-08 01:54:26 +02:00
|
|
|
self.logger = logging.getLogger(__name__)
|
2024-04-07 22:42:01 +02:00
|
|
|
self.hotword_enabled = hotword_enabled
|
|
|
|
self.stt_enabled = stt_enabled
|
|
|
|
self.intent_enabled = intent_enabled
|
|
|
|
self.keywords = list(keywords or [])
|
|
|
|
self.keyword_paths = None
|
|
|
|
self.keyword_model_path = None
|
2024-04-13 19:49:58 +02:00
|
|
|
self._responding = Event()
|
2024-04-07 22:42:01 +02:00
|
|
|
self.frame_expiration = frame_expiration
|
2024-04-08 01:54:26 +02:00
|
|
|
self.endpoint_duration = endpoint_duration
|
|
|
|
self.enable_automatic_punctuation = enable_automatic_punctuation
|
|
|
|
self.start_conversation_on_hotword = start_conversation_on_hotword
|
|
|
|
self.audio_queue_size = audio_queue_size
|
2024-04-10 22:26:45 +02:00
|
|
|
self._speech_model_path = speech_model_path
|
|
|
|
self._speech_model_path_override = None
|
2024-04-08 01:54:26 +02:00
|
|
|
|
2024-04-08 03:02:03 +02:00
|
|
|
self._on_conversation_start = on_conversation_start
|
|
|
|
self._on_conversation_end = on_conversation_end
|
|
|
|
self._on_conversation_timeout = on_conversation_timeout
|
|
|
|
self._on_speech_recognized = on_speech_recognized
|
|
|
|
self._on_hotword_detected = on_hotword_detected
|
|
|
|
|
2024-04-07 22:42:01 +02:00
|
|
|
self._recorder = None
|
2024-04-08 01:54:26 +02:00
|
|
|
self._state = AssistantState.IDLE
|
|
|
|
self._state_lock = RLock()
|
2024-04-08 03:02:03 +02:00
|
|
|
self._ctx = ConversationContext(timeout=conversation_timeout)
|
2024-04-07 22:42:01 +02:00
|
|
|
|
|
|
|
if hotword_enabled:
|
2024-04-08 01:54:26 +02:00
|
|
|
if not keywords:
|
|
|
|
raise ValueError(
|
|
|
|
'You need to provide a list of keywords if the wake-word engine is enabled'
|
|
|
|
)
|
|
|
|
|
2024-04-07 22:42:01 +02:00
|
|
|
if keyword_paths:
|
|
|
|
keyword_paths = [os.path.expanduser(path) for path in keyword_paths]
|
|
|
|
missing_paths = [
|
|
|
|
path for path in keyword_paths if not os.path.isfile(path)
|
|
|
|
]
|
|
|
|
if missing_paths:
|
|
|
|
raise FileNotFoundError(f'Keyword files not found: {missing_paths}')
|
|
|
|
|
|
|
|
self.keyword_paths = keyword_paths
|
|
|
|
|
|
|
|
if keyword_model_path:
|
|
|
|
keyword_model_path = os.path.expanduser(keyword_model_path)
|
|
|
|
if not os.path.isfile(keyword_model_path):
|
|
|
|
raise FileNotFoundError(
|
|
|
|
f'Keyword model file not found: {keyword_model_path}'
|
|
|
|
)
|
|
|
|
|
|
|
|
self.keyword_model_path = keyword_model_path
|
|
|
|
|
2024-04-10 22:26:45 +02:00
|
|
|
# Model path -> model instance cache
|
|
|
|
self._cheetah = {}
|
2024-04-07 22:42:01 +02:00
|
|
|
self._leopard: Optional[pvleopard.Leopard] = None
|
|
|
|
self._porcupine: Optional[pvporcupine.Porcupine] = None
|
|
|
|
self._rhino: Optional[pvrhino.Rhino] = None
|
|
|
|
|
2024-04-13 19:49:58 +02:00
|
|
|
@property
|
|
|
|
def is_responding(self):
|
|
|
|
return self._responding.is_set()
|
|
|
|
|
2024-04-10 22:26:45 +02:00
|
|
|
@property
|
|
|
|
def speech_model_path(self):
|
|
|
|
return self._speech_model_path_override or self._speech_model_path
|
|
|
|
|
|
|
|
@property
|
|
|
|
def tts(self) -> TtsPicovoicePlugin:
|
|
|
|
p = get_plugin('tts.picovoice')
|
|
|
|
assert p, 'Picovoice TTS plugin not configured/found'
|
|
|
|
return p
|
|
|
|
|
2024-04-13 19:49:58 +02:00
|
|
|
def set_responding(self, responding: bool):
|
|
|
|
if responding:
|
|
|
|
self._responding.set()
|
|
|
|
else:
|
|
|
|
self._responding.clear()
|
|
|
|
|
2024-04-07 22:42:01 +02:00
|
|
|
def should_stop(self):
|
|
|
|
return self._stop_event.is_set()
|
|
|
|
|
|
|
|
def wait_stop(self):
|
|
|
|
self._stop_event.wait()
|
|
|
|
|
2024-04-08 01:54:26 +02:00
|
|
|
@property
|
|
|
|
def state(self) -> AssistantState:
|
|
|
|
with self._state_lock:
|
|
|
|
return self._state
|
|
|
|
|
|
|
|
@state.setter
|
|
|
|
def state(self, state: AssistantState):
|
|
|
|
with self._state_lock:
|
|
|
|
prev_state = self._state
|
|
|
|
self._state = state
|
|
|
|
new_state = self.state
|
|
|
|
|
|
|
|
if prev_state == new_state:
|
|
|
|
return
|
|
|
|
|
|
|
|
if prev_state == AssistantState.DETECTING_SPEECH:
|
2024-04-10 22:26:45 +02:00
|
|
|
self.tts.stop()
|
2024-04-08 03:02:03 +02:00
|
|
|
self._ctx.stop()
|
2024-04-10 22:26:45 +02:00
|
|
|
self._speech_model_path_override = None
|
2024-04-08 03:02:03 +02:00
|
|
|
self._on_conversation_end()
|
2024-04-08 01:54:26 +02:00
|
|
|
elif new_state == AssistantState.DETECTING_SPEECH:
|
2024-04-08 03:02:03 +02:00
|
|
|
self._ctx.start()
|
|
|
|
self._on_conversation_start()
|
2024-04-08 01:54:26 +02:00
|
|
|
|
2024-04-10 22:26:45 +02:00
|
|
|
if new_state == AssistantState.DETECTING_HOTWORD:
|
|
|
|
self.tts.stop()
|
|
|
|
self._ctx.reset()
|
|
|
|
|
2024-04-08 01:54:26 +02:00
|
|
|
@property
|
|
|
|
def porcupine(self) -> Optional[pvporcupine.Porcupine]:
|
2024-04-07 22:42:01 +02:00
|
|
|
if not self.hotword_enabled:
|
|
|
|
return None
|
|
|
|
|
2024-04-08 01:54:26 +02:00
|
|
|
if not self._porcupine:
|
|
|
|
args: Dict[str, Any] = {'access_key': self._access_key}
|
|
|
|
if self.keywords:
|
|
|
|
args['keywords'] = self.keywords
|
|
|
|
if self.keyword_paths:
|
|
|
|
args['keyword_paths'] = self.keyword_paths
|
|
|
|
if self.keyword_model_path:
|
|
|
|
args['model_path'] = self.keyword_model_path
|
2024-04-07 22:42:01 +02:00
|
|
|
|
2024-04-08 01:54:26 +02:00
|
|
|
self._porcupine = pvporcupine.create(**args)
|
2024-04-07 22:42:01 +02:00
|
|
|
|
2024-04-08 01:54:26 +02:00
|
|
|
return self._porcupine
|
2024-04-07 22:42:01 +02:00
|
|
|
|
|
|
|
@property
|
2024-04-08 01:54:26 +02:00
|
|
|
def cheetah(self) -> Optional[pvcheetah.Cheetah]:
|
|
|
|
if not self.stt_enabled:
|
|
|
|
return None
|
2024-04-07 22:42:01 +02:00
|
|
|
|
2024-04-10 22:26:45 +02:00
|
|
|
if not self._cheetah.get(self.speech_model_path):
|
2024-04-08 01:54:26 +02:00
|
|
|
args: Dict[str, Any] = {'access_key': self._access_key}
|
|
|
|
if self.speech_model_path:
|
|
|
|
args['model_path'] = self.speech_model_path
|
|
|
|
if self.endpoint_duration:
|
|
|
|
args['endpoint_duration_sec'] = self.endpoint_duration
|
|
|
|
if self.enable_automatic_punctuation:
|
|
|
|
args['enable_automatic_punctuation'] = self.enable_automatic_punctuation
|
|
|
|
|
2024-04-10 22:26:45 +02:00
|
|
|
self._cheetah[self.speech_model_path] = pvcheetah.create(**args)
|
2024-04-08 01:54:26 +02:00
|
|
|
|
2024-04-10 22:26:45 +02:00
|
|
|
return self._cheetah[self.speech_model_path]
|
2024-04-07 22:42:01 +02:00
|
|
|
|
|
|
|
def __enter__(self):
|
2024-04-13 19:49:58 +02:00
|
|
|
"""
|
|
|
|
Get the assistant ready to start processing audio frames.
|
|
|
|
"""
|
2024-04-08 01:54:26 +02:00
|
|
|
if self.should_stop():
|
|
|
|
return self
|
|
|
|
|
2024-04-07 22:42:01 +02:00
|
|
|
if self._recorder:
|
|
|
|
self.logger.info('A recording stream already exists')
|
2024-04-10 22:26:45 +02:00
|
|
|
elif self.hotword_enabled or self.stt_enabled:
|
2024-04-08 01:54:26 +02:00
|
|
|
sample_rate = (self.porcupine or self.cheetah).sample_rate # type: ignore
|
|
|
|
frame_length = (self.porcupine or self.cheetah).frame_length # type: ignore
|
2024-04-07 22:42:01 +02:00
|
|
|
self._recorder = AudioRecorder(
|
|
|
|
stop_event=self._stop_event,
|
2024-04-08 01:54:26 +02:00
|
|
|
sample_rate=sample_rate,
|
|
|
|
frame_size=frame_length,
|
|
|
|
queue_size=self.audio_queue_size,
|
2024-04-07 22:42:01 +02:00
|
|
|
channels=1,
|
|
|
|
)
|
|
|
|
|
2024-04-10 22:26:45 +02:00
|
|
|
if self.stt_enabled:
|
|
|
|
self._cheetah[self.speech_model_path] = self.cheetah
|
|
|
|
|
2024-04-07 22:42:01 +02:00
|
|
|
self._recorder.__enter__()
|
|
|
|
|
2024-04-08 01:54:26 +02:00
|
|
|
if self.porcupine:
|
|
|
|
self.state = AssistantState.DETECTING_HOTWORD
|
|
|
|
else:
|
|
|
|
self.state = AssistantState.DETECTING_SPEECH
|
|
|
|
|
2024-04-07 22:42:01 +02:00
|
|
|
return self
|
|
|
|
|
|
|
|
def __exit__(self, *_):
|
2024-04-13 19:49:58 +02:00
|
|
|
"""
|
|
|
|
Stop the assistant and release all resources.
|
|
|
|
"""
|
2024-04-07 22:42:01 +02:00
|
|
|
if self._recorder:
|
|
|
|
self._recorder.__exit__(*_)
|
|
|
|
self._recorder = None
|
|
|
|
|
2024-04-08 01:54:26 +02:00
|
|
|
self.state = AssistantState.IDLE
|
2024-04-10 22:26:45 +02:00
|
|
|
for model in [*self._cheetah.keys()]:
|
|
|
|
cheetah = self._cheetah.pop(model, None)
|
|
|
|
if cheetah:
|
|
|
|
cheetah.delete()
|
2024-04-07 22:42:01 +02:00
|
|
|
|
|
|
|
if self._leopard:
|
|
|
|
self._leopard.delete()
|
|
|
|
self._leopard = None
|
|
|
|
|
|
|
|
if self._porcupine:
|
|
|
|
self._porcupine.delete()
|
|
|
|
self._porcupine = None
|
|
|
|
|
|
|
|
if self._rhino:
|
|
|
|
self._rhino.delete()
|
|
|
|
self._rhino = None
|
|
|
|
|
|
|
|
def __iter__(self):
|
2024-04-13 19:49:58 +02:00
|
|
|
"""
|
|
|
|
Iterate over processed assistant events.
|
|
|
|
"""
|
2024-04-07 22:42:01 +02:00
|
|
|
return self
|
|
|
|
|
|
|
|
def __next__(self):
|
2024-04-13 19:49:58 +02:00
|
|
|
"""
|
|
|
|
Process the next audio frame and return the corresponding event.
|
|
|
|
"""
|
2024-04-07 22:42:01 +02:00
|
|
|
has_data = False
|
|
|
|
if self.should_stop() or not self._recorder:
|
|
|
|
raise StopIteration
|
|
|
|
|
|
|
|
while not (self.should_stop() or has_data):
|
2024-04-08 01:54:26 +02:00
|
|
|
data = self._recorder.read()
|
|
|
|
if data is None:
|
|
|
|
continue
|
2024-04-07 22:42:01 +02:00
|
|
|
|
2024-04-08 01:54:26 +02:00
|
|
|
frame, t = data
|
|
|
|
if time() - t > self.frame_expiration:
|
|
|
|
self.logger.info(
|
|
|
|
'Skipping audio frame older than %ss', self.frame_expiration
|
|
|
|
)
|
|
|
|
continue # The audio frame is too old
|
2024-04-07 22:42:01 +02:00
|
|
|
|
2024-04-10 22:26:45 +02:00
|
|
|
if self.hotword_enabled and self.state == AssistantState.DETECTING_HOTWORD:
|
2024-04-08 01:54:26 +02:00
|
|
|
return self._process_hotword(frame)
|
|
|
|
|
2024-04-10 22:26:45 +02:00
|
|
|
if self.stt_enabled and self.state == AssistantState.DETECTING_SPEECH:
|
2024-04-08 01:54:26 +02:00
|
|
|
return self._process_speech(frame)
|
2024-04-07 22:42:01 +02:00
|
|
|
|
|
|
|
raise StopIteration
|
|
|
|
|
2024-04-08 01:54:26 +02:00
|
|
|
def _process_hotword(self, frame):
|
|
|
|
if not self.porcupine:
|
|
|
|
return None
|
|
|
|
|
|
|
|
keyword_index = self.porcupine.process(frame)
|
|
|
|
if keyword_index is None:
|
|
|
|
return None # No keyword detected
|
|
|
|
|
|
|
|
if keyword_index >= 0 and self.keywords:
|
|
|
|
if self.start_conversation_on_hotword:
|
|
|
|
self.state = AssistantState.DETECTING_SPEECH
|
|
|
|
|
2024-04-13 19:49:58 +02:00
|
|
|
self.tts.stop()
|
2024-04-08 03:02:03 +02:00
|
|
|
self._on_hotword_detected(hotword=self.keywords[keyword_index])
|
2024-04-08 01:54:26 +02:00
|
|
|
return HotwordDetectedEvent(hotword=self.keywords[keyword_index])
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
def _process_speech(self, frame):
|
|
|
|
if not self.cheetah:
|
|
|
|
return None
|
|
|
|
|
|
|
|
event = None
|
2024-04-08 03:02:03 +02:00
|
|
|
partial_transcript, self._ctx.is_final = self.cheetah.process(frame)
|
2024-04-08 01:54:26 +02:00
|
|
|
|
2024-04-08 03:02:03 +02:00
|
|
|
if partial_transcript:
|
2024-04-09 00:19:51 +02:00
|
|
|
self._ctx.transcript += partial_transcript
|
2024-04-08 01:54:26 +02:00
|
|
|
self.logger.info(
|
|
|
|
'Partial transcript: %s, is_final: %s',
|
2024-04-09 00:19:51 +02:00
|
|
|
self._ctx.transcript,
|
2024-04-08 03:02:03 +02:00
|
|
|
self._ctx.is_final,
|
2024-04-08 01:54:26 +02:00
|
|
|
)
|
|
|
|
|
2024-04-08 03:02:03 +02:00
|
|
|
if self._ctx.is_final or self._ctx.timed_out:
|
2024-04-10 22:26:45 +02:00
|
|
|
phrase = self.cheetah.flush() or ''
|
2024-04-09 00:19:51 +02:00
|
|
|
self._ctx.transcript += phrase
|
|
|
|
phrase = self._ctx.transcript
|
2024-04-08 03:02:03 +02:00
|
|
|
phrase = phrase[:1].lower() + phrase[1:]
|
|
|
|
|
2024-04-13 19:49:58 +02:00
|
|
|
if phrase:
|
2024-04-08 03:02:03 +02:00
|
|
|
event = SpeechRecognizedEvent(phrase=phrase)
|
|
|
|
self._on_speech_recognized(phrase=phrase)
|
|
|
|
else:
|
|
|
|
event = ConversationTimeoutEvent()
|
|
|
|
self._on_conversation_timeout()
|
|
|
|
|
|
|
|
self._ctx.reset()
|
|
|
|
if self.hotword_enabled:
|
2024-04-08 01:54:26 +02:00
|
|
|
self.state = AssistantState.DETECTING_HOTWORD
|
|
|
|
|
|
|
|
return event
|
|
|
|
|
2024-04-10 22:26:45 +02:00
|
|
|
def override_speech_model(self, model_path: Optional[str]):
|
|
|
|
self._speech_model_path_override = model_path
|
|
|
|
|
2024-04-07 22:42:01 +02:00
|
|
|
|
|
|
|
# vim:sw=4:ts=4:et:
|