- Simplified prototype for EntityManager.set

- Added small documentation/annotations notes to the `Plugin` module.

- Small LINT fixes
This commit is contained in:
Fabio Manganiello 2023-02-11 15:05:59 +01:00
parent 575635fd6b
commit 1d0be5c929
Signed by: blacklight
GPG key ID: D90FBA7F76362774
7 changed files with 35 additions and 16 deletions

View file

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Optional from typing import Any
from typing_extensions import override from typing_extensions import override
from . import EntityManager from . import EntityManager
@ -11,15 +11,13 @@ class WriteableEntityManager(EntityManager, ABC):
""" """
@abstractmethod @abstractmethod
def set(self, entity: str, value: Any, attribute: Optional[str] = None, **kwargs): def set(self, entity: str, value: Any, **kwargs):
""" """
Set the value of an entity. Set the value of an entity.
:param entity: The entity to set the value for. It's usually the ID of :param entity: The entity to set the value for. It's usually the ID of
the entity provided by the plugin. the entity provided by the plugin.
:param value: The value to set the entity to. :param value: The value to set the entity to.
:param attribute: The name of the attribute to set for the entity, if
required by the integration.
""" """
raise NotImplementedError() raise NotImplementedError()
@ -45,7 +43,7 @@ class SwitchEntityManager(WriteableEntityManager, ABC):
raise NotImplementedError() raise NotImplementedError()
@override @override
def set(self, entity: str, value: Any, attribute: Optional[str] = None, **kwargs): def set(self, entity: str, value: Any, **kwargs):
method = self.on if value else self.off method = self.on if value else self.off
return method(entity, **kwargs) return method(entity, **kwargs)

View file

@ -5,6 +5,7 @@ import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import wraps from functools import wraps
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
from typing_extensions import override
from platypush.bus import Bus from platypush.bus import Bus
from platypush.common import ExtensionWithManifest from platypush.common import ExtensionWithManifest
@ -97,21 +98,33 @@ class RunnablePlugin(Plugin):
self._thread: Optional[threading.Thread] = None self._thread: Optional[threading.Thread] = None
def main(self): def main(self):
"""
Implementation of the main loop of the plugin.
"""
raise NotImplementedError() raise NotImplementedError()
def should_stop(self): def should_stop(self) -> bool:
return self._should_stop.is_set() return self._should_stop.is_set()
def wait_stop(self, timeout=None): def wait_stop(self, timeout=None):
"""
Wait until a stop event is received.
"""
return self._should_stop.wait(timeout=timeout) return self._should_stop.wait(timeout=timeout)
def start(self): def start(self):
"""
Start the plugin.
"""
self._thread = threading.Thread( self._thread = threading.Thread(
target=self._runner, name=self.__class__.__name__ target=self._runner, name=self.__class__.__name__
) )
self._thread.start() self._thread.start()
def stop(self): def stop(self):
"""
Stop the plugin.
"""
self._should_stop.set() self._should_stop.set()
if self._thread and self._thread.is_alive(): if self._thread and self._thread.is_alive():
self.logger.info('Waiting for the plugin to stop') self.logger.info('Waiting for the plugin to stop')
@ -129,6 +142,9 @@ class RunnablePlugin(Plugin):
self.logger.info('%s stopped', self.__class__.__name__) self.logger.info('%s stopped', self.__class__.__name__)
def _runner(self): def _runner(self):
"""
Implementation of the runner thread.
"""
self.logger.info('Starting %s', self.__class__.__name__) self.logger.info('Starting %s', self.__class__.__name__)
while not self.should_stop(): while not self.should_stop():
@ -185,6 +201,9 @@ class AsyncRunnablePlugin(RunnablePlugin, ABC):
raise e raise e
def _run_listener(self): def _run_listener(self):
"""
Initialize an event loop and run the listener as a task.
"""
self._loop = asyncio.new_event_loop() self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop) asyncio.set_event_loop(self._loop)
@ -198,6 +217,7 @@ class AsyncRunnablePlugin(RunnablePlugin, ABC):
self._task.cancel() self._task.cancel()
@override
def main(self): def main(self):
if self.should_stop(): if self.should_stop():
self.logger.info('The plugin is already scheduled to stop') self.logger.info('The plugin is already scheduled to stop')
@ -214,6 +234,7 @@ class AsyncRunnablePlugin(RunnablePlugin, ABC):
else: else:
self.wait_stop() self.wait_stop()
@override
def stop(self): def stop(self):
if self._loop and self._loop.is_running(): if self._loop and self._loop.is_running():
self._loop.call_soon_threadsafe(self._loop.stop) self._loop.call_soon_threadsafe(self._loop.stop)

View file

@ -126,7 +126,7 @@ class MqttPlugin(Plugin):
if version == 'tlsv1.2': if version == 'tlsv1.2':
return ssl.PROTOCOL_TLSv1_2 return ssl.PROTOCOL_TLSv1_2
assert 'Unrecognized TLS version: {}'.format(version) assert f'Unrecognized TLS version: {version}'
def _mqtt_args(self, **kwargs): def _mqtt_args(self, **kwargs):
return { return {

View file

@ -356,7 +356,7 @@ class SmartthingsPlugin(
} }
missing_devs = {dev for dev in devices if dev not in found_devs} missing_devs = {dev for dev in devices if dev not in found_devs}
return list(found_devs.values()), list(missing_devs) # type: ignore return list(found_devs.values()), list(missing_devs)
def _get_devices(self, *devices: str) -> List[DeviceEntity]: def _get_devices(self, *devices: str) -> List[DeviceEntity]:
devs, missing_devs = self._get_existing_and_missing_devices(*devices) devs, missing_devs = self._get_existing_and_missing_devices(*devices)
@ -633,7 +633,7 @@ class SmartthingsPlugin(
self._entities_by_id.update({e.id: e for e in compatible_entities}) self._entities_by_id.update({e.id: e for e in compatible_entities})
return super().transform_entities(compatible_entities) # type: ignore return super().transform_entities(compatible_entities)
async def _get_device_status( async def _get_device_status(
self, api, device_id: str, publish_entities: bool self, api, device_id: str, publish_entities: bool
@ -642,7 +642,7 @@ class SmartthingsPlugin(
assert device, f'No such device: {device_id}' assert device, f'No such device: {device_id}'
await device.status.refresh() await device.status.refresh()
if publish_entities: if publish_entities:
self.publish_entities([device]) # type: ignore self.publish_entities([device])
self._devices_by_id[device_id] = device self._devices_by_id[device_id] = device
self._devices_by_name[device.label] = device self._devices_by_name[device.label] = device
@ -863,7 +863,6 @@ class SmartthingsPlugin(
@action @action
def set(self, entity: str, value: Any, attribute: Optional[str] = None, **kwargs): def set(self, entity: str, value: Any, attribute: Optional[str] = None, **kwargs):
super().set(entity, value, attribute, **kwargs)
return self.set_value(entity, property=attribute, value=value, **kwargs) return self.set_value(entity, property=attribute, value=value, **kwargs)
@action @action
@ -994,6 +993,7 @@ class SmartthingsPlugin(
self.logger.exception(e) self.logger.exception(e)
self.logger.error('Could not refresh the status: %s', e) self.logger.error('Could not refresh the status: %s', e)
self.wait_stop(3 * (self.poll_interval or 5)) self.wait_stop(3 * (self.poll_interval or 5))
return None
while not self.should_stop(): while not self.should_stop():
updated_devices = {} updated_devices = {}
@ -1010,7 +1010,7 @@ class SmartthingsPlugin(
if self._has_status_changed(devices.get(device_id, {}), new_status) if self._has_status_changed(devices.get(device_id, {}), new_status)
} }
self.publish_entities(updated_devices.values()) # type: ignore self.publish_entities(updated_devices.values())
devices.update(new_devices) devices.update(new_devices)
self.wait_stop(self.poll_interval) self.wait_stop(self.poll_interval)
refresh_status_safe() refresh_status_safe()

View file

@ -34,7 +34,7 @@ class DeviceMapper:
entity_type: Type[Entity], entity_type: Type[Entity],
capability: str, capability: str,
attribute: str, attribute: str,
value_type: Union[Type, str], value_type: Union[Type, Enum, str],
set_command: Optional[Union[str, Callable[[Any], str]]] = None, set_command: Optional[Union[str, Callable[[Any], str]]] = None,
get_value: Optional[Callable[[DeviceEntity], Any]] = None, get_value: Optional[Callable[[DeviceEntity], Any]] = None,
set_value_args: Optional[Callable[..., Any]] = None, set_value_args: Optional[Callable[..., Any]] = None,
@ -46,7 +46,7 @@ class DeviceMapper:
self.attribute = attribute self.attribute = attribute
self.value_type = value_type self.value_type = value_type
self.get_value = get_value if get_value else self._default_get_value self.get_value = get_value if get_value else self._default_get_value
self.values = [] self.values: List[str] = []
self.entity_args = kwargs self.entity_args = kwargs
if isinstance(value_type, Enum): if isinstance(value_type, Enum):

View file

@ -122,7 +122,7 @@ class SwitchbotBluetoothPlugin(BluetoothBlePlugin, EnumSwitchEntityManager):
self.logger.warning('Unknown command for SwitchBot "%s": "%s"', device, value) self.logger.warning('Unknown command for SwitchBot "%s": "%s"', device, value)
@override @override
def set(self, entity: str, value: Any, attribute: Optional[str] = None, **kwargs): def set(self, entity: str, value: Any, **kwargs):
return self.set_value(entity, value, **kwargs) return self.set_value(entity, value, **kwargs)
@override @override

View file

@ -348,7 +348,7 @@ class ZwaveBasePlugin(
raise NotImplementedError raise NotImplementedError
@action @action
def set(self, entity: str, value: Any, attribute: Optional[str] = None, **kwargs): def set(self, entity: str, value: Any, **kwargs):
return self.set_value( return self.set_value(
value_id=entity, id_on_network=entity, data=value, **kwargs value_id=entity, id_on_network=entity, data=value, **kwargs
) )