Skip to content

Commit

Permalink
Cooperative signal handling (#1600)
Browse files Browse the repository at this point in the history
* test desired signal behaviour

* capture and restore signal handlers

* ruff

* checks

* test asyncio handlers

* add note on signal handler handling

* remove legacy signal raising

* test SIGBREAK on windows

* remove test guard

* include convered branch

* Update docs/index.md

* Update docs/index.md

---------

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
  • Loading branch information
maxfischer2781 and Kludex authored Mar 19, 2024
1 parent f73b8be commit 9e32e8e
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 14 deletions.
67 changes: 67 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from __future__ import annotations

import asyncio
import contextlib
import signal
import sys
from typing import Callable, ContextManager, Generator

import pytest

from uvicorn.config import Config
from uvicorn.server import Server


# asyncio does NOT allow raising in signal handlers, so to detect
# raised signals raised a mutable `witness` receives the signal
@contextlib.contextmanager
def capture_signal_sync(sig: signal.Signals) -> Generator[list[int], None, None]:
"""Replace `sig` handling with a normal exception via `signal"""
witness: list[int] = []
original_handler = signal.signal(sig, lambda signum, frame: witness.append(signum))
yield witness
signal.signal(sig, original_handler)


@contextlib.contextmanager
def capture_signal_async(sig: signal.Signals) -> Generator[list[int], None, None]: # pragma: py-win32
"""Replace `sig` handling with a normal exception via `asyncio"""
witness: list[int] = []
original_handler = signal.getsignal(sig)
asyncio.get_running_loop().add_signal_handler(sig, witness.append, sig)
yield witness
signal.signal(sig, original_handler)


async def dummy_app(scope, receive, send): # pragma: py-win32
pass


if sys.platform == "win32":
signals = [signal.SIGBREAK]
signal_captures = [capture_signal_sync]
else:
signals = [signal.SIGTERM, signal.SIGINT]
signal_captures = [capture_signal_sync, capture_signal_async]


@pytest.mark.anyio
@pytest.mark.parametrize("exception_signal", signals)
@pytest.mark.parametrize("capture_signal", signal_captures)
async def test_server_interrupt(
exception_signal: signal.Signals, capture_signal: Callable[[signal.Signals], ContextManager[None]]
): # pragma: py-win32
"""Test interrupting a Server that is run explicitly inside asyncio"""

async def interrupt_running(srv: Server):
while not srv.started:
await asyncio.sleep(0.01)
signal.raise_signal(exception_signal)

server = Server(Config(app=dummy_app, loop="asyncio"))
asyncio.create_task(interrupt_running(server))
with capture_signal(exception_signal) as witness:
await server.serve()
assert witness
# set by the server's graceful exit handler
assert server.should_exit
39 changes: 25 additions & 14 deletions uvicorn/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import contextlib
import logging
import os
import platform
Expand All @@ -11,7 +12,7 @@
import time
from email.utils import formatdate
from types import FrameType
from typing import TYPE_CHECKING, Sequence, Union
from typing import TYPE_CHECKING, Generator, Sequence, Union

import click

Expand Down Expand Up @@ -57,11 +58,17 @@ def __init__(self, config: Config) -> None:
self.force_exit = False
self.last_notified = 0.0

self._captured_signals: list[int] = []

def run(self, sockets: list[socket.socket] | None = None) -> None:
self.config.setup_event_loop()
return asyncio.run(self.serve(sockets=sockets))

async def serve(self, sockets: list[socket.socket] | None = None) -> None:
with self.capture_signals():
await self._serve(sockets)

async def _serve(self, sockets: list[socket.socket] | None = None) -> None:
process_id = os.getpid()

config = self.config
Expand All @@ -70,8 +77,6 @@ async def serve(self, sockets: list[socket.socket] | None = None) -> None:

self.lifespan = config.lifespan_class(config)

self.install_signal_handlers()

message = "Started server process [%d]"
color_message = "Started server process [" + click.style("%d", fg="cyan") + "]"
logger.info(message, process_id, extra={"color_message": color_message})
Expand Down Expand Up @@ -302,22 +307,28 @@ async def _wait_tasks_to_complete(self) -> None:
for server in self.servers:
await server.wait_closed()

def install_signal_handlers(self) -> None:
@contextlib.contextmanager
def capture_signals(self) -> Generator[None, None, None]:
# Signals can only be listened to from the main thread.
if threading.current_thread() is not threading.main_thread():
# Signals can only be listened to from the main thread.
yield
return

loop = asyncio.get_event_loop()

# always use signal.signal, even if loop.add_signal_handler is available
# this allows to restore previous signal handlers later on
original_handlers = {sig: signal.signal(sig, self.handle_exit) for sig in HANDLED_SIGNALS}
try:
for sig in HANDLED_SIGNALS:
loop.add_signal_handler(sig, self.handle_exit, sig, None)
except NotImplementedError: # pragma: no cover
# Windows
for sig in HANDLED_SIGNALS:
signal.signal(sig, self.handle_exit)
yield
finally:
for sig, handler in original_handlers.items():
signal.signal(sig, handler)
# If we did gracefully shut down due to a signal, try to
# trigger the expected behaviour now; multiple signals would be
# done LIFO, see https://stackoverflow.com/questions/48434964
for captured_signal in reversed(self._captured_signals):
signal.raise_signal(captured_signal)

def handle_exit(self, sig: int, frame: FrameType | None) -> None:
self._captured_signals.append(sig)
if self.should_exit and sig == signal.SIGINT:
self.force_exit = True
else:
Expand Down

0 comments on commit 9e32e8e

Please sign in to comment.