diff --git a/object_store/Cargo.toml b/object_store/Cargo.toml
index 741539891597..cb96f9369951 100644
--- a/object_store/Cargo.toml
+++ b/object_store/Cargo.toml
@@ -48,6 +48,7 @@ quick-xml = { version = "0.23.0", features = ["serialize"], optional = true }
rustls-pemfile = { version = "1.0", default-features = false, optional = true }
ring = { version = "0.16", default-features = false, features = ["std"] }
base64 = { version = "0.13", default-features = false, optional = true }
+rand = { version = "0.8", default-features = false, optional = true, features = ["std", "std_rng"] }
# for rusoto
hyper = { version = "0.14", optional = true, default-features = false }
# for rusoto
@@ -58,7 +59,7 @@ percent-encoding = "2.1"
rusoto_core = { version = "0.48.0", optional = true, default-features = false, features = ["rustls"] }
rusoto_credential = { version = "0.48.0", optional = true, default-features = false }
rusoto_s3 = { version = "0.48.0", optional = true, default-features = false, features = ["rustls"] }
-rusoto_sts = { version = "0.48.0", optional = true, default-features = false, features = ["rustls"] }
+rusoto_sts = { version = "0.48.0", optional = true, default-features = false, features = ["rustls"] }
snafu = "0.7"
tokio = { version = "1.18", features = ["sync", "macros", "parking_lot", "rt-multi-thread", "time", "io-util"] }
tracing = { version = "0.1" }
@@ -71,7 +72,7 @@ walkdir = "2"
[features]
azure = ["azure_core", "azure_storage_blobs", "azure_storage", "reqwest"]
azure_test = ["azure", "azure_core/azurite_workaround", "azure_storage/azurite_workaround", "azure_storage_blobs/azurite_workaround"]
-gcp = ["serde", "serde_json", "quick-xml", "reqwest", "reqwest/json", "reqwest/stream", "chrono/serde", "rustls-pemfile", "base64"]
+gcp = ["serde", "serde_json", "quick-xml", "reqwest", "reqwest/json", "reqwest/stream", "chrono/serde", "rustls-pemfile", "base64", "rand"]
aws = ["rusoto_core", "rusoto_credential", "rusoto_s3", "rusoto_sts", "hyper", "hyper-rustls"]
[dev-dependencies] # In alphabetical order
diff --git a/object_store/src/client/backoff.rs b/object_store/src/client/backoff.rs
new file mode 100644
index 000000000000..5a6126cc45c6
--- /dev/null
+++ b/object_store/src/client/backoff.rs
@@ -0,0 +1,156 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use rand::prelude::*;
+use std::time::Duration;
+
+/// Exponential backoff with jitter
+///
+/// See
+#[allow(missing_copy_implementations)]
+#[derive(Debug, Clone)]
+pub struct BackoffConfig {
+ /// The initial backoff duration
+ pub init_backoff: Duration,
+ /// The maximum backoff duration
+ pub max_backoff: Duration,
+ /// The base of the exponential to use
+ pub base: f64,
+}
+
+impl Default for BackoffConfig {
+ fn default() -> Self {
+ Self {
+ init_backoff: Duration::from_millis(100),
+ max_backoff: Duration::from_secs(15),
+ base: 2.,
+ }
+ }
+}
+
+/// [`Backoff`] can be created from a [`BackoffConfig`]
+///
+/// Consecutive calls to [`Backoff::next`] will return the next backoff interval
+///
+pub struct Backoff {
+ init_backoff: f64,
+ next_backoff_secs: f64,
+ max_backoff_secs: f64,
+ base: f64,
+ rng: Option>,
+}
+
+impl std::fmt::Debug for Backoff {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("Backoff")
+ .field("init_backoff", &self.init_backoff)
+ .field("next_backoff_secs", &self.next_backoff_secs)
+ .field("max_backoff_secs", &self.max_backoff_secs)
+ .field("base", &self.base)
+ .finish()
+ }
+}
+
+impl Backoff {
+ /// Create a new [`Backoff`] from the provided [`BackoffConfig`]
+ pub fn new(config: &BackoffConfig) -> Self {
+ Self::new_with_rng(config, None)
+ }
+
+ /// Creates a new `Backoff` with the optional `rng`
+ ///
+ /// Used [`rand::thread_rng()`] if no rng provided
+ pub fn new_with_rng(
+ config: &BackoffConfig,
+ rng: Option>,
+ ) -> Self {
+ let init_backoff = config.init_backoff.as_secs_f64();
+ Self {
+ init_backoff,
+ next_backoff_secs: init_backoff,
+ max_backoff_secs: config.max_backoff.as_secs_f64(),
+ base: config.base,
+ rng,
+ }
+ }
+
+ /// Returns the next backoff duration to wait for
+ pub fn next(&mut self) -> Duration {
+ let range = self.init_backoff..(self.next_backoff_secs * self.base);
+
+ let rand_backoff = match self.rng.as_mut() {
+ Some(rng) => rng.gen_range(range),
+ None => thread_rng().gen_range(range),
+ };
+
+ let next_backoff = self.max_backoff_secs.min(rand_backoff);
+ Duration::from_secs_f64(std::mem::replace(
+ &mut self.next_backoff_secs,
+ next_backoff,
+ ))
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use rand::rngs::mock::StepRng;
+
+ #[test]
+ fn test_backoff() {
+ let init_backoff_secs = 1.;
+ let max_backoff_secs = 500.;
+ let base = 3.;
+
+ let config = BackoffConfig {
+ init_backoff: Duration::from_secs_f64(init_backoff_secs),
+ max_backoff: Duration::from_secs_f64(max_backoff_secs),
+ base,
+ };
+
+ let assert_fuzzy_eq =
+ |a: f64, b: f64| assert!((b - a).abs() < 0.0001, "{} != {}", a, b);
+
+ // Create a static rng that takes the minimum of the range
+ let rng = Box::new(StepRng::new(0, 0));
+ let mut backoff = Backoff::new_with_rng(&config, Some(rng));
+
+ for _ in 0..20 {
+ assert_eq!(backoff.next().as_secs_f64(), init_backoff_secs);
+ }
+
+ // Create a static rng that takes the maximum of the range
+ let rng = Box::new(StepRng::new(u64::MAX, 0));
+ let mut backoff = Backoff::new_with_rng(&config, Some(rng));
+
+ for i in 0..20 {
+ let value = (base.powi(i) * init_backoff_secs).min(max_backoff_secs);
+ assert_fuzzy_eq(backoff.next().as_secs_f64(), value);
+ }
+
+ // Create a static rng that takes the mid point of the range
+ let rng = Box::new(StepRng::new(u64::MAX / 2, 0));
+ let mut backoff = Backoff::new_with_rng(&config, Some(rng));
+
+ let mut value = init_backoff_secs;
+ for _ in 0..20 {
+ assert_fuzzy_eq(backoff.next().as_secs_f64(), value);
+ value = (init_backoff_secs + (value * base - init_backoff_secs) / 2.)
+ .min(max_backoff_secs);
+ }
+ }
+}
diff --git a/object_store/src/client/mod.rs b/object_store/src/client/mod.rs
new file mode 100644
index 000000000000..1166ebe7a525
--- /dev/null
+++ b/object_store/src/client/mod.rs
@@ -0,0 +1,23 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Generic utilities reqwest based ObjectStore implementations
+
+pub mod backoff;
+pub mod oauth;
+pub mod retry;
+pub mod token;
diff --git a/object_store/src/oauth.rs b/object_store/src/client/oauth.rs
similarity index 96%
rename from object_store/src/oauth.rs
rename to object_store/src/client/oauth.rs
index 273e37b64922..88e7a7b0f9e8 100644
--- a/object_store/src/oauth.rs
+++ b/object_store/src/client/oauth.rs
@@ -15,7 +15,9 @@
// specific language governing permissions and limitations
// under the License.
-use crate::token::TemporaryToken;
+use crate::client::retry::RetryExt;
+use crate::client::token::TemporaryToken;
+use crate::RetryConfig;
use reqwest::{Client, Method};
use ring::signature::RsaKeyPair;
use snafu::{ResultExt, Snafu};
@@ -133,7 +135,11 @@ impl OAuthProvider {
}
/// Fetch a fresh token
- pub async fn fetch_token(&self, client: &Client) -> Result> {
+ pub async fn fetch_token(
+ &self,
+ client: &Client,
+ retry: &RetryConfig,
+ ) -> Result> {
let now = seconds_since_epoch();
let exp = now + 3600;
@@ -168,7 +174,7 @@ impl OAuthProvider {
let response: TokenResponse = client
.request(Method::POST, &self.audience)
.form(&body)
- .send()
+ .send_retry(retry)
.await
.context(TokenRequestSnafu)?
.error_for_status()
diff --git a/object_store/src/client/retry.rs b/object_store/src/client/retry.rs
new file mode 100644
index 000000000000..c4dd6ee934cb
--- /dev/null
+++ b/object_store/src/client/retry.rs
@@ -0,0 +1,106 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! A shared HTTP client implementation incorporating retries
+
+use crate::client::backoff::{Backoff, BackoffConfig};
+use futures::future::BoxFuture;
+use futures::FutureExt;
+use reqwest::{Response, Result};
+use std::time::{Duration, Instant};
+use tracing::info;
+
+/// Contains the configuration for how to respond to server errors
+///
+/// By default they will be retried up to some limit, using exponential
+/// backoff with jitter. See [`BackoffConfig`] for more information
+///
+#[derive(Debug, Clone)]
+pub struct RetryConfig {
+ /// The backoff configuration
+ pub backoff: BackoffConfig,
+
+ /// The maximum number of times to retry a request
+ ///
+ /// Set to 0 to disable retries
+ pub max_retries: usize,
+
+ /// The maximum length of time from the initial request
+ /// after which no further retries will be attempted
+ ///
+ /// This not only bounds the length of time before a server
+ /// error will be surfaced to the application, but also bounds
+ /// the length of time a request's credentials must remain valid.
+ ///
+ /// As requests are retried without renewing credentials or
+ /// regenerating request payloads, this number should be kept
+ /// below 5 minutes to avoid errors due to expired credentials
+ /// and/or request payloads
+ pub retry_timeout: Duration,
+}
+
+impl Default for RetryConfig {
+ fn default() -> Self {
+ Self {
+ backoff: Default::default(),
+ max_retries: 10,
+ retry_timeout: Duration::from_secs(3 * 60),
+ }
+ }
+}
+
+pub trait RetryExt {
+ /// Dispatch a request with the given retry configuration
+ ///
+ /// # Panic
+ ///
+ /// This will panic if the request body is a stream
+ fn send_retry(self, config: &RetryConfig) -> BoxFuture<'static, Result>;
+}
+
+impl RetryExt for reqwest::RequestBuilder {
+ fn send_retry(self, config: &RetryConfig) -> BoxFuture<'static, Result> {
+ let mut backoff = Backoff::new(&config.backoff);
+ let max_retries = config.max_retries;
+ let retry_timeout = config.retry_timeout;
+
+ async move {
+ let mut retries = 0;
+ let now = Instant::now();
+
+ loop {
+ let s = self.try_clone().expect("request body must be cloneable");
+ match s.send().await {
+ Err(e)
+ if retries < max_retries
+ && now.elapsed() < retry_timeout
+ && e.status()
+ .map(|s| s.is_server_error())
+ .unwrap_or(false) =>
+ {
+ let sleep = backoff.next();
+ retries += 1;
+ info!("Encountered server error, backing off for {} seconds, retry {} of {}", sleep.as_secs_f32(), retries, max_retries);
+ tokio::time::sleep(sleep).await;
+ }
+ r => return r,
+ }
+ }
+ }
+ .boxed()
+ }
+}
diff --git a/object_store/src/token.rs b/object_store/src/client/token.rs
similarity index 100%
rename from object_store/src/token.rs
rename to object_store/src/client/token.rs
diff --git a/object_store/src/gcp.rs b/object_store/src/gcp.rs
index dea8769a736b..1c33c52cbf3c 100644
--- a/object_store/src/gcp.rs
+++ b/object_store/src/gcp.rs
@@ -46,14 +46,13 @@ use reqwest::{header, Client, Method, Response, StatusCode};
use snafu::{ResultExt, Snafu};
use tokio::io::AsyncWrite;
-use crate::multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart};
-use crate::util::format_http_range;
+use crate::client::retry::RetryExt;
use crate::{
- oauth::OAuthProvider,
+ client::{oauth::OAuthProvider, token::TokenCache},
+ multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart},
path::{Path, DELIMITER},
- token::TokenCache,
- util::format_prefix,
- GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result,
+ util::{format_http_range, format_prefix},
+ GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result, RetryConfig,
};
#[derive(Debug, Snafu)]
@@ -215,6 +214,8 @@ struct GoogleCloudStorageClient {
bucket_name: String,
bucket_name_encoded: String,
+ retry_config: RetryConfig,
+
// TODO: Hook this up in tests
max_list_results: Option,
}
@@ -224,7 +225,9 @@ impl GoogleCloudStorageClient {
if let Some(oauth_provider) = &self.oauth_provider {
Ok(self
.token_cache
- .get_or_insert_with(|| oauth_provider.fetch_token(&self.client))
+ .get_or_insert_with(|| {
+ oauth_provider.fetch_token(&self.client, &self.retry_config)
+ })
.await?)
} else {
Ok("".to_owned())
@@ -264,7 +267,7 @@ impl GoogleCloudStorageClient {
let response = builder
.bearer_auth(token)
.query(&[("alt", alt)])
- .send()
+ .send_retry(&self.retry_config)
.await
.context(GetRequestSnafu {
path: path.as_ref(),
@@ -292,7 +295,7 @@ impl GoogleCloudStorageClient {
.header(header::CONTENT_LENGTH, payload.len())
.query(&[("uploadType", "media"), ("name", path.as_ref())])
.body(payload)
- .send()
+ .send_retry(&self.retry_config)
.await
.context(PutRequestSnafu)?
.error_for_status()
@@ -313,7 +316,7 @@ impl GoogleCloudStorageClient {
.header(header::CONTENT_TYPE, "application/octet-stream")
.header(header::CONTENT_LENGTH, "0")
.query(&[("uploads", "")])
- .send()
+ .send_retry(&self.retry_config)
.await
.context(PutRequestSnafu)?
.error_for_status()
@@ -347,7 +350,7 @@ impl GoogleCloudStorageClient {
.header(header::CONTENT_TYPE, "application/octet-stream")
.header(header::CONTENT_LENGTH, "0")
.query(&[("uploadId", multipart_id)])
- .send()
+ .send_retry(&self.retry_config)
.await
.context(PutRequestSnafu)?
.error_for_status()
@@ -364,7 +367,7 @@ impl GoogleCloudStorageClient {
let builder = self.client.request(Method::DELETE, url);
builder
.bearer_auth(token)
- .send()
+ .send_retry(&self.retry_config)
.await
.context(DeleteRequestSnafu {
path: path.as_ref(),
@@ -407,7 +410,7 @@ impl GoogleCloudStorageClient {
builder
.bearer_auth(token)
- .send()
+ .send_retry(&self.retry_config)
.await
.context(CopyRequestSnafu {
path: from.as_ref(),
@@ -456,7 +459,7 @@ impl GoogleCloudStorageClient {
.request(Method::GET, url)
.query(&query)
.bearer_auth(token)
- .send()
+ .send_retry(&self.retry_config)
.await
.context(ListRequestSnafu)?
.error_for_status()
@@ -572,7 +575,7 @@ impl CloudMultiPartUploadImpl for GCSMultipartUpload {
.header(header::CONTENT_TYPE, "application/octet-stream")
.header(header::CONTENT_LENGTH, format!("{}", buf.len()))
.body(buf)
- .send()
+ .send_retry(&client.retry_config)
.await
.map_err(reqwest_error_as_io)?
.error_for_status()
@@ -643,7 +646,7 @@ impl CloudMultiPartUploadImpl for GCSMultipartUpload {
.bearer_auth(token)
.query(&[("uploadId", upload_id)])
.body(data)
- .send()
+ .send_retry(&client.retry_config)
.await
.map_err(reqwest_error_as_io)?
.error_for_status()
@@ -802,6 +805,7 @@ pub struct GoogleCloudStorageBuilder {
bucket_name: Option,
service_account_path: Option,
client: Option,
+ retry_config: RetryConfig,
}
impl GoogleCloudStorageBuilder {
@@ -837,6 +841,12 @@ impl GoogleCloudStorageBuilder {
self
}
+ /// Set the retry configuration
+ pub fn with_retry(mut self, retry_config: RetryConfig) -> Self {
+ self.retry_config = retry_config;
+ self
+ }
+
/// Use the specified http [`Client`] (defaults to [`Client::new`])
///
/// This allows you to set custom client options such as allowing
@@ -858,6 +868,7 @@ impl GoogleCloudStorageBuilder {
bucket_name,
service_account_path,
client,
+ retry_config,
} = self;
let bucket_name = bucket_name.ok_or(Error::MissingBucketName {})?;
@@ -896,6 +907,7 @@ impl GoogleCloudStorageBuilder {
token_cache: Default::default(),
bucket_name,
bucket_name_encoded: encoded_bucket_name,
+ retry_config,
max_list_results: None,
}),
})
diff --git a/object_store/src/lib.rs b/object_store/src/lib.rs
index 54d28273fa97..caa3f1d3e926 100644
--- a/object_store/src/lib.rs
+++ b/object_store/src/lib.rs
@@ -51,10 +51,10 @@ pub mod path;
pub mod throttle;
#[cfg(feature = "gcp")]
-mod oauth;
+mod client;
#[cfg(feature = "gcp")]
-mod token;
+pub use client::{backoff::BackoffConfig, retry::RetryConfig};
#[cfg(any(feature = "azure", feature = "aws", feature = "gcp"))]
mod multipart;
@@ -336,7 +336,7 @@ pub enum Error {
#[cfg(feature = "gcp")]
#[snafu(display("OAuth error: {}", source), context(false))]
- OAuth { source: oauth::Error },
+ OAuth { source: client::oauth::Error },
}
#[cfg(test)]