Skip to content

Commit

Permalink
Refactor Lambda response streaming. (#696)
Browse files Browse the repository at this point in the history
* Refactor Lambda response streaming.

Remove the separate streaming.rs from lambda-runtime crate. Merge into the `run` method.

Added FunctionResponse enum to capture both buffered response and streaming response.

Added IntoFunctionResponse trait to convert `Serialize` response into FunctionResponse::BufferedResponse, and convert `Stream` response into FunctionResponse::StreamingResponse. Existing handler functions should continue to work.

Improved error handling in response streaming. Return trailers to report errors instead of panic.

* Add comments for reporting midstream errors using error trailers

* Remove "pub" from internal run method
  • Loading branch information
bnusunny authored Sep 15, 2023
1 parent e2d51ad commit cf72bb0
Show file tree
Hide file tree
Showing 9 changed files with 266 additions and 308 deletions.
2 changes: 1 addition & 1 deletion examples/basic-streaming-response/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
2. Build the function with `cargo lambda build --release`
3. Deploy the function to AWS Lambda with `cargo lambda deploy --enable-function-url --iam-role YOUR_ROLE`
4. Enable Lambda streaming response on Lambda console: change the function url's invoke mode to `RESPONSE_STREAM`
5. Verify the function works: `curl <function-url>`. The results should be streamed back with 0.5 second pause between each word.
5. Verify the function works: `curl -v -N <function-url>`. The results should be streamed back with 0.5 second pause between each word.

## Build for ARM 64

Expand Down
18 changes: 8 additions & 10 deletions examples/basic-streaming-response/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use hyper::{body::Body, Response};
use lambda_runtime::{service_fn, Error, LambdaEvent};
use hyper::body::Body;
use lambda_runtime::{service_fn, Error, LambdaEvent, StreamResponse};
use serde_json::Value;
use std::{thread, time::Duration};

async fn func(_event: LambdaEvent<Value>) -> Result<Response<Body>, Error> {
async fn func(_event: LambdaEvent<Value>) -> Result<StreamResponse<Body>, Error> {
let messages = vec!["Hello", "world", "from", "Lambda!"];

let (mut tx, rx) = Body::channel();
Expand All @@ -15,12 +15,10 @@ async fn func(_event: LambdaEvent<Value>) -> Result<Response<Body>, Error> {
}
});

let resp = Response::builder()
.header("content-type", "text/html")
.header("CustomHeader", "outerspace")
.body(rx)?;

Ok(resp)
Ok(StreamResponse {
metadata_prelude: Default::default(),
stream: rx,
})
}

#[tokio::main]
Expand All @@ -34,6 +32,6 @@ async fn main() -> Result<(), Error> {
.without_time()
.init();

lambda_runtime::run_with_streaming_response(service_fn(func)).await?;
lambda_runtime::run(service_fn(func)).await?;
Ok(())
}
1 change: 1 addition & 0 deletions lambda-http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ lambda_runtime = { path = "../lambda-runtime", version = "0.8" }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
serde_urlencoded = "0.7"
tokio-stream = "0.1.2"
mime = "0.3"
encoding_rs = "0.8"
url = "2.2"
Expand Down
67 changes: 63 additions & 4 deletions lambda-http/src/streaming.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
use crate::http::header::SET_COOKIE;
use crate::tower::ServiceBuilder;
use crate::Request;
use crate::{request::LambdaRequest, RequestExt};
pub use aws_lambda_events::encodings::Body as LambdaEventBody;
use bytes::Bytes;
pub use http::{self, Response};
use http_body::Body;
use lambda_runtime::LambdaEvent;
pub use lambda_runtime::{self, service_fn, tower, Context, Error, Service};
pub use lambda_runtime::{
self, service_fn, tower, tower::ServiceExt, Error, FunctionResponse, LambdaEvent, MetadataPrelude, Service,
StreamResponse,
};
use std::fmt::{Debug, Display};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio_stream::Stream;

/// Starts the Lambda Rust runtime and stream response back [Configure Lambda
/// Streaming Response](https://docs.aws.amazon.com/lambda/latest/dg/configuration-response-streaming.html).
Expand All @@ -28,7 +34,60 @@ where
let event: Request = req.payload.into();
event.with_lambda_context(req.context)
})
.service(handler);
.service(handler)
.map_response(|res| {
let (parts, body) = res.into_parts();

lambda_runtime::run_with_streaming_response(svc).await
let mut prelude_headers = parts.headers;

let cookies = prelude_headers.get_all(SET_COOKIE);
let cookies = cookies
.iter()
.map(|c| String::from_utf8_lossy(c.as_bytes()).to_string())
.collect::<Vec<String>>();

prelude_headers.remove(SET_COOKIE);

let metadata_prelude = MetadataPrelude {
headers: prelude_headers,
status_code: parts.status,
cookies,
};

StreamResponse {
metadata_prelude,
stream: BodyStream { body },
}
});

lambda_runtime::run(svc).await
}

pub struct BodyStream<B> {
pub(crate) body: B,
}

impl<B> BodyStream<B>
where
B: Body + Unpin + Send + 'static,
B::Data: Into<Bytes> + Send,
B::Error: Into<Error> + Send + Debug,
{
fn project(self: Pin<&mut Self>) -> Pin<&mut B> {
unsafe { self.map_unchecked_mut(|s| &mut s.body) }
}
}

impl<B> Stream for BodyStream<B>
where
B: Body + Unpin + Send + 'static,
B::Data: Into<Bytes> + Send,
B::Error: Into<Error> + Send + Debug,
{
type Item = Result<B::Data, B::Error>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let body = self.project();
body.poll_data(cx)
}
}
2 changes: 2 additions & 0 deletions lambda-runtime/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,5 @@ tokio-stream = "0.1.2"
lambda_runtime_api_client = { version = "0.8", path = "../lambda-runtime-api-client" }
serde_path_to_error = "0.1.11"
http-serde = "1.1.3"
base64 = "0.20.0"
http-body = "0.4"
29 changes: 20 additions & 9 deletions lambda-runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//! Create a type that conforms to the [`tower::Service`] trait. This type can
//! then be passed to the the `lambda_runtime::run` function, which launches
//! and runs the Lambda runtime.
use bytes::Bytes;
use futures::FutureExt;
use hyper::{
client::{connect::Connection, HttpConnector},
Expand All @@ -20,6 +21,7 @@ use std::{
env,
fmt::{self, Debug, Display},
future::Future,
marker::PhantomData,
panic,
};
use tokio::io::{AsyncRead, AsyncWrite};
Expand All @@ -35,11 +37,8 @@ mod simulated;
/// Types available to a Lambda function.
mod types;

mod streaming;
pub use streaming::run_with_streaming_response;

use requests::{EventCompletionRequest, EventErrorRequest, IntoRequest, NextEventRequest};
pub use types::{Context, LambdaEvent};
pub use types::{Context, FunctionResponse, IntoFunctionResponse, LambdaEvent, MetadataPrelude, StreamResponse};

/// Error type that lambdas may result in
pub type Error = lambda_runtime_api_client::Error;
Expand Down Expand Up @@ -97,17 +96,21 @@ where
C::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
C::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static,
{
async fn run<F, A, B>(
async fn run<F, A, R, B, S, D, E>(
&self,
incoming: impl Stream<Item = Result<http::Response<hyper::Body>, Error>> + Send,
mut handler: F,
) -> Result<(), Error>
where
F: Service<LambdaEvent<A>>,
F::Future: Future<Output = Result<B, F::Error>>,
F::Future: Future<Output = Result<R, F::Error>>,
F::Error: fmt::Debug + fmt::Display,
A: for<'de> Deserialize<'de>,
R: IntoFunctionResponse<B, S>,
B: Serialize,
S: Stream<Item = Result<D, E>> + Unpin + Send + 'static,
D: Into<Bytes> + Send,
E: Into<Error> + Send + Debug,
{
let client = &self.client;
tokio::pin!(incoming);
Expand Down Expand Up @@ -177,6 +180,8 @@ where
EventCompletionRequest {
request_id,
body: response,
_unused_b: PhantomData,
_unused_s: PhantomData,
}
.into_req()
}
Expand Down Expand Up @@ -243,13 +248,17 @@ where
/// Ok(event.payload)
/// }
/// ```
pub async fn run<A, B, F>(handler: F) -> Result<(), Error>
pub async fn run<A, F, R, B, S, D, E>(handler: F) -> Result<(), Error>
where
F: Service<LambdaEvent<A>>,
F::Future: Future<Output = Result<B, F::Error>>,
F::Future: Future<Output = Result<R, F::Error>>,
F::Error: fmt::Debug + fmt::Display,
A: for<'de> Deserialize<'de>,
R: IntoFunctionResponse<B, S>,
B: Serialize,
S: Stream<Item = Result<D, E>> + Unpin + Send + 'static,
D: Into<Bytes> + Send,
E: Into<Error> + Send + Debug,
{
trace!("Loading config from env");
let config = Config::from_env()?;
Expand Down Expand Up @@ -293,7 +302,7 @@ mod endpoint_tests {
use lambda_runtime_api_client::Client;
use serde_json::json;
use simulated::DuplexStreamWrapper;
use std::{convert::TryFrom, env};
use std::{convert::TryFrom, env, marker::PhantomData};
use tokio::{
io::{self, AsyncRead, AsyncWrite},
select,
Expand Down Expand Up @@ -430,6 +439,8 @@ mod endpoint_tests {
let req = EventCompletionRequest {
request_id: "156cb537-e2d4-11e8-9b34-d36013741fb9",
body: "done",
_unused_b: PhantomData::<&str>,
_unused_s: PhantomData::<Body>,
};
let req = req.into_req()?;

Expand Down
94 changes: 83 additions & 11 deletions lambda-runtime/src/requests.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
use crate::{types::Diagnostic, Error};
use crate::types::ToStreamErrorTrailer;
use crate::{types::Diagnostic, Error, FunctionResponse, IntoFunctionResponse};
use bytes::Bytes;
use http::header::CONTENT_TYPE;
use http::{Method, Request, Response, Uri};
use hyper::Body;
use lambda_runtime_api_client::build_request;
use serde::Serialize;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::str::FromStr;
use tokio_stream::{Stream, StreamExt};

pub(crate) trait IntoRequest {
fn into_req(self) -> Result<Request<Body>, Error>;
Expand Down Expand Up @@ -65,23 +71,87 @@ fn test_next_event_request() {
}

// /runtime/invocation/{AwsRequestId}/response
pub(crate) struct EventCompletionRequest<'a, T> {
pub(crate) struct EventCompletionRequest<'a, R, B, S, D, E>
where
R: IntoFunctionResponse<B, S>,
B: Serialize,
S: Stream<Item = Result<D, E>> + Unpin + Send + 'static,
D: Into<Bytes> + Send,
E: Into<Error> + Send + Debug,
{
pub(crate) request_id: &'a str,
pub(crate) body: T,
pub(crate) body: R,
pub(crate) _unused_b: PhantomData<B>,
pub(crate) _unused_s: PhantomData<S>,
}

impl<'a, T> IntoRequest for EventCompletionRequest<'a, T>
impl<'a, R, B, S, D, E> IntoRequest for EventCompletionRequest<'a, R, B, S, D, E>
where
T: for<'serialize> Serialize,
R: IntoFunctionResponse<B, S>,
B: Serialize,
S: Stream<Item = Result<D, E>> + Unpin + Send + 'static,
D: Into<Bytes> + Send,
E: Into<Error> + Send + Debug,
{
fn into_req(self) -> Result<Request<Body>, Error> {
let uri = format!("/2018-06-01/runtime/invocation/{}/response", self.request_id);
let uri = Uri::from_str(&uri)?;
let body = serde_json::to_vec(&self.body)?;
let body = Body::from(body);
match self.body.into_response() {
FunctionResponse::BufferedResponse(body) => {
let uri = format!("/2018-06-01/runtime/invocation/{}/response", self.request_id);
let uri = Uri::from_str(&uri)?;

let req = build_request().method(Method::POST).uri(uri).body(body)?;
Ok(req)
let body = serde_json::to_vec(&body)?;
let body = Body::from(body);

let req = build_request().method(Method::POST).uri(uri).body(body)?;
Ok(req)
}
FunctionResponse::StreamingResponse(mut response) => {
let uri = format!("/2018-06-01/runtime/invocation/{}/response", self.request_id);
let uri = Uri::from_str(&uri)?;

let mut builder = build_request().method(Method::POST).uri(uri);
let req_headers = builder.headers_mut().unwrap();

req_headers.insert("Transfer-Encoding", "chunked".parse()?);
req_headers.insert("Lambda-Runtime-Function-Response-Mode", "streaming".parse()?);
// Report midstream errors using error trailers.
// See the details in Lambda Developer Doc: https://docs.aws.amazon.com/lambda/latest/dg/runtimes-custom.html#runtimes-custom-response-streaming
req_headers.append("Trailer", "Lambda-Runtime-Function-Error-Type".parse()?);
req_headers.append("Trailer", "Lambda-Runtime-Function-Error-Body".parse()?);
req_headers.insert(
"Content-Type",
"application/vnd.awslambda.http-integration-response".parse()?,
);

// default Content-Type
let preloud_headers = &mut response.metadata_prelude.headers;
preloud_headers
.entry(CONTENT_TYPE)
.or_insert("application/octet-stream".parse()?);

let metadata_prelude = serde_json::to_string(&response.metadata_prelude)?;

tracing::trace!(?metadata_prelude);

let (mut tx, rx) = Body::channel();

tokio::spawn(async move {
tx.send_data(metadata_prelude.into()).await.unwrap();
tx.send_data("\u{0}".repeat(8).into()).await.unwrap();

while let Some(chunk) = response.stream.next().await {
let chunk = match chunk {
Ok(chunk) => chunk.into(),
Err(err) => err.into().to_tailer().into(),
};
tx.send_data(chunk).await.unwrap();
}
});

let req = builder.body(rx)?;
Ok(req)
}
}
}
}

Expand All @@ -90,6 +160,8 @@ fn test_event_completion_request() {
let req = EventCompletionRequest {
request_id: "id",
body: "hello, world!",
_unused_b: PhantomData::<&str>,
_unused_s: PhantomData::<Body>,
};
let req = req.into_req().unwrap();
let expected = Uri::from_static("/2018-06-01/runtime/invocation/id/response");
Expand Down
Loading

0 comments on commit cf72bb0

Please sign in to comment.