diff --git a/object_store/src/aws/client.rs b/object_store/src/aws/client.rs index 1cdf785e5f4d..8ce743b31be9 100644 --- a/object_store/src/aws/client.rs +++ b/object_store/src/aws/client.rs @@ -16,8 +16,8 @@ // under the License. use crate::aws::checksum::Checksum; -use crate::aws::credential::{AwsCredential, CredentialExt, CredentialProvider}; -use crate::aws::{STORE, STRICT_PATH_ENCODE_SET}; +use crate::aws::credential::{AwsCredential, CredentialExt}; +use crate::aws::{AwsCredentialProvider, STORE, STRICT_PATH_ENCODE_SET}; use crate::client::list::ListResponse; use crate::client::pagination::stream_paginated; use crate::client::retry::RetryExt; @@ -135,7 +135,7 @@ pub struct S3Config { pub endpoint: String, pub bucket: String, pub bucket_endpoint: String, - pub credentials: Box, + pub credentials: AwsCredentialProvider, pub retry_config: RetryConfig, pub client_options: ClientOptions, pub sign_payload: bool, diff --git a/object_store/src/aws/credential.rs b/object_store/src/aws/credential.rs index 9e047941a3c2..47d681c631c7 100644 --- a/object_store/src/aws/credential.rs +++ b/object_store/src/aws/credential.rs @@ -18,12 +18,12 @@ use crate::aws::{STORE, STRICT_ENCODE_SET}; use crate::client::retry::RetryExt; use crate::client::token::{TemporaryToken, TokenCache}; +use crate::client::TokenProvider; use crate::util::hmac_sha256; use crate::{Result, RetryConfig}; +use async_trait::async_trait; use bytes::Buf; use chrono::{DateTime, Utc}; -use futures::future::BoxFuture; -use futures::TryFutureExt; use percent_encoding::utf8_percent_encode; use reqwest::header::{HeaderMap, HeaderValue}; use reqwest::{Client, Method, Request, RequestBuilder, StatusCode}; @@ -41,10 +41,14 @@ static EMPTY_SHA256_HASH: &str = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"; static UNSIGNED_PAYLOAD_LITERAL: &str = "UNSIGNED-PAYLOAD"; -#[derive(Debug)] +/// A set of AWS security credentials +#[derive(Debug, Eq, PartialEq)] pub struct AwsCredential { + /// AWS_ACCESS_KEY_ID pub key_id: String, + /// AWS_SECRET_ACCESS_KEY pub secret_key: String, + /// AWS_SESSION_TOKEN pub token: Option, } @@ -291,49 +295,31 @@ fn canonicalize_headers(header_map: &HeaderMap) -> (String, String) { (signed_headers, canonical_headers) } -/// Provides credentials for use when signing requests -pub trait CredentialProvider: std::fmt::Debug + Send + Sync { - fn get_credential(&self) -> BoxFuture<'_, Result>>; -} - -/// A static set of credentials -#[derive(Debug)] -pub struct StaticCredentialProvider { - pub credential: Arc, -} - -impl CredentialProvider for StaticCredentialProvider { - fn get_credential(&self) -> BoxFuture<'_, Result>> { - Box::pin(futures::future::ready(Ok(Arc::clone(&self.credential)))) - } -} - /// Credentials sourced from the instance metadata service /// /// #[derive(Debug)] pub struct InstanceCredentialProvider { pub cache: TokenCache>, - pub client: Client, - pub retry_config: RetryConfig, pub imdsv1_fallback: bool, pub metadata_endpoint: String, } -impl CredentialProvider for InstanceCredentialProvider { - fn get_credential(&self) -> BoxFuture<'_, Result>> { - Box::pin(self.cache.get_or_insert_with(|| { - instance_creds( - &self.client, - &self.retry_config, - &self.metadata_endpoint, - self.imdsv1_fallback, - ) +#[async_trait] +impl TokenProvider for InstanceCredentialProvider { + type Credential = AwsCredential; + + async fn fetch_token( + &self, + client: &Client, + retry: &RetryConfig, + ) -> Result>> { + instance_creds(client, retry, &self.metadata_endpoint, self.imdsv1_fallback) + .await .map_err(|source| crate::Error::Generic { store: STORE, source, }) - })) } } @@ -342,31 +328,34 @@ impl CredentialProvider for InstanceCredentialProvider { /// #[derive(Debug)] pub struct WebIdentityProvider { - pub cache: TokenCache>, pub token_path: String, pub role_arn: String, pub session_name: String, pub endpoint: String, - pub client: Client, - pub retry_config: RetryConfig, } -impl CredentialProvider for WebIdentityProvider { - fn get_credential(&self) -> BoxFuture<'_, Result>> { - Box::pin(self.cache.get_or_insert_with(|| { - web_identity( - &self.client, - &self.retry_config, - &self.token_path, - &self.role_arn, - &self.session_name, - &self.endpoint, - ) - .map_err(|source| crate::Error::Generic { - store: STORE, - source, - }) - })) +#[async_trait] +impl TokenProvider for WebIdentityProvider { + type Credential = AwsCredential; + + async fn fetch_token( + &self, + client: &Client, + retry: &RetryConfig, + ) -> Result>> { + web_identity( + client, + retry, + &self.token_path, + &self.role_arn, + &self.session_name, + &self.endpoint, + ) + .await + .map_err(|source| crate::Error::Generic { + store: STORE, + source, + }) } } diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs index 428e013f4478..ddb9dc799501 100644 --- a/object_store/src/aws/mod.rs +++ b/object_store/src/aws/mod.rs @@ -48,11 +48,13 @@ use url::Url; pub use crate::aws::checksum::Checksum; use crate::aws::client::{S3Client, S3Config}; use crate::aws::credential::{ - AwsCredential, CredentialProvider, InstanceCredentialProvider, - StaticCredentialProvider, WebIdentityProvider, + AwsCredential, InstanceCredentialProvider, WebIdentityProvider, }; use crate::client::header::header_meta; -use crate::client::ClientConfigKey; +use crate::client::{ + ClientConfigKey, CredentialProvider, StaticCredentialProvider, + TokenCredentialProvider, +}; use crate::config::ConfigValue; use crate::multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart}; use crate::{ @@ -83,6 +85,8 @@ const STRICT_PATH_ENCODE_SET: percent_encoding::AsciiSet = STRICT_ENCODE_SET.rem const STORE: &str = "S3"; +type AwsCredentialProvider = Arc>; + /// Default metadata endpoint static METADATA_ENDPOINT: &str = "http://169.254.169.254"; @@ -1001,13 +1005,12 @@ impl AmazonS3Builder { let credentials = match (self.access_key_id, self.secret_access_key, self.token) { (Some(key_id), Some(secret_key), token) => { info!("Using Static credential provider"); - Box::new(StaticCredentialProvider { - credential: Arc::new(AwsCredential { - key_id, - secret_key, - token, - }), - }) as _ + let credential = AwsCredential { + key_id, + secret_key, + token, + }; + Arc::new(StaticCredentialProvider::new(credential)) as _ } (None, Some(_), _) => return Err(Error::MissingAccessKeyId.into()), (Some(_), None, _) => return Err(Error::MissingSecretAccessKey.into()), @@ -1031,15 +1034,18 @@ impl AmazonS3Builder { .with_allow_http(false) .client()?; - Box::new(WebIdentityProvider { - cache: Default::default(), + let token = WebIdentityProvider { token_path, session_name, role_arn, endpoint, + }; + + Arc::new(TokenCredentialProvider::new( + token, client, - retry_config: self.retry_config.clone(), - }) as _ + self.retry_config.clone(), + )) as _ } _ => match self.profile { Some(profile) => { @@ -1049,19 +1055,20 @@ impl AmazonS3Builder { None => { info!("Using Instance credential provider"); - // The instance metadata endpoint is access over HTTP - let client_options = - self.client_options.clone().with_allow_http(true); - - Box::new(InstanceCredentialProvider { + let token = InstanceCredentialProvider { cache: Default::default(), - client: client_options.client()?, - retry_config: self.retry_config.clone(), imdsv1_fallback: self.imdsv1_fallback.get()?, metadata_endpoint: self .metadata_endpoint .unwrap_or_else(|| METADATA_ENDPOINT.into()), - }) as _ + }; + + Arc::new(TokenCredentialProvider::new( + token, + // The instance metadata endpoint is access over HTTP + self.client_options.clone().with_allow_http(true).client()?, + self.retry_config.clone(), + )) as _ } }, }, @@ -1114,11 +1121,8 @@ fn profile_region(profile: String) -> Option { } #[cfg(feature = "aws_profile")] -fn profile_credentials( - profile: String, - region: String, -) -> Result> { - Ok(Box::new(profile::ProfileProvider::new( +fn profile_credentials(profile: String, region: String) -> Result { + Ok(Arc::new(profile::ProfileProvider::new( profile, Some(region), ))) @@ -1133,7 +1137,7 @@ fn profile_region(_profile: String) -> Option { fn profile_credentials( _profile: String, _region: String, -) -> Result> { +) -> Result { Err(Error::MissingProfileFeature.into()) } diff --git a/object_store/src/aws/profile.rs b/object_store/src/aws/profile.rs index a88824c79f93..3fc08056444e 100644 --- a/object_store/src/aws/profile.rs +++ b/object_store/src/aws/profile.rs @@ -17,6 +17,7 @@ #![cfg(feature = "aws_profile")] +use async_trait::async_trait; use aws_config::meta::region::ProvideRegion; use aws_config::profile::profile_file::ProfileFiles; use aws_config::profile::ProfileFileCredentialsProvider; @@ -24,14 +25,13 @@ use aws_config::profile::ProfileFileRegionProvider; use aws_config::provider_config::ProviderConfig; use aws_credential_types::provider::ProvideCredentials; use aws_types::region::Region; -use futures::future::BoxFuture; use std::sync::Arc; use std::time::Instant; use std::time::SystemTime; -use crate::aws::credential::CredentialProvider; use crate::aws::AwsCredential; use crate::client::token::{TemporaryToken, TokenCache}; +use crate::client::CredentialProvider; use crate::Result; #[cfg(test)] @@ -91,38 +91,43 @@ impl ProfileProvider { } } +#[async_trait] impl CredentialProvider for ProfileProvider { - fn get_credential(&self) -> BoxFuture<'_, Result>> { - Box::pin(self.cache.get_or_insert_with(move || async move { - let region = self.region.clone().map(Region::new); - - let config = ProviderConfig::default().with_region(region); - - let credentials = ProfileFileCredentialsProvider::builder() - .configure(&config) - .profile_name(&self.name) - .build(); - - let c = credentials.provide_credentials().await.map_err(|source| { - crate::Error::Generic { - store: "S3", - source: Box::new(source), - } - })?; - let t_now = SystemTime::now(); - let expiry = c - .expiry() - .and_then(|e| e.duration_since(t_now).ok()) - .map(|ttl| Instant::now() + ttl); - - Ok(TemporaryToken { - token: Arc::new(AwsCredential { - key_id: c.access_key_id().to_string(), - secret_key: c.secret_access_key().to_string(), - token: c.session_token().map(ToString::to_string), - }), - expiry, + type Credential = AwsCredential; + + async fn get_credential(&self) -> Result> { + self.cache + .get_or_insert_with(move || async move { + let region = self.region.clone().map(Region::new); + + let config = ProviderConfig::default().with_region(region); + + let credentials = ProfileFileCredentialsProvider::builder() + .configure(&config) + .profile_name(&self.name) + .build(); + + let c = credentials.provide_credentials().await.map_err(|source| { + crate::Error::Generic { + store: "S3", + source: Box::new(source), + } + })?; + let t_now = SystemTime::now(); + let expiry = c + .expiry() + .and_then(|e| e.duration_since(t_now).ok()) + .map(|ttl| Instant::now() + ttl); + + Ok(TemporaryToken { + token: Arc::new(AwsCredential { + key_id: c.access_key_id().to_string(), + secret_key: c.secret_access_key().to_string(), + token: c.session_token().map(ToString::to_string), + }), + expiry, + }) }) - })) + .await } } diff --git a/object_store/src/azure/client.rs b/object_store/src/azure/client.rs index 4611986e30d2..a8273c2ffb73 100644 --- a/object_store/src/azure/client.rs +++ b/object_store/src/azure/client.rs @@ -15,9 +15,9 @@ // specific language governing permissions and limitations // under the License. -use super::credential::{AzureCredential, CredentialProvider}; +use super::credential::AzureCredential; use crate::azure::credential::*; -use crate::azure::STORE; +use crate::azure::{AzureCredentialProvider, STORE}; use crate::client::pagination::stream_paginated; use crate::client::retry::RetryExt; use crate::client::GetOptionsExt; @@ -40,6 +40,7 @@ use reqwest::{ use serde::{Deserialize, Serialize}; use snafu::{ResultExt, Snafu}; use std::collections::HashMap; +use std::sync::Arc; use url::Url; /// A specialized `Error` for object store-related errors @@ -101,10 +102,10 @@ impl From for crate::Error { /// Configuration for [AzureClient] #[derive(Debug)] -pub struct AzureConfig { +pub(crate) struct AzureConfig { pub account: String, pub container: String, - pub credentials: CredentialProvider, + pub credentials: AzureCredentialProvider, pub retry_config: RetryConfig, pub service: Url, pub is_emulator: bool, @@ -143,45 +144,8 @@ impl AzureClient { &self.config } - async fn get_credential(&self) -> Result { - match &self.config.credentials { - CredentialProvider::AccessKey(key) => { - Ok(AzureCredential::AccessKey(key.to_owned())) - } - CredentialProvider::BearerToken(token) => { - Ok(AzureCredential::AuthorizationToken( - // we do the conversion to a HeaderValue here, since it is fallible - // and we want to use it in an infallible function - HeaderValue::from_str(&format!("Bearer {token}")).map_err(|err| { - crate::Error::Generic { - store: STORE, - source: Box::new(err), - } - })?, - )) - } - CredentialProvider::TokenCredential(cache, cred) => { - let token = cache - .get_or_insert_with(|| { - cred.fetch_token(&self.client, &self.config.retry_config) - }) - .await - .context(AuthorizationSnafu)?; - Ok(AzureCredential::AuthorizationToken( - // we do the conversion to a HeaderValue here, since it is fallible - // and we want to use it in an infallible function - HeaderValue::from_str(&format!("Bearer {token}")).map_err(|err| { - crate::Error::Generic { - store: STORE, - source: Box::new(err), - } - })?, - )) - } - CredentialProvider::SASToken(sas) => { - Ok(AzureCredential::SASToken(sas.clone())) - } - } + async fn get_credential(&self) -> Result> { + self.config.credentials.get_credential().await } /// Make an Azure PUT request @@ -296,7 +260,7 @@ impl AzureClient { // If using SAS authorization must include the headers in the URL // - if let AzureCredential::SASToken(pairs) = &credential { + if let AzureCredential::SASToken(pairs) = credential.as_ref() { source.query_pairs_mut().extend_pairs(pairs); } diff --git a/object_store/src/azure/credential.rs b/object_store/src/azure/credential.rs index 8130df6361fd..fd75389249b0 100644 --- a/object_store/src/azure/credential.rs +++ b/object_store/src/azure/credential.rs @@ -15,10 +15,13 @@ // specific language governing permissions and limitations // under the License. +use crate::azure::STORE; use crate::client::retry::RetryExt; use crate::client::token::{TemporaryToken, TokenCache}; +use crate::client::{CredentialProvider, TokenProvider}; use crate::util::hmac_sha256; use crate::RetryConfig; +use async_trait::async_trait; use base64::prelude::BASE64_STANDARD; use base64::Engine; use chrono::{DateTime, Utc}; @@ -36,6 +39,7 @@ use snafu::{ResultExt, Snafu}; use std::borrow::Cow; use std::process::Command; use std::str; +use std::sync::Arc; use std::time::{Duration, Instant}; use url::Url; @@ -81,19 +85,30 @@ pub enum Error { pub type Result = std::result::Result; -/// Provides credentials for use when signing requests -#[derive(Debug)] -pub enum CredentialProvider { - AccessKey(String), - BearerToken(String), - SASToken(Vec<(String, String)>), - TokenCredential(TokenCache, Box), +impl From for crate::Error { + fn from(value: Error) -> Self { + Self::Generic { + store: STORE, + source: Box::new(value), + } + } } -pub(crate) enum AzureCredential { +/// An Azure storage credential +#[derive(Debug, Eq, PartialEq)] +pub enum AzureCredential { + /// A shared access key + /// + /// AccessKey(String), + /// A shared access signature + /// + /// SASToken(Vec<(String, String)>), - AuthorizationToken(HeaderValue), + /// An authorization token + /// + /// + BearerToken(String), } /// A list of known Azure authority hosts @@ -155,9 +170,7 @@ impl CredentialExt for RequestBuilder { Self::from_parts(client, request) } - AzureCredential::AuthorizationToken(token) => { - self.header(AUTHORIZATION, token) - } + AzureCredential::BearerToken(token) => self.bearer_auth(token), AzureCredential::SASToken(query_pairs) => self.query(&query_pairs), } } @@ -291,15 +304,6 @@ fn lexy_sort<'a>( values } -#[async_trait::async_trait] -pub trait TokenCredential: std::fmt::Debug + Send + Sync + 'static { - async fn fetch_token( - &self, - client: &Client, - retry: &RetryConfig, - ) -> Result>; -} - #[derive(Deserialize, Debug)] struct TokenResponse { access_token: String, @@ -338,13 +342,15 @@ impl ClientSecretOAuthProvider { } #[async_trait::async_trait] -impl TokenCredential for ClientSecretOAuthProvider { +impl TokenProvider for ClientSecretOAuthProvider { + type Credential = AzureCredential; + /// Fetch a token async fn fetch_token( &self, client: &Client, retry: &RetryConfig, - ) -> Result> { + ) -> crate::Result>> { let response: TokenResponse = client .request(Method::POST, &self.token_url) .header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON)) @@ -361,12 +367,10 @@ impl TokenCredential for ClientSecretOAuthProvider { .await .context(TokenResponseBodySnafu)?; - let token = TemporaryToken { - token: response.access_token, + Ok(TemporaryToken { + token: Arc::new(AzureCredential::BearerToken(response.access_token)), expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)), - }; - - Ok(token) + }) } } @@ -397,7 +401,6 @@ pub struct ImdsManagedIdentityProvider { client_id: Option, object_id: Option, msi_res_id: Option, - client: Client, } impl ImdsManagedIdentityProvider { @@ -407,7 +410,6 @@ impl ImdsManagedIdentityProvider { object_id: Option, msi_res_id: Option, msi_endpoint: Option, - client: Client, ) -> Self { let msi_endpoint = msi_endpoint.unwrap_or_else(|| { "http://169.254.169.254/metadata/identity/oauth2/token".to_owned() @@ -418,19 +420,20 @@ impl ImdsManagedIdentityProvider { client_id, object_id, msi_res_id, - client, } } } #[async_trait::async_trait] -impl TokenCredential for ImdsManagedIdentityProvider { +impl TokenProvider for ImdsManagedIdentityProvider { + type Credential = AzureCredential; + /// Fetch a token async fn fetch_token( &self, - _client: &Client, + client: &Client, retry: &RetryConfig, - ) -> Result> { + ) -> crate::Result>> { let mut query_items = vec![ ("api-version", MSI_API_VERSION), ("resource", AZURE_STORAGE_RESOURCE), @@ -450,8 +453,7 @@ impl TokenCredential for ImdsManagedIdentityProvider { query_items.push((key, value)); } - let mut builder = self - .client + let mut builder = client .request(Method::GET, &self.msi_endpoint) .header("metadata", "true") .query(&query_items); @@ -468,12 +470,10 @@ impl TokenCredential for ImdsManagedIdentityProvider { .await .context(TokenResponseBodySnafu)?; - let token = TemporaryToken { - token: response.access_token, + Ok(TemporaryToken { + token: Arc::new(AzureCredential::BearerToken(response.access_token)), expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)), - }; - - Ok(token) + }) } } @@ -511,13 +511,15 @@ impl WorkloadIdentityOAuthProvider { } #[async_trait::async_trait] -impl TokenCredential for WorkloadIdentityOAuthProvider { +impl TokenProvider for WorkloadIdentityOAuthProvider { + type Credential = AzureCredential; + /// Fetch a token async fn fetch_token( &self, client: &Client, retry: &RetryConfig, - ) -> Result> { + ) -> crate::Result>> { let token_str = std::fs::read_to_string(&self.federated_token_file) .map_err(|_| Error::FederatedTokenFile)?; @@ -542,12 +544,10 @@ impl TokenCredential for WorkloadIdentityOAuthProvider { .await .context(TokenResponseBodySnafu)?; - let token = TemporaryToken { - token: response.access_token, + Ok(TemporaryToken { + token: Arc::new(AzureCredential::BearerToken(response.access_token)), expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)), - }; - - Ok(token) + }) } } @@ -585,23 +585,16 @@ struct AzureCliTokenResponse { #[derive(Default, Debug)] pub struct AzureCliCredential { - _private: (), + cache: TokenCache>, } impl AzureCliCredential { pub fn new() -> Self { Self::default() } -} -#[async_trait::async_trait] -impl TokenCredential for AzureCliCredential { /// Fetch a token - async fn fetch_token( - &self, - _client: &Client, - _retry: &RetryConfig, - ) -> Result> { + async fn fetch_token(&self) -> Result>> { // on window az is a cmd and it should be called like this // see https://doc.rust-lang.org/nightly/std/process/struct.Command.html let program = if cfg!(target_os = "windows") { @@ -642,7 +635,9 @@ impl TokenCredential for AzureCliCredential { let duration = token_response.expires_on.naive_local() - chrono::Local::now().naive_local(); Ok(TemporaryToken { - token: token_response.access_token, + token: Arc::new(AzureCredential::BearerToken( + token_response.access_token, + )), expiry: Some( Instant::now() + duration.to_std().map_err(|_| Error::AzureCli { @@ -669,6 +664,15 @@ impl TokenCredential for AzureCliCredential { } } +#[async_trait] +impl CredentialProvider for AzureCliCredential { + type Credential = AzureCredential; + + async fn get_credential(&self) -> crate::Result> { + Ok(self.cache.get_or_insert_with(|| self.fetch_token()).await?) + } +} + #[cfg(test)] mod tests { use super::*; @@ -723,7 +727,6 @@ mod tests { None, None, Some(format!("{endpoint}/metadata/identity/oauth2/token")), - client.clone(), ); let token = credential @@ -731,7 +734,10 @@ mod tests { .await .unwrap(); - assert_eq!(&token.token, "TOKEN"); + assert_eq!( + token.token.as_ref(), + &AzureCredential::BearerToken("TOKEN".into()) + ); } #[tokio::test] @@ -779,6 +785,9 @@ mod tests { .await .unwrap(); - assert_eq!(&token.token, "TOKEN"); + assert_eq!( + token.token.as_ref(), + &AzureCredential::BearerToken("TOKEN".into()) + ); } } diff --git a/object_store/src/azure/mod.rs b/object_store/src/azure/mod.rs index 0f8dae00c6c0..6dc14cfb54e9 100644 --- a/object_store/src/azure/mod.rs +++ b/object_store/src/azure/mod.rs @@ -27,7 +27,6 @@ //! a way to drop old blocks. Instead unused blocks are automatically cleaned up //! after 7 days. use self::client::{BlockId, BlockList}; -use crate::client::token::TokenCache; use crate::{ multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart}, path::Path, @@ -49,14 +48,20 @@ use std::{collections::BTreeSet, str::FromStr}; use tokio::io::AsyncWrite; use url::Url; +use crate::azure::credential::AzureCredential; use crate::client::header::header_meta; -use crate::client::ClientConfigKey; +use crate::client::{ + ClientConfigKey, CredentialProvider, StaticCredentialProvider, + TokenCredentialProvider, +}; use crate::config::ConfigValue; pub use credential::authority_hosts; mod client; mod credential; +type AzureCredentialProvider = Arc>; + const STORE: &str = "MicrosoftAzure"; /// The well-known account used by Azurite and the legacy Azure Storage Emulator. @@ -101,12 +106,6 @@ enum Error { #[snafu(display("Container name must be specified"))] MissingContainerName {}, - #[snafu(display("At least one authorization option must be specified"))] - MissingCredentials {}, - - #[snafu(display("Azure credential error: {}", source), context(false))] - Credential { source: credential::Error }, - #[snafu(display( "Unknown url scheme cannot be parsed into storage location: {}", scheme @@ -913,6 +912,9 @@ impl MicrosoftAzureBuilder { } let container = self.container_name.ok_or(Error::MissingContainerName {})?; + let static_creds = |credential: AzureCredential| -> AzureCredentialProvider { + Arc::new(StaticCredentialProvider::new(credential)) + }; let (is_emulator, storage_url, auth, account) = if self.use_emulator.get()? { let account_name = self @@ -924,7 +926,8 @@ impl MicrosoftAzureBuilder { let account_key = self .access_key .unwrap_or_else(|| EMULATOR_ACCOUNT_KEY.to_string()); - let credential = credential::CredentialProvider::AccessKey(account_key); + + let credential = static_creds(AzureCredential::AccessKey(account_key)); self.client_options = self.client_options.with_allow_http(true); (true, url, credential, account_name) @@ -933,10 +936,11 @@ impl MicrosoftAzureBuilder { let account_url = format!("https://{}.blob.core.windows.net", &account_name); let url = Url::parse(&account_url) .context(UnableToParseUrlSnafu { url: account_url })?; + let credential = if let Some(bearer_token) = self.bearer_token { - credential::CredentialProvider::BearerToken(bearer_token) + static_creds(AzureCredential::BearerToken(bearer_token)) } else if let Some(access_key) = self.access_key { - credential::CredentialProvider::AccessKey(access_key) + static_creds(AzureCredential::AccessKey(access_key)) } else if let (Some(client_id), Some(tenant_id), Some(federated_token_file)) = (&self.client_id, &self.tenant_id, self.federated_token_file) { @@ -946,10 +950,11 @@ impl MicrosoftAzureBuilder { tenant_id, self.authority_host, ); - credential::CredentialProvider::TokenCredential( - TokenCache::default(), - Box::new(client_credential), - ) + Arc::new(TokenCredentialProvider::new( + client_credential, + self.client_options.client()?, + self.retry_config.clone(), + )) as _ } else if let (Some(client_id), Some(client_secret), Some(tenant_id)) = (&self.client_id, self.client_secret, &self.tenant_id) { @@ -959,33 +964,29 @@ impl MicrosoftAzureBuilder { tenant_id, self.authority_host, ); - credential::CredentialProvider::TokenCredential( - TokenCache::default(), - Box::new(client_credential), - ) + Arc::new(TokenCredentialProvider::new( + client_credential, + self.client_options.client()?, + self.retry_config.clone(), + )) as _ } else if let Some(query_pairs) = self.sas_query_pairs { - credential::CredentialProvider::SASToken(query_pairs) + static_creds(AzureCredential::SASToken(query_pairs)) } else if let Some(sas) = self.sas_key { - credential::CredentialProvider::SASToken(split_sas(&sas)?) + static_creds(AzureCredential::SASToken(split_sas(&sas)?)) } else if self.use_azure_cli.get()? { - credential::CredentialProvider::TokenCredential( - TokenCache::default(), - Box::new(credential::AzureCliCredential::new()), - ) + Arc::new(credential::AzureCliCredential::new()) as _ } else { - let client = - self.client_options.clone().with_allow_http(true).client()?; let msi_credential = credential::ImdsManagedIdentityProvider::new( self.client_id, self.object_id, self.msi_resource_id, self.msi_endpoint, - client, ); - credential::CredentialProvider::TokenCredential( - TokenCache::default(), - Box::new(msi_credential), - ) + Arc::new(TokenCredentialProvider::new( + msi_credential, + self.client_options.clone().with_allow_http(true).client()?, + self.retry_config.clone(), + )) as _ }; (false, url, credential, account_name) }; diff --git a/object_store/src/client/mod.rs b/object_store/src/client/mod.rs index c6a73fe7a618..292e4678fd69 100644 --- a/object_store/src/client/mod.rs +++ b/object_store/src/client/mod.rs @@ -32,17 +32,20 @@ pub mod header; #[cfg(any(feature = "aws", feature = "gcp"))] pub mod list; +use async_trait::async_trait; use std::collections::HashMap; use std::str::FromStr; +use std::sync::Arc; use std::time::Duration; use reqwest::header::{HeaderMap, HeaderValue}; use reqwest::{Client, ClientBuilder, Proxy, RequestBuilder}; use serde::{Deserialize, Serialize}; +use crate::client::token::{TemporaryToken, TokenCache}; use crate::config::{fmt_duration, ConfigValue}; use crate::path::Path; -use crate::GetOptions; +use crate::{GetOptions, Result, RetryConfig}; fn map_client_error(e: reqwest::Error) -> super::Error { super::Error::Generic { @@ -503,6 +506,90 @@ impl GetOptionsExt for RequestBuilder { } } +/// Provides credentials for use when signing requests +#[async_trait] +pub trait CredentialProvider: std::fmt::Debug + Send + Sync { + type Credential; + + async fn get_credential(&self) -> Result>; +} + +/// A static set of credentials +#[derive(Debug)] +pub struct StaticCredentialProvider { + credential: Arc, +} + +impl StaticCredentialProvider { + pub fn new(credential: T) -> Self { + Self { + credential: Arc::new(credential), + } + } +} + +#[async_trait] +impl CredentialProvider for StaticCredentialProvider +where + T: std::fmt::Debug + Send + Sync, +{ + type Credential = T; + + async fn get_credential(&self) -> Result> { + Ok(Arc::clone(&self.credential)) + } +} + +#[cfg(any(feature = "aws", feature = "azure", feature = "gcp"))] +mod cloud { + use super::*; + + /// A [`CredentialProvider`] that uses [`Client`] to fetch temporary tokens + #[derive(Debug)] + pub struct TokenCredentialProvider { + inner: T, + client: Client, + retry: RetryConfig, + cache: TokenCache>, + } + + impl TokenCredentialProvider { + pub fn new(inner: T, client: Client, retry: RetryConfig) -> Self { + Self { + inner, + client, + retry, + cache: Default::default(), + } + } + } + + #[async_trait] + impl CredentialProvider for TokenCredentialProvider { + type Credential = T::Credential; + + async fn get_credential(&self) -> Result> { + self.cache + .get_or_insert_with(|| self.inner.fetch_token(&self.client, &self.retry)) + .await + } + } + + #[async_trait] + pub trait TokenProvider: std::fmt::Debug + Send + Sync { + type Credential: std::fmt::Debug + Send + Sync; + + async fn fetch_token( + &self, + client: &Client, + retry: &RetryConfig, + ) -> Result>>; + } +} + +#[cfg(any(feature = "aws", feature = "azure", feature = "gcp"))] +pub use cloud::*; + #[cfg(test)] mod tests { use super::*; diff --git a/object_store/src/gcp/credential.rs b/object_store/src/gcp/credential.rs index 057e013334ed..ad12855e19ef 100644 --- a/object_store/src/gcp/credential.rs +++ b/object_store/src/gcp/credential.rs @@ -17,6 +17,9 @@ use crate::client::retry::RetryExt; use crate::client::token::TemporaryToken; +use crate::client::{TokenCredentialProvider, TokenProvider}; +use crate::gcp::credential::Error::UnsupportedCredentialsType; +use crate::gcp::{GcpCredentialProvider, STORE}; use crate::ClientOptions; use crate::RetryConfig; use async_trait::async_trait; @@ -30,6 +33,7 @@ use std::env; use std::fs::File; use std::io::BufReader; use std::path::{Path, PathBuf}; +use std::sync::Arc; use std::time::{Duration, Instant}; use tracing::info; @@ -67,9 +71,21 @@ pub enum Error { #[snafu(display("Unsupported ApplicationCredentials type: {}", type_))] UnsupportedCredentialsType { type_: String }, +} + +impl From for crate::Error { + fn from(value: Error) -> Self { + Self::Generic { + store: STORE, + source: Box::new(value), + } + } +} - #[snafu(display("Error creating client: {}", source))] - Client { source: crate::Error }, +#[derive(Debug, Eq, PartialEq)] +pub struct GcpCredential { + /// An HTTP bearer token + pub bearer: String, } pub type Result = std::result::Result; @@ -127,15 +143,6 @@ struct TokenResponse { expires_in: u64, } -#[async_trait] -pub trait TokenProvider: std::fmt::Debug + Send + Sync { - async fn fetch_token( - &self, - client: &Client, - retry: &RetryConfig, - ) -> Result>; -} - /// Encapsulates the logic to perform an OAuth token challenge #[derive(Debug)] pub struct OAuthProvider { @@ -174,12 +181,14 @@ impl OAuthProvider { #[async_trait] impl TokenProvider for OAuthProvider { + type Credential = GcpCredential; + /// Fetch a fresh token async fn fetch_token( &self, client: &Client, retry: &RetryConfig, - ) -> Result> { + ) -> crate::Result>> { let now = seconds_since_epoch(); let exp = now + 3600; @@ -221,12 +230,12 @@ impl TokenProvider for OAuthProvider { .await .context(TokenResponseBodySnafu)?; - let token = TemporaryToken { - token: response.access_token, + Ok(TemporaryToken { + token: Arc::new(GcpCredential { + bearer: response.access_token, + }), expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)), - }; - - Ok(token) + }) } } @@ -281,17 +290,17 @@ impl ServiceAccountCredentials { } /// Create an [`OAuthProvider`] from this credentials struct. - pub fn token_provider( + pub fn oauth_provider( self, scope: &str, audience: &str, - ) -> Result> { - Ok(Box::new(OAuthProvider::new( + ) -> crate::Result { + Ok(OAuthProvider::new( self.client_email, self.private_key, scope.to_string(), audience.to_string(), - )?) as Box) + )?) } } @@ -329,23 +338,14 @@ fn b64_encode_obj(obj: &T) -> Result { #[derive(Debug, Default)] pub struct InstanceCredentialProvider { audience: String, - client: Client, } impl InstanceCredentialProvider { /// Create a new [`InstanceCredentialProvider`], we need to control the client in order to enable http access so save the options. - pub fn new>( - audience: T, - client_options: ClientOptions, - ) -> Result { - client_options - .with_allow_http(true) - .client() - .map(|client| Self { - audience: audience.into(), - client, - }) - .context(ClientSnafu) + pub fn new>(audience: T) -> Self { + Self { + audience: audience.into(), + } } } @@ -355,7 +355,7 @@ async fn make_metadata_request( hostname: &str, retry: &RetryConfig, audience: &str, -) -> Result { +) -> crate::Result { let url = format!( "http://{hostname}/computeMetadata/v1/instance/service-accounts/default/token" ); @@ -374,30 +374,29 @@ async fn make_metadata_request( #[async_trait] impl TokenProvider for InstanceCredentialProvider { + type Credential = GcpCredential; + /// Fetch a token from the metadata server. /// Since the connection is local we need to enable http access and don't actually use the client object passed in. async fn fetch_token( &self, - _client: &Client, + client: &Client, retry: &RetryConfig, - ) -> Result> { + ) -> crate::Result>> { const METADATA_IP: &str = "169.254.169.254"; const METADATA_HOST: &str = "metadata"; info!("fetching token from metadata server"); let response = - make_metadata_request(&self.client, METADATA_HOST, retry, &self.audience) + make_metadata_request(client, METADATA_HOST, retry, &self.audience) .or_else(|_| { - make_metadata_request( - &self.client, - METADATA_IP, - retry, - &self.audience, - ) + make_metadata_request(client, METADATA_IP, retry, &self.audience) }) .await?; let token = TemporaryToken { - token: response.access_token, + token: Arc::new(GcpCredential { + bearer: response.access_token, + }), expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)), }; Ok(token) @@ -406,31 +405,35 @@ impl TokenProvider for InstanceCredentialProvider { /// ApplicationDefaultCredentials /// -#[derive(Debug)] -pub enum ApplicationDefaultCredentials { - /// - AuthorizedUser { - client_id: String, - client_secret: String, - refresh_token: String, - }, -} - -impl ApplicationDefaultCredentials { - pub fn new(path: Option<&str>) -> Result, Error> { - let file = match ApplicationDefaultCredentialsFile::read(path)? { - Some(f) => f, - None => return Ok(None), - }; - - Ok(Some(match file.type_.as_str() { - "authorized_user" => Self::AuthorizedUser { +pub fn application_default_credentials( + path: Option<&str>, + client: &ClientOptions, + retry: &RetryConfig, +) -> crate::Result> { + let file = match ApplicationDefaultCredentialsFile::read(path)? { + Some(x) => x, + None => return Ok(None), + }; + + match file.type_.as_str() { + // + "authorized_user" => { + let token = AuthorizedUserCredentials { client_id: file.client_id, client_secret: file.client_secret, refresh_token: file.refresh_token, - }, - type_ => return UnsupportedCredentialsTypeSnafu { type_ }.fail(), - })) + }; + + Ok(Some(Arc::new(TokenCredentialProvider::new( + token, + client.client()?, + retry.clone(), + )))) + } + type_ => Err(UnsupportedCredentialsType { + type_: type_.to_string(), + } + .into()), } } @@ -473,41 +476,43 @@ impl ApplicationDefaultCredentialsFile { const DEFAULT_TOKEN_GCP_URI: &str = "https://accounts.google.com/o/oauth2/token"; +/// +#[derive(Debug)] +struct AuthorizedUserCredentials { + client_id: String, + client_secret: String, + refresh_token: String, +} + #[async_trait] -impl TokenProvider for ApplicationDefaultCredentials { +impl TokenProvider for AuthorizedUserCredentials { + type Credential = GcpCredential; + async fn fetch_token( &self, client: &Client, retry: &RetryConfig, - ) -> Result, Error> { - let builder = client.request(Method::POST, DEFAULT_TOKEN_GCP_URI); - let builder = match self { - Self::AuthorizedUser { - client_id, - client_secret, - refresh_token, - } => { - let body = [ - ("grant_type", "refresh_token"), - ("client_id", client_id), - ("client_secret", client_secret), - ("refresh_token", refresh_token), - ]; - builder.form(&body) - } - }; - - let response = builder + ) -> crate::Result>> { + let response = client + .request(Method::POST, DEFAULT_TOKEN_GCP_URI) + .form(&[ + ("grant_type", "refresh_token"), + ("client_id", &self.client_id), + ("client_secret", &self.client_secret), + ("refresh_token", &self.refresh_token), + ]) .send_retry(retry) .await .context(TokenRequestSnafu)? .json::() .await .context(TokenResponseBodySnafu)?; - let token = TemporaryToken { - token: response.access_token, + + Ok(TemporaryToken { + token: Arc::new(GcpCredential { + bearer: response.access_token, + }), expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)), - }; - Ok(token) + }) } } diff --git a/object_store/src/gcp/mod.rs b/object_store/src/gcp/mod.rs index 32f4055f1178..6813bbf6ecf7 100644 --- a/object_store/src/gcp/mod.rs +++ b/object_store/src/gcp/mod.rs @@ -48,9 +48,12 @@ use crate::client::header::header_meta; use crate::client::list::ListResponse; use crate::client::pagination::stream_paginated; use crate::client::retry::RetryExt; -use crate::client::{ClientConfigKey, GetOptionsExt}; +use crate::client::{ + ClientConfigKey, CredentialProvider, GetOptionsExt, StaticCredentialProvider, + TokenCredentialProvider, +}; +use crate::gcp::credential::{application_default_credentials, GcpCredential}; use crate::{ - client::token::TokenCache, multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart}, path::{Path, DELIMITER}, util::format_prefix, @@ -59,14 +62,15 @@ use crate::{ }; use self::credential::{ - default_gcs_base_url, ApplicationDefaultCredentials, InstanceCredentialProvider, - ServiceAccountCredentials, TokenProvider, + default_gcs_base_url, InstanceCredentialProvider, ServiceAccountCredentials, }; mod credential; const STORE: &str = "GCS"; +type GcpCredentialProvider = Arc>; + #[derive(Debug, Snafu)] enum Error { #[snafu(display("Got invalid XML response for {} {}: {}", method, url, source))] @@ -119,9 +123,6 @@ enum Error { #[snafu(display("Missing bucket name"))] MissingBucketName {}, - #[snafu(display("Could not find either metadata credentials or configuration properties to initialize GCS credentials."))] - MissingCredentials, - #[snafu(display( "One of service account path or service account key may be provided." ))] @@ -209,8 +210,7 @@ struct GoogleCloudStorageClient { client: Client, base_url: String, - token_provider: Option>>, - token_cache: TokenCache, + credentials: GcpCredentialProvider, bucket_name: String, bucket_name_encoded: String, @@ -223,18 +223,8 @@ struct GoogleCloudStorageClient { } impl GoogleCloudStorageClient { - async fn get_token(&self) -> Result { - if let Some(token_provider) = &self.token_provider { - Ok(self - .token_cache - .get_or_insert_with(|| { - token_provider.fetch_token(&self.client, &self.retry_config) - }) - .await - .context(CredentialSnafu)?) - } else { - Ok("".to_owned()) - } + async fn get_credential(&self) -> Result> { + self.credentials.get_credential().await } fn object_url(&self, path: &Path) -> String { @@ -249,7 +239,7 @@ impl GoogleCloudStorageClient { options: GetOptions, head: bool, ) -> Result { - let token = self.get_token().await?; + let credential = self.get_credential().await?; let url = self.object_url(path); let method = match head { @@ -260,7 +250,7 @@ impl GoogleCloudStorageClient { let response = self .client .request(method, url) - .bearer_auth(token) + .bearer_auth(&credential.bearer) .with_get_options(options) .send_retry(&self.retry_config) .await @@ -273,7 +263,7 @@ impl GoogleCloudStorageClient { /// Perform a put request async fn put_request(&self, path: &Path, payload: Bytes) -> Result<()> { - let token = self.get_token().await?; + let credential = self.get_credential().await?; let url = self.object_url(path); let content_type = self @@ -283,7 +273,7 @@ impl GoogleCloudStorageClient { self.client .request(Method::PUT, url) - .bearer_auth(token) + .bearer_auth(&credential.bearer) .header(header::CONTENT_TYPE, content_type) .header(header::CONTENT_LENGTH, payload.len()) .body(payload) @@ -298,7 +288,7 @@ impl GoogleCloudStorageClient { /// Initiate a multi-part upload async fn multipart_initiate(&self, path: &Path) -> Result { - let token = self.get_token().await?; + let credential = self.get_credential().await?; let url = format!("{}/{}/{}", self.base_url, self.bucket_name_encoded, path); let content_type = self @@ -309,7 +299,7 @@ impl GoogleCloudStorageClient { let response = self .client .request(Method::POST, &url) - .bearer_auth(token) + .bearer_auth(&credential.bearer) .header(header::CONTENT_TYPE, content_type) .header(header::CONTENT_LENGTH, "0") .query(&[("uploads", "")]) @@ -338,12 +328,12 @@ impl GoogleCloudStorageClient { path: &str, multipart_id: &MultipartId, ) -> Result<()> { - let token = self.get_token().await?; + let credential = self.get_credential().await?; let url = format!("{}/{}/{}", self.base_url, self.bucket_name_encoded, path); self.client .request(Method::DELETE, &url) - .bearer_auth(token) + .bearer_auth(&credential.bearer) .header(header::CONTENT_TYPE, "application/octet-stream") .header(header::CONTENT_LENGTH, "0") .query(&[("uploadId", multipart_id)]) @@ -356,12 +346,12 @@ impl GoogleCloudStorageClient { /// Perform a delete request async fn delete_request(&self, path: &Path) -> Result<()> { - let token = self.get_token().await?; + let credential = self.get_credential().await?; let url = self.object_url(path); let builder = self.client.request(Method::DELETE, url); builder - .bearer_auth(token) + .bearer_auth(&credential.bearer) .send_retry(&self.retry_config) .await .context(DeleteRequestSnafu { @@ -378,7 +368,7 @@ impl GoogleCloudStorageClient { to: &Path, if_not_exists: bool, ) -> Result<()> { - let token = self.get_token().await?; + let credential = self.get_credential().await?; let url = self.object_url(to); let from = utf8_percent_encode(from.as_ref(), NON_ALPHANUMERIC); @@ -394,7 +384,7 @@ impl GoogleCloudStorageClient { } builder - .bearer_auth(token) + .bearer_auth(&credential.bearer) // Needed if reqwest is compiled with native-tls instead of rustls-tls // See https://github.com/apache/arrow-rs/pull/3921 .header(header::CONTENT_LENGTH, 0) @@ -418,7 +408,7 @@ impl GoogleCloudStorageClient { delimiter: bool, page_token: Option<&str>, ) -> Result { - let token = self.get_token().await?; + let credential = self.get_credential().await?; let url = format!("{}/{}", self.base_url, self.bucket_name_encoded); let mut query = Vec::with_capacity(5); @@ -443,7 +433,7 @@ impl GoogleCloudStorageClient { .client .request(Method::GET, url) .query(&query) - .bearer_auth(token) + .bearer_auth(&credential.bearer) .send_retry(&self.retry_config) .await .context(ListRequestSnafu)? @@ -495,9 +485,9 @@ impl CloudMultiPartUploadImpl for GCSMultipartUpload { self.client.base_url, self.client.bucket_name_encoded, self.encoded_path ); - let token = self + let credential = self .client - .get_token() + .get_credential() .await .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; @@ -505,7 +495,7 @@ impl CloudMultiPartUploadImpl for GCSMultipartUpload { .client .client .request(Method::PUT, &url) - .bearer_auth(token) + .bearer_auth(&credential.bearer) .query(&[ ("partNumber", format!("{}", part_idx + 1)), ("uploadId", upload_id), @@ -549,9 +539,9 @@ impl CloudMultiPartUploadImpl for GCSMultipartUpload { }) .collect(); - let token = self + let credential = self .client - .get_token() + .get_credential() .await .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; @@ -567,7 +557,7 @@ impl CloudMultiPartUploadImpl for GCSMultipartUpload { self.client .client .request(Method::POST, &url) - .bearer_auth(token) + .bearer_auth(&credential.bearer) .query(&[("uploadId", upload_id)]) .body(data) .send_retry(&self.client.retry_config) @@ -1062,10 +1052,11 @@ impl GoogleCloudStorageBuilder { }; // Then try to initialize from the application credentials file, or the environment. - let application_default_credentials = ApplicationDefaultCredentials::new( + let application_default_credentials = application_default_credentials( self.application_credentials_path.as_deref(), - ) - .context(CredentialSnafu)?; + &self.client_options, + &self.retry_config, + )?; let disable_oauth = service_account_credentials .as_ref() @@ -1081,29 +1072,24 @@ impl GoogleCloudStorageBuilder { let scope = "https://www.googleapis.com/auth/devstorage.full_control"; let audience = "https://www.googleapis.com/oauth2/v4/token"; - let token_provider = if disable_oauth { - None + let credentials = if disable_oauth { + Arc::new(StaticCredentialProvider::new(GcpCredential { + bearer: "".to_string(), + })) as _ + } else if let Some(credentials) = service_account_credentials { + Arc::new(TokenCredentialProvider::new( + credentials.oauth_provider(scope, audience)?, + self.client_options.client()?, + self.retry_config.clone(), + )) as _ + } else if let Some(credentials) = application_default_credentials { + credentials } else { - let best_provider = if let Some(credentials) = service_account_credentials { - Some( - credentials - .token_provider(scope, audience) - .context(CredentialSnafu)?, - ) - } else if let Some(credentials) = application_default_credentials { - Some(Box::new(credentials) as Box) - } else { - Some(Box::new( - InstanceCredentialProvider::new( - audience, - self.client_options.clone(), - ) - .context(CredentialSnafu)?, - ) as Box) - }; - - // A provider is required at this point, bail out if we don't have one. - Some(best_provider.ok_or(Error::MissingCredentials)?) + Arc::new(TokenCredentialProvider::new( + InstanceCredentialProvider::new(audience), + self.client_options.clone().with_allow_http(true).client()?, + self.retry_config.clone(), + )) as _ }; let encoded_bucket_name = @@ -1113,8 +1099,7 @@ impl GoogleCloudStorageBuilder { client: Arc::new(GoogleCloudStorageClient { client, base_url: gcs_base_url, - token_provider: token_provider.map(Arc::new), - token_cache: Default::default(), + credentials, bucket_name, bucket_name_encoded: encoded_bucket_name, retry_config: self.retry_config,