New client ID generation logic (closes #205)

MQTT client IDs are now generated as a function of
`(client_id, host, port, topics, on_message)` to
prevent client ID clashes.
This commit is contained in:
Fabio Manganiello 2021-12-13 20:34:06 +01:00
parent fa708663e1
commit fa0f4925ed
Signed by: blacklight
GPG key ID: D90FBA7F76362774

View file

@ -1,3 +1,4 @@
import hashlib
import json import json
import os import os
import threading import threading
@ -190,7 +191,7 @@ class MqttBackend(Backend):
self.topic = '{}/{}'.format(topic, self.device_id) self.topic = '{}/{}'.format(topic, self.device_id)
self.subscribe_default_topic = subscribe_default_topic self.subscribe_default_topic = subscribe_default_topic
self._listeners = {} # (host, port, msg_handler) -> MqttClient map self._listeners = {} # client_id -> MqttClient map
self.listeners_conf = listeners or [] self.listeners_conf = listeners or []
def send_message(self, msg, topic: Optional[str] = None, **kwargs): def send_message(self, msg, topic: Optional[str] = None, **kwargs):
@ -248,22 +249,34 @@ class MqttBackend(Backend):
if not client.is_alive(): if not client.is_alive():
client.start() client.start()
def _get_client_id(
self, host: str, port: int, topics: Optional[List[str]] = None,
client_id: Optional[str] = None, on_message: Optional[bool] = None,
) -> str:
return '{client_id}-{client_hash}'.format(
client_id=client_id or self.client_id,
client_hash=hashlib.sha1('|'.join([
host, str(port),
json.dumps(sorted(topics or [])),
str(id(on_message))
]).encode()).hexdigest(),
)
def _get_client(self, host: str, port: int, topics: Optional[List[str]] = None, username: Optional[str] = None, def _get_client(self, host: str, port: int, topics: Optional[List[str]] = None, username: Optional[str] = None,
password: Optional[str] = None, client_id: Optional[str] = None, tls_cafile: Optional[str] = None, password: Optional[str] = None, client_id: Optional[str] = None, tls_cafile: Optional[str] = None,
tls_certfile: Optional[str] = None, tls_keyfile: Optional[str] = None, tls_version: Optional = None, tls_certfile: Optional[str] = None, tls_keyfile: Optional[str] = None, tls_version: Optional = None,
tls_ciphers: Optional = None, tls_insecure: bool = False, on_message: Optional[Callable] = None) \ tls_ciphers: Optional = None, tls_insecure: bool = False, on_message: Optional[Callable] = None) \
-> MqttClient: -> MqttClient:
on_message = on_message or self.on_mqtt_message() client_id = self._get_client_id(host=host, port=port, topics=topics, client_id=client_id, on_message=on_message)
on_message_name = repr(on_message) client = self._listeners.get(client_id)
client = self._listeners.get((host, port, on_message_name))
if not (client and client.is_alive()): if not (client and client.is_alive()):
client = MqttClient(host=host, port=port, topics=topics, username=username, password=password, client = self._listeners[client_id] = MqttClient(
host=host, port=port, topics=topics, username=username, password=password,
client_id=client_id, tls_cafile=tls_cafile, tls_certfile=tls_certfile, client_id=client_id, tls_cafile=tls_cafile, tls_certfile=tls_certfile,
tls_keyfile=tls_keyfile, tls_version=tls_version, tls_ciphers=tls_ciphers, tls_keyfile=tls_keyfile, tls_version=tls_version, tls_ciphers=tls_ciphers,
tls_insecure=tls_insecure, on_message=on_message) tls_insecure=tls_insecure, on_message=on_message
)
self._listeners[(host, port, on_message_name)] = client
client.subscribe(*topics) client.subscribe(*topics)
return client return client
@ -343,13 +356,11 @@ class MqttBackend(Backend):
def on_stop(self): def on_stop(self):
self.logger.info('Received STOP event on the MQTT backend') self.logger.info('Received STOP event on the MQTT backend')
for ((host, port, _), listener) in self._listeners.items(): for listener in self._listeners.values():
try: try:
listener.stop() listener.stop()
except Exception as e: except Exception as e:
# noinspection PyProtectedMember self.logger.warning(f'Could not stop MQTT listener: {e}')
self.logger.warning('Could not stop listener {host}:{port}: {error}'.
format(host=host, port=port, error=str(e)))
self.logger.info('MQTT backend terminated') self.logger.info('MQTT backend terminated')