platypush/platypush/entities/_engine/repo/merger.py

195 lines
6.7 KiB
Python

from typing import Dict, Iterable, List, Optional, Tuple
from sqlalchemy.orm import Session, exc
from platypush.entities import Entity
# pylint: disable=too-few-public-methods
class EntitiesMerger:
"""
This object is in charge of detecting and merging entities that already
exist on the database before flushing the session.
"""
def __init__(self, repository):
from . import EntitiesRepository
self._repo: EntitiesRepository = repository
def merge(
self,
session: Session,
entities: Iterable[Entity],
) -> List[Entity]:
"""
Merge a set of entities with their existing representations and update
the parent/child relationships and return a tuple with
``[new_entities, updated_entities]``.
"""
new_entities: Dict[Tuple[str, str], Entity] = {}
existing_entities: Dict[Tuple[str, str], Entity] = {}
self._merge(
session,
entities,
new_entities=new_entities,
existing_entities=existing_entities,
)
return [*existing_entities.values(), *new_entities.values()]
def _merge(
self,
session: Session,
entities: Iterable[Entity],
new_entities: Dict[Tuple[str, str], Entity],
existing_entities: Dict[Tuple[str, str], Entity],
) -> List[Entity]:
"""
(Recursive) inner implementation of the entity merge logic.
"""
processed_entities = []
existing_entities.update(self._repo.get(session, entities))
# Make sure that we have no duplicate entity keys in the current batch
entities = list(
{
**({e.entity_key: e for e in entities}),
**(
{
e.entity_key: e
for e in {str(ee.id): ee for ee in entities if ee.id}.values()
}
),
}.values()
)
# Retrieve existing records and merge them
for entity in entities:
key = entity.entity_key
existing_entity = existing_entities.get(key, new_entities.get(key))
parent_id, parent = self._update_parent(session, entity, new_entities)
if existing_entity:
# Update the parent
if not parent_id and parent:
existing_entity.parent = parent
else:
existing_entity.parent_id = parent_id
# Merge the other columns
self._merge_columns(entity, existing_entity)
# Merge the children
self._merge(session, entity.children, new_entities, existing_entities)
# Use the updated version of the existing entity.
entity = existing_entity
else:
# Add it to the map of new entities if the entity doesn't exist
# on the repo
new_entities[key] = entity
processed_entities.append(entity)
return processed_entities
def _update_parent(
self,
session: Session,
entity: Entity,
new_entities: Dict[Tuple[str, str], Entity],
) -> Tuple[Optional[int], Optional[Entity]]:
"""
Recursively update the hierarchy of an entity, moving upwards towards
the parent.
"""
parent_id: Optional[int] = entity.parent_id
try:
parent: Optional[Entity] = entity.parent
except exc.DetachedInstanceError:
# Dirty fix for `Parent instance <...> is not bound to a Session;
# lazy load operation of attribute 'parent' cannot proceed
parent = session.query(Entity).get(parent_id) if parent_id else None
# If the entity has a parent with an ID, use that
if parent and parent.id:
parent_id = parent_id or parent.id
# If there's no parent_id but there is a parent object, try to fetch
# its stored version
if not parent_id and parent:
batch = list(self._repo.get(session, [parent]).values())
# If the parent is already stored, use its ID
if batch:
parent = batch[0]
parent_id = parent.id
# Otherwise, check if its key is already among those awaiting flush
# and reuse the same objects (prevents SQLAlchemy from generating
# duplicate inserts)
else:
temp_entity = new_entities.get(parent.entity_key)
if temp_entity:
self._remove_duplicate_children(entity, temp_entity)
parent = entity.parent = temp_entity
else:
new_entities[parent.entity_key] = parent
# Recursively apply any changes up in the hierarchy
self._update_parent(session, parent, new_entities=new_entities)
# If we found a parent_id, populate it on the entity (and remove the
# supporting relationship object so SQLAlchemy doesn't go nuts when
# flushing)
if parent_id:
entity.parent = None
entity.parent_id = parent_id
return parent_id, parent
@staticmethod
def _remove_duplicate_children(entity: Entity, parent: Optional[Entity] = None):
if not parent:
return
# Make sure that an entity has no duplicate entity IDs among its
# children
existing_child_index_by_id = None
if entity.id:
try:
existing_child_index_by_id = [e.id for e in parent.children].index(
entity.id
)
parent.children.pop(existing_child_index_by_id)
except ValueError:
pass
# Make sure that an entity has no duplicate entity keys among its
# children
existing_child_index_by_key = None
try:
existing_child_index_by_key = [e.entity_key for e in parent.children].index(
entity.entity_key
)
parent.children.pop(existing_child_index_by_key)
except ValueError:
pass
@classmethod
def _merge_columns(cls, entity: Entity, existing_entity: Entity) -> Entity:
"""
Merge two versions of an entity column by column.
"""
columns = [col.key for col in entity.columns]
for col in columns:
if col == 'meta':
existing_entity.meta = {
**(existing_entity.meta or {}),
**(entity.meta or {}),
}
elif col not in ('id', 'created_at'):
setattr(existing_entity, col, getattr(entity, col))
return existing_entity