Skip to content

Commit

Permalink
kms: use BaseEndpoint for testing
Browse files Browse the repository at this point in the history
This does the same, but with much less boilerplate.

xref: https://aws.github.io/aws-sdk-go-v2/docs/configuring-sdk/endpoints/#v2-endpointresolverv2--baseendpoint

Signed-off-by: Hidde Beydals <hidde@hhh.computer>
  • Loading branch information
hiddeco committed Aug 16, 2023
1 parent 4603a0d commit 0fbafd7
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 22 deletions.
22 changes: 13 additions & 9 deletions kms/keysource.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ type MasterKey struct {
// using CredentialsProvider.ApplyToMasterKey. If nil, the default client is used
// which utilizes runtime environmental values.
credentialsProvider aws.CredentialsProvider
// epResolver can be used to override the endpoint the AWS client resolves
// baseEndpoint can be used to override the endpoint the AWS client resolves
// to by default. This is mostly used for testing purposes as it can not be
// injected using e.g. an environment variable. The field is not publicly
// exposed, nor configurable.
epResolver aws.EndpointResolverWithOptions
baseEndpoint string
}

// NewMasterKey creates a new MasterKey from an ARN, role and context, setting
Expand Down Expand Up @@ -197,7 +197,7 @@ func (key *MasterKey) Encrypt(dataKey []byte) error {
log.WithField("arn", key.Arn).Error("Encryption failed")
return err
}
client := kms.NewFromConfig(*cfg)
client := key.createClient(cfg)
input := &kms.EncryptInput{
KeyId: &key.Arn,
Plaintext: dataKey,
Expand Down Expand Up @@ -245,7 +245,7 @@ func (key *MasterKey) Decrypt() ([]byte, error) {
log.WithField("arn", key.Arn).Error("Decryption failed")
return nil, err
}
client := kms.NewFromConfig(*cfg)
client := key.createClient(cfg)
input := &kms.DecryptInput{
KeyId: &key.Arn,
CiphertextBlob: k,
Expand Down Expand Up @@ -309,11 +309,6 @@ func (key MasterKey) createKMSConfig() (*aws.Config, error) {
lo.SharedConfigProfile = key.AwsProfile
}
lo.Region = region

// Set the epResolver, if present. Used ONLY for tests.
if key.epResolver != nil {
lo.EndpointResolverWithOptions = key.epResolver
}
return nil
})
if err != nil {
Expand All @@ -326,6 +321,15 @@ func (key MasterKey) createKMSConfig() (*aws.Config, error) {
return &cfg, nil
}

// createClient creates a new AWS KMS client with the provided config.
func (key MasterKey) createClient(config *aws.Config) *kms.Client {
return kms.NewFromConfig(*config, func(o *kms.Options) {
if key.baseEndpoint != "" {
o.BaseEndpoint = aws.String(key.baseEndpoint)
}
})
}

// createSTSConfig uses AWS STS to assume a role and returns a config
// configured with that role's credentials. It returns an error if
// it fails to construct a session name, or assume the role.
Expand Down
17 changes: 4 additions & 13 deletions kms/keysource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ func createTestMasterKey(arn string) MasterKey {
return MasterKey{
Arn: arn,
credentialsProvider: credentials.NewStaticCredentialsProvider("id", "secret", ""),
epResolver: epResolver{},
baseEndpoint: testKMSServerURL,
}
}

Expand All @@ -560,16 +560,7 @@ func createTestKMSClient(key MasterKey) (*kms.Client, error) {
if err != nil {
return nil, err
}
cfg.EndpointResolverWithOptions = epResolver{}
return kms.NewFromConfig(*cfg), nil
}

// epResolver is a dummy resolver that points to the local test KMS server.
type epResolver struct{}

// ResolveEndpoint always resolves to testKMSServerURL.
func (e epResolver) ResolveEndpoint(_, _ string, _ ...interface{}) (aws.Endpoint, error) {
return aws.Endpoint{
URL: testKMSServerURL,
}, nil
return kms.NewFromConfig(*cfg, func(options *kms.Options) {
options.BaseEndpoint = aws.String(testKMSServerURL)
}), nil
}

0 comments on commit 0fbafd7

Please sign in to comment.