Implemented HttpBackend._stop_workers.

The Tornado WSGI container won't guarantee the termination of the
spawned workers upon termination, so the code of the backend has to take
care of it and terminate all the children processes of the server
process when it terminates.

This also means that `psutil` is now a required base dependency, as we
need to expand the process subtree under the webserver launcher.
This commit is contained in:
Fabio Manganiello 2023-08-15 00:13:34 +02:00
parent 04921c759f
commit a8a7ceb2ac
Signed by untrusted user: blacklight
GPG key ID: D90FBA7F76362774

View file

@ -2,13 +2,17 @@ import asyncio
import os import os
import pathlib import pathlib
import secrets import secrets
import signal
import threading import threading
from functools import partial
from multiprocessing import Process from multiprocessing import Process
from time import time from time import time
from typing import List, Mapping, Optional from typing import Mapping, Optional
from tornado.httpserver import HTTPServer
import psutil
from tornado.httpserver import HTTPServer
from tornado.netutil import bind_sockets from tornado.netutil import bind_sockets
from tornado.process import cpu_count, fork_processes from tornado.process import cpu_count, fork_processes
from tornado.wsgi import WSGIContainer from tornado.wsgi import WSGIContainer
@ -18,9 +22,9 @@ from platypush.backend import Backend
from platypush.backend.http.app import application from platypush.backend.http.app import application
from platypush.backend.http.app.utils import get_streaming_routes, get_ws_routes from platypush.backend.http.app.utils import get_streaming_routes, get_ws_routes
from platypush.backend.http.app.ws.events import WSEventProxy from platypush.backend.http.app.ws.events import WSEventProxy
from platypush.bus.redis import RedisBus from platypush.bus.redis import RedisBus
from platypush.config import Config from platypush.config import Config
from platypush.utils import get_remaining_timeout
class HttpBackend(Backend): class HttpBackend(Backend):
@ -191,6 +195,9 @@ class HttpBackend(Backend):
_DEFAULT_HTTP_PORT = 8008 _DEFAULT_HTTP_PORT = 8008
"""The default listen port for the webserver.""" """The default listen port for the webserver."""
_STOP_TIMEOUT = 5
"""How long we should wait (in seconds) before killing the worker processes."""
def __init__( def __init__(
self, self,
port: int = _DEFAULT_HTTP_PORT, port: int = _DEFAULT_HTTP_PORT,
@ -227,7 +234,6 @@ class HttpBackend(Backend):
self.port = port self.port = port
self._server_proc: Optional[Process] = None self._server_proc: Optional[Process] = None
self._workers: List[Process] = []
self._service_registry_thread = None self._service_registry_thread = None
self.bind_address = bind_address self.bind_address = bind_address
@ -254,39 +260,37 @@ class HttpBackend(Backend):
"""On backend stop""" """On backend stop"""
super().on_stop() super().on_stop()
self.logger.info('Received STOP event on HttpBackend') self.logger.info('Received STOP event on HttpBackend')
start = time()
start_time = time() remaining_time: partial[float] = partial( # type: ignore
timeout = 5 get_remaining_timeout, timeout=self._STOP_TIMEOUT, start=start
workers = self._workers.copy() )
for i, worker in enumerate(workers[::-1]):
if worker and worker.is_alive():
worker.terminate()
worker.join(timeout=max(0, start_time + timeout - time()))
if worker and worker.is_alive():
worker.kill()
self._workers.pop(i)
if self._server_proc: if self._server_proc:
try: if self._server_proc.pid:
self._server_proc.terminate() try:
self._server_proc.join(timeout=5) os.kill(self._server_proc.pid, signal.SIGINT)
except AttributeError: except OSError:
pass pass
self._server_proc = None if self._server_proc and self._server_proc.is_alive():
self._server_proc.join(timeout=remaining_time() / 2)
try:
self._server_proc.terminate()
self._server_proc.join(timeout=remaining_time() / 2)
except AttributeError:
pass
if self._server_proc and self._server_proc.is_alive(): if self._server_proc and self._server_proc.is_alive():
self._server_proc.kill() self._server_proc.kill()
self._server_proc = None self._server_proc = None
self.logger.info('HTTP server terminated')
if self._service_registry_thread and self._service_registry_thread.is_alive(): if self._service_registry_thread and self._service_registry_thread.is_alive():
self._service_registry_thread.join(timeout=5) self._service_registry_thread.join(timeout=remaining_time())
self._service_registry_thread = None self._service_registry_thread = None
self.logger.info('HTTP server terminated')
def notify_web_clients(self, event): def notify_web_clients(self, event):
"""Notify all the connected web clients (over websocket) of a new event""" """Notify all the connected web clients (over websocket) of a new event"""
WSEventProxy.publish(event) # noqa: E1120 WSEventProxy.publish(event) # noqa: E1120
@ -348,7 +352,10 @@ class HttpBackend(Backend):
try: try:
await asyncio.Event().wait() await asyncio.Event().wait()
except (asyncio.CancelledError, KeyboardInterrupt): except (asyncio.CancelledError, KeyboardInterrupt):
return pass
finally:
server.stop()
await server.close_all_connections()
def _web_server_proc(self): def _web_server_proc(self):
self.logger.info( self.logger.info(
@ -375,7 +382,65 @@ class HttpBackend(Backend):
future = self._post_fork_main(sockets) future = self._post_fork_main(sockets)
asyncio.run(future) asyncio.run(future)
except (asyncio.CancelledError, KeyboardInterrupt): except (asyncio.CancelledError, KeyboardInterrupt):
return pass
finally:
self._stop_workers()
def _stop_workers(self):
"""
Stop all the worker processes.
We have to run this manually on server termination because of a
long-standing issue with Tornado not being able to wind down the forked
workers when the server terminates:
https://github.com/tornadoweb/tornado/issues/1912.
"""
parent_pid = (
self._server_proc.pid
if self._server_proc and self._server_proc.pid
else None
)
if not parent_pid:
return
try:
cur_proc = psutil.Process(parent_pid)
except psutil.NoSuchProcess:
return
# Send a SIGTERM to all the children
children = cur_proc.children()
for child in children:
if child.pid != parent_pid and child.is_running():
try:
os.kill(child.pid, signal.SIGTERM)
except OSError as e:
self.logger.warning(
'Could not send SIGTERM to PID %d: %s', child.pid, e
)
# Initialize the timeout
start = time()
remaining_time: partial[int] = partial( # type: ignore
get_remaining_timeout, timeout=self._STOP_TIMEOUT, start=start, cls=int
)
# Wait for all children to terminate (with timeout)
for child in children:
if child.pid != parent_pid and child.is_running():
try:
child.wait(timeout=remaining_time())
except TimeoutError:
pass
# Send a SIGKILL to any child process that is still running
for child in children:
if child.pid != parent_pid and child.is_running():
try:
child.kill()
except OSError:
pass
def _start_web_server(self): def _start_web_server(self):
self._server_proc = Process(target=self._web_server_proc) self._server_proc = Process(target=self._web_server_proc)