Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: S3 200 error customization should only not apply to streaming + blob shapes #1633

Merged
merged 16 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,18 @@ import AwsCommonRuntimeKit
import ClientRuntime

public class MockHttpClientEngine: HTTPClient {
private let errorResponsePayload: String

// Public initializer
public init() {}
public init(response: String) {
self.errorResponsePayload = response
}

func successHttpResponse(request: SmithyHTTPAPI.HTTPRequest) -> HTTPResponse {
let errorResponsePayload = """
<Error>
<Code>SlowDown</Code>
<Message>Please reduce your request rate.</Message>
<RequestId>K2H6N7ZGQT6WHCEG</RequestId>
<HostId>WWoZlnK4pTjKCYn6eNV7GgOurabfqLkjbSyqTvDMGBaI9uwzyNhSaDhOCPs8paFGye7S6b/AB3A=</HostId>
</Error>
"""
request.withHeader(name: "Date", value: "Wed, 21 Oct 2015 07:28:00 GMT")
return HTTPResponse(
headers: request.headers,
body: ByteStream.data(errorResponsePayload.data(using: .utf8)),
body: ByteStream.data(self.errorResponsePayload.data(using: .utf8)),
statusCode: .ok
)
}
Expand All @@ -44,25 +39,101 @@ public class MockHttpClientEngine: HTTPClient {

class S3ErrorIn200Test: XCTestCase {

let errorInternalErrorResponsePayload = """
<Error>
<Code>InternalError</Code>
<Message>We encountered an internal error. Please try again.</Message>
<RequestId>656c76696e6727732072657175657374</RequestId>
<HostId>Uuag1LuByRx9e6j5Onimru9pO4ZVKnJ2Qz7/C1NPcfTWAtRPfTaOFg==</HostId>
</Error>
"""

let errorSlowDownResponsePayload = """
<Error>
<Code>SlowDown</Code>
<Message>Please reduce your request rate.</Message>
<RequestId>K2H6N7ZGQT6WHCEG</RequestId>
<HostId>WWoZlnK4pTjKCYn6eNV7GgOurabfqLkjbSyqTvDMGBaI9uwzyNhSaDhOCPs8paFGye7S6b/AB3A=</HostId>
</Error>
"""

let shouldNotApplyResponsePayload = """
<DeleteResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
<Deleted>
<Key>sample1.txt</Key>
</Deleted>
<Error>
<Key>sample2.txt</Key>
<Code>AccessDenied</Code>
<Message>Access Denied</Message>
</Error>
</DeleteResult>
"""

override class func setUp() {
AwsCommonRuntimeKit.CommonRuntimeKit.initialize()
}

/// S3Client throws expected error in response (200) with <Error> tag
func test_foundExpectedError() async throws {
/// S3Client throws expected InternalError error in response (200) with <Error> tag
func test_foundInternalErrorExpectedError() async throws {
let config = try await S3Client.S3ClientConfiguration(region: "us-west-2")
config.httpClientEngine = MockHttpClientEngine()
config.httpClientEngine = MockHttpClientEngine(response: errorInternalErrorResponsePayload)
let client = S3Client(config: config)

do {
// any method on S3Client where the output shape doesnt have a stream
// any method on S3Client where the output shape doesnt have a blob stream
_ = try await client.listBuckets(input: .init())
XCTFail("Expected an error to be thrown, but it was not.")
} catch let error as UnknownAWSHTTPServiceError {
// check for the error we added in our mock client
XCTAssertEqual("InternalError", error.typeName)
XCTAssertEqual("We encountered an internal error. Please try again.", error.message)
} catch {
XCTFail("Unexpected error: \(error)")
}
}

/// S3Client throws expected SlowDown error in response (200) with <Error> tag
func test_foundSlowDownExpectedError() async throws {
let config = try await S3Client.S3ClientConfiguration(region: "us-west-2")
config.httpClientEngine = MockHttpClientEngine(response: errorSlowDownResponsePayload)
let client = S3Client(config: config)

do {
// any method on S3Client where the output shape doesnt have a blob stream
_ = try await client.listBuckets(input: .init())
XCTFail("Expected an error to be thrown, but it was not.")
} catch let error as UnknownAWSHTTPServiceError {
// check for the error we added in our mock client
XCTAssertEqual("SlowDown", error.typeName)
XCTAssertEqual("Please reduce your request rate.", error.message)
} catch {
XCTFail("Unexpected error: \(error)")
}
}

/// S3Client does not throw error when <Error> is not at the root
func test_noErrorExpected() async throws {
let config = try await S3Client.S3ClientConfiguration(region: "us-west-2")
config.httpClientEngine = MockHttpClientEngine(response: shouldNotApplyResponsePayload)
let client = S3Client(config: config)

do {
// any method on S3Client where the output shape doesnt have a stream
let result = try await client.deleteObjects(input: .init(delete: .init(objects: [.init(key: "test")])))

// Check results
XCTAssertEqual(result.deleted?.count, 1)
XCTAssertEqual(result.errors?.count, 1)

let actualDeleted = result.deleted?.first
XCTAssertEqual(actualDeleted?.key, "sample1.txt")

let actualError = result.errors?.first
XCTAssertEqual(actualError?.code, "AccessDenied")
XCTAssertEqual(actualError?.key, "sample2.txt")
} catch let error {
XCTFail("Expected success, but received \(error).")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,14 @@ class S3EventStreamTests: S3XCTestCase {
outputSerialization: S3ClientTypes.OutputSerialization(json: S3ClientTypes.JSONOutput())
))

let outputStream = result.payload
guard let outputStream = result.payload else {
XCTFail("result.payload is nil")
return
}

var actualOutput = ""

for try await event in outputStream! {
for try await event in outputStream {
switch event {
case .records(let record):
actualOutput = actualOutput + (String(data: record.payload ?? Data(), encoding: .utf8) ?? "")
Expand Down
2 changes: 2 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ extension Target.Dependency {
static var smithyWaitersAPI: Self { .product(name: "SmithyWaitersAPI", package: "smithy-swift") }
static var smithyTestUtils: Self { .product(name: "SmithyTestUtil", package: "smithy-swift") }
static var smithyStreams: Self { .product(name: "SmithyStreams", package: "smithy-swift") }
static var smithyXML: Self { .product(name: "SmithyXML", package: "smithy-swift") }
}

// MARK: - Base Package
Expand Down Expand Up @@ -79,6 +80,7 @@ let package = Package(
.smithyRetries,
.smithyEventStreamsAPI,
.smithyEventStreamsAuthAPI,
.smithyXML,
.awsSDKCommon,
.awsSDKHTTPAuth,
.awsSDKIdentity
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,25 @@
// SPDX-License-Identifier: Apache-2.0
//

import enum Smithy.ByteStream
import class Smithy.Context
import ClientRuntime
import SmithyHTTPAPI
import SmithyXML
import struct Foundation.Data
import SmithyStreams

public struct AWSS3ErrorWith200StatusXMLMiddleware<OperationStackInput, OperationStackOutput> {
public let id: String = "AWSS3ErrorWith200StatusXMLMiddleware"
private let errorStatusCode: HTTPStatusCode = .internalServerError

public init() {}

private func isErrorWith200Status(response: HTTPResponse) async throws -> Bool {
// Check if the status code is OK (200)
guard response.statusCode == .ok else {
return false
}

// Check if the response body contains an XML Error
guard let data = try await response.body.readData() else {
return false
}
private func isRootErrorElement(data: Data) throws -> Bool {
let reader = try Reader.from(data: data)

response.body = .data(data)
let xmlString = String(decoding: data, as: UTF8.self)
return xmlString.contains("<Error>")
// Check if there's an "Error" node at the root of the XML response
return reader.nodeInfo.name == "Error"
jbelkins marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand All @@ -40,9 +35,41 @@ extension AWSS3ErrorWith200StatusXMLMiddleware: HttpInterceptor {
context: some MutableResponse<Self.InputType, Self.RequestType, Self.ResponseType>
) async throws {
let response = context.getResponse()
if try await isErrorWith200Status(response: response) {
response.statusCode = errorStatusCode
context.updateResponse(updated: response)

// Check if the status code is OK (200)
guard response.statusCode == .ok else {
return
}

guard let data = try await response.body.readData() else {
return
}

let statusCode = try isRootErrorElement(data: data) ? errorStatusCode : response.statusCode
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we know we will have enough data here to determine if the body returns an error?
readData() could return zero bytes if the data is not yet available, and the stream could be filled with error data after this step is complete.

Copy link
Contributor Author

@dayaffe dayaffe Jul 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't get all the messages back all at once, each response is a separate action so it should be checking each streamed message for Error at root


// For event streams the body needs to be copied as buffered streams are non-seekable
let updatedBody = response.body.copy(data: data)

let updatedResponse = response.copy(
body: updatedBody,
statusCode: statusCode
)
sichanyoo marked this conversation as resolved.
Show resolved Hide resolved

context.updateResponse(updated: updatedResponse)
}
}

extension ByteStream {

// Copy an existing ByteStream, optionally with new data
public func copy(data: Data?) -> ByteStream {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be public? Is it used anywhere other than in the middleware above?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, can this param be non-nil Data? In the code above, data is safe-unwrapped before this function is called

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure I can make it private and non-nil unless we want to make it generally available and move it to a place like SmithyStreams or ClientRuntime. Only issue with that is neither has both ByteStream and the various types of streams

switch self {
case .data(let existingData):
return .data(data ?? existingData)
case .stream(let existingStream):
return .stream(data != nil ? BufferedStream(data: data, isClosed: true) : existingStream)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the copied stream be created un-closed? At this point we don't know if there will be more data in the stream in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If its created unclosed then the AsyncThrowingStream will never complete when looping through it. It yields each event until an end event and then hangs forever.

case .noStream:
return .noStream
}
}
}
1 change: 1 addition & 0 deletions Sources/Services/AWSS3/Sources/AWSS3/S3Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7604,6 +7604,7 @@ extension S3Client {
builder.selectAuthScheme(ClientRuntime.AuthSchemeMiddleware<SelectObjectContentOutput>())
builder.interceptors.add(AWSClientRuntime.AmzSdkInvocationIdMiddleware<SelectObjectContentInput, SelectObjectContentOutput>())
builder.interceptors.add(AWSClientRuntime.AmzSdkRequestMiddleware<SelectObjectContentInput, SelectObjectContentOutput>(maxRetries: config.retryStrategyOptions.maxRetriesBase))
builder.interceptors.add(AWSClientRuntime.AWSS3ErrorWith200StatusXMLMiddleware<SelectObjectContentInput, SelectObjectContentOutput>())
var metricsAttributes = Smithy.Attributes()
metricsAttributes.set(key: ClientRuntime.OrchestratorMetricsAttributesKeys.service, value: "S3")
metricsAttributes.set(key: ClientRuntime.OrchestratorMetricsAttributesKeys.method, value: "SelectObjectContent")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import software.amazon.smithy.aws.swift.codegen.swiftmodules.AWSClientRuntimeTyp
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.traits.StreamingTrait
import software.amazon.smithy.swift.codegen.SwiftSettings
import software.amazon.smithy.swift.codegen.SwiftWriter
import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator
Expand All @@ -17,7 +16,7 @@ import software.amazon.smithy.swift.codegen.integration.middlewares.handlers.Mid
import software.amazon.smithy.swift.codegen.middleware.MiddlewareRenderable
import software.amazon.smithy.swift.codegen.middleware.OperationMiddleware
import software.amazon.smithy.swift.codegen.model.expectShape
import software.amazon.smithy.swift.codegen.model.hasTrait
import software.amazon.smithy.swift.codegen.model.isStreaming

/**
* Register interceptor to handle S3 error responses returned with an HTTP 200 status code.
Expand All @@ -39,10 +38,14 @@ class S3ErrorWith200StatusIntegration : SwiftIntegration {
// Instead of playing whack-a-mole broadly apply this interceptor to everything but streaming responses
// which adds a small amount of overhead to response processing.
val output = ctx.model.expectShape(operationShape.output.get())
val outputIsNotStreaming = output.members().none {
it.hasTrait<StreamingTrait>() || ctx.model.expectShape(it.target).hasTrait<StreamingTrait>()
val outputIsNotAStreamingBlobShape = output.members().none {
val targetShape = ctx.model.expectShape(it.target)
val isBlob = it.isBlobShape || targetShape.isBlobShape
val isStreaming = it.isStreaming || targetShape.isStreaming
isBlob && isStreaming
}
if (outputIsNotStreaming) {

if (outputIsNotAStreamingBlobShape) {
operationMiddleware.appendMiddleware(operationShape, S3HandleError200ResponseMiddleware)
}
}
Expand Down
Loading