From cf72bb05c59c3ca094d169e924774f17972f4b2d Mon Sep 17 00:00:00 2001 From: Harold Sun Date: Fri, 15 Sep 2023 10:46:17 +0800 Subject: [PATCH] Refactor Lambda response streaming. (#696) * Refactor Lambda response streaming. Remove the separate streaming.rs from lambda-runtime crate. Merge into the `run` method. Added FunctionResponse enum to capture both buffered response and streaming response. Added IntoFunctionResponse trait to convert `Serialize` response into FunctionResponse::BufferedResponse, and convert `Stream` response into FunctionResponse::StreamingResponse. Existing handler functions should continue to work. Improved error handling in response streaming. Return trailers to report errors instead of panic. * Add comments for reporting midstream errors using error trailers * Remove "pub" from internal run method --- examples/basic-streaming-response/README.md | 2 +- examples/basic-streaming-response/src/main.rs | 18 +- lambda-http/Cargo.toml | 1 + lambda-http/src/streaming.rs | 67 ++++- lambda-runtime/Cargo.toml | 2 + lambda-runtime/src/lib.rs | 29 +- lambda-runtime/src/requests.rs | 94 +++++- lambda-runtime/src/streaming.rs | 272 ------------------ lambda-runtime/src/types.rs | 89 +++++- 9 files changed, 266 insertions(+), 308 deletions(-) delete mode 100644 lambda-runtime/src/streaming.rs diff --git a/examples/basic-streaming-response/README.md b/examples/basic-streaming-response/README.md index 3b68f518..ac744a33 100644 --- a/examples/basic-streaming-response/README.md +++ b/examples/basic-streaming-response/README.md @@ -6,7 +6,7 @@ 2. Build the function with `cargo lambda build --release` 3. Deploy the function to AWS Lambda with `cargo lambda deploy --enable-function-url --iam-role YOUR_ROLE` 4. Enable Lambda streaming response on Lambda console: change the function url's invoke mode to `RESPONSE_STREAM` -5. Verify the function works: `curl `. The results should be streamed back with 0.5 second pause between each word. +5. Verify the function works: `curl -v -N `. The results should be streamed back with 0.5 second pause between each word. ## Build for ARM 64 diff --git a/examples/basic-streaming-response/src/main.rs b/examples/basic-streaming-response/src/main.rs index d90ebd33..9d505206 100644 --- a/examples/basic-streaming-response/src/main.rs +++ b/examples/basic-streaming-response/src/main.rs @@ -1,9 +1,9 @@ -use hyper::{body::Body, Response}; -use lambda_runtime::{service_fn, Error, LambdaEvent}; +use hyper::body::Body; +use lambda_runtime::{service_fn, Error, LambdaEvent, StreamResponse}; use serde_json::Value; use std::{thread, time::Duration}; -async fn func(_event: LambdaEvent) -> Result, Error> { +async fn func(_event: LambdaEvent) -> Result, Error> { let messages = vec!["Hello", "world", "from", "Lambda!"]; let (mut tx, rx) = Body::channel(); @@ -15,12 +15,10 @@ async fn func(_event: LambdaEvent) -> Result, Error> { } }); - let resp = Response::builder() - .header("content-type", "text/html") - .header("CustomHeader", "outerspace") - .body(rx)?; - - Ok(resp) + Ok(StreamResponse { + metadata_prelude: Default::default(), + stream: rx, + }) } #[tokio::main] @@ -34,6 +32,6 @@ async fn main() -> Result<(), Error> { .without_time() .init(); - lambda_runtime::run_with_streaming_response(service_fn(func)).await?; + lambda_runtime::run(service_fn(func)).await?; Ok(()) } diff --git a/lambda-http/Cargo.toml b/lambda-http/Cargo.toml index be111092..ea4a5fba 100644 --- a/lambda-http/Cargo.toml +++ b/lambda-http/Cargo.toml @@ -33,6 +33,7 @@ lambda_runtime = { path = "../lambda-runtime", version = "0.8" } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" serde_urlencoded = "0.7" +tokio-stream = "0.1.2" mime = "0.3" encoding_rs = "0.8" url = "2.2" diff --git a/lambda-http/src/streaming.rs b/lambda-http/src/streaming.rs index 9a27d915..a59cf700 100644 --- a/lambda-http/src/streaming.rs +++ b/lambda-http/src/streaming.rs @@ -1,3 +1,4 @@ +use crate::http::header::SET_COOKIE; use crate::tower::ServiceBuilder; use crate::Request; use crate::{request::LambdaRequest, RequestExt}; @@ -5,9 +6,14 @@ pub use aws_lambda_events::encodings::Body as LambdaEventBody; use bytes::Bytes; pub use http::{self, Response}; use http_body::Body; -use lambda_runtime::LambdaEvent; -pub use lambda_runtime::{self, service_fn, tower, Context, Error, Service}; +pub use lambda_runtime::{ + self, service_fn, tower, tower::ServiceExt, Error, FunctionResponse, LambdaEvent, MetadataPrelude, Service, + StreamResponse, +}; use std::fmt::{Debug, Display}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio_stream::Stream; /// Starts the Lambda Rust runtime and stream response back [Configure Lambda /// Streaming Response](https://docs.aws.amazon.com/lambda/latest/dg/configuration-response-streaming.html). @@ -28,7 +34,60 @@ where let event: Request = req.payload.into(); event.with_lambda_context(req.context) }) - .service(handler); + .service(handler) + .map_response(|res| { + let (parts, body) = res.into_parts(); - lambda_runtime::run_with_streaming_response(svc).await + let mut prelude_headers = parts.headers; + + let cookies = prelude_headers.get_all(SET_COOKIE); + let cookies = cookies + .iter() + .map(|c| String::from_utf8_lossy(c.as_bytes()).to_string()) + .collect::>(); + + prelude_headers.remove(SET_COOKIE); + + let metadata_prelude = MetadataPrelude { + headers: prelude_headers, + status_code: parts.status, + cookies, + }; + + StreamResponse { + metadata_prelude, + stream: BodyStream { body }, + } + }); + + lambda_runtime::run(svc).await +} + +pub struct BodyStream { + pub(crate) body: B, +} + +impl BodyStream +where + B: Body + Unpin + Send + 'static, + B::Data: Into + Send, + B::Error: Into + Send + Debug, +{ + fn project(self: Pin<&mut Self>) -> Pin<&mut B> { + unsafe { self.map_unchecked_mut(|s| &mut s.body) } + } +} + +impl Stream for BodyStream +where + B: Body + Unpin + Send + 'static, + B::Data: Into + Send, + B::Error: Into + Send + Debug, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let body = self.project(); + body.poll_data(cx) + } } diff --git a/lambda-runtime/Cargo.toml b/lambda-runtime/Cargo.toml index 1137f845..9202b1c1 100644 --- a/lambda-runtime/Cargo.toml +++ b/lambda-runtime/Cargo.toml @@ -43,3 +43,5 @@ tokio-stream = "0.1.2" lambda_runtime_api_client = { version = "0.8", path = "../lambda-runtime-api-client" } serde_path_to_error = "0.1.11" http-serde = "1.1.3" +base64 = "0.20.0" +http-body = "0.4" diff --git a/lambda-runtime/src/lib.rs b/lambda-runtime/src/lib.rs index e3ffd49d..18b1066e 100644 --- a/lambda-runtime/src/lib.rs +++ b/lambda-runtime/src/lib.rs @@ -7,6 +7,7 @@ //! Create a type that conforms to the [`tower::Service`] trait. This type can //! then be passed to the the `lambda_runtime::run` function, which launches //! and runs the Lambda runtime. +use bytes::Bytes; use futures::FutureExt; use hyper::{ client::{connect::Connection, HttpConnector}, @@ -20,6 +21,7 @@ use std::{ env, fmt::{self, Debug, Display}, future::Future, + marker::PhantomData, panic, }; use tokio::io::{AsyncRead, AsyncWrite}; @@ -35,11 +37,8 @@ mod simulated; /// Types available to a Lambda function. mod types; -mod streaming; -pub use streaming::run_with_streaming_response; - use requests::{EventCompletionRequest, EventErrorRequest, IntoRequest, NextEventRequest}; -pub use types::{Context, LambdaEvent}; +pub use types::{Context, FunctionResponse, IntoFunctionResponse, LambdaEvent, MetadataPrelude, StreamResponse}; /// Error type that lambdas may result in pub type Error = lambda_runtime_api_client::Error; @@ -97,17 +96,21 @@ where C::Error: Into>, C::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static, { - async fn run( + async fn run( &self, incoming: impl Stream, Error>> + Send, mut handler: F, ) -> Result<(), Error> where F: Service>, - F::Future: Future>, + F::Future: Future>, F::Error: fmt::Debug + fmt::Display, A: for<'de> Deserialize<'de>, + R: IntoFunctionResponse, B: Serialize, + S: Stream> + Unpin + Send + 'static, + D: Into + Send, + E: Into + Send + Debug, { let client = &self.client; tokio::pin!(incoming); @@ -177,6 +180,8 @@ where EventCompletionRequest { request_id, body: response, + _unused_b: PhantomData, + _unused_s: PhantomData, } .into_req() } @@ -243,13 +248,17 @@ where /// Ok(event.payload) /// } /// ``` -pub async fn run(handler: F) -> Result<(), Error> +pub async fn run(handler: F) -> Result<(), Error> where F: Service>, - F::Future: Future>, + F::Future: Future>, F::Error: fmt::Debug + fmt::Display, A: for<'de> Deserialize<'de>, + R: IntoFunctionResponse, B: Serialize, + S: Stream> + Unpin + Send + 'static, + D: Into + Send, + E: Into + Send + Debug, { trace!("Loading config from env"); let config = Config::from_env()?; @@ -293,7 +302,7 @@ mod endpoint_tests { use lambda_runtime_api_client::Client; use serde_json::json; use simulated::DuplexStreamWrapper; - use std::{convert::TryFrom, env}; + use std::{convert::TryFrom, env, marker::PhantomData}; use tokio::{ io::{self, AsyncRead, AsyncWrite}, select, @@ -430,6 +439,8 @@ mod endpoint_tests { let req = EventCompletionRequest { request_id: "156cb537-e2d4-11e8-9b34-d36013741fb9", body: "done", + _unused_b: PhantomData::<&str>, + _unused_s: PhantomData::, }; let req = req.into_req()?; diff --git a/lambda-runtime/src/requests.rs b/lambda-runtime/src/requests.rs index 26257d20..8e72fc2d 100644 --- a/lambda-runtime/src/requests.rs +++ b/lambda-runtime/src/requests.rs @@ -1,9 +1,15 @@ -use crate::{types::Diagnostic, Error}; +use crate::types::ToStreamErrorTrailer; +use crate::{types::Diagnostic, Error, FunctionResponse, IntoFunctionResponse}; +use bytes::Bytes; +use http::header::CONTENT_TYPE; use http::{Method, Request, Response, Uri}; use hyper::Body; use lambda_runtime_api_client::build_request; use serde::Serialize; +use std::fmt::Debug; +use std::marker::PhantomData; use std::str::FromStr; +use tokio_stream::{Stream, StreamExt}; pub(crate) trait IntoRequest { fn into_req(self) -> Result, Error>; @@ -65,23 +71,87 @@ fn test_next_event_request() { } // /runtime/invocation/{AwsRequestId}/response -pub(crate) struct EventCompletionRequest<'a, T> { +pub(crate) struct EventCompletionRequest<'a, R, B, S, D, E> +where + R: IntoFunctionResponse, + B: Serialize, + S: Stream> + Unpin + Send + 'static, + D: Into + Send, + E: Into + Send + Debug, +{ pub(crate) request_id: &'a str, - pub(crate) body: T, + pub(crate) body: R, + pub(crate) _unused_b: PhantomData, + pub(crate) _unused_s: PhantomData, } -impl<'a, T> IntoRequest for EventCompletionRequest<'a, T> +impl<'a, R, B, S, D, E> IntoRequest for EventCompletionRequest<'a, R, B, S, D, E> where - T: for<'serialize> Serialize, + R: IntoFunctionResponse, + B: Serialize, + S: Stream> + Unpin + Send + 'static, + D: Into + Send, + E: Into + Send + Debug, { fn into_req(self) -> Result, Error> { - let uri = format!("/2018-06-01/runtime/invocation/{}/response", self.request_id); - let uri = Uri::from_str(&uri)?; - let body = serde_json::to_vec(&self.body)?; - let body = Body::from(body); + match self.body.into_response() { + FunctionResponse::BufferedResponse(body) => { + let uri = format!("/2018-06-01/runtime/invocation/{}/response", self.request_id); + let uri = Uri::from_str(&uri)?; - let req = build_request().method(Method::POST).uri(uri).body(body)?; - Ok(req) + let body = serde_json::to_vec(&body)?; + let body = Body::from(body); + + let req = build_request().method(Method::POST).uri(uri).body(body)?; + Ok(req) + } + FunctionResponse::StreamingResponse(mut response) => { + let uri = format!("/2018-06-01/runtime/invocation/{}/response", self.request_id); + let uri = Uri::from_str(&uri)?; + + let mut builder = build_request().method(Method::POST).uri(uri); + let req_headers = builder.headers_mut().unwrap(); + + req_headers.insert("Transfer-Encoding", "chunked".parse()?); + req_headers.insert("Lambda-Runtime-Function-Response-Mode", "streaming".parse()?); + // Report midstream errors using error trailers. + // See the details in Lambda Developer Doc: https://docs.aws.amazon.com/lambda/latest/dg/runtimes-custom.html#runtimes-custom-response-streaming + req_headers.append("Trailer", "Lambda-Runtime-Function-Error-Type".parse()?); + req_headers.append("Trailer", "Lambda-Runtime-Function-Error-Body".parse()?); + req_headers.insert( + "Content-Type", + "application/vnd.awslambda.http-integration-response".parse()?, + ); + + // default Content-Type + let preloud_headers = &mut response.metadata_prelude.headers; + preloud_headers + .entry(CONTENT_TYPE) + .or_insert("application/octet-stream".parse()?); + + let metadata_prelude = serde_json::to_string(&response.metadata_prelude)?; + + tracing::trace!(?metadata_prelude); + + let (mut tx, rx) = Body::channel(); + + tokio::spawn(async move { + tx.send_data(metadata_prelude.into()).await.unwrap(); + tx.send_data("\u{0}".repeat(8).into()).await.unwrap(); + + while let Some(chunk) = response.stream.next().await { + let chunk = match chunk { + Ok(chunk) => chunk.into(), + Err(err) => err.into().to_tailer().into(), + }; + tx.send_data(chunk).await.unwrap(); + } + }); + + let req = builder.body(rx)?; + Ok(req) + } + } } } @@ -90,6 +160,8 @@ fn test_event_completion_request() { let req = EventCompletionRequest { request_id: "id", body: "hello, world!", + _unused_b: PhantomData::<&str>, + _unused_s: PhantomData::, }; let req = req.into_req().unwrap(); let expected = Uri::from_static("/2018-06-01/runtime/invocation/id/response"); diff --git a/lambda-runtime/src/streaming.rs b/lambda-runtime/src/streaming.rs deleted file mode 100644 index 5ea369ad..00000000 --- a/lambda-runtime/src/streaming.rs +++ /dev/null @@ -1,272 +0,0 @@ -use crate::{ - build_event_error_request, deserializer, incoming, type_name_of_val, Config, Context, Error, EventErrorRequest, - IntoRequest, LambdaEvent, Runtime, -}; -use bytes::Bytes; -use futures::FutureExt; -use http::header::{CONTENT_TYPE, SET_COOKIE}; -use http::{HeaderMap, Method, Request, Response, StatusCode, Uri}; -use hyper::body::HttpBody; -use hyper::{client::connect::Connection, Body}; -use lambda_runtime_api_client::{build_request, Client}; -use serde::{Deserialize, Serialize}; -use std::str::FromStr; -use std::{ - env, - fmt::{self, Debug, Display}, - future::Future, - panic, -}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_stream::{Stream, StreamExt}; -use tower::{Service, ServiceExt}; -use tracing::{error, trace, Instrument}; - -/// Starts the Lambda Rust runtime and stream response back [Configure Lambda -/// Streaming Response](https://docs.aws.amazon.com/lambda/latest/dg/configuration-response-streaming.html). -/// -/// # Example -/// ```no_run -/// use hyper::{body::Body, Response}; -/// use lambda_runtime::{service_fn, Error, LambdaEvent}; -/// use std::{thread, time::Duration}; -/// use serde_json::Value; -/// -/// #[tokio::main] -/// async fn main() -> Result<(), Error> { -/// lambda_runtime::run_with_streaming_response(service_fn(func)).await?; -/// Ok(()) -/// } -/// async fn func(_event: LambdaEvent) -> Result, Error> { -/// let messages = vec!["Hello ", "world ", "from ", "Lambda!"]; -/// -/// let (mut tx, rx) = Body::channel(); -/// -/// tokio::spawn(async move { -/// for message in messages.iter() { -/// tx.send_data((*message).into()).await.unwrap(); -/// thread::sleep(Duration::from_millis(500)); -/// } -/// }); -/// -/// let resp = Response::builder() -/// .header("content-type", "text/plain") -/// .header("CustomHeader", "outerspace") -/// .body(rx)?; -/// -/// Ok(resp) -/// } -/// ``` -pub async fn run_with_streaming_response(handler: F) -> Result<(), Error> -where - F: Service>, - F::Future: Future, F::Error>>, - F::Error: Debug + Display, - A: for<'de> Deserialize<'de>, - B: HttpBody + Unpin + Send + 'static, - B::Data: Into + Send, - B::Error: Into + Send + Debug, -{ - trace!("Loading config from env"); - let config = Config::from_env()?; - let client = Client::builder().build().expect("Unable to create a runtime client"); - let runtime = Runtime { client, config }; - - let client = &runtime.client; - let incoming = incoming(client); - runtime.run_with_streaming_response(incoming, handler).await -} - -impl Runtime -where - C: Service + Clone + Send + Sync + Unpin + 'static, - C::Future: Unpin + Send, - C::Error: Into>, - C::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static, -{ - async fn run_with_streaming_response( - &self, - incoming: impl Stream, Error>> + Send, - mut handler: F, - ) -> Result<(), Error> - where - F: Service>, - F::Future: Future, F::Error>>, - F::Error: fmt::Debug + fmt::Display, - A: for<'de> Deserialize<'de>, - B: HttpBody + Unpin + Send + 'static, - B::Data: Into + Send, - B::Error: Into + Send + Debug, - { - let client = &self.client; - tokio::pin!(incoming); - while let Some(next_event_response) = incoming.next().await { - trace!("New event arrived (run loop)"); - let event = next_event_response?; - let (parts, body) = event.into_parts(); - - #[cfg(debug_assertions)] - if parts.status == http::StatusCode::NO_CONTENT { - // Ignore the event if the status code is 204. - // This is a way to keep the runtime alive when - // there are no events pending to be processed. - continue; - } - - let ctx: Context = Context::try_from(parts.headers)?; - let ctx: Context = ctx.with_config(&self.config); - let request_id = &ctx.request_id.clone(); - - let request_span = match &ctx.xray_trace_id { - Some(trace_id) => { - env::set_var("_X_AMZN_TRACE_ID", trace_id); - tracing::info_span!("Lambda runtime invoke", requestId = request_id, xrayTraceId = trace_id) - } - None => { - env::remove_var("_X_AMZN_TRACE_ID"); - tracing::info_span!("Lambda runtime invoke", requestId = request_id) - } - }; - - // Group the handling in one future and instrument it with the span - async { - let body = hyper::body::to_bytes(body).await?; - trace!("incoming request payload - {}", std::str::from_utf8(&body)?); - - #[cfg(debug_assertions)] - if parts.status.is_server_error() { - error!("Lambda Runtime server returned an unexpected error"); - return Err(parts.status.to_string().into()); - } - - let lambda_event = match deserializer::deserialize(&body, ctx) { - Ok(lambda_event) => lambda_event, - Err(err) => { - let req = build_event_error_request(request_id, err)?; - client.call(req).await.expect("Unable to send response to Runtime APIs"); - return Ok(()); - } - }; - - let req = match handler.ready().await { - Ok(handler) => { - // Catches panics outside of a `Future` - let task = panic::catch_unwind(panic::AssertUnwindSafe(|| handler.call(lambda_event))); - - let task = match task { - // Catches panics inside of the `Future` - Ok(task) => panic::AssertUnwindSafe(task).catch_unwind().await, - Err(err) => Err(err), - }; - - match task { - Ok(response) => match response { - Ok(response) => { - trace!("Ok response from handler (run loop)"); - EventCompletionStreamingRequest { - request_id, - body: response, - } - .into_req() - } - Err(err) => build_event_error_request(request_id, err), - }, - Err(err) => { - error!("{:?}", err); - let error_type = type_name_of_val(&err); - let msg = if let Some(msg) = err.downcast_ref::<&str>() { - format!("Lambda panicked: {msg}") - } else { - "Lambda panicked".to_string() - }; - EventErrorRequest::new(request_id, error_type, &msg).into_req() - } - } - } - Err(err) => build_event_error_request(request_id, err), - }?; - - client.call(req).await.expect("Unable to send response to Runtime APIs"); - Ok::<(), Error>(()) - } - .instrument(request_span) - .await?; - } - Ok(()) - } -} - -pub(crate) struct EventCompletionStreamingRequest<'a, B> { - pub(crate) request_id: &'a str, - pub(crate) body: Response, -} - -#[derive(Debug, Serialize)] -#[serde(rename_all = "camelCase")] -struct MetadataPrelude { - #[serde(serialize_with = "http_serde::status_code::serialize")] - status_code: StatusCode, - #[serde(serialize_with = "http_serde::header_map::serialize")] - headers: HeaderMap, - cookies: Vec, -} - -impl<'a, B> IntoRequest for EventCompletionStreamingRequest<'a, B> -where - B: HttpBody + Unpin + Send + 'static, - B::Data: Into + Send, - B::Error: Into + Send + Debug, -{ - fn into_req(self) -> Result, Error> { - let uri = format!("/2018-06-01/runtime/invocation/{}/response", self.request_id); - let uri = Uri::from_str(&uri)?; - - let (parts, mut body) = self.body.into_parts(); - - let mut builder = build_request().method(Method::POST).uri(uri); - let req_headers = builder.headers_mut().unwrap(); - - req_headers.insert("Transfer-Encoding", "chunked".parse()?); - req_headers.insert("Lambda-Runtime-Function-Response-Mode", "streaming".parse()?); - req_headers.insert( - "Content-Type", - "application/vnd.awslambda.http-integration-response".parse()?, - ); - - let mut prelude_headers = parts.headers; - // default Content-Type - prelude_headers - .entry(CONTENT_TYPE) - .or_insert("application/octet-stream".parse()?); - - let cookies = prelude_headers.get_all(SET_COOKIE); - let cookies = cookies - .iter() - .map(|c| String::from_utf8_lossy(c.as_bytes()).to_string()) - .collect::>(); - prelude_headers.remove(SET_COOKIE); - - let metadata_prelude = serde_json::to_string(&MetadataPrelude { - status_code: parts.status, - headers: prelude_headers, - cookies, - })?; - - trace!(?metadata_prelude); - - let (mut tx, rx) = Body::channel(); - - tokio::spawn(async move { - tx.send_data(metadata_prelude.into()).await.unwrap(); - tx.send_data("\u{0}".repeat(8).into()).await.unwrap(); - - while let Some(chunk) = body.data().await { - let chunk = chunk.unwrap(); - tx.send_data(chunk.into()).await.unwrap(); - } - }); - - let req = builder.body(rx)?; - Ok(req) - } -} diff --git a/lambda-runtime/src/types.rs b/lambda-runtime/src/types.rs index 87d6ded5..27a4a9ae 100644 --- a/lambda-runtime/src/types.rs +++ b/lambda-runtime/src/types.rs @@ -1,11 +1,14 @@ use crate::{Config, Error}; -use http::{HeaderMap, HeaderValue}; +use bytes::Bytes; +use http::{HeaderMap, HeaderValue, StatusCode}; use serde::{Deserialize, Serialize}; use std::{ collections::HashMap, convert::TryFrom, + fmt::Debug, time::{Duration, SystemTime}, }; +use tokio_stream::Stream; #[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] @@ -182,6 +185,90 @@ impl LambdaEvent { } } +/// Metadata prelude for a stream response. +#[derive(Debug, Default, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct MetadataPrelude { + #[serde(with = "http_serde::status_code")] + /// The HTTP status code. + pub status_code: StatusCode, + #[serde(with = "http_serde::header_map")] + /// The HTTP headers. + pub headers: HeaderMap, + /// The HTTP cookies. + pub cookies: Vec, +} + +pub trait ToStreamErrorTrailer { + /// Convert the hyper error into a stream error trailer. + fn to_tailer(&self) -> String; +} + +impl ToStreamErrorTrailer for Error { + fn to_tailer(&self) -> String { + format!( + "Lambda-Runtime-Function-Error-Type: Runtime.StreamError\r\nLambda-Runtime-Function-Error-Body: {}\r\n", + base64::encode(self.to_string()) + ) + } +} + +/// A streaming response that contains the metadata prelude and the stream of bytes that will be +/// sent to the client. +#[derive(Debug)] +pub struct StreamResponse { + /// The metadata prelude. + pub metadata_prelude: MetadataPrelude, + /// The stream of bytes that will be sent to the client. + pub stream: S, +} + +/// An enum representing the response of a function that can return either a buffered +/// response of type `B` or a streaming response of type `S`. +pub enum FunctionResponse { + /// A buffered response containing the entire payload of the response. This is useful + /// for responses that can be processed quickly and have a relatively small payload size(<= 6MB). + BufferedResponse(B), + /// A streaming response that delivers the payload incrementally. This is useful for + /// large payloads(> 6MB) or responses that take a long time to generate. The client can start + /// processing the response as soon as the first chunk is available, without waiting + /// for the entire payload to be generated. + StreamingResponse(StreamResponse), +} + +/// a trait that can be implemented for any type that can be converted into a FunctionResponse. +/// This allows us to use the `into` method to convert a type into a FunctionResponse. +pub trait IntoFunctionResponse { + /// Convert the type into a FunctionResponse. + fn into_response(self) -> FunctionResponse; +} + +impl IntoFunctionResponse for FunctionResponse { + fn into_response(self) -> FunctionResponse { + self + } +} + +impl IntoFunctionResponse for B +where + B: Serialize, +{ + fn into_response(self) -> FunctionResponse { + FunctionResponse::BufferedResponse(self) + } +} + +impl IntoFunctionResponse<(), S> for StreamResponse +where + S: Stream> + Unpin + Send + 'static, + D: Into + Send, + E: Into + Send + Debug, +{ + fn into_response(self) -> FunctionResponse<(), S> { + FunctionResponse::StreamingResponse(self) + } +} + #[cfg(test)] mod test { use super::*;