Skip to content

Commit

Permalink
feat(client): add client::connect::capture_connection() (#3144)
Browse files Browse the repository at this point in the history
Add `capture_connection` functionality. This allows callers to retrieve the `Connected` struct of the connection that was used internally by Hyper. This is in service of #2605.

Although this uses `http::Extensions` under the hood, the API exposed explicitly hides that detail.
  • Loading branch information
rcoh committed Feb 22, 2023
1 parent 37ed5a2 commit c849339
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 6 deletions.
15 changes: 11 additions & 4 deletions src/client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,21 @@ use http::uri::{Port, Scheme};
use http::{Method, Request, Response, Uri, Version};
use tracing::{debug, trace, warn};

use crate::body::{Body, HttpBody};
use crate::client::connect::CaptureConnectionExtension;
use crate::common::{
exec::BoxSendFuture, lazy as hyper_lazy, sync_wrapper::SyncWrapper, task, Future, Lazy, Pin,
Poll,
};
use crate::rt::Executor;

use super::conn;
use super::connect::{self, sealed::Connect, Alpn, Connected, Connection};
use super::pool::{
self, CheckoutIsClosedError, Key as PoolKey, Pool, Poolable, Pooled, Reservation,
};
#[cfg(feature = "tcp")]
use super::HttpConnector;
use crate::body::{Body, HttpBody};
use crate::common::{exec::BoxSendFuture, sync_wrapper::SyncWrapper, lazy as hyper_lazy, task, Future, Lazy, Pin, Poll};
use crate::rt::Executor;

/// A Client to make outgoing HTTP requests.
///
Expand Down Expand Up @@ -238,7 +243,9 @@ where
})
}
};

req.extensions_mut()
.get_mut::<CaptureConnectionExtension>()
.map(|conn| conn.set(&pooled.conn_info));
if pooled.is_http1() {
if req.version() == Version::HTTP_2 {
warn!("Connection is HTTP/1, but request requires HTTP/2");
Expand Down
180 changes: 179 additions & 1 deletion src/client/connect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,11 @@
use std::fmt;
use std::fmt::{Debug, Formatter};
use std::sync::atomic::{AtomicBool, Ordering};
use std::ops::Deref;
use std::sync::Arc;

use ::http::Extensions;
use tokio::sync::watch;

cfg_feature! {
#![feature = "tcp"]
Expand Down Expand Up @@ -146,6 +148,114 @@ impl PoisonPill {
}
}

/// [`CaptureConnection`] allows callers to capture [`Connected`] information
///
/// To capture a connection for a request, use [`capture_connection`].
#[derive(Debug, Clone)]
pub struct CaptureConnection {
rx: watch::Receiver<Option<Connected>>,
}

/// Capture the connection for a given request
///
/// When making a request with Hyper, the underlying connection must implement the [`Connection`] trait.
/// [`capture_connection`] allows a caller to capture the returned [`Connected`] structure as soon
/// as the connection is established.
///
/// *Note*: If establishing a connection fails, [`CaptureConnection::connection_metadata`] will always return none.
///
/// # Examples
///
/// **Synchronous access**:
/// The [`CaptureConnection::connection_metadata`] method allows callers to check if a connection has been
/// established. This is ideal for situations where you are certain the connection has already
/// been established (e.g. after the response future has already completed).
/// ```rust
/// use hyper::client::connect::{capture_connection, CaptureConnection};
/// let mut request = http::Request::builder()
/// .uri("http://foo.com")
/// .body(())
/// .unwrap();
///
/// let captured_connection = capture_connection(&mut request);
/// // some time later after the request has been sent...
/// let connection_info = captured_connection.connection_metadata();
/// println!("we are connected! {:?}", connection_info.as_ref());
/// ```
///
/// **Asynchronous access**:
/// The [`CaptureConnection::wait_for_connection_metadata`] method returns a future resolves as soon as the
/// connection is available.
///
/// ```rust
/// # #[cfg(feature = "runtime")]
/// # async fn example() {
/// use hyper::client::connect::{capture_connection, CaptureConnection};
/// let mut request = http::Request::builder()
/// .uri("http://foo.com")
/// .body(hyper::Body::empty())
/// .unwrap();
///
/// let mut captured = capture_connection(&mut request);
/// tokio::task::spawn(async move {
/// let connection_info = captured.wait_for_connection_metadata().await;
/// println!("we are connected! {:?}", connection_info.as_ref());
/// });
///
/// let client = hyper::Client::new();
/// client.request(request).await.expect("request failed");
/// # }
/// ```
pub fn capture_connection<B>(request: &mut crate::http::Request<B>) -> CaptureConnection {
let (tx, rx) = CaptureConnection::new();
request.extensions_mut().insert(tx);
rx
}

/// TxSide for [`CaptureConnection`]
///
/// This is inserted into `Extensions` to allow Hyper to back channel connection info
#[derive(Clone)]
pub(crate) struct CaptureConnectionExtension {
tx: Arc<watch::Sender<Option<Connected>>>,
}

impl CaptureConnectionExtension {
pub(crate) fn set(&self, connected: &Connected) {
self.tx.send_replace(Some(connected.clone()));
}
}

impl CaptureConnection {
/// Internal API to create the tx and rx half of [`CaptureConnection`]
pub(crate) fn new() -> (CaptureConnectionExtension, Self) {
let (tx, rx) = watch::channel(None);
(
CaptureConnectionExtension { tx: Arc::new(tx) },
CaptureConnection { rx },
)
}

/// Retrieve the connection metadata, if available
pub fn connection_metadata(&self) -> impl Deref<Target = Option<Connected>> + '_ {
self.rx.borrow()
}

/// Wait for the connection to be established
///
/// If a connection was established, this will always return `Some(...)`. If the request never
/// successfully connected (e.g. DNS resolution failure), this method will never return.
pub async fn wait_for_connection_metadata(
&mut self,
) -> impl Deref<Target = Option<Connected>> + '_ {
if self.rx.borrow().is_some() {
return self.rx.borrow();
}
let _ = self.rx.changed().await;
self.rx.borrow()
}
}

pub(super) struct Extra(Box<dyn ExtraInner>);

#[derive(Clone, Copy, Debug, PartialEq)]
Expand Down Expand Up @@ -233,7 +343,6 @@ impl Connected {

// Don't public expose that `Connected` is `Clone`, unsure if we want to
// keep that contract...
#[cfg(feature = "http2")]
pub(super) fn clone(&self) -> Connected {
Connected {
alpn: self.alpn.clone(),
Expand Down Expand Up @@ -394,6 +503,7 @@ pub(super) mod sealed {
#[cfg(test)]
mod tests {
use super::Connected;
use crate::client::connect::CaptureConnection;

#[derive(Clone, Debug, PartialEq)]
struct Ex1(usize);
Expand Down Expand Up @@ -452,4 +562,72 @@ mod tests {
assert_eq!(ex2.get::<Ex1>(), Some(&Ex1(99)));
assert_eq!(ex2.get::<Ex2>(), Some(&Ex2("hiccup")));
}

#[test]
fn test_sync_capture_connection() {
let (tx, rx) = CaptureConnection::new();
assert!(
rx.connection_metadata().is_none(),
"connection has not been set"
);
tx.set(&Connected::new().proxy(true));
assert_eq!(
rx.connection_metadata()
.as_ref()
.expect("connected should be set")
.is_proxied(),
true
);

// ensure it can be called multiple times
assert_eq!(
rx.connection_metadata()
.as_ref()
.expect("connected should be set")
.is_proxied(),
true
);
}

#[tokio::test]
async fn async_capture_connection() {
let (tx, mut rx) = CaptureConnection::new();
assert!(
rx.connection_metadata().is_none(),
"connection has not been set"
);
let test_task = tokio::spawn(async move {
assert_eq!(
rx.wait_for_connection_metadata()
.await
.as_ref()
.expect("connection should be set")
.is_proxied(),
true
);
// can be awaited multiple times
assert!(
rx.wait_for_connection_metadata().await.is_some(),
"should be awaitable multiple times"
);

assert_eq!(rx.connection_metadata().is_some(), true);
});
// can't be finished, we haven't set the connection yet
assert_eq!(test_task.is_finished(), false);
tx.set(&Connected::new().proxy(true));

assert!(test_task.await.is_ok());
}

#[tokio::test]
async fn capture_connection_sender_side_dropped() {
let (tx, mut rx) = CaptureConnection::new();
assert!(
rx.connection_metadata().is_none(),
"connection has not been set"
);
drop(tx);
assert!(rx.wait_for_connection_metadata().await.is_none());
}
}
34 changes: 33 additions & 1 deletion tests/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1121,10 +1121,11 @@ mod dispatch_impl {
use http::Uri;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::TcpStream;
use tokio_test::block_on;

use super::support;
use hyper::body::HttpBody;
use hyper::client::connect::{Connected, Connection, HttpConnector};
use hyper::client::connect::{capture_connection, Connected, Connection, HttpConnector};
use hyper::Client;

#[test]
Expand Down Expand Up @@ -1533,6 +1534,37 @@ mod dispatch_impl {
assert_eq!(connects.load(Ordering::Relaxed), 0);
}

#[test]
fn capture_connection_on_client() {
let _ = pretty_env_logger::try_init();

let _rt = support::runtime();
let connector = DebugConnector::new();

let client = Client::builder().build(connector);

let server = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = server.local_addr().unwrap();
thread::spawn(move || {
let mut sock = server.accept().unwrap().0;
//drop(server);
sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap();
sock.set_write_timeout(Some(Duration::from_secs(5)))
.unwrap();
let mut buf = [0; 4096];
sock.read(&mut buf).expect("read 1");
sock.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
.expect("write 1");
});
let mut req = Request::builder()
.uri(&*format!("http://{}/a", addr))
.body(Body::empty())
.unwrap();
let captured_conn = capture_connection(&mut req);
block_on(client.request(req)).expect("200 OK");
assert!(captured_conn.connection_metadata().is_some());
}

#[test]
fn client_keep_alive_0() {
let _ = pretty_env_logger::try_init();
Expand Down

0 comments on commit c849339

Please sign in to comment.