-
-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Xuanwo <github@xuanwo.io>
- Loading branch information
Showing
2 changed files
with
373 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,370 @@ | ||
use std::future::Future; | ||
use std::pin::{pin, Pin}; | ||
use std::task::Context; | ||
use std::task::Poll; | ||
use std::time::Duration; | ||
|
||
use futures_core::ready; | ||
use pin_project::pin_project; | ||
|
||
use crate::backoff::BackoffBuilder; | ||
use crate::Backoff; | ||
|
||
/// RetryableWithContext will add retry support for functions that produces a futures with results | ||
/// and context. | ||
/// | ||
/// That means all types that implement `FnMut(Ctx) -> impl Future<Output = (Ctx, Result<T, E>)>` | ||
/// will be able to use `retry`. | ||
/// | ||
/// This will allow users to pass a context to the function and return it back while retry finish. | ||
/// | ||
/// # Example | ||
/// | ||
/// Without context, we could meet errors like the following: | ||
/// | ||
/// ```shell | ||
/// error: captured variable cannot escape `FnMut` closure body | ||
/// --> src/retry.rs:404:27 | ||
/// | | ||
/// 400 | let mut test = Test; | ||
/// | -------- variable defined here | ||
/// ... | ||
/// 404 | let result = { || async { test.hello().await } } | ||
/// | - ^^^^^^^^----^^^^^^^^^^^^^^^^ | ||
/// | | | | | ||
/// | | | variable captured here | ||
/// | | returns an `async` block that contains a reference to a captured variable, which then escapes the closure body | ||
/// | inferred to be a `FnMut` closure | ||
/// | | ||
/// = note: `FnMut` closures only have access to their captured variables while they are executing... | ||
/// = note: ...therefore, they cannot allow references to captured variables to escape | ||
/// ``` | ||
/// | ||
/// But with context support, we can implement in this way: | ||
/// | ||
/// ```no_run | ||
/// use anyhow::anyhow; | ||
/// use anyhow::Result; | ||
/// use backon::ExponentialBuilder; | ||
/// use backon::RetryableWithContext; | ||
/// | ||
/// struct Test; | ||
/// | ||
/// impl Test { | ||
/// async fn hello(&mut self) -> Result<usize> { | ||
/// Err(anyhow!("not retryable")) | ||
/// } | ||
/// } | ||
/// | ||
/// #[tokio::main(flavor = "current_thread")] | ||
/// async fn main() -> Result<()> { | ||
/// let mut test = Test; | ||
/// | ||
/// // (Test, Result<usize>) | ||
/// let (_, result) = { | ||
/// |mut v: Test| async { | ||
/// let res = v.hello().await; | ||
/// (v, res) | ||
/// } | ||
/// } | ||
/// .retry(&ExponentialBuilder::default()) | ||
/// .context(test) | ||
/// .await; | ||
/// | ||
/// Ok(()) | ||
/// } | ||
/// ``` | ||
pub trait RetryableWithContext< | ||
B: BackoffBuilder, | ||
T, | ||
E, | ||
Ctx, | ||
Fut: Future<Output = (Ctx, Result<T, E>)>, | ||
FutureFn: FnMut(Ctx) -> Fut, | ||
> | ||
{ | ||
/// Generate a new retry | ||
fn retry(self, builder: &B) -> Retry<B::Backoff, T, E, Ctx, Fut, FutureFn>; | ||
} | ||
|
||
impl<B, T, E, Ctx, Fut, FutureFn> RetryableWithContext<B, T, E, Ctx, Fut, FutureFn> for FutureFn | ||
where | ||
B: BackoffBuilder, | ||
Fut: Future<Output = (Ctx, Result<T, E>)>, | ||
FutureFn: FnMut(Ctx) -> Fut, | ||
{ | ||
fn retry(self, builder: &B) -> Retry<B::Backoff, T, E, Ctx, Fut, FutureFn> { | ||
Retry::new(self, builder.build()) | ||
} | ||
} | ||
|
||
/// Retry struct generated by [`Retryable`]. | ||
#[pin_project] | ||
pub struct Retry< | ||
B: Backoff, | ||
T, | ||
E, | ||
Ctx, | ||
Fut: Future<Output = (Ctx, Result<T, E>)>, | ||
FutureFn: FnMut(Ctx) -> Fut, | ||
RF = fn(&E) -> bool, | ||
NF = fn(&E, Duration), | ||
> { | ||
backoff: B, | ||
retryable: RF, | ||
notify: NF, | ||
future_fn: FutureFn, | ||
|
||
#[pin] | ||
state: State<T, E, Ctx, Fut>, | ||
} | ||
|
||
impl<B, T, E, Ctx, Fut, FutureFn> Retry<B, T, E, Ctx, Fut, FutureFn> | ||
where | ||
B: Backoff, | ||
Fut: Future<Output = (Ctx, Result<T, E>)>, | ||
FutureFn: FnMut(Ctx) -> Fut, | ||
{ | ||
/// Create a new retry. | ||
/// | ||
/// # Notes | ||
/// | ||
/// `context` must be set by `context` method before calling `await`. | ||
fn new(future_fn: FutureFn, backoff: B) -> Self { | ||
Retry { | ||
backoff, | ||
retryable: |_: &E| true, | ||
notify: |_: &E, _: Duration| {}, | ||
future_fn, | ||
state: State::Idle(None), | ||
} | ||
} | ||
} | ||
|
||
impl<B, T, E, Ctx, Fut, FutureFn, RF, NF> Retry<B, T, E, Ctx, Fut, FutureFn, RF, NF> | ||
where | ||
B: Backoff, | ||
Fut: Future<Output = (Ctx, Result<T, E>)>, | ||
FutureFn: FnMut(Ctx) -> Fut, | ||
RF: FnMut(&E) -> bool, | ||
NF: FnMut(&E, Duration), | ||
{ | ||
/// Set the context for retrying. | ||
pub fn context(self, context: Ctx) -> Retry<B, T, E, Ctx, Fut, FutureFn, RF, NF> { | ||
Retry { | ||
backoff: self.backoff, | ||
retryable: self.retryable, | ||
notify: self.notify, | ||
future_fn: self.future_fn, | ||
state: State::Idle(Some(context)), | ||
} | ||
} | ||
|
||
/// Set the conditions for retrying. | ||
/// | ||
/// If not specified, we treat all errors as retryable. | ||
/// | ||
/// # Examples | ||
/// | ||
/// ```no_run | ||
/// use anyhow::Result; | ||
/// use backon::ExponentialBuilder; | ||
/// use backon::Retryable; | ||
/// | ||
/// async fn fetch() -> Result<String> { | ||
/// Ok(reqwest::get("https://www.rust-lang.org") | ||
/// .await? | ||
/// .text() | ||
/// .await?) | ||
/// } | ||
/// | ||
/// #[tokio::main(flavor = "current_thread")] | ||
/// async fn main() -> Result<()> { | ||
/// let content = fetch | ||
/// .retry(&ExponentialBuilder::default()) | ||
/// .when(|e| e.to_string() == "EOF") | ||
/// .await?; | ||
/// println!("fetch succeeded: {}", content); | ||
/// | ||
/// Ok(()) | ||
/// } | ||
/// ``` | ||
pub fn when<RN: FnMut(&E) -> bool>( | ||
self, | ||
retryable: RN, | ||
) -> Retry<B, T, E, Ctx, Fut, FutureFn, RN, NF> { | ||
Retry { | ||
backoff: self.backoff, | ||
retryable, | ||
notify: self.notify, | ||
future_fn: self.future_fn, | ||
state: self.state, | ||
} | ||
} | ||
|
||
/// Set to notify for everything retrying. | ||
/// | ||
/// If not specified, this is a no-op. | ||
/// | ||
/// # Examples | ||
/// | ||
/// ```no_run | ||
/// use std::time::Duration; | ||
/// | ||
/// use anyhow::Result; | ||
/// use backon::ExponentialBuilder; | ||
/// use backon::Retryable; | ||
/// | ||
/// async fn fetch() -> Result<String> { | ||
/// Ok(reqwest::get("https://www.rust-lang.org") | ||
/// .await? | ||
/// .text() | ||
/// .await?) | ||
/// } | ||
/// | ||
/// #[tokio::main(flavor = "current_thread")] | ||
/// async fn main() -> Result<()> { | ||
/// let content = fetch | ||
/// .retry(&ExponentialBuilder::default()) | ||
/// .notify(|err: &anyhow::Error, dur: Duration| { | ||
/// println!("retrying error {:?} with sleeping {:?}", err, dur); | ||
/// }) | ||
/// .await?; | ||
/// println!("fetch succeeded: {}", content); | ||
/// | ||
/// Ok(()) | ||
/// } | ||
/// ``` | ||
pub fn notify<NN: FnMut(&E, Duration)>( | ||
self, | ||
notify: NN, | ||
) -> Retry<B, T, E, Ctx, Fut, FutureFn, RF, NN> { | ||
Retry { | ||
backoff: self.backoff, | ||
retryable: self.retryable, | ||
notify, | ||
future_fn: self.future_fn, | ||
state: self.state, | ||
} | ||
} | ||
} | ||
|
||
/// State maintains internal state of retry. | ||
/// | ||
/// # Notes | ||
/// | ||
/// `tokio::time::Sleep` is a very struct that occupy 640B, so we wrap it | ||
/// into a `Pin<Box<_>>` to avoid this enum too large. | ||
#[pin_project(project = StateProject)] | ||
enum State<T, E, Ctx, Fut: Future<Output = (Ctx, Result<T, E>)>> { | ||
Idle(Option<Ctx>), | ||
Polling(#[pin] Fut), | ||
// TODO: we need to support other sleeper | ||
Sleeping((Option<Ctx>, Pin<Box<tokio::time::Sleep>>)), | ||
} | ||
|
||
impl<B, T, E, Ctx, Fut, FutureFn, RF, NF> Future for Retry<B, T, E, Ctx, Fut, FutureFn, RF, NF> | ||
where | ||
B: Backoff, | ||
Fut: Future<Output = (Ctx, Result<T, E>)>, | ||
FutureFn: FnMut(Ctx) -> Fut, | ||
RF: FnMut(&E) -> bool, | ||
NF: FnMut(&E, Duration), | ||
{ | ||
type Output = (Ctx, Result<T, E>); | ||
|
||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { | ||
let mut this = self.project(); | ||
loop { | ||
let state = this.state.as_mut().project(); | ||
match state { | ||
StateProject::Idle(ctx) => { | ||
let ctx = ctx.take().expect("context must be valid"); | ||
let fut = (this.future_fn)(ctx); | ||
this.state.set(State::Polling(fut)); | ||
continue; | ||
} | ||
StateProject::Polling(fut) => { | ||
let (ctx, res) = ready!(fut.poll(cx)); | ||
match res { | ||
Ok(v) => return Poll::Ready((ctx, Ok(v))), | ||
Err(err) => { | ||
// If input error is not retryable, return error directly. | ||
if !(this.retryable)(&err) { | ||
return Poll::Ready((ctx, Err(err))); | ||
} | ||
match this.backoff.next() { | ||
None => return Poll::Ready((ctx, Err(err))), | ||
Some(dur) => { | ||
(this.notify)(&err, dur); | ||
this.state.set(State::Sleeping(( | ||
Some(ctx), | ||
Box::pin(tokio::time::sleep(dur)), | ||
))); | ||
continue; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
StateProject::Sleeping((ctx, sl)) => { | ||
ready!(pin!(sl).poll(cx)); | ||
let ctx = ctx.take().expect("context must be valid"); | ||
this.state.set(State::Idle(Some(ctx))); | ||
continue; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use std::time::Duration; | ||
|
||
use anyhow::anyhow; | ||
use tokio::sync::Mutex; | ||
|
||
use super::*; | ||
use crate::exponential::ExponentialBuilder; | ||
use anyhow::Result; | ||
|
||
struct Test; | ||
|
||
impl Test { | ||
async fn hello(&mut self) -> Result<usize> { | ||
Err(anyhow!("not retryable")) | ||
} | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_retry_with_not_retryable_error() -> Result<()> { | ||
let error_times = Mutex::new(0); | ||
|
||
let test = Test; | ||
|
||
let backoff = ExponentialBuilder::default().with_min_delay(Duration::from_millis(1)); | ||
|
||
let (_, result) = { | ||
|mut v: Test| async { | ||
let mut x = error_times.lock().await; | ||
*x += 1; | ||
|
||
let res = v.hello().await; | ||
(v, res) | ||
} | ||
} | ||
.retry(&backoff) | ||
.context(test) | ||
// Only retry If error message is `retryable` | ||
.when(|e| e.to_string() == "retryable") | ||
.await; | ||
|
||
assert!(result.is_err()); | ||
assert_eq!("not retryable", result.unwrap_err().to_string()); | ||
// `f` always returns error "not retryable", so it should be executed | ||
// only once. | ||
assert_eq!(*error_times.lock().await, 1); | ||
Ok(()) | ||
} | ||
} |