Websocket notifications delivery should be thread-safe.

If multiple threads process events and notify the websocket
clients at the same time then we may end up with inconsistent
messages delivered on the websocket (and websockets is not
designed to handle such cases). Protecting the send call with
a per-socket lock makes sure that we only write one message
at the time for a certain client.
This commit is contained in:
Fabio Manganiello 2020-02-06 01:04:36 +01:00
parent a6526a2a2d
commit ca030c9b25
2 changed files with 48 additions and 13 deletions

View File

@ -189,6 +189,10 @@ class HttpBackend(Backend):
self.local_base_url = '{proto}://localhost:{port}'.\ self.local_base_url = '{proto}://localhost:{port}'.\
format(proto=('https' if ssl_cert else 'http'), port=self.port) format(proto=('https' if ssl_cert else 'http'), port=self.port)
self._websocket_lock_timeout = 10
self._websocket_lock = threading.RLock()
self._websocket_locks = {}
def send_message(self, msg, **kwargs): def send_message(self, msg, **kwargs):
self.logger.warning('Use cURL or any HTTP client to query the HTTP backend') self.logger.warning('Use cURL or any HTTP client to query the HTTP backend')
@ -204,25 +208,54 @@ class HttpBackend(Backend):
self.server_proc.terminate() self.server_proc.terminate()
self.server_proc.join() self.server_proc.join()
def _acquire_websocket_lock(self, ws):
try:
acquire_ok = self._websocket_lock.acquire(timeout=self._websocket_lock_timeout)
if not acquire_ok:
raise TimeoutError('Websocket lock acquire timeout')
addr = ws.remote_address
if addr not in self._websocket_locks:
self._websocket_locks[addr] = threading.RLock()
finally:
self._websocket_lock.release()
acquire_ok = self._websocket_locks[addr].acquire(timeout=self._websocket_lock_timeout)
if not acquire_ok:
raise TimeoutError('Websocket on address {} not ready to receive data'.format(addr))
def _release_websocket_lock(self, ws):
addr = ws.local_address
if addr in self._websocket_locks:
try:
self._websocket_locks[addr].release()
except RuntimeError:
pass
def notify_web_clients(self, event): def notify_web_clients(self, event):
""" Notify all the connected web clients (over websocket) of a new event """ """ Notify all the connected web clients (over websocket) of a new event """
import websockets import websockets
async def send_event(ws): async def send_event(ws):
try: try:
self._acquire_websocket_lock(ws)
await ws.send(str(event)) await ws.send(str(event))
except Exception as e: except Exception as e:
self.logger.warning('Error on websocket send_event: {}'.format(e)) self.logger.warning('Error on websocket send_event: {}'.format(e))
finally:
self._release_websocket_lock(ws)
loop = get_or_create_event_loop() loop = get_or_create_event_loop()
wss = self.active_websockets.copy() wss = self.active_websockets.copy()
for websocket in wss: for _ws in wss:
try: try:
loop.run_until_complete(send_event(websocket)) loop.run_until_complete(send_event(_ws))
except websockets.exceptions.ConnectionClosed: except websockets.exceptions.ConnectionClosed:
self.logger.info('Client connection lost') self.logger.warning('Websocket client {} connection lost'.format(_ws.remote_address))
self.active_websockets.remove(websocket) self.active_websockets.remove(_ws)
if _ws.remote_address in self._websocket_locks:
del self._websocket_locks[_ws.remote_address]
def websocket(self): def websocket(self):
""" Websocket main server """ """ Websocket main server """
@ -230,7 +263,7 @@ class HttpBackend(Backend):
set_thread_name('WebsocketServer') set_thread_name('WebsocketServer')
async def register_websocket(websocket, path): async def register_websocket(websocket, path):
address = websocket.remote_address[0] if websocket.remote_address \ address = websocket.remote_address if websocket.remote_address \
else '<unknown client>' else '<unknown client>'
self.logger.info('New websocket connection from {} on path {}'.format(address, path)) self.logger.info('New websocket connection from {} on path {}'.format(address, path))
@ -241,6 +274,8 @@ class HttpBackend(Backend):
except websockets.exceptions.ConnectionClosed: except websockets.exceptions.ConnectionClosed:
self.logger.info('Websocket client {} closed connection'.format(address)) self.logger.info('Websocket client {} closed connection'.format(address))
self.active_websockets.remove(websocket) self.active_websockets.remove(websocket)
if address in self._websocket_locks:
del self._websocket_locks[address]
websocket_args = {} websocket_args = {}
if self.ssl_context: if self.ssl_context:

View File

@ -1,4 +1,4 @@
var websocket = { const websocket = {
ws: undefined, ws: undefined,
instance: undefined, instance: undefined,
pending: false, pending: false,
@ -10,7 +10,7 @@ var websocket = {
function initEvents() { function initEvents() {
try { try {
url_prefix = window.config.has_ssl ? 'wss://' : 'ws://'; const url_prefix = window.config.has_ssl ? 'wss://' : 'ws://';
websocket.ws = new WebSocket(url_prefix + window.location.hostname + ':' + window.config.websocket_port); websocket.ws = new WebSocket(url_prefix + window.location.hostname + ':' + window.config.websocket_port);
} catch (err) { } catch (err) {
console.error("Websocket initialization error"); console.error("Websocket initialization error");
@ -20,7 +20,7 @@ function initEvents() {
websocket.pending = true; websocket.pending = true;
var onWebsocketTimeout = function(self) { const onWebsocketTimeout = function(self) {
return function() { return function() {
console.log('Websocket reconnection timed out, retrying'); console.log('Websocket reconnection timed out, retrying');
websocket.pending = false; websocket.pending = false;
@ -33,8 +33,7 @@ function initEvents() {
onWebsocketTimeout(websocket.ws), websocket.reconnectMsecs); onWebsocketTimeout(websocket.ws), websocket.reconnectMsecs);
websocket.ws.onmessage = function(event) { websocket.ws.onmessage = function(event) {
console.debug(event); const handlers = [];
handlers = [];
event = event.data; event = event.data;
if (typeof event === 'string') { if (typeof event === 'string') {
@ -46,6 +45,7 @@ function initEvents() {
} }
} }
console.debug(event);
if (event.type !== 'event') { if (event.type !== 'event') {
// Discard non-event messages // Discard non-event messages
return; return;
@ -59,7 +59,7 @@ function initEvents() {
handlers.push(...websocket.handlers[event.args.type]); handlers.push(...websocket.handlers[event.args.type]);
} }
for (var handler of handlers) { for (const handler of handlers) {
handler(event.args); handler(event.args);
} }
}; };
@ -100,12 +100,12 @@ function initEvents() {
initEvents(); initEvents();
} }
}; };
}; }
function registerEventHandler(handler, ...events) { function registerEventHandler(handler, ...events) {
if (events.length) { if (events.length) {
// Event type filter specified // Event type filter specified
for (var event of events) { for (const event of events) {
if (!(event in websocket.handlers)) { if (!(event in websocket.handlers)) {
websocket.handlers[event] = []; websocket.handlers[event] = [];
} }