diff --git a/platypush/context/__init__.py b/platypush/context/__init__.py index 982926e7..9b22d9c0 100644 --- a/platypush/context/__init__.py +++ b/platypush/context/__init__.py @@ -8,7 +8,7 @@ from typing import Optional, Any from ..bus import Bus from ..config import Config -from ..utils import get_enabled_plugins +from ..utils import get_enabled_plugins, get_plugin_name_by_class logger = logging.getLogger('platypush:context') @@ -108,32 +108,50 @@ def get_backend(name): return _ctx.backends.get(name) -def get_plugin(plugin_name, reload=False): +# pylint: disable=too-many-branches +def get_plugin(plugin, plugin_name=None, reload=False): """ Registers a plugin instance by name if not registered already, or returns the registered plugin instance. + + :param plugin: Plugin name or class type. + :param plugin_name: Plugin name, kept only for backwards compatibility. + :param reload: If ``True``, the plugin will be reloaded if it's already + been registered. """ - if plugin_name not in plugins_init_locks: - plugins_init_locks[plugin_name] = RLock() + from ..plugins import Plugin - if plugin_name in _ctx.plugins and not reload: - return _ctx.plugins[plugin_name] + if isinstance(plugin, str): + name = plugin + elif plugin_name: + name = plugin_name + elif issubclass(plugin, Plugin): + name = get_plugin_name_by_class(plugin) # type: ignore + else: + raise TypeError(f'Invalid plugin type/name: {plugin}') - try: - plugin = importlib.import_module('platypush.plugins.' + plugin_name) - except ImportError as e: - logger.warning('No such plugin: %s', plugin_name) - raise RuntimeError(e) from e + if name not in plugins_init_locks: + plugins_init_locks[name] = RLock() + + if name in _ctx.plugins and not reload: + return _ctx.plugins[name] + + if isinstance(plugin, str): + try: + plugin = importlib.import_module( + 'platypush.plugins.' + name + ) # type: ignore + except ImportError as e: + logger.warning('No such plugin: %s', name) + raise RuntimeError(e) from e # e.g. plugins.music.mpd main class: MusicMpdPlugin cls_name = '' - for token in plugin_name.split('.'): + for token in name.split('.'): cls_name += token.title() cls_name += 'Plugin' - plugin_conf = ( - Config.get_plugins()[plugin_name] if plugin_name in Config.get_plugins() else {} - ) + plugin_conf = Config.get_plugins()[name] if name in Config.get_plugins() else {} if 'disabled' in plugin_conf: if plugin_conf['disabled'] is True: @@ -148,15 +166,15 @@ def get_plugin(plugin_name, reload=False): try: plugin_class = getattr(plugin, cls_name) except AttributeError as e: - logger.warning('No such class in %s: %s [error: %s]', plugin_name, cls_name, e) + logger.warning('No such class in %s: %s [error: %s]', name, cls_name, e) raise RuntimeError(e) from e - with plugins_init_locks[plugin_name]: - if _ctx.plugins.get(plugin_name) and not reload: - return _ctx.plugins[plugin_name] - _ctx.plugins[plugin_name] = plugin_class(**plugin_conf) + with plugins_init_locks[name]: + if _ctx.plugins.get(name) and not reload: + return _ctx.plugins[name] + _ctx.plugins[name] = plugin_class(**plugin_conf) - return _ctx.plugins[plugin_name] + return _ctx.plugins[name] def get_bus() -> Bus: