diff --git a/platypush/plugins/db/__init__.py b/platypush/plugins/db/__init__.py index e777ca16..e92c2fe7 100644 --- a/platypush/plugins/db/__init__.py +++ b/platypush/plugins/db/__init__.py @@ -3,7 +3,12 @@ from contextlib import contextmanager from multiprocessing import RLock from typing import Optional, Generator, Union -from sqlalchemy import create_engine, Table, MetaData +from sqlalchemy import ( + create_engine, + Table, + MetaData, + __version__ as sa_version, +) from sqlalchemy.engine import Engine from sqlalchemy.exc import CompileError from sqlalchemy.orm import Session, sessionmaker, scoped_session @@ -14,6 +19,21 @@ from platypush.plugins import Plugin, action session_locks = {} +@contextmanager +def conn_begin(conn): + """ + Utility method to deal with `autobegin` being enabled on SQLAlchemy 2.0 but + not on earlier version. + """ + sa_maj_ver = int(sa_version.split('.')[0]) + if sa_maj_ver < 2: + yield conn.begin() + else: + yield conn._transaction + + conn.commit() + + class DbPlugin(Plugin): """ Database plugin. It allows you to programmatically select, insert, update @@ -101,7 +121,7 @@ class DbPlugin(Plugin): engine = self.get_engine(engine, *args, **kwargs) with engine.connect() as connection: - connection.execute(statement) + connection.execute(text(statement)) def _get_table(self, table, engine=None, *args, **kwargs): if not engine: @@ -115,7 +135,7 @@ class DbPlugin(Plugin): try: n_tries += 1 metadata = MetaData() - table = Table(table, metadata, autoload=True, autoload_with=engine) + table = Table(table, metadata, autoload_with=engine) db_ok = True except Exception as e: last_error = e @@ -139,7 +159,7 @@ class DbPlugin(Plugin): engine=None, data: Optional[dict] = None, *args, - **kwargs + **kwargs, ): """ Returns rows (as a list of hashes) given a query. @@ -219,7 +239,7 @@ class DbPlugin(Plugin): query = table.select() if filter: - for (k, v) in filter.items(): + for k, v in filter.items(): query = query.where(self._build_condition(table, k, v)) if query is None: @@ -246,7 +266,7 @@ class DbPlugin(Plugin): key_columns=None, on_duplicate_update=False, *args, - **kwargs + **kwargs, ): """ Inserts records (as a list of hashes) into a table. @@ -324,7 +344,7 @@ class DbPlugin(Plugin): connection, table, records, key_columns ) - with connection.begin(): + with conn_begin(connection): if insert_records: insert = table.insert().values(insert_records) ret = self._execute_try_returning(connection, insert) @@ -394,7 +414,7 @@ class DbPlugin(Plugin): values = {k: v for (k, v) in record.items() if k not in key_columns} update = table.update() - for (k, v) in key.items(): + for k, v in key.items(): update = update.where(self._build_condition(table, k, v)) update = update.values(**values) @@ -498,18 +518,18 @@ class DbPlugin(Plugin): engine = self.get_engine(engine, *args, **kwargs) - with engine.connect() as connection, connection.begin(): + with engine.connect() as connection, conn_begin(connection): for record in records: table, engine = self._get_table(table, engine=engine, *args, **kwargs) delete = table.delete() - for (k, v) in record.items(): + for k, v in record.items(): delete = delete.where(self._build_condition(table, k, v)) connection.execute(delete) def create_all(self, engine, base): - with (self.get_session(engine, locked=True) as session, session.begin()): + with self.get_session(engine, locked=True) as session, session.begin(): base.metadata.create_all(session.connection()) @contextmanager @@ -523,7 +543,7 @@ class DbPlugin(Plugin): # Mock lock lock = RLock() - with (lock, engine.connect() as conn, conn.begin()): + with lock, engine.connect() as conn, conn_begin(conn): session = scoped_session( sessionmaker( expire_on_commit=False, diff --git a/setup.py b/setup.py index eb77d9ec..48e64f1b 100755 --- a/setup.py +++ b/setup.py @@ -57,7 +57,7 @@ setup( 'redis', 'requests', 'croniter', - 'sqlalchemy<2.0.0', + 'sqlalchemy', 'websockets', 'websocket-client', 'wheel',