From d33d760361375b8f00c069342f1c7fdeaa3ac6b1 Mon Sep 17 00:00:00 2001 From: Fabio Manganiello Date: Mon, 24 Apr 2023 23:21:39 +0200 Subject: [PATCH 1/2] Better way to import `declarative_base` from SQLAlchemy. Import `declarative_base` in a way that is compatible with any SQLAlchemy version between 1.3 and 2.x. --- platypush/backend/covid19/__init__.py | 3 ++- platypush/backend/github/__init__.py | 3 ++- platypush/backend/http/request/rss/__init__.py | 3 ++- platypush/backend/mail/__init__.py | 4 +++- platypush/common/db.py | 9 ++++++++- platypush/plugins/media/search/local.py | 3 ++- 6 files changed, 19 insertions(+), 6 deletions(-) diff --git a/platypush/backend/covid19/__init__.py b/platypush/backend/covid19/__init__.py index 1be1db01..5f744497 100644 --- a/platypush/backend/covid19/__init__.py +++ b/platypush/backend/covid19/__init__.py @@ -3,9 +3,10 @@ import os from typing import Optional, Union, List, Dict, Any from sqlalchemy import create_engine, Column, Integer, String, DateTime -from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base +from sqlalchemy.orm import sessionmaker, scoped_session from platypush.backend import Backend +from platypush.common.db import declarative_base from platypush.config import Config from platypush.context import get_plugin from platypush.message.event.covid19 import Covid19UpdateEvent diff --git a/platypush/backend/github/__init__.py b/platypush/backend/github/__init__.py index 0a1bc3e6..6922db39 100644 --- a/platypush/backend/github/__init__.py +++ b/platypush/backend/github/__init__.py @@ -6,9 +6,10 @@ from typing import Optional, List import requests from sqlalchemy import create_engine, Column, String, DateTime -from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base +from sqlalchemy.orm import sessionmaker, scoped_session from platypush.backend import Backend +from platypush.common.db import declarative_base from platypush.config import Config from platypush.message.event.github import ( GithubPushEvent, diff --git a/platypush/backend/http/request/rss/__init__.py b/platypush/backend/http/request/rss/__init__.py index 7ca6d9c6..6c624ae4 100644 --- a/platypush/backend/http/request/rss/__init__.py +++ b/platypush/backend/http/request/rss/__init__.py @@ -12,10 +12,11 @@ from sqlalchemy import ( ForeignKey, ) -from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base +from sqlalchemy.orm import sessionmaker, scoped_session from sqlalchemy.sql.expression import func from platypush.backend.http.request import HttpRequest +from platypush.common.db import declarative_base from platypush.config import Config from platypush.context import get_plugin from platypush.message.event.http.rss import NewFeedEvent diff --git a/platypush/backend/mail/__init__.py b/platypush/backend/mail/__init__.py index 84f1f8ab..5dd4474f 100644 --- a/platypush/backend/mail/__init__.py +++ b/platypush/backend/mail/__init__.py @@ -9,9 +9,10 @@ from threading import Thread, RLock from typing import List, Dict, Any, Optional, Tuple from sqlalchemy import engine, create_engine, Column, Integer, String, DateTime -from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base +from sqlalchemy.orm import sessionmaker, scoped_session from platypush.backend import Backend +from platypush.common.db import declarative_base from platypush.config import Config from platypush.context import get_plugin from platypush.message.event.mail import ( @@ -40,6 +41,7 @@ class MailboxStatus(Base): # + # @dataclass class Mailbox: diff --git a/platypush/common/db.py b/platypush/common/db.py index 59be7030..f1979e3b 100644 --- a/platypush/common/db.py +++ b/platypush/common/db.py @@ -1,3 +1,10 @@ -from sqlalchemy.orm import declarative_base +from sqlalchemy import __version__ + +sa_version = tuple(map(int, __version__.split('.'))) + +if sa_version >= (1, 4, 0): + from sqlalchemy.orm import declarative_base +else: + from sqlalchemy.ext.declarative import declarative_base Base = declarative_base() diff --git a/platypush/plugins/media/search/local.py b/platypush/plugins/media/search/local.py index 76868aa4..21603dc1 100644 --- a/platypush/plugins/media/search/local.py +++ b/platypush/plugins/media/search/local.py @@ -12,9 +12,10 @@ from sqlalchemy import ( PrimaryKeyConstraint, ForeignKey, ) -from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base +from sqlalchemy.orm import sessionmaker, scoped_session from sqlalchemy.sql.expression import func +from platypush.common.db import declarative_base from platypush.config import Config from platypush.plugins.media import MediaPlugin from platypush.plugins.media.search import MediaSearcher From f4e13d0cb06dc6173c4ef3ed2d8cf29e7b60956b Mon Sep 17 00:00:00 2001 From: Fabio Manganiello Date: Mon, 24 Apr 2023 23:55:50 +0200 Subject: [PATCH 2/2] No need for `session.begin` in `db.create_all`. --- platypush/plugins/db/__init__.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/platypush/plugins/db/__init__.py b/platypush/plugins/db/__init__.py index e777ca16..a1d4c51a 100644 --- a/platypush/plugins/db/__init__.py +++ b/platypush/plugins/db/__init__.py @@ -139,7 +139,7 @@ class DbPlugin(Plugin): engine=None, data: Optional[dict] = None, *args, - **kwargs + **kwargs, ): """ Returns rows (as a list of hashes) given a query. @@ -219,7 +219,7 @@ class DbPlugin(Plugin): query = table.select() if filter: - for (k, v) in filter.items(): + for k, v in filter.items(): query = query.where(self._build_condition(table, k, v)) if query is None: @@ -246,7 +246,7 @@ class DbPlugin(Plugin): key_columns=None, on_duplicate_update=False, *args, - **kwargs + **kwargs, ): """ Inserts records (as a list of hashes) into a table. @@ -394,7 +394,7 @@ class DbPlugin(Plugin): values = {k: v for (k, v) in record.items() if k not in key_columns} update = table.update() - for (k, v) in key.items(): + for k, v in key.items(): update = update.where(self._build_condition(table, k, v)) update = update.values(**values) @@ -503,13 +503,13 @@ class DbPlugin(Plugin): table, engine = self._get_table(table, engine=engine, *args, **kwargs) delete = table.delete() - for (k, v) in record.items(): + for k, v in record.items(): delete = delete.where(self._build_condition(table, k, v)) connection.execute(delete) def create_all(self, engine, base): - with (self.get_session(engine, locked=True) as session, session.begin()): + with self.get_session(engine, locked=True) as session: base.metadata.create_all(session.connection()) @contextmanager @@ -523,7 +523,7 @@ class DbPlugin(Plugin): # Mock lock lock = RLock() - with (lock, engine.connect() as conn, conn.begin()): + with lock, engine.connect() as conn, conn.begin(): session = scoped_session( sessionmaker( expire_on_commit=False,