Skip to content

Commit

Permalink
pkg/aws: use regional endpoint for STS
Browse files Browse the repository at this point in the history
Signed-off-by: Gyuho Lee <leegyuho@amazon.com>
  • Loading branch information
gyuho committed Jul 13, 2020
1 parent c9140a0 commit 6f1a3f8
Showing 1 changed file with 39 additions and 7 deletions.
46 changes: 39 additions & 7 deletions pkg/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"
"go.uber.org/zap"
"k8s.io/client-go/util/homedir"
Expand Down Expand Up @@ -51,6 +52,9 @@ func New(cfg *Config) (ss *session.Session, stsOutput *sts.GetCallerIdentityOutp
if cfg.Region == "" {
return nil, nil, "", fmt.Errorf("missing region")
}
if cfg.ResolverURL != "" && cfg.SigningName == "" {
return nil, nil, "", fmt.Errorf("got empty signing name for resolver %q", cfg.ResolverURL)
}

awsConfig := aws.Config{
Region: aws.String(cfg.Region),
Expand Down Expand Up @@ -93,32 +97,60 @@ func New(cfg *Config) (ss *session.Session, stsOutput *sts.GetCallerIdentityOutp
awsConfig.LogLevel = &lvl
}

var partition endpoints.Partition
switch cfg.Partition {
case endpoints.AwsPartitionID:
partition = endpoints.AwsPartition()
case endpoints.AwsCnPartitionID:
partition = endpoints.AwsCnPartition()
case endpoints.AwsUsGovPartitionID:
partition = endpoints.AwsUsGovPartition()
case endpoints.AwsIsoPartitionID:
partition = endpoints.AwsIsoPartition()
case endpoints.AwsIsoBPartitionID:
partition = endpoints.AwsIsoBPartition()
default:
return nil, nil, "", fmt.Errorf("unknown partition %q", cfg.Partition)
}
regions := partition.Regions()
region, ok := regions[cfg.Region]
if !ok {
return nil, nil, "", fmt.Errorf("region %q for partition %q not found in %+v", cfg.Region, cfg.Partition, regions)
}
stsEndpoint, err := region.ResolveEndpoint(endpoints.StsServiceID)
if err != nil {
return nil, nil, "", fmt.Errorf("failed to resolve endpoints for sts %q (%v)", cfg.Region, err)
}
stsConfig := awsConfig
stsConfig.STSRegionalEndpoint = endpoints.RegionalSTSEndpoint
var stsSession *session.Session
stsSession, err = session.NewSession(&awsConfig)
stsSession, err = session.NewSession(&stsConfig)
if err != nil {
return nil, nil, "", err
}
iamSvc := iam.New(stsSession)
if _, err = iamSvc.SetSecurityTokenServicePreferences(&iam.SetSecurityTokenServicePreferencesInput{
GlobalEndpointTokenVersion: aws.String("v2Token"),
}); err != nil {
cfg.Logger.Warn("failed to enable v2 security token", zap.Error(err))
}
stsSvc := sts.New(stsSession)
stsOutput, err = stsSvc.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err != nil {
return nil, nil, "", err
}

cfg.Logger.Info(
"creating AWS session",
zap.String("partition", cfg.Partition),
zap.String("region", cfg.Region),
zap.String("region-resolved-sts-endpoint", stsEndpoint.URL),
zap.String("account-id", *stsOutput.Account),
zap.String("user-id", *stsOutput.UserId),
zap.String("arn", *stsOutput.Arn),
)

resolver := endpoints.DefaultResolver()

if cfg.ResolverURL != "" && cfg.SigningName == "" {
return nil, nil, "", fmt.Errorf("got empty signing name for resolver %q", cfg.ResolverURL)
}

// support test endpoint (e.g. https://api.beta.us-west-2.wesley.amazonaws.com)
if cfg.ResolverURL != "" {
cfg.Logger.Info(
"setting custom resolver",
Expand Down

0 comments on commit 6f1a3f8

Please sign in to comment.