diff --git a/platypush/procedure/__init__.py b/platypush/procedure/__init__.py index 50ac5bb49c..36c57c4dfa 100644 --- a/platypush/procedure/__init__.py +++ b/platypush/procedure/__init__.py @@ -1,10 +1,11 @@ import enum import logging import re +from dataclasses import dataclass from functools import wraps from queue import LifoQueue -from typing import Optional +from typing import Any, Optional from ..common import exec_wrapper from ..config import Config @@ -14,7 +15,7 @@ from ..message.response import Response logger = logging.getLogger('platypush') -class Statement(enum.Enum): +class StatementType(enum.Enum): """ Enumerates the possible statements in a procedure. """ @@ -24,6 +25,45 @@ class Statement(enum.Enum): RETURN = 'return' +@dataclass +class Statement: + """ + Models a statement in a procedure. + """ + + type: StatementType + argument: Optional[str] = None + + @classmethod + def build(cls, statement: str): + """ + Builds a statement from a string. + """ + + m = re.match(r'\s*return\s*(.*)\s*', statement, re.IGNORECASE) + if m: + return ReturnStatement(argument=m.group(1)) + + return cls(StatementType(statement.lower())) + + def run(self, *_, **__) -> Optional[Any]: + """ + Executes the statement. + """ + + +@dataclass +class ReturnStatement(Statement): + """ + Models a return statement in a procedure. + """ + + type: StatementType = StatementType.RETURN + + def run(self, *_, **context): + return Request.expand_value_from_context(self.argument, **context) + + class Procedure: """Procedure class. A procedure is a pre-configured list of requests""" @@ -70,7 +110,15 @@ class Procedure: for request_config in requests: # Check if it's a break/continue/return statement if isinstance(request_config, str): - reqs.append(Statement(request_config)) + reqs.append(Statement.build(request_config)) + continue + + # Check if it's a return statement with a value + if ( + len(request_config.keys()) == 1 + and list(request_config.keys())[0] == 'return' + ): + reqs.append(ReturnStatement(argument=request_config['return'])) continue # Check if this request is an if-else @@ -218,15 +266,16 @@ class Procedure: continue if isinstance(request, Statement): - if request == Statement.RETURN: + if isinstance(request, ReturnStatement): + response = Response(output=request.run(**context)) self._should_return = True for proc in __stack__: proc._should_return = True # pylint: disable=protected-access break - if request in [Statement.BREAK, Statement.CONTINUE]: + if request.type in [StatementType.BREAK, StatementType.CONTINUE]: loop = self._find_nearest_loop(__stack__) - if request == Statement.BREAK: + if request == StatementType.BREAK: loop._should_break = True # pylint: disable=protected-access else: loop._should_continue = True # pylint: disable=protected-access