diff --git a/platypush/entities/_engine.py b/platypush/entities/_engine.py index f99a4725..f747027e 100644 --- a/platypush/entities/_engine.py +++ b/platypush/entities/_engine.py @@ -6,14 +6,13 @@ from typing import Iterable, List from sqlalchemy import and_, or_, inspect as schema_inspect from sqlalchemy.orm import Session -from sqlalchemy.sql.elements import Null from ._base import Entity class EntitiesEngine(Thread): # Processing queue timeout in seconds - _queue_timeout = 5. + _queue_timeout = 5.0 def __init__(self): obj_name = self.__class__.__name__ @@ -42,7 +41,8 @@ class EntitiesEngine(Thread): last_poll_time = time() while not self.should_stop and ( - time() - last_poll_time < self._queue_timeout): + time() - last_poll_time < self._queue_timeout + ): try: msg = self._queue.get(block=True, timeout=0.5) except Empty: @@ -58,44 +58,64 @@ class EntitiesEngine(Thread): self.logger.info('Stopped entities engine') - def _get_if_exist(self, session: Session, entities: Iterable[Entity]) -> Iterable[Entity]: + def _get_if_exist( + self, session: Session, entities: Iterable[Entity] + ) -> Iterable[Entity]: existing_entities = { (entity.external_id or entity.name, entity.plugin): entity - for entity in session.query(Entity).filter( - or_(*[ - and_(Entity.external_id == entity.external_id, Entity.plugin == entity.plugin) - if entity.external_id is not None else - and_(Entity.name == entity.name, Entity.plugin == entity.plugin) - for entity in entities - ]) - ).all() + for entity in session.query(Entity) + .filter( + or_( + *[ + and_( + Entity.external_id == entity.external_id, + Entity.plugin == entity.plugin, + ) + if entity.external_id is not None + else and_( + Entity.name == entity.name, Entity.plugin == entity.plugin + ) + for entity in entities + ] + ) + ) + .all() } return [ existing_entities.get( (entity.external_id or entity.name, entity.plugin), None - ) for entity in entities + ) + for entity in entities ] def _merge_entities( - self, entities: List[Entity], - existing_entities: List[Entity] + self, entities: List[Entity], existing_entities: List[Entity] ) -> List[Entity]: - new_entities = [] + def merge(entity: Entity, existing_entity: Entity) -> Entity: + inspector = schema_inspect(entity.__class__) + columns = [col.key for col in inspector.mapper.column_attrs] + for col in columns: + if col not in ('id', 'created_at'): + setattr(existing_entity, col, getattr(entity, col)) + return existing_entity + + new_entities = [] + entities_map = {} + + # Get the latest update for each ((id|name), plugin) record + for e in entities: + key = ((e.external_id or e.name), e.plugin) + entities_map[key] = e + + # Retrieve existing records and merge them for i, entity in enumerate(entities): existing_entity = existing_entities[i] if existing_entity: - inspector = schema_inspect(entity.__class__) - columns = [col.key for col in inspector.mapper.column_attrs] - for col in columns: - new_value = getattr(entity, col) - if new_value is not None and new_value.__class__ != Null: - setattr(existing_entity, col, getattr(entity, col)) + entity = merge(entity, existing_entity) - new_entities.append(existing_entity) - else: - new_entities.append(entity) + new_entities.append(entity) return new_entities @@ -107,4 +127,3 @@ class EntitiesEngine(Thread): entities = self._merge_entities(entities, existing_entities) # type: ignore session.add_all(entities) session.commit() -