diff --git a/platypush/entities/_engine/repo/__init__.py b/platypush/entities/_engine/repo/__init__.py index 49b0b8a3f..8b7f150fb 100644 --- a/platypush/entities/_engine/repo/__init__.py +++ b/platypush/entities/_engine/repo/__init__.py @@ -4,9 +4,8 @@ from typing import Dict, Iterable, Optional, Tuple from sqlalchemy.orm import Session from platypush.entities._base import Entity, EntityMapping - -# pylint: disable=no-name-in-module from platypush.entities._engine.repo.db import EntitiesDb +from platypush.entities._engine.repo.helpers import get_parent from platypush.entities._engine.repo.merger import EntitiesMerger logger = logging.getLogger('entities') @@ -98,7 +97,7 @@ class EntitiesRepository: """ parent = entity while parent: - parent = self._merge.get_parent(session, entity) + parent = get_parent(session, entity) if parent: entity = parent diff --git a/platypush/entities/_engine/repo/db.py b/platypush/entities/_engine/repo/db.py index aab1187af..0d97d2936 100644 --- a/platypush/entities/_engine/repo/db.py +++ b/platypush/entities/_engine/repo/db.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import Session from platypush.context import get_plugin from platypush.entities._base import Entity +from .helpers import get_parent @dataclass @@ -69,7 +70,7 @@ class EntitiesDb: batch.clear() def _split_entity_batches_for_flush( - self, entities: Iterable[Entity] + self, session: Session, entities: Iterable[Entity] ) -> List[List[Entity]]: """ This method retrieves the root entities given a list of entities and @@ -93,9 +94,10 @@ class EntitiesDb: for entity in entities: parent_key = None parent_id = entity.parent_id - if entity.parent: - parent_id = parent_id or entity.parent.id - parent_key = entity.parent.entity_key + parent = get_parent(session, entity) + if parent: + parent_id = parent_id or parent.id + parent_key = parent.entity_key if parent_id: children_by_parent_id[parent_id][entity.entity_key] = entity @@ -169,7 +171,7 @@ class EntitiesDb: Persist a set of entities. """ # Get the "unwrapped" batches - batches = self._split_entity_batches_for_flush(entities) + batches = self._split_entity_batches_for_flush(session, entities) # Flush each batch as we process it for batch in batches: diff --git a/platypush/entities/_engine/repo/helpers.py b/platypush/entities/_engine/repo/helpers.py new file mode 100644 index 000000000..8ac93247a --- /dev/null +++ b/platypush/entities/_engine/repo/helpers.py @@ -0,0 +1,18 @@ +from typing import Optional + +from sqlalchemy.orm import Session, exc + +from platypush.entities import Entity + + +def get_parent(session: Session, entity: Entity) -> Optional[Entity]: + """ + Gets the parent of an entity, and it fetches if it's not available in + the current session. + """ + try: + return entity.parent + except exc.DetachedInstanceError: + # Dirty fix for `Parent instance <...> is not bound to a Session; + # lazy load operation of attribute 'parent' cannot proceed` + return session.query(Entity).get(entity.parent_id) if entity.parent_id else None diff --git a/platypush/entities/_engine/repo/merger.py b/platypush/entities/_engine/repo/merger.py index 6f3e9b344..3d7a46df4 100644 --- a/platypush/entities/_engine/repo/merger.py +++ b/platypush/entities/_engine/repo/merger.py @@ -1,9 +1,11 @@ from typing import Iterable, List, Optional -from sqlalchemy.orm import Session, exc +from sqlalchemy.orm import Session from platypush.entities._base import Entity, EntityMapping +from .helpers import get_parent + # pylint: disable=too-few-public-methods class EntitiesMerger: @@ -93,7 +95,7 @@ class EntitiesMerger: appropriately rewired and that all the relevant objects are added to this session. """ - parent = cls.get_parent(session, entity) + parent = get_parent(session, entity) if not parent: # No parent -> we can terminate the recursive climbing return entity @@ -140,23 +142,6 @@ class EntitiesMerger: cls._sync_parent(session, existing_parent, new_entities, existing_entities) return entity - @staticmethod - def get_parent(session: Session, entity: Entity) -> Optional[Entity]: - """ - Gets the parent of an entity, and it fetches if it's not available in - the current session. - """ - try: - return entity.parent - except exc.DetachedInstanceError: - # Dirty fix for `Parent instance <...> is not bound to a Session; - # lazy load operation of attribute 'parent' cannot proceed` - return ( - session.query(Entity).get(entity.parent_id) - if entity.parent_id - else None - ) - @staticmethod def _append_children(entity: Entity, *children: Entity): """