platypush/platypush/utils/__init__.py

563 lines
15 KiB
Python
Raw Normal View History

import ast
import contextlib
import datetime
import hashlib
2017-12-18 01:10:51 +01:00
import importlib
import inspect
2017-12-18 01:10:51 +01:00
import logging
import os
import pathlib
import re
import signal
import socket
import ssl
import urllib.request
from typing import Optional, Tuple, Union
2017-12-18 01:10:51 +01:00
from dateutil import parser, tz
from redis import Redis
from rsa.key import PublicKey, PrivateKey, newkeys
2020-09-27 01:33:38 +02:00
logger = logging.getLogger('utils')
2017-12-18 01:10:51 +01:00
2019-07-01 22:26:04 +02:00
def get_module_and_method_from_action(action):
"""
Input: action=music.mpd.play
Output: ('music.mpd', 'play')
"""
tokens = action.split('.')
module_name = str.join('.', tokens[:-1])
method_name = tokens[-1:][0]
2019-07-01 22:26:04 +02:00
return module_name, method_name
def get_message_class_by_type(msgtype):
"""Gets the class of a message type given as string"""
try:
module = importlib.import_module('platypush.message.' + msgtype)
2018-02-20 22:58:13 +01:00
except ImportError as e:
logger.warning('Unsupported message type %s', msgtype)
raise RuntimeError(e) from e
cls_name = msgtype[0].upper() + msgtype[1:]
try:
msgclass = getattr(module, cls_name)
except AttributeError as e:
logger.warning('No such class in %s: %s', module.__name__, cls_name)
raise RuntimeError(e) from e
return msgclass
def get_event_class_by_type(type): # pylint: disable=redefined-builtin
"""Gets an event class by type name"""
event_module = importlib.import_module('.'.join(type.split('.')[:-1]))
return getattr(event_module, type.split('.')[-1])
2019-07-01 22:26:04 +02:00
def get_plugin_module_by_name(plugin_name):
"""Gets the module of a plugin by name (e.g. "music.mpd" or "media.vlc")"""
2019-07-01 22:26:04 +02:00
module_name = 'platypush.plugins.' + plugin_name
try:
return importlib.import_module('platypush.plugins.' + plugin_name)
except ImportError as e:
logger.error('Cannot import %s: %s', module_name, e)
2019-07-01 22:26:04 +02:00
return None
def get_plugin_class_by_name(plugin_name):
"""Gets the class of a plugin by name (e.g. "music.mpd" or "media.vlc")"""
2019-07-01 22:26:04 +02:00
module = get_plugin_module_by_name(plugin_name)
if not module:
return
class_name = getattr(
module, ''.join([_.capitalize() for _ in plugin_name.split('.')]) + 'Plugin'
)
2019-07-01 22:26:04 +02:00
try:
return getattr(
module, ''.join([_.capitalize() for _ in plugin_name.split('.')]) + 'Plugin'
)
2019-07-01 22:26:04 +02:00
except Exception as e:
logger.error('Cannot import class %s: %s', class_name, e)
2019-07-01 22:26:04 +02:00
return None
2020-09-27 01:33:38 +02:00
def get_plugin_name_by_class(plugin) -> Optional[str]:
"""Gets the common name of a plugin (e.g. "music.mpd" or "media.vlc") given its class."""
2020-09-19 00:50:22 +02:00
from platypush.plugins import Plugin
if isinstance(plugin, Plugin):
plugin = plugin.__class__
class_name = plugin.__name__
class_tokens = [
token.lower()
for token in re.sub(r'([A-Z])', r' \1', class_name).split(' ')
2020-09-19 00:50:22 +02:00
if token.strip() and token != 'Plugin'
]
return '.'.join(class_tokens)
2020-09-27 01:33:38 +02:00
def get_backend_name_by_class(backend) -> Optional[str]:
"""Gets the common name of a backend (e.g. "http" or "mqtt") given its class."""
2020-09-27 01:33:38 +02:00
from platypush.backend import Backend
if isinstance(backend, Backend):
backend = backend.__class__
class_name = backend.__name__
class_tokens = [
token.lower()
for token in re.sub(r'([A-Z])', r' \1', class_name).split(' ')
2020-09-27 01:33:38 +02:00
if token.strip() and token != 'Backend'
]
return '.'.join(class_tokens)
def set_timeout(seconds, on_timeout):
"""
Set a function to be called if timeout expires without being cleared.
It only works on the main thread.
Params:
seconds -- Timeout in seconds
on_timeout -- Function invoked on timeout unless clear_timeout is called before
"""
2021-04-05 00:58:44 +02:00
def _sighandler(*_):
on_timeout()
signal.signal(signal.SIGALRM, _sighandler)
signal.alarm(seconds)
def clear_timeout():
"""Clear any previously set timeout"""
signal.alarm(0)
2017-12-18 01:10:51 +01:00
def get_hash(s):
"""Get the SHA256 hash hex digest of a string input"""
return hashlib.sha256(s.encode('utf-8')).hexdigest()
def get_decorators(cls, climb_class_hierarchy=False):
"""
Get the decorators of a class as a {"decorator_name": [list of methods]} dictionary
:param cls: Class type
:param climb_class_hierarchy: If set to True (default: False), it will search return the decorators in the parent
classes as well
:type climb_class_hierarchy: bool
"""
decorators = {}
# noinspection PyPep8Naming
def visit_FunctionDef(node):
for n in node.decorator_list:
if isinstance(n, ast.Call):
# noinspection PyUnresolvedReferences
name = n.func.attr if isinstance(n.func, ast.Attribute) else n.func.id
else:
name = n.attr if isinstance(n, ast.Attribute) else n.id
decorators[name] = decorators.get(name, set())
decorators[name].add(node.name)
if climb_class_hierarchy:
targets = inspect.getmro(cls)
else:
targets = [cls]
node_iter = ast.NodeVisitor()
node_iter.visit_FunctionDef = visit_FunctionDef
for target in targets:
with contextlib.suppress(TypeError):
node_iter.visit(ast.parse(inspect.getsource(target)))
return decorators
def get_redis_queue_name_by_message(msg) -> Optional[str]:
"""
Get the Redis queue name for the response(s) associated to a request
message.
2018-09-20 12:49:57 +02:00
:param msg: Input message, as a :class:`platypush.message.request.Request`
object.
"""
from platypush.message.request import Request
2018-09-20 12:49:57 +02:00
if not isinstance(msg, Request):
logger.warning('Not a valid request (type: %s): %s', type(msg), msg)
return None
return f'platypush/responses/{msg.id}' if msg.id else None
2018-09-20 12:49:57 +02:00
def _get_ssl_context(
context_type=None, ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None
):
if not context_type:
ssl_context = ssl.create_default_context(cafile=ssl_cafile, capath=ssl_capath)
else:
2019-06-21 02:13:14 +02:00
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
if ssl_cafile or ssl_capath:
ssl_context.load_verify_locations(cafile=ssl_cafile, capath=ssl_capath)
ssl_context.load_cert_chain(
certfile=os.path.abspath(os.path.expanduser(ssl_cert)),
keyfile=os.path.abspath(os.path.expanduser(ssl_key)) if ssl_key else None,
)
return ssl_context
def get_ssl_context(ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None):
"""
Generic builder for SSL context.
"""
return _get_ssl_context(
context_type=None,
ssl_cert=ssl_cert,
ssl_key=ssl_key,
ssl_cafile=ssl_cafile,
ssl_capath=ssl_capath,
)
def get_ssl_server_context(
ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None
):
"""
Builder for a server-side SSL context.
"""
return _get_ssl_context(
context_type=ssl.PROTOCOL_TLS_SERVER,
ssl_cert=ssl_cert,
ssl_key=ssl_key,
ssl_cafile=ssl_cafile,
ssl_capath=ssl_capath,
)
def get_ssl_client_context(
ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None
):
"""
Builder for a client-side SSL context.
"""
return _get_ssl_context(
context_type=ssl.PROTOCOL_TLS_CLIENT,
ssl_cert=ssl_cert,
ssl_key=ssl_key,
ssl_cafile=ssl_cafile,
ssl_capath=ssl_capath,
)
def set_thread_name(name: str):
"""
Set the name of the current thread.
"""
try:
import prctl
prctl.set_name(name) # pylint: disable=no-member
except ImportError:
logger.debug('Unable to set thread name: prctl module is missing')
2019-02-03 17:43:30 +01:00
def find_bins_in_path(bin_name):
"""
Search for a binary in the PATH variable.
"""
return [
os.path.join(p, bin_name)
for p in os.environ.get('PATH', '').split(':')
if os.path.isfile(os.path.join(p, bin_name))
and (os.name == 'nt' or os.access(os.path.join(p, bin_name), os.X_OK))
]
2017-12-18 01:10:51 +01:00
def find_files_by_ext(directory, *exts):
"""
Finds all the files in the given directory with the provided extensions.
"""
if not exts:
raise AttributeError('No extensions provided')
if not os.path.isdir(directory):
raise AttributeError(f'{directory} is not a valid directory')
min_len = len(min(exts, key=len))
max_len = len(max(exts, key=len))
result = []
for _, __, files in os.walk(directory):
for i in range(min_len, max_len + 1):
result += [f for f in files if f[-i:] in exts]
return result
def is_process_alive(pid: int) -> bool:
"""
:param pid: Process ID.
:return: True if the process with the given PID is alive.
"""
try:
os.kill(pid, 0)
return True
except OSError:
return False
def get_ip_or_hostname() -> str:
"""
Get the the default IP address or hostname of the machine.
"""
ip = socket.gethostbyname(socket.gethostname())
if ip.startswith('127.') or ip.startswith('::1'):
2021-01-14 00:15:35 +01:00
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.connect(('10.255.255.255', 1))
ip = sock.getsockname()[0]
sock.close()
2021-04-05 00:58:44 +02:00
except Exception as e:
logger.debug(e)
2021-01-14 00:15:35 +01:00
return ip
def get_mime_type(resource: str) -> Optional[str]:
"""
Get the MIME type of the given resource.
:param resource: The resource to get the MIME type for - it can be a file
path or a URL.
"""
import magic
if resource.startswith('file://'):
offset = len('file://')
resource = resource[offset:]
# noinspection HttpUrlsUsage
if resource.startswith('http://') or resource.startswith('https://'):
with urllib.request.urlopen(resource) as response:
return response.info().get_content_type()
else:
if hasattr(magic, 'detect_from_filename'):
mime = magic.detect_from_filename(resource)
elif hasattr(magic, 'from_file'):
mime = magic.from_file(resource, mime=True)
else:
raise RuntimeError(
'The installed magic version provides neither detect_from_filename nor from_file'
)
2019-06-21 02:13:14 +02:00
if mime:
2021-01-14 00:15:35 +01:00
return mime.mime_type if hasattr(mime, 'mime_type') else mime
return None
def camel_case_to_snake_case(string):
"""
Utility function to convert CamelCase to snake_case.
"""
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', string)
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
def grouper(n, iterable, fillvalue=None):
"""
Split an iterable in groups of max N elements.
grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx
"""
from itertools import zip_longest
args = [iter(iterable)] * n
if fillvalue:
return zip_longest(*args, fillvalue=fillvalue)
for chunk in zip_longest(*args):
yield filter(None, chunk)
def is_functional_procedure(obj) -> bool:
"""
Check if the given object is a functional procedure.
"""
return callable(obj) and hasattr(obj, 'procedure')
def is_functional_hook(obj) -> bool:
"""
Check if the given object is a functional hook.
"""
return callable(obj) and hasattr(obj, 'hook')
def is_functional_cron(obj) -> bool:
"""
Check if the given object is a functional cron.
"""
return callable(obj) and hasattr(obj, 'cron') and hasattr(obj, 'cron_expression')
def run(action, *args, **kwargs):
"""
Run the given action with the given arguments. Example:
>>> from platypush.utils import run
>>> run('music.mpd.play', resource='file:///home/user/music.mp3')
"""
from platypush.context import get_plugin
(module_name, method_name) = get_module_and_method_from_action(action)
plugin = get_plugin(module_name)
method = getattr(plugin, method_name)
response = method(*args, **kwargs)
if response.errors:
raise RuntimeError(response.errors[0])
return response.output
def generate_rsa_key_pair(
key_file: Optional[str] = None, size: int = 2048
) -> Tuple[PublicKey, PrivateKey]:
"""
Generate an RSA key pair.
:param key_file: Target file for the private key (the associated public key will be stored in ``<key_file>.pub``.
If no key file is specified then the public and private keys will be returned in ASCII format in a dictionary
with the following structure:
.. code-block:: json
{
"private": "private key here",
"public": "public key here"
}
:param size: Key size (default: 2048 bits).
:return: A tuple with the generated ``(priv_key_str, pub_key_str)``.
"""
logger.info('Generating RSA keypair')
pub_key, priv_key = newkeys(size)
logger.info('Generated RSA keypair')
public_key_str = pub_key.save_pkcs1('PEM').decode()
private_key_str = priv_key.save_pkcs1('PEM').decode()
if key_file:
logger.info('Saving private key to %s', key_file)
with open(os.path.expanduser(key_file), 'w') as f1, open(
os.path.expanduser(key_file) + '.pub', 'w'
) as f2:
f1.write(private_key_str)
f2.write(public_key_str)
os.chmod(key_file, 0o600)
return pub_key, priv_key
def get_or_generate_jwt_rsa_key_pair():
"""
Get or generate a JWT RSA key pair.
"""
from platypush.config import Config
key_dir = os.path.join(Config.get('workdir'), 'jwt')
priv_key_file = os.path.join(key_dir, 'id_rsa')
pub_key_file = priv_key_file + '.pub'
if os.path.isfile(priv_key_file) and os.path.isfile(pub_key_file):
with open(pub_key_file, 'r') as f1, open(priv_key_file, 'r') as f2:
return (
PublicKey.load_pkcs1(f1.read().encode()),
PrivateKey.load_pkcs1(f2.read().encode()),
)
pathlib.Path(key_dir).mkdir(parents=True, exist_ok=True, mode=0o755)
return generate_rsa_key_pair(priv_key_file, size=2048)
def get_enabled_plugins() -> dict:
"""
Get the enabled plugins.
:return: A dictionary with the enabled plugins, in the format ``name`` ->
:class:`platypush.plugins.Plugin` instance.
"""
from platypush.config import Config
from platypush.context import get_plugin
plugins = {}
for name in Config.get_plugins():
try:
plugin = get_plugin(name)
if plugin:
plugins[name] = plugin
2021-04-05 00:58:44 +02:00
except Exception as e:
logger.warning('Could not initialize plugin %s', name)
logger.exception(e)
return plugins
2023-02-05 22:00:50 +01:00
def get_redis(*args, **kwargs) -> Redis:
"""
Get a Redis client on the basis of the Redis configuration.
The Redis configuration can be loaded from:
1. The ``backend.redis`` configuration (``redis_args`` attribute)
2. The ``redis`` plugin.
"""
from platypush.config import Config
2023-02-05 22:00:50 +01:00
if not (args or kwargs):
kwargs = (
(Config.get('backend.redis') or {}).get('redis_args', {})
or Config.get('redis')
or {}
)
2023-02-05 22:00:50 +01:00
return Redis(*args, **kwargs)
def to_datetime(t: Union[str, int, float, datetime.datetime]) -> datetime.datetime:
"""
Utility function to convert a datetime/timestamp provided as a
string/integer/float/datetime to a ``datetime.datetime`` instance.
"""
if isinstance(t, (int, float)):
return datetime.datetime.fromtimestamp(t, tz=tz.tzutc())
if isinstance(t, str):
return parser.parse(t)
return t
2019-02-03 17:43:30 +01:00
# vim:sw=4:ts=4:et: