Compare commits

...

8 commits

Author SHA1 Message Date
4223576016
We should always update the entities cache on addEntity.
All checks were successful
continuous-integration/drone/push Build is passing
Updating the entity cache only when we receive an event from a root
entity means that we lose events sent by individual child entities.
2023-09-15 00:45:38 +02:00
1020b63da7
All EntityMixin components should be allowed to emit loading events. 2023-09-15 00:34:29 +02:00
2c93049ee5
Catch all the exceptions in a plugin action wrapper.
The @action decorator should capture all the exceptions,
log them and return them on `Response.errors`.

This ensures that uncaught exceptions from plugin
actions won't unwind out of control, and also that they
are logged and treated consistently across all the
integrations.
2023-09-14 23:08:23 +02:00
ac72b2f7a8
Fixed management of state on zigbee.mqtt.
Before the merge of the plugin and the listener those components
used to have their own separate state, which led to inconsistencies.
2023-09-14 23:05:27 +02:00
5a514fdcce
Only support the run_topic logic on the MQTT plugin.
Plugins that extend `MqttPlugin` shouldn't run messages as
requests, even if the parent MQTT plugin is configured to
do so.
2023-09-14 01:09:03 +02:00
4cb5aa7acb
Prepend the class name to the string used to generate the MQTT client_id hash.
If we include the class name by default then we won't have to
explicitly modify the client_id in the implementation classes
in order to prevent clashes.
2023-09-14 01:06:53 +02:00
3104a59f44
Better processing of configuration file parameters.
- Do `abspath`+`expanduser` on the configuration file path before
  checking if it exists.

- If the path doesn't exist, but the user explicitly passed a
  configuration file, then copy/create the default configuration
  under the specified directory.
2023-09-14 00:24:52 +02:00
ddd8f1afdc
base_topic param in zigbee.mqtt renamed to topic_prefix.
This is for sake of consistency with other integrations (like
`zwave.mqtt`) that also use the same parameter name for the MQTT topic
prefix.
2023-09-07 21:32:56 +02:00
8 changed files with 186 additions and 153 deletions

View file

@ -4,7 +4,7 @@ import Utils from "@/Utils"
export default { export default {
name: "EntityMixin", name: "EntityMixin",
mixins: [Utils], mixins: [Utils],
emits: ['input'], emits: ['input', 'loading'],
props: { props: {
loading: { loading: {
type: Boolean, type: Boolean,

View file

@ -168,10 +168,11 @@ export default {
methods: { methods: {
addEntity(entity) { addEntity(entity) {
this.entities[entity.id] = entity
if (entity.parent_id != null) if (entity.parent_id != null)
return // Only group entities that have no parent return // Only group entities that have no parent
this.entities[entity.id] = entity;
['id', 'type', 'category', 'plugin'].forEach((attr) => { ['id', 'type', 'category', 'plugin'].forEach((attr) => {
if (entity[attr] == null) if (entity[attr] == null)
return return

View file

@ -106,10 +106,11 @@ class Config:
if cfgfile is None: if cfgfile is None:
cfgfile = self._get_default_cfgfile() cfgfile = self._get_default_cfgfile()
cfgfile = os.path.abspath(os.path.expanduser(cfgfile))
if cfgfile is None or not os.path.exists(cfgfile): if cfgfile is None or not os.path.exists(cfgfile):
cfgfile = self._create_default_config() cfgfile = self._create_default_config(cfgfile)
self.config_file = os.path.abspath(os.path.expanduser(cfgfile)) self.config_file = cfgfile
def _init_logging(self): def _init_logging(self):
logging_config = { logging_config = {
@ -211,8 +212,11 @@ class Config:
'variable': {}, 'variable': {},
} }
def _create_default_config(self): @staticmethod
def _create_default_config(cfgfile: Optional[str] = None):
cfg_mod_dir = os.path.dirname(os.path.abspath(__file__)) cfg_mod_dir = os.path.dirname(os.path.abspath(__file__))
if not cfgfile:
# Use /etc/platypush/config.yaml if the user is running as root, # Use /etc/platypush/config.yaml if the user is running as root,
# otherwise ~/.config/platypush/config.yaml # otherwise ~/.config/platypush/config.yaml
cfgfile = ( cfgfile = (
@ -526,6 +530,7 @@ class Config:
Get a config value or the whole configuration object. Get a config value or the whole configuration object.
:param key: Configuration entry to get (default: all entries). :param key: Configuration entry to get (default: all entries).
:param default: Default value to return if the key is missing.
""" """
# pylint: disable=protected-access # pylint: disable=protected-access
config = cls._get_instance()._config.copy() config = cls._get_instance()._config.copy()

View file

@ -25,8 +25,8 @@ class EntitiesEngine(Thread):
together (preventing excessive writes and throttling events), and together (preventing excessive writes and throttling events), and
prevents race conditions when SQLite is used. prevents race conditions when SQLite is used.
2. Merge any existing entities with their newer representations. 2. Merge any existing entities with their newer representations.
3. Update the entities taxonomy. 3. Update the entities' taxonomy.
4. Persist the new state to the entities database. 4. Persist the new state to the entities' database.
5. Trigger events for the updated entities. 5. Trigger events for the updated entities.
""" """

View file

@ -32,7 +32,10 @@ def action(f: Callable[..., Any]) -> Callable[..., Response]:
response = Response() response = Response()
try: try:
result = f(*args, **kwargs) result = f(*args, **kwargs)
except TypeError as e: except Exception as e:
if isinstance(e, KeyboardInterrupt):
return response
_logger.exception(e) _logger.exception(e)
result = Response(errors=[str(e)]) result = Response(errors=[str(e)])

View file

@ -137,12 +137,13 @@ class MqttPlugin(RunnablePlugin):
self.client_id = client_id or str(Config.get('device_id')) self.client_id = client_id or str(Config.get('device_id'))
self.run_topic = ( self.run_topic = (
f'{run_topic_prefix}/{Config.get("device_id")}' f'{run_topic_prefix}/{Config.get("device_id")}'
if run_topic_prefix if type(self) == MqttPlugin and run_topic_prefix
else None else None
) )
self._listeners_lock = defaultdict(threading.RLock) self._listeners_lock = defaultdict(threading.RLock)
self.listeners: Dict[str, MqttClient] = {} # client_id -> MqttClient map self.listeners: Dict[str, MqttClient] = {} # client_id -> MqttClient map
self.timeout = timeout
self.default_listener = ( self.default_listener = (
self._get_client( self._get_client(
host=host, host=host,
@ -188,6 +189,7 @@ class MqttPlugin(RunnablePlugin):
client_hash = hashlib.sha1( client_hash = hashlib.sha1(
'|'.join( '|'.join(
[ [
self.__class__.__name__,
host, host,
str(port), str(port),
json.dumps(sorted(topics)), json.dumps(sorted(topics)),

View file

@ -1,11 +1,9 @@
import contextlib import contextlib
from dataclasses import dataclass, field
from enum import Enum
import json import json
import re import re
import threading import threading
from queue import Queue from queue import Empty, Queue
from typing import ( from typing import (
Any, Any,
Collection, Collection,
@ -75,82 +73,7 @@ from platypush.message.event.zigbee.mqtt import (
) )
from platypush.message.response import Response from platypush.message.response import Response
from platypush.plugins.mqtt import DEFAULT_TIMEOUT, MqttClient, MqttPlugin, action from platypush.plugins.mqtt import DEFAULT_TIMEOUT, MqttClient, MqttPlugin, action
from ._state import BridgeState, ZigbeeState
class BridgeState(Enum):
"""
Known bridge states.
"""
ONLINE = 'online'
OFFLINE = 'offline'
@dataclass
class ZigbeeDevicesInfo:
"""
Cached information about the devices on the network.
"""
by_address: Dict[str, dict] = field(default_factory=dict)
by_name: Dict[str, dict] = field(default_factory=dict)
def __contains__(self, name: str) -> bool:
"""
:return: True if the device with the given name exists in the cache.
"""
return name in self.by_name or name in self.by_address
def get(self, name: str) -> Optional[dict]:
"""
Retrieves a cached device record either by name or by address.
"""
return self.by_address.get(name, self.by_name.get(name))
def add(self, device: dict):
"""
Adds a device record to the cache.
"""
if device.get('ieee_address'):
self.by_address[device['ieee_address']] = device
if device.get('friendly_name'):
self.by_name[device['friendly_name']] = device
def remove(self, device: Union[str, dict]):
"""
Removes a device record from the cache.
"""
if isinstance(device, str):
dev = self.get(device)
if not dev:
return # No such device
else:
dev = device
if dev.get('ieee_address'):
self.by_address.pop(dev['ieee_address'], None)
if dev.get('friendly_name'):
self.by_name.pop(dev['friendly_name'], None)
def reset(self, *keys: str):
"""
Reset the state for the devices with the given keys.
"""
for k in keys:
self.by_address[k] = {}
self.by_name[k] = {}
@dataclass
class ZigbeeInfo:
"""
Cached information about the devices and groups on the network.
"""
devices: ZigbeeDevicesInfo = field(default_factory=ZigbeeDevicesInfo)
groups: Dict[str, dict] = field(default_factory=dict)
# pylint: disable=too-many-ancestors # pylint: disable=too-many-ancestors
@ -226,7 +149,7 @@ class ZigbeeMqttPlugin(
# MQTT settings # MQTT settings
mqtt: mqtt:
# MQTT base topic for zigbee2mqtt MQTT messages # MQTT base topic for zigbee2mqtt MQTT messages
base_topic: zigbee2mqtt topic_prefix: zigbee2mqtt
# MQTT server URL # MQTT server URL
server: 'mqtt://localhost' server: 'mqtt://localhost'
# MQTT server authentication, uncomment if required: # MQTT server authentication, uncomment if required:
@ -294,7 +217,8 @@ class ZigbeeMqttPlugin(
self, self,
host: str, host: str,
port: int = 1883, port: int = 1883,
base_topic: str = 'zigbee2mqtt', topic_prefix: str = 'zigbee2mqtt',
base_topic: Optional[str] = None,
timeout: int = 10, timeout: int = 10,
tls_certfile: Optional[str] = None, tls_certfile: Optional[str] = None,
tls_keyfile: Optional[str] = None, tls_keyfile: Optional[str] = None,
@ -307,8 +231,10 @@ class ZigbeeMqttPlugin(
""" """
:param host: Default MQTT broker where ``zigbee2mqtt`` publishes its messages. :param host: Default MQTT broker where ``zigbee2mqtt`` publishes its messages.
:param port: Broker listen port (default: 1883). :param port: Broker listen port (default: 1883).
:param base_topic: Topic prefix, as specified in :param topic_prefix: Prefix for the published topics, as specified in
``/opt/zigbee2mqtt/data/configuration.yaml`` (default: '``zigbee2mqtt``'). ``/opt/zigbee2mqtt/data/configuration.yaml`` (default: '``zigbee2mqtt``').
:param base_topic: Legacy alias for ``topic_prefix`` (default:
'``zigbee2mqtt``').
:param timeout: If the command expects from a response, then this :param timeout: If the command expects from a response, then this
timeout value will be used (default: 60 seconds). timeout value will be used (default: 60 seconds).
:param tls_cafile: If the connection requires TLS/SSL, specify the :param tls_cafile: If the connection requires TLS/SSL, specify the
@ -326,11 +252,17 @@ class ZigbeeMqttPlugin(
:param password: If the connection requires user authentication, specify :param password: If the connection requires user authentication, specify
the password (default: None) the password (default: None)
""" """
if base_topic:
self.logger.warning(
'base_topic is deprecated, please use topic_prefix instead'
)
topic_prefix = base_topic
super().__init__( super().__init__(
host=host, host=host,
port=port, port=port,
topics=[ topics=[
f'{base_topic}/{topic}' f'{topic_prefix}/{topic}'
for topic in [ for topic in [
'bridge/state', 'bridge/state',
'bridge/log', 'bridge/log',
@ -345,17 +277,12 @@ class ZigbeeMqttPlugin(
tls_ciphers=tls_ciphers, tls_ciphers=tls_ciphers,
username=username, username=username,
password=password, password=password,
timeout=timeout,
**kwargs, **kwargs,
) )
# Append a unique suffix to the client ID to avoid client name clashes self.topic_prefix = topic_prefix
# with other MQTT plugins. self._info = ZigbeeState()
self.client_id += '-zigbee-mqtt'
self.base_topic = base_topic
self.timeout = timeout
self._info = ZigbeeInfo()
self._devices_meta: Dict[str, dict] = {}
self._bridge_state = BridgeState.OFFLINE
@staticmethod @staticmethod
def _get_properties(device: dict) -> dict: def _get_properties(device: dict) -> dict:
@ -550,7 +477,7 @@ class ZigbeeMqttPlugin(
) )
client.on_message = msg_callback client.on_message = msg_callback
client.connect() client.connect()
client.subscribe(self.base_topic + '/bridge/#') client.subscribe(self.topic_prefix + '/bridge/#')
client.loop_start() client.loop_start()
for event in info_ready_events.values(): for event in info_ready_events.values():
@ -590,7 +517,7 @@ class ZigbeeMqttPlugin(
Utility method that construct a topic prefixed by the configured base Utility method that construct a topic prefixed by the configured base
topic. topic.
""" """
return f'{self.base_topic}/{topic}' return f'{self.topic_prefix}/{topic}'
@staticmethod @staticmethod
def _parse_response(response: Union[dict, Response]) -> dict: def _parse_response(response: Union[dict, Response]) -> dict:
@ -617,7 +544,7 @@ class ZigbeeMqttPlugin(
**kwargs, **kwargs,
) -> dict: ) -> dict:
""" """
Sends a request/message to the Zigbeebee2MQTT bridge and waits for a Sends a request/message to the Zigbee2MQTT bridge and waits for a
response. response.
""" """
return self._parse_response( return self._parse_response(
@ -668,7 +595,7 @@ class ZigbeeMqttPlugin(
{ {
"date_code": "20180906", "date_code": "20180906",
"friendly_name": "My Lightbulb", "friendly_name": "My Light Bulb",
"ieee_address": "0x00123456789abcdf", "ieee_address": "0x00123456789abcdf",
"network_address": 52715, "network_address": 52715,
"power_source": "Mains (single phase)", "power_source": "Mains (single phase)",
@ -1004,26 +931,25 @@ class ZigbeeMqttPlugin(
converted. converted.
""" """
def extract_value(value: dict, root: dict, depth: int = 0): def extract_value(val: dict, root: dict, depth: int = 0):
for feature in value.get('features', []): for feature in val.get('features', []):
new_root = root new_root = root
if depth > 0: if depth > 0:
new_root = root[value['property']] = root.get(value['property'], {}) new_root = root[val['property']] = root.get(val['property'], {})
extract_value(feature, new_root, depth=depth + 1) extract_value(feature, new_root, depth=depth + 1)
if not value.get('access', 1) & 0x4: if not val.get('access', 1) & 0x4:
# Property not readable/query-able # Property not readable/query-able
return return
if 'features' not in value: if 'features' not in val:
if 'property' in value: if 'property' in val:
root[value['property']] = 0 if value['type'] == 'numeric' else '' root[val['property']] = 0 if val['type'] == 'numeric' else ''
return return
if 'property' in value: if 'property' in val:
root[value['property']] = root.get(value['property'], {}) root[val['property']] = root.get(val['property'], {})
root = root[value['property']]
ret: Dict[str, Any] = {} ret: Dict[str, Any] = {}
for value in values: for value in values:
@ -1116,10 +1042,7 @@ class ZigbeeMqttPlugin(
# If the device has no queryable properties, don't specify a reply # If the device has no queryable properties, don't specify a reply
# topic to listen on # topic to listen on
req = self._build_device_get_request(exposes) req = self._build_device_get_request(exposes)
reply_topic = self._topic(device) reply_topic = self._topic(device) if req else None
if not req:
reply_topic = None
return self._run_request( return self._run_request(
topic=self._topic(device) + '/get', topic=self._topic(device) + '/get',
reply_topic=reply_topic, reply_topic=reply_topic,
@ -1165,7 +1088,7 @@ class ZigbeeMqttPlugin(
) )
def worker(device: str, q: Queue): def worker(device: str, q: Queue):
q.put(self.device_get(device, **kwargs).output) # type: ignore q.put_nowait(self.device_get(device, **kwargs).output) # type: ignore
queues: Dict[str, Queue] = {} queues: Dict[str, Queue] = {}
workers = {} workers = {}
@ -1179,9 +1102,16 @@ class ZigbeeMqttPlugin(
workers[device].start() workers[device].start()
for device in devices: for device in devices:
timeout = kwargs.get('timeout')
try: try:
response[device] = queues[device].get(timeout=kwargs.get('timeout')) response[device] = queues[device].get(timeout=timeout)
workers[device].join(timeout=kwargs.get('timeout')) workers[device].join(timeout=timeout)
except Empty:
self.logger.warning(
'Could not get the status of the device %s: timeout after %f seconds',
device,
timeout,
)
except Exception as e: except Exception as e:
self.logger.warning( self.logger.warning(
'An error occurred while getting the status of the device %s: %s', 'An error occurred while getting the status of the device %s: %s',
@ -1189,6 +1119,8 @@ class ZigbeeMqttPlugin(
e, e,
) )
self.logger.exception(e)
return response return response
@action @action
@ -1279,6 +1211,7 @@ class ZigbeeMqttPlugin(
:param property: Name of the property to set. If not specified here, it :param property: Name of the property to set. If not specified here, it
should be specified on ``device`` in ``<address>:<property>`` should be specified on ``device`` in ``<address>:<property>``
format. format.
:param data: Value to set for the property.
:param kwargs: Extra arguments to be passed to :param kwargs: Extra arguments to be passed to
:meth:`platypush.plugins.mqtt.MqttPlugin.publish`` (default: query :meth:`platypush.plugins.mqtt.MqttPlugin.publish`` (default: query
the default configured device). the default configured device).
@ -1420,7 +1353,7 @@ class ZigbeeMqttPlugin(
"device_options": {}, "device_options": {},
"devices": { "devices": {
"0x00123456789abcdf": { "0x00123456789abcdf": {
"friendly_name": "My Lightbulb" "friendly_name": "My Light Bulb"
} }
}, },
"experimental": { "experimental": {
@ -1819,7 +1752,7 @@ class ZigbeeMqttPlugin(
@staticmethod @staticmethod
def _is_query_disabled(feature: dict) -> bool: def _is_query_disabled(feature: dict) -> bool:
""" """
Utility method that checks if a feature doesn't support programmating Utility method that checks if a feature doesn't support programmatic
querying (i.e. it will only broadcast its state when available) on the querying (i.e. it will only broadcast its state when available) on the
basis of its access flags. basis of its access flags.
""" """
@ -1845,9 +1778,9 @@ class ZigbeeMqttPlugin(
# IEEE address + property format # IEEE address + property format
if re.search(r'^0x[0-9a-fA-F]{16}:', dev): if re.search(r'^0x[0-9a-fA-F]{16}:', dev):
parts = dev.split(':') parts = dev.split(':')
return (parts[0], parts[1] if len(parts) > 1 else None) return parts[0], parts[1] if len(parts) > 1 else None
return (dev, None) return dev, None
@classmethod @classmethod
def _ieee_address(cls, device: Union[dict, str]) -> str: def _ieee_address(cls, device: Union[dict, str]) -> str:
@ -2211,7 +2144,8 @@ class ZigbeeMqttPlugin(
""" """
def handler(client: MqttClient, _, msg: mqtt.MQTTMessage): def handler(client: MqttClient, _, msg: mqtt.MQTTMessage):
topic = msg.topic[len(self.base_topic) + 1 :] topic_idx = len(self.topic_prefix) + 1
topic = msg.topic[topic_idx:]
data = msg.payload.decode() data = msg.payload.decode()
if not data: if not data:
return return
@ -2235,7 +2169,9 @@ class ZigbeeMqttPlugin(
dev = self._info.devices.get(name) dev = self._info.devices.get(name)
assert dev is not None, f'No such device: {name}' assert dev is not None, f'No such device: {name}'
changed_props = {k: v for k, v in data.items() if v != dev.get(k)} changed_props = {
k: v for k, v in data.items() if v != dev.get('state', {}).get(k)
}
if changed_props: if changed_props:
self._process_property_update(name, data) self._process_property_update(name, data)
@ -2248,11 +2184,11 @@ class ZigbeeMqttPlugin(
) )
) )
device_meta = self._devices_meta.get(name) dev = self._info.devices.get(name)
if device_meta: if dev:
data['friendly_name'] = device_meta.get('friendly_name') self._info.devices.set_state(
data['ieee_address'] = device_meta.get('ieee_address') dev.get('friendly_name') or dev.get('ieee_address'), data
self._info.devices.add(data) )
return handler return handler
@ -2267,15 +2203,15 @@ class ZigbeeMqttPlugin(
""" """
Process a state message. Process a state message.
""" """
if msg == self._bridge_state: if msg == self._info.bridge_state:
return return
if msg == 'online': if msg == 'online':
evt = ZigbeeMqttOnlineEvent evt = ZigbeeMqttOnlineEvent
self._bridge_state = BridgeState.ONLINE self._info.bridge_state = BridgeState.ONLINE
elif msg == 'offline': elif msg == 'offline':
evt = ZigbeeMqttOfflineEvent evt = ZigbeeMqttOfflineEvent
self._bridge_state = BridgeState.OFFLINE self._info.bridge_state = BridgeState.OFFLINE
self.logger.warning('The zigbee2mqtt service is offline') self.logger.warning('The zigbee2mqtt service is offline')
else: else:
return return
@ -2285,7 +2221,7 @@ class ZigbeeMqttPlugin(
# pylint: disable=too-many-branches # pylint: disable=too-many-branches
def _process_log_message(self, client, msg): def _process_log_message(self, client, msg):
""" """
Process a logevent. Process a log event.
""" """
msg_type = msg.get('type') msg_type = msg.get('type')
@ -2296,7 +2232,7 @@ class ZigbeeMqttPlugin(
devices = {} devices = {}
for dev in text or []: for dev in text or []:
devices[dev['friendly_name']] = dev devices[dev['friendly_name']] = dev
client.subscribe(self.base_topic + '/' + dev['friendly_name']) client.subscribe(self.topic_prefix + '/' + dev['friendly_name'])
elif msg_type == 'pairing': elif msg_type == 'pairing':
self._bus.post(ZigbeeMqttDevicePairingEvent(device=text, **args)) self._bus.post(ZigbeeMqttDevicePairingEvent(device=text, **args))
elif msg_type in ['device_ban', 'device_banned']: elif msg_type in ['device_ban', 'device_banned']:
@ -2349,7 +2285,7 @@ class ZigbeeMqttPlugin(
# Subscribe to updates from all the known devices # Subscribe to updates from all the known devices
event_args = {'host': client.host, 'port': client.port} event_args = {'host': client.host, 'port': client.port}
client.subscribe( client.subscribe(
*[self.base_topic + '/' + device for device in devices_info.keys()] *[self.topic_prefix + '/' + device for device in devices_info.keys()]
) )
for name, device in devices_info.items(): for name, device in devices_info.items():
@ -2365,7 +2301,7 @@ class ZigbeeMqttPlugin(
payload = self._build_device_get_request(exposes) payload = self._build_device_get_request(exposes)
if payload: if payload:
client.publish( client.publish(
self.base_topic + '/' + name + '/get', self.topic_prefix + '/' + name + '/get',
json.dumps(payload), json.dumps(payload),
) )
@ -2375,8 +2311,8 @@ class ZigbeeMqttPlugin(
self._bus.post(ZigbeeMqttDeviceRemovedEvent(device=name, **event_args)) self._bus.post(ZigbeeMqttDeviceRemovedEvent(device=name, **event_args))
self._info.devices.remove(name) self._info.devices.remove(name)
self._info.devices.reset(*devices_info) for dev in devices_info.values():
self._devices_meta = devices_info self._info.devices.add(dev)
def _process_groups(self, client: MqttClient, msg): def _process_groups(self, client: MqttClient, msg):
""" """
@ -2408,7 +2344,7 @@ class ZigbeeMqttPlugin(
It will appropriately forward an It will appropriately forward an
:class:`platypush.message.event.entities.EntityUpdateEvent` to the bus. :class:`platypush.message.event.entities.EntityUpdateEvent` to the bus.
""" """
device_info = self._devices_meta.get(device_name) device_info = self._info.devices.get(device_name)
if not (device_info and properties): if not (device_info and properties):
return return

View file

@ -0,0 +1,86 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, Optional, Union
class BridgeState(Enum):
"""
Known bridge states.
"""
ONLINE = 'online'
OFFLINE = 'offline'
@dataclass
class ZigbeeDevices:
"""
Cached information about the devices on the network.
"""
by_address: Dict[str, dict] = field(default_factory=dict)
by_name: Dict[str, dict] = field(default_factory=dict)
def __contains__(self, name: str) -> bool:
"""
:return: True if the device with the given name exists in the cache.
"""
return name in self.by_name or name in self.by_address
def get(self, name: str) -> Optional[dict]:
"""
Retrieves a cached device record either by name or by address.
"""
return self.by_address.get(name, self.by_name.get(name))
def add(self, device: dict):
"""
Adds a device record to the cache.
"""
if device.get('ieee_address'):
self.by_address[device['ieee_address']] = device
if device.get('friendly_name'):
self.by_name[device['friendly_name']] = device
if not device.get('state'):
device['state'] = {}
def remove(self, device: Union[str, dict]):
"""
Removes a device record from the cache.
"""
if isinstance(device, str):
dev = self.get(device)
if not dev:
return # No such device
else:
dev = device
if dev.get('ieee_address'):
self.by_address.pop(dev['ieee_address'], None)
if dev.get('friendly_name'):
self.by_name.pop(dev['friendly_name'], None)
def set_state(self, device: str, state: dict):
"""
Updates the state of a device in the cache.
:param device: Name or address of the device.
:param state: Map containing the new state.
"""
dev = self.get(device)
if not dev:
return
dev['state'] = state
@dataclass
class ZigbeeState:
"""
Cached information about the devices and groups on the network.
"""
devices: ZigbeeDevices = field(default_factory=ZigbeeDevices)
groups: Dict[str, dict] = field(default_factory=dict)
bridge_state: BridgeState = BridgeState.OFFLINE