Skip to content

Commit

Permalink
Split RuntimeError and RequestRejection by protocol
Browse files Browse the repository at this point in the history
As outlined in the [Protocol Specific Errors] of the [Service Builder
Improvements RFC], `RuntimeError` should be split up into smaller,
protocol specific, errors which accurately model the failure cases of
each protocol.

The same goes for `RequestRejection`.

Closes #1703.

[Protocol Specific Errors]: https://github.com/awslabs/smithy-rs/blob/main/design/src/rfcs/rfc0020_service_builder.md#protocol-specific-errors
[Service Builder Improvements RFC]: https://github.com/awslabs/smithy-rs/blob/main/design/src/rfcs/rfc0020_service_builder.md
  • Loading branch information
david-perez committed Mar 30, 2023
1 parent 92316f7 commit c47f972
Show file tree
Hide file tree
Showing 24 changed files with 808 additions and 462 deletions.
1 change: 1 addition & 0 deletions codegen-server-test/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ val allCodegenTests = "../codegen-core/common-test-models".let { commonModels ->
CodegenTest(
"aws.protocoltests.json#JsonProtocol",
"json_rpc11",
// TODO We probably can remove these now.
extraConfig = """, "codegen": { "ignoreUnsupportedConstraints": true } """,
),
CodegenTest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,15 @@ open class ServerCodegenVisitor(

val baseModel = baselineTransform(context.model)
val service = settings.getService(baseModel)
val (protocol, generator) =
val (protocolShape, protocolGeneratorFactory) =
ServerProtocolLoader(
codegenDecorator.protocols(
service.id,
ServerProtocolLoader.DefaultProtocols,
),
)
.protocolFor(context.model, service)
protocolGeneratorFactory = generator
this.protocolGeneratorFactory = protocolGeneratorFactory

model = codegenDecorator.transformModel(service, baseModel)

Expand All @@ -145,7 +145,7 @@ open class ServerCodegenVisitor(
serverSymbolProviders.symbolProvider,
null,
service,
protocol,
protocolShape,
settings,
serverSymbolProviders.unconstrainedShapeSymbolProvider,
serverSymbolProviders.constrainedShapeSymbolProvider,
Expand All @@ -169,7 +169,7 @@ open class ServerCodegenVisitor(
settings.codegenConfig,
codegenContext.expectModuleDocProvider(),
)
protocolGenerator = protocolGeneratorFactory.buildProtocolGenerator(codegenContext)
protocolGenerator = this.protocolGeneratorFactory.buildProtocolGenerator(codegenContext)
}

/**
Expand Down Expand Up @@ -315,7 +315,12 @@ open class ServerCodegenVisitor(
writer: RustWriter,
) {
if (codegenContext.settings.codegenConfig.publicConstrainedTypes || shape.isReachableFromOperationInput()) {
val serverBuilderGenerator = ServerBuilderGenerator(codegenContext, shape, validationExceptionConversionGenerator)
val serverBuilderGenerator = ServerBuilderGenerator(
codegenContext,
shape,
validationExceptionConversionGenerator,
protocolGenerator.protocol,
)
serverBuilderGenerator.render(rustCrate, writer)

if (codegenContext.settings.codegenConfig.publicConstrainedTypes) {
Expand All @@ -336,7 +341,12 @@ open class ServerCodegenVisitor(

if (!codegenContext.settings.codegenConfig.publicConstrainedTypes) {
val serverBuilderGeneratorWithoutPublicConstrainedTypes =
ServerBuilderGeneratorWithoutPublicConstrainedTypes(codegenContext, shape, validationExceptionConversionGenerator)
ServerBuilderGeneratorWithoutPublicConstrainedTypes(
codegenContext,
shape,
validationExceptionConversionGenerator,
protocolGenerator.protocol,
)
serverBuilderGeneratorWithoutPublicConstrainedTypes.render(rustCrate, writer)

writer.implBlock(codegenContext.symbolProvider.toSymbol(shape)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,19 @@ package software.amazon.smithy.rust.codegen.server.smithy
import software.amazon.smithy.rust.codegen.core.rustlang.InlineDependency
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol

/**
* Object used *exclusively* in the runtime of the server, for separation concerns.
* Analogous to the companion object in [RuntimeType]; see its documentation for details.
* For a runtime type that is used in the client, or in both the client and the server, use [RuntimeType] directly.
*/
object ServerRuntimeType {
fun forInlineDependency(inlineDependency: InlineDependency) = RuntimeType("crate::${inlineDependency.name}", inlineDependency)
fun router(runtimeConfig: RuntimeConfig) =
ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("routing::Router")

fun router(runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("routing::Router")

fun runtimeError(runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("runtime_error::RuntimeError")

fun requestRejection(runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("rejection::RequestRejection")

fun responseRejection(runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("rejection::ResponseRejection")

fun protocol(name: String, path: String, runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("proto::$path::$name")
fun protocol(name: String, path: String, runtimeConfig: RuntimeConfig) =
ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("proto::$path::$name")

fun protocol(runtimeConfig: RuntimeConfig) = protocol("Protocol", "", runtimeConfig)
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.StringTraitI
import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.isKeyConstrained
import software.amazon.smithy.rust.codegen.server.smithy.generators.isValueConstrained
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.validationErrorMessage

/**
Expand Down Expand Up @@ -67,11 +68,7 @@ class ValidationExceptionWithReasonConversionGenerator(private val codegenContex
override val shapeId: ShapeId =
ShapeId.from(codegenContext.settings.codegenConfig.experimentalCustomValidationExceptionWithReasonPleaseDoNotUse)

override fun renderImplFromConstraintViolationForRequestRejection(): Writable = writable {
val codegenScope = arrayOf(
"RequestRejection" to ServerRuntimeType.requestRejection(codegenContext.runtimeConfig),
"From" to RuntimeType.From,
)
override fun renderImplFromConstraintViolationForRequestRejection(protocol: ServerProtocol): Writable = writable {
rustTemplate(
"""
impl #{From}<ConstraintViolation> for #{RequestRejection} {
Expand All @@ -89,7 +86,8 @@ class ValidationExceptionWithReasonConversionGenerator(private val codegenContex
}
}
""",
*codegenScope,
"RequestRejection" to protocol.requestRejection(codegenContext.runtimeConfig),
"From" to RuntimeType.From,
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.TraitInfo
import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.isKeyConstrained
import software.amazon.smithy.rust.codegen.server.smithy.generators.isValueConstrained
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.validationErrorMessage

/**
Expand Down Expand Up @@ -66,11 +67,7 @@ class SmithyValidationExceptionConversionGenerator(private val codegenContext: S
}
override val shapeId: ShapeId = SHAPE_ID

override fun renderImplFromConstraintViolationForRequestRejection(): Writable = writable {
val codegenScope = arrayOf(
"RequestRejection" to ServerRuntimeType.requestRejection(codegenContext.runtimeConfig),
"From" to RuntimeType.From,
)
override fun renderImplFromConstraintViolationForRequestRejection(protocol: ServerProtocol): Writable = writable {
rustTemplate(
"""
impl #{From}<ConstraintViolation> for #{RequestRejection} {
Expand All @@ -87,7 +84,8 @@ class SmithyValidationExceptionConversionGenerator(private val codegenContext: S
}
}
""",
*codegenScope,
"RequestRejection" to protocol.requestRejection(codegenContext.runtimeConfig),
"From" to RuntimeType.From,
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.hasConstraintTraitOrTargetHasConstraintTrait
import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape
import software.amazon.smithy.rust.codegen.server.smithy.traits.ConstraintViolationRustBoxTrait
Expand Down Expand Up @@ -92,6 +93,7 @@ class ServerBuilderGenerator(
val codegenContext: ServerCodegenContext,
private val shape: StructureShape,
private val customValidationExceptionWithReasonConversionGenerator: ValidationExceptionConversionGenerator,
private val protocol: ServerProtocol,
) {
companion object {
/**
Expand Down Expand Up @@ -148,7 +150,7 @@ class ServerBuilderGenerator(
ServerBuilderConstraintViolations(codegenContext, shape, takeInUnconstrainedTypes, customValidationExceptionWithReasonConversionGenerator)

private val codegenScope = arrayOf(
"RequestRejection" to ServerRuntimeType.requestRejection(runtimeConfig),
"RequestRejection" to protocol.requestRejection(codegenContext.runtimeConfig),
"Structure" to structureSymbol,
"From" to RuntimeType.From,
"TryFrom" to RuntimeType.TryFrom,
Expand Down Expand Up @@ -222,7 +224,8 @@ class ServerBuilderGenerator(
"""
#{Converter:W}
""",
"Converter" to customValidationExceptionWithReasonConversionGenerator.renderImplFromConstraintViolationForRequestRejection(),
"Converter" to
customValidationExceptionWithReasonConversionGenerator.renderImplFromConstraintViolationForRequestRejection(protocol),
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.makeOptional
import software.amazon.smithy.rust.codegen.core.smithy.module
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.withInMemoryInlineModule

/**
Expand All @@ -49,6 +50,7 @@ class ServerBuilderGeneratorWithoutPublicConstrainedTypes(
private val codegenContext: ServerCodegenContext,
shape: StructureShape,
validationExceptionConversionGenerator: ValidationExceptionConversionGenerator,
protocol: ServerProtocol,
) {
companion object {
/**
Expand Down Expand Up @@ -85,7 +87,7 @@ class ServerBuilderGeneratorWithoutPublicConstrainedTypes(
ServerBuilderConstraintViolations(codegenContext, shape, builderTakesInUnconstrainedTypes = false, validationExceptionConversionGenerator)

private val codegenScope = arrayOf(
"RequestRejection" to ServerRuntimeType.requestRejection(codegenContext.runtimeConfig),
"RequestRejection" to protocol.requestRejection(codegenContext.runtimeConfig),
"Structure" to structureSymbol,
"From" to RuntimeType.From,
"TryFrom" to RuntimeType.TryFrom,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol

/**
* Collection of methods that will be invoked by the respective generators to generate code to convert constraint
Expand All @@ -26,7 +27,7 @@ interface ValidationExceptionConversionGenerator {
* Convert from a top-level operation input's constraint violation into
* `aws_smithy_http_server::rejection::RequestRejection`.
*/
fun renderImplFromConstraintViolationForRequestRejection(): Writable
fun renderImplFromConstraintViolationForRequestRejection(protocol: ServerProtocol): Writable

// Simple shapes.
fun stringShapeConstraintViolationImplBlock(stringConstraintsInfo: Collection<StringTraitInfo>): Writable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,10 @@ import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerAwsJson
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerRestJsonSerializerGenerator
import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape

private fun allOperations(codegenContext: CodegenContext): List<OperationShape> {
val index = TopDownIndex.of(codegenContext.model)
return index.getContainedOperations(codegenContext.serviceShape).sortedBy { it.id }
}

interface ServerProtocol : Protocol {
/** The path such that `aws_smithy_http_server::proto::$path` points to the protocol's module. */
val protocolModulePath: String;

/** Returns the Rust marker struct enjoying `OperationShape`. */
fun markerStruct(): RuntimeType

Expand Down Expand Up @@ -76,6 +74,17 @@ interface ServerProtocol : Protocol {
* Returns a boolean indicating whether to perform this check.
*/
fun serverContentTypeCheckNoModeledInput(): Boolean = false

// TODO Docs
fun requestRejection(runtimeConfig: RuntimeConfig): RuntimeType =
ServerCargoDependency.smithyHttpServer(runtimeConfig)
.toType().resolve("proto::${protocolModulePath}::rejection::RequestRejection")
fun responseRejection(runtimeConfig: RuntimeConfig): RuntimeType =
ServerCargoDependency.smithyHttpServer(runtimeConfig)
.toType().resolve("proto::${protocolModulePath}::rejection::ResponseRejection")
fun runtimeError(runtimeConfig: RuntimeConfig): RuntimeType =
ServerCargoDependency.smithyHttpServer(runtimeConfig)
.toType().resolve("proto::${protocolModulePath}::runtime_error::RuntimeError")
}

class ServerAwsJsonProtocol(
Expand All @@ -84,6 +93,12 @@ class ServerAwsJsonProtocol(
) : AwsJson(serverCodegenContext, awsJsonVersion), ServerProtocol {
private val runtimeConfig = codegenContext.runtimeConfig

override val protocolModulePath: String
get() = when (version) {
is AwsJsonVersion.Json10 -> "aws_json_10"
is AwsJsonVersion.Json11 -> "aws_json_11"
}

override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator {
fun returnSymbolToParse(shape: Shape): ReturnSymbolToParse =
if (shape.canReachConstrainedShape(codegenContext.model, serverCodegenContext.symbolProvider)) {
Expand All @@ -107,12 +122,8 @@ class ServerAwsJsonProtocol(

override fun markerStruct(): RuntimeType {
return when (version) {
is AwsJsonVersion.Json10 -> {
ServerRuntimeType.protocol("AwsJson1_0", "aws_json_10", runtimeConfig)
}
is AwsJsonVersion.Json11 -> {
ServerRuntimeType.protocol("AwsJson1_1", "aws_json_11", runtimeConfig)
}
is AwsJsonVersion.Json10 -> ServerRuntimeType.protocol("AwsJson1_0", protocolModulePath, runtimeConfig)
is AwsJsonVersion.Json11 -> ServerRuntimeType.protocol("AwsJson1_1", protocolModulePath, runtimeConfig)
}
}

Expand All @@ -139,6 +150,16 @@ class ServerAwsJsonProtocol(
AwsJsonVersion.Json10 -> "new_aws_json_10_router"
AwsJsonVersion.Json11 -> "new_aws_json_11_router"
}

override fun requestRejection(runtimeConfig: RuntimeConfig): RuntimeType =
ServerCargoDependency.smithyHttpServer(runtimeConfig)
.toType().resolve("proto::aws_json::rejection::RequestRejection")
override fun responseRejection(runtimeConfig: RuntimeConfig): RuntimeType =
ServerCargoDependency.smithyHttpServer(runtimeConfig)
.toType().resolve("proto::aws_json::rejection::ResponseRejection")
override fun runtimeError(runtimeConfig: RuntimeConfig): RuntimeType =
ServerCargoDependency.smithyHttpServer(runtimeConfig)
.toType().resolve("proto::aws_json::runtime_error::RuntimeError")
}

private fun restRouterType(runtimeConfig: RuntimeConfig) =
Expand All @@ -150,6 +171,8 @@ class ServerRestJsonProtocol(
) : RestJson(serverCodegenContext), ServerProtocol {
val runtimeConfig = codegenContext.runtimeConfig

override val protocolModulePath: String = "rest_json_1"

override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator {
fun returnSymbolToParse(shape: Shape): ReturnSymbolToParse =
if (shape.canReachConstrainedShape(codegenContext.model, codegenContext.symbolProvider)) {
Expand All @@ -173,7 +196,8 @@ class ServerRestJsonProtocol(
override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator =
ServerRestJsonSerializerGenerator(serverCodegenContext, httpBindingResolver)

override fun markerStruct() = ServerRuntimeType.protocol("RestJson1", "rest_json_1", runtimeConfig)

override fun markerStruct() = ServerRuntimeType.protocol("RestJson1", protocolModulePath, runtimeConfig)

override fun routerType() = restRouterType(runtimeConfig)

Expand All @@ -196,8 +220,9 @@ class ServerRestXmlProtocol(
codegenContext: CodegenContext,
) : RestXml(codegenContext), ServerProtocol {
val runtimeConfig = codegenContext.runtimeConfig
override val protocolModulePath = "rest_xml"

override fun markerStruct() = ServerRuntimeType.protocol("RestXml", "rest_xml", runtimeConfig)
override fun markerStruct() = ServerRuntimeType.protocol("RestXml", protocolModulePath, runtimeConfig)

override fun routerType() = restRouterType(runtimeConfig)

Expand Down
Loading

0 comments on commit c47f972

Please sign in to comment.