Basic support for entities on the local db and implemented support for switch entities on the tplink plugin

This commit is contained in:
Fabio Manganiello 2022-04-04 16:50:17 +02:00
parent b1491b8048
commit 4ee7e4db29
Signed by: blacklight
GPG key ID: D90FBA7F76362774
12 changed files with 506 additions and 134 deletions

View file

@ -9,11 +9,13 @@ import argparse
import logging
import os
import sys
from typing import Optional
from .bus.redis import RedisBus
from .config import Config
from .context import register_backends, register_plugins
from .cron.scheduler import CronScheduler
from .entities import init_entities_engine, EntitiesEngine
from .event.processor import EventProcessor
from .logger import Logger
from .message.event import Event
@ -86,6 +88,7 @@ class Daemon:
self.no_capture_stdout = no_capture_stdout
self.no_capture_stderr = no_capture_stderr
self.event_processor = EventProcessor()
self.entities_engine: Optional[EntitiesEngine] = None
self.requests_to_process = requests_to_process
self.processed_requests = 0
self.cron_scheduler = None
@ -161,6 +164,7 @@ class Daemon:
""" Stops the backends and the bus """
from .plugins import RunnablePlugin
if self.backends:
for backend in self.backends.values():
backend.stop()
@ -168,9 +172,17 @@ class Daemon:
if isinstance(plugin, RunnablePlugin):
plugin.stop()
if self.bus:
self.bus.stop()
self.bus = None
if self.cron_scheduler:
self.cron_scheduler.stop()
self.cron_scheduler = None
if self.entities_engine:
self.entities_engine.stop()
self.entities_engine = None
def run(self):
""" Start the daemon """
@ -192,6 +204,9 @@ class Daemon:
# Initialize the plugins
register_plugins(bus=self.bus)
# Initialize the entities engine
self.entities_engine = init_entities_engine()
# Start the cron scheduler
if Config.get_cronjobs():
self.cron_scheduler = CronScheduler(jobs=Config.get_cronjobs())

View file

@ -0,0 +1,36 @@
import warnings
from typing import Collection, Optional
from ._base import Entity
from ._engine import EntitiesEngine
from ._registry import manages, register_entity_plugin, get_plugin_registry
_engine: Optional[EntitiesEngine] = None
def init_entities_engine() -> EntitiesEngine:
from ._base import init_entities_db
global _engine
init_entities_db()
_engine = EntitiesEngine()
_engine.start()
return _engine
def publish_entities(entities: Collection[Entity]):
if not _engine:
warnings.warn('No entities engine registered')
return
_engine.post(*entities)
__all__ = (
'Entity',
'EntitiesEngine',
'init_entities_engine',
'publish_entities',
'register_entity_plugin',
'get_plugin_registry',
'manages',
)

View file

@ -0,0 +1,73 @@
import inspect
import pathlib
from datetime import datetime
from typing import Mapping, Type
import pkgutil
from sqlalchemy import Column, Index, Integer, String, DateTime, JSON, UniqueConstraint
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
entities_registry: Mapping[Type['Entity'], Mapping] = {}
class Entity(Base):
"""
Model for a general-purpose platform entity
"""
__tablename__ = 'entity'
id = Column(Integer, autoincrement=True, primary_key=True)
external_id = Column(String, nullable=True)
name = Column(String, nullable=False, index=True)
type = Column(String, nullable=False, index=True)
plugin = Column(String, nullable=False)
data = Column(JSON, default=dict)
created_at = Column(DateTime(timezone=False), default=datetime.utcnow(), nullable=False)
updated_at = Column(DateTime(timezone=False), default=datetime.utcnow(), onupdate=datetime.now())
UniqueConstraint(external_id, plugin)
__table_args__ = (
Index(name, plugin),
)
__mapper_args__ = {
'polymorphic_identity': __tablename__,
'polymorphic_on': type,
}
def _discover_entity_types():
from platypush.context import get_plugin
logger = get_plugin('logger')
assert logger
for loader, modname, _ in pkgutil.walk_packages(
path=[str(pathlib.Path(__file__).parent.absolute())],
prefix=__package__ + '.',
onerror=lambda _: None
):
try:
mod_loader = loader.find_module(modname) # type: ignore
assert mod_loader
module = mod_loader.load_module() # type: ignore
except Exception as e:
logger.warning(f'Could not import module {modname}')
logger.exception(e)
continue
for _, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, Entity):
entities_registry[obj] = {}
def init_entities_db():
from platypush.context import get_plugin
_discover_entity_types()
db = get_plugin('db')
assert db
engine = db.get_engine()
db.create_all(engine, Base)

View file

@ -0,0 +1,110 @@
from logging import getLogger
from queue import Queue, Empty
from threading import Thread, Event
from time import time
from typing import Iterable, List
from sqlalchemy import and_, or_, inspect as schema_inspect
from sqlalchemy.orm import Session
from sqlalchemy.sql.elements import Null
from ._base import Entity
class EntitiesEngine(Thread):
# Processing queue timeout in seconds
_queue_timeout = 5.
def __init__(self):
obj_name = self.__class__.__name__
super().__init__(name=obj_name)
self.logger = getLogger(name=obj_name)
self._queue = Queue()
self._should_stop = Event()
def post(self, *entities: Entity):
for entity in entities:
self._queue.put(entity)
@property
def should_stop(self) -> bool:
return self._should_stop.is_set()
def stop(self):
self._should_stop.set()
def run(self):
super().run()
self.logger.info('Started entities engine')
while not self.should_stop:
msgs = []
last_poll_time = time()
while not self.should_stop and (
time() - last_poll_time < self._queue_timeout):
try:
msg = self._queue.get(block=True, timeout=0.5)
except Empty:
continue
if msg:
msgs.append(msg)
if not msgs or self.should_stop:
continue
self._process_entities(*msgs)
self.logger.info('Stopped entities engine')
def _get_if_exist(self, session: Session, entities: Iterable[Entity]) -> Iterable[Entity]:
existing_entities = {
(entity.external_id or entity.name, entity.plugin): entity
for entity in session.query(Entity).filter(
or_(*[
and_(Entity.external_id == entity.external_id, Entity.plugin == entity.plugin)
if entity.external_id is not None else
and_(Entity.name == entity.name, Entity.plugin == entity.plugin)
for entity in entities
])
).all()
}
return [
existing_entities.get(
(entity.external_id or entity.name, entity.plugin), None
) for entity in entities
]
def _merge_entities(
self, entities: List[Entity],
existing_entities: List[Entity]
) -> List[Entity]:
new_entities = []
for i, entity in enumerate(entities):
existing_entity = existing_entities[i]
if existing_entity:
inspector = schema_inspect(entity.__class__)
columns = [col.key for col in inspector.mapper.column_attrs]
for col in columns:
new_value = getattr(entity, col)
if new_value is not None and new_value.__class__ != Null:
setattr(existing_entity, col, getattr(entity, col))
new_entities.append(existing_entity)
else:
new_entities.append(entity)
return new_entities
def _process_entities(self, *entities: Entity):
from platypush.context import get_plugin
with get_plugin('db').get_session() as session: # type: ignore
existing_entities = self._get_if_exist(session, entities)
entities = self._merge_entities(entities, existing_entities) # type: ignore
session.add_all(entities)
session.commit()

View file

@ -0,0 +1,62 @@
from datetime import datetime
from typing import Optional, Mapping, Dict, Collection, Type
from platypush.plugins import Plugin
from platypush.utils import get_plugin_name_by_class
from ._base import Entity
_entity_plugin_registry: Mapping[Type[Entity], Dict[str, Plugin]] = {}
def register_entity_plugin(entity_type: Type[Entity], plugin: Plugin):
plugins = _entity_plugin_registry.get(entity_type, {})
plugin_name = get_plugin_name_by_class(plugin.__class__)
assert plugin_name
plugins[plugin_name] = plugin
_entity_plugin_registry[entity_type] = plugins
def get_plugin_registry():
return _entity_plugin_registry.copy()
class EntityManagerMixin:
def transform_entities(self, entities):
entities = entities or []
for entity in entities:
if entity.id:
# Entity IDs can only refer to the internal primary key
entity.external_id = entity.id
entity.id = None # type: ignore
entity.plugin = get_plugin_name_by_class(self.__class__) # type: ignore
entity.updated_at = datetime.utcnow()
return entities
def publish_entities(self, entities: Optional[Collection[Entity]]):
from . import publish_entities
entities = self.transform_entities(entities)
publish_entities(entities)
def manages(*entities: Type[Entity]):
def wrapper(plugin: Type[Plugin]):
init = plugin.__init__
def __init__(self, *args, **kwargs):
for entity_type in entities:
register_entity_plugin(entity_type, self)
init(self, *args, **kwargs)
plugin.__init__ = __init__
# Inject the EntityManagerMixin
if EntityManagerMixin not in plugin.__bases__:
plugin.__bases__ = (EntityManagerMixin,) + plugin.__bases__
return plugin
return wrapper

View file

@ -0,0 +1,14 @@
from sqlalchemy import Column, Integer, ForeignKey
from ._base import Entity
class Device(Entity):
__tablename__ = 'device'
id = Column(Integer, ForeignKey(Entity.id), primary_key=True)
__mapper_args__ = {
'polymorphic_identity': __tablename__,
}

View file

@ -0,0 +1,14 @@
from sqlalchemy import Column, Integer, ForeignKey
from .devices import Device
class Light(Device):
__tablename__ = 'light'
id = Column(Integer, ForeignKey(Device.id), primary_key=True)
__mapper_args__ = {
'polymorphic_identity': __tablename__,
}

View file

@ -0,0 +1,15 @@
from sqlalchemy import Column, Integer, ForeignKey, Boolean
from .devices import Device
class Switch(Device):
__tablename__ = 'switch'
id = Column(Integer, ForeignKey(Device.id), primary_key=True)
state = Column(Boolean)
__mapper_args__ = {
'polymorphic_identity': __tablename__,
}

View file

@ -1,11 +1,11 @@
"""
.. moduleauthor:: Fabio Manganiello <blacklight86@gmail.com>
"""
import time
from contextlib import contextmanager
from multiprocessing import RLock
from typing import Generator
from sqlalchemy import create_engine, Table, MetaData
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session, sessionmaker, scoped_session
from platypush.plugins import Plugin, action
@ -30,22 +30,23 @@ class DbPlugin(Plugin):
"""
super().__init__()
self.engine = self._get_engine(engine, *args, **kwargs)
self.engine = self.get_engine(engine, *args, **kwargs)
self._session_locks = {}
def _get_engine(self, engine=None, *args, **kwargs):
def get_engine(self, engine=None, *args, **kwargs) -> Engine:
if engine:
if isinstance(engine, Engine):
return engine
if engine.startswith('sqlite://'):
kwargs['connect_args'] = {'check_same_thread': False}
return create_engine(engine, *args, **kwargs)
return create_engine(engine, *args, **kwargs) # type: ignore
assert self.engine
return self.engine
# noinspection PyUnusedLocal
@staticmethod
def _build_condition(table, column, value):
def _build_condition(_, column, value):
if isinstance(value, str):
value = "'{}'".format(value)
elif not isinstance(value, int) and not isinstance(value, float):
@ -73,14 +74,14 @@ class DbPlugin(Plugin):
:param kwargs: Extra kwargs that will be passed to ``sqlalchemy.create_engine`` (seehttps:///docs.sqlalchemy.org/en/latest/core/engines.html)
"""
engine = self._get_engine(engine, *args, **kwargs)
engine = self.get_engine(engine, *args, **kwargs)
with engine.connect() as connection:
connection.execute(statement)
def _get_table(self, table, engine=None, *args, **kwargs):
if not engine:
engine = self._get_engine(engine, *args, **kwargs)
engine = self.get_engine(engine, *args, **kwargs)
db_ok = False
n_tries = 0
@ -98,7 +99,7 @@ class DbPlugin(Plugin):
self.logger.exception(e)
self.logger.info('Waiting {} seconds before retrying'.format(wait_time))
time.sleep(wait_time)
engine = self._get_engine(engine, *args, **kwargs)
engine = self.get_engine(engine, *args, **kwargs)
if not db_ok and last_error:
raise last_error
@ -163,7 +164,7 @@ class DbPlugin(Plugin):
]
"""
engine = self._get_engine(engine, *args, **kwargs)
engine = self.get_engine(engine, *args, **kwargs)
if table:
table, engine = self._get_table(table, engine=engine, *args, **kwargs)
@ -234,7 +235,7 @@ class DbPlugin(Plugin):
if key_columns is None:
key_columns = []
engine = self._get_engine(engine, *args, **kwargs)
engine = self.get_engine(engine, *args, **kwargs)
for record in records:
table, engine = self._get_table(table, engine=engine, *args, **kwargs)
@ -293,7 +294,7 @@ class DbPlugin(Plugin):
}
"""
engine = self._get_engine(engine, *args, **kwargs)
engine = self.get_engine(engine, *args, **kwargs)
for record in records:
table, engine = self._get_table(table, engine=engine, *args, **kwargs)
@ -341,7 +342,7 @@ class DbPlugin(Plugin):
}
"""
engine = self._get_engine(engine, *args, **kwargs)
engine = self.get_engine(engine, *args, **kwargs)
for record in records:
table, engine = self._get_table(table, engine=engine, *args, **kwargs)
@ -352,5 +353,22 @@ class DbPlugin(Plugin):
engine.execute(delete)
def create_all(self, engine, base):
self._session_locks[engine.url] = self._session_locks.get(engine.url, RLock())
with self._session_locks[engine.url]:
base.metadata.create_all(engine)
@contextmanager
def get_session(self, engine=None, *args, **kwargs) -> Generator[Session, None, None]:
engine = self.get_engine(engine, *args, **kwargs)
self._session_locks[engine.url] = self._session_locks.get(engine.url, RLock())
with self._session_locks[engine.url]:
session = scoped_session(sessionmaker(expire_on_commit=False))
session.configure(bind=engine)
s = session()
yield s
s.commit()
s.close()
# vim:sw=4:ts=4:et:

View file

@ -1,9 +1,12 @@
from abc import ABC, abstractmethod
from typing import List, Union
from platypush.entities import manages
from platypush.entities.switches import Switch
from platypush.plugins import Plugin, action
@manages(Switch)
class SwitchPlugin(Plugin, ABC):
"""
Abstract class for interacting with switch devices
@ -46,7 +49,7 @@ class SwitchPlugin(Plugin, ABC):
return devices
@action
def status(self, device=None, *args, **kwargs) -> Union[dict, List[dict]]:
def status(self, device=None, *_, **__) -> Union[dict, List[dict]]:
"""
Get the status of all the devices, or filter by device name or ID (alias for :meth:`.switch_status`).

View file

@ -1,4 +1,4 @@
from typing import Union, Dict, List
from typing import Union, Mapping, List, Collection, Optional
from pyHS100 import SmartDevice, SmartPlug, SmartBulb, SmartStrip, Discover, SmartDeviceException
@ -20,8 +20,12 @@ class SwitchTplinkPlugin(SwitchPlugin):
_ip_to_dev = {}
_alias_to_dev = {}
def __init__(self, plugs: Union[Dict[str, str], List[str]] = None, bulbs: Union[Dict[str, str], List[str]] = None,
strips: Union[Dict[str, str], List[str]] = None, **kwargs):
def __init__(
self,
plugs: Optional[Union[Mapping[str, str], List[str]]] = None,
bulbs: Optional[Union[Mapping[str, str], List[str]]] = None,
strips: Optional[Union[Mapping[str, str], List[str]]] = None, **kwargs
):
"""
:param plugs: Optional list of IP addresses or name->address mapping if you have a static list of
TpLink plugs and you want to save on the scan time.
@ -62,7 +66,7 @@ class SwitchTplinkPlugin(SwitchPlugin):
self._update_devices()
def _update_devices(self, devices: Dict[str, SmartDevice] = None):
def _update_devices(self, devices: Optional[Mapping[str, SmartDevice]] = None):
for (addr, info) in self._static_devices.items():
try:
dev = info['type'](addr)
@ -75,6 +79,26 @@ class SwitchTplinkPlugin(SwitchPlugin):
self._ip_to_dev[ip] = dev
self._alias_to_dev[dev.alias] = dev
if devices:
self.publish_entities(devices.values()) # type: ignore
def transform_entities(self, devices: Collection[SmartDevice]):
from platypush.entities.switches import Switch
return super().transform_entities([ # type: ignore
Switch(
id=dev.host,
name=dev.alias,
state=dev.is_on,
data={
'current_consumption': dev.current_consumption(),
'ip': dev.host,
'host': dev.host,
'hw_info': dev.hw_info,
}
)
for dev in (devices or [])
])
def _scan(self):
devices = Discover.discover()
self._update_devices(devices)
@ -95,8 +119,15 @@ class SwitchTplinkPlugin(SwitchPlugin):
else:
raise RuntimeError('Device {} not found'.format(device))
def _set(self, device: SmartDevice, state: bool):
action_name = 'turn_on' if state else 'turn_off'
action = getattr(device, action_name)
action()
self.publish_entities([device]) # type: ignore
return self._serialize(device)
@action
def on(self, device, **kwargs):
def on(self, device, **_):
"""
Turn on a device
@ -105,11 +136,10 @@ class SwitchTplinkPlugin(SwitchPlugin):
"""
device = self._get_device(device)
device.turn_on()
return self.status(device)
return self._set(device, True)
@action
def off(self, device, **kwargs):
def off(self, device, **_):
"""
Turn off a device
@ -118,11 +148,10 @@ class SwitchTplinkPlugin(SwitchPlugin):
"""
device = self._get_device(device)
device.turn_off()
return self.status(device)
return self._set(device, False)
@action
def toggle(self, device, **kwargs):
def toggle(self, device, **_):
"""
Toggle the state of a device (on/off)
@ -131,12 +160,10 @@ class SwitchTplinkPlugin(SwitchPlugin):
"""
device = self._get_device(device)
return self._set(device, not device.is_on)
if device.is_on:
device.turn_off()
else:
device.turn_on()
@staticmethod
def _serialize(device: SmartDevice) -> dict:
return {
'current_consumption': device.current_consumption(),
'id': device.host,
@ -150,15 +177,8 @@ class SwitchTplinkPlugin(SwitchPlugin):
@property
def switches(self) -> List[dict]:
return [
{
'current_consumption': dev.current_consumption(),
'id': ip,
'ip': ip,
'host': dev.host,
'hw_info': dev.hw_info,
'name': dev.alias,
'on': dev.is_on,
} for (ip, dev) in self._scan().items()
self._serialize(dev)
for dev in self._scan().values()
]

View file

@ -13,7 +13,7 @@ except ImportError:
from jwt import PyJWTError, encode as jwt_encode, decode as jwt_decode
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.orm import make_transient
from sqlalchemy.ext.declarative import declarative_base
from platypush.context import get_plugin
@ -28,39 +28,42 @@ class UserManager:
Main class for managing platform users
"""
# noinspection PyProtectedMember
def __init__(self):
db_plugin = get_plugin('db')
if not db_plugin:
raise ModuleNotFoundError('Please enable/configure the db plugin for multi-user support')
self.db = get_plugin('db')
assert self.db
self._engine = self.db.get_engine()
self.db.create_all(self._engine, Base)
self._engine = db_plugin._get_engine()
@staticmethod
def _mask_password(user):
make_transient(user)
user.password = None
return user
def get_user(self, username):
session = self._get_db_session()
with self.db.get_session() as session:
user = self._get_user(session, username)
if not user:
return None
# Hide password
user.password = None
return user
session.expunge(user)
return self._mask_password(user)
def get_user_count(self):
session = self._get_db_session()
with self.db.get_session() as session:
return session.query(User).count()
def get_users(self):
session = self._get_db_session()
with self.db.get_session() as session:
return session.query(User)
def create_user(self, username, password, **kwargs):
session = self._get_db_session()
if not username:
raise ValueError('Invalid or empty username')
if not password:
raise ValueError('Please provide a password for the user')
with self.db.get_session() as session:
user = self._get_user(session, username)
if user:
raise NameError('The user {} already exists'.format(username))
@ -72,12 +75,10 @@ class UserManager:
session.commit()
user = self._get_user(session, username)
# Hide password
user.password = None
return user
return self._mask_password(user)
def update_password(self, username, old_password, new_password):
session = self._get_db_session()
with self.db.get_session() as session:
if not self._authenticate_user(session, username, old_password):
return False
@ -87,11 +88,11 @@ class UserManager:
return True
def authenticate_user(self, username, password):
session = self._get_db_session()
with self.db.get_session() as session:
return self._authenticate_user(session, username, password)
def authenticate_user_session(self, session_token):
session = self._get_db_session()
with self.db.get_session() as session:
user_session = session.query(UserSession).filter_by(session_token=session_token).first()
if not user_session or (
@ -99,13 +100,10 @@ class UserManager:
return None, None
user = session.query(User).filter_by(user_id=user_session.user_id).first()
# Hide password
user.password = None
return user, session
return self._mask_password(user), user_session
def delete_user(self, username):
session = self._get_db_session()
with self.db.get_session() as session:
user = self._get_user(session, username)
if not user:
raise NameError('No such user: {}'.format(username))
@ -119,7 +117,7 @@ class UserManager:
return True
def delete_user_session(self, session_token):
session = self._get_db_session()
with self.db.get_session() as session:
user_session = session.query(UserSession).filter_by(session_token=session_token).first()
if not user_session:
@ -130,7 +128,7 @@ class UserManager:
return True
def create_user_session(self, username, password, expires_at=None):
session = self._get_db_session()
with self.db.get_session() as session:
user = self._authenticate_user(session, username, password)
if not user:
return None
@ -180,7 +178,7 @@ class UserManager:
:param session_token: Session token.
"""
session = self._get_db_session()
with self.db.get_session() as session:
return session.query(User).join(UserSession).filter_by(session_token=session_token).first()
def generate_jwt_token(self, username: str, password: str, expires_at: Optional[datetime.datetime] = None) -> str:
@ -240,12 +238,6 @@ class UserManager:
return payload
def _get_db_session(self):
Base.metadata.create_all(self._engine)
session = scoped_session(sessionmaker(expire_on_commit=False))
session.configure(bind=self._engine)
return session()
def _authenticate_user(self, session, username, password):
"""
:return: :class:`platypush.user.User` instance if the user exists and the password is valid, ``None`` otherwise.