Skip to content

Commit

Permalink
fix(aws provider): Pass configured region to credentials provider (#1…
Browse files Browse the repository at this point in the history
…2315)

* fix(aws provider): Pass configured region to credentials provider

Fixes: #12314

Relevant: #12313

Signed-off-by: Jesse Szwedko <jesse@szwedko.me>
  • Loading branch information
jszwedko committed Apr 21, 2022
1 parent 0a4ed84 commit 8a9f37a
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 14 deletions.
38 changes: 28 additions & 10 deletions src/aws/auth.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use aws_config::{default_provider::credentials::default_provider, sts::AssumeRoleProviderBuilder};
use aws_types::{credentials::SharedCredentialsProvider, Credentials};
use aws_config::{
default_provider::credentials::DefaultCredentialsChain, sts::AssumeRoleProviderBuilder,
};
use aws_types::{credentials::SharedCredentialsProvider, region::Region, Credentials};
use serde::{Deserialize, Serialize};

/// Configuration for configuring authentication strategy for AWS.
Expand Down Expand Up @@ -28,7 +30,10 @@ pub enum AwsAuthentication {
}

impl AwsAuthentication {
pub async fn credentials_provider(&self) -> crate::Result<SharedCredentialsProvider> {
pub async fn credentials_provider(
&self,
region: Option<Region>,
) -> crate::Result<SharedCredentialsProvider> {
match self {
Self::Static {
access_key_id,
Expand All @@ -41,12 +46,19 @@ impl AwsAuthentication {
AwsAuthentication::File { .. } => {
Err("Overriding the credentials file is not supported.".into())
}
AwsAuthentication::Role { assume_role } => Ok(SharedCredentialsProvider::new(
AssumeRoleProviderBuilder::new(assume_role)
.build(default_credentials_provider().await),
)),
AwsAuthentication::Role { assume_role } => {
let mut credentials = AssumeRoleProviderBuilder::new(assume_role);

if let Some(ref region) = region {
credentials = credentials.region(region.clone())
}

Ok(SharedCredentialsProvider::new(
credentials.build(default_credentials_provider(region).await),
))
}
AwsAuthentication::Default {} => Ok(SharedCredentialsProvider::new(
default_credentials_provider().await,
default_credentials_provider(region).await,
)),
}
}
Expand All @@ -60,8 +72,14 @@ impl AwsAuthentication {
}
}

async fn default_credentials_provider() -> SharedCredentialsProvider {
SharedCredentialsProvider::new(default_provider().await)
async fn default_credentials_provider(region: Option<Region>) -> SharedCredentialsProvider {
let mut credentials = DefaultCredentialsChain::builder();

if let Some(region) = region {
credentials = credentials.region(region)
}

SharedCredentialsProvider::new(credentials.build().await)
}

#[cfg(test)]
Expand Down
3 changes: 2 additions & 1 deletion src/aws/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ pub async fn create_client<T: ClientBuilder>(
proxy: &ProxyConfig,
tls_options: &Option<TlsOptions>,
) -> crate::Result<T::Client> {
let mut config_builder = T::create_config_builder(auth.credentials_provider().await?);
let mut config_builder =
T::create_config_builder(auth.credentials_provider(region.clone()).await?);

if let Some(endpoint_override) = endpoint {
config_builder = T::with_endpoint_resolver(config_builder, endpoint_override);
Expand Down
2 changes: 1 addition & 1 deletion src/sinks/aws_kinesis_firehose/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ async fn ensure_elasticsearch_domain(domain_name: String) -> String {
aws_sdk_elasticsearch::config::Builder::new()
.credentials_provider(
AwsAuthentication::test_auth()
.credentials_provider()
.credentials_provider(test_region_endpoint().region())
.await
.unwrap(),
)
Expand Down
4 changes: 3 additions & 1 deletion src/sinks/elasticsearch/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ impl ElasticsearchCommon {

let aws_auth = match &config.auth {
Some(ElasticsearchAuth::Basic { .. }) | None => None,
Some(ElasticsearchAuth::Aws(aws)) => Some(aws.credentials_provider().await?),
Some(ElasticsearchAuth::Aws(aws)) => {
Some(aws.credentials_provider(region.clone()).await?)
}
};

let compression = config.compression;
Expand Down
2 changes: 1 addition & 1 deletion src/sources/aws_sqs/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async fn get_sqs_client() -> aws_sdk_sqs::Client {
let config = aws_sdk_sqs::config::Builder::new()
.credentials_provider(
AwsAuthentication::test_auth()
.credentials_provider()
.credentials_provider(Some(Region::new("custom")))
.await
.unwrap(),
)
Expand Down

0 comments on commit 8a9f37a

Please sign in to comment.