Fixed management of state on zigbee.mqtt.

Before the merge of the plugin and the listener those components
used to have their own separate state, which led to inconsistencies.
This commit is contained in:
Fabio Manganiello 2023-09-14 23:05:27 +02:00
parent 5a514fdcce
commit ac72b2f7a8
Signed by: blacklight
GPG key ID: D90FBA7F76362774
3 changed files with 139 additions and 126 deletions

View file

@ -25,8 +25,8 @@ class EntitiesEngine(Thread):
together (preventing excessive writes and throttling events), and together (preventing excessive writes and throttling events), and
prevents race conditions when SQLite is used. prevents race conditions when SQLite is used.
2. Merge any existing entities with their newer representations. 2. Merge any existing entities with their newer representations.
3. Update the entities taxonomy. 3. Update the entities' taxonomy.
4. Persist the new state to the entities database. 4. Persist the new state to the entities' database.
5. Trigger events for the updated entities. 5. Trigger events for the updated entities.
""" """

View file

@ -1,11 +1,9 @@
import contextlib import contextlib
from dataclasses import dataclass, field
from enum import Enum
import json import json
import re import re
import threading import threading
from queue import Queue from queue import Empty, Queue
from typing import ( from typing import (
Any, Any,
Collection, Collection,
@ -75,82 +73,7 @@ from platypush.message.event.zigbee.mqtt import (
) )
from platypush.message.response import Response from platypush.message.response import Response
from platypush.plugins.mqtt import DEFAULT_TIMEOUT, MqttClient, MqttPlugin, action from platypush.plugins.mqtt import DEFAULT_TIMEOUT, MqttClient, MqttPlugin, action
from ._state import BridgeState, ZigbeeState
class BridgeState(Enum):
"""
Known bridge states.
"""
ONLINE = 'online'
OFFLINE = 'offline'
@dataclass
class ZigbeeDevicesInfo:
"""
Cached information about the devices on the network.
"""
by_address: Dict[str, dict] = field(default_factory=dict)
by_name: Dict[str, dict] = field(default_factory=dict)
def __contains__(self, name: str) -> bool:
"""
:return: True if the device with the given name exists in the cache.
"""
return name in self.by_name or name in self.by_address
def get(self, name: str) -> Optional[dict]:
"""
Retrieves a cached device record either by name or by address.
"""
return self.by_address.get(name, self.by_name.get(name))
def add(self, device: dict):
"""
Adds a device record to the cache.
"""
if device.get('ieee_address'):
self.by_address[device['ieee_address']] = device
if device.get('friendly_name'):
self.by_name[device['friendly_name']] = device
def remove(self, device: Union[str, dict]):
"""
Removes a device record from the cache.
"""
if isinstance(device, str):
dev = self.get(device)
if not dev:
return # No such device
else:
dev = device
if dev.get('ieee_address'):
self.by_address.pop(dev['ieee_address'], None)
if dev.get('friendly_name'):
self.by_name.pop(dev['friendly_name'], None)
def reset(self, *keys: str):
"""
Reset the state for the devices with the given keys.
"""
for k in keys:
self.by_address[k] = {}
self.by_name[k] = {}
@dataclass
class ZigbeeInfo:
"""
Cached information about the devices and groups on the network.
"""
devices: ZigbeeDevicesInfo = field(default_factory=ZigbeeDevicesInfo)
groups: Dict[str, dict] = field(default_factory=dict)
# pylint: disable=too-many-ancestors # pylint: disable=too-many-ancestors
@ -331,7 +254,7 @@ class ZigbeeMqttPlugin(
""" """
if base_topic: if base_topic:
self.logger.warning( self.logger.warning(
'base_topic is deprprecated, please use topic_prefix instead' 'base_topic is deprecated, please use topic_prefix instead'
) )
topic_prefix = base_topic topic_prefix = base_topic
@ -354,17 +277,12 @@ class ZigbeeMqttPlugin(
tls_ciphers=tls_ciphers, tls_ciphers=tls_ciphers,
username=username, username=username,
password=password, password=password,
timeout=timeout,
**kwargs, **kwargs,
) )
# Append a unique suffix to the client ID to avoid client name clashes
# with other MQTT plugins.
self.client_id += '-zigbee-mqtt'
self.topic_prefix = topic_prefix self.topic_prefix = topic_prefix
self.timeout = timeout self._info = ZigbeeState()
self._info = ZigbeeInfo()
self._devices_meta: Dict[str, dict] = {}
self._bridge_state = BridgeState.OFFLINE
@staticmethod @staticmethod
def _get_properties(device: dict) -> dict: def _get_properties(device: dict) -> dict:
@ -626,7 +544,7 @@ class ZigbeeMqttPlugin(
**kwargs, **kwargs,
) -> dict: ) -> dict:
""" """
Sends a request/message to the Zigbeebee2MQTT bridge and waits for a Sends a request/message to the Zigbee2MQTT bridge and waits for a
response. response.
""" """
return self._parse_response( return self._parse_response(
@ -677,7 +595,7 @@ class ZigbeeMqttPlugin(
{ {
"date_code": "20180906", "date_code": "20180906",
"friendly_name": "My Lightbulb", "friendly_name": "My Light Bulb",
"ieee_address": "0x00123456789abcdf", "ieee_address": "0x00123456789abcdf",
"network_address": 52715, "network_address": 52715,
"power_source": "Mains (single phase)", "power_source": "Mains (single phase)",
@ -1013,26 +931,25 @@ class ZigbeeMqttPlugin(
converted. converted.
""" """
def extract_value(value: dict, root: dict, depth: int = 0): def extract_value(val: dict, root: dict, depth: int = 0):
for feature in value.get('features', []): for feature in val.get('features', []):
new_root = root new_root = root
if depth > 0: if depth > 0:
new_root = root[value['property']] = root.get(value['property'], {}) new_root = root[val['property']] = root.get(val['property'], {})
extract_value(feature, new_root, depth=depth + 1) extract_value(feature, new_root, depth=depth + 1)
if not value.get('access', 1) & 0x4: if not val.get('access', 1) & 0x4:
# Property not readable/query-able # Property not readable/query-able
return return
if 'features' not in value: if 'features' not in val:
if 'property' in value: if 'property' in val:
root[value['property']] = 0 if value['type'] == 'numeric' else '' root[val['property']] = 0 if val['type'] == 'numeric' else ''
return return
if 'property' in value: if 'property' in val:
root[value['property']] = root.get(value['property'], {}) root[val['property']] = root.get(val['property'], {})
root = root[value['property']]
ret: Dict[str, Any] = {} ret: Dict[str, Any] = {}
for value in values: for value in values:
@ -1125,10 +1042,7 @@ class ZigbeeMqttPlugin(
# If the device has no queryable properties, don't specify a reply # If the device has no queryable properties, don't specify a reply
# topic to listen on # topic to listen on
req = self._build_device_get_request(exposes) req = self._build_device_get_request(exposes)
reply_topic = self._topic(device) reply_topic = self._topic(device) if req else None
if not req:
reply_topic = None
return self._run_request( return self._run_request(
topic=self._topic(device) + '/get', topic=self._topic(device) + '/get',
reply_topic=reply_topic, reply_topic=reply_topic,
@ -1174,7 +1088,7 @@ class ZigbeeMqttPlugin(
) )
def worker(device: str, q: Queue): def worker(device: str, q: Queue):
q.put(self.device_get(device, **kwargs).output) # type: ignore q.put_nowait(self.device_get(device, **kwargs).output) # type: ignore
queues: Dict[str, Queue] = {} queues: Dict[str, Queue] = {}
workers = {} workers = {}
@ -1188,9 +1102,16 @@ class ZigbeeMqttPlugin(
workers[device].start() workers[device].start()
for device in devices: for device in devices:
timeout = kwargs.get('timeout')
try: try:
response[device] = queues[device].get(timeout=kwargs.get('timeout')) response[device] = queues[device].get(timeout=timeout)
workers[device].join(timeout=kwargs.get('timeout')) workers[device].join(timeout=timeout)
except Empty:
self.logger.warning(
'Could not get the status of the device %s: timeout after %f seconds',
device,
timeout,
)
except Exception as e: except Exception as e:
self.logger.warning( self.logger.warning(
'An error occurred while getting the status of the device %s: %s', 'An error occurred while getting the status of the device %s: %s',
@ -1198,6 +1119,8 @@ class ZigbeeMqttPlugin(
e, e,
) )
self.logger.exception(e)
return response return response
@action @action
@ -1288,6 +1211,7 @@ class ZigbeeMqttPlugin(
:param property: Name of the property to set. If not specified here, it :param property: Name of the property to set. If not specified here, it
should be specified on ``device`` in ``<address>:<property>`` should be specified on ``device`` in ``<address>:<property>``
format. format.
:param data: Value to set for the property.
:param kwargs: Extra arguments to be passed to :param kwargs: Extra arguments to be passed to
:meth:`platypush.plugins.mqtt.MqttPlugin.publish`` (default: query :meth:`platypush.plugins.mqtt.MqttPlugin.publish`` (default: query
the default configured device). the default configured device).
@ -1429,7 +1353,7 @@ class ZigbeeMqttPlugin(
"device_options": {}, "device_options": {},
"devices": { "devices": {
"0x00123456789abcdf": { "0x00123456789abcdf": {
"friendly_name": "My Lightbulb" "friendly_name": "My Light Bulb"
} }
}, },
"experimental": { "experimental": {
@ -1828,7 +1752,7 @@ class ZigbeeMqttPlugin(
@staticmethod @staticmethod
def _is_query_disabled(feature: dict) -> bool: def _is_query_disabled(feature: dict) -> bool:
""" """
Utility method that checks if a feature doesn't support programmating Utility method that checks if a feature doesn't support programmatic
querying (i.e. it will only broadcast its state when available) on the querying (i.e. it will only broadcast its state when available) on the
basis of its access flags. basis of its access flags.
""" """
@ -1854,9 +1778,9 @@ class ZigbeeMqttPlugin(
# IEEE address + property format # IEEE address + property format
if re.search(r'^0x[0-9a-fA-F]{16}:', dev): if re.search(r'^0x[0-9a-fA-F]{16}:', dev):
parts = dev.split(':') parts = dev.split(':')
return (parts[0], parts[1] if len(parts) > 1 else None) return parts[0], parts[1] if len(parts) > 1 else None
return (dev, None) return dev, None
@classmethod @classmethod
def _ieee_address(cls, device: Union[dict, str]) -> str: def _ieee_address(cls, device: Union[dict, str]) -> str:
@ -2220,7 +2144,8 @@ class ZigbeeMqttPlugin(
""" """
def handler(client: MqttClient, _, msg: mqtt.MQTTMessage): def handler(client: MqttClient, _, msg: mqtt.MQTTMessage):
topic = msg.topic[len(self.topic_prefix) + 1 :] topic_idx = len(self.topic_prefix) + 1
topic = msg.topic[topic_idx:]
data = msg.payload.decode() data = msg.payload.decode()
if not data: if not data:
return return
@ -2244,7 +2169,9 @@ class ZigbeeMqttPlugin(
dev = self._info.devices.get(name) dev = self._info.devices.get(name)
assert dev is not None, f'No such device: {name}' assert dev is not None, f'No such device: {name}'
changed_props = {k: v for k, v in data.items() if v != dev.get(k)} changed_props = {
k: v for k, v in data.items() if v != dev.get('state', {}).get(k)
}
if changed_props: if changed_props:
self._process_property_update(name, data) self._process_property_update(name, data)
@ -2257,11 +2184,11 @@ class ZigbeeMqttPlugin(
) )
) )
device_meta = self._devices_meta.get(name) dev = self._info.devices.get(name)
if device_meta: if dev:
data['friendly_name'] = device_meta.get('friendly_name') self._info.devices.set_state(
data['ieee_address'] = device_meta.get('ieee_address') dev.get('friendly_name') or dev.get('ieee_address'), data
self._info.devices.add(data) )
return handler return handler
@ -2276,15 +2203,15 @@ class ZigbeeMqttPlugin(
""" """
Process a state message. Process a state message.
""" """
if msg == self._bridge_state: if msg == self._info.bridge_state:
return return
if msg == 'online': if msg == 'online':
evt = ZigbeeMqttOnlineEvent evt = ZigbeeMqttOnlineEvent
self._bridge_state = BridgeState.ONLINE self._info.bridge_state = BridgeState.ONLINE
elif msg == 'offline': elif msg == 'offline':
evt = ZigbeeMqttOfflineEvent evt = ZigbeeMqttOfflineEvent
self._bridge_state = BridgeState.OFFLINE self._info.bridge_state = BridgeState.OFFLINE
self.logger.warning('The zigbee2mqtt service is offline') self.logger.warning('The zigbee2mqtt service is offline')
else: else:
return return
@ -2294,7 +2221,7 @@ class ZigbeeMqttPlugin(
# pylint: disable=too-many-branches # pylint: disable=too-many-branches
def _process_log_message(self, client, msg): def _process_log_message(self, client, msg):
""" """
Process a logevent. Process a log event.
""" """
msg_type = msg.get('type') msg_type = msg.get('type')
@ -2384,8 +2311,8 @@ class ZigbeeMqttPlugin(
self._bus.post(ZigbeeMqttDeviceRemovedEvent(device=name, **event_args)) self._bus.post(ZigbeeMqttDeviceRemovedEvent(device=name, **event_args))
self._info.devices.remove(name) self._info.devices.remove(name)
self._info.devices.reset(*devices_info) for dev in devices_info.values():
self._devices_meta = devices_info self._info.devices.add(dev)
def _process_groups(self, client: MqttClient, msg): def _process_groups(self, client: MqttClient, msg):
""" """
@ -2417,7 +2344,7 @@ class ZigbeeMqttPlugin(
It will appropriately forward an It will appropriately forward an
:class:`platypush.message.event.entities.EntityUpdateEvent` to the bus. :class:`platypush.message.event.entities.EntityUpdateEvent` to the bus.
""" """
device_info = self._devices_meta.get(device_name) device_info = self._info.devices.get(device_name)
if not (device_info and properties): if not (device_info and properties):
return return

View file

@ -0,0 +1,86 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, Optional, Union
class BridgeState(Enum):
"""
Known bridge states.
"""
ONLINE = 'online'
OFFLINE = 'offline'
@dataclass
class ZigbeeDevices:
"""
Cached information about the devices on the network.
"""
by_address: Dict[str, dict] = field(default_factory=dict)
by_name: Dict[str, dict] = field(default_factory=dict)
def __contains__(self, name: str) -> bool:
"""
:return: True if the device with the given name exists in the cache.
"""
return name in self.by_name or name in self.by_address
def get(self, name: str) -> Optional[dict]:
"""
Retrieves a cached device record either by name or by address.
"""
return self.by_address.get(name, self.by_name.get(name))
def add(self, device: dict):
"""
Adds a device record to the cache.
"""
if device.get('ieee_address'):
self.by_address[device['ieee_address']] = device
if device.get('friendly_name'):
self.by_name[device['friendly_name']] = device
if not device.get('state'):
device['state'] = {}
def remove(self, device: Union[str, dict]):
"""
Removes a device record from the cache.
"""
if isinstance(device, str):
dev = self.get(device)
if not dev:
return # No such device
else:
dev = device
if dev.get('ieee_address'):
self.by_address.pop(dev['ieee_address'], None)
if dev.get('friendly_name'):
self.by_name.pop(dev['friendly_name'], None)
def set_state(self, device: str, state: dict):
"""
Updates the state of a device in the cache.
:param device: Name or address of the device.
:param state: Map containing the new state.
"""
dev = self.get(device)
if not dev:
return
dev['state'] = state
@dataclass
class ZigbeeState:
"""
Cached information about the devices and groups on the network.
"""
devices: ZigbeeDevices = field(default_factory=ZigbeeDevices)
groups: Dict[str, dict] = field(default_factory=dict)
bridge_state: BridgeState = BridgeState.OFFLINE