diff --git a/platypush/plugins/websocket/__init__.py b/platypush/plugins/websocket/__init__.py index 8999b124..5744c075 100644 --- a/platypush/plugins/websocket/__init__.py +++ b/platypush/plugins/websocket/__init__.py @@ -2,18 +2,20 @@ import asyncio import json import time -from websockets import connect as websocket_connect +from typing import Optional, Collection + +from websockets import connect as websocket_connect # type: ignore from websockets.exceptions import ConnectionClosed -from platypush.context import get_or_create_event_loop, get_bus +from platypush.context import get_bus from platypush.message.event.websocket import WebsocketMessageEvent -from platypush.plugins import Plugin, action +from platypush.plugins import AsyncRunnablePlugin, action from platypush.utils import get_ssl_client_context -class WebsocketPlugin(Plugin): +class WebsocketPlugin(AsyncRunnablePlugin): """ - Plugin to send messages over a websocket connection. + Plugin to send and receive messages over websocket connections. Triggers: @@ -22,6 +24,22 @@ class WebsocketPlugin(Plugin): """ + def __init__(self, subscriptions: Optional[Collection[str]] = None, **kwargs): + """ + :param subscriptions: List of websocket URLs that should be subscribed + at startup, prefixed by ``ws://`` or ``wss://``. + """ + super().__init__(**kwargs) + self._subscriptions = subscriptions or [] + + @property + def loop(self): + if not self._loop: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + + return self._loop + @action def send( self, @@ -52,6 +70,8 @@ class WebsocketPlugin(Plugin): otherwise nothing. """ + msg = self._parse_msg(msg) + async def send(): websocket_args = { 'ssl': self._get_ssl_context( @@ -70,13 +90,11 @@ class WebsocketPlugin(Plugin): self.logger.warning('Error on websocket %s: %s', url, err) if wait_response: - messages = await self._ws_recv(ws, num_messages=1) + messages = await self._recv(ws, num_messages=1) if messages: return self._parse_msg(messages[0]) - msg = self._parse_msg(msg) - loop = get_or_create_event_loop() - return loop.run_until_complete(send()) + return asyncio.run_coroutine_threadsafe(send(), self.loop).result() @action def recv( @@ -123,14 +141,11 @@ class WebsocketPlugin(Plugin): } async with websocket_connect(url, **websocket_args) as ws: - return await self._ws_recv( - ws, timeout=timeout, num_messages=num_messages - ) + return await self._recv(ws, timeout=timeout, num_messages=num_messages) - loop = get_or_create_event_loop() - return loop.run_until_complete(recv()) + return self.loop.call_soon_threadsafe(recv) - async def _ws_recv(self, ws, timeout=0, num_messages=0): + async def _recv(self, ws, timeout=0, num_messages=0): messages = [] time_start = time.time() time_end = time_start + timeout if timeout else 0 @@ -166,6 +181,10 @@ class WebsocketPlugin(Plugin): return messages + @property + def _should_start_runner(self): + return bool(self._subscriptions) + @staticmethod def _parse_msg(msg): try: @@ -175,11 +194,18 @@ class WebsocketPlugin(Plugin): return msg + async def listen(self): + async def _recv(url): + async with websocket_connect(url) as ws: + return await self._recv(ws) + + await asyncio.wait([_recv(url) for url in set(self._subscriptions)]) + @staticmethod def _get_ssl_context( url: str, ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None ): - if url.startswith('wss://'): + if url.startswith('wss://') or url.startswith('https://'): return get_ssl_client_context( ssl_cert=ssl_cert, ssl_key=ssl_key,