Refactored the new websocket routes.

Defined a `platypush.backend.http.ws` package with all the routes, a
base `WSRoute` class that all the websocket routes can extend, and a
logic in the HTTP backend to automatically scan the package to register
exposed websocket routes.
This commit is contained in:
Fabio Manganiello 2023-05-08 11:45:14 +02:00
parent 56dc8d0972
commit f5fcccb0bd
Signed by: blacklight
GPG key ID: D90FBA7F76362774
6 changed files with 179 additions and 77 deletions

View file

@ -12,7 +12,8 @@ from tornado.ioloop import IOLoop
from platypush.backend import Backend from platypush.backend import Backend
from platypush.backend.http.app import application from platypush.backend.http.app import application
from platypush.backend.http.ws import WSEventProxy, events_redis_topic from platypush.backend.http.ws import scan_routes
from platypush.backend.http.ws.events import events_redis_topic
from platypush.bus.redis import RedisBus from platypush.bus.redis import RedisBus
from platypush.config import Config from platypush.config import Config
@ -263,7 +264,7 @@ class HttpBackend(Backend):
container = WSGIContainer(application) container = WSGIContainer(application)
server = Application( server = Application(
[ [
(r'/ws/events', WSEventProxy), *[(route.path(), route) for route in scan_routes()],
(r'.*', FallbackHandler, {'fallback': container}), (r'.*', FallbackHandler, {'fallback': container}),
] ]
) )

View file

@ -1,75 +0,0 @@
from logging import getLogger
from threading import Thread
from typing_extensions import override
from redis import ConnectionError
from tornado.ioloop import IOLoop
from tornado.websocket import WebSocketHandler
from platypush.config import Config
from platypush.message.event import Event
from platypush.utils import get_redis
events_redis_topic = f'_platypush/{Config.get("device_id")}/events' # type: ignore
logger = getLogger(__name__)
class WSEventProxy(WebSocketHandler, Thread):
"""
Websocket event proxy mapped to ``/ws/events``.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._sub = get_redis().pubsub()
self._io_loop = IOLoop.current()
@override
def open(self, *_, **__):
logger.info('Started websocket connection with %s', self.request.remote_ip)
self.name = f'ws:events@{self.request.remote_ip}'
self.start()
@override
def on_message(self, *_, **__):
pass
@override
def data_received(self, *_, **__):
pass
@override
def run(self) -> None:
super().run()
self._sub.subscribe(events_redis_topic)
try:
for msg in self._sub.listen():
if (
msg.get('type') != 'message'
and msg.get('channel').decode() != events_redis_topic
):
continue
try:
evt = Event.build(msg.get('data').decode())
except Exception as e:
logger.warning('Error parsing event: %s: %s', msg.get('data'), e)
continue
self._io_loop.asyncio_loop.call_soon_threadsafe( # type: ignore
self.write_message, str(evt)
)
except ConnectionError:
pass
@override
def on_close(self):
self._sub.unsubscribe(events_redis_topic)
self._sub.close()
logger.info(
'Websocket connection to %s closed, reason=%s, message=%s',
self.request.remote_ip,
self.close_code,
self.close_reason,
)

View file

@ -0,0 +1,4 @@
from ._base import WSRoute, logger, pubsub_redis_topic
from ._scanner import scan_routes
__all__ = ['WSRoute', 'logger', 'pubsub_redis_topic', 'scan_routes']

View file

@ -0,0 +1,104 @@
from abc import ABC, abstractclassmethod
from logging import getLogger
from threading import RLock, Thread
from typing import Any, Generator, Iterable, Optional, Union
from typing_extensions import override
from redis import ConnectionError as RedisConnectionError
from tornado.ioloop import IOLoop
from tornado.websocket import WebSocketHandler
from platypush.config import Config
from platypush.utils import get_redis
logger = getLogger(__name__)
def pubsub_redis_topic(topic: str) -> str:
return f'_platypush/{Config.get("device_id")}/{topic}' # type: ignore
class WSRoute(WebSocketHandler, Thread, ABC):
"""
Base class for Tornado websocket endpoints.
"""
def __init__(self, *args, redis_topics: Optional[Iterable[str]] = None, **kwargs):
super().__init__(*args, **kwargs)
self._redis_topics = set(redis_topics or [])
self._sub = get_redis().pubsub()
self._io_loop = IOLoop.current()
self._sub_lock = RLock()
@override
def open(self, *_, **__):
logger.info('Started websocket connection with %s', self.request.remote_ip)
self.name = f'ws:{self.app_name()}@{self.request.remote_ip}'
self.start()
@override
def data_received(self, *_, **__):
pass
@override
def on_message(self, *_, **__):
pass
@abstractclassmethod
def app_name(cls) -> str:
raise NotImplementedError()
@classmethod
def path(cls) -> str:
return f'/ws/{cls.app_name()}'
def subscribe(self, *topics: str) -> None:
with self._sub_lock:
for topic in topics:
self._sub.subscribe(topic)
self._redis_topics.add(topic)
def unsubscribe(self, *topics: str) -> None:
with self._sub_lock:
for topic in topics:
if topic in self._redis_topics:
self._sub.unsubscribe(topic)
self._redis_topics.remove(topic)
def listen(self) -> Generator[Any, None, None]:
try:
for msg in self._sub.listen():
if (
msg.get('type') != 'message'
and msg.get('channel').decode() not in self._redis_topics
):
continue
yield msg.get('data')
except RedisConnectionError:
return
def send(self, msg: Union[str, bytes]) -> None:
self._io_loop.asyncio_loop.call_soon_threadsafe( # type: ignore
self.write_message, msg
)
@override
def run(self) -> None:
super().run()
for topic in self._redis_topics:
self._sub.subscribe(topic)
@override
def on_close(self):
topics = self._redis_topics.copy()
for topic in topics:
self.unsubscribe(topic)
self._sub.close()
logger.info(
'Websocket connection to %s closed, reason=%s, message=%s',
self.request.remote_ip,
self.close_code,
self.close_reason,
)

View file

@ -0,0 +1,35 @@
import os
import importlib
import inspect
from typing import List, Type
import pkgutil
from ._base import WSRoute, logger
def scan_routes() -> List[Type[WSRoute]]:
"""
Scans for websocket route objects.
"""
base_dir = os.path.dirname(__file__)
routes = []
for _, mod_name, _ in pkgutil.walk_packages([base_dir], prefix=__package__ + '.'):
try:
module = importlib.import_module(mod_name)
except Exception as e:
logger.warning('Could not import module %s', mod_name)
logger.exception(e)
continue
for _, obj in inspect.getmembers(module):
if (
inspect.isclass(obj)
and not inspect.isabstract(obj)
and issubclass(obj, WSRoute)
):
routes.append(obj)
return routes

View file

@ -0,0 +1,33 @@
from typing_extensions import override
from platypush.message.event import Event
from . import WSRoute, logger, pubsub_redis_topic
events_redis_topic = pubsub_redis_topic('events')
class WSEventProxy(WSRoute):
"""
Websocket event proxy mapped to ``/ws/events``.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.subscribe(events_redis_topic)
@classmethod
@override
def app_name(cls) -> str:
return 'events'
@override
def run(self) -> None:
for msg in self.listen():
try:
evt = Event.build(msg.decode())
except Exception as e:
logger.warning('Error parsing event: %s: %s', msg, e)
continue
self.send(str(evt))