forked from platypush/platypush
Merge pull request 'Fixed compatibility with SQLAlchemy >= 2.0' (#250) from 239-sqlalchemy-2-compatibility into master
Reviewed-on: platypush/platypush#250 Closes: #239
This commit is contained in:
commit
99382e4505
5 changed files with 97 additions and 93 deletions
|
@ -25,7 +25,7 @@ class DbPlugin(Plugin):
|
|||
_db_error_wait_interval = 5.0
|
||||
_db_error_retries = 3
|
||||
|
||||
def __init__(self, engine=None, *args, **kwargs):
|
||||
def __init__(self, engine=None, **kwargs):
|
||||
"""
|
||||
:param engine: Default SQLAlchemy connection engine string (e.g.
|
||||
``sqlite:///:memory:`` or ``mysql://user:pass@localhost/test``)
|
||||
|
@ -42,7 +42,7 @@ class DbPlugin(Plugin):
|
|||
|
||||
super().__init__()
|
||||
self.engine_url = engine
|
||||
self.engine = self.get_engine(engine, *args, **kwargs)
|
||||
self.engine = self.get_engine(engine, **kwargs)
|
||||
|
||||
def get_engine(
|
||||
self, engine: Optional[Union[str, Engine]] = None, *args, **kwargs
|
||||
|
@ -66,16 +66,14 @@ class DbPlugin(Plugin):
|
|||
return self.engine
|
||||
|
||||
@staticmethod
|
||||
def _build_condition(table, column, value): # type: ignore
|
||||
if isinstance(value, str):
|
||||
value = "'{}'".format(value)
|
||||
elif not isinstance(value, int) and not isinstance(value, float):
|
||||
value = "'{}'".format(str(value))
|
||||
def _build_condition(table, column, value): # pylint: disable=unused-argument
|
||||
if isinstance(value, str) or not isinstance(value, (int, float)):
|
||||
value = f"'{value}'"
|
||||
|
||||
return eval('table.c.{}=={}'.format(column, value))
|
||||
return eval(f'table.c.{column}=={value}') # pylint: disable=eval-used
|
||||
|
||||
@action
|
||||
def execute(self, statement, engine=None, *args, **kwargs):
|
||||
def execute(self, statement, *args, engine=None, **kwargs):
|
||||
"""
|
||||
Executes a raw SQL statement.
|
||||
|
||||
|
@ -98,37 +96,37 @@ class DbPlugin(Plugin):
|
|||
(see https:///docs.sqlalchemy.org/en/latest/core/engines.html)
|
||||
"""
|
||||
|
||||
engine = self.get_engine(engine, *args, **kwargs)
|
||||
with self.get_engine(engine, *args, **kwargs).connect() as connection:
|
||||
connection.execute(text(statement))
|
||||
|
||||
with engine.connect() as connection:
|
||||
connection.execute(statement)
|
||||
|
||||
def _get_table(self, table, engine=None, *args, **kwargs):
|
||||
def _get_table(self, table: str, *args, engine=None, **kwargs):
|
||||
if not engine:
|
||||
engine = self.get_engine(engine, *args, **kwargs)
|
||||
|
||||
db_ok = False
|
||||
n_tries = 0
|
||||
last_error = None
|
||||
table_ = None
|
||||
|
||||
while not db_ok and n_tries < self._db_error_retries:
|
||||
try:
|
||||
n_tries += 1
|
||||
metadata = MetaData()
|
||||
table = Table(table, metadata, autoload=True, autoload_with=engine)
|
||||
table_ = Table(table, metadata, autoload_with=engine)
|
||||
db_ok = True
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
wait_time = self._db_error_wait_interval * n_tries
|
||||
self.logger.exception(e)
|
||||
self.logger.info('Waiting {} seconds before retrying'.format(wait_time))
|
||||
self.logger.info('Waiting %s seconds before retrying', wait_time)
|
||||
time.sleep(wait_time)
|
||||
engine = self.get_engine(engine, *args, **kwargs)
|
||||
|
||||
if not db_ok and last_error:
|
||||
raise last_error
|
||||
|
||||
return table, engine
|
||||
assert table_, f'No such table: {table}'
|
||||
return table_, engine
|
||||
|
||||
@action
|
||||
def select(
|
||||
|
@ -324,7 +322,6 @@ class DbPlugin(Plugin):
|
|||
connection, table, records, key_columns
|
||||
)
|
||||
|
||||
with connection.begin():
|
||||
if insert_records:
|
||||
insert = table.insert().values(insert_records)
|
||||
ret = self._execute_try_returning(connection, insert)
|
||||
|
@ -365,8 +362,9 @@ class DbPlugin(Plugin):
|
|||
|
||||
query = table.select().where(
|
||||
or_(
|
||||
and_(
|
||||
self._build_condition(table, k, record.get(k)) for k in key_columns
|
||||
and_( # type: ignore
|
||||
self._build_condition(table, k, record.get(k))
|
||||
for k in key_columns # type: ignore
|
||||
)
|
||||
for record in records
|
||||
)
|
||||
|
@ -498,13 +496,13 @@ class DbPlugin(Plugin):
|
|||
|
||||
engine = self.get_engine(engine, *args, **kwargs)
|
||||
|
||||
with engine.connect() as connection, connection.begin():
|
||||
with engine.connect() as connection:
|
||||
for record in records:
|
||||
table, engine = self._get_table(table, engine=engine, *args, **kwargs)
|
||||
delete = table.delete()
|
||||
table_, engine = self._get_table(table, engine=engine, *args, **kwargs)
|
||||
delete = table_.delete()
|
||||
|
||||
for k, v in record.items():
|
||||
delete = delete.where(self._build_condition(table, k, v))
|
||||
delete = delete.where(self._build_condition(table_, k, v))
|
||||
|
||||
connection.execute(delete)
|
||||
|
||||
|
@ -514,7 +512,7 @@ class DbPlugin(Plugin):
|
|||
|
||||
@contextmanager
|
||||
def get_session(
|
||||
self, engine=None, locked=False, autoflush=True, *args, **kwargs
|
||||
self, *args, engine=None, locked=False, autoflush=True, **kwargs
|
||||
) -> Generator[Session, None, None]:
|
||||
engine = self.get_engine(engine, *args, **kwargs)
|
||||
if locked:
|
||||
|
@ -523,16 +521,20 @@ class DbPlugin(Plugin):
|
|||
# Mock lock
|
||||
lock = RLock()
|
||||
|
||||
with lock, engine.connect() as conn, conn.begin():
|
||||
session = scoped_session(
|
||||
with lock, engine.connect() as conn:
|
||||
session_maker = scoped_session(
|
||||
sessionmaker(
|
||||
expire_on_commit=False,
|
||||
autoflush=autoflush,
|
||||
)
|
||||
)
|
||||
|
||||
session.configure(bind=conn)
|
||||
yield session()
|
||||
session_maker.configure(bind=conn)
|
||||
session = session_maker()
|
||||
yield session
|
||||
|
||||
session.flush()
|
||||
session.commit()
|
||||
|
||||
|
||||
# vim:sw=4:ts=4:et:
|
||||
|
|
|
@ -3,7 +3,7 @@ from threading import Thread
|
|||
from time import time
|
||||
from typing import Optional, Any, Collection, Mapping
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import or_, text
|
||||
from sqlalchemy.orm import make_transient, Session
|
||||
|
||||
from platypush.config import Config
|
||||
|
@ -198,7 +198,7 @@ class EntitiesPlugin(Plugin):
|
|||
if str(session.connection().engine.url).startswith('sqlite://'):
|
||||
# SQLite requires foreign_keys to be explicitly enabled
|
||||
# in order to proper manage cascade deletions
|
||||
session.execute('PRAGMA foreign_keys = ON')
|
||||
session.execute(text('PRAGMA foreign_keys = ON'))
|
||||
|
||||
entities: Collection[Entity] = (
|
||||
session.query(Entity).filter(Entity.id.in_(entities)).all()
|
||||
|
|
|
@ -1,6 +1,21 @@
|
|||
from platypush.config import Config
|
||||
from sqlalchemy import Column, String
|
||||
|
||||
from platypush.common.db import declarative_base
|
||||
from platypush.context import get_plugin
|
||||
from platypush.plugins import Plugin, action
|
||||
from platypush.plugins.db import DbPlugin
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
class Variable(Base):
|
||||
"""Models the variable table"""
|
||||
|
||||
__tablename__ = 'variable'
|
||||
|
||||
name = Column(String, primary_key=True, nullable=False)
|
||||
value = Column(String)
|
||||
|
||||
|
||||
class VariablePlugin(Plugin):
|
||||
|
@ -11,8 +26,6 @@ class VariablePlugin(Plugin):
|
|||
will be stored either persisted on a local database or on the local Redis instance.
|
||||
"""
|
||||
|
||||
_variable_table_name = 'variable'
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
The plugin will create a table named ``variable`` on the database
|
||||
|
@ -21,24 +34,14 @@ class VariablePlugin(Plugin):
|
|||
"""
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self.db_plugin = get_plugin('db')
|
||||
self.redis_plugin = get_plugin('redis')
|
||||
db_plugin = get_plugin('db')
|
||||
redis_plugin = get_plugin('redis')
|
||||
assert db_plugin, 'Database plugin not configured'
|
||||
assert redis_plugin, 'Redis plugin not configured'
|
||||
|
||||
db = Config.get('db')
|
||||
self.db_config = {
|
||||
'engine': db.get('engine'),
|
||||
'args': db.get('args', []),
|
||||
'kwargs': db.get('kwargs', {})
|
||||
}
|
||||
|
||||
self._create_tables()
|
||||
# self._variables = {}
|
||||
|
||||
def _create_tables(self):
|
||||
self.db_plugin.execute("""CREATE TABLE IF NOT EXISTS {}(
|
||||
name varchar(255) not null primary key,
|
||||
value text
|
||||
)""".format(self._variable_table_name))
|
||||
self.redis_plugin = redis_plugin
|
||||
self.db_plugin: DbPlugin = db_plugin
|
||||
self.db_plugin.create_all(self.db_plugin.get_engine(), Base)
|
||||
|
||||
@action
|
||||
def get(self, name, default_value=None):
|
||||
|
@ -53,13 +56,10 @@ class VariablePlugin(Plugin):
|
|||
:returns: A map in the format ``{"<name>":"<value>"}``
|
||||
"""
|
||||
|
||||
rows = self.db_plugin.select(table=self._variable_table_name,
|
||||
filter={'name': name},
|
||||
engine=self.db_config['engine'],
|
||||
*self.db_config['args'],
|
||||
**self.db_config['kwargs']).output
|
||||
with self.db_plugin.get_session() as session:
|
||||
var = session.query(Variable).filter_by(name=name).first()
|
||||
|
||||
return {name: rows[0]['value'] if rows else default_value}
|
||||
return {name: (var.value if var is not None else default_value)}
|
||||
|
||||
@action
|
||||
def set(self, **kwargs):
|
||||
|
@ -69,15 +69,24 @@ class VariablePlugin(Plugin):
|
|||
:param kwargs: Key-value list of variables to set (e.g. ``foo='bar', answer=42``)
|
||||
"""
|
||||
|
||||
records = [{'name': k, 'value': v}
|
||||
for (k, v) in kwargs.items()]
|
||||
with self.db_plugin.get_session() as session:
|
||||
existing_vars = {
|
||||
var.name: var
|
||||
for var in session.query(Variable)
|
||||
.filter(Variable.name.in_(kwargs.keys()))
|
||||
.all()
|
||||
}
|
||||
|
||||
self.db_plugin.insert(table=self._variable_table_name,
|
||||
records=records, key_columns=['name'],
|
||||
engine=self.db_config['engine'],
|
||||
on_duplicate_update=True,
|
||||
*self.db_config['args'],
|
||||
**self.db_config['kwargs'])
|
||||
new_vars = {
|
||||
name: Variable(name=name, value=value)
|
||||
for name, value in kwargs.items()
|
||||
if name not in existing_vars
|
||||
}
|
||||
|
||||
for name, var in existing_vars.items():
|
||||
var.value = kwargs[name] # type: ignore
|
||||
|
||||
session.add_all([*existing_vars.values(), *new_vars.values()])
|
||||
|
||||
return kwargs
|
||||
|
||||
|
@ -90,12 +99,8 @@ class VariablePlugin(Plugin):
|
|||
:type name: str
|
||||
"""
|
||||
|
||||
records = [{'name': name}]
|
||||
|
||||
self.db_plugin.delete(table=self._variable_table_name,
|
||||
records=records, engine=self.db_config['engine'],
|
||||
*self.db_config['args'],
|
||||
**self.db_config['kwargs'])
|
||||
with self.db_plugin.get_session() as session:
|
||||
session.query(Variable).filter_by(name=name).delete()
|
||||
|
||||
return True
|
||||
|
||||
|
@ -150,4 +155,5 @@ class VariablePlugin(Plugin):
|
|||
|
||||
return self.redis_plugin.expire(name, expire)
|
||||
|
||||
|
||||
# vim:sw=4:ts=4:et:
|
||||
|
|
|
@ -3,11 +3,11 @@ import datetime
|
|||
import hashlib
|
||||
import json
|
||||
import random
|
||||
import rsa
|
||||
import time
|
||||
from typing import Optional, Dict
|
||||
|
||||
import bcrypt
|
||||
import rsa
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import make_transient
|
||||
|
@ -73,7 +73,7 @@ class UserManager:
|
|||
username=username,
|
||||
password=self._encrypt_password(password),
|
||||
created_at=datetime.datetime.utcnow(),
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
session.add(record)
|
||||
|
@ -238,9 +238,7 @@ class UserManager:
|
|||
indent=None,
|
||||
)
|
||||
|
||||
return base64.b64encode(
|
||||
rsa.encrypt(payload.encode('ascii'), pub_key)
|
||||
).decode()
|
||||
return base64.b64encode(rsa.encrypt(payload.encode('ascii'), pub_key)).decode()
|
||||
|
||||
def validate_jwt_token(self, token: str) -> Dict[str, str]:
|
||||
"""
|
||||
|
@ -263,21 +261,19 @@ class UserManager:
|
|||
|
||||
try:
|
||||
payload = json.loads(
|
||||
rsa.decrypt(
|
||||
base64.b64decode(token.encode('ascii')),
|
||||
priv_key
|
||||
).decode('ascii')
|
||||
rsa.decrypt(base64.b64decode(token.encode('ascii')), priv_key).decode(
|
||||
'ascii'
|
||||
)
|
||||
)
|
||||
except (TypeError, ValueError) as e:
|
||||
raise InvalidJWTTokenException(f'Could not decode JWT token: {e}')
|
||||
raise InvalidJWTTokenException(f'Could not decode JWT token: {e}') from e
|
||||
|
||||
expires_at = payload.get('expires_at')
|
||||
if expires_at and time.time() > expires_at:
|
||||
raise InvalidJWTTokenException('Expired JWT token')
|
||||
|
||||
user = self.authenticate_user(
|
||||
payload.get('username', ''),
|
||||
payload.get('password', '')
|
||||
payload.get('username', ''), payload.get('password', '')
|
||||
)
|
||||
|
||||
if not user:
|
||||
|
|
2
setup.py
2
setup.py
|
@ -57,7 +57,7 @@ setup(
|
|||
'redis',
|
||||
'requests',
|
||||
'croniter',
|
||||
'sqlalchemy<2.0.0',
|
||||
'sqlalchemy',
|
||||
'websockets',
|
||||
'websocket-client',
|
||||
'wheel',
|
||||
|
|
Loading…
Reference in a new issue