from abc import ABC, abstractclassmethod
from logging import getLogger
from threading import RLock, Thread
from typing import Any, Generator, Iterable, Optional, Union
from typing_extensions import override

from redis import ConnectionError as RedisConnectionError
from tornado.ioloop import IOLoop
from tornado.websocket import WebSocketHandler

from platypush.config import Config
from platypush.utils import get_redis

logger = getLogger(__name__)


def pubsub_redis_topic(topic: str) -> str:
    return f'_platypush/{Config.get("device_id")}/{topic}'  # type: ignore


class WSRoute(WebSocketHandler, Thread, ABC):
    """
    Base class for Tornado websocket endpoints.
    """

    def __init__(self, *args, redis_topics: Optional[Iterable[str]] = None, **kwargs):
        super().__init__(*args, **kwargs)
        self._redis_topics = set(redis_topics or [])
        self._sub = get_redis().pubsub()
        self._io_loop = IOLoop.current()
        self._sub_lock = RLock()

    @override
    def open(self, *_, **__):
        logger.info('Started websocket connection with %s', self.request.remote_ip)
        self.name = f'ws:{self.app_name()}@{self.request.remote_ip}'
        self.start()

    @override
    def data_received(self, *_, **__):
        pass

    @override
    def on_message(self, *_, **__):
        pass

    @abstractclassmethod
    def app_name(cls) -> str:
        raise NotImplementedError()

    @classmethod
    def path(cls) -> str:
        return f'/ws/{cls.app_name()}'

    def subscribe(self, *topics: str) -> None:
        with self._sub_lock:
            for topic in topics:
                self._sub.subscribe(topic)
                self._redis_topics.add(topic)

    def unsubscribe(self, *topics: str) -> None:
        with self._sub_lock:
            for topic in topics:
                if topic in self._redis_topics:
                    self._sub.unsubscribe(topic)
                    self._redis_topics.remove(topic)

    def listen(self) -> Generator[Any, None, None]:
        try:
            for msg in self._sub.listen():
                if (
                    msg.get('type') != 'message'
                    and msg.get('channel').decode() not in self._redis_topics
                ):
                    continue

                yield msg.get('data')
        except RedisConnectionError:
            return

    def send(self, msg: Union[str, bytes]) -> None:
        self._io_loop.asyncio_loop.call_soon_threadsafe(  # type: ignore
            self.write_message, msg
        )

    @override
    def run(self) -> None:
        super().run()
        for topic in self._redis_topics:
            self._sub.subscribe(topic)

    @override
    def on_close(self):
        topics = self._redis_topics.copy()
        for topic in topics:
            self.unsubscribe(topic)

        self._sub.close()
        logger.info(
            'Websocket connection to %s closed, reason=%s, message=%s',
            self.request.remote_ip,
            self.close_code,
            self.close_reason,
        )