Merge pull request 'Fixed compatibility with SQLAlchemy >= 2.0' (#250) from 239-sqlalchemy-2-compatibility into master

Reviewed-on: platypush/platypush#250
Closes: #239
This commit is contained in:
Fabio Manganiello 2023-04-25 10:47:27 +02:00
commit 99382e4505
5 changed files with 97 additions and 93 deletions

View file

@ -25,7 +25,7 @@ class DbPlugin(Plugin):
_db_error_wait_interval = 5.0 _db_error_wait_interval = 5.0
_db_error_retries = 3 _db_error_retries = 3
def __init__(self, engine=None, *args, **kwargs): def __init__(self, engine=None, **kwargs):
""" """
:param engine: Default SQLAlchemy connection engine string (e.g. :param engine: Default SQLAlchemy connection engine string (e.g.
``sqlite:///:memory:`` or ``mysql://user:pass@localhost/test``) ``sqlite:///:memory:`` or ``mysql://user:pass@localhost/test``)
@ -42,7 +42,7 @@ class DbPlugin(Plugin):
super().__init__() super().__init__()
self.engine_url = engine self.engine_url = engine
self.engine = self.get_engine(engine, *args, **kwargs) self.engine = self.get_engine(engine, **kwargs)
def get_engine( def get_engine(
self, engine: Optional[Union[str, Engine]] = None, *args, **kwargs self, engine: Optional[Union[str, Engine]] = None, *args, **kwargs
@ -66,16 +66,14 @@ class DbPlugin(Plugin):
return self.engine return self.engine
@staticmethod @staticmethod
def _build_condition(table, column, value): # type: ignore def _build_condition(table, column, value): # pylint: disable=unused-argument
if isinstance(value, str): if isinstance(value, str) or not isinstance(value, (int, float)):
value = "'{}'".format(value) value = f"'{value}'"
elif not isinstance(value, int) and not isinstance(value, float):
value = "'{}'".format(str(value))
return eval('table.c.{}=={}'.format(column, value)) return eval(f'table.c.{column}=={value}') # pylint: disable=eval-used
@action @action
def execute(self, statement, engine=None, *args, **kwargs): def execute(self, statement, *args, engine=None, **kwargs):
""" """
Executes a raw SQL statement. Executes a raw SQL statement.
@ -98,37 +96,37 @@ class DbPlugin(Plugin):
(see https:///docs.sqlalchemy.org/en/latest/core/engines.html) (see https:///docs.sqlalchemy.org/en/latest/core/engines.html)
""" """
engine = self.get_engine(engine, *args, **kwargs) with self.get_engine(engine, *args, **kwargs).connect() as connection:
connection.execute(text(statement))
with engine.connect() as connection: def _get_table(self, table: str, *args, engine=None, **kwargs):
connection.execute(statement)
def _get_table(self, table, engine=None, *args, **kwargs):
if not engine: if not engine:
engine = self.get_engine(engine, *args, **kwargs) engine = self.get_engine(engine, *args, **kwargs)
db_ok = False db_ok = False
n_tries = 0 n_tries = 0
last_error = None last_error = None
table_ = None
while not db_ok and n_tries < self._db_error_retries: while not db_ok and n_tries < self._db_error_retries:
try: try:
n_tries += 1 n_tries += 1
metadata = MetaData() metadata = MetaData()
table = Table(table, metadata, autoload=True, autoload_with=engine) table_ = Table(table, metadata, autoload_with=engine)
db_ok = True db_ok = True
except Exception as e: except Exception as e:
last_error = e last_error = e
wait_time = self._db_error_wait_interval * n_tries wait_time = self._db_error_wait_interval * n_tries
self.logger.exception(e) self.logger.exception(e)
self.logger.info('Waiting {} seconds before retrying'.format(wait_time)) self.logger.info('Waiting %s seconds before retrying', wait_time)
time.sleep(wait_time) time.sleep(wait_time)
engine = self.get_engine(engine, *args, **kwargs) engine = self.get_engine(engine, *args, **kwargs)
if not db_ok and last_error: if not db_ok and last_error:
raise last_error raise last_error
return table, engine assert table_, f'No such table: {table}'
return table_, engine
@action @action
def select( def select(
@ -324,7 +322,6 @@ class DbPlugin(Plugin):
connection, table, records, key_columns connection, table, records, key_columns
) )
with connection.begin():
if insert_records: if insert_records:
insert = table.insert().values(insert_records) insert = table.insert().values(insert_records)
ret = self._execute_try_returning(connection, insert) ret = self._execute_try_returning(connection, insert)
@ -365,8 +362,9 @@ class DbPlugin(Plugin):
query = table.select().where( query = table.select().where(
or_( or_(
and_( and_( # type: ignore
self._build_condition(table, k, record.get(k)) for k in key_columns self._build_condition(table, k, record.get(k))
for k in key_columns # type: ignore
) )
for record in records for record in records
) )
@ -498,13 +496,13 @@ class DbPlugin(Plugin):
engine = self.get_engine(engine, *args, **kwargs) engine = self.get_engine(engine, *args, **kwargs)
with engine.connect() as connection, connection.begin(): with engine.connect() as connection:
for record in records: for record in records:
table, engine = self._get_table(table, engine=engine, *args, **kwargs) table_, engine = self._get_table(table, engine=engine, *args, **kwargs)
delete = table.delete() delete = table_.delete()
for k, v in record.items(): for k, v in record.items():
delete = delete.where(self._build_condition(table, k, v)) delete = delete.where(self._build_condition(table_, k, v))
connection.execute(delete) connection.execute(delete)
@ -514,7 +512,7 @@ class DbPlugin(Plugin):
@contextmanager @contextmanager
def get_session( def get_session(
self, engine=None, locked=False, autoflush=True, *args, **kwargs self, *args, engine=None, locked=False, autoflush=True, **kwargs
) -> Generator[Session, None, None]: ) -> Generator[Session, None, None]:
engine = self.get_engine(engine, *args, **kwargs) engine = self.get_engine(engine, *args, **kwargs)
if locked: if locked:
@ -523,16 +521,20 @@ class DbPlugin(Plugin):
# Mock lock # Mock lock
lock = RLock() lock = RLock()
with lock, engine.connect() as conn, conn.begin(): with lock, engine.connect() as conn:
session = scoped_session( session_maker = scoped_session(
sessionmaker( sessionmaker(
expire_on_commit=False, expire_on_commit=False,
autoflush=autoflush, autoflush=autoflush,
) )
) )
session.configure(bind=conn) session_maker.configure(bind=conn)
yield session() session = session_maker()
yield session
session.flush()
session.commit()
# vim:sw=4:ts=4:et: # vim:sw=4:ts=4:et:

View file

@ -3,7 +3,7 @@ from threading import Thread
from time import time from time import time
from typing import Optional, Any, Collection, Mapping from typing import Optional, Any, Collection, Mapping
from sqlalchemy import or_ from sqlalchemy import or_, text
from sqlalchemy.orm import make_transient, Session from sqlalchemy.orm import make_transient, Session
from platypush.config import Config from platypush.config import Config
@ -198,7 +198,7 @@ class EntitiesPlugin(Plugin):
if str(session.connection().engine.url).startswith('sqlite://'): if str(session.connection().engine.url).startswith('sqlite://'):
# SQLite requires foreign_keys to be explicitly enabled # SQLite requires foreign_keys to be explicitly enabled
# in order to proper manage cascade deletions # in order to proper manage cascade deletions
session.execute('PRAGMA foreign_keys = ON') session.execute(text('PRAGMA foreign_keys = ON'))
entities: Collection[Entity] = ( entities: Collection[Entity] = (
session.query(Entity).filter(Entity.id.in_(entities)).all() session.query(Entity).filter(Entity.id.in_(entities)).all()

View file

@ -1,6 +1,21 @@
from platypush.config import Config from sqlalchemy import Column, String
from platypush.common.db import declarative_base
from platypush.context import get_plugin from platypush.context import get_plugin
from platypush.plugins import Plugin, action from platypush.plugins import Plugin, action
from platypush.plugins.db import DbPlugin
Base = declarative_base()
# pylint: disable=too-few-public-methods
class Variable(Base):
"""Models the variable table"""
__tablename__ = 'variable'
name = Column(String, primary_key=True, nullable=False)
value = Column(String)
class VariablePlugin(Plugin): class VariablePlugin(Plugin):
@ -11,8 +26,6 @@ class VariablePlugin(Plugin):
will be stored either persisted on a local database or on the local Redis instance. will be stored either persisted on a local database or on the local Redis instance.
""" """
_variable_table_name = 'variable'
def __init__(self, **kwargs): def __init__(self, **kwargs):
""" """
The plugin will create a table named ``variable`` on the database The plugin will create a table named ``variable`` on the database
@ -21,24 +34,14 @@ class VariablePlugin(Plugin):
""" """
super().__init__(**kwargs) super().__init__(**kwargs)
self.db_plugin = get_plugin('db') db_plugin = get_plugin('db')
self.redis_plugin = get_plugin('redis') redis_plugin = get_plugin('redis')
assert db_plugin, 'Database plugin not configured'
assert redis_plugin, 'Redis plugin not configured'
db = Config.get('db') self.redis_plugin = redis_plugin
self.db_config = { self.db_plugin: DbPlugin = db_plugin
'engine': db.get('engine'), self.db_plugin.create_all(self.db_plugin.get_engine(), Base)
'args': db.get('args', []),
'kwargs': db.get('kwargs', {})
}
self._create_tables()
# self._variables = {}
def _create_tables(self):
self.db_plugin.execute("""CREATE TABLE IF NOT EXISTS {}(
name varchar(255) not null primary key,
value text
)""".format(self._variable_table_name))
@action @action
def get(self, name, default_value=None): def get(self, name, default_value=None):
@ -53,13 +56,10 @@ class VariablePlugin(Plugin):
:returns: A map in the format ``{"<name>":"<value>"}`` :returns: A map in the format ``{"<name>":"<value>"}``
""" """
rows = self.db_plugin.select(table=self._variable_table_name, with self.db_plugin.get_session() as session:
filter={'name': name}, var = session.query(Variable).filter_by(name=name).first()
engine=self.db_config['engine'],
*self.db_config['args'],
**self.db_config['kwargs']).output
return {name: rows[0]['value'] if rows else default_value} return {name: (var.value if var is not None else default_value)}
@action @action
def set(self, **kwargs): def set(self, **kwargs):
@ -69,15 +69,24 @@ class VariablePlugin(Plugin):
:param kwargs: Key-value list of variables to set (e.g. ``foo='bar', answer=42``) :param kwargs: Key-value list of variables to set (e.g. ``foo='bar', answer=42``)
""" """
records = [{'name': k, 'value': v} with self.db_plugin.get_session() as session:
for (k, v) in kwargs.items()] existing_vars = {
var.name: var
for var in session.query(Variable)
.filter(Variable.name.in_(kwargs.keys()))
.all()
}
self.db_plugin.insert(table=self._variable_table_name, new_vars = {
records=records, key_columns=['name'], name: Variable(name=name, value=value)
engine=self.db_config['engine'], for name, value in kwargs.items()
on_duplicate_update=True, if name not in existing_vars
*self.db_config['args'], }
**self.db_config['kwargs'])
for name, var in existing_vars.items():
var.value = kwargs[name] # type: ignore
session.add_all([*existing_vars.values(), *new_vars.values()])
return kwargs return kwargs
@ -90,12 +99,8 @@ class VariablePlugin(Plugin):
:type name: str :type name: str
""" """
records = [{'name': name}] with self.db_plugin.get_session() as session:
session.query(Variable).filter_by(name=name).delete()
self.db_plugin.delete(table=self._variable_table_name,
records=records, engine=self.db_config['engine'],
*self.db_config['args'],
**self.db_config['kwargs'])
return True return True
@ -150,4 +155,5 @@ class VariablePlugin(Plugin):
return self.redis_plugin.expire(name, expire) return self.redis_plugin.expire(name, expire)
# vim:sw=4:ts=4:et: # vim:sw=4:ts=4:et:

View file

@ -3,11 +3,11 @@ import datetime
import hashlib import hashlib
import json import json
import random import random
import rsa
import time import time
from typing import Optional, Dict from typing import Optional, Dict
import bcrypt import bcrypt
import rsa
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import make_transient from sqlalchemy.orm import make_transient
@ -73,7 +73,7 @@ class UserManager:
username=username, username=username,
password=self._encrypt_password(password), password=self._encrypt_password(password),
created_at=datetime.datetime.utcnow(), created_at=datetime.datetime.utcnow(),
**kwargs **kwargs,
) )
session.add(record) session.add(record)
@ -238,9 +238,7 @@ class UserManager:
indent=None, indent=None,
) )
return base64.b64encode( return base64.b64encode(rsa.encrypt(payload.encode('ascii'), pub_key)).decode()
rsa.encrypt(payload.encode('ascii'), pub_key)
).decode()
def validate_jwt_token(self, token: str) -> Dict[str, str]: def validate_jwt_token(self, token: str) -> Dict[str, str]:
""" """
@ -263,21 +261,19 @@ class UserManager:
try: try:
payload = json.loads( payload = json.loads(
rsa.decrypt( rsa.decrypt(base64.b64decode(token.encode('ascii')), priv_key).decode(
base64.b64decode(token.encode('ascii')), 'ascii'
priv_key )
).decode('ascii')
) )
except (TypeError, ValueError) as e: except (TypeError, ValueError) as e:
raise InvalidJWTTokenException(f'Could not decode JWT token: {e}') raise InvalidJWTTokenException(f'Could not decode JWT token: {e}') from e
expires_at = payload.get('expires_at') expires_at = payload.get('expires_at')
if expires_at and time.time() > expires_at: if expires_at and time.time() > expires_at:
raise InvalidJWTTokenException('Expired JWT token') raise InvalidJWTTokenException('Expired JWT token')
user = self.authenticate_user( user = self.authenticate_user(
payload.get('username', ''), payload.get('username', ''), payload.get('password', '')
payload.get('password', '')
) )
if not user: if not user:

View file

@ -57,7 +57,7 @@ setup(
'redis', 'redis',
'requests', 'requests',
'croniter', 'croniter',
'sqlalchemy<2.0.0', 'sqlalchemy',
'websockets', 'websockets',
'websocket-client', 'websocket-client',
'wheel', 'wheel',