Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aws/credentials/stscreds: adding web identity role provider #2193

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aws/credentials/stscreds/testdata/token.jwt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
foo
106 changes: 106 additions & 0 deletions aws/credentials/stscreds/web_identity_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package stscreds

import (
"fmt"
"io/ioutil"
"strconv"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
)

const (
// ErrCodeWebIdentityRetrievalErr will be used as an error code when constructing
// a new error to be returned during Retrieve.
ErrCodeWebIdentityRetrievalErr = "WebIdentityRetrievalErr"
)

// now is used to return a time.Time object representing
// the current time. This can be used to easily test and
// compare test values.
var now = func() time.Time {
return time.Now()
}

// WebIdentityRoleProvider is used to retrieve credentials using
// an OIDC token.
type WebIdentityRoleProvider struct {
credentials.Expiry

client stsiface.STSAPI
ExpiryWindow time.Duration

tokenFilePath string
roleARN string
roleSessionName string
}

// NewWebIdentityCredentials will return a new set of credentials with a given
// configuration, role arn, and token file path.
func NewWebIdentityCredentials(c client.ConfigProvider, roleARN, roleSessionName, path string) *credentials.Credentials {
svc := sts.New(c)
p := NewWebIdentityRoleProvider(svc, roleARN, roleSessionName, path)
return credentials.NewCredentials(p)
}

// NewWebIdentityRoleProvider will return a new WebIdentityRoleProvider with the
// provided stsiface.STSAPI
func NewWebIdentityRoleProvider(svc stsiface.STSAPI, roleARN, roleSessionName, path string) *WebIdentityRoleProvider {
return &WebIdentityRoleProvider{
client: svc,
tokenFilePath: path,
roleARN: roleARN,
roleSessionName: roleSessionName,
}
}

var emptyTokenFilePathErr = awserr.New(ErrCodeWebIdentityRetrievalErr, "token file path is not set", nil)
var emptyRoleARNErr = awserr.New(ErrCodeWebIdentityRetrievalErr, "role ARN is not set", nil)

// Retrieve will attempt to assume a role from a token which is located at
// 'WebIdentityTokenFilePath' specified destination and if that is empty an
// error will be returned.
func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) {
if len(p.tokenFilePath) == 0 {
return credentials.Value{}, emptyTokenFilePathErr
}

if len(p.roleARN) == 0 {
return credentials.Value{}, emptyRoleARNErr
}

b, err := ioutil.ReadFile(p.tokenFilePath)
if err != nil {
errMsg := fmt.Sprintf("unabled to read file at %s", p.tokenFilePath)
return credentials.Value{}, awserr.New(ErrCodeWebIdentityRetrievalErr, errMsg, err)
}

sessionName := p.roleSessionName
if len(sessionName) == 0 {
// session name is used to uniquely identify a session. This simply
// uses unix time in nanoseconds to uniquely identify sessions.
sessionName = strconv.FormatInt(now().UnixNano(), 10)
}
resp, err := p.client.AssumeRoleWithWebIdentity(&sts.AssumeRoleWithWebIdentityInput{
RoleArn: &p.roleARN,
RoleSessionName: &sessionName,
WebIdentityToken: aws.String(string(b)),
})
if err != nil {
return credentials.Value{}, awserr.New(ErrCodeWebIdentityRetrievalErr, "failed to retrieve credentials", err)
}

p.SetExpiration(aws.TimeValue(resp.Credentials.Expiration), p.ExpiryWindow)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible for the Expiration to be nil? If so these credentials will expire immediately since the TimeValue will be zero time. If no Expiration is a valid use case might need to only SetExpiration if Expiration isn't nil.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should never be nil, an assumed role credential will always have an expiration and a token. https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html


value := credentials.Value{
AccessKeyID: aws.StringValue(resp.Credentials.AccessKeyId),
SecretAccessKey: aws.StringValue(resp.Credentials.SecretAccessKey),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably want to set the SessionToken value as well incase the AssumeRoleWithWebIdentity call returns it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As stated above, a SessionToken is a required field and should be added

SessionToken: aws.StringValue(resp.Credentials.SessionToken),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should you add the Value.ProviderName to the credentials?

}
return value, nil
}
133 changes: 133 additions & 0 deletions aws/credentials/stscreds/web_identity_provider_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package stscreds

import (
"fmt"
"reflect"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sts"
)

type mockSTS struct {
*sts.STS
AssumeRoleWithWebIdentityFn func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error)
}

func (m *mockSTS) AssumeRoleWithWebIdentity(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) {
if m.AssumeRoleWithWebIdentityFn != nil {
return m.AssumeRoleWithWebIdentityFn(input)
}

return nil, nil
}

func TestWebIdentityProviderRetrieve(t *testing.T) {
now = func() time.Time {
return time.Time{}
}

cases := []struct {
name string
mockSTS *mockSTS
roleARN string
tokenFilepath string
sessionName string
expectedError error
expectedCredValue credentials.Value
}{
{
name: "no role arn",
tokenFilepath: "foo/bar",
mockSTS: &mockSTS{
AssumeRoleWithWebIdentityFn: func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) {
if e, a := fmt.Sprintf("%d", now().UnixNano()), *input.RoleSessionName; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}

return &sts.AssumeRoleWithWebIdentityOutput{}, nil
},
},
expectedError: emptyRoleARNErr,
},
{
name: "no token file path",
roleARN: "arn",
mockSTS: &mockSTS{
AssumeRoleWithWebIdentityFn: func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) {
if e, a := fmt.Sprintf("%d", now().UnixNano()), *input.RoleSessionName; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}

return &sts.AssumeRoleWithWebIdentityOutput{}, nil
},
},
expectedError: emptyTokenFilePathErr,
},
{
name: "session name case",
roleARN: "arn",
tokenFilepath: "testdata/token.jwt",
sessionName: "foo",
mockSTS: &mockSTS{
AssumeRoleWithWebIdentityFn: func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) {
if e, a := "foo", *input.RoleSessionName; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}

return &sts.AssumeRoleWithWebIdentityOutput{
Credentials: &sts.Credentials{
Expiration: aws.Time(time.Now()),
AccessKeyId: aws.String("access-key-id"),
SecretAccessKey: aws.String("secret-access-key"),
},
}, nil
},
},
expectedCredValue: credentials.Value{
AccessKeyID: "access-key-id",
SecretAccessKey: "secret-access-key",
},
},
{
name: "valid case",
roleARN: "arn",
tokenFilepath: "testdata/token.jwt",
mockSTS: &mockSTS{
AssumeRoleWithWebIdentityFn: func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) {
if e, a := fmt.Sprintf("%d", now().UnixNano()), *input.RoleSessionName; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}

return &sts.AssumeRoleWithWebIdentityOutput{
Credentials: &sts.Credentials{
Expiration: aws.Time(time.Now()),
AccessKeyId: aws.String("access-key-id"),
SecretAccessKey: aws.String("secret-access-key"),
},
}, nil
},
},
expectedCredValue: credentials.Value{
AccessKeyID: "access-key-id",
SecretAccessKey: "secret-access-key",
},
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
p := NewWebIdentityRoleProvider(c.mockSTS, c.roleARN, c.sessionName, c.tokenFilepath)
credValue, err := p.Retrieve()
if e, a := c.expectedError, err; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}

if e, a := c.expectedCredValue, credValue; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}
})
}
}
18 changes: 18 additions & 0 deletions aws/session/env_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ type envConfig struct {
CSMEnabled bool
CSMPort string
CSMClientID string

WebIdentityRoleARN string
IAMRoleSessionName string
WebIdentityTokenFilePath string
}

var (
Expand Down Expand Up @@ -139,6 +143,15 @@ var (
sharedConfigFileEnvKey = []string{
"AWS_CONFIG_FILE",
}
webIdentityRoleARNEnvKey = []string{
"AWS_IAM_ROLE_ARN",
}
webIdentityTokenFilePathEnvKey = []string{
"AWS_WEB_IDENTITY_TOKEN_FILE",
}
roleSessionNameEnvKey = []string{
"AWS_IAM_ROLE_SESSION_NAME",
}
)

// loadEnvConfig retrieves the SDK's environment configuration.
Expand Down Expand Up @@ -170,13 +183,18 @@ func envConfigLoad(enableSharedConfig bool) envConfig {
setFromEnvVal(&cfg.Creds.AccessKeyID, credAccessEnvKey)
setFromEnvVal(&cfg.Creds.SecretAccessKey, credSecretEnvKey)
setFromEnvVal(&cfg.Creds.SessionToken, credSessionEnvKey)
setFromEnvVal(&cfg.IAMRoleSessionName, roleSessionNameEnvKey)

// CSM environment variables
setFromEnvVal(&cfg.csmEnabled, csmEnabledEnvKey)
setFromEnvVal(&cfg.CSMPort, csmPortEnvKey)
setFromEnvVal(&cfg.CSMClientID, csmClientIDEnvKey)
cfg.CSMEnabled = len(cfg.csmEnabled) > 0

// Web identity environment variables
setFromEnvVal(&cfg.WebIdentityRoleARN, webIdentityRoleARNEnvKey)
setFromEnvVal(&cfg.WebIdentityTokenFilePath, webIdentityTokenFilePathEnvKey)

// Require logical grouping of credentials
if len(cfg.Creds.AccessKeyID) == 0 || len(cfg.Creds.SecretAccessKey) == 0 {
cfg.Creds = credentials.Value{}
Expand Down
13 changes: 13 additions & 0 deletions aws/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,19 @@ func mergeConfigSrcs(cfg, userCfg *aws.Config, envCfg envConfig, sharedCfg share
cfg.Credentials = credentials.NewStaticCredentialsFromCreds(
envCfg.Creds,
)

} else if len(envCfg.WebIdentityTokenFilePath) > 0 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also validate that the RoleArn env var is provided, or just let the creds fail?

Copy link
Member

@micahhausler micahhausler Oct 9, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The role ARN is a required field, I'd say it should be required to use this provider

Copy link
Contributor Author

@xibz xibz Oct 9, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will return an error, if role ARN is empty, when Retrieve is called. I figure it'd be best to add it to the credential chain and if it is empty return an error signaling an issue back to the user rather than ignoring it silently.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am okay with changing it to behaving the other way, if it seems more preferable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about this, return an error if only one is set at the session create level rather than at the credential retrieval. That way it fails quickly and let's the user know that they have misconfigured the application

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok that makes sense. I like an error being reported if only one is present

// handles assume role via OIDC token. This should happen before any other
// assume role call.
sessionName := envCfg.IAMRoleSessionName
if len(sessionName) == 0 {
sessionName = sharedCfg.AssumeRole.RoleSessionName
}

cfg.Credentials = stscreds.NewWebIdentityCredentials(&Session{
Config: cfg,
Handlers: handlers.Copy(),
}, envCfg.WebIdentityRoleARN, sessionName, envCfg.WebIdentityTokenFilePath)
} else if envCfg.EnableSharedConfig && len(sharedCfg.AssumeRole.RoleARN) > 0 && sharedCfg.AssumeRoleSource != nil {
cfgCp := *cfg
cfgCp.Credentials = credentials.NewStaticCredentialsFromCreds(
Expand Down