forked from platypush/platypush
Basic support for entities on the local db and implemented support for switch entities on the tplink plugin
This commit is contained in:
parent
b1491b8048
commit
4ee7e4db29
12 changed files with 506 additions and 134 deletions
|
@ -9,11 +9,13 @@ import argparse
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
from .bus.redis import RedisBus
|
||||
from .config import Config
|
||||
from .context import register_backends, register_plugins
|
||||
from .cron.scheduler import CronScheduler
|
||||
from .entities import init_entities_engine, EntitiesEngine
|
||||
from .event.processor import EventProcessor
|
||||
from .logger import Logger
|
||||
from .message.event import Event
|
||||
|
@ -86,6 +88,7 @@ class Daemon:
|
|||
self.no_capture_stdout = no_capture_stdout
|
||||
self.no_capture_stderr = no_capture_stderr
|
||||
self.event_processor = EventProcessor()
|
||||
self.entities_engine: Optional[EntitiesEngine] = None
|
||||
self.requests_to_process = requests_to_process
|
||||
self.processed_requests = 0
|
||||
self.cron_scheduler = None
|
||||
|
@ -161,16 +164,25 @@ class Daemon:
|
|||
""" Stops the backends and the bus """
|
||||
from .plugins import RunnablePlugin
|
||||
|
||||
for backend in self.backends.values():
|
||||
backend.stop()
|
||||
if self.backends:
|
||||
for backend in self.backends.values():
|
||||
backend.stop()
|
||||
|
||||
for plugin in get_enabled_plugins().values():
|
||||
if isinstance(plugin, RunnablePlugin):
|
||||
plugin.stop()
|
||||
|
||||
self.bus.stop()
|
||||
if self.bus:
|
||||
self.bus.stop()
|
||||
self.bus = None
|
||||
|
||||
if self.cron_scheduler:
|
||||
self.cron_scheduler.stop()
|
||||
self.cron_scheduler = None
|
||||
|
||||
if self.entities_engine:
|
||||
self.entities_engine.stop()
|
||||
self.entities_engine = None
|
||||
|
||||
def run(self):
|
||||
""" Start the daemon """
|
||||
|
@ -192,6 +204,9 @@ class Daemon:
|
|||
# Initialize the plugins
|
||||
register_plugins(bus=self.bus)
|
||||
|
||||
# Initialize the entities engine
|
||||
self.entities_engine = init_entities_engine()
|
||||
|
||||
# Start the cron scheduler
|
||||
if Config.get_cronjobs():
|
||||
self.cron_scheduler = CronScheduler(jobs=Config.get_cronjobs())
|
||||
|
|
36
platypush/entities/__init__.py
Normal file
36
platypush/entities/__init__.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
import warnings
|
||||
from typing import Collection, Optional
|
||||
|
||||
from ._base import Entity
|
||||
from ._engine import EntitiesEngine
|
||||
from ._registry import manages, register_entity_plugin, get_plugin_registry
|
||||
|
||||
_engine: Optional[EntitiesEngine] = None
|
||||
|
||||
|
||||
def init_entities_engine() -> EntitiesEngine:
|
||||
from ._base import init_entities_db
|
||||
global _engine
|
||||
init_entities_db()
|
||||
_engine = EntitiesEngine()
|
||||
_engine.start()
|
||||
return _engine
|
||||
|
||||
|
||||
def publish_entities(entities: Collection[Entity]):
|
||||
if not _engine:
|
||||
warnings.warn('No entities engine registered')
|
||||
return
|
||||
|
||||
_engine.post(*entities)
|
||||
|
||||
__all__ = (
|
||||
'Entity',
|
||||
'EntitiesEngine',
|
||||
'init_entities_engine',
|
||||
'publish_entities',
|
||||
'register_entity_plugin',
|
||||
'get_plugin_registry',
|
||||
'manages',
|
||||
)
|
||||
|
73
platypush/entities/_base.py
Normal file
73
platypush/entities/_base.py
Normal file
|
@ -0,0 +1,73 @@
|
|||
import inspect
|
||||
import pathlib
|
||||
from datetime import datetime
|
||||
from typing import Mapping, Type
|
||||
|
||||
import pkgutil
|
||||
from sqlalchemy import Column, Index, Integer, String, DateTime, JSON, UniqueConstraint
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
Base = declarative_base()
|
||||
entities_registry: Mapping[Type['Entity'], Mapping] = {}
|
||||
|
||||
|
||||
class Entity(Base):
|
||||
"""
|
||||
Model for a general-purpose platform entity
|
||||
"""
|
||||
|
||||
__tablename__ = 'entity'
|
||||
|
||||
id = Column(Integer, autoincrement=True, primary_key=True)
|
||||
external_id = Column(String, nullable=True)
|
||||
name = Column(String, nullable=False, index=True)
|
||||
type = Column(String, nullable=False, index=True)
|
||||
plugin = Column(String, nullable=False)
|
||||
data = Column(JSON, default=dict)
|
||||
created_at = Column(DateTime(timezone=False), default=datetime.utcnow(), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=False), default=datetime.utcnow(), onupdate=datetime.now())
|
||||
|
||||
UniqueConstraint(external_id, plugin)
|
||||
|
||||
__table_args__ = (
|
||||
Index(name, plugin),
|
||||
)
|
||||
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': __tablename__,
|
||||
'polymorphic_on': type,
|
||||
}
|
||||
|
||||
|
||||
def _discover_entity_types():
|
||||
from platypush.context import get_plugin
|
||||
logger = get_plugin('logger')
|
||||
assert logger
|
||||
|
||||
for loader, modname, _ in pkgutil.walk_packages(
|
||||
path=[str(pathlib.Path(__file__).parent.absolute())],
|
||||
prefix=__package__ + '.',
|
||||
onerror=lambda _: None
|
||||
):
|
||||
try:
|
||||
mod_loader = loader.find_module(modname) # type: ignore
|
||||
assert mod_loader
|
||||
module = mod_loader.load_module() # type: ignore
|
||||
except Exception as e:
|
||||
logger.warning(f'Could not import module {modname}')
|
||||
logger.exception(e)
|
||||
continue
|
||||
|
||||
for _, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, Entity):
|
||||
entities_registry[obj] = {}
|
||||
|
||||
|
||||
def init_entities_db():
|
||||
from platypush.context import get_plugin
|
||||
_discover_entity_types()
|
||||
db = get_plugin('db')
|
||||
assert db
|
||||
engine = db.get_engine()
|
||||
db.create_all(engine, Base)
|
||||
|
110
platypush/entities/_engine.py
Normal file
110
platypush/entities/_engine.py
Normal file
|
@ -0,0 +1,110 @@
|
|||
from logging import getLogger
|
||||
from queue import Queue, Empty
|
||||
from threading import Thread, Event
|
||||
from time import time
|
||||
from typing import Iterable, List
|
||||
|
||||
from sqlalchemy import and_, or_, inspect as schema_inspect
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.elements import Null
|
||||
|
||||
from ._base import Entity
|
||||
|
||||
|
||||
class EntitiesEngine(Thread):
|
||||
# Processing queue timeout in seconds
|
||||
_queue_timeout = 5.
|
||||
|
||||
def __init__(self):
|
||||
obj_name = self.__class__.__name__
|
||||
super().__init__(name=obj_name)
|
||||
self.logger = getLogger(name=obj_name)
|
||||
self._queue = Queue()
|
||||
self._should_stop = Event()
|
||||
|
||||
def post(self, *entities: Entity):
|
||||
for entity in entities:
|
||||
self._queue.put(entity)
|
||||
|
||||
@property
|
||||
def should_stop(self) -> bool:
|
||||
return self._should_stop.is_set()
|
||||
|
||||
def stop(self):
|
||||
self._should_stop.set()
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
self.logger.info('Started entities engine')
|
||||
|
||||
while not self.should_stop:
|
||||
msgs = []
|
||||
last_poll_time = time()
|
||||
|
||||
while not self.should_stop and (
|
||||
time() - last_poll_time < self._queue_timeout):
|
||||
try:
|
||||
msg = self._queue.get(block=True, timeout=0.5)
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
if msg:
|
||||
msgs.append(msg)
|
||||
|
||||
if not msgs or self.should_stop:
|
||||
continue
|
||||
|
||||
self._process_entities(*msgs)
|
||||
|
||||
self.logger.info('Stopped entities engine')
|
||||
|
||||
def _get_if_exist(self, session: Session, entities: Iterable[Entity]) -> Iterable[Entity]:
|
||||
existing_entities = {
|
||||
(entity.external_id or entity.name, entity.plugin): entity
|
||||
for entity in session.query(Entity).filter(
|
||||
or_(*[
|
||||
and_(Entity.external_id == entity.external_id, Entity.plugin == entity.plugin)
|
||||
if entity.external_id is not None else
|
||||
and_(Entity.name == entity.name, Entity.plugin == entity.plugin)
|
||||
for entity in entities
|
||||
])
|
||||
).all()
|
||||
}
|
||||
|
||||
return [
|
||||
existing_entities.get(
|
||||
(entity.external_id or entity.name, entity.plugin), None
|
||||
) for entity in entities
|
||||
]
|
||||
|
||||
def _merge_entities(
|
||||
self, entities: List[Entity],
|
||||
existing_entities: List[Entity]
|
||||
) -> List[Entity]:
|
||||
new_entities = []
|
||||
|
||||
for i, entity in enumerate(entities):
|
||||
existing_entity = existing_entities[i]
|
||||
if existing_entity:
|
||||
inspector = schema_inspect(entity.__class__)
|
||||
columns = [col.key for col in inspector.mapper.column_attrs]
|
||||
for col in columns:
|
||||
new_value = getattr(entity, col)
|
||||
if new_value is not None and new_value.__class__ != Null:
|
||||
setattr(existing_entity, col, getattr(entity, col))
|
||||
|
||||
new_entities.append(existing_entity)
|
||||
else:
|
||||
new_entities.append(entity)
|
||||
|
||||
return new_entities
|
||||
|
||||
def _process_entities(self, *entities: Entity):
|
||||
from platypush.context import get_plugin
|
||||
|
||||
with get_plugin('db').get_session() as session: # type: ignore
|
||||
existing_entities = self._get_if_exist(session, entities)
|
||||
entities = self._merge_entities(entities, existing_entities) # type: ignore
|
||||
session.add_all(entities)
|
||||
session.commit()
|
||||
|
62
platypush/entities/_registry.py
Normal file
62
platypush/entities/_registry.py
Normal file
|
@ -0,0 +1,62 @@
|
|||
from datetime import datetime
|
||||
from typing import Optional, Mapping, Dict, Collection, Type
|
||||
|
||||
from platypush.plugins import Plugin
|
||||
from platypush.utils import get_plugin_name_by_class
|
||||
|
||||
from ._base import Entity
|
||||
|
||||
_entity_plugin_registry: Mapping[Type[Entity], Dict[str, Plugin]] = {}
|
||||
|
||||
|
||||
def register_entity_plugin(entity_type: Type[Entity], plugin: Plugin):
|
||||
plugins = _entity_plugin_registry.get(entity_type, {})
|
||||
plugin_name = get_plugin_name_by_class(plugin.__class__)
|
||||
assert plugin_name
|
||||
plugins[plugin_name] = plugin
|
||||
_entity_plugin_registry[entity_type] = plugins
|
||||
|
||||
|
||||
def get_plugin_registry():
|
||||
return _entity_plugin_registry.copy()
|
||||
|
||||
|
||||
class EntityManagerMixin:
|
||||
def transform_entities(self, entities):
|
||||
entities = entities or []
|
||||
for entity in entities:
|
||||
if entity.id:
|
||||
# Entity IDs can only refer to the internal primary key
|
||||
entity.external_id = entity.id
|
||||
entity.id = None # type: ignore
|
||||
|
||||
entity.plugin = get_plugin_name_by_class(self.__class__) # type: ignore
|
||||
entity.updated_at = datetime.utcnow()
|
||||
|
||||
return entities
|
||||
|
||||
def publish_entities(self, entities: Optional[Collection[Entity]]):
|
||||
from . import publish_entities
|
||||
entities = self.transform_entities(entities)
|
||||
publish_entities(entities)
|
||||
|
||||
|
||||
def manages(*entities: Type[Entity]):
|
||||
def wrapper(plugin: Type[Plugin]):
|
||||
init = plugin.__init__
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
for entity_type in entities:
|
||||
register_entity_plugin(entity_type, self)
|
||||
|
||||
init(self, *args, **kwargs)
|
||||
|
||||
plugin.__init__ = __init__
|
||||
# Inject the EntityManagerMixin
|
||||
if EntityManagerMixin not in plugin.__bases__:
|
||||
plugin.__bases__ = (EntityManagerMixin,) + plugin.__bases__
|
||||
|
||||
return plugin
|
||||
|
||||
return wrapper
|
||||
|
14
platypush/entities/devices.py
Normal file
14
platypush/entities/devices.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
from sqlalchemy import Column, Integer, ForeignKey
|
||||
|
||||
from ._base import Entity
|
||||
|
||||
|
||||
class Device(Entity):
|
||||
__tablename__ = 'device'
|
||||
|
||||
id = Column(Integer, ForeignKey(Entity.id), primary_key=True)
|
||||
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': __tablename__,
|
||||
}
|
||||
|
14
platypush/entities/lights.py
Normal file
14
platypush/entities/lights.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
from sqlalchemy import Column, Integer, ForeignKey
|
||||
|
||||
from .devices import Device
|
||||
|
||||
|
||||
class Light(Device):
|
||||
__tablename__ = 'light'
|
||||
|
||||
id = Column(Integer, ForeignKey(Device.id), primary_key=True)
|
||||
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': __tablename__,
|
||||
}
|
||||
|
15
platypush/entities/switches.py
Normal file
15
platypush/entities/switches.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
from sqlalchemy import Column, Integer, ForeignKey, Boolean
|
||||
|
||||
from .devices import Device
|
||||
|
||||
|
||||
class Switch(Device):
|
||||
__tablename__ = 'switch'
|
||||
|
||||
id = Column(Integer, ForeignKey(Device.id), primary_key=True)
|
||||
state = Column(Boolean)
|
||||
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': __tablename__,
|
||||
}
|
||||
|
|
@ -1,11 +1,11 @@
|
|||
"""
|
||||
.. moduleauthor:: Fabio Manganiello <blacklight86@gmail.com>
|
||||
"""
|
||||
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from multiprocessing import RLock
|
||||
from typing import Generator
|
||||
|
||||
from sqlalchemy import create_engine, Table, MetaData
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import Session, sessionmaker, scoped_session
|
||||
|
||||
from platypush.plugins import Plugin, action
|
||||
|
||||
|
@ -30,22 +30,23 @@ class DbPlugin(Plugin):
|
|||
"""
|
||||
|
||||
super().__init__()
|
||||
self.engine = self._get_engine(engine, *args, **kwargs)
|
||||
self.engine = self.get_engine(engine, *args, **kwargs)
|
||||
self._session_locks = {}
|
||||
|
||||
def _get_engine(self, engine=None, *args, **kwargs):
|
||||
def get_engine(self, engine=None, *args, **kwargs) -> Engine:
|
||||
if engine:
|
||||
if isinstance(engine, Engine):
|
||||
return engine
|
||||
if engine.startswith('sqlite://'):
|
||||
kwargs['connect_args'] = {'check_same_thread': False}
|
||||
|
||||
return create_engine(engine, *args, **kwargs)
|
||||
return create_engine(engine, *args, **kwargs) # type: ignore
|
||||
|
||||
assert self.engine
|
||||
return self.engine
|
||||
|
||||
# noinspection PyUnusedLocal
|
||||
@staticmethod
|
||||
def _build_condition(table, column, value):
|
||||
def _build_condition(_, column, value):
|
||||
if isinstance(value, str):
|
||||
value = "'{}'".format(value)
|
||||
elif not isinstance(value, int) and not isinstance(value, float):
|
||||
|
@ -73,14 +74,14 @@ class DbPlugin(Plugin):
|
|||
:param kwargs: Extra kwargs that will be passed to ``sqlalchemy.create_engine`` (seehttps:///docs.sqlalchemy.org/en/latest/core/engines.html)
|
||||
"""
|
||||
|
||||
engine = self._get_engine(engine, *args, **kwargs)
|
||||
engine = self.get_engine(engine, *args, **kwargs)
|
||||
|
||||
with engine.connect() as connection:
|
||||
connection.execute(statement)
|
||||
|
||||
def _get_table(self, table, engine=None, *args, **kwargs):
|
||||
if not engine:
|
||||
engine = self._get_engine(engine, *args, **kwargs)
|
||||
engine = self.get_engine(engine, *args, **kwargs)
|
||||
|
||||
db_ok = False
|
||||
n_tries = 0
|
||||
|
@ -98,7 +99,7 @@ class DbPlugin(Plugin):
|
|||
self.logger.exception(e)
|
||||
self.logger.info('Waiting {} seconds before retrying'.format(wait_time))
|
||||
time.sleep(wait_time)
|
||||
engine = self._get_engine(engine, *args, **kwargs)
|
||||
engine = self.get_engine(engine, *args, **kwargs)
|
||||
|
||||
if not db_ok and last_error:
|
||||
raise last_error
|
||||
|
@ -163,7 +164,7 @@ class DbPlugin(Plugin):
|
|||
]
|
||||
"""
|
||||
|
||||
engine = self._get_engine(engine, *args, **kwargs)
|
||||
engine = self.get_engine(engine, *args, **kwargs)
|
||||
|
||||
if table:
|
||||
table, engine = self._get_table(table, engine=engine, *args, **kwargs)
|
||||
|
@ -234,7 +235,7 @@ class DbPlugin(Plugin):
|
|||
if key_columns is None:
|
||||
key_columns = []
|
||||
|
||||
engine = self._get_engine(engine, *args, **kwargs)
|
||||
engine = self.get_engine(engine, *args, **kwargs)
|
||||
|
||||
for record in records:
|
||||
table, engine = self._get_table(table, engine=engine, *args, **kwargs)
|
||||
|
@ -293,7 +294,7 @@ class DbPlugin(Plugin):
|
|||
}
|
||||
"""
|
||||
|
||||
engine = self._get_engine(engine, *args, **kwargs)
|
||||
engine = self.get_engine(engine, *args, **kwargs)
|
||||
|
||||
for record in records:
|
||||
table, engine = self._get_table(table, engine=engine, *args, **kwargs)
|
||||
|
@ -341,7 +342,7 @@ class DbPlugin(Plugin):
|
|||
}
|
||||
"""
|
||||
|
||||
engine = self._get_engine(engine, *args, **kwargs)
|
||||
engine = self.get_engine(engine, *args, **kwargs)
|
||||
|
||||
for record in records:
|
||||
table, engine = self._get_table(table, engine=engine, *args, **kwargs)
|
||||
|
@ -352,5 +353,22 @@ class DbPlugin(Plugin):
|
|||
|
||||
engine.execute(delete)
|
||||
|
||||
def create_all(self, engine, base):
|
||||
self._session_locks[engine.url] = self._session_locks.get(engine.url, RLock())
|
||||
with self._session_locks[engine.url]:
|
||||
base.metadata.create_all(engine)
|
||||
|
||||
@contextmanager
|
||||
def get_session(self, engine=None, *args, **kwargs) -> Generator[Session, None, None]:
|
||||
engine = self.get_engine(engine, *args, **kwargs)
|
||||
self._session_locks[engine.url] = self._session_locks.get(engine.url, RLock())
|
||||
with self._session_locks[engine.url]:
|
||||
session = scoped_session(sessionmaker(expire_on_commit=False))
|
||||
session.configure(bind=engine)
|
||||
s = session()
|
||||
yield s
|
||||
s.commit()
|
||||
s.close()
|
||||
|
||||
|
||||
# vim:sw=4:ts=4:et:
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import List, Union
|
||||
|
||||
from platypush.entities import manages
|
||||
from platypush.entities.switches import Switch
|
||||
from platypush.plugins import Plugin, action
|
||||
|
||||
|
||||
@manages(Switch)
|
||||
class SwitchPlugin(Plugin, ABC):
|
||||
"""
|
||||
Abstract class for interacting with switch devices
|
||||
|
@ -46,7 +49,7 @@ class SwitchPlugin(Plugin, ABC):
|
|||
return devices
|
||||
|
||||
@action
|
||||
def status(self, device=None, *args, **kwargs) -> Union[dict, List[dict]]:
|
||||
def status(self, device=None, *_, **__) -> Union[dict, List[dict]]:
|
||||
"""
|
||||
Get the status of all the devices, or filter by device name or ID (alias for :meth:`.switch_status`).
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Union, Dict, List
|
||||
from typing import Union, Mapping, List, Collection, Optional
|
||||
|
||||
from pyHS100 import SmartDevice, SmartPlug, SmartBulb, SmartStrip, Discover, SmartDeviceException
|
||||
|
||||
|
@ -20,8 +20,12 @@ class SwitchTplinkPlugin(SwitchPlugin):
|
|||
_ip_to_dev = {}
|
||||
_alias_to_dev = {}
|
||||
|
||||
def __init__(self, plugs: Union[Dict[str, str], List[str]] = None, bulbs: Union[Dict[str, str], List[str]] = None,
|
||||
strips: Union[Dict[str, str], List[str]] = None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
plugs: Optional[Union[Mapping[str, str], List[str]]] = None,
|
||||
bulbs: Optional[Union[Mapping[str, str], List[str]]] = None,
|
||||
strips: Optional[Union[Mapping[str, str], List[str]]] = None, **kwargs
|
||||
):
|
||||
"""
|
||||
:param plugs: Optional list of IP addresses or name->address mapping if you have a static list of
|
||||
TpLink plugs and you want to save on the scan time.
|
||||
|
@ -62,7 +66,7 @@ class SwitchTplinkPlugin(SwitchPlugin):
|
|||
|
||||
self._update_devices()
|
||||
|
||||
def _update_devices(self, devices: Dict[str, SmartDevice] = None):
|
||||
def _update_devices(self, devices: Optional[Mapping[str, SmartDevice]] = None):
|
||||
for (addr, info) in self._static_devices.items():
|
||||
try:
|
||||
dev = info['type'](addr)
|
||||
|
@ -75,6 +79,26 @@ class SwitchTplinkPlugin(SwitchPlugin):
|
|||
self._ip_to_dev[ip] = dev
|
||||
self._alias_to_dev[dev.alias] = dev
|
||||
|
||||
if devices:
|
||||
self.publish_entities(devices.values()) # type: ignore
|
||||
|
||||
def transform_entities(self, devices: Collection[SmartDevice]):
|
||||
from platypush.entities.switches import Switch
|
||||
return super().transform_entities([ # type: ignore
|
||||
Switch(
|
||||
id=dev.host,
|
||||
name=dev.alias,
|
||||
state=dev.is_on,
|
||||
data={
|
||||
'current_consumption': dev.current_consumption(),
|
||||
'ip': dev.host,
|
||||
'host': dev.host,
|
||||
'hw_info': dev.hw_info,
|
||||
}
|
||||
)
|
||||
for dev in (devices or [])
|
||||
])
|
||||
|
||||
def _scan(self):
|
||||
devices = Discover.discover()
|
||||
self._update_devices(devices)
|
||||
|
@ -95,8 +119,15 @@ class SwitchTplinkPlugin(SwitchPlugin):
|
|||
else:
|
||||
raise RuntimeError('Device {} not found'.format(device))
|
||||
|
||||
def _set(self, device: SmartDevice, state: bool):
|
||||
action_name = 'turn_on' if state else 'turn_off'
|
||||
action = getattr(device, action_name)
|
||||
action()
|
||||
self.publish_entities([device]) # type: ignore
|
||||
return self._serialize(device)
|
||||
|
||||
@action
|
||||
def on(self, device, **kwargs):
|
||||
def on(self, device, **_):
|
||||
"""
|
||||
Turn on a device
|
||||
|
||||
|
@ -105,11 +136,10 @@ class SwitchTplinkPlugin(SwitchPlugin):
|
|||
"""
|
||||
|
||||
device = self._get_device(device)
|
||||
device.turn_on()
|
||||
return self.status(device)
|
||||
return self._set(device, True)
|
||||
|
||||
@action
|
||||
def off(self, device, **kwargs):
|
||||
def off(self, device, **_):
|
||||
"""
|
||||
Turn off a device
|
||||
|
||||
|
@ -118,11 +148,10 @@ class SwitchTplinkPlugin(SwitchPlugin):
|
|||
"""
|
||||
|
||||
device = self._get_device(device)
|
||||
device.turn_off()
|
||||
return self.status(device)
|
||||
return self._set(device, False)
|
||||
|
||||
@action
|
||||
def toggle(self, device, **kwargs):
|
||||
def toggle(self, device, **_):
|
||||
"""
|
||||
Toggle the state of a device (on/off)
|
||||
|
||||
|
@ -131,12 +160,10 @@ class SwitchTplinkPlugin(SwitchPlugin):
|
|||
"""
|
||||
|
||||
device = self._get_device(device)
|
||||
return self._set(device, not device.is_on)
|
||||
|
||||
if device.is_on:
|
||||
device.turn_off()
|
||||
else:
|
||||
device.turn_on()
|
||||
|
||||
@staticmethod
|
||||
def _serialize(device: SmartDevice) -> dict:
|
||||
return {
|
||||
'current_consumption': device.current_consumption(),
|
||||
'id': device.host,
|
||||
|
@ -150,15 +177,8 @@ class SwitchTplinkPlugin(SwitchPlugin):
|
|||
@property
|
||||
def switches(self) -> List[dict]:
|
||||
return [
|
||||
{
|
||||
'current_consumption': dev.current_consumption(),
|
||||
'id': ip,
|
||||
'ip': ip,
|
||||
'host': dev.host,
|
||||
'hw_info': dev.hw_info,
|
||||
'name': dev.alias,
|
||||
'on': dev.is_on,
|
||||
} for (ip, dev) in self._scan().items()
|
||||
self._serialize(dev)
|
||||
for dev in self._scan().values()
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ except ImportError:
|
|||
from jwt import PyJWTError, encode as jwt_encode, decode as jwt_decode
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||
from sqlalchemy.orm import make_transient
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from platypush.context import get_plugin
|
||||
|
@ -28,126 +28,124 @@ class UserManager:
|
|||
Main class for managing platform users
|
||||
"""
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
def __init__(self):
|
||||
db_plugin = get_plugin('db')
|
||||
if not db_plugin:
|
||||
raise ModuleNotFoundError('Please enable/configure the db plugin for multi-user support')
|
||||
self.db = get_plugin('db')
|
||||
assert self.db
|
||||
self._engine = self.db.get_engine()
|
||||
self.db.create_all(self._engine, Base)
|
||||
|
||||
self._engine = db_plugin._get_engine()
|
||||
|
||||
def get_user(self, username):
|
||||
session = self._get_db_session()
|
||||
user = self._get_user(session, username)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
# Hide password
|
||||
@staticmethod
|
||||
def _mask_password(user):
|
||||
make_transient(user)
|
||||
user.password = None
|
||||
return user
|
||||
|
||||
def get_user(self, username):
|
||||
with self.db.get_session() as session:
|
||||
user = self._get_user(session, username)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
session.expunge(user)
|
||||
return self._mask_password(user)
|
||||
|
||||
def get_user_count(self):
|
||||
session = self._get_db_session()
|
||||
return session.query(User).count()
|
||||
with self.db.get_session() as session:
|
||||
return session.query(User).count()
|
||||
|
||||
def get_users(self):
|
||||
session = self._get_db_session()
|
||||
return session.query(User)
|
||||
with self.db.get_session() as session:
|
||||
return session.query(User)
|
||||
|
||||
def create_user(self, username, password, **kwargs):
|
||||
session = self._get_db_session()
|
||||
if not username:
|
||||
raise ValueError('Invalid or empty username')
|
||||
if not password:
|
||||
raise ValueError('Please provide a password for the user')
|
||||
|
||||
user = self._get_user(session, username)
|
||||
if user:
|
||||
raise NameError('The user {} already exists'.format(username))
|
||||
with self.db.get_session() as session:
|
||||
user = self._get_user(session, username)
|
||||
if user:
|
||||
raise NameError('The user {} already exists'.format(username))
|
||||
|
||||
record = User(username=username, password=self._encrypt_password(password),
|
||||
created_at=datetime.datetime.utcnow(), **kwargs)
|
||||
record = User(username=username, password=self._encrypt_password(password),
|
||||
created_at=datetime.datetime.utcnow(), **kwargs)
|
||||
|
||||
session.add(record)
|
||||
session.commit()
|
||||
user = self._get_user(session, username)
|
||||
session.add(record)
|
||||
session.commit()
|
||||
user = self._get_user(session, username)
|
||||
|
||||
# Hide password
|
||||
user.password = None
|
||||
return user
|
||||
return self._mask_password(user)
|
||||
|
||||
def update_password(self, username, old_password, new_password):
|
||||
session = self._get_db_session()
|
||||
if not self._authenticate_user(session, username, old_password):
|
||||
return False
|
||||
with self.db.get_session() as session:
|
||||
if not self._authenticate_user(session, username, old_password):
|
||||
return False
|
||||
|
||||
user = self._get_user(session, username)
|
||||
user.password = self._encrypt_password(new_password)
|
||||
session.commit()
|
||||
return True
|
||||
user = self._get_user(session, username)
|
||||
user.password = self._encrypt_password(new_password)
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
def authenticate_user(self, username, password):
|
||||
session = self._get_db_session()
|
||||
return self._authenticate_user(session, username, password)
|
||||
with self.db.get_session() as session:
|
||||
return self._authenticate_user(session, username, password)
|
||||
|
||||
def authenticate_user_session(self, session_token):
|
||||
session = self._get_db_session()
|
||||
user_session = session.query(UserSession).filter_by(session_token=session_token).first()
|
||||
with self.db.get_session() as session:
|
||||
user_session = session.query(UserSession).filter_by(session_token=session_token).first()
|
||||
|
||||
if not user_session or (
|
||||
user_session.expires_at and user_session.expires_at < datetime.datetime.utcnow()):
|
||||
return None, None
|
||||
if not user_session or (
|
||||
user_session.expires_at and user_session.expires_at < datetime.datetime.utcnow()):
|
||||
return None, None
|
||||
|
||||
user = session.query(User).filter_by(user_id=user_session.user_id).first()
|
||||
|
||||
# Hide password
|
||||
user.password = None
|
||||
return user, session
|
||||
user = session.query(User).filter_by(user_id=user_session.user_id).first()
|
||||
return self._mask_password(user), user_session
|
||||
|
||||
def delete_user(self, username):
|
||||
session = self._get_db_session()
|
||||
user = self._get_user(session, username)
|
||||
if not user:
|
||||
raise NameError('No such user: {}'.format(username))
|
||||
with self.db.get_session() as session:
|
||||
user = self._get_user(session, username)
|
||||
if not user:
|
||||
raise NameError('No such user: {}'.format(username))
|
||||
|
||||
user_sessions = session.query(UserSession).filter_by(user_id=user.user_id).all()
|
||||
for user_session in user_sessions:
|
||||
session.delete(user_session)
|
||||
user_sessions = session.query(UserSession).filter_by(user_id=user.user_id).all()
|
||||
for user_session in user_sessions:
|
||||
session.delete(user_session)
|
||||
|
||||
session.delete(user)
|
||||
session.commit()
|
||||
return True
|
||||
session.delete(user)
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
def delete_user_session(self, session_token):
|
||||
session = self._get_db_session()
|
||||
user_session = session.query(UserSession).filter_by(session_token=session_token).first()
|
||||
with self.db.get_session() as session:
|
||||
user_session = session.query(UserSession).filter_by(session_token=session_token).first()
|
||||
|
||||
if not user_session:
|
||||
return False
|
||||
if not user_session:
|
||||
return False
|
||||
|
||||
session.delete(user_session)
|
||||
session.commit()
|
||||
return True
|
||||
session.delete(user_session)
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
def create_user_session(self, username, password, expires_at=None):
|
||||
session = self._get_db_session()
|
||||
user = self._authenticate_user(session, username, password)
|
||||
if not user:
|
||||
return None
|
||||
with self.db.get_session() as session:
|
||||
user = self._authenticate_user(session, username, password)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
if expires_at:
|
||||
if isinstance(expires_at, int) or isinstance(expires_at, float):
|
||||
expires_at = datetime.datetime.fromtimestamp(expires_at)
|
||||
elif isinstance(expires_at, str):
|
||||
expires_at = datetime.datetime.fromisoformat(expires_at)
|
||||
if expires_at:
|
||||
if isinstance(expires_at, int) or isinstance(expires_at, float):
|
||||
expires_at = datetime.datetime.fromtimestamp(expires_at)
|
||||
elif isinstance(expires_at, str):
|
||||
expires_at = datetime.datetime.fromisoformat(expires_at)
|
||||
|
||||
user_session = UserSession(user_id=user.user_id, session_token=self.generate_session_token(),
|
||||
csrf_token=self.generate_session_token(), created_at=datetime.datetime.utcnow(),
|
||||
expires_at=expires_at)
|
||||
user_session = UserSession(user_id=user.user_id, session_token=self.generate_session_token(),
|
||||
csrf_token=self.generate_session_token(), created_at=datetime.datetime.utcnow(),
|
||||
expires_at=expires_at)
|
||||
|
||||
session.add(user_session)
|
||||
session.commit()
|
||||
return user_session
|
||||
session.add(user_session)
|
||||
session.commit()
|
||||
return user_session
|
||||
|
||||
@staticmethod
|
||||
def _get_user(session, username):
|
||||
|
@ -180,8 +178,8 @@ class UserManager:
|
|||
|
||||
:param session_token: Session token.
|
||||
"""
|
||||
session = self._get_db_session()
|
||||
return session.query(User).join(UserSession).filter_by(session_token=session_token).first()
|
||||
with self.db.get_session() as session:
|
||||
return session.query(User).join(UserSession).filter_by(session_token=session_token).first()
|
||||
|
||||
def generate_jwt_token(self, username: str, password: str, expires_at: Optional[datetime.datetime] = None) -> str:
|
||||
"""
|
||||
|
@ -240,12 +238,6 @@ class UserManager:
|
|||
|
||||
return payload
|
||||
|
||||
def _get_db_session(self):
|
||||
Base.metadata.create_all(self._engine)
|
||||
session = scoped_session(sessionmaker(expire_on_commit=False))
|
||||
session.configure(bind=self._engine)
|
||||
return session()
|
||||
|
||||
def _authenticate_user(self, session, username, password):
|
||||
"""
|
||||
:return: :class:`platypush.user.User` instance if the user exists and the password is valid, ``None`` otherwise.
|
||||
|
|
Loading…
Reference in a new issue