Skip to content

Commit

Permalink
Integrate new service builder into protocol tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Harry Barber committed Sep 9, 2022
1 parent 32dc9b7 commit ba8d772
Showing 1 changed file with 58 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import software.amazon.smithy.rust.codegen.client.util.inputShape
import software.amazon.smithy.rust.codegen.client.util.isStreaming
import software.amazon.smithy.rust.codegen.client.util.orNull
import software.amazon.smithy.rust.codegen.client.util.outputShape
import software.amazon.smithy.rust.codegen.client.util.toPascalCase
import software.amazon.smithy.rust.codegen.client.util.toSnakeCase
import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
Expand All @@ -75,6 +76,7 @@ class ServerProtocolTestGenerator(
private val symbolProvider = coreCodegenContext.symbolProvider
private val operationIndex = OperationIndex.of(coreCodegenContext.model)

private val serviceName = coreCodegenContext.serviceShape.id.name.toPascalCase()
private val operations = TopDownIndex.of(coreCodegenContext.model).getContainedOperations(coreCodegenContext.serviceShape).sortedBy { it.id }

private val operationInputOutputTypes = operations.associateWith {
Expand Down Expand Up @@ -357,13 +359,23 @@ class ServerProtocolTestGenerator(
rust("/* test case disabled for this protocol (not yet supported) */")
return
}

// Test against original `OperationRegistryBuilder`.
with(httpRequestTestCase) {
renderHttpRequest(uri, method, headers, body.orNull(), queryParams, host.orNull())
}
if (protocolSupport.requestBodyDeserialization) {
checkRequest(operationShape, operationSymbol, httpRequestTestCase, this)
}

// Test against new service builder.
with(httpRequestTestCase) {
renderHttpRequest(uri, method, headers, body.orNull(), queryParams, host.orNull())
}
if (protocolSupport.requestBodyDeserialization) {
checkRequest2(operationShape, operationSymbol, httpRequestTestCase, this)
}

// Explicitly warn if the test case defined parameters that we aren't doing anything with
with(httpRequestTestCase) {
if (authScheme.isPresent) {
Expand Down Expand Up @@ -495,11 +507,34 @@ class ServerProtocolTestGenerator(
}
}

private fun checkRequest(operationShape: OperationShape, operationSymbol: Symbol, httpRequestTestCase: HttpRequestTestCase, rustWriter: RustWriter) {
/** Returns the body of the request test. */
private fun checkRequestHandler(operationShape: OperationShape, httpRequestTestCase: HttpRequestTestCase) = writable {
val inputShape = operationShape.inputShape(coreCodegenContext.model)
val outputShape = operationShape.outputShape(coreCodegenContext.model)

// Construct expected request.
withBlock("let expected = ", ";") {
instantiator.render(this, inputShape, httpRequestTestCase.params)
}

checkRequestParams(inputShape, this)

// Construct a dummy response.
withBlock("let response = ", ";") {
instantiator.render(this, outputShape, Node.objectNode(), Instantiator.defaultContext().copy(defaultsForRequiredFields = true))
}

if (operationShape.errors.isEmpty()) {
write("response")
} else {
write("Ok(response)")
}
}

/** Checks the request using the `OperationRegistryBuilder`. */
private fun checkRequest(operationShape: OperationShape, operationSymbol: Symbol, httpRequestTestCase: HttpRequestTestCase, rustWriter: RustWriter) {
val (inputT, outputT) = operationInputOutputTypes[operationShape]!!

rustWriter.withBlock(
"""
super::$PROTOCOL_TEST_HELPER_MODULE_NAME::build_router_and_make_request(
Expand All @@ -509,29 +544,33 @@ class ServerProtocolTestGenerator(
builder.${operationShape.toName()}((|input| Box::pin(async move {
""",

"})) as super::$PROTOCOL_TEST_HELPER_MODULE_NAME::Fun<$inputT, $outputT>)}).await",
"})) as super::$PROTOCOL_TEST_HELPER_MODULE_NAME::Fun<$inputT, $outputT>)}).await;",

) {
// Construct expected request.
rustWriter.withBlock("let expected = ", ";") {
instantiator.render(this, inputShape, httpRequestTestCase.params)
}

checkRequestParams(inputShape, rustWriter)

// Construct a dummy response.
rustWriter.withBlock("let response = ", ";") {
instantiator.render(this, outputShape, Node.objectNode(), Instantiator.defaultContext().copy(defaultsForRequiredFields = true))
}

if (operationShape.errors.isEmpty()) {
rustWriter.write("response")
} else {
rustWriter.write("Ok(response)")
}
checkRequestHandler(operationShape, httpRequestTestCase)()
}
}

/** Checks the request using the new service builder. */
private fun checkRequest2(operationShape: OperationShape, operationSymbol: Symbol, httpRequestTestCase: HttpRequestTestCase, rustWriter: RustWriter) {
val (inputT, outputT) = operationInputOutputTypes[operationShape]!!
val operationName = RustReservedWords.escapeIfNeeded(operationSymbol.name.toSnakeCase())
rustWriter.rustTemplate(
"""
let service = crate::service::$serviceName::unchecked_builder()
.$operationName(|input: $inputT| async move {
#{Body:W}
})
.build::<#{Hyper}::body::Body>();
let http_response = #{Tower}::ServiceExt::oneshot(service, http_request)
.await
.expect("unable to make an HTTP request");
""",
"Body" to checkRequestHandler(operationShape, httpRequestTestCase),
*codegenScope,
)
}

private fun checkRequestParams(inputShape: StructureShape, rustWriter: RustWriter) {
if (inputShape.hasStreamingMember(model)) {
// A streaming shape does not implement `PartialEq`, so we have to iterate over the input shape's members
Expand Down

0 comments on commit ba8d772

Please sign in to comment.