Using a different SQLite database for entities

This prevents multiprocessing/concurrency issues when modifying the same
database file both from the main process and from the web server process
This commit is contained in:
Fabio Manganiello 2022-11-12 02:00:55 +01:00
parent 3fc94181b7
commit 6b7933cd33
Signed by: blacklight
GPG key ID: D90FBA7F76362774
6 changed files with 45 additions and 22 deletions

View file

@ -1,7 +1,7 @@
import warnings import warnings
from typing import Collection, Optional from typing import Collection, Optional
from ._base import Entity, get_entities_registry from ._base import Entity, get_entities_registry, db_url
from ._engine import EntitiesEngine from ._engine import EntitiesEngine
from ._registry import manages, register_entity_plugin, get_plugin_entity_registry from ._registry import manages, register_entity_plugin, get_plugin_entity_registry
@ -29,6 +29,7 @@ def publish_entities(entities: Collection[Entity]):
__all__ = ( __all__ = (
'Entity', 'Entity',
'EntitiesEngine', 'EntitiesEngine',
'db_url',
'init_entities_engine', 'init_entities_engine',
'publish_entities', 'publish_entities',
'register_entity_plugin', 'register_entity_plugin',

View file

@ -1,4 +1,5 @@
import inspect import inspect
import os
import pathlib import pathlib
import types import types
from datetime import datetime from datetime import datetime
@ -18,11 +19,13 @@ from sqlalchemy import (
) )
from sqlalchemy.orm import declarative_base, ColumnProperty from sqlalchemy.orm import declarative_base, ColumnProperty
from platypush.config import Config
from platypush.message import JSONAble from platypush.message import JSONAble
Base = declarative_base() Base = declarative_base()
entities_registry: Mapping[Type['Entity'], Mapping] = {} entities_registry: Mapping[Type['Entity'], Mapping] = {}
entity_types_registry: Dict[str, Type['Entity']] = {} entity_types_registry: Dict[str, Type['Entity']] = {}
db_url = 'sqlite:///' + os.path.join(str(Config.get('workdir') or ''), 'entities.db')
class Entity(Base): class Entity(Base):
@ -135,5 +138,8 @@ def init_entities_db():
_discover_entity_types() _discover_entity_types()
db = get_plugin('db') db = get_plugin('db')
assert db assert db
engine = db.get_engine()
db.create_all(engine, Base) engine = db.get_engine(engine=db_url)
with db.get_session() as session:
db.create_all(engine, Base)
session.flush()

View file

@ -11,7 +11,7 @@ from sqlalchemy.orm import Session, make_transient
from platypush.context import get_bus from platypush.context import get_bus
from platypush.message.event.entities import EntityUpdateEvent from platypush.message.event.entities import EntityUpdateEvent
from ._base import Entity from ._base import Entity, db_url
class EntitiesEngine(Thread): class EntitiesEngine(Thread):
@ -98,7 +98,7 @@ class EntitiesEngine(Thread):
self._cache_entities(new_entity) self._cache_entities(new_entity)
def _init_entities_cache(self): def _init_entities_cache(self):
with self._get_db().get_session() as session: with self._get_db().get_session(engine=db_url) as session:
entities = session.query(Entity).all() entities = session.query(Entity).all()
for entity in entities: for entity in entities:
make_transient(entity) make_transient(entity)
@ -249,7 +249,7 @@ class EntitiesEngine(Thread):
return list(new_entities.values()) return list(new_entities.values())
def _process_entities(self, *entities: Entity): def _process_entities(self, *entities: Entity):
with self._get_db().get_session() as session: with self._get_db().get_session(engine=db_url) as session:
# Ensure that the internal IDs are set to null before the merge # Ensure that the internal IDs are set to null before the merge
for e in entities: for e in entities:
e.id = None # type: ignore e.id = None # type: ignore

View file

@ -1,7 +1,7 @@
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from multiprocessing import RLock from multiprocessing import RLock
from typing import Optional, Generator from typing import Optional, Generator, Union
from sqlalchemy import create_engine, Table, MetaData from sqlalchemy import create_engine, Table, MetaData
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
@ -39,19 +39,29 @@ class DbPlugin(Plugin):
""" """
super().__init__() super().__init__()
self.engine_url = engine
self.engine = self.get_engine(engine, *args, **kwargs) self.engine = self.get_engine(engine, *args, **kwargs)
self._session_locks = {} self._session_locks = {}
def get_engine(self, engine=None, *args, **kwargs) -> Engine: def get_engine(
if engine: self, engine: Optional[Union[str, Engine]] = None, *args, **kwargs
) -> Engine:
if engine == self.engine_url and self.engine:
return self.engine
if engine or not self.engine:
if isinstance(engine, Engine): if isinstance(engine, Engine):
return engine return engine
if engine.startswith('sqlite://'): if not engine:
engine = self.engine_url
if isinstance(engine, str) and engine.startswith('sqlite://'):
kwargs['connect_args'] = {'check_same_thread': False} kwargs['connect_args'] = {'check_same_thread': False}
return create_engine(engine, *args, **kwargs) # type: ignore return create_engine(engine, *args, **kwargs) # type: ignore
assert self.engine if not self.engine:
return create_engine(self.engine_url, *args, **kwargs) # type: ignore
return self.engine return self.engine
@staticmethod @staticmethod

View file

@ -8,7 +8,12 @@ from sqlalchemy.orm import make_transient
from platypush.config import Config from platypush.config import Config
from platypush.context import get_plugin, get_bus from platypush.context import get_plugin, get_bus
from platypush.entities import Entity, get_plugin_entity_registry, get_entities_registry from platypush.entities import (
Entity,
get_plugin_entity_registry,
get_entities_registry,
db_url,
)
from platypush.message.event.entities import EntityUpdateEvent, EntityDeleteEvent from platypush.message.event.entities import EntityUpdateEvent, EntityDeleteEvent
from platypush.plugins import Plugin, action from platypush.plugins import Plugin, action
@ -22,10 +27,10 @@ class EntitiesPlugin(Plugin):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
def _get_db(self): def _get_session(self):
db = get_plugin('db') db = get_plugin('db')
assert db assert db
return db return db.get_session(engine=db_url)
@action @action
def get( def get(
@ -58,7 +63,6 @@ class EntitiesPlugin(Plugin):
selected_types = entity_types.keys() selected_types = entity_types.keys()
db = self._get_db()
enabled_plugins = list( enabled_plugins = list(
{ {
*Config.get_plugins().keys(), *Config.get_plugins().keys(),
@ -66,7 +70,7 @@ class EntitiesPlugin(Plugin):
} }
) )
with db.get_session() as session: with self._get_session() as session:
query = session.query(Entity).filter( query = session.query(Entity).filter(
or_(Entity.plugin.in_(enabled_plugins), Entity.plugin.is_(None)) or_(Entity.plugin.in_(enabled_plugins), Entity.plugin.is_(None))
) )
@ -173,8 +177,7 @@ class EntitiesPlugin(Plugin):
:param args: Action's extra positional arguments. :param args: Action's extra positional arguments.
:param kwargs: Action's extra named arguments. :param kwargs: Action's extra named arguments.
""" """
db = self._get_db() with self._get_session() as session:
with db.get_session() as session:
entity = session.query(Entity).filter_by(id=id).one_or_none() entity = session.query(Entity).filter_by(id=id).one_or_none()
assert entity, f'No such entity ID: {id}' assert entity, f'No such entity ID: {id}'
@ -192,7 +195,7 @@ class EntitiesPlugin(Plugin):
:param entities: IDs of the entities to be removed. :param entities: IDs of the entities to be removed.
:return: The payload of the deleted entities. :return: The payload of the deleted entities.
""" """
with self._get_db().get_session() as session: with self._get_session() as session:
entities: Collection[Entity] = ( entities: Collection[Entity] = (
session.query(Entity).filter(Entity.id.in_(entities)).all() session.query(Entity).filter(Entity.id.in_(entities)).all()
) )
@ -233,7 +236,7 @@ class EntitiesPlugin(Plugin):
:return: The updated entities. :return: The updated entities.
""" """
entities = {str(k): v for k, v in entities.items()} entities = {str(k): v for k, v in entities.items()}
with self._get_db().get_session() as session: with self._get_session() as session:
objs = session.query(Entity).filter(Entity.id.in_(entities.keys())).all() objs = session.query(Entity).filter(Entity.id.in_(entities.keys())).all()
for obj in objs: for obj in objs:
obj.meta = {**(obj.meta or {}), **(entities.get(str(obj.id), {}))} obj.meta = {**(obj.meta or {}), **(entities.get(str(obj.id), {}))}

View file

@ -33,8 +33,11 @@ class UserManager:
def __init__(self): def __init__(self):
self.db = get_plugin('db') self.db = get_plugin('db')
assert self.db assert self.db
self._engine = self.db.get_engine() self._engine = self.db.get_engine(engine=self.db.engine_url)
self.db.create_all(self._engine, Base)
with self.db.get_session() as session:
self.db.create_all(self._engine, Base)
session.flush()
@staticmethod @staticmethod
def _mask_password(user): def _mask_password(user):