Skip to content

Commit

Permalink
Refactor smithy-json and fix several protocol tests for the new JsonS…
Browse files Browse the repository at this point in the history
…erializerGenerator (#418)

* Split out a JsonValueWriter from JsonObjectWriter/JsonArrayWriter

* Add document support to JsonSerializerGenerator

* Add operation support to JsonSerializerGenerator

* Fix some bugs found by protocol tests

* Fix struct serializer function naming bug

* Fix handling of sparse lists and maps

* CR feedback
  • Loading branch information
jdisanti committed May 26, 2021
1 parent 70a3526 commit d79e80c
Show file tree
Hide file tree
Showing 7 changed files with 451 additions and 392 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,6 @@ fun <T : CodeWriter> T.rust(
this.write(contents, *args)
}

/**
* Convenience wrapper that tells Intellij that the contents of this block are Rust
*/
fun <T : CodeWriter> T.rustInline(
@Language("Rust", prefix = "macro_rules! foo { () => {{ ", suffix = "}}}") contents: String,
vararg args: Any
) {
this.writeInline(contents, *args)
}

/**
* Sibling method to [rustBlock] that enables `#{variablename}` style templating
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,7 @@ class HttpTraitProtocolGenerator(
payloadName: String,
serializer: StructuredDataSerializerGenerator
): BodyMetadata {
val targetShape = model.expectShape(member.target)
return when (targetShape) {
return when (val targetShape = model.expectShape(member.target)) {
// Write the raw string to the payload
is StringShape -> {
if (targetShape.hasTrait<EnumTrait>()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/

package software.amazon.smithy.rust.codegen.smithy.protocols

import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.util.toSnakeCase

/**
* Creates a unique name for a serialization function.
*
* The prefixes will look like the following (for grep):
* - serialize_operation
* - serialize_structure
* - serialize_union
* - serialize_payload
*/
fun RustSymbolProvider.serializeFunctionName(shape: Shape): String = shapeFunctionName("serialize", shape)

private fun RustSymbolProvider.shapeFunctionName(prefix: String, shape: Shape): String {
val symbolNameSnakeCase = toSymbol(shape).name.toSnakeCase()
return prefix + "_" + when (shape) {
is OperationShape -> "operation_$symbolNameSnakeCase"
is StructureShape -> "structure_$symbolNameSnakeCase"
is UnionShape -> "union_$symbolNameSnakeCase"
is MemberShape -> "payload_${shape.target.name.toSnakeCase()}_${shape.container.name.toSnakeCase()}"
else -> TODO("SerializerFunctionNamer.name: $shape")
}
}

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig
import software.amazon.smithy.rust.codegen.smithy.isOptional
import software.amazon.smithy.rust.codegen.smithy.protocols.XmlMemberIndex
import software.amazon.smithy.rust.codegen.smithy.protocols.XmlNameIndex
import software.amazon.smithy.rust.codegen.smithy.protocols.serializeFunctionName
import software.amazon.smithy.rust.codegen.smithy.rustType
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.expectMember
import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.toPascalCase
import software.amazon.smithy.rust.codegen.util.toSnakeCase

class XmlBindingTraitSerializerGenerator(protocolConfig: ProtocolConfig) : StructuredDataSerializerGenerator {
private val symbolProvider = protocolConfig.symbolProvider
Expand Down Expand Up @@ -95,7 +95,7 @@ class XmlBindingTraitSerializerGenerator(protocolConfig: ProtocolConfig) : Struc
this.copy(input = "$input.${symbolProvider.toMemberName(member)}")

override fun operationSerializer(operationShape: OperationShape): RuntimeType? {
val fnName = "serialize_operation_${operationShape.id.name.toSnakeCase()}"
val fnName = symbolProvider.serializeFunctionName(operationShape)
val inputShape = operationShape.inputShape(model)
val xmlMembers = operationShape.operationXmlMembers()
if (!xmlMembers.isNotEmpty()) {
Expand Down Expand Up @@ -132,8 +132,8 @@ class XmlBindingTraitSerializerGenerator(protocolConfig: ProtocolConfig) : Struc
}

override fun payloadSerializer(member: MemberShape): RuntimeType {
val fnName = symbolProvider.serializeFunctionName(member)
val target = model.expectShape(member.target, StructureShape::class.java)
val fnName = "serialize_payload_${target.id.name.toSnakeCase()}_${member.container.name.toSnakeCase()}"
return RuntimeType.forInlineFun(fnName, "xml_ser") {
val t = symbolProvider.toSymbol(member).rustType().stripOuter<RustType.Option>().render(true)
it.rustBlock(
Expand Down Expand Up @@ -274,12 +274,12 @@ class XmlBindingTraitSerializerGenerator(protocolConfig: ProtocolConfig) : Struc
members: XmlMemberIndex,
ctx: Ctx.Element
) {
val fnName = "serialize_structure_${structureShape.id.name.toSnakeCase()}"
val structureSymbol = symbolProvider.toSymbol(structureShape)
val fnName = symbolProvider.serializeFunctionName(structureShape)
val structureSerializer = RuntimeType.forInlineFun(fnName, "xml_ser") {
it.rustBlockTemplate(
"pub fn $fnName(input: &#{Shape}, writer: #{ElementWriter})",
"Shape" to structureSymbol,
"pub fn $fnName(input: &#{Input}, writer: #{ElementWriter})",
"Input" to structureSymbol,
*codegenScope
) {
if (!members.isNotEmpty()) {
Expand All @@ -293,12 +293,12 @@ class XmlBindingTraitSerializerGenerator(protocolConfig: ProtocolConfig) : Struc
}

private fun RustWriter.serializeUnion(unionShape: UnionShape, ctx: Ctx.Element) {
val fnName = "serialize_union_${unionShape.id.name.toSnakeCase()}"
val fnName = symbolProvider.serializeFunctionName(unionShape)
val unionSymbol = symbolProvider.toSymbol(unionShape)
val structureSerializer = RuntimeType.forInlineFun(fnName, "xml_ser") {
it.rustBlockTemplate(
"pub fn $fnName(input: &#{Shape}, writer: #{ElementWriter})",
"Shape" to unionSymbol,
"pub fn $fnName(input: &#{Input}, writer: #{ElementWriter})",
"Input" to unionSymbol,
*codegenScope
) {
rust("let mut scope_writer = writer.finish();")
Expand Down Expand Up @@ -369,10 +369,10 @@ class XmlBindingTraitSerializerGenerator(protocolConfig: ProtocolConfig) : Struc
}

private fun OperationShape.operationXmlMembers(): XmlMemberIndex {
val outputShape = this.inputShape(model)
val inputShape = this.inputShape(model)
val documentMembers =
httpIndex.getRequestBindings(this).filter { it.value.location == HttpBinding.Location.DOCUMENT }
.keys.map { outputShape.expectMember(it) }
.keys.map { inputShape.expectMember(it) }
return XmlMemberIndex.fromMembers(documentMembers)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,20 @@ class JsonSerializerGeneratorTest {
use aws.protocols#restJson1
union Choice {
map: MyMap,
list: SomeList,
s: String,
enum: FooEnum,
blob: Blob,
boolean: Boolean,
date: Timestamp,
document: Document,
enum: FooEnum,
int: Integer,
list: SomeList,
listSparse: SomeSparseList,
long: Long,
map: MyMap,
mapSparse: MySparseMap,
number: Double,
s: String,
top: Top,
blob: Blob,
document: Document,
}
@enum([{name: "FOO", value: "FOO"}])
Expand All @@ -45,10 +50,21 @@ class JsonSerializerGeneratorTest {
value: Choice,
}
@sparse
map MySparseMap {
key: String,
value: Choice,
}
list SomeList {
member: Choice
}
@sparse
list SomeSparseList {
member: Choice
}
structure Top {
choice: Choice,
field: String,
Expand All @@ -64,8 +80,8 @@ class JsonSerializerGeneratorTest {
structure OpInput {
@httpHeader("x-test")
someHeader: String,
@httpPayload
payload: Top
top: Top
}
@http(uri: "/top", method: "POST")
Expand All @@ -84,7 +100,6 @@ class JsonSerializerGeneratorTest {
)
val symbolProvider = testSymbolProvider(model)
val parserGenerator = JsonSerializerGenerator(testProtocolConfig(model))
val payloadGenerator = parserGenerator.payloadSerializer(model.lookup("test#OpInput\$payload"))
val operationGenerator = parserGenerator.operationSerializer(model.lookup("test#Op"))
val documentGenerator = parserGenerator.documentSerializer()

Expand All @@ -94,20 +109,19 @@ class JsonSerializerGeneratorTest {
"""
use model::Top;
// Generate the operation/document serializers even if they're not directly tested
// ${writer.format(operationGenerator!!)}
// Generate the document serializer even though it's not tested directly
// ${writer.format(documentGenerator)}
let inp = crate::input::OpInput::builder().payload(
let input = crate::input::OpInput::builder().top(
Top::builder()
.field("hello!")
.extra(45)
.recursive(Top::builder().extra(55).build())
.build()
).build().unwrap();
let serialized = ${writer.format(payloadGenerator)}(&inp.payload.unwrap()).unwrap();
let serialized = ${writer.format(operationGenerator!!)}(&input).unwrap();
let output = std::str::from_utf8(serialized.bytes().unwrap()).unwrap();
assert_eq!(output, r#"{"field":"hello!","extra":45,"rec":[{"extra":55}]}"#);
assert_eq!(output, r#"{"top":{"field":"hello!","extra":45,"rec":[{"extra":55}]}}"#);
"""
)
}
Expand Down
Loading

0 comments on commit d79e80c

Please sign in to comment.