Don't lock read session from the main database

This commit is contained in:
Fabio Manganiello 2022-11-12 16:10:57 +01:00
parent 86edd70d93
commit 69e097707d
Signed by untrusted user: blacklight
GPG key ID: D90FBA7F76362774
4 changed files with 20 additions and 16 deletions

View file

@ -32,10 +32,10 @@ class EntitiesEngine(Thread):
'by_name_and_plugin': {},
}
def _get_session(self):
def _get_session(self, *args, **kwargs):
db = get_plugin('db')
assert db
return db.get_session()
return db.get_session(*args, **kwargs)
def _get_cached_entity(self, entity: Entity) -> Optional[dict]:
if entity.id:
@ -247,7 +247,7 @@ class EntitiesEngine(Thread):
return list(new_entities.values())
def _process_entities(self, *entities: Entity):
with self._get_session() as session:
with self._get_session(locked=True) as session:
# Ensure that the internal IDs are set to null before the merge
for e in entities:
e.id = None # type: ignore

View file

@ -509,17 +509,21 @@ class DbPlugin(Plugin):
connection.execute(delete)
def create_all(self, engine, base):
with (self.get_session(engine) as session, session.begin()):
with (self.get_session(engine, locked=True) as session, session.begin()):
base.metadata.create_all(session.connection())
@contextmanager
def get_session(
self, engine=None, *args, **kwargs
self, engine=None, locked=False, *args, **kwargs
) -> Generator[Session, None, None]:
engine = self.get_engine(engine, *args, **kwargs)
session_locks[engine.url] = session_locks.get(engine.url, RLock())
if locked:
lock = session_locks[engine.url] = session_locks.get(engine.url, RLock())
else:
# Mock lock
lock = RLock()
with (session_locks[engine.url], engine.connect() as conn, conn.begin()):
with (lock, engine.connect() as conn, conn.begin()):
session = scoped_session(
sessionmaker(
expire_on_commit=False,

View file

@ -26,10 +26,10 @@ class EntitiesPlugin(Plugin):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _get_session(self):
def _get_session(self, *args, **kwargs):
db = get_plugin('db')
assert db
return db.get_session()
return db.get_session(*args, **kwargs)
@action
def get(
@ -194,7 +194,7 @@ class EntitiesPlugin(Plugin):
:param entities: IDs of the entities to be removed.
:return: The payload of the deleted entities.
"""
with self._get_session() as session:
with self._get_session(locked=True) as session:
entities: Collection[Entity] = (
session.query(Entity).filter(Entity.id.in_(entities)).all()
)
@ -235,7 +235,7 @@ class EntitiesPlugin(Plugin):
:return: The updated entities.
"""
entities = {str(k): v for k, v in entities.items()}
with self._get_session() as session:
with self._get_session(locked=True) as session:
objs = session.query(Entity).filter(Entity.id.in_(entities.keys())).all()
for obj in objs:
obj.meta = {**(obj.meta or {}), **(entities.get(str(obj.id), {}))}

View file

@ -67,7 +67,7 @@ class UserManager:
if not password:
raise ValueError('Please provide a password for the user')
with self._get_session() as session:
with self._get_session(locked=True) as session:
user = self._get_user(session, username)
if user:
raise NameError('The user {} already exists'.format(username))
@ -86,7 +86,7 @@ class UserManager:
return self._mask_password(user)
def update_password(self, username, old_password, new_password):
with self._get_session() as session:
with self._get_session(locked=True) as session:
if not self._authenticate_user(session, username, old_password):
return False
@ -117,7 +117,7 @@ class UserManager:
return self._mask_password(user), user_session
def delete_user(self, username):
with self._get_session() as session:
with self._get_session(locked=True) as session:
user = self._get_user(session, username)
if not user:
raise NameError('No such user: {}'.format(username))
@ -133,7 +133,7 @@ class UserManager:
return True
def delete_user_session(self, session_token):
with self._get_session() as session:
with self._get_session(locked=True) as session:
user_session = (
session.query(UserSession)
.filter_by(session_token=session_token)
@ -148,7 +148,7 @@ class UserManager:
return True
def create_user_session(self, username, password, expires_at=None):
with self._get_session() as session:
with self._get_session(locked=True) as session:
user = self._authenticate_user(session, username, password)
if not user:
return None