diff --git a/platypush/backend/http/__init__.py b/platypush/backend/http/__init__.py index 3e03ba6b41..67caf49a2a 100644 --- a/platypush/backend/http/__init__.py +++ b/platypush/backend/http/__init__.py @@ -1,14 +1,18 @@ +import asyncio import os import pathlib import secrets import threading from multiprocessing import Process -from typing import Mapping, Optional +from time import time +from typing import List, Mapping, Optional +from tornado.httpserver import HTTPServer +from tornado.netutil import bind_sockets +from tornado.process import cpu_count, fork_processes from tornado.wsgi import WSGIContainer from tornado.web import Application, FallbackHandler -from tornado.ioloop import IOLoop from platypush.backend import Backend from platypush.backend.http.app import application @@ -186,6 +190,7 @@ class HttpBackend(Backend): """ _DEFAULT_HTTP_PORT = 8008 + """The default listen port for the webserver.""" def __init__( self, @@ -193,6 +198,7 @@ class HttpBackend(Backend): bind_address: str = '0.0.0.0', resource_dirs: Optional[Mapping[str, str]] = None, secret_key_file: Optional[str] = None, + num_workers: Optional[int] = None, **kwargs, ): """ @@ -204,15 +210,16 @@ class HttpBackend(Backend): the value is the absolute path to expose. :param secret_key_file: Path to the file containing the secret key that will be used by Flask (default: ``~/.local/share/platypush/flask.secret.key``). + :param num_workers: Number of worker processes to use (default: ``(cpu_count * 2) + 1``). """ super().__init__(**kwargs) self.port = port - self.server_proc = None + self._server_proc: Optional[Process] = None + self._workers: List[Process] = [] self._service_registry_thread = None self.bind_address = bind_address - self._io_loop: Optional[IOLoop] = None if resource_dirs: self.resource_dirs = { @@ -227,6 +234,7 @@ class HttpBackend(Backend): or os.path.join(Config.get('workdir'), 'flask.secret.key') # type: ignore ) self.local_base_url = f'http://localhost:{self.port}' + self.num_workers = num_workers or (cpu_count() * 2) + 1 def send_message(self, *_, **__): self.logger.warning('Use cURL or any HTTP client to query the HTTP backend') @@ -236,28 +244,34 @@ class HttpBackend(Backend): super().on_stop() self.logger.info('Received STOP event on HttpBackend') - if self._io_loop: - self._io_loop.stop() - self._io_loop.close() + start_time = time() + timeout = 5 + workers = self._workers.copy() - if self.server_proc: - self.server_proc.terminate() - self.server_proc.join(timeout=10) - if self.server_proc.is_alive(): - self.server_proc.kill() - if self.server_proc.is_alive(): - self.logger.info( - 'HTTP server process may be still alive at termination' - ) - else: - self.logger.info('HTTP server process terminated') + for i, worker in enumerate(workers[::-1]): + if worker and worker.is_alive(): + worker.terminate() + worker.join(timeout=max(0, start_time + timeout - time())) + + if worker and worker.is_alive(): + worker.kill() + self._workers.pop(i) + + if self._server_proc: + self._server_proc.terminate() + self._server_proc.join(timeout=5) + self._server_proc = None + + if self._server_proc and self._server_proc.is_alive(): + self._server_proc.kill() + + self._server_proc = None + self.logger.info('HTTP server terminated') if self._service_registry_thread and self._service_registry_thread.is_alive(): self._service_registry_thread.join(timeout=5) self._service_registry_thread = None - self._io_loop = None - def notify_web_clients(self, event): """Notify all the connected web clients (over websocket) of a new event""" get_redis().publish(events_redis_topic, str(event)) @@ -281,35 +295,6 @@ class HttpBackend(Backend): raise e - def _web_server_proc(self): - def proc(): - self.logger.info('Starting local web server on port %s', self.port) - assert isinstance( - self.bus, RedisBus - ), 'The HTTP backend only works if backed by a Redis bus' - - application.config['redis_queue'] = self.bus.redis_queue - application.secret_key = self._get_secret_key() - - container = WSGIContainer(application) - server = Application( - [ - *[(route.path(), route) for route in get_ws_routes()], - (r'.*', FallbackHandler, {'fallback': container}), - ] - ) - - server.listen(address=self.bind_address, port=self.port) - self._io_loop = IOLoop.instance() - - try: - self._io_loop.start() - except Exception as e: - if not self.should_stop(): - raise e - - return proc - def _register_service(self): try: self.register_service(port=self.port) @@ -324,16 +309,55 @@ class HttpBackend(Backend): ) self._service_registry_thread.start() - def _run_web_server(self): - self.server_proc = Process(target=self._web_server_proc(), name='WebServer') - self.server_proc.start() - self.server_proc.join() + async def _post_fork_main(self, sockets): + assert isinstance( + self.bus, RedisBus + ), 'The HTTP backend only works if backed by a Redis bus' + + application.config['redis_queue'] = self.bus.redis_queue + application.secret_key = self._get_secret_key() + container = WSGIContainer(application) + tornado_app = Application( + [ + *[(route.path(), route) for route in get_ws_routes()], + (r'.*', FallbackHandler, {'fallback': container}), + ] + ) + + server = HTTPServer(tornado_app) + server.add_sockets(sockets) + + try: + await asyncio.Event().wait() + except (asyncio.CancelledError, KeyboardInterrupt): + return + + def _web_server_proc(self): + self.logger.info( + 'Starting local web server on port %s with %d service workers', + self.port, + self.num_workers, + ) + + sockets = bind_sockets(self.port, address=self.bind_address, reuse_port=True) + + try: + fork_processes(self.num_workers) + future = self._post_fork_main(sockets) + asyncio.run(future) + except (asyncio.CancelledError, KeyboardInterrupt): + return + + def _start_web_server(self): + self._server_proc = Process(target=self._web_server_proc) + self._server_proc.start() + self._server_proc.join() def run(self): super().run() self._start_zeroconf_service() - self._run_web_server() + self._start_web_server() # vim:sw=4:ts=4:et: diff --git a/platypush/backend/http/app/ws/_base.py b/platypush/backend/http/app/ws/_base.py index 47e4498852..3ce53378f8 100644 --- a/platypush/backend/http/app/ws/_base.py +++ b/platypush/backend/http/app/ws/_base.py @@ -89,7 +89,7 @@ class WSRoute(WebSocketHandler, Thread, ABC): continue yield msg.get('data') - except RedisConnectionError: + except (AttributeError, RedisConnectionError): return def send(self, msg: Union[str, bytes, dict, list, tuple, set]) -> None: