Merge pull request 'Fixed compatibility with SQLAlchemy >= 2.0' (#250) from 239-sqlalchemy-2-compatibility into master
Reviewed-on: #250 Closes: #239
This commit is contained in:
commit
99382e4505
5 changed files with 97 additions and 93 deletions
|
@ -25,7 +25,7 @@ class DbPlugin(Plugin):
|
||||||
_db_error_wait_interval = 5.0
|
_db_error_wait_interval = 5.0
|
||||||
_db_error_retries = 3
|
_db_error_retries = 3
|
||||||
|
|
||||||
def __init__(self, engine=None, *args, **kwargs):
|
def __init__(self, engine=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
:param engine: Default SQLAlchemy connection engine string (e.g.
|
:param engine: Default SQLAlchemy connection engine string (e.g.
|
||||||
``sqlite:///:memory:`` or ``mysql://user:pass@localhost/test``)
|
``sqlite:///:memory:`` or ``mysql://user:pass@localhost/test``)
|
||||||
|
@ -42,7 +42,7 @@ class DbPlugin(Plugin):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.engine_url = engine
|
self.engine_url = engine
|
||||||
self.engine = self.get_engine(engine, *args, **kwargs)
|
self.engine = self.get_engine(engine, **kwargs)
|
||||||
|
|
||||||
def get_engine(
|
def get_engine(
|
||||||
self, engine: Optional[Union[str, Engine]] = None, *args, **kwargs
|
self, engine: Optional[Union[str, Engine]] = None, *args, **kwargs
|
||||||
|
@ -66,16 +66,14 @@ class DbPlugin(Plugin):
|
||||||
return self.engine
|
return self.engine
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_condition(table, column, value): # type: ignore
|
def _build_condition(table, column, value): # pylint: disable=unused-argument
|
||||||
if isinstance(value, str):
|
if isinstance(value, str) or not isinstance(value, (int, float)):
|
||||||
value = "'{}'".format(value)
|
value = f"'{value}'"
|
||||||
elif not isinstance(value, int) and not isinstance(value, float):
|
|
||||||
value = "'{}'".format(str(value))
|
|
||||||
|
|
||||||
return eval('table.c.{}=={}'.format(column, value))
|
return eval(f'table.c.{column}=={value}') # pylint: disable=eval-used
|
||||||
|
|
||||||
@action
|
@action
|
||||||
def execute(self, statement, engine=None, *args, **kwargs):
|
def execute(self, statement, *args, engine=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Executes a raw SQL statement.
|
Executes a raw SQL statement.
|
||||||
|
|
||||||
|
@ -98,37 +96,37 @@ class DbPlugin(Plugin):
|
||||||
(see https:///docs.sqlalchemy.org/en/latest/core/engines.html)
|
(see https:///docs.sqlalchemy.org/en/latest/core/engines.html)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
engine = self.get_engine(engine, *args, **kwargs)
|
with self.get_engine(engine, *args, **kwargs).connect() as connection:
|
||||||
|
connection.execute(text(statement))
|
||||||
|
|
||||||
with engine.connect() as connection:
|
def _get_table(self, table: str, *args, engine=None, **kwargs):
|
||||||
connection.execute(statement)
|
|
||||||
|
|
||||||
def _get_table(self, table, engine=None, *args, **kwargs):
|
|
||||||
if not engine:
|
if not engine:
|
||||||
engine = self.get_engine(engine, *args, **kwargs)
|
engine = self.get_engine(engine, *args, **kwargs)
|
||||||
|
|
||||||
db_ok = False
|
db_ok = False
|
||||||
n_tries = 0
|
n_tries = 0
|
||||||
last_error = None
|
last_error = None
|
||||||
|
table_ = None
|
||||||
|
|
||||||
while not db_ok and n_tries < self._db_error_retries:
|
while not db_ok and n_tries < self._db_error_retries:
|
||||||
try:
|
try:
|
||||||
n_tries += 1
|
n_tries += 1
|
||||||
metadata = MetaData()
|
metadata = MetaData()
|
||||||
table = Table(table, metadata, autoload=True, autoload_with=engine)
|
table_ = Table(table, metadata, autoload_with=engine)
|
||||||
db_ok = True
|
db_ok = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
last_error = e
|
last_error = e
|
||||||
wait_time = self._db_error_wait_interval * n_tries
|
wait_time = self._db_error_wait_interval * n_tries
|
||||||
self.logger.exception(e)
|
self.logger.exception(e)
|
||||||
self.logger.info('Waiting {} seconds before retrying'.format(wait_time))
|
self.logger.info('Waiting %s seconds before retrying', wait_time)
|
||||||
time.sleep(wait_time)
|
time.sleep(wait_time)
|
||||||
engine = self.get_engine(engine, *args, **kwargs)
|
engine = self.get_engine(engine, *args, **kwargs)
|
||||||
|
|
||||||
if not db_ok and last_error:
|
if not db_ok and last_error:
|
||||||
raise last_error
|
raise last_error
|
||||||
|
|
||||||
return table, engine
|
assert table_, f'No such table: {table}'
|
||||||
|
return table_, engine
|
||||||
|
|
||||||
@action
|
@action
|
||||||
def select(
|
def select(
|
||||||
|
@ -324,17 +322,16 @@ class DbPlugin(Plugin):
|
||||||
connection, table, records, key_columns
|
connection, table, records, key_columns
|
||||||
)
|
)
|
||||||
|
|
||||||
with connection.begin():
|
if insert_records:
|
||||||
if insert_records:
|
insert = table.insert().values(insert_records)
|
||||||
insert = table.insert().values(insert_records)
|
ret = self._execute_try_returning(connection, insert)
|
||||||
ret = self._execute_try_returning(connection, insert)
|
if ret:
|
||||||
if ret:
|
returned_records += ret
|
||||||
returned_records += ret
|
|
||||||
|
|
||||||
if update_records and on_duplicate_update:
|
if update_records and on_duplicate_update:
|
||||||
ret = self._update(connection, table, update_records, key_columns)
|
ret = self._update(connection, table, update_records, key_columns)
|
||||||
if ret:
|
if ret:
|
||||||
returned_records = ret + returned_records
|
returned_records = ret + returned_records
|
||||||
|
|
||||||
if returned_records:
|
if returned_records:
|
||||||
return returned_records
|
return returned_records
|
||||||
|
@ -365,8 +362,9 @@ class DbPlugin(Plugin):
|
||||||
|
|
||||||
query = table.select().where(
|
query = table.select().where(
|
||||||
or_(
|
or_(
|
||||||
and_(
|
and_( # type: ignore
|
||||||
self._build_condition(table, k, record.get(k)) for k in key_columns
|
self._build_condition(table, k, record.get(k))
|
||||||
|
for k in key_columns # type: ignore
|
||||||
)
|
)
|
||||||
for record in records
|
for record in records
|
||||||
)
|
)
|
||||||
|
@ -498,13 +496,13 @@ class DbPlugin(Plugin):
|
||||||
|
|
||||||
engine = self.get_engine(engine, *args, **kwargs)
|
engine = self.get_engine(engine, *args, **kwargs)
|
||||||
|
|
||||||
with engine.connect() as connection, connection.begin():
|
with engine.connect() as connection:
|
||||||
for record in records:
|
for record in records:
|
||||||
table, engine = self._get_table(table, engine=engine, *args, **kwargs)
|
table_, engine = self._get_table(table, engine=engine, *args, **kwargs)
|
||||||
delete = table.delete()
|
delete = table_.delete()
|
||||||
|
|
||||||
for k, v in record.items():
|
for k, v in record.items():
|
||||||
delete = delete.where(self._build_condition(table, k, v))
|
delete = delete.where(self._build_condition(table_, k, v))
|
||||||
|
|
||||||
connection.execute(delete)
|
connection.execute(delete)
|
||||||
|
|
||||||
|
@ -514,7 +512,7 @@ class DbPlugin(Plugin):
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def get_session(
|
def get_session(
|
||||||
self, engine=None, locked=False, autoflush=True, *args, **kwargs
|
self, *args, engine=None, locked=False, autoflush=True, **kwargs
|
||||||
) -> Generator[Session, None, None]:
|
) -> Generator[Session, None, None]:
|
||||||
engine = self.get_engine(engine, *args, **kwargs)
|
engine = self.get_engine(engine, *args, **kwargs)
|
||||||
if locked:
|
if locked:
|
||||||
|
@ -523,16 +521,20 @@ class DbPlugin(Plugin):
|
||||||
# Mock lock
|
# Mock lock
|
||||||
lock = RLock()
|
lock = RLock()
|
||||||
|
|
||||||
with lock, engine.connect() as conn, conn.begin():
|
with lock, engine.connect() as conn:
|
||||||
session = scoped_session(
|
session_maker = scoped_session(
|
||||||
sessionmaker(
|
sessionmaker(
|
||||||
expire_on_commit=False,
|
expire_on_commit=False,
|
||||||
autoflush=autoflush,
|
autoflush=autoflush,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
session.configure(bind=conn)
|
session_maker.configure(bind=conn)
|
||||||
yield session()
|
session = session_maker()
|
||||||
|
yield session
|
||||||
|
|
||||||
|
session.flush()
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
# vim:sw=4:ts=4:et:
|
# vim:sw=4:ts=4:et:
|
||||||
|
|
|
@ -3,7 +3,7 @@ from threading import Thread
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Optional, Any, Collection, Mapping
|
from typing import Optional, Any, Collection, Mapping
|
||||||
|
|
||||||
from sqlalchemy import or_
|
from sqlalchemy import or_, text
|
||||||
from sqlalchemy.orm import make_transient, Session
|
from sqlalchemy.orm import make_transient, Session
|
||||||
|
|
||||||
from platypush.config import Config
|
from platypush.config import Config
|
||||||
|
@ -198,7 +198,7 @@ class EntitiesPlugin(Plugin):
|
||||||
if str(session.connection().engine.url).startswith('sqlite://'):
|
if str(session.connection().engine.url).startswith('sqlite://'):
|
||||||
# SQLite requires foreign_keys to be explicitly enabled
|
# SQLite requires foreign_keys to be explicitly enabled
|
||||||
# in order to proper manage cascade deletions
|
# in order to proper manage cascade deletions
|
||||||
session.execute('PRAGMA foreign_keys = ON')
|
session.execute(text('PRAGMA foreign_keys = ON'))
|
||||||
|
|
||||||
entities: Collection[Entity] = (
|
entities: Collection[Entity] = (
|
||||||
session.query(Entity).filter(Entity.id.in_(entities)).all()
|
session.query(Entity).filter(Entity.id.in_(entities)).all()
|
||||||
|
|
|
@ -1,6 +1,21 @@
|
||||||
from platypush.config import Config
|
from sqlalchemy import Column, String
|
||||||
|
|
||||||
|
from platypush.common.db import declarative_base
|
||||||
from platypush.context import get_plugin
|
from platypush.context import get_plugin
|
||||||
from platypush.plugins import Plugin, action
|
from platypush.plugins import Plugin, action
|
||||||
|
from platypush.plugins.db import DbPlugin
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=too-few-public-methods
|
||||||
|
class Variable(Base):
|
||||||
|
"""Models the variable table"""
|
||||||
|
|
||||||
|
__tablename__ = 'variable'
|
||||||
|
|
||||||
|
name = Column(String, primary_key=True, nullable=False)
|
||||||
|
value = Column(String)
|
||||||
|
|
||||||
|
|
||||||
class VariablePlugin(Plugin):
|
class VariablePlugin(Plugin):
|
||||||
|
@ -11,8 +26,6 @@ class VariablePlugin(Plugin):
|
||||||
will be stored either persisted on a local database or on the local Redis instance.
|
will be stored either persisted on a local database or on the local Redis instance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_variable_table_name = 'variable'
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
The plugin will create a table named ``variable`` on the database
|
The plugin will create a table named ``variable`` on the database
|
||||||
|
@ -21,24 +34,14 @@ class VariablePlugin(Plugin):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.db_plugin = get_plugin('db')
|
db_plugin = get_plugin('db')
|
||||||
self.redis_plugin = get_plugin('redis')
|
redis_plugin = get_plugin('redis')
|
||||||
|
assert db_plugin, 'Database plugin not configured'
|
||||||
|
assert redis_plugin, 'Redis plugin not configured'
|
||||||
|
|
||||||
db = Config.get('db')
|
self.redis_plugin = redis_plugin
|
||||||
self.db_config = {
|
self.db_plugin: DbPlugin = db_plugin
|
||||||
'engine': db.get('engine'),
|
self.db_plugin.create_all(self.db_plugin.get_engine(), Base)
|
||||||
'args': db.get('args', []),
|
|
||||||
'kwargs': db.get('kwargs', {})
|
|
||||||
}
|
|
||||||
|
|
||||||
self._create_tables()
|
|
||||||
# self._variables = {}
|
|
||||||
|
|
||||||
def _create_tables(self):
|
|
||||||
self.db_plugin.execute("""CREATE TABLE IF NOT EXISTS {}(
|
|
||||||
name varchar(255) not null primary key,
|
|
||||||
value text
|
|
||||||
)""".format(self._variable_table_name))
|
|
||||||
|
|
||||||
@action
|
@action
|
||||||
def get(self, name, default_value=None):
|
def get(self, name, default_value=None):
|
||||||
|
@ -53,13 +56,10 @@ class VariablePlugin(Plugin):
|
||||||
:returns: A map in the format ``{"<name>":"<value>"}``
|
:returns: A map in the format ``{"<name>":"<value>"}``
|
||||||
"""
|
"""
|
||||||
|
|
||||||
rows = self.db_plugin.select(table=self._variable_table_name,
|
with self.db_plugin.get_session() as session:
|
||||||
filter={'name': name},
|
var = session.query(Variable).filter_by(name=name).first()
|
||||||
engine=self.db_config['engine'],
|
|
||||||
*self.db_config['args'],
|
|
||||||
**self.db_config['kwargs']).output
|
|
||||||
|
|
||||||
return {name: rows[0]['value'] if rows else default_value}
|
return {name: (var.value if var is not None else default_value)}
|
||||||
|
|
||||||
@action
|
@action
|
||||||
def set(self, **kwargs):
|
def set(self, **kwargs):
|
||||||
|
@ -69,15 +69,24 @@ class VariablePlugin(Plugin):
|
||||||
:param kwargs: Key-value list of variables to set (e.g. ``foo='bar', answer=42``)
|
:param kwargs: Key-value list of variables to set (e.g. ``foo='bar', answer=42``)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
records = [{'name': k, 'value': v}
|
with self.db_plugin.get_session() as session:
|
||||||
for (k, v) in kwargs.items()]
|
existing_vars = {
|
||||||
|
var.name: var
|
||||||
|
for var in session.query(Variable)
|
||||||
|
.filter(Variable.name.in_(kwargs.keys()))
|
||||||
|
.all()
|
||||||
|
}
|
||||||
|
|
||||||
self.db_plugin.insert(table=self._variable_table_name,
|
new_vars = {
|
||||||
records=records, key_columns=['name'],
|
name: Variable(name=name, value=value)
|
||||||
engine=self.db_config['engine'],
|
for name, value in kwargs.items()
|
||||||
on_duplicate_update=True,
|
if name not in existing_vars
|
||||||
*self.db_config['args'],
|
}
|
||||||
**self.db_config['kwargs'])
|
|
||||||
|
for name, var in existing_vars.items():
|
||||||
|
var.value = kwargs[name] # type: ignore
|
||||||
|
|
||||||
|
session.add_all([*existing_vars.values(), *new_vars.values()])
|
||||||
|
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
@ -90,12 +99,8 @@ class VariablePlugin(Plugin):
|
||||||
:type name: str
|
:type name: str
|
||||||
"""
|
"""
|
||||||
|
|
||||||
records = [{'name': name}]
|
with self.db_plugin.get_session() as session:
|
||||||
|
session.query(Variable).filter_by(name=name).delete()
|
||||||
self.db_plugin.delete(table=self._variable_table_name,
|
|
||||||
records=records, engine=self.db_config['engine'],
|
|
||||||
*self.db_config['args'],
|
|
||||||
**self.db_config['kwargs'])
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -150,4 +155,5 @@ class VariablePlugin(Plugin):
|
||||||
|
|
||||||
return self.redis_plugin.expire(name, expire)
|
return self.redis_plugin.expire(name, expire)
|
||||||
|
|
||||||
|
|
||||||
# vim:sw=4:ts=4:et:
|
# vim:sw=4:ts=4:et:
|
||||||
|
|
|
@ -3,11 +3,11 @@ import datetime
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import rsa
|
|
||||||
import time
|
import time
|
||||||
from typing import Optional, Dict
|
from typing import Optional, Dict
|
||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
|
import rsa
|
||||||
|
|
||||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
||||||
from sqlalchemy.orm import make_transient
|
from sqlalchemy.orm import make_transient
|
||||||
|
@ -73,7 +73,7 @@ class UserManager:
|
||||||
username=username,
|
username=username,
|
||||||
password=self._encrypt_password(password),
|
password=self._encrypt_password(password),
|
||||||
created_at=datetime.datetime.utcnow(),
|
created_at=datetime.datetime.utcnow(),
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
session.add(record)
|
session.add(record)
|
||||||
|
@ -238,9 +238,7 @@ class UserManager:
|
||||||
indent=None,
|
indent=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
return base64.b64encode(
|
return base64.b64encode(rsa.encrypt(payload.encode('ascii'), pub_key)).decode()
|
||||||
rsa.encrypt(payload.encode('ascii'), pub_key)
|
|
||||||
).decode()
|
|
||||||
|
|
||||||
def validate_jwt_token(self, token: str) -> Dict[str, str]:
|
def validate_jwt_token(self, token: str) -> Dict[str, str]:
|
||||||
"""
|
"""
|
||||||
|
@ -263,21 +261,19 @@ class UserManager:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = json.loads(
|
payload = json.loads(
|
||||||
rsa.decrypt(
|
rsa.decrypt(base64.b64decode(token.encode('ascii')), priv_key).decode(
|
||||||
base64.b64decode(token.encode('ascii')),
|
'ascii'
|
||||||
priv_key
|
)
|
||||||
).decode('ascii')
|
|
||||||
)
|
)
|
||||||
except (TypeError, ValueError) as e:
|
except (TypeError, ValueError) as e:
|
||||||
raise InvalidJWTTokenException(f'Could not decode JWT token: {e}')
|
raise InvalidJWTTokenException(f'Could not decode JWT token: {e}') from e
|
||||||
|
|
||||||
expires_at = payload.get('expires_at')
|
expires_at = payload.get('expires_at')
|
||||||
if expires_at and time.time() > expires_at:
|
if expires_at and time.time() > expires_at:
|
||||||
raise InvalidJWTTokenException('Expired JWT token')
|
raise InvalidJWTTokenException('Expired JWT token')
|
||||||
|
|
||||||
user = self.authenticate_user(
|
user = self.authenticate_user(
|
||||||
payload.get('username', ''),
|
payload.get('username', ''), payload.get('password', '')
|
||||||
payload.get('password', '')
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -57,7 +57,7 @@ setup(
|
||||||
'redis',
|
'redis',
|
||||||
'requests',
|
'requests',
|
||||||
'croniter',
|
'croniter',
|
||||||
'sqlalchemy<2.0.0',
|
'sqlalchemy',
|
||||||
'websockets',
|
'websockets',
|
||||||
'websocket-client',
|
'websocket-client',
|
||||||
'wheel',
|
'wheel',
|
||||||
|
|
Loading…
Reference in a new issue