Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port Shared Config Improvements & Credential Providers #488

Merged
merged 12 commits into from
Mar 2, 2020
Merged
4 changes: 2 additions & 2 deletions aws/ec2metadata/api_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ import (
)

func TestClientDisableIMDS(t *testing.T) {
env := awstesting.StashEnv()
defer awstesting.PopEnv(env)
restoreEnv := awstesting.StashEnv()
defer awstesting.PopEnv(restoreEnv)

os.Setenv("AWS_EC2_METADATA_DISABLED", "true")

Expand Down
31 changes: 20 additions & 11 deletions aws/ec2rolecreds/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ const ProviderName = "EC2RoleProvider"
// A Provider retrieves credentials from the EC2 service, and keeps track if
// those credentials are expired.
//
// The NewProvider function must be used to create the Provider.
// The New function must be used to create the Provider.
//
// p := &ec2rolecreds.NewProvider(ec2metadata.New(cfg))
// p := &ec2rolecreds.New(ec2metadata.New(options))
//
// // Expire the credentials 10 minutes before IAM states they should. Proactivily
// // refreshing the credentials.
Expand All @@ -31,8 +31,13 @@ type Provider struct {
aws.SafeCredentialsProvider

// Required EC2Metadata client to use when connecting to EC2 metadata service.
Client *ec2metadata.Client
client *ec2metadata.Client

options ProviderOptions
}

// ProviderOptions is a list of user settable options for setting the behavior of the Provider.
type ProviderOptions struct {
// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
// with expiring credentials do not cause request to fail unexpectedly
Expand All @@ -45,22 +50,26 @@ type Provider struct {
ExpiryWindow time.Duration
}

// NewProvider returns an initialized Provider value configured to retrieve
// New returns an initialized Provider value configured to retrieve
// credentials from EC2 Instance Metadata service.
func NewProvider(client *ec2metadata.Client) *Provider {
p := &Provider{
Client: client,
}
func New(client *ec2metadata.Client, options ...func(*ProviderOptions)) *Provider {
p := &Provider{}

p.client = client
p.RetrieveFn = p.retrieveFn

for _, option := range options {
option(&p.options)
}

return p
}

// Retrieve retrieves credentials from the EC2 service.
// Error will be returned if the request fails, or unable to extract
// the desired credentials.
func (p *Provider) retrieveFn(ctx context.Context) (aws.Credentials, error) {
credsList, err := requestCredList(ctx, p.Client)
credsList, err := requestCredList(ctx, p.client)
if err != nil {
return aws.Credentials{}, err
}
Expand All @@ -71,7 +80,7 @@ func (p *Provider) retrieveFn(ctx context.Context) (aws.Credentials, error) {
}
credsName := credsList[0]

roleCreds, err := requestCred(ctx, p.Client, credsName)
roleCreds, err := requestCred(ctx, p.client, credsName)
if err != nil {
return aws.Credentials{}, err
}
Expand All @@ -83,7 +92,7 @@ func (p *Provider) retrieveFn(ctx context.Context) (aws.Credentials, error) {
Source: ProviderName,

CanExpire: true,
Expires: roleCreds.Expiration.Add(-p.ExpiryWindow),
Expires: roleCreds.Expiration.Add(-p.options.ExpiryWindow),
}

return creds, nil
Expand Down
13 changes: 7 additions & 6 deletions aws/ec2rolecreds/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func TestProvider(t *testing.T) {
cfg := unit.Config()
cfg.EndpointResolver = aws.ResolveWithEndpointURL(server.URL + "/latest")

p := ec2rolecreds.NewProvider(ec2metadata.New(cfg))
p := ec2rolecreds.New(ec2metadata.New(cfg))

creds, err := p.Retrieve(context.Background())
if err != nil {
Expand Down Expand Up @@ -91,7 +91,7 @@ func TestProvider_FailAssume(t *testing.T) {
cfg := unit.Config()
cfg.EndpointResolver = aws.ResolveWithEndpointURL(server.URL + "/latest")

p := ec2rolecreds.NewProvider(ec2metadata.New(cfg))
p := ec2rolecreds.New(ec2metadata.New(cfg))

creds, err := p.Retrieve(context.Background())
if err == nil {
Expand Down Expand Up @@ -131,7 +131,7 @@ func TestProvider_IsExpired(t *testing.T) {
cfg := unit.Config()
cfg.EndpointResolver = aws.ResolveWithEndpointURL(server.URL + "/latest")

p := ec2rolecreds.NewProvider(ec2metadata.New(cfg))
p := ec2rolecreds.New(ec2metadata.New(cfg))

sdk.NowTime = func() time.Time {
return time.Date(2014, 12, 16, 0, 55, 37, 0, time.UTC)
Expand Down Expand Up @@ -164,8 +164,9 @@ func TestProvider_ExpiryWindowIsExpired(t *testing.T) {
cfg := unit.Config()
cfg.EndpointResolver = aws.ResolveWithEndpointURL(server.URL + "/latest")

p := ec2rolecreds.NewProvider(ec2metadata.New(cfg))
p.ExpiryWindow = time.Hour
p := ec2rolecreds.New(ec2metadata.New(cfg), func(options *ec2rolecreds.ProviderOptions) {
options.ExpiryWindow = time.Hour
})

sdk.NowTime = func() time.Time {
return time.Date(2014, 12, 16, 0, 40, 37, 0, time.UTC)
Expand Down Expand Up @@ -195,7 +196,7 @@ func BenchmarkProvider(b *testing.B) {
cfg := unit.Config()
cfg.EndpointResolver = aws.ResolveWithEndpointURL(server.URL + "/latest")

p := ec2rolecreds.NewProvider(ec2metadata.New(cfg))
p := ec2rolecreds.New(ec2metadata.New(cfg))

if _, err := p.Retrieve(context.Background()); err != nil {
b.Fatal(err)
Expand Down
37 changes: 27 additions & 10 deletions aws/endpointcreds/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,13 @@ type Provider struct {
// The AWS Client to make HTTP requests to the endpoint with. The endpoint
// the request will be made to is provided by the aws.Config's
// EndpointResolver.
Client *aws.Client
client *aws.Client

options ProviderOptions
}

// ProviderOptions is structure of configurable options for Provider
type ProviderOptions struct {
// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
// with expiring credentials do not cause request to fail unexpectedly
Expand All @@ -61,13 +66,17 @@ type Provider struct {
//
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration

// Optional authorization token value if set will be used as the value of
// the Authorization header of the endpoint credential request.
AuthorizationToken string
}

// New returns a credentials Provider for retrieving AWS credentials
// from arbitrary endpoint.
func New(cfg aws.Config) *Provider {
func New(cfg aws.Config, options ...func(*ProviderOptions)) *Provider {
p := &Provider{
Client: aws.NewClient(
client: aws.NewClient(
cfg,
aws.Metadata{
ServiceName: ProviderName,
Expand All @@ -76,10 +85,14 @@ func New(cfg aws.Config) *Provider {
}
p.RetrieveFn = p.retrieveFn

p.Client.Handlers.Unmarshal.PushBack(unmarshalHandler)
p.Client.Handlers.UnmarshalError.PushBack(unmarshalError)
p.Client.Handlers.Validate.Clear()
p.Client.Handlers.Validate.PushBack(validateEndpointHandler)
p.client.Handlers.Unmarshal.PushBack(unmarshalHandler)
p.client.Handlers.UnmarshalError.PushBack(unmarshalError)
p.client.Handlers.Validate.Clear()
p.client.Handlers.Validate.PushBack(validateEndpointHandler)

for _, option := range options {
option(&p.options)
}

return p
}
Expand All @@ -102,7 +115,7 @@ func (p *Provider) retrieveFn(ctx context.Context) (aws.Credentials, error) {

if resp.Expiration != nil {
creds.CanExpire = true
creds.Expires = resp.Expiration.Add(-p.ExpiryWindow)
creds.Expires = resp.Expiration.Add(-p.options.ExpiryWindow)
}

return creds, nil
Expand All @@ -127,9 +140,13 @@ func (p *Provider) getCredentials(ctx context.Context) (*getCredentialsOutput, e
}

out := &getCredentialsOutput{}
req := p.Client.NewRequest(op, nil, out)
req.HTTPRequest.Header.Set("Accept", "application/json")
req := p.client.NewRequest(op, nil, out)
req.SetContext(ctx)
req.HTTPRequest.Header.Set("Accept", "application/json")
if authToken := p.options.AuthorizationToken; len(authToken) != 0 {
req.HTTPRequest.Header.Set("Authorization", authToken)
}

return out, req.Send()
}

Expand Down
17 changes: 7 additions & 10 deletions aws/external/codegen/main.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 3 additions & 5 deletions aws/external/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,14 @@ var DefaultConfigLoaders = []ConfigLoader{
// This will setup the AWS configuration's Region,
var DefaultAWSConfigResolvers = []AWSConfigResolver{
ResolveDefaultAWSConfig,
ResolveHandlersFunc,
ResolveEndpointResolverFunc,
ResolveCustomCABundle,
ResolveEnableEndpointDiscovery,

ResolveRegion,

ResolveFallbackEC2Credentials, // Initial defauilt credentails provider.
ResolveCredentialsValue,
ResolveEndpointCredentials,
ResolveContainerEndpointPathCredentials, // TODO is this order right?
ResolveAssumeRoleCredentials,
ResolveCredentials,
}

// A Config represents a generic configuration value or set of values. This type
Expand Down
12 changes: 7 additions & 5 deletions aws/external/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,17 @@ func TestConfigs_AppendFromLoaders(t *testing.T) {
func TestConfigs_ResolveAWSConfig(t *testing.T) {
configSources := Configs{
WithRegion("mock-region"),
WithCredentialsValue(aws.Credentials{
AccessKeyID: "AKID", SecretAccessKey: "SECRET",
Source: "provider",
}),
WithCredentialsProvider{aws.StaticCredentialsProvider{
Value: aws.Credentials{
AccessKeyID: "AKID", SecretAccessKey: "SECRET",
Source: "provider",
},
}},
}

cfg, err := configSources.ResolveAWSConfig([]AWSConfigResolver{
ResolveRegion,
ResolveCredentialsValue,
ResolveCredentials,
})
if err != nil {
t.Fatalf("expect no error, got %v", err)
Expand Down
Loading