diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index e65606ac1..411b53ff3 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -74,6 +74,7 @@ tokio-rustls = { version = "=0.12.0-alpha.5", optional = true } rustls-native-certs = { version = "0.1", optional = true } [dev-dependencies] +hyper-unix-connector = "0.1.1" static_assertions = "1.0" rand = "0.7.2" criterion = "0.3" diff --git a/tonic/src/transport/channel.rs b/tonic/src/transport/channel.rs index 2d8eba0dd..28d8f2874 100644 --- a/tonic/src/transport/channel.rs +++ b/tonic/src/transport/channel.rs @@ -97,7 +97,16 @@ impl Channel { /// /// This creates a [`Channel`] that will load balance accross all the /// provided endpoints. - pub fn balance_list(list: impl Iterator) -> Self { + pub fn balance_list_with_connector( + list: impl Iterator, + connector: C, + ) -> Self + where + C: tower_make::MakeConnection + Send + Clone + Unpin + 'static, + C::Connection: Unpin + Send + 'static, + C::Future: Send + 'static, + C::Error: Into> + Send, + { let list = list.collect::>(); let buffer_size = list @@ -111,16 +120,44 @@ impl Channel { .next() .and_then(|e| e.interceptor_headers.clone()); - let discover = ServiceList::new(list); + let discover = ServiceList::new(list, connector); Self::balance(discover, buffer_size, interceptor_headers) } - pub(crate) async fn connect(endpoint: Endpoint) -> Result { + /// Balance a list of [`Endpoint`]'s. + /// + /// This creates a [`Channel`] that will load balance accross all the + /// provided endpoints. + pub fn balance_list(list: impl Iterator) -> Self { + // Backwards API compatibility. + // Uses TCP if the TLS feature is not enabled, and TLS otherwise. + + let list = list.collect::>(); + + #[cfg(feature = "tls")] + let connector = { + let tls_connector = list.iter().next().and_then(|e| e.tls.clone()); + super::service::connector(tls_connector) + }; + + #[cfg(not(feature = "tls"))] + let connector = super::service::connector(); + + Channel::balance_list_with_connector(list.into_iter(), connector) + } + + pub(crate) async fn connect(endpoint: Endpoint, connector: C) -> Result + where + C: tower_make::MakeConnection + Send + 'static, + C::Connection: Unpin + Send + 'static, + C::Future: Send + 'static, + C::Error: Into> + Send, + { let buffer_size = endpoint.buffer_size.clone().unwrap_or(DEFAULT_BUFFER_SIZE); let interceptor_headers = endpoint.interceptor_headers.clone(); - let svc = Connection::new(endpoint) + let svc = Connection::new(endpoint, connector) .await .map_err(|e| super::Error::from_source(super::ErrorKind::Client, e))?; diff --git a/tonic/src/transport/endpoint.rs b/tonic/src/transport/endpoint.rs index 0a4f32dcc..4a08ab5dc 100644 --- a/tonic/src/transport/endpoint.rs +++ b/tonic/src/transport/endpoint.rs @@ -157,7 +157,54 @@ impl Endpoint { /// Create a channel from this config. pub async fn connect(&self) -> Result { - Channel::connect(self.clone()).await + // Backwards API compatibility. + // Uses TCP if the TLS feature is not enabled, and TLS otherwise. + + #[cfg(feature = "tls")] + let connector = super::service::connector(self.tls.clone()); + + #[cfg(not(feature = "tls"))] + let connector = super::service::connector(); + + self.connect_with_connector(connector).await + } + + /// Create a channel using a custom connector. + /// + /// The [`tower_make::MakeConnection`] requirement is an alias for `tower::Service` - for example, a TCP stream as in [`Endpoint::connect`] above. + /// + /// # Example + /// ```rust + /// use hyper::client::connect::HttpConnector; + /// use tonic::transport::Endpoint; + /// + /// // note: This connector is the same as the default provided in `connect()`. + /// let mut connector = HttpConnector::new(); + /// connector.enforce_http(false); + /// connector.set_nodelay(true); + /// + /// let endpoint = Endpoint::from_static("http://example.com"); + /// endpoint.connect_with_connector(connector); //.await + /// ``` + /// + /// # Example with non-default Connector + /// ```rust + /// // Use for unix-domain sockets + /// use hyper_unix_connector::UnixClient; + /// use tonic::transport::Endpoint; + /// + /// let endpoint = Endpoint::from_static("http://example.com"); + /// endpoint.connect_with_connector(UnixClient); //.await + /// ``` + pub async fn connect_with_connector(&self, connector: C) -> Result + where + C: tower_make::MakeConnection + Send + 'static, + C::Connection: Unpin + Send + 'static, + C::Future: Send + 'static, + C::Error: Into> + Send, + { + Channel::connect(self.clone(), connector).await } } diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/service/connection.rs index 47693b35e..663148ed0 100644 --- a/tonic/src/transport/service/connection.rs +++ b/tonic/src/transport/service/connection.rs @@ -1,4 +1,4 @@ -use super::{connector, layer::ServiceBuilderExt, reconnect::Reconnect, AddOrigin}; +use super::{layer::ServiceBuilderExt, reconnect::Reconnect, AddOrigin}; use crate::{body::BoxBody, transport::Endpoint}; use hyper::client::conn::Builder; use hyper::client::service::Connect as HyperConnect; @@ -26,19 +26,23 @@ pub(crate) struct Connection { } impl Connection { - pub(crate) async fn new(endpoint: Endpoint) -> Result { - #[cfg(feature = "tls")] - let connector = connector(endpoint.tls.clone()); - - #[cfg(not(feature = "tls"))] - let connector = connector(); - + pub(crate) async fn new(endpoint: Endpoint, connector: C) -> Result + where + C: tower_make::MakeConnection + Send + 'static, + C::Connection: Unpin + Send + 'static, + C::Future: Send + 'static, + C::Error: Into> + Send, + { let settings = Builder::new() .http2_initial_stream_window_size(endpoint.init_stream_window_size) .http2_initial_connection_window_size(endpoint.init_connection_window_size) .http2_only(true) .clone(); + let mut connector = HyperConnect::new(connector, settings); + let initial_conn = connector.call(endpoint.uri.clone()).await?; + let conn = Reconnect::new(initial_conn, connector, endpoint.uri.clone()); + let stack = ServiceBuilder::new() .layer_fn(|s| AddOrigin::new(s, endpoint.uri.clone())) .optional_layer(endpoint.timeout.map(TimeoutLayer::new)) @@ -46,10 +50,6 @@ impl Connection { .optional_layer(endpoint.rate_limit.map(|(l, d)| RateLimitLayer::new(l, d))) .into_inner(); - let mut connector = HyperConnect::new(connector, settings); - let initial_conn = connector.call(endpoint.uri.clone()).await?; - let conn = Reconnect::new(initial_conn, connector, endpoint.uri.clone()); - let inner = stack.layer(conn); Ok(Self { diff --git a/tonic/src/transport/service/connector.rs b/tonic/src/transport/service/connector.rs index c02d17367..f8a63b88a 100644 --- a/tonic/src/transport/service/connector.rs +++ b/tonic/src/transport/service/connector.rs @@ -22,6 +22,7 @@ pub(crate) fn connector(tls: Option) -> Connector { Connector::new(tls) } +#[derive(Clone)] pub(crate) struct Connector { http: HttpConnector, #[cfg(feature = "tls")] diff --git a/tonic/src/transport/service/discover.rs b/tonic/src/transport/service/discover.rs index 4635d8617..38808324b 100644 --- a/tonic/src/transport/service/discover.rs +++ b/tonic/src/transport/service/discover.rs @@ -9,24 +9,32 @@ use std::{ }; use tower::discover::{Change, Discover}; -pub(crate) struct ServiceList { +pub(crate) struct ServiceList { list: VecDeque, + connector: C, connecting: Option> + Send + 'static>>>, i: usize, } -impl ServiceList { - pub(crate) fn new(list: Vec) -> Self { +impl ServiceList { + pub(crate) fn new(list: Vec, connector: C) -> Self { Self { list: list.into(), + connector, connecting: None, i: 0, } } } -impl Discover for ServiceList { +impl Discover for ServiceList +where + C: tower_make::MakeConnection + Send + Clone + Unpin + 'static, + C::Connection: Unpin + Send + 'static, + C::Future: Send + 'static, + C::Error: Into> + Send, +{ type Key = usize; type Service = Connection; type Error = crate::Error; @@ -49,7 +57,8 @@ impl Discover for ServiceList { } if let Some(endpoint) = self.list.pop_front() { - let fut = Connection::new(endpoint); + let c = &self.connector; + let fut = Connection::new(endpoint, c.clone()); self.connecting = Some(Box::pin(fut)); } else { return Poll::Pending; @@ -58,7 +67,7 @@ impl Discover for ServiceList { } } -impl fmt::Debug for ServiceList { +impl fmt::Debug for ServiceList { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ServiceList") .field("list", &self.list)