From 95d86829aa08ff2598f2d3313de471d1c2f22dbb Mon Sep 17 00:00:00 2001 From: Fabio Manganiello Date: Tue, 17 Jul 2018 01:23:12 +0200 Subject: [PATCH] Plugin action decorators can now be inherited from parent classes --- platypush/plugins/__init__.py | 5 ++++- platypush/utils/__init__.py | 29 ++++++++++++++++++++++------- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/platypush/plugins/__init__.py b/platypush/plugins/__init__.py index e424f67d6..2836cdcc9 100644 --- a/platypush/plugins/__init__.py +++ b/platypush/plugins/__init__.py @@ -36,7 +36,10 @@ class Plugin(object): if 'logging' in kwargs: self.logger.setLevel(getattr(logging, kwargs['logging'].upper())) - self.registered_actions = set(get_decorators(self.__class__).get('action', [])) + self.registered_actions = set( + get_decorators(self.__class__, climb_class_hierarchy=True) + .get('action', []) + ) def run(self, method, *args, **kwargs): if method not in self.registered_actions: diff --git a/platypush/utils/__init__.py b/platypush/utils/__init__.py index 1015ac893..01644a921 100644 --- a/platypush/utils/__init__.py +++ b/platypush/utils/__init__.py @@ -73,12 +73,16 @@ def get_hash(s): return hashlib.sha256(s.encode('utf-8')).hexdigest() -def get_decorators(cls): - target = cls +def get_decorators(cls, climb_class_hierarchy=False): + """ + Get the decorators of a class as a {"decorator_name": [list of methods]} dictionary + :param climb_class_hierarchy: If set to True (default: False), it will search return the decorators in the parent classes as well + :type climb_class_hierarchy: bool + """ + decorators = {} def visit_FunctionDef(node): - # decorators[node.name] = [] for n in node.decorator_list: name = '' if isinstance(n, ast.Call): @@ -86,13 +90,24 @@ def get_decorators(cls): else: name = n.attr if isinstance(n, ast.Attribute) else n.id - decorators[name] = decorators.get(name, []) - # decorators[node.name].append(name) - decorators[name].append(node.name) + decorators[name] = decorators.get(name, set()) + decorators[name].add(node.name) + + if climb_class_hierarchy: + targets = inspect.getmro(cls) + else: + targets = [cls] node_iter = ast.NodeVisitor() node_iter.visit_FunctionDef = visit_FunctionDef - node_iter.visit(ast.parse(inspect.getsource(target))) + + for target in targets: + try: + node_iter.visit(ast.parse(inspect.getsource(target))) + except TypeError: + # Ignore built-in classes + pass + return decorators