Large refactor of the entities engine.

This commit is contained in:
Fabio Manganiello 2022-12-17 21:41:23 +01:00
parent 9ddebb920f
commit b0464219d3
Signed by untrusted user: blacklight
GPG key ID: D90FBA7F76362774
9 changed files with 757 additions and 360 deletions

View file

@ -1,19 +1,21 @@
import inspect import inspect
import json
import pathlib import pathlib
import types import types
from datetime import datetime from datetime import datetime
from typing import Dict, Mapping, Type, Tuple, Any from dateutil.tz import tzutc
from typing import Mapping, Type, Tuple, Any
import pkgutil import pkgutil
from sqlalchemy import ( from sqlalchemy import (
Boolean, Boolean,
Column, Column,
DateTime,
ForeignKey, ForeignKey,
Index, Index,
Integer, Integer,
String,
DateTime,
JSON, JSON,
String,
UniqueConstraint, UniqueConstraint,
inspect as schema_inspect, inspect as schema_inspect,
) )
@ -23,9 +25,10 @@ from platypush.common.db import Base
from platypush.message import JSONAble from platypush.message import JSONAble
entities_registry: Mapping[Type['Entity'], Mapping] = {} entities_registry: Mapping[Type['Entity'], Mapping] = {}
entity_types_registry: Dict[str, Type['Entity']] = {}
if 'entity' not in Base.metadata:
class Entity(Base): class Entity(Base):
""" """
Model for a general-purpose platform entity. Model for a general-purpose platform entity.
@ -34,7 +37,7 @@ class Entity(Base):
__tablename__ = 'entity' __tablename__ = 'entity'
id = Column(Integer, autoincrement=True, primary_key=True) id = Column(Integer, autoincrement=True, primary_key=True)
external_id = Column(String, nullable=True) external_id = Column(String, nullable=False)
name = Column(String, nullable=False, index=True) name = Column(String, nullable=False, index=True)
description = Column(String) description = Column(String)
type = Column(String, nullable=False, index=True) type = Column(String, nullable=False, index=True)
@ -54,7 +57,9 @@ class Entity(Base):
DateTime(timezone=False), default=datetime.utcnow(), nullable=False DateTime(timezone=False), default=datetime.utcnow(), nullable=False
) )
updated_at = Column( updated_at = Column(
DateTime(timezone=False), default=datetime.utcnow(), onupdate=datetime.utcnow() DateTime(timezone=False),
default=datetime.utcnow(),
onupdate=datetime.utcnow(),
) )
parent: Mapped['Entity'] = relationship( parent: Mapped['Entity'] = relationship(
@ -62,6 +67,7 @@ class Entity(Base):
remote_side=[id], remote_side=[id],
uselist=False, uselist=False,
lazy=True, lazy=True,
post_update=True,
backref=backref( backref=backref(
'children', 'children',
remote_side=[parent_id], remote_side=[parent_id],
@ -89,17 +95,50 @@ class Entity(Base):
inspector = schema_inspect(cls) inspector = schema_inspect(cls)
return tuple(inspector.mapper.column_attrs) return tuple(inspector.mapper.column_attrs)
@property
def entity_key(self) -> Tuple[str, str]:
"""
This method returns the "external" key of an entity.
"""
return (str(self.external_id), str(self.plugin))
def _serialize_value(self, col: ColumnProperty) -> Any: def _serialize_value(self, col: ColumnProperty) -> Any:
val = getattr(self, col.key) val = getattr(self, col.key)
if isinstance(val, datetime): if isinstance(val, datetime):
# All entity timestamps are in UTC # All entity timestamps are in UTC
val = val.isoformat() + '+00:00' val = val.replace(tzinfo=tzutc()).isoformat()
return val return val
def copy(self) -> 'Entity':
args = {c.key: getattr(self, c.key) for c in self.columns}
# if self.parent:
# args['parent'] = self.parent.copy()
# args['children'] = [c.copy() for c in self.children]
return self.__class__(**args)
def to_json(self) -> dict: def to_json(self) -> dict:
return {col.key: self._serialize_value(col) for col in self.columns} return {col.key: self._serialize_value(col) for col in self.columns}
def __repr__(self):
return str(self)
def __str__(self):
return json.dumps(self.to_json())
def __setattr__(self, key, value):
matching_columns = [c for c in self.columns if c.expression.name == key]
if (
matching_columns
and issubclass(type(matching_columns[0].columns[0].type), DateTime)
and isinstance(value, str)
):
value = datetime.fromisoformat(value)
return super().__setattr__(key, value)
def get_plugin(self): def get_plugin(self):
from platypush.context import get_plugin from platypush.context import get_plugin
@ -113,7 +152,6 @@ class Entity(Base):
assert method, f'No such action: {self.plugin}.{action}' assert method, f'No such action: {self.plugin}.{action}'
return method(self.external_id or self.name, *args, **kwargs) return method(self.external_id or self.name, *args, **kwargs)
# Inject the JSONAble mixin (Python goes nuts if done through # Inject the JSONAble mixin (Python goes nuts if done through
# standard multiple inheritance with an SQLAlchemy ORM class) # standard multiple inheritance with an SQLAlchemy ORM class)
Entity.__bases__ = Entity.__bases__ + (JSONAble,) Entity.__bases__ = Entity.__bases__ + (JSONAble,)

View file

@ -1,281 +0,0 @@
import json
from logging import getLogger
from queue import Queue, Empty
from threading import Thread, Event, RLock
from time import time
from typing import Iterable, List, Optional
from sqlalchemy import and_, or_
from sqlalchemy.exc import InvalidRequestError
from sqlalchemy.orm import Session, make_transient
from platypush.context import get_bus, get_plugin
from platypush.message.event.entities import EntityUpdateEvent
from ._base import Entity
class EntitiesEngine(Thread):
# Processing queue timeout in seconds
_queue_timeout = 2.0
def __init__(self):
obj_name = self.__class__.__name__
super().__init__(name=obj_name)
self.logger = getLogger(name=obj_name)
self._queue = Queue()
self._should_stop = Event()
self._entities_awaiting_flush = set()
self._entities_cache_lock = RLock()
self._entities_cache = {
'by_id': {},
'by_external_id_and_plugin': {},
'by_name_and_plugin': {},
}
def _get_session(self, *args, **kwargs):
db = get_plugin('db')
assert db
return db.get_session(*args, **kwargs)
def _get_cached_entity(self, entity: Entity) -> Optional[dict]:
if entity.id:
e = self._entities_cache['by_id'].get(entity.id)
if e:
return e
if entity.external_id and entity.plugin:
e = self._entities_cache['by_external_id_and_plugin'].get(
(entity.external_id, entity.plugin)
)
if e:
return e
if entity.name and entity.plugin:
e = self._entities_cache['by_name_and_plugin'].get(
(entity.name, entity.plugin)
)
if e:
return e
@staticmethod
def _cache_repr(entity: Entity) -> dict:
repr_ = entity.to_json()
repr_.pop('data', None)
repr_.pop('meta', None)
repr_.pop('created_at', None)
repr_.pop('updated_at', None)
return repr_
def _cache_entities(self, *entities: Entity, overwrite_cache=False):
for entity in entities:
e = self._cache_repr(entity)
if not overwrite_cache:
existing_entity = self._entities_cache['by_id'].get(entity.id)
if existing_entity:
for k, v in existing_entity.items():
if e.get(k) is None:
e[k] = v
if entity.id:
self._entities_cache['by_id'][entity.id] = e
if entity.external_id and entity.plugin:
self._entities_cache['by_external_id_and_plugin'][
(entity.external_id, entity.plugin)
] = e
if entity.name and entity.plugin:
self._entities_cache['by_name_and_plugin'][
(entity.name, entity.plugin)
] = e
def _populate_entity_id_from_cache(self, new_entity: Entity):
with self._entities_cache_lock:
cached_entity = self._get_cached_entity(new_entity)
if cached_entity and cached_entity.get('id'):
new_entity.id = cached_entity['id']
if new_entity.id:
self._cache_entities(new_entity)
def _init_entities_cache(self):
with self._get_session() as session:
entities = session.query(Entity).all()
for entity in entities:
make_transient(entity)
with self._entities_cache_lock:
self._cache_entities(*entities, overwrite_cache=True)
self.logger.info('Entities cache initialized')
def _process_event(self, entity: Entity):
self._populate_entity_id_from_cache(entity)
if entity.id:
get_bus().post(EntityUpdateEvent(entity=entity))
else:
self._entities_awaiting_flush.add(self._to_entity_awaiting_flush(entity))
@staticmethod
def _to_entity_awaiting_flush(entity: Entity):
e = entity.to_json()
return json.dumps(
{k: v for k, v in e.items() if k in {'external_id', 'name', 'plugin'}},
sort_keys=True,
)
def post(self, *entities: Entity):
for entity in entities:
self._queue.put(entity)
@property
def should_stop(self) -> bool:
return self._should_stop.is_set()
def stop(self):
self._should_stop.set()
def run(self):
super().run()
self.logger.info('Started entities engine')
self._init_entities_cache()
while not self.should_stop:
msgs = []
last_poll_time = time()
while not self.should_stop and (
time() - last_poll_time < self._queue_timeout
):
try:
msg = self._queue.get(block=True, timeout=0.5)
except Empty:
continue
if msg:
msgs.append(msg)
# Trigger an EntityUpdateEvent if there has
# been a change on the entity state
self._process_event(msg)
if not msgs or self.should_stop:
continue
try:
self._process_entities(*msgs)
except Exception as e:
self.logger.error('Error while processing entity updates: ' + str(e))
self.logger.exception(e)
self.logger.info('Stopped entities engine')
def get_if_exist(
self, session: Session, entities: Iterable[Entity]
) -> List[Entity]:
existing_entities = {
(
str(entity.external_id)
if entity.external_id is not None
else entity.name,
entity.plugin,
): entity
for entity in session.query(Entity)
.filter(
or_(
*[
and_(
Entity.external_id == entity.external_id,
Entity.plugin == entity.plugin,
)
if entity.external_id is not None
else and_(
Entity.name == entity.name,
Entity.type == entity.type,
Entity.plugin == entity.plugin,
)
for entity in entities
]
)
)
.all()
}
return [
existing_entities.get(
(
str(entity.external_id)
if entity.external_id is not None
else entity.name,
entity.plugin,
),
None,
)
for entity in entities
]
def _process_entities(self, *entities: Entity): # type: ignore
with self._get_session(locked=True) as session:
# Ensure that the internal IDs are set to null before the merge
for e in entities:
e.id = None # type: ignore
entities: List[Entity] = self._merge_entities(session, entities)
session.add_all(entities)
session.commit()
for e in entities:
try:
session.expunge(e)
except InvalidRequestError:
pass
with self._entities_cache_lock:
for entity in entities:
self._cache_entities(entity, overwrite_cache=True)
entities_awaiting_flush = {*self._entities_awaiting_flush}
for entity in entities:
e = self._to_entity_awaiting_flush(entity)
if e in entities_awaiting_flush:
self._process_event(entity)
self._entities_awaiting_flush.remove(e)
def _merge_entities(
self, session: Session, entities: Iterable[Entity]
) -> List[Entity]:
existing_entities = self.get_if_exist(session, entities)
new_entities = {}
entities_map = {}
# Get the latest update for each ((id|name), plugin) record
for e in entities:
entities_map[self.entity_key(e)] = e
# Retrieve existing records and merge them
for i, entity in enumerate(entities):
existing_entity = existing_entities[i]
if existing_entity:
existing_entity.children = self._merge_entities(
session, entity.children
)
entity = self._merge_entity_columns(entity, existing_entity)
new_entities[self.entity_key(entity)] = entity
return list(new_entities.values())
@classmethod
def _merge_entity_columns(cls, entity: Entity, existing_entity: Entity) -> Entity:
columns = [col.key for col in entity.columns]
for col in columns:
if col == 'meta':
existing_entity.meta = { # type: ignore
**(existing_entity.meta or {}), # type: ignore
**(entity.meta or {}), # type: ignore
}
elif col not in ('id', 'created_at'):
setattr(existing_entity, col, getattr(entity, col))
return existing_entity
@staticmethod
def entity_key(entity: Entity):
return ((entity.external_id or entity.name), entity.plugin)

View file

@ -0,0 +1,76 @@
from logging import getLogger
from threading import Thread, Event
from platypush.entities import Entity
from platypush.utils import set_thread_name
from platypush.entities._engine.notifier import EntityNotifier
from platypush.entities._engine.queue import EntitiesQueue
from platypush.entities._engine.repo import EntitiesRepository
class EntitiesEngine(Thread):
"""
This thread runs the "brain" of the entities data persistence logic.
Its purpose is to:
1. Consume entities from a queue (synchronized with the upstream
integrations that produce/handle them). The producer/consumer model
ensure that only this thread writes to the database, packs events
together (preventing eccessive writes and throttling events), and
prevents race conditions when SQLite is used.
2. Merge any existing entities with their newer representations.
3. Update the entities taxonomy.
4. Persist the new state to the entities database.
5. Trigger events for the updated entities.
"""
def __init__(self):
obj_name = self.__class__.__name__
super().__init__(name=obj_name)
self.logger = getLogger(name=obj_name)
self._should_stop = Event()
self._queue = EntitiesQueue(stop_event=self._should_stop)
self._repo = EntitiesRepository()
self._notifier = EntityNotifier(self._repo._cache)
def post(self, *entities: Entity):
self._queue.put(*entities)
@property
def should_stop(self) -> bool:
return self._should_stop.is_set()
def stop(self):
self._should_stop.set()
def run(self):
super().run()
set_thread_name('entities')
self.logger.info('Started entities engine')
while not self.should_stop:
# Get a batch of entity updates forwarded by other integrations
entities = self._queue.get()
if not entities or self.should_stop:
continue
# Trigger/prepare EntityUpdateEvent objects
for entity in entities:
self._notifier.notify(entity)
# Store the batch of entities
try:
entities = self._repo.save(*entities)
except Exception as e:
self.logger.error('Error while processing entity updates: ' + str(e))
self.logger.exception(e)
continue
# Flush any pending notifications
self._notifier.flush(*entities)
self.logger.info('Stopped entities engine')

View file

@ -0,0 +1,47 @@
from platypush.context import get_bus
from platypush.entities import Entity
from platypush.message.event.entities import EntityUpdateEvent
from platypush.entities._engine.repo.cache import EntitiesCache
class EntityNotifier:
"""
This object is in charge of forwarding EntityUpdateEvent instances on the
application bus when some entities are changed.
"""
def __init__(self, cache: EntitiesCache):
self._cache = cache
self._entities_awaiting_flush = set()
def _populate_entity_id_from_cache(self, new_entity: Entity):
cached_entity = self._cache.get(new_entity)
if cached_entity and cached_entity.id:
new_entity.id = cached_entity.id
if new_entity.id:
self._cache.update(new_entity)
def notify(self, entity: Entity):
"""
Trigger an EntityUpdateEvent if the entity has been persisted, or queue
it to the list of entities whose notifications will be flushed when the
session is committed.
"""
self._populate_entity_id_from_cache(entity)
if entity.id:
get_bus().post(EntityUpdateEvent(entity=entity))
else:
self._entities_awaiting_flush.add(entity.entity_key)
def flush(self, *entities: Entity):
"""
Flush and process any entities with pending EntityUpdateEvent
notifications.
"""
entities_awaiting_flush = {*self._entities_awaiting_flush}
for entity in entities:
key = entity.entity_key
if key in entities_awaiting_flush:
self.notify(entity)
self._entities_awaiting_flush.remove(key)

View file

@ -0,0 +1,50 @@
from queue import Queue, Empty
from threading import Event
from time import time
from typing import List, Optional
from platypush.entities import Entity
class EntitiesQueue(Queue):
"""
Extends the ``Queue`` class to provide an abstraction that allows to
getting and putting multiple entities at once and synchronize with the
upstream caller.
"""
def __init__(self, stop_event: Optional[Event] = None, timeout: float = 2.0):
super().__init__()
self._timeout = timeout
self._should_stop = stop_event
@property
def should_stop(self) -> bool:
return self._should_stop.is_set() if self._should_stop else False
def get(self, block=True, timeout=None) -> List[Entity]:
"""
Returns a batch of entities read from the queue.
"""
timeout = timeout or self._timeout
entities = []
last_poll_time = time()
while not self.should_stop and (time() - last_poll_time < timeout):
try:
entity = super().get(block=block, timeout=0.5)
except Empty:
continue
if entity:
entities.append(entity)
return entities
def put(self, *entities: Entity, block=True, timeout=None):
"""
This methood is called by an entity manager to update and persist the
state of some entities.
"""
for entity in entities:
super().put(entity, block=block, timeout=timeout)

View file

@ -0,0 +1,86 @@
import logging
from typing import Dict, Iterable, Tuple
from sqlalchemy.orm import Session, make_transient
from platypush.entities import Entity
from platypush.entities._engine.repo.cache import EntitiesCache
from platypush.entities._engine.repo.db import EntitiesDb
from platypush.entities._engine.repo.merger import EntitiesMerger
logger = logging.getLogger('entities')
class EntitiesRepository:
"""
This object is used to get and save entities, and it wraps database and
cache objects.
"""
def __init__(self):
self._cache = EntitiesCache()
self._db = EntitiesDb()
self._merger = EntitiesMerger(self)
self._init_entities_cache()
def _init_entities_cache(self):
"""
Initializes the repository with the existing entities.
"""
logger.info('Initializing entities cache')
with self._db.get_session() as session:
entities = session.query(Entity).all()
for entity in entities:
make_transient(entity)
self._cache.update(*entities, overwrite=True)
logger.info('Entities cache initialized')
def get(
self, session: Session, entities: Iterable[Entity]
) -> Dict[Tuple[str, str], Entity]:
"""
Given a set of entity objects, it returns those that already exist
(or have the same ``entity_key``). It looks up both the cache and the
database.
"""
entities_map: Dict[Tuple[str, str], Entity] = {
e.entity_key: e for e in entities
}
# Fetch the entities that exist in the cache
existing_entities = {}
# TODO UNCOMMENT THIS CODE TO ACTUALLY USE THE CACHE!
# existing_entities = {
# key: self._entities_cache.by_external_id_and_plugin[key]
# for key in entities_map.keys()
# if key in self._entities_cache.by_external_id_and_plugin
# }
# Retrieve from the database the entities that miss from the cache
cache_miss_entities = {
key: e for key, e in entities_map.items() if key not in existing_entities
}
cache_miss_existing_entities = self._db.fetch(
session, cache_miss_entities.values()
)
# Update the cache
self._cache.update(*cache_miss_existing_entities.values())
# Return the union of the cached + retrieved entities
existing_entities.update(cache_miss_existing_entities)
return existing_entities
def save(self, *entities: Entity) -> Iterable[Entity]:
"""
Perform an upsert of entities after merging duplicates and rebuilding
the taxonomies. It updates both the database and the cache.
"""
with self._db.get_session(locked=True, autoflush=False) as session:
merged_entities = self._merger.merge(session, entities)
merged_entities = self._db.upsert(session, merged_entities)
self._cache.update(*merged_entities, overwrite=True)
return merged_entities

View file

@ -0,0 +1,51 @@
from threading import RLock
from typing import Dict, Optional, Tuple
from platypush.entities import Entity
class EntitiesCache:
"""
An auxiliary class to model an entities lookup cache with multiple keys.
"""
def __init__(self):
self.by_id: Dict[str, Entity] = {}
self.by_external_id_and_plugin: Dict[Tuple[str, str], Entity] = {}
self._lock = RLock()
def get(self, entity: Entity) -> Optional[Entity]:
"""
Retrieve the cached representation of an entity, if it exists.
"""
if entity.id:
e = self.by_id.get(str(entity.id))
if e:
return e
if entity.external_id and entity.plugin:
e = self.by_external_id_and_plugin.get(
(str(entity.external_id), str(entity.plugin))
)
if e:
return e
def update(self, *entities: Entity, overwrite=False):
"""
Update the cache with a list of new entities.
"""
with self._lock:
for entity in entities:
if not overwrite:
existing_entity = self.by_id.get(str(entity.id))
if existing_entity:
for k, v in existing_entity.to_json().items():
if getattr(entity, k, None) is None:
setattr(entity, k, v)
if entity.id:
self.by_id[str(entity.id)] = entity
if entity.external_id and entity.plugin:
self.by_external_id_and_plugin[
(str(entity.external_id), str(entity.plugin))
] = entity

View file

@ -0,0 +1,187 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, Iterable, List, Tuple
from sqlalchemy import and_, or_
from sqlalchemy.exc import InvalidRequestError
from sqlalchemy.orm import Session
from platypush.context import get_plugin
from platypush.entities import Entity
@dataclass
class _TaxonomyAwareEntity:
"""
A support class used to map an entity and its level within a taxonomy to be
flushed to the database.
"""
entity: Entity
level: int
class EntitiesDb:
"""
This object is a facade around the entities database. It shouldn't be used
directly. Instead, it is encapsulated by
:class:`platypush.entities._repo.EntitiesRepository`, which is in charge of
caching as well.
"""
def get_session(self, *args, **kwargs) -> Session:
db = get_plugin('db')
assert db
return db.get_session(*args, **kwargs)
def fetch(
self, session: Session, entities: Iterable[Entity]
) -> Dict[Tuple[str, str], Entity]:
"""
Given a set of entities, it returns those that already exist on the database.
"""
if not entities:
return {}
entities_filter = or_(
*[
and_(
Entity.external_id == entity.external_id,
Entity.plugin == entity.plugin,
)
for entity in entities
]
)
query = session.query(Entity).filter(entities_filter)
existing_entities = {entity.entity_key: entity for entity in query.all()}
return {
entity.entity_key: existing_entities[entity.entity_key]
for entity in entities
if existing_entities.get(entity.entity_key)
}
@staticmethod
def _close_batch(batch: List[_TaxonomyAwareEntity], batches: List[List[Entity]]):
if batch:
batches.append([item.entity for item in batch])
batch.clear()
def _split_entity_batches_for_flush(
self, entities: Iterable[Entity]
) -> List[List[Entity]]:
"""
This method retrieves the root entities given a list of entities and
generates batches of "flushable" entities ready for upsert using a BFS
algorithm.
This is needed because we want hierarchies of entities to be flushed
starting from the top layer, once their parents have been appropriately
rewired. Otherwise, we may end up with conflicts on entities that have
already been flushed.
"""
# Index childrens by parent_id and by parent_key
children_by_parent_id = defaultdict(list)
children_by_parent_key = defaultdict(list)
for entity in entities:
parent_key = None
parent_id = entity.parent_id
if entity.parent:
parent_id = parent_id or entity.parent.id
parent_key = entity.parent.entity_key
if parent_id:
children_by_parent_id[parent_id].append(entity)
if parent_key:
children_by_parent_key[parent_key].append(entity)
# Find the root entities in the hierarchy (i.e. those that have a null
# parent)
root_entities = list(
{
e.entity_key: e
for e in entities
if e.parent is None and e.parent_id is None
}.values()
)
# Prepare a list of entities to process through BFS starting with the
# root nodes (i.e. level=0)
entities_to_process = [
_TaxonomyAwareEntity(entity=e, level=0) for e in root_entities
]
batches = []
current_batch = []
while entities_to_process:
# Pop the first element in the list (FIFO implementation)
item = entities_to_process.pop(0)
entity = item.entity
level = item.level
# If the depth has increased compared to the previous node, flush
# the current batch and open a new one.
if current_batch and current_batch[-1].level < level:
self._close_batch(current_batch, batches)
current_batch.append(item)
# Index the children nodes by key
children_to_process = {
e.entity_key: e
for e in children_by_parent_key.get(entity.entity_key, [])
}
# If this entity has already been persisted, add back its children
# that haven't been updated, so we won't lose those connections
if entity.id:
children_to_process.update(
{e.entity_key: e for e in children_by_parent_id.get(entity.id, [])}
)
# Add all the updated+inserted+existing children to the next layer
# to be expanded
entities_to_process += [
_TaxonomyAwareEntity(entity=e, level=level + 1)
for e in children_to_process.values()
]
# Close any pending batches
self._close_batch(current_batch, batches)
return batches
def upsert(
self,
session: Session,
entities: Iterable[Entity],
) -> Iterable[Entity]:
"""
Persist a set of entities.
"""
# Get the "unwrapped" batches
batches = self._split_entity_batches_for_flush(entities)
# Flush each batch as we process it
for batch in batches:
session.add_all(batch)
session.flush()
session.commit()
all_entities = list(
{
entity.entity_key: entity for batch in batches for entity in batch
}.values()
)
# Remove all the entities from the existing session, so they can be
# accessed outside of this context
for e in all_entities:
try:
session.expunge(e)
except InvalidRequestError:
pass
return all_entities

View file

@ -0,0 +1,143 @@
from typing import Dict, Iterable, List, Optional, Tuple
from sqlalchemy.orm import Session
from platypush.entities import Entity
class EntitiesMerger:
"""
This object is in charge of detecting and merging entities that already
exist on the database before flushing the session.
"""
def __init__(self, repository):
from platypush.entities._engine.repo import EntitiesRepository
self._repo: EntitiesRepository = repository
def merge(
self,
session: Session,
entities: Iterable[Entity],
) -> List[Entity]:
"""
Merge a set of entities with their existing representations and update
the parent/child relationships and return a tuple with
``[new_entities, updated_entities]``.
"""
new_entities = {}
existing_entities = {}
self._merge(
session,
entities,
new_entities=new_entities,
existing_entities=existing_entities,
)
return [*existing_entities.values(), *new_entities.values()]
def _merge(
self,
session: Session,
entities: Iterable[Entity],
new_entities: Dict[Tuple[str, str], Entity],
existing_entities: Dict[Tuple[str, str], Entity],
) -> List[Entity]:
"""
(Recursive) inner implementation of the entity merge logic.
"""
processed_entities = []
existing_entities.update(self._repo.get(session, entities))
# Retrieve existing records and merge them
for entity in entities:
key = entity.entity_key
existing_entity = existing_entities.get(key, new_entities.get(key))
parent_id, parent = self._update_parent(session, entity, new_entities)
if existing_entity:
# Update the parent
if not parent_id and parent:
existing_entity.parent = parent
else:
existing_entity.parent_id = parent_id # type: ignore
# Merge the other columns
self._merge_columns(entity, existing_entity)
entity = existing_entity
else:
# Add it to the map of new entities if the entity doesn't exist
# on the repo
new_entities[key] = entity
processed_entities.append(entity)
return processed_entities
def _update_parent(
self,
session: Session,
entity: Entity,
new_entities: Dict[Tuple[str, str], Entity],
) -> Tuple[Optional[int], Optional[Entity]]:
"""
Recursively update the hierarchy of an entity, moving upwards towards
the parent.
"""
parent_id: Optional[int] = entity.parent_id # type: ignore
parent: Optional[Entity] = entity.parent
# If the entity has a parent with an ID, use that
if parent and parent.id:
parent_id = parent_id or parent.id
# If there's no parent_id but there is a parent object, try to fetch
# its stored version
if not parent_id and parent:
batch = list(self._repo.get(session, [parent]).values())
# If the parent is already stored, use its ID
if batch:
parent = batch[0]
parent_id = parent.id # type: ignore
# Otherwise, check if its key is already among those awaiting flush
# and reuse the same objects (prevents SQLAlchemy from generating
# duplicate inserts)
else:
temp_entity = new_entities.get(parent.entity_key)
if temp_entity:
parent = entity.parent = temp_entity
else:
new_entities[parent.entity_key] = parent
# Recursively apply any changes up in the hierarchy
self._update_parent(session, parent, new_entities=new_entities)
# If we found a parent_id, populate it on the entity (and remove the
# supporting relationship object so SQLAlchemy doesn't go nuts when
# flushing)
if parent_id:
entity.parent = None
entity.parent_id = parent_id # type: ignore
return parent_id, parent
@classmethod
def _merge_columns(cls, entity: Entity, existing_entity: Entity) -> Entity:
"""
Merge two versions of an entity column by column.
"""
columns = [col.key for col in entity.columns]
for col in columns:
if col == 'meta':
existing_entity.meta = { # type: ignore
**(existing_entity.meta or {}), # type: ignore
**(entity.meta or {}), # type: ignore
}
elif col not in ('id', 'created_at'):
setattr(existing_entity, col, getattr(entity, col))
return existing_entity