Compare commits
2 commits
91df18f7b5
...
37722d12cd
Author | SHA1 | Date | |
---|---|---|---|
37722d12cd | |||
6fa179e769 |
1 changed files with 19 additions and 18 deletions
|
@ -45,7 +45,7 @@ class DbPlugin(Plugin):
|
||||||
_db_error_wait_interval = 5.0
|
_db_error_wait_interval = 5.0
|
||||||
_db_error_retries = 3
|
_db_error_retries = 3
|
||||||
|
|
||||||
def __init__(self, engine=None, *args, **kwargs):
|
def __init__(self, *args, engine=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
:param engine: Default SQLAlchemy connection engine string (e.g.
|
:param engine: Default SQLAlchemy connection engine string (e.g.
|
||||||
``sqlite:///:memory:`` or ``mysql://user:pass@localhost/test``)
|
``sqlite:///:memory:`` or ``mysql://user:pass@localhost/test``)
|
||||||
|
@ -86,16 +86,14 @@ class DbPlugin(Plugin):
|
||||||
return self.engine
|
return self.engine
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_condition(table, column, value): # type: ignore
|
def _build_condition(table, column, value): # pylint: disable=unused-argument
|
||||||
if isinstance(value, str):
|
if isinstance(value, str) or not isinstance(value, (int, float)):
|
||||||
value = "'{}'".format(value)
|
value = f"'{value}'"
|
||||||
elif not isinstance(value, int) and not isinstance(value, float):
|
|
||||||
value = "'{}'".format(str(value))
|
|
||||||
|
|
||||||
return eval('table.c.{}=={}'.format(column, value))
|
return eval(f'table.c.{column}=={value}') # pylint: disable=eval-used
|
||||||
|
|
||||||
@action
|
@action
|
||||||
def execute(self, statement, engine=None, *args, **kwargs):
|
def execute(self, statement, *args, engine=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Executes a raw SQL statement.
|
Executes a raw SQL statement.
|
||||||
|
|
||||||
|
@ -123,32 +121,34 @@ class DbPlugin(Plugin):
|
||||||
with engine.connect() as connection:
|
with engine.connect() as connection:
|
||||||
connection.execute(text(statement))
|
connection.execute(text(statement))
|
||||||
|
|
||||||
def _get_table(self, table, engine=None, *args, **kwargs):
|
def _get_table(self, table: str, engine=None, *args, **kwargs):
|
||||||
if not engine:
|
if not engine:
|
||||||
engine = self.get_engine(engine, *args, **kwargs)
|
engine = self.get_engine(engine, *args, **kwargs)
|
||||||
|
|
||||||
db_ok = False
|
db_ok = False
|
||||||
n_tries = 0
|
n_tries = 0
|
||||||
last_error = None
|
last_error = None
|
||||||
|
table_ = None
|
||||||
|
|
||||||
while not db_ok and n_tries < self._db_error_retries:
|
while not db_ok and n_tries < self._db_error_retries:
|
||||||
try:
|
try:
|
||||||
n_tries += 1
|
n_tries += 1
|
||||||
metadata = MetaData()
|
metadata = MetaData()
|
||||||
table = Table(table, metadata, autoload_with=engine)
|
table_ = Table(table, metadata, autoload_with=engine)
|
||||||
db_ok = True
|
db_ok = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
last_error = e
|
last_error = e
|
||||||
wait_time = self._db_error_wait_interval * n_tries
|
wait_time = self._db_error_wait_interval * n_tries
|
||||||
self.logger.exception(e)
|
self.logger.exception(e)
|
||||||
self.logger.info('Waiting {} seconds before retrying'.format(wait_time))
|
self.logger.info('Waiting %s seconds before retrying', wait_time)
|
||||||
time.sleep(wait_time)
|
time.sleep(wait_time)
|
||||||
engine = self.get_engine(engine, *args, **kwargs)
|
engine = self.get_engine(engine, *args, **kwargs)
|
||||||
|
|
||||||
if not db_ok and last_error:
|
if not db_ok and last_error:
|
||||||
raise last_error
|
raise last_error
|
||||||
|
|
||||||
return table, engine
|
assert table_, f'No such table: {table}'
|
||||||
|
return table_, engine
|
||||||
|
|
||||||
@action
|
@action
|
||||||
def select(
|
def select(
|
||||||
|
@ -385,8 +385,9 @@ class DbPlugin(Plugin):
|
||||||
|
|
||||||
query = table.select().where(
|
query = table.select().where(
|
||||||
or_(
|
or_(
|
||||||
and_(
|
and_( # type: ignore
|
||||||
self._build_condition(table, k, record.get(k)) for k in key_columns
|
self._build_condition(table, k, record.get(k))
|
||||||
|
for k in key_columns # type: ignore
|
||||||
)
|
)
|
||||||
for record in records
|
for record in records
|
||||||
)
|
)
|
||||||
|
@ -520,16 +521,16 @@ class DbPlugin(Plugin):
|
||||||
|
|
||||||
with engine.connect() as connection, conn_begin(connection):
|
with engine.connect() as connection, conn_begin(connection):
|
||||||
for record in records:
|
for record in records:
|
||||||
table, engine = self._get_table(table, engine=engine, *args, **kwargs)
|
table_, engine = self._get_table(table, engine=engine, *args, **kwargs)
|
||||||
delete = table.delete()
|
delete = table_.delete()
|
||||||
|
|
||||||
for k, v in record.items():
|
for k, v in record.items():
|
||||||
delete = delete.where(self._build_condition(table, k, v))
|
delete = delete.where(self._build_condition(table_, k, v))
|
||||||
|
|
||||||
connection.execute(delete)
|
connection.execute(delete)
|
||||||
|
|
||||||
def create_all(self, engine, base):
|
def create_all(self, engine, base):
|
||||||
with self.get_session(engine, locked=True) as session, session.begin():
|
with self.get_session(engine, locked=True) as session:
|
||||||
base.metadata.create_all(session.connection())
|
base.metadata.create_all(session.connection())
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
|
Loading…
Reference in a new issue