-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
aws/credentials/stscreds: adding web identity role provider #2193
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
foo |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
package stscreds | ||
|
||
import ( | ||
"fmt" | ||
"io/ioutil" | ||
"strconv" | ||
"time" | ||
|
||
"github.com/aws/aws-sdk-go/aws" | ||
"github.com/aws/aws-sdk-go/aws/awserr" | ||
"github.com/aws/aws-sdk-go/aws/client" | ||
"github.com/aws/aws-sdk-go/aws/credentials" | ||
"github.com/aws/aws-sdk-go/service/sts" | ||
"github.com/aws/aws-sdk-go/service/sts/stsiface" | ||
) | ||
|
||
const ( | ||
// ErrCodeWebIdentityRetrievalErr will be used as an error code when constructing | ||
// a new error to be returned during Retrieve. | ||
ErrCodeWebIdentityRetrievalErr = "WebIdentityRetrievalErr" | ||
) | ||
|
||
// now is used to return a time.Time object representing | ||
// the current time. This can be used to easily test and | ||
// compare test values. | ||
var now = func() time.Time { | ||
return time.Now() | ||
} | ||
|
||
// WebIdentityRoleProvider is used to retrieve credentials using | ||
// an OIDC token. | ||
type WebIdentityRoleProvider struct { | ||
credentials.Expiry | ||
|
||
client stsiface.STSAPI | ||
ExpiryWindow time.Duration | ||
|
||
tokenFilePath string | ||
roleARN string | ||
roleSessionName string | ||
} | ||
|
||
// NewWebIdentityCredentials will return a new set of credentials with a given | ||
// configuration, role arn, and token file path. | ||
func NewWebIdentityCredentials(c client.ConfigProvider, roleARN, roleSessionName, path string) *credentials.Credentials { | ||
svc := sts.New(c) | ||
p := NewWebIdentityRoleProvider(svc, roleARN, roleSessionName, path) | ||
return credentials.NewCredentials(p) | ||
} | ||
|
||
// NewWebIdentityRoleProvider will return a new WebIdentityRoleProvider with the | ||
// provided stsiface.STSAPI | ||
func NewWebIdentityRoleProvider(svc stsiface.STSAPI, roleARN, roleSessionName, path string) *WebIdentityRoleProvider { | ||
return &WebIdentityRoleProvider{ | ||
client: svc, | ||
tokenFilePath: path, | ||
roleARN: roleARN, | ||
roleSessionName: roleSessionName, | ||
} | ||
} | ||
|
||
var emptyTokenFilePathErr = awserr.New(ErrCodeWebIdentityRetrievalErr, "token file path is not set", nil) | ||
var emptyRoleARNErr = awserr.New(ErrCodeWebIdentityRetrievalErr, "role ARN is not set", nil) | ||
|
||
// Retrieve will attempt to assume a role from a token which is located at | ||
// 'WebIdentityTokenFilePath' specified destination and if that is empty an | ||
// error will be returned. | ||
func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) { | ||
if len(p.tokenFilePath) == 0 { | ||
return credentials.Value{}, emptyTokenFilePathErr | ||
} | ||
|
||
if len(p.roleARN) == 0 { | ||
return credentials.Value{}, emptyRoleARNErr | ||
} | ||
|
||
b, err := ioutil.ReadFile(p.tokenFilePath) | ||
if err != nil { | ||
errMsg := fmt.Sprintf("unabled to read file at %s", p.tokenFilePath) | ||
return credentials.Value{}, awserr.New(ErrCodeWebIdentityRetrievalErr, errMsg, err) | ||
} | ||
|
||
sessionName := p.roleSessionName | ||
if len(sessionName) == 0 { | ||
// session name is used to uniquely identify a session. This simply | ||
// uses unix time in nanoseconds to uniquely identify sessions. | ||
sessionName = strconv.FormatInt(now().UnixNano(), 10) | ||
} | ||
resp, err := p.client.AssumeRoleWithWebIdentity(&sts.AssumeRoleWithWebIdentityInput{ | ||
RoleArn: &p.roleARN, | ||
RoleSessionName: &sessionName, | ||
WebIdentityToken: aws.String(string(b)), | ||
}) | ||
if err != nil { | ||
return credentials.Value{}, awserr.New(ErrCodeWebIdentityRetrievalErr, "failed to retrieve credentials", err) | ||
} | ||
|
||
p.SetExpiration(aws.TimeValue(resp.Credentials.Expiration), p.ExpiryWindow) | ||
|
||
value := credentials.Value{ | ||
AccessKeyID: aws.StringValue(resp.Credentials.AccessKeyId), | ||
SecretAccessKey: aws.StringValue(resp.Credentials.SecretAccessKey), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably want to set the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As stated above, a SessionToken is a required field and should be added |
||
SessionToken: aws.StringValue(resp.Credentials.SessionToken), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should you add the |
||
} | ||
return value, nil | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
package stscreds | ||
|
||
import ( | ||
"fmt" | ||
"reflect" | ||
"testing" | ||
"time" | ||
|
||
"github.com/aws/aws-sdk-go/aws" | ||
"github.com/aws/aws-sdk-go/aws/credentials" | ||
"github.com/aws/aws-sdk-go/service/sts" | ||
) | ||
|
||
type mockSTS struct { | ||
*sts.STS | ||
AssumeRoleWithWebIdentityFn func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) | ||
} | ||
|
||
func (m *mockSTS) AssumeRoleWithWebIdentity(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) { | ||
if m.AssumeRoleWithWebIdentityFn != nil { | ||
return m.AssumeRoleWithWebIdentityFn(input) | ||
} | ||
|
||
return nil, nil | ||
} | ||
|
||
func TestWebIdentityProviderRetrieve(t *testing.T) { | ||
now = func() time.Time { | ||
return time.Time{} | ||
} | ||
|
||
cases := []struct { | ||
name string | ||
mockSTS *mockSTS | ||
roleARN string | ||
tokenFilepath string | ||
sessionName string | ||
expectedError error | ||
expectedCredValue credentials.Value | ||
}{ | ||
{ | ||
name: "no role arn", | ||
tokenFilepath: "foo/bar", | ||
mockSTS: &mockSTS{ | ||
AssumeRoleWithWebIdentityFn: func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) { | ||
if e, a := fmt.Sprintf("%d", now().UnixNano()), *input.RoleSessionName; !reflect.DeepEqual(e, a) { | ||
t.Errorf("expected %v, but received %v", e, a) | ||
} | ||
|
||
return &sts.AssumeRoleWithWebIdentityOutput{}, nil | ||
}, | ||
}, | ||
expectedError: emptyRoleARNErr, | ||
}, | ||
{ | ||
name: "no token file path", | ||
roleARN: "arn", | ||
mockSTS: &mockSTS{ | ||
AssumeRoleWithWebIdentityFn: func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) { | ||
if e, a := fmt.Sprintf("%d", now().UnixNano()), *input.RoleSessionName; !reflect.DeepEqual(e, a) { | ||
t.Errorf("expected %v, but received %v", e, a) | ||
} | ||
|
||
return &sts.AssumeRoleWithWebIdentityOutput{}, nil | ||
}, | ||
}, | ||
expectedError: emptyTokenFilePathErr, | ||
}, | ||
{ | ||
name: "session name case", | ||
roleARN: "arn", | ||
tokenFilepath: "testdata/token.jwt", | ||
sessionName: "foo", | ||
mockSTS: &mockSTS{ | ||
AssumeRoleWithWebIdentityFn: func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) { | ||
if e, a := "foo", *input.RoleSessionName; !reflect.DeepEqual(e, a) { | ||
t.Errorf("expected %v, but received %v", e, a) | ||
} | ||
|
||
return &sts.AssumeRoleWithWebIdentityOutput{ | ||
Credentials: &sts.Credentials{ | ||
Expiration: aws.Time(time.Now()), | ||
AccessKeyId: aws.String("access-key-id"), | ||
SecretAccessKey: aws.String("secret-access-key"), | ||
}, | ||
}, nil | ||
}, | ||
}, | ||
expectedCredValue: credentials.Value{ | ||
AccessKeyID: "access-key-id", | ||
SecretAccessKey: "secret-access-key", | ||
}, | ||
}, | ||
{ | ||
name: "valid case", | ||
roleARN: "arn", | ||
tokenFilepath: "testdata/token.jwt", | ||
mockSTS: &mockSTS{ | ||
AssumeRoleWithWebIdentityFn: func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) { | ||
if e, a := fmt.Sprintf("%d", now().UnixNano()), *input.RoleSessionName; !reflect.DeepEqual(e, a) { | ||
t.Errorf("expected %v, but received %v", e, a) | ||
} | ||
|
||
return &sts.AssumeRoleWithWebIdentityOutput{ | ||
Credentials: &sts.Credentials{ | ||
Expiration: aws.Time(time.Now()), | ||
AccessKeyId: aws.String("access-key-id"), | ||
SecretAccessKey: aws.String("secret-access-key"), | ||
}, | ||
}, nil | ||
}, | ||
}, | ||
expectedCredValue: credentials.Value{ | ||
AccessKeyID: "access-key-id", | ||
SecretAccessKey: "secret-access-key", | ||
}, | ||
}, | ||
} | ||
|
||
for _, c := range cases { | ||
t.Run(c.name, func(t *testing.T) { | ||
p := NewWebIdentityRoleProvider(c.mockSTS, c.roleARN, c.sessionName, c.tokenFilepath) | ||
credValue, err := p.Retrieve() | ||
if e, a := c.expectedError, err; !reflect.DeepEqual(e, a) { | ||
t.Errorf("expected %v, but received %v", e, a) | ||
} | ||
|
||
if e, a := c.expectedCredValue, credValue; !reflect.DeepEqual(e, a) { | ||
t.Errorf("expected %v, but received %v", e, a) | ||
} | ||
}) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -440,6 +440,19 @@ func mergeConfigSrcs(cfg, userCfg *aws.Config, envCfg envConfig, sharedCfg share | |
cfg.Credentials = credentials.NewStaticCredentialsFromCreds( | ||
envCfg.Creds, | ||
) | ||
|
||
} else if len(envCfg.WebIdentityTokenFilePath) > 0 { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this also validate that the RoleArn env var is provided, or just let the creds fail? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The role ARN is a required field, I'd say it should be required to use this provider There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will return an error, if role ARN is empty, when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am okay with changing it to behaving the other way, if it seems more preferable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about this, return an error if only one is set at the session create level rather than at the credential retrieval. That way it fails quickly and let's the user know that they have misconfigured the application There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah ok that makes sense. I like an error being reported if only one is present |
||
// handles assume role via OIDC token. This should happen before any other | ||
// assume role call. | ||
sessionName := envCfg.IAMRoleSessionName | ||
if len(sessionName) == 0 { | ||
sessionName = sharedCfg.AssumeRole.RoleSessionName | ||
} | ||
|
||
cfg.Credentials = stscreds.NewWebIdentityCredentials(&Session{ | ||
Config: cfg, | ||
Handlers: handlers.Copy(), | ||
}, envCfg.WebIdentityRoleARN, sessionName, envCfg.WebIdentityTokenFilePath) | ||
} else if envCfg.EnableSharedConfig && len(sharedCfg.AssumeRole.RoleARN) > 0 && sharedCfg.AssumeRoleSource != nil { | ||
cfgCp := *cfg | ||
cfgCp.Credentials = credentials.NewStaticCredentialsFromCreds( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible for the
Expiration
to be nil? If so these credentials will expire immediately since the TimeValue will be zero time. If no Expiration is a valid use case might need to only SetExpiration if Expiration isn't nil.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should never be nil, an assumed role credential will always have an expiration and a token. https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html