diff --git a/codegen-test/model/rest-json-extras.smithy b/codegen-test/model/rest-json-extras.smithy index 9d23322cb1..35387b2ec8 100644 --- a/codegen-test/model/rest-json-extras.smithy +++ b/codegen-test/model/rest-json-extras.smithy @@ -53,7 +53,7 @@ apply QueryPrecedence @httpRequestTests([ @restJson1 service RestJsonExtras { version: "2019-12-16", - operations: [StringPayload, PrimitiveIntHeader, EnumQuery, StatusResponse] + operations: [StringPayload, PrimitiveIntHeader, EnumQuery, StatusResponse, MapWithEnumKeyOp] } @http(uri: "/StringPayload", method: "POST") @@ -135,3 +135,37 @@ structure StatusOutput { @httpResponseCode field: PrimitiveInt } + +map MapWithEnumKey { + key: StringEnum, + value: String, +} + +structure MapWithEnumKeyInputOutput { + map: MapWithEnumKey, +} + +@http(uri: "/map-with-enum-key", method: "POST") +@httpRequestTests([ + { + id: "MapWithEnumKeyRequest", + uri: "/map-with-enum-key", + method: "POST", + protocol: "aws.protocols#restJson1", + body: "{\"map\":{\"enumvalue\":\"something\"}}", + params: { map: { "enumvalue": "something" } } + }, +]) +@httpResponseTests([ + { + id: "MapWithEnumKeyResponse", + protocol: "aws.protocols#restJson1", + code: 200, + body: "{\"map\":{\"enumvalue\":\"something\"}}", + params: { map: { "enumvalue": "something" } }, + }, +]) +operation MapWithEnumKeyOp { + input: MapWithEnumKeyInputOutput, + output: MapWithEnumKeyInputOutput, +} \ No newline at end of file diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt index 29d9b411d9..9bce52d9b1 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt @@ -140,9 +140,7 @@ data class RuntimeType(val name: String?, val dependency: RustDependency?, val n path, dependency = CargoDependency.Serde, namespace = "serde" ) - val Serialize = RuntimeType("Serialize", CargoDependency.Serde, namespace = "serde") val Deserialize: RuntimeType = RuntimeType("Deserialize", CargoDependency.Serde, namespace = "serde") - val Serializer = RuntimeType("Serializer", CargoDependency.Serde, namespace = "serde") val Deserializer = RuntimeType("Deserializer", CargoDependency.Serde, namespace = "serde") fun SerdeJson(path: String) = RuntimeType(path, dependency = CargoDependency.SerdeJson, namespace = "serde_json") diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/EnumGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/EnumGenerator.kt index 90f0d42b0b..acbfe425fe 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/EnumGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/EnumGenerator.kt @@ -176,12 +176,6 @@ class EnumGenerator( private fun renderSerde() { writer.rustTemplate( """ - impl #{serialize} for $enumName { - fn serialize(&self, serializer: S) -> Result<::Ok, ::Error> where S: #{serializer}{ - serializer.serialize_str(self.as_str()) - } - } - impl<'de> #{deserialize}<'de> for $enumName { fn deserialize(deserializer: D) -> Result where D: #{deserializer}<'de> { let data = <&str>::deserialize(deserializer)?; @@ -189,8 +183,6 @@ class EnumGenerator( } } """, - "serializer" to RuntimeType.Serializer, - "serialize" to RuntimeType.Serialize, "deserializer" to RuntimeType.Deserializer, "deserialize" to RuntimeType.Deserialize ) diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt index e10dd65a40..741842f096 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt @@ -44,7 +44,7 @@ import software.amazon.smithy.rust.codegen.smithy.letIf import software.amazon.smithy.rust.codegen.smithy.rustType import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.expectMember -import software.amazon.smithy.rust.codegen.util.getTrait +import software.amazon.smithy.rust.codegen.util.hasTrait import software.amazon.smithy.rust.codegen.util.isStreaming import software.amazon.smithy.rust.codegen.util.toPascalCase @@ -126,12 +126,7 @@ class Instantiator( * If the shape is optional: `Some(inner)` or `None` * otherwise: `inner` */ - private fun renderMember( - writer: RustWriter, - shape: MemberShape, - arg: Node, - ctx: Ctx - ) { + private fun renderMember(writer: RustWriter, shape: MemberShape, arg: Node, ctx: Ctx) { val target = model.expectShape(shape.target) val symbol = symbolProvider.toSymbol(shape) if (arg is NullNode) { @@ -176,28 +171,24 @@ class Instantiator( * ret * } */ - private fun renderMap( - writer: RustWriter, - shape: MapShape, - data: ObjectNode, - ctx: Ctx, - ) { - val lowercase = when (ctx.lowercaseMapKeys) { - true -> ".to_ascii_lowercase()" - else -> "" - } - if (data.members.isNotEmpty()) { + private fun renderMap(writer: RustWriter, shape: MapShape, data: ObjectNode, ctx: Ctx) { + if (data.members.isEmpty()) { + writer.write("#T::new()", RustType.HashMap.RuntimeType) + } else { writer.rustBlock("") { write("let mut ret = #T::new();", RustType.HashMap.RuntimeType) - data.members.forEach { (k, v) -> - withBlock("ret.insert(${k.value.dq()}.to_string()$lowercase,", ");") { - renderMember(this, shape.value, v, ctx) + for ((key, value) in data.members) { + withBlock("ret.insert(", ");") { + renderMember(this, shape.key, key, ctx) + when (ctx.lowercaseMapKeys) { + true -> rust(".to_ascii_lowercase(), ") + else -> rust(", ") + } + renderMember(this, shape.value, value, ctx) } } write("ret") } - } else { - writer.write("#T::new()", RustType.HashMap.RuntimeType) } } @@ -206,12 +197,7 @@ class Instantiator( * MyUnion::Variant(...) * ``` */ - private fun renderUnion( - writer: RustWriter, - shape: UnionShape, - data: ObjectNode, - ctx: Ctx - ) { + private fun renderUnion(writer: RustWriter, shape: UnionShape, data: ObjectNode, ctx: Ctx) { val unionSymbol = symbolProvider.toSymbol(shape) check(data.members.size == 1) val variant = data.members.iterator().next() @@ -230,12 +216,7 @@ class Instantiator( * vec![..., ..., ...] * ``` */ - private fun renderList( - writer: RustWriter, - shape: CollectionShape, - data: ArrayNode, - ctx: Ctx - ) { + private fun renderList(writer: RustWriter, shape: CollectionShape, data: ArrayNode, ctx: Ctx) { writer.withBlock("vec![", "]") { data.elements.forEach { v -> renderMember(this, shape.member, v, ctx) @@ -244,14 +225,9 @@ class Instantiator( } } - private fun renderString( - writer: RustWriter, - shape: StringShape, - arg: StringNode - ) { - val enumTrait = shape.getTrait() + private fun renderString(writer: RustWriter, shape: StringShape, arg: StringNode) { val data = writer.escape(arg.value).dq() - if (enumTrait == null) { + if (!shape.hasTrait()) { writer.rust("$data.to_string()") } else { val enumSymbol = symbolProvider.toSymbol(shape) @@ -264,12 +240,7 @@ class Instantiator( * MyStruct::builder().field_1("hello").field_2(5).build() * ``` */ - private fun renderStructure( - writer: RustWriter, - shape: StructureShape, - data: ObjectNode, - ctx: Ctx - ) { + private fun renderStructure(writer: RustWriter, shape: StructureShape, data: ObjectNode, ctx: Ctx) { writer.write("#T::builder()", symbolProvider.toSymbol(shape)) data.members.forEach { (key, value) -> val memberShape = shape.expectMember(key.value) diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson10.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson.kt similarity index 99% rename from codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson10.kt rename to codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson.kt index 0a7de3333e..271aea3e76 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson10.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson.kt @@ -36,8 +36,8 @@ import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol import software.amazon.smithy.rust.codegen.smithy.generators.operationBuildError import software.amazon.smithy.rust.codegen.smithy.locatedIn import software.amazon.smithy.rust.codegen.smithy.meta +import software.amazon.smithy.rust.codegen.smithy.protocols.parsers.JsonSerializerGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.parsers.SerdeJsonParserGenerator -import software.amazon.smithy.rust.codegen.smithy.protocols.parsers.SerdeJsonSerializerGenerator import software.amazon.smithy.rust.codegen.smithy.rustType import software.amazon.smithy.rust.codegen.smithy.traits.InputBodyTrait import software.amazon.smithy.rust.codegen.smithy.traits.OutputBodyTrait @@ -198,7 +198,7 @@ class BasicAwsJsonGenerator( } override fun RustWriter.body(self: String, operationShape: OperationShape): BodyMetadata { - val generator = SerdeJsonSerializerGenerator(protocolConfig) + val generator = JsonSerializerGenerator(protocolConfig) val serializer = generator.operationSerializer(operationShape) serializer?.also { sym -> rustTemplate( diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsRestJson.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsRestJson.kt index 9444fd2769..7adf5b22ed 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsRestJson.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsRestJson.kt @@ -18,8 +18,8 @@ import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolGeneratorFactory import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolSupport +import software.amazon.smithy.rust.codegen.smithy.protocols.parsers.JsonSerializerGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.parsers.SerdeJsonParserGenerator -import software.amazon.smithy.rust.codegen.smithy.protocols.parsers.SerdeJsonSerializerGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.parsers.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.parsers.StructuredDataSerializerGenerator import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer @@ -81,7 +81,7 @@ class RestJson(private val protocolConfig: ProtocolConfig) : Protocol { } override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator { - return SerdeJsonSerializerGenerator(protocolConfig) + return JsonSerializerGenerator(protocolConfig) } override fun parseGenericError(operationShape: OperationShape): RuntimeType { diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/CustomSerializerGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/CustomSerializerGenerator.kt index f47ce2a8e5..1d22386972 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/CustomSerializerGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/CustomSerializerGenerator.kt @@ -38,11 +38,9 @@ import software.amazon.smithy.rust.codegen.smithy.rustType */ class CustomSerializerGenerator( private val symbolProvider: RustSymbolProvider, - private val model: Model, + model: Model, private val defaultTimestampFormat: TimestampFormatTrait.Format ) { - private val inp = "_inp" - private val ser = "_serializer" private val httpBindingIndex = HttpBindingIndex.of(model) private val runtimeConfig = symbolProvider.config().runtimeConfig @@ -51,46 +49,6 @@ class CustomSerializerGenerator( private val document = RuntimeType.Document(runtimeConfig).toSymbol().rustType() private val customShapes = setOf(instant, blob, document) - /** - * Generate a custom serialization function for [memberShape], suitable to be used - * in the serde annotation `serialize_with` (See [JsonSerializerSymbolProvider]) - * - * The returned object is a RuntimeType, which generates and creates all necessary dependencies when used. - * - * If this shape does not require custom serialization, this function returns null. - * - * For Example, for `Option` being serialized in Epoch seconds: - * To make it more readable, I've manually removed the fully qualified types. - * ```rust - * pub fn stdoptionoptioninstant_epoch_seconds_ser( - * _inp: &Option, - * _serializer: S, - * ) -> Result<::Ok, ::Error> - * where S: Serializer, { - * use Serialize; - * let el = _inp; - * el.as_ref() - * .map(|el| instant_epoch::InstantEpoch(*el)) - * .serialize(_serializer) - * } - * ``` - * - */ - - fun serializerFor(memberShape: MemberShape): RuntimeType? { - val symbol = symbolProvider.toSymbol(memberShape) - val rustType = symbol.rustType() - if (customShapes.none { rustType.contains(it) }) { - return null - } - val fnName = serializerName(rustType, memberShape, "ser") - return RuntimeType.forInlineFun(fnName, "serde_util") { writer -> - serializeFn(writer, fnName, symbol) { - serializer(rustType, memberShape) - } - } - } - /** * Generate a custom deserialization function for [memberShape], suitable to be used * in the serde annotation `serialize_with` (See [JsonSerializerSymbolProvider]) @@ -123,7 +81,7 @@ class CustomSerializerGenerator( if (customShapes.none { rustType.contains(it) }) { return null } - val fnName = serializerName(rustType, memberShape, "deser") + val fnName = deserializerName(rustType, memberShape) return RuntimeType.forInlineFun(fnName, "serde_util") { writer -> deserializeFn(writer, fnName, symbol) { deserializer(rustType, memberShape) @@ -131,54 +89,6 @@ class CustomSerializerGenerator( } } - private fun rollSer(t: RustType, memberShape: MemberShape): Writable { - return when (t) { - is RustType.Option -> writable { - withBlock("el.as_ref().map(|el|", ")") { - rollSer(t.member, memberShape)(this) - } - } - is RustType.Vec -> writable { - withBlock("el.iter().map(|el|", ").collect::>()") { - rollSer(t.member, memberShape)(this) - } - } - is RustType.HashMap -> writable { - withBlock("el.iter().map(|(k,el)|(k, ", ")).collect::<${t.namespace}::${t.name}<_, _>>()") { - rollSer(t.member, memberShape)(this) - } - } - is RustType.Box -> writable { - // TODO: this only works for exterior boxes. - withBlock("let el = el.as_ref();", "") { - rollSer(t.member, memberShape)(this) - } - } - is RustType.Reference -> writable { - rollSer(t.member, memberShape)(this) - } - else -> if (customShapes.contains(t)) { - writable { - serdeType(t, memberShape, SerdeDirection.Serialize)(this) - if (t == instant) { - write("(*el)") - } else { - write("(el)") - } - } - } else { - TODO("unsupported type $t") - } - } - } - - private fun RustWriter.serializer(t: RustType, memberShape: MemberShape) { - write("use #T;", RuntimeType.Serialize) - write("let el = $inp;") - rollSer(t, memberShape)(this) - write(".serialize($ser)") - } - /** * Generate a deserializer for the given type dynamically, eg: * ```rust @@ -197,7 +107,7 @@ class CustomSerializerGenerator( private fun RustWriter.deserializer(t: RustType, memberShape: MemberShape) { write("use #T;", RuntimeType.Deserialize) withBlock("Ok(", ")") { - serdeType(t, memberShape, SerdeDirection.Deserialize)(this) + serdeType(t, memberShape)(this) write("::deserialize(_deser)?") unrollDeser(t) } @@ -238,13 +148,11 @@ class CustomSerializerGenerator( else -> "${realType.namespace}::${realType.name}::<" } withBlock(prefix, ">") { - serdeType(realType.member, memberShape, SerdeDirection.Deserialize)(this) + serdeType(realType.member, memberShape)(this) } } - enum class SerdeDirection { Serialize, Deserialize } - - private fun serdeType(realType: RustType, memberShape: MemberShape, serdeDirection: SerdeDirection): Writable { + private fun serdeType(realType: RustType, memberShape: MemberShape): Writable { return when (realType) { instant -> writable { val format = tsFormat(memberShape) @@ -259,41 +167,20 @@ class CustomSerializerGenerator( } } blob -> writable { - if (serdeDirection == SerdeDirection.Deserialize) { - write("#T::BlobDeser", RuntimeType.BlobSerde(runtimeConfig)) - } else { - write("#T::BlobSer", RuntimeType.BlobSerde(runtimeConfig)) - } + write("#T::BlobDeser", RuntimeType.BlobSerde(runtimeConfig)) } document -> writable { - when (serdeDirection) { - SerdeDirection.Serialize -> write("#T::SerDoc", RuntimeType.DocJson) - SerdeDirection.Deserialize -> write("#T::DeserDoc", RuntimeType.DocJson) - } + write("#T::DeserDoc", RuntimeType.DocJson) } is RustType.Container -> writable { serdeContainerType(realType, memberShape) } - else -> TODO("$serdeDirection for $realType is not supported") - } - } - - /** correct argument type for the serde custom serializer */ - private fun serializerType(symbol: Symbol): Symbol { - val unref = symbol.rustType().stripOuter() - - // Convert `Vec` to `[T]` when present. This is needed to avoid - // Clippy complaining (and is also better in general). - val outType = when (unref) { - is RustType.Vec -> RustType.Slice(unref.member) - else -> unref + else -> TODO("Deserialize for $realType is not supported") } - val referenced = RustType.Reference(member = outType, lifetime = null) - return symbol.toBuilder().rustType(referenced).build() } private fun tsFormat(memberShape: MemberShape) = httpBindingIndex.determineTimestampFormat(memberShape, HttpBinding.Location.PAYLOAD, defaultTimestampFormat) - private fun serializerName(rustType: RustType, memberShape: MemberShape, suffix: String): String { + private fun deserializerName(rustType: RustType, memberShape: MemberShape): String { val context = when { rustType.contains(instant) -> tsFormat(memberShape).name.replace('-', '_').toLowerCase() else -> null @@ -301,23 +188,7 @@ class CustomSerializerGenerator( val typeToFnName = rustType.stripOuter().render(fullyQualified = true).filter { it.isLetterOrDigit() } .toLowerCase() - return listOfNotNull(typeToFnName, context, suffix).joinToString("_") - } - - private fun serializeFn( - rustWriter: RustWriter, - functionName: String, - symbol: Symbol, - body: RustWriter.() -> Unit - ) { - rustWriter.rustBlock( - "pub fn $functionName(_inp: #1T, _serializer: S) -> " + - "Result<::Ok, ::Error> where S: #2T", - serializerType(symbol), - RuntimeType.Serializer - ) { - body(this) - } + return listOfNotNull(typeToFnName, context, "deser").joinToString("_") } private fun deserializeFn( diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/JsonSerializerSymbolProvider.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/JsonSerializerSymbolProvider.kt index c9b25b802b..131ffd86d4 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/JsonSerializerSymbolProvider.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/JsonSerializerSymbolProvider.kt @@ -15,15 +15,12 @@ import software.amazon.smithy.model.traits.JsonNameTrait import software.amazon.smithy.model.traits.TimestampFormatTrait import software.amazon.smithy.rust.codegen.rustlang.Attribute import software.amazon.smithy.rust.codegen.rustlang.RustMetadata -import software.amazon.smithy.rust.codegen.rustlang.RustType -import software.amazon.smithy.rust.codegen.rustlang.stripOuter import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.smithy.SymbolMetadataProvider import software.amazon.smithy.rust.codegen.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.smithy.isOptional import software.amazon.smithy.rust.codegen.smithy.letIf -import software.amazon.smithy.rust.codegen.smithy.rustType import software.amazon.smithy.rust.codegen.smithy.traits.InputBodyTrait import software.amazon.smithy.rust.codegen.smithy.traits.OutputBodyTrait import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait @@ -42,7 +39,7 @@ class JsonSerializerSymbolProvider( ) : SymbolMetadataProvider(base) { - data class SerdeConfig(val serialize: Boolean, val deserialize: Boolean) + data class SerdeConfig(val deserialize: Boolean) private fun MemberShape.serializedName() = this.getTrait()?.value ?: this.memberName @@ -51,17 +48,9 @@ class JsonSerializerSymbolProvider( val currentMeta = base.toSymbol(memberShape).expectRustMetadata() val serdeConfig = serdeRequired(model.expectShape(memberShape.container)) val attribs = mutableListOf() - if (serdeConfig.serialize || serdeConfig.deserialize) { + if (serdeConfig.deserialize) { attribs.add(Attribute.Custom("serde(rename = ${memberShape.serializedName().dq()})")) } - if (serdeConfig.serialize) { - if (base.toSymbol(memberShape).rustType().stripOuter() is RustType.Option) { - attribs.add(Attribute.Custom("serde(skip_serializing_if = \"Option::is_none\")")) - } - serializerBuilder.serializerFor(memberShape)?.also { - attribs.add(Attribute.Custom("serde(serialize_with = ${it.fullyQualifiedName().dq()})", listOf(it))) - } - } if (serdeConfig.deserialize) { serializerBuilder.deserializerFor(memberShape)?.also { attribs.add(Attribute.Custom("serde(deserialize_with = ${it.fullyQualifiedName().dq()})", listOf(it))) @@ -82,20 +71,18 @@ class JsonSerializerSymbolProvider( private fun containerMeta(container: Shape): RustMetadata { val currentMeta = base.toSymbol(container).expectRustMetadata() val requiredSerde = serdeRequired(container) - return currentMeta - .letIf(requiredSerde.serialize) { it.withDerives(RuntimeType.Serialize) } - .letIf(requiredSerde.deserialize) { it.withDerives(RuntimeType.Deserialize) } + return currentMeta.letIf(requiredSerde.deserialize) { it.withDerives(RuntimeType.Deserialize) } } private fun serdeRequired(shape: Shape): SerdeConfig { return when { - shape.hasTrait() -> SerdeConfig(serialize = true, deserialize = false) - shape.hasTrait() -> SerdeConfig(serialize = false, deserialize = true) + shape.hasTrait() -> SerdeConfig(deserialize = false) + shape.hasTrait() -> SerdeConfig(deserialize = true) // The bodies must be serializable. The top level inputs are _not_ - shape.hasTrait() -> SerdeConfig(serialize = false, deserialize = false) - shape.hasTrait() -> SerdeConfig(serialize = false, deserialize = false) - else -> SerdeConfig(serialize = true, deserialize = true) + shape.hasTrait() -> SerdeConfig(deserialize = false) + shape.hasTrait() -> SerdeConfig(deserialize = false) + else -> SerdeConfig(deserialize = true) } } } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/JsonSerializerGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/JsonSerializerGenerator.kt index 10aa15f44a..e58a2bbca1 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/JsonSerializerGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/JsonSerializerGenerator.kt @@ -39,7 +39,6 @@ import software.amazon.smithy.rust.codegen.smithy.isOptional import software.amazon.smithy.rust.codegen.smithy.protocols.serializeFunctionName import software.amazon.smithy.rust.codegen.smithy.rustType import software.amazon.smithy.rust.codegen.util.dq -import software.amazon.smithy.rust.codegen.util.expectMember import software.amazon.smithy.rust.codegen.util.getTrait import software.amazon.smithy.rust.codegen.util.hasTrait import software.amazon.smithy.rust.codegen.util.inputShape @@ -170,11 +169,9 @@ class JsonSerializerGenerator(protocolConfig: ProtocolConfig) : StructuredDataSe // Don't generate an operation JSON serializer if there is no JSON body val httpBindings = httpIndex.getRequestBindings(operationShape) - val hasDocumentHttpBindings = httpBindings - .filter { it.value.location == Location.DOCUMENT } - .keys.map { inputShape.expectMember(it) } - .isNotEmpty() - if (inputShape.members().isEmpty() || httpBindings.isNotEmpty() && !hasDocumentHttpBindings) { + val httpBound = httpBindings.isNotEmpty() + val httpDocumentMembers = httpBindings.filter { it.value.location == Location.DOCUMENT }.keys + if (inputShape.members().isEmpty() || httpBound && httpDocumentMembers.isEmpty()) { return null } @@ -186,7 +183,9 @@ class JsonSerializerGenerator(protocolConfig: ProtocolConfig) : StructuredDataSe ) { rust("let mut out = String::new();") rustTemplate("let mut object = #{JsonObjectWriter}::new(&mut out);", *codegenScope) - serializeStructure(StructContext("object", "input", inputShape)) + serializeStructure(StructContext("object", "input", inputShape)) { member -> + !httpBound || httpDocumentMembers.contains(member.memberName) + } rust("object.finish();") rustTemplate("Ok(#{SdkBody}::from(out))", *codegenScope) } @@ -209,7 +208,10 @@ class JsonSerializerGenerator(protocolConfig: ProtocolConfig) : StructuredDataSe } } - private fun RustWriter.serializeStructure(context: StructContext) { + private fun RustWriter.serializeStructure( + context: StructContext, + includeMember: (MemberShape) -> Boolean = { true } + ) { val fnName = symbolProvider.serializeFunctionName(context.shape) val structureSymbol = symbolProvider.toSymbol(context.shape) val structureSerializer = RuntimeType.forInlineFun(fnName, "json_ser") { writer -> @@ -219,10 +221,11 @@ class JsonSerializerGenerator(protocolConfig: ProtocolConfig) : StructuredDataSe *codegenScope, ) { context.copy(objectName = "object", localName = "input").also { inner -> - if (inner.shape.members().isEmpty()) { + val members = inner.shape.members().filter(includeMember) + if (members.isEmpty()) { rust("let (_, _) = (object, input);") // Suppress unused argument warnings } - for (member in inner.shape.members()) { + for (member in members) { serializeMember(MemberContext.structMember(inner, member, symbolProvider)) } } @@ -265,7 +268,10 @@ class JsonSerializerGenerator(protocolConfig: ProtocolConfig) : StructuredDataSe is RustType.Integer -> "NegInt" else -> throw IllegalStateException("unreachable") } - rust("$writer.number(#T::$numberType((${value.asValue()}).into()));", smithyTypes.member("Number")) + rust( + "$writer.number(##[allow(clippy::useless_conversion)]#T::$numberType((${value.asValue()}).into()));", + smithyTypes.member("Number") + ) } is BlobShape -> rust( "$writer.string_unchecked(&#T(${value.name}));", @@ -321,7 +327,12 @@ class JsonSerializerGenerator(protocolConfig: ProtocolConfig) : StructuredDataSe val keyName = safeName("key") val valueName = safeName("value") rustBlock("for ($keyName, $valueName) in ${context.valueExpression.asRef()}") { - serializeMember(MemberContext.mapMember(context, keyName, valueName)) + val keyTarget = model.expectShape(context.shape.key.target) + val keyExpression = when (keyTarget.hasTrait()) { + true -> "$keyName.as_str()" + else -> keyName + } + serializeMember(MemberContext.mapMember(context, keyExpression, valueName)) } } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/SerdeJsonSerializerGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/SerdeJsonSerializerGenerator.kt deleted file mode 100644 index cc5853fd06..0000000000 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/SerdeJsonSerializerGenerator.kt +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ - -package software.amazon.smithy.rust.codegen.smithy.protocols.parsers - -import software.amazon.smithy.model.shapes.MemberShape -import software.amazon.smithy.model.shapes.OperationShape -import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.rustlang.rustBlock -import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate -import software.amazon.smithy.rust.codegen.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.rustlang.withBlock -import software.amazon.smithy.rust.codegen.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig -import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait -import software.amazon.smithy.rust.codegen.util.expectTrait -import software.amazon.smithy.rust.codegen.util.inputShape -import software.amazon.smithy.rust.codegen.util.toSnakeCase - -class SerdeJsonSerializerGenerator(protocolConfig: ProtocolConfig) : StructuredDataSerializerGenerator { - private val model = protocolConfig.model - private val symbolProvider = protocolConfig.symbolProvider - private val runtimeConfig = protocolConfig.runtimeConfig - private val serializerError = RuntimeType.SerdeJson("error::Error") - private val codegenScope = arrayOf( - "Error" to serializerError, - "serde_json" to RuntimeType.serdeJson, - "SdkBody" to RuntimeType.sdkBody(runtimeConfig) - ) - - override fun payloadSerializer(member: MemberShape): RuntimeType { - val target = model.expectShape(member.target) - val fnName = "serialize_payload_${target.id.name.toSnakeCase()}_${member.container.name.toSnakeCase()}" - return RuntimeType.forInlineFun(fnName, "operation_ser") { - it.rustTemplate( - """ - pub fn $fnName(input: &#{target}) -> Result<#{SdkBody}, #{Error}> { - #{serde_json}::to_vec(&input).map(#{SdkBody}::from) - } - """, - *codegenScope, "target" to symbolProvider.toSymbol(target) - ) - } - } - - override fun operationSerializer(operationShape: OperationShape): RuntimeType? { - // Currently, JSON shapes are serialized via a synthetic body structure that gets generated during model - // transformation - val inputShape = operationShape.inputShape(model) - val inputBody = inputShape.expectTrait().body?.let { - model.expectShape( - it, - StructureShape::class.java - ) - } ?: return null - val fnName = "serialize_operation_${inputBody.id.name.toSnakeCase()}" - return RuntimeType.forInlineFun(fnName, "operation_ser") { - it.rustBlockTemplate( - "pub fn $fnName(input: &#{target}) -> Result<#{SdkBody}, #{Error}>", - *codegenScope, "target" to symbolProvider.toSymbol(inputShape) - ) { - // copy the input (via references) into the synthetic body: - withBlock("let body = ", ";") { - rustBlock("#T", symbolProvider.toSymbol(inputBody)) { - for (member in inputBody.members()) { - val name = symbolProvider.toMemberName(member) - write("$name: &input.$name,") - } - } - } - rustTemplate( - """#{serde_json}::to_vec(&body).map(#{SdkBody}::from)""", - *codegenScope - ) - } - } - } - - override fun documentSerializer(): RuntimeType { - val fnName = "serialize_document" - return RuntimeType.forInlineFun(fnName, "operation_ser") { - it.rustTemplate( - """ - pub fn $fnName(input: &#{Document}) -> Result<#{SdkBody}, #{Error}> { - #{serde_json}::to_vec(&#{doc_json}::SerDoc(&input)).map(#{SdkBody}::from) - } - - """, - "Document" to RuntimeType.Document(runtimeConfig), "doc_json" to RuntimeType.DocJson, *codegenScope - ) - } - } -} diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/CustomSerializerGeneratorTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/CustomSerializerGeneratorTest.kt index acba640e72..29c4b855ef 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/CustomSerializerGeneratorTest.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/CustomSerializerGeneratorTest.kt @@ -54,8 +54,6 @@ internal class CustomSerializerGeneratorTest { @Test fun `generate correct function names`() { val serializerBuilder = CustomSerializerGenerator(provider, model, TimestampFormatTrait.Format.EPOCH_SECONDS) - serializerBuilder.serializerFor(model.lookup("test#S\$timestamp"))!!.name shouldBe "stdoptionoptionsmithytypesinstant_epoch_seconds_ser" - serializerBuilder.serializerFor(model.lookup("test#S\$blob"))!!.name shouldBe "stdoptionoptionsmithytypesblob_ser" serializerBuilder.deserializerFor(model.lookup("test#S\$blob"))!!.name shouldBe "stdoptionoptionsmithytypesblob_deser" serializerBuilder.deserializerFor(model.lookup("test#S\$string")) shouldBe null } @@ -66,12 +64,6 @@ internal class CustomSerializerGeneratorTest { checkSymbol(symbol) } - private fun checkSerializer(builder: CustomSerializerGenerator, shapeId: String) { - val symbol = builder.serializerFor(model.lookup(shapeId)) - check(symbol != null) { "For $shapeId, expected a custom serializer" } - checkSymbol(symbol) - } - private fun checkSymbol(symbol: RuntimeType) { val writer = TestWorkspace.testProject(provider) writer.lib { @@ -100,14 +92,12 @@ internal class CustomSerializerGeneratorTest { fun `generate basic deserializers that compile`(memberName: String) { val serializerBuilder = CustomSerializerGenerator(provider, model, TimestampFormatTrait.Format.EPOCH_SECONDS) checkDeserializer(serializerBuilder, "test#S\$$memberName") - checkSerializer(serializerBuilder, "test#S\$$memberName") } @Test fun `support deeply nested structures`() { val serializerBuilder = CustomSerializerGenerator(provider, model, TimestampFormatTrait.Format.EPOCH_SECONDS) checkDeserializer(serializerBuilder, "test#TopLevel\$member") - checkSerializer(serializerBuilder, "test#TopLevel\$member") } @Test @@ -118,7 +108,6 @@ internal class CustomSerializerGeneratorTest { } } val serializerBuilder = CustomSerializerGenerator(boxingProvider, model, TimestampFormatTrait.Format.EPOCH_SECONDS) - checkSerializer(serializerBuilder, "test#S\$timestamp") checkDeserializer(serializerBuilder, "test#S\$timestamp") } }