From f9598977db5f2824ceed5058831211f3b59ac7ef Mon Sep 17 00:00:00 2001 From: Fabio Manganiello Date: Wed, 10 Feb 2021 22:26:51 +0100 Subject: [PATCH] Refactored backend.mqtt to reuse connections whenever possible, as well as programmatically subscribe/unsubscribe topics at runtime --- platypush/backend/mqtt.py | 211 +++++++++++++++++++++++--------------- 1 file changed, 127 insertions(+), 84 deletions(-) diff --git a/platypush/backend/mqtt.py b/platypush/backend/mqtt.py index 80daeaa5..53a6e104 100644 --- a/platypush/backend/mqtt.py +++ b/platypush/backend/mqtt.py @@ -1,7 +1,9 @@ import json import os import threading -from typing import Optional +from typing import Optional, List, Callable + +import paho.mqtt.client as mqtt from platypush.backend import Backend from platypush.config import Config @@ -13,6 +15,73 @@ from platypush.plugins.mqtt import MqttPlugin as MQTTPlugin from platypush.utils import set_thread_name +class MqttClient(mqtt.Client, threading.Thread): + def __init__(self, *args, host: str, port: int, topics: Optional[List[str]] = None, + on_message: Optional[Callable] = None, username: Optional[str] = None, password: Optional[str] = None, + client_id: Optional[str] = None, tls_cafile: Optional[str] = None, tls_certfile: Optional[str] = None, + tls_keyfile: Optional[str] = None, tls_version: Optional = None, tls_ciphers: Optional = None, + tls_insecure: bool = False, keepalive: Optional[int] = 60, **kwargs): + mqtt.Client.__init__(self, *args, client_id=client_id, **kwargs) + threading.Thread.__init__(self) + + self.host = host + self.port = port + self.topics = set(topics or []) + self.keepalive = keepalive + self.on_connect = self.connect_hndl() + + if on_message: + self.on_message = on_message + + if username and password: + self.username_pw_set(username, password) + + if tls_cafile: + self.tls_set(ca_certs=tls_cafile, certfile=tls_certfile, keyfile=tls_keyfile, tls_version=tls_version, + ciphers=tls_ciphers) + + self.tls_insecure_set(tls_insecure) + + self._running = False + self._stop_scheduled = False + + def subscribe(self, *topics, **kwargs): + if not topics: + topics = self.topics + + self.topics.update(topics) + for topic in topics: + super().subscribe(topic, **kwargs) + + def unsubscribe(self, *topics, **kwargs): + if not topics: + topics = self.topics + + for topic in topics: + super().unsubscribe(topic, **kwargs) + self.topics.remove(topic) + + def connect_hndl(self): + def handler(*_, **__): + self.subscribe() + + return handler + + def run(self): + super().run() + self.connect(host=self.host, port=self.port, keepalive=self.keepalive) + self._running = True + self.loop_forever() + + def stop(self): + if not self.is_alive(): + return + + self._stop_scheduled = True + self.disconnect() + self._running = False + + class MqttBackend(Backend): """ Backend that reads messages from a configured MQTT topic (default: @@ -115,9 +184,7 @@ class MqttBackend(Backend): self.topic = '{}/{}'.format(topic, self.device_id) self.subscribe_default_topic = subscribe_default_topic - self._client = None - self._listeners = [] - + self._listeners = {} # (host, port, msg_handler) -> MqttClient map self.listeners_conf = listeners or [] def send_message(self, msg, topic: Optional[str] = None, **kwargs): @@ -128,51 +195,20 @@ class MqttBackend(Backend): password=self.password, tls_cafile=self.tls_cafile, tls_certfile=self.tls_certfile, tls_keyfile=self.tls_keyfile, tls_version=self.tls_version, tls_insecure=self.tls_insecure, - tls_ciphers=self.tls_ciphers, client_id=self.client_id, **kwargs) + tls_ciphers=self.tls_ciphers, **kwargs) except Exception as e: self.logger.exception(e) - @staticmethod - def on_connect(*topics): - # noinspection PyUnusedLocal - def handler(client, userdata, flags, rc): - for topic in topics: - client.subscribe(topic) - - return handler - - def on_mqtt_message(self): - def handler(client, _, msg): - data = msg.payload - # noinspection PyBroadException - try: - data = data.decode('utf-8') - data = json.loads(data) - except: - pass - - # noinspection PyProtectedMember - self.bus.post(MQTTMessageEvent(host=client._host, port=client._port, topic=msg.topic, msg=data)) - - return handler - @staticmethod def _expandpath(path: str) -> str: return os.path.abspath(os.path.expanduser(path)) if path else path - def _initialize_listeners(self, listeners_conf): - import paho.mqtt.client as mqtt - - def listener_thread(client_, host, port): - client_.connect(host, port) - client_.loop_forever() - + def add_listeners(self, *listeners): # noinspection PyShadowingNames,PyUnusedLocal - for i, listener in enumerate(listeners_conf): + for i, listener in enumerate(listeners): host = listener.get('host') if host: port = listener.get('port', self._default_mqtt_port) - topics = listener.get('topics') username = listener.get('username') password = listener.get('password') tls_cafile = self._expandpath(listener.get('tls_cafile')) @@ -189,7 +225,7 @@ class MqttBackend(Backend): tls_cafile = self.tls_cafile tls_certfile = self.tls_certfile tls_keyfile = self.tls_keyfile - tls_version = self.tls_keyfile + tls_version = self.tls_version tls_ciphers = self.tls_ciphers tls_insecure = self.tls_insecure @@ -198,24 +234,48 @@ class MqttBackend(Backend): self.logger.warning('No list of topics specified for listener n.{}'.format(i+1)) continue - client = mqtt.Client() - client.on_connect = self.on_connect(*topics) - client.on_message = self.on_mqtt_message() + client = self._get_client(host=host, port=port, topics=topics, username=username, password=password, + client_id=self.client_id, tls_cafile=tls_cafile, tls_certfile=tls_certfile, + tls_keyfile=tls_keyfile, tls_version=tls_version, tls_ciphers=tls_ciphers, + tls_insecure=tls_insecure) - if username and password: - client.username_pw_set(username, password) + if not client.is_alive(): + client.start() - if tls_cafile: - client.tls_set(ca_certs=tls_cafile, - certfile=tls_certfile, - keyfile=tls_keyfile, - tls_version=tls_version, - ciphers=tls_ciphers) + def _get_client(self, host: str, port: int, topics: Optional[List[str]] = None, username: Optional[str] = None, + password: Optional[str] = None, client_id: Optional[str] = None, tls_cafile: Optional[str] = None, + tls_certfile: Optional[str] = None, tls_keyfile: Optional[str] = None, tls_version: Optional = None, + tls_ciphers: Optional = None, tls_insecure: bool = False, on_message: Optional[Callable] = None) \ + -> MqttClient: + on_message = on_message or self.on_mqtt_message() + on_message_name = repr(on_message) + client = self._listeners.get((host, port, on_message_name)) - client.tls_insecure_set(tls_insecure) + if not (client and client.is_alive()): + client = MqttClient(host=host, port=port, topics=topics, username=username, password=password, + client_id=client_id, tls_cafile=tls_cafile, tls_certfile=tls_certfile, + tls_keyfile=tls_keyfile, tls_version=tls_version, tls_ciphers=tls_ciphers, + tls_insecure=tls_insecure, on_message=on_message) - threading.Thread(target=listener_thread, kwargs={ - 'client_': client, 'host': host, 'port': port}).start() + self._listeners[(host, port, on_message_name)] = client + + client.subscribe(*topics) + return client + + def on_mqtt_message(self): + def handler(client, __, msg): + data = msg.payload + # noinspection PyBroadException + try: + data = data.decode('utf-8') + data = json.loads(data) + except: + pass + + # noinspection PyProtectedMember + self.bus.post(MQTTMessageEvent(host=client._host, port=client._port, topic=msg.topic, msg=data)) + + return handler def on_exec_message(self): def handler(_, __, msg): @@ -257,51 +317,34 @@ class MqttBackend(Backend): return handler def run(self): - import paho.mqtt.client as mqtt - super().run() - self._client = None if self.host: - self._client = mqtt.Client(self.client_id) - if self.subscribe_default_topic: - self._client.on_connect = self.on_connect(self.topic) + topics = [self.topic] if self.subscribe_default_topic else [] + client = self._get_client(host=self.host, port=self.port, topics=topics, username=self.username, + password=self.password, client_id=self.client_id, + tls_cafile=self.tls_cafile, tls_certfile=self.tls_certfile, + tls_keyfile=self.tls_keyfile, tls_version=self.tls_version, + tls_ciphers=self.tls_ciphers, tls_insecure=self.tls_insecure, + on_message=self.on_exec_message()) - self._client.on_message = self.on_exec_message() - if self.username and self.password: - self._client.username_pw_set(self.username, self.password) - - if self.tls_cafile: - self._client.tls_set(ca_certs=self.tls_cafile, certfile=self.tls_certfile, - keyfile=self.tls_keyfile, - tls_version=self.tls_version, - ciphers=self.tls_ciphers) - - self._client.tls_insecure_set(self.tls_insecure) - - self._client.connect(self.host, self.port, 60) + client.start() self.logger.info('Initialized MQTT backend on host {}:{}, topic {}'. format(self.host, self.port, self.topic)) - self._initialize_listeners(self.listeners_conf) - if self._client: - self._client.loop_forever() + self.add_listeners(*self.listeners_conf) def stop(self): - self.logger.info('Received STOP event on MqttBackend') - if self._client: - self._client.disconnect() - self._client.loop_stop() - self._client = None + self.logger.info('Received STOP event on the MQTT backend') - for listener in self._listeners: + for ((host, port, _), listener) in self._listeners.items(): try: listener.loop_stop() + listener.disconnect() except Exception as e: # noinspection PyProtectedMember - self.logger.warning('Could not stop listener {host}:{port}: {error}'.format( - host=listener._host, port=listener._port, - error=str(e))) + self.logger.warning('Could not stop listener {host}:{port}: {error}'. + format(host=host, port=port, error=str(e))) # vim:sw=4:ts=4:et: