Refactored backend.mqtt to reuse connections whenever possible, as well as programmatically subscribe/unsubscribe topics at runtime

This commit is contained in:
Fabio Manganiello 2021-02-10 22:26:51 +01:00
parent 1a70c6ea0b
commit f9598977db

View file

@ -1,7 +1,9 @@
import json import json
import os import os
import threading import threading
from typing import Optional from typing import Optional, List, Callable
import paho.mqtt.client as mqtt
from platypush.backend import Backend from platypush.backend import Backend
from platypush.config import Config from platypush.config import Config
@ -13,6 +15,73 @@ from platypush.plugins.mqtt import MqttPlugin as MQTTPlugin
from platypush.utils import set_thread_name from platypush.utils import set_thread_name
class MqttClient(mqtt.Client, threading.Thread):
def __init__(self, *args, host: str, port: int, topics: Optional[List[str]] = None,
on_message: Optional[Callable] = None, username: Optional[str] = None, password: Optional[str] = None,
client_id: Optional[str] = None, tls_cafile: Optional[str] = None, tls_certfile: Optional[str] = None,
tls_keyfile: Optional[str] = None, tls_version: Optional = None, tls_ciphers: Optional = None,
tls_insecure: bool = False, keepalive: Optional[int] = 60, **kwargs):
mqtt.Client.__init__(self, *args, client_id=client_id, **kwargs)
threading.Thread.__init__(self)
self.host = host
self.port = port
self.topics = set(topics or [])
self.keepalive = keepalive
self.on_connect = self.connect_hndl()
if on_message:
self.on_message = on_message
if username and password:
self.username_pw_set(username, password)
if tls_cafile:
self.tls_set(ca_certs=tls_cafile, certfile=tls_certfile, keyfile=tls_keyfile, tls_version=tls_version,
ciphers=tls_ciphers)
self.tls_insecure_set(tls_insecure)
self._running = False
self._stop_scheduled = False
def subscribe(self, *topics, **kwargs):
if not topics:
topics = self.topics
self.topics.update(topics)
for topic in topics:
super().subscribe(topic, **kwargs)
def unsubscribe(self, *topics, **kwargs):
if not topics:
topics = self.topics
for topic in topics:
super().unsubscribe(topic, **kwargs)
self.topics.remove(topic)
def connect_hndl(self):
def handler(*_, **__):
self.subscribe()
return handler
def run(self):
super().run()
self.connect(host=self.host, port=self.port, keepalive=self.keepalive)
self._running = True
self.loop_forever()
def stop(self):
if not self.is_alive():
return
self._stop_scheduled = True
self.disconnect()
self._running = False
class MqttBackend(Backend): class MqttBackend(Backend):
""" """
Backend that reads messages from a configured MQTT topic (default: Backend that reads messages from a configured MQTT topic (default:
@ -115,9 +184,7 @@ class MqttBackend(Backend):
self.topic = '{}/{}'.format(topic, self.device_id) self.topic = '{}/{}'.format(topic, self.device_id)
self.subscribe_default_topic = subscribe_default_topic self.subscribe_default_topic = subscribe_default_topic
self._client = None self._listeners = {} # (host, port, msg_handler) -> MqttClient map
self._listeners = []
self.listeners_conf = listeners or [] self.listeners_conf = listeners or []
def send_message(self, msg, topic: Optional[str] = None, **kwargs): def send_message(self, msg, topic: Optional[str] = None, **kwargs):
@ -128,51 +195,20 @@ class MqttBackend(Backend):
password=self.password, tls_cafile=self.tls_cafile, password=self.password, tls_cafile=self.tls_cafile,
tls_certfile=self.tls_certfile, tls_keyfile=self.tls_keyfile, tls_certfile=self.tls_certfile, tls_keyfile=self.tls_keyfile,
tls_version=self.tls_version, tls_insecure=self.tls_insecure, tls_version=self.tls_version, tls_insecure=self.tls_insecure,
tls_ciphers=self.tls_ciphers, client_id=self.client_id, **kwargs) tls_ciphers=self.tls_ciphers, **kwargs)
except Exception as e: except Exception as e:
self.logger.exception(e) self.logger.exception(e)
@staticmethod
def on_connect(*topics):
# noinspection PyUnusedLocal
def handler(client, userdata, flags, rc):
for topic in topics:
client.subscribe(topic)
return handler
def on_mqtt_message(self):
def handler(client, _, msg):
data = msg.payload
# noinspection PyBroadException
try:
data = data.decode('utf-8')
data = json.loads(data)
except:
pass
# noinspection PyProtectedMember
self.bus.post(MQTTMessageEvent(host=client._host, port=client._port, topic=msg.topic, msg=data))
return handler
@staticmethod @staticmethod
def _expandpath(path: str) -> str: def _expandpath(path: str) -> str:
return os.path.abspath(os.path.expanduser(path)) if path else path return os.path.abspath(os.path.expanduser(path)) if path else path
def _initialize_listeners(self, listeners_conf): def add_listeners(self, *listeners):
import paho.mqtt.client as mqtt
def listener_thread(client_, host, port):
client_.connect(host, port)
client_.loop_forever()
# noinspection PyShadowingNames,PyUnusedLocal # noinspection PyShadowingNames,PyUnusedLocal
for i, listener in enumerate(listeners_conf): for i, listener in enumerate(listeners):
host = listener.get('host') host = listener.get('host')
if host: if host:
port = listener.get('port', self._default_mqtt_port) port = listener.get('port', self._default_mqtt_port)
topics = listener.get('topics')
username = listener.get('username') username = listener.get('username')
password = listener.get('password') password = listener.get('password')
tls_cafile = self._expandpath(listener.get('tls_cafile')) tls_cafile = self._expandpath(listener.get('tls_cafile'))
@ -189,7 +225,7 @@ class MqttBackend(Backend):
tls_cafile = self.tls_cafile tls_cafile = self.tls_cafile
tls_certfile = self.tls_certfile tls_certfile = self.tls_certfile
tls_keyfile = self.tls_keyfile tls_keyfile = self.tls_keyfile
tls_version = self.tls_keyfile tls_version = self.tls_version
tls_ciphers = self.tls_ciphers tls_ciphers = self.tls_ciphers
tls_insecure = self.tls_insecure tls_insecure = self.tls_insecure
@ -198,24 +234,48 @@ class MqttBackend(Backend):
self.logger.warning('No list of topics specified for listener n.{}'.format(i+1)) self.logger.warning('No list of topics specified for listener n.{}'.format(i+1))
continue continue
client = mqtt.Client() client = self._get_client(host=host, port=port, topics=topics, username=username, password=password,
client.on_connect = self.on_connect(*topics) client_id=self.client_id, tls_cafile=tls_cafile, tls_certfile=tls_certfile,
client.on_message = self.on_mqtt_message() tls_keyfile=tls_keyfile, tls_version=tls_version, tls_ciphers=tls_ciphers,
tls_insecure=tls_insecure)
if username and password: if not client.is_alive():
client.username_pw_set(username, password) client.start()
if tls_cafile: def _get_client(self, host: str, port: int, topics: Optional[List[str]] = None, username: Optional[str] = None,
client.tls_set(ca_certs=tls_cafile, password: Optional[str] = None, client_id: Optional[str] = None, tls_cafile: Optional[str] = None,
certfile=tls_certfile, tls_certfile: Optional[str] = None, tls_keyfile: Optional[str] = None, tls_version: Optional = None,
keyfile=tls_keyfile, tls_ciphers: Optional = None, tls_insecure: bool = False, on_message: Optional[Callable] = None) \
tls_version=tls_version, -> MqttClient:
ciphers=tls_ciphers) on_message = on_message or self.on_mqtt_message()
on_message_name = repr(on_message)
client = self._listeners.get((host, port, on_message_name))
client.tls_insecure_set(tls_insecure) if not (client and client.is_alive()):
client = MqttClient(host=host, port=port, topics=topics, username=username, password=password,
client_id=client_id, tls_cafile=tls_cafile, tls_certfile=tls_certfile,
tls_keyfile=tls_keyfile, tls_version=tls_version, tls_ciphers=tls_ciphers,
tls_insecure=tls_insecure, on_message=on_message)
threading.Thread(target=listener_thread, kwargs={ self._listeners[(host, port, on_message_name)] = client
'client_': client, 'host': host, 'port': port}).start()
client.subscribe(*topics)
return client
def on_mqtt_message(self):
def handler(client, __, msg):
data = msg.payload
# noinspection PyBroadException
try:
data = data.decode('utf-8')
data = json.loads(data)
except:
pass
# noinspection PyProtectedMember
self.bus.post(MQTTMessageEvent(host=client._host, port=client._port, topic=msg.topic, msg=data))
return handler
def on_exec_message(self): def on_exec_message(self):
def handler(_, __, msg): def handler(_, __, msg):
@ -257,51 +317,34 @@ class MqttBackend(Backend):
return handler return handler
def run(self): def run(self):
import paho.mqtt.client as mqtt
super().run() super().run()
self._client = None
if self.host: if self.host:
self._client = mqtt.Client(self.client_id) topics = [self.topic] if self.subscribe_default_topic else []
if self.subscribe_default_topic: client = self._get_client(host=self.host, port=self.port, topics=topics, username=self.username,
self._client.on_connect = self.on_connect(self.topic) password=self.password, client_id=self.client_id,
tls_cafile=self.tls_cafile, tls_certfile=self.tls_certfile,
tls_keyfile=self.tls_keyfile, tls_version=self.tls_version,
tls_ciphers=self.tls_ciphers, tls_insecure=self.tls_insecure,
on_message=self.on_exec_message())
self._client.on_message = self.on_exec_message() client.start()
if self.username and self.password:
self._client.username_pw_set(self.username, self.password)
if self.tls_cafile:
self._client.tls_set(ca_certs=self.tls_cafile, certfile=self.tls_certfile,
keyfile=self.tls_keyfile,
tls_version=self.tls_version,
ciphers=self.tls_ciphers)
self._client.tls_insecure_set(self.tls_insecure)
self._client.connect(self.host, self.port, 60)
self.logger.info('Initialized MQTT backend on host {}:{}, topic {}'. self.logger.info('Initialized MQTT backend on host {}:{}, topic {}'.
format(self.host, self.port, self.topic)) format(self.host, self.port, self.topic))
self._initialize_listeners(self.listeners_conf) self.add_listeners(*self.listeners_conf)
if self._client:
self._client.loop_forever()
def stop(self): def stop(self):
self.logger.info('Received STOP event on MqttBackend') self.logger.info('Received STOP event on the MQTT backend')
if self._client:
self._client.disconnect()
self._client.loop_stop()
self._client = None
for listener in self._listeners: for ((host, port, _), listener) in self._listeners.items():
try: try:
listener.loop_stop() listener.loop_stop()
listener.disconnect()
except Exception as e: except Exception as e:
# noinspection PyProtectedMember # noinspection PyProtectedMember
self.logger.warning('Could not stop listener {host}:{port}: {error}'.format( self.logger.warning('Could not stop listener {host}:{port}: {error}'.
host=listener._host, port=listener._port, format(host=host, port=port, error=str(e)))
error=str(e)))
# vim:sw=4:ts=4:et: # vim:sw=4:ts=4:et: