Skip to content

Commit

Permalink
feat: Add retry with context
Browse files Browse the repository at this point in the history
Signed-off-by: Xuanwo <github@xuanwo.io>
  • Loading branch information
Xuanwo committed Mar 8, 2024
1 parent 8a6bd0a commit 60fe7a9
Show file tree
Hide file tree
Showing 2 changed files with 373 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,6 @@ pub use retry::Retryable;
mod blocking_retry;
pub use blocking_retry::BlockingRetry;
pub use blocking_retry::BlockingRetryable;

mod retry_with_context;
pub use retry_with_context::RetryableWithContext;
370 changes: 370 additions & 0 deletions src/retry_with_context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,370 @@
use std::future::Future;
use std::pin::{pin, Pin};
use std::task::Context;
use std::task::Poll;
use std::time::Duration;

use futures_core::ready;
use pin_project::pin_project;

use crate::backoff::BackoffBuilder;
use crate::Backoff;

/// RetryableWithContext will add retry support for functions that produces a futures with results
/// and context.
///
/// That means all types that implement `FnMut(Ctx) -> impl Future<Output = (Ctx, Result<T, E>)>`
/// will be able to use `retry`.
///
/// This will allow users to pass a context to the function and return it back while retry finish.
///
/// # Example
///
/// Without context, we could meet errors like the following:
///
/// ```shell
/// error: captured variable cannot escape `FnMut` closure body
/// --> src/retry.rs:404:27
/// |
/// 400 | let mut test = Test;
/// | -------- variable defined here
/// ...
/// 404 | let result = { || async { test.hello().await } }
/// | - ^^^^^^^^----^^^^^^^^^^^^^^^^
/// | | | |
/// | | | variable captured here
/// | | returns an `async` block that contains a reference to a captured variable, which then escapes the closure body
/// | inferred to be a `FnMut` closure
/// |
/// = note: `FnMut` closures only have access to their captured variables while they are executing...
/// = note: ...therefore, they cannot allow references to captured variables to escape
/// ```
///
/// But with context support, we can implement in this way:
///
/// ```no_run
/// use anyhow::anyhow;
/// use anyhow::Result;
/// use backon::ExponentialBuilder;
/// use backon::RetryableWithContext;
///
/// struct Test;
///
/// impl Test {
/// async fn hello(&mut self) -> Result<usize> {
/// Err(anyhow!("not retryable"))
/// }
/// }
///
/// #[tokio::main(flavor = "current_thread")]
/// async fn main() -> Result<()> {
/// let mut test = Test;
///
/// // (Test, Result<usize>)
/// let (_, result) = {
/// |mut v: Test| async {
/// let res = v.hello().await;
/// (v, res)
/// }
/// }
/// .retry(&ExponentialBuilder::default())
/// .context(test)
/// .await;
///
/// Ok(())
/// }
/// ```
pub trait RetryableWithContext<
B: BackoffBuilder,
T,
E,
Ctx,
Fut: Future<Output = (Ctx, Result<T, E>)>,
FutureFn: FnMut(Ctx) -> Fut,
>
{
/// Generate a new retry
fn retry(self, builder: &B) -> Retry<B::Backoff, T, E, Ctx, Fut, FutureFn>;
}

impl<B, T, E, Ctx, Fut, FutureFn> RetryableWithContext<B, T, E, Ctx, Fut, FutureFn> for FutureFn
where
B: BackoffBuilder,
Fut: Future<Output = (Ctx, Result<T, E>)>,
FutureFn: FnMut(Ctx) -> Fut,
{
fn retry(self, builder: &B) -> Retry<B::Backoff, T, E, Ctx, Fut, FutureFn> {
Retry::new(self, builder.build())
}
}

/// Retry struct generated by [`Retryable`].
#[pin_project]
pub struct Retry<
B: Backoff,
T,
E,
Ctx,
Fut: Future<Output = (Ctx, Result<T, E>)>,
FutureFn: FnMut(Ctx) -> Fut,
RF = fn(&E) -> bool,
NF = fn(&E, Duration),
> {
backoff: B,
retryable: RF,
notify: NF,
future_fn: FutureFn,

#[pin]
state: State<T, E, Ctx, Fut>,
}

impl<B, T, E, Ctx, Fut, FutureFn> Retry<B, T, E, Ctx, Fut, FutureFn>
where
B: Backoff,
Fut: Future<Output = (Ctx, Result<T, E>)>,
FutureFn: FnMut(Ctx) -> Fut,
{
/// Create a new retry.
///
/// # Notes
///
/// `context` must be set by `context` method before calling `await`.
fn new(future_fn: FutureFn, backoff: B) -> Self {
Retry {
backoff,
retryable: |_: &E| true,
notify: |_: &E, _: Duration| {},
future_fn,
state: State::Idle(None),
}
}
}

impl<B, T, E, Ctx, Fut, FutureFn, RF, NF> Retry<B, T, E, Ctx, Fut, FutureFn, RF, NF>
where
B: Backoff,
Fut: Future<Output = (Ctx, Result<T, E>)>,
FutureFn: FnMut(Ctx) -> Fut,
RF: FnMut(&E) -> bool,
NF: FnMut(&E, Duration),
{
/// Set the context for retrying.
pub fn context(self, context: Ctx) -> Retry<B, T, E, Ctx, Fut, FutureFn, RF, NF> {
Retry {
backoff: self.backoff,
retryable: self.retryable,
notify: self.notify,
future_fn: self.future_fn,
state: State::Idle(Some(context)),
}
}

/// Set the conditions for retrying.
///
/// If not specified, we treat all errors as retryable.
///
/// # Examples
///
/// ```no_run
/// use anyhow::Result;
/// use backon::ExponentialBuilder;
/// use backon::Retryable;
///
/// async fn fetch() -> Result<String> {
/// Ok(reqwest::get("https://www.rust-lang.org")
/// .await?
/// .text()
/// .await?)
/// }
///
/// #[tokio::main(flavor = "current_thread")]
/// async fn main() -> Result<()> {
/// let content = fetch
/// .retry(&ExponentialBuilder::default())
/// .when(|e| e.to_string() == "EOF")
/// .await?;
/// println!("fetch succeeded: {}", content);
///
/// Ok(())
/// }
/// ```
pub fn when<RN: FnMut(&E) -> bool>(
self,
retryable: RN,
) -> Retry<B, T, E, Ctx, Fut, FutureFn, RN, NF> {
Retry {
backoff: self.backoff,
retryable,
notify: self.notify,
future_fn: self.future_fn,
state: self.state,
}
}

/// Set to notify for everything retrying.
///
/// If not specified, this is a no-op.
///
/// # Examples
///
/// ```no_run
/// use std::time::Duration;
///
/// use anyhow::Result;
/// use backon::ExponentialBuilder;
/// use backon::Retryable;
///
/// async fn fetch() -> Result<String> {
/// Ok(reqwest::get("https://www.rust-lang.org")
/// .await?
/// .text()
/// .await?)
/// }
///
/// #[tokio::main(flavor = "current_thread")]
/// async fn main() -> Result<()> {
/// let content = fetch
/// .retry(&ExponentialBuilder::default())
/// .notify(|err: &anyhow::Error, dur: Duration| {
/// println!("retrying error {:?} with sleeping {:?}", err, dur);
/// })
/// .await?;
/// println!("fetch succeeded: {}", content);
///
/// Ok(())
/// }
/// ```
pub fn notify<NN: FnMut(&E, Duration)>(
self,
notify: NN,
) -> Retry<B, T, E, Ctx, Fut, FutureFn, RF, NN> {
Retry {
backoff: self.backoff,
retryable: self.retryable,
notify,
future_fn: self.future_fn,
state: self.state,
}
}
}

/// State maintains internal state of retry.
///
/// # Notes
///
/// `tokio::time::Sleep` is a very struct that occupy 640B, so we wrap it
/// into a `Pin<Box<_>>` to avoid this enum too large.
#[pin_project(project = StateProject)]
enum State<T, E, Ctx, Fut: Future<Output = (Ctx, Result<T, E>)>> {
Idle(Option<Ctx>),
Polling(#[pin] Fut),
// TODO: we need to support other sleeper
Sleeping((Option<Ctx>, Pin<Box<tokio::time::Sleep>>)),
}

impl<B, T, E, Ctx, Fut, FutureFn, RF, NF> Future for Retry<B, T, E, Ctx, Fut, FutureFn, RF, NF>
where
B: Backoff,
Fut: Future<Output = (Ctx, Result<T, E>)>,
FutureFn: FnMut(Ctx) -> Fut,
RF: FnMut(&E) -> bool,
NF: FnMut(&E, Duration),
{
type Output = (Ctx, Result<T, E>);

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
loop {
let state = this.state.as_mut().project();
match state {
StateProject::Idle(ctx) => {
let ctx = ctx.take().expect("context must be valid");
let fut = (this.future_fn)(ctx);
this.state.set(State::Polling(fut));
continue;
}
StateProject::Polling(fut) => {
let (ctx, res) = ready!(fut.poll(cx));
match res {
Ok(v) => return Poll::Ready((ctx, Ok(v))),
Err(err) => {
// If input error is not retryable, return error directly.
if !(this.retryable)(&err) {
return Poll::Ready((ctx, Err(err)));
}
match this.backoff.next() {
None => return Poll::Ready((ctx, Err(err))),
Some(dur) => {
(this.notify)(&err, dur);
this.state.set(State::Sleeping((
Some(ctx),
Box::pin(tokio::time::sleep(dur)),
)));
continue;
}
}
}
}
}
StateProject::Sleeping((ctx, sl)) => {
ready!(pin!(sl).poll(cx));
let ctx = ctx.take().expect("context must be valid");
this.state.set(State::Idle(Some(ctx)));
continue;
}
}
}
}
}

#[cfg(test)]
mod tests {
use std::time::Duration;

use anyhow::anyhow;
use tokio::sync::Mutex;

use super::*;
use crate::exponential::ExponentialBuilder;
use anyhow::Result;

struct Test;

impl Test {
async fn hello(&mut self) -> Result<usize> {
Err(anyhow!("not retryable"))
}
}

#[tokio::test]
async fn test_retry_with_not_retryable_error() -> Result<()> {
let error_times = Mutex::new(0);

let test = Test;

let backoff = ExponentialBuilder::default().with_min_delay(Duration::from_millis(1));

let (_, result) = {
|mut v: Test| async {
let mut x = error_times.lock().await;
*x += 1;

let res = v.hello().await;
(v, res)
}
}
.retry(&backoff)
.context(test)
// Only retry If error message is `retryable`
.when(|e| e.to_string() == "retryable")
.await;

assert!(result.is_err());
assert_eq!("not retryable", result.unwrap_err().to_string());
// `f` always returns error "not retryable", so it should be executed
// only once.
assert_eq!(*error_times.lock().await, 1);
Ok(())
}
}

0 comments on commit 60fe7a9

Please sign in to comment.