81 lines
2.1 KiB
Python
81 lines
2.1 KiB
Python
from abc import ABC, abstractmethod
|
|
from logging import getLogger
|
|
from threading import Thread
|
|
|
|
from tornado.ioloop import IOLoop
|
|
from tornado.websocket import WebSocketHandler
|
|
|
|
from platypush.backend.http.app.utils.auth import AuthStatus, get_auth_status
|
|
|
|
from ..mixins import MessageType, PubSubMixin
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
class WSRoute(WebSocketHandler, Thread, PubSubMixin, ABC):
|
|
"""
|
|
Base class for Tornado websocket endpoints.
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
WebSocketHandler.__init__(self, *args)
|
|
PubSubMixin.__init__(self, **kwargs)
|
|
Thread.__init__(self)
|
|
self._io_loop = IOLoop.current()
|
|
|
|
def open(self, *_, **__):
|
|
auth_status = get_auth_status(self.request)
|
|
if auth_status != AuthStatus.OK:
|
|
self.close(code=1008, reason=auth_status.value.message) # Policy Violation
|
|
return
|
|
|
|
logger.info(
|
|
'Client %s connected to %s', self.request.remote_ip, self.request.path
|
|
)
|
|
self.name = f'ws:{self.app_name()}@{self.request.remote_ip}'
|
|
self.start()
|
|
|
|
def data_received(self, *_, **__):
|
|
pass
|
|
|
|
def on_message(self, message):
|
|
return message
|
|
|
|
@classmethod
|
|
@abstractmethod
|
|
def app_name(cls) -> str:
|
|
raise NotImplementedError()
|
|
|
|
@classmethod
|
|
def path(cls) -> str:
|
|
return f'/ws/{cls.app_name()}'
|
|
|
|
@property
|
|
def auth_required(self):
|
|
return True
|
|
|
|
def send(self, msg: MessageType) -> None:
|
|
self._io_loop.asyncio_loop.call_soon_threadsafe( # type: ignore
|
|
self.write_message, self._serialize(msg)
|
|
)
|
|
|
|
def run(self) -> None:
|
|
super().run()
|
|
self.subscribe(*self._subscriptions)
|
|
|
|
def on_close(self):
|
|
super().on_close()
|
|
for channel in self._subscriptions.copy():
|
|
self.unsubscribe(channel)
|
|
|
|
if self._pubsub:
|
|
self._pubsub.close()
|
|
|
|
logger.info(
|
|
'Client %s disconnected from %s, reason=%s, message=%s',
|
|
self.request.remote_ip,
|
|
self.request.path,
|
|
self.close_code,
|
|
self.close_reason,
|
|
)
|