platypush/platypush/plugins/mqtt/_client.py

250 lines
7.0 KiB
Python

from enum import IntEnum
import logging
import os
import threading
from typing import Any, Callable, Dict, Iterable, Optional, Union
import paho.mqtt.client as mqtt
from platypush.config import Config
MqttCallback = Callable[["MqttClient", Any, mqtt.MQTTMessage], Any]
DEFAULT_TIMEOUT: int = 30
class MqttClient(mqtt.Client, threading.Thread):
"""
Wrapper class for an MQTT client executed in a separate thread.
"""
def __init__(
self,
*args,
host: str,
port: int,
client_id: str,
topics: Iterable[str] = (),
on_message: Optional[MqttCallback] = None,
username: Optional[str] = None,
password: Optional[str] = None,
tls_cafile: Optional[str] = None,
tls_certfile: Optional[str] = None,
tls_keyfile: Optional[str] = None,
tls_version: Optional[Union[str, IntEnum]] = None,
tls_ciphers: Optional[str] = None,
tls_insecure: bool = False,
timeout: int = DEFAULT_TIMEOUT,
**kwargs,
):
self.client_id = client_id or str(Config.get('device_id'))
mqtt.Client.__init__(self, *args, client_id=self.client_id, **kwargs)
threading.Thread.__init__(self, name=f'MQTTClient:{self.client_id}')
self.logger = logging.getLogger(self.__class__.__name__)
self.host = host
self.port = port
self.tls_cafile = self._expandpath(tls_cafile)
self.tls_certfile = self._expandpath(tls_certfile)
self.tls_keyfile = self._expandpath(tls_keyfile)
self.tls_version = self._get_tls_version(tls_version)
self.tls_ciphers = self._expandpath(tls_ciphers)
self.tls_insecure = tls_insecure
self.username = username
self.password = password
self.topics = set(topics or [])
self.timeout = timeout
self.on_connect = self.connect_hndl()
self.on_disconnect = self.disconnect_hndl()
if on_message:
self.on_message = on_message # type: ignore
if username and password:
self.username_pw_set(username, password)
if tls_cafile:
self.tls_set(
ca_certs=self.tls_cafile,
certfile=self.tls_certfile,
keyfile=self.tls_keyfile,
tls_version=self.tls_version,
ciphers=self.tls_ciphers,
)
self.tls_insecure_set(self.tls_insecure)
self._running = False
self._stop_scheduled = False
@staticmethod
def _expandpath(path: Optional[str] = None) -> Optional[str]:
"""
Utility method to expand a path string.
"""
return os.path.abspath(os.path.expanduser(path)) if path else None
@staticmethod
def _get_tls_version(version: Optional[Union[str, IntEnum]] = None):
"""
A utility method that normalizes an SSL version string or enum to a
standard ``_SSLMethod`` enum.
"""
import ssl
if not version:
return None
if isinstance(version, type(ssl.PROTOCOL_TLS)):
return version
if isinstance(version, str):
version = version.lower()
if version == 'tls':
return ssl.PROTOCOL_TLS
if version == 'tlsv1':
return ssl.PROTOCOL_TLSv1
if version == 'tlsv1.1':
return ssl.PROTOCOL_TLSv1_1
if version == 'tlsv1.2':
return ssl.PROTOCOL_TLSv1_2
raise AssertionError(f'Unrecognized TLS version: {version}')
def connect(
self,
*args,
host: Optional[str] = None,
port: Optional[int] = None,
keepalive: Optional[int] = None,
**kwargs,
):
"""
Overrides the default connect method.
"""
if not self.is_connected():
self.logger.debug(
'Connecting to MQTT broker %s:%d, client_id=%s...',
self.host,
self.port,
self.client_id,
)
return super().connect(
host=host or self.host,
port=port or self.port,
keepalive=keepalive or self.timeout,
*args,
**kwargs,
)
return None
@property
def configuration(self) -> Dict[str, Any]:
"""
:return: The configuration of the client.
"""
return {
'host': self.host,
'port': self.port,
'topics': self.topics,
'on_message': self.on_message,
'username': self.username,
'password': self.password,
'client_id': self.client_id,
'tls_cafile': self.tls_cafile,
'tls_certfile': self.tls_certfile,
'tls_keyfile': self.tls_keyfile,
'tls_version': self.tls_version,
'tls_ciphers': self.tls_ciphers,
'tls_insecure': self.tls_insecure,
'timeout': self.timeout,
}
def subscribe(self, *topics, **kwargs):
"""
Client subscription handler.
"""
if not topics:
topics = self.topics
self.topics.update(topics)
for topic in topics:
super().subscribe(topic, **kwargs)
def unsubscribe(self, *topics, **kwargs):
"""
Client unsubscribe handler.
"""
if not topics:
topics = self.topics
for topic in topics:
if topic not in self.topics:
self.logger.info('The topic %s is not subscribed', topic)
continue
super().unsubscribe(topic, **kwargs)
self.topics.remove(topic)
def connect_hndl(self):
"""
When the client connects, subscribe to all the registered topics.
"""
def handler(*_, **__):
self.logger.debug(
'Connected to MQTT broker %s:%d, client_id=%s',
self.host,
self.port,
self.client_id,
)
self.subscribe()
return handler
def disconnect_hndl(self):
"""
Notifies the client disconnection.
"""
def handler(*_, **__):
self.logger.debug(
'Disconnected from MQTT broker %s:%d, client_id=%s',
self.host,
self.port,
self.client_id,
)
return handler
def run(self):
"""
Connects to the MQTT server, subscribes to all the registered topics
and listens for messages.
"""
super().run()
self.connect()
self._running = True
self.loop_forever()
def stop(self):
"""
The stop method schedules the stop and disconnects the client.
"""
if not self.is_alive():
return
try:
self.loop_stop()
except Exception as e:
self.logger.debug('Could not stop client loop: %s: %s', type(e).__name__, e)
self._stop_scheduled = True
self.disconnect()
self._running = False
# vim:sw=4:ts=4:et: