Skip to content

Commit

Permalink
service/sqs: Add support for validating message checksums by default (#…
Browse files Browse the repository at this point in the history
…1748)

Adds support for the SQS client to automatically validate message
checksums for SendMessage, SendMessageBatch, and ReceiveMessage. This
brings the v2 SDK up to speed with the v1 SDK's behavior.

A DisableMessageChecksumValidation parameter has been added to the
Options struct for SQS package. Setting this to true will disable the
checksum validation. This can be set when creating a client, or per
operation call.
  • Loading branch information
jasdel authored Jul 1, 2022
1 parent 7ece169 commit 681ca65
Show file tree
Hide file tree
Showing 9 changed files with 831 additions and 0 deletions.
8 changes: 8 additions & 0 deletions .changelog/131fe156ee0640ff85a5ba09e3cda44c.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "131fe156-ee06-40ff-85a5-ba09e3cda44c",
"type": "feature",
"description": "Adds support for the SQS client to automatically validate message checksums for SendMessage, SendMessageBatch, and ReceiveMessage. A DisableMessageChecksumValidation parameter has been added to the Options struct for SQS package. Setting this to true will disable the checksum validation. This can be set when creating a client, or per operation call.",
"modules": [
"service/sqs"
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package software.amazon.smithy.aws.go.codegen.customization;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;
import software.amazon.smithy.codegen.core.SymbolProvider;
import software.amazon.smithy.go.codegen.GoCodegenPlugin;
import software.amazon.smithy.go.codegen.GoSettings;
import software.amazon.smithy.go.codegen.SymbolUtils;
import software.amazon.smithy.go.codegen.integration.ConfigField;
import software.amazon.smithy.go.codegen.integration.GoIntegration;
import software.amazon.smithy.go.codegen.integration.MiddlewareRegistrar;
import software.amazon.smithy.go.codegen.integration.RuntimeClientPlugin;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.ShapeId;
import software.amazon.smithy.utils.MapUtils;
import software.amazon.smithy.utils.SetUtils;

public class SQSValidateMessageChecksum implements GoIntegration {
private static final Logger LOGGER = Logger.getLogger(SQSValidateMessageChecksum.class.getName());

/**
* Map of service shape to Set of operation shapes that need to have this
* customization.
*/
public static final Map<ShapeId, Set<ShapeId>> SERVICE_TO_OPERATION_MAP = MapUtils.of(
ShapeId.from("com.amazonaws.sqs#AmazonSQS"), SetUtils.of(
ShapeId.from("com.amazonaws.sqs#SendMessage"),
ShapeId.from("com.amazonaws.sqs#SendMessageBatch"),
ShapeId.from("com.amazonaws.sqs#ReceiveMessage")
)
);
static final String DISABLE_MESSAGE_CHECKSUM_VALIDATION_OPTION_NAME = "DisableMessageChecksumValidation";

private final List<RuntimeClientPlugin> runtimeClientPlugins = new ArrayList<>();

/**
* Builds the set of runtime plugs used by the customization.
*
* @param settings codegen settings
* @param model api model
*/
@Override
public void processFinalizedModel(GoSettings settings, Model model) {
ShapeId serviceId = settings.getService();
if (!SERVICE_TO_OPERATION_MAP.containsKey(serviceId)) {
return;
}

ServiceShape service = settings.getService(model);

// Add option to disable message checksum validation
runtimeClientPlugins.add(RuntimeClientPlugin.builder()
.servicePredicate((m, s) -> s.equals(service))
.addConfigField(ConfigField.builder()
.name(DISABLE_MESSAGE_CHECKSUM_VALIDATION_OPTION_NAME)
.type(SymbolUtils.createValueSymbolBuilder("bool")
.putProperty(SymbolUtils.GO_UNIVERSE_TYPE, true).build())
.documentation("Allows you to disable the client's validation of "
+ "response message checksums. Enabled by default. "
+ "Used by SendMessage, SendMessageBatch, and ReceiveMessage.")
.build())
.build());

for (ShapeId operationId : SERVICE_TO_OPERATION_MAP.get(serviceId)) {
final OperationShape operation = model.expectShape(operationId, OperationShape.class);

// Create a symbol provider because one is not available in this call.
SymbolProvider symbolProvider = GoCodegenPlugin.createSymbolProvider(model, settings);

String helperFuncName = addMiddlewareFuncName(symbolProvider.toSymbol(operation).getName());

runtimeClientPlugins.add(RuntimeClientPlugin.builder()
.servicePredicate((m, s) -> s.equals(service))
.operationPredicate((m, s, o) -> o.equals(operation))
.registerMiddleware(MiddlewareRegistrar.builder()
.resolvedFunction(SymbolUtils.createValueSymbolBuilder(helperFuncName)
.build())
.useClientOptions()
.build())
.build());
}
}

String addMiddlewareFuncName(String operationName) {
return "addValidate" + operationName + "Checksum";
}

/**
* Returns the list of runtime client plugins added by this customization
*
* @return runtime client plugins
*/
@Override
public List<RuntimeClientPlugin> getClientPlugins() {
return runtimeClientPlugins;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,5 @@ software.amazon.smithy.aws.go.codegen.RequestResponseLogging
software.amazon.smithy.aws.go.codegen.customization.S3AddPutObjectUnseekableBodyDoc
software.amazon.smithy.aws.go.codegen.customization.BackfillEc2UnboxedToBoxedShapes
software.amazon.smithy.aws.go.codegen.customization.AdjustAwsRestJsonContentType
software.amazon.smithy.aws.go.codegen.customization.SQSValidateMessageChecksum
software.amazon.smithy.aws.go.codegen.EndpointDiscoveryGenerator
4 changes: 4 additions & 0 deletions service/sqs/api_client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions service/sqs/api_op_ReceiveMessage.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions service/sqs/api_op_SendMessage.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions service/sqs/api_op_SendMessageBatch.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

234 changes: 234 additions & 0 deletions service/sqs/cust_checksum_validation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
package sqs

import (
"context"
"crypto/md5"
"encoding/hex"
"fmt"
"strings"

"github.com/aws/aws-sdk-go-v2/aws"
sqstypes "github.com/aws/aws-sdk-go-v2/service/sqs/types"
"github.com/aws/smithy-go/middleware"
)

// addValidateSendMessageChecksum adds the ValidateMessageChecksum middleware
// to the stack configured for the SendMessage Operation.
func addValidateSendMessageChecksum(stack *middleware.Stack, o Options) error {
return addValidateMessageChecksum(stack, o, validateSendMessageChecksum)
}

// validateSendMessageChecksum validates the SendMessage operation's input
// message payload MD5 checksum matches that returned by the API.
//
// The input and output types must match the SendMessage operation.
func validateSendMessageChecksum(input, output interface{}) error {
in, ok := input.(*SendMessageInput)
if !ok {
return fmt.Errorf("wrong input type, expect %T, got %T", in, input)
}
out, ok := output.(*SendMessageOutput)
if !ok {
return fmt.Errorf("wrong output type, expect %T, got %T", out, output)
}

// Nothing to validate if the members aren't populated.
if in.MessageBody == nil || out.MD5OfMessageBody == nil {
return nil
}

if err := validateMessageChecksum(*in.MessageBody, *out.MD5OfMessageBody); err != nil {
return messageChecksumError{
MessageID: aws.ToString(out.MessageId),
Err: err,
}
}
return nil
}

// addValidateSendMessageBatchChecksum adds the ValidateMessagechecksum
// middleware to the stack configured for the SendMessageBatch operation.
func addValidateSendMessageBatchChecksum(stack *middleware.Stack, o Options) error {
return addValidateMessageChecksum(stack, o, validateSendMessageBatchChecksum)
}

// validateSendMessageBatchChecksum validates the SendMessageBatch operation's
// input messages body MD5 checksum matches those returned by the API.
//
// The input and output types must match the SendMessageBatch operation.
func validateSendMessageBatchChecksum(input, output interface{}) error {
in, ok := input.(*SendMessageBatchInput)
if !ok {
return fmt.Errorf("wrong input type, expect %T, got %T", in, input)
}
out, ok := output.(*SendMessageBatchOutput)
if !ok {
return fmt.Errorf("wrong output type, expect %T, got %T", out, output)
}

outEntries := map[string]sqstypes.SendMessageBatchResultEntry{}
for _, e := range out.Successful {
outEntries[*e.Id] = e
}

var failedMessageErrs []messageChecksumError
for _, inEntry := range in.Entries {
outEntry, ok := outEntries[*inEntry.Id]
// Nothing to validate if the members aren't populated.
if !ok || inEntry.MessageBody == nil || outEntry.MD5OfMessageBody == nil {
continue
}

if err := validateMessageChecksum(*inEntry.MessageBody, *outEntry.MD5OfMessageBody); err != nil {
failedMessageErrs = append(failedMessageErrs, messageChecksumError{
MessageID: aws.ToString(outEntry.MessageId),
Err: err,
})
}
}

if len(failedMessageErrs) != 0 {
return batchMessageChecksumError{
Errs: failedMessageErrs,
}
}

return nil
}

// addValidateReceiveMessageChecksum adds the ValidateMessagechecksum
// middleware to the stack configured for the ReceiveMessage operation.
func addValidateReceiveMessageChecksum(stack *middleware.Stack, o Options) error {
return addValidateMessageChecksum(stack, o, validateReceiveMessageChecksum)
}

// validateReceiveMessageChecksum validates the ReceiveMessage operation's
// input messages body MD5 checksum matches those returned by the API.
//
// The input and output types must match the ReceiveMessage operation.
func validateReceiveMessageChecksum(_, output interface{}) error {
out, ok := output.(*ReceiveMessageOutput)
if !ok {
return fmt.Errorf("wrong output type, expect %T, got %T", out, output)
}

var failedMessageErrs []messageChecksumError
for _, msg := range out.Messages {
// Nothing to validate if the members aren't populated.
if msg.Body == nil || msg.MD5OfBody == nil {
continue
}

if err := validateMessageChecksum(*msg.Body, *msg.MD5OfBody); err != nil {
failedMessageErrs = append(failedMessageErrs, messageChecksumError{
MessageID: aws.ToString(msg.MessageId),
Err: err,
})
}
}

if len(failedMessageErrs) != 0 {
return batchMessageChecksumError{
Errs: failedMessageErrs,
}
}

return nil
}

// messageChecksumValidator provides the function signature for the operation's
// validator.
type messageChecksumValidator func(input, output interface{}) error

// addValidateMessageChecksum adds the ValidateMessageChecksum middleware to
// the stack with the passed in validator specified.
func addValidateMessageChecksum(stack *middleware.Stack, o Options, validate messageChecksumValidator) error {
if o.DisableMessageChecksumValidation {
return nil
}

m := validateMessageChecksumMiddleware{
validate: validate,
}
err := stack.Initialize.Add(m, middleware.Before)
if err != nil {
return fmt.Errorf("failed to add %s middleware, %w", m.ID(), err)
}

return nil
}

// validateMessageChecksumMiddleware provides the Initialize middleware for
// validating an operation's message checksum is validate. Needs to b
// configured with the operation's validator.
type validateMessageChecksumMiddleware struct {
validate messageChecksumValidator
}

// ID returns the Middleware ID.
func (validateMessageChecksumMiddleware) ID() string { return "SQSValidateMessageChecksum" }

// HandleInitialize implements the InitializeMiddleware interface providing a
// middleware that will validate an operation's message checksum based on
// calling the validate member.
func (m validateMessageChecksumMiddleware) HandleInitialize(
ctx context.Context, input middleware.InitializeInput, next middleware.InitializeHandler,
) (
out middleware.InitializeOutput, meta middleware.Metadata, err error,
) {
out, meta, err = next.HandleInitialize(ctx, input)
if err != nil {
return out, meta, err
}

err = m.validate(input.Parameters, out.Result)
if err != nil {
return out, meta, fmt.Errorf("message checksum validation failed, %w", err)
}

return out, meta, nil
}

// validateMessageChecksum compares the MD5 checksums of value parameter with
// the expected MD5 value. Returns an error if the computed checksum does not
// match the expected value.
func validateMessageChecksum(value, expect string) error {
msum := md5.Sum([]byte(value))
sum := hex.EncodeToString(msum[:])
if sum != expect {
return fmt.Errorf("expected MD5 checksum %s, got %s", expect, sum)
}

return nil
}

// messageChecksumError provides an error type for invalid message checksums.
type messageChecksumError struct {
MessageID string
Err error
}

func (e messageChecksumError) Error() string {
prefix := "message"
if e.MessageID != "" {
prefix += " " + e.MessageID
}
return fmt.Sprintf("%s has invalid checksum, %v", prefix, e.Err.Error())
}

// batchMessageChecksumError provides an error type for a collection of invalid
// message checksum errors.
type batchMessageChecksumError struct {
Errs []messageChecksumError
}

func (e batchMessageChecksumError) Error() string {
var w strings.Builder
fmt.Fprintf(&w, "message checksum errors")

for _, err := range e.Errs {
fmt.Fprintf(&w, "\n\t%s", err.Error())
}

return w.String()
}
Loading

0 comments on commit 681ca65

Please sign in to comment.