Skip to content

Commit

Permalink
Direct passing WebIdentity token to AWS provider (#1724)
Browse files Browse the repository at this point in the history
* [K10-13548] Direct passing WebIdentity token to AWS provider

* Addressed review comments

* Addressed review comments

* Addressed review comments
  • Loading branch information
ed-shilo committed Nov 10, 2022
1 parent 5109313 commit ee171e8
Showing 1 changed file with 100 additions and 24 deletions.
124 changes: 100 additions & 24 deletions pkg/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ const (
// "config". It is optional.
ConfigEFSVaultName = "efsVaultName"

// ConfigWebIdentityToken represents the key for AWS Web Identity token
ConfigWebIdentityToken = "webIdentityToken"

// AccessKeyID represents AWS Access key ID
AccessKeyID = "AWS_ACCESS_KEY_ID"
// SecretAccessKey represents AWS Secret Access Key
Expand All @@ -60,6 +63,17 @@ const (
AssumeRoleDuration = "assumeRoleDuration"
)

var _ stscreds.TokenFetcher = (*staticToken)(nil)

// staticToken implements stscreds.TokenFetcher interface for retrieval of plaintext web
// identity token
type staticToken string

// FetchToken returns a plaintext web identity token as is.
func (f staticToken) FetchToken(ctx credentials.Context) ([]byte, error) {
return []byte(f), nil
}

func durationFromString(config map[string]string) (time.Duration, error) {
d, ok := config[AssumeRoleDuration]
if !ok || d == "" {
Expand All @@ -68,25 +82,78 @@ func durationFromString(config map[string]string) (time.Duration, error) {
return time.ParseDuration(d)
}

func authenticateAWSCredentials(ctx context.Context, config map[string]string, assumeRoleDuration time.Duration) (*credentials.Credentials, string, error) {
var creds *credentials.Credentials
var assumedRole string

switch {
case config[AccessKeyID] != "" && config[SecretAccessKey] != "":
// If AccessKeys were provided - use those
creds = credentials.NewStaticCredentials(config[AccessKeyID], config[SecretAccessKey], "")
case os.Getenv(webIdentityTokenFilePathEnvKey) != "" && os.Getenv(roleARNEnvKey) != "":
sess, err := session.NewSessionWithOptions(session.Options{AssumeRoleDuration: assumeRoleDuration})
if err != nil {
return nil, "", errors.Wrap(err, "Failed to create session to initialize Web Identify credentials")
}
creds = getCredentialsWithDuration(sess, assumeRoleDuration)
assumedRole = os.Getenv(roleARNEnvKey)
default:
return nil, "", errors.New("AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY required to initialize AWS credentials")
}
return creds, assumedRole, nil
func authenticateAWSCredentials(
config map[string]string,
assumeRoleDuration time.Duration,
) (*credentials.Credentials, string, error) {
// If AccessKeys were provided - use those
creds := fetchStaticAWSCredentials(config)
if creds != nil {
return creds, "", nil
}

// If Web Identity token and role were provided - use them
var err error
creds, err = fetchWebIdentityTokenFromConfig(config, assumeRoleDuration)
if err != nil {
return nil, "", err
}
if creds != nil {
return creds, config[ConfigRole], nil
}

// Otherwise use Web Identity token file and role provided via ENV
creds, err = fetchWebIdentityTokenFromFile(assumeRoleDuration)
if err != nil {
return nil, "", err
}
if creds != nil {
return creds, os.Getenv(roleARNEnvKey), nil
}

return nil, "", errors.New("Missing AWS credentials, please check that either AWS access keys or web identity token are provided")
}

func fetchStaticAWSCredentials(config map[string]string) *credentials.Credentials {
if config[AccessKeyID] == "" || config[SecretAccessKey] == "" {
return nil
}

return credentials.NewStaticCredentials(config[AccessKeyID], config[SecretAccessKey], "")
}

func fetchWebIdentityTokenFromConfig(config map[string]string, assumeRoleDuration time.Duration) (*credentials.Credentials, error) {
if config[ConfigWebIdentityToken] == "" || config[ConfigRole] == "" {
return nil, nil
}

creds, err := getCredentialsWithDuration(
config[ConfigRole],
staticToken(config[ConfigWebIdentityToken]),
assumeRoleDuration,
)
if err != nil {
return nil, err
}

return creds, nil
}

func fetchWebIdentityTokenFromFile(assumeRoleDuration time.Duration) (*credentials.Credentials, error) {
if os.Getenv(webIdentityTokenFilePathEnvKey) == "" || os.Getenv(roleARNEnvKey) == "" {
return nil, nil
}

creds, err := getCredentialsWithDuration(
os.Getenv(roleARNEnvKey),
stscreds.FetchTokenPath(os.Getenv(webIdentityTokenFilePathEnvKey)),
assumeRoleDuration,
)
if err != nil {
return nil, err
}

return creds, nil
}

// switchAWSRole checks if the caller wants to assume a different role
Expand Down Expand Up @@ -116,7 +183,7 @@ func GetCredentials(ctx context.Context, config map[string]string) (*credentials
log.Debug().Print("Assume Role Duration setup", field.M{"assumeRoleDuration": assumeRoleDuration})

// authenticate AWS creds
creds, assumedRole, err := authenticateAWSCredentials(ctx, config, assumeRoleDuration)
creds, assumedRole, err := authenticateAWSCredentials(config, assumeRoleDuration)
if err != nil {
return nil, err
}
Expand All @@ -128,16 +195,25 @@ func GetCredentials(ctx context.Context, config map[string]string) (*credentials
// In order to set a custom assume role duration, we have to get the
// the provider first and then set it's Duration field before
// getting the credentials from the provider.
func getCredentialsWithDuration(sess *session.Session, duration time.Duration) *credentials.Credentials {
func getCredentialsWithDuration(
roleARN string,
tokenFetcher stscreds.TokenFetcher,
duration time.Duration,
) (*credentials.Credentials, error) {
sess, err := session.NewSessionWithOptions(session.Options{AssumeRoleDuration: duration})
if err != nil {
return nil, errors.Wrap(err, "Failed to create session to initialize Web Identify credentials")
}

svc := sts.New(sess)
p := stscreds.NewWebIdentityRoleProviderWithOptions(
svc,
os.Getenv(roleARNEnvKey),
roleARN,
"",
stscreds.FetchTokenPath(os.Getenv(webIdentityTokenFilePathEnvKey)),
tokenFetcher,
)
p.Duration = duration
return credentials.NewCredentials(p)
return credentials.NewCredentials(p), nil
}

// GetConfig returns a configuration to establish AWS connection and connected region name.
Expand Down

0 comments on commit ee171e8

Please sign in to comment.