Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support graceful shutdown on serve #2398

Merged
merged 11 commits into from
Dec 29, 2023
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ default-members = ["axum", "axum-*"]
# Example has been deleted, but README.md remains
exclude = ["examples/async-graphql"]
resolver = "2"

[patch.crates-io]
hyper = { git = "https://github.com/hyperium/hyper", rev = "cf68ea902749e" }
hyper-util = { git = "https://github.com/hyperium/hyper-util", rev = "64f896695c0f1" }
12 changes: 12 additions & 0 deletions axum/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,15 @@ macro_rules! all_the_tuples {
$name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15], T16);
};
}

#[cfg(feature = "tracing")]
macro_rules! trace {
($($tt:tt)*) => {
tracing::trace!($($tt)*)
};
}

#[cfg(not(feature = "tracing"))]
macro_rules! trace {
($($tt:tt)*) => {};
}
172 changes: 166 additions & 6 deletions axum/src/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,28 @@

use std::{
convert::Infallible,
fmt::Debug,
future::{Future, IntoFuture},
io,
marker::PhantomData,
net::SocketAddr,
pin::Pin,
pin::{pin, Pin},
sync::Arc,
task::{Context, Poll},
};

use axum_core::{body::Body, extract::Request, response::Response};
use futures_util::future::poll_fn;
use futures_util::{future::poll_fn, FutureExt};
use hyper::body::Incoming;
use hyper_util::{
rt::{TokioExecutor, TokioIo},
server::conn::auto::Builder,
};
use pin_project_lite::pin_project;
use tokio::net::{TcpListener, TcpStream};
use tokio::{
net::{TcpListener, TcpStream},
sync::watch,
};
use tower::util::{Oneshot, ServiceExt};
use tower_service::Service;

Expand Down Expand Up @@ -109,9 +114,25 @@ pub struct Serve<M, S> {
}

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S> std::fmt::Debug for Serve<M, S>
impl<M, S> Serve<M, S> {
/// TODO
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved
pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<M, S, F>
where
F: Future<Output = ()> + Send + 'static,
{
WithGracefulShutdown {
tcp_listener: self.tcp_listener,
make_service: self.make_service,
signal,
_marker: PhantomData,
}
}
}

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S> Debug for Serve<M, S>
where
M: std::fmt::Debug,
M: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self {
Expand Down Expand Up @@ -166,7 +187,7 @@ where
service: tower_service,
};

tokio::task::spawn(async move {
tokio::spawn(async move {
match Builder::new(TokioExecutor::new())
// upgrades needed for websockets
.serve_connection_with_upgrades(tcp_stream, hyper_service)
Expand All @@ -187,6 +208,145 @@ where
}
}

/// Serve future with graceful shutdown enabled.
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
pub struct WithGracefulShutdown<M, S, F> {
tcp_listener: TcpListener,
make_service: M,
signal: F,
_marker: PhantomData<S>,
}

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S, F> Debug for WithGracefulShutdown<M, S, F>
where
M: Debug,
S: Debug,
F: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self {
tcp_listener,
make_service,
signal,
_marker: _,
} = self;

f.debug_struct("WithGracefulShutdown")
.field("tcp_listener", tcp_listener)
.field("make_service", make_service)
.field("signal", signal)
.finish()
}
}

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S, F> IntoFuture for WithGracefulShutdown<M, S, F>
where
M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a>>>::Future: Send,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
F: Future<Output = ()> + Send + 'static,
{
type Output = io::Result<()>;
type IntoFuture = private::ServeFuture;

fn into_future(self) -> Self::IntoFuture {
let Self {
tcp_listener,
mut make_service,
signal,
_marker: _,
} = self;

let (signal_tx, signal_rx) = watch::channel(());
let signal_tx = Arc::new(signal_tx);
tokio::spawn(async move {
signal.await;
trace!("received graceful shutdown signal. Telling tasks to shutdown");
drop(signal_rx);
});

let (close_tx, close_rx) = watch::channel(());

private::ServeFuture(Box::pin(async move {
loop {
let (tcp_stream, remote_addr) = tokio::select! {
result = tcp_listener.accept() => {
result?
}
_ = signal_tx.closed() => {
trace!("signal received, not accepting new connections");
break;
}
};
let tcp_stream = TokioIo::new(tcp_stream);

trace!("connection {remote_addr} accepted");

poll_fn(|cx| make_service.poll_ready(cx))
.await
.unwrap_or_else(|err| match err {});

let tower_service = make_service
.call(IncomingStream {
tcp_stream: &tcp_stream,
remote_addr,
})
.await
.unwrap_or_else(|err| match err {});

let hyper_service = TowerToHyperService {
service: tower_service,
};

let signal_tx = Arc::clone(&signal_tx);

let close_rx = close_rx.clone();

tokio::spawn(async move {
let builder = Builder::new(TokioExecutor::new());
let conn = builder.serve_connection_with_upgrades(tcp_stream, hyper_service);
let mut conn = pin!(conn);

let mut signal_closed = pin!(signal_tx.closed().fuse());

loop {
tokio::select! {
result = conn.as_mut() => {
if let Err(_err) = result {
trace!("failed to serve connection: {_err:#}");
}
break;
}
_ = &mut signal_closed => {
trace!("signal received in task, starting graceful shutdown");
conn.as_mut().graceful_shutdown();
}
}
}

trace!("connection {remote_addr} closed");

drop(close_rx);
});
}

drop(close_rx);
drop(tcp_listener);

trace!(
"waiting for {} task(s) to finish",
close_tx.receiver_count()
);
close_tx.closed().await;

Ok(())
}))
}
}

mod private {
use std::{
future::Future,
Expand Down
2 changes: 1 addition & 1 deletion examples/graceful-shutdown/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ edition = "2021"
publish = false

[dependencies]
axum = { path = "../../axum" }
axum = { path = "../../axum", features = ["tracing"] }
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved
hyper = { version = "1.0", features = [] }
hyper-util = { version = "0.1", features = ["tokio", "server-auto", "http1"] }
tokio = { version = "1.0", features = ["full"] }
Expand Down
113 changes: 10 additions & 103 deletions examples/graceful-shutdown/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,24 @@

use std::time::Duration;

use axum::{extract::Request, routing::get, Router};
use hyper::body::Incoming;
use hyper_util::rt::TokioIo;
use axum::{routing::get, Router};
use tokio::net::TcpListener;
use tokio::signal;
use tokio::sync::watch;
use tokio::time::sleep;
use tower::Service;
use tower_http::timeout::TimeoutLayer;
use tower_http::trace::TraceLayer;
use tracing::debug;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

#[tokio::main]
async fn main() {
// Enable tracing.
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "example_graceful_shutdown=debug,tower_http=debug".into()),
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| {
"example_graceful_shutdown=debug,tower_http=debug,axum=trace".into()
}),
)
.with(tracing_subscriber::fmt::layer())
.with(tracing_subscriber::fmt::layer().without_time())
.init();

// Create a regular axum app.
Expand All @@ -48,100 +44,11 @@ async fn main() {
// Create a `TcpListener` using tokio.
let listener = TcpListener::bind("0.0.0.0:3000").await.unwrap();

// Create a watch channel to track tasks that are handling connections and wait for them to
// complete.
let (close_tx, close_rx) = watch::channel(());

// Continuously accept new connections.
loop {
let (socket, remote_addr) = tokio::select! {
// Either accept a new connection...
result = listener.accept() => {
result.unwrap()
}
// ...or wait to receive a shutdown signal and stop the accept loop.
_ = shutdown_signal() => {
debug!("signal received, not accepting new connections");
break;
}
};

debug!("connection {remote_addr} accepted");

// We don't need to call `poll_ready` because `Router` is always ready.
let tower_service = app.clone();

// Clone the watch receiver and move it into the task.
let close_rx = close_rx.clone();

// Spawn a task to handle the connection. That way we can serve multiple connections
// concurrently.
tokio::spawn(async move {
// Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio.
// `TokioIo` converts between them.
let socket = TokioIo::new(socket);

// Hyper also has its own `Service` trait and doesn't use tower. We can use
// `hyper::service::service_fn` to create a hyper `Service` that calls our app through
// `tower::Service::call`.
let hyper_service = hyper::service::service_fn(move |request: Request<Incoming>| {
// We have to clone `tower_service` because hyper's `Service` uses `&self` whereas
// tower's `Service` requires `&mut self`.
//
// We don't need to call `poll_ready` since `Router` is always ready.
tower_service.clone().call(request)
});

// `hyper_util::server::conn::auto::Builder` supports both http1 and http2 but doesn't
// support graceful so we have to use hyper directly and unfortunately pick between
// http1 and http2.
let conn = hyper::server::conn::http1::Builder::new()
.serve_connection(socket, hyper_service)
// `with_upgrades` is required for websockets.
.with_upgrades();

// `graceful_shutdown` requires a pinned connection.
let mut conn = std::pin::pin!(conn);

loop {
tokio::select! {
// Poll the connection. This completes when the client has closed the
// connection, graceful shutdown has completed, or we encounter a TCP error.
result = conn.as_mut() => {
if let Err(err) = result {
debug!("failed to serve connection: {err:#}");
}
break;
}
// Start graceful shutdown when we receive a shutdown signal.
//
// We use a loop to continue polling the connection to allow requests to finish
// after starting graceful shutdown. Our `Router` has `TimeoutLayer` so
// requests will finish after at most 10 seconds.
_ = shutdown_signal() => {
debug!("signal received, starting graceful shutdown");
conn.as_mut().graceful_shutdown();
}
}
}

debug!("connection {remote_addr} closed");

// Drop the watch receiver to signal to `main` that this task is done.
drop(close_rx);
});
}

// We only care about the watch receivers that were moved into the tasks so close the residual
// receiver.
drop(close_rx);

// Close the listener to stop accepting new connections.
drop(listener);

// Wait for all tasks to complete.
debug!("waiting for {} tasks to finish", close_tx.receiver_count());
close_tx.closed().await;
// Run the server with graceful shutdown
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await
.unwrap();
}

async fn shutdown_signal() {
Expand Down