Removed `connection.begin()` pattern from the `db` plugin.

SQLAlchemy should automatically begin a transaction on
connection/session creation. Plus, `.begin()` messes up things with
SQLAlchemy 2, which has `autobegin` enabled with no easy way of
disabling it.
This commit is contained in:
Fabio Manganiello 2023-04-25 10:31:49 +02:00
parent 37722d12cd
commit e1cd22121a
Signed by: blacklight
GPG Key ID: D90FBA7F76362774
1 changed files with 24 additions and 43 deletions

View File

@ -3,12 +3,7 @@ from contextlib import contextmanager
from multiprocessing import RLock from multiprocessing import RLock
from typing import Optional, Generator, Union from typing import Optional, Generator, Union
from sqlalchemy import ( from sqlalchemy import create_engine, Table, MetaData
create_engine,
Table,
MetaData,
__version__ as sa_version,
)
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
from sqlalchemy.exc import CompileError from sqlalchemy.exc import CompileError
from sqlalchemy.orm import Session, sessionmaker, scoped_session from sqlalchemy.orm import Session, sessionmaker, scoped_session
@ -19,21 +14,6 @@ from platypush.plugins import Plugin, action
session_locks = {} session_locks = {}
@contextmanager
def conn_begin(conn):
"""
Utility method to deal with `autobegin` being enabled on SQLAlchemy 2.0 but
not on earlier version.
"""
sa_maj_ver = int(sa_version.split('.')[0])
if sa_maj_ver < 2:
yield conn.begin()
else:
yield conn._transaction
conn.commit()
class DbPlugin(Plugin): class DbPlugin(Plugin):
""" """
Database plugin. It allows you to programmatically select, insert, update Database plugin. It allows you to programmatically select, insert, update
@ -45,7 +25,7 @@ class DbPlugin(Plugin):
_db_error_wait_interval = 5.0 _db_error_wait_interval = 5.0
_db_error_retries = 3 _db_error_retries = 3
def __init__(self, *args, engine=None, **kwargs): def __init__(self, engine=None, **kwargs):
""" """
:param engine: Default SQLAlchemy connection engine string (e.g. :param engine: Default SQLAlchemy connection engine string (e.g.
``sqlite:///:memory:`` or ``mysql://user:pass@localhost/test``) ``sqlite:///:memory:`` or ``mysql://user:pass@localhost/test``)
@ -62,7 +42,7 @@ class DbPlugin(Plugin):
super().__init__() super().__init__()
self.engine_url = engine self.engine_url = engine
self.engine = self.get_engine(engine, *args, **kwargs) self.engine = self.get_engine(engine, **kwargs)
def get_engine( def get_engine(
self, engine: Optional[Union[str, Engine]] = None, *args, **kwargs self, engine: Optional[Union[str, Engine]] = None, *args, **kwargs
@ -116,12 +96,10 @@ class DbPlugin(Plugin):
(see https:///docs.sqlalchemy.org/en/latest/core/engines.html) (see https:///docs.sqlalchemy.org/en/latest/core/engines.html)
""" """
engine = self.get_engine(engine, *args, **kwargs) with self.get_engine(engine, *args, **kwargs).connect() as connection:
with engine.connect() as connection:
connection.execute(text(statement)) connection.execute(text(statement))
def _get_table(self, table: str, engine=None, *args, **kwargs): def _get_table(self, table: str, *args, engine=None, **kwargs):
if not engine: if not engine:
engine = self.get_engine(engine, *args, **kwargs) engine = self.get_engine(engine, *args, **kwargs)
@ -344,17 +322,16 @@ class DbPlugin(Plugin):
connection, table, records, key_columns connection, table, records, key_columns
) )
with conn_begin(connection): if insert_records:
if insert_records: insert = table.insert().values(insert_records)
insert = table.insert().values(insert_records) ret = self._execute_try_returning(connection, insert)
ret = self._execute_try_returning(connection, insert) if ret:
if ret: returned_records += ret
returned_records += ret
if update_records and on_duplicate_update: if update_records and on_duplicate_update:
ret = self._update(connection, table, update_records, key_columns) ret = self._update(connection, table, update_records, key_columns)
if ret: if ret:
returned_records = ret + returned_records returned_records = ret + returned_records
if returned_records: if returned_records:
return returned_records return returned_records
@ -519,7 +496,7 @@ class DbPlugin(Plugin):
engine = self.get_engine(engine, *args, **kwargs) engine = self.get_engine(engine, *args, **kwargs)
with engine.connect() as connection, conn_begin(connection): with engine.connect() as connection:
for record in records: for record in records:
table_, engine = self._get_table(table, engine=engine, *args, **kwargs) table_, engine = self._get_table(table, engine=engine, *args, **kwargs)
delete = table_.delete() delete = table_.delete()
@ -535,7 +512,7 @@ class DbPlugin(Plugin):
@contextmanager @contextmanager
def get_session( def get_session(
self, engine=None, locked=False, autoflush=True, *args, **kwargs self, *args, engine=None, locked=False, autoflush=True, **kwargs
) -> Generator[Session, None, None]: ) -> Generator[Session, None, None]:
engine = self.get_engine(engine, *args, **kwargs) engine = self.get_engine(engine, *args, **kwargs)
if locked: if locked:
@ -544,16 +521,20 @@ class DbPlugin(Plugin):
# Mock lock # Mock lock
lock = RLock() lock = RLock()
with lock, engine.connect() as conn, conn_begin(conn): with lock, engine.connect() as conn:
session = scoped_session( session_maker = scoped_session(
sessionmaker( sessionmaker(
expire_on_commit=False, expire_on_commit=False,
autoflush=autoflush, autoflush=autoflush,
) )
) )
session.configure(bind=conn) session_maker.configure(bind=conn)
yield session() session = session_maker()
yield session
session.flush()
session.commit()
# vim:sw=4:ts=4:et: # vim:sw=4:ts=4:et: