diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index ebf52a91e0..221a4735c4 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -52,6 +52,7 @@ import software.amazon.smithy.rust.codegen.client.util.inputShape import software.amazon.smithy.rust.codegen.client.util.isStreaming import software.amazon.smithy.rust.codegen.client.util.orNull import software.amazon.smithy.rust.codegen.client.util.outputShape +import software.amazon.smithy.rust.codegen.client.util.toPascalCase import software.amazon.smithy.rust.codegen.client.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType @@ -75,6 +76,7 @@ class ServerProtocolTestGenerator( private val symbolProvider = coreCodegenContext.symbolProvider private val operationIndex = OperationIndex.of(coreCodegenContext.model) + private val serviceName = coreCodegenContext.serviceShape.id.name.toPascalCase() private val operations = TopDownIndex.of(coreCodegenContext.model).getContainedOperations(coreCodegenContext.serviceShape).sortedBy { it.id } private val operationInputOutputTypes = operations.associateWith { @@ -357,6 +359,8 @@ class ServerProtocolTestGenerator( rust("/* test case disabled for this protocol (not yet supported) */") return } + + // Test against original `OperationRegistryBuilder`. with(httpRequestTestCase) { renderHttpRequest(uri, method, headers, body.orNull(), queryParams, host.orNull()) } @@ -364,6 +368,14 @@ class ServerProtocolTestGenerator( checkRequest(operationShape, operationSymbol, httpRequestTestCase, this) } + // Test against new service builder. + with(httpRequestTestCase) { + renderHttpRequest(uri, method, headers, body.orNull(), queryParams, host.orNull()) + } + if (protocolSupport.requestBodyDeserialization) { + checkRequest2(operationShape, operationSymbol, httpRequestTestCase, this) + } + // Explicitly warn if the test case defined parameters that we aren't doing anything with with(httpRequestTestCase) { if (authScheme.isPresent) { @@ -495,11 +507,34 @@ class ServerProtocolTestGenerator( } } - private fun checkRequest(operationShape: OperationShape, operationSymbol: Symbol, httpRequestTestCase: HttpRequestTestCase, rustWriter: RustWriter) { + /** Returns the body of the request test. */ + private fun checkRequestHandler(operationShape: OperationShape, httpRequestTestCase: HttpRequestTestCase) = writable { val inputShape = operationShape.inputShape(coreCodegenContext.model) val outputShape = operationShape.outputShape(coreCodegenContext.model) + // Construct expected request. + withBlock("let expected = ", ";") { + instantiator.render(this, inputShape, httpRequestTestCase.params) + } + + checkRequestParams(inputShape, this) + + // Construct a dummy response. + withBlock("let response = ", ";") { + instantiator.render(this, outputShape, Node.objectNode(), Instantiator.defaultContext().copy(defaultsForRequiredFields = true)) + } + + if (operationShape.errors.isEmpty()) { + write("response") + } else { + write("Ok(response)") + } + } + + /** Checks the request using the `OperationRegistryBuilder`. */ + private fun checkRequest(operationShape: OperationShape, operationSymbol: Symbol, httpRequestTestCase: HttpRequestTestCase, rustWriter: RustWriter) { val (inputT, outputT) = operationInputOutputTypes[operationShape]!! + rustWriter.withBlock( """ super::$PROTOCOL_TEST_HELPER_MODULE_NAME::build_router_and_make_request( @@ -509,29 +544,33 @@ class ServerProtocolTestGenerator( builder.${operationShape.toName()}((|input| Box::pin(async move { """, - "})) as super::$PROTOCOL_TEST_HELPER_MODULE_NAME::Fun<$inputT, $outputT>)}).await", + "})) as super::$PROTOCOL_TEST_HELPER_MODULE_NAME::Fun<$inputT, $outputT>)}).await;", ) { - // Construct expected request. - rustWriter.withBlock("let expected = ", ";") { - instantiator.render(this, inputShape, httpRequestTestCase.params) - } - - checkRequestParams(inputShape, rustWriter) - - // Construct a dummy response. - rustWriter.withBlock("let response = ", ";") { - instantiator.render(this, outputShape, Node.objectNode(), Instantiator.defaultContext().copy(defaultsForRequiredFields = true)) - } - - if (operationShape.errors.isEmpty()) { - rustWriter.write("response") - } else { - rustWriter.write("Ok(response)") - } + checkRequestHandler(operationShape, httpRequestTestCase)() } } + /** Checks the request using the new service builder. */ + private fun checkRequest2(operationShape: OperationShape, operationSymbol: Symbol, httpRequestTestCase: HttpRequestTestCase, rustWriter: RustWriter) { + val (inputT, outputT) = operationInputOutputTypes[operationShape]!! + val operationName = RustReservedWords.escapeIfNeeded(operationSymbol.name.toSnakeCase()) + rustWriter.rustTemplate( + """ + let service = crate::service::$serviceName::unchecked_builder() + .$operationName(|input: $inputT| async move { + #{Body:W} + }) + .build::<#{Hyper}::body::Body>(); + let http_response = #{Tower}::ServiceExt::oneshot(service, http_request) + .await + .expect("unable to make an HTTP request"); + """, + "Body" to checkRequestHandler(operationShape, httpRequestTestCase), + *codegenScope, + ) + } + private fun checkRequestParams(inputShape: StructureShape, rustWriter: RustWriter) { if (inputShape.hasStreamingMember(model)) { // A streaming shape does not implement `PartialEq`, so we have to iterate over the input shape's members