From fa0f4925edae41ef8cb55b36b809d804603a5f17 Mon Sep 17 00:00:00 2001 From: Fabio Manganiello Date: Mon, 13 Dec 2021 20:34:06 +0100 Subject: [PATCH] New client ID generation logic (closes #205) MQTT client IDs are now generated as a function of `(client_id, host, port, topics, on_message)` to prevent client ID clashes. --- platypush/backend/mqtt/__init__.py | 39 +++++++++++++++++++----------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/platypush/backend/mqtt/__init__.py b/platypush/backend/mqtt/__init__.py index da2aead34..df854a571 100644 --- a/platypush/backend/mqtt/__init__.py +++ b/platypush/backend/mqtt/__init__.py @@ -1,3 +1,4 @@ +import hashlib import json import os import threading @@ -190,7 +191,7 @@ class MqttBackend(Backend): self.topic = '{}/{}'.format(topic, self.device_id) self.subscribe_default_topic = subscribe_default_topic - self._listeners = {} # (host, port, msg_handler) -> MqttClient map + self._listeners = {} # client_id -> MqttClient map self.listeners_conf = listeners or [] def send_message(self, msg, topic: Optional[str] = None, **kwargs): @@ -248,22 +249,34 @@ class MqttBackend(Backend): if not client.is_alive(): client.start() + def _get_client_id( + self, host: str, port: int, topics: Optional[List[str]] = None, + client_id: Optional[str] = None, on_message: Optional[bool] = None, + ) -> str: + return '{client_id}-{client_hash}'.format( + client_id=client_id or self.client_id, + client_hash=hashlib.sha1('|'.join([ + host, str(port), + json.dumps(sorted(topics or [])), + str(id(on_message)) + ]).encode()).hexdigest(), + ) + def _get_client(self, host: str, port: int, topics: Optional[List[str]] = None, username: Optional[str] = None, password: Optional[str] = None, client_id: Optional[str] = None, tls_cafile: Optional[str] = None, tls_certfile: Optional[str] = None, tls_keyfile: Optional[str] = None, tls_version: Optional = None, tls_ciphers: Optional = None, tls_insecure: bool = False, on_message: Optional[Callable] = None) \ -> MqttClient: - on_message = on_message or self.on_mqtt_message() - on_message_name = repr(on_message) - client = self._listeners.get((host, port, on_message_name)) + client_id = self._get_client_id(host=host, port=port, topics=topics, client_id=client_id, on_message=on_message) + client = self._listeners.get(client_id) if not (client and client.is_alive()): - client = MqttClient(host=host, port=port, topics=topics, username=username, password=password, - client_id=client_id, tls_cafile=tls_cafile, tls_certfile=tls_certfile, - tls_keyfile=tls_keyfile, tls_version=tls_version, tls_ciphers=tls_ciphers, - tls_insecure=tls_insecure, on_message=on_message) - - self._listeners[(host, port, on_message_name)] = client + client = self._listeners[client_id] = MqttClient( + host=host, port=port, topics=topics, username=username, password=password, + client_id=client_id, tls_cafile=tls_cafile, tls_certfile=tls_certfile, + tls_keyfile=tls_keyfile, tls_version=tls_version, tls_ciphers=tls_ciphers, + tls_insecure=tls_insecure, on_message=on_message + ) client.subscribe(*topics) return client @@ -343,13 +356,11 @@ class MqttBackend(Backend): def on_stop(self): self.logger.info('Received STOP event on the MQTT backend') - for ((host, port, _), listener) in self._listeners.items(): + for listener in self._listeners.values(): try: listener.stop() except Exception as e: - # noinspection PyProtectedMember - self.logger.warning('Could not stop listener {host}:{port}: {error}'. - format(host=host, port=port, error=str(e))) + self.logger.warning(f'Could not stop MQTT listener: {e}') self.logger.info('MQTT backend terminated')