Skip to content

Commit

Permalink
Adding web identity role provider
Browse files Browse the repository at this point in the history
  • Loading branch information
xibz committed Oct 8, 2018
1 parent a423839 commit 9531f6f
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 0 deletions.
1 change: 1 addition & 0 deletions aws/credentials/stscreds/testdata/token.jwt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
foo
100 changes: 100 additions & 0 deletions aws/credentials/stscreds/web_identity_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package stscreds

import (
"fmt"
"io/ioutil"
"strconv"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
)

const (
// ErrCodeWebIdentityRetrievalErr will be used as an error code when constructing
// a new error to be returned during Retrieve.
ErrCodeWebIdentityRetrievalErr = "WebIdentityRetrievalErr"
)

// now is used to return a time.Time object representing
// the current time. This can be used to easily test and
// compare test values.
var now = func() time.Time {
return time.Now()
}

// WebIdentityRoleProvider is used to retrieve credentials using
// an OIDC token.
type WebIdentityRoleProvider struct {
credentials.Expiry

client stsiface.STSAPI
ExpiryWindow time.Duration

tokenFilePath string
roleARN string
}

// NewWebIdentityCredentials will return a new set of credentials with a given
// configuration, role arn, and token file path.
func NewWebIdentityCredentials(c client.ConfigProvider, roleARN, path string) *credentials.Credentials {
svc := sts.New(c)
p := NewWebIdentityRoleProvider(svc, roleARN, path)
return credentials.NewCredentials(p)
}

// NewWebIdentityRoleProvider will return a new WebIdentityRoleProvider with the
// provided stsiface.STSAPI
func NewWebIdentityRoleProvider(svc stsiface.STSAPI, roleARN, path string) *WebIdentityRoleProvider {
return &WebIdentityRoleProvider{
client: svc,
tokenFilePath: path,
roleARN: roleARN,
}
}

var emptyTokenFilePathErr = awserr.New(ErrCodeWebIdentityRetrievalErr, "'WebIdentityTokenFilePath' environment variable is empty", nil)
var emptyRoleARNErr = awserr.New(ErrCodeWebIdentityRetrievalErr, "'WebIdentityRoleARN' environment variable is empty", nil)

// Retrieve will attempt to assume a role from a token which is located at
// 'WebIdentityTokenFilePath' specified destination and if that is empty an
// error will be returned.
func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) {
if len(p.tokenFilePath) == 0 {
return credentials.Value{}, emptyTokenFilePathErr
}

if len(p.roleARN) == 0 {
return credentials.Value{}, emptyRoleARNErr
}

b, err := ioutil.ReadFile(p.tokenFilePath)
if err != nil {
errMsg := fmt.Sprintf("unabled to read file at %s", p.tokenFilePath)
return credentials.Value{}, awserr.New(ErrCodeWebIdentityRetrievalErr, errMsg, err)
}

// session name is used to uniquely identify a session. This simply
// uses unix time in nanoseconds to uniquely identify sessions.
sessionName := strconv.FormatInt(now().UTC().UnixNano(), 10)
resp, err := p.client.AssumeRoleWithWebIdentity(&sts.AssumeRoleWithWebIdentityInput{
RoleArn: &p.roleARN,
RoleSessionName: &sessionName,
WebIdentityToken: aws.String(string(b)),
})
if err != nil {
return credentials.Value{}, awserr.New(ErrCodeWebIdentityRetrievalErr, "failed to retrieve credentials", err)
}

p.SetExpiration(aws.TimeValue(resp.Credentials.Expiration), p.ExpiryWindow)

value := credentials.Value{
AccessKeyID: aws.StringValue(resp.Credentials.AccessKeyId),
SecretAccessKey: aws.StringValue(resp.Credentials.SecretAccessKey),
}
return value, nil
}
90 changes: 90 additions & 0 deletions aws/credentials/stscreds/web_identity_provider_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package stscreds

import (
"reflect"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sts"
)

type mockSTS struct {
*sts.STS
AssumeRoleWithWebIdentityFn func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error)
}

func (m *mockSTS) AssumeRoleWithWebIdentity(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) {
if m.AssumeRoleWithWebIdentityFn != nil {
return m.AssumeRoleWithWebIdentityFn(input)
}

return nil, nil
}

func TestWebIdentityProviderRetrieve(t *testing.T) {
cases := []struct {
name string
mockSTS *mockSTS
roleARN string
tokenFilepath string
expectedError error
expectedCredValue credentials.Value
}{
{
name: "no role arn",
tokenFilepath: "foo/bar",
mockSTS: &mockSTS{
AssumeRoleWithWebIdentityFn: func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) {
return &sts.AssumeRoleWithWebIdentityOutput{}, nil
},
},
expectedError: emptyRoleARNErr,
},
{
name: "no token file path",
roleARN: "arn",
mockSTS: &mockSTS{
AssumeRoleWithWebIdentityFn: func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) {
return &sts.AssumeRoleWithWebIdentityOutput{}, nil
},
},
expectedError: emptyTokenFilePathErr,
},
{
name: "valid case",
roleARN: "arn",
tokenFilepath: "testdata/token.jwt",
mockSTS: &mockSTS{
AssumeRoleWithWebIdentityFn: func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) {
return &sts.AssumeRoleWithWebIdentityOutput{
Credentials: &sts.Credentials{
Expiration: aws.Time(time.Now()),
AccessKeyId: aws.String("access-key-id"),
SecretAccessKey: aws.String("secret-access-key"),
},
}, nil
},
},
expectedCredValue: credentials.Value{
AccessKeyID: "access-key-id",
SecretAccessKey: "secret-access-key",
},
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
p := NewWebIdentityRoleProvider(c.mockSTS, c.roleARN, c.tokenFilepath)
credValue, err := p.Retrieve()
if e, a := c.expectedError, err; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}

if e, a := c.expectedCredValue, credValue; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}
})
}
}
13 changes: 13 additions & 0 deletions aws/session/env_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ type envConfig struct {
CSMEnabled bool
CSMPort string
CSMClientID string

WebIdentityRoleARN string
WebIdentityTokenFilePath string
}

var (
Expand Down Expand Up @@ -139,6 +142,12 @@ var (
sharedConfigFileEnvKey = []string{
"AWS_CONFIG_FILE",
}
webIdentityRoleARNEnvKey = []string{
"AWS_IAM_ROLE_ARN",
}
webIdentityTokenFilePathEnvKey = []string{
"AWS_WEB_IDENTITY_TOKEN_FILE",
}
)

// loadEnvConfig retrieves the SDK's environment configuration.
Expand Down Expand Up @@ -177,6 +186,10 @@ func envConfigLoad(enableSharedConfig bool) envConfig {
setFromEnvVal(&cfg.CSMClientID, csmClientIDEnvKey)
cfg.CSMEnabled = len(cfg.csmEnabled) > 0

// Web identity environment variables
setFromEnvVal(&cfg.WebIdentityRoleARN, webIdentityRoleARNEnvKey)
setFromEnvVal(&cfg.WebIdentityTokenFilePath, webIdentityTokenFilePathEnvKey)

// Require logical grouping of credentials
if len(cfg.Creds.AccessKeyID) == 0 || len(cfg.Creds.SecretAccessKey) == 0 {
cfg.Creds = credentials.Value{}
Expand Down
9 changes: 9 additions & 0 deletions aws/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,15 @@ func mergeConfigSrcs(cfg, userCfg *aws.Config, envCfg envConfig, sharedCfg share
cfg.Credentials = credentials.NewStaticCredentialsFromCreds(
envCfg.Creds,
)

} else if len(envCfg.WebIdentityTokenFilePath) > 0 {
// handles assume role via OIDC token. This should happen before any other
// assume role call.

cfg.Credentials = stscreds.NewWebIdentityCredentials(&Session{
Config: cfg,
Handlers: handlers.Copy(),
}, envCfg.WebIdentityRoleARN, envCfg.WebIdentityTokenFilePath)
} else if envCfg.EnableSharedConfig && len(sharedCfg.AssumeRole.RoleARN) > 0 && sharedCfg.AssumeRoleSource != nil {
cfgCp := *cfg
cfgCp.Credentials = credentials.NewStaticCredentialsFromCreds(
Expand Down

0 comments on commit 9531f6f

Please sign in to comment.