From be8140ddb5f40b644eb9d8f61d7ae5d7b20f0ccf Mon Sep 17 00:00:00 2001 From: Fabio Manganiello Date: Mon, 16 Sep 2024 03:16:53 +0200 Subject: [PATCH] [procedure] Several improvements to the procedure engine. - Add `set` statement, which can be used to set context variables within YAML procedures. Example: ```yaml procedure.test: - set: foo: bar - action: logger.info args: msg: ${bar} ``` - More reliable flow control for nested break/continue/return. - Propagate changes to context variables also to upstream procedures. --- platypush/procedure/__init__.py | 205 ++++++++++++++++++-------------- 1 file changed, 119 insertions(+), 86 deletions(-) diff --git a/platypush/procedure/__init__.py b/platypush/procedure/__init__.py index bb5c01675e..0e28b5f1a9 100644 --- a/platypush/procedure/__init__.py +++ b/platypush/procedure/__init__.py @@ -1,11 +1,12 @@ import enum import logging import re -from dataclasses import dataclass +from copy import deepcopy +from dataclasses import dataclass, field from functools import wraps from queue import LifoQueue -from typing import Any, List, Optional +from typing import Any, Dict, Iterable, List, Optional from ..common import exec_wrapper from ..config import Config @@ -23,6 +24,7 @@ class StatementType(enum.Enum): BREAK = 'break' CONTINUE = 'continue' RETURN = 'return' + SET = 'set' @dataclass @@ -32,7 +34,7 @@ class Statement: """ type: StatementType - argument: Optional[str] = None + argument: Optional[Any] = None @classmethod def build(cls, statement: str): @@ -60,8 +62,30 @@ class ReturnStatement(Statement): type: StatementType = StatementType.RETURN + def run(self, *_, **context) -> Any: + return Response( + output=Request.expand_value_from_context( + self.argument, **_update_context(context) + ) + ) + + +@dataclass +class SetStatement(Statement): + """ + Models a set variable statement in a procedure. + """ + + type: StatementType = StatementType.SET + vars: dict = field(default_factory=dict) + def run(self, *_, **context): - return Request.expand_value_from_context(self.argument, **context) + vars = deepcopy(self.vars) # pylint: disable=redefined-builtin + for k, v in vars.items(): + vars[k] = Request.expand_value_from_context(v, **context) + + context.update(vars) + return Response(output=vars) class Procedure: @@ -117,10 +141,20 @@ class Procedure: # Check if it's a return statement with a value if ( len(request_config.keys()) == 1 - and list(request_config.keys())[0] == 'return' + and list(request_config.keys())[0] == StatementType.RETURN.value ): cls._flush_if_statements(reqs, if_config) - reqs.append(ReturnStatement(argument=request_config['return'])) + reqs.append( + ReturnStatement(argument=request_config[StatementType.RETURN.value]) + ) + continue + + # Check if it's a variable set statement + if (len(request_config.keys()) == 1) and ( + list(request_config.keys())[0] == StatementType.SET.value + ): + cls._flush_if_statements(reqs, if_config) + reqs.append(SetStatement(vars=request_config[StatementType.SET.value])) continue # Check if this request is an if-else @@ -129,6 +163,7 @@ class Procedure: m = re.match(r'\s*(if)\s+\${(.*)}\s*', key) if m: + cls._flush_if_statements(reqs, if_config) if_count += 1 if_name = f'{name}__if_{if_count}' condition = m.group(2) @@ -233,86 +268,96 @@ class Procedure: pending_if = if_config.get() requests.append(IfProcedure.build(**pending_if)) - @staticmethod - def _find_nearest_loop(stack): - for proc in stack[::-1]: - if isinstance(proc, LoopProcedure): - return proc - - raise AssertionError('break/continue statement found outside of a loop') - # pylint: disable=too-many-branches,too-many-statements - def execute(self, n_tries=1, __stack__=None, **context): + def execute( + self, + n_tries: int = 1, + __stack__: Optional[Iterable] = None, + new_context: Optional[Dict[str, Any]] = None, + **context, + ): """ Execute the requests in the procedure. :param n_tries: Number of tries in case of failure before raising a RuntimeError. """ - if not __stack__: - __stack__ = [self] - else: - __stack__.append(self) + __stack__ = (self,) if not __stack__ else (self, *__stack__) + new_context = new_context or {} if self.args: args = self.args.copy() for k, v in args.items(): - v = Request.expand_value_from_context(v, **context) - args[k] = v - context[k] = v + args[k] = context[k] = Request.expand_value_from_context(v, **context) logger.info('Executing procedure %s with arguments %s', self.name, args) else: logger.info('Executing procedure %s', self.name) response = Response() token = Config.get('token') + context = _update_context(context) + locals().update(context) + # pylint: disable=too-many-nested-blocks for request in self.requests: if callable(request): response = request(**context) continue + context['_async'] = self._async + context['n_tries'] = n_tries + context['__stack__'] = __stack__ + context['new_context'] = new_context + if isinstance(request, Statement): if isinstance(request, ReturnStatement): - response = Response(output=request.run(**context)) + response = request.run(**context) self._should_return = True for proc in __stack__: proc._should_return = True # pylint: disable=protected-access + break + if isinstance(request, SetStatement): + rs: dict = request.run(**context).output # type: ignore + context.update(rs) + new_context.update(rs) + locals().update(rs) + continue + if request.type in [StatementType.BREAK, StatementType.CONTINUE]: - loop = self._find_nearest_loop(__stack__) - if request.type == StatementType.BREAK: - loop._should_break = True # pylint: disable=protected-access - else: - loop._should_continue = True # pylint: disable=protected-access + for proc in __stack__: + if isinstance(proc, LoopProcedure): + if request.type == StatementType.BREAK: + setattr(proc, '_should_break', True) # noqa: B010 + else: + setattr(proc, '_should_continue', True) # noqa: B010 + break + + proc._should_return = True # pylint: disable=protected-access + break should_continue = getattr(self, '_should_continue', False) should_break = getattr(self, '_should_break', False) - if isinstance(self, LoopProcedure) and (should_continue or should_break): - if should_continue: - setattr(self, '_should_continue', False) # noqa[B010] - else: - setattr(self, '_should_break', False) # noqa[B010] - + if self._should_return or should_continue or should_break: break if token and not isinstance(request, Statement): request.token = token - context['_async'] = self._async - context['n_tries'] = n_tries exec_ = getattr(request, 'execute', None) if callable(exec_): - response = exec_(__stack__=__stack__, **context) + response = exec_(**context) + context.update(context.get('new_context', {})) if not self._async and response: if isinstance(response.output, dict): - for k, v in response.output.items(): - context[k] = v + context.update(response.output) context['output'] = response.output context['errors'] = response.errors + new_context.update(context) + locals().update(context) if self._should_return: break @@ -333,10 +378,8 @@ class LoopProcedure(Procedure): Base class while and for/fork loops. """ - def __init__(self, name, requests, _async=False, args=None, backend=None): - super().__init__( - name=name, _async=_async, requests=requests, args=args, backend=backend - ) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self._should_break = False self._should_continue = False @@ -381,6 +424,9 @@ class ForProcedure(LoopProcedure): # pylint: disable=eval-used def execute(self, *_, **context): + ctx = _update_context(context) + locals().update(ctx) + try: iterable = eval(self.iterable) assert hasattr( @@ -388,11 +434,18 @@ class ForProcedure(LoopProcedure): ), f'Object of type {type(iterable)} is not iterable: {iterable}' except Exception as e: logger.debug('Iterable %s expansion error: %s', self.iterable, e) - iterable = Request.expand_value_from_context(self.iterable, **context) + iterable = Request.expand_value_from_context(self.iterable, **ctx) response = Response() for item in iterable: + ctx[self.iterator_name] = item + response = super().execute(**ctx) + ctx.update(ctx.get('new_context', {})) + + if response.output and isinstance(response.output, dict): + ctx = _update_context(ctx, **response.output) + if self._should_return: logger.info('Returning from %s', self.name) break @@ -407,9 +460,6 @@ class ForProcedure(LoopProcedure): logger.info('Breaking loop %s', self.name) break - context[self.iterator_name] = item - response = super().execute(**context) - return response @@ -446,41 +496,23 @@ class WhileProcedure(LoopProcedure): ) self.condition = condition - @staticmethod - def _get_context(**context): - for k, v in context.items(): - try: - context[k] = eval(v) # pylint: disable=eval-used - except Exception as e: - logger.debug('Evaluation error for %s=%s: %s', k, v, e) - if isinstance(v, str): - try: - context[k] = eval( # pylint: disable=eval-used - '"' + re.sub(r'(^|[^\\])"', '\1\\"', v) + '"' - ) - except Exception as ee: - logger.warning( - 'Could not parse value for context variable %s=%s: %s', - k, - v, - ee, - ) - logger.warning('Context: %s', context) - logger.exception(e) - - return context - def execute(self, *_, **context): response = Response() - context = self._get_context(**context) - for k, v in context.items(): - locals()[k] = v + ctx = _update_context(context) + locals().update(ctx) while True: condition_true = eval(self.condition) # pylint: disable=eval-used if not condition_true: break + response = super().execute(**ctx) + ctx.update(ctx.get('new_context', {})) + if response.output and isinstance(response.output, dict): + _update_context(ctx, **response.output) + + locals().update(ctx) + if self._should_return: logger.info('Returning from %s', self.name) break @@ -495,13 +527,6 @@ class WhileProcedure(LoopProcedure): logger.info('Breaking loop %s', self.name) break - response = super().execute(**context) - - if response.output and isinstance(response.output, dict): - new_context = self._get_context(**response.output) - for k, v in new_context.items(): - locals()[k] = v - return response @@ -595,20 +620,28 @@ class IfProcedure(Procedure): ) def execute(self, *_, **context): - for k, v in context.items(): - locals()[k] = v - + ctx = _update_context(context) + locals().update(ctx) condition_true = eval(self.condition) # pylint: disable=eval-used response = Response() if condition_true: - response = super().execute(**context) + response = super().execute(**ctx) elif self.else_branch: - response = self.else_branch.execute(**context) + response = self.else_branch.execute(**ctx) return response +def _update_context(context: Optional[Dict[str, Any]] = None, **kwargs): + ctx = context or {} + ctx = {**ctx.get('context', {}), **ctx, **kwargs} + for k, v in ctx.items(): + ctx[k] = Request.expand_value_from_context(v, **ctx) + + return ctx + + def procedure(name_or_func: Optional[str] = None, *upper_args, **upper_kwargs): name = name_or_func if isinstance(name_or_func, str) else None