platypush/platypush/backend/http/app/ws/_base.py

81 lines
2.1 KiB
Python

from abc import ABC, abstractmethod
from logging import getLogger
from threading import Thread
from tornado.ioloop import IOLoop
from tornado.websocket import WebSocketHandler
from platypush.backend.http.app.utils.auth import AuthStatus, get_auth_status
from ..mixins import MessageType, PubSubMixin
logger = getLogger(__name__)
class WSRoute(WebSocketHandler, Thread, PubSubMixin, ABC):
"""
Base class for Tornado websocket endpoints.
"""
def __init__(self, *args, **kwargs):
WebSocketHandler.__init__(self, *args)
PubSubMixin.__init__(self, **kwargs)
Thread.__init__(self)
self._io_loop = IOLoop.current()
def open(self, *_, **__):
auth_status = get_auth_status(self.request)
if auth_status != AuthStatus.OK:
self.close(code=1008, reason=auth_status.value.message) # Policy Violation
return
logger.info(
'Client %s connected to %s', self.request.remote_ip, self.request.path
)
self.name = f'ws:{self.app_name()}@{self.request.remote_ip}'
self.start()
def data_received(self, *_, **__):
pass
def on_message(self, message):
return message
@classmethod
@abstractmethod
def app_name(cls) -> str:
raise NotImplementedError()
@classmethod
def path(cls) -> str:
return f'/ws/{cls.app_name()}'
@property
def auth_required(self):
return True
def send(self, msg: MessageType) -> None:
self._io_loop.asyncio_loop.call_soon_threadsafe( # type: ignore
self.write_message, self._serialize(msg)
)
def run(self) -> None:
super().run()
self.subscribe(*self._subscriptions)
def on_close(self):
super().on_close()
for channel in self._subscriptions.copy():
self.unsubscribe(channel)
if self._pubsub:
self._pubsub.close()
logger.info(
'Client %s disconnected from %s, reason=%s, message=%s',
self.request.remote_ip,
self.request.path,
self.close_code,
self.close_reason,
)