diff --git a/platypush/bus/__init__.py b/platypush/bus/__init__.py index 8bfc4e6e6..f8610d4ad 100644 --- a/platypush/bus/__init__.py +++ b/platypush/bus/__init__.py @@ -1,15 +1,34 @@ +from collections import defaultdict +from dataclasses import dataclass, field import logging import threading import time from queue import Queue, Empty -from typing import Callable, Type +from typing import Callable, Dict, Iterable, Type +from platypush.message import Message from platypush.message.event import Event logger = logging.getLogger('platypush:bus') +@dataclass +class MessageHandler: + """ + Wrapper for a message callback handler. + """ + + msg_type: Type[Message] + callback: Callable[[Message], None] + kwargs: dict = field(default_factory=dict) + + def match(self, msg: Message) -> bool: + return isinstance(msg, self.msg_type) and all( + getattr(msg, k, None) == v for k, v in self.kwargs.items() + ) + + class Bus: """ Main local bus where the daemon will listen for new messages. @@ -21,7 +40,10 @@ class Bus: self.bus = Queue() self.on_message = on_message self.thread_id = threading.get_ident() - self.event_handlers = {} + self.handlers: Dict[ + Type[Message], Dict[Callable[[Message], None], MessageHandler] + ] = defaultdict(dict) + self._should_stop = threading.Event() def post(self, msg): @@ -38,26 +60,24 @@ class Bus: def stop(self): self._should_stop.set() + def _get_matching_handlers( + self, msg: Message + ) -> Iterable[Callable[[Message], None]]: + return [ + hndl.callback + for cls in type(msg).__mro__ + for hndl in self.handlers.get(cls, []) + if hndl.match(msg) + ] + def _msg_executor(self, msg): def event_handler(event: Event, handler: Callable[[Event], None]): logger.info('Triggering event handler %s', handler.__name__) handler(event) def executor(): - if isinstance(msg, Event): - handlers = self.event_handlers.get( - type(msg), - { - *[ - hndl - for event_type, hndl in self.event_handlers.items() - if isinstance(msg, event_type) - ] - }, - ) - - for hndl in handlers: - threading.Thread(target=event_handler, args=(msg, hndl)) + for hndl in self._get_matching_handlers(msg): + threading.Thread(target=event_handler, args=(msg, hndl)).start() try: if self.on_message: @@ -100,27 +120,25 @@ class Bus: logger.info('Bus service stopped') def register_handler( - self, event_type: Type[Event], handler: Callable[[Event], None] + self, type: Type[Message], handler: Callable[[Message], None], **kwargs ) -> Callable[[], None]: """ - Register an event handler to the bus. + Register a generic handler to the bus. - :param event_type: Event type to subscribe (event inheritance also works). - :param handler: Event handler - a function that takes an Event object as parameter. + :param type: Type of the message to subscribe to (event inheritance also works). + :param handler: Event handler - a function that takes a Message object as parameter. + :param kwargs: Extra filter on the message values. :return: A function that can be called to remove the handler (no parameters required). """ - if event_type not in self.event_handlers: - self.event_handlers[event_type] = set() - - self.event_handlers[event_type].add(handler) + self.handlers[type][handler] = MessageHandler(type, handler, kwargs) def unregister(): - self.unregister_handler(event_type, handler) + self.unregister_handler(type, handler) return unregister def unregister_handler( - self, event_type: Type[Event], handler: Callable[[Event], None] + self, type: Type[Message], handler: Callable[[Message], None] ) -> None: """ Remove an event handler. @@ -128,14 +146,12 @@ class Bus: :param event_type: Event type. :param handler: Existing event handler. """ - if event_type not in self.event_handlers: + if type not in self.handlers: return - if handler in self.event_handlers[event_type]: - self.event_handlers[event_type].remove(handler) - - if len(self.event_handlers[event_type]) == 0: - del self.event_handlers[event_type] + self.handlers[type].pop(handler, None) + if len(self.handlers[type]) == 0: + del self.handlers[type] # vim:sw=4:ts=4:et: