diff --git a/src/connector.rs b/src/connector.rs index 02a29a8..f4eeea2 100644 --- a/src/connector.rs +++ b/src/connector.rs @@ -69,7 +69,7 @@ where // dst.scheme() would need to derive Eq to be matchable; // use an if cascade instead match dst.scheme() { - Some(scheme) if scheme == &http::uri::Scheme::HTTP => { + Some(scheme) if scheme == &http::uri::Scheme::HTTP && !self.force_https => { let future = self.http.call(dst); return Box::pin(async move { Ok(MaybeHttpsStream::Http(future.await.map_err(Into::into)?)) @@ -199,3 +199,104 @@ pub trait ResolveServerName { uri: &Uri, ) -> Result, Box>; } + +#[cfg(all( + test, + any(feature = "ring", feature = "aws-lc-rs"), + any( + feature = "rustls-native-certs", + feature = "webpki-roots", + feature = "rustls-platform-verifier", + ) +))] +mod tests { + use std::future::poll_fn; + + use http::Uri; + use hyper_util::client::legacy::connect::HttpConnector; + use tower_service::Service; + + use super::HttpsConnector; + use crate::{ConfigBuilderExt, HttpsConnectorBuilder}; + + fn tls_config() -> rustls::ClientConfig { + #[cfg(feature = "rustls-native-certs")] + return rustls::ClientConfig::builder() + .with_native_roots() + .unwrap() + .with_no_client_auth(); + + #[cfg(feature = "webpki-roots")] + return rustls::ClientConfig::builder() + .with_webpki_roots() + .with_no_client_auth(); + + #[cfg(feature = "rustls-platform-verifier")] + return rustls::ClientConfig::builder() + .with_platform_verifier() + .with_no_client_auth(); + } + + fn https_or_http_connector() -> HttpsConnector { + HttpsConnectorBuilder::new() + .with_tls_config(tls_config()) + .https_or_http() + .enable_http1() + .build() + } + + fn https_only_connector() -> HttpsConnector { + HttpsConnectorBuilder::new() + .with_tls_config(tls_config()) + .https_only() + .enable_http1() + .build() + } + + async fn oneshot(mut service: S, req: Req) -> Result + where + S: Service, + { + poll_fn(|cx| service.poll_ready(cx)).await?; + service.call(req).await + } + + fn https_uri() -> Uri { + Uri::from_static("https://google.com") + } + + fn http_uri() -> Uri { + Uri::from_static("http://google.com") + } + + #[tokio::test] + async fn connects_https() { + oneshot(https_or_http_connector(), https_uri()) + .await + .unwrap(); + } + + #[tokio::test] + async fn connects_http() { + oneshot(https_or_http_connector(), http_uri()) + .await + .unwrap(); + } + + #[tokio::test] + async fn connects_https_only() { + oneshot(https_only_connector(), https_uri()) + .await + .unwrap(); + } + + #[tokio::test] + async fn enforces_https_only() { + let message = oneshot(https_only_connector(), http_uri()) + .await + .unwrap_err() + .to_string(); + + assert_eq!(message, "unsupported scheme http"); + } +}