From 26ffc0b0e1995e461d2a17e071fc0c889f4757d8 Mon Sep 17 00:00:00 2001 From: Fabio Manganiello Date: Thu, 7 Apr 2022 00:18:11 +0200 Subject: [PATCH] Use Redis instead of an in-process map to store the entity/plugin registry This is particularly useful when we want to access the registry from another process, like the web server or an external script. --- platypush/entities/__init__.py | 10 +- platypush/entities/_registry.py | 36 ++++--- platypush/utils/__init__.py | 160 +++++++++++++++++++------------- 3 files changed, 127 insertions(+), 79 deletions(-) diff --git a/platypush/entities/__init__.py b/platypush/entities/__init__.py index f59ab240b1..361d67e5e1 100644 --- a/platypush/entities/__init__.py +++ b/platypush/entities/__init__.py @@ -1,15 +1,16 @@ import warnings from typing import Collection, Optional -from ._base import Entity +from ._base import Entity, get_entities_registry from ._engine import EntitiesEngine -from ._registry import manages, register_entity_plugin, get_plugin_registry +from ._registry import manages, register_entity_plugin, get_plugin_entity_registry _engine: Optional[EntitiesEngine] = None def init_entities_engine() -> EntitiesEngine: from ._base import init_entities_db + global _engine init_entities_db() _engine = EntitiesEngine() @@ -24,13 +25,14 @@ def publish_entities(entities: Collection[Entity]): _engine.post(*entities) + __all__ = ( 'Entity', 'EntitiesEngine', 'init_entities_engine', 'publish_entities', 'register_entity_plugin', - 'get_plugin_registry', + 'get_plugin_entity_registry', + 'get_entities_registry', 'manages', ) - diff --git a/platypush/entities/_registry.py b/platypush/entities/_registry.py index b8644808f1..53d87234ca 100644 --- a/platypush/entities/_registry.py +++ b/platypush/entities/_registry.py @@ -1,24 +1,38 @@ +import json from datetime import datetime -from typing import Optional, Mapping, Dict, Collection, Type +from typing import Optional, Dict, Collection, Type from platypush.plugins import Plugin -from platypush.utils import get_plugin_name_by_class +from platypush.utils import get_plugin_name_by_class, get_redis from ._base import Entity -_entity_plugin_registry: Mapping[Type[Entity], Dict[str, Plugin]] = {} +_entity_registry_varname = '_platypush/plugin_entity_registry' def register_entity_plugin(entity_type: Type[Entity], plugin: Plugin): - plugins = _entity_plugin_registry.get(entity_type, {}) - plugin_name = get_plugin_name_by_class(plugin.__class__) - assert plugin_name - plugins[plugin_name] = plugin - _entity_plugin_registry[entity_type] = plugins + plugin_name = get_plugin_name_by_class(plugin.__class__) or '' + entity_type_name = entity_type.__name__.lower() + redis = get_redis() + registry = get_plugin_entity_registry() + registry_by_plugin = set(registry['by_plugin'].get(plugin_name, [])) + + registry_by_entity_type = set(registry['by_entity_type'].get(entity_type_name, [])) + + registry_by_plugin.add(entity_type_name) + registry_by_entity_type.add(plugin_name) + registry['by_plugin'][plugin_name] = list(registry_by_plugin) + registry['by_entity_type'][entity_type_name] = list(registry_by_entity_type) + redis.mset({_entity_registry_varname: json.dumps(registry)}) -def get_plugin_registry(): - return _entity_plugin_registry.copy() +def get_plugin_entity_registry() -> Dict[str, Dict[str, Collection[str]]]: + redis = get_redis() + registry = redis.mget([_entity_registry_varname])[0] + try: + return json.loads((registry or b'').decode()) + except (TypeError, ValueError): + return {'by_plugin': {}, 'by_entity_type': {}} class EntityManagerMixin: @@ -37,6 +51,7 @@ class EntityManagerMixin: def publish_entities(self, entities: Optional[Collection[Entity]]): from . import publish_entities + entities = self.transform_entities(entities) publish_entities(entities) @@ -59,4 +74,3 @@ def manages(*entities: Type[Entity]): return plugin return wrapper - diff --git a/platypush/utils/__init__.py b/platypush/utils/__init__.py index c75615c855..61f2fc23e5 100644 --- a/platypush/utils/__init__.py +++ b/platypush/utils/__init__.py @@ -1,4 +1,5 @@ import ast +import contextlib import datetime import hashlib import importlib @@ -20,8 +21,8 @@ logger = logging.getLogger('utils') def get_module_and_method_from_action(action): - """ Input : action=music.mpd.play - Output : ('music.mpd', 'play') """ + """Input : action=music.mpd.play + Output : ('music.mpd', 'play')""" tokens = action.split('.') module_name = str.join('.', tokens[:-1]) @@ -30,7 +31,7 @@ def get_module_and_method_from_action(action): def get_message_class_by_type(msgtype): - """ Gets the class of a message type given as string """ + """Gets the class of a message type given as string""" try: module = importlib.import_module('platypush.message.' + msgtype) @@ -43,8 +44,7 @@ def get_message_class_by_type(msgtype): try: msgclass = getattr(module, cls_name) except AttributeError as e: - logger.warning('No such class in {}: {}'.format( - module.__name__, cls_name)) + logger.warning('No such class in {}: {}'.format(module.__name__, cls_name)) raise RuntimeError(e) return msgclass @@ -52,13 +52,13 @@ def get_message_class_by_type(msgtype): # noinspection PyShadowingBuiltins def get_event_class_by_type(type): - """ Gets an event class by type name """ + """Gets an event class by type name""" event_module = importlib.import_module('.'.join(type.split('.')[:-1])) return getattr(event_module, type.split('.')[-1]) def get_plugin_module_by_name(plugin_name): - """ Gets the module of a plugin by name (e.g. "music.mpd" or "media.vlc") """ + """Gets the module of a plugin by name (e.g. "music.mpd" or "media.vlc")""" module_name = 'platypush.plugins.' + plugin_name try: @@ -69,22 +69,26 @@ def get_plugin_module_by_name(plugin_name): def get_plugin_class_by_name(plugin_name): - """ Gets the class of a plugin by name (e.g. "music.mpd" or "media.vlc") """ + """Gets the class of a plugin by name (e.g. "music.mpd" or "media.vlc")""" module = get_plugin_module_by_name(plugin_name) if not module: return - class_name = getattr(module, ''.join([_.capitalize() for _ in plugin_name.split('.')]) + 'Plugin') + class_name = getattr( + module, ''.join([_.capitalize() for _ in plugin_name.split('.')]) + 'Plugin' + ) try: - return getattr(module, ''.join([_.capitalize() for _ in plugin_name.split('.')]) + 'Plugin') + return getattr( + module, ''.join([_.capitalize() for _ in plugin_name.split('.')]) + 'Plugin' + ) except Exception as e: logger.error('Cannot import class {}: {}'.format(class_name, str(e))) return None def get_plugin_name_by_class(plugin) -> Optional[str]: - """Gets the common name of a plugin (e.g. "music.mpd" or "media.vlc") given its class. """ + """Gets the common name of a plugin (e.g. "music.mpd" or "media.vlc") given its class.""" from platypush.plugins import Plugin @@ -93,7 +97,8 @@ def get_plugin_name_by_class(plugin) -> Optional[str]: class_name = plugin.__name__ class_tokens = [ - token.lower() for token in re.sub(r'([A-Z])', r' \1', class_name).split(' ') + token.lower() + for token in re.sub(r'([A-Z])', r' \1', class_name).split(' ') if token.strip() and token != 'Plugin' ] @@ -101,7 +106,7 @@ def get_plugin_name_by_class(plugin) -> Optional[str]: def get_backend_name_by_class(backend) -> Optional[str]: - """Gets the common name of a backend (e.g. "http" or "mqtt") given its class. """ + """Gets the common name of a backend (e.g. "http" or "mqtt") given its class.""" from platypush.backend import Backend @@ -110,7 +115,8 @@ def get_backend_name_by_class(backend) -> Optional[str]: class_name = backend.__name__ class_tokens = [ - token.lower() for token in re.sub(r'([A-Z])', r' \1', class_name).split(' ') + token.lower() + for token in re.sub(r'([A-Z])', r' \1', class_name).split(' ') if token.strip() and token != 'Backend' ] @@ -135,12 +141,12 @@ def set_timeout(seconds, on_timeout): def clear_timeout(): - """ Clear any previously set timeout """ + """Clear any previously set timeout""" signal.alarm(0) def get_hash(s): - """ Get the SHA256 hash hex digest of a string input """ + """Get the SHA256 hash hex digest of a string input""" return hashlib.sha256(s.encode('utf-8')).hexdigest() @@ -177,11 +183,8 @@ def get_decorators(cls, climb_class_hierarchy=False): node_iter.visit_FunctionDef = visit_FunctionDef for target in targets: - try: + with contextlib.suppress(TypeError): node_iter.visit(ast.parse(inspect.getsource(target))) - except TypeError: - # Ignore built-in classes - pass return decorators @@ -195,45 +198,57 @@ def get_redis_queue_name_by_message(msg): return 'platypush/responses/{}'.format(msg.id) if msg.id else None -def _get_ssl_context(context_type=None, ssl_cert=None, ssl_key=None, - ssl_cafile=None, ssl_capath=None): +def _get_ssl_context( + context_type=None, ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None +): if not context_type: - ssl_context = ssl.create_default_context(cafile=ssl_cafile, - capath=ssl_capath) + ssl_context = ssl.create_default_context(cafile=ssl_cafile, capath=ssl_capath) else: ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) if ssl_cafile or ssl_capath: - ssl_context.load_verify_locations( - cafile=ssl_cafile, capath=ssl_capath) + ssl_context.load_verify_locations(cafile=ssl_cafile, capath=ssl_capath) ssl_context.load_cert_chain( certfile=os.path.abspath(os.path.expanduser(ssl_cert)), - keyfile=os.path.abspath(os.path.expanduser(ssl_key)) if ssl_key else None + keyfile=os.path.abspath(os.path.expanduser(ssl_key)) if ssl_key else None, ) return ssl_context -def get_ssl_context(ssl_cert=None, ssl_key=None, ssl_cafile=None, - ssl_capath=None): - return _get_ssl_context(context_type=None, - ssl_cert=ssl_cert, ssl_key=ssl_key, - ssl_cafile=ssl_cafile, ssl_capath=ssl_capath) +def get_ssl_context(ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None): + return _get_ssl_context( + context_type=None, + ssl_cert=ssl_cert, + ssl_key=ssl_key, + ssl_cafile=ssl_cafile, + ssl_capath=ssl_capath, + ) -def get_ssl_server_context(ssl_cert=None, ssl_key=None, ssl_cafile=None, - ssl_capath=None): - return _get_ssl_context(context_type=ssl.PROTOCOL_TLS_SERVER, - ssl_cert=ssl_cert, ssl_key=ssl_key, - ssl_cafile=ssl_cafile, ssl_capath=ssl_capath) +def get_ssl_server_context( + ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None +): + return _get_ssl_context( + context_type=ssl.PROTOCOL_TLS_SERVER, + ssl_cert=ssl_cert, + ssl_key=ssl_key, + ssl_cafile=ssl_cafile, + ssl_capath=ssl_capath, + ) -def get_ssl_client_context(ssl_cert=None, ssl_key=None, ssl_cafile=None, - ssl_capath=None): - return _get_ssl_context(context_type=ssl.PROTOCOL_TLS_CLIENT, - ssl_cert=ssl_cert, ssl_key=ssl_key, - ssl_cafile=ssl_cafile, ssl_capath=ssl_capath) +def get_ssl_client_context( + ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None +): + return _get_ssl_context( + context_type=ssl.PROTOCOL_TLS_CLIENT, + ssl_cert=ssl_cert, + ssl_key=ssl_key, + ssl_cafile=ssl_cafile, + ssl_capath=ssl_capath, + ) def set_thread_name(name): @@ -241,6 +256,7 @@ def set_thread_name(name): try: import prctl + # noinspection PyUnresolvedReferences prctl.set_name(name) except ImportError: @@ -251,9 +267,9 @@ def find_bins_in_path(bin_name): return [ os.path.join(p, bin_name) for p in os.environ.get('PATH', '').split(':') - if os.path.isfile(os.path.join(p, bin_name)) and ( - os.name == 'nt' or os.access(os.path.join(p, bin_name), os.X_OK) - )] + if os.path.isfile(os.path.join(p, bin_name)) + and (os.name == 'nt' or os.access(os.path.join(p, bin_name), os.X_OK)) + ] def find_files_by_ext(directory, *exts): @@ -271,7 +287,7 @@ def find_files_by_ext(directory, *exts): max_len = len(max(exts, key=len)) result = [] - for root, dirs, files in os.walk(directory): + for _, __, files in os.walk(directory): for i in range(min_len, max_len + 1): result += [f for f in files if f[-i:] in exts] @@ -302,8 +318,9 @@ def get_ip_or_hostname(): def get_mime_type(resource): import magic + if resource.startswith('file://'): - resource = resource[len('file://'):] + resource = resource[len('file://') :] # noinspection HttpUrlsUsage if resource.startswith('http://') or resource.startswith('https://'): @@ -315,7 +332,9 @@ def get_mime_type(resource): elif hasattr(magic, 'from_file'): mime = magic.from_file(resource, mime=True) else: - raise RuntimeError('The installed magic version provides neither detect_from_filename nor from_file') + raise RuntimeError( + 'The installed magic version provides neither detect_from_filename nor from_file' + ) if mime: return mime.mime_type if hasattr(mime, 'mime_type') else mime @@ -332,6 +351,7 @@ def grouper(n, iterable, fillvalue=None): grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx """ from itertools import zip_longest + args = [iter(iterable)] * n if fillvalue: @@ -355,6 +375,7 @@ def is_functional_cron(obj) -> bool: def run(action, *args, **kwargs): from platypush.context import get_plugin + (module_name, method_name) = get_module_and_method_from_action(action) plugin = get_plugin(module_name) method = getattr(plugin, method_name) @@ -366,7 +387,9 @@ def run(action, *args, **kwargs): return response.output -def generate_rsa_key_pair(key_file: Optional[str] = None, size: int = 2048) -> Tuple[str, str]: +def generate_rsa_key_pair( + key_file: Optional[str] = None, size: int = 2048 +) -> Tuple[str, str]: """ Generate an RSA key pair. @@ -390,27 +413,30 @@ def generate_rsa_key_pair(key_file: Optional[str] = None, size: int = 2048) -> T public_exp = 65537 private_key = rsa.generate_private_key( - public_exponent=public_exp, - key_size=size, - backend=default_backend() + public_exponent=public_exp, key_size=size, backend=default_backend() ) logger.info('Generating RSA {} key pair'.format(size)) private_key_str = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption() + encryption_algorithm=serialization.NoEncryption(), ).decode() - public_key_str = private_key.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.PKCS1, - ).decode() + public_key_str = ( + private_key.public_key() + .public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.PKCS1, + ) + .decode() + ) if key_file: logger.info('Saving private key to {}'.format(key_file)) - with open(os.path.expanduser(key_file), 'w') as f1, \ - open(os.path.expanduser(key_file) + '.pub', 'w') as f2: + with open(os.path.expanduser(key_file), 'w') as f1, open( + os.path.expanduser(key_file) + '.pub', 'w' + ) as f2: f1.write(private_key_str) f2.write(public_key_str) os.chmod(key_file, 0o600) @@ -426,8 +452,7 @@ def get_or_generate_jwt_rsa_key_pair(): pub_key_file = priv_key_file + '.pub' if os.path.isfile(priv_key_file) and os.path.isfile(pub_key_file): - with open(pub_key_file, 'r') as f1, \ - open(priv_key_file, 'r') as f2: + with open(pub_key_file, 'r') as f1, open(priv_key_file, 'r') as f2: return f1.read(), f2.read() pathlib.Path(key_dir).mkdir(parents=True, exist_ok=True, mode=0o755) @@ -439,7 +464,7 @@ def get_enabled_plugins() -> dict: from platypush.context import get_plugin plugins = {} - for name, config in Config.get_plugins().items(): + for name in Config.get_plugins(): try: plugin = get_plugin(name) if plugin: @@ -453,11 +478,18 @@ def get_enabled_plugins() -> dict: def get_redis() -> Redis: from platypush.config import Config - return Redis(**(Config.get('backend.redis') or {}).get('redis_args', {})) + + return Redis( + **( + (Config.get('backend.redis') or {}).get('redis_args', {}) + or Config.get('redis') + or {} + ) + ) def to_datetime(t: Union[str, int, float, datetime.datetime]) -> datetime.datetime: - if isinstance(t, int) or isinstance(t, float): + if isinstance(t, (int, float)): return datetime.datetime.fromtimestamp(t, tz=tz.tzutc()) if isinstance(t, str): return parser.parse(t)