Skip to content

Commit

Permalink
Generate endpoint params in the orchestrator codegen (#2658)
Browse files Browse the repository at this point in the history
## Motivation and Context
This PR adds the codegen logic to generate endpoint parameters in the
endpoint params interceptor.

Fixes #2644.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
  • Loading branch information
jdisanti authored and david-perez committed May 18, 2023
1 parent fdb3ecf commit 7d6d401
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 73 deletions.
8 changes: 0 additions & 8 deletions aws/sra-test/integration-tests/aws-sdk-s3/tests/sra_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
use aws_http::user_agent::AwsUserAgent;
use aws_runtime::invocation_id::InvocationId;
use aws_sdk_s3::config::{Credentials, Region};
use aws_sdk_s3::endpoint::Params;
use aws_sdk_s3::Client;
use aws_smithy_client::dvr;
use aws_smithy_client::dvr::MediaType;
Expand All @@ -31,7 +30,6 @@ async fn sra_test() {
.build();
let client = Client::from_conf(config);
let fixup = FixupPlugin {
client: client.clone(),
timestamp: UNIX_EPOCH + Duration::from_secs(1624036048),
};

Expand All @@ -52,19 +50,13 @@ async fn sra_test() {
}

struct FixupPlugin {
client: Client,
timestamp: SystemTime,
}
impl RuntimePlugin for FixupPlugin {
fn configure(
&self,
cfg: &mut ConfigBag,
) -> Result<(), aws_smithy_runtime_api::client::runtime_plugin::BoxError> {
let params_builder = Params::builder()
.set_region(self.client.conf().region().map(|c| c.as_ref().to_string()))
.bucket("test-bucket");

cfg.put(params_builder);
cfg.set_request_time(RequestTime::new(self.timestamp.clone()));
cfg.put(AwsUserAgent::for_tests());
cfg.put(InvocationId::for_tests());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,6 @@ private class EndpointParametersRuntimePluginCustomization(
section.registerInterceptor(codegenContext.runtimeConfig, this) {
rust("${operationName}EndpointParamsInterceptor")
}
// The finalizer interceptor should be registered last
section.registerInterceptor(codegenContext.runtimeConfig, this) {
rust("${operationName}EndpointParamsFinalizerInterceptor")
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class EndpointsDecorator : ClientCodegenDecorator {
override val name: String = "Endpoints"
override val order: Byte = 0

// TODO(enableNewSmithyRuntime): Remove `operationCustomizations` and `InjectEndpointInMakeOperation`
override fun operationCustomizations(
codegenContext: ClientCodegenContext,
operation: OperationShape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,17 @@

package software.amazon.smithy.rust.codegen.client.smithy.endpoint.generators

import software.amazon.smithy.model.node.BooleanNode
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.node.StringNode
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ShapeType
import software.amazon.smithy.model.traits.EndpointTrait
import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters
import software.amazon.smithy.rulesengine.traits.ContextIndex
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointTypesGenerator
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rustName
import software.amazon.smithy.rust.codegen.client.smithy.generators.EndpointTraitBindings
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
Expand All @@ -17,13 +24,17 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.inputShape
import software.amazon.smithy.rust.codegen.core.util.orNull

class EndpointParamsInterceptorGenerator(
private val codegenContext: ClientCodegenContext,
) {
private val model = codegenContext.model
private val symbolProvider = codegenContext.symbolProvider
private val endpointTypesGenerator = EndpointTypesGenerator.fromContext(codegenContext)
private val codegenScope = codegenContext.runtimeConfig.let { rc ->
val endpointTypesGenerator = EndpointTypesGenerator.fromContext(codegenContext)
val runtimeApi = CargoDependency.smithyRuntimeApi(rc).toType()
Expand All @@ -34,29 +45,19 @@ class EndpointParamsInterceptorGenerator(
"ConfigBag" to runtimeApi.resolve("config_bag::ConfigBag"),
"ContextAttachedError" to interceptors.resolve("error::ContextAttachedError"),
"EndpointResolverParams" to orchestrator.resolve("EndpointResolverParams"),
"HttpResponse" to orchestrator.resolve("HttpResponse"),
"HttpRequest" to orchestrator.resolve("HttpRequest"),
"HttpResponse" to orchestrator.resolve("HttpResponse"),
"Interceptor" to interceptors.resolve("Interceptor"),
"InterceptorContext" to interceptors.resolve("InterceptorContext"),
"InterceptorError" to interceptors.resolve("error::InterceptorError"),
"ParamsBuilder" to endpointTypesGenerator.paramsBuilder(),
"Params" to endpointTypesGenerator.paramsStruct(),
)
}

fun render(writer: RustWriter, operationShape: OperationShape) {
val operationName = symbolProvider.toSymbol(operationShape).name
renderInterceptor(
writer,
"${operationName}EndpointParamsInterceptor",
implInterceptorBodyForEndpointParams(operationShape),
)
renderInterceptor(
writer, "${operationName}EndpointParamsFinalizerInterceptor",
implInterceptorBodyForEndpointParamsFinalizer,
)
}

private fun renderInterceptor(writer: RustWriter, interceptorName: String, implBody: Writable) {
val operationInput = symbolProvider.toSymbol(operationShape.inputShape(model))
val interceptorName = "${operationName}EndpointParamsInterceptor"
writer.rustTemplate(
"""
##[derive(Debug)]
Expand All @@ -68,37 +69,78 @@ class EndpointParamsInterceptorGenerator(
context: &#{InterceptorContext}<#{HttpRequest}, #{HttpResponse}>,
cfg: &mut #{ConfigBag},
) -> Result<(), #{BoxError}> {
#{body:W}
let _input = context.input()?;
let _input = _input
.downcast_ref::<${operationInput.name}>()
.ok_or("failed to downcast to ${operationInput.name}")?;
#{endpoint_prefix:W}
// HACK: pull the handle out of the config bag until config is implemented right
let handle = cfg.get::<std::sync::Arc<crate::client::Handle>>()
.expect("the handle is hacked into the config bag");
let _config = &handle.conf;
let params = #{Params}::builder()
#{param_setters}
.build()
.map_err(|err| #{ContextAttachedError}::new("endpoint params could not be built", err))?;
cfg.put(#{EndpointResolverParams}::new(params));
Ok(())
}
}
""",
*codegenScope,
"body" to implBody,
"endpoint_prefix" to endpointPrefix(operationShape),
"param_setters" to paramSetters(operationShape, endpointTypesGenerator.params),
)
}

private fun implInterceptorBodyForEndpointParams(operationShape: OperationShape): Writable = writable {
val operationInput = symbolProvider.toSymbol(operationShape.inputShape(model))
rustTemplate(
"""
let input = context.input()?;
let _input = input
.downcast_ref::<${operationInput.name}>()
.ok_or("failed to downcast to ${operationInput.name}")?;
let params_builder = cfg
.get::<#{ParamsBuilder}>()
.ok_or("missing endpoint params builder")?
.clone();
${"" /* TODO(EndpointResolver): Call setters on `params_builder` to update its fields by using values from `_input` */}
cfg.put(params_builder);
private fun paramSetters(operationShape: OperationShape, params: Parameters) = writable {
val idx = ContextIndex.of(codegenContext.model)
val memberParams = idx.getContextParams(operationShape).toList().sortedBy { it.first.memberName }
val builtInParams = params.toList().filter { it.isBuiltIn }
// first load builtins and their defaults
builtInParams.forEach { param ->
endpointTypesGenerator.builtInFor(param, "_config")?.also { defaultValue ->
rust(".set_${param.name.rustName()}(#W)", defaultValue)
}
}

#{endpoint_prefix:W}
idx.getClientContextParams(codegenContext.serviceShape).orNull()?.parameters?.forEach { (name, param) ->
val paramName = EndpointParamsGenerator.memberName(name)
val setterName = EndpointParamsGenerator.setterName(name)
if (param.type == ShapeType.BOOLEAN) {
rust(".$setterName(_config.$paramName)")
} else {
rust(".$setterName(_config.$paramName.clone())")
}
}

Ok(())
""",
*codegenScope,
"endpoint_prefix" to endpointPrefix(operationShape),
)
idx.getStaticContextParams(operationShape).orNull()?.parameters?.forEach { (name, param) ->
val setterName = EndpointParamsGenerator.setterName(name)
val value = param.value.toWritable()
rust(".$setterName(#W)", value)
}

// lastly, allow these to be overridden by members
memberParams.forEach { (memberShape, param) ->
val memberName = codegenContext.symbolProvider.toMemberName(memberShape)
rust(
".${EndpointParamsGenerator.setterName(param.name)}(_input.$memberName.clone())",
)
}
}

private fun Node.toWritable(): Writable {
val node = this
return writable {
when (node) {
is StringNode -> rust("Some(${node.value.dq()}.to_string())")
is BooleanNode -> rust("Some(${node.value})")
else -> PANIC("unsupported default value: $node")
}
}
}

private fun endpointPrefix(operationShape: OperationShape): Writable = writable {
Expand All @@ -124,25 +166,4 @@ class EndpointParamsInterceptorGenerator(
rust("cfg.put(endpoint_prefix);")
}
}

private val implInterceptorBodyForEndpointParamsFinalizer: Writable = writable {
rustTemplate(
"""
let _ = context;
let params_builder = cfg
.get::<#{ParamsBuilder}>()
.ok_or("missing endpoint params builder")?
.clone();
let params = params_builder
.build()
.map_err(|err| #{ContextAttachedError}::new("endpoint params could not be built", err))?;
cfg.put(
#{EndpointResolverParams}::new(params)
);
Ok(())
""",
*codegenScope,
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ class ServiceRuntimePluginGenerator(
val runtimeApi = RuntimeType.smithyRuntimeApi(rc)
arrayOf(
"AnonymousIdentityResolver" to runtimeApi.resolve("client::identity::AnonymousIdentityResolver"),
"StaticAuthOptionResolver" to runtimeApi.resolve("client::auth::option_resolver::StaticAuthOptionResolver"),
"BoxError" to runtimeApi.resolve("client::runtime_plugin::BoxError"),
"ConfigBag" to runtimeApi.resolve("config_bag::ConfigBag"),
"ConfigBagAccessors" to runtimeApi.resolve("client::orchestrator::ConfigBagAccessors"),
Expand All @@ -85,6 +84,7 @@ class ServiceRuntimePluginGenerator(
"ResolveEndpoint" to http.resolve("endpoint::ResolveEndpoint"),
"RuntimePlugin" to runtimeApi.resolve("client::runtime_plugin::RuntimePlugin"),
"SharedEndpointResolver" to http.resolve("endpoint::SharedEndpointResolver"),
"StaticAuthOptionResolver" to runtimeApi.resolve("client::auth::option_resolver::StaticAuthOptionResolver"),
"TraceProbe" to runtimeApi.resolve("client::orchestrator::TraceProbe"),
)
}
Expand All @@ -106,6 +106,9 @@ class ServiceRuntimePluginGenerator(
fn configure(&self, cfg: &mut #{ConfigBag}) -> Result<(), #{BoxError}> {
use #{ConfigBagAccessors};
// HACK: Put the handle into the config bag to work around config not being fully implemented yet
cfg.put(self.handle.clone());
let http_auth_schemes = #{HttpAuthSchemes}::builder()
#{http_auth_scheme_customizations}
.build();
Expand All @@ -118,9 +121,6 @@ class ServiceRuntimePluginGenerator(
#{SharedEndpointResolver}::from(self.handle.conf.endpoint_resolver()));
cfg.set_endpoint_resolver(endpoint_resolver);
${"" /* TODO(EndpointResolver): Create endpoint params builder from service config */}
cfg.put(#{Params}::builder());
// TODO(RuntimePlugins): Wire up standard retry
cfg.set_retry_strategy(#{NeverRetryStrategy}::new());
Expand Down

0 comments on commit 7d6d401

Please sign in to comment.