platypush/platypush/backend/http/app/ws/_base.py

81 lines
2.1 KiB
Python
Raw Normal View History

from abc import ABC, abstractmethod
from logging import getLogger
from threading import Thread
from tornado.ioloop import IOLoop
from tornado.websocket import WebSocketHandler
from platypush.backend.http.app.utils.auth import AuthStatus, get_auth_status
from ..mixins import MessageType, PubSubMixin
logger = getLogger(__name__)
class WSRoute(WebSocketHandler, Thread, PubSubMixin, ABC):
"""
Base class for Tornado websocket endpoints.
"""
def __init__(self, *args, **kwargs):
WebSocketHandler.__init__(self, *args)
PubSubMixin.__init__(self, **kwargs)
Thread.__init__(self)
self._io_loop = IOLoop.current()
def open(self, *_, **__):
auth_status = get_auth_status(self.request)
if auth_status != AuthStatus.OK:
self.close(code=1008, reason=auth_status.value.message) # Policy Violation
return
logger.info(
'Client %s connected to %s', self.request.remote_ip, self.request.path
)
self.name = f'ws:{self.app_name()}@{self.request.remote_ip}'
self.start()
def data_received(self, *_, **__):
pass
def on_message(self, message):
return message
@classmethod
@abstractmethod
def app_name(cls) -> str:
raise NotImplementedError()
@classmethod
def path(cls) -> str:
return f'/ws/{cls.app_name()}'
@property
def auth_required(self):
return True
def send(self, msg: MessageType) -> None:
self._io_loop.asyncio_loop.call_soon_threadsafe( # type: ignore
self.write_message, self._serialize(msg)
)
def run(self) -> None:
super().run()
self.subscribe(*self._subscriptions)
def on_close(self):
super().on_close()
for channel in self._subscriptions.copy():
self.unsubscribe(channel)
if self._pubsub:
self._pubsub.close()
logger.info(
'Client %s disconnected from %s, reason=%s, message=%s',
self.request.remote_ip,
self.request.path,
self.close_code,
self.close_reason,
)