platypush/platypush/entities/_engine/repo/__init__.py

106 lines
3.2 KiB
Python

import logging
from typing import Dict, Iterable, Optional, Tuple
from sqlalchemy.orm import Session
from platypush.entities._base import Entity, EntityMapping
# pylint: disable=no-name-in-module
from platypush.entities._engine.repo.db import EntitiesDb
from platypush.entities._engine.repo.merger import EntitiesMerger
logger = logging.getLogger('entities')
class EntitiesRepository:
"""
This object is used to get and save entities. It wraps the database
connection.
"""
def __init__(self):
self._db = EntitiesDb()
self._merge = EntitiesMerger()
def get(
self, session: Session, entities: Iterable[Entity]
) -> Dict[Tuple[str, str], Entity]:
"""
Given a set of entity objects, it returns those that already exist
(or have the same ``entity_key``).
"""
return self._db.fetch(session, entities)
def save(self, *entities: Entity) -> Iterable[Entity]:
"""
Perform an upsert of entities after merging duplicates and rebuilding
the taxonomies.
"""
with self._db.get_session(
locked=True,
autoflush=False,
autocommit=False,
expire_on_commit=False,
) as session:
merged_entities = self._merge(
session,
entities,
existing_entities=self._fetch_all_and_flatten(session, entities),
)
merged_entities = self._db.upsert(session, merged_entities)
return merged_entities
def _fetch_all_and_flatten(
self,
session: Session,
entities: Iterable[Entity],
) -> EntityMapping:
"""
Given a collection of entities, retrieves their persisted instances
(lookup is performed by ``entity_key``), and it also recursively
expands their relationships, so the session is updated with the latest
persisted versions of all the objects in the hierarchy.
:return: An ``entity_key -> entity`` mapping.
"""
expanded_entities = {}
for entity in entities:
root_entity = self._get_root_entity(session, entity)
expanded_entities.update(self._expand_children([root_entity]))
expanded_entities.update(self._expand_children([entity]))
return self.get(session, expanded_entities.values())
@classmethod
def _expand_children(
cls,
entities: Iterable[Entity],
all_entities: Optional[EntityMapping] = None,
) -> EntityMapping:
"""
Recursively expands and flattens all the children of a set of entities
into an ``entity_key -> entity`` mapping.
"""
all_entities = all_entities or {}
for entity in entities:
all_entities[entity.entity_key] = entity
cls._expand_children(entity.children, all_entities)
return all_entities
def _get_root_entity(self, session: Session, entity: Entity) -> Entity:
"""
Retrieve the root entity (i.e. the one with a null parent) of an
entity.
"""
parent = entity
while parent:
parent = self._merge.get_parent(session, entity)
if parent:
entity = parent
return entity