forked from platypush/platypush
220 lines
7.2 KiB
Python
220 lines
7.2 KiB
Python
import asyncio
|
|
import json
|
|
import time
|
|
|
|
from typing import Optional, Collection
|
|
|
|
from websockets import connect as websocket_connect # type: ignore
|
|
from websockets.exceptions import ConnectionClosed
|
|
|
|
from platypush.context import get_bus
|
|
from platypush.message.event.websocket import WebsocketMessageEvent
|
|
from platypush.plugins import AsyncRunnablePlugin, action
|
|
from platypush.utils import get_ssl_client_context
|
|
|
|
|
|
class WebsocketPlugin(AsyncRunnablePlugin):
|
|
"""
|
|
Plugin to send and receive messages over websocket connections.
|
|
|
|
Triggers:
|
|
|
|
* :class:`platypush.message.event.websocket.WebsocketMessageEvent` when
|
|
a message is received on a subscribed websocket.
|
|
|
|
"""
|
|
|
|
def __init__(self, subscriptions: Optional[Collection[str]] = None, **kwargs):
|
|
"""
|
|
:param subscriptions: List of websocket URLs that should be subscribed
|
|
at startup, prefixed by ``ws://`` or ``wss://``.
|
|
"""
|
|
super().__init__(**kwargs)
|
|
self._subscriptions = subscriptions or []
|
|
|
|
@property
|
|
def loop(self):
|
|
if not self._loop:
|
|
self._loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(self._loop)
|
|
|
|
return self._loop
|
|
|
|
@action
|
|
def send(
|
|
self,
|
|
url: str,
|
|
msg,
|
|
ssl_cert=None,
|
|
ssl_key=None,
|
|
ssl_cafile=None,
|
|
ssl_capath=None,
|
|
wait_response=False,
|
|
):
|
|
"""
|
|
Sends a message to a websocket.
|
|
|
|
:param url: Websocket URL, e.g. ws://localhost:8765 or wss://localhost:8765
|
|
:param msg: Message to be sent. It can be a list, a dict, or a Message object
|
|
:param ssl_cert: Path to the SSL certificate to be used, if the SSL
|
|
connection requires client authentication as well (default: None)
|
|
:param ssl_key: Path to the SSL key to be used, if the SSL connection
|
|
requires client authentication as well (default: None)
|
|
:param ssl_cafile: Path to the certificate authority file if required
|
|
by the SSL configuration (default: None)
|
|
:param ssl_capath: Path to the certificate authority directory if
|
|
required by the SSL configuration (default: None)
|
|
:param wait_response: Set to True if you expect a response to the
|
|
delivered message.
|
|
:return: The received response if ``wait_response`` is set to True,
|
|
otherwise nothing.
|
|
"""
|
|
|
|
msg = self._parse_msg(msg)
|
|
|
|
async def send():
|
|
websocket_args = {
|
|
'ssl': self._get_ssl_context(
|
|
url,
|
|
ssl_cert=ssl_cert,
|
|
ssl_key=ssl_key,
|
|
ssl_cafile=ssl_cafile,
|
|
ssl_capath=ssl_capath,
|
|
)
|
|
}
|
|
|
|
async with websocket_connect(url, **websocket_args) as ws:
|
|
try:
|
|
await ws.send(str(msg))
|
|
except ConnectionClosed as err:
|
|
self.logger.warning('Error on websocket %s: %s', url, err)
|
|
|
|
if wait_response:
|
|
messages = await self._recv(ws, num_messages=1)
|
|
if messages:
|
|
return self._parse_msg(messages[0])
|
|
|
|
return asyncio.run_coroutine_threadsafe(send(), self.loop).result()
|
|
|
|
@action
|
|
def recv(
|
|
self,
|
|
url: str,
|
|
ssl_cert=None,
|
|
ssl_key=None,
|
|
ssl_cafile=None,
|
|
ssl_capath=None,
|
|
num_messages=0,
|
|
timeout=0,
|
|
):
|
|
"""
|
|
Receive one or more messages from a websocket.
|
|
|
|
A :class:`platypush.message.event.websocket.WebsocketMessageEvent`
|
|
event will be triggered whenever a new message is received.
|
|
|
|
:param url: Websocket URL, e.g. ws://localhost:8765 or wss://localhost:8765
|
|
:param ssl_cert: Path to the SSL certificate to be used, if the SSL
|
|
connection requires client authentication as well (default: None)
|
|
:param ssl_key: Path to the SSL key to be used, if the SSL connection
|
|
requires client authentication as well (default: None)
|
|
:param ssl_cafile: Path to the certificate authority file if required
|
|
by the SSL configuration (default: None)
|
|
:param ssl_capath: Path to the certificate authority directory if
|
|
required by the SSL configuration (default: None)
|
|
:param num_messages: Exit after receiving this number of messages.
|
|
Default: 0, receive forever.
|
|
:param timeout: Message receive timeout in seconds. Default: 0 - no timeout.
|
|
:return: A list with the messages that have been received, unless
|
|
``num_messages`` is set to 0 or ``None``.
|
|
"""
|
|
|
|
async def recv():
|
|
websocket_args = {
|
|
'ssl': self._get_ssl_context(
|
|
url,
|
|
ssl_cert=ssl_cert,
|
|
ssl_key=ssl_key,
|
|
ssl_cafile=ssl_cafile,
|
|
ssl_capath=ssl_capath,
|
|
)
|
|
}
|
|
|
|
async with websocket_connect(url, **websocket_args) as ws:
|
|
return await self._recv(ws, timeout=timeout, num_messages=num_messages)
|
|
|
|
return self.loop.call_soon_threadsafe(recv)
|
|
|
|
async def _recv(self, ws, timeout=0, num_messages=0):
|
|
messages = []
|
|
time_start = time.time()
|
|
time_end = time_start + timeout if timeout else 0
|
|
url = 'ws{secure}://{host}:{port}{path}'.format(
|
|
secure='s' if ws._secure else '',
|
|
host=ws.remote_address[0],
|
|
port=ws.remote_address[1],
|
|
path=ws.path,
|
|
)
|
|
|
|
while (not num_messages) or (len(messages) < num_messages):
|
|
msg = None
|
|
err = None
|
|
remaining_timeout = time_end - time.time() if time_end else None
|
|
|
|
try:
|
|
msg = await asyncio.wait_for(ws.recv(), remaining_timeout)
|
|
except (ConnectionClosed, asyncio.exceptions.TimeoutError) as e:
|
|
err = e
|
|
self.logger.warning('Error on websocket %s: %s', url, e)
|
|
|
|
if isinstance(err, ConnectionClosed) or (
|
|
time_end and time.time() > time_end
|
|
):
|
|
break
|
|
|
|
if msg is None:
|
|
continue
|
|
|
|
msg = self._parse_msg(msg)
|
|
messages.append(msg)
|
|
get_bus().post(WebsocketMessageEvent(url=url, message=msg))
|
|
|
|
return messages
|
|
|
|
@property
|
|
def _should_start_runner(self):
|
|
return bool(self._subscriptions)
|
|
|
|
@staticmethod
|
|
def _parse_msg(msg):
|
|
try:
|
|
msg = json.dumps(msg)
|
|
except Exception:
|
|
pass
|
|
|
|
return msg
|
|
|
|
async def listen(self):
|
|
async def _recv(url):
|
|
async with websocket_connect(url) as ws:
|
|
return await self._recv(ws)
|
|
|
|
await asyncio.wait([_recv(url) for url in set(self._subscriptions)])
|
|
|
|
@staticmethod
|
|
def _get_ssl_context(
|
|
url: str, ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None
|
|
):
|
|
if url.startswith('wss://') or url.startswith('https://'):
|
|
return get_ssl_client_context(
|
|
ssl_cert=ssl_cert,
|
|
ssl_key=ssl_key,
|
|
ssl_cafile=ssl_cafile,
|
|
ssl_capath=ssl_capath,
|
|
)
|
|
|
|
return None
|
|
|
|
|
|
# vim:sw=4:ts=4:et:
|