From b0464219d3f9f8a2bbee2723913d0639e5278c0d Mon Sep 17 00:00:00 2001 From: Fabio Manganiello Date: Sat, 17 Dec 2022 21:41:23 +0100 Subject: [PATCH] Large refactor of the entities engine. --- platypush/entities/_base.py | 196 ++++++++------ platypush/entities/_engine.py | 281 -------------------- platypush/entities/_engine/__init__.py | 76 ++++++ platypush/entities/_engine/notifier.py | 47 ++++ platypush/entities/_engine/queue.py | 50 ++++ platypush/entities/_engine/repo/__init__.py | 86 ++++++ platypush/entities/_engine/repo/cache.py | 51 ++++ platypush/entities/_engine/repo/db.py | 187 +++++++++++++ platypush/entities/_engine/repo/merger.py | 143 ++++++++++ 9 files changed, 757 insertions(+), 360 deletions(-) delete mode 100644 platypush/entities/_engine.py create mode 100644 platypush/entities/_engine/__init__.py create mode 100644 platypush/entities/_engine/notifier.py create mode 100644 platypush/entities/_engine/queue.py create mode 100644 platypush/entities/_engine/repo/__init__.py create mode 100644 platypush/entities/_engine/repo/cache.py create mode 100644 platypush/entities/_engine/repo/db.py create mode 100644 platypush/entities/_engine/repo/merger.py diff --git a/platypush/entities/_base.py b/platypush/entities/_base.py index da56b6819..71e231a6d 100644 --- a/platypush/entities/_base.py +++ b/platypush/entities/_base.py @@ -1,19 +1,21 @@ import inspect +import json import pathlib import types from datetime import datetime -from typing import Dict, Mapping, Type, Tuple, Any +from dateutil.tz import tzutc +from typing import Mapping, Type, Tuple, Any import pkgutil from sqlalchemy import ( Boolean, Column, + DateTime, ForeignKey, Index, Integer, - String, - DateTime, JSON, + String, UniqueConstraint, inspect as schema_inspect, ) @@ -23,100 +25,136 @@ from platypush.common.db import Base from platypush.message import JSONAble entities_registry: Mapping[Type['Entity'], Mapping] = {} -entity_types_registry: Dict[str, Type['Entity']] = {} -class Entity(Base): - """ - Model for a general-purpose platform entity. - """ +if 'entity' not in Base.metadata: - __tablename__ = 'entity' + class Entity(Base): + """ + Model for a general-purpose platform entity. + """ - id = Column(Integer, autoincrement=True, primary_key=True) - external_id = Column(String, nullable=True) - name = Column(String, nullable=False, index=True) - description = Column(String) - type = Column(String, nullable=False, index=True) - plugin = Column(String, nullable=False) - parent_id = Column( - Integer, - ForeignKey(f'{__tablename__}.id', ondelete='CASCADE'), - nullable=True, - ) + __tablename__ = 'entity' - data = Column(JSON, default=dict) - meta = Column(JSON, default=dict) - is_read_only = Column(Boolean, default=False) - is_write_only = Column(Boolean, default=False) - is_query_disabled = Column(Boolean, default=False) - created_at = Column( - DateTime(timezone=False), default=datetime.utcnow(), nullable=False - ) - updated_at = Column( - DateTime(timezone=False), default=datetime.utcnow(), onupdate=datetime.utcnow() - ) + id = Column(Integer, autoincrement=True, primary_key=True) + external_id = Column(String, nullable=False) + name = Column(String, nullable=False, index=True) + description = Column(String) + type = Column(String, nullable=False, index=True) + plugin = Column(String, nullable=False) + parent_id = Column( + Integer, + ForeignKey(f'{__tablename__}.id', ondelete='CASCADE'), + nullable=True, + ) - parent: Mapped['Entity'] = relationship( - 'Entity', - remote_side=[id], - uselist=False, - lazy=True, - backref=backref( - 'children', - remote_side=[parent_id], - uselist=True, - cascade='all, delete-orphan', - ), - ) + data = Column(JSON, default=dict) + meta = Column(JSON, default=dict) + is_read_only = Column(Boolean, default=False) + is_write_only = Column(Boolean, default=False) + is_query_disabled = Column(Boolean, default=False) + created_at = Column( + DateTime(timezone=False), default=datetime.utcnow(), nullable=False + ) + updated_at = Column( + DateTime(timezone=False), + default=datetime.utcnow(), + onupdate=datetime.utcnow(), + ) - UniqueConstraint(external_id, plugin) + parent: Mapped['Entity'] = relationship( + 'Entity', + remote_side=[id], + uselist=False, + lazy=True, + post_update=True, + backref=backref( + 'children', + remote_side=[parent_id], + uselist=True, + cascade='all, delete-orphan', + ), + ) - __table_args__ = ( - Index('name_and_plugin_index', name, plugin), - Index('name_type_and_plugin_index', name, type, plugin), - {'extend_existing': True}, - ) + UniqueConstraint(external_id, plugin) - __mapper_args__ = { - 'polymorphic_identity': __tablename__, - 'polymorphic_on': type, - } + __table_args__ = ( + Index('name_and_plugin_index', name, plugin), + Index('name_type_and_plugin_index', name, type, plugin), + {'extend_existing': True}, + ) - @classmethod - @property - def columns(cls) -> Tuple[ColumnProperty]: - inspector = schema_inspect(cls) - return tuple(inspector.mapper.column_attrs) + __mapper_args__ = { + 'polymorphic_identity': __tablename__, + 'polymorphic_on': type, + } - def _serialize_value(self, col: ColumnProperty) -> Any: - val = getattr(self, col.key) - if isinstance(val, datetime): - # All entity timestamps are in UTC - val = val.isoformat() + '+00:00' + @classmethod + @property + def columns(cls) -> Tuple[ColumnProperty]: + inspector = schema_inspect(cls) + return tuple(inspector.mapper.column_attrs) - return val + @property + def entity_key(self) -> Tuple[str, str]: + """ + This method returns the "external" key of an entity. + """ + return (str(self.external_id), str(self.plugin)) - def to_json(self) -> dict: - return {col.key: self._serialize_value(col) for col in self.columns} + def _serialize_value(self, col: ColumnProperty) -> Any: + val = getattr(self, col.key) + if isinstance(val, datetime): + # All entity timestamps are in UTC + val = val.replace(tzinfo=tzutc()).isoformat() - def get_plugin(self): - from platypush.context import get_plugin + return val - plugin = get_plugin(self.plugin) - assert plugin, f'No such plugin: {plugin}' - return plugin + def copy(self) -> 'Entity': + args = {c.key: getattr(self, c.key) for c in self.columns} + # if self.parent: + # args['parent'] = self.parent.copy() - def run(self, action: str, *args, **kwargs): - plugin = self.get_plugin() - method = getattr(plugin, action, None) - assert method, f'No such action: {self.plugin}.{action}' - return method(self.external_id or self.name, *args, **kwargs) + # args['children'] = [c.copy() for c in self.children] + return self.__class__(**args) + def to_json(self) -> dict: + return {col.key: self._serialize_value(col) for col in self.columns} -# Inject the JSONAble mixin (Python goes nuts if done through -# standard multiple inheritance with an SQLAlchemy ORM class) -Entity.__bases__ = Entity.__bases__ + (JSONAble,) + def __repr__(self): + return str(self) + + def __str__(self): + return json.dumps(self.to_json()) + + def __setattr__(self, key, value): + matching_columns = [c for c in self.columns if c.expression.name == key] + + if ( + matching_columns + and issubclass(type(matching_columns[0].columns[0].type), DateTime) + and isinstance(value, str) + ): + value = datetime.fromisoformat(value) + + return super().__setattr__(key, value) + + def get_plugin(self): + from platypush.context import get_plugin + + plugin = get_plugin(self.plugin) + assert plugin, f'No such plugin: {plugin}' + return plugin + + def run(self, action: str, *args, **kwargs): + plugin = self.get_plugin() + method = getattr(plugin, action, None) + assert method, f'No such action: {self.plugin}.{action}' + return method(self.external_id or self.name, *args, **kwargs) + + # Inject the JSONAble mixin (Python goes nuts if done through + # standard multiple inheritance with an SQLAlchemy ORM class) + Entity.__bases__ = Entity.__bases__ + (JSONAble,) def _discover_entity_types(): diff --git a/platypush/entities/_engine.py b/platypush/entities/_engine.py deleted file mode 100644 index 2780596a3..000000000 --- a/platypush/entities/_engine.py +++ /dev/null @@ -1,281 +0,0 @@ -import json -from logging import getLogger -from queue import Queue, Empty -from threading import Thread, Event, RLock -from time import time -from typing import Iterable, List, Optional - -from sqlalchemy import and_, or_ -from sqlalchemy.exc import InvalidRequestError -from sqlalchemy.orm import Session, make_transient - -from platypush.context import get_bus, get_plugin -from platypush.message.event.entities import EntityUpdateEvent - -from ._base import Entity - - -class EntitiesEngine(Thread): - # Processing queue timeout in seconds - _queue_timeout = 2.0 - - def __init__(self): - obj_name = self.__class__.__name__ - super().__init__(name=obj_name) - self.logger = getLogger(name=obj_name) - self._queue = Queue() - self._should_stop = Event() - self._entities_awaiting_flush = set() - self._entities_cache_lock = RLock() - self._entities_cache = { - 'by_id': {}, - 'by_external_id_and_plugin': {}, - 'by_name_and_plugin': {}, - } - - def _get_session(self, *args, **kwargs): - db = get_plugin('db') - assert db - return db.get_session(*args, **kwargs) - - def _get_cached_entity(self, entity: Entity) -> Optional[dict]: - if entity.id: - e = self._entities_cache['by_id'].get(entity.id) - if e: - return e - - if entity.external_id and entity.plugin: - e = self._entities_cache['by_external_id_and_plugin'].get( - (entity.external_id, entity.plugin) - ) - if e: - return e - - if entity.name and entity.plugin: - e = self._entities_cache['by_name_and_plugin'].get( - (entity.name, entity.plugin) - ) - if e: - return e - - @staticmethod - def _cache_repr(entity: Entity) -> dict: - repr_ = entity.to_json() - repr_.pop('data', None) - repr_.pop('meta', None) - repr_.pop('created_at', None) - repr_.pop('updated_at', None) - return repr_ - - def _cache_entities(self, *entities: Entity, overwrite_cache=False): - for entity in entities: - e = self._cache_repr(entity) - if not overwrite_cache: - existing_entity = self._entities_cache['by_id'].get(entity.id) - if existing_entity: - for k, v in existing_entity.items(): - if e.get(k) is None: - e[k] = v - - if entity.id: - self._entities_cache['by_id'][entity.id] = e - if entity.external_id and entity.plugin: - self._entities_cache['by_external_id_and_plugin'][ - (entity.external_id, entity.plugin) - ] = e - if entity.name and entity.plugin: - self._entities_cache['by_name_and_plugin'][ - (entity.name, entity.plugin) - ] = e - - def _populate_entity_id_from_cache(self, new_entity: Entity): - with self._entities_cache_lock: - cached_entity = self._get_cached_entity(new_entity) - if cached_entity and cached_entity.get('id'): - new_entity.id = cached_entity['id'] - if new_entity.id: - self._cache_entities(new_entity) - - def _init_entities_cache(self): - with self._get_session() as session: - entities = session.query(Entity).all() - for entity in entities: - make_transient(entity) - - with self._entities_cache_lock: - self._cache_entities(*entities, overwrite_cache=True) - - self.logger.info('Entities cache initialized') - - def _process_event(self, entity: Entity): - self._populate_entity_id_from_cache(entity) - if entity.id: - get_bus().post(EntityUpdateEvent(entity=entity)) - else: - self._entities_awaiting_flush.add(self._to_entity_awaiting_flush(entity)) - - @staticmethod - def _to_entity_awaiting_flush(entity: Entity): - e = entity.to_json() - return json.dumps( - {k: v for k, v in e.items() if k in {'external_id', 'name', 'plugin'}}, - sort_keys=True, - ) - - def post(self, *entities: Entity): - for entity in entities: - self._queue.put(entity) - - @property - def should_stop(self) -> bool: - return self._should_stop.is_set() - - def stop(self): - self._should_stop.set() - - def run(self): - super().run() - self.logger.info('Started entities engine') - self._init_entities_cache() - - while not self.should_stop: - msgs = [] - last_poll_time = time() - - while not self.should_stop and ( - time() - last_poll_time < self._queue_timeout - ): - try: - msg = self._queue.get(block=True, timeout=0.5) - except Empty: - continue - - if msg: - msgs.append(msg) - # Trigger an EntityUpdateEvent if there has - # been a change on the entity state - self._process_event(msg) - - if not msgs or self.should_stop: - continue - - try: - self._process_entities(*msgs) - except Exception as e: - self.logger.error('Error while processing entity updates: ' + str(e)) - self.logger.exception(e) - - self.logger.info('Stopped entities engine') - - def get_if_exist( - self, session: Session, entities: Iterable[Entity] - ) -> List[Entity]: - existing_entities = { - ( - str(entity.external_id) - if entity.external_id is not None - else entity.name, - entity.plugin, - ): entity - for entity in session.query(Entity) - .filter( - or_( - *[ - and_( - Entity.external_id == entity.external_id, - Entity.plugin == entity.plugin, - ) - if entity.external_id is not None - else and_( - Entity.name == entity.name, - Entity.type == entity.type, - Entity.plugin == entity.plugin, - ) - for entity in entities - ] - ) - ) - .all() - } - - return [ - existing_entities.get( - ( - str(entity.external_id) - if entity.external_id is not None - else entity.name, - entity.plugin, - ), - None, - ) - for entity in entities - ] - - def _process_entities(self, *entities: Entity): # type: ignore - with self._get_session(locked=True) as session: - # Ensure that the internal IDs are set to null before the merge - for e in entities: - e.id = None # type: ignore - - entities: List[Entity] = self._merge_entities(session, entities) - session.add_all(entities) - session.commit() - - for e in entities: - try: - session.expunge(e) - except InvalidRequestError: - pass - - with self._entities_cache_lock: - for entity in entities: - self._cache_entities(entity, overwrite_cache=True) - - entities_awaiting_flush = {*self._entities_awaiting_flush} - for entity in entities: - e = self._to_entity_awaiting_flush(entity) - if e in entities_awaiting_flush: - self._process_event(entity) - self._entities_awaiting_flush.remove(e) - - def _merge_entities( - self, session: Session, entities: Iterable[Entity] - ) -> List[Entity]: - existing_entities = self.get_if_exist(session, entities) - new_entities = {} - entities_map = {} - - # Get the latest update for each ((id|name), plugin) record - for e in entities: - entities_map[self.entity_key(e)] = e - - # Retrieve existing records and merge them - for i, entity in enumerate(entities): - existing_entity = existing_entities[i] - if existing_entity: - existing_entity.children = self._merge_entities( - session, entity.children - ) - entity = self._merge_entity_columns(entity, existing_entity) - - new_entities[self.entity_key(entity)] = entity - - return list(new_entities.values()) - - @classmethod - def _merge_entity_columns(cls, entity: Entity, existing_entity: Entity) -> Entity: - columns = [col.key for col in entity.columns] - for col in columns: - if col == 'meta': - existing_entity.meta = { # type: ignore - **(existing_entity.meta or {}), # type: ignore - **(entity.meta or {}), # type: ignore - } - elif col not in ('id', 'created_at'): - setattr(existing_entity, col, getattr(entity, col)) - - return existing_entity - - @staticmethod - def entity_key(entity: Entity): - return ((entity.external_id or entity.name), entity.plugin) diff --git a/platypush/entities/_engine/__init__.py b/platypush/entities/_engine/__init__.py new file mode 100644 index 000000000..db071186d --- /dev/null +++ b/platypush/entities/_engine/__init__.py @@ -0,0 +1,76 @@ +from logging import getLogger +from threading import Thread, Event + +from platypush.entities import Entity +from platypush.utils import set_thread_name + +from platypush.entities._engine.notifier import EntityNotifier +from platypush.entities._engine.queue import EntitiesQueue +from platypush.entities._engine.repo import EntitiesRepository + + +class EntitiesEngine(Thread): + """ + This thread runs the "brain" of the entities data persistence logic. + + Its purpose is to: + + 1. Consume entities from a queue (synchronized with the upstream + integrations that produce/handle them). The producer/consumer model + ensure that only this thread writes to the database, packs events + together (preventing eccessive writes and throttling events), and + prevents race conditions when SQLite is used. + 2. Merge any existing entities with their newer representations. + 3. Update the entities taxonomy. + 4. Persist the new state to the entities database. + 5. Trigger events for the updated entities. + + """ + + def __init__(self): + obj_name = self.__class__.__name__ + super().__init__(name=obj_name) + + self.logger = getLogger(name=obj_name) + self._should_stop = Event() + self._queue = EntitiesQueue(stop_event=self._should_stop) + self._repo = EntitiesRepository() + self._notifier = EntityNotifier(self._repo._cache) + + def post(self, *entities: Entity): + self._queue.put(*entities) + + @property + def should_stop(self) -> bool: + return self._should_stop.is_set() + + def stop(self): + self._should_stop.set() + + def run(self): + super().run() + set_thread_name('entities') + self.logger.info('Started entities engine') + + while not self.should_stop: + # Get a batch of entity updates forwarded by other integrations + entities = self._queue.get() + if not entities or self.should_stop: + continue + + # Trigger/prepare EntityUpdateEvent objects + for entity in entities: + self._notifier.notify(entity) + + # Store the batch of entities + try: + entities = self._repo.save(*entities) + except Exception as e: + self.logger.error('Error while processing entity updates: ' + str(e)) + self.logger.exception(e) + continue + + # Flush any pending notifications + self._notifier.flush(*entities) + + self.logger.info('Stopped entities engine') diff --git a/platypush/entities/_engine/notifier.py b/platypush/entities/_engine/notifier.py new file mode 100644 index 000000000..cc45fece2 --- /dev/null +++ b/platypush/entities/_engine/notifier.py @@ -0,0 +1,47 @@ +from platypush.context import get_bus +from platypush.entities import Entity +from platypush.message.event.entities import EntityUpdateEvent + +from platypush.entities._engine.repo.cache import EntitiesCache + + +class EntityNotifier: + """ + This object is in charge of forwarding EntityUpdateEvent instances on the + application bus when some entities are changed. + """ + + def __init__(self, cache: EntitiesCache): + self._cache = cache + self._entities_awaiting_flush = set() + + def _populate_entity_id_from_cache(self, new_entity: Entity): + cached_entity = self._cache.get(new_entity) + if cached_entity and cached_entity.id: + new_entity.id = cached_entity.id + if new_entity.id: + self._cache.update(new_entity) + + def notify(self, entity: Entity): + """ + Trigger an EntityUpdateEvent if the entity has been persisted, or queue + it to the list of entities whose notifications will be flushed when the + session is committed. + """ + self._populate_entity_id_from_cache(entity) + if entity.id: + get_bus().post(EntityUpdateEvent(entity=entity)) + else: + self._entities_awaiting_flush.add(entity.entity_key) + + def flush(self, *entities: Entity): + """ + Flush and process any entities with pending EntityUpdateEvent + notifications. + """ + entities_awaiting_flush = {*self._entities_awaiting_flush} + for entity in entities: + key = entity.entity_key + if key in entities_awaiting_flush: + self.notify(entity) + self._entities_awaiting_flush.remove(key) diff --git a/platypush/entities/_engine/queue.py b/platypush/entities/_engine/queue.py new file mode 100644 index 000000000..bf8014eb8 --- /dev/null +++ b/platypush/entities/_engine/queue.py @@ -0,0 +1,50 @@ +from queue import Queue, Empty +from threading import Event +from time import time +from typing import List, Optional + +from platypush.entities import Entity + + +class EntitiesQueue(Queue): + """ + Extends the ``Queue`` class to provide an abstraction that allows to + getting and putting multiple entities at once and synchronize with the + upstream caller. + """ + + def __init__(self, stop_event: Optional[Event] = None, timeout: float = 2.0): + super().__init__() + self._timeout = timeout + self._should_stop = stop_event + + @property + def should_stop(self) -> bool: + return self._should_stop.is_set() if self._should_stop else False + + def get(self, block=True, timeout=None) -> List[Entity]: + """ + Returns a batch of entities read from the queue. + """ + timeout = timeout or self._timeout + entities = [] + last_poll_time = time() + + while not self.should_stop and (time() - last_poll_time < timeout): + try: + entity = super().get(block=block, timeout=0.5) + except Empty: + continue + + if entity: + entities.append(entity) + + return entities + + def put(self, *entities: Entity, block=True, timeout=None): + """ + This methood is called by an entity manager to update and persist the + state of some entities. + """ + for entity in entities: + super().put(entity, block=block, timeout=timeout) diff --git a/platypush/entities/_engine/repo/__init__.py b/platypush/entities/_engine/repo/__init__.py new file mode 100644 index 000000000..8fdab4db9 --- /dev/null +++ b/platypush/entities/_engine/repo/__init__.py @@ -0,0 +1,86 @@ +import logging +from typing import Dict, Iterable, Tuple + +from sqlalchemy.orm import Session, make_transient + +from platypush.entities import Entity +from platypush.entities._engine.repo.cache import EntitiesCache +from platypush.entities._engine.repo.db import EntitiesDb +from platypush.entities._engine.repo.merger import EntitiesMerger + +logger = logging.getLogger('entities') + + +class EntitiesRepository: + """ + This object is used to get and save entities, and it wraps database and + cache objects. + """ + + def __init__(self): + self._cache = EntitiesCache() + self._db = EntitiesDb() + self._merger = EntitiesMerger(self) + self._init_entities_cache() + + def _init_entities_cache(self): + """ + Initializes the repository with the existing entities. + """ + logger.info('Initializing entities cache') + with self._db.get_session() as session: + entities = session.query(Entity).all() + for entity in entities: + make_transient(entity) + + self._cache.update(*entities, overwrite=True) + logger.info('Entities cache initialized') + + def get( + self, session: Session, entities: Iterable[Entity] + ) -> Dict[Tuple[str, str], Entity]: + """ + Given a set of entity objects, it returns those that already exist + (or have the same ``entity_key``). It looks up both the cache and the + database. + """ + entities_map: Dict[Tuple[str, str], Entity] = { + e.entity_key: e for e in entities + } + + # Fetch the entities that exist in the cache + existing_entities = {} + # TODO UNCOMMENT THIS CODE TO ACTUALLY USE THE CACHE! + # existing_entities = { + # key: self._entities_cache.by_external_id_and_plugin[key] + # for key in entities_map.keys() + # if key in self._entities_cache.by_external_id_and_plugin + # } + + # Retrieve from the database the entities that miss from the cache + cache_miss_entities = { + key: e for key, e in entities_map.items() if key not in existing_entities + } + + cache_miss_existing_entities = self._db.fetch( + session, cache_miss_entities.values() + ) + + # Update the cache + self._cache.update(*cache_miss_existing_entities.values()) + + # Return the union of the cached + retrieved entities + existing_entities.update(cache_miss_existing_entities) + return existing_entities + + def save(self, *entities: Entity) -> Iterable[Entity]: + """ + Perform an upsert of entities after merging duplicates and rebuilding + the taxonomies. It updates both the database and the cache. + """ + with self._db.get_session(locked=True, autoflush=False) as session: + merged_entities = self._merger.merge(session, entities) + merged_entities = self._db.upsert(session, merged_entities) + self._cache.update(*merged_entities, overwrite=True) + + return merged_entities diff --git a/platypush/entities/_engine/repo/cache.py b/platypush/entities/_engine/repo/cache.py new file mode 100644 index 000000000..201714103 --- /dev/null +++ b/platypush/entities/_engine/repo/cache.py @@ -0,0 +1,51 @@ +from threading import RLock +from typing import Dict, Optional, Tuple + +from platypush.entities import Entity + + +class EntitiesCache: + """ + An auxiliary class to model an entities lookup cache with multiple keys. + """ + + def __init__(self): + self.by_id: Dict[str, Entity] = {} + self.by_external_id_and_plugin: Dict[Tuple[str, str], Entity] = {} + self._lock = RLock() + + def get(self, entity: Entity) -> Optional[Entity]: + """ + Retrieve the cached representation of an entity, if it exists. + """ + if entity.id: + e = self.by_id.get(str(entity.id)) + if e: + return e + + if entity.external_id and entity.plugin: + e = self.by_external_id_and_plugin.get( + (str(entity.external_id), str(entity.plugin)) + ) + if e: + return e + + def update(self, *entities: Entity, overwrite=False): + """ + Update the cache with a list of new entities. + """ + with self._lock: + for entity in entities: + if not overwrite: + existing_entity = self.by_id.get(str(entity.id)) + if existing_entity: + for k, v in existing_entity.to_json().items(): + if getattr(entity, k, None) is None: + setattr(entity, k, v) + + if entity.id: + self.by_id[str(entity.id)] = entity + if entity.external_id and entity.plugin: + self.by_external_id_and_plugin[ + (str(entity.external_id), str(entity.plugin)) + ] = entity diff --git a/platypush/entities/_engine/repo/db.py b/platypush/entities/_engine/repo/db.py new file mode 100644 index 000000000..77a5668c9 --- /dev/null +++ b/platypush/entities/_engine/repo/db.py @@ -0,0 +1,187 @@ +from collections import defaultdict +from dataclasses import dataclass +from typing import Dict, Iterable, List, Tuple + +from sqlalchemy import and_, or_ +from sqlalchemy.exc import InvalidRequestError +from sqlalchemy.orm import Session + +from platypush.context import get_plugin +from platypush.entities import Entity + + +@dataclass +class _TaxonomyAwareEntity: + """ + A support class used to map an entity and its level within a taxonomy to be + flushed to the database. + """ + + entity: Entity + level: int + + +class EntitiesDb: + """ + This object is a facade around the entities database. It shouldn't be used + directly. Instead, it is encapsulated by + :class:`platypush.entities._repo.EntitiesRepository`, which is in charge of + caching as well. + """ + + def get_session(self, *args, **kwargs) -> Session: + db = get_plugin('db') + assert db + return db.get_session(*args, **kwargs) + + def fetch( + self, session: Session, entities: Iterable[Entity] + ) -> Dict[Tuple[str, str], Entity]: + """ + Given a set of entities, it returns those that already exist on the database. + """ + if not entities: + return {} + + entities_filter = or_( + *[ + and_( + Entity.external_id == entity.external_id, + Entity.plugin == entity.plugin, + ) + for entity in entities + ] + ) + + query = session.query(Entity).filter(entities_filter) + existing_entities = {entity.entity_key: entity for entity in query.all()} + + return { + entity.entity_key: existing_entities[entity.entity_key] + for entity in entities + if existing_entities.get(entity.entity_key) + } + + @staticmethod + def _close_batch(batch: List[_TaxonomyAwareEntity], batches: List[List[Entity]]): + if batch: + batches.append([item.entity for item in batch]) + + batch.clear() + + def _split_entity_batches_for_flush( + self, entities: Iterable[Entity] + ) -> List[List[Entity]]: + """ + This method retrieves the root entities given a list of entities and + generates batches of "flushable" entities ready for upsert using a BFS + algorithm. + + This is needed because we want hierarchies of entities to be flushed + starting from the top layer, once their parents have been appropriately + rewired. Otherwise, we may end up with conflicts on entities that have + already been flushed. + """ + # Index childrens by parent_id and by parent_key + children_by_parent_id = defaultdict(list) + children_by_parent_key = defaultdict(list) + for entity in entities: + parent_key = None + parent_id = entity.parent_id + if entity.parent: + parent_id = parent_id or entity.parent.id + parent_key = entity.parent.entity_key + + if parent_id: + children_by_parent_id[parent_id].append(entity) + if parent_key: + children_by_parent_key[parent_key].append(entity) + + # Find the root entities in the hierarchy (i.e. those that have a null + # parent) + root_entities = list( + { + e.entity_key: e + for e in entities + if e.parent is None and e.parent_id is None + }.values() + ) + + # Prepare a list of entities to process through BFS starting with the + # root nodes (i.e. level=0) + entities_to_process = [ + _TaxonomyAwareEntity(entity=e, level=0) for e in root_entities + ] + + batches = [] + current_batch = [] + + while entities_to_process: + # Pop the first element in the list (FIFO implementation) + item = entities_to_process.pop(0) + entity = item.entity + level = item.level + + # If the depth has increased compared to the previous node, flush + # the current batch and open a new one. + if current_batch and current_batch[-1].level < level: + self._close_batch(current_batch, batches) + current_batch.append(item) + + # Index the children nodes by key + children_to_process = { + e.entity_key: e + for e in children_by_parent_key.get(entity.entity_key, []) + } + + # If this entity has already been persisted, add back its children + # that haven't been updated, so we won't lose those connections + if entity.id: + children_to_process.update( + {e.entity_key: e for e in children_by_parent_id.get(entity.id, [])} + ) + + # Add all the updated+inserted+existing children to the next layer + # to be expanded + entities_to_process += [ + _TaxonomyAwareEntity(entity=e, level=level + 1) + for e in children_to_process.values() + ] + + # Close any pending batches + self._close_batch(current_batch, batches) + return batches + + def upsert( + self, + session: Session, + entities: Iterable[Entity], + ) -> Iterable[Entity]: + """ + Persist a set of entities. + """ + # Get the "unwrapped" batches + batches = self._split_entity_batches_for_flush(entities) + + # Flush each batch as we process it + for batch in batches: + session.add_all(batch) + session.flush() + + session.commit() + + all_entities = list( + { + entity.entity_key: entity for batch in batches for entity in batch + }.values() + ) + + # Remove all the entities from the existing session, so they can be + # accessed outside of this context + for e in all_entities: + try: + session.expunge(e) + except InvalidRequestError: + pass + + return all_entities diff --git a/platypush/entities/_engine/repo/merger.py b/platypush/entities/_engine/repo/merger.py new file mode 100644 index 000000000..9b18ec85a --- /dev/null +++ b/platypush/entities/_engine/repo/merger.py @@ -0,0 +1,143 @@ +from typing import Dict, Iterable, List, Optional, Tuple + +from sqlalchemy.orm import Session + +from platypush.entities import Entity + + +class EntitiesMerger: + """ + This object is in charge of detecting and merging entities that already + exist on the database before flushing the session. + """ + + def __init__(self, repository): + from platypush.entities._engine.repo import EntitiesRepository + + self._repo: EntitiesRepository = repository + + def merge( + self, + session: Session, + entities: Iterable[Entity], + ) -> List[Entity]: + """ + Merge a set of entities with their existing representations and update + the parent/child relationships and return a tuple with + ``[new_entities, updated_entities]``. + """ + new_entities = {} + existing_entities = {} + + self._merge( + session, + entities, + new_entities=new_entities, + existing_entities=existing_entities, + ) + + return [*existing_entities.values(), *new_entities.values()] + + def _merge( + self, + session: Session, + entities: Iterable[Entity], + new_entities: Dict[Tuple[str, str], Entity], + existing_entities: Dict[Tuple[str, str], Entity], + ) -> List[Entity]: + """ + (Recursive) inner implementation of the entity merge logic. + """ + processed_entities = [] + existing_entities.update(self._repo.get(session, entities)) + + # Retrieve existing records and merge them + for entity in entities: + key = entity.entity_key + existing_entity = existing_entities.get(key, new_entities.get(key)) + parent_id, parent = self._update_parent(session, entity, new_entities) + + if existing_entity: + # Update the parent + if not parent_id and parent: + existing_entity.parent = parent + else: + existing_entity.parent_id = parent_id # type: ignore + + # Merge the other columns + self._merge_columns(entity, existing_entity) + entity = existing_entity + else: + # Add it to the map of new entities if the entity doesn't exist + # on the repo + new_entities[key] = entity + + processed_entities.append(entity) + + return processed_entities + + def _update_parent( + self, + session: Session, + entity: Entity, + new_entities: Dict[Tuple[str, str], Entity], + ) -> Tuple[Optional[int], Optional[Entity]]: + """ + Recursively update the hierarchy of an entity, moving upwards towards + the parent. + """ + parent_id: Optional[int] = entity.parent_id # type: ignore + parent: Optional[Entity] = entity.parent + + # If the entity has a parent with an ID, use that + if parent and parent.id: + parent_id = parent_id or parent.id + + # If there's no parent_id but there is a parent object, try to fetch + # its stored version + if not parent_id and parent: + batch = list(self._repo.get(session, [parent]).values()) + + # If the parent is already stored, use its ID + if batch: + parent = batch[0] + parent_id = parent.id # type: ignore + + # Otherwise, check if its key is already among those awaiting flush + # and reuse the same objects (prevents SQLAlchemy from generating + # duplicate inserts) + else: + temp_entity = new_entities.get(parent.entity_key) + if temp_entity: + parent = entity.parent = temp_entity + else: + new_entities[parent.entity_key] = parent + + # Recursively apply any changes up in the hierarchy + self._update_parent(session, parent, new_entities=new_entities) + + # If we found a parent_id, populate it on the entity (and remove the + # supporting relationship object so SQLAlchemy doesn't go nuts when + # flushing) + if parent_id: + entity.parent = None + entity.parent_id = parent_id # type: ignore + + return parent_id, parent + + @classmethod + def _merge_columns(cls, entity: Entity, existing_entity: Entity) -> Entity: + """ + Merge two versions of an entity column by column. + """ + columns = [col.key for col in entity.columns] + for col in columns: + if col == 'meta': + existing_entity.meta = { # type: ignore + **(existing_entity.meta or {}), # type: ignore + **(entity.meta or {}), # type: ignore + } + elif col not in ('id', 'created_at'): + setattr(existing_entity, col, getattr(entity, col)) + + return existing_entity