- Simplified prototype for EntityManager.set

- Added small documentation/annotations notes to the `Plugin` module.

- Small LINT fixes
This commit is contained in:
Fabio Manganiello 2023-02-11 15:05:59 +01:00
parent 575635fd6b
commit 1d0be5c929
Signed by: blacklight
GPG key ID: D90FBA7F76362774
7 changed files with 35 additions and 16 deletions

View file

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Optional
from typing import Any
from typing_extensions import override
from . import EntityManager
@ -11,15 +11,13 @@ class WriteableEntityManager(EntityManager, ABC):
"""
@abstractmethod
def set(self, entity: str, value: Any, attribute: Optional[str] = None, **kwargs):
def set(self, entity: str, value: Any, **kwargs):
"""
Set the value of an entity.
:param entity: The entity to set the value for. It's usually the ID of
the entity provided by the plugin.
:param value: The value to set the entity to.
:param attribute: The name of the attribute to set for the entity, if
required by the integration.
"""
raise NotImplementedError()
@ -45,7 +43,7 @@ class SwitchEntityManager(WriteableEntityManager, ABC):
raise NotImplementedError()
@override
def set(self, entity: str, value: Any, attribute: Optional[str] = None, **kwargs):
def set(self, entity: str, value: Any, **kwargs):
method = self.on if value else self.off
return method(entity, **kwargs)

View file

@ -5,6 +5,7 @@ import threading
from abc import ABC, abstractmethod
from functools import wraps
from typing import Any, Callable, Optional
from typing_extensions import override
from platypush.bus import Bus
from platypush.common import ExtensionWithManifest
@ -97,21 +98,33 @@ class RunnablePlugin(Plugin):
self._thread: Optional[threading.Thread] = None
def main(self):
"""
Implementation of the main loop of the plugin.
"""
raise NotImplementedError()
def should_stop(self):
def should_stop(self) -> bool:
return self._should_stop.is_set()
def wait_stop(self, timeout=None):
"""
Wait until a stop event is received.
"""
return self._should_stop.wait(timeout=timeout)
def start(self):
"""
Start the plugin.
"""
self._thread = threading.Thread(
target=self._runner, name=self.__class__.__name__
)
self._thread.start()
def stop(self):
"""
Stop the plugin.
"""
self._should_stop.set()
if self._thread and self._thread.is_alive():
self.logger.info('Waiting for the plugin to stop')
@ -129,6 +142,9 @@ class RunnablePlugin(Plugin):
self.logger.info('%s stopped', self.__class__.__name__)
def _runner(self):
"""
Implementation of the runner thread.
"""
self.logger.info('Starting %s', self.__class__.__name__)
while not self.should_stop():
@ -185,6 +201,9 @@ class AsyncRunnablePlugin(RunnablePlugin, ABC):
raise e
def _run_listener(self):
"""
Initialize an event loop and run the listener as a task.
"""
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
@ -198,6 +217,7 @@ class AsyncRunnablePlugin(RunnablePlugin, ABC):
self._task.cancel()
@override
def main(self):
if self.should_stop():
self.logger.info('The plugin is already scheduled to stop')
@ -214,6 +234,7 @@ class AsyncRunnablePlugin(RunnablePlugin, ABC):
else:
self.wait_stop()
@override
def stop(self):
if self._loop and self._loop.is_running():
self._loop.call_soon_threadsafe(self._loop.stop)

View file

@ -126,7 +126,7 @@ class MqttPlugin(Plugin):
if version == 'tlsv1.2':
return ssl.PROTOCOL_TLSv1_2
assert 'Unrecognized TLS version: {}'.format(version)
assert f'Unrecognized TLS version: {version}'
def _mqtt_args(self, **kwargs):
return {

View file

@ -356,7 +356,7 @@ class SmartthingsPlugin(
}
missing_devs = {dev for dev in devices if dev not in found_devs}
return list(found_devs.values()), list(missing_devs) # type: ignore
return list(found_devs.values()), list(missing_devs)
def _get_devices(self, *devices: str) -> List[DeviceEntity]:
devs, missing_devs = self._get_existing_and_missing_devices(*devices)
@ -633,7 +633,7 @@ class SmartthingsPlugin(
self._entities_by_id.update({e.id: e for e in compatible_entities})
return super().transform_entities(compatible_entities) # type: ignore
return super().transform_entities(compatible_entities)
async def _get_device_status(
self, api, device_id: str, publish_entities: bool
@ -642,7 +642,7 @@ class SmartthingsPlugin(
assert device, f'No such device: {device_id}'
await device.status.refresh()
if publish_entities:
self.publish_entities([device]) # type: ignore
self.publish_entities([device])
self._devices_by_id[device_id] = device
self._devices_by_name[device.label] = device
@ -863,7 +863,6 @@ class SmartthingsPlugin(
@action
def set(self, entity: str, value: Any, attribute: Optional[str] = None, **kwargs):
super().set(entity, value, attribute, **kwargs)
return self.set_value(entity, property=attribute, value=value, **kwargs)
@action
@ -994,6 +993,7 @@ class SmartthingsPlugin(
self.logger.exception(e)
self.logger.error('Could not refresh the status: %s', e)
self.wait_stop(3 * (self.poll_interval or 5))
return None
while not self.should_stop():
updated_devices = {}
@ -1010,7 +1010,7 @@ class SmartthingsPlugin(
if self._has_status_changed(devices.get(device_id, {}), new_status)
}
self.publish_entities(updated_devices.values()) # type: ignore
self.publish_entities(updated_devices.values())
devices.update(new_devices)
self.wait_stop(self.poll_interval)
refresh_status_safe()

View file

@ -34,7 +34,7 @@ class DeviceMapper:
entity_type: Type[Entity],
capability: str,
attribute: str,
value_type: Union[Type, str],
value_type: Union[Type, Enum, str],
set_command: Optional[Union[str, Callable[[Any], str]]] = None,
get_value: Optional[Callable[[DeviceEntity], Any]] = None,
set_value_args: Optional[Callable[..., Any]] = None,
@ -46,7 +46,7 @@ class DeviceMapper:
self.attribute = attribute
self.value_type = value_type
self.get_value = get_value if get_value else self._default_get_value
self.values = []
self.values: List[str] = []
self.entity_args = kwargs
if isinstance(value_type, Enum):

View file

@ -122,7 +122,7 @@ class SwitchbotBluetoothPlugin(BluetoothBlePlugin, EnumSwitchEntityManager):
self.logger.warning('Unknown command for SwitchBot "%s": "%s"', device, value)
@override
def set(self, entity: str, value: Any, attribute: Optional[str] = None, **kwargs):
def set(self, entity: str, value: Any, **kwargs):
return self.set_value(entity, value, **kwargs)
@override

View file

@ -348,7 +348,7 @@ class ZwaveBasePlugin(
raise NotImplementedError
@action
def set(self, entity: str, value: Any, attribute: Optional[str] = None, **kwargs):
def set(self, entity: str, value: Any, **kwargs):
return self.set_value(
value_id=entity, id_on_network=entity, data=value, **kwargs
)