from enum import IntEnum import logging import os import threading from typing import Any, Callable, Dict, Iterable, Optional, Union import paho.mqtt.client as mqtt from platypush.config import Config MqttCallback = Callable[["MqttClient", Any, mqtt.MQTTMessage], Any] DEFAULT_TIMEOUT: int = 30 class MqttClient(mqtt.Client, threading.Thread): """ Wrapper class for an MQTT client executed in a separate thread. """ def __init__( self, *args, host: str, port: int, client_id: str, topics: Iterable[str] = (), on_message: Optional[MqttCallback] = None, username: Optional[str] = None, password: Optional[str] = None, tls_cafile: Optional[str] = None, tls_certfile: Optional[str] = None, tls_keyfile: Optional[str] = None, tls_version: Optional[Union[str, IntEnum]] = None, tls_ciphers: Optional[str] = None, tls_insecure: bool = False, timeout: int = DEFAULT_TIMEOUT, **kwargs, ): self.client_id = client_id or str(Config.get('device_id')) mqtt.Client.__init__(self, mqtt.CallbackAPIVersion.VERSION1, *args, client_id=self.client_id, **kwargs) threading.Thread.__init__(self, name=f'MQTTClient:{self.client_id}') self.logger = logging.getLogger(self.__class__.__name__) self.host = host self.port = port self.tls_cafile = self._expandpath(tls_cafile) self.tls_certfile = self._expandpath(tls_certfile) self.tls_keyfile = self._expandpath(tls_keyfile) self.tls_version = self._get_tls_version(tls_version) self.tls_ciphers = self._expandpath(tls_ciphers) self.tls_insecure = tls_insecure self.username = username self.password = password self.topics = set(topics or []) self.timeout = timeout self.on_connect = self.connect_hndl() self.on_disconnect = self.disconnect_hndl() if on_message: self.on_message = on_message # type: ignore if username and password: self.username_pw_set(username, password) if tls_cafile: self.tls_set( ca_certs=self.tls_cafile, certfile=self.tls_certfile, keyfile=self.tls_keyfile, tls_version=self.tls_version, ciphers=self.tls_ciphers, ) self.tls_insecure_set(self.tls_insecure) self._running = False self._stop_scheduled = False @staticmethod def _expandpath(path: Optional[str] = None) -> Optional[str]: """ Utility method to expand a path string. """ return os.path.abspath(os.path.expanduser(path)) if path else None @staticmethod def _get_tls_version(version: Optional[Union[str, IntEnum]] = None): """ A utility method that normalizes an SSL version string or enum to a standard ``_SSLMethod`` enum. """ import ssl if not version: return None if isinstance(version, type(ssl.PROTOCOL_TLS)): return version if isinstance(version, str): version = version.lower() if version == 'tls': return ssl.PROTOCOL_TLS if version == 'tlsv1': return ssl.PROTOCOL_TLSv1 if version == 'tlsv1.1': return ssl.PROTOCOL_TLSv1_1 if version == 'tlsv1.2': return ssl.PROTOCOL_TLSv1_2 raise AssertionError(f'Unrecognized TLS version: {version}') def connect( self, *args, host: Optional[str] = None, port: Optional[int] = None, keepalive: Optional[int] = None, **kwargs, ): """ Overrides the default connect method. """ if not self.is_connected(): self.logger.debug( 'Connecting to MQTT broker %s:%d, client_id=%s...', self.host, self.port, self.client_id, ) return super().connect( host=host or self.host, port=port or self.port, keepalive=keepalive or self.timeout, *args, **kwargs, ) return None @property def configuration(self) -> Dict[str, Any]: """ :return: The configuration of the client. """ return { 'host': self.host, 'port': self.port, 'topics': self.topics, 'on_message': self.on_message, 'username': self.username, 'password': self.password, 'client_id': self.client_id, 'tls_cafile': self.tls_cafile, 'tls_certfile': self.tls_certfile, 'tls_keyfile': self.tls_keyfile, 'tls_version': self.tls_version, 'tls_ciphers': self.tls_ciphers, 'tls_insecure': self.tls_insecure, 'timeout': self.timeout, } def subscribe(self, *topics, **kwargs): """ Client subscription handler. """ if not topics: topics = self.topics self.topics.update(topics) for topic in topics: super().subscribe(topic, **kwargs) def unsubscribe(self, *topics, **kwargs): """ Client unsubscribe handler. """ if not topics: topics = self.topics for topic in topics: if topic not in self.topics: self.logger.info('The topic %s is not subscribed', topic) continue super().unsubscribe(topic, **kwargs) self.topics.remove(topic) def connect_hndl(self): """ When the client connects, subscribe to all the registered topics. """ def handler(*_, **__): self.logger.debug( 'Connected to MQTT broker %s:%d, client_id=%s', self.host, self.port, self.client_id, ) self.subscribe() return handler def disconnect_hndl(self): """ Notifies the client disconnection. """ def handler(*_, **__): self.logger.debug( 'Disconnected from MQTT broker %s:%d, client_id=%s', self.host, self.port, self.client_id, ) return handler def run(self): """ Connects to the MQTT server, subscribes to all the registered topics and listens for messages. """ super().run() self.connect() self._running = True self.loop_forever() def stop(self): """ The stop method schedules the stop and disconnects the client. """ if not self.is_alive(): return try: self.loop_stop() except Exception as e: self.logger.debug('Could not stop client loop: %s: %s', type(e).__name__, e) self._stop_scheduled = True self.disconnect() self._running = False # vim:sw=4:ts=4:et: