forked from platypush/platypush
441 lines
13 KiB
Python
441 lines
13 KiB
Python
import ast
|
|
import hashlib
|
|
import importlib
|
|
import inspect
|
|
import logging
|
|
import os
|
|
import pathlib
|
|
import re
|
|
import signal
|
|
import socket
|
|
import ssl
|
|
import urllib.request
|
|
from typing import Optional, Tuple
|
|
|
|
logger = logging.getLogger('utils')
|
|
|
|
|
|
def get_module_and_method_from_action(action):
|
|
""" Input : action=music.mpd.play
|
|
Output : ('music.mpd', 'play') """
|
|
|
|
tokens = action.split('.')
|
|
module_name = str.join('.', tokens[:-1])
|
|
method_name = tokens[-1:][0]
|
|
return module_name, method_name
|
|
|
|
|
|
def get_message_class_by_type(msgtype):
|
|
""" Gets the class of a message type given as string """
|
|
|
|
try:
|
|
module = importlib.import_module('platypush.message.' + msgtype)
|
|
except ImportError as e:
|
|
logger.warning('Unsupported message type {}'.format(msgtype))
|
|
raise RuntimeError(e)
|
|
|
|
cls_name = msgtype[0].upper() + msgtype[1:]
|
|
|
|
try:
|
|
msgclass = getattr(module, cls_name)
|
|
except AttributeError as e:
|
|
logger.warning('No such class in {}: {}'.format(
|
|
module.__name__, cls_name))
|
|
raise RuntimeError(e)
|
|
|
|
return msgclass
|
|
|
|
|
|
# noinspection PyShadowingBuiltins
|
|
def get_event_class_by_type(type):
|
|
""" Gets an event class by type name """
|
|
event_module = importlib.import_module('.'.join(type.split('.')[:-1]))
|
|
return getattr(event_module, type.split('.')[-1])
|
|
|
|
|
|
def get_plugin_module_by_name(plugin_name):
|
|
""" Gets the module of a plugin by name (e.g. "music.mpd" or "media.vlc") """
|
|
|
|
module_name = 'platypush.plugins.' + plugin_name
|
|
try:
|
|
return importlib.import_module('platypush.plugins.' + plugin_name)
|
|
except ImportError as e:
|
|
logger.error('Cannot import {}: {}'.format(module_name, str(e)))
|
|
return None
|
|
|
|
|
|
def get_plugin_class_by_name(plugin_name):
|
|
""" Gets the class of a plugin by name (e.g. "music.mpd" or "media.vlc") """
|
|
|
|
module = get_plugin_module_by_name(plugin_name)
|
|
if not module:
|
|
return
|
|
|
|
class_name = getattr(module, ''.join([_.capitalize() for _ in plugin_name.split('.')]) + 'Plugin')
|
|
try:
|
|
return getattr(module, ''.join([_.capitalize() for _ in plugin_name.split('.')]) + 'Plugin')
|
|
except Exception as e:
|
|
logger.error('Cannot import class {}: {}'.format(class_name, str(e)))
|
|
return None
|
|
|
|
|
|
def get_plugin_name_by_class(plugin) -> Optional[str]:
|
|
"""Gets the common name of a plugin (e.g. "music.mpd" or "media.vlc") given its class. """
|
|
|
|
from platypush.plugins import Plugin
|
|
|
|
if isinstance(plugin, Plugin):
|
|
plugin = plugin.__class__
|
|
|
|
class_name = plugin.__name__
|
|
class_tokens = [
|
|
token.lower() for token in re.sub(r'([A-Z])', r' \1', class_name).split(' ')
|
|
if token.strip() and token != 'Plugin'
|
|
]
|
|
|
|
return '.'.join(class_tokens)
|
|
|
|
|
|
def get_backend_name_by_class(backend) -> Optional[str]:
|
|
"""Gets the common name of a backend (e.g. "http" or "mqtt") given its class. """
|
|
|
|
from platypush.backend import Backend
|
|
|
|
if isinstance(backend, Backend):
|
|
backend = backend.__class__
|
|
|
|
class_name = backend.__name__
|
|
class_tokens = [
|
|
token.lower() for token in re.sub(r'([A-Z])', r' \1', class_name).split(' ')
|
|
if token.strip() and token != 'Backend'
|
|
]
|
|
|
|
return '.'.join(class_tokens)
|
|
|
|
|
|
def set_timeout(seconds, on_timeout):
|
|
"""
|
|
Set a function to be called if timeout expires without being cleared.
|
|
It only works on the main thread.
|
|
|
|
Params:
|
|
seconds -- Timeout in seconds
|
|
on_timeout -- Function invoked on timeout unless clear_timeout is called before
|
|
"""
|
|
|
|
def _sighandler(*_):
|
|
on_timeout()
|
|
|
|
signal.signal(signal.SIGALRM, _sighandler)
|
|
signal.alarm(seconds)
|
|
|
|
|
|
def clear_timeout():
|
|
""" Clear any previously set timeout """
|
|
signal.alarm(0)
|
|
|
|
|
|
def get_hash(s):
|
|
""" Get the SHA256 hash hexdigest of a string input """
|
|
return hashlib.sha256(s.encode('utf-8')).hexdigest()
|
|
|
|
|
|
def get_decorators(cls, climb_class_hierarchy=False):
|
|
"""
|
|
Get the decorators of a class as a {"decorator_name": [list of methods]} dictionary
|
|
:param climb_class_hierarchy: If set to True (default: False), it will search return the decorators in the parent classes as well
|
|
:type climb_class_hierarchy: bool
|
|
"""
|
|
|
|
decorators = {}
|
|
|
|
def visit_FunctionDef(node):
|
|
for n in node.decorator_list:
|
|
if isinstance(n, ast.Call):
|
|
name = n.func.attr if isinstance(n.func, ast.Attribute) else n.func.id
|
|
else:
|
|
name = n.attr if isinstance(n, ast.Attribute) else n.id
|
|
|
|
decorators[name] = decorators.get(name, set())
|
|
decorators[name].add(node.name)
|
|
|
|
if climb_class_hierarchy:
|
|
targets = inspect.getmro(cls)
|
|
else:
|
|
targets = [cls]
|
|
|
|
node_iter = ast.NodeVisitor()
|
|
node_iter.visit_FunctionDef = visit_FunctionDef
|
|
|
|
for target in targets:
|
|
try:
|
|
node_iter.visit(ast.parse(inspect.getsource(target)))
|
|
except TypeError:
|
|
# Ignore built-in classes
|
|
pass
|
|
|
|
return decorators
|
|
|
|
|
|
def get_redis_queue_name_by_message(msg):
|
|
from platypush.message import Message
|
|
|
|
if not isinstance(msg, Message):
|
|
logger.warning('Not a valid message (type: {}): {}'.format(type(msg), msg))
|
|
|
|
return 'platypush/responses/{}'.format(msg.id) if msg.id else None
|
|
|
|
|
|
def _get_ssl_context(context_type=None, ssl_cert=None, ssl_key=None,
|
|
ssl_cafile=None, ssl_capath=None):
|
|
if not context_type:
|
|
ssl_context = ssl.create_default_context(cafile=ssl_cafile,
|
|
capath=ssl_capath)
|
|
else:
|
|
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
|
|
|
if ssl_cafile or ssl_capath:
|
|
ssl_context.load_verify_locations(
|
|
cafile=ssl_cafile, capath=ssl_capath)
|
|
|
|
ssl_context.load_cert_chain(
|
|
certfile=os.path.abspath(os.path.expanduser(ssl_cert)),
|
|
keyfile=os.path.abspath(os.path.expanduser(ssl_key)) if ssl_key else None
|
|
)
|
|
|
|
return ssl_context
|
|
|
|
|
|
def get_ssl_context(ssl_cert=None, ssl_key=None, ssl_cafile=None,
|
|
ssl_capath=None):
|
|
return _get_ssl_context(context_type=None,
|
|
ssl_cert=ssl_cert, ssl_key=ssl_key,
|
|
ssl_cafile=ssl_cafile, ssl_capath=ssl_capath)
|
|
|
|
|
|
def get_ssl_server_context(ssl_cert=None, ssl_key=None, ssl_cafile=None,
|
|
ssl_capath=None):
|
|
return _get_ssl_context(context_type=ssl.PROTOCOL_TLS_SERVER,
|
|
ssl_cert=ssl_cert, ssl_key=ssl_key,
|
|
ssl_cafile=ssl_cafile, ssl_capath=ssl_capath)
|
|
|
|
|
|
def get_ssl_client_context(ssl_cert=None, ssl_key=None, ssl_cafile=None,
|
|
ssl_capath=None):
|
|
return _get_ssl_context(context_type=ssl.PROTOCOL_TLS_CLIENT,
|
|
ssl_cert=ssl_cert, ssl_key=ssl_key,
|
|
ssl_cafile=ssl_cafile, ssl_capath=ssl_capath)
|
|
|
|
def set_thread_name(name):
|
|
global logger
|
|
|
|
try:
|
|
import prctl
|
|
# noinspection PyUnresolvedReferences
|
|
prctl.set_name(name)
|
|
except ImportError:
|
|
logger.debug('Unable to set thread name: prctl module is missing')
|
|
|
|
|
|
def find_bins_in_path(bin_name):
|
|
return [os.path.join(p, bin_name)
|
|
for p in os.environ.get('PATH', '').split(':')
|
|
if os.path.isfile(os.path.join(p, bin_name))
|
|
and (os.name == 'nt' or
|
|
os.access(os.path.join(p, bin_name), os.X_OK))]
|
|
|
|
|
|
def find_files_by_ext(directory, *exts):
|
|
"""
|
|
Finds all the files in the given directory with the provided extensions
|
|
"""
|
|
|
|
if not exts:
|
|
raise AttributeError('No extensions provided')
|
|
|
|
if not os.path.isdir(directory):
|
|
raise AttributeError('{} is not a valid directory'.format(directory))
|
|
|
|
min_len = len(min(exts, key=len))
|
|
max_len = len(max(exts, key=len))
|
|
result = []
|
|
|
|
for root, dirs, files in os.walk(directory):
|
|
for i in range(min_len, max_len+1):
|
|
result += [f for f in files if f[-i:] in exts]
|
|
|
|
return result
|
|
|
|
|
|
def is_process_alive(pid):
|
|
try:
|
|
os.kill(pid, 0)
|
|
return True
|
|
except OSError:
|
|
return False
|
|
|
|
|
|
def get_ip_or_hostname():
|
|
ip = socket.gethostbyname(socket.gethostname())
|
|
if ip.startswith('127.'):
|
|
try:
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
sock.connect(('10.255.255.255', 1))
|
|
ip = sock.getsockname()[0]
|
|
sock.close()
|
|
except Exception as e:
|
|
logger.debug(e)
|
|
|
|
return ip
|
|
|
|
|
|
def get_mime_type(resource):
|
|
import magic
|
|
if resource.startswith('file://'):
|
|
resource = resource[len('file://'):]
|
|
|
|
if resource.startswith('http://') or resource.startswith('https://'):
|
|
with urllib.request.urlopen(resource) as response:
|
|
return response.info().get_content_type()
|
|
else:
|
|
if hasattr(magic, 'detect_from_filename'):
|
|
mime = magic.detect_from_filename(resource)
|
|
elif hasattr(magic, 'from_file'):
|
|
mime = magic.from_file(resource, mime=True)
|
|
else:
|
|
raise RuntimeError('The installed magic version provides neither detect_from_filename nor from_file')
|
|
|
|
if mime:
|
|
return mime.mime_type if hasattr(mime, 'mime_type') else mime
|
|
|
|
|
|
def camel_case_to_snake_case(string):
|
|
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', string)
|
|
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
|
|
|
|
|
|
def grouper(n, iterable, fillvalue=None):
|
|
"""
|
|
Split an iterable in groups of max N elements.
|
|
grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx
|
|
"""
|
|
from itertools import zip_longest
|
|
args = [iter(iterable)] * n
|
|
|
|
if fillvalue:
|
|
return zip_longest(fillvalue=fillvalue, *args)
|
|
|
|
for chunk in zip_longest(*args):
|
|
yield filter(None, chunk)
|
|
|
|
|
|
def is_functional_procedure(obj) -> bool:
|
|
return callable(obj) and hasattr(obj, 'procedure')
|
|
|
|
|
|
def is_functional_hook(obj) -> bool:
|
|
return callable(obj) and hasattr(obj, 'hook')
|
|
|
|
|
|
def is_functional_cron(obj) -> bool:
|
|
return callable(obj) and hasattr(obj, 'cron') and hasattr(obj, 'cron_expression')
|
|
|
|
|
|
def run(action, *args, **kwargs):
|
|
from platypush.context import get_plugin
|
|
(module_name, method_name) = get_module_and_method_from_action(action)
|
|
plugin = get_plugin(module_name)
|
|
method = getattr(plugin, method_name)
|
|
response = method(*args, **kwargs)
|
|
|
|
if response.errors:
|
|
raise RuntimeError(response.errors[0])
|
|
|
|
return response.output
|
|
|
|
|
|
def generate_rsa_key_pair(key_file: Optional[str] = None, size: int = 2048) -> Tuple[str, str]:
|
|
"""
|
|
Generate an RSA key pair.
|
|
|
|
:param key_file: Target file for the private key (the associated public key will be stored in ``<key_file>.pub``.
|
|
If no key file is specified then the public and private keys will be returned in ASCII format in a dictionary
|
|
with the following structure:
|
|
|
|
.. code-block:: json
|
|
|
|
{
|
|
"private": "private key here",
|
|
"public": "public key here"
|
|
}
|
|
|
|
:param size: Key size (default: 2048 bits).
|
|
:return: A tuple with the generated ``(priv_key_str, pub_key_str)``.
|
|
"""
|
|
from cryptography.hazmat.primitives import serialization
|
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
from cryptography.hazmat.backends import default_backend
|
|
|
|
public_exp = 65537
|
|
private_key = rsa.generate_private_key(
|
|
public_exponent=public_exp,
|
|
key_size=size,
|
|
backend=default_backend()
|
|
)
|
|
|
|
logger.info('Generating RSA {} key pair'.format(size))
|
|
private_key_str = private_key.private_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
|
encryption_algorithm=serialization.NoEncryption()
|
|
).decode()
|
|
|
|
public_key_str = private_key.public_key().public_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PublicFormat.PKCS1,
|
|
).decode()
|
|
|
|
if key_file:
|
|
logger.info('Saving private key to {}'.format(key_file))
|
|
with open(os.path.expanduser(key_file), 'w') as f1, \
|
|
open(os.path.expanduser(key_file) + '.pub', 'w') as f2:
|
|
f1.write(private_key_str)
|
|
f2.write(public_key_str)
|
|
os.chmod(key_file, 0o600)
|
|
|
|
return public_key_str, private_key_str
|
|
|
|
|
|
def get_or_generate_jwt_rsa_key_pair():
|
|
from platypush.config import Config
|
|
|
|
key_dir = os.path.join(Config.get('workdir'), 'jwt')
|
|
priv_key_file = os.path.join(key_dir, 'id_rsa')
|
|
pub_key_file = priv_key_file + '.pub'
|
|
|
|
if os.path.isfile(priv_key_file) and os.path.isfile(pub_key_file):
|
|
with open(pub_key_file, 'r') as f1, \
|
|
open(priv_key_file, 'r') as f2:
|
|
return f1.read(), f2.read()
|
|
|
|
pathlib.Path(key_dir).mkdir(parents=True, exist_ok=True, mode=0o755)
|
|
return generate_rsa_key_pair(priv_key_file, size=2048)
|
|
|
|
|
|
def get_enabled_plugins() -> dict:
|
|
from platypush.config import Config
|
|
from platypush.context import get_plugin
|
|
|
|
plugins = {}
|
|
for name, config in Config.get_plugins().items():
|
|
try:
|
|
plugin = get_plugin(name)
|
|
if plugin:
|
|
plugins[name] = plugin
|
|
except Exception as e:
|
|
logger.debug(f'Could not get plugin {name}: {e}')
|
|
|
|
return plugins
|
|
|
|
|
|
# vim:sw=4:ts=4:et:
|