diff --git a/platypush/config/__init__.py b/platypush/config/__init__.py index 945a47af85..3eaa7ab5a6 100644 --- a/platypush/config/__init__.py +++ b/platypush/config/__init__.py @@ -179,6 +179,8 @@ class Config: self._config['logging'] = logging_config def _init_db(self, db: Optional[str] = None): + self._config['_db'] = self._config.get('db', {}) + # If the db connection string is passed as an argument, use it if db: self._config['db'] = { diff --git a/platypush/plugins/db/__init__.py b/platypush/plugins/db/__init__.py index 5f32258c20..e8ac877665 100644 --- a/platypush/plugins/db/__init__.py +++ b/platypush/plugins/db/__init__.py @@ -40,8 +40,13 @@ class DbPlugin(Plugin): (see https:///docs.sqlalchemy.org/en/latest/core/engines.html) """ - super().__init__() - self.engine_url = engine + from platypush.config import Config + + kwargs.update(Config.get('_db', {})) + super().__init__(*args, **kwargs) + self.engine_url = engine or kwargs.pop('engine', None) + self.args = args + self.kwargs = kwargs self.engine = self.get_engine(engine, *args, **kwargs) def get_engine( @@ -50,6 +55,10 @@ class DbPlugin(Plugin): if engine == self.engine_url and self.engine: return self.engine + if not args: + args = self.args + kwargs = {**self.kwargs, **kwargs} + if engine or not self.engine: if isinstance(engine, Engine): return engine @@ -213,7 +222,7 @@ class DbPlugin(Plugin): query = text(query) if table: - table, engine = self._get_table(table, engine=engine, *args, **kwargs) + table, engine = self._get_table(table, *args, engine=engine, **kwargs) query = table.select() if filter: @@ -240,10 +249,10 @@ class DbPlugin(Plugin): self, table, records, + *args, engine=None, key_columns=None, on_duplicate_update=False, - *args, **kwargs, ): """ @@ -310,7 +319,7 @@ class DbPlugin(Plugin): key_columns = [] engine = self.get_engine(engine, *args, **kwargs) - table, engine = self._get_table(table, engine=engine, *args, **kwargs) + table, engine = self._get_table(table, *args, engine=engine, **kwargs) insert_records = records update_records = [] returned_records = [] @@ -454,7 +463,7 @@ class DbPlugin(Plugin): """ engine = self.get_engine(engine, *args, **kwargs) with engine.connect() as connection: - table, engine = self._get_table(table, engine=engine, *args, **kwargs) + table, engine = self._get_table(table, *args, engine=engine, **kwargs) return self._update(connection, table, records, key_columns) @action @@ -498,7 +507,7 @@ class DbPlugin(Plugin): with engine.connect() as connection: for record in records: - table_, engine = self._get_table(table, engine=engine, *args, **kwargs) + table_, engine = self._get_table(table, *args, engine=engine, **kwargs) delete = table_.delete() for k, v in record.items():