From f022ca2cc373c02f4f5ca6925965ccfdfe9d36d1 Mon Sep 17 00:00:00 2001 From: Weihang Lo Date: Mon, 5 Sep 2022 17:24:11 +0100 Subject: [PATCH 01/15] Make Instantiator generate default values for required field on demand --- .../codegen/smithy/generators/Instantiator.kt | 76 +++++++++++++++++-- .../smithy/generators/InstantiatorTest.kt | 39 ++++++++++ 2 files changed, 107 insertions(+), 8 deletions(-) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt index db0d4f9f1f..2e068d506d 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt @@ -5,6 +5,7 @@ package software.amazon.smithy.rust.codegen.smithy.generators +import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.model.Model import software.amazon.smithy.model.node.ArrayNode import software.amazon.smithy.model.node.Node @@ -68,13 +69,19 @@ class Instantiator( val streaming: Boolean, // Whether we are instantiating with a Builder, in which case all setters take Option val builder: Boolean, + // If a given required field is missing, try to provide a default value. + val defaultsForRequiredFields: Boolean, ) + companion object { + fun defaultContext() = Ctx(lowercaseMapKeys = false, streaming = false, builder = false, defaultsForRequiredFields = false) + } + fun render( writer: RustWriter, shape: Shape, arg: Node, - ctx: Ctx = Ctx(lowercaseMapKeys = false, streaming = false, builder = false), + ctx: Ctx = defaultContext(), ) { when (shape) { // Compound Shapes @@ -222,14 +229,23 @@ class Instantiator( */ private fun renderUnion(writer: RustWriter, shape: UnionShape, data: ObjectNode, ctx: Ctx) { val unionSymbol = symbolProvider.toSymbol(shape) - check(data.members.size == 1) - val variant = data.members.iterator().next() - val memberName = variant.key.value + + val variant = if (ctx.defaultsForRequiredFields && data.members.isEmpty()) { + val (name, memberShape) = shape.allMembers.entries.first() + val shape = model.expectShape(memberShape.target) + Node.from(name) to fillDefaultValue(shape) + } else { + check(data.members.size == 1) + val entry = data.members.iterator().next() + entry.key to entry.value + } + + val memberName = variant.first.value val member = shape.expectMember(memberName) writer.write("#T::${symbolProvider.toMemberName(member)}", unionSymbol) // unions should specify exactly one member writer.withBlock("(", ")") { - renderMember(this, member, variant.value, ctx) + renderMember(this, member, variant.second, ctx) } } @@ -267,16 +283,60 @@ class Instantiator( * ``` */ private fun renderStructure(writer: RustWriter, shape: StructureShape, data: ObjectNode, ctx: Ctx) { - writer.write("#T::builder()", symbolProvider.toSymbol(shape)) - data.members.forEach { (key, value) -> - val memberShape = shape.expectMember(key.value) + fun renderMemberHelper(memberShape: MemberShape, value: Node) { writer.withBlock(".${memberShape.setterName()}(", ")") { renderMember(this, memberShape, value, ctx) } } + + writer.write("#T::builder()", symbolProvider.toSymbol(shape)) + if (ctx.defaultsForRequiredFields) { + shape.allMembers.entries + .filter { (name, memberShape) -> + memberShape.isRequired && !data.members.containsKey(Node.from(name)) + } + .forEach { (_, memberShape) -> + val shape = model.expectShape(memberShape.target) + renderMemberHelper(memberShape, fillDefaultValue(shape)) + } + } + + data.members.forEach { (key, value) -> + val memberShape = shape.expectMember(key.value) + renderMemberHelper(memberShape, value) + } writer.write(".build()") if (StructureGenerator.fallibleBuilder(shape, symbolProvider)) { writer.write(".unwrap()") } } + + /** + * Fill default values for missing required value of a shape. + */ + private fun fillDefaultValue(shape: Shape): Node = when (shape) { + // Compound Shapes + is StructureShape -> Node.objectNode() + is UnionShape -> Node.objectNode() + + // Collections + is ListShape -> Node.arrayNode() + is MapShape -> Node.objectNode() + is SetShape -> Node.arrayNode() + + is MemberShape -> throw CodegenException("Unable to handle member shape `$shape`. Please provide target shape instead") + + // Wrapped Shapes + is TimestampShape -> Node.from(0) // Number node for timestamp + + is BlobShape -> Node.from("") // String node for bytes + + // Simple Shapes + is StringShape -> Node.from("") + is NumberShape -> Node.from(0) + is BooleanShape -> Node.from(false) + // TODO(weihanglo): how to handle document shape properly? + is DocumentShape -> Node.objectNode() + else -> throw CodegenException("Unrecognized shape `$shape`") + } } diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/InstantiatorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/InstantiatorTest.kt index 8d040f81c6..0e8e7dfcfe 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/InstantiatorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/InstantiatorTest.kt @@ -66,6 +66,22 @@ class InstantiatorTest { member: WithBox, value: Integer } + + structure MyStructRequired { + @required + foo: String, + @required + bar: PrimitiveInteger, + @required + baz: Integer, + @required + ts: Timestamp, + @required + byteValue: Byte + @required + union: MyUnion, + } + """.asSmithyModel().let { RecursiveShapeBoxer.transform(it) } private val symbolProvider = testSymbolProvider(model) @@ -236,4 +252,27 @@ class InstantiatorTest { } writer.compileAndTest() } + + @Test + fun `generate struct with missing required members`() { + val structure = model.lookup("com.test#MyStructRequired") + val union = model.lookup("com.test#MyUnion") + val sut = Instantiator(symbolProvider, model, runtimeConfig, CodegenTarget.CLIENT) + val data = Node.parse("{}") + val writer = RustWriter.forModule("model") + structure.renderWithModelBuilder(model, symbolProvider, writer) + UnionGenerator(model, symbolProvider, writer, union).render() + writer.test { + writer.withBlock("let result = ", ";") { + sut.render(this, structure, data, Instantiator.defaultContext().copy(defaultsForRequiredFields = true)) + } + writer.write("assert_eq!(result.foo.unwrap(), \"\");") + writer.write("assert_eq!(result.bar, 0);") + writer.write("assert_eq!(result.baz.unwrap(), 0);") + writer.write("assert_eq!(result.ts.unwrap(), aws_smithy_types::DateTime::from_secs(0));") + writer.write("assert_eq!(result.byte_value.unwrap(), 0);") + writer.write("assert_eq!(result.union.unwrap(), MyUnion::StringVariant(String::new()));") + } + writer.compileAndTest() + } } From 8182c2160220c4aa9a0f0e15a5406355f9fc1b2d Mon Sep 17 00:00:00 2001 From: Weihang Lo Date: Tue, 23 Aug 2022 14:42:02 +0100 Subject: [PATCH 02/15] Move looping over operations into ServerProtocolTestGenerator Signed-off-by: Weihang Lo --- .../generators/ServerServiceGenerator.kt | 12 +-- .../protocol/ServerProtocolTestGenerator.kt | 99 +++++++++++++------ 2 files changed, 72 insertions(+), 39 deletions(-) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt index 634be6c386..1d661f8780 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt @@ -37,15 +37,11 @@ open class ServerServiceGenerator( * which assigns a symbol location to each shape. */ fun render() { + rustCrate.withModule(RustModule.public("operation")) { writer -> + ServerProtocolTestGenerator(coreCodegenContext, protocolSupport, protocolGenerator).render(writer) + } + for (operation in operations) { - rustCrate.useShapeWriter(operation) { operationWriter -> - protocolGenerator.serverRenderOperation( - operationWriter, - operation, - ) - ServerProtocolTestGenerator(coreCodegenContext, protocolSupport, operation, operationWriter) - .render() - } if (operation.errors.isNotEmpty()) { rustCrate.withModule(RustModule.Error) { writer -> renderCombinedErrors(writer, operation) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 292165c4aa..b69da8d9e7 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -5,7 +5,9 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators.protocol +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.knowledge.OperationIndex +import software.amazon.smithy.model.knowledge.TopDownIndex import software.amazon.smithy.model.node.Node import software.amazon.smithy.model.shapes.DoubleShape import software.amazon.smithy.model.shapes.FloatShape @@ -33,11 +35,13 @@ import software.amazon.smithy.rust.codegen.rustlang.rustBlock import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.rustlang.withBlock import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency +import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpBoundProtocolGenerator import software.amazon.smithy.rust.codegen.smithy.CoreCodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget import software.amazon.smithy.rust.codegen.smithy.generators.Instantiator +import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolGenerator import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.smithy.transformers.allErrors import software.amazon.smithy.rust.codegen.testutil.TokioTest @@ -59,19 +63,16 @@ import kotlin.reflect.KFunction1 class ServerProtocolTestGenerator( private val coreCodegenContext: CoreCodegenContext, private val protocolSupport: ProtocolSupport, - private val operationShape: OperationShape, - private val writer: RustWriter, + private val protocolGenerator: ProtocolGenerator, ) { private val logger = Logger.getLogger(javaClass.name) + private val index = TopDownIndex.of(coreCodegenContext.model) + private val operations = index.getContainedOperations(coreCodegenContext.serviceShape).sortedBy { it.id } + private val model = coreCodegenContext.model - private val inputShape = operationShape.inputShape(coreCodegenContext.model) - private val outputShape = operationShape.outputShape(coreCodegenContext.model) private val symbolProvider = coreCodegenContext.symbolProvider - private val operationSymbol = symbolProvider.toSymbol(operationShape) private val operationIndex = OperationIndex.of(coreCodegenContext.model) - private val operationImplementationName = "${operationSymbol.name}${ServerHttpBoundProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}" - private val operationErrorName = "crate::error::${operationSymbol.name}Error" private val instantiator = with(coreCodegenContext) { Instantiator(symbolProvider, model, runtimeConfig, CodegenTarget.SERVER) @@ -82,8 +83,10 @@ class ServerProtocolTestGenerator( "SmithyHttp" to CargoDependency.SmithyHttp(coreCodegenContext.runtimeConfig).asType(), "Http" to CargoDependency.Http.asType(), "Hyper" to CargoDependency.Hyper.asType(), + "Tower" to CargoDependency.Tower.asType(), "SmithyHttpServer" to ServerCargoDependency.SmithyHttpServer(coreCodegenContext.runtimeConfig).asType(), "AssertEq" to CargoDependency.PrettyAssertions.asType().member("assert_eq!"), + "Router" to ServerRuntimeType.Router(coreCodegenContext.runtimeConfig), ) sealed class TestCase { @@ -114,7 +117,29 @@ class ServerProtocolTestGenerator( } } - fun render() { + fun render(writer: RustWriter) { + renderTestHelper(writer) + + for (operation in operations) { + protocolGenerator.serverRenderOperation(writer, operation) + renderOperationTestCases(operation, writer) + } + } + + /** + * Render a test helper module that help + * + * - generate dynamic builder for each handler, and + * - construct a tower service to exercise each test case. + */ + private fun renderTestHelper(writer: RustWriter) { + // Create a tower service to perform protocol test + } + + private fun renderOperationTestCases(operationShape: OperationShape, writer: RustWriter) { + val outputShape = operationShape.outputShape(coreCodegenContext.model) + val operationSymbol = symbolProvider.toSymbol(operationShape) + val requestTests = operationShape.getTrait() ?.getTestCasesFor(AppliesTo.SERVER).orEmpty().map { TestCase.RequestTest(it) } val responseTests = operationShape.getTrait() @@ -141,18 +166,18 @@ class ServerProtocolTestGenerator( visibility = Visibility.PRIVATE, ) writer.withModule(testModuleName, moduleMeta) { - renderAllTestCases(allTests) + renderAllTestCases(operationShape, operationSymbol, allTests) } } } - private fun RustWriter.renderAllTestCases(allTests: List) { + private fun RustWriter.renderAllTestCases(operationShape: OperationShape, operationSymbol: Symbol, allTests: List) { allTests.forEach { renderTestCaseBlock(it, this) { when (it) { - is TestCase.RequestTest -> this.renderHttpRequestTestCase(it.testCase) - is TestCase.ResponseTest -> this.renderHttpResponseTestCase(it.testCase, it.targetShape) - is TestCase.MalformedRequestTest -> this.renderHttpMalformedRequestTestCase(it.testCase) + is TestCase.RequestTest -> this.renderHttpRequestTestCase(it.testCase, operationShape, operationSymbol) + is TestCase.ResponseTest -> this.renderHttpResponseTestCase(it.testCase, it.targetShape, operationShape, operationSymbol) + is TestCase.MalformedRequestTest -> this.renderHttpMalformedRequestTestCase(it.testCase, operationSymbol) } } } @@ -236,6 +261,8 @@ class ServerProtocolTestGenerator( */ private fun RustWriter.renderHttpRequestTestCase( httpRequestTestCase: HttpRequestTestCase, + operationShape: OperationShape, + operationSymbol: Symbol, ) { if (!protocolSupport.requestDeserialization) { rust("/* test case disabled for this protocol (not yet supported) */") @@ -245,7 +272,7 @@ class ServerProtocolTestGenerator( renderHttpRequest(uri, headers, body.orNull(), queryParams, host.orNull()) } if (protocolSupport.requestBodyDeserialization) { - checkParams(httpRequestTestCase, this) + checkParams(operationShape, operationSymbol, httpRequestTestCase, this) } // Explicitly warn if the test case defined parameters that we aren't doing anything with @@ -272,7 +299,12 @@ class ServerProtocolTestGenerator( private fun RustWriter.renderHttpResponseTestCase( testCase: HttpResponseTestCase, shape: StructureShape, + operationShape: OperationShape, + operationSymbol: Symbol, ) { + val operationImplementationName = "${operationSymbol.name}${ServerHttpBoundProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}" + val operationErrorName = "crate::error::${operationSymbol.name}Error" + if (!protocolSupport.responseSerialization || ( !protocolSupport.errorSerialization && shape.hasTrait() ) @@ -308,7 +340,7 @@ class ServerProtocolTestGenerator( * We are given a request definition and a response definition, and we have to assert that the request is rejected * with the given response. */ - private fun RustWriter.renderHttpMalformedRequestTestCase(testCase: HttpMalformedRequestTestCase) { + private fun RustWriter.renderHttpMalformedRequestTestCase(testCase: HttpMalformedRequestTestCase, operationSymbol: Symbol) { with(testCase.request) { // TODO(https://github.com/awslabs/smithy/issues/1102): `uri` should probably not be an `Optional`. renderHttpRequest(uri.get(), headers, body.orNull(), queryParams, host.orNull()) @@ -372,7 +404,9 @@ class ServerProtocolTestGenerator( } } - private fun checkParams(httpRequestTestCase: HttpRequestTestCase, rustWriter: RustWriter) { + private fun checkParams(operationShape: OperationShape, operationSymbol: Symbol, httpRequestTestCase: HttpRequestTestCase, rustWriter: RustWriter) { + val inputShape = operationShape.inputShape(coreCodegenContext.model) + rustWriter.writeInline("let expected = ") instantiator.render(rustWriter, inputShape, httpRequestTestCase.params) rustWriter.write(";") @@ -522,21 +556,24 @@ class ServerProtocolTestGenerator( } } - private fun checkHttpOperationExtension(rustWriter: RustWriter) { - rustWriter.rustTemplate( - """ - let operation_extension = http_response.extensions() - .get::<#{SmithyHttpServer}::extension::OperationExtension>() - .expect("extension `OperationExtension` not found"); - """.trimIndent(), - *codegenScope, - ) - rustWriter.writeWithNoFormatting( - """ - assert_eq!(operation_extension.absolute(), format!("{}.{}", "${operationShape.id.namespace}", "${operationSymbol.name}")); - """.trimIndent(), - ) - } + // We can't check that the `OperationExtension` is set in the response, because it is set in the implementation + // of the operation `Handler` trait, a code path that does not get exercised when we don't have a request to + // invoke it with (like in the case of an `httpResponseTest` test case). + // private fun checkHttpOperationExtension(rustWriter: RustWriter) { + // rustWriter.rustTemplate( + // """ + // let operation_extension = http_response.extensions() + // .get::<#{SmithyHttpServer}::extension::OperationExtension>() + // .expect("extension `OperationExtension` not found"); + // """.trimIndent(), + // *codegenScope, + // ) + // rustWriter.writeWithNoFormatting( + // """ + // assert_eq!(operation_extension.absolute(), format!("{}.{}", "${operationShape.id.namespace}", "${operationSymbol.name}")); + // """.trimIndent(), + // ) + // } private fun checkStatusCode(rustWriter: RustWriter, statusCode: Int) { rustWriter.rustTemplate( From 9474fa2e2d4f4b9f93ebb0b8b37c874772e180ab Mon Sep 17 00:00:00 2001 From: Weihang Lo Date: Wed, 24 Aug 2022 12:35:43 +0100 Subject: [PATCH 03/15] Add protocol test helper functions Signed-off-by: Weihang Lo --- .../generators/ServerServiceGenerator.kt | 3 + .../protocol/ServerProtocolTestGenerator.kt | 134 +++++++++--------- 2 files changed, 70 insertions(+), 67 deletions(-) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt index 1d661f8780..73185f2750 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt @@ -9,6 +9,7 @@ import software.amazon.smithy.model.knowledge.TopDownIndex import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.rust.codegen.rustlang.RustModule import software.amazon.smithy.rust.codegen.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolTestGenerator import software.amazon.smithy.rust.codegen.smithy.CoreCodegenContext import software.amazon.smithy.rust.codegen.smithy.RustCrate @@ -38,6 +39,8 @@ open class ServerServiceGenerator( */ fun render() { rustCrate.withModule(RustModule.public("operation")) { writer -> + // TODO(weihanglo): remove #![allow(dead_code)] + writer.rust("##![allow(dead_code)]") ServerProtocolTestGenerator(coreCodegenContext, protocolSupport, protocolGenerator).render(writer) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index b69da8d9e7..b6e0af75bc 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -9,8 +9,6 @@ import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.knowledge.OperationIndex import software.amazon.smithy.model.knowledge.TopDownIndex import software.amazon.smithy.model.node.Node -import software.amazon.smithy.model.shapes.DoubleShape -import software.amazon.smithy.model.shapes.FloatShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StructureShape @@ -26,6 +24,7 @@ import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait import software.amazon.smithy.rust.codegen.rustlang.Attribute import software.amazon.smithy.rust.codegen.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.rustlang.RustReservedWords import software.amazon.smithy.rust.codegen.rustlang.RustWriter import software.amazon.smithy.rust.codegen.rustlang.Visibility import software.amazon.smithy.rust.codegen.rustlang.asType @@ -47,10 +46,8 @@ import software.amazon.smithy.rust.codegen.smithy.transformers.allErrors import software.amazon.smithy.rust.codegen.testutil.TokioTest import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.getTrait -import software.amazon.smithy.rust.codegen.util.hasStreamingMember import software.amazon.smithy.rust.codegen.util.hasTrait import software.amazon.smithy.rust.codegen.util.inputShape -import software.amazon.smithy.rust.codegen.util.isStreaming import software.amazon.smithy.rust.codegen.util.orNull import software.amazon.smithy.rust.codegen.util.outputShape import software.amazon.smithy.rust.codegen.util.toSnakeCase @@ -134,6 +131,67 @@ class ServerProtocolTestGenerator( */ private fun renderTestHelper(writer: RustWriter) { // Create a tower service to perform protocol test + val crateName = coreCodegenContext.settings.moduleName + val operationNames = operations.map { RustReservedWords.escapeIfNeeded(symbolProvider.toSymbol(it).name.toSnakeCase()) } + val operationRegistryName = "OperationRegistry" + val operationRegistryBuilderName = "${operationRegistryName}Builder" + + writer.withModule("protocol_test_helper") { + val operationInputOutputTypes = operations.map { operationShape -> + val inputSymbol = symbolProvider.toSymbol(operationShape.inputShape(model)) + val outputSymbol = symbolProvider.toSymbol(operationShape.outputShape(model)) + val operationSymbol = symbolProvider.toSymbol(operationShape) + val errorSymbol = RuntimeType("${operationSymbol.name}Error", null, "crate::error") + + val inputT = inputSymbol.fullName + val t = outputSymbol.fullName + val outputT = if (operationShape.errors.isEmpty()) { + t + } else { + val e = errorSymbol.fullyQualifiedName() + "Result<$t, $e>" + } + + Pair(inputT, outputT) + } + rustTemplate( + """ + use #{Tower}::Service as _; + + type F = fn(Input) -> std::pin::Pin + Send>>; + + type RegistryBuilder = crate::operation_registry::$operationRegistryBuilderName<#{Hyper}::Body, ${ + operationInputOutputTypes.map { (inputT, outputT) -> "F<$inputT, $outputT>, ()" }.joinToString(", ") + }>; + + fn create_operation_registry_builder() -> RegistryBuilder { + crate::operation_registry::$operationRegistryBuilderName::default() + ${operationNames.mapIndexed { index, operationName -> + val (inputT, outputT) = operationInputOutputTypes[index] + ".$operationName((|_| Box::pin(async { todo!() })) as F<$inputT, $outputT> )" + }.joinToString("\n")} + } + + pub(crate) async fn validate_request( + http_request: #{Http}::request::Request<#{SmithyHttpServer}::body::Body>, + f: &dyn Fn(RegistryBuilder) -> RegistryBuilder, + ) { + let router: #{Router} = f(create_operation_registry_builder()) + .build() + .expect("unable to build operation registry") + .into(); + _ = router.into_make_service() + .call(()) + .await + .expect("unable to get a router") + .call(http_request) + .await + .expect("unable to make an HTTP request"); + } + """, + *codegenScope, + ) + } } private fun renderOperationTestCases(operationShape: OperationShape, writer: RustWriter) { @@ -414,72 +472,14 @@ class ServerProtocolTestGenerator( val operationName = "${operationSymbol.name}${ServerHttpBoundProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}" rustWriter.rustTemplate( """ - let mut http_request = #{SmithyHttpServer}::request::RequestParts::new(http_request); - let parsed = super::$operationName::from_request(&mut http_request).await.expect("failed to parse request").0; + crate::operation::protocol_test_helper::validate_request( + http_request, + // |builder| builder, + &std::convert::identity, + ).await """, *codegenScope, ) - - if (inputShape.hasStreamingMember(model)) { - // A streaming shape does not implement `PartialEq`, so we have to iterate over the input shape's members - // and handle the equality assertion separately. - for (member in inputShape.members()) { - val memberName = coreCodegenContext.symbolProvider.toMemberName(member) - if (member.isStreaming(coreCodegenContext.model)) { - rustWriter.rustTemplate( - """ - #{AssertEq}( - parsed.$memberName.collect().await.unwrap().into_bytes(), - expected.$memberName.collect().await.unwrap().into_bytes() - ); - """, - *codegenScope, - ) - } else { - rustWriter.rustTemplate( - """ - #{AssertEq}(parsed.$memberName, expected.$memberName, "Unexpected value for `$memberName`"); - """, - *codegenScope, - ) - } - } - } else { - val hasFloatingPointMembers = inputShape.members().any { - val target = model.expectShape(it.target) - (target is DoubleShape) || (target is FloatShape) - } - - // TODO(https://github.com/awslabs/smithy-rs/issues/1147) Handle the case of nested floating point members. - if (hasFloatingPointMembers) { - for (member in inputShape.members()) { - val memberName = coreCodegenContext.symbolProvider.toMemberName(member) - when (coreCodegenContext.model.expectShape(member.target)) { - is DoubleShape, is FloatShape -> { - rustWriter.addUseImports( - RuntimeType.ProtocolTestHelper(coreCodegenContext.runtimeConfig, "FloatEquals").toSymbol(), - ) - rustWriter.rust( - """ - assert!(parsed.$memberName.float_equals(&expected.$memberName), - "Unexpected value for `$memberName` {:?} vs. {:?}", expected.$memberName, parsed.$memberName); - """, - ) - } - else -> { - rustWriter.rustTemplate( - """ - #{AssertEq}(parsed.$memberName, expected.$memberName, "Unexpected value for `$memberName`"); - """, - *codegenScope, - ) - } - } - } - } else { - rustWriter.rustTemplate("#{AssertEq}(parsed, expected);", *codegenScope) - } - } } private fun checkResponse(rustWriter: RustWriter, testCase: HttpResponseTestCase) { From ca2ea3327bf539158c045fa01f09370baf155b95 Mon Sep 17 00:00:00 2001 From: Weihang Lo Date: Thu, 1 Sep 2022 23:09:00 +0100 Subject: [PATCH 04/15] Add method param to construct http request --- .../protocol/ServerProtocolTestGenerator.kt | 58 ++++++++++--------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index b6e0af75bc..3b079f3df4 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -64,13 +64,31 @@ class ServerProtocolTestGenerator( ) { private val logger = Logger.getLogger(javaClass.name) - private val index = TopDownIndex.of(coreCodegenContext.model) - private val operations = index.getContainedOperations(coreCodegenContext.serviceShape).sortedBy { it.id } - private val model = coreCodegenContext.model private val symbolProvider = coreCodegenContext.symbolProvider private val operationIndex = OperationIndex.of(coreCodegenContext.model) + private val index = TopDownIndex.of(coreCodegenContext.model) + private val operations = index.getContainedOperations(coreCodegenContext.serviceShape).sortedBy { it.id } + + private val operationInputOutputTypes = operations.associate { + val inputSymbol = symbolProvider.toSymbol(it.inputShape(model)) + val outputSymbol = symbolProvider.toSymbol(it.outputShape(model)) + val operationSymbol = symbolProvider.toSymbol(it) + val errorSymbol = RuntimeType("${operationSymbol.name}Error", null, "crate::error") + + val inputT = inputSymbol.fullName + val t = outputSymbol.fullName + val outputT = if (it.errors.isEmpty()) { + t + } else { + val e = errorSymbol.fullyQualifiedName() + "Result<$t, $e>" + } + + Pair(it, Pair(inputT, outputT)) + } + private val instantiator = with(coreCodegenContext) { Instantiator(symbolProvider, model, runtimeConfig, CodegenTarget.SERVER) } @@ -137,23 +155,6 @@ class ServerProtocolTestGenerator( val operationRegistryBuilderName = "${operationRegistryName}Builder" writer.withModule("protocol_test_helper") { - val operationInputOutputTypes = operations.map { operationShape -> - val inputSymbol = symbolProvider.toSymbol(operationShape.inputShape(model)) - val outputSymbol = symbolProvider.toSymbol(operationShape.outputShape(model)) - val operationSymbol = symbolProvider.toSymbol(operationShape) - val errorSymbol = RuntimeType("${operationSymbol.name}Error", null, "crate::error") - - val inputT = inputSymbol.fullName - val t = outputSymbol.fullName - val outputT = if (operationShape.errors.isEmpty()) { - t - } else { - val e = errorSymbol.fullyQualifiedName() - "Result<$t, $e>" - } - - Pair(inputT, outputT) - } rustTemplate( """ use #{Tower}::Service as _; @@ -161,14 +162,17 @@ class ServerProtocolTestGenerator( type F = fn(Input) -> std::pin::Pin + Send>>; type RegistryBuilder = crate::operation_registry::$operationRegistryBuilderName<#{Hyper}::Body, ${ - operationInputOutputTypes.map { (inputT, outputT) -> "F<$inputT, $outputT>, ()" }.joinToString(", ") + operations.map { + val (inputT, outputT) = operationInputOutputTypes[it]!! + "F<$inputT, $outputT>, ()" + }.joinToString(", ") }>; fn create_operation_registry_builder() -> RegistryBuilder { crate::operation_registry::$operationRegistryBuilderName::default() - ${operationNames.mapIndexed { index, operationName -> - val (inputT, outputT) = operationInputOutputTypes[index] - ".$operationName((|_| Box::pin(async { todo!() })) as F<$inputT, $outputT> )" + ${operations.mapIndexed { idx, operationShape -> + val (inputT, outputT) = operationInputOutputTypes[operationShape]!! + ".${operationNames[idx]}((|_| Box::pin(async { todo!() })) as F<$inputT, $outputT> )" }.joinToString("\n")} } @@ -327,7 +331,7 @@ class ServerProtocolTestGenerator( return } with(httpRequestTestCase) { - renderHttpRequest(uri, headers, body.orNull(), queryParams, host.orNull()) + renderHttpRequest(uri, method, headers, body.orNull(), queryParams, host.orNull()) } if (protocolSupport.requestBodyDeserialization) { checkParams(operationShape, operationSymbol, httpRequestTestCase, this) @@ -401,7 +405,7 @@ class ServerProtocolTestGenerator( private fun RustWriter.renderHttpMalformedRequestTestCase(testCase: HttpMalformedRequestTestCase, operationSymbol: Symbol) { with(testCase.request) { // TODO(https://github.com/awslabs/smithy/issues/1102): `uri` should probably not be an `Optional`. - renderHttpRequest(uri.get(), headers, body.orNull(), queryParams, host.orNull()) + renderHttpRequest(uri.get(), method, headers, body.orNull(), queryParams, host.orNull()) } val operationName = "${operationSymbol.name}${ServerHttpBoundProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}" @@ -418,6 +422,7 @@ class ServerProtocolTestGenerator( private fun RustWriter.renderHttpRequest( uri: String, + method: String, headers: Map, body: String?, queryParams: List, @@ -428,6 +433,7 @@ class ServerProtocolTestGenerator( ##[allow(unused_mut)] let mut http_request = http::Request::builder() .uri("$uri") + .method("$method") """, *codegenScope, ) From 21c3a6dea84ed24f08986bebffe733773754a888 Mon Sep 17 00:00:00 2001 From: Weihang Lo Date: Fri, 2 Sep 2022 10:14:22 +0100 Subject: [PATCH 05/15] Put request validation logic inside closure Signed-off-by: Weihang Lo --- .../protocol/ServerProtocolTestGenerator.kt | 109 +++++++++++++++--- 1 file changed, 94 insertions(+), 15 deletions(-) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 3b079f3df4..89bca7394a 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -9,6 +9,8 @@ import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.knowledge.OperationIndex import software.amazon.smithy.model.knowledge.TopDownIndex import software.amazon.smithy.model.node.Node +import software.amazon.smithy.model.shapes.DoubleShape +import software.amazon.smithy.model.shapes.FloatShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StructureShape @@ -46,8 +48,10 @@ import software.amazon.smithy.rust.codegen.smithy.transformers.allErrors import software.amazon.smithy.rust.codegen.testutil.TokioTest import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.getTrait +import software.amazon.smithy.rust.codegen.util.hasStreamingMember import software.amazon.smithy.rust.codegen.util.hasTrait import software.amazon.smithy.rust.codegen.util.inputShape +import software.amazon.smithy.rust.codegen.util.isStreaming import software.amazon.smithy.rust.codegen.util.orNull import software.amazon.smithy.rust.codegen.util.outputShape import software.amazon.smithy.rust.codegen.util.toSnakeCase @@ -150,7 +154,7 @@ class ServerProtocolTestGenerator( private fun renderTestHelper(writer: RustWriter) { // Create a tower service to perform protocol test val crateName = coreCodegenContext.settings.moduleName - val operationNames = operations.map { RustReservedWords.escapeIfNeeded(symbolProvider.toSymbol(it).name.toSnakeCase()) } + val operationNames = operations.map { it.toName() } val operationRegistryName = "OperationRegistry" val operationRegistryBuilderName = "${operationRegistryName}Builder" @@ -159,7 +163,7 @@ class ServerProtocolTestGenerator( """ use #{Tower}::Service as _; - type F = fn(Input) -> std::pin::Pin + Send>>; + pub(crate) type F = fn(Input) -> std::pin::Pin + Send>>; type RegistryBuilder = crate::operation_registry::$operationRegistryBuilderName<#{Hyper}::Body, ${ operations.map { @@ -245,6 +249,8 @@ class ServerProtocolTestGenerator( } } + private fun OperationShape.toName(): String = RustReservedWords.escapeIfNeeded(symbolProvider.toSymbol(this).name.toSnakeCase()) + /** * Filter out test cases that are disabled or don't match the service protocol */ @@ -334,7 +340,7 @@ class ServerProtocolTestGenerator( renderHttpRequest(uri, method, headers, body.orNull(), queryParams, host.orNull()) } if (protocolSupport.requestBodyDeserialization) { - checkParams(operationShape, operationSymbol, httpRequestTestCase, this) + checkRequest(operationShape, operationSymbol, httpRequestTestCase, this) } // Explicitly warn if the test case defined parameters that we aren't doing anything with @@ -468,24 +474,97 @@ class ServerProtocolTestGenerator( } } - private fun checkParams(operationShape: OperationShape, operationSymbol: Symbol, httpRequestTestCase: HttpRequestTestCase, rustWriter: RustWriter) { + private fun checkRequest(operationShape: OperationShape, operationSymbol: Symbol, httpRequestTestCase: HttpRequestTestCase, rustWriter: RustWriter) { val inputShape = operationShape.inputShape(coreCodegenContext.model) - rustWriter.writeInline("let expected = ") - instantiator.render(rustWriter, inputShape, httpRequestTestCase.params) - rustWriter.write(";") - val operationName = "${operationSymbol.name}${ServerHttpBoundProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}" - rustWriter.rustTemplate( + val (inputT, outputT) = operationInputOutputTypes[operationShape]!! + val helperNamespace = "crate::operation::protocol_test_helper" + rustWriter.withBlock( """ - crate::operation::protocol_test_helper::validate_request( + $helperNamespace::validate_request( http_request, - // |builder| builder, - &std::convert::identity, - ).await + &|builder| { + builder.${operationShape.toName()}((|input| Box::pin(async move { """, - *codegenScope, - ) + + "})) as $helperNamespace::F<$inputT, $outputT>)}).await", + + ) { + // Construct expected request + rustWriter.writeInline("let expected = ") + instantiator.render(rustWriter, inputShape, httpRequestTestCase.params) + rustWriter.write(";") + + checkRequestParams(inputShape, rustWriter) + // TODO(weihanglo): provide actual response object + rustWriter.rust("todo!()") + } + } + + private fun checkRequestParams(inputShape: StructureShape, rustWriter: RustWriter) { + if (inputShape.hasStreamingMember(model)) { + // A streaming shape does not implement `PartialEq`, so we have to iterate over the input shape's members + // and handle the equality assertion separately. + for (member in inputShape.members()) { + val memberName = coreCodegenContext.symbolProvider.toMemberName(member) + if (member.isStreaming(coreCodegenContext.model)) { + rustWriter.rustTemplate( + """ + #{AssertEq}( + input.$memberName.collect().await.unwrap().into_bytes(), + expected.$memberName.collect().await.unwrap().into_bytes() + ); + """, + *codegenScope, + ) + } else { + rustWriter.rustTemplate( + """ + #{AssertEq}(input.$memberName, expected.$memberName, "Unexpected value for `$memberName`"); + """, + *codegenScope, + ) + } + } + } else { + val hasFloatingPointMembers = inputShape.members().any { + val target = model.expectShape(it.target) + (target is DoubleShape) || (target is FloatShape) + } + + // TODO(https://github.com/awslabs/smithy-rs/issues/1147) Handle the case of nested floating point members. + if (hasFloatingPointMembers) { + for (member in inputShape.members()) { + val memberName = coreCodegenContext.symbolProvider.toMemberName(member) + when (coreCodegenContext.model.expectShape(member.target)) { + is DoubleShape, is FloatShape -> { + rustWriter.addUseImports( + RuntimeType.ProtocolTestHelper(coreCodegenContext.runtimeConfig, "FloatEquals") + .toSymbol(), + ) + rustWriter.rust( + """ + assert!(input.$memberName.float_equals(&expected.$memberName), + "Unexpected value for `$memberName` {:?} vs. {:?}", expected.$memberName, input.$memberName); + """, + ) + } + + else -> { + rustWriter.rustTemplate( + """ + #{AssertEq}(input.$memberName, expected.$memberName, "Unexpected value for `$memberName`"); + """, + *codegenScope, + ) + } + } + } + } else { + rustWriter.rustTemplate("#{AssertEq}(input, expected);", *codegenScope) + } + } } private fun checkResponse(rustWriter: RustWriter, testCase: HttpResponseTestCase) { From ed129b34a38e53be527d645b05f1b983e10fca7e Mon Sep 17 00:00:00 2001 From: Weihang Lo Date: Mon, 5 Sep 2022 17:24:21 +0100 Subject: [PATCH 06/15] Make protocol test response instantiate with default values --- .../smithy/generators/ServerServiceGenerator.kt | 1 - .../protocol/ServerProtocolTestGenerator.kt | 16 +++++++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt index 73185f2750..be1edfbac6 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt @@ -39,7 +39,6 @@ open class ServerServiceGenerator( */ fun render() { rustCrate.withModule(RustModule.public("operation")) { writer -> - // TODO(weihanglo): remove #![allow(dead_code)] writer.rust("##![allow(dead_code)]") ServerProtocolTestGenerator(coreCodegenContext, protocolSupport, protocolGenerator).render(writer) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 89bca7394a..1d5920c5be 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -79,13 +79,13 @@ class ServerProtocolTestGenerator( val inputSymbol = symbolProvider.toSymbol(it.inputShape(model)) val outputSymbol = symbolProvider.toSymbol(it.outputShape(model)) val operationSymbol = symbolProvider.toSymbol(it) - val errorSymbol = RuntimeType("${operationSymbol.name}Error", null, "crate::error") val inputT = inputSymbol.fullName val t = outputSymbol.fullName val outputT = if (it.errors.isEmpty()) { t } else { + val errorSymbol = RuntimeType("${operationSymbol.name}Error", null, "crate::error") val e = errorSymbol.fullyQualifiedName() "Result<$t, $e>" } @@ -476,6 +476,7 @@ class ServerProtocolTestGenerator( private fun checkRequest(operationShape: OperationShape, operationSymbol: Symbol, httpRequestTestCase: HttpRequestTestCase, rustWriter: RustWriter) { val inputShape = operationShape.inputShape(coreCodegenContext.model) + val outputShape = operationShape.outputShape(coreCodegenContext.model) val operationName = "${operationSymbol.name}${ServerHttpBoundProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}" val (inputT, outputT) = operationInputOutputTypes[operationShape]!! @@ -497,8 +498,17 @@ class ServerProtocolTestGenerator( rustWriter.write(";") checkRequestParams(inputShape, rustWriter) - // TODO(weihanglo): provide actual response object - rustWriter.rust("todo!()") + + // Construct dummy output. + rustWriter.writeInline("let response = ") + instantiator.render(rustWriter, outputShape, Node.objectNode(), Instantiator.defaultContext().copy(defaultsForRequiredFields = true)) + rustWriter.write(";") + + if (operationShape.errors.isEmpty()) { + rustWriter.rust("response") + } else { + rustWriter.rust("Ok(response)") + } } } From d003bf6b36f94310318319f47f45452eeb503821 Mon Sep 17 00:00:00 2001 From: Weihang Lo Date: Mon, 5 Sep 2022 19:09:59 +0100 Subject: [PATCH 07/15] Add module meta for helper module Signed-off-by: Weihang Lo --- .../smithy/generators/ServerServiceGenerator.kt | 4 ++-- .../protocol/ServerProtocolTestGenerator.kt | 11 +++++++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt index be1edfbac6..d53d3477c8 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt @@ -12,6 +12,7 @@ import software.amazon.smithy.rust.codegen.rustlang.RustWriter import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolTestGenerator import software.amazon.smithy.rust.codegen.smithy.CoreCodegenContext +import software.amazon.smithy.rust.codegen.smithy.DefaultPublicModules import software.amazon.smithy.rust.codegen.smithy.RustCrate import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolGenerator import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport @@ -38,8 +39,7 @@ open class ServerServiceGenerator( * which assigns a symbol location to each shape. */ fun render() { - rustCrate.withModule(RustModule.public("operation")) { writer -> - writer.rust("##![allow(dead_code)]") + rustCrate.withModule(DefaultPublicModules["operation"]!!) { writer -> ServerProtocolTestGenerator(coreCodegenContext, protocolSupport, protocolGenerator).render(writer) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 1d5920c5be..f8193d56a7 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -153,12 +153,19 @@ class ServerProtocolTestGenerator( */ private fun renderTestHelper(writer: RustWriter) { // Create a tower service to perform protocol test - val crateName = coreCodegenContext.settings.moduleName val operationNames = operations.map { it.toName() } val operationRegistryName = "OperationRegistry" val operationRegistryBuilderName = "${operationRegistryName}Builder" - writer.withModule("protocol_test_helper") { + val moduleMeta = RustMetadata( + additionalAttributes = listOf( + Attribute.Cfg("test"), + Attribute.Custom("allow(dead_code)"), + ), + + visibility = Visibility.PUBCRATE, + ) + writer.withModule("protocol_test_helper", moduleMeta) { rustTemplate( """ use #{Tower}::Service as _; From 727fc1b5b6e3e53a096050e4a368baf4eb3a0cc3 Mon Sep 17 00:00:00 2001 From: Weihang Lo Date: Wed, 7 Sep 2022 10:57:34 +0100 Subject: [PATCH 08/15] Apply suggestions from code review Co-authored-by: david-perez --- .../rust/codegen/smithy/generators/Instantiator.kt | 6 ++++-- .../protocol/ServerProtocolTestGenerator.kt | 11 +++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt index 2e068d506d..9f40fc2646 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt @@ -69,7 +69,7 @@ class Instantiator( val streaming: Boolean, // Whether we are instantiating with a Builder, in which case all setters take Option val builder: Boolean, - // If a given required field is missing, try to provide a default value. + // Fill out `required` fields with a default value. val defaultsForRequiredFields: Boolean, ) @@ -312,7 +312,9 @@ class Instantiator( } /** - * Fill default values for missing required value of a shape. + * Returns a default value for a shape. + * + * Warning: this method does not take into account any constraint traits attached to the shape. */ private fun fillDefaultValue(shape: Shape): Node = when (shape) { // Compound Shapes diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index f8193d56a7..542c31f7b5 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -146,13 +146,12 @@ class ServerProtocolTestGenerator( } /** - * Render a test helper module that help + * Render a test helper module to: * - * - generate dynamic builder for each handler, and - * - construct a tower service to exercise each test case. + * - generate a dynamic builder for each handler, and + * - construct a Tower service to exercise each test case. */ private fun renderTestHelper(writer: RustWriter) { - // Create a tower service to perform protocol test val operationNames = operations.map { it.toName() } val operationRegistryName = "OperationRegistry" val operationRegistryBuilderName = "${operationRegistryName}Builder" @@ -499,14 +498,14 @@ class ServerProtocolTestGenerator( "})) as $helperNamespace::F<$inputT, $outputT>)}).await", ) { - // Construct expected request + // Construct expected request. rustWriter.writeInline("let expected = ") instantiator.render(rustWriter, inputShape, httpRequestTestCase.params) rustWriter.write(";") checkRequestParams(inputShape, rustWriter) - // Construct dummy output. + // Construct a dummy response. rustWriter.writeInline("let response = ") instantiator.render(rustWriter, outputShape, Node.objectNode(), Instantiator.defaultContext().copy(defaultsForRequiredFields = true)) rustWriter.write(";") From ad43acb909a43a7d39fac0e652c597237958aefe Mon Sep 17 00:00:00 2001 From: Weihang Lo Date: Wed, 7 Sep 2022 12:04:51 +0100 Subject: [PATCH 09/15] Address most style suggestions --- .../codegen/smithy/generators/Instantiator.kt | 20 +++---- .../protocol/ServerProtocolTestGenerator.kt | 57 +++++++++---------- 2 files changed, 34 insertions(+), 43 deletions(-) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt index 9f40fc2646..03df2a720c 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt @@ -232,8 +232,8 @@ class Instantiator( val variant = if (ctx.defaultsForRequiredFields && data.members.isEmpty()) { val (name, memberShape) = shape.allMembers.entries.first() - val shape = model.expectShape(memberShape.target) - Node.from(name) to fillDefaultValue(shape) + val targetShape = model.expectShape(memberShape.target) + Node.from(name) to fillDefaultValue(targetShape) } else { check(data.members.size == 1) val entry = data.members.iterator().next() @@ -296,8 +296,8 @@ class Instantiator( memberShape.isRequired && !data.members.containsKey(Node.from(name)) } .forEach { (_, memberShape) -> - val shape = model.expectShape(memberShape.target) - renderMemberHelper(memberShape, fillDefaultValue(shape)) + val targetShape = model.expectShape(memberShape.target) + renderMemberHelper(memberShape, fillDefaultValue(targetShape)) } } @@ -317,27 +317,21 @@ class Instantiator( * Warning: this method does not take into account any constraint traits attached to the shape. */ private fun fillDefaultValue(shape: Shape): Node = when (shape) { - // Compound Shapes + // Aggregate shapes. is StructureShape -> Node.objectNode() is UnionShape -> Node.objectNode() - - // Collections - is ListShape -> Node.arrayNode() + is CollectionShape -> Node.arrayNode() is MapShape -> Node.objectNode() - is SetShape -> Node.arrayNode() is MemberShape -> throw CodegenException("Unable to handle member shape `$shape`. Please provide target shape instead") - // Wrapped Shapes + // Simple Shapes is TimestampShape -> Node.from(0) // Number node for timestamp is BlobShape -> Node.from("") // String node for bytes - - // Simple Shapes is StringShape -> Node.from("") is NumberShape -> Node.from(0) is BooleanShape -> Node.from(false) - // TODO(weihanglo): how to handle document shape properly? is DocumentShape -> Node.objectNode() else -> throw CodegenException("Unrecognized shape `$shape`") } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 542c31f7b5..8bf43541a1 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -58,6 +58,8 @@ import software.amazon.smithy.rust.codegen.util.toSnakeCase import java.util.logging.Logger import kotlin.reflect.KFunction1 +private const val PROTOCOL_TEST_HELPER_MODULE_NAME = "protocol_test_helper" + /** * Generate protocol tests for an operation */ @@ -72,10 +74,9 @@ class ServerProtocolTestGenerator( private val symbolProvider = coreCodegenContext.symbolProvider private val operationIndex = OperationIndex.of(coreCodegenContext.model) - private val index = TopDownIndex.of(coreCodegenContext.model) - private val operations = index.getContainedOperations(coreCodegenContext.serviceShape).sortedBy { it.id } + private val operations = TopDownIndex.of(coreCodegenContext.model).getContainedOperations(coreCodegenContext.serviceShape).sortedBy { it.id } - private val operationInputOutputTypes = operations.associate { + private val operationInputOutputTypes = operations.associateWith { val inputSymbol = symbolProvider.toSymbol(it.inputShape(model)) val outputSymbol = symbolProvider.toSymbol(it.outputShape(model)) val operationSymbol = symbolProvider.toSymbol(it) @@ -85,12 +86,12 @@ class ServerProtocolTestGenerator( val outputT = if (it.errors.isEmpty()) { t } else { - val errorSymbol = RuntimeType("${operationSymbol.name}Error", null, "crate::error") - val e = errorSymbol.fullyQualifiedName() + val errorType = RuntimeType("${operationSymbol.name}Error", null, "crate::error") + val e = errorType.fullyQualifiedName() "Result<$t, $e>" } - Pair(it, Pair(inputT, outputT)) + inputT to outputT } private val instantiator = with(coreCodegenContext) { @@ -164,17 +165,17 @@ class ServerProtocolTestGenerator( visibility = Visibility.PUBCRATE, ) - writer.withModule("protocol_test_helper", moduleMeta) { + writer.withModule(PROTOCOL_TEST_HELPER_MODULE_NAME, moduleMeta) { rustTemplate( """ use #{Tower}::Service as _; - pub(crate) type F = fn(Input) -> std::pin::Pin + Send>>; + pub(crate) type Fun = fn(Input) -> std::pin::Pin + Send>>; type RegistryBuilder = crate::operation_registry::$operationRegistryBuilderName<#{Hyper}::Body, ${ operations.map { val (inputT, outputT) = operationInputOutputTypes[it]!! - "F<$inputT, $outputT>, ()" + "Fun<$inputT, $outputT>, ()" }.joinToString(", ") }>; @@ -182,22 +183,19 @@ class ServerProtocolTestGenerator( crate::operation_registry::$operationRegistryBuilderName::default() ${operations.mapIndexed { idx, operationShape -> val (inputT, outputT) = operationInputOutputTypes[operationShape]!! - ".${operationNames[idx]}((|_| Box::pin(async { todo!() })) as F<$inputT, $outputT> )" + ".${operationNames[idx]}((|_| Box::pin(async { todo!() })) as Fun<$inputT, $outputT> )" }.joinToString("\n")} } - pub(crate) async fn validate_request( + pub(crate) async fn build_router_and_make_request( http_request: #{Http}::request::Request<#{SmithyHttpServer}::body::Body>, f: &dyn Fn(RegistryBuilder) -> RegistryBuilder, ) { - let router: #{Router} = f(create_operation_registry_builder()) + let mut router: #{Router} = f(create_operation_registry_builder()) .build() .expect("unable to build operation registry") .into(); - _ = router.into_make_service() - .call(()) - .await - .expect("unable to get a router") + _ = router .call(http_request) .await .expect("unable to make an HTTP request"); @@ -238,13 +236,14 @@ class ServerProtocolTestGenerator( visibility = Visibility.PRIVATE, ) writer.withModule(testModuleName, moduleMeta) { - renderAllTestCases(operationShape, operationSymbol, allTests) + renderAllTestCases(operationShape, allTests) } } } - private fun RustWriter.renderAllTestCases(operationShape: OperationShape, operationSymbol: Symbol, allTests: List) { + private fun RustWriter.renderAllTestCases(operationShape: OperationShape, allTests: List) { allTests.forEach { + val operationSymbol = symbolProvider.toSymbol(operationShape) renderTestCaseBlock(it, this) { when (it) { is TestCase.RequestTest -> this.renderHttpRequestTestCase(it.testCase, operationShape, operationSymbol) @@ -484,36 +483,34 @@ class ServerProtocolTestGenerator( val inputShape = operationShape.inputShape(coreCodegenContext.model) val outputShape = operationShape.outputShape(coreCodegenContext.model) - val operationName = "${operationSymbol.name}${ServerHttpBoundProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}" val (inputT, outputT) = operationInputOutputTypes[operationShape]!! - val helperNamespace = "crate::operation::protocol_test_helper" rustWriter.withBlock( """ - $helperNamespace::validate_request( + super::$PROTOCOL_TEST_HELPER_MODULE_NAME::build_router_and_make_request( http_request, &|builder| { builder.${operationShape.toName()}((|input| Box::pin(async move { """, - "})) as $helperNamespace::F<$inputT, $outputT>)}).await", + "})) as super::$PROTOCOL_TEST_HELPER_MODULE_NAME::Fun<$inputT, $outputT>)}).await", ) { // Construct expected request. - rustWriter.writeInline("let expected = ") - instantiator.render(rustWriter, inputShape, httpRequestTestCase.params) - rustWriter.write(";") + rustWriter.withBlock("let expected = ", ";") { + instantiator.render(this, inputShape, httpRequestTestCase.params) + } checkRequestParams(inputShape, rustWriter) // Construct a dummy response. - rustWriter.writeInline("let response = ") - instantiator.render(rustWriter, outputShape, Node.objectNode(), Instantiator.defaultContext().copy(defaultsForRequiredFields = true)) - rustWriter.write(";") + rustWriter.withBlock("let response = ", ";") { + instantiator.render(this, outputShape, Node.objectNode(), Instantiator.defaultContext().copy(defaultsForRequiredFields = true)) + } if (operationShape.errors.isEmpty()) { - rustWriter.rust("response") + rustWriter.write("response") } else { - rustWriter.rust("Ok(response)") + rustWriter.write("Ok(response)") } } } From 7c2672ed7506eb54057174d6bcfa0ebe0fa827d1 Mon Sep 17 00:00:00 2001 From: Weihang Lo Date: Wed, 7 Sep 2022 14:09:14 +0100 Subject: [PATCH 10/15] add companion object for attribute #[allow(dead_code)] Signed-off-by: Weihang Lo --- .../software/amazon/smithy/rust/codegen/rustlang/RustTypes.kt | 1 + .../smithy/generators/protocol/ServerProtocolTestGenerator.kt | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustTypes.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustTypes.kt index cdc6a9a123..685990a792 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustTypes.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustTypes.kt @@ -384,6 +384,7 @@ sealed class Attribute { */ val NonExhaustive = Custom("non_exhaustive") val AllowUnusedMut = Custom("allow(unused_mut)") + val AllowDeadCode = Custom("allow(dead_code)") val DocHidden = Custom("doc(hidden)") val DocInline = Custom("doc(inline)") } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 8bf43541a1..f322c7ed20 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -160,7 +160,7 @@ class ServerProtocolTestGenerator( val moduleMeta = RustMetadata( additionalAttributes = listOf( Attribute.Cfg("test"), - Attribute.Custom("allow(dead_code)"), + Attribute.AllowDeadCode, ), visibility = Visibility.PUBCRATE, From 2b058d0b9b75d9e5a12e5073bb617f50863a0aef Mon Sep 17 00:00:00 2001 From: Weihang Lo Date: Wed, 7 Sep 2022 17:01:56 +0100 Subject: [PATCH 11/15] Use writable to make code readable --- .../protocol/ServerProtocolTestGenerator.kt | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index f322c7ed20..3f208bf5b7 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -35,6 +35,7 @@ import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustBlock import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.rustlang.withBlock +import software.amazon.smithy.rust.codegen.rustlang.writable import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpBoundProtocolGenerator @@ -157,6 +158,21 @@ class ServerProtocolTestGenerator( val operationRegistryName = "OperationRegistry" val operationRegistryBuilderName = "${operationRegistryName}Builder" + fun renderRegistryBuilderTypeParams() = writable { + operations.forEach { + val (inputT, outputT) = operationInputOutputTypes[it]!! + writeInline("Fun<$inputT, $outputT>, (), ") + } + } + + fun renderRegistryBuilderMethods() = writable { + operations.withIndex().forEach { + val (inputT, outputT) = operationInputOutputTypes[it.value]!! + val operationName = operationNames[it.index] + write(".$operationName((|_| Box::pin(async { todo!() })) as Fun<$inputT, $outputT> )") + } + } + val moduleMeta = RustMetadata( additionalAttributes = listOf( Attribute.Cfg("test"), @@ -172,19 +188,11 @@ class ServerProtocolTestGenerator( pub(crate) type Fun = fn(Input) -> std::pin::Pin + Send>>; - type RegistryBuilder = crate::operation_registry::$operationRegistryBuilderName<#{Hyper}::Body, ${ - operations.map { - val (inputT, outputT) = operationInputOutputTypes[it]!! - "Fun<$inputT, $outputT>, ()" - }.joinToString(", ") - }>; + type RegistryBuilder = crate::operation_registry::$operationRegistryBuilderName<#{Hyper}::Body, #{RegistryBuilderTypeParams:W}>; fn create_operation_registry_builder() -> RegistryBuilder { crate::operation_registry::$operationRegistryBuilderName::default() - ${operations.mapIndexed { idx, operationShape -> - val (inputT, outputT) = operationInputOutputTypes[operationShape]!! - ".${operationNames[idx]}((|_| Box::pin(async { todo!() })) as Fun<$inputT, $outputT> )" - }.joinToString("\n")} + #{RegistryBuilderMethods:W} } pub(crate) async fn build_router_and_make_request( @@ -201,6 +209,8 @@ class ServerProtocolTestGenerator( .expect("unable to make an HTTP request"); } """, + "RegistryBuilderTypeParams" to renderRegistryBuilderTypeParams(), + "RegistryBuilderMethods" to renderRegistryBuilderMethods(), *codegenScope, ) } From 4212818ad49e69998530cdbd174e017fa0e8d327 Mon Sep 17 00:00:00 2001 From: Weihang Lo Date: Wed, 7 Sep 2022 23:24:10 +0100 Subject: [PATCH 12/15] recursively call `filldefaultValue` Signed-off-by: Weihang Lo --- .../codegen/smithy/generators/Instantiator.kt | 8 +-- .../smithy/generators/InstantiatorTest.kt | 69 +++++++++++++++---- 2 files changed, 59 insertions(+), 18 deletions(-) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt index 03df2a720c..7c23ae07ee 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt @@ -296,8 +296,7 @@ class Instantiator( memberShape.isRequired && !data.members.containsKey(Node.from(name)) } .forEach { (_, memberShape) -> - val targetShape = model.expectShape(memberShape.target) - renderMemberHelper(memberShape, fillDefaultValue(targetShape)) + renderMemberHelper(memberShape, fillDefaultValue(memberShape)) } } @@ -317,17 +316,16 @@ class Instantiator( * Warning: this method does not take into account any constraint traits attached to the shape. */ private fun fillDefaultValue(shape: Shape): Node = when (shape) { + is MemberShape -> fillDefaultValue(model.expectShape(shape.target)) + // Aggregate shapes. is StructureShape -> Node.objectNode() is UnionShape -> Node.objectNode() is CollectionShape -> Node.arrayNode() is MapShape -> Node.objectNode() - is MemberShape -> throw CodegenException("Unable to handle member shape `$shape`. Please provide target shape instead") - // Simple Shapes is TimestampShape -> Node.from(0) // Number node for timestamp - is BlobShape -> Node.from("") // String node for bytes is StringShape -> Node.from("") is NumberShape -> Node.from(0) diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/InstantiatorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/InstantiatorTest.kt index 0e8e7dfcfe..392975fe6a 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/InstantiatorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/InstantiatorTest.kt @@ -69,19 +69,38 @@ class InstantiatorTest { structure MyStructRequired { @required - foo: String, + str: String, @required - bar: PrimitiveInteger, + primitiveInt: PrimitiveInteger, @required - baz: Integer, + int: Integer, @required ts: Timestamp, @required - byteValue: Byte + byte: Byte @required - union: MyUnion, + union: NestedUnion, + @required + structure: NestedStruct, + @required + list: MyList, + @required + map: NestedMap, + @required + doc: Document } + union NestedUnion { + struct: NestedStruct, + int: Integer + } + + structure NestedStruct { + @required + str: String, + @required + num: Integer + } """.asSmithyModel().let { RecursiveShapeBoxer.transform(it) } private val symbolProvider = testSymbolProvider(model) @@ -256,22 +275,46 @@ class InstantiatorTest { @Test fun `generate struct with missing required members`() { val structure = model.lookup("com.test#MyStructRequired") - val union = model.lookup("com.test#MyUnion") - val sut = Instantiator(symbolProvider, model, runtimeConfig, CodegenTarget.CLIENT) + val inner = model.lookup("com.test#Inner") + val nestedStruct = model.lookup("com.test#NestedStruct") + val union = model.lookup("com.test#NestedUnion") + val sut = Instantiator(symbolProvider, model, runtimeConfig, CodegenTarget.SERVER) val data = Node.parse("{}") val writer = RustWriter.forModule("model") structure.renderWithModelBuilder(model, symbolProvider, writer) + inner.renderWithModelBuilder(model, symbolProvider, writer) + nestedStruct.renderWithModelBuilder(model, symbolProvider, writer) UnionGenerator(model, symbolProvider, writer, union).render() writer.test { writer.withBlock("let result = ", ";") { sut.render(this, structure, data, Instantiator.defaultContext().copy(defaultsForRequiredFields = true)) } - writer.write("assert_eq!(result.foo.unwrap(), \"\");") - writer.write("assert_eq!(result.bar, 0);") - writer.write("assert_eq!(result.baz.unwrap(), 0);") - writer.write("assert_eq!(result.ts.unwrap(), aws_smithy_types::DateTime::from_secs(0));") - writer.write("assert_eq!(result.byte_value.unwrap(), 0);") - writer.write("assert_eq!(result.union.unwrap(), MyUnion::StringVariant(String::new()));") + writer.write( + """ + use std::collections::HashMap; + use aws_smithy_types::{DateTime, Document}; + + let expected = MyStructRequired { + str: Some("".into()), + primitive_int: 0, + int: Some(0), + ts: Some(DateTime::from_secs(0)), + byte: Some(0), + union: Some(NestedUnion::Struct(NestedStruct { + str: Some("".into()), + num: Some(0), + })), + structure: Some(NestedStruct { + str: Some("".into()), + num: Some(0), + }), + list: Some(vec![]), + map: Some(HashMap::new()), + doc: Some(Document::Object(HashMap::new())), + }; + assert_eq!(result, expected); + """, + ) } writer.compileAndTest() } From 440dbae47be6be45119bf028112aa6b42cecd076 Mon Sep 17 00:00:00 2001 From: Weihang Lo Date: Thu, 8 Sep 2022 12:34:37 +0100 Subject: [PATCH 13/15] Exercise with `OperationExtension` --- .../protocol/ServerProtocolTestGenerator.kt | 40 +++++++------------ 1 file changed, 14 insertions(+), 26 deletions(-) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 3f208bf5b7..f5722ec64c 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -195,18 +195,24 @@ class ServerProtocolTestGenerator( #{RegistryBuilderMethods:W} } + /// The operation full name is a concatenation of `.`. pub(crate) async fn build_router_and_make_request( http_request: #{Http}::request::Request<#{SmithyHttpServer}::body::Body>, + operation_full_name: &str, f: &dyn Fn(RegistryBuilder) -> RegistryBuilder, ) { let mut router: #{Router} = f(create_operation_registry_builder()) .build() .expect("unable to build operation registry") .into(); - _ = router + let http_response = router .call(http_request) .await .expect("unable to make an HTTP request"); + let operation_extension = http_response.extensions() + .get::<#{SmithyHttpServer}::extension::OperationExtension>() + .expect("extension `OperationExtension` not found"); + #{AssertEq}(operation_extension.absolute(), operation_full_name); } """, "RegistryBuilderTypeParams" to renderRegistryBuilderTypeParams(), @@ -498,6 +504,7 @@ class ServerProtocolTestGenerator( """ super::$PROTOCOL_TEST_HELPER_MODULE_NAME::build_router_and_make_request( http_request, + "${operationShape.id.namespace}.${operationSymbol.name}", &|builder| { builder.${operationShape.toName()}((|input| Box::pin(async move { """, @@ -599,7 +606,8 @@ class ServerProtocolTestGenerator( // We can't check that the `OperationExtension` is set in the response, because it is set in the implementation // of the operation `Handler` trait, a code path that does not get exercised when we don't have a request to // invoke it with (like in the case of an `httpResponseTest` test case). - // checkHttpOperationExtension(rustWriter) + // In https://github.com/awslabs/smithy-rs/pull/1708: We did change `httpResponseTest`s generation to `call()` + // the operation handler trait implementation instead of directly calling `from_request()`. // If no request body is defined, then no assertions are made about the body of the message. if (testCase.body.isPresent) { @@ -612,11 +620,10 @@ class ServerProtocolTestGenerator( checkHeaders(rustWriter, "&http_response.headers()", testCase.headers) // We can't check that the `OperationExtension` is set in the response, because it is set in the implementation - // of the operation `Handler` trait, a code path that does not get exercised by `httpRequestTest` test cases. - // TODO(https://github.com/awslabs/smithy-rs/issues/1212): We could change test case generation so as to `call()` - // the operation handler trait implementation instead of directly calling `from_request()`, or we could run an - // actual service. - // checkHttpOperationExtension(rustWriter) + // of the operation `Handler` trait, a code path that does not get exercised when we don't have a request to + // invoke it with (like in the case of an `httpResponseTest` test case). + // In https://github.com/awslabs/smithy-rs/pull/1708: We did change `httpResponseTest`s generation to `call()` + // the operation handler trait implementation instead of directly calling `from_request()`. // If no request body is defined, then no assertions are made about the body of the message. if (testCase.body.isEmpty) return @@ -664,25 +671,6 @@ class ServerProtocolTestGenerator( } } - // We can't check that the `OperationExtension` is set in the response, because it is set in the implementation - // of the operation `Handler` trait, a code path that does not get exercised when we don't have a request to - // invoke it with (like in the case of an `httpResponseTest` test case). - // private fun checkHttpOperationExtension(rustWriter: RustWriter) { - // rustWriter.rustTemplate( - // """ - // let operation_extension = http_response.extensions() - // .get::<#{SmithyHttpServer}::extension::OperationExtension>() - // .expect("extension `OperationExtension` not found"); - // """.trimIndent(), - // *codegenScope, - // ) - // rustWriter.writeWithNoFormatting( - // """ - // assert_eq!(operation_extension.absolute(), format!("{}.{}", "${operationShape.id.namespace}", "${operationSymbol.name}")); - // """.trimIndent(), - // ) - // } - private fun checkStatusCode(rustWriter: RustWriter, statusCode: Int) { rustWriter.rustTemplate( """ From 2a5768c1b9671c43b911127a1cd1eac4c67f31cd Mon Sep 17 00:00:00 2001 From: Weihang Lo Date: Thu, 8 Sep 2022 18:20:39 +0100 Subject: [PATCH 14/15] Temporary protocol tests fix for awslabs/smithy#1391 Missing `X-Amz-Target` in response header --- .../protocol/ServerProtocolTestGenerator.kt | 60 ++++++++++++++++--- 1 file changed, 51 insertions(+), 9 deletions(-) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index f5722ec64c..26bbe7da80 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -116,7 +116,7 @@ class ServerProtocolTestGenerator( abstract val protocol: ShapeId abstract val testType: TestType - data class RequestTest(val testCase: HttpRequestTestCase) : TestCase() { + data class RequestTest(val testCase: HttpRequestTestCase, val operationShape: OperationShape) : TestCase() { override val id: String = testCase.id override val documentation: String? = testCase.documentation.orNull() override val protocol: ShapeId = testCase.protocol @@ -227,7 +227,7 @@ class ServerProtocolTestGenerator( val operationSymbol = symbolProvider.toSymbol(operationShape) val requestTests = operationShape.getTrait() - ?.getTestCasesFor(AppliesTo.SERVER).orEmpty().map { TestCase.RequestTest(it) } + ?.getTestCasesFor(AppliesTo.SERVER).orEmpty().map { TestCase.RequestTest(it, operationShape) } val responseTests = operationShape.getTrait() ?.getTestCasesFor(AppliesTo.SERVER).orEmpty().map { TestCase.ResponseTest(it, outputShape) } val errorTests = operationIndex.getErrors(operationShape).flatMap { error -> @@ -296,8 +296,8 @@ class ServerProtocolTestGenerator( if (howToFixIt == null) { it } else { - val fixed = howToFixIt(it.testCase) - TestCase.RequestTest(fixed) + val fixed = howToFixIt(it.testCase, it.operationShape) + TestCase.RequestTest(fixed, it.operationShape) } } is TestCase.ResponseTest -> { @@ -915,7 +915,7 @@ class ServerProtocolTestGenerator( // or because they are flaky private val DisableTests = setOf() - private fun fixRestJsonSupportsNaNFloatQueryValues(testCase: HttpRequestTestCase): HttpRequestTestCase { + private fun fixRestJsonSupportsNaNFloatQueryValues(testCase: HttpRequestTestCase, operationShape: OperationShape): HttpRequestTestCase { val params = Node.parse( """ { @@ -931,7 +931,7 @@ class ServerProtocolTestGenerator( return testCase.toBuilder().params(params).build() } - private fun fixRestJsonSupportsInfinityFloatQueryValues(testCase: HttpRequestTestCase): HttpRequestTestCase = + private fun fixRestJsonSupportsInfinityFloatQueryValues(testCase: HttpRequestTestCase, operationShape: OperationShape): HttpRequestTestCase = testCase.toBuilder().params( Node.parse( """ @@ -946,7 +946,7 @@ class ServerProtocolTestGenerator( """.trimMargin(), ).asObjectNode().get(), ).build() - private fun fixRestJsonSupportsNegativeInfinityFloatQueryValues(testCase: HttpRequestTestCase): HttpRequestTestCase = + private fun fixRestJsonSupportsNegativeInfinityFloatQueryValues(testCase: HttpRequestTestCase, operationShape: OperationShape): HttpRequestTestCase = testCase.toBuilder().params( Node.parse( """ @@ -961,7 +961,7 @@ class ServerProtocolTestGenerator( """.trimMargin(), ).asObjectNode().get(), ).build() - private fun fixRestJsonAllQueryStringTypes(testCase: HttpRequestTestCase): HttpRequestTestCase = + private fun fixRestJsonAllQueryStringTypes(testCase: HttpRequestTestCase, operationShape: OperationShape): HttpRequestTestCase = testCase.toBuilder().params( Node.parse( """ @@ -1008,7 +1008,7 @@ class ServerProtocolTestGenerator( """.trimMargin(), ).asObjectNode().get(), ).build() - private fun fixRestJsonQueryStringEscaping(testCase: HttpRequestTestCase): HttpRequestTestCase = + private fun fixRestJsonQueryStringEscaping(testCase: HttpRequestTestCase, operationShape: OperationShape): HttpRequestTestCase = testCase.toBuilder().params( Node.parse( """ @@ -1022,6 +1022,9 @@ class ServerProtocolTestGenerator( ).asObjectNode().get(), ).build() + private fun fixAwsJson11MissingHeaderXAmzTarget(testCase: HttpRequestTestCase, operationShape: OperationShape): HttpRequestTestCase = + testCase.toBuilder().putHeader("x-amz-target", "JsonProtocol.${operationShape.id.name}").build() + // These are tests whose definitions in the `awslabs/smithy` repository are wrong. // This is because they have not been written from a server perspective, and as such the expected `params` field is incomplete. // TODO(https://github.com/awslabs/smithy-rs/issues/1288): Contribute a PR to fix them upstream. @@ -1032,6 +1035,45 @@ class ServerProtocolTestGenerator( Pair(RestJson, "RestJsonSupportsNegativeInfinityFloatQueryValues") to ::fixRestJsonSupportsNegativeInfinityFloatQueryValues, Pair(RestJson, "RestJsonAllQueryStringTypes") to ::fixRestJsonAllQueryStringTypes, Pair(RestJson, "RestJsonQueryStringEscaping") to ::fixRestJsonQueryStringEscaping, + + // https://github.com/awslabs/smithy/pull/1392 + // Missing `X-Amz-Target` in response header + Pair(AwsJson11, "AwsJson11Enums") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "AwsJson11ListsSerializeNull") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "AwsJson11MapsSerializeNullValues") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "AwsJson11ServersDontDeserializeNullStructureValues") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "PutAndGetInlineDocumentsInput") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "json_1_1_client_sends_empty_payload_for_no_input_shape") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "json_1_1_service_supports_empty_payload_for_no_input_shape") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "sends_requests_to_slash") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_blob_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_boolean_shapes_false") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_boolean_shapes_true") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_double_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_empty_list_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_empty_map_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_empty_structure_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_float_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_integer_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_list_of_map_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_list_of_recursive_structure_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_list_of_structure_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_list_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_long_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_map_of_list_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_map_of_recursive_structure_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_map_of_structure_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_map_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_recursive_structure_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_string_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_string_shapes_with_jsonvalue_trait") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_structure_members_with_locationname_traits") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_structure_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_structure_which_have_no_members") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_timestamp_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_timestamp_shapes_with_httpdate_timestampformat") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_timestamp_shapes_with_iso8601_timestampformat") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_timestamp_shapes_with_unixtimestamp_timestampformat") to ::fixAwsJson11MissingHeaderXAmzTarget, ) private val BrokenResponseTests: Map, KFunction1> = mapOf() From 334269383fcfdfc42f332f10ae866a20929ab32d Mon Sep 17 00:00:00 2001 From: Weihang Lo Date: Fri, 9 Sep 2022 01:31:20 +0100 Subject: [PATCH 15/15] Add `X-Amz-Target` for common models --- .../naming-obstacle-course-ops.smithy | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/codegen-core/common-test-models/naming-obstacle-course-ops.smithy b/codegen-core/common-test-models/naming-obstacle-course-ops.smithy index c1e84b4586..087d99b750 100644 --- a/codegen-core/common-test-models/naming-obstacle-course-ops.smithy +++ b/codegen-core/common-test-models/naming-obstacle-course-ops.smithy @@ -34,7 +34,10 @@ service Config { uri: "/", body: "{\"as\": 5, \"async\": true}", bodyMediaType: "application/json", - headers: {"Content-Type": "application/x-amz-json-1.1"} + headers: { + "Content-Type": "application/x-amz-json-1.1", + "X-Amz-Target": "Config.ReservedWordsAsMembers", + }, } ]) operation ReservedWordsAsMembers { @@ -78,7 +81,10 @@ structure Type { uri: "/", body: "{\"regular_string\": \"hello!\"}", bodyMediaType: "application/json", - headers: {"Content-Type": "application/x-amz-json-1.1"} + headers: { + "Content-Type": "application/x-amz-json-1.1", + "X-Amz-Target": "Config.StructureNamePunning", + }, } ]) operation StructureNamePunning {