forked from platypush/platypush
The websocket plugin now extends AsyncRunnablePlugin too
This commit is contained in:
parent
770a14daae
commit
2797ffbe53
1 changed files with 42 additions and 16 deletions
|
@ -2,18 +2,20 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from websockets import connect as websocket_connect
|
from typing import Optional, Collection
|
||||||
|
|
||||||
|
from websockets import connect as websocket_connect # type: ignore
|
||||||
from websockets.exceptions import ConnectionClosed
|
from websockets.exceptions import ConnectionClosed
|
||||||
|
|
||||||
from platypush.context import get_or_create_event_loop, get_bus
|
from platypush.context import get_bus
|
||||||
from platypush.message.event.websocket import WebsocketMessageEvent
|
from platypush.message.event.websocket import WebsocketMessageEvent
|
||||||
from platypush.plugins import Plugin, action
|
from platypush.plugins import AsyncRunnablePlugin, action
|
||||||
from platypush.utils import get_ssl_client_context
|
from platypush.utils import get_ssl_client_context
|
||||||
|
|
||||||
|
|
||||||
class WebsocketPlugin(Plugin):
|
class WebsocketPlugin(AsyncRunnablePlugin):
|
||||||
"""
|
"""
|
||||||
Plugin to send messages over a websocket connection.
|
Plugin to send and receive messages over websocket connections.
|
||||||
|
|
||||||
Triggers:
|
Triggers:
|
||||||
|
|
||||||
|
@ -22,6 +24,22 @@ class WebsocketPlugin(Plugin):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, subscriptions: Optional[Collection[str]] = None, **kwargs):
|
||||||
|
"""
|
||||||
|
:param subscriptions: List of websocket URLs that should be subscribed
|
||||||
|
at startup, prefixed by ``ws://`` or ``wss://``.
|
||||||
|
"""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._subscriptions = subscriptions or []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loop(self):
|
||||||
|
if not self._loop:
|
||||||
|
self._loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(self._loop)
|
||||||
|
|
||||||
|
return self._loop
|
||||||
|
|
||||||
@action
|
@action
|
||||||
def send(
|
def send(
|
||||||
self,
|
self,
|
||||||
|
@ -52,6 +70,8 @@ class WebsocketPlugin(Plugin):
|
||||||
otherwise nothing.
|
otherwise nothing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
msg = self._parse_msg(msg)
|
||||||
|
|
||||||
async def send():
|
async def send():
|
||||||
websocket_args = {
|
websocket_args = {
|
||||||
'ssl': self._get_ssl_context(
|
'ssl': self._get_ssl_context(
|
||||||
|
@ -70,13 +90,11 @@ class WebsocketPlugin(Plugin):
|
||||||
self.logger.warning('Error on websocket %s: %s', url, err)
|
self.logger.warning('Error on websocket %s: %s', url, err)
|
||||||
|
|
||||||
if wait_response:
|
if wait_response:
|
||||||
messages = await self._ws_recv(ws, num_messages=1)
|
messages = await self._recv(ws, num_messages=1)
|
||||||
if messages:
|
if messages:
|
||||||
return self._parse_msg(messages[0])
|
return self._parse_msg(messages[0])
|
||||||
|
|
||||||
msg = self._parse_msg(msg)
|
return asyncio.run_coroutine_threadsafe(send(), self.loop).result()
|
||||||
loop = get_or_create_event_loop()
|
|
||||||
return loop.run_until_complete(send())
|
|
||||||
|
|
||||||
@action
|
@action
|
||||||
def recv(
|
def recv(
|
||||||
|
@ -123,14 +141,11 @@ class WebsocketPlugin(Plugin):
|
||||||
}
|
}
|
||||||
|
|
||||||
async with websocket_connect(url, **websocket_args) as ws:
|
async with websocket_connect(url, **websocket_args) as ws:
|
||||||
return await self._ws_recv(
|
return await self._recv(ws, timeout=timeout, num_messages=num_messages)
|
||||||
ws, timeout=timeout, num_messages=num_messages
|
|
||||||
)
|
|
||||||
|
|
||||||
loop = get_or_create_event_loop()
|
return self.loop.call_soon_threadsafe(recv)
|
||||||
return loop.run_until_complete(recv())
|
|
||||||
|
|
||||||
async def _ws_recv(self, ws, timeout=0, num_messages=0):
|
async def _recv(self, ws, timeout=0, num_messages=0):
|
||||||
messages = []
|
messages = []
|
||||||
time_start = time.time()
|
time_start = time.time()
|
||||||
time_end = time_start + timeout if timeout else 0
|
time_end = time_start + timeout if timeout else 0
|
||||||
|
@ -166,6 +181,10 @@ class WebsocketPlugin(Plugin):
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _should_start_runner(self):
|
||||||
|
return bool(self._subscriptions)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _parse_msg(msg):
|
def _parse_msg(msg):
|
||||||
try:
|
try:
|
||||||
|
@ -175,11 +194,18 @@ class WebsocketPlugin(Plugin):
|
||||||
|
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
async def listen(self):
|
||||||
|
async def _recv(url):
|
||||||
|
async with websocket_connect(url) as ws:
|
||||||
|
return await self._recv(ws)
|
||||||
|
|
||||||
|
await asyncio.wait([_recv(url) for url in set(self._subscriptions)])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_ssl_context(
|
def _get_ssl_context(
|
||||||
url: str, ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None
|
url: str, ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None
|
||||||
):
|
):
|
||||||
if url.startswith('wss://'):
|
if url.startswith('wss://') or url.startswith('https://'):
|
||||||
return get_ssl_client_context(
|
return get_ssl_client_context(
|
||||||
ssl_cert=ssl_cert,
|
ssl_cert=ssl_cert,
|
||||||
ssl_key=ssl_key,
|
ssl_key=ssl_key,
|
||||||
|
|
Loading…
Reference in a new issue