From 278091c0566138b9be05d41011d001c2adc14a38 Mon Sep 17 00:00:00 2001 From: Joseph Lenton Date: Thu, 22 Aug 2024 09:59:12 +0200 Subject: [PATCH] refactor: rearrange TestResponse assertion methods (#98) * Move the main assert_status_* functions to be above the others. * Move `into_websocket` above assertion functions. --- src/test_response.rs | 198 +++++++++++++++++++++---------------------- 1 file changed, 99 insertions(+), 99 deletions(-) diff --git a/src/test_response.rs b/src/test_response.rs index 91fc50b..fd93ead 100644 --- a/src/test_response.rs +++ b/src/test_response.rs @@ -576,6 +576,64 @@ impl TestResponse { }) } + /// Consumes the request, turning it into a `TestWebSocket`. + /// If this cannot be done, then the response will panic. + /// + /// *Note*, this requires the server to be running on a real HTTP + /// port. Either using a randomly assigned port, or a specified one. + /// See the [`TestServerConfig::transport`](crate::TestServerConfig::transport) for more details. + /// + /// # Example + /// + /// ```rust + /// # async fn test() -> Result<(), Box> { + /// # + /// use ::axum::Router; + /// use ::axum_test::TestServer; + /// use ::axum_test::TestServerConfig; + /// + /// let app = Router::new(); + /// let config = TestServerConfig::builder().http_transport().build(); + /// let server = TestServer::new_with_config(app, config)?; + /// + /// let mut websocket = server + /// .get_websocket(&"/my-web-socket-end-point") + /// .await + /// .into_websocket() + /// .await; + /// + /// websocket.send_text("Hello!").await; + /// # + /// # Ok(()) } + /// ``` + /// + #[cfg(feature = "ws")] + #[must_use] + pub async fn into_websocket(self) -> TestWebSocket { + use crate::transport_layer::TransportLayerType; + + // Using the mock approach will just fail. + if self.websockets.transport_type != TransportLayerType::Http { + unimplemented!("WebSocket requires a HTTP based transport layer, see `TestServerConfig::transport`"); + } + + let debug_request_format = self.debug_request_format().to_string(); + + let on_upgrade = self.websockets.maybe_on_upgrade.with_context(|| { + format!("Expected WebSocket upgrade to be found, it is None, for request {debug_request_format}") + }) + .unwrap(); + + let upgraded = on_upgrade + .await + .with_context(|| { + format!("Failed to upgrade connection for, for request {debug_request_format}") + }) + .unwrap(); + + TestWebSocket::new(upgraded).await + } + /// This performs an assertion comparing the whole body of the response, /// against the text provided. #[track_caller] @@ -657,6 +715,33 @@ impl TestResponse { assert_eq!(*other, self.form::()); } + /// Assert the response status code matches the one given. + #[track_caller] + pub fn assert_status(&self, expected_status_code: StatusCode) { + let received_debug = StatusCodeFormatter(self.status_code); + let expected_debug = StatusCodeFormatter(expected_status_code); + let debug_request_format = self.debug_request_format(); + + assert_eq!( + expected_status_code, self.status_code, + "Expected status code to be {expected_debug}, received {received_debug}, for request {debug_request_format}", + ); + } + + /// Assert the response status code does **not** match the one given. + #[track_caller] + pub fn assert_not_status(&self, expected_status_code: StatusCode) { + let received_debug = StatusCodeFormatter(self.status_code); + let expected_debug = StatusCodeFormatter(expected_status_code); + let debug_request_format = self.debug_request_format(); + + assert_ne!( + expected_status_code, + self.status_code, + "Expected status code to not be {expected_debug}, received {received_debug}, for request {debug_request_format}" + ); + } + /// Assert that the status code is **within** the 2xx range. /// i.e. The range from 200-299. #[track_caller] @@ -667,7 +752,7 @@ impl TestResponse { assert!( 200 <= status_code && status_code <= 299, - "Expect status code within 2xx range, got {received_debug}, for request {debug_request_format}" + "Expect status code within 2xx range, received {received_debug}, for request {debug_request_format}" ); } @@ -681,10 +766,22 @@ impl TestResponse { assert!( status_code < 200 || 299 < status_code, - "Expect status code outside 2xx range, got {received_debug}, for request {debug_request_format}", + "Expect status code outside 2xx range, received {received_debug}, for request {debug_request_format}", ); } + /// Assert the response status code is 200. + #[track_caller] + pub fn assert_status_ok(&self) { + self.assert_status(StatusCode::OK) + } + + /// Assert the response status code is **not** 200. + #[track_caller] + pub fn assert_status_not_ok(&self) { + self.assert_not_status(StatusCode::OK) + } + /// Assert the response status code is 400. #[track_caller] pub fn assert_status_bad_request(&self) { @@ -729,18 +826,6 @@ impl TestResponse { self.assert_status(StatusCode::TOO_MANY_REQUESTS) } - /// Assert the response status code is 200. - #[track_caller] - pub fn assert_status_ok(&self) { - self.assert_status(StatusCode::OK) - } - - /// Assert the response status code is **not** 200. - #[track_caller] - pub fn assert_status_not_ok(&self) { - self.assert_not_status(StatusCode::OK) - } - /// Assert the response status code is 101. /// /// This type of code is used in Web Socket connection when @@ -756,91 +841,6 @@ impl TestResponse { self.assert_status(StatusCode::SERVICE_UNAVAILABLE) } - /// Assert the response status code matches the one given. - #[track_caller] - pub fn assert_status(&self, expected_status_code: StatusCode) { - let status_code = self.status_code.as_u16(); - let received_debug = StatusCodeFormatter(self.status_code); - let expected_debug = StatusCodeFormatter(expected_status_code); - let debug_request_format = self.debug_request_format(); - - assert_eq!( - expected_status_code, status_code, - "Expected status code {expected_debug}, got {received_debug}, for request {debug_request_format}", - ); - } - - /// Assert the response status code does **not** match the one given. - #[track_caller] - pub fn assert_not_status(&self, expected_status_code: StatusCode) { - let expected_debug = StatusCodeFormatter(expected_status_code); - let debug_request_format = self.debug_request_format(); - - assert_ne!( - expected_status_code, - self.status_code(), - "Expected status code to not be {expected_debug}, it is, for request {debug_request_format}" - ); - } - - /// Consumes the request, turning it into a `TestWebSocket`. - /// If this cannot be done, then the response will panic. - /// - /// *Note*, this requires the server to be running on a real HTTP - /// port. Either using a randomly assigned port, or a specified one. - /// See the [`TestServerConfig::transport`](crate::TestServerConfig::transport) for more details. - /// - /// # Example - /// - /// ```rust - /// # async fn test() -> Result<(), Box> { - /// # - /// use ::axum::Router; - /// use ::axum_test::TestServer; - /// use ::axum_test::TestServerConfig; - /// - /// let app = Router::new(); - /// let config = TestServerConfig::builder().http_transport().build(); - /// let server = TestServer::new_with_config(app, config)?; - /// - /// let mut websocket = server - /// .get_websocket(&"/my-web-socket-end-point") - /// .await - /// .into_websocket() - /// .await; - /// - /// websocket.send_text("Hello!").await; - /// # - /// # Ok(()) } - /// ``` - /// - #[cfg(feature = "ws")] - #[must_use] - pub async fn into_websocket(self) -> TestWebSocket { - use crate::transport_layer::TransportLayerType; - - // Using the mock approach will just fail. - if self.websockets.transport_type != TransportLayerType::Http { - unimplemented!("WebSocket requires a HTTP based transport layer, see `TestServerConfig::transport`"); - } - - let debug_request_format = self.debug_request_format().to_string(); - - let on_upgrade = self.websockets.maybe_on_upgrade.with_context(|| { - format!("Expected WebSocket upgrade to be found, it is None, for request {debug_request_format}") - }) - .unwrap(); - - let upgraded = on_upgrade - .await - .with_context(|| { - format!("Failed to upgrade connection for, for request {debug_request_format}") - }) - .unwrap(); - - TestWebSocket::new(upgraded).await - } - fn debug_request_format<'a>(&'a self) -> RequestPathFormatter<'a> { RequestPathFormatter::new(&self.method, &self.full_request_url.as_str(), None) }