forked from platypush/platypush
Removed connection.begin()
pattern from the db
plugin.
SQLAlchemy should automatically begin a transaction on connection/session creation. Plus, `.begin()` messes up things with SQLAlchemy 2, which has `autobegin` enabled with no easy way of disabling it.
This commit is contained in:
parent
37722d12cd
commit
e1cd22121a
1 changed files with 24 additions and 43 deletions
|
@ -3,12 +3,7 @@ from contextlib import contextmanager
|
|||
from multiprocessing import RLock
|
||||
from typing import Optional, Generator, Union
|
||||
|
||||
from sqlalchemy import (
|
||||
create_engine,
|
||||
Table,
|
||||
MetaData,
|
||||
__version__ as sa_version,
|
||||
)
|
||||
from sqlalchemy import create_engine, Table, MetaData
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.exc import CompileError
|
||||
from sqlalchemy.orm import Session, sessionmaker, scoped_session
|
||||
|
@ -19,21 +14,6 @@ from platypush.plugins import Plugin, action
|
|||
session_locks = {}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def conn_begin(conn):
|
||||
"""
|
||||
Utility method to deal with `autobegin` being enabled on SQLAlchemy 2.0 but
|
||||
not on earlier version.
|
||||
"""
|
||||
sa_maj_ver = int(sa_version.split('.')[0])
|
||||
if sa_maj_ver < 2:
|
||||
yield conn.begin()
|
||||
else:
|
||||
yield conn._transaction
|
||||
|
||||
conn.commit()
|
||||
|
||||
|
||||
class DbPlugin(Plugin):
|
||||
"""
|
||||
Database plugin. It allows you to programmatically select, insert, update
|
||||
|
@ -45,7 +25,7 @@ class DbPlugin(Plugin):
|
|||
_db_error_wait_interval = 5.0
|
||||
_db_error_retries = 3
|
||||
|
||||
def __init__(self, *args, engine=None, **kwargs):
|
||||
def __init__(self, engine=None, **kwargs):
|
||||
"""
|
||||
:param engine: Default SQLAlchemy connection engine string (e.g.
|
||||
``sqlite:///:memory:`` or ``mysql://user:pass@localhost/test``)
|
||||
|
@ -62,7 +42,7 @@ class DbPlugin(Plugin):
|
|||
|
||||
super().__init__()
|
||||
self.engine_url = engine
|
||||
self.engine = self.get_engine(engine, *args, **kwargs)
|
||||
self.engine = self.get_engine(engine, **kwargs)
|
||||
|
||||
def get_engine(
|
||||
self, engine: Optional[Union[str, Engine]] = None, *args, **kwargs
|
||||
|
@ -116,12 +96,10 @@ class DbPlugin(Plugin):
|
|||
(see https:///docs.sqlalchemy.org/en/latest/core/engines.html)
|
||||
"""
|
||||
|
||||
engine = self.get_engine(engine, *args, **kwargs)
|
||||
|
||||
with engine.connect() as connection:
|
||||
with self.get_engine(engine, *args, **kwargs).connect() as connection:
|
||||
connection.execute(text(statement))
|
||||
|
||||
def _get_table(self, table: str, engine=None, *args, **kwargs):
|
||||
def _get_table(self, table: str, *args, engine=None, **kwargs):
|
||||
if not engine:
|
||||
engine = self.get_engine(engine, *args, **kwargs)
|
||||
|
||||
|
@ -344,17 +322,16 @@ class DbPlugin(Plugin):
|
|||
connection, table, records, key_columns
|
||||
)
|
||||
|
||||
with conn_begin(connection):
|
||||
if insert_records:
|
||||
insert = table.insert().values(insert_records)
|
||||
ret = self._execute_try_returning(connection, insert)
|
||||
if ret:
|
||||
returned_records += ret
|
||||
if insert_records:
|
||||
insert = table.insert().values(insert_records)
|
||||
ret = self._execute_try_returning(connection, insert)
|
||||
if ret:
|
||||
returned_records += ret
|
||||
|
||||
if update_records and on_duplicate_update:
|
||||
ret = self._update(connection, table, update_records, key_columns)
|
||||
if ret:
|
||||
returned_records = ret + returned_records
|
||||
if update_records and on_duplicate_update:
|
||||
ret = self._update(connection, table, update_records, key_columns)
|
||||
if ret:
|
||||
returned_records = ret + returned_records
|
||||
|
||||
if returned_records:
|
||||
return returned_records
|
||||
|
@ -519,7 +496,7 @@ class DbPlugin(Plugin):
|
|||
|
||||
engine = self.get_engine(engine, *args, **kwargs)
|
||||
|
||||
with engine.connect() as connection, conn_begin(connection):
|
||||
with engine.connect() as connection:
|
||||
for record in records:
|
||||
table_, engine = self._get_table(table, engine=engine, *args, **kwargs)
|
||||
delete = table_.delete()
|
||||
|
@ -535,7 +512,7 @@ class DbPlugin(Plugin):
|
|||
|
||||
@contextmanager
|
||||
def get_session(
|
||||
self, engine=None, locked=False, autoflush=True, *args, **kwargs
|
||||
self, *args, engine=None, locked=False, autoflush=True, **kwargs
|
||||
) -> Generator[Session, None, None]:
|
||||
engine = self.get_engine(engine, *args, **kwargs)
|
||||
if locked:
|
||||
|
@ -544,16 +521,20 @@ class DbPlugin(Plugin):
|
|||
# Mock lock
|
||||
lock = RLock()
|
||||
|
||||
with lock, engine.connect() as conn, conn_begin(conn):
|
||||
session = scoped_session(
|
||||
with lock, engine.connect() as conn:
|
||||
session_maker = scoped_session(
|
||||
sessionmaker(
|
||||
expire_on_commit=False,
|
||||
autoflush=autoflush,
|
||||
)
|
||||
)
|
||||
|
||||
session.configure(bind=conn)
|
||||
yield session()
|
||||
session_maker.configure(bind=conn)
|
||||
session = session_maker()
|
||||
yield session
|
||||
|
||||
session.flush()
|
||||
session.commit()
|
||||
|
||||
|
||||
# vim:sw=4:ts=4:et:
|
||||
|
|
Loading…
Add table
Reference in a new issue