Skip to content

Commit

Permalink
Lazy initialize mime accept header (#2629)
Browse files Browse the repository at this point in the history
## Description

Do not parse and initialize the mime, known at compile time, on every
request.

See
#2607 (comment)


## Checklist
<!--- If a checkbox below is not applicable, then please DELETE it
rather than leaving it unchecked -->
- [ ] I have updated `CHANGELOG.next.toml` if I made changes to the
smithy-rs codegen or runtime crates
- [ ] I have updated `CHANGELOG.next.toml` if I made changes to the AWS
SDK, generated SDK code, or SDK runtime crates

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._

---------

Signed-off-by: Daniele Ahmed <ahmeddan@amazon.de>
  • Loading branch information
82marbag authored and david-perez committed May 22, 2023
1 parent ddd8add commit a879942
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -178,18 +178,36 @@ class ServerHttpBoundProtocolTraitImplGenerator(
outputSymbol: Symbol,
operationShape: OperationShape,
) {
val operationName = symbolProvider.toSymbol(operationShape).name
val staticContentType = "CONTENT_TYPE_${operationName.uppercase()}"
val verifyAcceptHeader = writable {
httpBindingResolver.responseContentType(operationShape)?.also { contentType ->
rustTemplate(
"""
if !#{SmithyHttpServer}::protocols::accept_header_classifier(request.headers(), ${contentType.dq()}) {
if !#{SmithyHttpServer}::protocols::accept_header_classifier(request.headers(), &$staticContentType) {
return Err(#{RequestRejection}::NotAcceptable);
}
""",
*codegenScope,
)
}
}
val verifyAcceptHeaderStaticContentTypeInit = writable {
httpBindingResolver.responseContentType(operationShape)?.also { contentType ->
val init = when (contentType) {
"application/json" -> "const $staticContentType: #{Mime}::Mime = #{Mime}::APPLICATION_JSON;"
"application/octet-stream" -> "const $staticContentType: #{Mime}::Mime = #{Mime}::APPLICATION_OCTET_STREAM;"
"application/x-www-form-urlencoded" -> "const $staticContentType: #{Mime}::Mime = #{Mime}::APPLICATION_WWW_FORM_URLENCODED;"
else ->
"""
static $staticContentType: #{OnceCell}::sync::Lazy<#{Mime}::Mime> = #{OnceCell}::sync::Lazy::new(|| {
${contentType.dq()}.parse::<#{Mime}::Mime>().expect("BUG: MIME parsing failed, content_type is not valid")
});
"""
}
rustTemplate(init, *codegenScope)
}
}
val verifyRequestContentTypeHeader = writable {
operationShape
.inputShape(model)
Expand All @@ -215,6 +233,7 @@ class ServerHttpBoundProtocolTraitImplGenerator(
// TODO(https://github.com/awslabs/smithy-rs/issues/2238): Remove the `Pin<Box<dyn Future>>` and replace with thin wrapper around `Collect`.
rustTemplate(
"""
#{verifyAcceptHeaderStaticContentTypeInit:W}
#{PinProjectLite}::pin_project! {
/// A [`Future`](std::future::Future) aggregating the body bytes of a [`Request`] and constructing the
/// [`${inputSymbol.name}`](#{I}) using modelled bindings.
Expand Down Expand Up @@ -267,6 +286,7 @@ class ServerHttpBoundProtocolTraitImplGenerator(
"Marker" to protocol.markerStruct(),
"parse_request" to serverParseRequest(operationShape),
"verifyAcceptHeader" to verifyAcceptHeader,
"verifyAcceptHeaderStaticContentTypeInit" to verifyAcceptHeaderStaticContentTypeInit,
"verifyRequestContentTypeHeader" to verifyRequestContentTypeHeader,
)

Expand Down
41 changes: 29 additions & 12 deletions rust-runtime/aws-smithy-http-server/src/protocols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,10 @@ pub fn content_type_header_classifier(
Ok(())
}

pub fn accept_header_classifier(headers: &HeaderMap, content_type: &'static str) -> bool {
pub fn accept_header_classifier(headers: &HeaderMap, content_type: &mime::Mime) -> bool {
if !headers.contains_key(http::header::ACCEPT) {
return true;
}
// Must be of the form: type/subtype
let content_type = content_type
.parse::<mime::Mime>()
.expect("BUG: MIME parsing failed, content_type is not valid");
headers
.get_all(http::header::ACCEPT)
.into_iter()
Expand Down Expand Up @@ -195,41 +191,62 @@ mod tests {
#[test]
fn valid_accept_header_classifier_multiple_values() {
let valid_request = req_accept("text/strings, application/json, invalid");
assert!(accept_header_classifier(&valid_request, "application/json"));
assert!(accept_header_classifier(
&valid_request,
&"application/json".parse().unwrap()
));
}

#[test]
fn invalid_accept_header_classifier() {
let invalid_request = req_accept("text/invalid, invalid, invalid/invalid");
assert!(!accept_header_classifier(&invalid_request, "application/json"));
assert!(!accept_header_classifier(
&invalid_request,
&"application/json".parse().unwrap()
));
}

#[test]
fn valid_accept_header_classifier_star() {
let valid_request = req_accept("application/*");
assert!(accept_header_classifier(&valid_request, "application/json"));
assert!(accept_header_classifier(
&valid_request,
&"application/json".parse().unwrap()
));
}

#[test]
fn valid_accept_header_classifier_star_star() {
let valid_request = req_accept("*/*");
assert!(accept_header_classifier(&valid_request, "application/json"));
assert!(accept_header_classifier(
&valid_request,
&"application/json".parse().unwrap()
));
}

#[test]
fn valid_empty_accept_header_classifier() {
assert!(accept_header_classifier(&HeaderMap::new(), "application/json"));
assert!(accept_header_classifier(
&HeaderMap::new(),
&"application/json".parse().unwrap()
));
}

#[test]
fn valid_accept_header_classifier_with_params() {
let valid_request = req_accept("application/json; q=30, */*");
assert!(accept_header_classifier(&valid_request, "application/json"));
assert!(accept_header_classifier(
&valid_request,
&"application/json".parse().unwrap()
));
}

#[test]
fn valid_accept_header_classifier() {
let valid_request = req_accept("application/json");
assert!(accept_header_classifier(&valid_request, "application/json"));
assert!(accept_header_classifier(
&valid_request,
&"application/json".parse().unwrap()
));
}
}

0 comments on commit a879942

Please sign in to comment.