diff --git a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/httpResponse/bindingTraits/HttpResponseTraitWithoutHttpPayload.kt b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/httpResponse/bindingTraits/HttpResponseTraitWithoutHttpPayload.kt index 8cf4024a0..cd76dd545 100644 --- a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/httpResponse/bindingTraits/HttpResponseTraitWithoutHttpPayload.kt +++ b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/httpResponse/bindingTraits/HttpResponseTraitWithoutHttpPayload.kt @@ -50,15 +50,18 @@ class HttpResponseTraitWithoutHttpPayload( .filter { !it.member.hasTrait(HttpQueryTrait::class.java) } .toMutableSet() val streamingMember = bodyMembers.firstOrNull { it.member.targetOrSelf(ctx.model).hasTrait(StreamingTrait::class.java) } - if (streamingMember != null) { - writeStreamingMember(streamingMember) + val initialResponseMembers = bodyMembers.filter { + val targetShape = it.member.targetOrSelf(ctx.model) + targetShape?.hasTrait(StreamingTrait::class.java) == false + }.toSet() + writeStreamingMember(streamingMember, initialResponseMembers) } else if (bodyMembersWithoutQueryTrait.isNotEmpty()) { writeNonStreamingMembers(bodyMembersWithoutQueryTrait) } } - fun writeStreamingMember(streamingMember: HttpBindingDescriptor) { + fun writeStreamingMember(streamingMember: HttpBindingDescriptor, initialResponseMembers: Set) { val shape = ctx.model.expectShape(streamingMember.member.target) val symbol = ctx.symbolProvider.toSymbol(shape) val memberName = ctx.symbolProvider.toMemberName(streamingMember.member) @@ -74,6 +77,9 @@ class HttpResponseTraitWithoutHttpPayload( symbol ) writer.write("self.\$L = decoderStream.toAsyncStream()", memberName) + if (isRPCService(ctx) && initialResponseMembers.isNotEmpty()) { + writeInitialResponseMembers(initialResponseMembers) + } } writer.indent() writer.write("self.\$L = nil", memberName).closeBlock("}") @@ -133,4 +139,52 @@ class HttpResponseTraitWithoutHttpPayload( } private val path: String = "properties.".takeIf { outputShape.hasTrait() } ?: "" + + private fun writeInitialResponseMembers(initialResponseMembers: Set) { + writer.apply { + write("if let initialDataWithoutHttp = await messageDecoder.awaitInitialResponse() {") + indent() + write("let decoder = JSONDecoder()") + write("do {") + indent() + write("let response = try decoder.decode([String: String].self, from: initialDataWithoutHttp)") + initialResponseMembers.forEach { responseMember -> + val responseMemberName = ctx.symbolProvider.toMemberName(responseMember.member) + write("self.$responseMemberName = response[\"$responseMemberName\"].map { value in KinesisClientTypes.Tag(value: value) }") + } + dedent() + write("} catch {") + indent() + write("print(\"Error decoding JSON: \\(error)\")") + initialResponseMembers.forEach { responseMember -> + val responseMemberName = ctx.symbolProvider.toMemberName(responseMember.member) + write("self.$responseMemberName = nil") + } + dedent() + write("}") + dedent() + write("} else {") + indent() + initialResponseMembers.forEach { responseMember -> + val responseMemberName = ctx.symbolProvider.toMemberName(responseMember.member) + write("self.$responseMemberName = nil") + } + dedent() + write("}") + } + } + + private fun isRPCService(ctx: ProtocolGenerator.GenerationContext): Boolean { + return rpcBoundProtocols.contains(ctx.protocol.name) + } + + /** + * A set of RPC-bound Smithy protocols + */ + private val rpcBoundProtocols = setOf( + "awsJson1_0", + "awsJson1_1", + "awsQuery", + "ec2Query", + ) } diff --git a/smithy-swift-codegen/src/test/kotlin/EventStreamsInitialResponseTests.kt b/smithy-swift-codegen/src/test/kotlin/EventStreamsInitialResponseTests.kt new file mode 100644 index 000000000..ad10072bd --- /dev/null +++ b/smithy-swift-codegen/src/test/kotlin/EventStreamsInitialResponseTests.kt @@ -0,0 +1,72 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +import io.kotest.matchers.string.shouldContainOnlyOnce +import mocks.MockHttpAWSJson11ProtocolGenerator +import org.junit.jupiter.api.Test +import software.amazon.smithy.swift.codegen.integration.HttpBindingProtocolGenerator + +class EventStreamsInitialResponseTests { + @Test + fun `should attempt to decode response if initial-response members are present in RPC (awsJson) smithy model`() { + val context = setupInitialMessageTests( + "event-stream-initial-request-response.smithy", + "com.test#Example", + MockHttpAWSJson11ProtocolGenerator() + ) + val contents = getFileContents( + context.manifest, + "/InitialMessageEventStreams/models/TestStreamOperationWithInitialRequestResponseOutput+HttpResponseBinding.swift" + ) + contents.shouldSyntacticSanityCheck() + contents.shouldContainOnlyOnce( + """ + extension TestStreamOperationWithInitialRequestResponseOutput: ClientRuntime.HttpResponseBinding { + public init(httpResponse: ClientRuntime.HttpResponse, decoder: ClientRuntime.ResponseDecoder? = nil) async throws { + if case let .stream(stream) = httpResponse.body, let responseDecoder = decoder { + let messageDecoder: ClientRuntime.MessageDecoder? = nil + let decoderStream = ClientRuntime.EventStream.DefaultMessageDecoderStream(stream: stream, messageDecoder: messageDecoder, responseDecoder: responseDecoder) + self.value = decoderStream.toAsyncStream() + if let initialDataWithoutHttp = await messageDecoder.awaitInitialResponse() { + let decoder = JSONDecoder() + do { + let response = try decoder.decode([String: String].self, from: initialDataWithoutHttp) + self.initial1 = response["initial1"].map { value in KinesisClientTypes.Tag(value: value) } + self.initial2 = response["initial2"].map { value in KinesisClientTypes.Tag(value: value) } + } catch { + print("Error decoding JSON: \(error)") + self.initial1 = nil + self.initial2 = nil + } + } else { + self.initial1 = nil + self.initial2 = nil + } + } else { + self.value = nil + } + } + } + """.trimIndent() + ) + } + + private fun setupInitialMessageTests( + smithyFile: String, + serviceShapeId: String, + protocolGenerator: HttpBindingProtocolGenerator + ): TestContext { + val context = TestContext.initContextFrom(smithyFile, serviceShapeId, protocolGenerator) { model -> + model.defaultSettings(serviceShapeId, "InitialMessageEventStreams", "123", "InitialMessageEventStreams") + } + context.generator.initializeMiddleware(context.generationCtx) + context.generator.generateSerializers(context.generationCtx) + context.generator.generateProtocolClient(context.generationCtx) + context.generator.generateDeserializers(context.generationCtx) + context.generator.generateCodableConformanceForNestedTypes(context.generationCtx) + context.generationCtx.delegator.flushWriters() + return context + } +} diff --git a/smithy-swift-codegen/src/test/kotlin/mocks/MockHttpAWSJson11ProtocolGenerator.kt b/smithy-swift-codegen/src/test/kotlin/mocks/MockHttpAWSJson11ProtocolGenerator.kt index 30f5c040a..04d70b2f4 100644 --- a/smithy-swift-codegen/src/test/kotlin/mocks/MockHttpAWSJson11ProtocolGenerator.kt +++ b/smithy-swift-codegen/src/test/kotlin/mocks/MockHttpAWSJson11ProtocolGenerator.kt @@ -53,7 +53,8 @@ class MockAWSJson11HttpProtocolCustomizations() : DefaultHttpProtocolCustomizati writer: SwiftWriter, op: OperationShape, ) { - TODO("Not yet implemented") + // Not yet implemented + return } } diff --git a/smithy-swift-codegen/src/test/resources/event-stream-initial-request-response.smithy b/smithy-swift-codegen/src/test/resources/event-stream-initial-request-response.smithy new file mode 100644 index 000000000..417462aa9 --- /dev/null +++ b/smithy-swift-codegen/src/test/resources/event-stream-initial-request-response.smithy @@ -0,0 +1,39 @@ +namespace com.test + +use aws.protocols#awsJson1_1 +use aws.api#service +use aws.auth#sigv4 + +@awsJson1_1 +@sigv4(name: "event-stream-test") +@service(sdkId: "InitialMessageEventStreams") +service Example { + version: "123", + operations: [TestStreamOperationWithInitialRequestResponse] +} + +operation TestStreamOperationWithInitialRequestResponse { + input: TestStreamInputOutputInitialRequestResponse, + output: TestStreamInputOutputInitialRequestResponse, + errors: [SomeError], +} + +structure TestStreamInputOutputInitialRequestResponse { + @required + value: TestStream + initial1: String + initial2: String +} + +@error("client") +structure SomeError { + Message: String, +} + +structure MessageWithString { @eventPayload data: String } + +@streaming +union TestStream { + MessageWithString: MessageWithString, + SomeError: SomeError, +} \ No newline at end of file