diff --git a/platypush/backend/websocket.py b/platypush/backend/websocket.py index 1b45d538d..376747b42 100644 --- a/platypush/backend/websocket.py +++ b/platypush/backend/websocket.py @@ -8,6 +8,7 @@ from platypush.backend import Backend from platypush.context import get_plugin, get_or_create_event_loop from platypush.message import Message from platypush.message.request import Request +from platypush.utils import get_ssl_server_context class WebsocketBackend(Backend): @@ -22,8 +23,9 @@ class WebsocketBackend(Backend): # Websocket client message recv timeout in seconds _websocket_client_timeout = 60 - def __init__(self, port=8765, bind_address='0.0.0.0', ssl_cert=None, - ssl_key=None, client_timeout=_websocket_client_timeout, **kwargs): + def __init__(self, port=8765, bind_address='0.0.0.0', ssl_cafile=None, + ssl_capath=None, ssl_cert=None, ssl_key=None, + client_timeout=_websocket_client_timeout, **kwargs): """ :param port: Listen port for the websocket server (default: 8765) :type port: int @@ -37,6 +39,12 @@ class WebsocketBackend(Backend): :param ssl_key: Path to the key file if you want to enable SSL (default: None) :type ssl_key: str + :param ssl_cafile: Path to the certificate authority file if required by the SSL configuration (default: None) + :type ssl_cafile: str + + :param ssl_capath: Path to the certificate authority directory if required by the SSL configuration (default: None) + :type ssl_capath: str + :param client_timeout: Timeout without any messages being received before closing a client connection. A zero timeout keeps the websocket open until an error occurs (default: 60 seconds) :type ping_timeout: int """ @@ -45,15 +53,13 @@ class WebsocketBackend(Backend): self.port = port self.bind_address = bind_address - self.ssl_context = None self.client_timeout = client_timeout - if ssl_cert: - self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - self.ssl_context.load_cert_chain( - certfile=os.path.abspath(os.path.expanduser(ssl_cert)), - keyfile=os.path.abspath(os.path.expanduser(ssl_key)) if ssl_key else None - ) + self.ssl_context = get_ssl_server_context(ssl_cert=ssl_cert, + ssl_key=ssl_key, + ssl_cafile=ssl_cafile, + ssl_capath=ssl_capath) \ + if ssl_cert else None def send_message(self, msg): websocket = get_plugin('websocket') diff --git a/platypush/plugins/websocket.py b/platypush/plugins/websocket.py index 44a1b4089..561e740ca 100644 --- a/platypush/plugins/websocket.py +++ b/platypush/plugins/websocket.py @@ -6,6 +6,7 @@ import websockets from platypush.context import get_or_create_event_loop from platypush.message import Message from platypush.plugins import Plugin, action +from platypush.utils import get_ssl_client_context class WebsocketPlugin(Plugin): @@ -21,7 +22,8 @@ class WebsocketPlugin(Plugin): super().__init__(*args, **kwargs) @action - def send(self, url, msg, ssl_cert=None, ssl_key=None, *args, **kwargs): + def send(self, url, msg, ssl_cert=None, ssl_key=None, ssl_cafile=None, + ssl_capath=None, *args, **kwargs): """ Sends a message to a websocket. @@ -35,17 +37,21 @@ class WebsocketPlugin(Plugin): :param ssl_key: Path to the SSL key to be used, if the SSL connection requires client authentication as well (default: None) :type ssl_key: str + + :param ssl_cafile: Path to the certificate authority file if required by the SSL configuration (default: None) + :type ssl_cafile: str + + :param ssl_capath: Path to the certificate authority directory if required by the SSL configuration (default: None) + :type ssl_capath: str """ async def send(): websocket_args = {} if ssl_cert: - ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - ssl_context.load_cert_chain( - certfile=os.path.abspath(os.path.expanduser(ssl_cert)), - keyfile=os.path.abspath(os.path.expanduser(ssl_key)) if ssl_key else None - ) - + websocket_args['ssl'] = get_ssl_client_context(ssl_cert=ssl_cert, + ssl_key=ssl_key, + ssl_cafile=ssl_cafile, + ssl_capath=ssl_capath) async with websockets.connect(url, **websocket_args) as websocket: try: diff --git a/platypush/utils/__init__.py b/platypush/utils/__init__.py index 32cd68f0c..989f0bd0b 100644 --- a/platypush/utils/__init__.py +++ b/platypush/utils/__init__.py @@ -6,6 +6,7 @@ import inspect import logging import os import signal +import ssl logger = logging.getLogger(__name__) @@ -120,5 +121,48 @@ def get_redis_queue_name_by_message(msg): return 'platypush/responses/{}'.format(msg.id) if msg.id else None +def _get_ssl_context(context_type=None, ssl_cert=None, ssl_key=None, + ssl_cafile=None, ssl_capath=None): + if not context_type: + ssl_context = ssl.create_default_context(cafile=ssl_cafile, + capath=ssl_capath) + else: + ssl_context = ssl.SSLContext(context_type) + + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + + if ssl_cafile or ssl_capath: + ssl_context.load_verify_locations( + cafile=ssl_cafile, capath=ssl_capath) + + ssl_context.load_cert_chain( + certfile=os.path.abspath(os.path.expanduser(ssl_cert)), + keyfile=os.path.abspath(os.path.expanduser(ssl_key)) if ssl_key else None + ) + + return ssl_context + + +def get_ssl_context(ssl_cert=None, ssl_key=None, ssl_cafile=None, + ssl_capath=None): + return _get_ssl_context(context_type=None, + ssl_cert=ssl_cert, ssl_key=ssl_key, + ssl_cafile=ssl_cafile, ssl_capath=ssl_capath) + + +def get_ssl_server_context(ssl_cert=None, ssl_key=None, ssl_cafile=None, + ssl_capath=None): + return _get_ssl_context(context_type=ssl.PROTOCOL_TLS_SERVER, + ssl_cert=ssl_cert, ssl_key=ssl_key, + ssl_cafile=ssl_cafile, ssl_capath=ssl_capath) + + +def get_ssl_client_context(ssl_cert=None, ssl_key=None, ssl_cafile=None, + ssl_capath=None): + return _get_ssl_context(context_type=ssl.PROTOCOL_TLS_CLIENT, + ssl_cert=ssl_cert, ssl_key=ssl_key, + ssl_cafile=ssl_cafile, ssl_capath=ssl_capath) + + # vim:sw=4:ts=4:et: