Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for eventstream input & initial request to RPC-based protocols #1377

Merged
merged 12 commits into from
Mar 15, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,53 @@ final class AWSMessageEncoderStreamTests: XCTestCase {
let read4 = try await sut.readAsync(upToCount: 500)
XCTAssertNil(read4)
}

func testInitialRequestEvent() async throws {
let context = HttpContextBuilder().withSigningRegion(value: region)
.withSigningName(value: serviceName)
.withRequestSignature(value: requestSignature)
.withIdentityResolver(
value: TestCustomAWSCredentialIdentityResolver(credentials: credentials),
schemeID: "aws.auth#sigv4"
)
.withIdentityResolver(
value: TestCustomAWSCredentialIdentityResolver(credentials: credentials),
schemeID: "aws.auth#sigv4a"
)
.build()

let messageSigner = AWSEventStream.AWSMessageSigner(encoder: messageEncoder) {
return AWSSigV4Signer()
} signingConfig: {
return try await context.makeEventStreamSigningConfig()
} requestSignature: {
return context.getRequestSignature()
}

let sut = EventStream.DefaultMessageEncoderStream(
stream: baseStream,
messageEncoder: messageEncoder,
requestEncoder: JSONEncoder(),
messageSigner: messageSigner,
initialRequestMessage: EventStream.Message(
headers: [EventStream.Header(name: ":event-type", value: .string("initial-request"))],
payload: Data()
)
)

let data = try await sut.readToEndAsync()

let messageDecoder = AWSEventStream.AWSMessageDecoder()
try messageDecoder.feed(data: data ?? Data())
let initialRequestMessage = try messageDecoder.message()

let payloadDecoder = AWSEventStream.AWSMessageDecoder()
try payloadDecoder.feed(data: initialRequestMessage?.payload ?? Data())
let initialRequestPayload = try payloadDecoder.message()

XCTAssertEqual(
initialRequestPayload?.headers.first(where: { $0.name == ":event-type" })?.value,
.string("initial-request")
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ software.amazon.smithy.aws.swift.codegen.model.AWSHttpTraitTransformer
software.amazon.smithy.aws.swift.codegen.model.AWSEndpointTraitTransformer
software.amazon.smithy.aws.swift.codegen.customization.MessageEncoderIntegration
software.amazon.smithy.aws.swift.codegen.AWSClientConfigurationIntegration
software.amazon.smithy.swift.codegen.swiftintegrations.InitialRequestIntegration
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package software.amazon.smithy.aws.swift.codegen.awsjson

import io.kotest.matchers.string.shouldContainOnlyOnce
import org.junit.jupiter.api.Test
import software.amazon.smithy.aws.swift.codegen.TestContext
import software.amazon.smithy.aws.swift.codegen.TestContextGenerator
import software.amazon.smithy.aws.swift.codegen.shouldSyntacticSanityCheck
import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait

class AWSJsonHttpInitialRequestTests {
@Test
fun `001 Conformance to MessageMarshallable gets generated correctly`() {
val context = setupTests(
"awsjson/initial-request.smithy",
"com.test#InitialRequestTest"
)
val contents = TestContextGenerator.getFileContents(
context.manifest,
"/Example/models/TestStream+MessageMarshallable.swift"
)
contents.shouldSyntacticSanityCheck()
val expectedContents =
"""
extension InitialRequestTestClientTypes.TestStream: ClientRuntime.MessageMarshallable {
public func marshall(encoder: ClientRuntime.RequestEncoder) throws -> ClientRuntime.EventStream.Message {
var headers: [ClientRuntime.EventStream.Header] = [.init(name: ":message-type", value: .string("event"))]
var payload: ClientRuntime.Data? = nil
switch self {
case .messagewithstring(let value):
headers.append(.init(name: ":event-type", value: .string("MessageWithString")))
headers.append(.init(name: ":content-type", value: .string("text/plain")))
payload = value.data?.data(using: .utf8)
case .sdkUnknown(_):
throw ClientRuntime.ClientError.unknownError("cannot serialize the unknown event type!")
}
return ClientRuntime.EventStream.Message(headers: headers, payload: payload ?? .init())
}
}
""".trimIndent()
contents.shouldContainOnlyOnce(expectedContents)
}

@Test
fun `002 EventStreamBodyMiddleware gets generated into operation stack with initialRequestMessage`() {
val context = setupTests(
"awsjson/initial-request.smithy",
"com.test#InitialRequestTest"
)
val contents = TestContextGenerator.getFileContents(
context.manifest,
"/Example/InitialRequestTestClient.swift"
)
contents.shouldSyntacticSanityCheck()
val expectedContents = """
let initialRequestMessage = try input.makeInitialRequestMessage(encoder: encoder)
operation.serializeStep.intercept(position: .after, middleware: ClientRuntime.EventStreamBodyMiddleware<EventStreamOpInput, EventStreamOpOutput, InitialRequestTestClientTypes.TestStream>(keyPath: \.eventStream, defaultBody: "{}", initialRequestMessage: initialRequestMessage))
"""
contents.shouldContainOnlyOnce(expectedContents)
}

@Test
fun `003 Encodable conformance is generated for input struct with streaming union member with streaming member excluded`() {
val context = setupTests(
"awsjson/initial-request.smithy",
"com.test#InitialRequestTest"
)
val contents = TestContextGenerator.getFileContents(
context.manifest,
"/Example/models/EventStreamOpInput+Encodable.swift"
)
contents.shouldSyntacticSanityCheck()
val expectedContents = """
extension EventStreamOpInput: Swift.Encodable {
enum CodingKeys: Swift.String, Swift.CodingKey {
case inputMember1
case inputMember2
}

public func encode(to encoder: Swift.Encoder) throws {
var encodeContainer = encoder.container(keyedBy: CodingKeys.self)
if let inputMember1 = self.inputMember1 {
try encodeContainer.encode(inputMember1, forKey: .inputMember1)
}
if let inputMember2 = self.inputMember2 {
try encodeContainer.encode(inputMember2, forKey: .inputMember2)
}
}
}
""".trimIndent()
contents.shouldContainOnlyOnce(expectedContents)
}

@Test
fun `004 makeInitialRequestMessage method gets generated for input struct in extension`() {
val context = setupTests(
"awsjson/initial-request.smithy",
"com.test#InitialRequestTest"
)
val contents = TestContextGenerator.getFileContents(
context.manifest,
"/Example/models/EventStreamOpInput+MakeInitialRequestMessage.swift"
)
contents.shouldSyntacticSanityCheck()
val expectedContents = """
extension EventStreamOpInput {
func makeInitialRequestMessage(encoder: ClientRuntime.RequestEncoder) throws -> EventStream.Message {
let initialRequestPayload = try ClientRuntime.JSONReadWrite.documentWritingClosure(encoder: encoder)(self, JSONReadWrite.writingClosure())
let initialRequestMessage = EventStream.Message(
headers: [
EventStream.Header(name: ":message-type", value: .string("event")),
EventStream.Header(name: ":event-type", value: .string("initial-request")),
EventStream.Header(name: ":content-type", value: .string("application/x-amz-json-1.0"))
],
payload: initialRequestPayload
)
return initialRequestMessage
}
}
""".trimIndent()
contents.shouldContainOnlyOnce(expectedContents)
}
private fun setupTests(smithyFile: String, serviceShapeId: String): TestContext {
val context = TestContextGenerator.initContextFrom(smithyFile, serviceShapeId, AwsJson1_0Trait.ID)
AwsJson1_0_ProtocolGenerator().run {
generateMessageMarshallable(context.ctx)
generateSerializers(context.ctx)
initializeMiddleware(context.ctx)
generateProtocolClient(context.ctx)
}
context.ctx.delegator.flushWriters()
return context
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
namespace com.test

use aws.protocols#awsJson1_0
use aws.api#service
use aws.auth#sigv4

@awsJson1_0
@sigv4(name: "initial-request-test")
@service(sdkId: "InitialRequestTest")
service InitialRequestTest {
version: "03-04-2024",
operations: [EventStreamOp]
}

operation EventStreamOp {
input: StructureWithStream,
output: FillerStructure,
errors: [SomeError]
}

structure StructureWithStream {
@required
eventStream: TestStream
inputMember1: String
inputMember2: String
}

structure FillerStructure {
fillerMessage: String
}

@error("client")
structure SomeError {
Message: String,
}

structure MessageWithString { @eventPayload data: String }

@streaming
union TestStream {
MessageWithString: MessageWithString,
SomeError: SomeError,
}
Loading