From ca030c9b25519285f9335f363a917951bc0a14f6 Mon Sep 17 00:00:00 2001 From: Fabio Manganiello Date: Thu, 6 Feb 2020 01:04:36 +0100 Subject: [PATCH] Websocket notifications delivery should be thread-safe. If multiple threads process events and notify the websocket clients at the same time then we may end up with inconsistent messages delivered on the websocket (and websockets is not designed to handle such cases). Protecting the send call with a per-socket lock makes sure that we only write one message at the time for a certain client. --- platypush/backend/http/__init__.py | 45 +++++++++++++++++++--- platypush/backend/http/static/js/events.js | 16 ++++---- 2 files changed, 48 insertions(+), 13 deletions(-) diff --git a/platypush/backend/http/__init__.py b/platypush/backend/http/__init__.py index 4948b320..a92cfcc9 100644 --- a/platypush/backend/http/__init__.py +++ b/platypush/backend/http/__init__.py @@ -189,6 +189,10 @@ class HttpBackend(Backend): self.local_base_url = '{proto}://localhost:{port}'.\ format(proto=('https' if ssl_cert else 'http'), port=self.port) + self._websocket_lock_timeout = 10 + self._websocket_lock = threading.RLock() + self._websocket_locks = {} + def send_message(self, msg, **kwargs): self.logger.warning('Use cURL or any HTTP client to query the HTTP backend') @@ -204,25 +208,54 @@ class HttpBackend(Backend): self.server_proc.terminate() self.server_proc.join() + def _acquire_websocket_lock(self, ws): + try: + acquire_ok = self._websocket_lock.acquire(timeout=self._websocket_lock_timeout) + if not acquire_ok: + raise TimeoutError('Websocket lock acquire timeout') + + addr = ws.remote_address + if addr not in self._websocket_locks: + self._websocket_locks[addr] = threading.RLock() + finally: + self._websocket_lock.release() + + acquire_ok = self._websocket_locks[addr].acquire(timeout=self._websocket_lock_timeout) + if not acquire_ok: + raise TimeoutError('Websocket on address {} not ready to receive data'.format(addr)) + + def _release_websocket_lock(self, ws): + addr = ws.local_address + if addr in self._websocket_locks: + try: + self._websocket_locks[addr].release() + except RuntimeError: + pass + def notify_web_clients(self, event): """ Notify all the connected web clients (over websocket) of a new event """ import websockets async def send_event(ws): try: + self._acquire_websocket_lock(ws) await ws.send(str(event)) except Exception as e: self.logger.warning('Error on websocket send_event: {}'.format(e)) + finally: + self._release_websocket_lock(ws) loop = get_or_create_event_loop() wss = self.active_websockets.copy() - for websocket in wss: + for _ws in wss: try: - loop.run_until_complete(send_event(websocket)) + loop.run_until_complete(send_event(_ws)) except websockets.exceptions.ConnectionClosed: - self.logger.info('Client connection lost') - self.active_websockets.remove(websocket) + self.logger.warning('Websocket client {} connection lost'.format(_ws.remote_address)) + self.active_websockets.remove(_ws) + if _ws.remote_address in self._websocket_locks: + del self._websocket_locks[_ws.remote_address] def websocket(self): """ Websocket main server """ @@ -230,7 +263,7 @@ class HttpBackend(Backend): set_thread_name('WebsocketServer') async def register_websocket(websocket, path): - address = websocket.remote_address[0] if websocket.remote_address \ + address = websocket.remote_address if websocket.remote_address \ else '' self.logger.info('New websocket connection from {} on path {}'.format(address, path)) @@ -241,6 +274,8 @@ class HttpBackend(Backend): except websockets.exceptions.ConnectionClosed: self.logger.info('Websocket client {} closed connection'.format(address)) self.active_websockets.remove(websocket) + if address in self._websocket_locks: + del self._websocket_locks[address] websocket_args = {} if self.ssl_context: diff --git a/platypush/backend/http/static/js/events.js b/platypush/backend/http/static/js/events.js index 266ecad4..13732aa2 100644 --- a/platypush/backend/http/static/js/events.js +++ b/platypush/backend/http/static/js/events.js @@ -1,4 +1,4 @@ -var websocket = { +const websocket = { ws: undefined, instance: undefined, pending: false, @@ -10,7 +10,7 @@ var websocket = { function initEvents() { try { - url_prefix = window.config.has_ssl ? 'wss://' : 'ws://'; + const url_prefix = window.config.has_ssl ? 'wss://' : 'ws://'; websocket.ws = new WebSocket(url_prefix + window.location.hostname + ':' + window.config.websocket_port); } catch (err) { console.error("Websocket initialization error"); @@ -20,7 +20,7 @@ function initEvents() { websocket.pending = true; - var onWebsocketTimeout = function(self) { + const onWebsocketTimeout = function(self) { return function() { console.log('Websocket reconnection timed out, retrying'); websocket.pending = false; @@ -33,8 +33,7 @@ function initEvents() { onWebsocketTimeout(websocket.ws), websocket.reconnectMsecs); websocket.ws.onmessage = function(event) { - console.debug(event); - handlers = []; + const handlers = []; event = event.data; if (typeof event === 'string') { @@ -46,6 +45,7 @@ function initEvents() { } } + console.debug(event); if (event.type !== 'event') { // Discard non-event messages return; @@ -59,7 +59,7 @@ function initEvents() { handlers.push(...websocket.handlers[event.args.type]); } - for (var handler of handlers) { + for (const handler of handlers) { handler(event.args); } }; @@ -100,12 +100,12 @@ function initEvents() { initEvents(); } }; -}; +} function registerEventHandler(handler, ...events) { if (events.length) { // Event type filter specified - for (var event of events) { + for (const event of events) { if (!(event in websocket.handlers)) { websocket.handlers[event] = []; }