Skip to content

Commit

Permalink
Rework TLS configuration to use a builder
Browse files Browse the repository at this point in the history
This commit reworks TLS configuration of both servers and endpoints in
order to provide a more flexible API. We now add options to configure
the selected TLS library using the appropriate 'native' configuration
structures, as well as retaining the existing simplier interface which
is compatible with both.

The new API can also be easily extended to support simple interfaces for
configuring mTLS and a range of other options without creating sprawl
in the builders for `Server` and `Endpoint`.
  • Loading branch information
jen20 committed Oct 7, 2019
1 parent 959791c commit 43f60fa
Show file tree
Hide file tree
Showing 9 changed files with 288 additions and 129 deletions.
9 changes: 6 additions & 3 deletions tonic-examples/src/tls/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,22 @@ pub mod pb {
}

use pb::{client::EchoClient, EchoRequest};
use tonic::transport::{Certificate, Channel};
use tonic::transport::{Certificate, Channel, ClientTlsConfig, TlsProvider};

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let pem = tokio::fs::read("tonic-examples/data/tls/ca.pem").await?;
let ca = Certificate::from_pem(pem);

let tls = ClientTlsConfig::new(TlsProvider::Rustls)
.ca_certificate(ca)
.domain_name(Some("example.com".into()));

let channel = Channel::from_static("http://[::1]:50051")
.rustls_tls(ca, Some("example.com".into()))
.tls(tls)
.channel();

let mut client = EchoClient::new(channel);

let request = tonic::Request::new(EchoRequest {
message: "hello".into(),
});
Expand Down
6 changes: 4 additions & 2 deletions tonic-examples/src/tls/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pub mod pb {
use pb::{EchoRequest, EchoResponse};
use std::collections::VecDeque;
use tonic::{
transport::{Identity, Server},
transport::{Identity, Server, ServerTlsConfig, TlsProvider},
Request, Response, Status, Streaming,
};

Expand Down Expand Up @@ -58,8 +58,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let addr = "[::1]:50051".parse().unwrap();
let server = EchoServer::default();

let tls = ServerTlsConfig::new(TlsProvider::Rustls).identity(identity);

Server::builder()
.rustls_tls(identity)
.tls(tls)
.clone()
.serve(addr, pb::server::EchoServer::new(server))
.await?;
Expand Down
8 changes: 6 additions & 2 deletions tonic-interop/src/bin/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::time::Duration;
use structopt::{clap::arg_enum, StructOpt};
use tonic::transport::{Certificate, Endpoint};
use tonic::transport::{Certificate, ClientTlsConfig, Endpoint, TlsProvider};
use tonic_interop::client;

#[derive(StructOpt)]
Expand Down Expand Up @@ -33,7 +33,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
if matches.use_tls {
let pem = tokio::fs::read("tonic-interop/data/ca.pem").await?;
let ca = Certificate::from_pem(pem);
endpoint.openssl_tls(ca, Some("foo.test.google.fr".into()));

let tls = ClientTlsConfig::new(TlsProvider::OpenSsl)
.ca_certificate(ca)
.domain_name(Some("foo.test.google.fr".into()));
endpoint.tls(tls);
}

let channel = endpoint.channel();
Expand Down
4 changes: 2 additions & 2 deletions tonic-interop/src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use http::header::HeaderName;
use structopt::StructOpt;
use tonic::body::BoxBody;
use tonic::client::GrpcService;
use tonic::transport::{Identity, Server};
use tonic::transport::{Identity, Server, ServerTlsConfig, TlsProvider};
use tonic_interop::{server, MergeTrailers};

#[derive(StructOpt)]
Expand All @@ -26,7 +26,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let key = tokio::fs::read("tonic-interop/data/server1.key").await?;

let identity = Identity::from_pem(cert, key);
builder.openssl_tls(identity);
builder.tls(ServerTlsConfig::new(TlsProvider::OpenSsl).identity(identity));
}

builder.interceptor_fn(|svc, req| {
Expand Down
157 changes: 99 additions & 58 deletions tonic/src/transport/endpoint.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::channel::Channel;
#[cfg(feature = "tls")]
use super::{service::TlsConnector, tls::Certificate};
use crate::transport::tls::TlsProvider;
use bytes::Bytes;
use http::uri::{InvalidUriBytes, Uri};
use std::{
Expand Down Expand Up @@ -122,64 +123,6 @@ impl Endpoint {
self
}

/// Enable TLS and apply the CA as the root certificate.
///
/// Providing an optional domain to override. If `None` is passed to this
/// the TLS implementation will use the `Uri` that was used to create the
/// `Endpoint` builder.
///
/// ```no_run
/// # use tonic::transport::{Certificate, Endpoint};
/// # fn dothing() -> Result<(), Box<dyn std::error::Error>> {
/// # let mut builder = Endpoint::from_static("https://example.com");
/// let ca = std::fs::read_to_string("ca.pem")?;
///
/// let ca = Certificate::from_pem(ca);
///
/// builder.openssl_tls(ca, "example.com".to_string());
/// # Ok(())
/// # }
/// ```
#[cfg(feature = "openssl")]
#[cfg_attr(docsrs, doc(cfg(feature = "openssl")))]
pub fn openssl_tls(&mut self, ca: Certificate, domain: impl Into<Option<String>>) -> &mut Self {
let domain = domain
.into()
.unwrap_or_else(|| self.uri.clone().to_string());
let tls = TlsConnector::new_with_openssl(ca, domain).unwrap();
self.tls = Some(tls);
self
}

/// Enable TLS and apply the CA as the root certificate.
///
/// Providing an optional domain to override. If `None` is passed to this
/// the TLS implementation will use the `Uri` that was used to create the
/// `Endpoint` builder.
///
/// ```no_run
/// # use tonic::transport::{Certificate, Endpoint};
/// # fn dothing() -> Result<(), Box<dyn std::error::Error>> {
/// # let mut builder = Endpoint::from_static("https://example.com");
/// let ca = std::fs::read_to_string("ca.pem")?;
///
/// let ca = Certificate::from_pem(ca);
///
/// builder.rustls_tls(ca, "example.com".to_string());
/// # Ok(())
/// # }
/// ```
#[cfg(feature = "rustls")]
#[cfg_attr(docsrs, doc(cfg(feature = "rustls")))]
pub fn rustls_tls(&mut self, ca: Certificate, domain: impl Into<Option<String>>) -> &mut Self {
let domain = domain
.into()
.unwrap_or_else(|| self.uri.clone().to_string());
let tls = TlsConnector::new_with_rustls(ca, domain).unwrap();
self.tls = Some(tls);
self
}

/// Intercept outbound HTTP Request headers;
pub fn intercept_headers<F>(&mut self, f: F) -> &mut Self
where
Expand All @@ -189,6 +132,12 @@ impl Endpoint {
self
}

/// Configures TLS for the endpoint.
pub fn tls(&mut self, tls_config: ClientTlsConfig) -> &mut Self {
self.tls = tls_config.tls_connector(self.uri.clone());
self
}

/// Create a channel from this config.
pub fn channel(&self) -> Channel {
Channel::connect(self.clone())
Expand Down Expand Up @@ -252,3 +201,95 @@ impl fmt::Debug for Endpoint {
f.debug_struct("Endpoint").finish()
}
}

/// Configures TLS settings for endpoints.
#[cfg(feature = "tls")]
pub struct ClientTlsConfig {
provider: TlsProvider,
domain: Option<String>,
cert: Option<Certificate>,
#[cfg(feature = "openssl")]
openssl_raw: Option<openssl1::ssl::SslConnector>,
#[cfg(feature = "rustls")]
rustls_raw: Option<tokio_rustls::rustls::ClientConfig>,
}

impl fmt::Debug for ClientTlsConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ClientTlsConfig")
.field("provider", &self.provider)
.finish()
}
}

#[cfg(feature = "tls")]
impl ClientTlsConfig {
/// Creates a new `ClientTlsConfig` backed by the specified provider. Enable the `openssl` or
/// `rustls` features of the `tonic` crate to use OpenSSL or Rustls respectively.
pub fn new(provider: TlsProvider) -> Self {
ClientTlsConfig {
provider,
domain: None,
cert: None,
#[cfg(feature = "openssl")]
openssl_raw: None,
#[cfg(feature = "rustls")]
rustls_raw: None,
}
}

/// Sets the domain name against which to verify the server's TLS certificate. If set to `None`
/// (the default), the address specified
pub fn domain_name(mut self, domain_name: impl Into<Option<String>>) -> Self {
self.domain = domain_name.into();
self
}

/// Sets the CA Certificate against which to verify the server's TLS certificate.
pub fn ca_certificate(mut self, ca_certificate: Certificate) -> Self {
self.cert = Some(ca_certificate);
self
}

/// Use options specified by the given `SslConnector` to configure TLS.
///
/// This overrides all other TLS options set via other means.
#[cfg(feature = "openssl")]
pub fn openssl_connector(mut self, connector: openssl1::ssl::SslConnector) -> Self {
self.openssl_raw = Some(connector);
self
}

/// Use options specified by the given `ClientConfig` to configure TLS.
///
/// This overrides all other TLS options set via other means.
#[cfg(feature = "rustls")]
pub fn rustls_client_config(mut self, config: tokio_rustls::rustls::ClientConfig) -> Self {
self.rustls_raw = Some(config);
self
}

fn tls_connector(self, uri: Uri) -> Option<TlsConnector> {
match self.provider {
TlsProvider::None => None,
#[cfg(feature = "openssl")]
TlsProvider::OpenSsl => {
if let Some(connector) = self.openssl_raw {
return Some(TlsConnector::new_with_openssl_raw(connector).unwrap());
} else {
let domain = self.domain.unwrap_or_else(|| uri.to_string());
return Some(TlsConnector::new_with_openssl_cert(self.cert, domain).unwrap());
}
}
#[cfg(feature = "rustls")]
TlsProvider::Rustls => {
if let Some(config) = self.rustls_raw {
return Some(TlsConnector::new_with_rustls_raw(config).unwrap());
} else {
let domain = self.domain.unwrap_or_else(|| uri.to_string());
return Some(TlsConnector::new_with_rustls_cert(self.cert, domain).unwrap());
}
}
}
}
}
15 changes: 11 additions & 4 deletions tonic/src/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
//! ## Client
//!
//! ```no_run
//! # use tonic::transport::{Channel, Certificate};
//! # use tonic::transport::{Channel, Certificate, ClientTlsConfig, TlsProvider};
//! # use std::time::Duration;
//! # use tonic::body::BoxBody;
//! # use tonic::client::GrpcService;;
Expand All @@ -29,7 +29,9 @@
//! let cert = std::fs::read_to_string("ca.pem")?;
//!
//! let mut channel = Channel::from_static("https://example.com")
//! .rustls_tls(Certificate::from_pem(&cert), "example.com".to_string())
//! .tls(ClientTlsConfig::new(TlsProvider::Rustls)
//! .ca_certificate(Certificate::from_pem(&cert))
//! .domain_name("example.com".to_string()))
//! .timeout(Duration::from_secs(5))
//! .rate_limit(5, Duration::from_secs(1))
//! .concurrency_limit(256)
Expand All @@ -43,7 +45,7 @@
//! ## Server
//!
//! ```no_run
//! # use tonic::transport::{Server, Identity};
//! # use tonic::transport::{Server, Identity, ServerTlsConfig, TlsProvider};
//! # use tower::{Service, service_fn};
//! # use futures_util::future::{err, ok};
//! # #[cfg(feature = "rustls")]
Expand All @@ -55,7 +57,8 @@
//! let addr = "[::1]:50051".parse()?;
//!
//! Server::builder()
//! .rustls_tls(Identity::from_pem(&cert, &key))
//! .tls(ServerTlsConfig::new(TlsProvider::Rustls)
//! .identity(Identity::from_pem(&cert, &key)))
//! .concurrency_limit_per_connection(256)
//! .interceptor_fn(|svc, req| {
//! println!("Request: {:?}", req);
Expand Down Expand Up @@ -89,4 +92,8 @@ pub use self::server::Server;
pub use self::tls::{Certificate, Identity};
pub use hyper::Body;

pub use self::endpoint::ClientTlsConfig;
pub use self::server::ServerTlsConfig;
pub use self::tls::TlsProvider;

pub(crate) use self::error::ErrorKind;
Loading

0 comments on commit 43f60fa

Please sign in to comment.