The websocket plugin now extends AsyncRunnablePlugin too

This commit is contained in:
Fabio Manganiello 2022-08-15 00:14:52 +02:00
parent 770a14daae
commit 2797ffbe53
Signed by untrusted user: blacklight
GPG key ID: D90FBA7F76362774

View file

@ -2,18 +2,20 @@ import asyncio
import json import json
import time import time
from websockets import connect as websocket_connect from typing import Optional, Collection
from websockets import connect as websocket_connect # type: ignore
from websockets.exceptions import ConnectionClosed from websockets.exceptions import ConnectionClosed
from platypush.context import get_or_create_event_loop, get_bus from platypush.context import get_bus
from platypush.message.event.websocket import WebsocketMessageEvent from platypush.message.event.websocket import WebsocketMessageEvent
from platypush.plugins import Plugin, action from platypush.plugins import AsyncRunnablePlugin, action
from platypush.utils import get_ssl_client_context from platypush.utils import get_ssl_client_context
class WebsocketPlugin(Plugin): class WebsocketPlugin(AsyncRunnablePlugin):
""" """
Plugin to send messages over a websocket connection. Plugin to send and receive messages over websocket connections.
Triggers: Triggers:
@ -22,6 +24,22 @@ class WebsocketPlugin(Plugin):
""" """
def __init__(self, subscriptions: Optional[Collection[str]] = None, **kwargs):
"""
:param subscriptions: List of websocket URLs that should be subscribed
at startup, prefixed by ``ws://`` or ``wss://``.
"""
super().__init__(**kwargs)
self._subscriptions = subscriptions or []
@property
def loop(self):
if not self._loop:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
return self._loop
@action @action
def send( def send(
self, self,
@ -52,6 +70,8 @@ class WebsocketPlugin(Plugin):
otherwise nothing. otherwise nothing.
""" """
msg = self._parse_msg(msg)
async def send(): async def send():
websocket_args = { websocket_args = {
'ssl': self._get_ssl_context( 'ssl': self._get_ssl_context(
@ -70,13 +90,11 @@ class WebsocketPlugin(Plugin):
self.logger.warning('Error on websocket %s: %s', url, err) self.logger.warning('Error on websocket %s: %s', url, err)
if wait_response: if wait_response:
messages = await self._ws_recv(ws, num_messages=1) messages = await self._recv(ws, num_messages=1)
if messages: if messages:
return self._parse_msg(messages[0]) return self._parse_msg(messages[0])
msg = self._parse_msg(msg) return asyncio.run_coroutine_threadsafe(send(), self.loop).result()
loop = get_or_create_event_loop()
return loop.run_until_complete(send())
@action @action
def recv( def recv(
@ -123,14 +141,11 @@ class WebsocketPlugin(Plugin):
} }
async with websocket_connect(url, **websocket_args) as ws: async with websocket_connect(url, **websocket_args) as ws:
return await self._ws_recv( return await self._recv(ws, timeout=timeout, num_messages=num_messages)
ws, timeout=timeout, num_messages=num_messages
)
loop = get_or_create_event_loop() return self.loop.call_soon_threadsafe(recv)
return loop.run_until_complete(recv())
async def _ws_recv(self, ws, timeout=0, num_messages=0): async def _recv(self, ws, timeout=0, num_messages=0):
messages = [] messages = []
time_start = time.time() time_start = time.time()
time_end = time_start + timeout if timeout else 0 time_end = time_start + timeout if timeout else 0
@ -166,6 +181,10 @@ class WebsocketPlugin(Plugin):
return messages return messages
@property
def _should_start_runner(self):
return bool(self._subscriptions)
@staticmethod @staticmethod
def _parse_msg(msg): def _parse_msg(msg):
try: try:
@ -175,11 +194,18 @@ class WebsocketPlugin(Plugin):
return msg return msg
async def listen(self):
async def _recv(url):
async with websocket_connect(url) as ws:
return await self._recv(ws)
await asyncio.wait([_recv(url) for url in set(self._subscriptions)])
@staticmethod @staticmethod
def _get_ssl_context( def _get_ssl_context(
url: str, ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None url: str, ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None
): ):
if url.startswith('wss://'): if url.startswith('wss://') or url.startswith('https://'):
return get_ssl_client_context( return get_ssl_client_context(
ssl_cert=ssl_cert, ssl_cert=ssl_cert,
ssl_key=ssl_key, ssl_key=ssl_key,