From f624e8dad2a8e705da45a428cc87567bb98e92fb Mon Sep 17 00:00:00 2001 From: James Nugent Date: Mon, 7 Oct 2019 13:53:17 +0200 Subject: [PATCH] Rework TLS configuration to use a builder This commit reworks TLS configuration of both servers and endpoints in order to provide a more flexible API. We now add options to configure the selected TLS library using the appropriate 'native' configuration structures, as well as retaining the existing simplier interface which is compatible with both. The new API can also be easily extended to support simple interfaces for configuring mTLS and a range of other options without creating sprawl in the builders for `Server` and `Endpoint`. --- tonic-examples/src/tls/client.rs | 9 +- tonic-examples/src/tls/server.rs | 6 +- tonic-interop/src/bin/client.rs | 8 +- tonic-interop/src/bin/server.rs | 4 +- tonic/src/transport/endpoint.rs | 157 ++++++++++++++++++----------- tonic/src/transport/mod.rs | 15 ++- tonic/src/transport/server.rs | 130 ++++++++++++++++-------- tonic/src/transport/service/tls.rs | 74 +++++++++++--- tonic/src/transport/tls.rs | 14 +++ 9 files changed, 288 insertions(+), 129 deletions(-) diff --git a/tonic-examples/src/tls/client.rs b/tonic-examples/src/tls/client.rs index b8346763c..305fe61b8 100644 --- a/tonic-examples/src/tls/client.rs +++ b/tonic-examples/src/tls/client.rs @@ -3,19 +3,22 @@ pub mod pb { } use pb::{client::EchoClient, EchoRequest}; -use tonic::transport::{Certificate, Channel}; +use tonic::transport::{Certificate, Channel, ClientTlsConfig, TlsProvider}; #[tokio::main] async fn main() -> Result<(), Box> { let pem = tokio::fs::read("tonic-examples/data/tls/ca.pem").await?; let ca = Certificate::from_pem(pem); + let tls = ClientTlsConfig::new(TlsProvider::Rustls) + .ca_certificate(ca) + .domain_name(Some("example.com".into())); + let channel = Channel::from_static("http://[::1]:50051") - .rustls_tls(ca, Some("example.com".into())) + .tls(tls) .channel(); let mut client = EchoClient::new(channel); - let request = tonic::Request::new(EchoRequest { message: "hello".into(), }); diff --git a/tonic-examples/src/tls/server.rs b/tonic-examples/src/tls/server.rs index 5713a4fd0..16496ba03 100644 --- a/tonic-examples/src/tls/server.rs +++ b/tonic-examples/src/tls/server.rs @@ -5,7 +5,7 @@ pub mod pb { use pb::{EchoRequest, EchoResponse}; use std::collections::VecDeque; use tonic::{ - transport::{Identity, Server}, + transport::{Identity, Server, ServerTlsConfig, TlsProvider}, Request, Response, Status, Streaming, }; @@ -58,8 +58,10 @@ async fn main() -> Result<(), Box> { let addr = "[::1]:50051".parse().unwrap(); let server = EchoServer::default(); + let tls = ServerTlsConfig::new(TlsProvider::Rustls).identity(identity); + Server::builder() - .rustls_tls(identity) + .tls(tls) .clone() .serve(addr, pb::server::EchoServer::new(server)) .await?; diff --git a/tonic-interop/src/bin/client.rs b/tonic-interop/src/bin/client.rs index 2012df61a..e0f0392a2 100644 --- a/tonic-interop/src/bin/client.rs +++ b/tonic-interop/src/bin/client.rs @@ -1,6 +1,6 @@ use std::time::Duration; use structopt::{clap::arg_enum, StructOpt}; -use tonic::transport::{Certificate, Endpoint}; +use tonic::transport::{Certificate, ClientTlsConfig, Endpoint, TlsProvider}; use tonic_interop::client; #[derive(StructOpt)] @@ -33,7 +33,11 @@ async fn main() -> Result<(), Box> { if matches.use_tls { let pem = tokio::fs::read("tonic-interop/data/ca.pem").await?; let ca = Certificate::from_pem(pem); - endpoint.openssl_tls(ca, Some("foo.test.google.fr".into())); + + let tls = ClientTlsConfig::new(TlsProvider::OpenSsl) + .ca_certificate(ca) + .domain_name(Some("foo.test.google.fr".into())); + endpoint.tls(tls); } let channel = endpoint.channel(); diff --git a/tonic-interop/src/bin/server.rs b/tonic-interop/src/bin/server.rs index 7a65473ab..cf8ba8ea0 100644 --- a/tonic-interop/src/bin/server.rs +++ b/tonic-interop/src/bin/server.rs @@ -2,7 +2,7 @@ use http::header::HeaderName; use structopt::StructOpt; use tonic::body::BoxBody; use tonic::client::GrpcService; -use tonic::transport::{Identity, Server}; +use tonic::transport::{Identity, Server, ServerTlsConfig, TlsProvider}; use tonic_interop::{server, MergeTrailers}; #[derive(StructOpt)] @@ -26,7 +26,7 @@ async fn main() -> std::result::Result<(), Box> { let key = tokio::fs::read("tonic-interop/data/server1.key").await?; let identity = Identity::from_pem(cert, key); - builder.openssl_tls(identity); + builder.tls(ServerTlsConfig::new(TlsProvider::OpenSsl).identity(identity)); } builder.interceptor_fn(|svc, req| { diff --git a/tonic/src/transport/endpoint.rs b/tonic/src/transport/endpoint.rs index ab11b9922..7d9230238 100644 --- a/tonic/src/transport/endpoint.rs +++ b/tonic/src/transport/endpoint.rs @@ -1,6 +1,7 @@ use super::channel::Channel; #[cfg(feature = "tls")] use super::{service::TlsConnector, tls::Certificate}; +use crate::transport::tls::TlsProvider; use bytes::Bytes; use http::uri::{InvalidUriBytes, Uri}; use std::{ @@ -122,64 +123,6 @@ impl Endpoint { self } - /// Enable TLS and apply the CA as the root certificate. - /// - /// Providing an optional domain to override. If `None` is passed to this - /// the TLS implementation will use the `Uri` that was used to create the - /// `Endpoint` builder. - /// - /// ```no_run - /// # use tonic::transport::{Certificate, Endpoint}; - /// # fn dothing() -> Result<(), Box> { - /// # let mut builder = Endpoint::from_static("https://example.com"); - /// let ca = std::fs::read_to_string("ca.pem")?; - /// - /// let ca = Certificate::from_pem(ca); - /// - /// builder.openssl_tls(ca, "example.com".to_string()); - /// # Ok(()) - /// # } - /// ``` - #[cfg(feature = "openssl")] - #[cfg_attr(docsrs, doc(cfg(feature = "openssl")))] - pub fn openssl_tls(&mut self, ca: Certificate, domain: impl Into>) -> &mut Self { - let domain = domain - .into() - .unwrap_or_else(|| self.uri.clone().to_string()); - let tls = TlsConnector::new_with_openssl(ca, domain).unwrap(); - self.tls = Some(tls); - self - } - - /// Enable TLS and apply the CA as the root certificate. - /// - /// Providing an optional domain to override. If `None` is passed to this - /// the TLS implementation will use the `Uri` that was used to create the - /// `Endpoint` builder. - /// - /// ```no_run - /// # use tonic::transport::{Certificate, Endpoint}; - /// # fn dothing() -> Result<(), Box> { - /// # let mut builder = Endpoint::from_static("https://example.com"); - /// let ca = std::fs::read_to_string("ca.pem")?; - /// - /// let ca = Certificate::from_pem(ca); - /// - /// builder.rustls_tls(ca, "example.com".to_string()); - /// # Ok(()) - /// # } - /// ``` - #[cfg(feature = "rustls")] - #[cfg_attr(docsrs, doc(cfg(feature = "rustls")))] - pub fn rustls_tls(&mut self, ca: Certificate, domain: impl Into>) -> &mut Self { - let domain = domain - .into() - .unwrap_or_else(|| self.uri.clone().to_string()); - let tls = TlsConnector::new_with_rustls(ca, domain).unwrap(); - self.tls = Some(tls); - self - } - /// Intercept outbound HTTP Request headers; pub fn intercept_headers(&mut self, f: F) -> &mut Self where @@ -189,6 +132,12 @@ impl Endpoint { self } + /// Configures TLS for the endpoint. + pub fn tls(&mut self, tls_config: ClientTlsConfig) -> &mut Self { + self.tls = tls_config.tls_connector(self.uri.clone()); + self + } + /// Create a channel from this config. pub fn channel(&self) -> Channel { Channel::connect(self.clone()) @@ -252,3 +201,95 @@ impl fmt::Debug for Endpoint { f.debug_struct("Endpoint").finish() } } + +/// Configures TLS settings for endpoints. +#[cfg(feature = "tls")] +pub struct ClientTlsConfig { + provider: TlsProvider, + domain: Option, + cert: Option, + #[cfg(feature = "openssl")] + openssl_raw: Option, + #[cfg(feature = "rustls")] + rustls_raw: Option, +} + +impl fmt::Debug for ClientTlsConfig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ClientTlsConfig") + .field("provider", &self.provider) + .finish() + } +} + +#[cfg(feature = "tls")] +impl ClientTlsConfig { + /// Creates a new `ClientTlsConfig` backed by the specified provider. Enable the `openssl` or + /// `rustls` features of the `tonic` crate to use OpenSSL or Rustls respectively. + pub fn new(provider: TlsProvider) -> Self { + ClientTlsConfig { + provider, + domain: None, + cert: None, + #[cfg(feature = "openssl")] + openssl_raw: None, + #[cfg(feature = "rustls")] + rustls_raw: None, + } + } + + /// Sets the domain name against which to verify the server's TLS certificate. If set to `None` + /// (the default), the address specified + pub fn domain_name(mut self, domain_name: impl Into>) -> Self { + self.domain = domain_name.into(); + self + } + + /// Sets the CA Certificate against which to verify the server's TLS certificate. + pub fn ca_certificate(mut self, ca_certificate: Certificate) -> Self { + self.cert = Some(ca_certificate); + self + } + + /// Use options specified by the given `SslConnector` to configure TLS. + /// + /// This overrides all other TLS options set via other means. + #[cfg(feature = "openssl")] + pub fn openssl_connector(mut self, connector: openssl1::ssl::SslConnector) -> Self { + self.openssl_raw = Some(connector); + self + } + + /// Use options specified by the given `ClientConfig` to configure TLS. + /// + /// This overrides all other TLS options set via other means. + #[cfg(feature = "rustls")] + pub fn rustls_client_config(mut self, config: tokio_rustls::rustls::ClientConfig) -> Self { + self.rustls_raw = Some(config); + self + } + + fn tls_connector(self, uri: Uri) -> Option { + match self.provider { + TlsProvider::None => None, + #[cfg(feature = "openssl")] + TlsProvider::OpenSsl => { + if let Some(connector) = self.openssl_raw { + return Some(TlsConnector::new_with_openssl_raw(connector).unwrap()); + } else { + let domain = self.domain.unwrap_or_else(|| uri.to_string()); + return Some(TlsConnector::new_with_openssl_cert(self.cert, domain).unwrap()); + } + } + #[cfg(feature = "rustls")] + TlsProvider::Rustls => { + if let Some(config) = self.rustls_raw { + return Some(TlsConnector::new_with_rustls_raw(config).unwrap()); + } else { + let domain = self.domain.unwrap_or_else(|| uri.to_string()); + return Some(TlsConnector::new_with_rustls_cert(self.cert, domain).unwrap()); + } + } + } + } +} diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index d963186e4..68e41e388 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -19,7 +19,7 @@ //! ## Client //! //! ```no_run -//! # use tonic::transport::{Channel, Certificate}; +//! # use tonic::transport::{Channel, Certificate, ClientTlsConfig, TlsProvider}; //! # use std::time::Duration; //! # use tonic::body::BoxBody; //! # use tonic::client::GrpcService;; @@ -29,7 +29,9 @@ //! let cert = std::fs::read_to_string("ca.pem")?; //! //! let mut channel = Channel::from_static("https://example.com") -//! .rustls_tls(Certificate::from_pem(&cert), "example.com".to_string()) +//! .tls(ClientTlsConfig::new(TlsProvider::Rustls) +//! .ca_certificate(Certificate::from_pem(&cert)) +//! .domain_name("example.com".to_string())) //! .timeout(Duration::from_secs(5)) //! .rate_limit(5, Duration::from_secs(1)) //! .concurrency_limit(256) @@ -43,7 +45,7 @@ //! ## Server //! //! ```no_run -//! # use tonic::transport::{Server, Identity}; +//! # use tonic::transport::{Server, Identity, ServerTlsConfig, TlsProvider}; //! # use tower::{Service, service_fn}; //! # use futures_util::future::{err, ok}; //! # #[cfg(feature = "rustls")] @@ -55,7 +57,8 @@ //! let addr = "[::1]:50051".parse()?; //! //! Server::builder() -//! .rustls_tls(Identity::from_pem(&cert, &key)) +//! .tls(ServerTlsConfig::new(TlsProvider::Rustls) +//! .identity(Identity::from_pem(&cert, &key))) //! .concurrency_limit_per_connection(256) //! .interceptor_fn(|svc, req| { //! println!("Request: {:?}", req); @@ -89,4 +92,8 @@ pub use self::server::Server; pub use self::tls::{Certificate, Identity}; pub use hyper::Body; +pub use self::endpoint::ClientTlsConfig; +pub use self::server::ServerTlsConfig; +pub use self::tls::TlsProvider; + pub(crate) use self::error::ErrorKind; diff --git a/tonic/src/transport/server.rs b/tonic/src/transport/server.rs index 34cd32711..987a19b65 100644 --- a/tonic/src/transport/server.rs +++ b/tonic/src/transport/server.rs @@ -4,6 +4,7 @@ use super::service::{layer_fn, BoxedIo, ServiceBuilderExt}; #[cfg(feature = "tls")] use super::{service::TlsAcceptor, tls::Identity}; use crate::body::BoxBody; +use crate::transport::TlsProvider; use futures_core::Stream; use futures_util::{ready, try_future::MapErr, TryFutureExt, TryStreamExt}; use http::{Request, Response}; @@ -60,49 +61,9 @@ impl Server { } impl Server { - /// Set the [`Identity`] of this server using `openssl`. - /// - /// ```no_run - /// # use tonic::transport::{Identity, Server}; - /// # fn dothing() -> Result<(), Box> { - /// # let mut builder = Server::builder(); - /// let cert = std::fs::read_to_string("server.pem")?; - /// let key = std::fs::read_to_string("server.key")?; - /// - /// let identity = Identity::from_pem(&cert, &key); - /// - /// builder.openssl_tls(identity); - /// # Ok(()) - /// # } - /// ``` - #[cfg(feature = "openssl")] - #[cfg_attr(docsrs, doc(cfg(feature = "openssl")))] - pub fn openssl_tls(&mut self, identity: Identity) -> &mut Self { - let acceptor = TlsAcceptor::new_with_openssl(identity).unwrap(); - self.tls = Some(acceptor); - self - } - - /// Set the [`Identity`] of this server using `rustls`. - /// - /// ```no_run - /// # use tonic::transport::{Identity, Server}; - /// # fn dothing() -> Result<(), Box> { - /// # let mut builder = Server::builder(); - /// let cert = std::fs::read_to_string("server.pem")?; - /// let key = std::fs::read_to_string("server.key")?; - /// - /// let identity = Identity::from_pem(&cert, &key); - /// - /// builder.rustls_tls(identity); - /// # Ok(()) - /// # } - /// ``` - #[cfg(feature = "rustls")] - #[cfg_attr(docsrs, doc(cfg(feature = "rustls")))] - pub fn rustls_tls(&mut self, identity: Identity) -> &mut Self { - let acceptor = TlsAcceptor::new_with_rustls(identity).unwrap(); - self.tls = Some(acceptor); + /// Configure TLS for this server. + pub fn tls(&mut self, tls_config: ServerTlsConfig) -> &mut Self { + self.tls = tls_config.tls_acceptor(); self } @@ -247,6 +208,89 @@ impl fmt::Debug for Server { } } +/// Configures TLS settings for servers. +#[cfg(feature = "tls")] +pub struct ServerTlsConfig { + provider: TlsProvider, + identity: Option, + #[cfg(feature = "openssl")] + openssl_raw: Option, + #[cfg(feature = "rustls")] + rustls_raw: Option, +} + +impl fmt::Debug for ServerTlsConfig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ServerTlsConfig") + .field("provider", &self.provider) + .finish() + } +} + +#[cfg(feature = "tls")] +impl ServerTlsConfig { + /// Creates a new `ServerTlsConfig` backed by the specified provider. Enable the `openssl` or + /// `rustls` features of the `tonic` crate to use OpenSSL or Rustls respectively. + pub fn new(provider: TlsProvider) -> Self { + ServerTlsConfig { + provider, + identity: None, + openssl_raw: None, + rustls_raw: None, + } + } + + /// Sets the [`Identity`] of the server. + pub fn identity(mut self, identity: Identity) -> Self { + self.identity = Some(identity); + self + } + + /// Use options specified by the given `SslAcceptor` to configure TLS. + /// + /// This overrides all other TLS options set via other means. + #[cfg(feature = "openssl")] + pub fn openssl_connector(mut self, acceptor: openssl1::ssl::SslAcceptor) -> Self { + self.openssl_raw = Some(acceptor); + self + } + + /// Use options specified by the given `ServerConfig` to configure TLS. + /// + /// This overrides all other TLS options set via other means. + #[cfg(feature = "rustls")] + pub fn rustls_client_config(mut self, config: tokio_rustls::rustls::ServerConfig) -> Self { + self.rustls_raw = Some(config); + self + } + + fn tls_acceptor(self) -> Option { + match self.provider { + TlsProvider::None => None, + #[cfg(feature = "openssl")] + TlsProvider::OpenSsl => { + if let Some(acceptor) = self.openssl_raw { + return Some(TlsAcceptor::new_with_openssl_raw(acceptor).unwrap()); + } else { + return Some( + TlsAcceptor::new_with_openssl_identity(self.identity.unwrap()).unwrap(), + ); + } + } + #[cfg(feature = "rustls")] + TlsProvider::Rustls => { + if let Some(config) = self.rustls_raw { + return Some(TlsAcceptor::new_with_rustls_raw(config).unwrap()); + } else { + return Some( + TlsAcceptor::new_with_rustls_identity(self.identity.unwrap()).unwrap(), + ); + } + } + } + } +} + #[derive(Debug)] struct TcpIncoming { inner: conn::AddrIncoming, diff --git a/tonic/src/transport/service/tls.rs b/tonic/src/transport/service/tls.rs index de45e3aa8..36bcf7e7a 100644 --- a/tonic/src/transport/service/tls.rs +++ b/tonic/src/transport/service/tls.rs @@ -56,41 +56,67 @@ enum Connector { impl TlsConnector { #[cfg(feature = "openssl")] - pub(crate) fn new_with_openssl( - cert: Certificate, + pub(crate) fn new_with_openssl_cert( + cert: Option, domain: String, ) -> Result { let mut config = SslConnector::builder(SslMethod::tls())?; - config.set_alpn_protos(ALPN_H2_WIRE)?; - let ca = X509::from_pem(&cert.pem[..])?; - - config.cert_store_mut().add_cert(ca)?; + if cert.is_some() { + let cert = cert.unwrap(); + let ca = X509::from_pem(&cert.pem[..])?; + config.cert_store_mut().add_cert(ca)?; + } - let config = config.build(); + // let config = config.build(); Ok(Self { - inner: Connector::Openssl(config), + inner: Connector::Openssl(config.build()), domain: Arc::new(domain), }) } - #[cfg(feature = "rustls")] - pub(crate) fn new_with_rustls(cert: Certificate, domain: String) -> Result { - let mut buf = std::io::Cursor::new(&cert.pem[..]); + #[cfg(feature = "openssl")] + pub(crate) fn new_with_openssl_raw( + ssl_connector: openssl1::ssl::SslConnector, + ) -> Result { + Ok(Self { + inner: Connector::Openssl(ssl_connector), + domain: Arc::new("".into()), + }) + } + #[cfg(feature = "rustls")] + pub(crate) fn new_with_rustls_cert( + cert: Option, + domain: String, + ) -> Result { let mut config = ClientConfig::new(); - - config.root_store.add_pem_file(&mut buf).unwrap(); config.set_protocols(&[Vec::from(&ALPN_H2[..])]); + if cert.is_some() { + let cert = cert.unwrap(); + let mut buf = std::io::Cursor::new(&cert.pem[..]); + config.root_store.add_pem_file(&mut buf).unwrap(); + } + Ok(Self { inner: Connector::Rustls(Arc::new(config)), domain: Arc::new(domain), }) } + #[cfg(feature = "rustls")] + pub(crate) fn new_with_rustls_raw( + config: tokio_rustls::rustls::ClientConfig, + ) -> Result { + Ok(Self { + inner: Connector::Rustls(Arc::new(config)), + domain: Arc::new("".into()), + }) + } + pub(crate) async fn connect(&self, io: TcpStream) -> Result { let tls_io = match &self.inner { #[cfg(feature = "openssl")] @@ -167,7 +193,7 @@ enum Acceptor { impl TlsAcceptor { #[cfg(feature = "openssl")] - pub(crate) fn new_with_openssl(identity: Identity) -> Result { + pub(crate) fn new_with_openssl_identity(identity: Identity) -> Result { let key = PKey::private_key_from_pem(&identity.key[..])?; let cert = X509::from_pem(&identity.cert.pem[..])?; @@ -185,6 +211,15 @@ impl TlsAcceptor { }) } + #[cfg(feature = "openssl")] + pub(crate) fn new_with_openssl_raw( + acceptor: openssl1::ssl::SslAcceptor, + ) -> Result { + Ok(Self { + inner: Acceptor::Openssl(acceptor), + }) + } + #[cfg(feature = "rustls")] fn load_rustls_private_key( mut cursor: std::io::Cursor<&[u8]>, @@ -209,7 +244,7 @@ impl TlsAcceptor { } #[cfg(feature = "rustls")] - pub(crate) fn new_with_rustls(identity: Identity) -> Result { + pub(crate) fn new_with_rustls_identity(identity: Identity) -> Result { let cert = { let mut cert = std::io::Cursor::new(&identity.cert.pem[..]); match pemfile::certs(&mut cert) { @@ -238,6 +273,15 @@ impl TlsAcceptor { }) } + #[cfg(feature = "rustls")] + pub(crate) fn new_with_rustls_raw( + config: tokio_rustls::rustls::ServerConfig, + ) -> Result { + Ok(Self { + inner: Acceptor::Rustls(Arc::new(config)), + }) + } + pub(crate) async fn connect(&self, io: TcpStream) -> Result { let io = match &self.inner { #[cfg(feature = "openssl")] diff --git a/tonic/src/transport/tls.rs b/tonic/src/transport/tls.rs index f8fb9916d..8fbc012e7 100644 --- a/tonic/src/transport/tls.rs +++ b/tonic/src/transport/tls.rs @@ -1,3 +1,17 @@ +/// Selects a library to provide TLS. +#[derive(Debug)] +pub enum TlsProvider { + /// Do not enable TLS. To use OpenSSL or Rustls, enable the `openssl` or `rustls` features of + /// the `tonic` crate. + None, + /// Use OpenSSL for TLS. + #[cfg(feature = "openssl")] + OpenSsl, + /// Use OpenSSL for TLS. + #[cfg(feature = "rustls")] + Rustls, +} + /// Represents a X509 certificate. #[derive(Debug, Clone)] pub struct Certificate {