forked from platypush/platypush
[procedure] Several improvements to the procedure engine.
- Add `set` statement, which can be used to set context variables within YAML procedures. Example: ```yaml procedure.test: - set: foo: bar - action: logger.info args: msg: ${bar} ``` - More reliable flow control for nested break/continue/return. - Propagate changes to context variables also to upstream procedures.
This commit is contained in:
parent
771e32e368
commit
be8140ddb5
1 changed files with 119 additions and 86 deletions
|
@ -1,11 +1,12 @@
|
|||
import enum
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from functools import wraps
|
||||
|
||||
from queue import LifoQueue
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from ..common import exec_wrapper
|
||||
from ..config import Config
|
||||
|
@ -23,6 +24,7 @@ class StatementType(enum.Enum):
|
|||
BREAK = 'break'
|
||||
CONTINUE = 'continue'
|
||||
RETURN = 'return'
|
||||
SET = 'set'
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -32,7 +34,7 @@ class Statement:
|
|||
"""
|
||||
|
||||
type: StatementType
|
||||
argument: Optional[str] = None
|
||||
argument: Optional[Any] = None
|
||||
|
||||
@classmethod
|
||||
def build(cls, statement: str):
|
||||
|
@ -60,8 +62,30 @@ class ReturnStatement(Statement):
|
|||
|
||||
type: StatementType = StatementType.RETURN
|
||||
|
||||
def run(self, *_, **context) -> Any:
|
||||
return Response(
|
||||
output=Request.expand_value_from_context(
|
||||
self.argument, **_update_context(context)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SetStatement(Statement):
|
||||
"""
|
||||
Models a set variable statement in a procedure.
|
||||
"""
|
||||
|
||||
type: StatementType = StatementType.SET
|
||||
vars: dict = field(default_factory=dict)
|
||||
|
||||
def run(self, *_, **context):
|
||||
return Request.expand_value_from_context(self.argument, **context)
|
||||
vars = deepcopy(self.vars) # pylint: disable=redefined-builtin
|
||||
for k, v in vars.items():
|
||||
vars[k] = Request.expand_value_from_context(v, **context)
|
||||
|
||||
context.update(vars)
|
||||
return Response(output=vars)
|
||||
|
||||
|
||||
class Procedure:
|
||||
|
@ -117,10 +141,20 @@ class Procedure:
|
|||
# Check if it's a return statement with a value
|
||||
if (
|
||||
len(request_config.keys()) == 1
|
||||
and list(request_config.keys())[0] == 'return'
|
||||
and list(request_config.keys())[0] == StatementType.RETURN.value
|
||||
):
|
||||
cls._flush_if_statements(reqs, if_config)
|
||||
reqs.append(ReturnStatement(argument=request_config['return']))
|
||||
reqs.append(
|
||||
ReturnStatement(argument=request_config[StatementType.RETURN.value])
|
||||
)
|
||||
continue
|
||||
|
||||
# Check if it's a variable set statement
|
||||
if (len(request_config.keys()) == 1) and (
|
||||
list(request_config.keys())[0] == StatementType.SET.value
|
||||
):
|
||||
cls._flush_if_statements(reqs, if_config)
|
||||
reqs.append(SetStatement(vars=request_config[StatementType.SET.value]))
|
||||
continue
|
||||
|
||||
# Check if this request is an if-else
|
||||
|
@ -129,6 +163,7 @@ class Procedure:
|
|||
m = re.match(r'\s*(if)\s+\${(.*)}\s*', key)
|
||||
|
||||
if m:
|
||||
cls._flush_if_statements(reqs, if_config)
|
||||
if_count += 1
|
||||
if_name = f'{name}__if_{if_count}'
|
||||
condition = m.group(2)
|
||||
|
@ -233,86 +268,96 @@ class Procedure:
|
|||
pending_if = if_config.get()
|
||||
requests.append(IfProcedure.build(**pending_if))
|
||||
|
||||
@staticmethod
|
||||
def _find_nearest_loop(stack):
|
||||
for proc in stack[::-1]:
|
||||
if isinstance(proc, LoopProcedure):
|
||||
return proc
|
||||
|
||||
raise AssertionError('break/continue statement found outside of a loop')
|
||||
|
||||
# pylint: disable=too-many-branches,too-many-statements
|
||||
def execute(self, n_tries=1, __stack__=None, **context):
|
||||
def execute(
|
||||
self,
|
||||
n_tries: int = 1,
|
||||
__stack__: Optional[Iterable] = None,
|
||||
new_context: Optional[Dict[str, Any]] = None,
|
||||
**context,
|
||||
):
|
||||
"""
|
||||
Execute the requests in the procedure.
|
||||
|
||||
:param n_tries: Number of tries in case of failure before raising a RuntimeError.
|
||||
"""
|
||||
if not __stack__:
|
||||
__stack__ = [self]
|
||||
else:
|
||||
__stack__.append(self)
|
||||
__stack__ = (self,) if not __stack__ else (self, *__stack__)
|
||||
new_context = new_context or {}
|
||||
|
||||
if self.args:
|
||||
args = self.args.copy()
|
||||
for k, v in args.items():
|
||||
v = Request.expand_value_from_context(v, **context)
|
||||
args[k] = v
|
||||
context[k] = v
|
||||
args[k] = context[k] = Request.expand_value_from_context(v, **context)
|
||||
logger.info('Executing procedure %s with arguments %s', self.name, args)
|
||||
else:
|
||||
logger.info('Executing procedure %s', self.name)
|
||||
|
||||
response = Response()
|
||||
token = Config.get('token')
|
||||
context = _update_context(context)
|
||||
locals().update(context)
|
||||
|
||||
# pylint: disable=too-many-nested-blocks
|
||||
for request in self.requests:
|
||||
if callable(request):
|
||||
response = request(**context)
|
||||
continue
|
||||
|
||||
context['_async'] = self._async
|
||||
context['n_tries'] = n_tries
|
||||
context['__stack__'] = __stack__
|
||||
context['new_context'] = new_context
|
||||
|
||||
if isinstance(request, Statement):
|
||||
if isinstance(request, ReturnStatement):
|
||||
response = Response(output=request.run(**context))
|
||||
response = request.run(**context)
|
||||
self._should_return = True
|
||||
for proc in __stack__:
|
||||
proc._should_return = True # pylint: disable=protected-access
|
||||
|
||||
break
|
||||
|
||||
if isinstance(request, SetStatement):
|
||||
rs: dict = request.run(**context).output # type: ignore
|
||||
context.update(rs)
|
||||
new_context.update(rs)
|
||||
locals().update(rs)
|
||||
continue
|
||||
|
||||
if request.type in [StatementType.BREAK, StatementType.CONTINUE]:
|
||||
loop = self._find_nearest_loop(__stack__)
|
||||
if request.type == StatementType.BREAK:
|
||||
loop._should_break = True # pylint: disable=protected-access
|
||||
else:
|
||||
loop._should_continue = True # pylint: disable=protected-access
|
||||
for proc in __stack__:
|
||||
if isinstance(proc, LoopProcedure):
|
||||
if request.type == StatementType.BREAK:
|
||||
setattr(proc, '_should_break', True) # noqa: B010
|
||||
else:
|
||||
setattr(proc, '_should_continue', True) # noqa: B010
|
||||
break
|
||||
|
||||
proc._should_return = True # pylint: disable=protected-access
|
||||
|
||||
break
|
||||
|
||||
should_continue = getattr(self, '_should_continue', False)
|
||||
should_break = getattr(self, '_should_break', False)
|
||||
if isinstance(self, LoopProcedure) and (should_continue or should_break):
|
||||
if should_continue:
|
||||
setattr(self, '_should_continue', False) # noqa[B010]
|
||||
else:
|
||||
setattr(self, '_should_break', False) # noqa[B010]
|
||||
|
||||
if self._should_return or should_continue or should_break:
|
||||
break
|
||||
|
||||
if token and not isinstance(request, Statement):
|
||||
request.token = token
|
||||
|
||||
context['_async'] = self._async
|
||||
context['n_tries'] = n_tries
|
||||
exec_ = getattr(request, 'execute', None)
|
||||
if callable(exec_):
|
||||
response = exec_(__stack__=__stack__, **context)
|
||||
response = exec_(**context)
|
||||
context.update(context.get('new_context', {}))
|
||||
|
||||
if not self._async and response:
|
||||
if isinstance(response.output, dict):
|
||||
for k, v in response.output.items():
|
||||
context[k] = v
|
||||
context.update(response.output)
|
||||
|
||||
context['output'] = response.output
|
||||
context['errors'] = response.errors
|
||||
new_context.update(context)
|
||||
locals().update(context)
|
||||
|
||||
if self._should_return:
|
||||
break
|
||||
|
@ -333,10 +378,8 @@ class LoopProcedure(Procedure):
|
|||
Base class while and for/fork loops.
|
||||
"""
|
||||
|
||||
def __init__(self, name, requests, _async=False, args=None, backend=None):
|
||||
super().__init__(
|
||||
name=name, _async=_async, requests=requests, args=args, backend=backend
|
||||
)
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._should_break = False
|
||||
self._should_continue = False
|
||||
|
||||
|
@ -381,6 +424,9 @@ class ForProcedure(LoopProcedure):
|
|||
|
||||
# pylint: disable=eval-used
|
||||
def execute(self, *_, **context):
|
||||
ctx = _update_context(context)
|
||||
locals().update(ctx)
|
||||
|
||||
try:
|
||||
iterable = eval(self.iterable)
|
||||
assert hasattr(
|
||||
|
@ -388,11 +434,18 @@ class ForProcedure(LoopProcedure):
|
|||
), f'Object of type {type(iterable)} is not iterable: {iterable}'
|
||||
except Exception as e:
|
||||
logger.debug('Iterable %s expansion error: %s', self.iterable, e)
|
||||
iterable = Request.expand_value_from_context(self.iterable, **context)
|
||||
iterable = Request.expand_value_from_context(self.iterable, **ctx)
|
||||
|
||||
response = Response()
|
||||
|
||||
for item in iterable:
|
||||
ctx[self.iterator_name] = item
|
||||
response = super().execute(**ctx)
|
||||
ctx.update(ctx.get('new_context', {}))
|
||||
|
||||
if response.output and isinstance(response.output, dict):
|
||||
ctx = _update_context(ctx, **response.output)
|
||||
|
||||
if self._should_return:
|
||||
logger.info('Returning from %s', self.name)
|
||||
break
|
||||
|
@ -407,9 +460,6 @@ class ForProcedure(LoopProcedure):
|
|||
logger.info('Breaking loop %s', self.name)
|
||||
break
|
||||
|
||||
context[self.iterator_name] = item
|
||||
response = super().execute(**context)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
|
@ -446,41 +496,23 @@ class WhileProcedure(LoopProcedure):
|
|||
)
|
||||
self.condition = condition
|
||||
|
||||
@staticmethod
|
||||
def _get_context(**context):
|
||||
for k, v in context.items():
|
||||
try:
|
||||
context[k] = eval(v) # pylint: disable=eval-used
|
||||
except Exception as e:
|
||||
logger.debug('Evaluation error for %s=%s: %s', k, v, e)
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
context[k] = eval( # pylint: disable=eval-used
|
||||
'"' + re.sub(r'(^|[^\\])"', '\1\\"', v) + '"'
|
||||
)
|
||||
except Exception as ee:
|
||||
logger.warning(
|
||||
'Could not parse value for context variable %s=%s: %s',
|
||||
k,
|
||||
v,
|
||||
ee,
|
||||
)
|
||||
logger.warning('Context: %s', context)
|
||||
logger.exception(e)
|
||||
|
||||
return context
|
||||
|
||||
def execute(self, *_, **context):
|
||||
response = Response()
|
||||
context = self._get_context(**context)
|
||||
for k, v in context.items():
|
||||
locals()[k] = v
|
||||
ctx = _update_context(context)
|
||||
locals().update(ctx)
|
||||
|
||||
while True:
|
||||
condition_true = eval(self.condition) # pylint: disable=eval-used
|
||||
if not condition_true:
|
||||
break
|
||||
|
||||
response = super().execute(**ctx)
|
||||
ctx.update(ctx.get('new_context', {}))
|
||||
if response.output and isinstance(response.output, dict):
|
||||
_update_context(ctx, **response.output)
|
||||
|
||||
locals().update(ctx)
|
||||
|
||||
if self._should_return:
|
||||
logger.info('Returning from %s', self.name)
|
||||
break
|
||||
|
@ -495,13 +527,6 @@ class WhileProcedure(LoopProcedure):
|
|||
logger.info('Breaking loop %s', self.name)
|
||||
break
|
||||
|
||||
response = super().execute(**context)
|
||||
|
||||
if response.output and isinstance(response.output, dict):
|
||||
new_context = self._get_context(**response.output)
|
||||
for k, v in new_context.items():
|
||||
locals()[k] = v
|
||||
|
||||
return response
|
||||
|
||||
|
||||
|
@ -595,20 +620,28 @@ class IfProcedure(Procedure):
|
|||
)
|
||||
|
||||
def execute(self, *_, **context):
|
||||
for k, v in context.items():
|
||||
locals()[k] = v
|
||||
|
||||
ctx = _update_context(context)
|
||||
locals().update(ctx)
|
||||
condition_true = eval(self.condition) # pylint: disable=eval-used
|
||||
response = Response()
|
||||
|
||||
if condition_true:
|
||||
response = super().execute(**context)
|
||||
response = super().execute(**ctx)
|
||||
elif self.else_branch:
|
||||
response = self.else_branch.execute(**context)
|
||||
response = self.else_branch.execute(**ctx)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def _update_context(context: Optional[Dict[str, Any]] = None, **kwargs):
|
||||
ctx = context or {}
|
||||
ctx = {**ctx.get('context', {}), **ctx, **kwargs}
|
||||
for k, v in ctx.items():
|
||||
ctx[k] = Request.expand_value_from_context(v, **ctx)
|
||||
|
||||
return ctx
|
||||
|
||||
|
||||
def procedure(name_or_func: Optional[str] = None, *upper_args, **upper_kwargs):
|
||||
name = name_or_func if isinstance(name_or_func, str) else None
|
||||
|
||||
|
|
Loading…
Reference in a new issue