forked from platypush/platypush
Fixed session/concurrency management on the main SQLite db
- The `declarative_base` instance should be shared - Database `session_locks` should be stored at module, not instance level - Better isolation of scoped sessions - Enclapsulated `get_session` method in `UserManager`
This commit is contained in:
parent
bfeb0a08c4
commit
86edd70d93
7 changed files with 48 additions and 51 deletions
3
platypush/common/db.py
Normal file
3
platypush/common/db.py
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
from sqlalchemy.orm import declarative_base
|
||||||
|
|
||||||
|
Base = declarative_base()
|
|
@ -1,7 +1,7 @@
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Collection, Optional
|
from typing import Collection, Optional
|
||||||
|
|
||||||
from ._base import Entity, get_entities_registry, db_url
|
from ._base import Entity, get_entities_registry
|
||||||
from ._engine import EntitiesEngine
|
from ._engine import EntitiesEngine
|
||||||
from ._registry import manages, register_entity_plugin, get_plugin_entity_registry
|
from ._registry import manages, register_entity_plugin, get_plugin_entity_registry
|
||||||
|
|
||||||
|
@ -29,7 +29,6 @@ def publish_entities(entities: Collection[Entity]):
|
||||||
__all__ = (
|
__all__ = (
|
||||||
'Entity',
|
'Entity',
|
||||||
'EntitiesEngine',
|
'EntitiesEngine',
|
||||||
'db_url',
|
|
||||||
'init_entities_engine',
|
'init_entities_engine',
|
||||||
'publish_entities',
|
'publish_entities',
|
||||||
'register_entity_plugin',
|
'register_entity_plugin',
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
|
||||||
import pathlib
|
import pathlib
|
||||||
import types
|
import types
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
@ -17,15 +16,13 @@ from sqlalchemy import (
|
||||||
UniqueConstraint,
|
UniqueConstraint,
|
||||||
inspect as schema_inspect,
|
inspect as schema_inspect,
|
||||||
)
|
)
|
||||||
from sqlalchemy.orm import declarative_base, ColumnProperty
|
from sqlalchemy.orm import ColumnProperty
|
||||||
|
|
||||||
from platypush.config import Config
|
from platypush.common.db import Base
|
||||||
from platypush.message import JSONAble
|
from platypush.message import JSONAble
|
||||||
|
|
||||||
Base = declarative_base()
|
|
||||||
entities_registry: Mapping[Type['Entity'], Mapping] = {}
|
entities_registry: Mapping[Type['Entity'], Mapping] = {}
|
||||||
entity_types_registry: Dict[str, Type['Entity']] = {}
|
entity_types_registry: Dict[str, Type['Entity']] = {}
|
||||||
db_url = 'sqlite:///' + os.path.join(str(Config.get('workdir') or ''), 'entities.db')
|
|
||||||
|
|
||||||
|
|
||||||
class Entity(Base):
|
class Entity(Base):
|
||||||
|
@ -138,8 +135,4 @@ def init_entities_db():
|
||||||
_discover_entity_types()
|
_discover_entity_types()
|
||||||
db = get_plugin('db')
|
db = get_plugin('db')
|
||||||
assert db
|
assert db
|
||||||
|
db.create_all(db.get_engine(), Base)
|
||||||
engine = db.get_engine(engine=db_url)
|
|
||||||
with db.get_session() as session:
|
|
||||||
db.create_all(engine, Base)
|
|
||||||
session.flush()
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ from sqlalchemy.orm import Session, make_transient
|
||||||
from platypush.context import get_bus, get_plugin
|
from platypush.context import get_bus, get_plugin
|
||||||
from platypush.message.event.entities import EntityUpdateEvent
|
from platypush.message.event.entities import EntityUpdateEvent
|
||||||
|
|
||||||
from ._base import Entity, db_url
|
from ._base import Entity
|
||||||
|
|
||||||
|
|
||||||
class EntitiesEngine(Thread):
|
class EntitiesEngine(Thread):
|
||||||
|
@ -35,7 +35,7 @@ class EntitiesEngine(Thread):
|
||||||
def _get_session(self):
|
def _get_session(self):
|
||||||
db = get_plugin('db')
|
db = get_plugin('db')
|
||||||
assert db
|
assert db
|
||||||
return db.get_session(engine=db_url)
|
return db.get_session()
|
||||||
|
|
||||||
def _get_cached_entity(self, entity: Entity) -> Optional[dict]:
|
def _get_cached_entity(self, entity: Entity) -> Optional[dict]:
|
||||||
if entity.id:
|
if entity.id:
|
||||||
|
|
|
@ -11,6 +11,8 @@ from sqlalchemy.sql import and_, or_, text
|
||||||
|
|
||||||
from platypush.plugins import Plugin, action
|
from platypush.plugins import Plugin, action
|
||||||
|
|
||||||
|
session_locks = {}
|
||||||
|
|
||||||
|
|
||||||
class DbPlugin(Plugin):
|
class DbPlugin(Plugin):
|
||||||
"""
|
"""
|
||||||
|
@ -41,7 +43,6 @@ class DbPlugin(Plugin):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.engine_url = engine
|
self.engine_url = engine
|
||||||
self.engine = self.get_engine(engine, *args, **kwargs)
|
self.engine = self.get_engine(engine, *args, **kwargs)
|
||||||
self._session_locks = {}
|
|
||||||
|
|
||||||
def get_engine(
|
def get_engine(
|
||||||
self, engine: Optional[Union[str, Engine]] = None, *args, **kwargs
|
self, engine: Optional[Union[str, Engine]] = None, *args, **kwargs
|
||||||
|
@ -60,7 +61,7 @@ class DbPlugin(Plugin):
|
||||||
return create_engine(engine, *args, **kwargs) # type: ignore
|
return create_engine(engine, *args, **kwargs) # type: ignore
|
||||||
|
|
||||||
if not self.engine:
|
if not self.engine:
|
||||||
return create_engine(self.engine_url, *args, **kwargs) # type: ignore
|
self.engine = create_engine(self.engine_url, *args, **kwargs) # type: ignore
|
||||||
|
|
||||||
return self.engine
|
return self.engine
|
||||||
|
|
||||||
|
@ -508,23 +509,26 @@ class DbPlugin(Plugin):
|
||||||
connection.execute(delete)
|
connection.execute(delete)
|
||||||
|
|
||||||
def create_all(self, engine, base):
|
def create_all(self, engine, base):
|
||||||
self._session_locks[engine.url] = self._session_locks.get(engine.url, RLock())
|
with (self.get_session(engine) as session, session.begin()):
|
||||||
with self._session_locks[engine.url]:
|
base.metadata.create_all(session.connection())
|
||||||
base.metadata.create_all(engine)
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def get_session(
|
def get_session(
|
||||||
self, engine=None, *args, **kwargs
|
self, engine=None, *args, **kwargs
|
||||||
) -> Generator[Session, None, None]:
|
) -> Generator[Session, None, None]:
|
||||||
engine = self.get_engine(engine, *args, **kwargs)
|
engine = self.get_engine(engine, *args, **kwargs)
|
||||||
self._session_locks[engine.url] = self._session_locks.get(engine.url, RLock())
|
session_locks[engine.url] = session_locks.get(engine.url, RLock())
|
||||||
with self._session_locks[engine.url]:
|
|
||||||
session = scoped_session(sessionmaker(expire_on_commit=False))
|
with (session_locks[engine.url], engine.connect() as conn, conn.begin()):
|
||||||
session.configure(bind=engine)
|
session = scoped_session(
|
||||||
s = session()
|
sessionmaker(
|
||||||
yield s
|
expire_on_commit=False,
|
||||||
s.commit()
|
autoflush=True,
|
||||||
s.close()
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
session.configure(bind=conn)
|
||||||
|
yield session()
|
||||||
|
|
||||||
|
|
||||||
# vim:sw=4:ts=4:et:
|
# vim:sw=4:ts=4:et:
|
||||||
|
|
|
@ -12,7 +12,6 @@ from platypush.entities import (
|
||||||
Entity,
|
Entity,
|
||||||
get_plugin_entity_registry,
|
get_plugin_entity_registry,
|
||||||
get_entities_registry,
|
get_entities_registry,
|
||||||
db_url,
|
|
||||||
)
|
)
|
||||||
from platypush.message.event.entities import EntityUpdateEvent, EntityDeleteEvent
|
from platypush.message.event.entities import EntityUpdateEvent, EntityDeleteEvent
|
||||||
from platypush.plugins import Plugin, action
|
from platypush.plugins import Plugin, action
|
||||||
|
@ -30,7 +29,7 @@ class EntitiesPlugin(Plugin):
|
||||||
def _get_session(self):
|
def _get_session(self):
|
||||||
db = get_plugin('db')
|
db = get_plugin('db')
|
||||||
assert db
|
assert db
|
||||||
return db.get_session(engine=db_url)
|
return db.get_session()
|
||||||
|
|
||||||
@action
|
@action
|
||||||
def get(
|
def get(
|
||||||
|
|
|
@ -13,8 +13,9 @@ except ImportError:
|
||||||
from jwt import PyJWTError, encode as jwt_encode, decode as jwt_decode
|
from jwt import PyJWTError, encode as jwt_encode, decode as jwt_decode
|
||||||
|
|
||||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
||||||
from sqlalchemy.orm import make_transient, declarative_base
|
from sqlalchemy.orm import make_transient
|
||||||
|
|
||||||
|
from platypush.common.db import Base
|
||||||
from platypush.context import get_plugin
|
from platypush.context import get_plugin
|
||||||
from platypush.exceptions.user import (
|
from platypush.exceptions.user import (
|
||||||
InvalidJWTTokenException,
|
InvalidJWTTokenException,
|
||||||
|
@ -22,8 +23,6 @@ from platypush.exceptions.user import (
|
||||||
)
|
)
|
||||||
from platypush.utils import get_or_generate_jwt_rsa_key_pair
|
from platypush.utils import get_or_generate_jwt_rsa_key_pair
|
||||||
|
|
||||||
Base = declarative_base()
|
|
||||||
|
|
||||||
|
|
||||||
class UserManager:
|
class UserManager:
|
||||||
"""
|
"""
|
||||||
|
@ -31,13 +30,10 @@ class UserManager:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.db = get_plugin('db')
|
db_plugin = get_plugin('db')
|
||||||
assert self.db
|
assert db_plugin, 'Database plugin not configured'
|
||||||
self._engine = self.db.get_engine(engine=self.db.engine_url)
|
self.db = db_plugin
|
||||||
|
self.db.create_all(self.db.get_engine(), Base)
|
||||||
with self.db.get_session() as session:
|
|
||||||
self.db.create_all(self._engine, Base)
|
|
||||||
session.flush()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _mask_password(user):
|
def _mask_password(user):
|
||||||
|
@ -45,8 +41,11 @@ class UserManager:
|
||||||
user.password = None
|
user.password = None
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
def _get_session(self, *args, **kwargs):
|
||||||
|
return self.db.get_session(self.db.get_engine(), *args, **kwargs)
|
||||||
|
|
||||||
def get_user(self, username):
|
def get_user(self, username):
|
||||||
with self.db.get_session() as session:
|
with self._get_session() as session:
|
||||||
user = self._get_user(session, username)
|
user = self._get_user(session, username)
|
||||||
if not user:
|
if not user:
|
||||||
return None
|
return None
|
||||||
|
@ -55,11 +54,11 @@ class UserManager:
|
||||||
return self._mask_password(user)
|
return self._mask_password(user)
|
||||||
|
|
||||||
def get_user_count(self):
|
def get_user_count(self):
|
||||||
with self.db.get_session() as session:
|
with self._get_session() as session:
|
||||||
return session.query(User).count()
|
return session.query(User).count()
|
||||||
|
|
||||||
def get_users(self):
|
def get_users(self):
|
||||||
with self.db.get_session() as session:
|
with self._get_session() as session:
|
||||||
return session.query(User)
|
return session.query(User)
|
||||||
|
|
||||||
def create_user(self, username, password, **kwargs):
|
def create_user(self, username, password, **kwargs):
|
||||||
|
@ -68,7 +67,7 @@ class UserManager:
|
||||||
if not password:
|
if not password:
|
||||||
raise ValueError('Please provide a password for the user')
|
raise ValueError('Please provide a password for the user')
|
||||||
|
|
||||||
with self.db.get_session() as session:
|
with self._get_session() as session:
|
||||||
user = self._get_user(session, username)
|
user = self._get_user(session, username)
|
||||||
if user:
|
if user:
|
||||||
raise NameError('The user {} already exists'.format(username))
|
raise NameError('The user {} already exists'.format(username))
|
||||||
|
@ -87,7 +86,7 @@ class UserManager:
|
||||||
return self._mask_password(user)
|
return self._mask_password(user)
|
||||||
|
|
||||||
def update_password(self, username, old_password, new_password):
|
def update_password(self, username, old_password, new_password):
|
||||||
with self.db.get_session() as session:
|
with self._get_session() as session:
|
||||||
if not self._authenticate_user(session, username, old_password):
|
if not self._authenticate_user(session, username, old_password):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -97,11 +96,11 @@ class UserManager:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def authenticate_user(self, username, password):
|
def authenticate_user(self, username, password):
|
||||||
with self.db.get_session() as session:
|
with self._get_session() as session:
|
||||||
return self._authenticate_user(session, username, password)
|
return self._authenticate_user(session, username, password)
|
||||||
|
|
||||||
def authenticate_user_session(self, session_token):
|
def authenticate_user_session(self, session_token):
|
||||||
with self.db.get_session() as session:
|
with self._get_session() as session:
|
||||||
user_session = (
|
user_session = (
|
||||||
session.query(UserSession)
|
session.query(UserSession)
|
||||||
.filter_by(session_token=session_token)
|
.filter_by(session_token=session_token)
|
||||||
|
@ -118,7 +117,7 @@ class UserManager:
|
||||||
return self._mask_password(user), user_session
|
return self._mask_password(user), user_session
|
||||||
|
|
||||||
def delete_user(self, username):
|
def delete_user(self, username):
|
||||||
with self.db.get_session() as session:
|
with self._get_session() as session:
|
||||||
user = self._get_user(session, username)
|
user = self._get_user(session, username)
|
||||||
if not user:
|
if not user:
|
||||||
raise NameError('No such user: {}'.format(username))
|
raise NameError('No such user: {}'.format(username))
|
||||||
|
@ -134,7 +133,7 @@ class UserManager:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def delete_user_session(self, session_token):
|
def delete_user_session(self, session_token):
|
||||||
with self.db.get_session() as session:
|
with self._get_session() as session:
|
||||||
user_session = (
|
user_session = (
|
||||||
session.query(UserSession)
|
session.query(UserSession)
|
||||||
.filter_by(session_token=session_token)
|
.filter_by(session_token=session_token)
|
||||||
|
@ -149,7 +148,7 @@ class UserManager:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def create_user_session(self, username, password, expires_at=None):
|
def create_user_session(self, username, password, expires_at=None):
|
||||||
with self.db.get_session() as session:
|
with self._get_session() as session:
|
||||||
user = self._authenticate_user(session, username, password)
|
user = self._authenticate_user(session, username, password)
|
||||||
if not user:
|
if not user:
|
||||||
return None
|
return None
|
||||||
|
@ -203,7 +202,7 @@ class UserManager:
|
||||||
|
|
||||||
:param session_token: Session token.
|
:param session_token: Session token.
|
||||||
"""
|
"""
|
||||||
with self.db.get_session() as session:
|
with self._get_session() as session:
|
||||||
return (
|
return (
|
||||||
session.query(User)
|
session.query(User)
|
||||||
.join(UserSession)
|
.join(UserSession)
|
||||||
|
@ -263,7 +262,7 @@ class UserManager:
|
||||||
pub_key, _ = get_or_generate_jwt_rsa_key_pair()
|
pub_key, _ = get_or_generate_jwt_rsa_key_pair()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = jwt_decode(token.encode(), pub_key, algorithms=['RS256'])
|
payload = jwt_decode(token.encode(), pub_key, algorithms=['RS256']) # type: ignore[reportGeneralTypeIssues]
|
||||||
except PyJWTError as e:
|
except PyJWTError as e:
|
||||||
raise InvalidJWTTokenException(str(e))
|
raise InvalidJWTTokenException(str(e))
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue