forked from platypush/platypush
New client ID generation logic (closes #205)
MQTT client IDs are now generated as a function of `(client_id, host, port, topics, on_message)` to prevent client ID clashes.
This commit is contained in:
parent
fa708663e1
commit
fa0f4925ed
1 changed files with 25 additions and 14 deletions
|
@ -1,3 +1,4 @@
|
|||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
|
@ -190,7 +191,7 @@ class MqttBackend(Backend):
|
|||
|
||||
self.topic = '{}/{}'.format(topic, self.device_id)
|
||||
self.subscribe_default_topic = subscribe_default_topic
|
||||
self._listeners = {} # (host, port, msg_handler) -> MqttClient map
|
||||
self._listeners = {} # client_id -> MqttClient map
|
||||
self.listeners_conf = listeners or []
|
||||
|
||||
def send_message(self, msg, topic: Optional[str] = None, **kwargs):
|
||||
|
@ -248,22 +249,34 @@ class MqttBackend(Backend):
|
|||
if not client.is_alive():
|
||||
client.start()
|
||||
|
||||
def _get_client_id(
|
||||
self, host: str, port: int, topics: Optional[List[str]] = None,
|
||||
client_id: Optional[str] = None, on_message: Optional[bool] = None,
|
||||
) -> str:
|
||||
return '{client_id}-{client_hash}'.format(
|
||||
client_id=client_id or self.client_id,
|
||||
client_hash=hashlib.sha1('|'.join([
|
||||
host, str(port),
|
||||
json.dumps(sorted(topics or [])),
|
||||
str(id(on_message))
|
||||
]).encode()).hexdigest(),
|
||||
)
|
||||
|
||||
def _get_client(self, host: str, port: int, topics: Optional[List[str]] = None, username: Optional[str] = None,
|
||||
password: Optional[str] = None, client_id: Optional[str] = None, tls_cafile: Optional[str] = None,
|
||||
tls_certfile: Optional[str] = None, tls_keyfile: Optional[str] = None, tls_version: Optional = None,
|
||||
tls_ciphers: Optional = None, tls_insecure: bool = False, on_message: Optional[Callable] = None) \
|
||||
-> MqttClient:
|
||||
on_message = on_message or self.on_mqtt_message()
|
||||
on_message_name = repr(on_message)
|
||||
client = self._listeners.get((host, port, on_message_name))
|
||||
client_id = self._get_client_id(host=host, port=port, topics=topics, client_id=client_id, on_message=on_message)
|
||||
client = self._listeners.get(client_id)
|
||||
|
||||
if not (client and client.is_alive()):
|
||||
client = MqttClient(host=host, port=port, topics=topics, username=username, password=password,
|
||||
client_id=client_id, tls_cafile=tls_cafile, tls_certfile=tls_certfile,
|
||||
tls_keyfile=tls_keyfile, tls_version=tls_version, tls_ciphers=tls_ciphers,
|
||||
tls_insecure=tls_insecure, on_message=on_message)
|
||||
|
||||
self._listeners[(host, port, on_message_name)] = client
|
||||
client = self._listeners[client_id] = MqttClient(
|
||||
host=host, port=port, topics=topics, username=username, password=password,
|
||||
client_id=client_id, tls_cafile=tls_cafile, tls_certfile=tls_certfile,
|
||||
tls_keyfile=tls_keyfile, tls_version=tls_version, tls_ciphers=tls_ciphers,
|
||||
tls_insecure=tls_insecure, on_message=on_message
|
||||
)
|
||||
|
||||
client.subscribe(*topics)
|
||||
return client
|
||||
|
@ -343,13 +356,11 @@ class MqttBackend(Backend):
|
|||
def on_stop(self):
|
||||
self.logger.info('Received STOP event on the MQTT backend')
|
||||
|
||||
for ((host, port, _), listener) in self._listeners.items():
|
||||
for listener in self._listeners.values():
|
||||
try:
|
||||
listener.stop()
|
||||
except Exception as e:
|
||||
# noinspection PyProtectedMember
|
||||
self.logger.warning('Could not stop listener {host}:{port}: {error}'.
|
||||
format(host=host, port=port, error=str(e)))
|
||||
self.logger.warning(f'Could not stop MQTT listener: {e}')
|
||||
|
||||
self.logger.info('MQTT backend terminated')
|
||||
|
||||
|
|
Loading…
Reference in a new issue