forked from platypush/platypush
Fabio Manganiello
26ffc0b0e1
This is particularly useful when we want to access the registry from another process, like the web server or an external script.
499 lines
14 KiB
Python
499 lines
14 KiB
Python
import ast
|
|
import contextlib
|
|
import datetime
|
|
import hashlib
|
|
import importlib
|
|
import inspect
|
|
import logging
|
|
import os
|
|
import pathlib
|
|
import re
|
|
import signal
|
|
import socket
|
|
import ssl
|
|
import urllib.request
|
|
from typing import Optional, Tuple, Union
|
|
|
|
from dateutil import parser, tz
|
|
from redis import Redis
|
|
|
|
logger = logging.getLogger('utils')
|
|
|
|
|
|
def get_module_and_method_from_action(action):
|
|
"""Input : action=music.mpd.play
|
|
Output : ('music.mpd', 'play')"""
|
|
|
|
tokens = action.split('.')
|
|
module_name = str.join('.', tokens[:-1])
|
|
method_name = tokens[-1:][0]
|
|
return module_name, method_name
|
|
|
|
|
|
def get_message_class_by_type(msgtype):
|
|
"""Gets the class of a message type given as string"""
|
|
|
|
try:
|
|
module = importlib.import_module('platypush.message.' + msgtype)
|
|
except ImportError as e:
|
|
logger.warning('Unsupported message type {}'.format(msgtype))
|
|
raise RuntimeError(e)
|
|
|
|
cls_name = msgtype[0].upper() + msgtype[1:]
|
|
|
|
try:
|
|
msgclass = getattr(module, cls_name)
|
|
except AttributeError as e:
|
|
logger.warning('No such class in {}: {}'.format(module.__name__, cls_name))
|
|
raise RuntimeError(e)
|
|
|
|
return msgclass
|
|
|
|
|
|
# noinspection PyShadowingBuiltins
|
|
def get_event_class_by_type(type):
|
|
"""Gets an event class by type name"""
|
|
event_module = importlib.import_module('.'.join(type.split('.')[:-1]))
|
|
return getattr(event_module, type.split('.')[-1])
|
|
|
|
|
|
def get_plugin_module_by_name(plugin_name):
|
|
"""Gets the module of a plugin by name (e.g. "music.mpd" or "media.vlc")"""
|
|
|
|
module_name = 'platypush.plugins.' + plugin_name
|
|
try:
|
|
return importlib.import_module('platypush.plugins.' + plugin_name)
|
|
except ImportError as e:
|
|
logger.error('Cannot import {}: {}'.format(module_name, str(e)))
|
|
return None
|
|
|
|
|
|
def get_plugin_class_by_name(plugin_name):
|
|
"""Gets the class of a plugin by name (e.g. "music.mpd" or "media.vlc")"""
|
|
|
|
module = get_plugin_module_by_name(plugin_name)
|
|
if not module:
|
|
return
|
|
|
|
class_name = getattr(
|
|
module, ''.join([_.capitalize() for _ in plugin_name.split('.')]) + 'Plugin'
|
|
)
|
|
try:
|
|
return getattr(
|
|
module, ''.join([_.capitalize() for _ in plugin_name.split('.')]) + 'Plugin'
|
|
)
|
|
except Exception as e:
|
|
logger.error('Cannot import class {}: {}'.format(class_name, str(e)))
|
|
return None
|
|
|
|
|
|
def get_plugin_name_by_class(plugin) -> Optional[str]:
|
|
"""Gets the common name of a plugin (e.g. "music.mpd" or "media.vlc") given its class."""
|
|
|
|
from platypush.plugins import Plugin
|
|
|
|
if isinstance(plugin, Plugin):
|
|
plugin = plugin.__class__
|
|
|
|
class_name = plugin.__name__
|
|
class_tokens = [
|
|
token.lower()
|
|
for token in re.sub(r'([A-Z])', r' \1', class_name).split(' ')
|
|
if token.strip() and token != 'Plugin'
|
|
]
|
|
|
|
return '.'.join(class_tokens)
|
|
|
|
|
|
def get_backend_name_by_class(backend) -> Optional[str]:
|
|
"""Gets the common name of a backend (e.g. "http" or "mqtt") given its class."""
|
|
|
|
from platypush.backend import Backend
|
|
|
|
if isinstance(backend, Backend):
|
|
backend = backend.__class__
|
|
|
|
class_name = backend.__name__
|
|
class_tokens = [
|
|
token.lower()
|
|
for token in re.sub(r'([A-Z])', r' \1', class_name).split(' ')
|
|
if token.strip() and token != 'Backend'
|
|
]
|
|
|
|
return '.'.join(class_tokens)
|
|
|
|
|
|
def set_timeout(seconds, on_timeout):
|
|
"""
|
|
Set a function to be called if timeout expires without being cleared.
|
|
It only works on the main thread.
|
|
|
|
Params:
|
|
seconds -- Timeout in seconds
|
|
on_timeout -- Function invoked on timeout unless clear_timeout is called before
|
|
"""
|
|
|
|
def _sighandler(*_):
|
|
on_timeout()
|
|
|
|
signal.signal(signal.SIGALRM, _sighandler)
|
|
signal.alarm(seconds)
|
|
|
|
|
|
def clear_timeout():
|
|
"""Clear any previously set timeout"""
|
|
signal.alarm(0)
|
|
|
|
|
|
def get_hash(s):
|
|
"""Get the SHA256 hash hex digest of a string input"""
|
|
return hashlib.sha256(s.encode('utf-8')).hexdigest()
|
|
|
|
|
|
def get_decorators(cls, climb_class_hierarchy=False):
|
|
"""
|
|
Get the decorators of a class as a {"decorator_name": [list of methods]} dictionary
|
|
|
|
:param cls: Class type
|
|
:param climb_class_hierarchy: If set to True (default: False), it will search return the decorators in the parent
|
|
classes as well
|
|
:type climb_class_hierarchy: bool
|
|
"""
|
|
|
|
decorators = {}
|
|
|
|
# noinspection PyPep8Naming
|
|
def visit_FunctionDef(node):
|
|
for n in node.decorator_list:
|
|
if isinstance(n, ast.Call):
|
|
# noinspection PyUnresolvedReferences
|
|
name = n.func.attr if isinstance(n.func, ast.Attribute) else n.func.id
|
|
else:
|
|
name = n.attr if isinstance(n, ast.Attribute) else n.id
|
|
|
|
decorators[name] = decorators.get(name, set())
|
|
decorators[name].add(node.name)
|
|
|
|
if climb_class_hierarchy:
|
|
targets = inspect.getmro(cls)
|
|
else:
|
|
targets = [cls]
|
|
|
|
node_iter = ast.NodeVisitor()
|
|
node_iter.visit_FunctionDef = visit_FunctionDef
|
|
|
|
for target in targets:
|
|
with contextlib.suppress(TypeError):
|
|
node_iter.visit(ast.parse(inspect.getsource(target)))
|
|
|
|
return decorators
|
|
|
|
|
|
def get_redis_queue_name_by_message(msg):
|
|
from platypush.message import Message
|
|
|
|
if not isinstance(msg, Message):
|
|
logger.warning('Not a valid message (type: {}): {}'.format(type(msg), msg))
|
|
|
|
return 'platypush/responses/{}'.format(msg.id) if msg.id else None
|
|
|
|
|
|
def _get_ssl_context(
|
|
context_type=None, ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None
|
|
):
|
|
if not context_type:
|
|
ssl_context = ssl.create_default_context(cafile=ssl_cafile, capath=ssl_capath)
|
|
else:
|
|
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
|
|
|
if ssl_cafile or ssl_capath:
|
|
ssl_context.load_verify_locations(cafile=ssl_cafile, capath=ssl_capath)
|
|
|
|
ssl_context.load_cert_chain(
|
|
certfile=os.path.abspath(os.path.expanduser(ssl_cert)),
|
|
keyfile=os.path.abspath(os.path.expanduser(ssl_key)) if ssl_key else None,
|
|
)
|
|
|
|
return ssl_context
|
|
|
|
|
|
def get_ssl_context(ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None):
|
|
return _get_ssl_context(
|
|
context_type=None,
|
|
ssl_cert=ssl_cert,
|
|
ssl_key=ssl_key,
|
|
ssl_cafile=ssl_cafile,
|
|
ssl_capath=ssl_capath,
|
|
)
|
|
|
|
|
|
def get_ssl_server_context(
|
|
ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None
|
|
):
|
|
return _get_ssl_context(
|
|
context_type=ssl.PROTOCOL_TLS_SERVER,
|
|
ssl_cert=ssl_cert,
|
|
ssl_key=ssl_key,
|
|
ssl_cafile=ssl_cafile,
|
|
ssl_capath=ssl_capath,
|
|
)
|
|
|
|
|
|
def get_ssl_client_context(
|
|
ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None
|
|
):
|
|
return _get_ssl_context(
|
|
context_type=ssl.PROTOCOL_TLS_CLIENT,
|
|
ssl_cert=ssl_cert,
|
|
ssl_key=ssl_key,
|
|
ssl_cafile=ssl_cafile,
|
|
ssl_capath=ssl_capath,
|
|
)
|
|
|
|
|
|
def set_thread_name(name):
|
|
global logger
|
|
|
|
try:
|
|
import prctl
|
|
|
|
# noinspection PyUnresolvedReferences
|
|
prctl.set_name(name)
|
|
except ImportError:
|
|
logger.debug('Unable to set thread name: prctl module is missing')
|
|
|
|
|
|
def find_bins_in_path(bin_name):
|
|
return [
|
|
os.path.join(p, bin_name)
|
|
for p in os.environ.get('PATH', '').split(':')
|
|
if os.path.isfile(os.path.join(p, bin_name))
|
|
and (os.name == 'nt' or os.access(os.path.join(p, bin_name), os.X_OK))
|
|
]
|
|
|
|
|
|
def find_files_by_ext(directory, *exts):
|
|
"""
|
|
Finds all the files in the given directory with the provided extensions
|
|
"""
|
|
|
|
if not exts:
|
|
raise AttributeError('No extensions provided')
|
|
|
|
if not os.path.isdir(directory):
|
|
raise AttributeError('{} is not a valid directory'.format(directory))
|
|
|
|
min_len = len(min(exts, key=len))
|
|
max_len = len(max(exts, key=len))
|
|
result = []
|
|
|
|
for _, __, files in os.walk(directory):
|
|
for i in range(min_len, max_len + 1):
|
|
result += [f for f in files if f[-i:] in exts]
|
|
|
|
return result
|
|
|
|
|
|
def is_process_alive(pid):
|
|
try:
|
|
os.kill(pid, 0)
|
|
return True
|
|
except OSError:
|
|
return False
|
|
|
|
|
|
def get_ip_or_hostname():
|
|
ip = socket.gethostbyname(socket.gethostname())
|
|
if ip.startswith('127.') or ip.startswith('::1'):
|
|
try:
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
sock.connect(('10.255.255.255', 1))
|
|
ip = sock.getsockname()[0]
|
|
sock.close()
|
|
except Exception as e:
|
|
logger.debug(e)
|
|
|
|
return ip
|
|
|
|
|
|
def get_mime_type(resource):
|
|
import magic
|
|
|
|
if resource.startswith('file://'):
|
|
resource = resource[len('file://') :]
|
|
|
|
# noinspection HttpUrlsUsage
|
|
if resource.startswith('http://') or resource.startswith('https://'):
|
|
with urllib.request.urlopen(resource) as response:
|
|
return response.info().get_content_type()
|
|
else:
|
|
if hasattr(magic, 'detect_from_filename'):
|
|
mime = magic.detect_from_filename(resource)
|
|
elif hasattr(magic, 'from_file'):
|
|
mime = magic.from_file(resource, mime=True)
|
|
else:
|
|
raise RuntimeError(
|
|
'The installed magic version provides neither detect_from_filename nor from_file'
|
|
)
|
|
|
|
if mime:
|
|
return mime.mime_type if hasattr(mime, 'mime_type') else mime
|
|
|
|
|
|
def camel_case_to_snake_case(string):
|
|
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', string)
|
|
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
|
|
|
|
|
|
def grouper(n, iterable, fillvalue=None):
|
|
"""
|
|
Split an iterable in groups of max N elements.
|
|
grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx
|
|
"""
|
|
from itertools import zip_longest
|
|
|
|
args = [iter(iterable)] * n
|
|
|
|
if fillvalue:
|
|
return zip_longest(fillvalue=fillvalue, *args)
|
|
|
|
for chunk in zip_longest(*args):
|
|
yield filter(None, chunk)
|
|
|
|
|
|
def is_functional_procedure(obj) -> bool:
|
|
return callable(obj) and hasattr(obj, 'procedure')
|
|
|
|
|
|
def is_functional_hook(obj) -> bool:
|
|
return callable(obj) and hasattr(obj, 'hook')
|
|
|
|
|
|
def is_functional_cron(obj) -> bool:
|
|
return callable(obj) and hasattr(obj, 'cron') and hasattr(obj, 'cron_expression')
|
|
|
|
|
|
def run(action, *args, **kwargs):
|
|
from platypush.context import get_plugin
|
|
|
|
(module_name, method_name) = get_module_and_method_from_action(action)
|
|
plugin = get_plugin(module_name)
|
|
method = getattr(plugin, method_name)
|
|
response = method(*args, **kwargs)
|
|
|
|
if response.errors:
|
|
raise RuntimeError(response.errors[0])
|
|
|
|
return response.output
|
|
|
|
|
|
def generate_rsa_key_pair(
|
|
key_file: Optional[str] = None, size: int = 2048
|
|
) -> Tuple[str, str]:
|
|
"""
|
|
Generate an RSA key pair.
|
|
|
|
:param key_file: Target file for the private key (the associated public key will be stored in ``<key_file>.pub``.
|
|
If no key file is specified then the public and private keys will be returned in ASCII format in a dictionary
|
|
with the following structure:
|
|
|
|
.. code-block:: json
|
|
|
|
{
|
|
"private": "private key here",
|
|
"public": "public key here"
|
|
}
|
|
|
|
:param size: Key size (default: 2048 bits).
|
|
:return: A tuple with the generated ``(priv_key_str, pub_key_str)``.
|
|
"""
|
|
from cryptography.hazmat.primitives import serialization
|
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
from cryptography.hazmat.backends import default_backend
|
|
|
|
public_exp = 65537
|
|
private_key = rsa.generate_private_key(
|
|
public_exponent=public_exp, key_size=size, backend=default_backend()
|
|
)
|
|
|
|
logger.info('Generating RSA {} key pair'.format(size))
|
|
private_key_str = private_key.private_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
|
encryption_algorithm=serialization.NoEncryption(),
|
|
).decode()
|
|
|
|
public_key_str = (
|
|
private_key.public_key()
|
|
.public_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PublicFormat.PKCS1,
|
|
)
|
|
.decode()
|
|
)
|
|
|
|
if key_file:
|
|
logger.info('Saving private key to {}'.format(key_file))
|
|
with open(os.path.expanduser(key_file), 'w') as f1, open(
|
|
os.path.expanduser(key_file) + '.pub', 'w'
|
|
) as f2:
|
|
f1.write(private_key_str)
|
|
f2.write(public_key_str)
|
|
os.chmod(key_file, 0o600)
|
|
|
|
return public_key_str, private_key_str
|
|
|
|
|
|
def get_or_generate_jwt_rsa_key_pair():
|
|
from platypush.config import Config
|
|
|
|
key_dir = os.path.join(Config.get('workdir'), 'jwt')
|
|
priv_key_file = os.path.join(key_dir, 'id_rsa')
|
|
pub_key_file = priv_key_file + '.pub'
|
|
|
|
if os.path.isfile(priv_key_file) and os.path.isfile(pub_key_file):
|
|
with open(pub_key_file, 'r') as f1, open(priv_key_file, 'r') as f2:
|
|
return f1.read(), f2.read()
|
|
|
|
pathlib.Path(key_dir).mkdir(parents=True, exist_ok=True, mode=0o755)
|
|
return generate_rsa_key_pair(priv_key_file, size=2048)
|
|
|
|
|
|
def get_enabled_plugins() -> dict:
|
|
from platypush.config import Config
|
|
from platypush.context import get_plugin
|
|
|
|
plugins = {}
|
|
for name in Config.get_plugins():
|
|
try:
|
|
plugin = get_plugin(name)
|
|
if plugin:
|
|
plugins[name] = plugin
|
|
except Exception as e:
|
|
logger.warning(f'Could not initialize plugin {name}')
|
|
logger.exception(e)
|
|
|
|
return plugins
|
|
|
|
|
|
def get_redis() -> Redis:
|
|
from platypush.config import Config
|
|
|
|
return Redis(
|
|
**(
|
|
(Config.get('backend.redis') or {}).get('redis_args', {})
|
|
or Config.get('redis')
|
|
or {}
|
|
)
|
|
)
|
|
|
|
|
|
def to_datetime(t: Union[str, int, float, datetime.datetime]) -> datetime.datetime:
|
|
if isinstance(t, (int, float)):
|
|
return datetime.datetime.fromtimestamp(t, tz=tz.tzutc())
|
|
if isinstance(t, str):
|
|
return parser.parse(t)
|
|
return t
|
|
|
|
|
|
# vim:sw=4:ts=4:et:
|