forked from platypush/platypush
Fixed compatibility with SQLAlchemy >= 2.0 in the db
plugin.
This commit is contained in:
parent
8478245cde
commit
87889142e0
2 changed files with 33 additions and 13 deletions
|
@ -3,7 +3,12 @@ from contextlib import contextmanager
|
||||||
from multiprocessing import RLock
|
from multiprocessing import RLock
|
||||||
from typing import Optional, Generator, Union
|
from typing import Optional, Generator, Union
|
||||||
|
|
||||||
from sqlalchemy import create_engine, Table, MetaData
|
from sqlalchemy import (
|
||||||
|
create_engine,
|
||||||
|
Table,
|
||||||
|
MetaData,
|
||||||
|
__version__ as sa_version,
|
||||||
|
)
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
from sqlalchemy.exc import CompileError
|
from sqlalchemy.exc import CompileError
|
||||||
from sqlalchemy.orm import Session, sessionmaker, scoped_session
|
from sqlalchemy.orm import Session, sessionmaker, scoped_session
|
||||||
|
@ -14,6 +19,21 @@ from platypush.plugins import Plugin, action
|
||||||
session_locks = {}
|
session_locks = {}
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def conn_begin(conn):
|
||||||
|
"""
|
||||||
|
Utility method to deal with `autobegin` being enabled on SQLAlchemy 2.0 but
|
||||||
|
not on earlier version.
|
||||||
|
"""
|
||||||
|
sa_maj_ver = int(sa_version.split('.')[0])
|
||||||
|
if sa_maj_ver < 2:
|
||||||
|
yield conn.begin()
|
||||||
|
else:
|
||||||
|
yield conn._transaction
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
class DbPlugin(Plugin):
|
class DbPlugin(Plugin):
|
||||||
"""
|
"""
|
||||||
Database plugin. It allows you to programmatically select, insert, update
|
Database plugin. It allows you to programmatically select, insert, update
|
||||||
|
@ -101,7 +121,7 @@ class DbPlugin(Plugin):
|
||||||
engine = self.get_engine(engine, *args, **kwargs)
|
engine = self.get_engine(engine, *args, **kwargs)
|
||||||
|
|
||||||
with engine.connect() as connection:
|
with engine.connect() as connection:
|
||||||
connection.execute(statement)
|
connection.execute(text(statement))
|
||||||
|
|
||||||
def _get_table(self, table, engine=None, *args, **kwargs):
|
def _get_table(self, table, engine=None, *args, **kwargs):
|
||||||
if not engine:
|
if not engine:
|
||||||
|
@ -115,7 +135,7 @@ class DbPlugin(Plugin):
|
||||||
try:
|
try:
|
||||||
n_tries += 1
|
n_tries += 1
|
||||||
metadata = MetaData()
|
metadata = MetaData()
|
||||||
table = Table(table, metadata, autoload=True, autoload_with=engine)
|
table = Table(table, metadata, autoload_with=engine)
|
||||||
db_ok = True
|
db_ok = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
last_error = e
|
last_error = e
|
||||||
|
@ -139,7 +159,7 @@ class DbPlugin(Plugin):
|
||||||
engine=None,
|
engine=None,
|
||||||
data: Optional[dict] = None,
|
data: Optional[dict] = None,
|
||||||
*args,
|
*args,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Returns rows (as a list of hashes) given a query.
|
Returns rows (as a list of hashes) given a query.
|
||||||
|
@ -219,7 +239,7 @@ class DbPlugin(Plugin):
|
||||||
query = table.select()
|
query = table.select()
|
||||||
|
|
||||||
if filter:
|
if filter:
|
||||||
for (k, v) in filter.items():
|
for k, v in filter.items():
|
||||||
query = query.where(self._build_condition(table, k, v))
|
query = query.where(self._build_condition(table, k, v))
|
||||||
|
|
||||||
if query is None:
|
if query is None:
|
||||||
|
@ -246,7 +266,7 @@ class DbPlugin(Plugin):
|
||||||
key_columns=None,
|
key_columns=None,
|
||||||
on_duplicate_update=False,
|
on_duplicate_update=False,
|
||||||
*args,
|
*args,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Inserts records (as a list of hashes) into a table.
|
Inserts records (as a list of hashes) into a table.
|
||||||
|
@ -324,7 +344,7 @@ class DbPlugin(Plugin):
|
||||||
connection, table, records, key_columns
|
connection, table, records, key_columns
|
||||||
)
|
)
|
||||||
|
|
||||||
with connection.begin():
|
with conn_begin(connection):
|
||||||
if insert_records:
|
if insert_records:
|
||||||
insert = table.insert().values(insert_records)
|
insert = table.insert().values(insert_records)
|
||||||
ret = self._execute_try_returning(connection, insert)
|
ret = self._execute_try_returning(connection, insert)
|
||||||
|
@ -394,7 +414,7 @@ class DbPlugin(Plugin):
|
||||||
values = {k: v for (k, v) in record.items() if k not in key_columns}
|
values = {k: v for (k, v) in record.items() if k not in key_columns}
|
||||||
update = table.update()
|
update = table.update()
|
||||||
|
|
||||||
for (k, v) in key.items():
|
for k, v in key.items():
|
||||||
update = update.where(self._build_condition(table, k, v))
|
update = update.where(self._build_condition(table, k, v))
|
||||||
|
|
||||||
update = update.values(**values)
|
update = update.values(**values)
|
||||||
|
@ -498,18 +518,18 @@ class DbPlugin(Plugin):
|
||||||
|
|
||||||
engine = self.get_engine(engine, *args, **kwargs)
|
engine = self.get_engine(engine, *args, **kwargs)
|
||||||
|
|
||||||
with engine.connect() as connection, connection.begin():
|
with engine.connect() as connection, conn_begin(connection):
|
||||||
for record in records:
|
for record in records:
|
||||||
table, engine = self._get_table(table, engine=engine, *args, **kwargs)
|
table, engine = self._get_table(table, engine=engine, *args, **kwargs)
|
||||||
delete = table.delete()
|
delete = table.delete()
|
||||||
|
|
||||||
for (k, v) in record.items():
|
for k, v in record.items():
|
||||||
delete = delete.where(self._build_condition(table, k, v))
|
delete = delete.where(self._build_condition(table, k, v))
|
||||||
|
|
||||||
connection.execute(delete)
|
connection.execute(delete)
|
||||||
|
|
||||||
def create_all(self, engine, base):
|
def create_all(self, engine, base):
|
||||||
with (self.get_session(engine, locked=True) as session, session.begin()):
|
with self.get_session(engine, locked=True) as session, session.begin():
|
||||||
base.metadata.create_all(session.connection())
|
base.metadata.create_all(session.connection())
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@ -523,7 +543,7 @@ class DbPlugin(Plugin):
|
||||||
# Mock lock
|
# Mock lock
|
||||||
lock = RLock()
|
lock = RLock()
|
||||||
|
|
||||||
with (lock, engine.connect() as conn, conn.begin()):
|
with lock, engine.connect() as conn, conn_begin(conn):
|
||||||
session = scoped_session(
|
session = scoped_session(
|
||||||
sessionmaker(
|
sessionmaker(
|
||||||
expire_on_commit=False,
|
expire_on_commit=False,
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -57,7 +57,7 @@ setup(
|
||||||
'redis',
|
'redis',
|
||||||
'requests',
|
'requests',
|
||||||
'croniter',
|
'croniter',
|
||||||
'sqlalchemy<2.0.0',
|
'sqlalchemy',
|
||||||
'websockets',
|
'websockets',
|
||||||
'websocket-client',
|
'websocket-client',
|
||||||
'wheel',
|
'wheel',
|
||||||
|
|
Loading…
Reference in a new issue