From 4ee7e4db296b9397f859559017c7e3f3aab25021 Mon Sep 17 00:00:00 2001 From: Fabio Manganiello Date: Mon, 4 Apr 2022 16:50:17 +0200 Subject: [PATCH] Basic support for entities on the local db and implemented support for switch entities on the tplink plugin --- platypush/__init__.py | 21 ++- platypush/entities/__init__.py | 36 +++++ platypush/entities/_base.py | 73 +++++++++ platypush/entities/_engine.py | 110 +++++++++++++ platypush/entities/_registry.py | 62 +++++++ platypush/entities/devices.py | 14 ++ platypush/entities/lights.py | 14 ++ platypush/entities/switches.py | 15 ++ platypush/plugins/db/__init__.py | 50 ++++-- platypush/plugins/switch/__init__.py | 5 +- platypush/plugins/switch/tplink/__init__.py | 70 +++++--- platypush/user/__init__.py | 170 ++++++++++---------- 12 files changed, 506 insertions(+), 134 deletions(-) create mode 100644 platypush/entities/__init__.py create mode 100644 platypush/entities/_base.py create mode 100644 platypush/entities/_engine.py create mode 100644 platypush/entities/_registry.py create mode 100644 platypush/entities/devices.py create mode 100644 platypush/entities/lights.py create mode 100644 platypush/entities/switches.py diff --git a/platypush/__init__.py b/platypush/__init__.py index 4c1d2cf18..e9f6640c0 100644 --- a/platypush/__init__.py +++ b/platypush/__init__.py @@ -9,11 +9,13 @@ import argparse import logging import os import sys +from typing import Optional from .bus.redis import RedisBus from .config import Config from .context import register_backends, register_plugins from .cron.scheduler import CronScheduler +from .entities import init_entities_engine, EntitiesEngine from .event.processor import EventProcessor from .logger import Logger from .message.event import Event @@ -86,6 +88,7 @@ class Daemon: self.no_capture_stdout = no_capture_stdout self.no_capture_stderr = no_capture_stderr self.event_processor = EventProcessor() + self.entities_engine: Optional[EntitiesEngine] = None self.requests_to_process = requests_to_process self.processed_requests = 0 self.cron_scheduler = None @@ -161,16 +164,25 @@ class Daemon: """ Stops the backends and the bus """ from .plugins import RunnablePlugin - for backend in self.backends.values(): - backend.stop() + if self.backends: + for backend in self.backends.values(): + backend.stop() for plugin in get_enabled_plugins().values(): if isinstance(plugin, RunnablePlugin): plugin.stop() - self.bus.stop() + if self.bus: + self.bus.stop() + self.bus = None + if self.cron_scheduler: self.cron_scheduler.stop() + self.cron_scheduler = None + + if self.entities_engine: + self.entities_engine.stop() + self.entities_engine = None def run(self): """ Start the daemon """ @@ -192,6 +204,9 @@ class Daemon: # Initialize the plugins register_plugins(bus=self.bus) + # Initialize the entities engine + self.entities_engine = init_entities_engine() + # Start the cron scheduler if Config.get_cronjobs(): self.cron_scheduler = CronScheduler(jobs=Config.get_cronjobs()) diff --git a/platypush/entities/__init__.py b/platypush/entities/__init__.py new file mode 100644 index 000000000..f59ab240b --- /dev/null +++ b/platypush/entities/__init__.py @@ -0,0 +1,36 @@ +import warnings +from typing import Collection, Optional + +from ._base import Entity +from ._engine import EntitiesEngine +from ._registry import manages, register_entity_plugin, get_plugin_registry + +_engine: Optional[EntitiesEngine] = None + + +def init_entities_engine() -> EntitiesEngine: + from ._base import init_entities_db + global _engine + init_entities_db() + _engine = EntitiesEngine() + _engine.start() + return _engine + + +def publish_entities(entities: Collection[Entity]): + if not _engine: + warnings.warn('No entities engine registered') + return + + _engine.post(*entities) + +__all__ = ( + 'Entity', + 'EntitiesEngine', + 'init_entities_engine', + 'publish_entities', + 'register_entity_plugin', + 'get_plugin_registry', + 'manages', +) + diff --git a/platypush/entities/_base.py b/platypush/entities/_base.py new file mode 100644 index 000000000..80be23b83 --- /dev/null +++ b/platypush/entities/_base.py @@ -0,0 +1,73 @@ +import inspect +import pathlib +from datetime import datetime +from typing import Mapping, Type + +import pkgutil +from sqlalchemy import Column, Index, Integer, String, DateTime, JSON, UniqueConstraint +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() +entities_registry: Mapping[Type['Entity'], Mapping] = {} + + +class Entity(Base): + """ + Model for a general-purpose platform entity + """ + + __tablename__ = 'entity' + + id = Column(Integer, autoincrement=True, primary_key=True) + external_id = Column(String, nullable=True) + name = Column(String, nullable=False, index=True) + type = Column(String, nullable=False, index=True) + plugin = Column(String, nullable=False) + data = Column(JSON, default=dict) + created_at = Column(DateTime(timezone=False), default=datetime.utcnow(), nullable=False) + updated_at = Column(DateTime(timezone=False), default=datetime.utcnow(), onupdate=datetime.now()) + + UniqueConstraint(external_id, plugin) + + __table_args__ = ( + Index(name, plugin), + ) + + __mapper_args__ = { + 'polymorphic_identity': __tablename__, + 'polymorphic_on': type, + } + + +def _discover_entity_types(): + from platypush.context import get_plugin + logger = get_plugin('logger') + assert logger + + for loader, modname, _ in pkgutil.walk_packages( + path=[str(pathlib.Path(__file__).parent.absolute())], + prefix=__package__ + '.', + onerror=lambda _: None + ): + try: + mod_loader = loader.find_module(modname) # type: ignore + assert mod_loader + module = mod_loader.load_module() # type: ignore + except Exception as e: + logger.warning(f'Could not import module {modname}') + logger.exception(e) + continue + + for _, obj in inspect.getmembers(module): + if inspect.isclass(obj) and issubclass(obj, Entity): + entities_registry[obj] = {} + + +def init_entities_db(): + from platypush.context import get_plugin + _discover_entity_types() + db = get_plugin('db') + assert db + engine = db.get_engine() + db.create_all(engine, Base) + diff --git a/platypush/entities/_engine.py b/platypush/entities/_engine.py new file mode 100644 index 000000000..f99a4725c --- /dev/null +++ b/platypush/entities/_engine.py @@ -0,0 +1,110 @@ +from logging import getLogger +from queue import Queue, Empty +from threading import Thread, Event +from time import time +from typing import Iterable, List + +from sqlalchemy import and_, or_, inspect as schema_inspect +from sqlalchemy.orm import Session +from sqlalchemy.sql.elements import Null + +from ._base import Entity + + +class EntitiesEngine(Thread): + # Processing queue timeout in seconds + _queue_timeout = 5. + + def __init__(self): + obj_name = self.__class__.__name__ + super().__init__(name=obj_name) + self.logger = getLogger(name=obj_name) + self._queue = Queue() + self._should_stop = Event() + + def post(self, *entities: Entity): + for entity in entities: + self._queue.put(entity) + + @property + def should_stop(self) -> bool: + return self._should_stop.is_set() + + def stop(self): + self._should_stop.set() + + def run(self): + super().run() + self.logger.info('Started entities engine') + + while not self.should_stop: + msgs = [] + last_poll_time = time() + + while not self.should_stop and ( + time() - last_poll_time < self._queue_timeout): + try: + msg = self._queue.get(block=True, timeout=0.5) + except Empty: + continue + + if msg: + msgs.append(msg) + + if not msgs or self.should_stop: + continue + + self._process_entities(*msgs) + + self.logger.info('Stopped entities engine') + + def _get_if_exist(self, session: Session, entities: Iterable[Entity]) -> Iterable[Entity]: + existing_entities = { + (entity.external_id or entity.name, entity.plugin): entity + for entity in session.query(Entity).filter( + or_(*[ + and_(Entity.external_id == entity.external_id, Entity.plugin == entity.plugin) + if entity.external_id is not None else + and_(Entity.name == entity.name, Entity.plugin == entity.plugin) + for entity in entities + ]) + ).all() + } + + return [ + existing_entities.get( + (entity.external_id or entity.name, entity.plugin), None + ) for entity in entities + ] + + def _merge_entities( + self, entities: List[Entity], + existing_entities: List[Entity] + ) -> List[Entity]: + new_entities = [] + + for i, entity in enumerate(entities): + existing_entity = existing_entities[i] + if existing_entity: + inspector = schema_inspect(entity.__class__) + columns = [col.key for col in inspector.mapper.column_attrs] + for col in columns: + new_value = getattr(entity, col) + if new_value is not None and new_value.__class__ != Null: + setattr(existing_entity, col, getattr(entity, col)) + + new_entities.append(existing_entity) + else: + new_entities.append(entity) + + return new_entities + + def _process_entities(self, *entities: Entity): + from platypush.context import get_plugin + + with get_plugin('db').get_session() as session: # type: ignore + existing_entities = self._get_if_exist(session, entities) + entities = self._merge_entities(entities, existing_entities) # type: ignore + session.add_all(entities) + session.commit() + diff --git a/platypush/entities/_registry.py b/platypush/entities/_registry.py new file mode 100644 index 000000000..b8644808f --- /dev/null +++ b/platypush/entities/_registry.py @@ -0,0 +1,62 @@ +from datetime import datetime +from typing import Optional, Mapping, Dict, Collection, Type + +from platypush.plugins import Plugin +from platypush.utils import get_plugin_name_by_class + +from ._base import Entity + +_entity_plugin_registry: Mapping[Type[Entity], Dict[str, Plugin]] = {} + + +def register_entity_plugin(entity_type: Type[Entity], plugin: Plugin): + plugins = _entity_plugin_registry.get(entity_type, {}) + plugin_name = get_plugin_name_by_class(plugin.__class__) + assert plugin_name + plugins[plugin_name] = plugin + _entity_plugin_registry[entity_type] = plugins + + +def get_plugin_registry(): + return _entity_plugin_registry.copy() + + +class EntityManagerMixin: + def transform_entities(self, entities): + entities = entities or [] + for entity in entities: + if entity.id: + # Entity IDs can only refer to the internal primary key + entity.external_id = entity.id + entity.id = None # type: ignore + + entity.plugin = get_plugin_name_by_class(self.__class__) # type: ignore + entity.updated_at = datetime.utcnow() + + return entities + + def publish_entities(self, entities: Optional[Collection[Entity]]): + from . import publish_entities + entities = self.transform_entities(entities) + publish_entities(entities) + + +def manages(*entities: Type[Entity]): + def wrapper(plugin: Type[Plugin]): + init = plugin.__init__ + + def __init__(self, *args, **kwargs): + for entity_type in entities: + register_entity_plugin(entity_type, self) + + init(self, *args, **kwargs) + + plugin.__init__ = __init__ + # Inject the EntityManagerMixin + if EntityManagerMixin not in plugin.__bases__: + plugin.__bases__ = (EntityManagerMixin,) + plugin.__bases__ + + return plugin + + return wrapper + diff --git a/platypush/entities/devices.py b/platypush/entities/devices.py new file mode 100644 index 000000000..dfc64f01f --- /dev/null +++ b/platypush/entities/devices.py @@ -0,0 +1,14 @@ +from sqlalchemy import Column, Integer, ForeignKey + +from ._base import Entity + + +class Device(Entity): + __tablename__ = 'device' + + id = Column(Integer, ForeignKey(Entity.id), primary_key=True) + + __mapper_args__ = { + 'polymorphic_identity': __tablename__, + } + diff --git a/platypush/entities/lights.py b/platypush/entities/lights.py new file mode 100644 index 000000000..95f303f90 --- /dev/null +++ b/platypush/entities/lights.py @@ -0,0 +1,14 @@ +from sqlalchemy import Column, Integer, ForeignKey + +from .devices import Device + + +class Light(Device): + __tablename__ = 'light' + + id = Column(Integer, ForeignKey(Device.id), primary_key=True) + + __mapper_args__ = { + 'polymorphic_identity': __tablename__, + } + diff --git a/platypush/entities/switches.py b/platypush/entities/switches.py new file mode 100644 index 000000000..4af4ba189 --- /dev/null +++ b/platypush/entities/switches.py @@ -0,0 +1,15 @@ +from sqlalchemy import Column, Integer, ForeignKey, Boolean + +from .devices import Device + + +class Switch(Device): + __tablename__ = 'switch' + + id = Column(Integer, ForeignKey(Device.id), primary_key=True) + state = Column(Boolean) + + __mapper_args__ = { + 'polymorphic_identity': __tablename__, + } + diff --git a/platypush/plugins/db/__init__.py b/platypush/plugins/db/__init__.py index f4594f089..ab901b2e8 100644 --- a/platypush/plugins/db/__init__.py +++ b/platypush/plugins/db/__init__.py @@ -1,11 +1,11 @@ -""" -.. moduleauthor:: Fabio Manganiello -""" - import time +from contextlib import contextmanager +from multiprocessing import RLock +from typing import Generator from sqlalchemy import create_engine, Table, MetaData from sqlalchemy.engine import Engine +from sqlalchemy.orm import Session, sessionmaker, scoped_session from platypush.plugins import Plugin, action @@ -30,22 +30,23 @@ class DbPlugin(Plugin): """ super().__init__() - self.engine = self._get_engine(engine, *args, **kwargs) + self.engine = self.get_engine(engine, *args, **kwargs) + self._session_locks = {} - def _get_engine(self, engine=None, *args, **kwargs): + def get_engine(self, engine=None, *args, **kwargs) -> Engine: if engine: if isinstance(engine, Engine): return engine if engine.startswith('sqlite://'): kwargs['connect_args'] = {'check_same_thread': False} - return create_engine(engine, *args, **kwargs) + return create_engine(engine, *args, **kwargs) # type: ignore + assert self.engine return self.engine - # noinspection PyUnusedLocal @staticmethod - def _build_condition(table, column, value): + def _build_condition(_, column, value): if isinstance(value, str): value = "'{}'".format(value) elif not isinstance(value, int) and not isinstance(value, float): @@ -73,14 +74,14 @@ class DbPlugin(Plugin): :param kwargs: Extra kwargs that will be passed to ``sqlalchemy.create_engine`` (seehttps:///docs.sqlalchemy.org/en/latest/core/engines.html) """ - engine = self._get_engine(engine, *args, **kwargs) + engine = self.get_engine(engine, *args, **kwargs) with engine.connect() as connection: connection.execute(statement) def _get_table(self, table, engine=None, *args, **kwargs): if not engine: - engine = self._get_engine(engine, *args, **kwargs) + engine = self.get_engine(engine, *args, **kwargs) db_ok = False n_tries = 0 @@ -98,7 +99,7 @@ class DbPlugin(Plugin): self.logger.exception(e) self.logger.info('Waiting {} seconds before retrying'.format(wait_time)) time.sleep(wait_time) - engine = self._get_engine(engine, *args, **kwargs) + engine = self.get_engine(engine, *args, **kwargs) if not db_ok and last_error: raise last_error @@ -163,7 +164,7 @@ class DbPlugin(Plugin): ] """ - engine = self._get_engine(engine, *args, **kwargs) + engine = self.get_engine(engine, *args, **kwargs) if table: table, engine = self._get_table(table, engine=engine, *args, **kwargs) @@ -234,7 +235,7 @@ class DbPlugin(Plugin): if key_columns is None: key_columns = [] - engine = self._get_engine(engine, *args, **kwargs) + engine = self.get_engine(engine, *args, **kwargs) for record in records: table, engine = self._get_table(table, engine=engine, *args, **kwargs) @@ -293,7 +294,7 @@ class DbPlugin(Plugin): } """ - engine = self._get_engine(engine, *args, **kwargs) + engine = self.get_engine(engine, *args, **kwargs) for record in records: table, engine = self._get_table(table, engine=engine, *args, **kwargs) @@ -341,7 +342,7 @@ class DbPlugin(Plugin): } """ - engine = self._get_engine(engine, *args, **kwargs) + engine = self.get_engine(engine, *args, **kwargs) for record in records: table, engine = self._get_table(table, engine=engine, *args, **kwargs) @@ -352,5 +353,22 @@ class DbPlugin(Plugin): engine.execute(delete) + def create_all(self, engine, base): + self._session_locks[engine.url] = self._session_locks.get(engine.url, RLock()) + with self._session_locks[engine.url]: + base.metadata.create_all(engine) + + @contextmanager + def get_session(self, engine=None, *args, **kwargs) -> Generator[Session, None, None]: + engine = self.get_engine(engine, *args, **kwargs) + self._session_locks[engine.url] = self._session_locks.get(engine.url, RLock()) + with self._session_locks[engine.url]: + session = scoped_session(sessionmaker(expire_on_commit=False)) + session.configure(bind=engine) + s = session() + yield s + s.commit() + s.close() + # vim:sw=4:ts=4:et: diff --git a/platypush/plugins/switch/__init__.py b/platypush/plugins/switch/__init__.py index e73513d90..4b33f7e75 100644 --- a/platypush/plugins/switch/__init__.py +++ b/platypush/plugins/switch/__init__.py @@ -1,9 +1,12 @@ from abc import ABC, abstractmethod from typing import List, Union +from platypush.entities import manages +from platypush.entities.switches import Switch from platypush.plugins import Plugin, action +@manages(Switch) class SwitchPlugin(Plugin, ABC): """ Abstract class for interacting with switch devices @@ -46,7 +49,7 @@ class SwitchPlugin(Plugin, ABC): return devices @action - def status(self, device=None, *args, **kwargs) -> Union[dict, List[dict]]: + def status(self, device=None, *_, **__) -> Union[dict, List[dict]]: """ Get the status of all the devices, or filter by device name or ID (alias for :meth:`.switch_status`). diff --git a/platypush/plugins/switch/tplink/__init__.py b/platypush/plugins/switch/tplink/__init__.py index 47470d357..cab5a6c13 100644 --- a/platypush/plugins/switch/tplink/__init__.py +++ b/platypush/plugins/switch/tplink/__init__.py @@ -1,4 +1,4 @@ -from typing import Union, Dict, List +from typing import Union, Mapping, List, Collection, Optional from pyHS100 import SmartDevice, SmartPlug, SmartBulb, SmartStrip, Discover, SmartDeviceException @@ -20,8 +20,12 @@ class SwitchTplinkPlugin(SwitchPlugin): _ip_to_dev = {} _alias_to_dev = {} - def __init__(self, plugs: Union[Dict[str, str], List[str]] = None, bulbs: Union[Dict[str, str], List[str]] = None, - strips: Union[Dict[str, str], List[str]] = None, **kwargs): + def __init__( + self, + plugs: Optional[Union[Mapping[str, str], List[str]]] = None, + bulbs: Optional[Union[Mapping[str, str], List[str]]] = None, + strips: Optional[Union[Mapping[str, str], List[str]]] = None, **kwargs + ): """ :param plugs: Optional list of IP addresses or name->address mapping if you have a static list of TpLink plugs and you want to save on the scan time. @@ -62,7 +66,7 @@ class SwitchTplinkPlugin(SwitchPlugin): self._update_devices() - def _update_devices(self, devices: Dict[str, SmartDevice] = None): + def _update_devices(self, devices: Optional[Mapping[str, SmartDevice]] = None): for (addr, info) in self._static_devices.items(): try: dev = info['type'](addr) @@ -75,6 +79,26 @@ class SwitchTplinkPlugin(SwitchPlugin): self._ip_to_dev[ip] = dev self._alias_to_dev[dev.alias] = dev + if devices: + self.publish_entities(devices.values()) # type: ignore + + def transform_entities(self, devices: Collection[SmartDevice]): + from platypush.entities.switches import Switch + return super().transform_entities([ # type: ignore + Switch( + id=dev.host, + name=dev.alias, + state=dev.is_on, + data={ + 'current_consumption': dev.current_consumption(), + 'ip': dev.host, + 'host': dev.host, + 'hw_info': dev.hw_info, + } + ) + for dev in (devices or []) + ]) + def _scan(self): devices = Discover.discover() self._update_devices(devices) @@ -95,8 +119,15 @@ class SwitchTplinkPlugin(SwitchPlugin): else: raise RuntimeError('Device {} not found'.format(device)) + def _set(self, device: SmartDevice, state: bool): + action_name = 'turn_on' if state else 'turn_off' + action = getattr(device, action_name) + action() + self.publish_entities([device]) # type: ignore + return self._serialize(device) + @action - def on(self, device, **kwargs): + def on(self, device, **_): """ Turn on a device @@ -105,11 +136,10 @@ class SwitchTplinkPlugin(SwitchPlugin): """ device = self._get_device(device) - device.turn_on() - return self.status(device) + return self._set(device, True) @action - def off(self, device, **kwargs): + def off(self, device, **_): """ Turn off a device @@ -118,11 +148,10 @@ class SwitchTplinkPlugin(SwitchPlugin): """ device = self._get_device(device) - device.turn_off() - return self.status(device) + return self._set(device, False) @action - def toggle(self, device, **kwargs): + def toggle(self, device, **_): """ Toggle the state of a device (on/off) @@ -131,12 +160,10 @@ class SwitchTplinkPlugin(SwitchPlugin): """ device = self._get_device(device) + return self._set(device, not device.is_on) - if device.is_on: - device.turn_off() - else: - device.turn_on() - + @staticmethod + def _serialize(device: SmartDevice) -> dict: return { 'current_consumption': device.current_consumption(), 'id': device.host, @@ -150,15 +177,8 @@ class SwitchTplinkPlugin(SwitchPlugin): @property def switches(self) -> List[dict]: return [ - { - 'current_consumption': dev.current_consumption(), - 'id': ip, - 'ip': ip, - 'host': dev.host, - 'hw_info': dev.hw_info, - 'name': dev.alias, - 'on': dev.is_on, - } for (ip, dev) in self._scan().items() + self._serialize(dev) + for dev in self._scan().values() ] diff --git a/platypush/user/__init__.py b/platypush/user/__init__.py index ccc2040aa..673497b7f 100644 --- a/platypush/user/__init__.py +++ b/platypush/user/__init__.py @@ -13,7 +13,7 @@ except ImportError: from jwt import PyJWTError, encode as jwt_encode, decode as jwt_decode from sqlalchemy import Column, Integer, String, DateTime, ForeignKey -from sqlalchemy.orm import sessionmaker, scoped_session +from sqlalchemy.orm import make_transient from sqlalchemy.ext.declarative import declarative_base from platypush.context import get_plugin @@ -28,126 +28,124 @@ class UserManager: Main class for managing platform users """ - # noinspection PyProtectedMember def __init__(self): - db_plugin = get_plugin('db') - if not db_plugin: - raise ModuleNotFoundError('Please enable/configure the db plugin for multi-user support') + self.db = get_plugin('db') + assert self.db + self._engine = self.db.get_engine() + self.db.create_all(self._engine, Base) - self._engine = db_plugin._get_engine() - - def get_user(self, username): - session = self._get_db_session() - user = self._get_user(session, username) - if not user: - return None - - # Hide password + @staticmethod + def _mask_password(user): + make_transient(user) user.password = None return user + def get_user(self, username): + with self.db.get_session() as session: + user = self._get_user(session, username) + if not user: + return None + + session.expunge(user) + return self._mask_password(user) + def get_user_count(self): - session = self._get_db_session() - return session.query(User).count() + with self.db.get_session() as session: + return session.query(User).count() def get_users(self): - session = self._get_db_session() - return session.query(User) + with self.db.get_session() as session: + return session.query(User) def create_user(self, username, password, **kwargs): - session = self._get_db_session() if not username: raise ValueError('Invalid or empty username') if not password: raise ValueError('Please provide a password for the user') - user = self._get_user(session, username) - if user: - raise NameError('The user {} already exists'.format(username)) + with self.db.get_session() as session: + user = self._get_user(session, username) + if user: + raise NameError('The user {} already exists'.format(username)) - record = User(username=username, password=self._encrypt_password(password), - created_at=datetime.datetime.utcnow(), **kwargs) + record = User(username=username, password=self._encrypt_password(password), + created_at=datetime.datetime.utcnow(), **kwargs) - session.add(record) - session.commit() - user = self._get_user(session, username) + session.add(record) + session.commit() + user = self._get_user(session, username) - # Hide password - user.password = None - return user + return self._mask_password(user) def update_password(self, username, old_password, new_password): - session = self._get_db_session() - if not self._authenticate_user(session, username, old_password): - return False + with self.db.get_session() as session: + if not self._authenticate_user(session, username, old_password): + return False - user = self._get_user(session, username) - user.password = self._encrypt_password(new_password) - session.commit() - return True + user = self._get_user(session, username) + user.password = self._encrypt_password(new_password) + session.commit() + return True def authenticate_user(self, username, password): - session = self._get_db_session() - return self._authenticate_user(session, username, password) + with self.db.get_session() as session: + return self._authenticate_user(session, username, password) def authenticate_user_session(self, session_token): - session = self._get_db_session() - user_session = session.query(UserSession).filter_by(session_token=session_token).first() + with self.db.get_session() as session: + user_session = session.query(UserSession).filter_by(session_token=session_token).first() - if not user_session or ( - user_session.expires_at and user_session.expires_at < datetime.datetime.utcnow()): - return None, None + if not user_session or ( + user_session.expires_at and user_session.expires_at < datetime.datetime.utcnow()): + return None, None - user = session.query(User).filter_by(user_id=user_session.user_id).first() - - # Hide password - user.password = None - return user, session + user = session.query(User).filter_by(user_id=user_session.user_id).first() + return self._mask_password(user), user_session def delete_user(self, username): - session = self._get_db_session() - user = self._get_user(session, username) - if not user: - raise NameError('No such user: {}'.format(username)) + with self.db.get_session() as session: + user = self._get_user(session, username) + if not user: + raise NameError('No such user: {}'.format(username)) - user_sessions = session.query(UserSession).filter_by(user_id=user.user_id).all() - for user_session in user_sessions: - session.delete(user_session) + user_sessions = session.query(UserSession).filter_by(user_id=user.user_id).all() + for user_session in user_sessions: + session.delete(user_session) - session.delete(user) - session.commit() - return True + session.delete(user) + session.commit() + return True def delete_user_session(self, session_token): - session = self._get_db_session() - user_session = session.query(UserSession).filter_by(session_token=session_token).first() + with self.db.get_session() as session: + user_session = session.query(UserSession).filter_by(session_token=session_token).first() - if not user_session: - return False + if not user_session: + return False - session.delete(user_session) - session.commit() - return True + session.delete(user_session) + session.commit() + return True def create_user_session(self, username, password, expires_at=None): - session = self._get_db_session() - user = self._authenticate_user(session, username, password) - if not user: - return None + with self.db.get_session() as session: + user = self._authenticate_user(session, username, password) + if not user: + return None - if expires_at: - if isinstance(expires_at, int) or isinstance(expires_at, float): - expires_at = datetime.datetime.fromtimestamp(expires_at) - elif isinstance(expires_at, str): - expires_at = datetime.datetime.fromisoformat(expires_at) + if expires_at: + if isinstance(expires_at, int) or isinstance(expires_at, float): + expires_at = datetime.datetime.fromtimestamp(expires_at) + elif isinstance(expires_at, str): + expires_at = datetime.datetime.fromisoformat(expires_at) - user_session = UserSession(user_id=user.user_id, session_token=self.generate_session_token(), - csrf_token=self.generate_session_token(), created_at=datetime.datetime.utcnow(), - expires_at=expires_at) + user_session = UserSession(user_id=user.user_id, session_token=self.generate_session_token(), + csrf_token=self.generate_session_token(), created_at=datetime.datetime.utcnow(), + expires_at=expires_at) - session.add(user_session) - session.commit() - return user_session + session.add(user_session) + session.commit() + return user_session @staticmethod def _get_user(session, username): @@ -180,8 +178,8 @@ class UserManager: :param session_token: Session token. """ - session = self._get_db_session() - return session.query(User).join(UserSession).filter_by(session_token=session_token).first() + with self.db.get_session() as session: + return session.query(User).join(UserSession).filter_by(session_token=session_token).first() def generate_jwt_token(self, username: str, password: str, expires_at: Optional[datetime.datetime] = None) -> str: """ @@ -240,12 +238,6 @@ class UserManager: return payload - def _get_db_session(self): - Base.metadata.create_all(self._engine) - session = scoped_session(sessionmaker(expire_on_commit=False)) - session.configure(bind=self._engine) - return session() - def _authenticate_user(self, session, username, password): """ :return: :class:`platypush.user.User` instance if the user exists and the password is valid, ``None`` otherwise.