[procedure] Several improvements to the procedure engine.

- Add `set` statement, which can be used to set context variables within
  YAML procedures. Example:

```yaml
procedure.test:
  - set:
      foo: bar

  - action: logger.info
    args:
      msg: ${bar}
```

- More reliable flow control for nested break/continue/return.

- Propagate changes to context variables also to upstream procedures.
This commit is contained in:
Fabio Manganiello 2024-09-16 03:16:53 +02:00
parent 771e32e368
commit be8140ddb5
Signed by untrusted user: blacklight
GPG key ID: D90FBA7F76362774

View file

@ -1,11 +1,12 @@
import enum
import logging
import re
from dataclasses import dataclass
from copy import deepcopy
from dataclasses import dataclass, field
from functools import wraps
from queue import LifoQueue
from typing import Any, List, Optional
from typing import Any, Dict, Iterable, List, Optional
from ..common import exec_wrapper
from ..config import Config
@ -23,6 +24,7 @@ class StatementType(enum.Enum):
BREAK = 'break'
CONTINUE = 'continue'
RETURN = 'return'
SET = 'set'
@dataclass
@ -32,7 +34,7 @@ class Statement:
"""
type: StatementType
argument: Optional[str] = None
argument: Optional[Any] = None
@classmethod
def build(cls, statement: str):
@ -60,8 +62,30 @@ class ReturnStatement(Statement):
type: StatementType = StatementType.RETURN
def run(self, *_, **context) -> Any:
return Response(
output=Request.expand_value_from_context(
self.argument, **_update_context(context)
)
)
@dataclass
class SetStatement(Statement):
"""
Models a set variable statement in a procedure.
"""
type: StatementType = StatementType.SET
vars: dict = field(default_factory=dict)
def run(self, *_, **context):
return Request.expand_value_from_context(self.argument, **context)
vars = deepcopy(self.vars) # pylint: disable=redefined-builtin
for k, v in vars.items():
vars[k] = Request.expand_value_from_context(v, **context)
context.update(vars)
return Response(output=vars)
class Procedure:
@ -117,10 +141,20 @@ class Procedure:
# Check if it's a return statement with a value
if (
len(request_config.keys()) == 1
and list(request_config.keys())[0] == 'return'
and list(request_config.keys())[0] == StatementType.RETURN.value
):
cls._flush_if_statements(reqs, if_config)
reqs.append(ReturnStatement(argument=request_config['return']))
reqs.append(
ReturnStatement(argument=request_config[StatementType.RETURN.value])
)
continue
# Check if it's a variable set statement
if (len(request_config.keys()) == 1) and (
list(request_config.keys())[0] == StatementType.SET.value
):
cls._flush_if_statements(reqs, if_config)
reqs.append(SetStatement(vars=request_config[StatementType.SET.value]))
continue
# Check if this request is an if-else
@ -129,6 +163,7 @@ class Procedure:
m = re.match(r'\s*(if)\s+\${(.*)}\s*', key)
if m:
cls._flush_if_statements(reqs, if_config)
if_count += 1
if_name = f'{name}__if_{if_count}'
condition = m.group(2)
@ -233,86 +268,96 @@ class Procedure:
pending_if = if_config.get()
requests.append(IfProcedure.build(**pending_if))
@staticmethod
def _find_nearest_loop(stack):
for proc in stack[::-1]:
if isinstance(proc, LoopProcedure):
return proc
raise AssertionError('break/continue statement found outside of a loop')
# pylint: disable=too-many-branches,too-many-statements
def execute(self, n_tries=1, __stack__=None, **context):
def execute(
self,
n_tries: int = 1,
__stack__: Optional[Iterable] = None,
new_context: Optional[Dict[str, Any]] = None,
**context,
):
"""
Execute the requests in the procedure.
:param n_tries: Number of tries in case of failure before raising a RuntimeError.
"""
if not __stack__:
__stack__ = [self]
else:
__stack__.append(self)
__stack__ = (self,) if not __stack__ else (self, *__stack__)
new_context = new_context or {}
if self.args:
args = self.args.copy()
for k, v in args.items():
v = Request.expand_value_from_context(v, **context)
args[k] = v
context[k] = v
args[k] = context[k] = Request.expand_value_from_context(v, **context)
logger.info('Executing procedure %s with arguments %s', self.name, args)
else:
logger.info('Executing procedure %s', self.name)
response = Response()
token = Config.get('token')
context = _update_context(context)
locals().update(context)
# pylint: disable=too-many-nested-blocks
for request in self.requests:
if callable(request):
response = request(**context)
continue
context['_async'] = self._async
context['n_tries'] = n_tries
context['__stack__'] = __stack__
context['new_context'] = new_context
if isinstance(request, Statement):
if isinstance(request, ReturnStatement):
response = Response(output=request.run(**context))
response = request.run(**context)
self._should_return = True
for proc in __stack__:
proc._should_return = True # pylint: disable=protected-access
break
if isinstance(request, SetStatement):
rs: dict = request.run(**context).output # type: ignore
context.update(rs)
new_context.update(rs)
locals().update(rs)
continue
if request.type in [StatementType.BREAK, StatementType.CONTINUE]:
loop = self._find_nearest_loop(__stack__)
for proc in __stack__:
if isinstance(proc, LoopProcedure):
if request.type == StatementType.BREAK:
loop._should_break = True # pylint: disable=protected-access
setattr(proc, '_should_break', True) # noqa: B010
else:
loop._should_continue = True # pylint: disable=protected-access
setattr(proc, '_should_continue', True) # noqa: B010
break
proc._should_return = True # pylint: disable=protected-access
break
should_continue = getattr(self, '_should_continue', False)
should_break = getattr(self, '_should_break', False)
if isinstance(self, LoopProcedure) and (should_continue or should_break):
if should_continue:
setattr(self, '_should_continue', False) # noqa[B010]
else:
setattr(self, '_should_break', False) # noqa[B010]
if self._should_return or should_continue or should_break:
break
if token and not isinstance(request, Statement):
request.token = token
context['_async'] = self._async
context['n_tries'] = n_tries
exec_ = getattr(request, 'execute', None)
if callable(exec_):
response = exec_(__stack__=__stack__, **context)
response = exec_(**context)
context.update(context.get('new_context', {}))
if not self._async and response:
if isinstance(response.output, dict):
for k, v in response.output.items():
context[k] = v
context.update(response.output)
context['output'] = response.output
context['errors'] = response.errors
new_context.update(context)
locals().update(context)
if self._should_return:
break
@ -333,10 +378,8 @@ class LoopProcedure(Procedure):
Base class while and for/fork loops.
"""
def __init__(self, name, requests, _async=False, args=None, backend=None):
super().__init__(
name=name, _async=_async, requests=requests, args=args, backend=backend
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._should_break = False
self._should_continue = False
@ -381,6 +424,9 @@ class ForProcedure(LoopProcedure):
# pylint: disable=eval-used
def execute(self, *_, **context):
ctx = _update_context(context)
locals().update(ctx)
try:
iterable = eval(self.iterable)
assert hasattr(
@ -388,11 +434,18 @@ class ForProcedure(LoopProcedure):
), f'Object of type {type(iterable)} is not iterable: {iterable}'
except Exception as e:
logger.debug('Iterable %s expansion error: %s', self.iterable, e)
iterable = Request.expand_value_from_context(self.iterable, **context)
iterable = Request.expand_value_from_context(self.iterable, **ctx)
response = Response()
for item in iterable:
ctx[self.iterator_name] = item
response = super().execute(**ctx)
ctx.update(ctx.get('new_context', {}))
if response.output and isinstance(response.output, dict):
ctx = _update_context(ctx, **response.output)
if self._should_return:
logger.info('Returning from %s', self.name)
break
@ -407,9 +460,6 @@ class ForProcedure(LoopProcedure):
logger.info('Breaking loop %s', self.name)
break
context[self.iterator_name] = item
response = super().execute(**context)
return response
@ -446,41 +496,23 @@ class WhileProcedure(LoopProcedure):
)
self.condition = condition
@staticmethod
def _get_context(**context):
for k, v in context.items():
try:
context[k] = eval(v) # pylint: disable=eval-used
except Exception as e:
logger.debug('Evaluation error for %s=%s: %s', k, v, e)
if isinstance(v, str):
try:
context[k] = eval( # pylint: disable=eval-used
'"' + re.sub(r'(^|[^\\])"', '\1\\"', v) + '"'
)
except Exception as ee:
logger.warning(
'Could not parse value for context variable %s=%s: %s',
k,
v,
ee,
)
logger.warning('Context: %s', context)
logger.exception(e)
return context
def execute(self, *_, **context):
response = Response()
context = self._get_context(**context)
for k, v in context.items():
locals()[k] = v
ctx = _update_context(context)
locals().update(ctx)
while True:
condition_true = eval(self.condition) # pylint: disable=eval-used
if not condition_true:
break
response = super().execute(**ctx)
ctx.update(ctx.get('new_context', {}))
if response.output and isinstance(response.output, dict):
_update_context(ctx, **response.output)
locals().update(ctx)
if self._should_return:
logger.info('Returning from %s', self.name)
break
@ -495,13 +527,6 @@ class WhileProcedure(LoopProcedure):
logger.info('Breaking loop %s', self.name)
break
response = super().execute(**context)
if response.output and isinstance(response.output, dict):
new_context = self._get_context(**response.output)
for k, v in new_context.items():
locals()[k] = v
return response
@ -595,20 +620,28 @@ class IfProcedure(Procedure):
)
def execute(self, *_, **context):
for k, v in context.items():
locals()[k] = v
ctx = _update_context(context)
locals().update(ctx)
condition_true = eval(self.condition) # pylint: disable=eval-used
response = Response()
if condition_true:
response = super().execute(**context)
response = super().execute(**ctx)
elif self.else_branch:
response = self.else_branch.execute(**context)
response = self.else_branch.execute(**ctx)
return response
def _update_context(context: Optional[Dict[str, Any]] = None, **kwargs):
ctx = context or {}
ctx = {**ctx.get('context', {}), **ctx, **kwargs}
for k, v in ctx.items():
ctx[k] = Request.expand_value_from_context(v, **ctx)
return ctx
def procedure(name_or_func: Optional[str] = None, *upper_args, **upper_kwargs):
name = name_or_func if isinstance(name_or_func, str) else None