diff --git a/platypush/app/_app.py b/platypush/app/_app.py index 8c8bd01af0..76204a145e 100644 --- a/platypush/app/_app.py +++ b/platypush/app/_app.py @@ -181,6 +181,8 @@ class Application: or os.environ.get('PLATYPUSH_REDIS_QUEUE') or RedisBus.DEFAULT_REDIS_QUEUE ) + + os.environ['PLATYPUSH_REDIS_QUEUE'] = self.redis_queue self.config_file = config_file or os.environ.get('PLATYPUSH_CONFIG') self.verbose = verbose self.db_engine = db or os.environ.get('PLATYPUSH_DB') diff --git a/platypush/backend/http/app/utils/bus.py b/platypush/backend/http/app/utils/bus.py index 168e44fb8a..a043b8df41 100644 --- a/platypush/backend/http/app/utils/bus.py +++ b/platypush/backend/http/app/utils/bus.py @@ -1,24 +1,57 @@ +from multiprocessing import Lock + from platypush.bus.redis import RedisBus +from platypush.context import get_bus from platypush.config import Config -from platypush.context import get_backend from platypush.message import Message from platypush.message.request import Request -from platypush.utils import get_redis_conf, get_message_response +from platypush.utils import get_message_response from .logger import logger -_bus = None + +class BusWrapper: # pylint: disable=too-few-public-methods + """ + Lazy singleton wrapper for the bus object. + """ + + def __init__(self): + self._redis_queue = None + self._bus = None + self._bus_lock = Lock() + + @property + def bus(self) -> RedisBus: + """ + Lazy getter/initializer for the bus object. + """ + with self._bus_lock: + if not self._bus: + self._bus = get_bus() + + bus_: RedisBus = self._bus # type: ignore + return bus_ + + def post(self, msg): + """ + Send a message to the bus. + + :param msg: The message to send. + """ + try: + self.bus.post(msg) + except Exception as e: + logger().exception(e) + + +_bus = BusWrapper() def bus(): """ Lazy getter/initializer for the bus object. """ - global _bus # pylint: disable=global-statement - if _bus is None: - redis_queue = get_backend('http').bus.redis_queue # type: ignore - _bus = RedisBus(**get_redis_conf(), redis_queue=redis_queue) - return _bus + return _bus.bus def send_message(msg, wait_for_response=True): diff --git a/platypush/bus/redis.py b/platypush/bus/redis.py index 40cd7350fa..ab836da5fc 100644 --- a/platypush/bus/redis.py +++ b/platypush/bus/redis.py @@ -1,6 +1,5 @@ import logging import threading -from typing import Optional from platypush.bus import Bus from platypush.message import Message @@ -24,25 +23,39 @@ class RedisBus(Bus): self.redis_queue = redis_queue or self.DEFAULT_REDIS_QUEUE self.on_message = on_message self.thread_id = threading.get_ident() + self._pubsub = None + self._pubsub_lock = threading.RLock() - def get(self) -> Optional[Message]: + @property + def pubsub(self): + with self._pubsub_lock: + if not self._pubsub: + self._pubsub = self.redis.pubsub() + return self._pubsub + + def poll(self): """ - Reads one message from the Redis queue + Polls the Redis queue for new messages """ - try: - if self.should_stop(): - return None + with self.pubsub as pubsub: + pubsub.subscribe(self.redis_queue) + try: + for msg in pubsub.listen(): + if msg.get('type') != 'message': + continue - msg = self.redis.blpop(self.redis_queue, timeout=1) - if not msg or msg[1] is None: - return None + if self.should_stop(): + break - msg = msg[1].decode('utf-8') - return Message.build(msg) - except Exception as e: - logger.exception(e) - - return None + try: + data = msg.get('data', b'').decode('utf-8') + parsed_msg = Message.build(data) + if parsed_msg and self.on_message: + self.on_message(parsed_msg) + except Exception as e: + logger.exception(e) + finally: + pubsub.unsubscribe(self.redis_queue) def post(self, msg): """ @@ -51,15 +64,13 @@ class RedisBus(Bus): from redis import exceptions try: - return self.redis.rpush(self.redis_queue, str(msg)) + self.redis.publish(self.redis_queue, str(msg)) except exceptions.ConnectionError as e: if not self.should_stop(): # Raise the exception only if the bus it not supposed to be # stopped raise e - return None - def stop(self): super().stop() self.redis.close() diff --git a/platypush/context/__init__.py b/platypush/context/__init__.py index 91131b7eaf..85857a8edd 100644 --- a/platypush/context/__init__.py +++ b/platypush/context/__init__.py @@ -1,6 +1,7 @@ import asyncio import importlib import logging +import os from dataclasses import dataclass, field from threading import RLock @@ -194,11 +195,20 @@ def get_bus() -> Bus: Get or register the main application bus. """ from platypush.bus.redis import RedisBus + from platypush.utils import get_redis_conf if _ctx.bus: return _ctx.bus - _ctx.bus = RedisBus() + redis_queue = ( + os.environ.get('PLATYPUSH_REDIS_QUEUE') or RedisBus.DEFAULT_REDIS_QUEUE + ) + + _ctx.bus = RedisBus( + redis_queue=redis_queue, + **get_redis_conf(), + ) + return _ctx.bus diff --git a/platypush/utils/__init__.py b/platypush/utils/__init__.py index f7c9c961ca..5f869e3d74 100644 --- a/platypush/utils/__init__.py +++ b/platypush/utils/__init__.py @@ -22,12 +22,14 @@ from threading import Event, Lock as TLock from typing import Generator, Optional, Tuple, Type, Union from dateutil import parser, tz -from redis import Redis +from redis import ConnectionPool, Redis from rsa.key import PublicKey, PrivateKey, newkeys logger = logging.getLogger('utils') Lock = Union[PLock, TLock] # type: ignore +redis_pools: dict[Tuple[str, int], ConnectionPool] = {} + def get_module_and_method_from_action(action): """ @@ -608,6 +610,29 @@ def get_enabled_backends() -> dict: return backends +def get_redis_pool(*args, **kwargs) -> ConnectionPool: + """ + Get a Redis connection pool on the basis of the Redis configuration. + + The Redis configuration can be loaded from: + + 1. The ``redis`` plugin. + 2. The ``backend.redis`` configuration (``redis_args`` attribute) + + """ + if not (args or kwargs): + kwargs = get_redis_conf() + + pool_key = (kwargs.get('host', 'localhost'), kwargs.get('port', 6379)) + pool = redis_pools.get(pool_key) + + if not pool: + pool = ConnectionPool(*args, **kwargs) + redis_pools[pool_key] = pool + + return pool + + def get_redis_conf() -> dict: """ Get the Redis connection arguments from the configuration. @@ -631,10 +656,7 @@ def get_redis(*args, **kwargs) -> Redis: 2. The ``backend.redis`` configuration (``redis_args`` attribute) """ - if not (args or kwargs): - kwargs = get_redis_conf() - - return Redis(*args, **kwargs) + return Redis(connection_pool=get_redis_pool(*args, **kwargs)) def to_datetime(t: Union[str, int, float, datetime.datetime]) -> datetime.datetime: