Added support for get_plugin(MyPlugin) besides get_plugin('my').

This commit is contained in:
Fabio Manganiello 2023-03-10 11:47:39 +01:00
parent 3fcc9957d1
commit 60da930e4b
Signed by: blacklight
GPG key ID: D90FBA7F76362774

View file

@ -8,7 +8,7 @@ from typing import Optional, Any
from ..bus import Bus from ..bus import Bus
from ..config import Config from ..config import Config
from ..utils import get_enabled_plugins from ..utils import get_enabled_plugins, get_plugin_name_by_class
logger = logging.getLogger('platypush:context') logger = logging.getLogger('platypush:context')
@ -108,32 +108,50 @@ def get_backend(name):
return _ctx.backends.get(name) return _ctx.backends.get(name)
def get_plugin(plugin_name, reload=False): # pylint: disable=too-many-branches
def get_plugin(plugin, plugin_name=None, reload=False):
""" """
Registers a plugin instance by name if not registered already, or returns Registers a plugin instance by name if not registered already, or returns
the registered plugin instance. the registered plugin instance.
:param plugin: Plugin name or class type.
:param plugin_name: Plugin name, kept only for backwards compatibility.
:param reload: If ``True``, the plugin will be reloaded if it's already
been registered.
""" """
if plugin_name not in plugins_init_locks: from ..plugins import Plugin
plugins_init_locks[plugin_name] = RLock()
if plugin_name in _ctx.plugins and not reload: if isinstance(plugin, str):
return _ctx.plugins[plugin_name] name = plugin
elif plugin_name:
name = plugin_name
elif issubclass(plugin, Plugin):
name = get_plugin_name_by_class(plugin) # type: ignore
else:
raise TypeError(f'Invalid plugin type/name: {plugin}')
try: if name not in plugins_init_locks:
plugin = importlib.import_module('platypush.plugins.' + plugin_name) plugins_init_locks[name] = RLock()
except ImportError as e:
logger.warning('No such plugin: %s', plugin_name) if name in _ctx.plugins and not reload:
raise RuntimeError(e) from e return _ctx.plugins[name]
if isinstance(plugin, str):
try:
plugin = importlib.import_module(
'platypush.plugins.' + name
) # type: ignore
except ImportError as e:
logger.warning('No such plugin: %s', name)
raise RuntimeError(e) from e
# e.g. plugins.music.mpd main class: MusicMpdPlugin # e.g. plugins.music.mpd main class: MusicMpdPlugin
cls_name = '' cls_name = ''
for token in plugin_name.split('.'): for token in name.split('.'):
cls_name += token.title() cls_name += token.title()
cls_name += 'Plugin' cls_name += 'Plugin'
plugin_conf = ( plugin_conf = Config.get_plugins()[name] if name in Config.get_plugins() else {}
Config.get_plugins()[plugin_name] if plugin_name in Config.get_plugins() else {}
)
if 'disabled' in plugin_conf: if 'disabled' in plugin_conf:
if plugin_conf['disabled'] is True: if plugin_conf['disabled'] is True:
@ -148,15 +166,15 @@ def get_plugin(plugin_name, reload=False):
try: try:
plugin_class = getattr(plugin, cls_name) plugin_class = getattr(plugin, cls_name)
except AttributeError as e: except AttributeError as e:
logger.warning('No such class in %s: %s [error: %s]', plugin_name, cls_name, e) logger.warning('No such class in %s: %s [error: %s]', name, cls_name, e)
raise RuntimeError(e) from e raise RuntimeError(e) from e
with plugins_init_locks[plugin_name]: with plugins_init_locks[name]:
if _ctx.plugins.get(plugin_name) and not reload: if _ctx.plugins.get(name) and not reload:
return _ctx.plugins[plugin_name] return _ctx.plugins[name]
_ctx.plugins[plugin_name] = plugin_class(**plugin_conf) _ctx.plugins[name] = plugin_class(**plugin_conf)
return _ctx.plugins[plugin_name] return _ctx.plugins[name]
def get_bus() -> Bus: def get_bus() -> Bus: