[core] Fix support for custom SQLAlchemy engine options on db conf.

Earlier any extra parameters passed to the `db` configuration other than
`engine` where ignored.

This enables engine-level configurations such as:

```yaml
db:
  # Display all SQL queries
  echo: true
```
This commit is contained in:
Fabio Manganiello 2024-08-31 21:55:19 +02:00
parent 740e35bd5e
commit a3eedc6adc
Signed by untrusted user: blacklight
GPG key ID: D90FBA7F76362774
2 changed files with 18 additions and 7 deletions

View file

@ -179,6 +179,8 @@ class Config:
self._config['logging'] = logging_config self._config['logging'] = logging_config
def _init_db(self, db: Optional[str] = None): def _init_db(self, db: Optional[str] = None):
self._config['_db'] = self._config.get('db', {})
# If the db connection string is passed as an argument, use it # If the db connection string is passed as an argument, use it
if db: if db:
self._config['db'] = { self._config['db'] = {

View file

@ -40,8 +40,13 @@ class DbPlugin(Plugin):
(see https:///docs.sqlalchemy.org/en/latest/core/engines.html) (see https:///docs.sqlalchemy.org/en/latest/core/engines.html)
""" """
super().__init__() from platypush.config import Config
self.engine_url = engine
kwargs.update(Config.get('_db', {}))
super().__init__(*args, **kwargs)
self.engine_url = engine or kwargs.pop('engine', None)
self.args = args
self.kwargs = kwargs
self.engine = self.get_engine(engine, *args, **kwargs) self.engine = self.get_engine(engine, *args, **kwargs)
def get_engine( def get_engine(
@ -50,6 +55,10 @@ class DbPlugin(Plugin):
if engine == self.engine_url and self.engine: if engine == self.engine_url and self.engine:
return self.engine return self.engine
if not args:
args = self.args
kwargs = {**self.kwargs, **kwargs}
if engine or not self.engine: if engine or not self.engine:
if isinstance(engine, Engine): if isinstance(engine, Engine):
return engine return engine
@ -213,7 +222,7 @@ class DbPlugin(Plugin):
query = text(query) query = text(query)
if table: if table:
table, engine = self._get_table(table, engine=engine, *args, **kwargs) table, engine = self._get_table(table, *args, engine=engine, **kwargs)
query = table.select() query = table.select()
if filter: if filter:
@ -240,10 +249,10 @@ class DbPlugin(Plugin):
self, self,
table, table,
records, records,
*args,
engine=None, engine=None,
key_columns=None, key_columns=None,
on_duplicate_update=False, on_duplicate_update=False,
*args,
**kwargs, **kwargs,
): ):
""" """
@ -310,7 +319,7 @@ class DbPlugin(Plugin):
key_columns = [] key_columns = []
engine = self.get_engine(engine, *args, **kwargs) engine = self.get_engine(engine, *args, **kwargs)
table, engine = self._get_table(table, engine=engine, *args, **kwargs) table, engine = self._get_table(table, *args, engine=engine, **kwargs)
insert_records = records insert_records = records
update_records = [] update_records = []
returned_records = [] returned_records = []
@ -454,7 +463,7 @@ class DbPlugin(Plugin):
""" """
engine = self.get_engine(engine, *args, **kwargs) engine = self.get_engine(engine, *args, **kwargs)
with engine.connect() as connection: with engine.connect() as connection:
table, engine = self._get_table(table, engine=engine, *args, **kwargs) table, engine = self._get_table(table, *args, engine=engine, **kwargs)
return self._update(connection, table, records, key_columns) return self._update(connection, table, records, key_columns)
@action @action
@ -498,7 +507,7 @@ class DbPlugin(Plugin):
with engine.connect() as connection: with engine.connect() as connection:
for record in records: for record in records:
table_, engine = self._get_table(table, engine=engine, *args, **kwargs) table_, engine = self._get_table(table, *args, engine=engine, **kwargs)
delete = table_.delete() delete = table_.delete()
for k, v in record.items(): for k, v in record.items():