diff --git a/platypush/utils/__init__.py b/platypush/utils/__init__.py index 48ae4ab98f..22bc0c0692 100644 --- a/platypush/utils/__init__.py +++ b/platypush/utils/__init__.py @@ -8,7 +8,6 @@ import logging import os import pathlib import re -import rsa import signal import socket import ssl @@ -17,14 +16,16 @@ from typing import Optional, Tuple, Union from dateutil import parser, tz from redis import Redis -from rsa.key import PublicKey, PrivateKey +from rsa.key import PublicKey, PrivateKey, newkeys logger = logging.getLogger('utils') def get_module_and_method_from_action(action): - """Input : action=music.mpd.play - Output : ('music.mpd', 'play')""" + """ + Input: action=music.mpd.play + Output: ('music.mpd', 'play') + """ tokens = action.split('.') module_name = str.join('.', tokens[:-1]) @@ -38,22 +39,21 @@ def get_message_class_by_type(msgtype): try: module = importlib.import_module('platypush.message.' + msgtype) except ImportError as e: - logger.warning('Unsupported message type {}'.format(msgtype)) - raise RuntimeError(e) + logger.warning('Unsupported message type %s', msgtype) + raise RuntimeError(e) from e cls_name = msgtype[0].upper() + msgtype[1:] try: msgclass = getattr(module, cls_name) except AttributeError as e: - logger.warning('No such class in {}: {}'.format(module.__name__, cls_name)) - raise RuntimeError(e) + logger.warning('No such class in %s: %s', module.__name__, cls_name) + raise RuntimeError(e) from e return msgclass -# noinspection PyShadowingBuiltins -def get_event_class_by_type(type): +def get_event_class_by_type(type): # pylint: disable=redefined-builtin """Gets an event class by type name""" event_module = importlib.import_module('.'.join(type.split('.')[:-1])) return getattr(event_module, type.split('.')[-1]) @@ -66,7 +66,7 @@ def get_plugin_module_by_name(plugin_name): try: return importlib.import_module('platypush.plugins.' + plugin_name) except ImportError as e: - logger.error('Cannot import {}: {}'.format(module_name, str(e))) + logger.error('Cannot import %s: %s', module_name, e) return None @@ -85,7 +85,7 @@ def get_plugin_class_by_name(plugin_name): module, ''.join([_.capitalize() for _ in plugin_name.split('.')]) + 'Plugin' ) except Exception as e: - logger.error('Cannot import class {}: {}'.format(class_name, str(e))) + logger.error('Cannot import class %s: %s', class_name, e) return None @@ -191,13 +191,20 @@ def get_decorators(cls, climb_class_hierarchy=False): return decorators -def get_redis_queue_name_by_message(msg): - from platypush.message import Message +def get_redis_queue_name_by_message(msg) -> Optional[str]: + """ + Get the Redis queue name for the response(s) associated to a request + message. - if not isinstance(msg, Message): - logger.warning('Not a valid message (type: {}): {}'.format(type(msg), msg)) + :param msg: Input message, as a :class:`platypush.message.request.Request` + object. + """ + from platypush.message.request import Request - return 'platypush/responses/{}'.format(msg.id) if msg.id else None + if not isinstance(msg, Request): + logger.warning('Not a valid request (type: %s): %s', type(msg), msg) + return None + return f'platypush/responses/{msg.id}' if msg.id else None def _get_ssl_context( @@ -220,6 +227,9 @@ def _get_ssl_context( def get_ssl_context(ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None): + """ + Generic builder for SSL context. + """ return _get_ssl_context( context_type=None, ssl_cert=ssl_cert, @@ -232,6 +242,9 @@ def get_ssl_context(ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=Non def get_ssl_server_context( ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None ): + """ + Builder for a server-side SSL context. + """ return _get_ssl_context( context_type=ssl.PROTOCOL_TLS_SERVER, ssl_cert=ssl_cert, @@ -244,6 +257,9 @@ def get_ssl_server_context( def get_ssl_client_context( ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None ): + """ + Builder for a client-side SSL context. + """ return _get_ssl_context( context_type=ssl.PROTOCOL_TLS_CLIENT, ssl_cert=ssl_cert, @@ -253,19 +269,22 @@ def get_ssl_client_context( ) -def set_thread_name(name): - global logger - +def set_thread_name(name: str): + """ + Set the name of the current thread. + """ try: import prctl - # noinspection PyUnresolvedReferences - prctl.set_name(name) + prctl.set_name(name) # pylint: disable=no-member except ImportError: logger.debug('Unable to set thread name: prctl module is missing') def find_bins_in_path(bin_name): + """ + Search for a binary in the PATH variable. + """ return [ os.path.join(p, bin_name) for p in os.environ.get('PATH', '').split(':') @@ -276,14 +295,14 @@ def find_bins_in_path(bin_name): def find_files_by_ext(directory, *exts): """ - Finds all the files in the given directory with the provided extensions + Finds all the files in the given directory with the provided extensions. """ if not exts: raise AttributeError('No extensions provided') if not os.path.isdir(directory): - raise AttributeError('{} is not a valid directory'.format(directory)) + raise AttributeError(f'{directory} is not a valid directory') min_len = len(min(exts, key=len)) max_len = len(max(exts, key=len)) @@ -296,7 +315,11 @@ def find_files_by_ext(directory, *exts): return result -def is_process_alive(pid): +def is_process_alive(pid: int) -> bool: + """ + :param pid: Process ID. + :return: True if the process with the given PID is alive. + """ try: os.kill(pid, 0) return True @@ -304,7 +327,10 @@ def is_process_alive(pid): return False -def get_ip_or_hostname(): +def get_ip_or_hostname() -> str: + """ + Get the the default IP address or hostname of the machine. + """ ip = socket.gethostbyname(socket.gethostname()) if ip.startswith('127.') or ip.startswith('::1'): try: @@ -318,11 +344,18 @@ def get_ip_or_hostname(): return ip -def get_mime_type(resource): +def get_mime_type(resource: str) -> Optional[str]: + """ + Get the MIME type of the given resource. + + :param resource: The resource to get the MIME type for - it can be a file + path or a URL. + """ import magic if resource.startswith('file://'): - resource = resource[len('file://') :] + offset = len('file://') + resource = resource[offset:] # noinspection HttpUrlsUsage if resource.startswith('http://') or resource.startswith('https://'): @@ -341,8 +374,13 @@ def get_mime_type(resource): if mime: return mime.mime_type if hasattr(mime, 'mime_type') else mime + return None + def camel_case_to_snake_case(string): + """ + Utility function to convert CamelCase to snake_case. + """ s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', string) return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() @@ -364,18 +402,34 @@ def grouper(n, iterable, fillvalue=None): def is_functional_procedure(obj) -> bool: + """ + Check if the given object is a functional procedure. + """ return callable(obj) and hasattr(obj, 'procedure') def is_functional_hook(obj) -> bool: + """ + Check if the given object is a functional hook. + """ return callable(obj) and hasattr(obj, 'hook') def is_functional_cron(obj) -> bool: + """ + Check if the given object is a functional cron. + """ return callable(obj) and hasattr(obj, 'cron') and hasattr(obj, 'cron_expression') def run(action, *args, **kwargs): + """ + Run the given action with the given arguments. Example: + + >>> from platypush.utils import run + >>> run('music.mpd.play', resource='file:///home/user/music.mp3') + + """ from platypush.context import get_plugin (module_name, method_name) = get_module_and_method_from_action(action) @@ -410,13 +464,13 @@ def generate_rsa_key_pair( :return: A tuple with the generated ``(priv_key_str, pub_key_str)``. """ logger.info('Generating RSA keypair') - pub_key, priv_key = rsa.newkeys(size) + pub_key, priv_key = newkeys(size) logger.info('Generated RSA keypair') public_key_str = pub_key.save_pkcs1('PEM').decode() private_key_str = priv_key.save_pkcs1('PEM').decode() if key_file: - logger.info('Saving private key to {}'.format(key_file)) + logger.info('Saving private key to %s', key_file) with open(os.path.expanduser(key_file), 'w') as f1, open( os.path.expanduser(key_file) + '.pub', 'w' ) as f2: @@ -428,6 +482,9 @@ def generate_rsa_key_pair( def get_or_generate_jwt_rsa_key_pair(): + """ + Get or generate a JWT RSA key pair. + """ from platypush.config import Config key_dir = os.path.join(Config.get('workdir'), 'jwt') @@ -437,8 +494,8 @@ def get_or_generate_jwt_rsa_key_pair(): if os.path.isfile(priv_key_file) and os.path.isfile(pub_key_file): with open(pub_key_file, 'r') as f1, open(priv_key_file, 'r') as f2: return ( - rsa.PublicKey.load_pkcs1(f1.read().encode()), - rsa.PrivateKey.load_pkcs1(f2.read().encode()), + PublicKey.load_pkcs1(f1.read().encode()), + PrivateKey.load_pkcs1(f2.read().encode()), ) pathlib.Path(key_dir).mkdir(parents=True, exist_ok=True, mode=0o755) @@ -446,6 +503,12 @@ def get_or_generate_jwt_rsa_key_pair(): def get_enabled_plugins() -> dict: + """ + Get the enabled plugins. + + :return: A dictionary with the enabled plugins, in the format ``name`` -> + :class:`platypush.plugins.Plugin` instance. + """ from platypush.config import Config from platypush.context import get_plugin @@ -456,13 +519,22 @@ def get_enabled_plugins() -> dict: if plugin: plugins[name] = plugin except Exception as e: - logger.warning(f'Could not initialize plugin {name}') + logger.warning('Could not initialize plugin %s', name) logger.exception(e) return plugins def get_redis() -> Redis: + """ + Get a Redis client on the basis of the Redis configuration. + + The Redis configuration can be loaded from: + + 1. The ``backend.redis`` configuration (``redis_args`` attribute) + 2. The ``redis`` plugin. + + """ from platypush.config import Config return Redis( @@ -475,6 +547,10 @@ def get_redis() -> Redis: def to_datetime(t: Union[str, int, float, datetime.datetime]) -> datetime.datetime: + """ + Utility function to convert a datetime/timestamp provided as a + string/integer/float/datetime to a ``datetime.datetime`` instance. + """ if isinstance(t, (int, float)): return datetime.datetime.fromtimestamp(t, tz=tz.tzutc()) if isinstance(t, str):