diff --git a/platypush/plugins/db/__init__.py b/platypush/plugins/db/__init__.py index 8cb0f89bb..e419f1573 100644 --- a/platypush/plugins/db/__init__.py +++ b/platypush/plugins/db/__init__.py @@ -7,7 +7,8 @@ from typing import Optional from sqlalchemy import create_engine, Table, MetaData from sqlalchemy.engine import Engine -from sqlalchemy.sql import text +from sqlalchemy.exc import CompileError +from sqlalchemy.sql import and_, or_, text from platypush.plugins import Plugin, action @@ -251,7 +252,9 @@ class DbPlugin(Plugin): :type key_columns: list :param on_duplicate_update: If set, update the records in case of duplicate rows (default: False). If set, you'll need to specify - ``key_columns`` as well. + ``key_columns`` as well. If ``key_columns`` is set, existing + records are found but ``on_duplicate_update`` is false, then + existing records will be ignored. :type on_duplicate_update: bool :param args: Extra arguments that will be passed to ``sqlalchemy.create_engine`` (see @@ -260,6 +263,9 @@ class DbPlugin(Plugin): ``sqlalchemy.create_engine`` (see https:///docs.sqlalchemy.org/en/latest/core/engines.html) + :return: The inserted records, if the underlying engine supports the + ``RETURNING`` statement, otherwise nothing. + Example: Request:: @@ -290,26 +296,98 @@ class DbPlugin(Plugin): key_columns = [] engine = self._get_engine(engine, *args, **kwargs) + table, engine = self._get_table(table, engine=engine, *args, **kwargs) + insert_records = records + update_records = [] + returned_records = [] + with engine.connect() as connection: + # Upsert case + if key_columns: + insert_records, update_records = self._get_new_and_existing_records( + connection, table, records, key_columns + ) + + with connection.begin(): + if insert_records: + insert = table.insert().values(insert_records) + ret = self._execute_try_returning(connection, insert) + if ret: + returned_records += ret + + if update_records and on_duplicate_update: + ret = self._update(connection, table, update_records, key_columns) + if ret: + returned_records = ret + returned_records + + if returned_records: + return returned_records + + @staticmethod + def _execute_try_returning(connection, stmt): + ret = None + stmt_with_ret = stmt.returning('*') + + try: + ret = connection.execute(stmt_with_ret) + except CompileError as e: + if str(e).startswith('RETURNING is not supported'): + connection.execute(stmt) + else: + raise e + + if ret: + return [ + {col.name: getattr(row, col.name, None) for col in stmt.table.c} + for row in ret + ] + + def _get_new_and_existing_records(self, connection, table, records, key_columns): + records_by_key = { + tuple(record.get(k) for k in key_columns): record for record in records + } + + query = table.select().where( + or_( + and_( + self._build_condition(table, k, record.get(k)) for k in key_columns + ) + for record in records + ) + ) + + existing_records = { + tuple(getattr(record, k, None) for k in key_columns): record + for record in connection.execute(query).all() + } + + update_records = [ + record for k, record in records_by_key.items() if k in existing_records + ] + + insert_records = [ + record for k, record in records_by_key.items() if k not in existing_records + ] + + return insert_records, update_records + + def _update(self, connection, table, records, key_columns): + updated_records = [] for record in records: - table, engine = self._get_table(table, engine=engine, *args, **kwargs) + key = {k: v for (k, v) in record.items() if k in key_columns} + values = {k: v for (k, v) in record.items() if k not in key_columns} + update = table.update() - insert = table.insert().values(**record) + for (k, v) in key.items(): + update = update.where(self._build_condition(table, k, v)) - try: - engine.execute(insert) - except Exception as e: - if on_duplicate_update and key_columns: - self.update( - table=table, - records=records, - key_columns=key_columns, - engine=engine, - *args, - **kwargs - ) - else: - raise e + update = update.values(**values) + ret = self._execute_try_returning(connection, update) + if ret: + updated_records += ret + + if updated_records: + return updated_records @action def update(self, table, records, key_columns, engine=None, *args, **kwargs): @@ -331,6 +409,9 @@ class DbPlugin(Plugin): ``sqlalchemy.create_engine`` (see https:///docs.sqlalchemy.org/en/latest/core/engines.html) + :return: The inserted records, if the underlying engine supports the + ``RETURNING`` statement, otherwise nothing. + Example: Request:: @@ -357,21 +438,10 @@ class DbPlugin(Plugin): } } """ - engine = self._get_engine(engine, *args, **kwargs) - - for record in records: + with engine.connect() as connection: table, engine = self._get_table(table, engine=engine, *args, **kwargs) - key = {k: v for (k, v) in record.items() if k in key_columns} - values = {k: v for (k, v) in record.items() if k not in key_columns} - - update = table.update() - - for (k, v) in key.items(): - update = update.where(self._build_condition(table, k, v)) - - update = update.values(**values) - engine.execute(update) + return self._update(connection, table, records, key_columns) @action def delete(self, table, records, engine=None, *args, **kwargs): @@ -412,14 +482,15 @@ class DbPlugin(Plugin): engine = self._get_engine(engine, *args, **kwargs) - for record in records: - table, engine = self._get_table(table, engine=engine, *args, **kwargs) - delete = table.delete() + with engine.connect() as connection, connection.begin(): + for record in records: + table, engine = self._get_table(table, engine=engine, *args, **kwargs) + delete = table.delete() - for (k, v) in record.items(): - delete = delete.where(self._build_condition(table, k, v)) + for (k, v) in record.items(): + delete = delete.where(self._build_condition(table, k, v)) - engine.execute(delete) + connection.execute(delete) # vim:sw=4:ts=4:et: