Support for SSL flag on MQTT plugins without having to specify other tls_* options

This commit is contained in:
Fabio Manganiello 2022-01-14 21:39:16 +01:00
parent a6b552504e
commit f3be4a50d8
Signed by: blacklight
GPG key ID: D90FBA7F76362774
3 changed files with 26 additions and 11 deletions

View file

@ -21,7 +21,7 @@ class MqttPlugin(Plugin):
"""
def __init__(self, host=None, port=1883, tls_cafile=None,
def __init__(self, host=None, port=1883, ssl=False, tls_cafile=None,
tls_certfile=None, tls_keyfile=None,
tls_version=None, tls_ciphers=None, tls_insecure=False,
username=None, password=None, client_id=None, timeout=None, **kwargs):
@ -32,6 +32,9 @@ class MqttPlugin(Plugin):
:param port: If a default host is set, specify the listen port (default: 1883)
:type port: int
:param ssl: Set to true if the connection uses TLS/SSL (default: False).
This flag is automatically set if any other tls_* parameters are passed.
:param tls_cafile: If a default host is set and requires TLS/SSL, specify the certificate authority file (default: None)
:type tls_cafile: str
@ -71,6 +74,7 @@ class MqttPlugin(Plugin):
self.port = port
self.username = username
self.password = password
self.timeout = timeout
self.client_id = client_id or Config.get('device_id')
self.tls_cafile = self._expandpath(tls_cafile) if tls_cafile else None
self.tls_certfile = self._expandpath(tls_certfile) if tls_certfile else None
@ -78,7 +82,10 @@ class MqttPlugin(Plugin):
self.tls_version = self.get_tls_version(tls_version)
self.tls_insecure = tls_insecure
self.tls_ciphers = tls_ciphers
self.timeout = timeout
self.ssl = bool(
ssl or tls_cafile or tls_certfile or
tls_keyfile or tls_insecure or tls_ciphers
)
@staticmethod
def get_tls_version(version: Optional[str] = None):
@ -101,13 +108,14 @@ class MqttPlugin(Plugin):
if version == 'tlsv1.2':
return ssl.PROTOCOL_TLSv1_2
assert 'Unrecognized TLS version: {}'.format(version)
assert f'Unrecognized TLS version: {version}'
def _mqtt_args(self, **kwargs):
return {
'host': kwargs.get('host', self.host),
'port': kwargs.get('port', self.port),
'timeout': kwargs.get('timeout', self.timeout),
'ssl': kwargs.get('ssl', self.ssl),
'tls_certfile': kwargs.get('tls_certfile', self.tls_certfile),
'tls_keyfile': kwargs.get('tls_keyfile', self.tls_keyfile),
'tls_version': kwargs.get('tls_version', self.tls_version),
@ -123,27 +131,31 @@ class MqttPlugin(Plugin):
def _get_client(self, tls_cafile: Optional[str] = None, tls_certfile: Optional[str] = None,
tls_keyfile: Optional[str] = None, tls_version: Optional[str] = None,
tls_ciphers: Optional[str] = None, tls_insecure: Optional[bool] = None,
username: Optional[str] = None, password: Optional[str] = None):
username: Optional[str] = None, password: Optional[str] = None,
ssl: Optional[bool] = None):
from paho.mqtt.client import Client
tls_cafile = self._expandpath(tls_cafile or self.tls_cafile)
tls_certfile = self._expandpath(tls_certfile or self.tls_certfile)
tls_keyfile = self._expandpath(tls_keyfile or self.tls_keyfile)
tls_ciphers = tls_ciphers or self.tls_ciphers
tls_version = tls_version or self.tls_version
ssl = ssl if ssl is not None else self.ssl
username = username or self.username
password = password or self.password
tls_version = tls_version or self.tls_version
if tls_version:
tls_version = self.get_tls_version(tls_version)
if tls_insecure is None:
tls_insecure = self.tls_insecure
if ssl or tls_cafile or tls_certfile or tls_keyfile or tls_ciphers or tls_version:
ssl = True
client = Client()
if username and password:
client.username_pw_set(username, password)
if tls_cafile:
if ssl:
client.tls_set(ca_certs=tls_cafile, certfile=tls_certfile, keyfile=tls_keyfile,
tls_version=tls_version, ciphers=tls_ciphers)
@ -153,7 +165,7 @@ class MqttPlugin(Plugin):
@action
def publish(self, topic: str, msg: Any, host: Optional[str] = None, port: Optional[int] = None,
reply_topic: Optional[str] = None, timeout: int = 60,
reply_topic: Optional[str] = None, timeout: int = 60, ssl: Optional[bool] = None,
tls_cafile: Optional[str] = None, tls_certfile: Optional[str] = None,
tls_keyfile: Optional[str] = None, tls_version: Optional[str] = None,
tls_ciphers: Optional[str] = None, tls_insecure: Optional[bool] = None,
@ -170,6 +182,7 @@ class MqttPlugin(Plugin):
wait for a response (default: 60 seconds).
:param tls_cafile: If TLS/SSL is enabled on the MQTT server and the certificate requires a certificate authority
to authenticate it, `ssl_cafile` will point to the provided ca.crt file (default: None).
:param ssl: SSL flag override.
:param tls_certfile: If TLS/SSL is enabled on the MQTT server and a client certificate it required, specify it
here (default: None).
:param tls_keyfile: If TLS/SSL is enabled on the MQTT server and a client certificate key it required, specify
@ -201,7 +214,7 @@ class MqttPlugin(Plugin):
client = self._get_client(tls_cafile=tls_cafile, tls_certfile=tls_certfile, tls_keyfile=tls_keyfile,
tls_version=tls_version, tls_ciphers=tls_ciphers, tls_insecure=tls_insecure,
username=username, password=password)
username=username, password=password, ssl=ssl)
client.connect(host, port, keepalive=timeout)
response_received = threading.Event()

View file

@ -104,7 +104,7 @@ class ZigbeeMqttPlugin(MqttPlugin, SwitchPlugin): # lgtm [py/missing-call-to-i
"""
def __init__(self, host: str = 'localhost', port: int = 1883, base_topic: str = 'zigbee2mqtt', timeout: int = 10,
tls_certfile: Optional[str] = None, tls_keyfile: Optional[str] = None,
ssl: bool = False, tls_certfile: Optional[str] = None, tls_keyfile: Optional[str] = None,
tls_version: Optional[str] = None, tls_ciphers: Optional[str] = None,
username: Optional[str] = None, password: Optional[str] = None, **kwargs):
"""
@ -113,6 +113,7 @@ class ZigbeeMqttPlugin(MqttPlugin, SwitchPlugin): # lgtm [py/missing-call-to-i
:param base_topic: Topic prefix, as specified in ``/opt/zigbee2mqtt/data/configuration.yaml``
(default: '``base_topic``').
:param timeout: If the command expects from a response, then this timeout value will be used
:param ssl: Set to true if SSL is enabled on the server.
(default: 60 seconds).
:param tls_cafile: If the connection requires TLS/SSL, specify the certificate authority file
(default: None)
@ -126,7 +127,7 @@ class ZigbeeMqttPlugin(MqttPlugin, SwitchPlugin): # lgtm [py/missing-call-to-i
"""
super().__init__(host=host, port=port, tls_certfile=tls_certfile, tls_keyfile=tls_keyfile,
tls_version=tls_version, tls_ciphers=tls_ciphers, username=username,
password=password, **kwargs)
password=password, ssl=ssl, **kwargs)
self.base_topic = base_topic
self.timeout = timeout

View file

@ -48,7 +48,7 @@ class ZwaveMqttPlugin(MqttPlugin, ZwaveBasePlugin):
def __init__(self, name: str, host: str = 'localhost', port: int = 1883, topic_prefix: str = 'zwave',
timeout: int = 10, tls_certfile: Optional[str] = None, tls_keyfile: Optional[str] = None,
tls_version: Optional[str] = None, tls_ciphers: Optional[str] = None, username: Optional[str] = None,
password: Optional[str] = None, **kwargs):
password: Optional[str] = None, ssl: bool = False, **kwargs):
"""
:param name: Gateway name, as configured from the zwavejs2mqtt web panel from Mqtt -> Name.
:param host: MQTT broker host, as configured from the zwavejs2mqtt web panel from Mqtt -> Host
@ -59,6 +59,7 @@ class ZwaveMqttPlugin(MqttPlugin, ZwaveBasePlugin):
(default: ``zwave``).
:param timeout: If the command expects from a response, then this timeout value will be used
(default: 60 seconds).
:param ssl: Set to True if SSL is enabled on the server.
:param tls_cafile: If the connection requires TLS/SSL, specify the certificate authority file
(default: None)
:param tls_certfile: If the connection requires TLS/SSL, specify the certificate file (default: None)