Basic support for entities on the local db and implemented support for switch entities on the tplink plugin

This commit is contained in:
Fabio Manganiello 2022-04-04 16:50:17 +02:00
parent b1491b8048
commit 4ee7e4db29
Signed by: blacklight
GPG Key ID: D90FBA7F76362774
12 changed files with 506 additions and 134 deletions

View File

@ -9,11 +9,13 @@ import argparse
import logging import logging
import os import os
import sys import sys
from typing import Optional
from .bus.redis import RedisBus from .bus.redis import RedisBus
from .config import Config from .config import Config
from .context import register_backends, register_plugins from .context import register_backends, register_plugins
from .cron.scheduler import CronScheduler from .cron.scheduler import CronScheduler
from .entities import init_entities_engine, EntitiesEngine
from .event.processor import EventProcessor from .event.processor import EventProcessor
from .logger import Logger from .logger import Logger
from .message.event import Event from .message.event import Event
@ -86,6 +88,7 @@ class Daemon:
self.no_capture_stdout = no_capture_stdout self.no_capture_stdout = no_capture_stdout
self.no_capture_stderr = no_capture_stderr self.no_capture_stderr = no_capture_stderr
self.event_processor = EventProcessor() self.event_processor = EventProcessor()
self.entities_engine: Optional[EntitiesEngine] = None
self.requests_to_process = requests_to_process self.requests_to_process = requests_to_process
self.processed_requests = 0 self.processed_requests = 0
self.cron_scheduler = None self.cron_scheduler = None
@ -161,16 +164,25 @@ class Daemon:
""" Stops the backends and the bus """ """ Stops the backends and the bus """
from .plugins import RunnablePlugin from .plugins import RunnablePlugin
for backend in self.backends.values(): if self.backends:
backend.stop() for backend in self.backends.values():
backend.stop()
for plugin in get_enabled_plugins().values(): for plugin in get_enabled_plugins().values():
if isinstance(plugin, RunnablePlugin): if isinstance(plugin, RunnablePlugin):
plugin.stop() plugin.stop()
self.bus.stop() if self.bus:
self.bus.stop()
self.bus = None
if self.cron_scheduler: if self.cron_scheduler:
self.cron_scheduler.stop() self.cron_scheduler.stop()
self.cron_scheduler = None
if self.entities_engine:
self.entities_engine.stop()
self.entities_engine = None
def run(self): def run(self):
""" Start the daemon """ """ Start the daemon """
@ -192,6 +204,9 @@ class Daemon:
# Initialize the plugins # Initialize the plugins
register_plugins(bus=self.bus) register_plugins(bus=self.bus)
# Initialize the entities engine
self.entities_engine = init_entities_engine()
# Start the cron scheduler # Start the cron scheduler
if Config.get_cronjobs(): if Config.get_cronjobs():
self.cron_scheduler = CronScheduler(jobs=Config.get_cronjobs()) self.cron_scheduler = CronScheduler(jobs=Config.get_cronjobs())

View File

@ -0,0 +1,36 @@
import warnings
from typing import Collection, Optional
from ._base import Entity
from ._engine import EntitiesEngine
from ._registry import manages, register_entity_plugin, get_plugin_registry
_engine: Optional[EntitiesEngine] = None
def init_entities_engine() -> EntitiesEngine:
from ._base import init_entities_db
global _engine
init_entities_db()
_engine = EntitiesEngine()
_engine.start()
return _engine
def publish_entities(entities: Collection[Entity]):
if not _engine:
warnings.warn('No entities engine registered')
return
_engine.post(*entities)
__all__ = (
'Entity',
'EntitiesEngine',
'init_entities_engine',
'publish_entities',
'register_entity_plugin',
'get_plugin_registry',
'manages',
)

View File

@ -0,0 +1,73 @@
import inspect
import pathlib
from datetime import datetime
from typing import Mapping, Type
import pkgutil
from sqlalchemy import Column, Index, Integer, String, DateTime, JSON, UniqueConstraint
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
entities_registry: Mapping[Type['Entity'], Mapping] = {}
class Entity(Base):
"""
Model for a general-purpose platform entity
"""
__tablename__ = 'entity'
id = Column(Integer, autoincrement=True, primary_key=True)
external_id = Column(String, nullable=True)
name = Column(String, nullable=False, index=True)
type = Column(String, nullable=False, index=True)
plugin = Column(String, nullable=False)
data = Column(JSON, default=dict)
created_at = Column(DateTime(timezone=False), default=datetime.utcnow(), nullable=False)
updated_at = Column(DateTime(timezone=False), default=datetime.utcnow(), onupdate=datetime.now())
UniqueConstraint(external_id, plugin)
__table_args__ = (
Index(name, plugin),
)
__mapper_args__ = {
'polymorphic_identity': __tablename__,
'polymorphic_on': type,
}
def _discover_entity_types():
from platypush.context import get_plugin
logger = get_plugin('logger')
assert logger
for loader, modname, _ in pkgutil.walk_packages(
path=[str(pathlib.Path(__file__).parent.absolute())],
prefix=__package__ + '.',
onerror=lambda _: None
):
try:
mod_loader = loader.find_module(modname) # type: ignore
assert mod_loader
module = mod_loader.load_module() # type: ignore
except Exception as e:
logger.warning(f'Could not import module {modname}')
logger.exception(e)
continue
for _, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, Entity):
entities_registry[obj] = {}
def init_entities_db():
from platypush.context import get_plugin
_discover_entity_types()
db = get_plugin('db')
assert db
engine = db.get_engine()
db.create_all(engine, Base)

View File

@ -0,0 +1,110 @@
from logging import getLogger
from queue import Queue, Empty
from threading import Thread, Event
from time import time
from typing import Iterable, List
from sqlalchemy import and_, or_, inspect as schema_inspect
from sqlalchemy.orm import Session
from sqlalchemy.sql.elements import Null
from ._base import Entity
class EntitiesEngine(Thread):
# Processing queue timeout in seconds
_queue_timeout = 5.
def __init__(self):
obj_name = self.__class__.__name__
super().__init__(name=obj_name)
self.logger = getLogger(name=obj_name)
self._queue = Queue()
self._should_stop = Event()
def post(self, *entities: Entity):
for entity in entities:
self._queue.put(entity)
@property
def should_stop(self) -> bool:
return self._should_stop.is_set()
def stop(self):
self._should_stop.set()
def run(self):
super().run()
self.logger.info('Started entities engine')
while not self.should_stop:
msgs = []
last_poll_time = time()
while not self.should_stop and (
time() - last_poll_time < self._queue_timeout):
try:
msg = self._queue.get(block=True, timeout=0.5)
except Empty:
continue
if msg:
msgs.append(msg)
if not msgs or self.should_stop:
continue
self._process_entities(*msgs)
self.logger.info('Stopped entities engine')
def _get_if_exist(self, session: Session, entities: Iterable[Entity]) -> Iterable[Entity]:
existing_entities = {
(entity.external_id or entity.name, entity.plugin): entity
for entity in session.query(Entity).filter(
or_(*[
and_(Entity.external_id == entity.external_id, Entity.plugin == entity.plugin)
if entity.external_id is not None else
and_(Entity.name == entity.name, Entity.plugin == entity.plugin)
for entity in entities
])
).all()
}
return [
existing_entities.get(
(entity.external_id or entity.name, entity.plugin), None
) for entity in entities
]
def _merge_entities(
self, entities: List[Entity],
existing_entities: List[Entity]
) -> List[Entity]:
new_entities = []
for i, entity in enumerate(entities):
existing_entity = existing_entities[i]
if existing_entity:
inspector = schema_inspect(entity.__class__)
columns = [col.key for col in inspector.mapper.column_attrs]
for col in columns:
new_value = getattr(entity, col)
if new_value is not None and new_value.__class__ != Null:
setattr(existing_entity, col, getattr(entity, col))
new_entities.append(existing_entity)
else:
new_entities.append(entity)
return new_entities
def _process_entities(self, *entities: Entity):
from platypush.context import get_plugin
with get_plugin('db').get_session() as session: # type: ignore
existing_entities = self._get_if_exist(session, entities)
entities = self._merge_entities(entities, existing_entities) # type: ignore
session.add_all(entities)
session.commit()

View File

@ -0,0 +1,62 @@
from datetime import datetime
from typing import Optional, Mapping, Dict, Collection, Type
from platypush.plugins import Plugin
from platypush.utils import get_plugin_name_by_class
from ._base import Entity
_entity_plugin_registry: Mapping[Type[Entity], Dict[str, Plugin]] = {}
def register_entity_plugin(entity_type: Type[Entity], plugin: Plugin):
plugins = _entity_plugin_registry.get(entity_type, {})
plugin_name = get_plugin_name_by_class(plugin.__class__)
assert plugin_name
plugins[plugin_name] = plugin
_entity_plugin_registry[entity_type] = plugins
def get_plugin_registry():
return _entity_plugin_registry.copy()
class EntityManagerMixin:
def transform_entities(self, entities):
entities = entities or []
for entity in entities:
if entity.id:
# Entity IDs can only refer to the internal primary key
entity.external_id = entity.id
entity.id = None # type: ignore
entity.plugin = get_plugin_name_by_class(self.__class__) # type: ignore
entity.updated_at = datetime.utcnow()
return entities
def publish_entities(self, entities: Optional[Collection[Entity]]):
from . import publish_entities
entities = self.transform_entities(entities)
publish_entities(entities)
def manages(*entities: Type[Entity]):
def wrapper(plugin: Type[Plugin]):
init = plugin.__init__
def __init__(self, *args, **kwargs):
for entity_type in entities:
register_entity_plugin(entity_type, self)
init(self, *args, **kwargs)
plugin.__init__ = __init__
# Inject the EntityManagerMixin
if EntityManagerMixin not in plugin.__bases__:
plugin.__bases__ = (EntityManagerMixin,) + plugin.__bases__
return plugin
return wrapper

View File

@ -0,0 +1,14 @@
from sqlalchemy import Column, Integer, ForeignKey
from ._base import Entity
class Device(Entity):
__tablename__ = 'device'
id = Column(Integer, ForeignKey(Entity.id), primary_key=True)
__mapper_args__ = {
'polymorphic_identity': __tablename__,
}

View File

@ -0,0 +1,14 @@
from sqlalchemy import Column, Integer, ForeignKey
from .devices import Device
class Light(Device):
__tablename__ = 'light'
id = Column(Integer, ForeignKey(Device.id), primary_key=True)
__mapper_args__ = {
'polymorphic_identity': __tablename__,
}

View File

@ -0,0 +1,15 @@
from sqlalchemy import Column, Integer, ForeignKey, Boolean
from .devices import Device
class Switch(Device):
__tablename__ = 'switch'
id = Column(Integer, ForeignKey(Device.id), primary_key=True)
state = Column(Boolean)
__mapper_args__ = {
'polymorphic_identity': __tablename__,
}

View File

@ -1,11 +1,11 @@
"""
.. moduleauthor:: Fabio Manganiello <blacklight86@gmail.com>
"""
import time import time
from contextlib import contextmanager
from multiprocessing import RLock
from typing import Generator
from sqlalchemy import create_engine, Table, MetaData from sqlalchemy import create_engine, Table, MetaData
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session, sessionmaker, scoped_session
from platypush.plugins import Plugin, action from platypush.plugins import Plugin, action
@ -30,22 +30,23 @@ class DbPlugin(Plugin):
""" """
super().__init__() super().__init__()
self.engine = self._get_engine(engine, *args, **kwargs) self.engine = self.get_engine(engine, *args, **kwargs)
self._session_locks = {}
def _get_engine(self, engine=None, *args, **kwargs): def get_engine(self, engine=None, *args, **kwargs) -> Engine:
if engine: if engine:
if isinstance(engine, Engine): if isinstance(engine, Engine):
return engine return engine
if engine.startswith('sqlite://'): if engine.startswith('sqlite://'):
kwargs['connect_args'] = {'check_same_thread': False} kwargs['connect_args'] = {'check_same_thread': False}
return create_engine(engine, *args, **kwargs) return create_engine(engine, *args, **kwargs) # type: ignore
assert self.engine
return self.engine return self.engine
# noinspection PyUnusedLocal
@staticmethod @staticmethod
def _build_condition(table, column, value): def _build_condition(_, column, value):
if isinstance(value, str): if isinstance(value, str):
value = "'{}'".format(value) value = "'{}'".format(value)
elif not isinstance(value, int) and not isinstance(value, float): elif not isinstance(value, int) and not isinstance(value, float):
@ -73,14 +74,14 @@ class DbPlugin(Plugin):
:param kwargs: Extra kwargs that will be passed to ``sqlalchemy.create_engine`` (seehttps:///docs.sqlalchemy.org/en/latest/core/engines.html) :param kwargs: Extra kwargs that will be passed to ``sqlalchemy.create_engine`` (seehttps:///docs.sqlalchemy.org/en/latest/core/engines.html)
""" """
engine = self._get_engine(engine, *args, **kwargs) engine = self.get_engine(engine, *args, **kwargs)
with engine.connect() as connection: with engine.connect() as connection:
connection.execute(statement) connection.execute(statement)
def _get_table(self, table, engine=None, *args, **kwargs): def _get_table(self, table, engine=None, *args, **kwargs):
if not engine: if not engine:
engine = self._get_engine(engine, *args, **kwargs) engine = self.get_engine(engine, *args, **kwargs)
db_ok = False db_ok = False
n_tries = 0 n_tries = 0
@ -98,7 +99,7 @@ class DbPlugin(Plugin):
self.logger.exception(e) self.logger.exception(e)
self.logger.info('Waiting {} seconds before retrying'.format(wait_time)) self.logger.info('Waiting {} seconds before retrying'.format(wait_time))
time.sleep(wait_time) time.sleep(wait_time)
engine = self._get_engine(engine, *args, **kwargs) engine = self.get_engine(engine, *args, **kwargs)
if not db_ok and last_error: if not db_ok and last_error:
raise last_error raise last_error
@ -163,7 +164,7 @@ class DbPlugin(Plugin):
] ]
""" """
engine = self._get_engine(engine, *args, **kwargs) engine = self.get_engine(engine, *args, **kwargs)
if table: if table:
table, engine = self._get_table(table, engine=engine, *args, **kwargs) table, engine = self._get_table(table, engine=engine, *args, **kwargs)
@ -234,7 +235,7 @@ class DbPlugin(Plugin):
if key_columns is None: if key_columns is None:
key_columns = [] key_columns = []
engine = self._get_engine(engine, *args, **kwargs) engine = self.get_engine(engine, *args, **kwargs)
for record in records: for record in records:
table, engine = self._get_table(table, engine=engine, *args, **kwargs) table, engine = self._get_table(table, engine=engine, *args, **kwargs)
@ -293,7 +294,7 @@ class DbPlugin(Plugin):
} }
""" """
engine = self._get_engine(engine, *args, **kwargs) engine = self.get_engine(engine, *args, **kwargs)
for record in records: for record in records:
table, engine = self._get_table(table, engine=engine, *args, **kwargs) table, engine = self._get_table(table, engine=engine, *args, **kwargs)
@ -341,7 +342,7 @@ class DbPlugin(Plugin):
} }
""" """
engine = self._get_engine(engine, *args, **kwargs) engine = self.get_engine(engine, *args, **kwargs)
for record in records: for record in records:
table, engine = self._get_table(table, engine=engine, *args, **kwargs) table, engine = self._get_table(table, engine=engine, *args, **kwargs)
@ -352,5 +353,22 @@ class DbPlugin(Plugin):
engine.execute(delete) engine.execute(delete)
def create_all(self, engine, base):
self._session_locks[engine.url] = self._session_locks.get(engine.url, RLock())
with self._session_locks[engine.url]:
base.metadata.create_all(engine)
@contextmanager
def get_session(self, engine=None, *args, **kwargs) -> Generator[Session, None, None]:
engine = self.get_engine(engine, *args, **kwargs)
self._session_locks[engine.url] = self._session_locks.get(engine.url, RLock())
with self._session_locks[engine.url]:
session = scoped_session(sessionmaker(expire_on_commit=False))
session.configure(bind=engine)
s = session()
yield s
s.commit()
s.close()
# vim:sw=4:ts=4:et: # vim:sw=4:ts=4:et:

View File

@ -1,9 +1,12 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Union from typing import List, Union
from platypush.entities import manages
from platypush.entities.switches import Switch
from platypush.plugins import Plugin, action from platypush.plugins import Plugin, action
@manages(Switch)
class SwitchPlugin(Plugin, ABC): class SwitchPlugin(Plugin, ABC):
""" """
Abstract class for interacting with switch devices Abstract class for interacting with switch devices
@ -46,7 +49,7 @@ class SwitchPlugin(Plugin, ABC):
return devices return devices
@action @action
def status(self, device=None, *args, **kwargs) -> Union[dict, List[dict]]: def status(self, device=None, *_, **__) -> Union[dict, List[dict]]:
""" """
Get the status of all the devices, or filter by device name or ID (alias for :meth:`.switch_status`). Get the status of all the devices, or filter by device name or ID (alias for :meth:`.switch_status`).

View File

@ -1,4 +1,4 @@
from typing import Union, Dict, List from typing import Union, Mapping, List, Collection, Optional
from pyHS100 import SmartDevice, SmartPlug, SmartBulb, SmartStrip, Discover, SmartDeviceException from pyHS100 import SmartDevice, SmartPlug, SmartBulb, SmartStrip, Discover, SmartDeviceException
@ -20,8 +20,12 @@ class SwitchTplinkPlugin(SwitchPlugin):
_ip_to_dev = {} _ip_to_dev = {}
_alias_to_dev = {} _alias_to_dev = {}
def __init__(self, plugs: Union[Dict[str, str], List[str]] = None, bulbs: Union[Dict[str, str], List[str]] = None, def __init__(
strips: Union[Dict[str, str], List[str]] = None, **kwargs): self,
plugs: Optional[Union[Mapping[str, str], List[str]]] = None,
bulbs: Optional[Union[Mapping[str, str], List[str]]] = None,
strips: Optional[Union[Mapping[str, str], List[str]]] = None, **kwargs
):
""" """
:param plugs: Optional list of IP addresses or name->address mapping if you have a static list of :param plugs: Optional list of IP addresses or name->address mapping if you have a static list of
TpLink plugs and you want to save on the scan time. TpLink plugs and you want to save on the scan time.
@ -62,7 +66,7 @@ class SwitchTplinkPlugin(SwitchPlugin):
self._update_devices() self._update_devices()
def _update_devices(self, devices: Dict[str, SmartDevice] = None): def _update_devices(self, devices: Optional[Mapping[str, SmartDevice]] = None):
for (addr, info) in self._static_devices.items(): for (addr, info) in self._static_devices.items():
try: try:
dev = info['type'](addr) dev = info['type'](addr)
@ -75,6 +79,26 @@ class SwitchTplinkPlugin(SwitchPlugin):
self._ip_to_dev[ip] = dev self._ip_to_dev[ip] = dev
self._alias_to_dev[dev.alias] = dev self._alias_to_dev[dev.alias] = dev
if devices:
self.publish_entities(devices.values()) # type: ignore
def transform_entities(self, devices: Collection[SmartDevice]):
from platypush.entities.switches import Switch
return super().transform_entities([ # type: ignore
Switch(
id=dev.host,
name=dev.alias,
state=dev.is_on,
data={
'current_consumption': dev.current_consumption(),
'ip': dev.host,
'host': dev.host,
'hw_info': dev.hw_info,
}
)
for dev in (devices or [])
])
def _scan(self): def _scan(self):
devices = Discover.discover() devices = Discover.discover()
self._update_devices(devices) self._update_devices(devices)
@ -95,8 +119,15 @@ class SwitchTplinkPlugin(SwitchPlugin):
else: else:
raise RuntimeError('Device {} not found'.format(device)) raise RuntimeError('Device {} not found'.format(device))
def _set(self, device: SmartDevice, state: bool):
action_name = 'turn_on' if state else 'turn_off'
action = getattr(device, action_name)
action()
self.publish_entities([device]) # type: ignore
return self._serialize(device)
@action @action
def on(self, device, **kwargs): def on(self, device, **_):
""" """
Turn on a device Turn on a device
@ -105,11 +136,10 @@ class SwitchTplinkPlugin(SwitchPlugin):
""" """
device = self._get_device(device) device = self._get_device(device)
device.turn_on() return self._set(device, True)
return self.status(device)
@action @action
def off(self, device, **kwargs): def off(self, device, **_):
""" """
Turn off a device Turn off a device
@ -118,11 +148,10 @@ class SwitchTplinkPlugin(SwitchPlugin):
""" """
device = self._get_device(device) device = self._get_device(device)
device.turn_off() return self._set(device, False)
return self.status(device)
@action @action
def toggle(self, device, **kwargs): def toggle(self, device, **_):
""" """
Toggle the state of a device (on/off) Toggle the state of a device (on/off)
@ -131,12 +160,10 @@ class SwitchTplinkPlugin(SwitchPlugin):
""" """
device = self._get_device(device) device = self._get_device(device)
return self._set(device, not device.is_on)
if device.is_on: @staticmethod
device.turn_off() def _serialize(device: SmartDevice) -> dict:
else:
device.turn_on()
return { return {
'current_consumption': device.current_consumption(), 'current_consumption': device.current_consumption(),
'id': device.host, 'id': device.host,
@ -150,15 +177,8 @@ class SwitchTplinkPlugin(SwitchPlugin):
@property @property
def switches(self) -> List[dict]: def switches(self) -> List[dict]:
return [ return [
{ self._serialize(dev)
'current_consumption': dev.current_consumption(), for dev in self._scan().values()
'id': ip,
'ip': ip,
'host': dev.host,
'hw_info': dev.hw_info,
'name': dev.alias,
'on': dev.is_on,
} for (ip, dev) in self._scan().items()
] ]

View File

@ -13,7 +13,7 @@ except ImportError:
from jwt import PyJWTError, encode as jwt_encode, decode as jwt_decode from jwt import PyJWTError, encode as jwt_encode, decode as jwt_decode
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import sessionmaker, scoped_session from sqlalchemy.orm import make_transient
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from platypush.context import get_plugin from platypush.context import get_plugin
@ -28,126 +28,124 @@ class UserManager:
Main class for managing platform users Main class for managing platform users
""" """
# noinspection PyProtectedMember
def __init__(self): def __init__(self):
db_plugin = get_plugin('db') self.db = get_plugin('db')
if not db_plugin: assert self.db
raise ModuleNotFoundError('Please enable/configure the db plugin for multi-user support') self._engine = self.db.get_engine()
self.db.create_all(self._engine, Base)
self._engine = db_plugin._get_engine() @staticmethod
def _mask_password(user):
def get_user(self, username): make_transient(user)
session = self._get_db_session()
user = self._get_user(session, username)
if not user:
return None
# Hide password
user.password = None user.password = None
return user return user
def get_user(self, username):
with self.db.get_session() as session:
user = self._get_user(session, username)
if not user:
return None
session.expunge(user)
return self._mask_password(user)
def get_user_count(self): def get_user_count(self):
session = self._get_db_session() with self.db.get_session() as session:
return session.query(User).count() return session.query(User).count()
def get_users(self): def get_users(self):
session = self._get_db_session() with self.db.get_session() as session:
return session.query(User) return session.query(User)
def create_user(self, username, password, **kwargs): def create_user(self, username, password, **kwargs):
session = self._get_db_session()
if not username: if not username:
raise ValueError('Invalid or empty username') raise ValueError('Invalid or empty username')
if not password: if not password:
raise ValueError('Please provide a password for the user') raise ValueError('Please provide a password for the user')
user = self._get_user(session, username) with self.db.get_session() as session:
if user: user = self._get_user(session, username)
raise NameError('The user {} already exists'.format(username)) if user:
raise NameError('The user {} already exists'.format(username))
record = User(username=username, password=self._encrypt_password(password), record = User(username=username, password=self._encrypt_password(password),
created_at=datetime.datetime.utcnow(), **kwargs) created_at=datetime.datetime.utcnow(), **kwargs)
session.add(record) session.add(record)
session.commit() session.commit()
user = self._get_user(session, username) user = self._get_user(session, username)
# Hide password return self._mask_password(user)
user.password = None
return user
def update_password(self, username, old_password, new_password): def update_password(self, username, old_password, new_password):
session = self._get_db_session() with self.db.get_session() as session:
if not self._authenticate_user(session, username, old_password): if not self._authenticate_user(session, username, old_password):
return False return False
user = self._get_user(session, username) user = self._get_user(session, username)
user.password = self._encrypt_password(new_password) user.password = self._encrypt_password(new_password)
session.commit() session.commit()
return True return True
def authenticate_user(self, username, password): def authenticate_user(self, username, password):
session = self._get_db_session() with self.db.get_session() as session:
return self._authenticate_user(session, username, password) return self._authenticate_user(session, username, password)
def authenticate_user_session(self, session_token): def authenticate_user_session(self, session_token):
session = self._get_db_session() with self.db.get_session() as session:
user_session = session.query(UserSession).filter_by(session_token=session_token).first() user_session = session.query(UserSession).filter_by(session_token=session_token).first()
if not user_session or ( if not user_session or (
user_session.expires_at and user_session.expires_at < datetime.datetime.utcnow()): user_session.expires_at and user_session.expires_at < datetime.datetime.utcnow()):
return None, None return None, None
user = session.query(User).filter_by(user_id=user_session.user_id).first() user = session.query(User).filter_by(user_id=user_session.user_id).first()
return self._mask_password(user), user_session
# Hide password
user.password = None
return user, session
def delete_user(self, username): def delete_user(self, username):
session = self._get_db_session() with self.db.get_session() as session:
user = self._get_user(session, username) user = self._get_user(session, username)
if not user: if not user:
raise NameError('No such user: {}'.format(username)) raise NameError('No such user: {}'.format(username))
user_sessions = session.query(UserSession).filter_by(user_id=user.user_id).all() user_sessions = session.query(UserSession).filter_by(user_id=user.user_id).all()
for user_session in user_sessions: for user_session in user_sessions:
session.delete(user_session) session.delete(user_session)
session.delete(user) session.delete(user)
session.commit() session.commit()
return True return True
def delete_user_session(self, session_token): def delete_user_session(self, session_token):
session = self._get_db_session() with self.db.get_session() as session:
user_session = session.query(UserSession).filter_by(session_token=session_token).first() user_session = session.query(UserSession).filter_by(session_token=session_token).first()
if not user_session: if not user_session:
return False return False
session.delete(user_session) session.delete(user_session)
session.commit() session.commit()
return True return True
def create_user_session(self, username, password, expires_at=None): def create_user_session(self, username, password, expires_at=None):
session = self._get_db_session() with self.db.get_session() as session:
user = self._authenticate_user(session, username, password) user = self._authenticate_user(session, username, password)
if not user: if not user:
return None return None
if expires_at: if expires_at:
if isinstance(expires_at, int) or isinstance(expires_at, float): if isinstance(expires_at, int) or isinstance(expires_at, float):
expires_at = datetime.datetime.fromtimestamp(expires_at) expires_at = datetime.datetime.fromtimestamp(expires_at)
elif isinstance(expires_at, str): elif isinstance(expires_at, str):
expires_at = datetime.datetime.fromisoformat(expires_at) expires_at = datetime.datetime.fromisoformat(expires_at)
user_session = UserSession(user_id=user.user_id, session_token=self.generate_session_token(), user_session = UserSession(user_id=user.user_id, session_token=self.generate_session_token(),
csrf_token=self.generate_session_token(), created_at=datetime.datetime.utcnow(), csrf_token=self.generate_session_token(), created_at=datetime.datetime.utcnow(),
expires_at=expires_at) expires_at=expires_at)
session.add(user_session) session.add(user_session)
session.commit() session.commit()
return user_session return user_session
@staticmethod @staticmethod
def _get_user(session, username): def _get_user(session, username):
@ -180,8 +178,8 @@ class UserManager:
:param session_token: Session token. :param session_token: Session token.
""" """
session = self._get_db_session() with self.db.get_session() as session:
return session.query(User).join(UserSession).filter_by(session_token=session_token).first() return session.query(User).join(UserSession).filter_by(session_token=session_token).first()
def generate_jwt_token(self, username: str, password: str, expires_at: Optional[datetime.datetime] = None) -> str: def generate_jwt_token(self, username: str, password: str, expires_at: Optional[datetime.datetime] = None) -> str:
""" """
@ -240,12 +238,6 @@ class UserManager:
return payload return payload
def _get_db_session(self):
Base.metadata.create_all(self._engine)
session = scoped_session(sessionmaker(expire_on_commit=False))
session.configure(bind=self._engine)
return session()
def _authenticate_user(self, session, username, password): def _authenticate_user(self, session, username, password):
""" """
:return: :class:`platypush.user.User` instance if the user exists and the password is valid, ``None`` otherwise. :return: :class:`platypush.user.User` instance if the user exists and the password is valid, ``None`` otherwise.