diff --git a/integration/awskms/aws_kms_aead.go b/integration/awskms/aws_kms_aead.go index 070c521..43d5cdb 100644 --- a/integration/awskms/aws_kms_aead.go +++ b/integration/awskms/aws_kms_aead.go @@ -51,23 +51,17 @@ func newAWSAEAD(keyURI string, kms kmsiface.KMSAPI) *AWSAEAD { // Encrypt encrypts the plaintext with associatedData. func (a *AWSAEAD) Encrypt(plaintext, associatedData []byte) ([]byte, error) { - ad := hex.EncodeToString(associatedData) req := &kms.EncryptInput{ - KeyId: aws.String(a.keyURI), - Plaintext: plaintext, - EncryptionContext: map[string]*string{"additionalData": &ad}, + KeyId: aws.String(a.keyURI), + Plaintext: plaintext, } - if ad == "" { - req = &kms.EncryptInput{ - KeyId: aws.String(a.keyURI), - Plaintext: plaintext, - } + if ad := hex.EncodeToString(associatedData); ad != "" { + req.EncryptionContext = map[string]*string{"additionalData": &ad} } resp, err := a.kms.Encrypt(req) if err != nil { return nil, err } - return resp.CiphertextBlob, nil } @@ -82,16 +76,12 @@ func (a *AWSAEAD) Encrypt(plaintext, associatedData []byte) ([]byte, error) { // // See https://docs.aws.amazon.com/kms/latest/developerguide/concepts.html#key-id. func (a *AWSAEAD) Decrypt(ciphertext, associatedData []byte) ([]byte, error) { - ad := hex.EncodeToString(associatedData) req := &kms.DecryptInput{ - KeyId: aws.String(a.keyURI), - CiphertextBlob: ciphertext, - EncryptionContext: map[string]*string{"additionalData": &ad}, + KeyId: aws.String(a.keyURI), + CiphertextBlob: ciphertext, } - if ad == "" { - req = &kms.DecryptInput{ - CiphertextBlob: ciphertext, - } + if ad := hex.EncodeToString(associatedData); ad != "" { + req.EncryptionContext = map[string]*string{"additionalData": &ad} } resp, err := a.kms.Decrypt(req) if err != nil {