Skip to content

Commit

Permalink
refactor: rearrange TestResponse assertion methods (#98)
Browse files Browse the repository at this point in the history
* Move the main assert_status_* functions to be above the others.
 * Move `into_websocket` above assertion functions.
  • Loading branch information
JosephLenton authored Aug 22, 2024
1 parent f3fd041 commit 278091c
Showing 1 changed file with 99 additions and 99 deletions.
198 changes: 99 additions & 99 deletions src/test_response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,64 @@ impl TestResponse {
})
}

/// Consumes the request, turning it into a `TestWebSocket`.
/// If this cannot be done, then the response will panic.
///
/// *Note*, this requires the server to be running on a real HTTP
/// port. Either using a randomly assigned port, or a specified one.
/// See the [`TestServerConfig::transport`](crate::TestServerConfig::transport) for more details.
///
/// # Example
///
/// ```rust
/// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
/// #
/// use ::axum::Router;
/// use ::axum_test::TestServer;
/// use ::axum_test::TestServerConfig;
///
/// let app = Router::new();
/// let config = TestServerConfig::builder().http_transport().build();
/// let server = TestServer::new_with_config(app, config)?;
///
/// let mut websocket = server
/// .get_websocket(&"/my-web-socket-end-point")
/// .await
/// .into_websocket()
/// .await;
///
/// websocket.send_text("Hello!").await;
/// #
/// # Ok(()) }
/// ```
///
#[cfg(feature = "ws")]
#[must_use]
pub async fn into_websocket(self) -> TestWebSocket {
use crate::transport_layer::TransportLayerType;

// Using the mock approach will just fail.
if self.websockets.transport_type != TransportLayerType::Http {
unimplemented!("WebSocket requires a HTTP based transport layer, see `TestServerConfig::transport`");
}

let debug_request_format = self.debug_request_format().to_string();

let on_upgrade = self.websockets.maybe_on_upgrade.with_context(|| {
format!("Expected WebSocket upgrade to be found, it is None, for request {debug_request_format}")
})
.unwrap();

let upgraded = on_upgrade
.await
.with_context(|| {
format!("Failed to upgrade connection for, for request {debug_request_format}")
})
.unwrap();

TestWebSocket::new(upgraded).await
}

/// This performs an assertion comparing the whole body of the response,
/// against the text provided.
#[track_caller]
Expand Down Expand Up @@ -657,6 +715,33 @@ impl TestResponse {
assert_eq!(*other, self.form::<T>());
}

/// Assert the response status code matches the one given.
#[track_caller]
pub fn assert_status(&self, expected_status_code: StatusCode) {
let received_debug = StatusCodeFormatter(self.status_code);
let expected_debug = StatusCodeFormatter(expected_status_code);
let debug_request_format = self.debug_request_format();

assert_eq!(
expected_status_code, self.status_code,
"Expected status code to be {expected_debug}, received {received_debug}, for request {debug_request_format}",
);
}

/// Assert the response status code does **not** match the one given.
#[track_caller]
pub fn assert_not_status(&self, expected_status_code: StatusCode) {
let received_debug = StatusCodeFormatter(self.status_code);
let expected_debug = StatusCodeFormatter(expected_status_code);
let debug_request_format = self.debug_request_format();

assert_ne!(
expected_status_code,
self.status_code,
"Expected status code to not be {expected_debug}, received {received_debug}, for request {debug_request_format}"
);
}

/// Assert that the status code is **within** the 2xx range.
/// i.e. The range from 200-299.
#[track_caller]
Expand All @@ -667,7 +752,7 @@ impl TestResponse {

assert!(
200 <= status_code && status_code <= 299,
"Expect status code within 2xx range, got {received_debug}, for request {debug_request_format}"
"Expect status code within 2xx range, received {received_debug}, for request {debug_request_format}"
);
}

Expand All @@ -681,10 +766,22 @@ impl TestResponse {

assert!(
status_code < 200 || 299 < status_code,
"Expect status code outside 2xx range, got {received_debug}, for request {debug_request_format}",
"Expect status code outside 2xx range, received {received_debug}, for request {debug_request_format}",
);
}

/// Assert the response status code is 200.
#[track_caller]
pub fn assert_status_ok(&self) {
self.assert_status(StatusCode::OK)
}

/// Assert the response status code is **not** 200.
#[track_caller]
pub fn assert_status_not_ok(&self) {
self.assert_not_status(StatusCode::OK)
}

/// Assert the response status code is 400.
#[track_caller]
pub fn assert_status_bad_request(&self) {
Expand Down Expand Up @@ -729,18 +826,6 @@ impl TestResponse {
self.assert_status(StatusCode::TOO_MANY_REQUESTS)
}

/// Assert the response status code is 200.
#[track_caller]
pub fn assert_status_ok(&self) {
self.assert_status(StatusCode::OK)
}

/// Assert the response status code is **not** 200.
#[track_caller]
pub fn assert_status_not_ok(&self) {
self.assert_not_status(StatusCode::OK)
}

/// Assert the response status code is 101.
///
/// This type of code is used in Web Socket connection when
Expand All @@ -756,91 +841,6 @@ impl TestResponse {
self.assert_status(StatusCode::SERVICE_UNAVAILABLE)
}

/// Assert the response status code matches the one given.
#[track_caller]
pub fn assert_status(&self, expected_status_code: StatusCode) {
let status_code = self.status_code.as_u16();
let received_debug = StatusCodeFormatter(self.status_code);
let expected_debug = StatusCodeFormatter(expected_status_code);
let debug_request_format = self.debug_request_format();

assert_eq!(
expected_status_code, status_code,
"Expected status code {expected_debug}, got {received_debug}, for request {debug_request_format}",
);
}

/// Assert the response status code does **not** match the one given.
#[track_caller]
pub fn assert_not_status(&self, expected_status_code: StatusCode) {
let expected_debug = StatusCodeFormatter(expected_status_code);
let debug_request_format = self.debug_request_format();

assert_ne!(
expected_status_code,
self.status_code(),
"Expected status code to not be {expected_debug}, it is, for request {debug_request_format}"
);
}

/// Consumes the request, turning it into a `TestWebSocket`.
/// If this cannot be done, then the response will panic.
///
/// *Note*, this requires the server to be running on a real HTTP
/// port. Either using a randomly assigned port, or a specified one.
/// See the [`TestServerConfig::transport`](crate::TestServerConfig::transport) for more details.
///
/// # Example
///
/// ```rust
/// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
/// #
/// use ::axum::Router;
/// use ::axum_test::TestServer;
/// use ::axum_test::TestServerConfig;
///
/// let app = Router::new();
/// let config = TestServerConfig::builder().http_transport().build();
/// let server = TestServer::new_with_config(app, config)?;
///
/// let mut websocket = server
/// .get_websocket(&"/my-web-socket-end-point")
/// .await
/// .into_websocket()
/// .await;
///
/// websocket.send_text("Hello!").await;
/// #
/// # Ok(()) }
/// ```
///
#[cfg(feature = "ws")]
#[must_use]
pub async fn into_websocket(self) -> TestWebSocket {
use crate::transport_layer::TransportLayerType;

// Using the mock approach will just fail.
if self.websockets.transport_type != TransportLayerType::Http {
unimplemented!("WebSocket requires a HTTP based transport layer, see `TestServerConfig::transport`");
}

let debug_request_format = self.debug_request_format().to_string();

let on_upgrade = self.websockets.maybe_on_upgrade.with_context(|| {
format!("Expected WebSocket upgrade to be found, it is None, for request {debug_request_format}")
})
.unwrap();

let upgraded = on_upgrade
.await
.with_context(|| {
format!("Failed to upgrade connection for, for request {debug_request_format}")
})
.unwrap();

TestWebSocket::new(upgraded).await
}

fn debug_request_format<'a>(&'a self) -> RequestPathFormatter<'a> {
RequestPathFormatter::new(&self.method, &self.full_request_url.as_str(), None)
}
Expand Down

0 comments on commit 278091c

Please sign in to comment.