diff --git a/platypush/plugins/procedures/__init__.py b/platypush/plugins/procedures/__init__.py index 5c5813e724..dfb98042bd 100644 --- a/platypush/plugins/procedures/__init__.py +++ b/platypush/plugins/procedures/__init__.py @@ -2,9 +2,11 @@ import json from dataclasses import dataclass from typing import Callable, Collection, Optional, Union +from platypush.context import get_plugin from platypush.entities.managers.procedures import ProcedureEntityManager from platypush.entities.procedures import Procedure from platypush.plugins import RunnablePlugin, action +from platypush.plugins.db import DbPlugin from platypush.utils import run from ._serialize import ProcedureEncoder @@ -72,6 +74,24 @@ class ProceduresPlugin(RunnablePlugin, ProcedureEntityManager): for name, proc in self._all_procedures.items() ] + def _sync_db_procedures(self): + cur_proc_names = set(self._all_procedures.keys()) + db: Optional[DbPlugin] = get_plugin(DbPlugin) + assert db, 'No database plugin configured' + + with db.get_session( + autoflush=False, autocommit=False, expire_on_commit=False + ) as session: + procs_to_remove = ( + session.query(Procedure) + .filter(Procedure.name.not_in(cur_proc_names)) + .all() + ) + + for proc in procs_to_remove: + self.logger.info('Removing stale procedure record for %s', proc.name) + session.delete(proc) + @staticmethod def _serialize_procedure( proc: Union[dict, Callable], name: Optional[str] = None @@ -89,6 +109,8 @@ class ProceduresPlugin(RunnablePlugin, ProcedureEntityManager): } def main(self, *_, **__): + self._sync_db_procedures() + while not self.should_stop(): self.publish_entities(self._get_wrapped_procedures()) self.wait_stop()