Merge pull request 'Fixed compatibility with SQLAlchemy >= 2.0' (#250) from 239-sqlalchemy-2-compatibility into master

Reviewed-on: platypush/platypush#250
Closes: #239
This commit is contained in:
Fabio Manganiello 2023-04-25 10:47:27 +02:00
commit 99382e4505
5 changed files with 97 additions and 93 deletions

View file

@ -25,7 +25,7 @@ class DbPlugin(Plugin):
_db_error_wait_interval = 5.0
_db_error_retries = 3
def __init__(self, engine=None, *args, **kwargs):
def __init__(self, engine=None, **kwargs):
"""
:param engine: Default SQLAlchemy connection engine string (e.g.
``sqlite:///:memory:`` or ``mysql://user:pass@localhost/test``)
@ -42,7 +42,7 @@ class DbPlugin(Plugin):
super().__init__()
self.engine_url = engine
self.engine = self.get_engine(engine, *args, **kwargs)
self.engine = self.get_engine(engine, **kwargs)
def get_engine(
self, engine: Optional[Union[str, Engine]] = None, *args, **kwargs
@ -66,16 +66,14 @@ class DbPlugin(Plugin):
return self.engine
@staticmethod
def _build_condition(table, column, value): # type: ignore
if isinstance(value, str):
value = "'{}'".format(value)
elif not isinstance(value, int) and not isinstance(value, float):
value = "'{}'".format(str(value))
def _build_condition(table, column, value): # pylint: disable=unused-argument
if isinstance(value, str) or not isinstance(value, (int, float)):
value = f"'{value}'"
return eval('table.c.{}=={}'.format(column, value))
return eval(f'table.c.{column}=={value}') # pylint: disable=eval-used
@action
def execute(self, statement, engine=None, *args, **kwargs):
def execute(self, statement, *args, engine=None, **kwargs):
"""
Executes a raw SQL statement.
@ -98,37 +96,37 @@ class DbPlugin(Plugin):
(see https:///docs.sqlalchemy.org/en/latest/core/engines.html)
"""
engine = self.get_engine(engine, *args, **kwargs)
with self.get_engine(engine, *args, **kwargs).connect() as connection:
connection.execute(text(statement))
with engine.connect() as connection:
connection.execute(statement)
def _get_table(self, table, engine=None, *args, **kwargs):
def _get_table(self, table: str, *args, engine=None, **kwargs):
if not engine:
engine = self.get_engine(engine, *args, **kwargs)
db_ok = False
n_tries = 0
last_error = None
table_ = None
while not db_ok and n_tries < self._db_error_retries:
try:
n_tries += 1
metadata = MetaData()
table = Table(table, metadata, autoload=True, autoload_with=engine)
table_ = Table(table, metadata, autoload_with=engine)
db_ok = True
except Exception as e:
last_error = e
wait_time = self._db_error_wait_interval * n_tries
self.logger.exception(e)
self.logger.info('Waiting {} seconds before retrying'.format(wait_time))
self.logger.info('Waiting %s seconds before retrying', wait_time)
time.sleep(wait_time)
engine = self.get_engine(engine, *args, **kwargs)
if not db_ok and last_error:
raise last_error
return table, engine
assert table_, f'No such table: {table}'
return table_, engine
@action
def select(
@ -324,7 +322,6 @@ class DbPlugin(Plugin):
connection, table, records, key_columns
)
with connection.begin():
if insert_records:
insert = table.insert().values(insert_records)
ret = self._execute_try_returning(connection, insert)
@ -365,8 +362,9 @@ class DbPlugin(Plugin):
query = table.select().where(
or_(
and_(
self._build_condition(table, k, record.get(k)) for k in key_columns
and_( # type: ignore
self._build_condition(table, k, record.get(k))
for k in key_columns # type: ignore
)
for record in records
)
@ -498,13 +496,13 @@ class DbPlugin(Plugin):
engine = self.get_engine(engine, *args, **kwargs)
with engine.connect() as connection, connection.begin():
with engine.connect() as connection:
for record in records:
table, engine = self._get_table(table, engine=engine, *args, **kwargs)
delete = table.delete()
table_, engine = self._get_table(table, engine=engine, *args, **kwargs)
delete = table_.delete()
for k, v in record.items():
delete = delete.where(self._build_condition(table, k, v))
delete = delete.where(self._build_condition(table_, k, v))
connection.execute(delete)
@ -514,7 +512,7 @@ class DbPlugin(Plugin):
@contextmanager
def get_session(
self, engine=None, locked=False, autoflush=True, *args, **kwargs
self, *args, engine=None, locked=False, autoflush=True, **kwargs
) -> Generator[Session, None, None]:
engine = self.get_engine(engine, *args, **kwargs)
if locked:
@ -523,16 +521,20 @@ class DbPlugin(Plugin):
# Mock lock
lock = RLock()
with lock, engine.connect() as conn, conn.begin():
session = scoped_session(
with lock, engine.connect() as conn:
session_maker = scoped_session(
sessionmaker(
expire_on_commit=False,
autoflush=autoflush,
)
)
session.configure(bind=conn)
yield session()
session_maker.configure(bind=conn)
session = session_maker()
yield session
session.flush()
session.commit()
# vim:sw=4:ts=4:et:

View file

@ -3,7 +3,7 @@ from threading import Thread
from time import time
from typing import Optional, Any, Collection, Mapping
from sqlalchemy import or_
from sqlalchemy import or_, text
from sqlalchemy.orm import make_transient, Session
from platypush.config import Config
@ -198,7 +198,7 @@ class EntitiesPlugin(Plugin):
if str(session.connection().engine.url).startswith('sqlite://'):
# SQLite requires foreign_keys to be explicitly enabled
# in order to proper manage cascade deletions
session.execute('PRAGMA foreign_keys = ON')
session.execute(text('PRAGMA foreign_keys = ON'))
entities: Collection[Entity] = (
session.query(Entity).filter(Entity.id.in_(entities)).all()

View file

@ -1,6 +1,21 @@
from platypush.config import Config
from sqlalchemy import Column, String
from platypush.common.db import declarative_base
from platypush.context import get_plugin
from platypush.plugins import Plugin, action
from platypush.plugins.db import DbPlugin
Base = declarative_base()
# pylint: disable=too-few-public-methods
class Variable(Base):
"""Models the variable table"""
__tablename__ = 'variable'
name = Column(String, primary_key=True, nullable=False)
value = Column(String)
class VariablePlugin(Plugin):
@ -11,8 +26,6 @@ class VariablePlugin(Plugin):
will be stored either persisted on a local database or on the local Redis instance.
"""
_variable_table_name = 'variable'
def __init__(self, **kwargs):
"""
The plugin will create a table named ``variable`` on the database
@ -21,24 +34,14 @@ class VariablePlugin(Plugin):
"""
super().__init__(**kwargs)
self.db_plugin = get_plugin('db')
self.redis_plugin = get_plugin('redis')
db_plugin = get_plugin('db')
redis_plugin = get_plugin('redis')
assert db_plugin, 'Database plugin not configured'
assert redis_plugin, 'Redis plugin not configured'
db = Config.get('db')
self.db_config = {
'engine': db.get('engine'),
'args': db.get('args', []),
'kwargs': db.get('kwargs', {})
}
self._create_tables()
# self._variables = {}
def _create_tables(self):
self.db_plugin.execute("""CREATE TABLE IF NOT EXISTS {}(
name varchar(255) not null primary key,
value text
)""".format(self._variable_table_name))
self.redis_plugin = redis_plugin
self.db_plugin: DbPlugin = db_plugin
self.db_plugin.create_all(self.db_plugin.get_engine(), Base)
@action
def get(self, name, default_value=None):
@ -53,13 +56,10 @@ class VariablePlugin(Plugin):
:returns: A map in the format ``{"<name>":"<value>"}``
"""
rows = self.db_plugin.select(table=self._variable_table_name,
filter={'name': name},
engine=self.db_config['engine'],
*self.db_config['args'],
**self.db_config['kwargs']).output
with self.db_plugin.get_session() as session:
var = session.query(Variable).filter_by(name=name).first()
return {name: rows[0]['value'] if rows else default_value}
return {name: (var.value if var is not None else default_value)}
@action
def set(self, **kwargs):
@ -69,15 +69,24 @@ class VariablePlugin(Plugin):
:param kwargs: Key-value list of variables to set (e.g. ``foo='bar', answer=42``)
"""
records = [{'name': k, 'value': v}
for (k, v) in kwargs.items()]
with self.db_plugin.get_session() as session:
existing_vars = {
var.name: var
for var in session.query(Variable)
.filter(Variable.name.in_(kwargs.keys()))
.all()
}
self.db_plugin.insert(table=self._variable_table_name,
records=records, key_columns=['name'],
engine=self.db_config['engine'],
on_duplicate_update=True,
*self.db_config['args'],
**self.db_config['kwargs'])
new_vars = {
name: Variable(name=name, value=value)
for name, value in kwargs.items()
if name not in existing_vars
}
for name, var in existing_vars.items():
var.value = kwargs[name] # type: ignore
session.add_all([*existing_vars.values(), *new_vars.values()])
return kwargs
@ -90,12 +99,8 @@ class VariablePlugin(Plugin):
:type name: str
"""
records = [{'name': name}]
self.db_plugin.delete(table=self._variable_table_name,
records=records, engine=self.db_config['engine'],
*self.db_config['args'],
**self.db_config['kwargs'])
with self.db_plugin.get_session() as session:
session.query(Variable).filter_by(name=name).delete()
return True
@ -150,4 +155,5 @@ class VariablePlugin(Plugin):
return self.redis_plugin.expire(name, expire)
# vim:sw=4:ts=4:et:

View file

@ -3,11 +3,11 @@ import datetime
import hashlib
import json
import random
import rsa
import time
from typing import Optional, Dict
import bcrypt
import rsa
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import make_transient
@ -73,7 +73,7 @@ class UserManager:
username=username,
password=self._encrypt_password(password),
created_at=datetime.datetime.utcnow(),
**kwargs
**kwargs,
)
session.add(record)
@ -238,9 +238,7 @@ class UserManager:
indent=None,
)
return base64.b64encode(
rsa.encrypt(payload.encode('ascii'), pub_key)
).decode()
return base64.b64encode(rsa.encrypt(payload.encode('ascii'), pub_key)).decode()
def validate_jwt_token(self, token: str) -> Dict[str, str]:
"""
@ -263,21 +261,19 @@ class UserManager:
try:
payload = json.loads(
rsa.decrypt(
base64.b64decode(token.encode('ascii')),
priv_key
).decode('ascii')
rsa.decrypt(base64.b64decode(token.encode('ascii')), priv_key).decode(
'ascii'
)
)
except (TypeError, ValueError) as e:
raise InvalidJWTTokenException(f'Could not decode JWT token: {e}')
raise InvalidJWTTokenException(f'Could not decode JWT token: {e}') from e
expires_at = payload.get('expires_at')
if expires_at and time.time() > expires_at:
raise InvalidJWTTokenException('Expired JWT token')
user = self.authenticate_user(
payload.get('username', ''),
payload.get('password', '')
payload.get('username', ''), payload.get('password', '')
)
if not user:

View file

@ -57,7 +57,7 @@ setup(
'redis',
'requests',
'croniter',
'sqlalchemy<2.0.0',
'sqlalchemy',
'websockets',
'websocket-client',
'wheel',