diff --git a/dropshot/src/api_description.rs b/dropshot/src/api_description.rs index c371b3aa..c64109ee 100644 --- a/dropshot/src/api_description.rs +++ b/dropshot/src/api_description.rs @@ -31,6 +31,7 @@ use std::collections::HashSet; /// provided explicitly--as well as parameters and a description which can be /// inferred from function parameter types and doc comments (respectively). #[derive(Debug)] +#[non_exhaustive] pub struct ApiEndpoint { pub operation_id: String, pub handler: Box>, @@ -38,6 +39,7 @@ pub struct ApiEndpoint { pub path: String, pub parameters: Vec, pub body_content_type: ApiEndpointBodyContentType, + pub request_body_max_bytes: Option, pub response: ApiEndpointResponse, pub summary: Option, pub description: Option, @@ -72,6 +74,7 @@ impl<'a, Context: ServerContext> ApiEndpoint { path: path.to_string(), parameters: func_parameters.parameters, body_content_type, + request_body_max_bytes: None, response, summary: None, description: None, @@ -92,6 +95,11 @@ impl<'a, Context: ServerContext> ApiEndpoint { self } + pub fn request_body_max_bytes(mut self, max_bytes: usize) -> Self { + self.request_body_max_bytes = Some(max_bytes); + self + } + pub fn tag(mut self, tag: T) -> Self { self.tags.push(tag.to_string()); self diff --git a/dropshot/src/extractor/body.rs b/dropshot/src/extractor/body.rs index 82b457f3..67f40800 100644 --- a/dropshot/src/extractor/body.rs +++ b/dropshot/src/extractor/body.rs @@ -57,9 +57,8 @@ async fn http_request_load_body( where BodyType: JsonSchema + DeserializeOwned + Send + Sync, { - let server = &rqctx.server; let (parts, body) = request.into_parts(); - let body = StreamingBody::new(body, server.config.request_body_max_bytes) + let body = StreamingBody::new(body, rqctx.request_body_max_bytes) .into_bytes_mut() .await?; @@ -191,12 +190,10 @@ impl ExclusiveExtractor for UntypedBody { rqctx: &RequestContext, request: hyper::Request, ) -> Result { - let server = &rqctx.server; let body = request.into_body(); - let body_bytes = - StreamingBody::new(body, server.config.request_body_max_bytes) - .into_bytes_mut() - .await?; + let body_bytes = StreamingBody::new(body, rqctx.request_body_max_bytes) + .into_bytes_mut() + .await?; Ok(UntypedBody { content: body_bytes.freeze() }) } @@ -396,11 +393,9 @@ impl ExclusiveExtractor for StreamingBody { rqctx: &RequestContext, request: hyper::Request, ) -> Result { - let server = &rqctx.server; - Ok(Self { body: request.into_body(), - cap: server.config.request_body_max_bytes, + cap: rqctx.request_body_max_bytes, }) } diff --git a/dropshot/src/handler.rs b/dropshot/src/handler.rs index 38adc2ad..b2bbf842 100644 --- a/dropshot/src/handler.rs +++ b/dropshot/src/handler.rs @@ -72,6 +72,7 @@ pub type HttpHandlerResult = Result, HttpError>; /// Handle for various interfaces useful during request processing. #[derive(Debug)] +#[non_exhaustive] pub struct RequestContext { /// shared server state pub server: Arc>, @@ -79,6 +80,10 @@ pub struct RequestContext { pub path_variables: VariableSet, /// expected request body mime type pub body_content_type: ApiEndpointBodyContentType, + /// Maximum request body size: typically the same as + /// [`server.config.request_body_max_bytes`], but can be overridden for an + /// individual request + pub request_body_max_bytes: usize, /// unique id assigned to this request pub request_id: String, /// logger for this specific request diff --git a/dropshot/src/lib.rs b/dropshot/src/lib.rs index c19d8744..aecc2bd3 100644 --- a/dropshot/src/lib.rs +++ b/dropshot/src/lib.rs @@ -191,6 +191,9 @@ //! //! // Optional fields //! tags = [ "all", "your", "OpenAPI", "tags" ], +//! // An optional limit for the request body size that overrides the +//! // default server config. +//! request_body_max_bytes = 1048576, //! }] //! ``` //! diff --git a/dropshot/src/router.rs b/dropshot/src/router.rs index b8149d5d..9003affc 100644 --- a/dropshot/src/router.rs +++ b/dropshot/src/router.rs @@ -207,10 +207,12 @@ impl MapValue for VariableValue { /// corresponding values in the actual path, and the expected body /// content type. #[derive(Debug)] +#[non_exhaustive] pub struct RouterLookupResult<'a, Context: ServerContext> { pub handler: &'a dyn RouteHandler, pub variables: VariableSet, pub body_content_type: ApiEndpointBodyContentType, + pub request_body_max_bytes: Option, } impl HttpRouterNode { @@ -483,6 +485,7 @@ impl HttpRouter { handler: &*handler.handler, variables, body_content_type: handler.body_content_type.clone(), + request_body_max_bytes: handler.request_body_max_bytes, }) .ok_or_else(|| { HttpError::for_status(None, StatusCode::METHOD_NOT_ALLOWED) @@ -766,6 +769,7 @@ mod test { parameters: vec![], body_content_type: ApiEndpointBodyContentType::default(), response: ApiEndpointResponse::default(), + request_body_max_bytes: None, summary: None, description: None, tags: vec![], diff --git a/dropshot/src/server.rs b/dropshot/src/server.rs index 7b3519ab..a62f63a9 100644 --- a/dropshot/src/server.rs +++ b/dropshot/src/server.rs @@ -789,6 +789,9 @@ async fn http_request_handle( request: RequestInfo::from(&request), path_variables: lookup_result.variables, body_content_type: lookup_result.body_content_type, + request_body_max_bytes: lookup_result + .request_body_max_bytes + .unwrap_or(server.config.request_body_max_bytes), request_id: request_id.to_string(), log: request_log, }; diff --git a/dropshot/src/websocket.rs b/dropshot/src/websocket.rs index eeddfcf2..33f62314 100644 --- a/dropshot/src/websocket.rs +++ b/dropshot/src/websocket.rs @@ -335,6 +335,7 @@ mod tests { request: RequestInfo::from(&request), path_variables: Default::default(), body_content_type: Default::default(), + request_body_max_bytes: 0, request_id: "".to_string(), log: log.clone(), }; diff --git a/dropshot/tests/fail/bad_endpoint20.rs b/dropshot/tests/fail/bad_endpoint20.rs new file mode 100644 index 00000000..76c9492c --- /dev/null +++ b/dropshot/tests/fail/bad_endpoint20.rs @@ -0,0 +1,22 @@ +// Copyright 2023 Oxide Computer Company + +#![allow(unused_imports)] + +use dropshot::endpoint; +use dropshot::HttpError; +use dropshot::HttpResponseOk; +use dropshot::UntypedBody; + +#[endpoint { + method = GET, + path = "/test", + request_body_max_bytes = true, +}] +async fn bad_request_body_max_bytes( + _rqctx: RequestContext<()>, + param: UntypedBody, +) -> Result, HttpError> { + Ok(HttpResponseOk(())) +} + +fn main() {} diff --git a/dropshot/tests/fail/bad_endpoint20.stderr b/dropshot/tests/fail/bad_endpoint20.stderr new file mode 100644 index 00000000..c2d55b80 --- /dev/null +++ b/dropshot/tests/fail/bad_endpoint20.stderr @@ -0,0 +1,5 @@ +error: expected u64, but found `true` + --> tests/fail/bad_endpoint20.rs:13:30 + | +13 | request_body_max_bytes = true, + | ^^^^ diff --git a/dropshot/tests/fail/bad_endpoint21.rs b/dropshot/tests/fail/bad_endpoint21.rs new file mode 100644 index 00000000..f778410c --- /dev/null +++ b/dropshot/tests/fail/bad_endpoint21.rs @@ -0,0 +1,26 @@ +// Copyright 2023 Oxide Computer Company + +// This does not currently work, but we may want to support it in the future. + +#![allow(unused_imports)] + +use dropshot::endpoint; +use dropshot::HttpError; +use dropshot::HttpResponseOk; +use dropshot::UntypedBody; + +const MAX_REQUEST_BYTES: u64 = 400_000; + +#[endpoint { + method = GET, + path = "/test", + request_body_max_bytes = MAX_REQUEST_BYTES, +}] +async fn bad_request_body_max_bytes( + _rqctx: RequestContext<()>, + param: UntypedBody, +) -> Result, HttpError> { + Ok(HttpResponseOk(())) +} + +fn main() {} diff --git a/dropshot/tests/fail/bad_endpoint21.stderr b/dropshot/tests/fail/bad_endpoint21.stderr new file mode 100644 index 00000000..9f145f06 --- /dev/null +++ b/dropshot/tests/fail/bad_endpoint21.stderr @@ -0,0 +1,5 @@ +error: expected u64, but found `MAX_REQUEST_BYTES` + --> tests/fail/bad_endpoint21.rs:17:30 + | +17 | request_body_max_bytes = MAX_REQUEST_BYTES, + | ^^^^^^^^^^^^^^^^^ diff --git a/dropshot/tests/test_demo.rs b/dropshot/tests/test_demo.rs index 9e802d16..6dc57651 100644 --- a/dropshot/tests/test_demo.rs +++ b/dropshot/tests/test_demo.rs @@ -72,8 +72,11 @@ fn demo_api() -> ApiDescription { api.register(demo_handler_path_param_string).unwrap(); api.register(demo_handler_path_param_uuid).unwrap(); api.register(demo_handler_path_param_u32).unwrap(); + api.register(demo_large_typed_body).unwrap(); api.register(demo_handler_untyped_body).unwrap(); + api.register(demo_handler_large_untyped_body).unwrap(); api.register(demo_handler_streaming_body).unwrap(); + api.register(demo_handler_large_streaming_body).unwrap(); api.register(demo_handler_raw_request).unwrap(); api.register(demo_handler_delete).unwrap(); api.register(demo_handler_headers).unwrap(); @@ -644,6 +647,48 @@ async fn test_demo_path_param_u32() { testctx.teardown().await; } +// Test a `TypedBody` with a large payload. +#[tokio::test] +async fn test_large_typed_body() { + let api = demo_api(); + let testctx = common::test_setup("test_large_typed_body", api); + let client = &testctx.client_testctx; + + // This serializes to exactly 2058 bytes. + let body = DemoLargeTypedBody { body: vec![0; 1024] }; + let body_json = serde_json::to_string(&body).unwrap(); + assert_eq!(body_json.len(), 2058); + let mut response = client + .make_request_with_body( + Method::GET, + "/testing/large_typed_body", + body_json.into(), + StatusCode::OK, + ) + .await + .unwrap(); + let response_json: DemoLargeTypedBody = read_json(&mut response).await; + assert_eq!(body, response_json); + + // This serializes to 2060 bytes, which is over the limit. + let body = DemoLargeTypedBody { body: vec![0; 1025] }; + let body_json = serde_json::to_string(&body).unwrap(); + assert_eq!(body_json.len(), 2060); + let error = client + .make_request_with_body( + Method::GET, + "/testing/large_typed_body", + body_json.into(), + StatusCode::BAD_REQUEST, + ) + .await + .unwrap_err(); + assert_eq!( + error.message, + "request body exceeded maximum size of 2058 bytes" + ); +} + // Test `UntypedBody`. #[tokio::test] async fn test_untyped_body() { @@ -727,6 +772,36 @@ async fn test_untyped_body() { assert_eq!(json.nbytes, 4); assert_eq!(json.as_utf8, Some(String::from("tμv"))); + // Success case: Large body endpoint. + let large_body = vec![0u8; 2048]; + let mut response = client + .make_request_with_body( + Method::PUT, + "/testing/large_untyped_body", + large_body.into(), + StatusCode::OK, + ) + .await + .unwrap(); + let json: DemoUntyped = read_json(&mut response).await; + assert_eq!(json.nbytes, 2048); + + // Error case: Large body endpoint failure. + let large_body = vec![0u8; 2049]; + let error = client + .make_request_with_body( + Method::PUT, + "/testing/large_untyped_body", + large_body.into(), + StatusCode::BAD_REQUEST, + ) + .await + .unwrap_err(); + assert_eq!( + error.message, + "request body exceeded maximum size of 2048 bytes" + ); + testctx.teardown().await; } @@ -779,6 +854,36 @@ async fn test_streaming_body() { error.message, "request body exceeded maximum size of 1024 bytes" ); + + // Success case: Large body endpoint. + let large_body = vec![0u8; 2048]; + let mut response = client + .make_request_with_body( + Method::PUT, + "/testing/large_streaming_body", + large_body.into(), + StatusCode::OK, + ) + .await + .unwrap(); + let json: DemoUntyped = read_json(&mut response).await; + assert_eq!(json.nbytes, 2048); + + // Error case: Large body endpoint failure. + let large_body = vec![0u8; 2049]; + let error = client + .make_request_with_body( + Method::PUT, + "/testing/large_streaming_body", + large_body.into(), + StatusCode::BAD_REQUEST, + ) + .await + .unwrap_err(); + assert_eq!( + error.message, + "request body exceeded maximum size of 2048 bytes" + ); } // Test `RawRequest`. @@ -1122,6 +1227,23 @@ async fn demo_handler_path_param_u32( http_echo(&path_params.into_inner()) } +#[derive(Debug, Deserialize, Serialize, JsonSchema, Eq, PartialEq)] +pub struct DemoLargeTypedBody { + pub body: Vec, +} +#[endpoint { + method = GET, + path = "/testing/large_typed_body", + // This is 2058 rather than 2048 because that's what the test requires. + request_body_max_bytes = 2058, +}] +async fn demo_large_typed_body( + _rqctx: RequestCtx, + body: TypedBody, +) -> Result, HttpError> { + http_echo(&body.into_inner()) +} + #[derive(Deserialize, Serialize, JsonSchema)] pub struct DemoUntyped { pub nbytes: usize, @@ -1150,6 +1272,26 @@ async fn demo_handler_untyped_body( Ok(HttpResponseOk(DemoUntyped { nbytes, as_utf8 })) } +#[endpoint { + method = PUT, + path = "/testing/large_untyped_body", + request_body_max_bytes = 2048, +}] +async fn demo_handler_large_untyped_body( + _rqctx: RequestContext, + query: Query, + body: UntypedBody, +) -> Result, HttpError> { + let nbytes = body.as_bytes().len(); + let as_utf8 = if query.into_inner().parse_str.unwrap_or(false) { + Some(String::from(body.as_str()?)) + } else { + None + }; + + Ok(HttpResponseOk(DemoUntyped { nbytes, as_utf8 })) +} + #[derive(Deserialize, Serialize, JsonSchema)] pub struct DemoStreaming { pub nbytes: usize, @@ -1170,6 +1312,23 @@ async fn demo_handler_streaming_body( Ok(HttpResponseOk(DemoStreaming { nbytes })) } +#[endpoint { + method = PUT, + path = "/testing/large_streaming_body", + request_body_max_bytes = 2048, +}] +async fn demo_handler_large_streaming_body( + _rqctx: RequestContext, + body: StreamingBody, +) -> Result, HttpError> { + let nbytes = body + .into_stream() + .try_fold(0, |acc, v| futures::future::ok(acc + v.len())) + .await?; + + Ok(HttpResponseOk(DemoStreaming { nbytes })) +} + #[derive(Deserialize, Serialize, JsonSchema)] pub struct DemoRaw { pub nbytes: usize, diff --git a/dropshot_endpoint/src/lib.rs b/dropshot_endpoint/src/lib.rs index ad3c9b45..f29131b4 100644 --- a/dropshot_endpoint/src/lib.rs +++ b/dropshot_endpoint/src/lib.rs @@ -55,6 +55,8 @@ struct EndpointMetadata { unpublished: bool, #[serde(default)] deprecated: bool, + #[serde(default)] + request_body_max_bytes: Option, content_type: Option, _dropshot_crate: Option, } @@ -242,6 +244,7 @@ fn do_channel( tags, unpublished, deprecated, + request_body_max_bytes: None, content_type: Some("application/json".to_string()), _dropshot_crate, }; @@ -388,6 +391,13 @@ fn do_endpoint_inner( quote! {} }; + let request_body_max_bytes = + metadata.request_body_max_bytes.map(|max_bytes| { + quote! { + .request_body_max_bytes(#max_bytes) + } + }); + let dropshot = get_crate(metadata._dropshot_crate); let first_arg = match ast.sig.inputs.first() { @@ -591,6 +601,7 @@ fn do_endpoint_inner( #(#tags)* #visible #deprecated + #request_body_max_bytes } } else { quote! {