diff --git a/platypush/plugins/procedures/__init__.py b/platypush/plugins/procedures/__init__.py index 1b82befcdb..e536201737 100644 --- a/platypush/plugins/procedures/__init__.py +++ b/platypush/plugins/procedures/__init__.py @@ -1,6 +1,8 @@ -from contextlib import contextmanager import json +from contextlib import contextmanager from dataclasses import dataclass +from multiprocessing import RLock +from random import randint from typing import Callable, Collection, Generator, Iterable, Optional, Union from sqlalchemy.orm import Session @@ -8,6 +10,7 @@ from sqlalchemy.orm import Session from platypush.context import get_plugin from platypush.entities.managers.procedures import ProcedureEntityManager from platypush.entities.procedures import Procedure, ProcedureType +from platypush.message.event.entities import EntityDeleteEvent from platypush.plugins import RunnablePlugin, action from platypush.plugins.db import DbPlugin from platypush.utils import run @@ -26,11 +29,60 @@ class ProceduresPlugin(RunnablePlugin, ProcedureEntityManager): Utility plugin to run and store procedures as native entities. """ - @action - def exec(self, procedure: str, *args, **kwargs): - return run(f'procedure.{procedure}', *args, **kwargs) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._status_lock = RLock() + + @action + def exec(self, procedure: Union[str, dict], *args, **kwargs): + """ + Execute a procedure. + + :param procedure: Procedure name or definition. If a string is passed, + then the procedure will be looked up by name in the configured + procedures. If a dictionary is passed, then it should be a valid + procedure definition with at least the ``actions`` key. + :param args: Optional arguments to be passed to the procedure. + :param kwargs: Optional arguments to be passed to the procedure. + """ + if isinstance(procedure, str): + return run(f'procedure.{procedure}', *args, **kwargs) + + assert isinstance(procedure, dict), 'Invalid procedure definition' + procedure_name = procedure.get( + 'name', f'procedure_{f"{randint(0, 1 << 32):08x}"}' + ) + + actions = procedure.get('actions') + assert actions and isinstance( + actions, (list, tuple, set) + ), 'Procedure definition should have at least the "actions" key as a list of actions' + + try: + # Create a temporary procedure definition and execute it + self._all_procedures[procedure_name] = { + 'name': procedure_name, + 'type': ProcedureType.CONFIG.value, + 'actions': list(actions), + 'args': procedure.get('args', []), + '_async': False, + } + + kwargs = { + **procedure.get('args', {}), + **kwargs, + } + + return self.exec(procedure_name, *args, **kwargs) + finally: + self._all_procedures.pop(procedure_name, None) + + def _convert_procedure( + self, name: str, proc: Union[dict, Callable, Procedure] + ) -> Procedure: + if isinstance(proc, Procedure): + return proc - def _convert_procedure(self, name: str, proc: Union[dict, Callable]) -> Procedure: metadata = self._serialize_procedure(proc, name=name) return Procedure( id=name, @@ -45,8 +97,15 @@ class ProceduresPlugin(RunnablePlugin, ProcedureEntityManager): ) @action - def status(self, *_, **__): + def status(self, *_, publish: bool = True, **__): """ + :param publish: If set to True (default) then the + :class:`platypush.message.event.entities.EntityUpdateEvent` events + will be published to the bus with the current configured procedures. + Usually this should be set to True, unless you're calling this method + from a context where you first want to retrieve the procedures and + then immediately modify them. In such cases, the published events may + result in race conditions on the entities engine. :return: The serialized configured procedures. Format: .. code-block:: json @@ -62,20 +121,11 @@ class ProceduresPlugin(RunnablePlugin, ProcedureEntityManager): } """ - self.publish_entities(self._get_wrapped_procedures()) - return self._get_serialized_procedures() + with self._status_lock: + if publish: + self.publish_entities(self._get_wrapped_procedures()) - def _update_procedure(self, old: Procedure, new: Procedure, session: Session): - assert old.procedure_type == ProcedureType.DB.value, ( # type: ignore[attr-defined] - f'Procedure {old.name} is not stored in the database, ' - f'it should be removed from the source file: {old.source}' - ) - - old.external_id = new.external_id - old.name = new.name - old.args = new.args - old.actions = new.actions - session.add(old) + return self._get_serialized_procedures() @action def save( @@ -115,61 +165,30 @@ class ProceduresPlugin(RunnablePlugin, ProcedureEntityManager): ), 'Procedure actions should be dictionaries with an "action" key' args = args or [] - proc_def = { - 'type': ProcedureType.DB.value, + proc_args = { 'name': name, + 'type': ProcedureType.DB.value, 'actions': actions, 'args': args, } - existing_proc = None - old_proc = None - new_proc = Procedure( - external_id=name, - plugin=str(self), - procedure_type=ProcedureType.DB.value, - name=name, - actions=actions, - args=args, - ) + with self._status_lock: + with self._db_session() as session: + if old_name and old_name != name: + try: + self._delete(old_name, session=session) + except AssertionError as e: + self.logger.warning( + 'Error while deleting old procedure: name=%s: %s', + old_name, + e, + ) - with self._db_session() as session: - if old_name and old_name != name: - old_proc = ( - session.query(Procedure).filter(Procedure.name == old_name).first() - ) + self._all_procedures[name] = proc_args - if old_proc: - self._update_procedure(old=old_proc, new=new_proc, session=session) - else: - self.logger.warning( - 'Procedure %s not found, skipping rename', old_name - ) + self.publish_entities([_ProcedureWrapper(name=name, obj=proc_args)]) - existing_proc = ( - session.query(Procedure).filter(Procedure.name == name).first() - ) - - if existing_proc: - if old_proc: - self._delete(str(existing_proc.name), session=session) - else: - self._update_procedure( - old=existing_proc, new=new_proc, session=session - ) - elif not old_proc: - session.add(new_proc) - - if old_proc: - old_name = str(old_proc.name) - self._all_procedures.pop(old_name, None) - - self._all_procedures[name] = { - **self._all_procedures.get(name, {}), # type: ignore[operator] - **proc_def, - } - - self.status() + return self.status() @action def delete(self, name: str): @@ -191,9 +210,7 @@ class ProceduresPlugin(RunnablePlugin, ProcedureEntityManager): def _db_session(self) -> Generator[Session, None, None]: db: Optional[DbPlugin] = get_plugin(DbPlugin) assert db, 'No database plugin configured' - with db.get_session( - autoflush=False, autocommit=False, expire_on_commit=False - ) as session: + with db.get_session(locked=True) as session: assert isinstance(session, Session) yield session @@ -216,12 +233,17 @@ class ProceduresPlugin(RunnablePlugin, ProcedureEntityManager): session.delete(proc_row) self._all_procedures.pop(name, None) + self._bus.post(EntityDeleteEvent(plugin=self, entity=proc_row)) def transform_entities( self, entities: Collection[_ProcedureWrapper], **_ ) -> Collection[Procedure]: return [ - self._convert_procedure(name=proc.name, proc=proc.obj) for proc in entities + self._convert_procedure( + name=proc.name, + proc=proc if isinstance(proc, Procedure) else proc.obj, + ) + for proc in entities ] def _get_wrapped_procedures(self) -> Collection[_ProcedureWrapper]: @@ -231,38 +253,40 @@ class ProceduresPlugin(RunnablePlugin, ProcedureEntityManager): ] def _sync_db_procedures(self): - cur_proc_names = set(self._all_procedures.keys()) - with self._db_session() as session: - saved_procs = { - str(proc.name): proc for proc in session.query(Procedure).all() - } - - procs_to_remove = [ - proc - for name, proc in saved_procs.items() - if name not in cur_proc_names - and proc.procedure_type != ProcedureType.DB.value # type: ignore[attr-defined] - ] - - for proc in procs_to_remove: - self.logger.info('Removing stale procedure record for %s', proc.name) - session.delete(proc) - - procs_to_add = [ - proc - for name, proc in saved_procs.items() - if proc.procedure_type == ProcedureType.DB.value # type: ignore[attr-defined] - and name not in cur_proc_names - ] - - for proc in procs_to_add: - self._all_procedures[str(proc.name)] = { - 'type': proc.procedure_type, - 'name': proc.name, - 'args': proc.args, - 'actions': proc.actions, + with self._status_lock: + cur_proc_names = set(self._all_procedures.keys()) + with self._db_session() as session: + saved_procs = { + str(proc.name): proc for proc in session.query(Procedure).all() } + procs_to_remove = [ + proc + for name, proc in saved_procs.items() + if name not in cur_proc_names + and proc.procedure_type != ProcedureType.DB.value # type: ignore[attr-defined] + ] + + for proc in procs_to_remove: + self.logger.info( + 'Removing stale procedure record for %s', proc.name + ) + session.delete(proc) + + procs_to_add = [ + proc + for proc in saved_procs.values() + if proc.procedure_type == ProcedureType.DB.value # type: ignore[attr-defined] + ] + + for proc in procs_to_add: + self._all_procedures[str(proc.name)] = { + 'type': proc.procedure_type, + 'name': proc.name, + 'args': proc.args, + 'actions': proc.actions, + } + @staticmethod def _serialize_procedure( proc: Union[dict, Callable], name: Optional[str] = None