diff --git a/platypush/plugins/db/__init__.py b/platypush/plugins/db/__init__.py index a1d4c51a..70708bd8 100644 --- a/platypush/plugins/db/__init__.py +++ b/platypush/plugins/db/__init__.py @@ -25,7 +25,7 @@ class DbPlugin(Plugin): _db_error_wait_interval = 5.0 _db_error_retries = 3 - def __init__(self, engine=None, *args, **kwargs): + def __init__(self, engine=None, **kwargs): """ :param engine: Default SQLAlchemy connection engine string (e.g. ``sqlite:///:memory:`` or ``mysql://user:pass@localhost/test``) @@ -42,7 +42,7 @@ class DbPlugin(Plugin): super().__init__() self.engine_url = engine - self.engine = self.get_engine(engine, *args, **kwargs) + self.engine = self.get_engine(engine, **kwargs) def get_engine( self, engine: Optional[Union[str, Engine]] = None, *args, **kwargs @@ -66,16 +66,14 @@ class DbPlugin(Plugin): return self.engine @staticmethod - def _build_condition(table, column, value): # type: ignore - if isinstance(value, str): - value = "'{}'".format(value) - elif not isinstance(value, int) and not isinstance(value, float): - value = "'{}'".format(str(value)) + def _build_condition(table, column, value): # pylint: disable=unused-argument + if isinstance(value, str) or not isinstance(value, (int, float)): + value = f"'{value}'" - return eval('table.c.{}=={}'.format(column, value)) + return eval(f'table.c.{column}=={value}') # pylint: disable=eval-used @action - def execute(self, statement, engine=None, *args, **kwargs): + def execute(self, statement, *args, engine=None, **kwargs): """ Executes a raw SQL statement. @@ -98,37 +96,37 @@ class DbPlugin(Plugin): (see https:///docs.sqlalchemy.org/en/latest/core/engines.html) """ - engine = self.get_engine(engine, *args, **kwargs) + with self.get_engine(engine, *args, **kwargs).connect() as connection: + connection.execute(text(statement)) - with engine.connect() as connection: - connection.execute(statement) - - def _get_table(self, table, engine=None, *args, **kwargs): + def _get_table(self, table: str, *args, engine=None, **kwargs): if not engine: engine = self.get_engine(engine, *args, **kwargs) db_ok = False n_tries = 0 last_error = None + table_ = None while not db_ok and n_tries < self._db_error_retries: try: n_tries += 1 metadata = MetaData() - table = Table(table, metadata, autoload=True, autoload_with=engine) + table_ = Table(table, metadata, autoload_with=engine) db_ok = True except Exception as e: last_error = e wait_time = self._db_error_wait_interval * n_tries self.logger.exception(e) - self.logger.info('Waiting {} seconds before retrying'.format(wait_time)) + self.logger.info('Waiting %s seconds before retrying', wait_time) time.sleep(wait_time) engine = self.get_engine(engine, *args, **kwargs) if not db_ok and last_error: raise last_error - return table, engine + assert table_, f'No such table: {table}' + return table_, engine @action def select( @@ -324,17 +322,16 @@ class DbPlugin(Plugin): connection, table, records, key_columns ) - with connection.begin(): - if insert_records: - insert = table.insert().values(insert_records) - ret = self._execute_try_returning(connection, insert) - if ret: - returned_records += ret + if insert_records: + insert = table.insert().values(insert_records) + ret = self._execute_try_returning(connection, insert) + if ret: + returned_records += ret - if update_records and on_duplicate_update: - ret = self._update(connection, table, update_records, key_columns) - if ret: - returned_records = ret + returned_records + if update_records and on_duplicate_update: + ret = self._update(connection, table, update_records, key_columns) + if ret: + returned_records = ret + returned_records if returned_records: return returned_records @@ -365,8 +362,9 @@ class DbPlugin(Plugin): query = table.select().where( or_( - and_( - self._build_condition(table, k, record.get(k)) for k in key_columns + and_( # type: ignore + self._build_condition(table, k, record.get(k)) + for k in key_columns # type: ignore ) for record in records ) @@ -498,13 +496,13 @@ class DbPlugin(Plugin): engine = self.get_engine(engine, *args, **kwargs) - with engine.connect() as connection, connection.begin(): + with engine.connect() as connection: for record in records: - table, engine = self._get_table(table, engine=engine, *args, **kwargs) - delete = table.delete() + table_, engine = self._get_table(table, engine=engine, *args, **kwargs) + delete = table_.delete() for k, v in record.items(): - delete = delete.where(self._build_condition(table, k, v)) + delete = delete.where(self._build_condition(table_, k, v)) connection.execute(delete) @@ -514,7 +512,7 @@ class DbPlugin(Plugin): @contextmanager def get_session( - self, engine=None, locked=False, autoflush=True, *args, **kwargs + self, *args, engine=None, locked=False, autoflush=True, **kwargs ) -> Generator[Session, None, None]: engine = self.get_engine(engine, *args, **kwargs) if locked: @@ -523,16 +521,20 @@ class DbPlugin(Plugin): # Mock lock lock = RLock() - with lock, engine.connect() as conn, conn.begin(): - session = scoped_session( + with lock, engine.connect() as conn: + session_maker = scoped_session( sessionmaker( expire_on_commit=False, autoflush=autoflush, ) ) - session.configure(bind=conn) - yield session() + session_maker.configure(bind=conn) + session = session_maker() + yield session + + session.flush() + session.commit() # vim:sw=4:ts=4:et: diff --git a/platypush/plugins/entities/__init__.py b/platypush/plugins/entities/__init__.py index 443cf51a..74caf3cd 100644 --- a/platypush/plugins/entities/__init__.py +++ b/platypush/plugins/entities/__init__.py @@ -3,7 +3,7 @@ from threading import Thread from time import time from typing import Optional, Any, Collection, Mapping -from sqlalchemy import or_ +from sqlalchemy import or_, text from sqlalchemy.orm import make_transient, Session from platypush.config import Config @@ -198,7 +198,7 @@ class EntitiesPlugin(Plugin): if str(session.connection().engine.url).startswith('sqlite://'): # SQLite requires foreign_keys to be explicitly enabled # in order to proper manage cascade deletions - session.execute('PRAGMA foreign_keys = ON') + session.execute(text('PRAGMA foreign_keys = ON')) entities: Collection[Entity] = ( session.query(Entity).filter(Entity.id.in_(entities)).all() diff --git a/platypush/plugins/variable/__init__.py b/platypush/plugins/variable/__init__.py index 0346ba88..9556b07f 100644 --- a/platypush/plugins/variable/__init__.py +++ b/platypush/plugins/variable/__init__.py @@ -1,6 +1,21 @@ -from platypush.config import Config +from sqlalchemy import Column, String + +from platypush.common.db import declarative_base from platypush.context import get_plugin from platypush.plugins import Plugin, action +from platypush.plugins.db import DbPlugin + +Base = declarative_base() + + +# pylint: disable=too-few-public-methods +class Variable(Base): + """Models the variable table""" + + __tablename__ = 'variable' + + name = Column(String, primary_key=True, nullable=False) + value = Column(String) class VariablePlugin(Plugin): @@ -11,8 +26,6 @@ class VariablePlugin(Plugin): will be stored either persisted on a local database or on the local Redis instance. """ - _variable_table_name = 'variable' - def __init__(self, **kwargs): """ The plugin will create a table named ``variable`` on the database @@ -21,24 +34,14 @@ class VariablePlugin(Plugin): """ super().__init__(**kwargs) - self.db_plugin = get_plugin('db') - self.redis_plugin = get_plugin('redis') + db_plugin = get_plugin('db') + redis_plugin = get_plugin('redis') + assert db_plugin, 'Database plugin not configured' + assert redis_plugin, 'Redis plugin not configured' - db = Config.get('db') - self.db_config = { - 'engine': db.get('engine'), - 'args': db.get('args', []), - 'kwargs': db.get('kwargs', {}) - } - - self._create_tables() - # self._variables = {} - - def _create_tables(self): - self.db_plugin.execute("""CREATE TABLE IF NOT EXISTS {}( - name varchar(255) not null primary key, - value text - )""".format(self._variable_table_name)) + self.redis_plugin = redis_plugin + self.db_plugin: DbPlugin = db_plugin + self.db_plugin.create_all(self.db_plugin.get_engine(), Base) @action def get(self, name, default_value=None): @@ -53,13 +56,10 @@ class VariablePlugin(Plugin): :returns: A map in the format ``{"":""}`` """ - rows = self.db_plugin.select(table=self._variable_table_name, - filter={'name': name}, - engine=self.db_config['engine'], - *self.db_config['args'], - **self.db_config['kwargs']).output + with self.db_plugin.get_session() as session: + var = session.query(Variable).filter_by(name=name).first() - return {name: rows[0]['value'] if rows else default_value} + return {name: (var.value if var is not None else default_value)} @action def set(self, **kwargs): @@ -69,15 +69,24 @@ class VariablePlugin(Plugin): :param kwargs: Key-value list of variables to set (e.g. ``foo='bar', answer=42``) """ - records = [{'name': k, 'value': v} - for (k, v) in kwargs.items()] + with self.db_plugin.get_session() as session: + existing_vars = { + var.name: var + for var in session.query(Variable) + .filter(Variable.name.in_(kwargs.keys())) + .all() + } - self.db_plugin.insert(table=self._variable_table_name, - records=records, key_columns=['name'], - engine=self.db_config['engine'], - on_duplicate_update=True, - *self.db_config['args'], - **self.db_config['kwargs']) + new_vars = { + name: Variable(name=name, value=value) + for name, value in kwargs.items() + if name not in existing_vars + } + + for name, var in existing_vars.items(): + var.value = kwargs[name] # type: ignore + + session.add_all([*existing_vars.values(), *new_vars.values()]) return kwargs @@ -90,12 +99,8 @@ class VariablePlugin(Plugin): :type name: str """ - records = [{'name': name}] - - self.db_plugin.delete(table=self._variable_table_name, - records=records, engine=self.db_config['engine'], - *self.db_config['args'], - **self.db_config['kwargs']) + with self.db_plugin.get_session() as session: + session.query(Variable).filter_by(name=name).delete() return True @@ -150,4 +155,5 @@ class VariablePlugin(Plugin): return self.redis_plugin.expire(name, expire) + # vim:sw=4:ts=4:et: diff --git a/platypush/user/__init__.py b/platypush/user/__init__.py index 7d268b59..432774a4 100644 --- a/platypush/user/__init__.py +++ b/platypush/user/__init__.py @@ -3,11 +3,11 @@ import datetime import hashlib import json import random -import rsa import time from typing import Optional, Dict import bcrypt +import rsa from sqlalchemy import Column, Integer, String, DateTime, ForeignKey from sqlalchemy.orm import make_transient @@ -73,7 +73,7 @@ class UserManager: username=username, password=self._encrypt_password(password), created_at=datetime.datetime.utcnow(), - **kwargs + **kwargs, ) session.add(record) @@ -238,9 +238,7 @@ class UserManager: indent=None, ) - return base64.b64encode( - rsa.encrypt(payload.encode('ascii'), pub_key) - ).decode() + return base64.b64encode(rsa.encrypt(payload.encode('ascii'), pub_key)).decode() def validate_jwt_token(self, token: str) -> Dict[str, str]: """ @@ -263,21 +261,19 @@ class UserManager: try: payload = json.loads( - rsa.decrypt( - base64.b64decode(token.encode('ascii')), - priv_key - ).decode('ascii') + rsa.decrypt(base64.b64decode(token.encode('ascii')), priv_key).decode( + 'ascii' + ) ) except (TypeError, ValueError) as e: - raise InvalidJWTTokenException(f'Could not decode JWT token: {e}') + raise InvalidJWTTokenException(f'Could not decode JWT token: {e}') from e expires_at = payload.get('expires_at') if expires_at and time.time() > expires_at: raise InvalidJWTTokenException('Expired JWT token') user = self.authenticate_user( - payload.get('username', ''), - payload.get('password', '') + payload.get('username', ''), payload.get('password', '') ) if not user: diff --git a/setup.py b/setup.py index eb77d9ec..48e64f1b 100755 --- a/setup.py +++ b/setup.py @@ -57,7 +57,7 @@ setup( 'redis', 'requests', 'croniter', - 'sqlalchemy<2.0.0', + 'sqlalchemy', 'websockets', 'websocket-client', 'wheel',