from typing import Dict, Iterable, List, Optional, Tuple from sqlalchemy.orm import Session, exc from platypush.entities import Entity # pylint: disable=too-few-public-methods class EntitiesMerger: """ This object is in charge of detecting and merging entities that already exist on the database before flushing the session. """ def __init__(self, repository): from . import EntitiesRepository self._repo: EntitiesRepository = repository def merge( self, session: Session, entities: Iterable[Entity], ) -> List[Entity]: """ Merge a set of entities with their existing representations and update the parent/child relationships and return a tuple with ``[new_entities, updated_entities]``. """ new_entities: Dict[Tuple[str, str], Entity] = {} existing_entities: Dict[Tuple[str, str], Entity] = {} self._merge( session, entities, new_entities=new_entities, existing_entities=existing_entities, ) return [*existing_entities.values(), *new_entities.values()] def _merge( self, session: Session, entities: Iterable[Entity], new_entities: Dict[Tuple[str, str], Entity], existing_entities: Dict[Tuple[str, str], Entity], ) -> List[Entity]: """ (Recursive) inner implementation of the entity merge logic. """ processed_entities = [] existing_entities.update(self._repo.get(session, entities)) # Make sure that we have no duplicate entity keys in the current batch entities = list( { **({e.entity_key: e for e in entities}), **( { e.entity_key: e for e in {str(ee.id): ee for ee in entities if ee.id}.values() } ), }.values() ) # Retrieve existing records and merge them for entity in entities: key = entity.entity_key existing_entity = existing_entities.get(key, new_entities.get(key)) parent_id, parent = self._update_parent(session, entity, new_entities) if existing_entity: # Update the parent if not parent_id and parent: existing_entity.parent = parent else: existing_entity.parent_id = parent_id # Merge the other columns self._merge_columns(entity, existing_entity) # Merge the children self._merge(session, entity.children, new_entities, existing_entities) # Use the updated version of the existing entity. entity = existing_entity else: # Add it to the map of new entities if the entity doesn't exist # on the repo new_entities[key] = entity processed_entities.append(entity) return processed_entities def _update_parent( self, session: Session, entity: Entity, new_entities: Dict[Tuple[str, str], Entity], ) -> Tuple[Optional[int], Optional[Entity]]: """ Recursively update the hierarchy of an entity, moving upwards towards the parent. """ parent_id: Optional[int] = entity.parent_id try: parent: Optional[Entity] = entity.parent except exc.DetachedInstanceError: # Dirty fix for `Parent instance <...> is not bound to a Session; # lazy load operation of attribute 'parent' cannot proceed parent = session.query(Entity).get(parent_id) if parent_id else None # If the entity has a parent with an ID, use that if parent and parent.id: parent_id = parent_id or parent.id # If there's no parent_id but there is a parent object, try to fetch # its stored version if not parent_id and parent: batch = list(self._repo.get(session, [parent]).values()) # If the parent is already stored, use its ID if batch: parent = batch[0] parent_id = parent.id # Otherwise, check if its key is already among those awaiting flush # and reuse the same objects (prevents SQLAlchemy from generating # duplicate inserts) else: temp_entity = new_entities.get(parent.entity_key) if temp_entity: self._remove_duplicate_children(entity, temp_entity) parent = entity.parent = temp_entity else: new_entities[parent.entity_key] = parent # Recursively apply any changes up in the hierarchy self._update_parent(session, parent, new_entities=new_entities) # If we found a parent_id, populate it on the entity (and remove the # supporting relationship object so SQLAlchemy doesn't go nuts when # flushing) if parent_id: entity.parent = None entity.parent_id = parent_id return parent_id, parent @staticmethod def _remove_duplicate_children(entity: Entity, parent: Optional[Entity] = None): if not parent: return # Make sure that an entity has no duplicate entity IDs among its # children existing_child_index_by_id = None if entity.id: try: existing_child_index_by_id = [e.id for e in parent.children].index( entity.id ) parent.children.pop(existing_child_index_by_id) except ValueError: pass # Make sure that an entity has no duplicate entity keys among its # children existing_child_index_by_key = None try: existing_child_index_by_key = [e.entity_key for e in parent.children].index( entity.entity_key ) parent.children.pop(existing_child_index_by_key) except ValueError: pass @classmethod def _merge_columns(cls, entity: Entity, existing_entity: Entity) -> Entity: """ Merge two versions of an entity column by column. """ columns = [col.key for col in entity.columns] for col in columns: if col == 'meta': existing_entity.meta = { **(existing_entity.meta or {}), **(entity.meta or {}), } elif col not in ('id', 'created_at'): setattr(existing_entity, col, getattr(entity, col)) return existing_entity