diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/rustlang/RustTypes.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/rustlang/RustTypes.kt index c63eb2805e..6d20af877f 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/rustlang/RustTypes.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/rustlang/RustTypes.kt @@ -384,6 +384,7 @@ sealed class Attribute { */ val NonExhaustive = Custom("non_exhaustive") val AllowUnusedMut = Custom("allow(unused_mut)") + val AllowDeadCode = Custom("allow(dead_code)") val DocHidden = Custom("doc(hidden)") val DocInline = Custom("doc(inline)") } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/Instantiator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/Instantiator.kt index a9c57aad2c..a54ab771d8 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/Instantiator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/Instantiator.kt @@ -5,6 +5,7 @@ package software.amazon.smithy.rust.codegen.client.smithy.generators +import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.model.Model import software.amazon.smithy.model.node.ArrayNode import software.amazon.smithy.model.node.Node @@ -68,13 +69,19 @@ class Instantiator( val streaming: Boolean, // Whether we are instantiating with a Builder, in which case all setters take Option val builder: Boolean, + // Fill out `required` fields with a default value. + val defaultsForRequiredFields: Boolean, ) + companion object { + fun defaultContext() = Ctx(lowercaseMapKeys = false, streaming = false, builder = false, defaultsForRequiredFields = false) + } + fun render( writer: RustWriter, shape: Shape, arg: Node, - ctx: Ctx = Ctx(lowercaseMapKeys = false, streaming = false, builder = false), + ctx: Ctx = defaultContext(), ) { when (shape) { // Compound Shapes @@ -222,14 +229,23 @@ class Instantiator( */ private fun renderUnion(writer: RustWriter, shape: UnionShape, data: ObjectNode, ctx: Ctx) { val unionSymbol = symbolProvider.toSymbol(shape) - check(data.members.size == 1) - val variant = data.members.iterator().next() - val memberName = variant.key.value + + val variant = if (ctx.defaultsForRequiredFields && data.members.isEmpty()) { + val (name, memberShape) = shape.allMembers.entries.first() + val targetShape = model.expectShape(memberShape.target) + Node.from(name) to fillDefaultValue(targetShape) + } else { + check(data.members.size == 1) + val entry = data.members.iterator().next() + entry.key to entry.value + } + + val memberName = variant.first.value val member = shape.expectMember(memberName) writer.write("#T::${symbolProvider.toMemberName(member)}", unionSymbol) // unions should specify exactly one member writer.withBlock("(", ")") { - renderMember(this, member, variant.value, ctx) + renderMember(this, member, variant.second, ctx) } } @@ -267,16 +283,54 @@ class Instantiator( * ``` */ private fun renderStructure(writer: RustWriter, shape: StructureShape, data: ObjectNode, ctx: Ctx) { - writer.write("#T::builder()", symbolProvider.toSymbol(shape)) - data.members.forEach { (key, value) -> - val memberShape = shape.expectMember(key.value) + fun renderMemberHelper(memberShape: MemberShape, value: Node) { writer.withBlock(".${memberShape.setterName()}(", ")") { renderMember(this, memberShape, value, ctx) } } + + writer.write("#T::builder()", symbolProvider.toSymbol(shape)) + if (ctx.defaultsForRequiredFields) { + shape.allMembers.entries + .filter { (name, memberShape) -> + memberShape.isRequired && !data.members.containsKey(Node.from(name)) + } + .forEach { (_, memberShape) -> + renderMemberHelper(memberShape, fillDefaultValue(memberShape)) + } + } + + data.members.forEach { (key, value) -> + val memberShape = shape.expectMember(key.value) + renderMemberHelper(memberShape, value) + } writer.write(".build()") if (StructureGenerator.fallibleBuilder(shape, symbolProvider)) { writer.write(".unwrap()") } } + + /** + * Returns a default value for a shape. + * + * Warning: this method does not take into account any constraint traits attached to the shape. + */ + private fun fillDefaultValue(shape: Shape): Node = when (shape) { + is MemberShape -> fillDefaultValue(model.expectShape(shape.target)) + + // Aggregate shapes. + is StructureShape -> Node.objectNode() + is UnionShape -> Node.objectNode() + is CollectionShape -> Node.arrayNode() + is MapShape -> Node.objectNode() + + // Simple Shapes + is TimestampShape -> Node.from(0) // Number node for timestamp + is BlobShape -> Node.from("") // String node for bytes + is StringShape -> Node.from("") + is NumberShape -> Node.from(0) + is BooleanShape -> Node.from(false) + is DocumentShape -> Node.objectNode() + else -> throw CodegenException("Unrecognized shape `$shape`") + } } diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/InstantiatorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/InstantiatorTest.kt index 1550f99b1e..5c1c009c4d 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/InstantiatorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/InstantiatorTest.kt @@ -66,6 +66,41 @@ class InstantiatorTest { member: WithBox, value: Integer } + + structure MyStructRequired { + @required + str: String, + @required + primitiveInt: PrimitiveInteger, + @required + int: Integer, + @required + ts: Timestamp, + @required + byte: Byte + @required + union: NestedUnion, + @required + structure: NestedStruct, + @required + list: MyList, + @required + map: NestedMap, + @required + doc: Document + } + + union NestedUnion { + struct: NestedStruct, + int: Integer + } + + structure NestedStruct { + @required + str: String, + @required + num: Integer + } """.asSmithyModel().let { RecursiveShapeBoxer.transform(it) } private val symbolProvider = testSymbolProvider(model) @@ -236,4 +271,51 @@ class InstantiatorTest { } writer.compileAndTest() } + + @Test + fun `generate struct with missing required members`() { + val structure = model.lookup("com.test#MyStructRequired") + val inner = model.lookup("com.test#Inner") + val nestedStruct = model.lookup("com.test#NestedStruct") + val union = model.lookup("com.test#NestedUnion") + val sut = Instantiator(symbolProvider, model, runtimeConfig, CodegenTarget.SERVER) + val data = Node.parse("{}") + val writer = RustWriter.forModule("model") + structure.renderWithModelBuilder(model, symbolProvider, writer) + inner.renderWithModelBuilder(model, symbolProvider, writer) + nestedStruct.renderWithModelBuilder(model, symbolProvider, writer) + UnionGenerator(model, symbolProvider, writer, union).render() + writer.test { + writer.withBlock("let result = ", ";") { + sut.render(this, structure, data, Instantiator.defaultContext().copy(defaultsForRequiredFields = true)) + } + writer.write( + """ + use std::collections::HashMap; + use aws_smithy_types::{DateTime, Document}; + + let expected = MyStructRequired { + str: Some("".into()), + primitive_int: 0, + int: Some(0), + ts: Some(DateTime::from_secs(0)), + byte: Some(0), + union: Some(NestedUnion::Struct(NestedStruct { + str: Some("".into()), + num: Some(0), + })), + structure: Some(NestedStruct { + str: Some("".into()), + num: Some(0), + }), + list: Some(vec![]), + map: Some(HashMap::new()), + doc: Some(Document::Object(HashMap::new())), + }; + assert_eq!(result, expected); + """, + ) + } + writer.compileAndTest() + } } diff --git a/codegen-core/common-test-models/naming-obstacle-course-ops.smithy b/codegen-core/common-test-models/naming-obstacle-course-ops.smithy index c1e84b4586..087d99b750 100644 --- a/codegen-core/common-test-models/naming-obstacle-course-ops.smithy +++ b/codegen-core/common-test-models/naming-obstacle-course-ops.smithy @@ -34,7 +34,10 @@ service Config { uri: "/", body: "{\"as\": 5, \"async\": true}", bodyMediaType: "application/json", - headers: {"Content-Type": "application/x-amz-json-1.1"} + headers: { + "Content-Type": "application/x-amz-json-1.1", + "X-Amz-Target": "Config.ReservedWordsAsMembers", + }, } ]) operation ReservedWordsAsMembers { @@ -78,7 +81,10 @@ structure Type { uri: "/", body: "{\"regular_string\": \"hello!\"}", bodyMediaType: "application/json", - headers: {"Content-Type": "application/x-amz-json-1.1"} + headers: { + "Content-Type": "application/x-amz-json-1.1", + "X-Amz-Target": "Config.StructureNamePunning", + }, } ]) operation StructureNamePunning { diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt index 516e9940cb..f8eb568ecf 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt @@ -10,6 +10,7 @@ import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.rust.codegen.client.rustlang.RustModule import software.amazon.smithy.rust.codegen.client.rustlang.RustWriter import software.amazon.smithy.rust.codegen.client.smithy.CoreCodegenContext +import software.amazon.smithy.rust.codegen.client.smithy.DefaultPublicModules import software.amazon.smithy.rust.codegen.client.smithy.RustCrate import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolGenerator import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolSupport @@ -37,15 +38,11 @@ open class ServerServiceGenerator( * which assigns a symbol location to each shape. */ fun render() { + rustCrate.withModule(DefaultPublicModules["operation"]!!) { writer -> + ServerProtocolTestGenerator(coreCodegenContext, protocolSupport, protocolGenerator).render(writer) + } + for (operation in operations) { - rustCrate.useShapeWriter(operation) { operationWriter -> - protocolGenerator.serverRenderOperation( - operationWriter, - operation, - ) - ServerProtocolTestGenerator(coreCodegenContext, protocolSupport, operation, operationWriter) - .render() - } if (operation.errors.isNotEmpty()) { rustCrate.withModule(RustModule.Error) { writer -> renderCombinedErrors(writer, operation) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index c76fd5a317..ebf52a91e0 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -5,7 +5,9 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators.protocol +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.knowledge.OperationIndex +import software.amazon.smithy.model.knowledge.TopDownIndex import software.amazon.smithy.model.node.Node import software.amazon.smithy.model.shapes.DoubleShape import software.amazon.smithy.model.shapes.FloatShape @@ -24,6 +26,7 @@ import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait import software.amazon.smithy.rust.codegen.client.rustlang.Attribute import software.amazon.smithy.rust.codegen.client.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.client.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.client.rustlang.RustReservedWords import software.amazon.smithy.rust.codegen.client.rustlang.RustWriter import software.amazon.smithy.rust.codegen.client.rustlang.Visibility import software.amazon.smithy.rust.codegen.client.rustlang.asType @@ -32,10 +35,12 @@ import software.amazon.smithy.rust.codegen.client.rustlang.rust import software.amazon.smithy.rust.codegen.client.rustlang.rustBlock import software.amazon.smithy.rust.codegen.client.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.client.rustlang.withBlock +import software.amazon.smithy.rust.codegen.client.rustlang.writable import software.amazon.smithy.rust.codegen.client.smithy.CoreCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.RuntimeType import software.amazon.smithy.rust.codegen.client.smithy.generators.CodegenTarget import software.amazon.smithy.rust.codegen.client.smithy.generators.Instantiator +import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolGenerator import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.client.smithy.transformers.allErrors import software.amazon.smithy.rust.codegen.client.testutil.TokioTest @@ -49,29 +54,46 @@ import software.amazon.smithy.rust.codegen.client.util.orNull import software.amazon.smithy.rust.codegen.client.util.outputShape import software.amazon.smithy.rust.codegen.client.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency +import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpBoundProtocolGenerator import java.util.logging.Logger import kotlin.reflect.KFunction1 +private const val PROTOCOL_TEST_HELPER_MODULE_NAME = "protocol_test_helper" + /** * Generate protocol tests for an operation */ class ServerProtocolTestGenerator( private val coreCodegenContext: CoreCodegenContext, private val protocolSupport: ProtocolSupport, - private val operationShape: OperationShape, - private val writer: RustWriter, + private val protocolGenerator: ProtocolGenerator, ) { private val logger = Logger.getLogger(javaClass.name) private val model = coreCodegenContext.model - private val inputShape = operationShape.inputShape(coreCodegenContext.model) - private val outputShape = operationShape.outputShape(coreCodegenContext.model) private val symbolProvider = coreCodegenContext.symbolProvider - private val operationSymbol = symbolProvider.toSymbol(operationShape) private val operationIndex = OperationIndex.of(coreCodegenContext.model) - private val operationImplementationName = "${operationSymbol.name}${ServerHttpBoundProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}" - private val operationErrorName = "crate::error::${operationSymbol.name}Error" + + private val operations = TopDownIndex.of(coreCodegenContext.model).getContainedOperations(coreCodegenContext.serviceShape).sortedBy { it.id } + + private val operationInputOutputTypes = operations.associateWith { + val inputSymbol = symbolProvider.toSymbol(it.inputShape(model)) + val outputSymbol = symbolProvider.toSymbol(it.outputShape(model)) + val operationSymbol = symbolProvider.toSymbol(it) + + val inputT = inputSymbol.fullName + val t = outputSymbol.fullName + val outputT = if (it.errors.isEmpty()) { + t + } else { + val errorType = RuntimeType("${operationSymbol.name}Error", null, "crate::error") + val e = errorType.fullyQualifiedName() + "Result<$t, $e>" + } + + inputT to outputT + } private val instantiator = with(coreCodegenContext) { Instantiator(symbolProvider, model, runtimeConfig, CodegenTarget.SERVER) @@ -82,8 +104,10 @@ class ServerProtocolTestGenerator( "SmithyHttp" to CargoDependency.SmithyHttp(coreCodegenContext.runtimeConfig).asType(), "Http" to CargoDependency.Http.asType(), "Hyper" to CargoDependency.Hyper.asType(), + "Tower" to CargoDependency.Tower.asType(), "SmithyHttpServer" to ServerCargoDependency.SmithyHttpServer(coreCodegenContext.runtimeConfig).asType(), "AssertEq" to CargoDependency.PrettyAssertions.asType().member("assert_eq!"), + "Router" to ServerRuntimeType.Router(coreCodegenContext.runtimeConfig), ) sealed class TestCase { @@ -92,7 +116,7 @@ class ServerProtocolTestGenerator( abstract val protocol: ShapeId abstract val testType: TestType - data class RequestTest(val testCase: HttpRequestTestCase) : TestCase() { + data class RequestTest(val testCase: HttpRequestTestCase, val operationShape: OperationShape) : TestCase() { override val id: String = testCase.id override val documentation: String? = testCase.documentation.orNull() override val protocol: ShapeId = testCase.protocol @@ -114,9 +138,96 @@ class ServerProtocolTestGenerator( } } - fun render() { + fun render(writer: RustWriter) { + renderTestHelper(writer) + + for (operation in operations) { + protocolGenerator.serverRenderOperation(writer, operation) + renderOperationTestCases(operation, writer) + } + } + + /** + * Render a test helper module to: + * + * - generate a dynamic builder for each handler, and + * - construct a Tower service to exercise each test case. + */ + private fun renderTestHelper(writer: RustWriter) { + val operationNames = operations.map { it.toName() } + val operationRegistryName = "OperationRegistry" + val operationRegistryBuilderName = "${operationRegistryName}Builder" + + fun renderRegistryBuilderTypeParams() = writable { + operations.forEach { + val (inputT, outputT) = operationInputOutputTypes[it]!! + writeInline("Fun<$inputT, $outputT>, (), ") + } + } + + fun renderRegistryBuilderMethods() = writable { + operations.withIndex().forEach { + val (inputT, outputT) = operationInputOutputTypes[it.value]!! + val operationName = operationNames[it.index] + write(".$operationName((|_| Box::pin(async { todo!() })) as Fun<$inputT, $outputT> )") + } + } + + val moduleMeta = RustMetadata( + additionalAttributes = listOf( + Attribute.Cfg("test"), + Attribute.AllowDeadCode, + ), + + visibility = Visibility.PUBCRATE, + ) + writer.withModule(PROTOCOL_TEST_HELPER_MODULE_NAME, moduleMeta) { + rustTemplate( + """ + use #{Tower}::Service as _; + + pub(crate) type Fun = fn(Input) -> std::pin::Pin + Send>>; + + type RegistryBuilder = crate::operation_registry::$operationRegistryBuilderName<#{Hyper}::Body, #{RegistryBuilderTypeParams:W}>; + + fn create_operation_registry_builder() -> RegistryBuilder { + crate::operation_registry::$operationRegistryBuilderName::default() + #{RegistryBuilderMethods:W} + } + + /// The operation full name is a concatenation of `.`. + pub(crate) async fn build_router_and_make_request( + http_request: #{Http}::request::Request<#{SmithyHttpServer}::body::Body>, + operation_full_name: &str, + f: &dyn Fn(RegistryBuilder) -> RegistryBuilder, + ) { + let mut router: #{Router} = f(create_operation_registry_builder()) + .build() + .expect("unable to build operation registry") + .into(); + let http_response = router + .call(http_request) + .await + .expect("unable to make an HTTP request"); + let operation_extension = http_response.extensions() + .get::<#{SmithyHttpServer}::extension::OperationExtension>() + .expect("extension `OperationExtension` not found"); + #{AssertEq}(operation_extension.absolute(), operation_full_name); + } + """, + "RegistryBuilderTypeParams" to renderRegistryBuilderTypeParams(), + "RegistryBuilderMethods" to renderRegistryBuilderMethods(), + *codegenScope, + ) + } + } + + private fun renderOperationTestCases(operationShape: OperationShape, writer: RustWriter) { + val outputShape = operationShape.outputShape(coreCodegenContext.model) + val operationSymbol = symbolProvider.toSymbol(operationShape) + val requestTests = operationShape.getTrait() - ?.getTestCasesFor(AppliesTo.SERVER).orEmpty().map { TestCase.RequestTest(it) } + ?.getTestCasesFor(AppliesTo.SERVER).orEmpty().map { TestCase.RequestTest(it, operationShape) } val responseTests = operationShape.getTrait() ?.getTestCasesFor(AppliesTo.SERVER).orEmpty().map { TestCase.ResponseTest(it, outputShape) } val errorTests = operationIndex.getErrors(operationShape).flatMap { error -> @@ -141,23 +252,26 @@ class ServerProtocolTestGenerator( visibility = Visibility.PRIVATE, ) writer.withModule(testModuleName, moduleMeta) { - renderAllTestCases(allTests) + renderAllTestCases(operationShape, allTests) } } } - private fun RustWriter.renderAllTestCases(allTests: List) { + private fun RustWriter.renderAllTestCases(operationShape: OperationShape, allTests: List) { allTests.forEach { + val operationSymbol = symbolProvider.toSymbol(operationShape) renderTestCaseBlock(it, this) { when (it) { - is TestCase.RequestTest -> this.renderHttpRequestTestCase(it.testCase) - is TestCase.ResponseTest -> this.renderHttpResponseTestCase(it.testCase, it.targetShape) - is TestCase.MalformedRequestTest -> this.renderHttpMalformedRequestTestCase(it.testCase) + is TestCase.RequestTest -> this.renderHttpRequestTestCase(it.testCase, operationShape, operationSymbol) + is TestCase.ResponseTest -> this.renderHttpResponseTestCase(it.testCase, it.targetShape, operationShape, operationSymbol) + is TestCase.MalformedRequestTest -> this.renderHttpMalformedRequestTestCase(it.testCase, operationSymbol) } } } } + private fun OperationShape.toName(): String = RustReservedWords.escapeIfNeeded(symbolProvider.toSymbol(this).name.toSnakeCase()) + /** * Filter out test cases that are disabled or don't match the service protocol */ @@ -182,8 +296,8 @@ class ServerProtocolTestGenerator( if (howToFixIt == null) { it } else { - val fixed = howToFixIt(it.testCase) - TestCase.RequestTest(fixed) + val fixed = howToFixIt(it.testCase, it.operationShape) + TestCase.RequestTest(fixed, it.operationShape) } } is TestCase.ResponseTest -> { @@ -236,16 +350,18 @@ class ServerProtocolTestGenerator( */ private fun RustWriter.renderHttpRequestTestCase( httpRequestTestCase: HttpRequestTestCase, + operationShape: OperationShape, + operationSymbol: Symbol, ) { if (!protocolSupport.requestDeserialization) { rust("/* test case disabled for this protocol (not yet supported) */") return } with(httpRequestTestCase) { - renderHttpRequest(uri, headers, body.orNull(), queryParams, host.orNull()) + renderHttpRequest(uri, method, headers, body.orNull(), queryParams, host.orNull()) } if (protocolSupport.requestBodyDeserialization) { - checkParams(httpRequestTestCase, this) + checkRequest(operationShape, operationSymbol, httpRequestTestCase, this) } // Explicitly warn if the test case defined parameters that we aren't doing anything with @@ -272,7 +388,12 @@ class ServerProtocolTestGenerator( private fun RustWriter.renderHttpResponseTestCase( testCase: HttpResponseTestCase, shape: StructureShape, + operationShape: OperationShape, + operationSymbol: Symbol, ) { + val operationImplementationName = "${operationSymbol.name}${ServerHttpBoundProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}" + val operationErrorName = "crate::error::${operationSymbol.name}Error" + if (!protocolSupport.responseSerialization || ( !protocolSupport.errorSerialization && shape.hasTrait() ) @@ -308,10 +429,10 @@ class ServerProtocolTestGenerator( * We are given a request definition and a response definition, and we have to assert that the request is rejected * with the given response. */ - private fun RustWriter.renderHttpMalformedRequestTestCase(testCase: HttpMalformedRequestTestCase) { + private fun RustWriter.renderHttpMalformedRequestTestCase(testCase: HttpMalformedRequestTestCase, operationSymbol: Symbol) { with(testCase.request) { // TODO(https://github.com/awslabs/smithy/issues/1102): `uri` should probably not be an `Optional`. - renderHttpRequest(uri.get(), headers, body.orNull(), queryParams, host.orNull()) + renderHttpRequest(uri.get(), method, headers, body.orNull(), queryParams, host.orNull()) } val operationName = "${operationSymbol.name}${ServerHttpBoundProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}" @@ -328,6 +449,7 @@ class ServerProtocolTestGenerator( private fun RustWriter.renderHttpRequest( uri: String, + method: String, headers: Map, body: String?, queryParams: List, @@ -338,6 +460,7 @@ class ServerProtocolTestGenerator( ##[allow(unused_mut)] let mut http_request = http::Request::builder() .uri("$uri") + .method("$method") """, *codegenScope, ) @@ -372,20 +495,44 @@ class ServerProtocolTestGenerator( } } - private fun checkParams(httpRequestTestCase: HttpRequestTestCase, rustWriter: RustWriter) { - rustWriter.writeInline("let expected = ") - instantiator.render(rustWriter, inputShape, httpRequestTestCase.params) - rustWriter.write(";") + private fun checkRequest(operationShape: OperationShape, operationSymbol: Symbol, httpRequestTestCase: HttpRequestTestCase, rustWriter: RustWriter) { + val inputShape = operationShape.inputShape(coreCodegenContext.model) + val outputShape = operationShape.outputShape(coreCodegenContext.model) - val operationName = "${operationSymbol.name}${ServerHttpBoundProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}" - rustWriter.rustTemplate( + val (inputT, outputT) = operationInputOutputTypes[operationShape]!! + rustWriter.withBlock( """ - let mut http_request = #{SmithyHttpServer}::request::RequestParts::new(http_request); - let parsed = super::$operationName::from_request(&mut http_request).await.expect("failed to parse request").0; + super::$PROTOCOL_TEST_HELPER_MODULE_NAME::build_router_and_make_request( + http_request, + "${operationShape.id.namespace}.${operationSymbol.name}", + &|builder| { + builder.${operationShape.toName()}((|input| Box::pin(async move { """, - *codegenScope, - ) + "})) as super::$PROTOCOL_TEST_HELPER_MODULE_NAME::Fun<$inputT, $outputT>)}).await", + + ) { + // Construct expected request. + rustWriter.withBlock("let expected = ", ";") { + instantiator.render(this, inputShape, httpRequestTestCase.params) + } + + checkRequestParams(inputShape, rustWriter) + + // Construct a dummy response. + rustWriter.withBlock("let response = ", ";") { + instantiator.render(this, outputShape, Node.objectNode(), Instantiator.defaultContext().copy(defaultsForRequiredFields = true)) + } + + if (operationShape.errors.isEmpty()) { + rustWriter.write("response") + } else { + rustWriter.write("Ok(response)") + } + } + } + + private fun checkRequestParams(inputShape: StructureShape, rustWriter: RustWriter) { if (inputShape.hasStreamingMember(model)) { // A streaming shape does not implement `PartialEq`, so we have to iterate over the input shape's members // and handle the equality assertion separately. @@ -395,7 +542,7 @@ class ServerProtocolTestGenerator( rustWriter.rustTemplate( """ #{AssertEq}( - parsed.$memberName.collect().await.unwrap().into_bytes(), + input.$memberName.collect().await.unwrap().into_bytes(), expected.$memberName.collect().await.unwrap().into_bytes() ); """, @@ -404,7 +551,7 @@ class ServerProtocolTestGenerator( } else { rustWriter.rustTemplate( """ - #{AssertEq}(parsed.$memberName, expected.$memberName, "Unexpected value for `$memberName`"); + #{AssertEq}(input.$memberName, expected.$memberName, "Unexpected value for `$memberName`"); """, *codegenScope, ) @@ -423,19 +570,21 @@ class ServerProtocolTestGenerator( when (coreCodegenContext.model.expectShape(member.target)) { is DoubleShape, is FloatShape -> { rustWriter.addUseImports( - RuntimeType.ProtocolTestHelper(coreCodegenContext.runtimeConfig, "FloatEquals").toSymbol(), + RuntimeType.ProtocolTestHelper(coreCodegenContext.runtimeConfig, "FloatEquals") + .toSymbol(), ) rustWriter.rust( """ - assert!(parsed.$memberName.float_equals(&expected.$memberName), - "Unexpected value for `$memberName` {:?} vs. {:?}", expected.$memberName, parsed.$memberName); + assert!(input.$memberName.float_equals(&expected.$memberName), + "Unexpected value for `$memberName` {:?} vs. {:?}", expected.$memberName, input.$memberName); """, ) } + else -> { rustWriter.rustTemplate( """ - #{AssertEq}(parsed.$memberName, expected.$memberName, "Unexpected value for `$memberName`"); + #{AssertEq}(input.$memberName, expected.$memberName, "Unexpected value for `$memberName`"); """, *codegenScope, ) @@ -443,7 +592,7 @@ class ServerProtocolTestGenerator( } } } else { - rustWriter.rustTemplate("#{AssertEq}(parsed, expected);", *codegenScope) + rustWriter.rustTemplate("#{AssertEq}(input, expected);", *codegenScope) } } } @@ -457,7 +606,8 @@ class ServerProtocolTestGenerator( // We can't check that the `OperationExtension` is set in the response, because it is set in the implementation // of the operation `Handler` trait, a code path that does not get exercised when we don't have a request to // invoke it with (like in the case of an `httpResponseTest` test case). - // checkHttpOperationExtension(rustWriter) + // In https://github.com/awslabs/smithy-rs/pull/1708: We did change `httpResponseTest`s generation to `call()` + // the operation handler trait implementation instead of directly calling `from_request()`. // If no request body is defined, then no assertions are made about the body of the message. if (testCase.body.isPresent) { @@ -470,11 +620,10 @@ class ServerProtocolTestGenerator( checkHeaders(rustWriter, "&http_response.headers()", testCase.headers) // We can't check that the `OperationExtension` is set in the response, because it is set in the implementation - // of the operation `Handler` trait, a code path that does not get exercised by `httpRequestTest` test cases. - // TODO(https://github.com/awslabs/smithy-rs/issues/1212): We could change test case generation so as to `call()` - // the operation handler trait implementation instead of directly calling `from_request()`, or we could run an - // actual service. - // checkHttpOperationExtension(rustWriter) + // of the operation `Handler` trait, a code path that does not get exercised when we don't have a request to + // invoke it with (like in the case of an `httpResponseTest` test case). + // In https://github.com/awslabs/smithy-rs/pull/1708: We did change `httpResponseTest`s generation to `call()` + // the operation handler trait implementation instead of directly calling `from_request()`. // If no request body is defined, then no assertions are made about the body of the message. if (testCase.body.isEmpty) return @@ -522,22 +671,6 @@ class ServerProtocolTestGenerator( } } - private fun checkHttpOperationExtension(rustWriter: RustWriter) { - rustWriter.rustTemplate( - """ - let operation_extension = http_response.extensions() - .get::<#{SmithyHttpServer}::extension::OperationExtension>() - .expect("extension `OperationExtension` not found"); - """.trimIndent(), - *codegenScope, - ) - rustWriter.writeWithNoFormatting( - """ - assert_eq!(operation_extension.absolute(), format!("{}.{}", "${operationShape.id.namespace}", "${operationSymbol.name}")); - """.trimIndent(), - ) - } - private fun checkStatusCode(rustWriter: RustWriter, statusCode: Int) { rustWriter.rustTemplate( """ @@ -782,7 +915,7 @@ class ServerProtocolTestGenerator( // or because they are flaky private val DisableTests = setOf() - private fun fixRestJsonSupportsNaNFloatQueryValues(testCase: HttpRequestTestCase): HttpRequestTestCase { + private fun fixRestJsonSupportsNaNFloatQueryValues(testCase: HttpRequestTestCase, operationShape: OperationShape): HttpRequestTestCase { val params = Node.parse( """ { @@ -798,7 +931,7 @@ class ServerProtocolTestGenerator( return testCase.toBuilder().params(params).build() } - private fun fixRestJsonSupportsInfinityFloatQueryValues(testCase: HttpRequestTestCase): HttpRequestTestCase = + private fun fixRestJsonSupportsInfinityFloatQueryValues(testCase: HttpRequestTestCase, operationShape: OperationShape): HttpRequestTestCase = testCase.toBuilder().params( Node.parse( """ @@ -813,7 +946,7 @@ class ServerProtocolTestGenerator( """.trimMargin(), ).asObjectNode().get(), ).build() - private fun fixRestJsonSupportsNegativeInfinityFloatQueryValues(testCase: HttpRequestTestCase): HttpRequestTestCase = + private fun fixRestJsonSupportsNegativeInfinityFloatQueryValues(testCase: HttpRequestTestCase, operationShape: OperationShape): HttpRequestTestCase = testCase.toBuilder().params( Node.parse( """ @@ -828,7 +961,7 @@ class ServerProtocolTestGenerator( """.trimMargin(), ).asObjectNode().get(), ).build() - private fun fixRestJsonAllQueryStringTypes(testCase: HttpRequestTestCase): HttpRequestTestCase = + private fun fixRestJsonAllQueryStringTypes(testCase: HttpRequestTestCase, operationShape: OperationShape): HttpRequestTestCase = testCase.toBuilder().params( Node.parse( """ @@ -875,7 +1008,7 @@ class ServerProtocolTestGenerator( """.trimMargin(), ).asObjectNode().get(), ).build() - private fun fixRestJsonQueryStringEscaping(testCase: HttpRequestTestCase): HttpRequestTestCase = + private fun fixRestJsonQueryStringEscaping(testCase: HttpRequestTestCase, operationShape: OperationShape): HttpRequestTestCase = testCase.toBuilder().params( Node.parse( """ @@ -889,6 +1022,9 @@ class ServerProtocolTestGenerator( ).asObjectNode().get(), ).build() + private fun fixAwsJson11MissingHeaderXAmzTarget(testCase: HttpRequestTestCase, operationShape: OperationShape): HttpRequestTestCase = + testCase.toBuilder().putHeader("x-amz-target", "JsonProtocol.${operationShape.id.name}").build() + // These are tests whose definitions in the `awslabs/smithy` repository are wrong. // This is because they have not been written from a server perspective, and as such the expected `params` field is incomplete. // TODO(https://github.com/awslabs/smithy-rs/issues/1288): Contribute a PR to fix them upstream. @@ -899,6 +1035,45 @@ class ServerProtocolTestGenerator( Pair(RestJson, "RestJsonSupportsNegativeInfinityFloatQueryValues") to ::fixRestJsonSupportsNegativeInfinityFloatQueryValues, Pair(RestJson, "RestJsonAllQueryStringTypes") to ::fixRestJsonAllQueryStringTypes, Pair(RestJson, "RestJsonQueryStringEscaping") to ::fixRestJsonQueryStringEscaping, + + // https://github.com/awslabs/smithy/pull/1392 + // Missing `X-Amz-Target` in response header + Pair(AwsJson11, "AwsJson11Enums") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "AwsJson11ListsSerializeNull") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "AwsJson11MapsSerializeNullValues") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "AwsJson11ServersDontDeserializeNullStructureValues") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "PutAndGetInlineDocumentsInput") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "json_1_1_client_sends_empty_payload_for_no_input_shape") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "json_1_1_service_supports_empty_payload_for_no_input_shape") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "sends_requests_to_slash") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_blob_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_boolean_shapes_false") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_boolean_shapes_true") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_double_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_empty_list_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_empty_map_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_empty_structure_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_float_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_integer_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_list_of_map_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_list_of_recursive_structure_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_list_of_structure_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_list_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_long_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_map_of_list_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_map_of_recursive_structure_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_map_of_structure_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_map_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_recursive_structure_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_string_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_string_shapes_with_jsonvalue_trait") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_structure_members_with_locationname_traits") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_structure_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_structure_which_have_no_members") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_timestamp_shapes") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_timestamp_shapes_with_httpdate_timestampformat") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_timestamp_shapes_with_iso8601_timestampformat") to ::fixAwsJson11MissingHeaderXAmzTarget, + Pair(AwsJson11, "serializes_timestamp_shapes_with_unixtimestamp_timestampformat") to ::fixAwsJson11MissingHeaderXAmzTarget, ) private val BrokenResponseTests: Map, KFunction1> = mapOf()