diff --git a/Cargo.lock b/Cargo.lock index 7f85f322..bc9b45d5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3277,6 +3277,7 @@ dependencies = [ "thiserror", "tokio", "tokio-metrics", + "tokio-util", "tower", "tower-http", "tracing", @@ -3656,9 +3657,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.10" +version = "0.7.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5419f34732d9eb6ee4c3578b7989078579b7f039cbbb9ca2c4da015749371e15" +checksum = "9cf6b47b3771c49ac75ad09a6162f53ad4b8088b76ac60e8ec1455b31a189fe1" dependencies = [ "bytes", "futures-core", @@ -3666,7 +3667,6 @@ dependencies = [ "futures-sink", "pin-project-lite", "tokio", - "tracing", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 732de819..f7adba74 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,7 @@ serde_json = { version = "1.0.111" } tempfile = { version = "3" } thiserror = { version = "1.0.61" } tokio = { version = "1.35", default-features = false } +tokio-util = { version = "0.7.11" } tracing = { version = "0.1.40" } tracing-subscriber = { version = "0.3.18", features = ["smallvec", "fmt", "ansi", "std", "env-filter", "time"], default-features = false } url = { version = "2.5.0" } diff --git a/crates/synd_api/Cargo.toml b/crates/synd_api/Cargo.toml index 1177edc3..e9c942c3 100644 --- a/crates/synd_api/Cargo.toml +++ b/crates/synd_api/Cargo.toml @@ -42,8 +42,9 @@ serde = { workspace = true } serde_json = "1.0.111" supports-color = { version = "3.0.0" } thiserror = { workspace = true } -tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } +tokio = { workspace = true, features = ["macros", "rt-multi-thread", "sync"] } tokio-metrics = { version = "0.3.1", default-features = false, features = ["rt"] } +tokio-util = { workspace = true } tower = { version = "0.4.13", default_features = false, features = ["limit", "timeout"] } tower-http = { version = "0.5.1", default_features = false, features = ["trace", "sensitive-headers", "cors", "limit"] } tracing = { workspace = true } diff --git a/crates/synd_api/src/main.rs b/crates/synd_api/src/main.rs index 63b2cb86..2f7523b1 100644 --- a/crates/synd_api/src/main.rs +++ b/crates/synd_api/src/main.rs @@ -121,7 +121,7 @@ async fn main() { Err(err) => err.exit(), }; let _guard = init_tracing(&args.o11y); - let shutdown = Shutdown::watch_signal(); + let shutdown = Shutdown::watch_signal(tokio::signal::ctrl_c(), || {}); init_file_descriptor_limit(); diff --git a/crates/synd_api/src/shutdown.rs b/crates/synd_api/src/shutdown.rs index 2e03e995..18de7a96 100644 --- a/crates/synd_api/src/shutdown.rs +++ b/crates/synd_api/src/shutdown.rs @@ -1,52 +1,139 @@ -use std::time::Duration; +use std::{future::Future, io, time::Duration}; use axum_server::Handle; -use tokio::sync::broadcast::{self, Receiver, Sender}; +use tokio_util::sync::CancellationToken; +/// `CancellationToken` wrapper pub struct Shutdown { - tx: Sender<()>, - rx: Receiver<()>, + root: CancellationToken, handle: Handle, } impl Shutdown { - pub fn watch_signal() -> Self { - let (tx, rx) = broadcast::channel(2); - let handle = Handle::new(); - - let tx2 = tx.clone(); - let handle2 = handle.clone(); + /// When the given signal Future is resolved, call the `cancel` method of the held `CancellationToken`. + pub fn watch_signal(signal: Fut, on_graceful_shutdown: F) -> Self + where + F: FnOnce() + Send + 'static, + Fut: Future> + Send + 'static, + { + // Root cancellation token which is cancelled when signal received + let root = CancellationToken::new(); + let notify = root.clone(); tokio::spawn(async move { - match tokio::signal::ctrl_c().await { - Ok(()) => tracing::info!("Received ctrl-c signal"), + match signal.await { + Ok(()) => tracing::info!("Received signal"), + Err(err) => tracing::error!("Failed to handle signal {err}"), } - // Signal graceful shutdown to axum_server - handle2.graceful_shutdown(Some(Duration::from_secs(3))); - tx2.send(()).ok(); + notify.cancel(); + }); + + // Notify graceful shutdown to axum server + let ct = root.clone(); + let handle = axum_server::Handle::new(); + let notify = handle.clone(); + tokio::spawn(async move { + ct.cancelled().await; + on_graceful_shutdown(); + tracing::info!("Notify axum handler to shutdown"); + notify.graceful_shutdown(Some(Duration::from_secs(3))); }); - Self { tx, rx, handle } + Self { root, handle } + } + + pub fn shutdown(&self) { + self.root.cancel(); } pub fn into_handle(self) -> Handle { self.handle } - pub async fn notify(mut self) { - self.rx.recv().await.ok(); + pub fn cancellation_token(&self) -> CancellationToken { + self.root.clone() } +} - pub fn shutdown(&self) { - self.handle.shutdown(); +#[cfg(test)] +mod tests { + use std::{ + io::ErrorKind, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + }; + + use futures_util::future; + + use super::*; + + #[tokio::test(flavor = "multi_thread")] + async fn signal_trigger_graceful_shutdown() { + for signal_result in [Ok(()), Err(io::Error::from(ErrorKind::Other))] { + let called = Arc::new(AtomicBool::new(false)); + let called_cloned = Arc::clone(&called); + let on_graceful_shutdown = move || { + called_cloned.store(true, Ordering::Relaxed); + }; + let (tx, rx) = tokio::sync::oneshot::channel::>(); + let s = Shutdown::watch_signal( + async move { + rx.await.unwrap().ok(); + signal_result + }, + on_graceful_shutdown, + ); + let ct = s.cancellation_token(); + + // Mock signal triggered + tx.send(Ok(())).unwrap(); + + // Check cancellation token is cancelled and axum handler called + let mut ok = false; + let mut count = 0; + loop { + count += 1; + if count >= 10 { + break; + } + if s.root.is_cancelled() && ct.is_cancelled() && called.load(Ordering::Relaxed) { + ok = true; + break; + } + tokio::time::sleep(Duration::from_millis(100)).await; + } + assert!(ok, "cancelation does not work"); + } } -} -impl Clone for Shutdown { - fn clone(&self) -> Self { - let rx = self.tx.subscribe(); - let tx = self.tx.clone(); - let handle = self.handle.clone(); - Self { tx, rx, handle } + #[tokio::test(flavor = "multi_thread")] + async fn shutdown_trigger_graceful_shutdown() { + let called = Arc::new(AtomicBool::new(false)); + let called_cloned = Arc::clone(&called); + let on_graceful_shutdown = move || { + called_cloned.store(true, Ordering::Relaxed); + }; + let s = Shutdown::watch_signal(future::pending(), on_graceful_shutdown); + let ct = s.cancellation_token(); + + s.shutdown(); + + // Check cancellation token is cancelled and axum handler called + let mut ok = false; + let mut count = 0; + loop { + count += 1; + if count >= 10 { + break; + } + if s.root.is_cancelled() && ct.is_cancelled() && called.load(Ordering::Relaxed) { + ok = true; + break; + } + tokio::time::sleep(Duration::from_millis(100)).await; + } + assert!(ok, "cancelation does not work"); } } diff --git a/crates/synd_term/tests/test/helper.rs b/crates/synd_term/tests/test/helper.rs index e18109e7..54a03205 100644 --- a/crates/synd_term/tests/test/helper.rs +++ b/crates/synd_term/tests/test/helper.rs @@ -1,5 +1,6 @@ use std::{io, path::PathBuf, sync::Once, time::Duration}; +use futures_util::future; use ratatui::backend::TestBackend; use synd_api::{ args::{CacheOptions, KvsdOptions, ServeOptions, TlsOptions}, @@ -265,7 +266,7 @@ pub async fn serve_api( tokio::spawn(synd_api::serve::serve( listener, dep, - Shutdown::watch_signal(), + Shutdown::watch_signal(future::pending(), || {}), )); Ok(())