-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
213 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
foo |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
package stscreds | ||
|
||
import ( | ||
"fmt" | ||
"io/ioutil" | ||
"strconv" | ||
"time" | ||
|
||
"github.com/aws/aws-sdk-go/aws" | ||
"github.com/aws/aws-sdk-go/aws/awserr" | ||
"github.com/aws/aws-sdk-go/aws/client" | ||
"github.com/aws/aws-sdk-go/aws/credentials" | ||
"github.com/aws/aws-sdk-go/service/sts" | ||
"github.com/aws/aws-sdk-go/service/sts/stsiface" | ||
) | ||
|
||
const ( | ||
// ErrCodeWebIdentityRetrievalErr will be used as an error code when constructing | ||
// a new error to be returned during Retrieve. | ||
ErrCodeWebIdentityRetrievalErr = "WebIdentityRetrievalErr" | ||
) | ||
|
||
// now is used to return a time.Time object representing | ||
// the current time. This can be used to easily test and | ||
// compare test values. | ||
var now = func() time.Time { | ||
return time.Now() | ||
} | ||
|
||
// WebIdentityRoleProvider is used to retrieve credentials using | ||
// an OIDC token. | ||
type WebIdentityRoleProvider struct { | ||
credentials.Expiry | ||
|
||
client stsiface.STSAPI | ||
ExpiryWindow time.Duration | ||
|
||
tokenFilePath string | ||
roleARN string | ||
} | ||
|
||
// NewWebIdentityCredentials will return a new set of credentials with a given | ||
// configuration, role arn, and token file path. | ||
func NewWebIdentityCredentials(c client.ConfigProvider, roleARN, path string) *credentials.Credentials { | ||
svc := sts.New(c) | ||
p := NewWebIdentityRoleProvider(svc, roleARN, path) | ||
return credentials.NewCredentials(p) | ||
} | ||
|
||
// NewWebIdentityRoleProvider will return a new WebIdentityRoleProvider with the | ||
// provided stsiface.STSAPI | ||
func NewWebIdentityRoleProvider(svc stsiface.STSAPI, roleARN, path string) *WebIdentityRoleProvider { | ||
return &WebIdentityRoleProvider{ | ||
client: svc, | ||
tokenFilePath: path, | ||
roleARN: roleARN, | ||
} | ||
} | ||
|
||
var emptyTokenFilePathErr = awserr.New(ErrCodeWebIdentityRetrievalErr, "'WebIdentityTokenFilePath' environment variable is empty", nil) | ||
var emptyRoleARNErr = awserr.New(ErrCodeWebIdentityRetrievalErr, "'WebIdentityRoleARN' environment variable is empty", nil) | ||
|
||
// Retrieve will attempt to assume a role from a token which is located at | ||
// 'WebIdentityTokenFilePath' specified destination and if that is empty an | ||
// error will be returned. | ||
func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) { | ||
if len(p.tokenFilePath) == 0 { | ||
return credentials.Value{}, emptyTokenFilePathErr | ||
} | ||
|
||
if len(p.roleARN) == 0 { | ||
return credentials.Value{}, emptyRoleARNErr | ||
} | ||
|
||
b, err := ioutil.ReadFile(p.tokenFilePath) | ||
if err != nil { | ||
errMsg := fmt.Sprintf("unabled to read file at %s", p.tokenFilePath) | ||
return credentials.Value{}, awserr.New(ErrCodeWebIdentityRetrievalErr, errMsg, err) | ||
} | ||
|
||
// session name is used to uniquely identify a session. This simply | ||
// uses unix time in nanoseconds to uniquely identify sessions. | ||
sessionName := strconv.FormatInt(now().UTC().UnixNano(), 10) | ||
resp, err := p.client.AssumeRoleWithWebIdentity(&sts.AssumeRoleWithWebIdentityInput{ | ||
RoleArn: &p.roleARN, | ||
RoleSessionName: &sessionName, | ||
WebIdentityToken: aws.String(string(b)), | ||
}) | ||
if err != nil { | ||
return credentials.Value{}, awserr.New(ErrCodeWebIdentityRetrievalErr, "failed to retrieve credentials", err) | ||
} | ||
|
||
p.SetExpiration(aws.TimeValue(resp.Credentials.Expiration), p.ExpiryWindow) | ||
|
||
value := credentials.Value{ | ||
AccessKeyID: aws.StringValue(resp.Credentials.AccessKeyId), | ||
SecretAccessKey: aws.StringValue(resp.Credentials.SecretAccessKey), | ||
} | ||
return value, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
package stscreds | ||
|
||
import ( | ||
"reflect" | ||
"testing" | ||
"time" | ||
|
||
"github.com/aws/aws-sdk-go/aws" | ||
"github.com/aws/aws-sdk-go/aws/credentials" | ||
"github.com/aws/aws-sdk-go/service/sts" | ||
) | ||
|
||
type mockSTS struct { | ||
*sts.STS | ||
AssumeRoleWithWebIdentityFn func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) | ||
} | ||
|
||
func (m *mockSTS) AssumeRoleWithWebIdentity(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) { | ||
if m.AssumeRoleWithWebIdentityFn != nil { | ||
return m.AssumeRoleWithWebIdentityFn(input) | ||
} | ||
|
||
return nil, nil | ||
} | ||
|
||
func TestWebIdentityProviderRetrieve(t *testing.T) { | ||
cases := []struct { | ||
name string | ||
mockSTS *mockSTS | ||
roleARN string | ||
tokenFilepath string | ||
expectedError error | ||
expectedCredValue credentials.Value | ||
}{ | ||
{ | ||
name: "no role arn", | ||
tokenFilepath: "foo/bar", | ||
mockSTS: &mockSTS{ | ||
AssumeRoleWithWebIdentityFn: func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) { | ||
return &sts.AssumeRoleWithWebIdentityOutput{}, nil | ||
}, | ||
}, | ||
expectedError: emptyRoleARNErr, | ||
}, | ||
{ | ||
name: "no token file path", | ||
roleARN: "arn", | ||
mockSTS: &mockSTS{ | ||
AssumeRoleWithWebIdentityFn: func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) { | ||
return &sts.AssumeRoleWithWebIdentityOutput{}, nil | ||
}, | ||
}, | ||
expectedError: emptyTokenFilePathErr, | ||
}, | ||
{ | ||
name: "valid case", | ||
roleARN: "arn", | ||
tokenFilepath: "testdata/token.jwt", | ||
mockSTS: &mockSTS{ | ||
AssumeRoleWithWebIdentityFn: func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) { | ||
return &sts.AssumeRoleWithWebIdentityOutput{ | ||
Credentials: &sts.Credentials{ | ||
Expiration: aws.Time(time.Now()), | ||
AccessKeyId: aws.String("access-key-id"), | ||
SecretAccessKey: aws.String("secret-access-key"), | ||
}, | ||
}, nil | ||
}, | ||
}, | ||
expectedCredValue: credentials.Value{ | ||
AccessKeyID: "access-key-id", | ||
SecretAccessKey: "secret-access-key", | ||
}, | ||
}, | ||
} | ||
|
||
for _, c := range cases { | ||
t.Run(c.name, func(t *testing.T) { | ||
p := NewWebIdentityRoleProvider(c.mockSTS, c.roleARN, c.tokenFilepath) | ||
credValue, err := p.Retrieve() | ||
if e, a := c.expectedError, err; !reflect.DeepEqual(e, a) { | ||
t.Errorf("expected %v, but received %v", e, a) | ||
} | ||
|
||
if e, a := c.expectedCredValue, credValue; !reflect.DeepEqual(e, a) { | ||
t.Errorf("expected %v, but received %v", e, a) | ||
} | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters