Skip to content

Commit

Permalink
refactor(api): use tokio_util CancellationToken instead of mpsc channel
Browse files Browse the repository at this point in the history
  • Loading branch information
ymgyt committed Jun 10, 2024
1 parent 8c1b174 commit 3ca15bf
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 33 deletions.
6 changes: 3 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ serde_json = { version = "1.0.111" }
tempfile = { version = "3" }
thiserror = { version = "1.0.61" }
tokio = { version = "1.35", default-features = false }
tokio-util = { version = "0.7.11" }
tracing = { version = "0.1.40" }
tracing-subscriber = { version = "0.3.18", features = ["smallvec", "fmt", "ansi", "std", "env-filter", "time"], default-features = false }
url = { version = "2.5.0" }
Expand Down
3 changes: 2 additions & 1 deletion crates/synd_api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ serde = { workspace = true }
serde_json = "1.0.111"
supports-color = { version = "3.0.0" }
thiserror = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "sync"] }
tokio-metrics = { version = "0.3.1", default-features = false, features = ["rt"] }
tokio-util = { workspace = true }
tower = { version = "0.4.13", default_features = false, features = ["limit", "timeout"] }
tower-http = { version = "0.5.1", default_features = false, features = ["trace", "sensitive-headers", "cors", "limit"] }
tracing = { workspace = true }
Expand Down
2 changes: 1 addition & 1 deletion crates/synd_api/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ async fn main() {
Err(err) => err.exit(),
};
let _guard = init_tracing(&args.o11y);
let shutdown = Shutdown::watch_signal();
let shutdown = Shutdown::watch_signal(tokio::signal::ctrl_c(), || {});

init_file_descriptor_limit();

Expand Down
141 changes: 114 additions & 27 deletions crates/synd_api/src/shutdown.rs
Original file line number Diff line number Diff line change
@@ -1,52 +1,139 @@
use std::time::Duration;
use std::{future::Future, io, time::Duration};

use axum_server::Handle;
use tokio::sync::broadcast::{self, Receiver, Sender};
use tokio_util::sync::CancellationToken;

/// `CancellationToken` wrapper
pub struct Shutdown {
tx: Sender<()>,
rx: Receiver<()>,
root: CancellationToken,
handle: Handle,
}

impl Shutdown {
pub fn watch_signal() -> Self {
let (tx, rx) = broadcast::channel(2);
let handle = Handle::new();

let tx2 = tx.clone();
let handle2 = handle.clone();
/// When the given signal Future is resolved, call the `cancel` method of the held `CancellationToken`.
pub fn watch_signal<Fut, F>(signal: Fut, on_graceful_shutdown: F) -> Self
where
F: FnOnce() + Send + 'static,
Fut: Future<Output = io::Result<()>> + Send + 'static,
{
// Root cancellation token which is cancelled when signal received
let root = CancellationToken::new();
let notify = root.clone();
tokio::spawn(async move {
match tokio::signal::ctrl_c().await {
Ok(()) => tracing::info!("Received ctrl-c signal"),
match signal.await {
Ok(()) => tracing::info!("Received signal"),

Err(err) => tracing::error!("Failed to handle signal {err}"),
}
// Signal graceful shutdown to axum_server
handle2.graceful_shutdown(Some(Duration::from_secs(3)));
tx2.send(()).ok();
notify.cancel();
});

// Notify graceful shutdown to axum server
let ct = root.clone();
let handle = axum_server::Handle::new();
let notify = handle.clone();
tokio::spawn(async move {
ct.cancelled().await;
on_graceful_shutdown();
tracing::info!("Notify axum handler to shutdown");
notify.graceful_shutdown(Some(Duration::from_secs(3)));
});

Self { tx, rx, handle }
Self { root, handle }
}

pub fn shutdown(&self) {
self.root.cancel();
}

pub fn into_handle(self) -> Handle {
self.handle
}

pub async fn notify(mut self) {
self.rx.recv().await.ok();
pub fn cancellation_token(&self) -> CancellationToken {
self.root.clone()
}
}

pub fn shutdown(&self) {
self.handle.shutdown();
#[cfg(test)]
mod tests {
use std::{
io::ErrorKind,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};

use futures_util::future;

use super::*;

#[tokio::test(flavor = "multi_thread")]
async fn signal_trigger_graceful_shutdown() {
for signal_result in [Ok(()), Err(io::Error::from(ErrorKind::Other))] {
let called = Arc::new(AtomicBool::new(false));
let called_cloned = Arc::clone(&called);
let on_graceful_shutdown = move || {
called_cloned.store(true, Ordering::Relaxed);
};
let (tx, rx) = tokio::sync::oneshot::channel::<io::Result<()>>();
let s = Shutdown::watch_signal(
async move {
rx.await.unwrap().ok();
signal_result
},
on_graceful_shutdown,
);
let ct = s.cancellation_token();

// Mock signal triggered
tx.send(Ok(())).unwrap();

// Check cancellation token is cancelled and axum handler called
let mut ok = false;
let mut count = 0;
loop {
count += 1;
if count >= 10 {
break;
}
if s.root.is_cancelled() && ct.is_cancelled() && called.load(Ordering::Relaxed) {
ok = true;
break;
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
assert!(ok, "cancelation does not work");
}
}
}

impl Clone for Shutdown {
fn clone(&self) -> Self {
let rx = self.tx.subscribe();
let tx = self.tx.clone();
let handle = self.handle.clone();
Self { tx, rx, handle }
#[tokio::test(flavor = "multi_thread")]
async fn shutdown_trigger_graceful_shutdown() {
let called = Arc::new(AtomicBool::new(false));
let called_cloned = Arc::clone(&called);
let on_graceful_shutdown = move || {
called_cloned.store(true, Ordering::Relaxed);
};
let s = Shutdown::watch_signal(future::pending(), on_graceful_shutdown);
let ct = s.cancellation_token();

s.shutdown();

// Check cancellation token is cancelled and axum handler called
let mut ok = false;
let mut count = 0;
loop {
count += 1;
if count >= 10 {
break;
}
if s.root.is_cancelled() && ct.is_cancelled() && called.load(Ordering::Relaxed) {
ok = true;
break;
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
assert!(ok, "cancelation does not work");
}
}
3 changes: 2 additions & 1 deletion crates/synd_term/tests/test/helper.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{io, path::PathBuf, sync::Once, time::Duration};

use futures_util::future;
use ratatui::backend::TestBackend;
use synd_api::{
args::{CacheOptions, KvsdOptions, ServeOptions, TlsOptions},
Expand Down Expand Up @@ -265,7 +266,7 @@ pub async fn serve_api(
tokio::spawn(synd_api::serve::serve(
listener,
dep,
Shutdown::watch_signal(),
Shutdown::watch_signal(future::pending(), || {}),
));

Ok(())
Expand Down

0 comments on commit 3ca15bf

Please sign in to comment.