diff --git a/platypush/backend/websocket.py b/platypush/backend/websocket.py index 99b3dbd138..271a7d66fc 100644 --- a/platypush/backend/websocket.py +++ b/platypush/backend/websocket.py @@ -19,7 +19,11 @@ class WebsocketBackend(Backend): * **websockets** (``pip install websockets``) """ - def __init__(self, port=8765, bind_address='0.0.0.0', ssl_cert=None, **kwargs): + # Websocket client message recv timeout in seconds + _websocket_client_timeout = 60 + + def __init__(self, port=8765, bind_address='0.0.0.0', ssl_cert=None, + client_timeout=_websocket_client_timeout, **kwargs): """ :param port: Listen port for the websocket server (default: 8765) :type port: int @@ -29,6 +33,9 @@ class WebsocketBackend(Backend): :param ssl_cert: Path to the PEM certificate file if you want to enable SSL (default: None) :type ssl_cert: str + + :param client_timeout: Timeout without any messages being received before closing a client connection. A zero timeout keeps the websocket open until an error occurs (default: 60 seconds) + :type ping_timeout: int """ super().__init__(**kwargs) @@ -36,6 +43,7 @@ class WebsocketBackend(Backend): self.port = port self.bind_address = bind_address self.ssl_context = None + self.client_timeout = client_timeout if ssl_cert: self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) @@ -62,29 +70,39 @@ class WebsocketBackend(Backend): super().run() async def serve_client(websocket, path): - self.logger.info('New websocket connection from {}'. + self.logger.debug('New websocket connection from {}'. format(websocket.remote_address[0])) try: - msg = await websocket.recv() - msg = Message.build(msg) - self.logger.info('Received message from {}: {}'. - format(websocket.remote_address[0], msg)) + while True: + if self.client_timeout: + msg = await asyncio.wait_for(websocket.recv(), + timeout=self.client_timeout) + else: + msg = await websocket.recv() - self.on_message(msg) + msg = Message.build(msg) + self.logger.info('Received message from {}: {}'. + format(websocket.remote_address[0], msg)) - if isinstance(msg, Request): - response = self.get_message_response(msg) + self.on_message(msg) + + if isinstance(msg, Request): + response = self.get_message_response(msg) + assert response is not None - if response: self.logger.info('Processing response on the websocket backend: {}'. - format(response)) + format(response)) await websocket.send(str(response)) + except Exception as e: if isinstance(e, websockets.exceptions.ConnectionClosed): - self.logger.info('Websocket client {} closed connection'. - format(websocket.remote_address[0])) + self.logger.debug('Websocket client {} closed connection'. + format(websocket.remote_address[0])) + elif isinstance(e, asyncio.TimeoutError): + self.logger.debug('Websocket connection to {} timed out'. + format(websocket.remote_address[0])) else: self.logger.exception(e) @@ -96,7 +114,9 @@ class WebsocketBackend(Backend): websocket_args['ssl'] = self.ssl_context loop = get_or_create_event_loop() - server = websockets.serve(serve_client, self.bind_address, self.port, **websocket_args) + server = websockets.serve(serve_client, self.bind_address, self.port, + **websocket_args) + loop.run_until_complete(server) loop.run_forever()