diff --git a/tests/auth.rs b/tests/auth.rs index 8d0cd02b..2cb29332 100644 --- a/tests/auth.rs +++ b/tests/auth.rs @@ -49,7 +49,7 @@ fn negotiate_auth_exists() { .send() .unwrap(); - assert!(!m.requests().is_empty()); + assert_eq!(m.requests_received(), 1); } #[cfg(all(feature = "spnego", windows))] diff --git a/tests/headers.rs b/tests/headers.rs index 066ff264..84143129 100644 --- a/tests/headers.rs +++ b/tests/headers.rs @@ -87,12 +87,8 @@ fn set_title_case_headers_to_true() { client.get(m.url()).unwrap(); - assert_eq!(m.request().method, "GET"); - m.request() - .headers - .iter() - .find(|(key, value)| key == "Foo-Bar" && value == "baz") - .expect("header not found"); + assert_eq!(m.request().method(), "GET"); + m.request().expect_header("Foo-Bar", "baz"); } #[test] diff --git a/tests/interceptors.rs b/tests/interceptors.rs index 0e4d4e29..c1d671ac 100644 --- a/tests/interceptors.rs +++ b/tests/interceptors.rs @@ -17,5 +17,5 @@ fn change_http_method_with_interceptor() { client.get(m.url()).unwrap(); - assert_eq!(m.request().method, "HEAD"); + assert_eq!(m.request().method(), "HEAD"); } diff --git a/tests/methods.rs b/tests/methods.rs index c5932a78..ee8fa8de 100644 --- a/tests/methods.rs +++ b/tests/methods.rs @@ -7,7 +7,7 @@ fn get_request() { isahc::get(m.url()).unwrap(); - assert_eq!(m.request().method, "GET"); + assert_eq!(m.request().method(), "GET"); } #[test] @@ -16,7 +16,7 @@ fn head_request() { isahc::head(m.url()).unwrap(); - assert_eq!(m.request().method, "HEAD"); + assert_eq!(m.request().method(), "HEAD"); } #[test] @@ -25,7 +25,7 @@ fn post_request() { isahc::post(m.url(), ()).unwrap(); - assert_eq!(m.request().method, "POST"); + assert_eq!(m.request().method(), "POST"); } #[test] @@ -34,7 +34,7 @@ fn put_request() { isahc::put(m.url(), ()).unwrap(); - assert_eq!(m.request().method, "PUT"); + assert_eq!(m.request().method(), "PUT"); } #[test] @@ -43,7 +43,7 @@ fn delete_request() { isahc::delete(m.url()).unwrap(); - assert_eq!(m.request().method, "DELETE"); + assert_eq!(m.request().method(), "DELETE"); } #[test] @@ -58,5 +58,5 @@ fn arbitrary_foobar_request() { .send() .unwrap(); - assert_eq!(m.request().method, "FOOBAR"); + assert_eq!(m.request().method(), "FOOBAR"); } diff --git a/tests/metrics.rs b/tests/metrics.rs index 6fb6690a..cc6cae56 100644 --- a/tests/metrics.rs +++ b/tests/metrics.rs @@ -8,7 +8,7 @@ fn metrics_are_disabled_by_default() { let response = isahc::get(m.url()).unwrap(); - assert!(!m.requests().is_empty()); + assert_eq!(m.requests_received(), 1); assert!(response.metrics().is_none()); } diff --git a/tests/net.rs b/tests/net.rs index a885732c..d133dc7b 100644 --- a/tests/net.rs +++ b/tests/net.rs @@ -28,7 +28,7 @@ fn local_addr_returns_expected_address() { let response = isahc::get(m.url()).unwrap(); - assert!(!m.requests().is_empty()); + assert_eq!(m.requests_received(), 1); assert_eq!(response.local_addr().unwrap().ip(), Ipv4Addr::LOCALHOST); assert!(response.local_addr().unwrap().port() > 0); } @@ -39,7 +39,7 @@ fn remote_addr_returns_expected_address() { let response = isahc::get(m.url()).unwrap(); - assert!(!m.requests().is_empty()); + assert_eq!(m.requests_received(), 1); assert_eq!(response.remote_addr(), Some(m.addr())); } diff --git a/tests/proxy.rs b/tests/proxy.rs index 8c4ff035..f3f23fc9 100644 --- a/tests/proxy.rs +++ b/tests/proxy.rs @@ -12,7 +12,7 @@ fn no_proxy() { .send() .unwrap(); - assert_eq!(m.requests().len(), 1); + assert_eq!(m.requests_received(), 1); } #[test] @@ -35,7 +35,7 @@ fn http_proxy() { // protocol. The request-target should be the absolute URI of our // upstream request target (see [RFC // 7230](https://tools.ietf.org/html/rfc7230), sections 5.3 and 5.7). - assert_eq!(m.request().url, upstream.to_string()); + assert_eq!(m.request().url(), upstream.to_string()); // Host should be the upstream authority, not the proxy host. m.request() .expect_header("host", upstream.authority().unwrap().as_str()); @@ -71,8 +71,8 @@ fn socks4_proxy() { .send() .unwrap(); - // ...expecting to receive it through the proxy. - assert_eq!(m.requests().len(), 1); + // ...expecting to receive it through the proxy.Z + assert_eq!(m.requests_received(), 1); } #[test] @@ -93,5 +93,5 @@ fn proxy_blacklist_works() { .send() .unwrap(); - assert_eq!(m.requests().len(), 1); + assert_eq!(m.requests_received(), 1); } diff --git a/tests/redirects.rs b/tests/redirects.rs index 3e72726e..a4e854fb 100644 --- a/tests/redirects.rs +++ b/tests/redirects.rs @@ -17,7 +17,7 @@ fn response_301_no_follow() { assert_eq!(response.headers()["Location"], "/2"); assert_eq!(response.effective_uri().unwrap().path(), "/"); - assert!(!m.requests().is_empty()); + assert_eq!(m.requests_received(), 1); } #[test] @@ -46,8 +46,8 @@ fn response_301_auto_follow() { assert_eq!(response.text().unwrap(), "ok"); assert_eq!(response.effective_uri().unwrap().to_string(), m2.url()); - assert!(!m1.requests().is_empty()); - assert!(!m2.requests().is_empty()); + assert_eq!(m1.requests_received(), 1); + assert_eq!(m2.requests_received(), 1); } #[test] @@ -82,8 +82,8 @@ fn headers_are_reset_every_redirect() { assert_eq!(response.headers()["X-Baz"], "zzz"); assert!(!response.headers().contains_key("X-Bar")); - assert!(!m1.requests().is_empty()); - assert!(!m2.requests().is_empty()); + assert_eq!(m1.requests_received(), 1); + assert_eq!(m2.requests_received(), 1); } #[test_case(301)] @@ -110,8 +110,8 @@ fn redirect_changes_post_to_get(status: u16) { assert_eq!(response.status(), 200); assert_eq!(response.effective_uri().unwrap().to_string(), m2.url()); - assert_eq!(m1.request().method, "POST"); - assert_eq!(m2.request().method, "GET"); + assert_eq!(m1.request().method(), "POST"); + assert_eq!(m2.request().method(), "GET"); } #[test_case(307)] @@ -137,8 +137,8 @@ fn redirect_also_sends_post(status: u16) { assert_eq!(response.status(), 200); assert_eq!(response.effective_uri().unwrap().to_string(), m2.url()); - assert_eq!(m1.request().method, "POST"); - assert_eq!(m2.request().method, "POST"); + assert_eq!(m1.request().method(), "POST"); + assert_eq!(m2.request().method(), "POST"); } // Issue #250 @@ -167,8 +167,8 @@ fn redirect_with_response_body() { assert_eq!(response.status(), 200); assert_eq!(response.effective_uri().unwrap().to_string(), m2.url()); - assert_eq!(m1.request().method, "POST"); - assert_eq!(m2.request().method, "GET"); + assert_eq!(m1.request().method(), "POST"); + assert_eq!(m2.request().method(), "GET"); } // Issue #250 @@ -194,8 +194,8 @@ fn redirect_policy_from_client() { assert_eq!(response.status(), 200); assert_eq!(response.effective_uri().unwrap().to_string(), m2.url()); - assert_eq!(m1.request().method, "POST"); - assert_eq!(m2.request().method, "GET"); + assert_eq!(m1.request().method(), "POST"); + assert_eq!(m2.request().method(), "GET"); } #[test] @@ -222,7 +222,7 @@ fn redirect_non_rewindable_body_returns_error() { assert_eq!(error, isahc::error::ErrorKind::RequestBodyNotRewindable); assert_eq!(error.remote_addr(), Some(m1.addr())); - assert_eq!(m1.request().method, "POST"); + assert_eq!(m1.request().method(), "POST"); } #[test] @@ -246,7 +246,7 @@ fn redirect_limit_is_respected() { assert_eq!(error.remote_addr(), Some(m.addr())); // After request (limit + 1) that returns a redirect should error. - assert_eq!(m.requests().len(), 6); + assert_eq!(m.requests_received(), 6); } #[test] @@ -314,6 +314,6 @@ fn redirect_with_unencoded_utf8_bytes_in_location() { assert_eq!(response.text().unwrap(), "ok"); assert_eq!(response.effective_uri().unwrap().to_string(), m2.url()); - assert!(!m1.requests().is_empty()); - assert!(!m2.requests().is_empty()); + assert_eq!(m1.requests_received(), 1); + assert_eq!(m2.requests_received(), 1); } diff --git a/tests/request_body.rs b/tests/request_body.rs index 3159c763..8d2618ba 100644 --- a/tests/request_body.rs +++ b/tests/request_body.rs @@ -33,7 +33,7 @@ fn request_with_body_of_known_size(method: &str) { .send() .unwrap(); - assert_eq!(m.request().method, method); + assert_eq!(m.request().method(), method); m.request() .expect_header("content-length", body.len().to_string()); m.request() @@ -63,7 +63,7 @@ fn request_with_body_of_unknown_size_uses_chunked_encoding(method: &str) { .send() .unwrap(); - assert_eq!(m.request().method, method); + assert_eq!(m.request().method(), method); m.request().expect_header("transfer-encoding", "chunked"); m.request().expect_body(body); } @@ -89,7 +89,7 @@ fn content_length_header_takes_precedence_over_body_objects_length(method: &str) .send() .unwrap(); - assert_eq!(m.request().method, method); + assert_eq!(m.request().method(), method); m.request().expect_header("content-length", "3"); m.request().expect_body("abc"); // truncated to 3 bytes } diff --git a/tests/response_body.rs b/tests/response_body.rs index ced29fab..a21ea75e 100644 --- a/tests/response_body.rs +++ b/tests/response_body.rs @@ -155,7 +155,9 @@ fn consume_unread_response_body() { let m = { let body = body.clone(); mock! { - body: body.clone(), + _ => { + body: body.clone(), + }, } }; @@ -173,7 +175,9 @@ fn consume_unread_response_body_async() { let m = { let body = body.clone(); mock! { - body: body.clone(), + _ => { + body: body.clone(), + }, } }; diff --git a/tests/status.rs b/tests/status.rs index ba32121b..681f5427 100644 --- a/tests/status.rs +++ b/tests/status.rs @@ -22,5 +22,5 @@ fn returns_correct_response_code(status: u16) { let response = isahc::get(m.url()).unwrap(); assert_eq!(response.status(), status); - assert_eq!(m.requests().len(), 1); + assert_eq!(m.requests_received(), 1); } diff --git a/tests/timeouts.rs b/tests/timeouts.rs index 1e79c32d..13c3f9e4 100644 --- a/tests/timeouts.rs +++ b/tests/timeouts.rs @@ -27,7 +27,7 @@ fn request_errors_if_read_timeout_is_reached() { // Client should time-out. assert_matches!(result, Err(e) if e == isahc::error::ErrorKind::Timeout); - assert_eq!(m.requests().len(), 1); + assert_eq!(m.requests_received(), 1); } /// Issue #154 @@ -43,7 +43,9 @@ fn timeout_during_response_body_produces_error() { } let m = mock! { - body_reader: Cursor::new(vec![0; 100_000]).chain(SlowReader), + _ => { + body_reader: Cursor::new(vec![0; 100_000]).chain(SlowReader), + }, }; let mut response = Request::get(m.url()) diff --git a/testserver/Cargo.toml b/testserver/Cargo.toml index ce475687..50bafe30 100644 --- a/testserver/Cargo.toml +++ b/testserver/Cargo.toml @@ -7,6 +7,7 @@ publish = false [dependencies] humantime = "2" +once_cell = "1" regex = "1.3" -threadpool = "1.8" +threadfin = "0.1.1" tiny_http = "0.9" diff --git a/testserver/src/lib.rs b/testserver/src/lib.rs index f2e669cb..20f7c407 100644 --- a/testserver/src/lib.rs +++ b/testserver/src/lib.rs @@ -1,104 +1,17 @@ +//! HTTP server for testing. + +#[macro_use] +mod macros; mod mock; +mod pool; mod request; mod responder; mod response; pub mod socks4; +pub use macros::macro_api; pub use mock::Mock; pub use request::Request; -pub use responder::Responder; +pub use responder::{RequestContext, Responder}; pub use response::Response; - -/// Macro to define a mock endpoint using a more concise DSL. -#[macro_export] -macro_rules! mock { - (@response($response:expr) status: $status:expr, $($tail:tt)*) => {{ - let mut response = $response; - - response.status_code = $status as u16; - - $crate::mock!(@response(response) $($tail)*) - }}; - - (@response($response:expr) body: $body:expr, $($tail:tt)*) => {{ - let mut response = $response; - - response = response.with_body_buf($body); - - $crate::mock!(@response(response) $($tail)*) - }}; - - (@response($response:expr) body_reader: $body:expr, $($tail:tt)*) => {{ - let mut response = $response; - - response = response.with_body_reader($body); - - $crate::mock!(@response(response) $($tail)*) - }}; - - (@response($response:expr) transfer_encoding: $value:expr, $($tail:tt)*) => {{ - let mut response = $response; - - if $value { - response.body_len = None; - } - - $crate::mock!(@response(response) $($tail)*) - }}; - - (@response($response:expr) delay: $delay:tt, $($tail:tt)*) => {{ - let duration = $crate::helpers::parse_duration(stringify!($delay)); - ::std::thread::sleep(duration); - - $crate::mock!(@response($response) $($tail)*) - }}; - - (@response($response:expr) headers { - $( - $name:literal: $value:expr, - )* - } $($tail:tt)*) => {{ - let mut response = $response; - - $( - response.headers.push(($name.to_string(), $value.to_string())); - )* - - $crate::mock!(@response(response) $($tail)*) - }}; - - (@response($response:expr)) => {{ - $response - }}; - - ($($inner:tt)*) => {{ - struct Responder(F); - - impl $crate::Responder for Responder - where - F: Send + Sync + 'static + Fn($crate::Request) -> Option<$crate::Response>, - { - fn respond(&self, request: $crate::Request) -> Option<$crate::Response> { - (self.0)(request) - } - } - - $crate::Mock::new(Responder(move |request| { - let mut response = $crate::Response::default(); - - let response = $crate::mock!(@response(response) $($inner)*); - - Some(response) - })) - }}; -} - -#[doc(hidden)] -pub mod helpers { - use std::time::Duration; - - pub fn parse_duration(s: &str) -> Duration { - humantime::parse_duration(s).unwrap() - } -} diff --git a/testserver/src/macros.rs b/testserver/src/macros.rs new file mode 100644 index 00000000..0c0f46ab --- /dev/null +++ b/testserver/src/macros.rs @@ -0,0 +1,191 @@ +/// Macro to define a mock endpoint using a more concise DSL. +#[macro_export] +macro_rules! mock { + () => { + $crate::mock! { + _ => {}, + } + }; + + ($($inner:tt)*) => {{ + let mut builder = $crate::Mock::builder(); + + $crate::__mock_impl!(@responders(builder) $($inner)*); + + builder.build() + }}; +} + +#[doc(hidden)] +#[macro_export] +macro_rules! __mock_impl { + ( + @responders($builder:ident) + /$($path:tt)? => $response:tt, + $($tail:tt)* + ) => { + $builder = $builder.responder($crate::macro_api::ClosureResponder::new(move |ctx| { + if ctx.request().url() == stringify!(/$($path)*) { + $crate::__mock_impl!(@responder(ctx) $response); + } + })); + + $crate::__mock_impl!(@responders($builder) $($tail)*); + }; + + ( + @responders($builder:ident) + #$num:expr => writer |$writer:ident| {}, + $($tail:tt)* + ) => { + $builder = $builder.responder($crate::macro_api::ClosureResponder::new(move |ctx| { + if ctx.request().number() == $num { + let mut $writer = ctx.into_raw(); + } + })); + + $crate::__mock_impl!(@responders($builder) $($tail)*); + }; + + ( + @responders($builder:ident) + #$num:expr => { + $($response_attrs:tt)* + }, + $($tail:tt)* + ) => { + $builder = $builder.responder($crate::macro_api::ClosureResponder::new(move |ctx| { + if ctx.request().number() == $num { + let mut response = $crate::Response::default(); + + $crate::__mock_impl!(@response(response) $($response_attrs)*); + + ctx.send(response); + } + })); + + $crate::__mock_impl!(@responders($builder) $($tail)*); + }; + + ( + @responders($builder:ident) + _ => writer |$writer:ident| {}, + $($tail:tt)* + ) => { + $builder = $builder.responder($crate::macro_api::ClosureResponder::new(move |ctx| { + $crate::__mock_impl!(@responder(ctx) $response); + })); + + $crate::__mock_impl!(@responders($builder) $($tail)*); + }; + + ( + @responders($builder:ident) + _ => { + $($response_attrs:tt)* + }, + $($tail:tt)* + ) => { + $builder = $builder.responder($crate::macro_api::ClosureResponder::new(move |ctx| { + let mut response = $crate::Response::default(); + + $crate::__mock_impl!(@response(response) $($response_attrs)*); + + ctx.send(response); + })); + + $crate::__mock_impl!(@responders($builder) $($tail)*); + }; + + // For backwards compatibility. + (@responders($builder:ident) $($response_attrs:tt)+) => { + // $crate::__mock_impl!(@responders($builder) _ => { + // $($response_attrs)* + // }); + $builder = $builder.responder($crate::macro_api::ClosureResponder::new(move |ctx| { + let mut response = $crate::Response::default(); + + $crate::__mock_impl!(@response(response) $($response_attrs)*); + + ctx.send(response); + })); + }; + + (@responders($builder:ident)) => {}; + + (@response($response:ident) status: $status:expr, $($tail:tt)*) => { + $response.status_code = $status as u16; + + $crate::__mock_impl!(@response($response) $($tail)*) + }; + + (@response($response:ident) body: $body:expr, $($tail:tt)*) => { + $response = $response.with_body_buf($body); + + $crate::__mock_impl!(@response($response) $($tail)*) + }; + + (@response($response:ident) body_reader: $body:expr, $($tail:tt)*) => { + $response = $response.with_body_reader($body); + + $crate::__mock_impl!(@response($response) $($tail)*) + }; + + (@response($response:ident) transfer_encoding: $value:expr, $($tail:tt)*) => { + if $value { + $response.body_len = None; + } + + $crate::__mock_impl!(@response($response) $($tail)*) + }; + + (@response($response:ident) delay: $delay:tt, $($tail:tt)*) => { + let duration = $crate::macro_api::parse_duration(stringify!($delay)); + ::std::thread::sleep(duration); + + $crate::__mock_impl!(@response($response) $($tail)*) + }; + + (@response($response:ident) headers { + $( + $name:literal: $value:expr, + )* + } $($tail:tt)*) => { + $( + $response.headers.push(($name.to_string(), $value.to_string())); + )* + + $crate::__mock_impl!(@response($response) $($tail)*) + }; + + (@response($response:ident)) => {}; +} + +#[doc(hidden)] +pub mod macro_api { + use std::time::Duration; + + pub fn parse_duration(s: &str) -> Duration { + humantime::parse_duration(s).unwrap() + } + + pub struct ClosureResponder(F); + + impl ClosureResponder + where + F: Send + Sync + 'static + for<'r> Fn(&'r mut crate::RequestContext<'_>), + { + pub fn new(f: F) -> Self { + Self(f) + } + } + + impl crate::Responder for ClosureResponder + where + F: Send + Sync + 'static + for<'r> Fn(&'r mut crate::RequestContext<'_>), + { + fn respond(&self, ctx: &mut crate::RequestContext<'_>) { + (self.0)(ctx) + } + } +} diff --git a/testserver/src/mock.rs b/testserver/src/mock.rs index cb506b00..a02adaa0 100644 --- a/testserver/src/mock.rs +++ b/testserver/src/mock.rs @@ -7,67 +7,73 @@ //! Only HTTP/1.x is implemented, as newer HTTP versions are mostly the same //! semantically and are far more complex to deal with. -use crate::{request::Request, responder::*, response::Response}; +use crate::{pool::pool, request::Request, responder::*, response::Response}; use std::{ collections::VecDeque, io::{Cursor, Read, Write}, net::{SocketAddr, TcpStream}, - sync::{Arc, Mutex}, + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, + Mutex, + }, thread, time::Duration, }; use tiny_http::Server; /// A mock HTTP endpoint. -pub struct Mock { - server: Arc, - requests: Arc>>, - responder: Arc, -} +#[derive(Clone)] +pub struct Mock(Arc); -impl Mock { - pub fn new(responder: R) -> Self { - let mock = Self { - server: Arc::new(Server::http("127.0.0.1:0").unwrap()), - requests: Default::default(), - responder: Arc::new(responder), - }; +struct Inner { + server: Server, - thread::spawn({ - let mock = mock.clone(); + requests: Mutex>, - move || { - for request in mock.server.incoming_requests() { - mock.handle_request(request); - } - } - }); + /// Number of requests received since the mock was created. + request_counter: AtomicU32, - mock.wait_until_ready(); + /// A list of responders. When receiving a request each responder is tried + /// in order until one returns a response. + responders: Vec>, +} - mock +impl Mock { + /// Create a new mock server with a single responder. + pub fn new(responder: R) -> Self { + Self::builder().responder(responder).build() } + /// Create a builder for creating a customized mock server. + pub fn builder() -> Builder { + Builder { + responders: vec![], + } + } + + /// Get the socket address of this mock server. pub fn addr(&self) -> SocketAddr { - self.server.server_addr() + self.0.server.server_addr() } + /// Get the HTTP URL of this mock server. pub fn url(&self) -> String { format!("http://{}/", self.addr()) } + /// Get the number of requests received so far by this mock. + pub fn requests_received(&self) -> u32 { + self.0.request_counter.load(Ordering::SeqCst) + } + /// Get the first request received by this mock. pub fn request(&self) -> Request { - let request = self.requests.lock().unwrap().get(0).cloned(); + let request = self.0.requests.lock().unwrap().front().cloned(); request.expect("no request received") } - /// Get all requests received by this mock. - pub fn requests(&self) -> Vec { - self.requests.lock().unwrap().iter().cloned().collect() - } - #[rustfmt::skip] fn is_ready(&self) -> bool { TcpStream::connect(self.addr()) @@ -98,19 +104,6 @@ impl Mock { panic!("mock server did not become ready after 9 tries"); } - fn respond(&self, request: Request) -> Response { - if let Some(response) = self.responder.respond(request.clone()) { - return response; - } - - Response { - status_code: 404, - headers: Vec::new(), - body: Box::new(std::io::empty()), - body_len: Some(0), - } - } - fn handle_request(&self, mut request: tiny_http::Request) { if request .headers() @@ -118,10 +111,8 @@ impl Mock { .find(|h| h.field.as_str() == "host" && h.value == "api.mock.local") .is_some() { - if let Some(response) = self.handle_api_request(&request) { - request.respond(response).unwrap(); - return; - } + self.handle_api_request(request); + return; } let mut body = Vec::new(); @@ -133,7 +124,8 @@ impl Mock { request.as_reader().read_to_end(&mut body).unwrap(); // Build a record of the request received. - let mock_request = Request { + let mut mock_request = Request { + number: self.0.request_counter.fetch_add(1, Ordering::SeqCst), method: request.method().to_string(), url: request.url().to_string(), headers: request @@ -144,46 +136,84 @@ impl Mock { body: Some(body), }; - self.requests + self.0 + .requests .lock() .unwrap() .push_back(mock_request.clone()); - let response = self.respond(mock_request); + let mut ctx = RequestContext::new(&mut mock_request, request); + + for responder in &self.0.responders { + responder.respond(&mut ctx); + + if ctx.http_request.is_none() { + break; + } + } - request.respond(response.into_http_response()).unwrap(); + if let Some(request) = ctx.http_request.take() { + request + .respond( + Response { + status_code: 404, + headers: Vec::new(), + body: Box::new(std::io::empty()), + body_len: Some(0), + } + .into_http_response(), + ) + .unwrap(); + } } - fn handle_api_request( - &self, - request: &tiny_http::Request, - ) -> Option>>> { + fn handle_api_request(&self, request: tiny_http::Request) { if request.url() == "/health" { - Some(tiny_http::Response::new( + let _ = request.respond(tiny_http::Response::new( 200.into(), vec![], - Cursor::new((&b"OK"[..]).into()), + Cursor::new(b"OK".to_vec()), Some(2), None, - )) - } else { - None + )); } } } -impl Default for Mock { - fn default() -> Self { - Self::new(DefaultResponder) - } +/// A builder for creating mock servers. +pub struct Builder { + responders: Vec>, } -impl Clone for Mock { - fn clone(&self) -> Self { - Self { - server: self.server.clone(), - requests: self.requests.clone(), - responder: self.responder.clone(), - } +impl Builder { + /// Add a responder to the mock. Responders are tried in the order that they + /// are added to the builder. + pub fn responder(mut self, responder: R) -> Self { + self.responders.push(Box::new(responder)); + self + } + + /// Start a new mock server. + pub fn build(self) -> Mock { + let mock = Mock(Arc::new(Inner { + server: Server::http("127.0.0.1:0").unwrap(), + requests: Default::default(), + request_counter: AtomicU32::new(0), + responders: self.responders, + })); + + pool().execute({ + let mock = mock.clone(); + + move || { + for request in mock.0.server.incoming_requests() { + mock.handle_request(request); + } + } + }); + + mock.wait_until_ready(); + + mock } } diff --git a/testserver/src/pool.rs b/testserver/src/pool.rs new file mode 100644 index 00000000..2e715e04 --- /dev/null +++ b/testserver/src/pool.rs @@ -0,0 +1,19 @@ +//! Shared thread pool for executing request handlers. +//! +//! While mocks can't share TCP servers since the approach of this library is +//! port-per-test, we _can_ share threads across all mock servers to make it not +//! quite as inefficient. + +use once_cell::sync::Lazy; +use threadfin::ThreadPool; + +/// Get access to the shared thread pool. +pub(crate) fn pool() -> &'static ThreadPool { + // Pool that crates pretty much as many threads as needed, while still + // allowing reuse. + static POOL: Lazy = Lazy::new(|| ThreadPool::builder() + .size(..100) + .build()); + + &*POOL +} diff --git a/testserver/src/request.rs b/testserver/src/request.rs index 236a6fe8..6320f1d7 100644 --- a/testserver/src/request.rs +++ b/testserver/src/request.rs @@ -2,13 +2,30 @@ use regex::Regex; #[derive(Clone, Debug, Eq, PartialEq)] pub struct Request { - pub method: String, - pub url: String, - pub headers: Vec<(String, String)>, - pub body: Option>, + pub(crate) number: u32, + pub(crate) method: String, + pub(crate) url: String, + pub(crate) headers: Vec<(String, String)>, + pub(crate) body: Option>, } impl Request { + pub fn method(&self) -> &str { + self.method.as_str() + } + + pub fn url(&self) -> &str { + self.url.as_str() + } + + /// Get the request number. + /// + /// This is a monotonically increasing number, starting from 0, that + /// indicates the order of requests received by the mock. + pub fn number(&self) -> u32 { + self.number + } + pub fn get_header(&self, name: impl AsRef) -> impl Iterator + '_ { let name_lower = name.as_ref().to_lowercase(); diff --git a/testserver/src/responder.rs b/testserver/src/responder.rs index 46e643b8..9f8bfcd3 100644 --- a/testserver/src/responder.rs +++ b/testserver/src/responder.rs @@ -1,18 +1,62 @@ +use std::{io::Write, thread::sleep, time::Duration}; + use crate::{request::Request, response::Response}; +/// Provides methods for responding to a request. +pub struct RequestContext<'r> { + pub(crate) request: &'r mut Request, + pub(crate) http_request: Option, + pub(crate) delay: Option, +} + +impl<'r> RequestContext<'r> { + pub(crate) fn new(request: &'r mut Request, http_request: tiny_http::Request) -> Self { + Self { + request, + http_request: Some(http_request), + delay: None, + } + } + + pub fn request(&self) -> &Request { + self.request + } + + pub fn send(&mut self, response: Response) { + if let Some(delay) = self.delay { + sleep(delay); + } + + self.http_request + .take() + .unwrap() + .respond(response.into_http_response()) + .unwrap(); + } + + pub fn into_raw(&mut self) -> impl Write { + self.http_request.take().unwrap().into_writer() + } + + pub fn set_delay(&mut self, delay: Duration) { + self.delay = Some(delay); + } +} + /// A responder is a request-response handler responsible for producing the /// responses returned by a mock endpoint. /// -/// Responders are not responsible for doing any assertions. +/// Responders are not responsible for doing any test assertions. pub trait Responder: Send + Sync + 'static { - fn respond(&self, request: Request) -> Option; + /// Respond to a request. + fn respond(&self, ctx: &mut RequestContext<'_>); } /// Simple responder that returns a general response. pub struct DefaultResponder; impl Responder for DefaultResponder { - fn respond(&self, _: Request) -> Option { - Some(Response::default()) + fn respond(&self, ctx: &mut RequestContext<'_>) { + ctx.send(Response::default()); } } diff --git a/testserver/src/server.rs b/testserver/src/server.rs deleted file mode 100644 index ea97534e..00000000 --- a/testserver/src/server.rs +++ /dev/null @@ -1,111 +0,0 @@ -use std::{io::{BufRead, BufReader, Read, Result}, net::TcpListener}; - -use httparse::parse_headers; - -pub(crate) struct Server { - listener: TcpListener, -} - -impl Server { - pub(crate) fn accept(&mut self) -> Result { - let (stream, addr) = self.listener.accept()?; - - let mut reader = GrowableBufReader::new(stream); - let mut headers = [httparse::EMPTY_HEADER; 16]; - let mut request = httparse::Request::new(&mut headers); - - loop { - reader.fill_buf_additional(8192)?; - - let result = request.parse(reader.buffer()); - match result { - Ok(httparse::Status::Partial) => continue, - Ok(httparse::Status::Complete(offset)) => { - reader.consume(offset); - }, - Err(_) => unimplemented!(), - } - } - - unimplemented!() - } -} - -pub(crate) struct Connection { - -} - -struct GrowableBufReader { - inner: R, - buffer: Vec, - low: usize, - high: usize, -} - -impl GrowableBufReader { - fn new(inner: R) -> Self { - Self { - inner, - buffer: Vec::with_capacity(8192), - low: 0, - high: 0, - } - } - - #[inline] - fn available(&self) -> usize { - self.high - self.low - } - - #[inline] - fn buffer(&self) -> &[u8] { - &self.buffer[self.low..self.high] - } - - fn fill_buf_additional(&mut self, max: usize) -> Result { - self.reserve(max); - let amt = self.inner.read(&mut self.buffer[self.high..])?; - self.high += amt; - - Ok(amt) - } - - fn reserve(&mut self, capacity: usize) { - let desired_buffer_size = self.high + capacity; - - if self.buffer.len() < desired_buffer_size { - self.buffer.resize(desired_buffer_size, 0); - } - } -} - -impl BufRead for GrowableBufReader { - fn fill_buf(&mut self) -> Result<&[u8]> { - if self.available() == 0 { - self.fill_buf_additional(8192)?; - } - - Ok(self.buffer()) - } - - fn consume(&mut self, amt: usize) { - if amt >= self.available() { - self.low = 0; - self.high = 0; - self.buffer.clear(); - } else { - self.low += amt; - } - } -} - -impl Read for GrowableBufReader { - fn read(&mut self, buf: &mut [u8]) -> Result { - let src = self.fill_buf()?; - let amt = buf.len().min(src.len()); - buf[..amt].copy_from_slice(&src[..amt]); - self.consume(amt); - - Ok(amt) - } -} diff --git a/testserver/src/socks4.rs b/testserver/src/socks4.rs index 973eab10..6d428d2d 100644 --- a/testserver/src/socks4.rs +++ b/testserver/src/socks4.rs @@ -4,15 +4,14 @@ use std::{ io::{self, BufRead, BufReader, Write}, net::{IpAddr, Shutdown, SocketAddr, TcpListener, TcpStream, ToSocketAddrs}, sync::Arc, - thread, }; -use threadpool::ThreadPool; + +use crate::pool::pool; #[derive(Clone, Debug)] pub struct Socks4Server { listener: Arc, addr: SocketAddr, - pool: ThreadPool, } impl Socks4Server { @@ -23,7 +22,6 @@ impl Socks4Server { Ok(Self { addr: listener.local_addr()?, listener: Arc::new(listener), - pool: ThreadPool::default(), }) } @@ -38,7 +36,7 @@ impl Socks4Server { Ok(connection) => { let s = self.clone(); - self.pool.execute(move || { + pool().execute(move || { s.handle(connection).unwrap(); }); } @@ -50,7 +48,7 @@ impl Socks4Server { } pub fn spawn(self) { - thread::spawn(move || self.run()); + pool().execute(move || self.run()); } fn handle(&self, connection: TcpStream) -> io::Result<()> { @@ -97,7 +95,7 @@ impl Socks4Server { client_writer.flush()?; // Copy bytes in and to the upstream in parallel. - self.pool.execute(move || { + pool().execute(move || { io::copy(&mut client_reader, &mut upstream_writer).unwrap(); });