-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(api): use tokio_util CancellationToken instead of mpsc channel
- Loading branch information
Showing
6 changed files
with
123 additions
and
33 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,52 +1,139 @@ | ||
use std::time::Duration; | ||
use std::{future::Future, io, time::Duration}; | ||
|
||
use axum_server::Handle; | ||
use tokio::sync::broadcast::{self, Receiver, Sender}; | ||
use tokio_util::sync::CancellationToken; | ||
|
||
/// `CancellationToken` wrapper | ||
pub struct Shutdown { | ||
tx: Sender<()>, | ||
rx: Receiver<()>, | ||
root: CancellationToken, | ||
handle: Handle, | ||
} | ||
|
||
impl Shutdown { | ||
pub fn watch_signal() -> Self { | ||
let (tx, rx) = broadcast::channel(2); | ||
let handle = Handle::new(); | ||
|
||
let tx2 = tx.clone(); | ||
let handle2 = handle.clone(); | ||
/// When the given signal Future is resolved, call the `cancel` method of the held `CancellationToken`. | ||
pub fn watch_signal<Fut, F>(signal: Fut, on_graceful_shutdown: F) -> Self | ||
where | ||
F: FnOnce() + Send + 'static, | ||
Fut: Future<Output = io::Result<()>> + Send + 'static, | ||
{ | ||
// Root cancellation token which is cancelled when signal received | ||
let root = CancellationToken::new(); | ||
let notify = root.clone(); | ||
tokio::spawn(async move { | ||
match tokio::signal::ctrl_c().await { | ||
Ok(()) => tracing::info!("Received ctrl-c signal"), | ||
match signal.await { | ||
Ok(()) => tracing::info!("Received signal"), | ||
|
||
Err(err) => tracing::error!("Failed to handle signal {err}"), | ||
} | ||
// Signal graceful shutdown to axum_server | ||
handle2.graceful_shutdown(Some(Duration::from_secs(3))); | ||
tx2.send(()).ok(); | ||
notify.cancel(); | ||
}); | ||
|
||
// Notify graceful shutdown to axum server | ||
let ct = root.clone(); | ||
let handle = axum_server::Handle::new(); | ||
let notify = handle.clone(); | ||
tokio::spawn(async move { | ||
ct.cancelled().await; | ||
on_graceful_shutdown(); | ||
tracing::info!("Notify axum handler to shutdown"); | ||
notify.graceful_shutdown(Some(Duration::from_secs(3))); | ||
}); | ||
|
||
Self { tx, rx, handle } | ||
Self { root, handle } | ||
} | ||
|
||
pub fn shutdown(&self) { | ||
self.root.cancel(); | ||
} | ||
|
||
pub fn into_handle(self) -> Handle { | ||
self.handle | ||
} | ||
|
||
pub async fn notify(mut self) { | ||
self.rx.recv().await.ok(); | ||
pub fn cancellation_token(&self) -> CancellationToken { | ||
self.root.clone() | ||
} | ||
} | ||
|
||
pub fn shutdown(&self) { | ||
self.handle.shutdown(); | ||
#[cfg(test)] | ||
mod tests { | ||
use std::{ | ||
io::ErrorKind, | ||
sync::{ | ||
atomic::{AtomicBool, Ordering}, | ||
Arc, | ||
}, | ||
}; | ||
|
||
use futures_util::future; | ||
|
||
use super::*; | ||
|
||
#[tokio::test(flavor = "multi_thread")] | ||
async fn signal_trigger_graceful_shutdown() { | ||
for signal_result in [Ok(()), Err(io::Error::from(ErrorKind::Other))] { | ||
let called = Arc::new(AtomicBool::new(false)); | ||
let called_cloned = Arc::clone(&called); | ||
let on_graceful_shutdown = move || { | ||
called_cloned.store(true, Ordering::Relaxed); | ||
}; | ||
let (tx, rx) = tokio::sync::oneshot::channel::<io::Result<()>>(); | ||
let s = Shutdown::watch_signal( | ||
async move { | ||
rx.await.unwrap().ok(); | ||
signal_result | ||
}, | ||
on_graceful_shutdown, | ||
); | ||
let ct = s.cancellation_token(); | ||
|
||
// Mock signal triggered | ||
tx.send(Ok(())).unwrap(); | ||
|
||
// Check cancellation token is cancelled and axum handler called | ||
let mut ok = false; | ||
let mut count = 0; | ||
loop { | ||
count += 1; | ||
if count >= 10 { | ||
break; | ||
} | ||
if s.root.is_cancelled() && ct.is_cancelled() && called.load(Ordering::Relaxed) { | ||
ok = true; | ||
break; | ||
} | ||
tokio::time::sleep(Duration::from_millis(100)).await; | ||
} | ||
assert!(ok, "cancelation does not work"); | ||
} | ||
} | ||
} | ||
|
||
impl Clone for Shutdown { | ||
fn clone(&self) -> Self { | ||
let rx = self.tx.subscribe(); | ||
let tx = self.tx.clone(); | ||
let handle = self.handle.clone(); | ||
Self { tx, rx, handle } | ||
#[tokio::test(flavor = "multi_thread")] | ||
async fn shutdown_trigger_graceful_shutdown() { | ||
let called = Arc::new(AtomicBool::new(false)); | ||
let called_cloned = Arc::clone(&called); | ||
let on_graceful_shutdown = move || { | ||
called_cloned.store(true, Ordering::Relaxed); | ||
}; | ||
let s = Shutdown::watch_signal(future::pending(), on_graceful_shutdown); | ||
let ct = s.cancellation_token(); | ||
|
||
s.shutdown(); | ||
|
||
// Check cancellation token is cancelled and axum handler called | ||
let mut ok = false; | ||
let mut count = 0; | ||
loop { | ||
count += 1; | ||
if count >= 10 { | ||
break; | ||
} | ||
if s.root.is_cancelled() && ct.is_cancelled() && called.load(Ordering::Relaxed) { | ||
ok = true; | ||
break; | ||
} | ||
tokio::time::sleep(Duration::from_millis(100)).await; | ||
} | ||
assert!(ok, "cancelation does not work"); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters