Skip to content

Commit

Permalink
feat: make Endpoint connector swappable
Browse files Browse the repository at this point in the history
Change the `connect()` API to lift up the use of `connector()` to
`Endpoint::connect()`, allowing users to provide their own
implementations (for example, Unix-domain sockets).
Any type which impls `tower_make::MakeConnection` is
suitable.

To avoid breaking the default case of HTTP(S), introduce
`connect_with_connector()` and retain `connect()`, which creates the
default connector according to the activated feature gate and passes it
to `connect_with_connector()`.

Fixes: hyperium#136
  • Loading branch information
akshayknarayan committed Dec 3, 2019
1 parent 4471a5f commit 36757ba
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 23 deletions.
1 change: 1 addition & 0 deletions tonic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ tokio-rustls = { version = "=0.12.0-alpha.5", optional = true }
rustls-native-certs = { version = "0.1", optional = true }

[dev-dependencies]
hyper-unix-connector = "0.1.1"
static_assertions = "1.0"
rand = "0.7.2"
criterion = "0.3"
Expand Down
45 changes: 41 additions & 4 deletions tonic/src/transport/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,16 @@ impl Channel {
///
/// This creates a [`Channel`] that will load balance accross all the
/// provided endpoints.
pub fn balance_list(list: impl Iterator<Item = Endpoint>) -> Self {
pub fn balance_list_with_connector<C>(
list: impl Iterator<Item = Endpoint>,
connector: C,
) -> Self
where
C: tower_make::MakeConnection<hyper::Uri> + Send + Clone + Unpin + 'static,
C::Connection: Unpin + Send + 'static,
C::Future: Send + 'static,
C::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send,
{
let list = list.collect::<Vec<_>>();

let buffer_size = list
Expand All @@ -111,16 +120,44 @@ impl Channel {
.next()
.and_then(|e| e.interceptor_headers.clone());

let discover = ServiceList::new(list);
let discover = ServiceList::new(list, connector);

Self::balance(discover, buffer_size, interceptor_headers)
}

pub(crate) async fn connect(endpoint: Endpoint) -> Result<Self, super::Error> {
/// Balance a list of [`Endpoint`]'s.
///
/// This creates a [`Channel`] that will load balance accross all the
/// provided endpoints.
pub fn balance_list(list: impl Iterator<Item = Endpoint>) -> Self {
// Backwards API compatibility.
// Uses TCP if the TLS feature is not enabled, and TLS otherwise.

let list = list.collect::<Vec<_>>();

#[cfg(feature = "tls")]
let connector = {
let tls_connector = list.iter().next().and_then(|e| e.tls.clone());
super::service::connector(tls_connector)
};

#[cfg(not(feature = "tls"))]
let connector = super::service::connector();

Channel::balance_list_with_connector(list.into_iter(), connector)
}

pub(crate) async fn connect<C>(endpoint: Endpoint, connector: C) -> Result<Self, super::Error>
where
C: tower_make::MakeConnection<hyper::Uri> + Send + 'static,
C::Connection: Unpin + Send + 'static,
C::Future: Send + 'static,
C::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send,
{
let buffer_size = endpoint.buffer_size.clone().unwrap_or(DEFAULT_BUFFER_SIZE);
let interceptor_headers = endpoint.interceptor_headers.clone();

let svc = Connection::new(endpoint)
let svc = Connection::new(endpoint, connector)
.await
.map_err(|e| super::Error::from_source(super::ErrorKind::Client, e))?;

Expand Down
49 changes: 48 additions & 1 deletion tonic/src/transport/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,54 @@ impl Endpoint {

/// Create a channel from this config.
pub async fn connect(&self) -> Result<Channel, super::Error> {
Channel::connect(self.clone()).await
// Backwards API compatibility.
// Uses TCP if the TLS feature is not enabled, and TLS otherwise.

#[cfg(feature = "tls")]
let connector = super::service::connector(self.tls.clone());

#[cfg(not(feature = "tls"))]
let connector = super::service::connector();

self.connect_with_connector(connector).await
}

/// Create a channel using a custom connector.
///
/// The [`tower_make::MakeConnection`] requirement is an alias for `tower::Service<Uri, Response = AsyncRead +
/// Async Write>` - for example, a TCP stream as in [`Endpoint::connect`] above.
///
/// # Example
/// ```rust
/// use hyper::client::connect::HttpConnector;
/// use tonic::transport::Endpoint;
///
/// // note: This connector is the same as the default provided in `connect()`.
/// let mut connector = HttpConnector::new();
/// connector.enforce_http(false);
/// connector.set_nodelay(true);
///
/// let endpoint = Endpoint::from_static("http://example.com");
/// endpoint.connect_with_connector(connector); //.await
/// ```
///
/// # Example with non-default Connector
/// ```rust
/// // Use for unix-domain sockets
/// use hyper_unix_connector::UnixClient;
/// use tonic::transport::Endpoint;
///
/// let endpoint = Endpoint::from_static("http://example.com");
/// endpoint.connect_with_connector(UnixClient); //.await
/// ```
pub async fn connect_with_connector<C>(&self, connector: C) -> Result<Channel, super::Error>
where
C: tower_make::MakeConnection<hyper::Uri> + Send + 'static,
C::Connection: Unpin + Send + 'static,
C::Future: Send + 'static,
C::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send,
{
Channel::connect(self.clone(), connector).await
}
}

Expand Down
24 changes: 12 additions & 12 deletions tonic/src/transport/service/connection.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{connector, layer::ServiceBuilderExt, reconnect::Reconnect, AddOrigin};
use super::{layer::ServiceBuilderExt, reconnect::Reconnect, AddOrigin};
use crate::{body::BoxBody, transport::Endpoint};
use hyper::client::conn::Builder;
use hyper::client::service::Connect as HyperConnect;
Expand Down Expand Up @@ -26,30 +26,30 @@ pub(crate) struct Connection {
}

impl Connection {
pub(crate) async fn new(endpoint: Endpoint) -> Result<Self, crate::Error> {
#[cfg(feature = "tls")]
let connector = connector(endpoint.tls.clone());

#[cfg(not(feature = "tls"))]
let connector = connector();

pub(crate) async fn new<C>(endpoint: Endpoint, connector: C) -> Result<Self, crate::Error>
where
C: tower_make::MakeConnection<hyper::Uri> + Send + 'static,
C::Connection: Unpin + Send + 'static,
C::Future: Send + 'static,
C::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send,
{
let settings = Builder::new()
.http2_initial_stream_window_size(endpoint.init_stream_window_size)
.http2_initial_connection_window_size(endpoint.init_connection_window_size)
.http2_only(true)
.clone();

let mut connector = HyperConnect::new(connector, settings);
let initial_conn = connector.call(endpoint.uri.clone()).await?;
let conn = Reconnect::new(initial_conn, connector, endpoint.uri.clone());

let stack = ServiceBuilder::new()
.layer_fn(|s| AddOrigin::new(s, endpoint.uri.clone()))
.optional_layer(endpoint.timeout.map(TimeoutLayer::new))
.optional_layer(endpoint.concurrency_limit.map(ConcurrencyLimitLayer::new))
.optional_layer(endpoint.rate_limit.map(|(l, d)| RateLimitLayer::new(l, d)))
.into_inner();

let mut connector = HyperConnect::new(connector, settings);
let initial_conn = connector.call(endpoint.uri.clone()).await?;
let conn = Reconnect::new(initial_conn, connector, endpoint.uri.clone());

let inner = stack.layer(conn);

Ok(Self {
Expand Down
1 change: 1 addition & 0 deletions tonic/src/transport/service/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub(crate) fn connector(tls: Option<TlsConnector>) -> Connector {
Connector::new(tls)
}

#[derive(Clone)]
pub(crate) struct Connector {
http: HttpConnector,
#[cfg(feature = "tls")]
Expand Down
21 changes: 15 additions & 6 deletions tonic/src/transport/service/discover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,32 @@ use std::{
};
use tower::discover::{Change, Discover};

pub(crate) struct ServiceList {
pub(crate) struct ServiceList<C> {
list: VecDeque<Endpoint>,
connector: C,
connecting:
Option<Pin<Box<dyn Future<Output = Result<Connection, crate::Error>> + Send + 'static>>>,
i: usize,
}

impl ServiceList {
pub(crate) fn new(list: Vec<Endpoint>) -> Self {
impl<C> ServiceList<C> {
pub(crate) fn new(list: Vec<Endpoint>, connector: C) -> Self {
Self {
list: list.into(),
connector,
connecting: None,
i: 0,
}
}
}

impl Discover for ServiceList {
impl<C> Discover for ServiceList<C>
where
C: tower_make::MakeConnection<hyper::Uri> + Send + Clone + Unpin + 'static,
C::Connection: Unpin + Send + 'static,
C::Future: Send + 'static,
C::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send,
{
type Key = usize;
type Service = Connection;
type Error = crate::Error;
Expand All @@ -49,7 +57,8 @@ impl Discover for ServiceList {
}

if let Some(endpoint) = self.list.pop_front() {
let fut = Connection::new(endpoint);
let c = &self.connector;
let fut = Connection::new(endpoint, c.clone());
self.connecting = Some(Box::pin(fut));
} else {
return Poll::Pending;
Expand All @@ -58,7 +67,7 @@ impl Discover for ServiceList {
}
}

impl fmt::Debug for ServiceList {
impl<C> fmt::Debug for ServiceList<C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ServiceList")
.field("list", &self.list)
Expand Down

0 comments on commit 36757ba

Please sign in to comment.