[procedure] Several improvements to the procedure engine.

- Add `set` statement, which can be used to set context variables within
  YAML procedures. Example:

```yaml
procedure.test:
  - set:
      foo: bar

  - action: logger.info
    args:
      msg: ${bar}
```

- More reliable flow control for nested break/continue/return.

- Propagate changes to context variables also to upstream procedures.
This commit is contained in:
Fabio Manganiello 2024-09-16 03:16:53 +02:00
parent 771e32e368
commit be8140ddb5
Signed by untrusted user: blacklight
GPG key ID: D90FBA7F76362774

View file

@ -1,11 +1,12 @@
import enum import enum
import logging import logging
import re import re
from dataclasses import dataclass from copy import deepcopy
from dataclasses import dataclass, field
from functools import wraps from functools import wraps
from queue import LifoQueue from queue import LifoQueue
from typing import Any, List, Optional from typing import Any, Dict, Iterable, List, Optional
from ..common import exec_wrapper from ..common import exec_wrapper
from ..config import Config from ..config import Config
@ -23,6 +24,7 @@ class StatementType(enum.Enum):
BREAK = 'break' BREAK = 'break'
CONTINUE = 'continue' CONTINUE = 'continue'
RETURN = 'return' RETURN = 'return'
SET = 'set'
@dataclass @dataclass
@ -32,7 +34,7 @@ class Statement:
""" """
type: StatementType type: StatementType
argument: Optional[str] = None argument: Optional[Any] = None
@classmethod @classmethod
def build(cls, statement: str): def build(cls, statement: str):
@ -60,8 +62,30 @@ class ReturnStatement(Statement):
type: StatementType = StatementType.RETURN type: StatementType = StatementType.RETURN
def run(self, *_, **context) -> Any:
return Response(
output=Request.expand_value_from_context(
self.argument, **_update_context(context)
)
)
@dataclass
class SetStatement(Statement):
"""
Models a set variable statement in a procedure.
"""
type: StatementType = StatementType.SET
vars: dict = field(default_factory=dict)
def run(self, *_, **context): def run(self, *_, **context):
return Request.expand_value_from_context(self.argument, **context) vars = deepcopy(self.vars) # pylint: disable=redefined-builtin
for k, v in vars.items():
vars[k] = Request.expand_value_from_context(v, **context)
context.update(vars)
return Response(output=vars)
class Procedure: class Procedure:
@ -117,10 +141,20 @@ class Procedure:
# Check if it's a return statement with a value # Check if it's a return statement with a value
if ( if (
len(request_config.keys()) == 1 len(request_config.keys()) == 1
and list(request_config.keys())[0] == 'return' and list(request_config.keys())[0] == StatementType.RETURN.value
): ):
cls._flush_if_statements(reqs, if_config) cls._flush_if_statements(reqs, if_config)
reqs.append(ReturnStatement(argument=request_config['return'])) reqs.append(
ReturnStatement(argument=request_config[StatementType.RETURN.value])
)
continue
# Check if it's a variable set statement
if (len(request_config.keys()) == 1) and (
list(request_config.keys())[0] == StatementType.SET.value
):
cls._flush_if_statements(reqs, if_config)
reqs.append(SetStatement(vars=request_config[StatementType.SET.value]))
continue continue
# Check if this request is an if-else # Check if this request is an if-else
@ -129,6 +163,7 @@ class Procedure:
m = re.match(r'\s*(if)\s+\${(.*)}\s*', key) m = re.match(r'\s*(if)\s+\${(.*)}\s*', key)
if m: if m:
cls._flush_if_statements(reqs, if_config)
if_count += 1 if_count += 1
if_name = f'{name}__if_{if_count}' if_name = f'{name}__if_{if_count}'
condition = m.group(2) condition = m.group(2)
@ -233,86 +268,96 @@ class Procedure:
pending_if = if_config.get() pending_if = if_config.get()
requests.append(IfProcedure.build(**pending_if)) requests.append(IfProcedure.build(**pending_if))
@staticmethod
def _find_nearest_loop(stack):
for proc in stack[::-1]:
if isinstance(proc, LoopProcedure):
return proc
raise AssertionError('break/continue statement found outside of a loop')
# pylint: disable=too-many-branches,too-many-statements # pylint: disable=too-many-branches,too-many-statements
def execute(self, n_tries=1, __stack__=None, **context): def execute(
self,
n_tries: int = 1,
__stack__: Optional[Iterable] = None,
new_context: Optional[Dict[str, Any]] = None,
**context,
):
""" """
Execute the requests in the procedure. Execute the requests in the procedure.
:param n_tries: Number of tries in case of failure before raising a RuntimeError. :param n_tries: Number of tries in case of failure before raising a RuntimeError.
""" """
if not __stack__: __stack__ = (self,) if not __stack__ else (self, *__stack__)
__stack__ = [self] new_context = new_context or {}
else:
__stack__.append(self)
if self.args: if self.args:
args = self.args.copy() args = self.args.copy()
for k, v in args.items(): for k, v in args.items():
v = Request.expand_value_from_context(v, **context) args[k] = context[k] = Request.expand_value_from_context(v, **context)
args[k] = v
context[k] = v
logger.info('Executing procedure %s with arguments %s', self.name, args) logger.info('Executing procedure %s with arguments %s', self.name, args)
else: else:
logger.info('Executing procedure %s', self.name) logger.info('Executing procedure %s', self.name)
response = Response() response = Response()
token = Config.get('token') token = Config.get('token')
context = _update_context(context)
locals().update(context)
# pylint: disable=too-many-nested-blocks
for request in self.requests: for request in self.requests:
if callable(request): if callable(request):
response = request(**context) response = request(**context)
continue continue
context['_async'] = self._async
context['n_tries'] = n_tries
context['__stack__'] = __stack__
context['new_context'] = new_context
if isinstance(request, Statement): if isinstance(request, Statement):
if isinstance(request, ReturnStatement): if isinstance(request, ReturnStatement):
response = Response(output=request.run(**context)) response = request.run(**context)
self._should_return = True self._should_return = True
for proc in __stack__: for proc in __stack__:
proc._should_return = True # pylint: disable=protected-access proc._should_return = True # pylint: disable=protected-access
break break
if isinstance(request, SetStatement):
rs: dict = request.run(**context).output # type: ignore
context.update(rs)
new_context.update(rs)
locals().update(rs)
continue
if request.type in [StatementType.BREAK, StatementType.CONTINUE]: if request.type in [StatementType.BREAK, StatementType.CONTINUE]:
loop = self._find_nearest_loop(__stack__) for proc in __stack__:
if isinstance(proc, LoopProcedure):
if request.type == StatementType.BREAK: if request.type == StatementType.BREAK:
loop._should_break = True # pylint: disable=protected-access setattr(proc, '_should_break', True) # noqa: B010
else: else:
loop._should_continue = True # pylint: disable=protected-access setattr(proc, '_should_continue', True) # noqa: B010
break
proc._should_return = True # pylint: disable=protected-access
break break
should_continue = getattr(self, '_should_continue', False) should_continue = getattr(self, '_should_continue', False)
should_break = getattr(self, '_should_break', False) should_break = getattr(self, '_should_break', False)
if isinstance(self, LoopProcedure) and (should_continue or should_break): if self._should_return or should_continue or should_break:
if should_continue:
setattr(self, '_should_continue', False) # noqa[B010]
else:
setattr(self, '_should_break', False) # noqa[B010]
break break
if token and not isinstance(request, Statement): if token and not isinstance(request, Statement):
request.token = token request.token = token
context['_async'] = self._async
context['n_tries'] = n_tries
exec_ = getattr(request, 'execute', None) exec_ = getattr(request, 'execute', None)
if callable(exec_): if callable(exec_):
response = exec_(__stack__=__stack__, **context) response = exec_(**context)
context.update(context.get('new_context', {}))
if not self._async and response: if not self._async and response:
if isinstance(response.output, dict): if isinstance(response.output, dict):
for k, v in response.output.items(): context.update(response.output)
context[k] = v
context['output'] = response.output context['output'] = response.output
context['errors'] = response.errors context['errors'] = response.errors
new_context.update(context)
locals().update(context)
if self._should_return: if self._should_return:
break break
@ -333,10 +378,8 @@ class LoopProcedure(Procedure):
Base class while and for/fork loops. Base class while and for/fork loops.
""" """
def __init__(self, name, requests, _async=False, args=None, backend=None): def __init__(self, *args, **kwargs):
super().__init__( super().__init__(*args, **kwargs)
name=name, _async=_async, requests=requests, args=args, backend=backend
)
self._should_break = False self._should_break = False
self._should_continue = False self._should_continue = False
@ -381,6 +424,9 @@ class ForProcedure(LoopProcedure):
# pylint: disable=eval-used # pylint: disable=eval-used
def execute(self, *_, **context): def execute(self, *_, **context):
ctx = _update_context(context)
locals().update(ctx)
try: try:
iterable = eval(self.iterable) iterable = eval(self.iterable)
assert hasattr( assert hasattr(
@ -388,11 +434,18 @@ class ForProcedure(LoopProcedure):
), f'Object of type {type(iterable)} is not iterable: {iterable}' ), f'Object of type {type(iterable)} is not iterable: {iterable}'
except Exception as e: except Exception as e:
logger.debug('Iterable %s expansion error: %s', self.iterable, e) logger.debug('Iterable %s expansion error: %s', self.iterable, e)
iterable = Request.expand_value_from_context(self.iterable, **context) iterable = Request.expand_value_from_context(self.iterable, **ctx)
response = Response() response = Response()
for item in iterable: for item in iterable:
ctx[self.iterator_name] = item
response = super().execute(**ctx)
ctx.update(ctx.get('new_context', {}))
if response.output and isinstance(response.output, dict):
ctx = _update_context(ctx, **response.output)
if self._should_return: if self._should_return:
logger.info('Returning from %s', self.name) logger.info('Returning from %s', self.name)
break break
@ -407,9 +460,6 @@ class ForProcedure(LoopProcedure):
logger.info('Breaking loop %s', self.name) logger.info('Breaking loop %s', self.name)
break break
context[self.iterator_name] = item
response = super().execute(**context)
return response return response
@ -446,41 +496,23 @@ class WhileProcedure(LoopProcedure):
) )
self.condition = condition self.condition = condition
@staticmethod
def _get_context(**context):
for k, v in context.items():
try:
context[k] = eval(v) # pylint: disable=eval-used
except Exception as e:
logger.debug('Evaluation error for %s=%s: %s', k, v, e)
if isinstance(v, str):
try:
context[k] = eval( # pylint: disable=eval-used
'"' + re.sub(r'(^|[^\\])"', '\1\\"', v) + '"'
)
except Exception as ee:
logger.warning(
'Could not parse value for context variable %s=%s: %s',
k,
v,
ee,
)
logger.warning('Context: %s', context)
logger.exception(e)
return context
def execute(self, *_, **context): def execute(self, *_, **context):
response = Response() response = Response()
context = self._get_context(**context) ctx = _update_context(context)
for k, v in context.items(): locals().update(ctx)
locals()[k] = v
while True: while True:
condition_true = eval(self.condition) # pylint: disable=eval-used condition_true = eval(self.condition) # pylint: disable=eval-used
if not condition_true: if not condition_true:
break break
response = super().execute(**ctx)
ctx.update(ctx.get('new_context', {}))
if response.output and isinstance(response.output, dict):
_update_context(ctx, **response.output)
locals().update(ctx)
if self._should_return: if self._should_return:
logger.info('Returning from %s', self.name) logger.info('Returning from %s', self.name)
break break
@ -495,13 +527,6 @@ class WhileProcedure(LoopProcedure):
logger.info('Breaking loop %s', self.name) logger.info('Breaking loop %s', self.name)
break break
response = super().execute(**context)
if response.output and isinstance(response.output, dict):
new_context = self._get_context(**response.output)
for k, v in new_context.items():
locals()[k] = v
return response return response
@ -595,20 +620,28 @@ class IfProcedure(Procedure):
) )
def execute(self, *_, **context): def execute(self, *_, **context):
for k, v in context.items(): ctx = _update_context(context)
locals()[k] = v locals().update(ctx)
condition_true = eval(self.condition) # pylint: disable=eval-used condition_true = eval(self.condition) # pylint: disable=eval-used
response = Response() response = Response()
if condition_true: if condition_true:
response = super().execute(**context) response = super().execute(**ctx)
elif self.else_branch: elif self.else_branch:
response = self.else_branch.execute(**context) response = self.else_branch.execute(**ctx)
return response return response
def _update_context(context: Optional[Dict[str, Any]] = None, **kwargs):
ctx = context or {}
ctx = {**ctx.get('context', {}), **ctx, **kwargs}
for k, v in ctx.items():
ctx[k] = Request.expand_value_from_context(v, **ctx)
return ctx
def procedure(name_or_func: Optional[str] = None, *upper_args, **upper_kwargs): def procedure(name_or_func: Optional[str] = None, *upper_args, **upper_kwargs):
name = name_or_func if isinstance(name_or_func, str) else None name = name_or_func if isinstance(name_or_func, str) else None