diff --git a/auth/src/api/cache.rs b/auth/src/api/cache.rs index dcb380151..4ff2cc09f 100644 --- a/auth/src/api/cache.rs +++ b/auth/src/api/cache.rs @@ -16,6 +16,7 @@ use std::{ time::Duration, }; use tower::{Layer, Service}; +use tracing::error; use ttl_cache::TtlCache; use super::RouterState; @@ -43,10 +44,7 @@ impl CacheManagement for CacheManager { } fn invalidate(&self, key: &str) -> Option { - self.cache - .write() - .unwrap() - .remove(key) + self.cache.write().unwrap().remove(key) } } @@ -136,42 +134,82 @@ where let public_key = self.state.key_manager.public_key(); let cache_manager = self.state.cache_manager.clone(); - // TODO: error handling. return Box::pin(async move { let response: Response = future.await?; - let public_key = public_key.clone(); + // Return response directly if it failed. + if response.status() != StatusCode::OK { + return Ok(response); + } + // We'll re-use the parts in the response if all goes well. let (parts, body) = response.into_parts(); - let bytes = hyper::body::to_bytes(body).await.unwrap(); + let body = match hyper::body::to_bytes(body).await { + Ok(body) => body, + Err(error) => { + error!( + error = &error as &dyn std::error::Error, + "failed to get response body" + ); - let Ok(value): Result = serde_json::from_slice(&bytes) else { - return - Ok(Response::builder() - .status(StatusCode::UNAUTHORIZED) + return Ok(Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Default::default()) + .unwrap()); + } + }; + + let value: Value = match serde_json::from_slice(&body) { + Ok(value) => value, + Err(error) => { + error!( + error = &error as &dyn std::error::Error, + "response body is malformed" + ); + + return Ok(Response::builder() + .status(StatusCode::BAD_REQUEST) .body(Default::default()) .unwrap()); + } }; + let Some(jwt) = value["token"].as_str() else { - return - Ok(Response::builder() + error!("response json is missing 'token' key"); + + return Ok(Response::builder() .status(StatusCode::UNAUTHORIZED) .body(Default::default()) .unwrap()); - }; - let claim = Claim::from_token(jwt, &public_key).unwrap(); + let public_key = public_key.clone(); + + let claim = match Claim::from_token(jwt, &public_key) { + Ok(claim) => claim, + Err(status) => { + return Ok(Response::builder() + .status(status) + .body(Default::default()) + .unwrap()); + } + }; // Expiration time (as UTC timestamp). let exp = claim.exp; - let expiration_timestamp = Utc + let Some(expiration_timestamp) = Utc .timestamp_opt(exp as i64, 0) .single() - .ok_or(StatusCode::INTERNAL_SERVER_ERROR) - .unwrap(); + else { + error!("expiration timestamp is out of range number"); + + return Ok(Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Default::default()) + .unwrap()); + }; let duration = expiration_timestamp - Utc::now(); @@ -182,8 +220,10 @@ where Duration::from_secs(duration.num_seconds() as u64), ); + // Request succeeded and JWT was cached. Convert the body bytes back into a HttpBody, + // and return it along with the original response parts. let body = - ::map_err(bytes.into(), axum::Error::new).boxed_unsync(); + ::map_err(body.into(), axum::Error::new).boxed_unsync(); Ok(Response::from_parts(parts, body)) }); diff --git a/auth/src/api/handlers.rs b/auth/src/api/handlers.rs index 4ae7aaabe..b87716ef8 100644 --- a/auth/src/api/handlers.rs +++ b/auth/src/api/handlers.rs @@ -62,20 +62,18 @@ pub(crate) async fn logout( State(cache_manager): State, headers: HeaderMap, ) { - // TODO: this is a POC, needs refactor and error handling. + // If there is a cookie, extract it and try to get the id. let cache_key = if let Ok(Some(cookie)) = headers.typed_try_get::() { - if let Some(id) = cookie.get("shuttle.sid") { - Some(id.to_string()) - } else { - None - } + cookie.get("shuttle.sid").map(|id| id.to_string()) } else { None - } - .unwrap(); + }; - // Clear the session's associated JWT from the cache. - cache_manager.invalidate(&cache_key); + // If there was an id in the cookie, clear it from the cache. + if let Some(key) = cache_key { + // Clear the session's associated JWT from the cache. + cache_manager.invalidate(&key); + } session.destroy(); }