Refactored SSL context logic as utils methods
This commit is contained in:
parent
bf52304758
commit
38a8cac9c6
3 changed files with 72 additions and 16 deletions
|
@ -8,6 +8,7 @@ from platypush.backend import Backend
|
||||||
from platypush.context import get_plugin, get_or_create_event_loop
|
from platypush.context import get_plugin, get_or_create_event_loop
|
||||||
from platypush.message import Message
|
from platypush.message import Message
|
||||||
from platypush.message.request import Request
|
from platypush.message.request import Request
|
||||||
|
from platypush.utils import get_ssl_server_context
|
||||||
|
|
||||||
|
|
||||||
class WebsocketBackend(Backend):
|
class WebsocketBackend(Backend):
|
||||||
|
@ -22,8 +23,9 @@ class WebsocketBackend(Backend):
|
||||||
# Websocket client message recv timeout in seconds
|
# Websocket client message recv timeout in seconds
|
||||||
_websocket_client_timeout = 60
|
_websocket_client_timeout = 60
|
||||||
|
|
||||||
def __init__(self, port=8765, bind_address='0.0.0.0', ssl_cert=None,
|
def __init__(self, port=8765, bind_address='0.0.0.0', ssl_cafile=None,
|
||||||
ssl_key=None, client_timeout=_websocket_client_timeout, **kwargs):
|
ssl_capath=None, ssl_cert=None, ssl_key=None,
|
||||||
|
client_timeout=_websocket_client_timeout, **kwargs):
|
||||||
"""
|
"""
|
||||||
:param port: Listen port for the websocket server (default: 8765)
|
:param port: Listen port for the websocket server (default: 8765)
|
||||||
:type port: int
|
:type port: int
|
||||||
|
@ -37,6 +39,12 @@ class WebsocketBackend(Backend):
|
||||||
:param ssl_key: Path to the key file if you want to enable SSL (default: None)
|
:param ssl_key: Path to the key file if you want to enable SSL (default: None)
|
||||||
:type ssl_key: str
|
:type ssl_key: str
|
||||||
|
|
||||||
|
:param ssl_cafile: Path to the certificate authority file if required by the SSL configuration (default: None)
|
||||||
|
:type ssl_cafile: str
|
||||||
|
|
||||||
|
:param ssl_capath: Path to the certificate authority directory if required by the SSL configuration (default: None)
|
||||||
|
:type ssl_capath: str
|
||||||
|
|
||||||
:param client_timeout: Timeout without any messages being received before closing a client connection. A zero timeout keeps the websocket open until an error occurs (default: 60 seconds)
|
:param client_timeout: Timeout without any messages being received before closing a client connection. A zero timeout keeps the websocket open until an error occurs (default: 60 seconds)
|
||||||
:type ping_timeout: int
|
:type ping_timeout: int
|
||||||
"""
|
"""
|
||||||
|
@ -45,15 +53,13 @@ class WebsocketBackend(Backend):
|
||||||
|
|
||||||
self.port = port
|
self.port = port
|
||||||
self.bind_address = bind_address
|
self.bind_address = bind_address
|
||||||
self.ssl_context = None
|
|
||||||
self.client_timeout = client_timeout
|
self.client_timeout = client_timeout
|
||||||
|
|
||||||
if ssl_cert:
|
self.ssl_context = get_ssl_server_context(ssl_cert=ssl_cert,
|
||||||
self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
ssl_key=ssl_key,
|
||||||
self.ssl_context.load_cert_chain(
|
ssl_cafile=ssl_cafile,
|
||||||
certfile=os.path.abspath(os.path.expanduser(ssl_cert)),
|
ssl_capath=ssl_capath) \
|
||||||
keyfile=os.path.abspath(os.path.expanduser(ssl_key)) if ssl_key else None
|
if ssl_cert else None
|
||||||
)
|
|
||||||
|
|
||||||
def send_message(self, msg):
|
def send_message(self, msg):
|
||||||
websocket = get_plugin('websocket')
|
websocket = get_plugin('websocket')
|
||||||
|
|
|
@ -6,6 +6,7 @@ import websockets
|
||||||
from platypush.context import get_or_create_event_loop
|
from platypush.context import get_or_create_event_loop
|
||||||
from platypush.message import Message
|
from platypush.message import Message
|
||||||
from platypush.plugins import Plugin, action
|
from platypush.plugins import Plugin, action
|
||||||
|
from platypush.utils import get_ssl_client_context
|
||||||
|
|
||||||
|
|
||||||
class WebsocketPlugin(Plugin):
|
class WebsocketPlugin(Plugin):
|
||||||
|
@ -21,7 +22,8 @@ class WebsocketPlugin(Plugin):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
@action
|
@action
|
||||||
def send(self, url, msg, ssl_cert=None, ssl_key=None, *args, **kwargs):
|
def send(self, url, msg, ssl_cert=None, ssl_key=None, ssl_cafile=None,
|
||||||
|
ssl_capath=None, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Sends a message to a websocket.
|
Sends a message to a websocket.
|
||||||
|
|
||||||
|
@ -35,17 +37,21 @@ class WebsocketPlugin(Plugin):
|
||||||
|
|
||||||
:param ssl_key: Path to the SSL key to be used, if the SSL connection requires client authentication as well (default: None)
|
:param ssl_key: Path to the SSL key to be used, if the SSL connection requires client authentication as well (default: None)
|
||||||
:type ssl_key: str
|
:type ssl_key: str
|
||||||
|
|
||||||
|
:param ssl_cafile: Path to the certificate authority file if required by the SSL configuration (default: None)
|
||||||
|
:type ssl_cafile: str
|
||||||
|
|
||||||
|
:param ssl_capath: Path to the certificate authority directory if required by the SSL configuration (default: None)
|
||||||
|
:type ssl_capath: str
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def send():
|
async def send():
|
||||||
websocket_args = {}
|
websocket_args = {}
|
||||||
if ssl_cert:
|
if ssl_cert:
|
||||||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
websocket_args['ssl'] = get_ssl_client_context(ssl_cert=ssl_cert,
|
||||||
ssl_context.load_cert_chain(
|
ssl_key=ssl_key,
|
||||||
certfile=os.path.abspath(os.path.expanduser(ssl_cert)),
|
ssl_cafile=ssl_cafile,
|
||||||
keyfile=os.path.abspath(os.path.expanduser(ssl_key)) if ssl_key else None
|
ssl_capath=ssl_capath)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async with websockets.connect(url, **websocket_args) as websocket:
|
async with websockets.connect(url, **websocket_args) as websocket:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -6,6 +6,7 @@ import inspect
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
|
import ssl
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -120,5 +121,48 @@ def get_redis_queue_name_by_message(msg):
|
||||||
return 'platypush/responses/{}'.format(msg.id) if msg.id else None
|
return 'platypush/responses/{}'.format(msg.id) if msg.id else None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_ssl_context(context_type=None, ssl_cert=None, ssl_key=None,
|
||||||
|
ssl_cafile=None, ssl_capath=None):
|
||||||
|
if not context_type:
|
||||||
|
ssl_context = ssl.create_default_context(cafile=ssl_cafile,
|
||||||
|
capath=ssl_capath)
|
||||||
|
else:
|
||||||
|
ssl_context = ssl.SSLContext(context_type)
|
||||||
|
|
||||||
|
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||||
|
|
||||||
|
if ssl_cafile or ssl_capath:
|
||||||
|
ssl_context.load_verify_locations(
|
||||||
|
cafile=ssl_cafile, capath=ssl_capath)
|
||||||
|
|
||||||
|
ssl_context.load_cert_chain(
|
||||||
|
certfile=os.path.abspath(os.path.expanduser(ssl_cert)),
|
||||||
|
keyfile=os.path.abspath(os.path.expanduser(ssl_key)) if ssl_key else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return ssl_context
|
||||||
|
|
||||||
|
|
||||||
|
def get_ssl_context(ssl_cert=None, ssl_key=None, ssl_cafile=None,
|
||||||
|
ssl_capath=None):
|
||||||
|
return _get_ssl_context(context_type=None,
|
||||||
|
ssl_cert=ssl_cert, ssl_key=ssl_key,
|
||||||
|
ssl_cafile=ssl_cafile, ssl_capath=ssl_capath)
|
||||||
|
|
||||||
|
|
||||||
|
def get_ssl_server_context(ssl_cert=None, ssl_key=None, ssl_cafile=None,
|
||||||
|
ssl_capath=None):
|
||||||
|
return _get_ssl_context(context_type=ssl.PROTOCOL_TLS_SERVER,
|
||||||
|
ssl_cert=ssl_cert, ssl_key=ssl_key,
|
||||||
|
ssl_cafile=ssl_cafile, ssl_capath=ssl_capath)
|
||||||
|
|
||||||
|
|
||||||
|
def get_ssl_client_context(ssl_cert=None, ssl_key=None, ssl_cafile=None,
|
||||||
|
ssl_capath=None):
|
||||||
|
return _get_ssl_context(context_type=ssl.PROTOCOL_TLS_CLIENT,
|
||||||
|
ssl_cert=ssl_cert, ssl_key=ssl_key,
|
||||||
|
ssl_cafile=ssl_cafile, ssl_capath=ssl_capath)
|
||||||
|
|
||||||
|
|
||||||
# vim:sw=4:ts=4:et:
|
# vim:sw=4:ts=4:et:
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue