diff --git a/platypush/backend/http/__init__.py b/platypush/backend/http/__init__.py index 12ce9f2c5..6bda90657 100644 --- a/platypush/backend/http/__init__.py +++ b/platypush/backend/http/__init__.py @@ -2,13 +2,17 @@ import asyncio import os import pathlib import secrets +import signal import threading +from functools import partial from multiprocessing import Process from time import time -from typing import List, Mapping, Optional -from tornado.httpserver import HTTPServer +from typing import Mapping, Optional +import psutil + +from tornado.httpserver import HTTPServer from tornado.netutil import bind_sockets from tornado.process import cpu_count, fork_processes from tornado.wsgi import WSGIContainer @@ -18,9 +22,9 @@ from platypush.backend import Backend from platypush.backend.http.app import application from platypush.backend.http.app.utils import get_streaming_routes, get_ws_routes from platypush.backend.http.app.ws.events import WSEventProxy - from platypush.bus.redis import RedisBus from platypush.config import Config +from platypush.utils import get_remaining_timeout class HttpBackend(Backend): @@ -191,6 +195,9 @@ class HttpBackend(Backend): _DEFAULT_HTTP_PORT = 8008 """The default listen port for the webserver.""" + _STOP_TIMEOUT = 5 + """How long we should wait (in seconds) before killing the worker processes.""" + def __init__( self, port: int = _DEFAULT_HTTP_PORT, @@ -227,7 +234,6 @@ class HttpBackend(Backend): self.port = port self._server_proc: Optional[Process] = None - self._workers: List[Process] = [] self._service_registry_thread = None self.bind_address = bind_address @@ -254,39 +260,37 @@ class HttpBackend(Backend): """On backend stop""" super().on_stop() self.logger.info('Received STOP event on HttpBackend') - - start_time = time() - timeout = 5 - workers = self._workers.copy() - - for i, worker in enumerate(workers[::-1]): - if worker and worker.is_alive(): - worker.terminate() - worker.join(timeout=max(0, start_time + timeout - time())) - - if worker and worker.is_alive(): - worker.kill() - self._workers.pop(i) + start = time() + remaining_time: partial[float] = partial( # type: ignore + get_remaining_timeout, timeout=self._STOP_TIMEOUT, start=start + ) if self._server_proc: - try: - self._server_proc.terminate() - self._server_proc.join(timeout=5) - except AttributeError: - pass + if self._server_proc.pid: + try: + os.kill(self._server_proc.pid, signal.SIGINT) + except OSError: + pass - self._server_proc = None + if self._server_proc and self._server_proc.is_alive(): + self._server_proc.join(timeout=remaining_time() / 2) + try: + self._server_proc.terminate() + self._server_proc.join(timeout=remaining_time() / 2) + except AttributeError: + pass if self._server_proc and self._server_proc.is_alive(): self._server_proc.kill() self._server_proc = None - self.logger.info('HTTP server terminated') if self._service_registry_thread and self._service_registry_thread.is_alive(): - self._service_registry_thread.join(timeout=5) + self._service_registry_thread.join(timeout=remaining_time()) self._service_registry_thread = None + self.logger.info('HTTP server terminated') + def notify_web_clients(self, event): """Notify all the connected web clients (over websocket) of a new event""" WSEventProxy.publish(event) # noqa: E1120 @@ -348,7 +352,10 @@ class HttpBackend(Backend): try: await asyncio.Event().wait() except (asyncio.CancelledError, KeyboardInterrupt): - return + pass + finally: + server.stop() + await server.close_all_connections() def _web_server_proc(self): self.logger.info( @@ -375,7 +382,65 @@ class HttpBackend(Backend): future = self._post_fork_main(sockets) asyncio.run(future) except (asyncio.CancelledError, KeyboardInterrupt): - return + pass + finally: + self._stop_workers() + + def _stop_workers(self): + """ + Stop all the worker processes. + + We have to run this manually on server termination because of a + long-standing issue with Tornado not being able to wind down the forked + workers when the server terminates: + https://github.com/tornadoweb/tornado/issues/1912. + """ + parent_pid = ( + self._server_proc.pid + if self._server_proc and self._server_proc.pid + else None + ) + + if not parent_pid: + return + + try: + cur_proc = psutil.Process(parent_pid) + except psutil.NoSuchProcess: + return + + # Send a SIGTERM to all the children + children = cur_proc.children() + for child in children: + if child.pid != parent_pid and child.is_running(): + try: + os.kill(child.pid, signal.SIGTERM) + except OSError as e: + self.logger.warning( + 'Could not send SIGTERM to PID %d: %s', child.pid, e + ) + + # Initialize the timeout + start = time() + remaining_time: partial[int] = partial( # type: ignore + get_remaining_timeout, timeout=self._STOP_TIMEOUT, start=start, cls=int + ) + + # Wait for all children to terminate (with timeout) + for child in children: + if child.pid != parent_pid and child.is_running(): + try: + child.wait(timeout=remaining_time()) + except TimeoutError: + pass + + # Send a SIGKILL to any child process that is still running + for child in children: + if child.pid != parent_pid and child.is_running(): + try: + child.kill() + except OSError: + pass def _start_web_server(self): self._server_proc = Process(target=self._web_server_proc)