Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent GreenletExit exceptions (and improve shutdown time) when shutting down with active keepalive connections #1009

Merged
merged 5 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 102 additions & 4 deletions baseplate/server/wsgi.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from __future__ import annotations

import datetime
import logging
import socket
from typing import Any
from typing import Any, Literal

import gevent
from gevent.event import Event
from gevent.pool import Pool
from gevent.pywsgi import LoggingLogAdapter, WSGIServer
from gevent.pywsgi import LoggingLogAdapter, WSGIHandler, WSGIServer
from gevent.server import StreamServer

from baseplate.lib import config
Expand All @@ -13,6 +17,92 @@
logger = logging.getLogger(__name__)


class BaseplateWSGIServer(WSGIServer):
"""WSGI server which closes existing keepalive connections when shutting down.

The default gevent WSGIServer prevents new *connections* once the server
enters shutdown, but does not prevent new *requests* over existing
keepalive connections. This results in slow shutdowns and in some cases
requests being killed mid-flight once the server reaches stop_timeout.

This server may be used with any gevent WSGIHandler, but the keepalive
behavior only works when using BaseplateWSGIHandler.
"""

shutdown_event: Event

def __init__(self, *args: Any, **kwargs: Any) -> None:
self.shutdown_event = Event()
super().__init__(*args, **kwargs)

def stop(self, *args: Any, **kwargs: Any) -> None:
self.shutdown_event.set()
super().stop(*args, **kwargs)


class BaseplateWSGIHandler(WSGIHandler):
"""WSGI handler which avoids processing requests when the server is in shutdown.

This handler may only be used with BaseplateWSGIServer.
"""

_shutdown_event: Event

# Flag representing whether the base class thinks the connection should be
# closed. The base class sets `self.close_connection` based on the HTTP
# version and headers, which we intercept using a property setter into this
# attribute.
_close_connection: bool = False

def __init__(
self, sock: socket.socket, address: tuple[str, int], server: BaseplateWSGIServer
) -> None:
self._shutdown_event = server.shutdown_event
super().__init__(sock, address, server)

@property
def close_connection(self) -> bool:
# This property overrides `close_connection` in the base class which is
# used to control keepalive behavior.
return self._close_connection or self._shutdown_event.is_set()

@close_connection.setter
def close_connection(self, value: bool) -> None:
# This setter allows the base class to set `self.close_connection`
# directly, while still allowing us to override the value when we know
# the Baseplate server is in shutdown.
self._close_connection = value

def read_requestline(self) -> str | None:
real_read_requestline = gevent.spawn(super().read_requestline)
ready = gevent.wait([self._shutdown_event, real_read_requestline], count=1)

if self._shutdown_event in ready:
real_read_requestline.kill()
# None triggers the base class to close the connection.
return None

ret = real_read_requestline.get()
if isinstance(ret, BaseException):
raise ret
return ret

def handle_one_request(
self,
) -> (
# 'None' is used to indicate that the connection should be closed by the caller.
None
# 'True' is used to indicate that the connection should be kept open for future requests.
| Literal[True]
# Tuple of status line and response body is used for returning an error response.
| tuple[str, bytes]
):
ret = super().handle_one_request()
if ret is True and self._shutdown_event.is_set():
return None
return ret


def make_server(server_config: dict[str, str], listener: socket.socket, app: Any) -> StreamServer:
"""Make a gevent server for WSGI apps."""
# pylint: disable=maybe-no-member
Expand All @@ -35,11 +125,19 @@ def make_server(server_config: dict[str, str], listener: socket.socket, app: Any
pool = Pool()
log = LoggingLogAdapter(logger, level=logging.DEBUG)

kwargs: dict[str, Any] = {}
kwargs: dict[str, Any] = {
"handler_class": BaseplateWSGIHandler,
}
if cfg.handler:
kwargs["handler_class"] = _load_factory(cfg.handler, default_name=None)
if not issubclass(kwargs["handler_class"], BaseplateWSGIHandler):
logger.warning(
"Custom handler %r is not a subclass of BaseplateWSGIHandler. "
"This may prevent proper shutdown behavior.",
cfg.handler,
)

server = WSGIServer(
server = BaseplateWSGIServer(
listener,
application=app,
spawn=pool,
Expand Down
162 changes: 160 additions & 2 deletions tests/integration/requests_tests.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
from __future__ import annotations

import contextlib
import dataclasses
import importlib
import logging
import time
import urllib.parse

import gevent
import pytest
import requests
import urllib3.connection
from gevent.pywsgi import WSGIServer
from pyramid.config import Configurator
from pyramid.httpexceptions import HTTPNoContent

Expand Down Expand Up @@ -34,6 +42,8 @@ def gevent_socket():
@pytest.fixture
def http_server(gevent_socket):
class HttpServer:
server: WSGIServer

def __init__(self, address):
self.url = f"http://{address[0]}:{address[1]}/"
self.requests = []
Expand All @@ -56,8 +66,8 @@ def handle_request(self, request):
configurator.add_view(http_server.handle_request, route_name="test_view", renderer="json")
wsgi_app = configurator.make_wsgi_app()

server = make_server({"stop_timeout": "1 millisecond"}, listener, wsgi_app)
server_greenlet = gevent.spawn(server.serve_forever)
http_server.server = make_server({"stop_timeout": "1 millisecond"}, listener, wsgi_app)
server_greenlet = gevent.spawn(http_server.server.serve_forever)
try:
yield http_server
finally:
Expand Down Expand Up @@ -182,3 +192,151 @@ def test_external_client_doesnt_send_headers(http_server):
assert "X-Parent" not in http_server.requests[0].headers
assert "X-Span" not in http_server.requests[0].headers
assert "X-Edge-Request" not in http_server.requests[0].headers


def _is_connected(conn: urllib3.connection.HTTPConnection) -> bool:
"""Backport of urllib3.connection.HTTPConnection.is_connected().

Based on urllib3 v2.2.3:
https://github.com/urllib3/urllib3/blob/f9d37add7983d441b151146db447318dff4186c9/src/urllib3/connection.py#L299
"""
if conn.sock is None:
return False
return not urllib3.util.wait_for_read(conn.sock, timeout=0.0)


@dataclasses.dataclass
class KeepaliveClientResult:
requests_completed: int = 0
connection_closed_time: float | None = None


def _keepalive_client(
url: str, ready_event: gevent.event.Event, wait_time: float
) -> KeepaliveClientResult:
"""HTTP client that makes requests forever over a single keepalive connection.

Returns iff the connection is closed. Otherwise, it must be killed.
"""
parsed = urllib.parse.urlparse(url)
with contextlib.closing(
urllib3.connection.HTTPConnection(parsed.hostname, parsed.port, timeout=1),
) as conn:
ret = KeepaliveClientResult()
conn.connect()
ready_event.set()

last_request_time = None
while True:
if not _is_connected(conn):
print("Client lost connection to server, stopping request loop.")
ret.connection_closed_time = time.time()
break

if last_request_time is None or time.time() - last_request_time >= wait_time:
print("Client making request.")
last_request_time = time.time()
conn.request("GET", "/")
response = conn.getresponse()
response.close()

assert response.status == 204
print("Client got expected response.")
ret.requests_completed += 1

# Sleeping for a short time rather than the full `wait_time` so we
# can notice if the connection closes.
gevent.sleep(0.01)

return ret


@pytest.mark.parametrize(
(
"delay_between_requests",
"min_expected_successful_requests",
"max_expected_successful_requests",
),
(
# Client that sends a request every 0.1 seconds.
(
0.1,
# ~10 requests in 1 second.
5,
15,
),
# Client that sends one request then sleeps forever while keeping the
# connection open.
#
# This is used to test that the server closes keepalive connections
# even if they remain idle for the entire shutdown period.
(
999999999,
# The client should make exactly one request.
1,
1,
),
),
)
def test_shutdown_closes_existing_keepalive_connection(
http_server,
delay_between_requests,
min_expected_successful_requests,
max_expected_successful_requests,
):
"""Ensure that the server closes keepalive connections when shutting down.

By default, calling `stop()` on a gevent WSGIServer prevents new
connections but does not close existing ones. This allows clients to
continue sending new requests over existing connections right up until the
server's stop_timeout, resulting in slow shutdown and connections being
killed mid-flight, which causes user-facing errors.

We work around this by subclassing WSGIHandler and (a) disabling keepalive
when the server is in shutdown, and (b) closing existing idle connections
when the server enters shutdown.
"""
http_server.server.stop_timeout = 10

ready_event = gevent.event.Event()
client_greenlet = gevent.spawn(
_keepalive_client,
http_server.url,
ready_event,
delay_between_requests,
)
try:
print("Waiting for client to connect...")
ready_event.wait()

print("Client connected, now waiting while it makes requests.")
gevent.sleep(1)

print("Triggering server shutdown...")
shutdown_start = time.time()
http_server.server.stop()
finally:
# Server usually exits before the client notices the connection closed,
# so give it a second to finish.
client_greenlet.join(timeout=5)

print(f"Shutdown completed after {time.time() - shutdown_start:.1f}s.")

ret = client_greenlet.get()
if isinstance(ret, BaseException):
# This usually happens with GreenletExit.
raise ret

print("Requests completed:", ret.requests_completed)
connection_closed_delay = ret.connection_closed_time - shutdown_start
print("Connection closed delay:", connection_closed_delay)

assert (
min_expected_successful_requests
<= ret.requests_completed
<= max_expected_successful_requests
)

# connection_closed_time should be within ~2 seconds after the shutdown
# start time, but not before it.
assert 0 <= connection_closed_delay <= 2