diff --git a/platypush/plugins/db/__init__.py b/platypush/plugins/db/__init__.py index b5a83df3..70708bd8 100644 --- a/platypush/plugins/db/__init__.py +++ b/platypush/plugins/db/__init__.py @@ -3,12 +3,7 @@ from contextlib import contextmanager from multiprocessing import RLock from typing import Optional, Generator, Union -from sqlalchemy import ( - create_engine, - Table, - MetaData, - __version__ as sa_version, -) +from sqlalchemy import create_engine, Table, MetaData from sqlalchemy.engine import Engine from sqlalchemy.exc import CompileError from sqlalchemy.orm import Session, sessionmaker, scoped_session @@ -19,21 +14,6 @@ from platypush.plugins import Plugin, action session_locks = {} -@contextmanager -def conn_begin(conn): - """ - Utility method to deal with `autobegin` being enabled on SQLAlchemy 2.0 but - not on earlier version. - """ - sa_maj_ver = int(sa_version.split('.')[0]) - if sa_maj_ver < 2: - yield conn.begin() - else: - yield conn._transaction - - conn.commit() - - class DbPlugin(Plugin): """ Database plugin. It allows you to programmatically select, insert, update @@ -45,7 +25,7 @@ class DbPlugin(Plugin): _db_error_wait_interval = 5.0 _db_error_retries = 3 - def __init__(self, *args, engine=None, **kwargs): + def __init__(self, engine=None, **kwargs): """ :param engine: Default SQLAlchemy connection engine string (e.g. ``sqlite:///:memory:`` or ``mysql://user:pass@localhost/test``) @@ -62,7 +42,7 @@ class DbPlugin(Plugin): super().__init__() self.engine_url = engine - self.engine = self.get_engine(engine, *args, **kwargs) + self.engine = self.get_engine(engine, **kwargs) def get_engine( self, engine: Optional[Union[str, Engine]] = None, *args, **kwargs @@ -116,12 +96,10 @@ class DbPlugin(Plugin): (see https:///docs.sqlalchemy.org/en/latest/core/engines.html) """ - engine = self.get_engine(engine, *args, **kwargs) - - with engine.connect() as connection: + with self.get_engine(engine, *args, **kwargs).connect() as connection: connection.execute(text(statement)) - def _get_table(self, table: str, engine=None, *args, **kwargs): + def _get_table(self, table: str, *args, engine=None, **kwargs): if not engine: engine = self.get_engine(engine, *args, **kwargs) @@ -344,17 +322,16 @@ class DbPlugin(Plugin): connection, table, records, key_columns ) - with conn_begin(connection): - if insert_records: - insert = table.insert().values(insert_records) - ret = self._execute_try_returning(connection, insert) - if ret: - returned_records += ret + if insert_records: + insert = table.insert().values(insert_records) + ret = self._execute_try_returning(connection, insert) + if ret: + returned_records += ret - if update_records and on_duplicate_update: - ret = self._update(connection, table, update_records, key_columns) - if ret: - returned_records = ret + returned_records + if update_records and on_duplicate_update: + ret = self._update(connection, table, update_records, key_columns) + if ret: + returned_records = ret + returned_records if returned_records: return returned_records @@ -519,7 +496,7 @@ class DbPlugin(Plugin): engine = self.get_engine(engine, *args, **kwargs) - with engine.connect() as connection, conn_begin(connection): + with engine.connect() as connection: for record in records: table_, engine = self._get_table(table, engine=engine, *args, **kwargs) delete = table_.delete() @@ -535,7 +512,7 @@ class DbPlugin(Plugin): @contextmanager def get_session( - self, engine=None, locked=False, autoflush=True, *args, **kwargs + self, *args, engine=None, locked=False, autoflush=True, **kwargs ) -> Generator[Session, None, None]: engine = self.get_engine(engine, *args, **kwargs) if locked: @@ -544,16 +521,20 @@ class DbPlugin(Plugin): # Mock lock lock = RLock() - with lock, engine.connect() as conn, conn_begin(conn): - session = scoped_session( + with lock, engine.connect() as conn: + session_maker = scoped_session( sessionmaker( expire_on_commit=False, autoflush=autoflush, ) ) - session.configure(bind=conn) - yield session() + session_maker.configure(bind=conn) + session = session_maker() + yield session + + session.flush() + session.commit() # vim:sw=4:ts=4:et: