From 9ca78b7d701ea5cf9a882a670d079a41144570f8 Mon Sep 17 00:00:00 2001 From: xibz Date: Fri, 5 Oct 2018 16:51:28 -0700 Subject: [PATCH 1/2] Adding web identity role provider --- aws/credentials/stscreds/testdata/token.jwt | 1 + .../stscreds/web_identity_provider.go | 106 ++++++++++++++ .../stscreds/web_identity_provider_test.go | 133 ++++++++++++++++++ aws/session/env_config.go | 18 +++ aws/session/session.go | 13 ++ 5 files changed, 271 insertions(+) create mode 100644 aws/credentials/stscreds/testdata/token.jwt create mode 100644 aws/credentials/stscreds/web_identity_provider.go create mode 100644 aws/credentials/stscreds/web_identity_provider_test.go diff --git a/aws/credentials/stscreds/testdata/token.jwt b/aws/credentials/stscreds/testdata/token.jwt new file mode 100644 index 00000000000..257cc5642cb --- /dev/null +++ b/aws/credentials/stscreds/testdata/token.jwt @@ -0,0 +1 @@ +foo diff --git a/aws/credentials/stscreds/web_identity_provider.go b/aws/credentials/stscreds/web_identity_provider.go new file mode 100644 index 00000000000..e45b96331de --- /dev/null +++ b/aws/credentials/stscreds/web_identity_provider.go @@ -0,0 +1,106 @@ +package stscreds + +import ( + "fmt" + "io/ioutil" + "strconv" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/client" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go/service/sts/stsiface" +) + +const ( + // ErrCodeWebIdentityRetrievalErr will be used as an error code when constructing + // a new error to be returned during Retrieve. + ErrCodeWebIdentityRetrievalErr = "WebIdentityRetrievalErr" +) + +// now is used to return a time.Time object representing +// the current time. This can be used to easily test and +// compare test values. +var now = func() time.Time { + return time.Now() +} + +// WebIdentityRoleProvider is used to retrieve credentials using +// an OIDC token. +type WebIdentityRoleProvider struct { + credentials.Expiry + + client stsiface.STSAPI + ExpiryWindow time.Duration + + tokenFilePath string + roleARN string + roleSessionName string +} + +// NewWebIdentityCredentials will return a new set of credentials with a given +// configuration, role arn, and token file path. +func NewWebIdentityCredentials(c client.ConfigProvider, roleARN, roleSessionName, path string) *credentials.Credentials { + svc := sts.New(c) + p := NewWebIdentityRoleProvider(svc, roleARN, roleSessionName, path) + return credentials.NewCredentials(p) +} + +// NewWebIdentityRoleProvider will return a new WebIdentityRoleProvider with the +// provided stsiface.STSAPI +func NewWebIdentityRoleProvider(svc stsiface.STSAPI, roleARN, roleSessionName, path string) *WebIdentityRoleProvider { + return &WebIdentityRoleProvider{ + client: svc, + tokenFilePath: path, + roleARN: roleARN, + roleSessionName: roleSessionName, + } +} + +var emptyTokenFilePathErr = awserr.New(ErrCodeWebIdentityRetrievalErr, "token file path is not set", nil) +var emptyRoleARNErr = awserr.New(ErrCodeWebIdentityRetrievalErr, "role ARN is not set", nil) + +// Retrieve will attempt to assume a role from a token which is located at +// 'WebIdentityTokenFilePath' specified destination and if that is empty an +// error will be returned. +func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) { + if len(p.tokenFilePath) == 0 { + return credentials.Value{}, emptyTokenFilePathErr + } + + if len(p.roleARN) == 0 { + return credentials.Value{}, emptyRoleARNErr + } + + b, err := ioutil.ReadFile(p.tokenFilePath) + if err != nil { + errMsg := fmt.Sprintf("unabled to read file at %s", p.tokenFilePath) + return credentials.Value{}, awserr.New(ErrCodeWebIdentityRetrievalErr, errMsg, err) + } + + sessionName := p.roleSessionName + if len(sessionName) == 0 { + // session name is used to uniquely identify a session. This simply + // uses unix time in nanoseconds to uniquely identify sessions. + sessionName = strconv.FormatInt(now().UnixNano(), 10) + } + resp, err := p.client.AssumeRoleWithWebIdentity(&sts.AssumeRoleWithWebIdentityInput{ + RoleArn: &p.roleARN, + RoleSessionName: &sessionName, + WebIdentityToken: aws.String(string(b)), + }) + if err != nil { + return credentials.Value{}, awserr.New(ErrCodeWebIdentityRetrievalErr, "failed to retrieve credentials", err) + } + + p.SetExpiration(aws.TimeValue(resp.Credentials.Expiration), p.ExpiryWindow) + + value := credentials.Value{ + AccessKeyID: aws.StringValue(resp.Credentials.AccessKeyId), + SecretAccessKey: aws.StringValue(resp.Credentials.SecretAccessKey), + SessionToken: aws.StringValue(resp.Credentials.SessionToken), + } + return value, nil +} diff --git a/aws/credentials/stscreds/web_identity_provider_test.go b/aws/credentials/stscreds/web_identity_provider_test.go new file mode 100644 index 00000000000..e37d1bf240f --- /dev/null +++ b/aws/credentials/stscreds/web_identity_provider_test.go @@ -0,0 +1,133 @@ +package stscreds + +import ( + "fmt" + "reflect" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/service/sts" +) + +type mockSTS struct { + *sts.STS + AssumeRoleWithWebIdentityFn func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) +} + +func (m *mockSTS) AssumeRoleWithWebIdentity(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) { + if m.AssumeRoleWithWebIdentityFn != nil { + return m.AssumeRoleWithWebIdentityFn(input) + } + + return nil, nil +} + +func TestWebIdentityProviderRetrieve(t *testing.T) { + now = func() time.Time { + return time.Time{} + } + + cases := []struct { + name string + mockSTS *mockSTS + roleARN string + tokenFilepath string + sessionName string + expectedError error + expectedCredValue credentials.Value + }{ + { + name: "no role arn", + tokenFilepath: "foo/bar", + mockSTS: &mockSTS{ + AssumeRoleWithWebIdentityFn: func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) { + if e, a := fmt.Sprintf("%d", now().UnixNano()), *input.RoleSessionName; !reflect.DeepEqual(e, a) { + t.Errorf("expected %v, but received %v", e, a) + } + + return &sts.AssumeRoleWithWebIdentityOutput{}, nil + }, + }, + expectedError: emptyRoleARNErr, + }, + { + name: "no token file path", + roleARN: "arn", + mockSTS: &mockSTS{ + AssumeRoleWithWebIdentityFn: func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) { + if e, a := fmt.Sprintf("%d", now().UnixNano()), *input.RoleSessionName; !reflect.DeepEqual(e, a) { + t.Errorf("expected %v, but received %v", e, a) + } + + return &sts.AssumeRoleWithWebIdentityOutput{}, nil + }, + }, + expectedError: emptyTokenFilePathErr, + }, + { + name: "session name case", + roleARN: "arn", + tokenFilepath: "testdata/token.jwt", + sessionName: "foo", + mockSTS: &mockSTS{ + AssumeRoleWithWebIdentityFn: func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) { + if e, a := "foo", *input.RoleSessionName; !reflect.DeepEqual(e, a) { + t.Errorf("expected %v, but received %v", e, a) + } + + return &sts.AssumeRoleWithWebIdentityOutput{ + Credentials: &sts.Credentials{ + Expiration: aws.Time(time.Now()), + AccessKeyId: aws.String("access-key-id"), + SecretAccessKey: aws.String("secret-access-key"), + }, + }, nil + }, + }, + expectedCredValue: credentials.Value{ + AccessKeyID: "access-key-id", + SecretAccessKey: "secret-access-key", + }, + }, + { + name: "valid case", + roleARN: "arn", + tokenFilepath: "testdata/token.jwt", + mockSTS: &mockSTS{ + AssumeRoleWithWebIdentityFn: func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) { + if e, a := fmt.Sprintf("%d", now().UnixNano()), *input.RoleSessionName; !reflect.DeepEqual(e, a) { + t.Errorf("expected %v, but received %v", e, a) + } + + return &sts.AssumeRoleWithWebIdentityOutput{ + Credentials: &sts.Credentials{ + Expiration: aws.Time(time.Now()), + AccessKeyId: aws.String("access-key-id"), + SecretAccessKey: aws.String("secret-access-key"), + }, + }, nil + }, + }, + expectedCredValue: credentials.Value{ + AccessKeyID: "access-key-id", + SecretAccessKey: "secret-access-key", + }, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + p := NewWebIdentityRoleProvider(c.mockSTS, c.roleARN, c.sessionName, c.tokenFilepath) + credValue, err := p.Retrieve() + if e, a := c.expectedError, err; !reflect.DeepEqual(e, a) { + t.Errorf("expected %v, but received %v", e, a) + } + + if e, a := c.expectedCredValue, credValue; !reflect.DeepEqual(e, a) { + t.Errorf("expected %v, but received %v", e, a) + } + }) + } +} diff --git a/aws/session/env_config.go b/aws/session/env_config.go index 82e04d76cde..11ea87e237e 100644 --- a/aws/session/env_config.go +++ b/aws/session/env_config.go @@ -101,6 +101,10 @@ type envConfig struct { CSMEnabled bool CSMPort string CSMClientID string + + WebIdentityRoleARN string + IAMRoleSessionName string + WebIdentityTokenFilePath string } var ( @@ -139,6 +143,15 @@ var ( sharedConfigFileEnvKey = []string{ "AWS_CONFIG_FILE", } + webIdentityRoleARNEnvKey = []string{ + "AWS_IAM_ROLE_ARN", + } + webIdentityTokenFilePathEnvKey = []string{ + "AWS_WEB_IDENTITY_TOKEN_FILE", + } + roleSessionNameEnvKey = []string{ + "AWS_IAM_ROLE_SESSION_NAME", + } ) // loadEnvConfig retrieves the SDK's environment configuration. @@ -170,6 +183,7 @@ func envConfigLoad(enableSharedConfig bool) envConfig { setFromEnvVal(&cfg.Creds.AccessKeyID, credAccessEnvKey) setFromEnvVal(&cfg.Creds.SecretAccessKey, credSecretEnvKey) setFromEnvVal(&cfg.Creds.SessionToken, credSessionEnvKey) + setFromEnvVal(&cfg.IAMRoleSessionName, roleSessionNameEnvKey) // CSM environment variables setFromEnvVal(&cfg.csmEnabled, csmEnabledEnvKey) @@ -177,6 +191,10 @@ func envConfigLoad(enableSharedConfig bool) envConfig { setFromEnvVal(&cfg.CSMClientID, csmClientIDEnvKey) cfg.CSMEnabled = len(cfg.csmEnabled) > 0 + // Web identity environment variables + setFromEnvVal(&cfg.WebIdentityRoleARN, webIdentityRoleARNEnvKey) + setFromEnvVal(&cfg.WebIdentityTokenFilePath, webIdentityTokenFilePathEnvKey) + // Require logical grouping of credentials if len(cfg.Creds.AccessKeyID) == 0 || len(cfg.Creds.SecretAccessKey) == 0 { cfg.Creds = credentials.Value{} diff --git a/aws/session/session.go b/aws/session/session.go index 51f30556301..109591057a4 100644 --- a/aws/session/session.go +++ b/aws/session/session.go @@ -440,6 +440,19 @@ func mergeConfigSrcs(cfg, userCfg *aws.Config, envCfg envConfig, sharedCfg share cfg.Credentials = credentials.NewStaticCredentialsFromCreds( envCfg.Creds, ) + + } else if len(envCfg.WebIdentityTokenFilePath) > 0 { + // handles assume role via OIDC token. This should happen before any other + // assume role call. + sessionName := envCfg.IAMRoleSessionName + if len(sessionName) == 0 { + sessionName = sharedCfg.AssumeRole.RoleSessionName + } + + cfg.Credentials = stscreds.NewWebIdentityCredentials(&Session{ + Config: cfg, + Handlers: handlers.Copy(), + }, envCfg.WebIdentityRoleARN, sessionName, envCfg.WebIdentityTokenFilePath) } else if envCfg.EnableSharedConfig && len(sharedCfg.AssumeRole.RoleARN) > 0 && sharedCfg.AssumeRoleSource != nil { cfgCp := *cfg cfgCp.Credentials = credentials.NewStaticCredentialsFromCreds( From 1168ec5a01ad74d7b8bbb4dbe1cabe97ddcb9de8 Mon Sep 17 00:00:00 2001 From: xibz Date: Tue, 9 Oct 2018 13:42:55 -0700 Subject: [PATCH 2/2] adding session error if values aren't properly configured for web identity --- .../stscreds/web_identity_provider.go | 25 +++++-------- .../stscreds/web_identity_provider_test.go | 36 +++++-------------- aws/session/session.go | 18 +++++++++- 3 files changed, 34 insertions(+), 45 deletions(-) diff --git a/aws/credentials/stscreds/web_identity_provider.go b/aws/credentials/stscreds/web_identity_provider.go index e45b96331de..723ed6c6a0d 100644 --- a/aws/credentials/stscreds/web_identity_provider.go +++ b/aws/credentials/stscreds/web_identity_provider.go @@ -15,9 +15,12 @@ import ( ) const ( - // ErrCodeWebIdentityRetrievalErr will be used as an error code when constructing - // a new error to be returned during Retrieve. - ErrCodeWebIdentityRetrievalErr = "WebIdentityRetrievalErr" + // ErrCodeWebIdentity will be used as an error code when constructing + // a new error to be returned during session creation or retrieval. + ErrCodeWebIdentity = "WebIdentityErr" + + // WebIdentityProviderName is the web identity provider name + WebIdentityProviderName = "WebIdentityCredentials" ) // now is used to return a time.Time object representing @@ -59,25 +62,14 @@ func NewWebIdentityRoleProvider(svc stsiface.STSAPI, roleARN, roleSessionName, p } } -var emptyTokenFilePathErr = awserr.New(ErrCodeWebIdentityRetrievalErr, "token file path is not set", nil) -var emptyRoleARNErr = awserr.New(ErrCodeWebIdentityRetrievalErr, "role ARN is not set", nil) - // Retrieve will attempt to assume a role from a token which is located at // 'WebIdentityTokenFilePath' specified destination and if that is empty an // error will be returned. func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) { - if len(p.tokenFilePath) == 0 { - return credentials.Value{}, emptyTokenFilePathErr - } - - if len(p.roleARN) == 0 { - return credentials.Value{}, emptyRoleARNErr - } - b, err := ioutil.ReadFile(p.tokenFilePath) if err != nil { errMsg := fmt.Sprintf("unabled to read file at %s", p.tokenFilePath) - return credentials.Value{}, awserr.New(ErrCodeWebIdentityRetrievalErr, errMsg, err) + return credentials.Value{}, awserr.New(ErrCodeWebIdentity, errMsg, err) } sessionName := p.roleSessionName @@ -92,7 +84,7 @@ func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) { WebIdentityToken: aws.String(string(b)), }) if err != nil { - return credentials.Value{}, awserr.New(ErrCodeWebIdentityRetrievalErr, "failed to retrieve credentials", err) + return credentials.Value{}, awserr.New(ErrCodeWebIdentity, "failed to retrieve credentials", err) } p.SetExpiration(aws.TimeValue(resp.Credentials.Expiration), p.ExpiryWindow) @@ -101,6 +93,7 @@ func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) { AccessKeyID: aws.StringValue(resp.Credentials.AccessKeyId), SecretAccessKey: aws.StringValue(resp.Credentials.SecretAccessKey), SessionToken: aws.StringValue(resp.Credentials.SessionToken), + ProviderName: WebIdentityProviderName, } return value, nil } diff --git a/aws/credentials/stscreds/web_identity_provider_test.go b/aws/credentials/stscreds/web_identity_provider_test.go index e37d1bf240f..ae00a71b6fe 100644 --- a/aws/credentials/stscreds/web_identity_provider_test.go +++ b/aws/credentials/stscreds/web_identity_provider_test.go @@ -1,3 +1,5 @@ +// +build go1.7 + package stscreds import ( @@ -38,34 +40,6 @@ func TestWebIdentityProviderRetrieve(t *testing.T) { expectedError error expectedCredValue credentials.Value }{ - { - name: "no role arn", - tokenFilepath: "foo/bar", - mockSTS: &mockSTS{ - AssumeRoleWithWebIdentityFn: func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) { - if e, a := fmt.Sprintf("%d", now().UnixNano()), *input.RoleSessionName; !reflect.DeepEqual(e, a) { - t.Errorf("expected %v, but received %v", e, a) - } - - return &sts.AssumeRoleWithWebIdentityOutput{}, nil - }, - }, - expectedError: emptyRoleARNErr, - }, - { - name: "no token file path", - roleARN: "arn", - mockSTS: &mockSTS{ - AssumeRoleWithWebIdentityFn: func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) { - if e, a := fmt.Sprintf("%d", now().UnixNano()), *input.RoleSessionName; !reflect.DeepEqual(e, a) { - t.Errorf("expected %v, but received %v", e, a) - } - - return &sts.AssumeRoleWithWebIdentityOutput{}, nil - }, - }, - expectedError: emptyTokenFilePathErr, - }, { name: "session name case", roleARN: "arn", @@ -82,6 +56,7 @@ func TestWebIdentityProviderRetrieve(t *testing.T) { Expiration: aws.Time(time.Now()), AccessKeyId: aws.String("access-key-id"), SecretAccessKey: aws.String("secret-access-key"), + SessionToken: aws.String("session-token"), }, }, nil }, @@ -89,6 +64,8 @@ func TestWebIdentityProviderRetrieve(t *testing.T) { expectedCredValue: credentials.Value{ AccessKeyID: "access-key-id", SecretAccessKey: "secret-access-key", + SessionToken: "session-token", + ProviderName: WebIdentityProviderName, }, }, { @@ -106,6 +83,7 @@ func TestWebIdentityProviderRetrieve(t *testing.T) { Expiration: aws.Time(time.Now()), AccessKeyId: aws.String("access-key-id"), SecretAccessKey: aws.String("secret-access-key"), + SessionToken: aws.String("session-token"), }, }, nil }, @@ -113,6 +91,8 @@ func TestWebIdentityProviderRetrieve(t *testing.T) { expectedCredValue: credentials.Value{ AccessKeyID: "access-key-id", SecretAccessKey: "secret-access-key", + SessionToken: "session-token", + ProviderName: WebIdentityProviderName, }, }, } diff --git a/aws/session/session.go b/aws/session/session.go index 109591057a4..216752a4e7b 100644 --- a/aws/session/session.go +++ b/aws/session/session.go @@ -284,6 +284,14 @@ func Must(sess *Session, err error) *Session { return sess } +// WebIdentityEmptyRoleARNErr will occur if 'AWS_WEB_IDENTITY_TOKEN_FILE' was set but +// 'AWS_IAM_ROLE_ARN' was not set. +var WebIdentityEmptyRoleARNErr = awserr.New(stscreds.ErrCodeWebIdentity, "role ARN is not set", nil) + +// WebIdentityEmptyTokenFilePathErr will occur if 'AWS_IAM_ROLE_ARN' was set but +// 'AWS_WEB_IDENTITY_TOKEN_FILE' was not set. +var WebIdentityEmptyTokenFilePathErr = awserr.New(stscreds.ErrCodeWebIdentity, "token file path is not set", nil) + func deprecatedNewSession(cfgs ...*aws.Config) *Session { cfg := defaults.Config() handlers := defaults.Handlers() @@ -441,9 +449,17 @@ func mergeConfigSrcs(cfg, userCfg *aws.Config, envCfg envConfig, sharedCfg share envCfg.Creds, ) - } else if len(envCfg.WebIdentityTokenFilePath) > 0 { + } else if len(envCfg.WebIdentityTokenFilePath) > 0 || len(envCfg.WebIdentityRoleARN) > 0 { // handles assume role via OIDC token. This should happen before any other // assume role call. + if len(envCfg.WebIdentityTokenFilePath) == 0 { + return WebIdentityEmptyTokenFilePathErr + } + + if len(envCfg.WebIdentityRoleARN) == 0 { + return WebIdentityEmptyRoleARNErr + } + sessionName := envCfg.IAMRoleSessionName if len(sessionName) == 0 { sessionName = sharedCfg.AssumeRole.RoleSessionName