Skip to content

Commit

Permalink
Add Support for Endpoint Prefix (#420)
Browse files Browse the repository at this point in the history
* Add Support for Endpoint Prefix

I decided to allow the httpLabel fields to remain optional for now (in the future, hopefully we'll be able to code generate on their required status and delete some code!).

In other exciting news, there are now no failing protocol tests!

* Simplify code and add test for unset

* Update codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/EndpointTraitBindingGenerator.kt

* Fix clippy errors
  • Loading branch information
rcoh authored May 26, 2021
1 parent 7844b56 commit 70a3526
Show file tree
Hide file tree
Showing 10 changed files with 348 additions and 28 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/

package software.amazon.smithy.rust.codegen.smithy.customizations

import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.traits.EndpointTrait
import software.amazon.smithy.rust.codegen.rustlang.Writable
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.rustlang.withBlock
import software.amazon.smithy.rust.codegen.rustlang.writable
import software.amazon.smithy.rust.codegen.smithy.customize.OperationCustomization
import software.amazon.smithy.rust.codegen.smithy.customize.OperationSection
import software.amazon.smithy.rust.codegen.smithy.generators.EndpointTraitBindings
import software.amazon.smithy.rust.codegen.smithy.generators.OperationBuildError
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig

class EndpointPrefixGenerator(private val protocolConfig: ProtocolConfig, private val shape: OperationShape) :
OperationCustomization() {
override fun section(section: OperationSection): Writable = when (section) {
is OperationSection.MutateRequest -> writable {
shape.getTrait(EndpointTrait::class.java).map { epTrait ->
val endpointTraitBindings = EndpointTraitBindings(
protocolConfig.model,
protocolConfig.symbolProvider,
protocolConfig.runtimeConfig,
shape,
epTrait
)
val buildError = OperationBuildError(protocolConfig.runtimeConfig)
withBlock("let endpoint_prefix = ", ";") {
endpointTraitBindings.render(this, "self")
}
rustBlock("match endpoint_prefix") {
rust("Ok(prefix) => { request.config_mut().insert(prefix); },")
rust("Err(err) => return Err(${buildError.serializationError(this, "err")})")
}
}
}
else -> emptySection
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.smithy.customize
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.rust.codegen.smithy.customizations.AllowClippyLints
import software.amazon.smithy.rust.codegen.smithy.customizations.CrateVersionGenerator
import software.amazon.smithy.rust.codegen.smithy.customizations.EndpointPrefixGenerator
import software.amazon.smithy.rust.codegen.smithy.customizations.IdempotencyTokenGenerator
import software.amazon.smithy.rust.codegen.smithy.customizations.SmithyTypesPubUseGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.LibRsCustomization
Expand All @@ -26,7 +27,10 @@ class RequiredCustomizations : RustCodegenDecorator {
operation: OperationShape,
baseCustomizations: List<OperationCustomization>
): List<OperationCustomization> {
return baseCustomizations + IdempotencyTokenGenerator(protocolConfig, operation)
return baseCustomizations + IdempotencyTokenGenerator(protocolConfig, operation) + EndpointPrefixGenerator(
protocolConfig,
operation
)
}

override fun libRsCustomizations(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ fun StructureShape.builderSymbol(symbolProvider: RustSymbolProvider): RuntimeTyp

fun RuntimeConfig.operationBuildError() = RuntimeType.operationModule(this).member("BuildError")

class OperationBuildError(private val runtimeConfig: RuntimeConfig) {
fun missingField(w: RustWriter, field: String, details: String) = "${w.format(runtimeConfig.operationBuildError())}::MissingField { field: ${field.dq()}, details: ${details.dq()} }"
fun invalidField(w: RustWriter, field: String, details: String) = "${w.format(runtimeConfig.operationBuildError())}::InvalidField { field: ${field.dq()}, details: ${details.dq()}.to_string() }"
fun serializationError(w: RustWriter, error: String) = "${w.format(runtimeConfig.operationBuildError())}::SerializationError($error.into())"
}

/** setter names will never hit a reserved word and therefore never need escaping */
fun MemberShape.setterName(): String = "set_${this.memberName.toSnakeCase()}"

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/

package software.amazon.smithy.rust.codegen.smithy.generators

import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.traits.EndpointTrait
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.asType
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.generators.http.rustFormatString
import software.amazon.smithy.rust.codegen.smithy.isOptional
import software.amazon.smithy.rust.codegen.util.inputShape

fun EndpointTrait.prefixFormatString(): String {
return this.hostPrefix.rustFormatString("", "")
}

fun RuntimeConfig.smithyHttp() = CargoDependency.SmithyHttp(this).asType()

class EndpointTraitBindings(
model: Model,
private val symbolProvider: RustSymbolProvider,
private val runtimeConfig: RuntimeConfig,
operationShape: OperationShape,
private val endpointTrait: EndpointTrait
) {
private val inputShape = operationShape.inputShape(model)
private val smithyHttp = runtimeConfig.smithyHttp()
private val endpointPrefix = smithyHttp.member("endpoint::EndpointPrefix")

/**
* Render the `EndpointPrefix` struct. [input] refers to the symbol referring to the input of this operation.
*
* Generates code like:
* ```rust
* EndpointPrefix::new(format!("{}.aws.com", input.bucket));
* ```
*
* The returned expression is a `Result<EndpointPrefix, UriError>`
*/
fun render(writer: RustWriter, input: String) {
// the Rust format pattern to make the endpoint prefix eg. "{}.foo"
val formatLiteral = endpointTrait.prefixFormatString()
if (endpointTrait.hostPrefix.labels.isEmpty()) {
// if there are no labels, we don't need string formatting
writer.rustTemplate(
"#{EndpointPrefix}::new($formatLiteral)",
"EndpointPrefix" to endpointPrefix
)
} else {
val operationBuildError = OperationBuildError(runtimeConfig)
writer.rustBlock("") {
// build a list of args: `labelname = "field"`
// these eventually end up in the format! macro invocation:
// ```format!("some.{endpoint}", endpoint = endpoint);```
val args = endpointTrait.hostPrefix.labels.map { label ->
val memberShape = inputShape.getMember(label.content).get()
val field = symbolProvider.toMemberName(memberShape)
val invalidFieldError = operationBuildError.invalidField(
writer,
field,
"$field was unset or empty but must be set as part of the endpoint prefix"
)
if (symbolProvider.toSymbol(memberShape).isOptional()) {
rust("let $field = $input.$field.as_deref().unwrap_or_default();")
} else {
// NOTE: this is dead code until we start respecting @required
rust("let $field = &$input.$field;")
}
rust(
"""
if $field.is_empty() {
return Err($invalidFieldError)
}
"""
)
"${label.content} = $field"
}
writer.rustTemplate(
"#{EndpointPrefix}::new(format!($formatLiteral, ${args.joinToString()}))",
"EndpointPrefix" to endpointPrefix
)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ abstract class HttpProtocolGenerator(
}
val operationName = symbolProvider.toSymbol(operationShape).name
operationWriter.documentShape(operationShape, model)
Attribute.Derives(setOf(RuntimeType.Clone, RuntimeType.Default)).render(operationWriter)
Attribute.Derives(setOf(RuntimeType.Clone, RuntimeType.Default, RuntimeType.Debug)).render(operationWriter)
operationWriter.rustBlock("pub struct $operationName") {
write("_private: ()")
}
Expand All @@ -118,8 +118,6 @@ abstract class HttpProtocolGenerator(
traitImplementations(operationWriter, operationShape)
}

data class ResponseBody(val type: String, val mutability: String)

protected fun httpBuilderFun(implBlockWriter: RustWriter, f: RustWriter.() -> Unit) {
Attribute.Custom("allow(clippy::unnecessary_wraps)").render(implBlockWriter)
implBlockWriter.rustBlock(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,19 @@ class HttpProtocolTestGenerator(
instantiator.render(this, inputShape, httpRequestTestCase.params)

rust(""".make_operation(&config).expect("operation failed to build");""")
rust("let (http_request, _) = input.into_request_response().0.into_parts();")
rust("let (http_request, parts) = input.into_request_response().0.into_parts();")
with(httpRequestTestCase) {
host.orNull()?.also { host ->
val withScheme = "http://$host"
rust(
"""
let mut http_request = http_request;
let ep = #T::endpoint::Endpoint::mutable(#T::Uri::from_static(${withScheme.dq()}));
ep.set_endpoint(http_request.uri_mut(), parts.lock().unwrap().get());
""",
CargoDependency.SmithyHttp(protocolConfig.runtimeConfig).asType(), CargoDependency.Http.asType()
)
}
rust(
"""
assert_eq!(http_request.method(), ${method.dq()});
Expand Down Expand Up @@ -415,23 +426,10 @@ class HttpProtocolTestGenerator(
private val AwsJson11 = "aws.protocoltests.json#JsonProtocol"
private val RestJson = "aws.protocoltests.restjson#RestJson"
private val RestXml = "aws.protocoltests.restxml#RestXml"
private val ExpectFail = setOf(
// Endpoint trait https://github.com/awslabs/smithy-rs/issues/197
// This will also require running operations through the endpoint middleware (or moving endpoint middleware
// into operation construction
FailingTest(JsonRpc10, "AwsJson10EndpointTrait", Action.Request),
FailingTest(JsonRpc10, "AwsJson10EndpointTraitWithHostLabel", Action.Request),
FailingTest(AwsJson11, "AwsJson11EndpointTrait", Action.Request),
FailingTest(AwsJson11, "AwsJson11EndpointTraitWithHostLabel", Action.Request),
FailingTest(RestJson, "RestJsonEndpointTrait", Action.Request),
FailingTest(RestJson, "RestJsonEndpointTraitWithHostLabel", Action.Request),
FailingTest(RestXml, "RestXmlEndpointTraitWithHostLabelAndHttpBinding", Action.Request),
FailingTest(RestXml, "RestXmlEndpointTraitWithHostLabel", Action.Request),
FailingTest(RestXml, "RestXmlEndpointTrait", Action.Request)
)
private val ExpectFail = setOf<FailingTest>()
private val RunOnly: Set<String>? = null

// These tests are not even attempted to be compiled, either because they will not compile
// These tests are not even attempted to be generated, either because they will not compile
// or because they are flaky
private val DisableTests = setOf<String>()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ import software.amazon.smithy.rust.codegen.util.expectMember
import software.amazon.smithy.rust.codegen.util.hasTrait

fun HttpTrait.uriFormatString(): String {
val base = uri.segments.joinToString("/", prefix = "/") {
return uri.rustFormatString("/", "/")
}

fun SmithyPattern.rustFormatString(prefix: String, separator: String): String {
val base = segments.joinToString(separator = separator, prefix = prefix) {
when {
it.isLabel -> "{${it.content}}"
else -> it.content
Expand Down
Loading

0 comments on commit 70a3526

Please sign in to comment.