Fixed session/concurrency management on the main SQLite db

- The `declarative_base` instance should be shared
- Database `session_locks` should be stored at module, not instance
  level
- Better isolation of scoped sessions
- Enclapsulated `get_session` method in `UserManager`
This commit is contained in:
Fabio Manganiello 2022-11-12 15:36:17 +01:00
parent bfeb0a08c4
commit 86edd70d93
Signed by untrusted user: blacklight
GPG key ID: D90FBA7F76362774
7 changed files with 48 additions and 51 deletions

3
platypush/common/db.py Normal file
View file

@ -0,0 +1,3 @@
from sqlalchemy.orm import declarative_base
Base = declarative_base()

View file

@ -1,7 +1,7 @@
import warnings import warnings
from typing import Collection, Optional from typing import Collection, Optional
from ._base import Entity, get_entities_registry, db_url from ._base import Entity, get_entities_registry
from ._engine import EntitiesEngine from ._engine import EntitiesEngine
from ._registry import manages, register_entity_plugin, get_plugin_entity_registry from ._registry import manages, register_entity_plugin, get_plugin_entity_registry
@ -29,7 +29,6 @@ def publish_entities(entities: Collection[Entity]):
__all__ = ( __all__ = (
'Entity', 'Entity',
'EntitiesEngine', 'EntitiesEngine',
'db_url',
'init_entities_engine', 'init_entities_engine',
'publish_entities', 'publish_entities',
'register_entity_plugin', 'register_entity_plugin',

View file

@ -1,5 +1,4 @@
import inspect import inspect
import os
import pathlib import pathlib
import types import types
from datetime import datetime from datetime import datetime
@ -17,15 +16,13 @@ from sqlalchemy import (
UniqueConstraint, UniqueConstraint,
inspect as schema_inspect, inspect as schema_inspect,
) )
from sqlalchemy.orm import declarative_base, ColumnProperty from sqlalchemy.orm import ColumnProperty
from platypush.config import Config from platypush.common.db import Base
from platypush.message import JSONAble from platypush.message import JSONAble
Base = declarative_base()
entities_registry: Mapping[Type['Entity'], Mapping] = {} entities_registry: Mapping[Type['Entity'], Mapping] = {}
entity_types_registry: Dict[str, Type['Entity']] = {} entity_types_registry: Dict[str, Type['Entity']] = {}
db_url = 'sqlite:///' + os.path.join(str(Config.get('workdir') or ''), 'entities.db')
class Entity(Base): class Entity(Base):
@ -138,8 +135,4 @@ def init_entities_db():
_discover_entity_types() _discover_entity_types()
db = get_plugin('db') db = get_plugin('db')
assert db assert db
db.create_all(db.get_engine(), Base)
engine = db.get_engine(engine=db_url)
with db.get_session() as session:
db.create_all(engine, Base)
session.flush()

View file

@ -11,7 +11,7 @@ from sqlalchemy.orm import Session, make_transient
from platypush.context import get_bus, get_plugin from platypush.context import get_bus, get_plugin
from platypush.message.event.entities import EntityUpdateEvent from platypush.message.event.entities import EntityUpdateEvent
from ._base import Entity, db_url from ._base import Entity
class EntitiesEngine(Thread): class EntitiesEngine(Thread):
@ -35,7 +35,7 @@ class EntitiesEngine(Thread):
def _get_session(self): def _get_session(self):
db = get_plugin('db') db = get_plugin('db')
assert db assert db
return db.get_session(engine=db_url) return db.get_session()
def _get_cached_entity(self, entity: Entity) -> Optional[dict]: def _get_cached_entity(self, entity: Entity) -> Optional[dict]:
if entity.id: if entity.id:

View file

@ -11,6 +11,8 @@ from sqlalchemy.sql import and_, or_, text
from platypush.plugins import Plugin, action from platypush.plugins import Plugin, action
session_locks = {}
class DbPlugin(Plugin): class DbPlugin(Plugin):
""" """
@ -41,7 +43,6 @@ class DbPlugin(Plugin):
super().__init__() super().__init__()
self.engine_url = engine self.engine_url = engine
self.engine = self.get_engine(engine, *args, **kwargs) self.engine = self.get_engine(engine, *args, **kwargs)
self._session_locks = {}
def get_engine( def get_engine(
self, engine: Optional[Union[str, Engine]] = None, *args, **kwargs self, engine: Optional[Union[str, Engine]] = None, *args, **kwargs
@ -60,7 +61,7 @@ class DbPlugin(Plugin):
return create_engine(engine, *args, **kwargs) # type: ignore return create_engine(engine, *args, **kwargs) # type: ignore
if not self.engine: if not self.engine:
return create_engine(self.engine_url, *args, **kwargs) # type: ignore self.engine = create_engine(self.engine_url, *args, **kwargs) # type: ignore
return self.engine return self.engine
@ -508,23 +509,26 @@ class DbPlugin(Plugin):
connection.execute(delete) connection.execute(delete)
def create_all(self, engine, base): def create_all(self, engine, base):
self._session_locks[engine.url] = self._session_locks.get(engine.url, RLock()) with (self.get_session(engine) as session, session.begin()):
with self._session_locks[engine.url]: base.metadata.create_all(session.connection())
base.metadata.create_all(engine)
@contextmanager @contextmanager
def get_session( def get_session(
self, engine=None, *args, **kwargs self, engine=None, *args, **kwargs
) -> Generator[Session, None, None]: ) -> Generator[Session, None, None]:
engine = self.get_engine(engine, *args, **kwargs) engine = self.get_engine(engine, *args, **kwargs)
self._session_locks[engine.url] = self._session_locks.get(engine.url, RLock()) session_locks[engine.url] = session_locks.get(engine.url, RLock())
with self._session_locks[engine.url]:
session = scoped_session(sessionmaker(expire_on_commit=False)) with (session_locks[engine.url], engine.connect() as conn, conn.begin()):
session.configure(bind=engine) session = scoped_session(
s = session() sessionmaker(
yield s expire_on_commit=False,
s.commit() autoflush=True,
s.close() )
)
session.configure(bind=conn)
yield session()
# vim:sw=4:ts=4:et: # vim:sw=4:ts=4:et:

View file

@ -12,7 +12,6 @@ from platypush.entities import (
Entity, Entity,
get_plugin_entity_registry, get_plugin_entity_registry,
get_entities_registry, get_entities_registry,
db_url,
) )
from platypush.message.event.entities import EntityUpdateEvent, EntityDeleteEvent from platypush.message.event.entities import EntityUpdateEvent, EntityDeleteEvent
from platypush.plugins import Plugin, action from platypush.plugins import Plugin, action
@ -30,7 +29,7 @@ class EntitiesPlugin(Plugin):
def _get_session(self): def _get_session(self):
db = get_plugin('db') db = get_plugin('db')
assert db assert db
return db.get_session(engine=db_url) return db.get_session()
@action @action
def get( def get(

View file

@ -13,8 +13,9 @@ except ImportError:
from jwt import PyJWTError, encode as jwt_encode, decode as jwt_decode from jwt import PyJWTError, encode as jwt_encode, decode as jwt_decode
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import make_transient, declarative_base from sqlalchemy.orm import make_transient
from platypush.common.db import Base
from platypush.context import get_plugin from platypush.context import get_plugin
from platypush.exceptions.user import ( from platypush.exceptions.user import (
InvalidJWTTokenException, InvalidJWTTokenException,
@ -22,8 +23,6 @@ from platypush.exceptions.user import (
) )
from platypush.utils import get_or_generate_jwt_rsa_key_pair from platypush.utils import get_or_generate_jwt_rsa_key_pair
Base = declarative_base()
class UserManager: class UserManager:
""" """
@ -31,13 +30,10 @@ class UserManager:
""" """
def __init__(self): def __init__(self):
self.db = get_plugin('db') db_plugin = get_plugin('db')
assert self.db assert db_plugin, 'Database plugin not configured'
self._engine = self.db.get_engine(engine=self.db.engine_url) self.db = db_plugin
self.db.create_all(self.db.get_engine(), Base)
with self.db.get_session() as session:
self.db.create_all(self._engine, Base)
session.flush()
@staticmethod @staticmethod
def _mask_password(user): def _mask_password(user):
@ -45,8 +41,11 @@ class UserManager:
user.password = None user.password = None
return user return user
def _get_session(self, *args, **kwargs):
return self.db.get_session(self.db.get_engine(), *args, **kwargs)
def get_user(self, username): def get_user(self, username):
with self.db.get_session() as session: with self._get_session() as session:
user = self._get_user(session, username) user = self._get_user(session, username)
if not user: if not user:
return None return None
@ -55,11 +54,11 @@ class UserManager:
return self._mask_password(user) return self._mask_password(user)
def get_user_count(self): def get_user_count(self):
with self.db.get_session() as session: with self._get_session() as session:
return session.query(User).count() return session.query(User).count()
def get_users(self): def get_users(self):
with self.db.get_session() as session: with self._get_session() as session:
return session.query(User) return session.query(User)
def create_user(self, username, password, **kwargs): def create_user(self, username, password, **kwargs):
@ -68,7 +67,7 @@ class UserManager:
if not password: if not password:
raise ValueError('Please provide a password for the user') raise ValueError('Please provide a password for the user')
with self.db.get_session() as session: with self._get_session() as session:
user = self._get_user(session, username) user = self._get_user(session, username)
if user: if user:
raise NameError('The user {} already exists'.format(username)) raise NameError('The user {} already exists'.format(username))
@ -87,7 +86,7 @@ class UserManager:
return self._mask_password(user) return self._mask_password(user)
def update_password(self, username, old_password, new_password): def update_password(self, username, old_password, new_password):
with self.db.get_session() as session: with self._get_session() as session:
if not self._authenticate_user(session, username, old_password): if not self._authenticate_user(session, username, old_password):
return False return False
@ -97,11 +96,11 @@ class UserManager:
return True return True
def authenticate_user(self, username, password): def authenticate_user(self, username, password):
with self.db.get_session() as session: with self._get_session() as session:
return self._authenticate_user(session, username, password) return self._authenticate_user(session, username, password)
def authenticate_user_session(self, session_token): def authenticate_user_session(self, session_token):
with self.db.get_session() as session: with self._get_session() as session:
user_session = ( user_session = (
session.query(UserSession) session.query(UserSession)
.filter_by(session_token=session_token) .filter_by(session_token=session_token)
@ -118,7 +117,7 @@ class UserManager:
return self._mask_password(user), user_session return self._mask_password(user), user_session
def delete_user(self, username): def delete_user(self, username):
with self.db.get_session() as session: with self._get_session() as session:
user = self._get_user(session, username) user = self._get_user(session, username)
if not user: if not user:
raise NameError('No such user: {}'.format(username)) raise NameError('No such user: {}'.format(username))
@ -134,7 +133,7 @@ class UserManager:
return True return True
def delete_user_session(self, session_token): def delete_user_session(self, session_token):
with self.db.get_session() as session: with self._get_session() as session:
user_session = ( user_session = (
session.query(UserSession) session.query(UserSession)
.filter_by(session_token=session_token) .filter_by(session_token=session_token)
@ -149,7 +148,7 @@ class UserManager:
return True return True
def create_user_session(self, username, password, expires_at=None): def create_user_session(self, username, password, expires_at=None):
with self.db.get_session() as session: with self._get_session() as session:
user = self._authenticate_user(session, username, password) user = self._authenticate_user(session, username, password)
if not user: if not user:
return None return None
@ -203,7 +202,7 @@ class UserManager:
:param session_token: Session token. :param session_token: Session token.
""" """
with self.db.get_session() as session: with self._get_session() as session:
return ( return (
session.query(User) session.query(User)
.join(UserSession) .join(UserSession)
@ -263,7 +262,7 @@ class UserManager:
pub_key, _ = get_or_generate_jwt_rsa_key_pair() pub_key, _ = get_or_generate_jwt_rsa_key_pair()
try: try:
payload = jwt_decode(token.encode(), pub_key, algorithms=['RS256']) payload = jwt_decode(token.encode(), pub_key, algorithms=['RS256']) # type: ignore[reportGeneralTypeIssues]
except PyJWTError as e: except PyJWTError as e:
raise InvalidJWTTokenException(str(e)) raise InvalidJWTTokenException(str(e))