diff --git a/aws/retry/middleware.go b/aws/retry/middleware.go index 722ca34c6a0..dc703d482d2 100644 --- a/aws/retry/middleware.go +++ b/aws/retry/middleware.go @@ -328,10 +328,12 @@ func AddRetryMiddlewares(stack *smithymiddle.Stack, options AddRetryMiddlewaresO middleware.LogAttempts = options.LogRetryAttempts }) - if err := stack.Finalize.Add(attempt, smithymiddle.After); err != nil { + // index retry to before signing, if signing exists + if err := stack.Finalize.Insert(attempt, "Signing", smithymiddle.Before); err != nil { return err } - if err := stack.Finalize.Add(&MetricsHeader{}, smithymiddle.After); err != nil { + + if err := stack.Finalize.Insert(&MetricsHeader{}, attempt.ID(), smithymiddle.After); err != nil { return err } return nil diff --git a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AwsHttpPresignURLClientGenerator.java b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AwsHttpPresignURLClientGenerator.java index 1d9589a98bf..2d578673647 100644 --- a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AwsHttpPresignURLClientGenerator.java +++ b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AwsHttpPresignURLClientGenerator.java @@ -375,6 +375,14 @@ private void writeConvertToPresignMiddleware( if _, ok := stack.Finalize.Get(($1P)(nil).ID()); ok { stack.Finalize.Remove(($1P)(nil).ID()) }""", SdkGoTypes.ServiceInternal.AcceptEncoding.DisableGzip); + writer.write(""" + if _, ok := stack.Finalize.Get(($1P)(nil).ID()); ok { + stack.Finalize.Remove(($1P)(nil).ID()) + }""", SdkGoTypes.Aws.Retry.Attempt); + writer.write(""" + if _, ok := stack.Finalize.Get(($1P)(nil).ID()); ok { + stack.Finalize.Remove(($1P)(nil).ID()) + }""", SdkGoTypes.Aws.Retry.MetricsHeader); writer.write("stack.Deserialize.Clear()"); writer.write("stack.Build.Remove(($P)(nil).ID())", requestInvocationID); writer.write("stack.Build.Remove($S)", "UserAgent"); diff --git a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/SdkGoTypes.java b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/SdkGoTypes.java index 437372a1ec0..f960ae88d9e 100644 --- a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/SdkGoTypes.java +++ b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/SdkGoTypes.java @@ -34,6 +34,7 @@ public static final class Aws { public static final Symbol IsCredentialsProvider = AwsGoDependency.AWS_CORE.valueSymbol("IsCredentialsProvider"); public static final Symbol AnonymousCredentials = AwsGoDependency.AWS_CORE.pointableSymbol("AnonymousCredentials"); + public static final class Middleware { public static final Symbol GetRequiresLegacyEndpoints = AwsGoDependency.AWS_MIDDLEWARE.valueSymbol("GetRequiresLegacyEndpoints"); public static final Symbol GetSigningName = AwsGoDependency.AWS_MIDDLEWARE.valueSymbol("GetSigningName"); @@ -41,6 +42,12 @@ public static final class Middleware { public static final Symbol SetSigningName = AwsGoDependency.AWS_MIDDLEWARE.valueSymbol("SetSigningName"); public static final Symbol SetSigningRegion = AwsGoDependency.AWS_MIDDLEWARE.valueSymbol("SetSigningRegion"); } + + + public static final class Retry { + public static final Symbol Attempt = AwsGoDependency.AWS_RETRY.pointableSymbol("Attempt"); + public static final Symbol MetricsHeader = AwsGoDependency.AWS_RETRY.pointableSymbol("MetricsHeader"); + } } public static final class Internal { diff --git a/credentials/endpointcreds/internal/client/auth.go b/credentials/endpointcreds/internal/client/auth.go new file mode 100644 index 00000000000..c3f5dadcec9 --- /dev/null +++ b/credentials/endpointcreds/internal/client/auth.go @@ -0,0 +1,48 @@ +package client + +import ( + "context" + "github.com/aws/smithy-go/middleware" +) + +type getIdentityMiddleware struct { + options Options +} + +func (*getIdentityMiddleware) ID() string { + return "GetIdentity" +} + +func (m *getIdentityMiddleware) HandleFinalize(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, +) { + return next.HandleFinalize(ctx, in) +} + +type signRequestMiddleware struct { +} + +func (*signRequestMiddleware) ID() string { + return "Signing" +} + +func (m *signRequestMiddleware) HandleFinalize(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, +) { + return next.HandleFinalize(ctx, in) +} + +type resolveAuthSchemeMiddleware struct { + operation string + options Options +} + +func (*resolveAuthSchemeMiddleware) ID() string { + return "ResolveAuthScheme" +} + +func (m *resolveAuthSchemeMiddleware) HandleFinalize(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, +) { + return next.HandleFinalize(ctx, in) +} diff --git a/credentials/endpointcreds/internal/client/client.go b/credentials/endpointcreds/internal/client/client.go index df0e7575c44..9a869f89547 100644 --- a/credentials/endpointcreds/internal/client/client.go +++ b/credentials/endpointcreds/internal/client/client.go @@ -101,6 +101,7 @@ func (c *Client) GetCredentials(ctx context.Context, params *GetCredentialsInput stack.Serialize.Add(&serializeOpGetCredential{}, smithymiddleware.After) stack.Build.Add(&buildEndpoint{Endpoint: options.Endpoint}, smithymiddleware.After) stack.Deserialize.Add(&deserializeOpGetCredential{}, smithymiddleware.After) + addProtocolFinalizerMiddlewares(stack, options, "GetCredentials") retry.AddRetryMiddlewares(stack, retry.AddRetryMiddlewaresOptions{Retryer: options.Retryer}) middleware.AddSDKAgentKey(middleware.FeatureMetadata, ServiceID) smithyhttp.AddErrorCloseResponseBodyMiddleware(stack) diff --git a/credentials/endpointcreds/internal/client/endpoints.go b/credentials/endpointcreds/internal/client/endpoints.go new file mode 100644 index 00000000000..748ee67244e --- /dev/null +++ b/credentials/endpointcreds/internal/client/endpoints.go @@ -0,0 +1,20 @@ +package client + +import ( + "context" + "github.com/aws/smithy-go/middleware" +) + +type resolveEndpointV2Middleware struct { + options Options +} + +func (*resolveEndpointV2Middleware) ID() string { + return "ResolveEndpointV2" +} + +func (m *resolveEndpointV2Middleware) HandleFinalize(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, +) { + return next.HandleFinalize(ctx, in) +} diff --git a/credentials/endpointcreds/internal/client/middleware.go b/credentials/endpointcreds/internal/client/middleware.go index ddb28a66d1c..f2820d20eac 100644 --- a/credentials/endpointcreds/internal/client/middleware.go +++ b/credentials/endpointcreds/internal/client/middleware.go @@ -146,3 +146,19 @@ func stof(code int) smithy.ErrorFault { } return smithy.FaultClient } + +func addProtocolFinalizerMiddlewares(stack *smithymiddleware.Stack, options Options, operation string) error { + if err := stack.Finalize.Add(&resolveAuthSchemeMiddleware{operation: operation, options: options}, smithymiddleware.Before); err != nil { + return fmt.Errorf("add ResolveAuthScheme: %w", err) + } + if err := stack.Finalize.Insert(&getIdentityMiddleware{options: options}, "ResolveAuthScheme", smithymiddleware.After); err != nil { + return fmt.Errorf("add GetIdentity: %w", err) + } + if err := stack.Finalize.Insert(&resolveEndpointV2Middleware{options: options}, "GetIdentity", smithymiddleware.After); err != nil { + return fmt.Errorf("add ResolveEndpointV2: %w", err) + } + if err := stack.Finalize.Insert(&signRequestMiddleware{}, "ResolveEndpointV2", smithymiddleware.After); err != nil { + return fmt.Errorf("add Signing: %w", err) + } + return nil +} diff --git a/feature/ec2/imds/api_op_GetDynamicData.go b/feature/ec2/imds/api_op_GetDynamicData.go index 9e3bdb0e66e..af58b6bb102 100644 --- a/feature/ec2/imds/api_op_GetDynamicData.go +++ b/feature/ec2/imds/api_op_GetDynamicData.go @@ -56,6 +56,7 @@ type GetDynamicDataOutput struct { func addGetDynamicDataMiddleware(stack *middleware.Stack, options Options) error { return addAPIRequestMiddleware(stack, options, + "GetDynamicData", buildGetDynamicDataPath, buildGetDynamicDataOutput) } diff --git a/feature/ec2/imds/api_op_GetIAMInfo.go b/feature/ec2/imds/api_op_GetIAMInfo.go index 24845dccd6d..5111cc90cac 100644 --- a/feature/ec2/imds/api_op_GetIAMInfo.go +++ b/feature/ec2/imds/api_op_GetIAMInfo.go @@ -53,6 +53,7 @@ type GetIAMInfoOutput struct { func addGetIAMInfoMiddleware(stack *middleware.Stack, options Options) error { return addAPIRequestMiddleware(stack, options, + "GetIAMInfo", buildGetIAMInfoPath, buildGetIAMInfoOutput, ) diff --git a/feature/ec2/imds/api_op_GetInstanceIdentityDocument.go b/feature/ec2/imds/api_op_GetInstanceIdentityDocument.go index a87758ed302..dc8c09edf03 100644 --- a/feature/ec2/imds/api_op_GetInstanceIdentityDocument.go +++ b/feature/ec2/imds/api_op_GetInstanceIdentityDocument.go @@ -54,6 +54,7 @@ type GetInstanceIdentityDocumentOutput struct { func addGetInstanceIdentityDocumentMiddleware(stack *middleware.Stack, options Options) error { return addAPIRequestMiddleware(stack, options, + "GetInstanceIdentityDocument", buildGetInstanceIdentityDocumentPath, buildGetInstanceIdentityDocumentOutput, ) diff --git a/feature/ec2/imds/api_op_GetMetadata.go b/feature/ec2/imds/api_op_GetMetadata.go index cb0ce4c0004..869bfc9feb9 100644 --- a/feature/ec2/imds/api_op_GetMetadata.go +++ b/feature/ec2/imds/api_op_GetMetadata.go @@ -56,6 +56,7 @@ type GetMetadataOutput struct { func addGetMetadataMiddleware(stack *middleware.Stack, options Options) error { return addAPIRequestMiddleware(stack, options, + "GetMetadata", buildGetMetadataPath, buildGetMetadataOutput) } diff --git a/feature/ec2/imds/api_op_GetRegion.go b/feature/ec2/imds/api_op_GetRegion.go index 7b9b48912af..8c0572bb5c8 100644 --- a/feature/ec2/imds/api_op_GetRegion.go +++ b/feature/ec2/imds/api_op_GetRegion.go @@ -45,6 +45,7 @@ type GetRegionOutput struct { func addGetRegionMiddleware(stack *middleware.Stack, options Options) error { return addAPIRequestMiddleware(stack, options, + "GetRegion", buildGetInstanceIdentityDocumentPath, buildGetRegionOutput, ) diff --git a/feature/ec2/imds/api_op_GetToken.go b/feature/ec2/imds/api_op_GetToken.go index 841f802c1a3..1f9ee97a5b7 100644 --- a/feature/ec2/imds/api_op_GetToken.go +++ b/feature/ec2/imds/api_op_GetToken.go @@ -49,6 +49,7 @@ func addGetTokenMiddleware(stack *middleware.Stack, options Options) error { err := addRequestMiddleware(stack, options, "PUT", + "GetToken", buildGetTokenPath, buildGetTokenOutput) if err != nil { diff --git a/feature/ec2/imds/api_op_GetUserData.go b/feature/ec2/imds/api_op_GetUserData.go index 88aa61e9ad9..8903697244a 100644 --- a/feature/ec2/imds/api_op_GetUserData.go +++ b/feature/ec2/imds/api_op_GetUserData.go @@ -45,6 +45,7 @@ type GetUserDataOutput struct { func addGetUserDataMiddleware(stack *middleware.Stack, options Options) error { return addAPIRequestMiddleware(stack, options, + "GetUserData", buildGetUserDataPath, buildGetUserDataOutput) } diff --git a/feature/ec2/imds/auth.go b/feature/ec2/imds/auth.go new file mode 100644 index 00000000000..ad283cf825f --- /dev/null +++ b/feature/ec2/imds/auth.go @@ -0,0 +1,48 @@ +package imds + +import ( + "context" + "github.com/aws/smithy-go/middleware" +) + +type getIdentityMiddleware struct { + options Options +} + +func (*getIdentityMiddleware) ID() string { + return "GetIdentity" +} + +func (m *getIdentityMiddleware) HandleFinalize(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, +) { + return next.HandleFinalize(ctx, in) +} + +type signRequestMiddleware struct { +} + +func (*signRequestMiddleware) ID() string { + return "Signing" +} + +func (m *signRequestMiddleware) HandleFinalize(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, +) { + return next.HandleFinalize(ctx, in) +} + +type resolveAuthSchemeMiddleware struct { + operation string + options Options +} + +func (*resolveAuthSchemeMiddleware) ID() string { + return "ResolveAuthScheme" +} + +func (m *resolveAuthSchemeMiddleware) HandleFinalize(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, +) { + return next.HandleFinalize(ctx, in) +} diff --git a/feature/ec2/imds/endpoints.go b/feature/ec2/imds/endpoints.go new file mode 100644 index 00000000000..d7540da3481 --- /dev/null +++ b/feature/ec2/imds/endpoints.go @@ -0,0 +1,20 @@ +package imds + +import ( + "context" + "github.com/aws/smithy-go/middleware" +) + +type resolveEndpointV2Middleware struct { + options Options +} + +func (*resolveEndpointV2Middleware) ID() string { + return "ResolveEndpointV2" +} + +func (m *resolveEndpointV2Middleware) HandleFinalize(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, +) { + return next.HandleFinalize(ctx, in) +} diff --git a/feature/ec2/imds/request_middleware.go b/feature/ec2/imds/request_middleware.go index c8abd64916c..fc948c27d89 100644 --- a/feature/ec2/imds/request_middleware.go +++ b/feature/ec2/imds/request_middleware.go @@ -17,10 +17,11 @@ import ( func addAPIRequestMiddleware(stack *middleware.Stack, options Options, + operation string, getPath func(interface{}) (string, error), getOutput func(*smithyhttp.Response) (interface{}, error), ) (err error) { - err = addRequestMiddleware(stack, options, "GET", getPath, getOutput) + err = addRequestMiddleware(stack, options, "GET", operation, getPath, getOutput) if err != nil { return err } @@ -44,6 +45,7 @@ func addAPIRequestMiddleware(stack *middleware.Stack, func addRequestMiddleware(stack *middleware.Stack, options Options, method string, + operation string, getPath func(interface{}) (string, error), getOutput func(*smithyhttp.Response) (interface{}, error), ) (err error) { @@ -101,6 +103,10 @@ func addRequestMiddleware(stack *middleware.Stack, return err } + if err := addProtocolFinalizerMiddlewares(stack, options, operation); err != nil { + return fmt.Errorf("add protocol finalizers: %w", err) + } + // Retry support return retry.AddRetryMiddlewares(stack, retry.AddRetryMiddlewaresOptions{ Retryer: options.Retryer, @@ -283,3 +289,19 @@ func appendURIPath(base, add string) string { } return reqPath } + +func addProtocolFinalizerMiddlewares(stack *middleware.Stack, options Options, operation string) error { + if err := stack.Finalize.Add(&resolveAuthSchemeMiddleware{operation: operation, options: options}, middleware.Before); err != nil { + return fmt.Errorf("add ResolveAuthScheme: %w", err) + } + if err := stack.Finalize.Insert(&getIdentityMiddleware{options: options}, "ResolveAuthScheme", middleware.After); err != nil { + return fmt.Errorf("add GetIdentity: %w", err) + } + if err := stack.Finalize.Insert(&resolveEndpointV2Middleware{options: options}, "GetIdentity", middleware.After); err != nil { + return fmt.Errorf("add ResolveEndpointV2: %w", err) + } + if err := stack.Finalize.Insert(&signRequestMiddleware{}, "ResolveEndpointV2", middleware.After); err != nil { + return fmt.Errorf("add Signing: %w", err) + } + return nil +} diff --git a/feature/ec2/imds/request_middleware_test.go b/feature/ec2/imds/request_middleware_test.go index 53fd8b6ed73..04f00f44e64 100644 --- a/feature/ec2/imds/request_middleware_test.go +++ b/feature/ec2/imds/request_middleware_test.go @@ -33,6 +33,7 @@ func TestAddRequestMiddleware(t *testing.T) { "api request": { AddMiddleware: func(stack *middleware.Stack, options Options) error { return addAPIRequestMiddleware(stack, options, + "TestRequest", func(interface{}) (string, error) { return "/mockPath", nil }, @@ -53,9 +54,13 @@ func TestAddRequestMiddleware(t *testing.T) { "UserAgent", }, ExpectFinalize: []string{ + "ResolveAuthScheme", + "GetIdentity", + "ResolveEndpointV2", "Retry", "APITokenProvider", "RetryMetricsHeader", + "Signing", }, ExpectDeserialize: []string{ "APITokenProvider", @@ -66,7 +71,7 @@ func TestAddRequestMiddleware(t *testing.T) { "base request": { AddMiddleware: func(stack *middleware.Stack, options Options) error { - return addRequestMiddleware(stack, options, "POST", + return addRequestMiddleware(stack, options, "POST", "TestRequest", func(interface{}) (string, error) { return "/mockPath", nil }, @@ -87,8 +92,12 @@ func TestAddRequestMiddleware(t *testing.T) { "UserAgent", }, ExpectFinalize: []string{ + "ResolveAuthScheme", + "GetIdentity", + "ResolveEndpointV2", "Retry", "RetryMetricsHeader", + "Signing", }, ExpectDeserialize: []string{ "OperationDeserializer", @@ -590,6 +599,7 @@ func TestRequestGetToken(t *testing.T) { func(stack *middleware.Stack, options Options) error { return addAPIRequestMiddleware(stack, client.options.Copy(), + "TestRequest", func(interface{}) (string, error) { return "/latest/foo", nil }, diff --git a/service/docdb/api_client.go b/service/docdb/api_client.go index b7f8b686a03..72e8e156e3a 100644 --- a/service/docdb/api_client.go +++ b/service/docdb/api_client.go @@ -552,6 +552,12 @@ func (c presignConverter) convertToPresignMiddleware(stack *middleware.Stack, op if _, ok := stack.Finalize.Get((*acceptencodingcust.DisableGzip)(nil).ID()); ok { stack.Finalize.Remove((*acceptencodingcust.DisableGzip)(nil).ID()) } + if _, ok := stack.Finalize.Get((*retry.Attempt)(nil).ID()); ok { + stack.Finalize.Remove((*retry.Attempt)(nil).ID()) + } + if _, ok := stack.Finalize.Get((*retry.MetricsHeader)(nil).ID()); ok { + stack.Finalize.Remove((*retry.MetricsHeader)(nil).ID()) + } stack.Deserialize.Clear() stack.Build.Remove((*awsmiddleware.ClientRequestID)(nil).ID()) stack.Build.Remove("UserAgent") diff --git a/service/ec2/api_client.go b/service/ec2/api_client.go index d8456843cab..84365b955b0 100644 --- a/service/ec2/api_client.go +++ b/service/ec2/api_client.go @@ -568,6 +568,12 @@ func (c presignConverter) convertToPresignMiddleware(stack *middleware.Stack, op if _, ok := stack.Finalize.Get((*acceptencodingcust.DisableGzip)(nil).ID()); ok { stack.Finalize.Remove((*acceptencodingcust.DisableGzip)(nil).ID()) } + if _, ok := stack.Finalize.Get((*retry.Attempt)(nil).ID()); ok { + stack.Finalize.Remove((*retry.Attempt)(nil).ID()) + } + if _, ok := stack.Finalize.Get((*retry.MetricsHeader)(nil).ID()); ok { + stack.Finalize.Remove((*retry.MetricsHeader)(nil).ID()) + } stack.Deserialize.Clear() stack.Build.Remove((*awsmiddleware.ClientRequestID)(nil).ID()) stack.Build.Remove("UserAgent") diff --git a/service/internal/checksum/middleware_add.go b/service/internal/checksum/middleware_add.go index 610e7ca80eb..1b727acbe17 100644 --- a/service/internal/checksum/middleware_add.go +++ b/service/internal/checksum/middleware_add.go @@ -90,6 +90,19 @@ func AddInputMiddleware(stack *middleware.Stack, options InputMiddlewareOptions) return err } + // If trailing checksum is not supported no need for finalize handler to be added. + if options.EnableTrailingChecksum { + trailerMiddleware := &addInputChecksumTrailer{ + EnableTrailingChecksum: inputChecksum.EnableTrailingChecksum, + RequireChecksum: inputChecksum.RequireChecksum, + EnableComputePayloadHash: inputChecksum.EnableComputePayloadHash, + EnableDecodedContentLengthHeader: inputChecksum.EnableDecodedContentLengthHeader, + } + if err := stack.Finalize.Insert(trailerMiddleware, "Retry", middleware.After); err != nil { + return err + } + } + return nil } diff --git a/service/internal/checksum/middleware_add_test.go b/service/internal/checksum/middleware_add_test.go index 74f4cdfe78b..2e4a64f115b 100644 --- a/service/internal/checksum/middleware_add_test.go +++ b/service/internal/checksum/middleware_add_test.go @@ -39,6 +39,7 @@ func TestAddInputMiddleware(t *testing.T) { "ComputePayloadHash", "Finalize stack step", "Retry", + "addInputChecksumTrailer", "ResolveEndpointV2", "AWSChecksum:ComputeInputPayloadChecksum", "Signing", @@ -73,6 +74,7 @@ func TestAddInputMiddleware(t *testing.T) { "ComputePayloadHash", "Finalize stack step", "Retry", + "addInputChecksumTrailer", "ResolveEndpointV2", "AWSChecksum:ComputeInputPayloadChecksum", "Signing", @@ -209,6 +211,7 @@ func TestRemoveInputMiddleware(t *testing.T) { "ComputePayloadHash", "Finalize stack step", "Retry", + "addInputChecksumTrailer", "ResolveEndpointV2", "Signing", "Deserialize stack step", diff --git a/service/internal/checksum/middleware_compute_input_checksum.go b/service/internal/checksum/middleware_compute_input_checksum.go index c7740658aea..7ffca33f0ef 100644 --- a/service/internal/checksum/middleware_compute_input_checksum.go +++ b/service/internal/checksum/middleware_compute_input_checksum.go @@ -75,6 +75,8 @@ type computeInputPayloadChecksum struct { useTrailer bool } +type useTrailer struct{} + // ID provides the middleware's identifier. func (m *computeInputPayloadChecksum) ID() string { return "AWSChecksum:ComputeInputPayloadChecksum" @@ -178,15 +180,9 @@ func (m *computeInputPayloadChecksum) HandleFinalize( // ContentSHA256Header middleware handles the header ctx = v4.SetPayloadHash(ctx, streamingUnsignedPayloadTrailerPayloadHash) } - m.useTrailer = true - mw := &addInputChecksumTrailer{ - EnableTrailingChecksum: m.EnableTrailingChecksum, - RequireChecksum: m.RequireChecksum, - EnableComputePayloadHash: m.EnableComputePayloadHash, - EnableDecodedContentLengthHeader: m.EnableDecodedContentLengthHeader, - } - return mw.HandleFinalize(ctx, in, next) + ctx = middleware.WithStackValue(ctx, useTrailer{}, true) + return next.HandleFinalize(ctx, in) } // If trailing checksums are not enabled but protocol is still HTTPS @@ -268,6 +264,9 @@ func (m *addInputChecksumTrailer) HandleFinalize( ) ( out middleware.FinalizeOutput, metadata middleware.Metadata, err error, ) { + if enabled, _ := middleware.GetStackValue(ctx, useTrailer{}).(bool); !enabled { + return next.HandleFinalize(ctx, in) + } req, ok := in.Request.(*smithyhttp.Request) if !ok { return out, metadata, computeInputTrailingChecksumError{ diff --git a/service/internal/checksum/middleware_compute_input_checksum_test.go b/service/internal/checksum/middleware_compute_input_checksum_test.go index 3505663e5e1..c3362bf07e0 100644 --- a/service/internal/checksum/middleware_compute_input_checksum_test.go +++ b/service/internal/checksum/middleware_compute_input_checksum_test.go @@ -773,9 +773,16 @@ func TestComputeInputPayloadChecksum(t *testing.T) { EnableComputePayloadHash: true, EnableDecodedContentLengthHeader: true, } + if c.optionsFn != nil { c.optionsFn(m) } + trailerMiddleware := &addInputChecksumTrailer{ + EnableTrailingChecksum: m.EnableTrailingChecksum, + RequireChecksum: m.RequireChecksum, + EnableComputePayloadHash: m.EnableComputePayloadHash, + EnableDecodedContentLengthHeader: m.EnableDecodedContentLengthHeader, + } ctx := context.Background() var logged bytes.Buffer @@ -809,6 +816,7 @@ func TestComputeInputPayloadChecksum(t *testing.T) { // Build middleware stack.Finalize.Add(m, middleware.After) + stack.Finalize.Add(trailerMiddleware, middleware.After) // Validate defer to finalize was performed as expected stack.Finalize.Add(middleware.FinalizeMiddlewareFunc( diff --git a/service/internal/integrationtest/s3/checksum_test.go b/service/internal/integrationtest/s3/checksum_test.go index 9da73ac8d78..8ffd08d1544 100644 --- a/service/internal/integrationtest/s3/checksum_test.go +++ b/service/internal/integrationtest/s3/checksum_test.go @@ -7,21 +7,47 @@ import ( "bytes" "context" "fmt" - "io/ioutil" - "strings" - "testing" - "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3/types" s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/aws/smithy-go/logging" "github.com/google/go-cmp/cmp" + "io/ioutil" + "net/http" + "strings" + "testing" ) +type retryClient struct { + isRetriedCall bool + baseClient aws.HTTPClient +} + +type mockConnectionError struct{ err error } + +func (m mockConnectionError) ConnectionError() bool { + return true +} +func (m mockConnectionError) Error() string { + return fmt.Sprintf("request error: %v", m.err) +} +func (m mockConnectionError) Unwrap() error { + return m.err +} + +func (c *retryClient) Do(req *http.Request) (*http.Response, error) { + if !c.isRetriedCall { + c.isRetriedCall = true + return nil, mockConnectionError{} + } + return c.baseClient.Do(req) +} + func TestInteg_ObjectChecksums(t *testing.T) { cases := map[string]map[string]struct { disableHTTPS bool + retry bool params *s3.PutObjectInput expectErr string @@ -86,6 +112,7 @@ func TestInteg_ObjectChecksums(t *testing.T) { }, }, "autofill trailing checksum": { + retry: true, params: &s3.PutObjectInput{ Body: strings.NewReader("hello world"), ChecksumAlgorithm: s3types.ChecksumAlgorithmCrc32c, @@ -343,6 +370,14 @@ func TestInteg_ObjectChecksums(t *testing.T) { o.EndpointOptions.DisableHTTPS = c.disableHTTPS } + if c.retry { + opts := s3client.Options() + opts.HTTPClient = &retryClient{ + baseClient: opts.HTTPClient, + } + s3client = s3.New(opts) + } + t.Logf("putting bucket: %q, object: %q", *c.params.Bucket, *c.params.Key) putResult, err := s3client.PutObject(ctx, c.params, s3Options) if err == nil && len(c.expectErr) != 0 { diff --git a/service/neptune/api_client.go b/service/neptune/api_client.go index 37c3747798d..a56c87a9fac 100644 --- a/service/neptune/api_client.go +++ b/service/neptune/api_client.go @@ -551,6 +551,12 @@ func (c presignConverter) convertToPresignMiddleware(stack *middleware.Stack, op if _, ok := stack.Finalize.Get((*acceptencodingcust.DisableGzip)(nil).ID()); ok { stack.Finalize.Remove((*acceptencodingcust.DisableGzip)(nil).ID()) } + if _, ok := stack.Finalize.Get((*retry.Attempt)(nil).ID()); ok { + stack.Finalize.Remove((*retry.Attempt)(nil).ID()) + } + if _, ok := stack.Finalize.Get((*retry.MetricsHeader)(nil).ID()); ok { + stack.Finalize.Remove((*retry.MetricsHeader)(nil).ID()) + } stack.Deserialize.Clear() stack.Build.Remove((*awsmiddleware.ClientRequestID)(nil).ID()) stack.Build.Remove("UserAgent") diff --git a/service/polly/api_client.go b/service/polly/api_client.go index 24a8dbf0742..f86584d8794 100644 --- a/service/polly/api_client.go +++ b/service/polly/api_client.go @@ -551,6 +551,12 @@ func (c presignConverter) convertToPresignMiddleware(stack *middleware.Stack, op if _, ok := stack.Finalize.Get((*acceptencodingcust.DisableGzip)(nil).ID()); ok { stack.Finalize.Remove((*acceptencodingcust.DisableGzip)(nil).ID()) } + if _, ok := stack.Finalize.Get((*retry.Attempt)(nil).ID()); ok { + stack.Finalize.Remove((*retry.Attempt)(nil).ID()) + } + if _, ok := stack.Finalize.Get((*retry.MetricsHeader)(nil).ID()); ok { + stack.Finalize.Remove((*retry.MetricsHeader)(nil).ID()) + } stack.Deserialize.Clear() stack.Build.Remove((*awsmiddleware.ClientRequestID)(nil).ID()) stack.Build.Remove("UserAgent") diff --git a/service/rds/api_client.go b/service/rds/api_client.go index 4ceb38aabe6..55dd540d55f 100644 --- a/service/rds/api_client.go +++ b/service/rds/api_client.go @@ -552,6 +552,12 @@ func (c presignConverter) convertToPresignMiddleware(stack *middleware.Stack, op if _, ok := stack.Finalize.Get((*acceptencodingcust.DisableGzip)(nil).ID()); ok { stack.Finalize.Remove((*acceptencodingcust.DisableGzip)(nil).ID()) } + if _, ok := stack.Finalize.Get((*retry.Attempt)(nil).ID()); ok { + stack.Finalize.Remove((*retry.Attempt)(nil).ID()) + } + if _, ok := stack.Finalize.Get((*retry.MetricsHeader)(nil).ID()); ok { + stack.Finalize.Remove((*retry.MetricsHeader)(nil).ID()) + } stack.Deserialize.Clear() stack.Build.Remove((*awsmiddleware.ClientRequestID)(nil).ID()) stack.Build.Remove("UserAgent") diff --git a/service/s3/api_client.go b/service/s3/api_client.go index a3fe93b7f20..5e5f27b2d72 100644 --- a/service/s3/api_client.go +++ b/service/s3/api_client.go @@ -777,6 +777,12 @@ func (c presignConverter) convertToPresignMiddleware(stack *middleware.Stack, op if _, ok := stack.Finalize.Get((*acceptencodingcust.DisableGzip)(nil).ID()); ok { stack.Finalize.Remove((*acceptencodingcust.DisableGzip)(nil).ID()) } + if _, ok := stack.Finalize.Get((*retry.Attempt)(nil).ID()); ok { + stack.Finalize.Remove((*retry.Attempt)(nil).ID()) + } + if _, ok := stack.Finalize.Get((*retry.MetricsHeader)(nil).ID()); ok { + stack.Finalize.Remove((*retry.MetricsHeader)(nil).ID()) + } stack.Deserialize.Clear() stack.Build.Remove((*awsmiddleware.ClientRequestID)(nil).ID()) stack.Build.Remove("UserAgent") diff --git a/service/sts/api_client.go b/service/sts/api_client.go index 59cc4c70a38..369de83b8bc 100644 --- a/service/sts/api_client.go +++ b/service/sts/api_client.go @@ -552,6 +552,12 @@ func (c presignConverter) convertToPresignMiddleware(stack *middleware.Stack, op if _, ok := stack.Finalize.Get((*acceptencodingcust.DisableGzip)(nil).ID()); ok { stack.Finalize.Remove((*acceptencodingcust.DisableGzip)(nil).ID()) } + if _, ok := stack.Finalize.Get((*retry.Attempt)(nil).ID()); ok { + stack.Finalize.Remove((*retry.Attempt)(nil).ID()) + } + if _, ok := stack.Finalize.Get((*retry.MetricsHeader)(nil).ID()); ok { + stack.Finalize.Remove((*retry.MetricsHeader)(nil).ID()) + } stack.Deserialize.Clear() stack.Build.Remove((*awsmiddleware.ClientRequestID)(nil).ID()) stack.Build.Remove("UserAgent")