From 86edd70d93b1787fa45f36d6358d5054f0ec7e93 Mon Sep 17 00:00:00 2001 From: Fabio Manganiello Date: Sat, 12 Nov 2022 15:36:17 +0100 Subject: [PATCH] Fixed session/concurrency management on the main SQLite db - The `declarative_base` instance should be shared - Database `session_locks` should be stored at module, not instance level - Better isolation of scoped sessions - Enclapsulated `get_session` method in `UserManager` --- platypush/common/db.py | 3 ++ platypush/entities/__init__.py | 3 +- platypush/entities/_base.py | 13 ++------ platypush/entities/_engine.py | 4 +-- platypush/plugins/db/__init__.py | 30 ++++++++++-------- platypush/plugins/entities/__init__.py | 3 +- platypush/user/__init__.py | 43 +++++++++++++------------- 7 files changed, 48 insertions(+), 51 deletions(-) create mode 100644 platypush/common/db.py diff --git a/platypush/common/db.py b/platypush/common/db.py new file mode 100644 index 00000000..59be7030 --- /dev/null +++ b/platypush/common/db.py @@ -0,0 +1,3 @@ +from sqlalchemy.orm import declarative_base + +Base = declarative_base() diff --git a/platypush/entities/__init__.py b/platypush/entities/__init__.py index 2495df3d..361d67e5 100644 --- a/platypush/entities/__init__.py +++ b/platypush/entities/__init__.py @@ -1,7 +1,7 @@ import warnings from typing import Collection, Optional -from ._base import Entity, get_entities_registry, db_url +from ._base import Entity, get_entities_registry from ._engine import EntitiesEngine from ._registry import manages, register_entity_plugin, get_plugin_entity_registry @@ -29,7 +29,6 @@ def publish_entities(entities: Collection[Entity]): __all__ = ( 'Entity', 'EntitiesEngine', - 'db_url', 'init_entities_engine', 'publish_entities', 'register_entity_plugin', diff --git a/platypush/entities/_base.py b/platypush/entities/_base.py index b041a4ff..e69f7c77 100644 --- a/platypush/entities/_base.py +++ b/platypush/entities/_base.py @@ -1,5 +1,4 @@ import inspect -import os import pathlib import types from datetime import datetime @@ -17,15 +16,13 @@ from sqlalchemy import ( UniqueConstraint, inspect as schema_inspect, ) -from sqlalchemy.orm import declarative_base, ColumnProperty +from sqlalchemy.orm import ColumnProperty -from platypush.config import Config +from platypush.common.db import Base from platypush.message import JSONAble -Base = declarative_base() entities_registry: Mapping[Type['Entity'], Mapping] = {} entity_types_registry: Dict[str, Type['Entity']] = {} -db_url = 'sqlite:///' + os.path.join(str(Config.get('workdir') or ''), 'entities.db') class Entity(Base): @@ -138,8 +135,4 @@ def init_entities_db(): _discover_entity_types() db = get_plugin('db') assert db - - engine = db.get_engine(engine=db_url) - with db.get_session() as session: - db.create_all(engine, Base) - session.flush() + db.create_all(db.get_engine(), Base) diff --git a/platypush/entities/_engine.py b/platypush/entities/_engine.py index cc023638..192dbb9e 100644 --- a/platypush/entities/_engine.py +++ b/platypush/entities/_engine.py @@ -11,7 +11,7 @@ from sqlalchemy.orm import Session, make_transient from platypush.context import get_bus, get_plugin from platypush.message.event.entities import EntityUpdateEvent -from ._base import Entity, db_url +from ._base import Entity class EntitiesEngine(Thread): @@ -35,7 +35,7 @@ class EntitiesEngine(Thread): def _get_session(self): db = get_plugin('db') assert db - return db.get_session(engine=db_url) + return db.get_session() def _get_cached_entity(self, entity: Entity) -> Optional[dict]: if entity.id: diff --git a/platypush/plugins/db/__init__.py b/platypush/plugins/db/__init__.py index 91ae2b93..175c6d01 100644 --- a/platypush/plugins/db/__init__.py +++ b/platypush/plugins/db/__init__.py @@ -11,6 +11,8 @@ from sqlalchemy.sql import and_, or_, text from platypush.plugins import Plugin, action +session_locks = {} + class DbPlugin(Plugin): """ @@ -41,7 +43,6 @@ class DbPlugin(Plugin): super().__init__() self.engine_url = engine self.engine = self.get_engine(engine, *args, **kwargs) - self._session_locks = {} def get_engine( self, engine: Optional[Union[str, Engine]] = None, *args, **kwargs @@ -60,7 +61,7 @@ class DbPlugin(Plugin): return create_engine(engine, *args, **kwargs) # type: ignore if not self.engine: - return create_engine(self.engine_url, *args, **kwargs) # type: ignore + self.engine = create_engine(self.engine_url, *args, **kwargs) # type: ignore return self.engine @@ -508,23 +509,26 @@ class DbPlugin(Plugin): connection.execute(delete) def create_all(self, engine, base): - self._session_locks[engine.url] = self._session_locks.get(engine.url, RLock()) - with self._session_locks[engine.url]: - base.metadata.create_all(engine) + with (self.get_session(engine) as session, session.begin()): + base.metadata.create_all(session.connection()) @contextmanager def get_session( self, engine=None, *args, **kwargs ) -> Generator[Session, None, None]: engine = self.get_engine(engine, *args, **kwargs) - self._session_locks[engine.url] = self._session_locks.get(engine.url, RLock()) - with self._session_locks[engine.url]: - session = scoped_session(sessionmaker(expire_on_commit=False)) - session.configure(bind=engine) - s = session() - yield s - s.commit() - s.close() + session_locks[engine.url] = session_locks.get(engine.url, RLock()) + + with (session_locks[engine.url], engine.connect() as conn, conn.begin()): + session = scoped_session( + sessionmaker( + expire_on_commit=False, + autoflush=True, + ) + ) + + session.configure(bind=conn) + yield session() # vim:sw=4:ts=4:et: diff --git a/platypush/plugins/entities/__init__.py b/platypush/plugins/entities/__init__.py index 723f07fa..c9eb7c5f 100644 --- a/platypush/plugins/entities/__init__.py +++ b/platypush/plugins/entities/__init__.py @@ -12,7 +12,6 @@ from platypush.entities import ( Entity, get_plugin_entity_registry, get_entities_registry, - db_url, ) from platypush.message.event.entities import EntityUpdateEvent, EntityDeleteEvent from platypush.plugins import Plugin, action @@ -30,7 +29,7 @@ class EntitiesPlugin(Plugin): def _get_session(self): db = get_plugin('db') assert db - return db.get_session(engine=db_url) + return db.get_session() @action def get( diff --git a/platypush/user/__init__.py b/platypush/user/__init__.py index a47f860b..e821d333 100644 --- a/platypush/user/__init__.py +++ b/platypush/user/__init__.py @@ -13,8 +13,9 @@ except ImportError: from jwt import PyJWTError, encode as jwt_encode, decode as jwt_decode from sqlalchemy import Column, Integer, String, DateTime, ForeignKey -from sqlalchemy.orm import make_transient, declarative_base +from sqlalchemy.orm import make_transient +from platypush.common.db import Base from platypush.context import get_plugin from platypush.exceptions.user import ( InvalidJWTTokenException, @@ -22,8 +23,6 @@ from platypush.exceptions.user import ( ) from platypush.utils import get_or_generate_jwt_rsa_key_pair -Base = declarative_base() - class UserManager: """ @@ -31,13 +30,10 @@ class UserManager: """ def __init__(self): - self.db = get_plugin('db') - assert self.db - self._engine = self.db.get_engine(engine=self.db.engine_url) - - with self.db.get_session() as session: - self.db.create_all(self._engine, Base) - session.flush() + db_plugin = get_plugin('db') + assert db_plugin, 'Database plugin not configured' + self.db = db_plugin + self.db.create_all(self.db.get_engine(), Base) @staticmethod def _mask_password(user): @@ -45,8 +41,11 @@ class UserManager: user.password = None return user + def _get_session(self, *args, **kwargs): + return self.db.get_session(self.db.get_engine(), *args, **kwargs) + def get_user(self, username): - with self.db.get_session() as session: + with self._get_session() as session: user = self._get_user(session, username) if not user: return None @@ -55,11 +54,11 @@ class UserManager: return self._mask_password(user) def get_user_count(self): - with self.db.get_session() as session: + with self._get_session() as session: return session.query(User).count() def get_users(self): - with self.db.get_session() as session: + with self._get_session() as session: return session.query(User) def create_user(self, username, password, **kwargs): @@ -68,7 +67,7 @@ class UserManager: if not password: raise ValueError('Please provide a password for the user') - with self.db.get_session() as session: + with self._get_session() as session: user = self._get_user(session, username) if user: raise NameError('The user {} already exists'.format(username)) @@ -87,7 +86,7 @@ class UserManager: return self._mask_password(user) def update_password(self, username, old_password, new_password): - with self.db.get_session() as session: + with self._get_session() as session: if not self._authenticate_user(session, username, old_password): return False @@ -97,11 +96,11 @@ class UserManager: return True def authenticate_user(self, username, password): - with self.db.get_session() as session: + with self._get_session() as session: return self._authenticate_user(session, username, password) def authenticate_user_session(self, session_token): - with self.db.get_session() as session: + with self._get_session() as session: user_session = ( session.query(UserSession) .filter_by(session_token=session_token) @@ -118,7 +117,7 @@ class UserManager: return self._mask_password(user), user_session def delete_user(self, username): - with self.db.get_session() as session: + with self._get_session() as session: user = self._get_user(session, username) if not user: raise NameError('No such user: {}'.format(username)) @@ -134,7 +133,7 @@ class UserManager: return True def delete_user_session(self, session_token): - with self.db.get_session() as session: + with self._get_session() as session: user_session = ( session.query(UserSession) .filter_by(session_token=session_token) @@ -149,7 +148,7 @@ class UserManager: return True def create_user_session(self, username, password, expires_at=None): - with self.db.get_session() as session: + with self._get_session() as session: user = self._authenticate_user(session, username, password) if not user: return None @@ -203,7 +202,7 @@ class UserManager: :param session_token: Session token. """ - with self.db.get_session() as session: + with self._get_session() as session: return ( session.query(User) .join(UserSession) @@ -263,7 +262,7 @@ class UserManager: pub_key, _ = get_or_generate_jwt_rsa_key_pair() try: - payload = jwt_decode(token.encode(), pub_key, algorithms=['RS256']) + payload = jwt_decode(token.encode(), pub_key, algorithms=['RS256']) # type: ignore[reportGeneralTypeIssues] except PyJWTError as e: raise InvalidJWTTokenException(str(e))