Support for custom TTS engine for rendering assistant response (see #86)

This commit is contained in:
Fabio Manganiello 2020-02-24 20:22:45 +01:00
parent 40a29a8214
commit 87a51b391c
7 changed files with 120 additions and 21 deletions

View file

@ -1,12 +1,26 @@
import threading import threading
from typing import Optional, Dict, Any, Tuple
from platypush.backend import Backend from platypush.backend import Backend
from platypush.context import get_plugin
from platypush.plugins.tts import TtsPlugin
class AssistantBackend(Backend): class AssistantBackend(Backend):
def __init__(self, **kwargs): def __init__(self, tts_plugin: Optional[str] = None, tts_args: Optional[Dict[str, Any]] = None, **kwargs):
"""
Default assistant backend constructor.
:param tts_plugin: If set, and if the assistant returns the processed response as text, then the processed
response will be played through the selected text-to-speech plugin (can be e.g. "``tts``",
"``tts.google``" or any other implementation of :class:`platypush.plugins.tts.TtsPlugin`).
:param tts_args: Extra parameters to pass to the ``say`` method of the selected TTS plugin (e.g.
language, voice or gender).
"""
super().__init__(**kwargs) super().__init__(**kwargs)
self._detection_paused = threading.Event() self._detection_paused = threading.Event()
self.tts_plugin = tts_plugin
self.tts_args = tts_args or {}
def pause_detection(self): def pause_detection(self):
self._detection_paused.set() self._detection_paused.set()
@ -17,5 +31,8 @@ class AssistantBackend(Backend):
def is_detecting(self): def is_detecting(self):
return not self._detection_paused.is_set() return not self._detection_paused.is_set()
def _get_tts_plugin(self) -> Tuple[Optional[TtsPlugin], Dict[str, Any]]:
return get_plugin(self.tts_plugin) if self.tts_plugin else None, self.tts_args
# vim:sw=4:ts=4:et: # vim:sw=4:ts=4:et:

View file

@ -102,6 +102,11 @@ class AssistantGoogleBackend(AssistantBackend):
elif hasattr(EventType, 'ON_RENDER_RESPONSE') and \ elif hasattr(EventType, 'ON_RENDER_RESPONSE') and \
event.type == EventType.ON_RENDER_RESPONSE: event.type == EventType.ON_RENDER_RESPONSE:
self.bus.post(ResponseEvent(assistant=self, response_text=event.args.get('text'))) self.bus.post(ResponseEvent(assistant=self, response_text=event.args.get('text')))
tts, args = self._get_tts_plugin()
if tts and 'text' in event.args:
self.stop_conversation()
tts.say(text=event.args['text'], **args)
elif hasattr(EventType, 'ON_RESPONDING_STARTED') and \ elif hasattr(EventType, 'ON_RESPONDING_STARTED') and \
event.type == EventType.ON_RESPONDING_STARTED and \ event.type == EventType.ON_RESPONDING_STARTED and \
event.args.get('is_error_response', False) is True: event.args.get('is_error_response', False) is True:
@ -141,6 +146,20 @@ class AssistantGoogleBackend(AssistantBackend):
if self.assistant: if self.assistant:
self.assistant.stop_conversation() self.assistant.stop_conversation()
def set_mic_mute(self, muted):
if not self.assistant:
self.logger.warning('Assistant not running')
return
self.assistant.set_mic_mute(muted)
def send_text_query(self, query):
if not self.assistant:
self.logger.warning('Assistant not running')
return
self.assistant.send_text_query(query)
def run(self): def run(self):
import google.oauth2.credentials import google.oauth2.credentials
from google.assistant.library import Assistant from google.assistant.library import Assistant
@ -148,9 +167,7 @@ class AssistantGoogleBackend(AssistantBackend):
super().run() super().run()
with open(self.credentials_file, 'r') as f: with open(self.credentials_file, 'r') as f:
self.credentials = google.oauth2.credentials.Credentials( self.credentials = google.oauth2.credentials.Credentials(token=None, **json.load(f))
token=None,
**json.load(f))
while not self.should_stop(): while not self.should_stop():
self._has_error = False self._has_error = False

View file

@ -125,6 +125,8 @@ class AssistantSnowboyBackend(AssistantBackend):
'detect_sound': detect_sound, 'detect_sound': detect_sound,
'assistant_plugin': get_plugin(assistant_plugin_name) if assistant_plugin_name else None, 'assistant_plugin': get_plugin(assistant_plugin_name) if assistant_plugin_name else None,
'assistant_language': conf.get('assistant_language'), 'assistant_language': conf.get('assistant_language'),
'tts_plugin': conf.get('tts_plugin'),
'tts_args': conf.get('tts_args', {}),
} }
def hotword_detected(self, hotword): def hotword_detected(self, hotword):
@ -150,12 +152,15 @@ class AssistantSnowboyBackend(AssistantBackend):
detect_sound = model.get('detect_sound') detect_sound = model.get('detect_sound')
assistant_plugin = model.get('assistant_plugin') assistant_plugin = model.get('assistant_plugin')
assistant_language = model.get('assistant_language') assistant_language = model.get('assistant_language')
tts_plugin = model.get('tts_plugin')
tts_args = model.get('tts_args')
if detect_sound: if detect_sound:
threading.Thread(target=sound_thread, args=(detect_sound,)).start() threading.Thread(target=sound_thread, args=(detect_sound,)).start()
if assistant_plugin: if assistant_plugin:
assistant_plugin.start_conversation(language=assistant_language) assistant_plugin.start_conversation(language=assistant_language, tts_plugin=tts_plugin,
tts_args=tts_args)
return callback return callback

View file

@ -10,7 +10,7 @@ class AssistantPlugin(ABC, Plugin):
""" """
@abstractmethod @abstractmethod
def start_conversation(self, *args, language=None, **kwargs): def start_conversation(self, *args, language=None, tts_plugin=None, tts_args=None, **kwargs):
""" """
Start a conversation. Start a conversation.
""" """

View file

@ -2,6 +2,7 @@
.. moduleauthor:: Fabio Manganiello <blacklight86@gmail.com> .. moduleauthor:: Fabio Manganiello <blacklight86@gmail.com>
""" """
from platypush.backend.assistant.google import AssistantGoogleBackend
from platypush.context import get_backend from platypush.context import get_backend
from platypush.plugins import action from platypush.plugins import action
from platypush.plugins.assistant import AssistantPlugin from platypush.plugins.assistant import AssistantPlugin
@ -17,11 +18,11 @@ class AssistantGooglePlugin(AssistantPlugin):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
def _get_assistant(self): def _get_assistant(self) -> AssistantGoogleBackend:
return get_backend('assistant.google') return get_backend('assistant.google')
@action @action
def start_conversation(self): def start_conversation(self, **kwargs):
""" """
Programmatically start a conversation with the assistant Programmatically start a conversation with the assistant
""" """
@ -36,5 +37,25 @@ class AssistantGooglePlugin(AssistantPlugin):
assistant = self._get_assistant() assistant = self._get_assistant()
assistant.stop_conversation() assistant.stop_conversation()
@action
def set_mic_mute(self, muted: bool = True):
"""
Programmatically mute/unmute the microphone.
:param muted: Set to True or False.
"""
assistant = self._get_assistant()
assistant.set_mic_mute(muted)
@action
def send_text_query(self, query: str):
"""
Send a text query to the assistant.
:param query: Query to be sent.
"""
assistant = self._get_assistant()
assistant.send_text_query(query)
# vim:sw=4:ts=4:et: # vim:sw=4:ts=4:et:

View file

@ -72,9 +72,7 @@ class SampleAssistant(object):
self.is_new_conversation = True self.is_new_conversation = True
# Create Google Assistant API gRPC client. # Create Google Assistant API gRPC client.
self.assistant = embedded_assistant_pb2_grpc.EmbeddedAssistantStub( self.assistant = embedded_assistant_pb2_grpc.EmbeddedAssistantStub(channel)
channel
)
self.deadline = deadline_sec self.deadline = deadline_sec
self.device_handler = device_handler self.device_handler = device_handler
@ -126,8 +124,7 @@ class SampleAssistant(object):
# This generator yields AssistResponse proto messages # This generator yields AssistResponse proto messages
# received from the gRPC Google Assistant API. # received from the gRPC Google Assistant API.
for resp in self.assistant.Assist(iter_log_assist_requests(), for resp in self.assistant.Assist(iter_log_assist_requests(), self.deadline):
self.deadline):
assistant_helpers.log_assist_response_without_audio(resp) assistant_helpers.log_assist_response_without_audio(resp)
if resp.event_type == END_OF_UTTERANCE: if resp.event_type == END_OF_UTTERANCE:
logging.info('End of audio request detected.') logging.info('End of audio request detected.')
@ -143,6 +140,7 @@ class SampleAssistant(object):
if len(r.transcript.strip())).strip() if len(r.transcript.strip())).strip()
logging.info('Transcript of user request: "%s".', self.detected_speech) logging.info('Transcript of user request: "%s".', self.detected_speech)
if len(resp.audio_out.audio_data) > 0: if len(resp.audio_out.audio_data) > 0:
if not self.conversation_stream.playing: if not self.conversation_stream.playing:
self.conversation_stream.stop_recording() self.conversation_stream.stop_recording()
@ -155,10 +153,12 @@ class SampleAssistant(object):
self.conversation_stream.write(resp.audio_out.audio_data) self.conversation_stream.write(resp.audio_out.audio_data)
elif self.conversation_stream.playing: elif self.conversation_stream.playing:
self.conversation_stream.stop_playback() self.conversation_stream.stop_playback()
if resp.dialog_state_out.conversation_state: if resp.dialog_state_out.conversation_state:
conversation_state = resp.dialog_state_out.conversation_state conversation_state = resp.dialog_state_out.conversation_state
logging.debug('Updating conversation state.') logging.debug('Updating conversation state.')
self.conversation_state = conversation_state self.conversation_state = conversation_state
if resp.dialog_state_out.volume_percentage != 0: if resp.dialog_state_out.volume_percentage != 0:
volume_percentage = resp.dialog_state_out.volume_percentage volume_percentage = resp.dialog_state_out.volume_percentage
logging.info('Setting volume to %s%%', volume_percentage) logging.info('Setting volume to %s%%', volume_percentage)
@ -166,11 +166,13 @@ class SampleAssistant(object):
if self.on_volume_changed: if self.on_volume_changed:
self.on_volume_changed(volume_percentage) self.on_volume_changed(volume_percentage)
if resp.dialog_state_out.microphone_mode == DIALOG_FOLLOW_ON: if resp.dialog_state_out.microphone_mode == DIALOG_FOLLOW_ON:
continue_conversation = True continue_conversation = True
logging.info('Expecting follow-on query from user.') logging.info('Expecting follow-on query from user.')
elif resp.dialog_state_out.microphone_mode == CLOSE_MICROPHONE: elif resp.dialog_state_out.microphone_mode == CLOSE_MICROPHONE:
continue_conversation = False continue_conversation = False
if resp.device_action.device_request_json: if resp.device_action.device_request_json:
device_request = json.loads( device_request = json.loads(
resp.device_action.device_request_json resp.device_action.device_request_json
@ -178,6 +180,7 @@ class SampleAssistant(object):
fs = self.device_handler(device_request) fs = self.device_handler(device_request)
if fs: if fs:
device_actions_futures.extend(fs) device_actions_futures.extend(fs)
if self.display and resp.screen_out.data: if self.display and resp.screen_out.data:
system_browser = browser_helpers.system_browser system_browser = browser_helpers.system_browser
system_browser.display(resp.screen_out.data) system_browser.display(resp.screen_out.data)

View file

@ -4,8 +4,9 @@
import json import json
import os import os
from typing import Optional, Dict, Any
from platypush.context import get_bus from platypush.context import get_bus, get_plugin
from platypush.message.event.assistant import ConversationStartEvent, \ from platypush.message.event.assistant import ConversationStartEvent, \
ConversationEndEvent, SpeechRecognizedEvent, VolumeChangedEvent, \ ConversationEndEvent, SpeechRecognizedEvent, VolumeChangedEvent, \
ResponseEvent ResponseEvent
@ -48,6 +49,8 @@ class AssistantGooglePushtotalkPlugin(AssistantPlugin):
'device_config.json'), 'device_config.json'),
language='en-US', language='en-US',
play_response=True, play_response=True,
tts_plugin=None,
tts_args=None,
**kwargs): **kwargs):
""" """
:param credentials_file: Path to the Google OAuth credentials file :param credentials_file: Path to the Google OAuth credentials file
@ -68,6 +71,12 @@ class AssistantGooglePushtotalkPlugin(AssistantPlugin):
:param play_response: If True (default) then the plugin will play the assistant response upon processed :param play_response: If True (default) then the plugin will play the assistant response upon processed
response. Otherwise nothing will be played - but you may want to handle the ``ResponseEvent`` manually. response. Otherwise nothing will be played - but you may want to handle the ``ResponseEvent`` manually.
:type play_response: bool :type play_response: bool
:param tts_plugin: Optional text-to-speech plugin to be used to process response text.
:type tts_plugin: str
:param tts_args: Optional arguments for the TTS plugin ``say`` method.
:type tts_args: dict
""" """
import googlesamples.assistant.grpc.audio_helpers as audio_helpers import googlesamples.assistant.grpc.audio_helpers as audio_helpers
@ -83,6 +92,8 @@ class AssistantGooglePushtotalkPlugin(AssistantPlugin):
self.credentials_file = credentials_file self.credentials_file = credentials_file
self.device_config = device_config self.device_config = device_config
self.play_response = play_response self.play_response = play_response
self.tts_plugin = tts_plugin
self.tts_args = tts_args or {}
self.assistant = None self.assistant = None
self.interactions = [] self.interactions = []
@ -188,18 +199,26 @@ class AssistantGooglePushtotalkPlugin(AssistantPlugin):
else: else:
self.interactions[-1]['response'] = response self.interactions[-1]['response'] = response
if self.tts_plugin:
tts = get_plugin(self.tts_plugin)
tts.say(response, **self.tts_args)
return handler return handler
@action @action
def start_conversation(self, *args, language=None, **kwargs): def start_conversation(self, *args, language: Optional[str] = None, tts_plugin: Optional[str] = None,
tts_args: Optional[Dict[str, Any]] = None, **kwargs):
""" """
Start a conversation Start a conversation
:param language: Language code override (default: default configured language) :param language: Language code override (default: default configured language).
:type language: str :param tts_plugin: Optional text-to-speech plugin to be used for rendering text.
:param tts_args: Optional arguments for the TTS plugin say method.
:returns: A list of the interactions that happen within the conversation. :returns: A list of the interactions that happen within the conversation.
..code-block:: json
[ [
{ {
"request": "request 1", "request": "request 1",
@ -212,15 +231,16 @@ class AssistantGooglePushtotalkPlugin(AssistantPlugin):
"response": "response 2" "response": "response 2"
} }
] ]
""" """
from platypush.plugins.assistant.google.lib import SampleAssistant from platypush.plugins.assistant.google.lib import SampleAssistant
if not language: self.tts_plugin = tts_plugin
language = self.language self.tts_args = tts_args
language = language or self.language
play_response = False if self.tts_plugin else self.play_response
self._init_assistant() self._init_assistant()
self.on_conversation_start() self.on_conversation_start()
@ -232,7 +252,7 @@ class AssistantGooglePushtotalkPlugin(AssistantPlugin):
display=None, display=None,
channel=self.grpc_channel, channel=self.grpc_channel,
deadline_sec=self.grpc_deadline, deadline_sec=self.grpc_deadline,
play_response=self.play_response, play_response=play_response,
device_handler=self.device_handler, device_handler=self.device_handler,
on_conversation_start=self.on_conversation_start(), on_conversation_start=self.on_conversation_start(),
on_conversation_end=self.on_conversation_end(), on_conversation_end=self.on_conversation_end(),
@ -262,6 +282,22 @@ class AssistantGooglePushtotalkPlugin(AssistantPlugin):
get_bus().post(ConversationEndEvent(assistant=self)) get_bus().post(ConversationEndEvent(assistant=self))
@action
def set_mic_mute(self, muted: bool = True):
"""
Programmatically mute/unmute the microphone.
:param muted: Set to True or False.
"""
if not self.conversation_stream:
self.logger.warning('The assistant is not running')
return
if muted:
self.conversation_stream.stop_recording()
else:
self.conversation_stream.start_recording()
def _install_device_handlers(self): def _install_device_handlers(self):
import googlesamples.assistant.grpc.device_helpers as device_helpers import googlesamples.assistant.grpc.device_helpers as device_helpers
self.device_handler = device_helpers.DeviceRequestHandler(self.device_id) self.device_handler = device_helpers.DeviceRequestHandler(self.device_id)