Fixed compatibility with SQLAlchemy >= 2.0 in the db plugin.

This commit is contained in:
Fabio Manganiello 2023-04-24 22:52:17 +02:00
parent 8478245cde
commit 87889142e0
Signed by: blacklight
GPG key ID: D90FBA7F76362774
2 changed files with 33 additions and 13 deletions

View file

@ -3,7 +3,12 @@ from contextlib import contextmanager
from multiprocessing import RLock from multiprocessing import RLock
from typing import Optional, Generator, Union from typing import Optional, Generator, Union
from sqlalchemy import create_engine, Table, MetaData from sqlalchemy import (
create_engine,
Table,
MetaData,
__version__ as sa_version,
)
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
from sqlalchemy.exc import CompileError from sqlalchemy.exc import CompileError
from sqlalchemy.orm import Session, sessionmaker, scoped_session from sqlalchemy.orm import Session, sessionmaker, scoped_session
@ -14,6 +19,21 @@ from platypush.plugins import Plugin, action
session_locks = {} session_locks = {}
@contextmanager
def conn_begin(conn):
"""
Utility method to deal with `autobegin` being enabled on SQLAlchemy 2.0 but
not on earlier version.
"""
sa_maj_ver = int(sa_version.split('.')[0])
if sa_maj_ver < 2:
yield conn.begin()
else:
yield conn._transaction
conn.commit()
class DbPlugin(Plugin): class DbPlugin(Plugin):
""" """
Database plugin. It allows you to programmatically select, insert, update Database plugin. It allows you to programmatically select, insert, update
@ -101,7 +121,7 @@ class DbPlugin(Plugin):
engine = self.get_engine(engine, *args, **kwargs) engine = self.get_engine(engine, *args, **kwargs)
with engine.connect() as connection: with engine.connect() as connection:
connection.execute(statement) connection.execute(text(statement))
def _get_table(self, table, engine=None, *args, **kwargs): def _get_table(self, table, engine=None, *args, **kwargs):
if not engine: if not engine:
@ -115,7 +135,7 @@ class DbPlugin(Plugin):
try: try:
n_tries += 1 n_tries += 1
metadata = MetaData() metadata = MetaData()
table = Table(table, metadata, autoload=True, autoload_with=engine) table = Table(table, metadata, autoload_with=engine)
db_ok = True db_ok = True
except Exception as e: except Exception as e:
last_error = e last_error = e
@ -139,7 +159,7 @@ class DbPlugin(Plugin):
engine=None, engine=None,
data: Optional[dict] = None, data: Optional[dict] = None,
*args, *args,
**kwargs **kwargs,
): ):
""" """
Returns rows (as a list of hashes) given a query. Returns rows (as a list of hashes) given a query.
@ -219,7 +239,7 @@ class DbPlugin(Plugin):
query = table.select() query = table.select()
if filter: if filter:
for (k, v) in filter.items(): for k, v in filter.items():
query = query.where(self._build_condition(table, k, v)) query = query.where(self._build_condition(table, k, v))
if query is None: if query is None:
@ -246,7 +266,7 @@ class DbPlugin(Plugin):
key_columns=None, key_columns=None,
on_duplicate_update=False, on_duplicate_update=False,
*args, *args,
**kwargs **kwargs,
): ):
""" """
Inserts records (as a list of hashes) into a table. Inserts records (as a list of hashes) into a table.
@ -324,7 +344,7 @@ class DbPlugin(Plugin):
connection, table, records, key_columns connection, table, records, key_columns
) )
with connection.begin(): with conn_begin(connection):
if insert_records: if insert_records:
insert = table.insert().values(insert_records) insert = table.insert().values(insert_records)
ret = self._execute_try_returning(connection, insert) ret = self._execute_try_returning(connection, insert)
@ -394,7 +414,7 @@ class DbPlugin(Plugin):
values = {k: v for (k, v) in record.items() if k not in key_columns} values = {k: v for (k, v) in record.items() if k not in key_columns}
update = table.update() update = table.update()
for (k, v) in key.items(): for k, v in key.items():
update = update.where(self._build_condition(table, k, v)) update = update.where(self._build_condition(table, k, v))
update = update.values(**values) update = update.values(**values)
@ -498,18 +518,18 @@ class DbPlugin(Plugin):
engine = self.get_engine(engine, *args, **kwargs) engine = self.get_engine(engine, *args, **kwargs)
with engine.connect() as connection, connection.begin(): with engine.connect() as connection, conn_begin(connection):
for record in records: for record in records:
table, engine = self._get_table(table, engine=engine, *args, **kwargs) table, engine = self._get_table(table, engine=engine, *args, **kwargs)
delete = table.delete() delete = table.delete()
for (k, v) in record.items(): for k, v in record.items():
delete = delete.where(self._build_condition(table, k, v)) delete = delete.where(self._build_condition(table, k, v))
connection.execute(delete) connection.execute(delete)
def create_all(self, engine, base): def create_all(self, engine, base):
with (self.get_session(engine, locked=True) as session, session.begin()): with self.get_session(engine, locked=True) as session, session.begin():
base.metadata.create_all(session.connection()) base.metadata.create_all(session.connection())
@contextmanager @contextmanager
@ -523,7 +543,7 @@ class DbPlugin(Plugin):
# Mock lock # Mock lock
lock = RLock() lock = RLock()
with (lock, engine.connect() as conn, conn.begin()): with lock, engine.connect() as conn, conn_begin(conn):
session = scoped_session( session = scoped_session(
sessionmaker( sessionmaker(
expire_on_commit=False, expire_on_commit=False,

View file

@ -57,7 +57,7 @@ setup(
'redis', 'redis',
'requests', 'requests',
'croniter', 'croniter',
'sqlalchemy<2.0.0', 'sqlalchemy',
'websockets', 'websockets',
'websocket-client', 'websocket-client',
'wheel', 'wheel',