diff --git a/src/session.rs b/src/session.rs index a3328b0..8c24920 100644 --- a/src/session.rs +++ b/src/session.rs @@ -275,21 +275,19 @@ where fn call(&mut self, mut request: Request) -> Self::Future { let session_layer = self.layer.clone(); - let cookie_values = request + // Multiple cookies may be all concatenated into a single Cookie header + // separated with semicolons (HTTP/1.1 behaviour) or into multiple separate + // Cookie headers (HTTP/2 behaviour). Search for the session cookie from + // all Cookie headers, assuming both forms are possible + let cookie_value = request .headers() - .get(COOKIE) - .map(|cookies| cookies.to_str()); - - let cookie_value = if let Some(Ok(cookies)) = cookie_values { - cookies - .split(';') - .map(|cookie| cookie.trim()) - .filter_map(|cookie| Cookie::parse_encoded(cookie).ok()) - .filter(|cookie| cookie.name() == session_layer.cookie_name) - .find_map(|cookie| self.layer.verify_signature(cookie.value()).ok()) - } else { - None - }; + .get_all(COOKIE) + .iter() + .filter_map(|cookie_header| cookie_header.to_str().ok()) + .flat_map(|cookie_header| cookie_header.split(';')) + .filter_map(|cookie_header| Cookie::parse_encoded(cookie_header.trim()).ok()) + .filter(|cookie| cookie.name() == session_layer.cookie_name) + .find_map(|cookie| self.layer.verify_signature(cookie.value()).ok()); let inner = self.inner.clone(); let mut inner = std::mem::replace(&mut self.inner, inner); @@ -366,7 +364,7 @@ mod tests { use axum::http::{Request, Response}; use http::{ header::{COOKIE, SET_COOKIE}, - StatusCode, + HeaderValue, StatusCode, }; use hyper::Body; use rand::Rng; @@ -432,6 +430,74 @@ mod tests { assert_eq!(counter, Counter { counter: 1 }); } + #[tokio::test] + async fn multiple_cookies_in_single_header() { + let secret = rand::thread_rng().gen::<[u8; 64]>(); + let store = MemoryStore::new(); + let session_layer = SessionLayer::new(store, &secret); + let mut service = ServiceBuilder::new() + .layer(session_layer) + .service_fn(increment); + + let request = Request::get("/").body(Body::empty()).unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + let session_cookie = res.headers().get(SET_COOKIE).unwrap().clone(); + + // build a Cookie header that contains two cookies: an unrelated dummy cookie, + // and the given session cookie + let request_cookie = + HeaderValue::from_str(&format!("key=value; {}", session_cookie.to_str().unwrap())) + .unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + + let json_bs = &hyper::body::to_bytes(res.into_body()).await.unwrap()[..]; + let counter: Counter = serde_json::from_slice(json_bs).unwrap(); + assert_eq!(counter, Counter { counter: 0 }); + + let mut request = Request::get("/").body(Body::empty()).unwrap(); + request.headers_mut().insert(COOKIE, request_cookie); + let res = service.ready().await.unwrap().call(request).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let json_bs = &hyper::body::to_bytes(res.into_body()).await.unwrap()[..]; + let counter: Counter = serde_json::from_slice(json_bs).unwrap(); + assert_eq!(counter, Counter { counter: 1 }); + } + + #[tokio::test] + async fn multiple_cookie_headers() { + let secret = rand::thread_rng().gen::<[u8; 64]>(); + let store = MemoryStore::new(); + let session_layer = SessionLayer::new(store, &secret); + let mut service = ServiceBuilder::new() + .layer(session_layer) + .service_fn(increment); + + let request = Request::get("/").body(Body::empty()).unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + let session_cookie = res.headers().get(SET_COOKIE).unwrap().clone(); + let dummy_cookie = HeaderValue::from_str("key=value").unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + + let json_bs = &hyper::body::to_bytes(res.into_body()).await.unwrap()[..]; + let counter: Counter = serde_json::from_slice(json_bs).unwrap(); + assert_eq!(counter, Counter { counter: 0 }); + + let mut request = Request::get("/").body(Body::empty()).unwrap(); + request.headers_mut().append(COOKIE, dummy_cookie); + request.headers_mut().append(COOKIE, session_cookie); + let res = service.ready().await.unwrap().call(request).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let json_bs = &hyper::body::to_bytes(res.into_body()).await.unwrap()[..]; + let counter: Counter = serde_json::from_slice(json_bs).unwrap(); + assert_eq!(counter, Counter { counter: 1 }); + } + #[tokio::test] async fn invalid_session_sets_cookie() { let secret = rand::thread_rng().gen::<[u8; 64]>();