diff --git a/platypush/entities/__init__.py b/platypush/entities/__init__.py index d51b8539..8ea774d0 100644 --- a/platypush/entities/__init__.py +++ b/platypush/entities/__init__.py @@ -1,7 +1,12 @@ import logging from typing import Collection, Optional -from ._base import Entity, get_entities_registry, init_entities_db +from ._base import ( + Entity, + EntitySavedCallback, + get_entities_registry, + init_entities_db, +) from ._engine import EntitiesEngine from ._managers import ( EntityManager, @@ -31,7 +36,9 @@ def init_entities_engine() -> EntitiesEngine: return _engine -def publish_entities(entities: Collection[Entity]): +def publish_entities( + entities: Collection[Entity], callback: Optional[EntitySavedCallback] = None +) -> None: """ Publish a collection of entities to be processed by the engine. @@ -47,7 +54,7 @@ def publish_entities(entities: Collection[Entity]): logger.debug('No entities engine registered') return - _engine.post(*entities) + _engine.post(*entities, callback=callback) __all__ = ( @@ -55,6 +62,7 @@ __all__ = ( 'EntitiesEngine', 'Entity', 'EntityManager', + 'EntitySavedCallback', 'EnumSwitchEntityManager', 'LightEntityManager', 'SensorEntityManager', diff --git a/platypush/entities/_base.py b/platypush/entities/_base.py index d3950585..415ffea1 100644 --- a/platypush/entities/_base.py +++ b/platypush/entities/_base.py @@ -4,7 +4,7 @@ import pathlib import types from datetime import datetime from dateutil.tz import tzutc -from typing import Mapping, Type, Tuple, Any +from typing import Callable, Mapping, Type, Tuple, Any import pkgutil from sqlalchemy import ( @@ -70,7 +70,7 @@ if 'entity' not in Base.metadata: 'Entity', remote_side=[id], uselist=False, - lazy=True, + lazy='selectin', post_update=True, backref=backref( 'children', @@ -105,7 +105,7 @@ if 'entity' not in Base.metadata: """ This method returns the "external" key of an entity. """ - return (str(self.external_id or self.id), str(self.plugin)) + return str(self.external_id), str(self.plugin) def _serialize_value(self, col: ColumnProperty) -> Any: val = getattr(self, col.key) @@ -153,6 +153,12 @@ if 'entity' not in Base.metadata: # standard multiple inheritance with an SQLAlchemy ORM class) Entity.__bases__ = Entity.__bases__ + (JSONAble,) + EntitySavedCallback = Callable[[Entity], None] + """ + Type for the callback functions that should be called when an entity is saved + on the database. + """ + def _discover_entity_types(): from platypush.context import get_plugin diff --git a/platypush/entities/_engine/__init__.py b/platypush/entities/_engine/__init__.py index 505d7a70..45a08403 100644 --- a/platypush/entities/_engine/__init__.py +++ b/platypush/entities/_engine/__init__.py @@ -1,12 +1,13 @@ from logging import getLogger from threading import Thread, Event +from typing import Dict, Optional, Tuple from platypush.context import get_bus from platypush.entities import Entity from platypush.message.event.entities import EntityUpdateEvent from platypush.utils import set_thread_name -# pylint: disable=no-name-in-module +from platypush.entities._base import EntitySavedCallback from platypush.entities._engine.queue import EntitiesQueue from platypush.entities._engine.repo import EntitiesRepository @@ -29,16 +30,25 @@ class EntitiesEngine(Thread): """ - def __init__(self): + def __init__(self) -> None: obj_name = self.__class__.__name__ super().__init__(name=obj_name) self.logger = getLogger(name=obj_name) self._should_stop = Event() + """ Event used to synchronize stop events downstream.""" self._queue = EntitiesQueue(stop_event=self._should_stop) + """ Queue where all entity upsert requests are received.""" self._repo = EntitiesRepository() + """ The repository of the processed entities. """ + self._callbacks: Dict[Tuple[str, str], EntitySavedCallback] = {} + """ (external_id, plugin) -> callback mapping""" + + def post(self, *entities: Entity, callback: Optional[EntitySavedCallback] = None): + if callback: + for entity in entities: + self._callbacks[entity.entity_key] = callback - def post(self, *entities: Entity): self._queue.put(*entities) @property @@ -52,10 +62,13 @@ class EntitiesEngine(Thread): """ Trigger an EntityUpdateEvent if the entity has been persisted, or queue it to the list of entities whose notifications will be flushed when the - session is committed. + session is committed. It will also invoke any registered callbacks. """ for entity in entities: get_bus().post(EntityUpdateEvent(entity=entity)) + callback = self._callbacks.pop(entity.entity_key, None) + if callback: + callback(entity) def run(self): super().run() diff --git a/platypush/entities/_engine/repo/__init__.py b/platypush/entities/_engine/repo/__init__.py index 1bf38131..ec4956a1 100644 --- a/platypush/entities/_engine/repo/__init__.py +++ b/platypush/entities/_engine/repo/__init__.py @@ -37,7 +37,9 @@ class EntitiesRepository: the taxonomies. """ - with self._db.get_session(locked=True, autoflush=False) as session: + with self._db.get_session( + locked=True, autoflush=False, expire_on_commit=False + ) as session: merged_entities = self._merger.merge(session, entities) merged_entities = self._db.upsert(session, merged_entities) diff --git a/platypush/entities/_managers/__init__.py b/platypush/entities/_managers/__init__.py index 3949469d..7f55d584 100644 --- a/platypush/entities/_managers/__init__.py +++ b/platypush/entities/_managers/__init__.py @@ -5,7 +5,7 @@ from datetime import datetime from typing import Any, Optional, Dict, Collection, Type from platypush.config import Config -from platypush.entities import Entity +from platypush.entities._base import Entity, EntitySavedCallback from platypush.utils import get_plugin_name_by_class, get_redis _entity_registry_varname = '_platypush/plugin_entity_registry' @@ -68,7 +68,7 @@ class EntityManager(ABC): def _normalize_entities(self, entities: Collection[Entity]) -> Collection[Entity]: for entity in entities: - if entity.id: + if entity.id and not entity.external_id: # Entity IDs can only refer to the internal primary key entity.external_id = entity.id entity.id = None # type: ignore @@ -80,7 +80,9 @@ class EntityManager(ABC): return entities def publish_entities( - self, entities: Optional[Collection[Any]] + self, + entities: Optional[Collection[Any]], + callback: Optional[EntitySavedCallback] = None, ) -> Collection[Entity]: """ Publishes a list of entities. The downstream consumers include: @@ -91,6 +93,9 @@ class EntityManager(ABC): :class:`platypush.message.event.entities.EntityUpdateEvent` events (e.g. web clients) + It also accepts an optional callback that will be called when each of + the entities in the set is flushed to the database. + You usually don't need to override this class (but you may want to extend :meth:`.transform_entities` instead if your extension doesn't natively handle `Entity` objects). @@ -101,7 +106,7 @@ class EntityManager(ABC): self.transform_entities(entities or []) ) - publish_entities(transformed_entities) + publish_entities(transformed_entities, callback=callback) return transformed_entities