forked from platypush/platypush
Refactored backend.mqtt to reuse connections whenever possible, as well as programmatically subscribe/unsubscribe topics at runtime
This commit is contained in:
parent
1a70c6ea0b
commit
f9598977db
1 changed files with 127 additions and 84 deletions
|
@ -1,7 +1,9 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
from typing import Optional
|
from typing import Optional, List, Callable
|
||||||
|
|
||||||
|
import paho.mqtt.client as mqtt
|
||||||
|
|
||||||
from platypush.backend import Backend
|
from platypush.backend import Backend
|
||||||
from platypush.config import Config
|
from platypush.config import Config
|
||||||
|
@ -13,6 +15,73 @@ from platypush.plugins.mqtt import MqttPlugin as MQTTPlugin
|
||||||
from platypush.utils import set_thread_name
|
from platypush.utils import set_thread_name
|
||||||
|
|
||||||
|
|
||||||
|
class MqttClient(mqtt.Client, threading.Thread):
|
||||||
|
def __init__(self, *args, host: str, port: int, topics: Optional[List[str]] = None,
|
||||||
|
on_message: Optional[Callable] = None, username: Optional[str] = None, password: Optional[str] = None,
|
||||||
|
client_id: Optional[str] = None, tls_cafile: Optional[str] = None, tls_certfile: Optional[str] = None,
|
||||||
|
tls_keyfile: Optional[str] = None, tls_version: Optional = None, tls_ciphers: Optional = None,
|
||||||
|
tls_insecure: bool = False, keepalive: Optional[int] = 60, **kwargs):
|
||||||
|
mqtt.Client.__init__(self, *args, client_id=client_id, **kwargs)
|
||||||
|
threading.Thread.__init__(self)
|
||||||
|
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.topics = set(topics or [])
|
||||||
|
self.keepalive = keepalive
|
||||||
|
self.on_connect = self.connect_hndl()
|
||||||
|
|
||||||
|
if on_message:
|
||||||
|
self.on_message = on_message
|
||||||
|
|
||||||
|
if username and password:
|
||||||
|
self.username_pw_set(username, password)
|
||||||
|
|
||||||
|
if tls_cafile:
|
||||||
|
self.tls_set(ca_certs=tls_cafile, certfile=tls_certfile, keyfile=tls_keyfile, tls_version=tls_version,
|
||||||
|
ciphers=tls_ciphers)
|
||||||
|
|
||||||
|
self.tls_insecure_set(tls_insecure)
|
||||||
|
|
||||||
|
self._running = False
|
||||||
|
self._stop_scheduled = False
|
||||||
|
|
||||||
|
def subscribe(self, *topics, **kwargs):
|
||||||
|
if not topics:
|
||||||
|
topics = self.topics
|
||||||
|
|
||||||
|
self.topics.update(topics)
|
||||||
|
for topic in topics:
|
||||||
|
super().subscribe(topic, **kwargs)
|
||||||
|
|
||||||
|
def unsubscribe(self, *topics, **kwargs):
|
||||||
|
if not topics:
|
||||||
|
topics = self.topics
|
||||||
|
|
||||||
|
for topic in topics:
|
||||||
|
super().unsubscribe(topic, **kwargs)
|
||||||
|
self.topics.remove(topic)
|
||||||
|
|
||||||
|
def connect_hndl(self):
|
||||||
|
def handler(*_, **__):
|
||||||
|
self.subscribe()
|
||||||
|
|
||||||
|
return handler
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
super().run()
|
||||||
|
self.connect(host=self.host, port=self.port, keepalive=self.keepalive)
|
||||||
|
self._running = True
|
||||||
|
self.loop_forever()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
if not self.is_alive():
|
||||||
|
return
|
||||||
|
|
||||||
|
self._stop_scheduled = True
|
||||||
|
self.disconnect()
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
|
||||||
class MqttBackend(Backend):
|
class MqttBackend(Backend):
|
||||||
"""
|
"""
|
||||||
Backend that reads messages from a configured MQTT topic (default:
|
Backend that reads messages from a configured MQTT topic (default:
|
||||||
|
@ -115,9 +184,7 @@ class MqttBackend(Backend):
|
||||||
|
|
||||||
self.topic = '{}/{}'.format(topic, self.device_id)
|
self.topic = '{}/{}'.format(topic, self.device_id)
|
||||||
self.subscribe_default_topic = subscribe_default_topic
|
self.subscribe_default_topic = subscribe_default_topic
|
||||||
self._client = None
|
self._listeners = {} # (host, port, msg_handler) -> MqttClient map
|
||||||
self._listeners = []
|
|
||||||
|
|
||||||
self.listeners_conf = listeners or []
|
self.listeners_conf = listeners or []
|
||||||
|
|
||||||
def send_message(self, msg, topic: Optional[str] = None, **kwargs):
|
def send_message(self, msg, topic: Optional[str] = None, **kwargs):
|
||||||
|
@ -128,51 +195,20 @@ class MqttBackend(Backend):
|
||||||
password=self.password, tls_cafile=self.tls_cafile,
|
password=self.password, tls_cafile=self.tls_cafile,
|
||||||
tls_certfile=self.tls_certfile, tls_keyfile=self.tls_keyfile,
|
tls_certfile=self.tls_certfile, tls_keyfile=self.tls_keyfile,
|
||||||
tls_version=self.tls_version, tls_insecure=self.tls_insecure,
|
tls_version=self.tls_version, tls_insecure=self.tls_insecure,
|
||||||
tls_ciphers=self.tls_ciphers, client_id=self.client_id, **kwargs)
|
tls_ciphers=self.tls_ciphers, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.exception(e)
|
self.logger.exception(e)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def on_connect(*topics):
|
|
||||||
# noinspection PyUnusedLocal
|
|
||||||
def handler(client, userdata, flags, rc):
|
|
||||||
for topic in topics:
|
|
||||||
client.subscribe(topic)
|
|
||||||
|
|
||||||
return handler
|
|
||||||
|
|
||||||
def on_mqtt_message(self):
|
|
||||||
def handler(client, _, msg):
|
|
||||||
data = msg.payload
|
|
||||||
# noinspection PyBroadException
|
|
||||||
try:
|
|
||||||
data = data.decode('utf-8')
|
|
||||||
data = json.loads(data)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# noinspection PyProtectedMember
|
|
||||||
self.bus.post(MQTTMessageEvent(host=client._host, port=client._port, topic=msg.topic, msg=data))
|
|
||||||
|
|
||||||
return handler
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _expandpath(path: str) -> str:
|
def _expandpath(path: str) -> str:
|
||||||
return os.path.abspath(os.path.expanduser(path)) if path else path
|
return os.path.abspath(os.path.expanduser(path)) if path else path
|
||||||
|
|
||||||
def _initialize_listeners(self, listeners_conf):
|
def add_listeners(self, *listeners):
|
||||||
import paho.mqtt.client as mqtt
|
|
||||||
|
|
||||||
def listener_thread(client_, host, port):
|
|
||||||
client_.connect(host, port)
|
|
||||||
client_.loop_forever()
|
|
||||||
|
|
||||||
# noinspection PyShadowingNames,PyUnusedLocal
|
# noinspection PyShadowingNames,PyUnusedLocal
|
||||||
for i, listener in enumerate(listeners_conf):
|
for i, listener in enumerate(listeners):
|
||||||
host = listener.get('host')
|
host = listener.get('host')
|
||||||
if host:
|
if host:
|
||||||
port = listener.get('port', self._default_mqtt_port)
|
port = listener.get('port', self._default_mqtt_port)
|
||||||
topics = listener.get('topics')
|
|
||||||
username = listener.get('username')
|
username = listener.get('username')
|
||||||
password = listener.get('password')
|
password = listener.get('password')
|
||||||
tls_cafile = self._expandpath(listener.get('tls_cafile'))
|
tls_cafile = self._expandpath(listener.get('tls_cafile'))
|
||||||
|
@ -189,7 +225,7 @@ class MqttBackend(Backend):
|
||||||
tls_cafile = self.tls_cafile
|
tls_cafile = self.tls_cafile
|
||||||
tls_certfile = self.tls_certfile
|
tls_certfile = self.tls_certfile
|
||||||
tls_keyfile = self.tls_keyfile
|
tls_keyfile = self.tls_keyfile
|
||||||
tls_version = self.tls_keyfile
|
tls_version = self.tls_version
|
||||||
tls_ciphers = self.tls_ciphers
|
tls_ciphers = self.tls_ciphers
|
||||||
tls_insecure = self.tls_insecure
|
tls_insecure = self.tls_insecure
|
||||||
|
|
||||||
|
@ -198,24 +234,48 @@ class MqttBackend(Backend):
|
||||||
self.logger.warning('No list of topics specified for listener n.{}'.format(i+1))
|
self.logger.warning('No list of topics specified for listener n.{}'.format(i+1))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
client = mqtt.Client()
|
client = self._get_client(host=host, port=port, topics=topics, username=username, password=password,
|
||||||
client.on_connect = self.on_connect(*topics)
|
client_id=self.client_id, tls_cafile=tls_cafile, tls_certfile=tls_certfile,
|
||||||
client.on_message = self.on_mqtt_message()
|
tls_keyfile=tls_keyfile, tls_version=tls_version, tls_ciphers=tls_ciphers,
|
||||||
|
tls_insecure=tls_insecure)
|
||||||
|
|
||||||
if username and password:
|
if not client.is_alive():
|
||||||
client.username_pw_set(username, password)
|
client.start()
|
||||||
|
|
||||||
if tls_cafile:
|
def _get_client(self, host: str, port: int, topics: Optional[List[str]] = None, username: Optional[str] = None,
|
||||||
client.tls_set(ca_certs=tls_cafile,
|
password: Optional[str] = None, client_id: Optional[str] = None, tls_cafile: Optional[str] = None,
|
||||||
certfile=tls_certfile,
|
tls_certfile: Optional[str] = None, tls_keyfile: Optional[str] = None, tls_version: Optional = None,
|
||||||
keyfile=tls_keyfile,
|
tls_ciphers: Optional = None, tls_insecure: bool = False, on_message: Optional[Callable] = None) \
|
||||||
tls_version=tls_version,
|
-> MqttClient:
|
||||||
ciphers=tls_ciphers)
|
on_message = on_message or self.on_mqtt_message()
|
||||||
|
on_message_name = repr(on_message)
|
||||||
|
client = self._listeners.get((host, port, on_message_name))
|
||||||
|
|
||||||
client.tls_insecure_set(tls_insecure)
|
if not (client and client.is_alive()):
|
||||||
|
client = MqttClient(host=host, port=port, topics=topics, username=username, password=password,
|
||||||
|
client_id=client_id, tls_cafile=tls_cafile, tls_certfile=tls_certfile,
|
||||||
|
tls_keyfile=tls_keyfile, tls_version=tls_version, tls_ciphers=tls_ciphers,
|
||||||
|
tls_insecure=tls_insecure, on_message=on_message)
|
||||||
|
|
||||||
threading.Thread(target=listener_thread, kwargs={
|
self._listeners[(host, port, on_message_name)] = client
|
||||||
'client_': client, 'host': host, 'port': port}).start()
|
|
||||||
|
client.subscribe(*topics)
|
||||||
|
return client
|
||||||
|
|
||||||
|
def on_mqtt_message(self):
|
||||||
|
def handler(client, __, msg):
|
||||||
|
data = msg.payload
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
data = data.decode('utf-8')
|
||||||
|
data = json.loads(data)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
self.bus.post(MQTTMessageEvent(host=client._host, port=client._port, topic=msg.topic, msg=data))
|
||||||
|
|
||||||
|
return handler
|
||||||
|
|
||||||
def on_exec_message(self):
|
def on_exec_message(self):
|
||||||
def handler(_, __, msg):
|
def handler(_, __, msg):
|
||||||
|
@ -257,51 +317,34 @@ class MqttBackend(Backend):
|
||||||
return handler
|
return handler
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
import paho.mqtt.client as mqtt
|
|
||||||
|
|
||||||
super().run()
|
super().run()
|
||||||
self._client = None
|
|
||||||
|
|
||||||
if self.host:
|
if self.host:
|
||||||
self._client = mqtt.Client(self.client_id)
|
topics = [self.topic] if self.subscribe_default_topic else []
|
||||||
if self.subscribe_default_topic:
|
client = self._get_client(host=self.host, port=self.port, topics=topics, username=self.username,
|
||||||
self._client.on_connect = self.on_connect(self.topic)
|
password=self.password, client_id=self.client_id,
|
||||||
|
tls_cafile=self.tls_cafile, tls_certfile=self.tls_certfile,
|
||||||
|
tls_keyfile=self.tls_keyfile, tls_version=self.tls_version,
|
||||||
|
tls_ciphers=self.tls_ciphers, tls_insecure=self.tls_insecure,
|
||||||
|
on_message=self.on_exec_message())
|
||||||
|
|
||||||
self._client.on_message = self.on_exec_message()
|
client.start()
|
||||||
if self.username and self.password:
|
|
||||||
self._client.username_pw_set(self.username, self.password)
|
|
||||||
|
|
||||||
if self.tls_cafile:
|
|
||||||
self._client.tls_set(ca_certs=self.tls_cafile, certfile=self.tls_certfile,
|
|
||||||
keyfile=self.tls_keyfile,
|
|
||||||
tls_version=self.tls_version,
|
|
||||||
ciphers=self.tls_ciphers)
|
|
||||||
|
|
||||||
self._client.tls_insecure_set(self.tls_insecure)
|
|
||||||
|
|
||||||
self._client.connect(self.host, self.port, 60)
|
|
||||||
self.logger.info('Initialized MQTT backend on host {}:{}, topic {}'.
|
self.logger.info('Initialized MQTT backend on host {}:{}, topic {}'.
|
||||||
format(self.host, self.port, self.topic))
|
format(self.host, self.port, self.topic))
|
||||||
|
|
||||||
self._initialize_listeners(self.listeners_conf)
|
self.add_listeners(*self.listeners_conf)
|
||||||
if self._client:
|
|
||||||
self._client.loop_forever()
|
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
self.logger.info('Received STOP event on MqttBackend')
|
self.logger.info('Received STOP event on the MQTT backend')
|
||||||
if self._client:
|
|
||||||
self._client.disconnect()
|
|
||||||
self._client.loop_stop()
|
|
||||||
self._client = None
|
|
||||||
|
|
||||||
for listener in self._listeners:
|
for ((host, port, _), listener) in self._listeners.items():
|
||||||
try:
|
try:
|
||||||
listener.loop_stop()
|
listener.loop_stop()
|
||||||
|
listener.disconnect()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
self.logger.warning('Could not stop listener {host}:{port}: {error}'.format(
|
self.logger.warning('Could not stop listener {host}:{port}: {error}'.
|
||||||
host=listener._host, port=listener._port,
|
format(host=host, port=port, error=str(e)))
|
||||||
error=str(e)))
|
|
||||||
|
|
||||||
|
|
||||||
# vim:sw=4:ts=4:et:
|
# vim:sw=4:ts=4:et:
|
||||||
|
|
Loading…
Reference in a new issue