forked from platypush/platypush
Refactored backend.mqtt to reuse connections whenever possible, as well as programmatically subscribe/unsubscribe topics at runtime
This commit is contained in:
parent
1a70c6ea0b
commit
f9598977db
1 changed files with 127 additions and 84 deletions
|
@ -1,7 +1,9 @@
|
|||
import json
|
||||
import os
|
||||
import threading
|
||||
from typing import Optional
|
||||
from typing import Optional, List, Callable
|
||||
|
||||
import paho.mqtt.client as mqtt
|
||||
|
||||
from platypush.backend import Backend
|
||||
from platypush.config import Config
|
||||
|
@ -13,6 +15,73 @@ from platypush.plugins.mqtt import MqttPlugin as MQTTPlugin
|
|||
from platypush.utils import set_thread_name
|
||||
|
||||
|
||||
class MqttClient(mqtt.Client, threading.Thread):
|
||||
def __init__(self, *args, host: str, port: int, topics: Optional[List[str]] = None,
|
||||
on_message: Optional[Callable] = None, username: Optional[str] = None, password: Optional[str] = None,
|
||||
client_id: Optional[str] = None, tls_cafile: Optional[str] = None, tls_certfile: Optional[str] = None,
|
||||
tls_keyfile: Optional[str] = None, tls_version: Optional = None, tls_ciphers: Optional = None,
|
||||
tls_insecure: bool = False, keepalive: Optional[int] = 60, **kwargs):
|
||||
mqtt.Client.__init__(self, *args, client_id=client_id, **kwargs)
|
||||
threading.Thread.__init__(self)
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.topics = set(topics or [])
|
||||
self.keepalive = keepalive
|
||||
self.on_connect = self.connect_hndl()
|
||||
|
||||
if on_message:
|
||||
self.on_message = on_message
|
||||
|
||||
if username and password:
|
||||
self.username_pw_set(username, password)
|
||||
|
||||
if tls_cafile:
|
||||
self.tls_set(ca_certs=tls_cafile, certfile=tls_certfile, keyfile=tls_keyfile, tls_version=tls_version,
|
||||
ciphers=tls_ciphers)
|
||||
|
||||
self.tls_insecure_set(tls_insecure)
|
||||
|
||||
self._running = False
|
||||
self._stop_scheduled = False
|
||||
|
||||
def subscribe(self, *topics, **kwargs):
|
||||
if not topics:
|
||||
topics = self.topics
|
||||
|
||||
self.topics.update(topics)
|
||||
for topic in topics:
|
||||
super().subscribe(topic, **kwargs)
|
||||
|
||||
def unsubscribe(self, *topics, **kwargs):
|
||||
if not topics:
|
||||
topics = self.topics
|
||||
|
||||
for topic in topics:
|
||||
super().unsubscribe(topic, **kwargs)
|
||||
self.topics.remove(topic)
|
||||
|
||||
def connect_hndl(self):
|
||||
def handler(*_, **__):
|
||||
self.subscribe()
|
||||
|
||||
return handler
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
self.connect(host=self.host, port=self.port, keepalive=self.keepalive)
|
||||
self._running = True
|
||||
self.loop_forever()
|
||||
|
||||
def stop(self):
|
||||
if not self.is_alive():
|
||||
return
|
||||
|
||||
self._stop_scheduled = True
|
||||
self.disconnect()
|
||||
self._running = False
|
||||
|
||||
|
||||
class MqttBackend(Backend):
|
||||
"""
|
||||
Backend that reads messages from a configured MQTT topic (default:
|
||||
|
@ -115,9 +184,7 @@ class MqttBackend(Backend):
|
|||
|
||||
self.topic = '{}/{}'.format(topic, self.device_id)
|
||||
self.subscribe_default_topic = subscribe_default_topic
|
||||
self._client = None
|
||||
self._listeners = []
|
||||
|
||||
self._listeners = {} # (host, port, msg_handler) -> MqttClient map
|
||||
self.listeners_conf = listeners or []
|
||||
|
||||
def send_message(self, msg, topic: Optional[str] = None, **kwargs):
|
||||
|
@ -128,51 +195,20 @@ class MqttBackend(Backend):
|
|||
password=self.password, tls_cafile=self.tls_cafile,
|
||||
tls_certfile=self.tls_certfile, tls_keyfile=self.tls_keyfile,
|
||||
tls_version=self.tls_version, tls_insecure=self.tls_insecure,
|
||||
tls_ciphers=self.tls_ciphers, client_id=self.client_id, **kwargs)
|
||||
tls_ciphers=self.tls_ciphers, **kwargs)
|
||||
except Exception as e:
|
||||
self.logger.exception(e)
|
||||
|
||||
@staticmethod
|
||||
def on_connect(*topics):
|
||||
# noinspection PyUnusedLocal
|
||||
def handler(client, userdata, flags, rc):
|
||||
for topic in topics:
|
||||
client.subscribe(topic)
|
||||
|
||||
return handler
|
||||
|
||||
def on_mqtt_message(self):
|
||||
def handler(client, _, msg):
|
||||
data = msg.payload
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
data = data.decode('utf-8')
|
||||
data = json.loads(data)
|
||||
except:
|
||||
pass
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
self.bus.post(MQTTMessageEvent(host=client._host, port=client._port, topic=msg.topic, msg=data))
|
||||
|
||||
return handler
|
||||
|
||||
@staticmethod
|
||||
def _expandpath(path: str) -> str:
|
||||
return os.path.abspath(os.path.expanduser(path)) if path else path
|
||||
|
||||
def _initialize_listeners(self, listeners_conf):
|
||||
import paho.mqtt.client as mqtt
|
||||
|
||||
def listener_thread(client_, host, port):
|
||||
client_.connect(host, port)
|
||||
client_.loop_forever()
|
||||
|
||||
def add_listeners(self, *listeners):
|
||||
# noinspection PyShadowingNames,PyUnusedLocal
|
||||
for i, listener in enumerate(listeners_conf):
|
||||
for i, listener in enumerate(listeners):
|
||||
host = listener.get('host')
|
||||
if host:
|
||||
port = listener.get('port', self._default_mqtt_port)
|
||||
topics = listener.get('topics')
|
||||
username = listener.get('username')
|
||||
password = listener.get('password')
|
||||
tls_cafile = self._expandpath(listener.get('tls_cafile'))
|
||||
|
@ -189,7 +225,7 @@ class MqttBackend(Backend):
|
|||
tls_cafile = self.tls_cafile
|
||||
tls_certfile = self.tls_certfile
|
||||
tls_keyfile = self.tls_keyfile
|
||||
tls_version = self.tls_keyfile
|
||||
tls_version = self.tls_version
|
||||
tls_ciphers = self.tls_ciphers
|
||||
tls_insecure = self.tls_insecure
|
||||
|
||||
|
@ -198,24 +234,48 @@ class MqttBackend(Backend):
|
|||
self.logger.warning('No list of topics specified for listener n.{}'.format(i+1))
|
||||
continue
|
||||
|
||||
client = mqtt.Client()
|
||||
client.on_connect = self.on_connect(*topics)
|
||||
client.on_message = self.on_mqtt_message()
|
||||
client = self._get_client(host=host, port=port, topics=topics, username=username, password=password,
|
||||
client_id=self.client_id, tls_cafile=tls_cafile, tls_certfile=tls_certfile,
|
||||
tls_keyfile=tls_keyfile, tls_version=tls_version, tls_ciphers=tls_ciphers,
|
||||
tls_insecure=tls_insecure)
|
||||
|
||||
if username and password:
|
||||
client.username_pw_set(username, password)
|
||||
if not client.is_alive():
|
||||
client.start()
|
||||
|
||||
if tls_cafile:
|
||||
client.tls_set(ca_certs=tls_cafile,
|
||||
certfile=tls_certfile,
|
||||
keyfile=tls_keyfile,
|
||||
tls_version=tls_version,
|
||||
ciphers=tls_ciphers)
|
||||
def _get_client(self, host: str, port: int, topics: Optional[List[str]] = None, username: Optional[str] = None,
|
||||
password: Optional[str] = None, client_id: Optional[str] = None, tls_cafile: Optional[str] = None,
|
||||
tls_certfile: Optional[str] = None, tls_keyfile: Optional[str] = None, tls_version: Optional = None,
|
||||
tls_ciphers: Optional = None, tls_insecure: bool = False, on_message: Optional[Callable] = None) \
|
||||
-> MqttClient:
|
||||
on_message = on_message or self.on_mqtt_message()
|
||||
on_message_name = repr(on_message)
|
||||
client = self._listeners.get((host, port, on_message_name))
|
||||
|
||||
client.tls_insecure_set(tls_insecure)
|
||||
if not (client and client.is_alive()):
|
||||
client = MqttClient(host=host, port=port, topics=topics, username=username, password=password,
|
||||
client_id=client_id, tls_cafile=tls_cafile, tls_certfile=tls_certfile,
|
||||
tls_keyfile=tls_keyfile, tls_version=tls_version, tls_ciphers=tls_ciphers,
|
||||
tls_insecure=tls_insecure, on_message=on_message)
|
||||
|
||||
threading.Thread(target=listener_thread, kwargs={
|
||||
'client_': client, 'host': host, 'port': port}).start()
|
||||
self._listeners[(host, port, on_message_name)] = client
|
||||
|
||||
client.subscribe(*topics)
|
||||
return client
|
||||
|
||||
def on_mqtt_message(self):
|
||||
def handler(client, __, msg):
|
||||
data = msg.payload
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
data = data.decode('utf-8')
|
||||
data = json.loads(data)
|
||||
except:
|
||||
pass
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
self.bus.post(MQTTMessageEvent(host=client._host, port=client._port, topic=msg.topic, msg=data))
|
||||
|
||||
return handler
|
||||
|
||||
def on_exec_message(self):
|
||||
def handler(_, __, msg):
|
||||
|
@ -257,51 +317,34 @@ class MqttBackend(Backend):
|
|||
return handler
|
||||
|
||||
def run(self):
|
||||
import paho.mqtt.client as mqtt
|
||||
|
||||
super().run()
|
||||
self._client = None
|
||||
|
||||
if self.host:
|
||||
self._client = mqtt.Client(self.client_id)
|
||||
if self.subscribe_default_topic:
|
||||
self._client.on_connect = self.on_connect(self.topic)
|
||||
topics = [self.topic] if self.subscribe_default_topic else []
|
||||
client = self._get_client(host=self.host, port=self.port, topics=topics, username=self.username,
|
||||
password=self.password, client_id=self.client_id,
|
||||
tls_cafile=self.tls_cafile, tls_certfile=self.tls_certfile,
|
||||
tls_keyfile=self.tls_keyfile, tls_version=self.tls_version,
|
||||
tls_ciphers=self.tls_ciphers, tls_insecure=self.tls_insecure,
|
||||
on_message=self.on_exec_message())
|
||||
|
||||
self._client.on_message = self.on_exec_message()
|
||||
if self.username and self.password:
|
||||
self._client.username_pw_set(self.username, self.password)
|
||||
|
||||
if self.tls_cafile:
|
||||
self._client.tls_set(ca_certs=self.tls_cafile, certfile=self.tls_certfile,
|
||||
keyfile=self.tls_keyfile,
|
||||
tls_version=self.tls_version,
|
||||
ciphers=self.tls_ciphers)
|
||||
|
||||
self._client.tls_insecure_set(self.tls_insecure)
|
||||
|
||||
self._client.connect(self.host, self.port, 60)
|
||||
client.start()
|
||||
self.logger.info('Initialized MQTT backend on host {}:{}, topic {}'.
|
||||
format(self.host, self.port, self.topic))
|
||||
|
||||
self._initialize_listeners(self.listeners_conf)
|
||||
if self._client:
|
||||
self._client.loop_forever()
|
||||
self.add_listeners(*self.listeners_conf)
|
||||
|
||||
def stop(self):
|
||||
self.logger.info('Received STOP event on MqttBackend')
|
||||
if self._client:
|
||||
self._client.disconnect()
|
||||
self._client.loop_stop()
|
||||
self._client = None
|
||||
self.logger.info('Received STOP event on the MQTT backend')
|
||||
|
||||
for listener in self._listeners:
|
||||
for ((host, port, _), listener) in self._listeners.items():
|
||||
try:
|
||||
listener.loop_stop()
|
||||
listener.disconnect()
|
||||
except Exception as e:
|
||||
# noinspection PyProtectedMember
|
||||
self.logger.warning('Could not stop listener {host}:{port}: {error}'.format(
|
||||
host=listener._host, port=listener._port,
|
||||
error=str(e)))
|
||||
self.logger.warning('Could not stop listener {host}:{port}: {error}'.
|
||||
format(host=host, port=port, error=str(e)))
|
||||
|
||||
|
||||
# vim:sw=4:ts=4:et:
|
||||
|
|
Loading…
Add table
Reference in a new issue