diff --git a/Cargo.lock b/Cargo.lock index 99f1cb0b..6451d01b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -57,9 +57,9 @@ dependencies = [ [[package]] name = "autocfg" -version = "1.0.1" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "base64" @@ -118,6 +118,15 @@ dependencies = [ "byte-tools", ] +[[package]] +name = "buf-list" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f462e45b27db47403356859af1cb4bfbbd0021cb7b7d10db6ea40958bb4e2c48" +dependencies = [ + "bytes", +] + [[package]] name = "bumpalo" version = "3.12.0" @@ -304,6 +313,7 @@ dependencies = [ "async-stream", "async-trait", "base64 0.21.0", + "buf-list", "bytes", "camino", "chrono", @@ -1039,9 +1049,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.7" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d31d11c69a6b52a174b42bdc0c30e5e11670f90788b2c471c31c1d17d449443" +checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116" [[package]] name = "pin-utils" @@ -1710,22 +1720,22 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" [[package]] name = "tokio" -version = "1.19.2" +version = "1.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c51a52ed6686dd62c320f9b89299e9dfb46f730c7a48e635c19f21d116cb1439" +checksum = "c8e00990ebabbe4c14c08aca901caed183ecd5c09562a12c824bb53d3c3fd3af" dependencies = [ + "autocfg", "bytes", "libc", "memchr", "mio", "num_cpus", - "once_cell", "parking_lot", "pin-project-lite", "signal-hook-registry", "socket2", "tokio-macros", - "winapi", + "windows-sys", ] [[package]] diff --git a/dropshot/Cargo.toml b/dropshot/Cargo.toml index fd540658..0f7946b0 100644 --- a/dropshot/Cargo.toml +++ b/dropshot/Cargo.toml @@ -76,6 +76,7 @@ version = "0.8.12" features = [ "uuid1" ] [dev-dependencies] +buf-list = "1.0.0" expectorate = "1.0.6" hyper-rustls = "0.23.2" hyper-staticfile = "0.9" @@ -111,4 +112,4 @@ features = [ "max_level_trace", "release_max_level_debug" ] version_check = "0.9.4" [features] -usdt-probes = [ "usdt/asm" ] +usdt-probes = ["usdt/asm"] diff --git a/dropshot/src/extractor/body.rs b/dropshot/src/extractor/body.rs index 1ffa566b..0a50b79c 100644 --- a/dropshot/src/extractor/body.rs +++ b/dropshot/src/extractor/body.rs @@ -6,7 +6,7 @@ use crate::api_description::ApiEndpointParameter; use crate::api_description::ApiSchemaGenerator; use crate::api_description::{ApiEndpointBodyContentType, ExtensionMode}; use crate::error::HttpError; -use crate::http_util::http_read_body; +use crate::http_util::http_dump_body; use crate::http_util::CONTENT_TYPE_JSON; use crate::schema_util::make_subschema_for; use crate::server::ServerContext; @@ -14,11 +14,17 @@ use crate::ExclusiveExtractor; use crate::ExtractorMetadata; use crate::RequestContext; use async_trait::async_trait; +use bytes::BufMut; use bytes::Bytes; +use bytes::BytesMut; +use futures::Stream; +use futures::TryStreamExt; +use hyper::body::HttpBody; use schemars::schema::InstanceType; use schemars::schema::SchemaObject; use schemars::JsonSchema; use serde::de::DeserializeOwned; +use std::convert::Infallible; use std::fmt::Debug; // TypedBody: body extractor for formats that can be deserialized to a specific @@ -46,23 +52,22 @@ impl /// to the content type, and deserialize it to an instance of `BodyType`. async fn http_request_load_body( rqctx: &RequestContext, - mut request: hyper::Request, + request: hyper::Request, ) -> Result, HttpError> where BodyType: JsonSchema + DeserializeOwned + Send + Sync, { let server = &rqctx.server; - let body = http_read_body( - request.body_mut(), - server.config.request_body_max_bytes, - ) - .await?; + let (parts, body) = request.into_parts(); + let body = StreamingBody::new(body, server.config.request_body_max_bytes) + .into_bytes_mut() + .await?; // RFC 7231 ยง3.1.1.1: media types are case insensitive and may // be followed by whitespace and/or a parameter (e.g., charset), // which we currently ignore. - let content_type = request - .headers() + let content_type = parts + .headers .get(http::header::CONTENT_TYPE) .map(|hv| { hv.to_str().map_err(|e| { @@ -184,38 +189,200 @@ impl UntypedBody { impl ExclusiveExtractor for UntypedBody { async fn from_request( rqctx: &RequestContext, - mut request: hyper::Request, + request: hyper::Request, ) -> Result { let server = &rqctx.server; - let body_bytes = http_read_body( - request.body_mut(), - server.config.request_body_max_bytes, - ) - .await?; - Ok(UntypedBody { content: body_bytes }) + let body = request.into_body(); + let body_bytes = + StreamingBody::new(body, server.config.request_body_max_bytes) + .into_bytes_mut() + .await?; + Ok(UntypedBody { content: body_bytes.freeze() }) } fn metadata( _content_type: ApiEndpointBodyContentType, ) -> ExtractorMetadata { - ExtractorMetadata { - parameters: vec![ApiEndpointParameter::new_body( - ApiEndpointBodyContentType::Bytes, - true, - ApiSchemaGenerator::Static { - schema: Box::new( - SchemaObject { - instance_type: Some(InstanceType::String.into()), - format: Some(String::from("binary")), - ..Default::default() - } - .into(), - ), - dependencies: indexmap::IndexMap::default(), - }, - vec![], - )], - extension_mode: ExtensionMode::None, + untyped_metadata() + } +} + +// StreamingBody: body extractor that provides a streaming representation of the body. + +/// An extractor for streaming the contents of the HTTP request body, making the +/// raw bytes available to the consumer. +#[derive(Debug)] +pub struct StreamingBody { + body: hyper::Body, + cap: usize, +} + +impl StreamingBody { + fn new(body: hyper::Body, cap: usize) -> Self { + Self { body, cap } + } + + /// Not part of the public API. Used only for doctests. + #[doc(hidden)] + pub fn __from_bytes(data: Bytes) -> Self { + let cap = data.len(); + let stream = futures::stream::iter([Ok::<_, Infallible>(data)]); + let body = hyper::Body::wrap_stream(stream); + Self { body, cap } + } + + /// Converts `self` into a stream. + /// + /// The `Stream` produces values of type `Result`. + /// + /// # Errors + /// + /// The stream produces an [`HttpError`] if any of the following cases occur: + /// + /// * A network error occurred. + /// * `request_body_max_bytes` was exceeded for this request. + /// + /// # Examples + /// + /// Buffer a `StreamingBody` in-memory, into a + /// [`BufList`](https://docs.rs/buf-list/latest/buf_list/struct.BufList.html) + /// (a segmented list of [`Bytes`] chunks). + /// + /// ``` + /// use buf_list::BufList; + /// use dropshot::{HttpError, StreamingBody}; + /// use futures::prelude::*; + /// # use std::iter::FromIterator; + /// + /// async fn into_buf_list(body: StreamingBody) -> Result { + /// body.into_stream().try_collect().await + /// } + /// + /// # #[tokio::main] + /// # async fn main() { + /// # let body = StreamingBody::__from_bytes(bytes::Bytes::from("foobar")); + /// # assert_eq!( + /// # into_buf_list(body).await.unwrap().into_iter().next(), + /// # Some(bytes::Bytes::from("foobar")), + /// # ); + /// # } + /// ``` + /// + /// --- + /// + /// Write a `StreamingBody` to an [`AsyncWrite`](tokio::io::AsyncWrite), + /// for example a [`tokio::fs::File`], without buffering it into memory: + /// + /// ``` + /// use dropshot::{HttpError, StreamingBody}; + /// use futures::prelude::*; + /// use tokio::io::{AsyncWrite, AsyncWriteExt}; + /// + /// async fn write_all( + /// body: StreamingBody, + /// writer: &mut W, + /// ) -> Result<(), HttpError> { + /// let stream = body.into_stream(); + /// tokio::pin!(stream); + /// + /// while let Some(res) = stream.next().await { + /// let mut data = res?; + /// writer.write_all_buf(&mut data).await.map_err(|error| { + /// HttpError::for_unavail(None, format!("write failed: {error}")) + /// })?; + /// } + /// + /// Ok(()) + /// } + /// + /// # #[tokio::main] + /// # async fn main() { + /// # let body = StreamingBody::__from_bytes(bytes::Bytes::from("foobar")); + /// # let mut writer = vec![]; + /// # write_all(body, &mut writer).await.unwrap(); + /// # assert_eq!(writer, &b"foobar"[..]); + /// # } + /// ``` + pub fn into_stream( + mut self, + ) -> impl Stream> + Send { + async_stream::try_stream! { + let mut bytes_read: usize = 0; + while let Some(buf_res) = self.body.data().await { + let buf = buf_res?; + let len = buf.len(); + + if bytes_read + len > self.cap { + http_dump_body(&mut self.body).await?; + // TODO-correctness check status code + Err(HttpError::for_bad_request( + None, + format!("request body exceeded maximum size of {} bytes", self.cap), + ))?; + } + + bytes_read += len; + yield buf; + } + + // Read the trailers as well, even though we're not going to do anything + // with them. + self.body.trailers().await?; } } + + /// Converts `self` into a [`BytesMut`], buffering the entire response in + /// memory. Not public API because most users of this should use + /// `UntypedBody` instead. + async fn into_bytes_mut(self) -> Result { + self.into_stream() + .try_fold(BytesMut::new(), |mut out, chunk| { + out.put(chunk); + futures::future::ok(out) + }) + .await + } +} + +#[async_trait] +impl ExclusiveExtractor for StreamingBody { + async fn from_request( + rqctx: &RequestContext, + request: hyper::Request, + ) -> Result { + let server = &rqctx.server; + + Ok(Self { + body: request.into_body(), + cap: server.config.request_body_max_bytes, + }) + } + + fn metadata( + _content_type: ApiEndpointBodyContentType, + ) -> ExtractorMetadata { + untyped_metadata() + } +} + +fn untyped_metadata() -> ExtractorMetadata { + ExtractorMetadata { + parameters: vec![ApiEndpointParameter::new_body( + ApiEndpointBodyContentType::Bytes, + true, + ApiSchemaGenerator::Static { + schema: Box::new( + SchemaObject { + instance_type: Some(InstanceType::String.into()), + format: Some(String::from("binary")), + ..Default::default() + } + .into(), + ), + dependencies: indexmap::IndexMap::default(), + }, + vec![], + )], + extension_mode: ExtensionMode::None, + } } diff --git a/dropshot/src/extractor/mod.rs b/dropshot/src/extractor/mod.rs index 103141c6..a9740103 100644 --- a/dropshot/src/extractor/mod.rs +++ b/dropshot/src/extractor/mod.rs @@ -11,6 +11,7 @@ pub use common::RequestExtractor; pub use common::SharedExtractor; mod body; +pub use body::StreamingBody; pub use body::TypedBody; pub use body::UntypedBody; diff --git a/dropshot/src/http_util.rs b/dropshot/src/http_util.rs index 91ce1b72..23b477ed 100644 --- a/dropshot/src/http_util.rs +++ b/dropshot/src/http_util.rs @@ -1,7 +1,6 @@ // Copyright 2020 Oxide Computer Company //! General-purpose HTTP-related facilities -use bytes::BufMut; use bytes::Bytes; use hyper::body::HttpBody; use serde::de::DeserializeOwned; @@ -21,56 +20,6 @@ pub const CONTENT_TYPE_NDJSON: &str = "application/x-ndjson"; /// MIME type for form/urlencoded data pub const CONTENT_TYPE_URL_ENCODED: &str = "application/x-www-form-urlencoded"; -/// Reads the rest of the body from the request up to the given number of bytes. -/// If the body fits within the specified cap, a buffer is returned with all the -/// bytes read. If not, an error is returned. -pub async fn http_read_body( - body: &mut T, - cap: usize, -) -> Result -where - T: HttpBody + std::marker::Unpin, -{ - // This looks a lot like the implementation of hyper::body::to_bytes(), but - // applies the requested cap. We've skipped the optimization for the - // 1-buffer case for now, as it seems likely this implementation will change - // anyway. - // TODO should this use some Stream interface instead? - // TODO why does this look so different in type signature (Data=Bytes, - // std::marker::Unpin, &mut T) - // TODO Error type shouldn't have to be hyper Error -- Into should - // work too? - // TODO do we need to use saturating_add() here? - let mut parts = std::vec::Vec::new(); - let mut nbytesread: usize = 0; - while let Some(maybebuf) = body.data().await { - let buf = maybebuf?; - let bufsize = buf.len(); - - if nbytesread + bufsize > cap { - http_dump_body(body).await?; - // TODO-correctness check status code - return Err(HttpError::for_bad_request( - None, - format!("request body exceeded maximum size of {} bytes", cap), - )); - } - - nbytesread += bufsize; - parts.put(buf); - } - - // Read the trailers as well, even though we're not going to do anything - // with them. - body.trailers().await?; - // TODO-correctness why does the is_end_stream() assertion fail and the next - // one panic? - // assert!(body.is_end_stream()); - // assert!(body.data().await.is_none()); - // assert!(body.trailers().await?.is_none()); - Ok(parts.into()) -} - /// Reads the rest of the body from the request, dropping all the bytes. This is /// useful after encountering error conditions. pub async fn http_dump_body(body: &mut T) -> Result diff --git a/dropshot/src/lib.rs b/dropshot/src/lib.rs index d9700f4f..c19d8744 100644 --- a/dropshot/src/lib.rs +++ b/dropshot/src/lib.rs @@ -212,7 +212,8 @@ //! [query_params: Query,] //! [path_params: Path

,] //! [body_param: TypedBody,] -//! [body_param: UntypedBody,] +//! [body_param: UntypedBody,] +//! [body_param: StreamingBody,] //! [raw_request: RawRequest,] //! ) -> Result //! ``` @@ -234,14 +235,17 @@ //! body as JSON (or form/url-encoded) and deserializing it into an instance //! of type `J`. `J` must implement `serde::Deserialize` and `schemars::JsonSchema`. //! * [`UntypedBody`] extracts the raw bytes of the request body. +//! * [`StreamingBody`] provides the raw bytes of the request body as a +//! [`Stream`](futures::Stream) of [`Bytes`](bytes::Bytes) chunks. //! * [`RawRequest`] provides access to the underlying [`hyper::Request`]. The //! hope is that this would generally not be needed. It can be useful to //! implement functionality not provided by Dropshot. //! -//! `Query` and `Path` impl `SharedExtractor`. `TypedBody`, `UntypedBody`, and -//! `RawRequest` impl `ExclusiveExtractor`. Your function may accept 0-3 -//! extractors, but only one can be `ExclusiveExtractor`, and it must be the -//! last one. Otherwise, the order of extractor arguments does not matter. +//! `Query` and `Path` impl `SharedExtractor`. `TypedBody`, `UntypedBody`, +//! `StreamingBody`, and `RawRequest` impl `ExclusiveExtractor`. Your function +//! may accept 0-3 extractors, but only one can be `ExclusiveExtractor`, and it +//! must be the last one. Otherwise, the order of extractor arguments does not +//! matter. //! //! If the handler accepts any extractors and the corresponding extraction //! cannot be completed, the request fails with status code 400 and an error @@ -603,6 +607,7 @@ pub use extractor::Path; pub use extractor::Query; pub use extractor::RawRequest; pub use extractor::SharedExtractor; +pub use extractor::StreamingBody; pub use extractor::TypedBody; pub use extractor::UntypedBody; pub use handler::http_response_found; diff --git a/dropshot/tests/fail/bad_endpoint1.stderr b/dropshot/tests/fail/bad_endpoint1.stderr index 7de7a399..d365cb7b 100644 --- a/dropshot/tests/fail/bad_endpoint1.stderr +++ b/dropshot/tests/fail/bad_endpoint1.stderr @@ -4,7 +4,8 @@ error: Endpoint handlers must have the following signature: [query_params: Query,] [path_params: Path

,] [body_param: TypedBody,] - [body_param: UntypedBody,] + [body_param: UntypedBody,] + [body_param: StreamingBody,] [raw_request: RawRequest,] ) -> Result --> tests/fail/bad_endpoint1.rs:20:1 diff --git a/dropshot/tests/fail/bad_endpoint11.stderr b/dropshot/tests/fail/bad_endpoint11.stderr index 79748382..bc581eef 100644 --- a/dropshot/tests/fail/bad_endpoint11.stderr +++ b/dropshot/tests/fail/bad_endpoint11.stderr @@ -4,7 +4,8 @@ error: Endpoint handlers must have the following signature: [query_params: Query,] [path_params: Path

,] [body_param: TypedBody,] - [body_param: UntypedBody,] + [body_param: UntypedBody,] + [body_param: StreamingBody,] [raw_request: RawRequest,] ) -> Result --> tests/fail/bad_endpoint11.rs:12:1 diff --git a/dropshot/tests/fail/bad_endpoint13.stderr b/dropshot/tests/fail/bad_endpoint13.stderr index 2e41d095..5fd2e58f 100644 --- a/dropshot/tests/fail/bad_endpoint13.stderr +++ b/dropshot/tests/fail/bad_endpoint13.stderr @@ -4,7 +4,8 @@ error: Endpoint handlers must have the following signature: [query_params: Query,] [path_params: Path

,] [body_param: TypedBody,] - [body_param: UntypedBody,] + [body_param: UntypedBody,] + [body_param: StreamingBody,] [raw_request: RawRequest,] ) -> Result --> tests/fail/bad_endpoint13.rs:18:1 diff --git a/dropshot/tests/fail/bad_endpoint2.stderr b/dropshot/tests/fail/bad_endpoint2.stderr index eb1e80c1..f83dafd8 100644 --- a/dropshot/tests/fail/bad_endpoint2.stderr +++ b/dropshot/tests/fail/bad_endpoint2.stderr @@ -4,7 +4,8 @@ error: Endpoint handlers must have the following signature: [query_params: Query,] [path_params: Path

,] [body_param: TypedBody,] - [body_param: UntypedBody,] + [body_param: UntypedBody,] + [body_param: StreamingBody,] [raw_request: RawRequest,] ) -> Result --> tests/fail/bad_endpoint2.rs:13:1 diff --git a/dropshot/tests/fail/bad_endpoint8.stderr b/dropshot/tests/fail/bad_endpoint8.stderr index 25518037..7bd680ae 100644 --- a/dropshot/tests/fail/bad_endpoint8.stderr +++ b/dropshot/tests/fail/bad_endpoint8.stderr @@ -4,7 +4,8 @@ error: Endpoint handlers must have the following signature: [query_params: Query,] [path_params: Path

,] [body_param: TypedBody,] - [body_param: UntypedBody,] + [body_param: UntypedBody,] + [body_param: StreamingBody,] [raw_request: RawRequest,] ) -> Result --> tests/fail/bad_endpoint8.rs:19:1 diff --git a/dropshot/tests/test_demo.rs b/dropshot/tests/test_demo.rs index 9aa8e302..41702b9c 100644 --- a/dropshot/tests/test_demo.rs +++ b/dropshot/tests/test_demo.rs @@ -36,6 +36,7 @@ use dropshot::Path; use dropshot::Query; use dropshot::RawRequest; use dropshot::RequestContext; +use dropshot::StreamingBody; use dropshot::TypedBody; use dropshot::UntypedBody; use dropshot::WebsocketChannelResult; @@ -43,6 +44,7 @@ use dropshot::WebsocketConnection; use dropshot::CONTENT_TYPE_JSON; use futures::stream::StreamExt; use futures::SinkExt; +use futures::TryStreamExt; use http::StatusCode; use hyper::Body; use hyper::Method; @@ -71,6 +73,7 @@ fn demo_api() -> ApiDescription { api.register(demo_handler_path_param_uuid).unwrap(); api.register(demo_handler_path_param_u32).unwrap(); api.register(demo_handler_untyped_body).unwrap(); + api.register(demo_handler_streaming_body).unwrap(); api.register(demo_handler_raw_request).unwrap(); api.register(demo_handler_delete).unwrap(); api.register(demo_handler_headers).unwrap(); @@ -727,6 +730,57 @@ async fn test_untyped_body() { testctx.teardown().await; } +// Test `StreamingBody`. +#[tokio::test] +async fn test_streaming_body() { + let api = demo_api(); + let testctx = common::test_setup("test_streaming_body", api); + let client = &testctx.client_testctx; + + // Success case: empty body + let mut response = client + .make_request_with_body( + Method::PUT, + "/testing/streaming_body", + "".into(), + StatusCode::OK, + ) + .await + .unwrap(); + let json: DemoStreaming = read_json(&mut response).await; + assert_eq!(json.nbytes, 0); + + // Success case: non-empty content + let body = vec![0u8; 1024]; + let mut response = client + .make_request_with_body( + Method::PUT, + "/testing/streaming_body", + body.into(), + StatusCode::OK, + ) + .await + .unwrap(); + let json: DemoStreaming = read_json(&mut response).await; + assert_eq!(json.nbytes, 1024); + + // Error case: body too large. + let big_body = vec![0u8; 1025]; + let error = client + .make_request_with_body( + Method::PUT, + "/testing/streaming_body", + big_body.into(), + StatusCode::BAD_REQUEST, + ) + .await + .unwrap_err(); + assert_eq!( + error.message, + "request body exceeded maximum size of 1024 bytes" + ); +} + // Test `RawRequest`. #[tokio::test] async fn test_raw_request() { @@ -1096,6 +1150,26 @@ async fn demo_handler_untyped_body( Ok(HttpResponseOk(DemoUntyped { nbytes, as_utf8 })) } +#[derive(Deserialize, Serialize, JsonSchema)] +pub struct DemoStreaming { + pub nbytes: usize, +} +#[endpoint { + method = PUT, + path = "/testing/streaming_body" +}] +async fn demo_handler_streaming_body( + _rqctx: RequestContext, + body: StreamingBody, +) -> Result, HttpError> { + let nbytes = body + .into_stream() + .try_fold(0, |acc, v| futures::future::ok(acc + v.len())) + .await?; + + Ok(HttpResponseOk(DemoStreaming { nbytes })) +} + #[derive(Deserialize, Serialize, JsonSchema)] pub struct DemoRaw { pub nbytes: usize, diff --git a/dropshot_endpoint/src/lib.rs b/dropshot_endpoint/src/lib.rs index 1f100fb4..d0f0fb26 100644 --- a/dropshot_endpoint/src/lib.rs +++ b/dropshot_endpoint/src/lib.rs @@ -85,7 +85,8 @@ const USAGE: &str = "Endpoint handlers must have the following signature: [query_params: Query,] [path_params: Path

,] [body_param: TypedBody,] - [body_param: UntypedBody,] + [body_param: UntypedBody,] + [body_param: StreamingBody,] [raw_request: RawRequest,] ) -> Result";