forked from platypush/platypush
106 lines
3.0 KiB
Python
106 lines
3.0 KiB
Python
from abc import ABC, abstractclassmethod
|
|
from logging import getLogger
|
|
from threading import RLock, Thread
|
|
from typing import Any, Generator, Iterable, Optional, Union
|
|
from typing_extensions import override
|
|
|
|
from redis import ConnectionError as RedisConnectionError
|
|
from tornado.ioloop import IOLoop
|
|
from tornado.websocket import WebSocketHandler
|
|
|
|
from platypush.config import Config
|
|
from platypush.utils import get_redis
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
def pubsub_redis_topic(topic: str) -> str:
|
|
return f'_platypush/{Config.get("device_id")}/{topic}' # type: ignore
|
|
|
|
|
|
class WSRoute(WebSocketHandler, Thread, ABC):
|
|
"""
|
|
Base class for Tornado websocket endpoints.
|
|
"""
|
|
|
|
def __init__(self, *args, redis_topics: Optional[Iterable[str]] = None, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._redis_topics = set(redis_topics or [])
|
|
self._sub = get_redis().pubsub()
|
|
self._io_loop = IOLoop.current()
|
|
self._sub_lock = RLock()
|
|
|
|
@override
|
|
def open(self, *_, **__):
|
|
logger.info('Client %s connected to %s', self.request.remote_ip, self.path())
|
|
self.name = f'ws:{self.app_name()}@{self.request.remote_ip}'
|
|
self.start()
|
|
|
|
@override
|
|
def data_received(self, *_, **__):
|
|
pass
|
|
|
|
@override
|
|
def on_message(self, *_, **__):
|
|
pass
|
|
|
|
@abstractclassmethod
|
|
def app_name(cls) -> str:
|
|
raise NotImplementedError()
|
|
|
|
@classmethod
|
|
def path(cls) -> str:
|
|
return f'/ws/{cls.app_name()}'
|
|
|
|
def subscribe(self, *topics: str) -> None:
|
|
with self._sub_lock:
|
|
for topic in topics:
|
|
self._sub.subscribe(topic)
|
|
self._redis_topics.add(topic)
|
|
|
|
def unsubscribe(self, *topics: str) -> None:
|
|
with self._sub_lock:
|
|
for topic in topics:
|
|
if topic in self._redis_topics:
|
|
self._sub.unsubscribe(topic)
|
|
self._redis_topics.remove(topic)
|
|
|
|
def listen(self) -> Generator[Any, None, None]:
|
|
try:
|
|
for msg in self._sub.listen():
|
|
if (
|
|
msg.get('type') != 'message'
|
|
and msg.get('channel').decode() not in self._redis_topics
|
|
):
|
|
continue
|
|
|
|
yield msg.get('data')
|
|
except RedisConnectionError:
|
|
return
|
|
|
|
def send(self, msg: Union[str, bytes]) -> None:
|
|
self._io_loop.asyncio_loop.call_soon_threadsafe( # type: ignore
|
|
self.write_message, msg
|
|
)
|
|
|
|
@override
|
|
def run(self) -> None:
|
|
super().run()
|
|
for topic in self._redis_topics:
|
|
self._sub.subscribe(topic)
|
|
|
|
@override
|
|
def on_close(self):
|
|
topics = self._redis_topics.copy()
|
|
for topic in topics:
|
|
self.unsubscribe(topic)
|
|
|
|
self._sub.close()
|
|
logger.info(
|
|
'Client %s disconnected from %s, reason=%s, message=%s',
|
|
self.request.remote_ip,
|
|
self.path(),
|
|
self.close_code,
|
|
self.close_reason,
|
|
)
|