Refactored backend.mqtt to reuse connections whenever possible, as well as programmatically subscribe/unsubscribe topics at runtime

This commit is contained in:
Fabio Manganiello 2021-02-10 22:26:51 +01:00
parent 7325c87068
commit ade04a6ea1

View file

@ -1,7 +1,9 @@
import json
import os
import threading
from typing import Optional
from typing import Optional, List, Callable
import paho.mqtt.client as mqtt
from platypush.backend import Backend
from platypush.config import Config
@ -13,6 +15,73 @@ from platypush.plugins.mqtt import MqttPlugin as MQTTPlugin
from platypush.utils import set_thread_name
class MqttClient(mqtt.Client, threading.Thread):
def __init__(self, *args, host: str, port: int, topics: Optional[List[str]] = None,
on_message: Optional[Callable] = None, username: Optional[str] = None, password: Optional[str] = None,
client_id: Optional[str] = None, tls_cafile: Optional[str] = None, tls_certfile: Optional[str] = None,
tls_keyfile: Optional[str] = None, tls_version: Optional = None, tls_ciphers: Optional = None,
tls_insecure: bool = False, keepalive: Optional[int] = 60, **kwargs):
mqtt.Client.__init__(self, *args, client_id=client_id, **kwargs)
threading.Thread.__init__(self)
self.host = host
self.port = port
self.topics = set(topics or [])
self.keepalive = keepalive
self.on_connect = self.connect_hndl()
if on_message:
self.on_message = on_message
if username and password:
self.username_pw_set(username, password)
if tls_cafile:
self.tls_set(ca_certs=tls_cafile, certfile=tls_certfile, keyfile=tls_keyfile, tls_version=tls_version,
ciphers=tls_ciphers)
self.tls_insecure_set(tls_insecure)
self._running = False
self._stop_scheduled = False
def subscribe(self, *topics, **kwargs):
if not topics:
topics = self.topics
self.topics.update(topics)
for topic in topics:
super().subscribe(topic, **kwargs)
def unsubscribe(self, *topics, **kwargs):
if not topics:
topics = self.topics
for topic in topics:
super().unsubscribe(topic, **kwargs)
self.topics.remove(topic)
def connect_hndl(self):
def handler(*_, **__):
self.subscribe()
return handler
def run(self):
super().run()
self.connect(host=self.host, port=self.port, keepalive=self.keepalive)
self._running = True
self.loop_forever()
def stop(self):
if not self.is_alive():
return
self._stop_scheduled = True
self.disconnect()
self._running = False
class MqttBackend(Backend):
"""
Backend that reads messages from a configured MQTT topic (default:
@ -115,9 +184,7 @@ class MqttBackend(Backend):
self.topic = '{}/{}'.format(topic, self.device_id)
self.subscribe_default_topic = subscribe_default_topic
self._client = None
self._listeners = []
self._listeners = {} # (host, port, msg_handler) -> MqttClient map
self.listeners_conf = listeners or []
def send_message(self, msg, topic: Optional[str] = None, **kwargs):
@ -128,51 +195,20 @@ class MqttBackend(Backend):
password=self.password, tls_cafile=self.tls_cafile,
tls_certfile=self.tls_certfile, tls_keyfile=self.tls_keyfile,
tls_version=self.tls_version, tls_insecure=self.tls_insecure,
tls_ciphers=self.tls_ciphers, client_id=self.client_id, **kwargs)
tls_ciphers=self.tls_ciphers, **kwargs)
except Exception as e:
self.logger.exception(e)
@staticmethod
def on_connect(*topics):
# noinspection PyUnusedLocal
def handler(client, userdata, flags, rc):
for topic in topics:
client.subscribe(topic)
return handler
def on_mqtt_message(self):
def handler(client, _, msg):
data = msg.payload
# noinspection PyBroadException
try:
data = data.decode('utf-8')
data = json.loads(data)
except:
pass
# noinspection PyProtectedMember
self.bus.post(MQTTMessageEvent(host=client._host, port=client._port, topic=msg.topic, msg=data))
return handler
@staticmethod
def _expandpath(path: str) -> str:
return os.path.abspath(os.path.expanduser(path)) if path else path
def _initialize_listeners(self, listeners_conf):
import paho.mqtt.client as mqtt
def listener_thread(client_, host, port):
client_.connect(host, port)
client_.loop_forever()
def add_listeners(self, *listeners):
# noinspection PyShadowingNames,PyUnusedLocal
for i, listener in enumerate(listeners_conf):
for i, listener in enumerate(listeners):
host = listener.get('host')
if host:
port = listener.get('port', self._default_mqtt_port)
topics = listener.get('topics')
username = listener.get('username')
password = listener.get('password')
tls_cafile = self._expandpath(listener.get('tls_cafile'))
@ -189,7 +225,7 @@ class MqttBackend(Backend):
tls_cafile = self.tls_cafile
tls_certfile = self.tls_certfile
tls_keyfile = self.tls_keyfile
tls_version = self.tls_keyfile
tls_version = self.tls_version
tls_ciphers = self.tls_ciphers
tls_insecure = self.tls_insecure
@ -198,24 +234,48 @@ class MqttBackend(Backend):
self.logger.warning('No list of topics specified for listener n.{}'.format(i+1))
continue
client = mqtt.Client()
client.on_connect = self.on_connect(*topics)
client.on_message = self.on_mqtt_message()
client = self._get_client(host=host, port=port, topics=topics, username=username, password=password,
client_id=self.client_id, tls_cafile=tls_cafile, tls_certfile=tls_certfile,
tls_keyfile=tls_keyfile, tls_version=tls_version, tls_ciphers=tls_ciphers,
tls_insecure=tls_insecure)
if username and password:
client.username_pw_set(username, password)
if not client.is_alive():
client.start()
if tls_cafile:
client.tls_set(ca_certs=tls_cafile,
certfile=tls_certfile,
keyfile=tls_keyfile,
tls_version=tls_version,
ciphers=tls_ciphers)
def _get_client(self, host: str, port: int, topics: Optional[List[str]] = None, username: Optional[str] = None,
password: Optional[str] = None, client_id: Optional[str] = None, tls_cafile: Optional[str] = None,
tls_certfile: Optional[str] = None, tls_keyfile: Optional[str] = None, tls_version: Optional = None,
tls_ciphers: Optional = None, tls_insecure: bool = False, on_message: Optional[Callable] = None) \
-> MqttClient:
on_message = on_message or self.on_mqtt_message()
on_message_name = repr(on_message)
client = self._listeners.get((host, port, on_message_name))
client.tls_insecure_set(tls_insecure)
if not (client and client.is_alive()):
client = MqttClient(host=host, port=port, topics=topics, username=username, password=password,
client_id=client_id, tls_cafile=tls_cafile, tls_certfile=tls_certfile,
tls_keyfile=tls_keyfile, tls_version=tls_version, tls_ciphers=tls_ciphers,
tls_insecure=tls_insecure, on_message=on_message)
threading.Thread(target=listener_thread, kwargs={
'client_': client, 'host': host, 'port': port}).start()
self._listeners[(host, port, on_message_name)] = client
client.subscribe(*topics)
return client
def on_mqtt_message(self):
def handler(client, __, msg):
data = msg.payload
# noinspection PyBroadException
try:
data = data.decode('utf-8')
data = json.loads(data)
except:
pass
# noinspection PyProtectedMember
self.bus.post(MQTTMessageEvent(host=client._host, port=client._port, topic=msg.topic, msg=data))
return handler
def on_exec_message(self):
def handler(_, __, msg):
@ -257,51 +317,34 @@ class MqttBackend(Backend):
return handler
def run(self):
import paho.mqtt.client as mqtt
super().run()
self._client = None
if self.host:
self._client = mqtt.Client(self.client_id)
if self.subscribe_default_topic:
self._client.on_connect = self.on_connect(self.topic)
topics = [self.topic] if self.subscribe_default_topic else []
client = self._get_client(host=self.host, port=self.port, topics=topics, username=self.username,
password=self.password, client_id=self.client_id,
tls_cafile=self.tls_cafile, tls_certfile=self.tls_certfile,
tls_keyfile=self.tls_keyfile, tls_version=self.tls_version,
tls_ciphers=self.tls_ciphers, tls_insecure=self.tls_insecure,
on_message=self.on_exec_message())
self._client.on_message = self.on_exec_message()
if self.username and self.password:
self._client.username_pw_set(self.username, self.password)
if self.tls_cafile:
self._client.tls_set(ca_certs=self.tls_cafile, certfile=self.tls_certfile,
keyfile=self.tls_keyfile,
tls_version=self.tls_version,
ciphers=self.tls_ciphers)
self._client.tls_insecure_set(self.tls_insecure)
self._client.connect(self.host, self.port, 60)
client.start()
self.logger.info('Initialized MQTT backend on host {}:{}, topic {}'.
format(self.host, self.port, self.topic))
self._initialize_listeners(self.listeners_conf)
if self._client:
self._client.loop_forever()
self.add_listeners(*self.listeners_conf)
def stop(self):
self.logger.info('Received STOP event on MqttBackend')
if self._client:
self._client.disconnect()
self._client.loop_stop()
self._client = None
self.logger.info('Received STOP event on the MQTT backend')
for listener in self._listeners:
for ((host, port, _), listener) in self._listeners.items():
try:
listener.loop_stop()
listener.disconnect()
except Exception as e:
# noinspection PyProtectedMember
self.logger.warning('Could not stop listener {host}:{port}: {error}'.format(
host=listener._host, port=listener._port,
error=str(e)))
self.logger.warning('Could not stop listener {host}:{port}: {error}'.
format(host=host, port=port, error=str(e)))
# vim:sw=4:ts=4:et: