forked from platypush/platypush
Encapsulate _get_session
in EntityManager
This commit is contained in:
parent
8450129858
commit
bfeb0a08c4
1 changed files with 5 additions and 7 deletions
|
@ -8,7 +8,7 @@ from typing import Iterable, List, Optional
|
||||||
from sqlalchemy import and_, or_
|
from sqlalchemy import and_, or_
|
||||||
from sqlalchemy.orm import Session, make_transient
|
from sqlalchemy.orm import Session, make_transient
|
||||||
|
|
||||||
from platypush.context import get_bus
|
from platypush.context import get_bus, get_plugin
|
||||||
from platypush.message.event.entities import EntityUpdateEvent
|
from platypush.message.event.entities import EntityUpdateEvent
|
||||||
|
|
||||||
from ._base import Entity, db_url
|
from ._base import Entity, db_url
|
||||||
|
@ -32,12 +32,10 @@ class EntitiesEngine(Thread):
|
||||||
'by_name_and_plugin': {},
|
'by_name_and_plugin': {},
|
||||||
}
|
}
|
||||||
|
|
||||||
def _get_db(self):
|
def _get_session(self):
|
||||||
from platypush.context import get_plugin
|
|
||||||
|
|
||||||
db = get_plugin('db')
|
db = get_plugin('db')
|
||||||
assert db
|
assert db
|
||||||
return db
|
return db.get_session(engine=db_url)
|
||||||
|
|
||||||
def _get_cached_entity(self, entity: Entity) -> Optional[dict]:
|
def _get_cached_entity(self, entity: Entity) -> Optional[dict]:
|
||||||
if entity.id:
|
if entity.id:
|
||||||
|
@ -98,7 +96,7 @@ class EntitiesEngine(Thread):
|
||||||
self._cache_entities(new_entity)
|
self._cache_entities(new_entity)
|
||||||
|
|
||||||
def _init_entities_cache(self):
|
def _init_entities_cache(self):
|
||||||
with self._get_db().get_session(engine=db_url) as session:
|
with self._get_session() as session:
|
||||||
entities = session.query(Entity).all()
|
entities = session.query(Entity).all()
|
||||||
for entity in entities:
|
for entity in entities:
|
||||||
make_transient(entity)
|
make_transient(entity)
|
||||||
|
@ -249,7 +247,7 @@ class EntitiesEngine(Thread):
|
||||||
return list(new_entities.values())
|
return list(new_entities.values())
|
||||||
|
|
||||||
def _process_entities(self, *entities: Entity):
|
def _process_entities(self, *entities: Entity):
|
||||||
with self._get_db().get_session(engine=db_url) as session:
|
with self._get_session() as session:
|
||||||
# Ensure that the internal IDs are set to null before the merge
|
# Ensure that the internal IDs are set to null before the merge
|
||||||
for e in entities:
|
for e in entities:
|
||||||
e.id = None # type: ignore
|
e.id = None # type: ignore
|
||||||
|
|
Loading…
Reference in a new issue