Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Direct passing WebIdentity token to AWS provider #1724

Merged
merged 4 commits into from
Nov 10, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 100 additions & 24 deletions pkg/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ const (
// "config". It is optional.
ConfigEFSVaultName = "efsVaultName"

// ConfigWebIdentityToken represents the key for AWS Web Identity token
ConfigWebIdentityToken = "webIdentityToken"

// AccessKeyID represents AWS Access key ID
AccessKeyID = "AWS_ACCESS_KEY_ID"
// SecretAccessKey represents AWS Secret Access Key
Expand All @@ -60,6 +63,17 @@ const (
AssumeRoleDuration = "assumeRoleDuration"
)

var _ stscreds.TokenFetcher = (*staticToken)(nil)

// staticToken implements stscreds.TokenFetcher interface for retrieval of plaintext web
// identity token
type staticToken string

// FetchToken returns a plaintext web identity token as is.
func (f staticToken) FetchToken(ctx credentials.Context) ([]byte, error) {
leuyentran marked this conversation as resolved.
Show resolved Hide resolved
return []byte(f), nil
}

func durationFromString(config map[string]string) (time.Duration, error) {
d, ok := config[AssumeRoleDuration]
if !ok || d == "" {
Expand All @@ -68,25 +82,78 @@ func durationFromString(config map[string]string) (time.Duration, error) {
return time.ParseDuration(d)
}

func authenticateAWSCredentials(ctx context.Context, config map[string]string, assumeRoleDuration time.Duration) (*credentials.Credentials, string, error) {
var creds *credentials.Credentials
var assumedRole string

switch {
case config[AccessKeyID] != "" && config[SecretAccessKey] != "":
// If AccessKeys were provided - use those
creds = credentials.NewStaticCredentials(config[AccessKeyID], config[SecretAccessKey], "")
case os.Getenv(webIdentityTokenFilePathEnvKey) != "" && os.Getenv(roleARNEnvKey) != "":
sess, err := session.NewSessionWithOptions(session.Options{AssumeRoleDuration: assumeRoleDuration})
if err != nil {
return nil, "", errors.Wrap(err, "Failed to create session to initialize Web Identify credentials")
}
creds = getCredentialsWithDuration(sess, assumeRoleDuration)
assumedRole = os.Getenv(roleARNEnvKey)
default:
return nil, "", errors.New("AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY required to initialize AWS credentials")
}
return creds, assumedRole, nil
func authenticateAWSCredentials(
config map[string]string,
assumeRoleDuration time.Duration,
) (*credentials.Credentials, string, error) {
// If AccessKeys were provided - use those
creds := fetchStaticAWSCredentials(config)
if creds != nil {
return creds, "", nil
}

// If Web Identity token and role were provided - use them
var err error
creds, err = fetchWebIdentityTokenFromConfig(config, assumeRoleDuration)
if err != nil {
return nil, "", err
}
if creds != nil {
return creds, config[ConfigRole], nil
}

// Otherwise use Web Identity token file and role provided via ENV
creds, err = fetchWebIdentityTokenFromFile(assumeRoleDuration)
if err != nil {
return nil, "", err
}
if creds != nil {
return creds, os.Getenv(roleARNEnvKey), nil
}

return nil, "", errors.New("Missing AWS credentials, please check that either AWS access keys or web identity token are provided")
}

func fetchStaticAWSCredentials(config map[string]string) *credentials.Credentials {
if config[AccessKeyID] == "" || config[SecretAccessKey] == "" {
return nil
}

return credentials.NewStaticCredentials(config[AccessKeyID], config[SecretAccessKey], "")
}

func fetchWebIdentityTokenFromConfig(config map[string]string, assumeRoleDuration time.Duration) (*credentials.Credentials, error) {
if config[ConfigWebIdentityToken] == "" || config[ConfigRole] == "" {
return nil, nil
}

creds, err := getCredentialsWithDuration(
config[ConfigRole],
staticToken(config[ConfigWebIdentityToken]),
assumeRoleDuration,
)
if err != nil {
return nil, err
}

return creds, nil
}

func fetchWebIdentityTokenFromFile(assumeRoleDuration time.Duration) (*credentials.Credentials, error) {
if os.Getenv(webIdentityTokenFilePathEnvKey) == "" || os.Getenv(roleARNEnvKey) == "" {
return nil, nil
}

creds, err := getCredentialsWithDuration(
os.Getenv(roleARNEnvKey),
stscreds.FetchTokenPath(os.Getenv(webIdentityTokenFilePathEnvKey)),
assumeRoleDuration,
)
if err != nil {
return nil, err
}

return creds, nil
}

// switchAWSRole checks if the caller wants to assume a different role
Expand Down Expand Up @@ -116,7 +183,7 @@ func GetCredentials(ctx context.Context, config map[string]string) (*credentials
log.Debug().Print("Assume Role Duration setup", field.M{"assumeRoleDuration": assumeRoleDuration})

// authenticate AWS creds
creds, assumedRole, err := authenticateAWSCredentials(ctx, config, assumeRoleDuration)
creds, assumedRole, err := authenticateAWSCredentials(config, assumeRoleDuration)
if err != nil {
return nil, err
}
Expand All @@ -128,16 +195,25 @@ func GetCredentials(ctx context.Context, config map[string]string) (*credentials
// In order to set a custom assume role duration, we have to get the
// the provider first and then set it's Duration field before
// getting the credentials from the provider.
func getCredentialsWithDuration(sess *session.Session, duration time.Duration) *credentials.Credentials {
func getCredentialsWithDuration(
roleARN string,
tokenFetcher stscreds.TokenFetcher,
duration time.Duration,
) (*credentials.Credentials, error) {
sess, err := session.NewSessionWithOptions(session.Options{AssumeRoleDuration: duration})
if err != nil {
return nil, errors.Wrap(err, "Failed to create session to initialize Web Identify credentials")
}

svc := sts.New(sess)
p := stscreds.NewWebIdentityRoleProviderWithOptions(
svc,
os.Getenv(roleARNEnvKey),
roleARN,
"",
stscreds.FetchTokenPath(os.Getenv(webIdentityTokenFilePathEnvKey)),
tokenFetcher,
)
p.Duration = duration
return credentials.NewCredentials(p)
return credentials.NewCredentials(p), nil
}

// GetConfig returns a configuration to establish AWS connection and connected region name.
Expand Down