Use Redis instead of an in-process map to store the entity/plugin registry

This is particularly useful when we want to access the registry from
another process, like the web server or an external script.
This commit is contained in:
Fabio Manganiello 2022-04-07 00:18:11 +02:00
parent 7b1a63e287
commit 26ffc0b0e1
Signed by untrusted user: blacklight
GPG key ID: D90FBA7F76362774
3 changed files with 127 additions and 79 deletions

View file

@ -1,15 +1,16 @@
import warnings import warnings
from typing import Collection, Optional from typing import Collection, Optional
from ._base import Entity from ._base import Entity, get_entities_registry
from ._engine import EntitiesEngine from ._engine import EntitiesEngine
from ._registry import manages, register_entity_plugin, get_plugin_registry from ._registry import manages, register_entity_plugin, get_plugin_entity_registry
_engine: Optional[EntitiesEngine] = None _engine: Optional[EntitiesEngine] = None
def init_entities_engine() -> EntitiesEngine: def init_entities_engine() -> EntitiesEngine:
from ._base import init_entities_db from ._base import init_entities_db
global _engine global _engine
init_entities_db() init_entities_db()
_engine = EntitiesEngine() _engine = EntitiesEngine()
@ -24,13 +25,14 @@ def publish_entities(entities: Collection[Entity]):
_engine.post(*entities) _engine.post(*entities)
__all__ = ( __all__ = (
'Entity', 'Entity',
'EntitiesEngine', 'EntitiesEngine',
'init_entities_engine', 'init_entities_engine',
'publish_entities', 'publish_entities',
'register_entity_plugin', 'register_entity_plugin',
'get_plugin_registry', 'get_plugin_entity_registry',
'get_entities_registry',
'manages', 'manages',
) )

View file

@ -1,24 +1,38 @@
import json
from datetime import datetime from datetime import datetime
from typing import Optional, Mapping, Dict, Collection, Type from typing import Optional, Dict, Collection, Type
from platypush.plugins import Plugin from platypush.plugins import Plugin
from platypush.utils import get_plugin_name_by_class from platypush.utils import get_plugin_name_by_class, get_redis
from ._base import Entity from ._base import Entity
_entity_plugin_registry: Mapping[Type[Entity], Dict[str, Plugin]] = {} _entity_registry_varname = '_platypush/plugin_entity_registry'
def register_entity_plugin(entity_type: Type[Entity], plugin: Plugin): def register_entity_plugin(entity_type: Type[Entity], plugin: Plugin):
plugins = _entity_plugin_registry.get(entity_type, {}) plugin_name = get_plugin_name_by_class(plugin.__class__) or ''
plugin_name = get_plugin_name_by_class(plugin.__class__) entity_type_name = entity_type.__name__.lower()
assert plugin_name redis = get_redis()
plugins[plugin_name] = plugin registry = get_plugin_entity_registry()
_entity_plugin_registry[entity_type] = plugins registry_by_plugin = set(registry['by_plugin'].get(plugin_name, []))
registry_by_entity_type = set(registry['by_entity_type'].get(entity_type_name, []))
registry_by_plugin.add(entity_type_name)
registry_by_entity_type.add(plugin_name)
registry['by_plugin'][plugin_name] = list(registry_by_plugin)
registry['by_entity_type'][entity_type_name] = list(registry_by_entity_type)
redis.mset({_entity_registry_varname: json.dumps(registry)})
def get_plugin_registry(): def get_plugin_entity_registry() -> Dict[str, Dict[str, Collection[str]]]:
return _entity_plugin_registry.copy() redis = get_redis()
registry = redis.mget([_entity_registry_varname])[0]
try:
return json.loads((registry or b'').decode())
except (TypeError, ValueError):
return {'by_plugin': {}, 'by_entity_type': {}}
class EntityManagerMixin: class EntityManagerMixin:
@ -37,6 +51,7 @@ class EntityManagerMixin:
def publish_entities(self, entities: Optional[Collection[Entity]]): def publish_entities(self, entities: Optional[Collection[Entity]]):
from . import publish_entities from . import publish_entities
entities = self.transform_entities(entities) entities = self.transform_entities(entities)
publish_entities(entities) publish_entities(entities)
@ -59,4 +74,3 @@ def manages(*entities: Type[Entity]):
return plugin return plugin
return wrapper return wrapper

View file

@ -1,4 +1,5 @@
import ast import ast
import contextlib
import datetime import datetime
import hashlib import hashlib
import importlib import importlib
@ -20,8 +21,8 @@ logger = logging.getLogger('utils')
def get_module_and_method_from_action(action): def get_module_and_method_from_action(action):
""" Input : action=music.mpd.play """Input : action=music.mpd.play
Output : ('music.mpd', 'play') """ Output : ('music.mpd', 'play')"""
tokens = action.split('.') tokens = action.split('.')
module_name = str.join('.', tokens[:-1]) module_name = str.join('.', tokens[:-1])
@ -30,7 +31,7 @@ def get_module_and_method_from_action(action):
def get_message_class_by_type(msgtype): def get_message_class_by_type(msgtype):
""" Gets the class of a message type given as string """ """Gets the class of a message type given as string"""
try: try:
module = importlib.import_module('platypush.message.' + msgtype) module = importlib.import_module('platypush.message.' + msgtype)
@ -43,8 +44,7 @@ def get_message_class_by_type(msgtype):
try: try:
msgclass = getattr(module, cls_name) msgclass = getattr(module, cls_name)
except AttributeError as e: except AttributeError as e:
logger.warning('No such class in {}: {}'.format( logger.warning('No such class in {}: {}'.format(module.__name__, cls_name))
module.__name__, cls_name))
raise RuntimeError(e) raise RuntimeError(e)
return msgclass return msgclass
@ -52,13 +52,13 @@ def get_message_class_by_type(msgtype):
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
def get_event_class_by_type(type): def get_event_class_by_type(type):
""" Gets an event class by type name """ """Gets an event class by type name"""
event_module = importlib.import_module('.'.join(type.split('.')[:-1])) event_module = importlib.import_module('.'.join(type.split('.')[:-1]))
return getattr(event_module, type.split('.')[-1]) return getattr(event_module, type.split('.')[-1])
def get_plugin_module_by_name(plugin_name): def get_plugin_module_by_name(plugin_name):
""" Gets the module of a plugin by name (e.g. "music.mpd" or "media.vlc") """ """Gets the module of a plugin by name (e.g. "music.mpd" or "media.vlc")"""
module_name = 'platypush.plugins.' + plugin_name module_name = 'platypush.plugins.' + plugin_name
try: try:
@ -69,22 +69,26 @@ def get_plugin_module_by_name(plugin_name):
def get_plugin_class_by_name(plugin_name): def get_plugin_class_by_name(plugin_name):
""" Gets the class of a plugin by name (e.g. "music.mpd" or "media.vlc") """ """Gets the class of a plugin by name (e.g. "music.mpd" or "media.vlc")"""
module = get_plugin_module_by_name(plugin_name) module = get_plugin_module_by_name(plugin_name)
if not module: if not module:
return return
class_name = getattr(module, ''.join([_.capitalize() for _ in plugin_name.split('.')]) + 'Plugin') class_name = getattr(
module, ''.join([_.capitalize() for _ in plugin_name.split('.')]) + 'Plugin'
)
try: try:
return getattr(module, ''.join([_.capitalize() for _ in plugin_name.split('.')]) + 'Plugin') return getattr(
module, ''.join([_.capitalize() for _ in plugin_name.split('.')]) + 'Plugin'
)
except Exception as e: except Exception as e:
logger.error('Cannot import class {}: {}'.format(class_name, str(e))) logger.error('Cannot import class {}: {}'.format(class_name, str(e)))
return None return None
def get_plugin_name_by_class(plugin) -> Optional[str]: def get_plugin_name_by_class(plugin) -> Optional[str]:
"""Gets the common name of a plugin (e.g. "music.mpd" or "media.vlc") given its class. """ """Gets the common name of a plugin (e.g. "music.mpd" or "media.vlc") given its class."""
from platypush.plugins import Plugin from platypush.plugins import Plugin
@ -93,7 +97,8 @@ def get_plugin_name_by_class(plugin) -> Optional[str]:
class_name = plugin.__name__ class_name = plugin.__name__
class_tokens = [ class_tokens = [
token.lower() for token in re.sub(r'([A-Z])', r' \1', class_name).split(' ') token.lower()
for token in re.sub(r'([A-Z])', r' \1', class_name).split(' ')
if token.strip() and token != 'Plugin' if token.strip() and token != 'Plugin'
] ]
@ -101,7 +106,7 @@ def get_plugin_name_by_class(plugin) -> Optional[str]:
def get_backend_name_by_class(backend) -> Optional[str]: def get_backend_name_by_class(backend) -> Optional[str]:
"""Gets the common name of a backend (e.g. "http" or "mqtt") given its class. """ """Gets the common name of a backend (e.g. "http" or "mqtt") given its class."""
from platypush.backend import Backend from platypush.backend import Backend
@ -110,7 +115,8 @@ def get_backend_name_by_class(backend) -> Optional[str]:
class_name = backend.__name__ class_name = backend.__name__
class_tokens = [ class_tokens = [
token.lower() for token in re.sub(r'([A-Z])', r' \1', class_name).split(' ') token.lower()
for token in re.sub(r'([A-Z])', r' \1', class_name).split(' ')
if token.strip() and token != 'Backend' if token.strip() and token != 'Backend'
] ]
@ -135,12 +141,12 @@ def set_timeout(seconds, on_timeout):
def clear_timeout(): def clear_timeout():
""" Clear any previously set timeout """ """Clear any previously set timeout"""
signal.alarm(0) signal.alarm(0)
def get_hash(s): def get_hash(s):
""" Get the SHA256 hash hex digest of a string input """ """Get the SHA256 hash hex digest of a string input"""
return hashlib.sha256(s.encode('utf-8')).hexdigest() return hashlib.sha256(s.encode('utf-8')).hexdigest()
@ -177,11 +183,8 @@ def get_decorators(cls, climb_class_hierarchy=False):
node_iter.visit_FunctionDef = visit_FunctionDef node_iter.visit_FunctionDef = visit_FunctionDef
for target in targets: for target in targets:
try: with contextlib.suppress(TypeError):
node_iter.visit(ast.parse(inspect.getsource(target))) node_iter.visit(ast.parse(inspect.getsource(target)))
except TypeError:
# Ignore built-in classes
pass
return decorators return decorators
@ -195,45 +198,57 @@ def get_redis_queue_name_by_message(msg):
return 'platypush/responses/{}'.format(msg.id) if msg.id else None return 'platypush/responses/{}'.format(msg.id) if msg.id else None
def _get_ssl_context(context_type=None, ssl_cert=None, ssl_key=None, def _get_ssl_context(
ssl_cafile=None, ssl_capath=None): context_type=None, ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None
):
if not context_type: if not context_type:
ssl_context = ssl.create_default_context(cafile=ssl_cafile, ssl_context = ssl.create_default_context(cafile=ssl_cafile, capath=ssl_capath)
capath=ssl_capath)
else: else:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
if ssl_cafile or ssl_capath: if ssl_cafile or ssl_capath:
ssl_context.load_verify_locations( ssl_context.load_verify_locations(cafile=ssl_cafile, capath=ssl_capath)
cafile=ssl_cafile, capath=ssl_capath)
ssl_context.load_cert_chain( ssl_context.load_cert_chain(
certfile=os.path.abspath(os.path.expanduser(ssl_cert)), certfile=os.path.abspath(os.path.expanduser(ssl_cert)),
keyfile=os.path.abspath(os.path.expanduser(ssl_key)) if ssl_key else None keyfile=os.path.abspath(os.path.expanduser(ssl_key)) if ssl_key else None,
) )
return ssl_context return ssl_context
def get_ssl_context(ssl_cert=None, ssl_key=None, ssl_cafile=None, def get_ssl_context(ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None):
ssl_capath=None): return _get_ssl_context(
return _get_ssl_context(context_type=None, context_type=None,
ssl_cert=ssl_cert, ssl_key=ssl_key, ssl_cert=ssl_cert,
ssl_cafile=ssl_cafile, ssl_capath=ssl_capath) ssl_key=ssl_key,
ssl_cafile=ssl_cafile,
ssl_capath=ssl_capath,
)
def get_ssl_server_context(ssl_cert=None, ssl_key=None, ssl_cafile=None, def get_ssl_server_context(
ssl_capath=None): ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None
return _get_ssl_context(context_type=ssl.PROTOCOL_TLS_SERVER, ):
ssl_cert=ssl_cert, ssl_key=ssl_key, return _get_ssl_context(
ssl_cafile=ssl_cafile, ssl_capath=ssl_capath) context_type=ssl.PROTOCOL_TLS_SERVER,
ssl_cert=ssl_cert,
ssl_key=ssl_key,
ssl_cafile=ssl_cafile,
ssl_capath=ssl_capath,
)
def get_ssl_client_context(ssl_cert=None, ssl_key=None, ssl_cafile=None, def get_ssl_client_context(
ssl_capath=None): ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None
return _get_ssl_context(context_type=ssl.PROTOCOL_TLS_CLIENT, ):
ssl_cert=ssl_cert, ssl_key=ssl_key, return _get_ssl_context(
ssl_cafile=ssl_cafile, ssl_capath=ssl_capath) context_type=ssl.PROTOCOL_TLS_CLIENT,
ssl_cert=ssl_cert,
ssl_key=ssl_key,
ssl_cafile=ssl_cafile,
ssl_capath=ssl_capath,
)
def set_thread_name(name): def set_thread_name(name):
@ -241,6 +256,7 @@ def set_thread_name(name):
try: try:
import prctl import prctl
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
prctl.set_name(name) prctl.set_name(name)
except ImportError: except ImportError:
@ -251,9 +267,9 @@ def find_bins_in_path(bin_name):
return [ return [
os.path.join(p, bin_name) os.path.join(p, bin_name)
for p in os.environ.get('PATH', '').split(':') for p in os.environ.get('PATH', '').split(':')
if os.path.isfile(os.path.join(p, bin_name)) and ( if os.path.isfile(os.path.join(p, bin_name))
os.name == 'nt' or os.access(os.path.join(p, bin_name), os.X_OK) and (os.name == 'nt' or os.access(os.path.join(p, bin_name), os.X_OK))
)] ]
def find_files_by_ext(directory, *exts): def find_files_by_ext(directory, *exts):
@ -271,7 +287,7 @@ def find_files_by_ext(directory, *exts):
max_len = len(max(exts, key=len)) max_len = len(max(exts, key=len))
result = [] result = []
for root, dirs, files in os.walk(directory): for _, __, files in os.walk(directory):
for i in range(min_len, max_len + 1): for i in range(min_len, max_len + 1):
result += [f for f in files if f[-i:] in exts] result += [f for f in files if f[-i:] in exts]
@ -302,8 +318,9 @@ def get_ip_or_hostname():
def get_mime_type(resource): def get_mime_type(resource):
import magic import magic
if resource.startswith('file://'): if resource.startswith('file://'):
resource = resource[len('file://'):] resource = resource[len('file://') :]
# noinspection HttpUrlsUsage # noinspection HttpUrlsUsage
if resource.startswith('http://') or resource.startswith('https://'): if resource.startswith('http://') or resource.startswith('https://'):
@ -315,7 +332,9 @@ def get_mime_type(resource):
elif hasattr(magic, 'from_file'): elif hasattr(magic, 'from_file'):
mime = magic.from_file(resource, mime=True) mime = magic.from_file(resource, mime=True)
else: else:
raise RuntimeError('The installed magic version provides neither detect_from_filename nor from_file') raise RuntimeError(
'The installed magic version provides neither detect_from_filename nor from_file'
)
if mime: if mime:
return mime.mime_type if hasattr(mime, 'mime_type') else mime return mime.mime_type if hasattr(mime, 'mime_type') else mime
@ -332,6 +351,7 @@ def grouper(n, iterable, fillvalue=None):
grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx
""" """
from itertools import zip_longest from itertools import zip_longest
args = [iter(iterable)] * n args = [iter(iterable)] * n
if fillvalue: if fillvalue:
@ -355,6 +375,7 @@ def is_functional_cron(obj) -> bool:
def run(action, *args, **kwargs): def run(action, *args, **kwargs):
from platypush.context import get_plugin from platypush.context import get_plugin
(module_name, method_name) = get_module_and_method_from_action(action) (module_name, method_name) = get_module_and_method_from_action(action)
plugin = get_plugin(module_name) plugin = get_plugin(module_name)
method = getattr(plugin, method_name) method = getattr(plugin, method_name)
@ -366,7 +387,9 @@ def run(action, *args, **kwargs):
return response.output return response.output
def generate_rsa_key_pair(key_file: Optional[str] = None, size: int = 2048) -> Tuple[str, str]: def generate_rsa_key_pair(
key_file: Optional[str] = None, size: int = 2048
) -> Tuple[str, str]:
""" """
Generate an RSA key pair. Generate an RSA key pair.
@ -390,27 +413,30 @@ def generate_rsa_key_pair(key_file: Optional[str] = None, size: int = 2048) -> T
public_exp = 65537 public_exp = 65537
private_key = rsa.generate_private_key( private_key = rsa.generate_private_key(
public_exponent=public_exp, public_exponent=public_exp, key_size=size, backend=default_backend()
key_size=size,
backend=default_backend()
) )
logger.info('Generating RSA {} key pair'.format(size)) logger.info('Generating RSA {} key pair'.format(size))
private_key_str = private_key.private_bytes( private_key_str = private_key.private_bytes(
encoding=serialization.Encoding.PEM, encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL, format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption() encryption_algorithm=serialization.NoEncryption(),
).decode() ).decode()
public_key_str = private_key.public_key().public_bytes( public_key_str = (
private_key.public_key()
.public_bytes(
encoding=serialization.Encoding.PEM, encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.PKCS1, format=serialization.PublicFormat.PKCS1,
).decode() )
.decode()
)
if key_file: if key_file:
logger.info('Saving private key to {}'.format(key_file)) logger.info('Saving private key to {}'.format(key_file))
with open(os.path.expanduser(key_file), 'w') as f1, \ with open(os.path.expanduser(key_file), 'w') as f1, open(
open(os.path.expanduser(key_file) + '.pub', 'w') as f2: os.path.expanduser(key_file) + '.pub', 'w'
) as f2:
f1.write(private_key_str) f1.write(private_key_str)
f2.write(public_key_str) f2.write(public_key_str)
os.chmod(key_file, 0o600) os.chmod(key_file, 0o600)
@ -426,8 +452,7 @@ def get_or_generate_jwt_rsa_key_pair():
pub_key_file = priv_key_file + '.pub' pub_key_file = priv_key_file + '.pub'
if os.path.isfile(priv_key_file) and os.path.isfile(pub_key_file): if os.path.isfile(priv_key_file) and os.path.isfile(pub_key_file):
with open(pub_key_file, 'r') as f1, \ with open(pub_key_file, 'r') as f1, open(priv_key_file, 'r') as f2:
open(priv_key_file, 'r') as f2:
return f1.read(), f2.read() return f1.read(), f2.read()
pathlib.Path(key_dir).mkdir(parents=True, exist_ok=True, mode=0o755) pathlib.Path(key_dir).mkdir(parents=True, exist_ok=True, mode=0o755)
@ -439,7 +464,7 @@ def get_enabled_plugins() -> dict:
from platypush.context import get_plugin from platypush.context import get_plugin
plugins = {} plugins = {}
for name, config in Config.get_plugins().items(): for name in Config.get_plugins():
try: try:
plugin = get_plugin(name) plugin = get_plugin(name)
if plugin: if plugin:
@ -453,11 +478,18 @@ def get_enabled_plugins() -> dict:
def get_redis() -> Redis: def get_redis() -> Redis:
from platypush.config import Config from platypush.config import Config
return Redis(**(Config.get('backend.redis') or {}).get('redis_args', {}))
return Redis(
**(
(Config.get('backend.redis') or {}).get('redis_args', {})
or Config.get('redis')
or {}
)
)
def to_datetime(t: Union[str, int, float, datetime.datetime]) -> datetime.datetime: def to_datetime(t: Union[str, int, float, datetime.datetime]) -> datetime.datetime:
if isinstance(t, int) or isinstance(t, float): if isinstance(t, (int, float)):
return datetime.datetime.fromtimestamp(t, tz=tz.tzutc()) return datetime.datetime.fromtimestamp(t, tz=tz.tzutc())
if isinstance(t, str): if isinstance(t, str):
return parser.parse(t) return parser.parse(t)