Compare commits

..

2 commits

Author SHA1 Message Date
37722d12cd
No need for session.begin in db.create_all. 2023-04-24 23:55:50 +02:00
6fa179e769
LINT fixes 2023-04-24 23:49:31 +02:00

View file

@ -45,7 +45,7 @@ class DbPlugin(Plugin):
_db_error_wait_interval = 5.0 _db_error_wait_interval = 5.0
_db_error_retries = 3 _db_error_retries = 3
def __init__(self, engine=None, *args, **kwargs): def __init__(self, *args, engine=None, **kwargs):
""" """
:param engine: Default SQLAlchemy connection engine string (e.g. :param engine: Default SQLAlchemy connection engine string (e.g.
``sqlite:///:memory:`` or ``mysql://user:pass@localhost/test``) ``sqlite:///:memory:`` or ``mysql://user:pass@localhost/test``)
@ -86,16 +86,14 @@ class DbPlugin(Plugin):
return self.engine return self.engine
@staticmethod @staticmethod
def _build_condition(table, column, value): # type: ignore def _build_condition(table, column, value): # pylint: disable=unused-argument
if isinstance(value, str): if isinstance(value, str) or not isinstance(value, (int, float)):
value = "'{}'".format(value) value = f"'{value}'"
elif not isinstance(value, int) and not isinstance(value, float):
value = "'{}'".format(str(value))
return eval('table.c.{}=={}'.format(column, value)) return eval(f'table.c.{column}=={value}') # pylint: disable=eval-used
@action @action
def execute(self, statement, engine=None, *args, **kwargs): def execute(self, statement, *args, engine=None, **kwargs):
""" """
Executes a raw SQL statement. Executes a raw SQL statement.
@ -123,32 +121,34 @@ class DbPlugin(Plugin):
with engine.connect() as connection: with engine.connect() as connection:
connection.execute(text(statement)) connection.execute(text(statement))
def _get_table(self, table, engine=None, *args, **kwargs): def _get_table(self, table: str, engine=None, *args, **kwargs):
if not engine: if not engine:
engine = self.get_engine(engine, *args, **kwargs) engine = self.get_engine(engine, *args, **kwargs)
db_ok = False db_ok = False
n_tries = 0 n_tries = 0
last_error = None last_error = None
table_ = None
while not db_ok and n_tries < self._db_error_retries: while not db_ok and n_tries < self._db_error_retries:
try: try:
n_tries += 1 n_tries += 1
metadata = MetaData() metadata = MetaData()
table = Table(table, metadata, autoload_with=engine) table_ = Table(table, metadata, autoload_with=engine)
db_ok = True db_ok = True
except Exception as e: except Exception as e:
last_error = e last_error = e
wait_time = self._db_error_wait_interval * n_tries wait_time = self._db_error_wait_interval * n_tries
self.logger.exception(e) self.logger.exception(e)
self.logger.info('Waiting {} seconds before retrying'.format(wait_time)) self.logger.info('Waiting %s seconds before retrying', wait_time)
time.sleep(wait_time) time.sleep(wait_time)
engine = self.get_engine(engine, *args, **kwargs) engine = self.get_engine(engine, *args, **kwargs)
if not db_ok and last_error: if not db_ok and last_error:
raise last_error raise last_error
return table, engine assert table_, f'No such table: {table}'
return table_, engine
@action @action
def select( def select(
@ -385,8 +385,9 @@ class DbPlugin(Plugin):
query = table.select().where( query = table.select().where(
or_( or_(
and_( and_( # type: ignore
self._build_condition(table, k, record.get(k)) for k in key_columns self._build_condition(table, k, record.get(k))
for k in key_columns # type: ignore
) )
for record in records for record in records
) )
@ -520,16 +521,16 @@ class DbPlugin(Plugin):
with engine.connect() as connection, conn_begin(connection): with engine.connect() as connection, conn_begin(connection):
for record in records: for record in records:
table, engine = self._get_table(table, engine=engine, *args, **kwargs) table_, engine = self._get_table(table, engine=engine, *args, **kwargs)
delete = table.delete() delete = table_.delete()
for k, v in record.items(): for k, v in record.items():
delete = delete.where(self._build_condition(table, k, v)) delete = delete.where(self._build_condition(table_, k, v))
connection.execute(delete) connection.execute(delete)
def create_all(self, engine, base): def create_all(self, engine, base):
with self.get_session(engine, locked=True) as session, session.begin(): with self.get_session(engine, locked=True) as session:
base.metadata.create_all(session.connection()) base.metadata.create_all(session.connection())
@contextmanager @contextmanager