diff --git a/platypush/message/request/__init__.py b/platypush/message/request/__init__.py index 7a55e65429..8dda18eb7d 100644 --- a/platypush/message/request/__init__.py +++ b/platypush/message/request/__init__.py @@ -64,12 +64,13 @@ class Request(Message): msg = super().parse(msg) args = { 'target': msg.get('target', Config.get('device_id')), - 'action': msg['action'], + 'action': msg.get('action', msg.get('name')), 'args': msg.get('args', {}), 'id': msg['id'] if 'id' in msg else cls._generate_id(), 'timestamp': msg['_timestamp'] if '_timestamp' in msg else time.time(), } + assert args.get('action'), 'No action specified in the request' if 'origin' in msg: args['origin'] = msg['origin'] if 'token' in msg: diff --git a/platypush/plugins/procedures/__init__.py b/platypush/plugins/procedures/__init__.py index 3e66f7631e..cffd2f65f1 100644 --- a/platypush/plugins/procedures/__init__.py +++ b/platypush/plugins/procedures/__init__.py @@ -286,9 +286,10 @@ class ProceduresPlugin(RunnablePlugin, ProcedureEntityManager): @classmethod def _serialize_action(cls, data: Union[Iterable, Dict]) -> Union[Dict, List]: if isinstance(data, dict): - if data.get('action'): + name = data.get('action', data.get('name')) + if name: return { - 'action': data['action'], + 'action': name, **({'args': data['args']} if data.get('args') else {}), } diff --git a/platypush/procedure/__init__.py b/platypush/procedure/__init__.py index 56cdc8f130..36c57c4dfa 100644 --- a/platypush/procedure/__init__.py +++ b/platypush/procedure/__init__.py @@ -1,10 +1,11 @@ import enum import logging import re +from dataclasses import dataclass from functools import wraps from queue import LifoQueue -from typing import Optional +from typing import Any, Optional from ..common import exec_wrapper from ..config import Config @@ -14,7 +15,7 @@ from ..message.response import Response logger = logging.getLogger('platypush') -class Statement(enum.Enum): +class StatementType(enum.Enum): """ Enumerates the possible statements in a procedure. """ @@ -24,6 +25,45 @@ class Statement(enum.Enum): RETURN = 'return' +@dataclass +class Statement: + """ + Models a statement in a procedure. + """ + + type: StatementType + argument: Optional[str] = None + + @classmethod + def build(cls, statement: str): + """ + Builds a statement from a string. + """ + + m = re.match(r'\s*return\s*(.*)\s*', statement, re.IGNORECASE) + if m: + return ReturnStatement(argument=m.group(1)) + + return cls(StatementType(statement.lower())) + + def run(self, *_, **__) -> Optional[Any]: + """ + Executes the statement. + """ + + +@dataclass +class ReturnStatement(Statement): + """ + Models a return statement in a procedure. + """ + + type: StatementType = StatementType.RETURN + + def run(self, *_, **context): + return Request.expand_value_from_context(self.argument, **context) + + class Procedure: """Procedure class. A procedure is a pre-configured list of requests""" @@ -55,7 +95,6 @@ class Procedure: requests, args=None, backend=None, - id=None, # pylint: disable=redefined-builtin procedure_class=None, **kwargs, ): @@ -66,11 +105,20 @@ class Procedure: if_config = LifoQueue() procedure_class = procedure_class or cls key = None + kwargs.pop('id', None) for request_config in requests: # Check if it's a break/continue/return statement if isinstance(request_config, str): - reqs.append(Statement(request_config)) + reqs.append(Statement.build(request_config)) + continue + + # Check if it's a return statement with a value + if ( + len(request_config.keys()) == 1 + and list(request_config.keys())[0] == 'return' + ): + reqs.append(ReturnStatement(argument=request_config['return'])) continue # Check if this request is an if-else @@ -91,7 +139,6 @@ class Procedure: 'condition': condition, 'else_branch': [], 'backend': backend, - 'id': id, } ) @@ -132,7 +179,6 @@ class Procedure: _async=_async, requests=request_config[key], backend=backend, - id=id, iterator_name=iterator_name, iterable=iterable, ) @@ -156,14 +202,12 @@ class Procedure: requests=request_config[key], condition=condition, backend=backend, - id=id, ) reqs.append(loop) continue request_config['origin'] = Config.get('device_id') - request_config['id'] = id if 'target' not in request_config: request_config['target'] = request_config['origin'] @@ -222,15 +266,16 @@ class Procedure: continue if isinstance(request, Statement): - if request == Statement.RETURN: + if isinstance(request, ReturnStatement): + response = Response(output=request.run(**context)) self._should_return = True for proc in __stack__: proc._should_return = True # pylint: disable=protected-access break - if request in [Statement.BREAK, Statement.CONTINUE]: + if request.type in [StatementType.BREAK, StatementType.CONTINUE]: loop = self._find_nearest_loop(__stack__) - if request == Statement.BREAK: + if request == StatementType.BREAK: loop._should_break = True # pylint: disable=protected-access else: loop._should_continue = True # pylint: disable=protected-access