import logging from typing import Dict, Iterable, Optional, Tuple from sqlalchemy.orm import Session from platypush.entities._base import Entity, EntityMapping # pylint: disable=no-name-in-module from platypush.entities._engine.repo.db import EntitiesDb from platypush.entities._engine.repo.merger import EntitiesMerger logger = logging.getLogger('entities') class EntitiesRepository: """ This object is used to get and save entities. It wraps the database connection. """ def __init__(self): self._db = EntitiesDb() self._merge = EntitiesMerger() def get( self, session: Session, entities: Iterable[Entity] ) -> Dict[Tuple[str, str], Entity]: """ Given a set of entity objects, it returns those that already exist (or have the same ``entity_key``). """ return self._db.fetch(session, entities) def save(self, *entities: Entity) -> Iterable[Entity]: """ Perform an upsert of entities after merging duplicates and rebuilding the taxonomies. """ with self._db.get_session( locked=True, autoflush=False, autocommit=False, expire_on_commit=False, ) as session: merged_entities = self._merge( session, entities, existing_entities=self._fetch_all_and_flatten(session, entities), ) merged_entities = self._db.upsert(session, merged_entities) return merged_entities def _fetch_all_and_flatten( self, session: Session, entities: Iterable[Entity], ) -> EntityMapping: """ Given a collection of entities, retrieves their persisted instances (lookup is performed by ``entity_key``), and it also recursively expands their relationships, so the session is updated with the latest persisted versions of all the objects in the hierarchy. :return: An ``entity_key -> entity`` mapping. """ expanded_entities = {} for entity in entities: root_entity = self._get_root_entity(session, entity) expanded_entities.update(self._expand_children([root_entity])) expanded_entities.update(self._expand_children([entity])) return self.get(session, expanded_entities.values()) @classmethod def _expand_children( cls, entities: Iterable[Entity], all_entities: Optional[EntityMapping] = None, ) -> EntityMapping: """ Recursively expands and flattens all the children of a set of entities into an ``entity_key -> entity`` mapping. """ all_entities = all_entities or {} for entity in entities: all_entities[entity.entity_key] = entity cls._expand_children(entity.children, all_entities) return all_entities def _get_root_entity(self, session: Session, entity: Entity) -> Entity: """ Retrieve the root entity (i.e. the one with a null parent) of an entity. """ parent = entity while parent: parent = self._merge.get_parent(session, entity) if parent: entity = parent return entity