Removed connection.begin() pattern from the db plugin.

SQLAlchemy should automatically begin a transaction on
connection/session creation. Plus, `.begin()` messes up things with
SQLAlchemy 2, which has `autobegin` enabled with no easy way of
disabling it.
This commit is contained in:
Fabio Manganiello 2023-04-25 10:31:49 +02:00
parent 37722d12cd
commit e1cd22121a
Signed by: blacklight
GPG key ID: D90FBA7F76362774

View file

@ -3,12 +3,7 @@ from contextlib import contextmanager
from multiprocessing import RLock
from typing import Optional, Generator, Union
from sqlalchemy import (
create_engine,
Table,
MetaData,
__version__ as sa_version,
)
from sqlalchemy import create_engine, Table, MetaData
from sqlalchemy.engine import Engine
from sqlalchemy.exc import CompileError
from sqlalchemy.orm import Session, sessionmaker, scoped_session
@ -19,21 +14,6 @@ from platypush.plugins import Plugin, action
session_locks = {}
@contextmanager
def conn_begin(conn):
"""
Utility method to deal with `autobegin` being enabled on SQLAlchemy 2.0 but
not on earlier version.
"""
sa_maj_ver = int(sa_version.split('.')[0])
if sa_maj_ver < 2:
yield conn.begin()
else:
yield conn._transaction
conn.commit()
class DbPlugin(Plugin):
"""
Database plugin. It allows you to programmatically select, insert, update
@ -45,7 +25,7 @@ class DbPlugin(Plugin):
_db_error_wait_interval = 5.0
_db_error_retries = 3
def __init__(self, *args, engine=None, **kwargs):
def __init__(self, engine=None, **kwargs):
"""
:param engine: Default SQLAlchemy connection engine string (e.g.
``sqlite:///:memory:`` or ``mysql://user:pass@localhost/test``)
@ -62,7 +42,7 @@ class DbPlugin(Plugin):
super().__init__()
self.engine_url = engine
self.engine = self.get_engine(engine, *args, **kwargs)
self.engine = self.get_engine(engine, **kwargs)
def get_engine(
self, engine: Optional[Union[str, Engine]] = None, *args, **kwargs
@ -116,12 +96,10 @@ class DbPlugin(Plugin):
(see https:///docs.sqlalchemy.org/en/latest/core/engines.html)
"""
engine = self.get_engine(engine, *args, **kwargs)
with engine.connect() as connection:
with self.get_engine(engine, *args, **kwargs).connect() as connection:
connection.execute(text(statement))
def _get_table(self, table: str, engine=None, *args, **kwargs):
def _get_table(self, table: str, *args, engine=None, **kwargs):
if not engine:
engine = self.get_engine(engine, *args, **kwargs)
@ -344,17 +322,16 @@ class DbPlugin(Plugin):
connection, table, records, key_columns
)
with conn_begin(connection):
if insert_records:
insert = table.insert().values(insert_records)
ret = self._execute_try_returning(connection, insert)
if ret:
returned_records += ret
if insert_records:
insert = table.insert().values(insert_records)
ret = self._execute_try_returning(connection, insert)
if ret:
returned_records += ret
if update_records and on_duplicate_update:
ret = self._update(connection, table, update_records, key_columns)
if ret:
returned_records = ret + returned_records
if update_records and on_duplicate_update:
ret = self._update(connection, table, update_records, key_columns)
if ret:
returned_records = ret + returned_records
if returned_records:
return returned_records
@ -519,7 +496,7 @@ class DbPlugin(Plugin):
engine = self.get_engine(engine, *args, **kwargs)
with engine.connect() as connection, conn_begin(connection):
with engine.connect() as connection:
for record in records:
table_, engine = self._get_table(table, engine=engine, *args, **kwargs)
delete = table_.delete()
@ -535,7 +512,7 @@ class DbPlugin(Plugin):
@contextmanager
def get_session(
self, engine=None, locked=False, autoflush=True, *args, **kwargs
self, *args, engine=None, locked=False, autoflush=True, **kwargs
) -> Generator[Session, None, None]:
engine = self.get_engine(engine, *args, **kwargs)
if locked:
@ -544,16 +521,20 @@ class DbPlugin(Plugin):
# Mock lock
lock = RLock()
with lock, engine.connect() as conn, conn_begin(conn):
session = scoped_session(
with lock, engine.connect() as conn:
session_maker = scoped_session(
sessionmaker(
expire_on_commit=False,
autoflush=autoflush,
)
)
session.configure(bind=conn)
yield session()
session_maker.configure(bind=conn)
session = session_maker()
yield session
session.flush()
session.commit()
# vim:sw=4:ts=4:et: