From b07c7313f635613d0600bf8840999732b686a80a Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Sun, 1 Oct 2023 10:07:45 +0200 Subject: [PATCH] cors: Don't overwrite vary header set by the inner service --- tower-http/src/cors/mod.rs | 7 +++++++ tower-http/src/cors/tests.rs | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) create mode 100644 tower-http/src/cors/tests.rs diff --git a/tower-http/src/cors/mod.rs b/tower-http/src/cors/mod.rs index 6883ac31..0524122d 100644 --- a/tower-http/src/cors/mod.rs +++ b/tower-http/src/cors/mod.rs @@ -681,6 +681,13 @@ where match self.project().inner.project() { KindProj::CorsCall { future, headers } => { let mut response: Response = ready!(future.poll(cx))?; + + // vary header can have multiple values, don't overwrite + // previously-set value(s). + if let Some(vary) = headers.remove(header::VARY) { + headers.append(header::VARY, vary); + } + // extend will overwrite previous headers of remaining names response.headers_mut().extend(headers.drain()); Poll::Ready(Ok(response)) diff --git a/tower-http/src/cors/tests.rs b/tower-http/src/cors/tests.rs new file mode 100644 index 00000000..4eccc41c --- /dev/null +++ b/tower-http/src/cors/tests.rs @@ -0,0 +1,33 @@ +use std::convert::Infallible; + +use http::{header, HeaderValue, Request, Response}; +use hyper::Body; +use tower::{service_fn, util::ServiceExt, Layer}; + +use crate::cors::CorsLayer; + +#[tokio::test] +#[allow( + clippy::declare_interior_mutable_const, + clippy::borrow_interior_mutable_const +)] +async fn vary_set_by_inner_service() { + const CUSTOM_VARY_HEADERS: HeaderValue = HeaderValue::from_static("accept, accept-encoding"); + const PERMISSIVE_CORS_VARY_HEADERS: HeaderValue = HeaderValue::from_static( + "origin, access-control-request-method, access-control-request-headers", + ); + + async fn inner_svc(_: Request) -> Result, Infallible> { + Ok(Response::builder() + .header(header::VARY, CUSTOM_VARY_HEADERS) + .body(Body::empty()) + .unwrap()) + } + + let svc = CorsLayer::permissive().layer(service_fn(inner_svc)); + let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); + let mut vary_headers = res.headers().get_all(header::VARY).into_iter(); + assert_eq!(vary_headers.next(), Some(&CUSTOM_VARY_HEADERS)); + assert_eq!(vary_headers.next(), Some(&PERMISSIVE_CORS_VARY_HEADERS)); + assert_eq!(vary_headers.next(), None); +}