diff --git a/platypush/entities/_engine.py b/platypush/entities/_engine.py index dfc57e16..2780596a 100644 --- a/platypush/entities/_engine.py +++ b/platypush/entities/_engine.py @@ -167,9 +167,9 @@ class EntitiesEngine(Thread): self.logger.info('Stopped entities engine') - def _get_if_exist( + def get_if_exist( self, session: Session, entities: Iterable[Entity] - ) -> Iterable[Entity]: + ) -> List[Entity]: existing_entities = { ( str(entity.external_id) @@ -211,50 +211,13 @@ class EntitiesEngine(Thread): for entity in entities ] - def _merge_entities( - self, entities: List[Entity], existing_entities: List[Entity] - ) -> List[Entity]: - def merge(entity: Entity, existing_entity: Entity) -> Entity: - columns = [col.key for col in entity.columns] - for col in columns: - if col == 'meta': - existing_entity.meta = { # type: ignore - **(existing_entity.meta or {}), # type: ignore - **(entity.meta or {}), # type: ignore - } - elif col not in ('id', 'created_at'): - setattr(existing_entity, col, getattr(entity, col)) - - return existing_entity - - def entity_key(entity: Entity): - return ((entity.external_id or entity.name), entity.plugin) - - new_entities = {} - entities_map = {} - - # Get the latest update for each ((id|name), plugin) record - for e in entities: - entities_map[entity_key(e)] = e - - # Retrieve existing records and merge them - for i, entity in enumerate(entities): - existing_entity = existing_entities[i] - if existing_entity: - entity = merge(entity, existing_entity) - - new_entities[entity_key(entity)] = entity - - return list(new_entities.values()) - - def _process_entities(self, *entities: Entity): + def _process_entities(self, *entities: Entity): # type: ignore with self._get_session(locked=True) as session: # Ensure that the internal IDs are set to null before the merge for e in entities: e.id = None # type: ignore - existing_entities = self._get_if_exist(session, entities) - entities = self._merge_entities(entities, existing_entities) # type: ignore + entities: List[Entity] = self._merge_entities(session, entities) session.add_all(entities) session.commit() @@ -274,3 +237,45 @@ class EntitiesEngine(Thread): if e in entities_awaiting_flush: self._process_event(entity) self._entities_awaiting_flush.remove(e) + + def _merge_entities( + self, session: Session, entities: Iterable[Entity] + ) -> List[Entity]: + existing_entities = self.get_if_exist(session, entities) + new_entities = {} + entities_map = {} + + # Get the latest update for each ((id|name), plugin) record + for e in entities: + entities_map[self.entity_key(e)] = e + + # Retrieve existing records and merge them + for i, entity in enumerate(entities): + existing_entity = existing_entities[i] + if existing_entity: + existing_entity.children = self._merge_entities( + session, entity.children + ) + entity = self._merge_entity_columns(entity, existing_entity) + + new_entities[self.entity_key(entity)] = entity + + return list(new_entities.values()) + + @classmethod + def _merge_entity_columns(cls, entity: Entity, existing_entity: Entity) -> Entity: + columns = [col.key for col in entity.columns] + for col in columns: + if col == 'meta': + existing_entity.meta = { # type: ignore + **(existing_entity.meta or {}), # type: ignore + **(entity.meta or {}), # type: ignore + } + elif col not in ('id', 'created_at'): + setattr(existing_entity, col, getattr(entity, col)) + + return existing_entity + + @staticmethod + def entity_key(entity: Entity): + return ((entity.external_id or entity.name), entity.plugin)