Skip to content

Commit

Permalink
feat: Add eventstream input support for protocols that use XML for re…
Browse files Browse the repository at this point in the history
…quests (#1404)

* Codegen test setup for restXml x Eventstreaming operation.

* Add integration test for restXml x eventstream output, using S3:SelectObjectContent API. This test will fail until XML deserialization gets fixed to handle empty payload. The test is confirmed to pass after modifying EndEvent's readingClosure to check whether the type has any properties and returning a new instance instead of nil if the struct has no properties.

* Create a directory named message, and move MessageMarshallableGenerators & MessageUnmarshallableGenerators to it for organization.

* Add XMLMessageMarshallableGenerator that is virtually the same as MessageMarshallableGenerator, but uses closures for serialization & generates static closure variable named marshal instead of conformance to MessageMarshallable for the streaming union shapes.

* Add XMLMessageMarshallableGenerator usage to RestXmlProtocolGenerator, replacing the call to MessageMarshallableGenerator.renderNotImplemented that printed error comment that said it's not implemented yet.

* Fix typo on codegen.

* Update codegen tests. Add codegen tests that check codegen for XML eventstream input.

* Update tests for message encoder stream, now that it takes marshalClosure as argument instead of requestEncoder during initialization.

* Fix ktlint.

* Update expected output in AWS JSON protocol codegen test to have both marshalClosure and initialRequest arguments passed to EventStreambodyMiddleware construction.

* Update runtime test to reflect change in DefaultMessageEncoderStream.

---------

Co-authored-by: Sichan Yoo <chanyoo@amazon.com>
  • Loading branch information
sichanyoo and Sichan Yoo authored Mar 15, 2024
1 parent bfc3352 commit 02cc6cb
Show file tree
Hide file tree
Showing 17 changed files with 478 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
//
// Copyright Amazon.com Inc. or its affiliates.
// All Rights Reserved.
//
// SPDX-License-Identifier: Apache-2.0
//

import Foundation
import AWSS3
import XCTest

class S3EventStreamTests: S3XCTestCase {
private let objectKey = "integ-test-json-object"

override func setUp() async throws {
// Create S3 client & unique bucket with UUID.
try await super.setUp()
// Put a JSON object to the bucket.
try await putObject(
body: """
{ "Rules": [ {"id": "1"}, {"expr": "y > x"}, {"id": "2", "expr": "z = DEBUG"} ]}
{ "created": "June 27", "modified": "July 6" }
""",
key: objectKey
)
}

// Tests event stream output in restXml protocol using S3::SelectObjectContent.
func testEventStreamOutput() async throws {
let result = try await client.selectObjectContent(input: SelectObjectContentInput(
bucket: bucketName,
// Gets the two ID objects from the S3 object content.
expression: "SELECT id FROM S3Object[*].Rules[*].id WHERE id IS NOT MISSING",
expressionType: .sql,
inputSerialization: S3ClientTypes.InputSerialization(
json: S3ClientTypes.JSONInput(type: .document)
),
key: objectKey,
outputSerialization: S3ClientTypes.OutputSerialization(json: S3ClientTypes.JSONOutput())
))

let outputStream = result.payload
var actualOutput = ""

for try await event in outputStream! {
switch event {
case .records(let record):
actualOutput = actualOutput + (String(data: record.payload ?? Data(), encoding: .utf8) ?? "")
case .stats, .end:
continue
case .sdkUnknown(let data):
XCTFail(data)
default:
XCTFail("Encountered an unexpected event in output stream.")
}
}

// Check returned record event's payload was successfully received.
let expectedOutput = "{\"id\":\"1\"}\n{\"id\":\"2\"}\n"
XCTAssertEqual(expectedOutput, actualOutput)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ final class AWSMessageEncoderStreamTests: XCTestCase {
let sut = EventStream.DefaultMessageEncoderStream(
stream: baseStream,
messageEncoder: messageEncoder,
requestEncoder: JSONEncoder(),
messageSigner: messageSigner
marshalClosure: jsonMarshalClosure(requestEncoder: JSONEncoder()),
messageSigner: messageSigner,
initialRequestMessage: nil
)

var actual: [Data] = []
Expand Down Expand Up @@ -94,8 +95,9 @@ final class AWSMessageEncoderStreamTests: XCTestCase {
let sut = EventStream.DefaultMessageEncoderStream(
stream: baseStream,
messageEncoder: messageEncoder,
requestEncoder: JSONEncoder(),
messageSigner: messageSigner
marshalClosure: jsonMarshalClosure(requestEncoder: JSONEncoder()),
messageSigner: messageSigner,
initialRequestMessage: nil
)

let read1 = try await sut.readAsync(upToCount: 100)
Expand Down Expand Up @@ -136,7 +138,7 @@ final class AWSMessageEncoderStreamTests: XCTestCase {
let sut = EventStream.DefaultMessageEncoderStream(
stream: baseStream,
messageEncoder: messageEncoder,
requestEncoder: JSONEncoder(),
marshalClosure: jsonMarshalClosure(requestEncoder: JSONEncoder()),
messageSigner: messageSigner,
initialRequestMessage: EventStream.Message(
headers: [EventStream.Header(name: ":event-type", value: .string("initial-request"))],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
*/
package software.amazon.smithy.aws.swift.codegen

import software.amazon.smithy.aws.swift.codegen.message.MessageMarshallableGenerator
import software.amazon.smithy.aws.swift.codegen.message.MessageUnmarshallableGenerator
import software.amazon.smithy.aws.swift.codegen.middleware.OperationEndpointResolverMiddleware
import software.amazon.smithy.aws.swift.codegen.middleware.UserAgentMiddleware
import software.amazon.smithy.codegen.core.Symbol
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ package software.amazon.smithy.aws.swift.codegen.awsjson

import software.amazon.smithy.aws.swift.codegen.AWSHttpBindingProtocolGenerator
import software.amazon.smithy.aws.swift.codegen.AWSHttpProtocolClientCustomizableFactory
import software.amazon.smithy.aws.swift.codegen.MessageMarshallableGenerator
import software.amazon.smithy.aws.swift.codegen.MessageUnmarshallableGenerator
import software.amazon.smithy.aws.swift.codegen.message.MessageMarshallableGenerator
import software.amazon.smithy.aws.swift.codegen.message.MessageUnmarshallableGenerator
import software.amazon.smithy.aws.swift.codegen.middleware.AWSXAmzTargetMiddleware
import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
import software.amazon.smithy.model.shapes.OperationShape
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ package software.amazon.smithy.aws.swift.codegen.awsjson

import software.amazon.smithy.aws.swift.codegen.AWSHttpBindingProtocolGenerator
import software.amazon.smithy.aws.swift.codegen.AWSHttpProtocolClientCustomizableFactory
import software.amazon.smithy.aws.swift.codegen.MessageMarshallableGenerator
import software.amazon.smithy.aws.swift.codegen.MessageUnmarshallableGenerator
import software.amazon.smithy.aws.swift.codegen.message.MessageMarshallableGenerator
import software.amazon.smithy.aws.swift.codegen.message.MessageUnmarshallableGenerator
import software.amazon.smithy.aws.swift.codegen.middleware.AWSXAmzTargetMiddleware
import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
import software.amazon.smithy.model.shapes.OperationShape
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ package software.amazon.smithy.aws.swift.codegen.awsquery
import software.amazon.smithy.aws.swift.codegen.AWSHttpBindingProtocolGenerator
import software.amazon.smithy.aws.swift.codegen.AWSHttpProtocolClientCustomizableFactory
import software.amazon.smithy.aws.swift.codegen.FormURLHttpBindingResolver
import software.amazon.smithy.aws.swift.codegen.XMLMessageUnmarshallableGenerator
import software.amazon.smithy.aws.swift.codegen.ec2query.httpResponse.AWSQueryHttpResponseBindingErrorGenerator
import software.amazon.smithy.aws.swift.codegen.message.XMLMessageUnmarshallableGenerator
import software.amazon.smithy.aws.swift.codegen.restxml.AWSXMLHttpResponseBindingErrorInitGeneratorFactory
import software.amazon.smithy.aws.traits.protocols.AwsQueryTrait
import software.amazon.smithy.model.shapes.MemberShape
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ package software.amazon.smithy.aws.swift.codegen.ec2query
import software.amazon.smithy.aws.swift.codegen.AWSHttpBindingProtocolGenerator
import software.amazon.smithy.aws.swift.codegen.AWSHttpProtocolClientCustomizableFactory
import software.amazon.smithy.aws.swift.codegen.FormURLHttpBindingResolver
import software.amazon.smithy.aws.swift.codegen.XMLMessageUnmarshallableGenerator
import software.amazon.smithy.aws.swift.codegen.ec2query.httpResponse.AWSEc2QueryHttpResponseBindingErrorGenerator
import software.amazon.smithy.aws.swift.codegen.ec2query.httpResponse.AWSEc2QueryHttpResponseBindingErrorInitGeneratorFactory
import software.amazon.smithy.aws.swift.codegen.message.XMLMessageUnmarshallableGenerator
import software.amazon.smithy.aws.traits.protocols.Ec2QueryTrait
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package software.amazon.smithy.aws.swift.codegen
package software.amazon.smithy.aws.swift.codegen.message

import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.codegen.core.Symbol
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package software.amazon.smithy.aws.swift.codegen
package software.amazon.smithy.aws.swift.codegen.message

import software.amazon.smithy.aws.swift.codegen.AWSClientRuntimeTypes
import software.amazon.smithy.aws.swift.codegen.AWSSwiftDependency
import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.shapes.MemberShape
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
package software.amazon.smithy.aws.swift.codegen.message

import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.ShapeType
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EventHeaderTrait
import software.amazon.smithy.model.traits.EventPayloadTrait
import software.amazon.smithy.swift.codegen.ClientRuntimeTypes
import software.amazon.smithy.swift.codegen.SwiftDependency
import software.amazon.smithy.swift.codegen.SwiftWriter
import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator
import software.amazon.smithy.swift.codegen.integration.serde.readwrite.DocumentWritingClosureUtils
import software.amazon.smithy.swift.codegen.integration.serde.readwrite.WritingClosureUtils
import software.amazon.smithy.swift.codegen.model.eventStreamEvents
import software.amazon.smithy.swift.codegen.model.hasTrait

class XMLMessageMarshallableGenerator(
private val ctx: ProtocolGenerator.GenerationContext,
private val payloadContentType: String
) {
internal fun render(streamShape: UnionShape) {
val streamSymbol: Symbol = ctx.symbolProvider.toSymbol(streamShape)
val rootNamespace = ctx.settings.moduleName
val streamMember = Symbol.builder()
.definitionFile("./$rootNamespace/models/${streamSymbol.name}+MessageMarshallable.swift")
.name(streamSymbol.name)
.build()
ctx.delegator.useShapeWriter(streamMember) { writer ->
writer.apply {
addImport(SwiftDependency.CLIENT_RUNTIME.target)
openBlock("extension \$L {", "}", streamSymbol.fullName) {
openBlock(
"static var marshal: \$N<\$N> {", "}",
ClientRuntimeTypes.EventStream.MarshalClosure,
streamSymbol
) {
openBlock("{ (self) in", "}") {
write(
"var headers: [\$N] = [.init(name: \":message-type\", value: .string(\"event\"))]",
ClientRuntimeTypes.EventStream.Header
)
write("var payload: \$D", ClientRuntimeTypes.Core.Data)
write("switch self {")
streamShape.eventStreamEvents(ctx.model).forEach { member ->
val memberName = ctx.symbolProvider.toMemberName(member)
write("case \$L(let value):", memberName)
indent()
addStringHeader(":event-type", member.memberName)
val variant = ctx.model.expectShape(member.target)
val eventHeaderBindings = variant.members().filter {
it.hasTrait<EventHeaderTrait>()
}
val eventPayloadBinding = variant.members().firstOrNull {
it.hasTrait<EventPayloadTrait>()
}
val unbound = variant.members().filterNot {
it.hasTrait<EventHeaderTrait>() || it.hasTrait<EventPayloadTrait>()
}

eventHeaderBindings.forEach {
renderSerializeEventHeader(ctx, it, writer)
}

when {
eventPayloadBinding != null -> renderSerializeEventPayload(ctx, eventPayloadBinding, writer)
unbound.isNotEmpty() -> {
writer.addStringHeader(":content-type", payloadContentType)
renderPayloadSerialization(ctx, writer, variant)
}
}
writer.dedent()
}
writer.write("case .sdkUnknown(_):")
writer.indent()
writer.write(
"throw \$N(\"cannot serialize the unknown event type!\")",
ClientRuntimeTypes.Core.UnknownClientError
)
writer.dedent()
writer.write("}")
writer.write(
"return \$N(headers: headers, payload: payload ?? .init())",
ClientRuntimeTypes.EventStream.Message
)
}
}
}
}
}
}

private fun renderSerializeEventPayload(ctx: ProtocolGenerator.GenerationContext, member: MemberShape, writer: SwiftWriter) {
val target = ctx.model.expectShape(member.target)
val memberName = ctx.symbolProvider.toMemberName(member)
when (target.type) {
ShapeType.BLOB -> {
writer.addStringHeader(":content-type", "application/octet-stream")
writer.write("payload = value.\$L", memberName)
}
ShapeType.STRING -> {
writer.addStringHeader(":content-type", "text/plain")
writer.write("payload = value.\$L?.data(using: .utf8)", memberName)
}
ShapeType.STRUCTURE, ShapeType.UNION -> {
writer.addStringHeader(":content-type", payloadContentType)
renderPayloadSerialization(ctx, writer, target)
}
else -> throw CodegenException("unsupported shape type `${target.type}` for target: $target; expected blob, string, structure, or union for eventPayload member: $member")
}
}

/**
*
* if let headerValue = value.blob {
* headers.append(.init(name: "blob", value: .byteArray(headerValue)))
* }
* if let headerValue = value.boolean {
* headers.append(.init(name: "boolean", value: .bool(headerValue)))
* }
* if let headerValue = value.byte {
* headers.append(.init(name: "byte", value: .byte(headerValue)))
* }
* if let headerValue = value.int {
* headers.append(.init(name: "int", value: .int32(Int32(headerValue))))
* }
* if let headerValue = value.long {
* headers.append(.init(name: "long", value: .int64(Int64(headerValue))))
* }
* if let headerValue = value.short {
* headers.append(.init(name: "short", value: .int16(headerValue)))
* }
* if let headerValue = value.string {
* headers.append(.init(name: "string", value: .string(headerValue)))
* }
* if let headerValue = value.timestamp {
* headers.append(.init(name: "timestamp", value: .timestamp(headerValue)))
* }
*/
private fun renderSerializeEventHeader(ctx: ProtocolGenerator.GenerationContext, member: MemberShape, writer: SwiftWriter) {
val target = ctx.model.expectShape(member.target)
val headerValue = when (target.type) {
ShapeType.BOOLEAN -> "bool"
ShapeType.BYTE -> "byte"
ShapeType.SHORT -> "int16"
ShapeType.INTEGER -> "int32"
ShapeType.LONG -> "int64"
ShapeType.BLOB -> "byteArray"
ShapeType.STRING -> "string"
ShapeType.TIMESTAMP -> "timestamp"
else -> throw CodegenException("unsupported shape type `${target.type}` for eventHeader member `$member`; target: $target")
}

val memberName = ctx.symbolProvider.toMemberName(member)
writer.openBlock("if let headerValue = value.\$L {", "}", memberName) {
when (target.type) {
ShapeType.INTEGER -> {
writer.write("headers.append(.init(name: \"${member.memberName}\", value: .\$L(Int32(headerValue))))", headerValue)
}
ShapeType.LONG -> {
writer.write("headers.append(.init(name: \"${member.memberName}\", value: .\$L(Int64(headerValue))))", headerValue)
}
else -> {
writer.write("headers.append(.init(name: \"${member.memberName}\", value: .\$L(headerValue)))", headerValue)
}
}
}
}

private fun SwiftWriter.addStringHeader(name: String, value: String) {
write("headers.append(.init(name: \$S, value: .string(\$S)))", name, value)
}

private fun renderPayloadSerialization(ctx: ProtocolGenerator.GenerationContext, writer: SwiftWriter, shape: Shape) {
// get a payload serializer for the given members of the variant
val documentWritingClosure = DocumentWritingClosureUtils(ctx, writer).closure(shape)
val valueWritingClosure = WritingClosureUtils(ctx, writer).writingClosure(shape)
writer.write("payload = try \$L(value, \$L)", documentWritingClosure, valueWritingClosure)
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package software.amazon.smithy.aws.swift.codegen
package software.amazon.smithy.aws.swift.codegen.message

import software.amazon.smithy.aws.swift.codegen.AWSClientRuntimeTypes
import software.amazon.smithy.aws.swift.codegen.AWSSwiftDependency
import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.shapes.MemberShape
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ package software.amazon.smithy.aws.swift.codegen.restjson

import software.amazon.smithy.aws.swift.codegen.AWSHttpBindingProtocolGenerator
import software.amazon.smithy.aws.swift.codegen.AWSHttpProtocolClientCustomizableFactory
import software.amazon.smithy.aws.swift.codegen.MessageMarshallableGenerator
import software.amazon.smithy.aws.swift.codegen.MessageUnmarshallableGenerator
import software.amazon.smithy.aws.swift.codegen.message.MessageMarshallableGenerator
import software.amazon.smithy.aws.swift.codegen.message.MessageUnmarshallableGenerator
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.traits.TimestampFormatTrait
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
package software.amazon.smithy.aws.swift.codegen.restxml

import software.amazon.smithy.aws.swift.codegen.AWSHttpBindingProtocolGenerator
import software.amazon.smithy.aws.swift.codegen.XMLMessageUnmarshallableGenerator
import software.amazon.smithy.aws.swift.codegen.message.XMLMessageMarshallableGenerator
import software.amazon.smithy.aws.swift.codegen.message.XMLMessageUnmarshallableGenerator
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.Shape
Expand Down Expand Up @@ -62,6 +63,14 @@ class RestXmlProtocolGenerator : AWSHttpBindingProtocolGenerator() {
}
}

override fun generateMessageUnmarshallable(ctx: ProtocolGenerator.GenerationContext) {
var streamingShapes = inputStreamingShapes(ctx)
val messageMarshallableGenerator = XMLMessageMarshallableGenerator(ctx, defaultContentType)
streamingShapes.forEach { streamingMember ->
messageMarshallableGenerator.render(streamingMember)
}
}

override fun generateDeserializers(ctx: ProtocolGenerator.GenerationContext) {
super.generateDeserializers(ctx)
val errorShapes = resolveErrorShapes(ctx)
Expand Down
Loading

0 comments on commit 02cc6cb

Please sign in to comment.