Use Redis instead of an in-process map to store the entity/plugin registry
This is particularly useful when we want to access the registry from another process, like the web server or an external script.
This commit is contained in:
parent
7b1a63e287
commit
26ffc0b0e1
3 changed files with 127 additions and 79 deletions
|
@ -1,15 +1,16 @@
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Collection, Optional
|
from typing import Collection, Optional
|
||||||
|
|
||||||
from ._base import Entity
|
from ._base import Entity, get_entities_registry
|
||||||
from ._engine import EntitiesEngine
|
from ._engine import EntitiesEngine
|
||||||
from ._registry import manages, register_entity_plugin, get_plugin_registry
|
from ._registry import manages, register_entity_plugin, get_plugin_entity_registry
|
||||||
|
|
||||||
_engine: Optional[EntitiesEngine] = None
|
_engine: Optional[EntitiesEngine] = None
|
||||||
|
|
||||||
|
|
||||||
def init_entities_engine() -> EntitiesEngine:
|
def init_entities_engine() -> EntitiesEngine:
|
||||||
from ._base import init_entities_db
|
from ._base import init_entities_db
|
||||||
|
|
||||||
global _engine
|
global _engine
|
||||||
init_entities_db()
|
init_entities_db()
|
||||||
_engine = EntitiesEngine()
|
_engine = EntitiesEngine()
|
||||||
|
@ -24,13 +25,14 @@ def publish_entities(entities: Collection[Entity]):
|
||||||
|
|
||||||
_engine.post(*entities)
|
_engine.post(*entities)
|
||||||
|
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
'Entity',
|
'Entity',
|
||||||
'EntitiesEngine',
|
'EntitiesEngine',
|
||||||
'init_entities_engine',
|
'init_entities_engine',
|
||||||
'publish_entities',
|
'publish_entities',
|
||||||
'register_entity_plugin',
|
'register_entity_plugin',
|
||||||
'get_plugin_registry',
|
'get_plugin_entity_registry',
|
||||||
|
'get_entities_registry',
|
||||||
'manages',
|
'manages',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,24 +1,38 @@
|
||||||
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, Mapping, Dict, Collection, Type
|
from typing import Optional, Dict, Collection, Type
|
||||||
|
|
||||||
from platypush.plugins import Plugin
|
from platypush.plugins import Plugin
|
||||||
from platypush.utils import get_plugin_name_by_class
|
from platypush.utils import get_plugin_name_by_class, get_redis
|
||||||
|
|
||||||
from ._base import Entity
|
from ._base import Entity
|
||||||
|
|
||||||
_entity_plugin_registry: Mapping[Type[Entity], Dict[str, Plugin]] = {}
|
_entity_registry_varname = '_platypush/plugin_entity_registry'
|
||||||
|
|
||||||
|
|
||||||
def register_entity_plugin(entity_type: Type[Entity], plugin: Plugin):
|
def register_entity_plugin(entity_type: Type[Entity], plugin: Plugin):
|
||||||
plugins = _entity_plugin_registry.get(entity_type, {})
|
plugin_name = get_plugin_name_by_class(plugin.__class__) or ''
|
||||||
plugin_name = get_plugin_name_by_class(plugin.__class__)
|
entity_type_name = entity_type.__name__.lower()
|
||||||
assert plugin_name
|
redis = get_redis()
|
||||||
plugins[plugin_name] = plugin
|
registry = get_plugin_entity_registry()
|
||||||
_entity_plugin_registry[entity_type] = plugins
|
registry_by_plugin = set(registry['by_plugin'].get(plugin_name, []))
|
||||||
|
|
||||||
|
registry_by_entity_type = set(registry['by_entity_type'].get(entity_type_name, []))
|
||||||
|
|
||||||
|
registry_by_plugin.add(entity_type_name)
|
||||||
|
registry_by_entity_type.add(plugin_name)
|
||||||
|
registry['by_plugin'][plugin_name] = list(registry_by_plugin)
|
||||||
|
registry['by_entity_type'][entity_type_name] = list(registry_by_entity_type)
|
||||||
|
redis.mset({_entity_registry_varname: json.dumps(registry)})
|
||||||
|
|
||||||
|
|
||||||
def get_plugin_registry():
|
def get_plugin_entity_registry() -> Dict[str, Dict[str, Collection[str]]]:
|
||||||
return _entity_plugin_registry.copy()
|
redis = get_redis()
|
||||||
|
registry = redis.mget([_entity_registry_varname])[0]
|
||||||
|
try:
|
||||||
|
return json.loads((registry or b'').decode())
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return {'by_plugin': {}, 'by_entity_type': {}}
|
||||||
|
|
||||||
|
|
||||||
class EntityManagerMixin:
|
class EntityManagerMixin:
|
||||||
|
@ -37,6 +51,7 @@ class EntityManagerMixin:
|
||||||
|
|
||||||
def publish_entities(self, entities: Optional[Collection[Entity]]):
|
def publish_entities(self, entities: Optional[Collection[Entity]]):
|
||||||
from . import publish_entities
|
from . import publish_entities
|
||||||
|
|
||||||
entities = self.transform_entities(entities)
|
entities = self.transform_entities(entities)
|
||||||
publish_entities(entities)
|
publish_entities(entities)
|
||||||
|
|
||||||
|
@ -59,4 +74,3 @@ def manages(*entities: Type[Entity]):
|
||||||
return plugin
|
return plugin
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import ast
|
import ast
|
||||||
|
import contextlib
|
||||||
import datetime
|
import datetime
|
||||||
import hashlib
|
import hashlib
|
||||||
import importlib
|
import importlib
|
||||||
|
@ -43,8 +44,7 @@ def get_message_class_by_type(msgtype):
|
||||||
try:
|
try:
|
||||||
msgclass = getattr(module, cls_name)
|
msgclass = getattr(module, cls_name)
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
logger.warning('No such class in {}: {}'.format(
|
logger.warning('No such class in {}: {}'.format(module.__name__, cls_name))
|
||||||
module.__name__, cls_name))
|
|
||||||
raise RuntimeError(e)
|
raise RuntimeError(e)
|
||||||
|
|
||||||
return msgclass
|
return msgclass
|
||||||
|
@ -75,9 +75,13 @@ def get_plugin_class_by_name(plugin_name):
|
||||||
if not module:
|
if not module:
|
||||||
return
|
return
|
||||||
|
|
||||||
class_name = getattr(module, ''.join([_.capitalize() for _ in plugin_name.split('.')]) + 'Plugin')
|
class_name = getattr(
|
||||||
|
module, ''.join([_.capitalize() for _ in plugin_name.split('.')]) + 'Plugin'
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
return getattr(module, ''.join([_.capitalize() for _ in plugin_name.split('.')]) + 'Plugin')
|
return getattr(
|
||||||
|
module, ''.join([_.capitalize() for _ in plugin_name.split('.')]) + 'Plugin'
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error('Cannot import class {}: {}'.format(class_name, str(e)))
|
logger.error('Cannot import class {}: {}'.format(class_name, str(e)))
|
||||||
return None
|
return None
|
||||||
|
@ -93,7 +97,8 @@ def get_plugin_name_by_class(plugin) -> Optional[str]:
|
||||||
|
|
||||||
class_name = plugin.__name__
|
class_name = plugin.__name__
|
||||||
class_tokens = [
|
class_tokens = [
|
||||||
token.lower() for token in re.sub(r'([A-Z])', r' \1', class_name).split(' ')
|
token.lower()
|
||||||
|
for token in re.sub(r'([A-Z])', r' \1', class_name).split(' ')
|
||||||
if token.strip() and token != 'Plugin'
|
if token.strip() and token != 'Plugin'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -110,7 +115,8 @@ def get_backend_name_by_class(backend) -> Optional[str]:
|
||||||
|
|
||||||
class_name = backend.__name__
|
class_name = backend.__name__
|
||||||
class_tokens = [
|
class_tokens = [
|
||||||
token.lower() for token in re.sub(r'([A-Z])', r' \1', class_name).split(' ')
|
token.lower()
|
||||||
|
for token in re.sub(r'([A-Z])', r' \1', class_name).split(' ')
|
||||||
if token.strip() and token != 'Backend'
|
if token.strip() and token != 'Backend'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -177,11 +183,8 @@ def get_decorators(cls, climb_class_hierarchy=False):
|
||||||
node_iter.visit_FunctionDef = visit_FunctionDef
|
node_iter.visit_FunctionDef = visit_FunctionDef
|
||||||
|
|
||||||
for target in targets:
|
for target in targets:
|
||||||
try:
|
with contextlib.suppress(TypeError):
|
||||||
node_iter.visit(ast.parse(inspect.getsource(target)))
|
node_iter.visit(ast.parse(inspect.getsource(target)))
|
||||||
except TypeError:
|
|
||||||
# Ignore built-in classes
|
|
||||||
pass
|
|
||||||
|
|
||||||
return decorators
|
return decorators
|
||||||
|
|
||||||
|
@ -195,45 +198,57 @@ def get_redis_queue_name_by_message(msg):
|
||||||
return 'platypush/responses/{}'.format(msg.id) if msg.id else None
|
return 'platypush/responses/{}'.format(msg.id) if msg.id else None
|
||||||
|
|
||||||
|
|
||||||
def _get_ssl_context(context_type=None, ssl_cert=None, ssl_key=None,
|
def _get_ssl_context(
|
||||||
ssl_cafile=None, ssl_capath=None):
|
context_type=None, ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None
|
||||||
|
):
|
||||||
if not context_type:
|
if not context_type:
|
||||||
ssl_context = ssl.create_default_context(cafile=ssl_cafile,
|
ssl_context = ssl.create_default_context(cafile=ssl_cafile, capath=ssl_capath)
|
||||||
capath=ssl_capath)
|
|
||||||
else:
|
else:
|
||||||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||||
|
|
||||||
if ssl_cafile or ssl_capath:
|
if ssl_cafile or ssl_capath:
|
||||||
ssl_context.load_verify_locations(
|
ssl_context.load_verify_locations(cafile=ssl_cafile, capath=ssl_capath)
|
||||||
cafile=ssl_cafile, capath=ssl_capath)
|
|
||||||
|
|
||||||
ssl_context.load_cert_chain(
|
ssl_context.load_cert_chain(
|
||||||
certfile=os.path.abspath(os.path.expanduser(ssl_cert)),
|
certfile=os.path.abspath(os.path.expanduser(ssl_cert)),
|
||||||
keyfile=os.path.abspath(os.path.expanduser(ssl_key)) if ssl_key else None
|
keyfile=os.path.abspath(os.path.expanduser(ssl_key)) if ssl_key else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ssl_context
|
return ssl_context
|
||||||
|
|
||||||
|
|
||||||
def get_ssl_context(ssl_cert=None, ssl_key=None, ssl_cafile=None,
|
def get_ssl_context(ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None):
|
||||||
ssl_capath=None):
|
return _get_ssl_context(
|
||||||
return _get_ssl_context(context_type=None,
|
context_type=None,
|
||||||
ssl_cert=ssl_cert, ssl_key=ssl_key,
|
ssl_cert=ssl_cert,
|
||||||
ssl_cafile=ssl_cafile, ssl_capath=ssl_capath)
|
ssl_key=ssl_key,
|
||||||
|
ssl_cafile=ssl_cafile,
|
||||||
|
ssl_capath=ssl_capath,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_ssl_server_context(ssl_cert=None, ssl_key=None, ssl_cafile=None,
|
def get_ssl_server_context(
|
||||||
ssl_capath=None):
|
ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None
|
||||||
return _get_ssl_context(context_type=ssl.PROTOCOL_TLS_SERVER,
|
):
|
||||||
ssl_cert=ssl_cert, ssl_key=ssl_key,
|
return _get_ssl_context(
|
||||||
ssl_cafile=ssl_cafile, ssl_capath=ssl_capath)
|
context_type=ssl.PROTOCOL_TLS_SERVER,
|
||||||
|
ssl_cert=ssl_cert,
|
||||||
|
ssl_key=ssl_key,
|
||||||
|
ssl_cafile=ssl_cafile,
|
||||||
|
ssl_capath=ssl_capath,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_ssl_client_context(ssl_cert=None, ssl_key=None, ssl_cafile=None,
|
def get_ssl_client_context(
|
||||||
ssl_capath=None):
|
ssl_cert=None, ssl_key=None, ssl_cafile=None, ssl_capath=None
|
||||||
return _get_ssl_context(context_type=ssl.PROTOCOL_TLS_CLIENT,
|
):
|
||||||
ssl_cert=ssl_cert, ssl_key=ssl_key,
|
return _get_ssl_context(
|
||||||
ssl_cafile=ssl_cafile, ssl_capath=ssl_capath)
|
context_type=ssl.PROTOCOL_TLS_CLIENT,
|
||||||
|
ssl_cert=ssl_cert,
|
||||||
|
ssl_key=ssl_key,
|
||||||
|
ssl_cafile=ssl_cafile,
|
||||||
|
ssl_capath=ssl_capath,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def set_thread_name(name):
|
def set_thread_name(name):
|
||||||
|
@ -241,6 +256,7 @@ def set_thread_name(name):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import prctl
|
import prctl
|
||||||
|
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
prctl.set_name(name)
|
prctl.set_name(name)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -251,9 +267,9 @@ def find_bins_in_path(bin_name):
|
||||||
return [
|
return [
|
||||||
os.path.join(p, bin_name)
|
os.path.join(p, bin_name)
|
||||||
for p in os.environ.get('PATH', '').split(':')
|
for p in os.environ.get('PATH', '').split(':')
|
||||||
if os.path.isfile(os.path.join(p, bin_name)) and (
|
if os.path.isfile(os.path.join(p, bin_name))
|
||||||
os.name == 'nt' or os.access(os.path.join(p, bin_name), os.X_OK)
|
and (os.name == 'nt' or os.access(os.path.join(p, bin_name), os.X_OK))
|
||||||
)]
|
]
|
||||||
|
|
||||||
|
|
||||||
def find_files_by_ext(directory, *exts):
|
def find_files_by_ext(directory, *exts):
|
||||||
|
@ -271,7 +287,7 @@ def find_files_by_ext(directory, *exts):
|
||||||
max_len = len(max(exts, key=len))
|
max_len = len(max(exts, key=len))
|
||||||
result = []
|
result = []
|
||||||
|
|
||||||
for root, dirs, files in os.walk(directory):
|
for _, __, files in os.walk(directory):
|
||||||
for i in range(min_len, max_len + 1):
|
for i in range(min_len, max_len + 1):
|
||||||
result += [f for f in files if f[-i:] in exts]
|
result += [f for f in files if f[-i:] in exts]
|
||||||
|
|
||||||
|
@ -302,6 +318,7 @@ def get_ip_or_hostname():
|
||||||
|
|
||||||
def get_mime_type(resource):
|
def get_mime_type(resource):
|
||||||
import magic
|
import magic
|
||||||
|
|
||||||
if resource.startswith('file://'):
|
if resource.startswith('file://'):
|
||||||
resource = resource[len('file://') :]
|
resource = resource[len('file://') :]
|
||||||
|
|
||||||
|
@ -315,7 +332,9 @@ def get_mime_type(resource):
|
||||||
elif hasattr(magic, 'from_file'):
|
elif hasattr(magic, 'from_file'):
|
||||||
mime = magic.from_file(resource, mime=True)
|
mime = magic.from_file(resource, mime=True)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError('The installed magic version provides neither detect_from_filename nor from_file')
|
raise RuntimeError(
|
||||||
|
'The installed magic version provides neither detect_from_filename nor from_file'
|
||||||
|
)
|
||||||
|
|
||||||
if mime:
|
if mime:
|
||||||
return mime.mime_type if hasattr(mime, 'mime_type') else mime
|
return mime.mime_type if hasattr(mime, 'mime_type') else mime
|
||||||
|
@ -332,6 +351,7 @@ def grouper(n, iterable, fillvalue=None):
|
||||||
grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx
|
grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx
|
||||||
"""
|
"""
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
|
|
||||||
args = [iter(iterable)] * n
|
args = [iter(iterable)] * n
|
||||||
|
|
||||||
if fillvalue:
|
if fillvalue:
|
||||||
|
@ -355,6 +375,7 @@ def is_functional_cron(obj) -> bool:
|
||||||
|
|
||||||
def run(action, *args, **kwargs):
|
def run(action, *args, **kwargs):
|
||||||
from platypush.context import get_plugin
|
from platypush.context import get_plugin
|
||||||
|
|
||||||
(module_name, method_name) = get_module_and_method_from_action(action)
|
(module_name, method_name) = get_module_and_method_from_action(action)
|
||||||
plugin = get_plugin(module_name)
|
plugin = get_plugin(module_name)
|
||||||
method = getattr(plugin, method_name)
|
method = getattr(plugin, method_name)
|
||||||
|
@ -366,7 +387,9 @@ def run(action, *args, **kwargs):
|
||||||
return response.output
|
return response.output
|
||||||
|
|
||||||
|
|
||||||
def generate_rsa_key_pair(key_file: Optional[str] = None, size: int = 2048) -> Tuple[str, str]:
|
def generate_rsa_key_pair(
|
||||||
|
key_file: Optional[str] = None, size: int = 2048
|
||||||
|
) -> Tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Generate an RSA key pair.
|
Generate an RSA key pair.
|
||||||
|
|
||||||
|
@ -390,27 +413,30 @@ def generate_rsa_key_pair(key_file: Optional[str] = None, size: int = 2048) -> T
|
||||||
|
|
||||||
public_exp = 65537
|
public_exp = 65537
|
||||||
private_key = rsa.generate_private_key(
|
private_key = rsa.generate_private_key(
|
||||||
public_exponent=public_exp,
|
public_exponent=public_exp, key_size=size, backend=default_backend()
|
||||||
key_size=size,
|
|
||||||
backend=default_backend()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info('Generating RSA {} key pair'.format(size))
|
logger.info('Generating RSA {} key pair'.format(size))
|
||||||
private_key_str = private_key.private_bytes(
|
private_key_str = private_key.private_bytes(
|
||||||
encoding=serialization.Encoding.PEM,
|
encoding=serialization.Encoding.PEM,
|
||||||
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||||
encryption_algorithm=serialization.NoEncryption()
|
encryption_algorithm=serialization.NoEncryption(),
|
||||||
).decode()
|
).decode()
|
||||||
|
|
||||||
public_key_str = private_key.public_key().public_bytes(
|
public_key_str = (
|
||||||
|
private_key.public_key()
|
||||||
|
.public_bytes(
|
||||||
encoding=serialization.Encoding.PEM,
|
encoding=serialization.Encoding.PEM,
|
||||||
format=serialization.PublicFormat.PKCS1,
|
format=serialization.PublicFormat.PKCS1,
|
||||||
).decode()
|
)
|
||||||
|
.decode()
|
||||||
|
)
|
||||||
|
|
||||||
if key_file:
|
if key_file:
|
||||||
logger.info('Saving private key to {}'.format(key_file))
|
logger.info('Saving private key to {}'.format(key_file))
|
||||||
with open(os.path.expanduser(key_file), 'w') as f1, \
|
with open(os.path.expanduser(key_file), 'w') as f1, open(
|
||||||
open(os.path.expanduser(key_file) + '.pub', 'w') as f2:
|
os.path.expanduser(key_file) + '.pub', 'w'
|
||||||
|
) as f2:
|
||||||
f1.write(private_key_str)
|
f1.write(private_key_str)
|
||||||
f2.write(public_key_str)
|
f2.write(public_key_str)
|
||||||
os.chmod(key_file, 0o600)
|
os.chmod(key_file, 0o600)
|
||||||
|
@ -426,8 +452,7 @@ def get_or_generate_jwt_rsa_key_pair():
|
||||||
pub_key_file = priv_key_file + '.pub'
|
pub_key_file = priv_key_file + '.pub'
|
||||||
|
|
||||||
if os.path.isfile(priv_key_file) and os.path.isfile(pub_key_file):
|
if os.path.isfile(priv_key_file) and os.path.isfile(pub_key_file):
|
||||||
with open(pub_key_file, 'r') as f1, \
|
with open(pub_key_file, 'r') as f1, open(priv_key_file, 'r') as f2:
|
||||||
open(priv_key_file, 'r') as f2:
|
|
||||||
return f1.read(), f2.read()
|
return f1.read(), f2.read()
|
||||||
|
|
||||||
pathlib.Path(key_dir).mkdir(parents=True, exist_ok=True, mode=0o755)
|
pathlib.Path(key_dir).mkdir(parents=True, exist_ok=True, mode=0o755)
|
||||||
|
@ -439,7 +464,7 @@ def get_enabled_plugins() -> dict:
|
||||||
from platypush.context import get_plugin
|
from platypush.context import get_plugin
|
||||||
|
|
||||||
plugins = {}
|
plugins = {}
|
||||||
for name, config in Config.get_plugins().items():
|
for name in Config.get_plugins():
|
||||||
try:
|
try:
|
||||||
plugin = get_plugin(name)
|
plugin = get_plugin(name)
|
||||||
if plugin:
|
if plugin:
|
||||||
|
@ -453,11 +478,18 @@ def get_enabled_plugins() -> dict:
|
||||||
|
|
||||||
def get_redis() -> Redis:
|
def get_redis() -> Redis:
|
||||||
from platypush.config import Config
|
from platypush.config import Config
|
||||||
return Redis(**(Config.get('backend.redis') or {}).get('redis_args', {}))
|
|
||||||
|
return Redis(
|
||||||
|
**(
|
||||||
|
(Config.get('backend.redis') or {}).get('redis_args', {})
|
||||||
|
or Config.get('redis')
|
||||||
|
or {}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def to_datetime(t: Union[str, int, float, datetime.datetime]) -> datetime.datetime:
|
def to_datetime(t: Union[str, int, float, datetime.datetime]) -> datetime.datetime:
|
||||||
if isinstance(t, int) or isinstance(t, float):
|
if isinstance(t, (int, float)):
|
||||||
return datetime.datetime.fromtimestamp(t, tz=tz.tzutc())
|
return datetime.datetime.fromtimestamp(t, tz=tz.tzutc())
|
||||||
if isinstance(t, str):
|
if isinstance(t, str):
|
||||||
return parser.parse(t)
|
return parser.parse(t)
|
||||||
|
|
Loading…
Reference in a new issue