diff --git a/platypush/entities/__init__.py b/platypush/entities/__init__.py index 361d67e5..2495df3d 100644 --- a/platypush/entities/__init__.py +++ b/platypush/entities/__init__.py @@ -1,7 +1,7 @@ import warnings from typing import Collection, Optional -from ._base import Entity, get_entities_registry +from ._base import Entity, get_entities_registry, db_url from ._engine import EntitiesEngine from ._registry import manages, register_entity_plugin, get_plugin_entity_registry @@ -29,6 +29,7 @@ def publish_entities(entities: Collection[Entity]): __all__ = ( 'Entity', 'EntitiesEngine', + 'db_url', 'init_entities_engine', 'publish_entities', 'register_entity_plugin', diff --git a/platypush/entities/_base.py b/platypush/entities/_base.py index 170c5a1d..b041a4ff 100644 --- a/platypush/entities/_base.py +++ b/platypush/entities/_base.py @@ -1,4 +1,5 @@ import inspect +import os import pathlib import types from datetime import datetime @@ -18,11 +19,13 @@ from sqlalchemy import ( ) from sqlalchemy.orm import declarative_base, ColumnProperty +from platypush.config import Config from platypush.message import JSONAble Base = declarative_base() entities_registry: Mapping[Type['Entity'], Mapping] = {} entity_types_registry: Dict[str, Type['Entity']] = {} +db_url = 'sqlite:///' + os.path.join(str(Config.get('workdir') or ''), 'entities.db') class Entity(Base): @@ -135,5 +138,8 @@ def init_entities_db(): _discover_entity_types() db = get_plugin('db') assert db - engine = db.get_engine() - db.create_all(engine, Base) + + engine = db.get_engine(engine=db_url) + with db.get_session() as session: + db.create_all(engine, Base) + session.flush() diff --git a/platypush/entities/_engine.py b/platypush/entities/_engine.py index 187bab17..88090e3b 100644 --- a/platypush/entities/_engine.py +++ b/platypush/entities/_engine.py @@ -11,7 +11,7 @@ from sqlalchemy.orm import Session, make_transient from platypush.context import get_bus from platypush.message.event.entities import EntityUpdateEvent -from ._base import Entity +from ._base import Entity, db_url class EntitiesEngine(Thread): @@ -98,7 +98,7 @@ class EntitiesEngine(Thread): self._cache_entities(new_entity) def _init_entities_cache(self): - with self._get_db().get_session() as session: + with self._get_db().get_session(engine=db_url) as session: entities = session.query(Entity).all() for entity in entities: make_transient(entity) @@ -249,7 +249,7 @@ class EntitiesEngine(Thread): return list(new_entities.values()) def _process_entities(self, *entities: Entity): - with self._get_db().get_session() as session: + with self._get_db().get_session(engine=db_url) as session: # Ensure that the internal IDs are set to null before the merge for e in entities: e.id = None # type: ignore diff --git a/platypush/plugins/db/__init__.py b/platypush/plugins/db/__init__.py index 793ee1b9..91ae2b93 100644 --- a/platypush/plugins/db/__init__.py +++ b/platypush/plugins/db/__init__.py @@ -1,7 +1,7 @@ import time from contextlib import contextmanager from multiprocessing import RLock -from typing import Optional, Generator +from typing import Optional, Generator, Union from sqlalchemy import create_engine, Table, MetaData from sqlalchemy.engine import Engine @@ -39,19 +39,29 @@ class DbPlugin(Plugin): """ super().__init__() + self.engine_url = engine self.engine = self.get_engine(engine, *args, **kwargs) self._session_locks = {} - def get_engine(self, engine=None, *args, **kwargs) -> Engine: - if engine: + def get_engine( + self, engine: Optional[Union[str, Engine]] = None, *args, **kwargs + ) -> Engine: + if engine == self.engine_url and self.engine: + return self.engine + + if engine or not self.engine: if isinstance(engine, Engine): return engine - if engine.startswith('sqlite://'): + if not engine: + engine = self.engine_url + if isinstance(engine, str) and engine.startswith('sqlite://'): kwargs['connect_args'] = {'check_same_thread': False} return create_engine(engine, *args, **kwargs) # type: ignore - assert self.engine + if not self.engine: + return create_engine(self.engine_url, *args, **kwargs) # type: ignore + return self.engine @staticmethod diff --git a/platypush/plugins/entities/__init__.py b/platypush/plugins/entities/__init__.py index e600a2d4..723f07fa 100644 --- a/platypush/plugins/entities/__init__.py +++ b/platypush/plugins/entities/__init__.py @@ -8,7 +8,12 @@ from sqlalchemy.orm import make_transient from platypush.config import Config from platypush.context import get_plugin, get_bus -from platypush.entities import Entity, get_plugin_entity_registry, get_entities_registry +from platypush.entities import ( + Entity, + get_plugin_entity_registry, + get_entities_registry, + db_url, +) from platypush.message.event.entities import EntityUpdateEvent, EntityDeleteEvent from platypush.plugins import Plugin, action @@ -22,10 +27,10 @@ class EntitiesPlugin(Plugin): def __init__(self, **kwargs): super().__init__(**kwargs) - def _get_db(self): + def _get_session(self): db = get_plugin('db') assert db - return db + return db.get_session(engine=db_url) @action def get( @@ -58,7 +63,6 @@ class EntitiesPlugin(Plugin): selected_types = entity_types.keys() - db = self._get_db() enabled_plugins = list( { *Config.get_plugins().keys(), @@ -66,7 +70,7 @@ class EntitiesPlugin(Plugin): } ) - with db.get_session() as session: + with self._get_session() as session: query = session.query(Entity).filter( or_(Entity.plugin.in_(enabled_plugins), Entity.plugin.is_(None)) ) @@ -173,8 +177,7 @@ class EntitiesPlugin(Plugin): :param args: Action's extra positional arguments. :param kwargs: Action's extra named arguments. """ - db = self._get_db() - with db.get_session() as session: + with self._get_session() as session: entity = session.query(Entity).filter_by(id=id).one_or_none() assert entity, f'No such entity ID: {id}' @@ -192,7 +195,7 @@ class EntitiesPlugin(Plugin): :param entities: IDs of the entities to be removed. :return: The payload of the deleted entities. """ - with self._get_db().get_session() as session: + with self._get_session() as session: entities: Collection[Entity] = ( session.query(Entity).filter(Entity.id.in_(entities)).all() ) @@ -233,7 +236,7 @@ class EntitiesPlugin(Plugin): :return: The updated entities. """ entities = {str(k): v for k, v in entities.items()} - with self._get_db().get_session() as session: + with self._get_session() as session: objs = session.query(Entity).filter(Entity.id.in_(entities.keys())).all() for obj in objs: obj.meta = {**(obj.meta or {}), **(entities.get(str(obj.id), {}))} diff --git a/platypush/user/__init__.py b/platypush/user/__init__.py index a9e759e5..a47f860b 100644 --- a/platypush/user/__init__.py +++ b/platypush/user/__init__.py @@ -33,8 +33,11 @@ class UserManager: def __init__(self): self.db = get_plugin('db') assert self.db - self._engine = self.db.get_engine() - self.db.create_all(self._engine, Base) + self._engine = self.db.get_engine(engine=self.db.engine_url) + + with self.db.get_session() as session: + self.db.create_all(self._engine, Base) + session.flush() @staticmethod def _mask_password(user):