diff --git a/baseplate/server/wsgi.py b/baseplate/server/wsgi.py index 448c41ad9..678deebbb 100644 --- a/baseplate/server/wsgi.py +++ b/baseplate/server/wsgi.py @@ -1,10 +1,14 @@ +from __future__ import annotations + import datetime import logging import socket -from typing import Any +from typing import Any, Literal +import gevent +from gevent.event import Event from gevent.pool import Pool -from gevent.pywsgi import LoggingLogAdapter, WSGIServer +from gevent.pywsgi import LoggingLogAdapter, WSGIHandler, WSGIServer from gevent.server import StreamServer from baseplate.lib import config @@ -13,6 +17,92 @@ logger = logging.getLogger(__name__) +class BaseplateWSGIServer(WSGIServer): + """WSGI server which closes existing keepalive connections when shutting down. + + The default gevent WSGIServer prevents new *connections* once the server + enters shutdown, but does not prevent new *requests* over existing + keepalive connections. This results in slow shutdowns and in some cases + requests being killed mid-flight once the server reaches stop_timeout. + + This server may be used with any gevent WSGIHandler, but the keepalive + behavior only works when using BaseplateWSGIHandler. + """ + + shutdown_event: Event + + def __init__(self, *args: Any, **kwargs: Any) -> None: + self.shutdown_event = Event() + super().__init__(*args, **kwargs) + + def stop(self, *args: Any, **kwargs: Any) -> None: + self.shutdown_event.set() + super().stop(*args, **kwargs) + + +class BaseplateWSGIHandler(WSGIHandler): + """WSGI handler which avoids processing requests when the server is in shutdown. + + This handler may only be used with BaseplateWSGIServer. + """ + + _shutdown_event: Event + + # Flag representing whether the base class thinks the connection should be + # closed. The base class sets `self.close_connection` based on the HTTP + # version and headers, which we intercept using a property setter into this + # attribute. + _close_connection: bool = False + + def __init__( + self, sock: socket.socket, address: tuple[str, int], server: BaseplateWSGIServer + ) -> None: + self._shutdown_event = server.shutdown_event + super().__init__(sock, address, server) + + @property + def close_connection(self) -> bool: + # This property overrides `close_connection` in the base class which is + # used to control keepalive behavior. + return self._close_connection or self._shutdown_event.is_set() + + @close_connection.setter + def close_connection(self, value: bool) -> None: + # This setter allows the base class to set `self.close_connection` + # directly, while still allowing us to override the value when we know + # the Baseplate server is in shutdown. + self._close_connection = value + + def read_requestline(self) -> str | None: + real_read_requestline = gevent.spawn(super().read_requestline) + ready = gevent.wait([self._shutdown_event, real_read_requestline], count=1) + + if self._shutdown_event in ready: + real_read_requestline.kill() + # None triggers the base class to close the connection. + return None + + ret = real_read_requestline.get() + if isinstance(ret, BaseException): + raise ret + return ret + + def handle_one_request( + self, + ) -> ( + # 'None' is used to indicate that the connection should be closed by the caller. + None + # 'True' is used to indicate that the connection should be kept open for future requests. + | Literal[True] + # Tuple of status line and response body is used for returning an error response. + | tuple[str, bytes] + ): + ret = super().handle_one_request() + if ret is True and self._shutdown_event.is_set(): + return None + return ret + + def make_server(server_config: dict[str, str], listener: socket.socket, app: Any) -> StreamServer: """Make a gevent server for WSGI apps.""" # pylint: disable=maybe-no-member @@ -35,11 +125,19 @@ def make_server(server_config: dict[str, str], listener: socket.socket, app: Any pool = Pool() log = LoggingLogAdapter(logger, level=logging.DEBUG) - kwargs: dict[str, Any] = {} + kwargs: dict[str, Any] = { + "handler_class": BaseplateWSGIHandler, + } if cfg.handler: kwargs["handler_class"] = _load_factory(cfg.handler, default_name=None) + if not issubclass(kwargs["handler_class"], BaseplateWSGIHandler): + logger.warning( + "Custom handler %r is not a subclass of BaseplateWSGIHandler. " + "This may prevent proper shutdown behavior.", + cfg.handler, + ) - server = WSGIServer( + server = BaseplateWSGIServer( listener, application=app, spawn=pool, diff --git a/tests/integration/requests_tests.py b/tests/integration/requests_tests.py index 49fc1f046..bebbf49b6 100644 --- a/tests/integration/requests_tests.py +++ b/tests/integration/requests_tests.py @@ -1,9 +1,17 @@ +from __future__ import annotations + +import contextlib +import dataclasses import importlib import logging +import time +import urllib.parse import gevent import pytest import requests +import urllib3.connection +from gevent.pywsgi import WSGIServer from pyramid.config import Configurator from pyramid.httpexceptions import HTTPNoContent @@ -34,6 +42,8 @@ def gevent_socket(): @pytest.fixture def http_server(gevent_socket): class HttpServer: + server: WSGIServer + def __init__(self, address): self.url = f"http://{address[0]}:{address[1]}/" self.requests = [] @@ -56,8 +66,8 @@ def handle_request(self, request): configurator.add_view(http_server.handle_request, route_name="test_view", renderer="json") wsgi_app = configurator.make_wsgi_app() - server = make_server({"stop_timeout": "1 millisecond"}, listener, wsgi_app) - server_greenlet = gevent.spawn(server.serve_forever) + http_server.server = make_server({"stop_timeout": "1 millisecond"}, listener, wsgi_app) + server_greenlet = gevent.spawn(http_server.server.serve_forever) try: yield http_server finally: @@ -182,3 +192,151 @@ def test_external_client_doesnt_send_headers(http_server): assert "X-Parent" not in http_server.requests[0].headers assert "X-Span" not in http_server.requests[0].headers assert "X-Edge-Request" not in http_server.requests[0].headers + + +def _is_connected(conn: urllib3.connection.HTTPConnection) -> bool: + """Backport of urllib3.connection.HTTPConnection.is_connected(). + + Based on urllib3 v2.2.3: + https://github.com/urllib3/urllib3/blob/f9d37add7983d441b151146db447318dff4186c9/src/urllib3/connection.py#L299 + """ + if conn.sock is None: + return False + return not urllib3.util.wait_for_read(conn.sock, timeout=0.0) + + +@dataclasses.dataclass +class KeepaliveClientResult: + requests_completed: int = 0 + connection_closed_time: float | None = None + + +def _keepalive_client( + url: str, ready_event: gevent.event.Event, wait_time: float +) -> KeepaliveClientResult: + """HTTP client that makes requests forever over a single keepalive connection. + + Returns iff the connection is closed. Otherwise, it must be killed. + """ + parsed = urllib.parse.urlparse(url) + with contextlib.closing( + urllib3.connection.HTTPConnection(parsed.hostname, parsed.port, timeout=1), + ) as conn: + ret = KeepaliveClientResult() + conn.connect() + ready_event.set() + + last_request_time = None + while True: + if not _is_connected(conn): + print("Client lost connection to server, stopping request loop.") + ret.connection_closed_time = time.time() + break + + if last_request_time is None or time.time() - last_request_time >= wait_time: + print("Client making request.") + last_request_time = time.time() + conn.request("GET", "/") + response = conn.getresponse() + response.close() + + assert response.status == 204 + print("Client got expected response.") + ret.requests_completed += 1 + + # Sleeping for a short time rather than the full `wait_time` so we + # can notice if the connection closes. + gevent.sleep(0.01) + + return ret + + +@pytest.mark.parametrize( + ( + "delay_between_requests", + "min_expected_successful_requests", + "max_expected_successful_requests", + ), + ( + # Client that sends a request every 0.1 seconds. + ( + 0.1, + # ~10 requests in 1 second. + 5, + 15, + ), + # Client that sends one request then sleeps forever while keeping the + # connection open. + # + # This is used to test that the server closes keepalive connections + # even if they remain idle for the entire shutdown period. + ( + 999999999, + # The client should make exactly one request. + 1, + 1, + ), + ), +) +def test_shutdown_closes_existing_keepalive_connection( + http_server, + delay_between_requests, + min_expected_successful_requests, + max_expected_successful_requests, +): + """Ensure that the server closes keepalive connections when shutting down. + + By default, calling `stop()` on a gevent WSGIServer prevents new + connections but does not close existing ones. This allows clients to + continue sending new requests over existing connections right up until the + server's stop_timeout, resulting in slow shutdown and connections being + killed mid-flight, which causes user-facing errors. + + We work around this by subclassing WSGIHandler and (a) disabling keepalive + when the server is in shutdown, and (b) closing existing idle connections + when the server enters shutdown. + """ + http_server.server.stop_timeout = 10 + + ready_event = gevent.event.Event() + client_greenlet = gevent.spawn( + _keepalive_client, + http_server.url, + ready_event, + delay_between_requests, + ) + try: + print("Waiting for client to connect...") + ready_event.wait() + + print("Client connected, now waiting while it makes requests.") + gevent.sleep(1) + + print("Triggering server shutdown...") + shutdown_start = time.time() + http_server.server.stop() + finally: + # Server usually exits before the client notices the connection closed, + # so give it a second to finish. + client_greenlet.join(timeout=5) + + print(f"Shutdown completed after {time.time() - shutdown_start:.1f}s.") + + ret = client_greenlet.get() + if isinstance(ret, BaseException): + # This usually happens with GreenletExit. + raise ret + + print("Requests completed:", ret.requests_completed) + connection_closed_delay = ret.connection_closed_time - shutdown_start + print("Connection closed delay:", connection_closed_delay) + + assert ( + min_expected_successful_requests + <= ret.requests_completed + <= max_expected_successful_requests + ) + + # connection_closed_time should be within ~2 seconds after the shutdown + # start time, but not before it. + assert 0 <= connection_closed_delay <= 2