Make sure that the accept() in backend.tcp does not block the process

This commit is contained in:
Fabio Manganiello 2021-07-25 11:33:48 +02:00
parent 550fd3abe9
commit 8e2d4d0bce

View file

@ -1,5 +1,8 @@
import multiprocessing
import queue
import socket import socket
import threading import threading
from typing import Optional
from platypush.backend import Backend from platypush.backend import Backend
from platypush.message import Message from platypush.message import Message
@ -31,6 +34,8 @@ class TcpBackend(Backend):
self.port = port self.port = port
self.bind_address = bind_address or '0.0.0.0' self.bind_address = bind_address or '0.0.0.0'
self.listen_queue = listen_queue self.listen_queue = listen_queue
self._accept_queue = multiprocessing.Queue()
self._srv: Optional[multiprocessing.Process] = None
def _process_client(self, sock, address): def _process_client(self, sock, address):
def _f(): def _f():
@ -89,6 +94,14 @@ class TcpBackend(Backend):
threading.Thread(target=_f_wrapper, name='TCPListener').start() threading.Thread(target=_f_wrapper, name='TCPListener').start()
def _accept_process(self, serv_sock: socket.socket):
while not self.should_stop():
try:
(sock, address) = serv_sock.accept()
self._accept_queue.put((sock, address))
except socket.timeout:
continue
def run(self): def run(self):
super().run() super().run()
self.register_service(port=self.port) self.register_service(port=self.port)
@ -102,16 +115,23 @@ class TcpBackend(Backend):
format(self.port, self.bind_address)) format(self.port, self.bind_address))
serv_sock.listen(self.listen_queue) serv_sock.listen(self.listen_queue)
self._srv = multiprocessing.Process(target=self._accept_process, args=(serv_sock,))
self._srv.start()
while not self.should_stop(): while not self.should_stop():
try: try:
(sock, address) = serv_sock.accept() sock, address = self._accept_queue.get(timeout=1)
except socket.timeout: except (socket.timeout, queue.Empty):
continue continue
self.logger.info('Accepted connection from client {}'.format(address[0])) self.logger.info('Accepted connection from client {}'.format(address[0]))
self._process_client(sock, address) self._process_client(sock, address)
if self._srv:
self._srv.kill()
self._srv.join()
self._srv = None
self.logger.info('TCP backend terminated') self.logger.info('TCP backend terminated')