Compare commits

...

2 commits

Author SHA1 Message Date
37722d12cd
No need for session.begin in db.create_all. 2023-04-24 23:55:50 +02:00
6fa179e769
LINT fixes 2023-04-24 23:49:31 +02:00

View file

@ -45,7 +45,7 @@ class DbPlugin(Plugin):
_db_error_wait_interval = 5.0
_db_error_retries = 3
def __init__(self, engine=None, *args, **kwargs):
def __init__(self, *args, engine=None, **kwargs):
"""
:param engine: Default SQLAlchemy connection engine string (e.g.
``sqlite:///:memory:`` or ``mysql://user:pass@localhost/test``)
@ -86,16 +86,14 @@ class DbPlugin(Plugin):
return self.engine
@staticmethod
def _build_condition(table, column, value): # type: ignore
if isinstance(value, str):
value = "'{}'".format(value)
elif not isinstance(value, int) and not isinstance(value, float):
value = "'{}'".format(str(value))
def _build_condition(table, column, value): # pylint: disable=unused-argument
if isinstance(value, str) or not isinstance(value, (int, float)):
value = f"'{value}'"
return eval('table.c.{}=={}'.format(column, value))
return eval(f'table.c.{column}=={value}') # pylint: disable=eval-used
@action
def execute(self, statement, engine=None, *args, **kwargs):
def execute(self, statement, *args, engine=None, **kwargs):
"""
Executes a raw SQL statement.
@ -123,32 +121,34 @@ class DbPlugin(Plugin):
with engine.connect() as connection:
connection.execute(text(statement))
def _get_table(self, table, engine=None, *args, **kwargs):
def _get_table(self, table: str, engine=None, *args, **kwargs):
if not engine:
engine = self.get_engine(engine, *args, **kwargs)
db_ok = False
n_tries = 0
last_error = None
table_ = None
while not db_ok and n_tries < self._db_error_retries:
try:
n_tries += 1
metadata = MetaData()
table = Table(table, metadata, autoload_with=engine)
table_ = Table(table, metadata, autoload_with=engine)
db_ok = True
except Exception as e:
last_error = e
wait_time = self._db_error_wait_interval * n_tries
self.logger.exception(e)
self.logger.info('Waiting {} seconds before retrying'.format(wait_time))
self.logger.info('Waiting %s seconds before retrying', wait_time)
time.sleep(wait_time)
engine = self.get_engine(engine, *args, **kwargs)
if not db_ok and last_error:
raise last_error
return table, engine
assert table_, f'No such table: {table}'
return table_, engine
@action
def select(
@ -385,8 +385,9 @@ class DbPlugin(Plugin):
query = table.select().where(
or_(
and_(
self._build_condition(table, k, record.get(k)) for k in key_columns
and_( # type: ignore
self._build_condition(table, k, record.get(k))
for k in key_columns # type: ignore
)
for record in records
)
@ -520,16 +521,16 @@ class DbPlugin(Plugin):
with engine.connect() as connection, conn_begin(connection):
for record in records:
table, engine = self._get_table(table, engine=engine, *args, **kwargs)
delete = table.delete()
table_, engine = self._get_table(table, engine=engine, *args, **kwargs)
delete = table_.delete()
for k, v in record.items():
delete = delete.where(self._build_condition(table, k, v))
delete = delete.where(self._build_condition(table_, k, v))
connection.execute(delete)
def create_all(self, engine, base):
with self.get_session(engine, locked=True) as session, session.begin():
with self.get_session(engine, locked=True) as session:
base.metadata.create_all(session.connection())
@contextmanager