diff --git a/platypush/backend/http/__init__.py b/platypush/backend/http/__init__.py index 922ad7818b..782bec3e80 100644 --- a/platypush/backend/http/__init__.py +++ b/platypush/backend/http/__init__.py @@ -12,7 +12,8 @@ from tornado.ioloop import IOLoop from platypush.backend import Backend from platypush.backend.http.app import application -from platypush.backend.http.ws import WSEventProxy, events_redis_topic +from platypush.backend.http.ws import scan_routes +from platypush.backend.http.ws.events import events_redis_topic from platypush.bus.redis import RedisBus from platypush.config import Config @@ -263,7 +264,7 @@ class HttpBackend(Backend): container = WSGIContainer(application) server = Application( [ - (r'/ws/events', WSEventProxy), + *[(route.path(), route) for route in scan_routes()], (r'.*', FallbackHandler, {'fallback': container}), ] ) diff --git a/platypush/backend/http/ws.py b/platypush/backend/http/ws.py deleted file mode 100644 index 2dcbee8a7d..0000000000 --- a/platypush/backend/http/ws.py +++ /dev/null @@ -1,75 +0,0 @@ -from logging import getLogger -from threading import Thread -from typing_extensions import override - -from redis import ConnectionError -from tornado.ioloop import IOLoop -from tornado.websocket import WebSocketHandler - -from platypush.config import Config -from platypush.message.event import Event -from platypush.utils import get_redis - -events_redis_topic = f'_platypush/{Config.get("device_id")}/events' # type: ignore -logger = getLogger(__name__) - - -class WSEventProxy(WebSocketHandler, Thread): - """ - Websocket event proxy mapped to ``/ws/events``. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._sub = get_redis().pubsub() - self._io_loop = IOLoop.current() - - @override - def open(self, *_, **__): - logger.info('Started websocket connection with %s', self.request.remote_ip) - self.name = f'ws:events@{self.request.remote_ip}' - self.start() - - @override - def on_message(self, *_, **__): - pass - - @override - def data_received(self, *_, **__): - pass - - @override - def run(self) -> None: - super().run() - self._sub.subscribe(events_redis_topic) - - try: - for msg in self._sub.listen(): - if ( - msg.get('type') != 'message' - and msg.get('channel').decode() != events_redis_topic - ): - continue - - try: - evt = Event.build(msg.get('data').decode()) - except Exception as e: - logger.warning('Error parsing event: %s: %s', msg.get('data'), e) - continue - - self._io_loop.asyncio_loop.call_soon_threadsafe( # type: ignore - self.write_message, str(evt) - ) - except ConnectionError: - pass - - @override - def on_close(self): - self._sub.unsubscribe(events_redis_topic) - self._sub.close() - logger.info( - 'Websocket connection to %s closed, reason=%s, message=%s', - self.request.remote_ip, - self.close_code, - self.close_reason, - ) diff --git a/platypush/backend/http/ws/__init__.py b/platypush/backend/http/ws/__init__.py new file mode 100644 index 0000000000..6b117c9c66 --- /dev/null +++ b/platypush/backend/http/ws/__init__.py @@ -0,0 +1,4 @@ +from ._base import WSRoute, logger, pubsub_redis_topic +from ._scanner import scan_routes + +__all__ = ['WSRoute', 'logger', 'pubsub_redis_topic', 'scan_routes'] diff --git a/platypush/backend/http/ws/_base.py b/platypush/backend/http/ws/_base.py new file mode 100644 index 0000000000..903513d72a --- /dev/null +++ b/platypush/backend/http/ws/_base.py @@ -0,0 +1,104 @@ +from abc import ABC, abstractclassmethod +from logging import getLogger +from threading import RLock, Thread +from typing import Any, Generator, Iterable, Optional, Union +from typing_extensions import override + +from redis import ConnectionError as RedisConnectionError +from tornado.ioloop import IOLoop +from tornado.websocket import WebSocketHandler + +from platypush.config import Config +from platypush.utils import get_redis + +logger = getLogger(__name__) + + +def pubsub_redis_topic(topic: str) -> str: + return f'_platypush/{Config.get("device_id")}/{topic}' # type: ignore + + +class WSRoute(WebSocketHandler, Thread, ABC): + """ + Base class for Tornado websocket endpoints. + """ + + def __init__(self, *args, redis_topics: Optional[Iterable[str]] = None, **kwargs): + super().__init__(*args, **kwargs) + self._redis_topics = set(redis_topics or []) + self._sub = get_redis().pubsub() + self._io_loop = IOLoop.current() + self._sub_lock = RLock() + + @override + def open(self, *_, **__): + logger.info('Started websocket connection with %s', self.request.remote_ip) + self.name = f'ws:{self.app_name()}@{self.request.remote_ip}' + self.start() + + @override + def data_received(self, *_, **__): + pass + + @override + def on_message(self, *_, **__): + pass + + @abstractclassmethod + def app_name(cls) -> str: + raise NotImplementedError() + + @classmethod + def path(cls) -> str: + return f'/ws/{cls.app_name()}' + + def subscribe(self, *topics: str) -> None: + with self._sub_lock: + for topic in topics: + self._sub.subscribe(topic) + self._redis_topics.add(topic) + + def unsubscribe(self, *topics: str) -> None: + with self._sub_lock: + for topic in topics: + if topic in self._redis_topics: + self._sub.unsubscribe(topic) + self._redis_topics.remove(topic) + + def listen(self) -> Generator[Any, None, None]: + try: + for msg in self._sub.listen(): + if ( + msg.get('type') != 'message' + and msg.get('channel').decode() not in self._redis_topics + ): + continue + + yield msg.get('data') + except RedisConnectionError: + return + + def send(self, msg: Union[str, bytes]) -> None: + self._io_loop.asyncio_loop.call_soon_threadsafe( # type: ignore + self.write_message, msg + ) + + @override + def run(self) -> None: + super().run() + for topic in self._redis_topics: + self._sub.subscribe(topic) + + @override + def on_close(self): + topics = self._redis_topics.copy() + for topic in topics: + self.unsubscribe(topic) + + self._sub.close() + logger.info( + 'Websocket connection to %s closed, reason=%s, message=%s', + self.request.remote_ip, + self.close_code, + self.close_reason, + ) diff --git a/platypush/backend/http/ws/_scanner.py b/platypush/backend/http/ws/_scanner.py new file mode 100644 index 0000000000..a42131bc8b --- /dev/null +++ b/platypush/backend/http/ws/_scanner.py @@ -0,0 +1,35 @@ +import os +import importlib +import inspect +from typing import List, Type + +import pkgutil + +from ._base import WSRoute, logger + + +def scan_routes() -> List[Type[WSRoute]]: + """ + Scans for websocket route objects. + """ + + base_dir = os.path.dirname(__file__) + routes = [] + + for _, mod_name, _ in pkgutil.walk_packages([base_dir], prefix=__package__ + '.'): + try: + module = importlib.import_module(mod_name) + except Exception as e: + logger.warning('Could not import module %s', mod_name) + logger.exception(e) + continue + + for _, obj in inspect.getmembers(module): + if ( + inspect.isclass(obj) + and not inspect.isabstract(obj) + and issubclass(obj, WSRoute) + ): + routes.append(obj) + + return routes diff --git a/platypush/backend/http/ws/events.py b/platypush/backend/http/ws/events.py new file mode 100644 index 0000000000..42cd75b084 --- /dev/null +++ b/platypush/backend/http/ws/events.py @@ -0,0 +1,33 @@ +from typing_extensions import override + +from platypush.message.event import Event + +from . import WSRoute, logger, pubsub_redis_topic + +events_redis_topic = pubsub_redis_topic('events') + + +class WSEventProxy(WSRoute): + """ + Websocket event proxy mapped to ``/ws/events``. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.subscribe(events_redis_topic) + + @classmethod + @override + def app_name(cls) -> str: + return 'events' + + @override + def run(self) -> None: + for msg in self.listen(): + try: + evt = Event.build(msg.decode()) + except Exception as e: + logger.warning('Error parsing event: %s: %s', msg, e) + continue + + self.send(str(evt))