From cb010c974f98eb044bcb087f8aeae9f2fcd36fa9 Mon Sep 17 00:00:00 2001 From: James Nugent Date: Mon, 7 Oct 2019 13:53:17 +0200 Subject: [PATCH] Rework TLS configuration to use a builder This commit reworks TLS configuration of both servers and endpoints in order to provide a more flexible API. We now add options to configure the selected TLS library using the appropriate 'native' configuration structures, as well as retaining the existing simplier interface which is compatible with both. The new API can also be easily extended to support simple interfaces for configuring mTLS and a range of other options without creating sprawl in the builders for `Server` and `Endpoint`. --- tonic-examples/src/tls/client.rs | 9 +- tonic-examples/src/tls/server.rs | 6 +- tonic-interop/src/bin/client.rs | 8 +- tonic-interop/src/bin/server.rs | 4 +- tonic/src/transport/endpoint.rs | 158 ++++++++++++++++++----------- tonic/src/transport/mod.rs | 17 +++- tonic/src/transport/server.rs | 133 ++++++++++++++++-------- tonic/src/transport/service/tls.rs | 74 +++++++++++--- tonic/src/transport/tls.rs | 15 +++ 9 files changed, 295 insertions(+), 129 deletions(-) diff --git a/tonic-examples/src/tls/client.rs b/tonic-examples/src/tls/client.rs index b8346763c..75a59c68f 100644 --- a/tonic-examples/src/tls/client.rs +++ b/tonic-examples/src/tls/client.rs @@ -3,19 +3,22 @@ pub mod pb { } use pb::{client::EchoClient, EchoRequest}; -use tonic::transport::{Certificate, Channel}; +use tonic::transport::{Certificate, Channel, ClientTlsConfig, TlsProvider}; #[tokio::main] async fn main() -> Result<(), Box> { let pem = tokio::fs::read("tonic-examples/data/tls/ca.pem").await?; let ca = Certificate::from_pem(pem); + let tls = ClientTlsConfig::new(TlsProvider::Rustls) + .ca_certificate(ca) + .domain_name(Some("example.com".into())); + let channel = Channel::from_static("http://[::1]:50051") - .rustls_tls(ca, Some("example.com".into())) + .tls_config(tls) .channel(); let mut client = EchoClient::new(channel); - let request = tonic::Request::new(EchoRequest { message: "hello".into(), }); diff --git a/tonic-examples/src/tls/server.rs b/tonic-examples/src/tls/server.rs index 5713a4fd0..38b046490 100644 --- a/tonic-examples/src/tls/server.rs +++ b/tonic-examples/src/tls/server.rs @@ -5,7 +5,7 @@ pub mod pb { use pb::{EchoRequest, EchoResponse}; use std::collections::VecDeque; use tonic::{ - transport::{Identity, Server}, + transport::{Identity, Server, ServerTlsConfig, TlsProvider}, Request, Response, Status, Streaming, }; @@ -58,8 +58,10 @@ async fn main() -> Result<(), Box> { let addr = "[::1]:50051".parse().unwrap(); let server = EchoServer::default(); + let tls = ServerTlsConfig::new(TlsProvider::Rustls).identity(identity); + Server::builder() - .rustls_tls(identity) + .tls_config(tls) .clone() .serve(addr, pb::server::EchoServer::new(server)) .await?; diff --git a/tonic-interop/src/bin/client.rs b/tonic-interop/src/bin/client.rs index 2012df61a..f5ac2ed00 100644 --- a/tonic-interop/src/bin/client.rs +++ b/tonic-interop/src/bin/client.rs @@ -1,6 +1,6 @@ use std::time::Duration; use structopt::{clap::arg_enum, StructOpt}; -use tonic::transport::{Certificate, Endpoint}; +use tonic::transport::{Certificate, ClientTlsConfig, Endpoint, TlsProvider}; use tonic_interop::client; #[derive(StructOpt)] @@ -33,7 +33,11 @@ async fn main() -> Result<(), Box> { if matches.use_tls { let pem = tokio::fs::read("tonic-interop/data/ca.pem").await?; let ca = Certificate::from_pem(pem); - endpoint.openssl_tls(ca, Some("foo.test.google.fr".into())); + + let tls = ClientTlsConfig::new(TlsProvider::OpenSsl) + .ca_certificate(ca) + .domain_name(Some("foo.test.google.fr".into())); + endpoint.tls_config(tls); } let channel = endpoint.channel(); diff --git a/tonic-interop/src/bin/server.rs b/tonic-interop/src/bin/server.rs index 7a65473ab..0bfe68b53 100644 --- a/tonic-interop/src/bin/server.rs +++ b/tonic-interop/src/bin/server.rs @@ -2,7 +2,7 @@ use http::header::HeaderName; use structopt::StructOpt; use tonic::body::BoxBody; use tonic::client::GrpcService; -use tonic::transport::{Identity, Server}; +use tonic::transport::{Identity, Server, ServerTlsConfig, TlsProvider}; use tonic_interop::{server, MergeTrailers}; #[derive(StructOpt)] @@ -26,7 +26,7 @@ async fn main() -> std::result::Result<(), Box> { let key = tokio::fs::read("tonic-interop/data/server1.key").await?; let identity = Identity::from_pem(cert, key); - builder.openssl_tls(identity); + builder.tls_config(ServerTlsConfig::new(TlsProvider::OpenSsl).identity(identity)); } builder.interceptor_fn(|svc, req| { diff --git a/tonic/src/transport/endpoint.rs b/tonic/src/transport/endpoint.rs index ab11b9922..c237504b7 100644 --- a/tonic/src/transport/endpoint.rs +++ b/tonic/src/transport/endpoint.rs @@ -122,64 +122,6 @@ impl Endpoint { self } - /// Enable TLS and apply the CA as the root certificate. - /// - /// Providing an optional domain to override. If `None` is passed to this - /// the TLS implementation will use the `Uri` that was used to create the - /// `Endpoint` builder. - /// - /// ```no_run - /// # use tonic::transport::{Certificate, Endpoint}; - /// # fn dothing() -> Result<(), Box> { - /// # let mut builder = Endpoint::from_static("https://example.com"); - /// let ca = std::fs::read_to_string("ca.pem")?; - /// - /// let ca = Certificate::from_pem(ca); - /// - /// builder.openssl_tls(ca, "example.com".to_string()); - /// # Ok(()) - /// # } - /// ``` - #[cfg(feature = "openssl")] - #[cfg_attr(docsrs, doc(cfg(feature = "openssl")))] - pub fn openssl_tls(&mut self, ca: Certificate, domain: impl Into>) -> &mut Self { - let domain = domain - .into() - .unwrap_or_else(|| self.uri.clone().to_string()); - let tls = TlsConnector::new_with_openssl(ca, domain).unwrap(); - self.tls = Some(tls); - self - } - - /// Enable TLS and apply the CA as the root certificate. - /// - /// Providing an optional domain to override. If `None` is passed to this - /// the TLS implementation will use the `Uri` that was used to create the - /// `Endpoint` builder. - /// - /// ```no_run - /// # use tonic::transport::{Certificate, Endpoint}; - /// # fn dothing() -> Result<(), Box> { - /// # let mut builder = Endpoint::from_static("https://example.com"); - /// let ca = std::fs::read_to_string("ca.pem")?; - /// - /// let ca = Certificate::from_pem(ca); - /// - /// builder.rustls_tls(ca, "example.com".to_string()); - /// # Ok(()) - /// # } - /// ``` - #[cfg(feature = "rustls")] - #[cfg_attr(docsrs, doc(cfg(feature = "rustls")))] - pub fn rustls_tls(&mut self, ca: Certificate, domain: impl Into>) -> &mut Self { - let domain = domain - .into() - .unwrap_or_else(|| self.uri.clone().to_string()); - let tls = TlsConnector::new_with_rustls(ca, domain).unwrap(); - self.tls = Some(tls); - self - } - /// Intercept outbound HTTP Request headers; pub fn intercept_headers(&mut self, f: F) -> &mut Self where @@ -189,6 +131,13 @@ impl Endpoint { self } + /// Configures TLS for the endpoint. + #[cfg(feature = "tls")] + pub fn tls_config(&mut self, tls_config: ClientTlsConfig) -> &mut Self { + self.tls = tls_config.tls_connector(self.uri.clone()); + self + } + /// Create a channel from this config. pub fn channel(&self) -> Channel { Channel::connect(self.clone()) @@ -252,3 +201,96 @@ impl fmt::Debug for Endpoint { f.debug_struct("Endpoint").finish() } } + +/// Configures TLS settings for endpoints. +#[cfg(feature = "tls")] +pub struct ClientTlsConfig { + provider: TlsProvider, + domain: Option, + cert: Option, + #[cfg(feature = "openssl")] + openssl_raw: Option, + #[cfg(feature = "rustls")] + rustls_raw: Option, +} + +#[cfg(feature = "tls")] +impl fmt::Debug for ClientTlsConfig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ClientTlsConfig") + .field("provider", &self.provider) + .finish() + } +} + +#[cfg(feature = "tls")] +impl ClientTlsConfig { + /// Creates a new `ClientTlsConfig` backed by the specified provider. Enable the `openssl` or + /// `rustls` features of the `tonic` crate to use OpenSSL or Rustls respectively. + pub fn new(provider: crate::transport::TlsProvider) -> Self { + ClientTlsConfig { + provider, + domain: None, + cert: None, + #[cfg(feature = "openssl")] + openssl_raw: None, + #[cfg(feature = "rustls")] + rustls_raw: None, + } + } + + /// Sets the domain name against which to verify the server's TLS certificate. If set to `None` + /// (the default), the address specified + pub fn domain_name(mut self, domain_name: impl Into>) -> Self { + self.domain = domain_name.into(); + self + } + + /// Sets the CA Certificate against which to verify the server's TLS certificate. + pub fn ca_certificate(mut self, ca_certificate: Certificate) -> Self { + self.cert = Some(ca_certificate); + self + } + + /// Use options specified by the given `SslConnector` to configure TLS. + /// + /// This overrides all other TLS options set via other means. + #[cfg(feature = "openssl")] + pub fn openssl_connector(mut self, connector: openssl1::ssl::SslConnector) -> Self { + self.openssl_raw = Some(connector); + self + } + + /// Use options specified by the given `ClientConfig` to configure TLS. + /// + /// This overrides all other TLS options set via other means. + #[cfg(feature = "rustls")] + pub fn rustls_client_config(mut self, config: tokio_rustls::rustls::ClientConfig) -> Self { + self.rustls_raw = Some(config); + self + } + + fn tls_connector(self, uri: Uri) -> Option { + match self.provider { + TlsProvider::None => None, + #[cfg(feature = "openssl")] + TlsProvider::OpenSsl => { + if let Some(connector) = self.openssl_raw { + return Some(TlsConnector::new_with_openssl_raw(connector).unwrap()); + } else { + let domain = self.domain.unwrap_or_else(|| uri.to_string()); + return Some(TlsConnector::new_with_openssl_cert(self.cert, domain).unwrap()); + } + } + #[cfg(feature = "rustls")] + TlsProvider::Rustls => { + if let Some(config) = self.rustls_raw { + return Some(TlsConnector::new_with_rustls_raw(config).unwrap()); + } else { + let domain = self.domain.unwrap_or_else(|| uri.to_string()); + return Some(TlsConnector::new_with_rustls_cert(self.cert, domain).unwrap()); + } + } + } + } +} diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index d963186e4..384fff5eb 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -19,7 +19,7 @@ //! ## Client //! //! ```no_run -//! # use tonic::transport::{Channel, Certificate}; +//! # use tonic::transport::{Channel, Certificate, ClientTlsConfig, TlsProvider}; //! # use std::time::Duration; //! # use tonic::body::BoxBody; //! # use tonic::client::GrpcService;; @@ -29,7 +29,9 @@ //! let cert = std::fs::read_to_string("ca.pem")?; //! //! let mut channel = Channel::from_static("https://example.com") -//! .rustls_tls(Certificate::from_pem(&cert), "example.com".to_string()) +//! .tls_config(ClientTlsConfig::new(TlsProvider::Rustls) +//! .ca_certificate(Certificate::from_pem(&cert)) +//! .domain_name("example.com".to_string())) //! .timeout(Duration::from_secs(5)) //! .rate_limit(5, Duration::from_secs(1)) //! .concurrency_limit(256) @@ -43,7 +45,7 @@ //! ## Server //! //! ```no_run -//! # use tonic::transport::{Server, Identity}; +//! # use tonic::transport::{Server, Identity, ServerTlsConfig, TlsProvider}; //! # use tower::{Service, service_fn}; //! # use futures_util::future::{err, ok}; //! # #[cfg(feature = "rustls")] @@ -55,7 +57,8 @@ //! let addr = "[::1]:50051".parse()?; //! //! Server::builder() -//! .rustls_tls(Identity::from_pem(&cert, &key)) +//! .tls_config(ServerTlsConfig::new(TlsProvider::Rustls) +//! .identity(Identity::from_pem(&cert, &key))) //! .concurrency_limit_per_connection(256) //! .interceptor_fn(|svc, req| { //! println!("Request: {:?}", req); @@ -89,4 +92,10 @@ pub use self::server::Server; pub use self::tls::{Certificate, Identity}; pub use hyper::Body; +#[cfg(feature = "tls")] +pub use self::endpoint::ClientTlsConfig; +#[cfg(feature = "tls")] +pub use self::server::ServerTlsConfig; +pub use self::tls::TlsProvider; + pub(crate) use self::error::ErrorKind; diff --git a/tonic/src/transport/server.rs b/tonic/src/transport/server.rs index 34cd32711..d0338ad17 100644 --- a/tonic/src/transport/server.rs +++ b/tonic/src/transport/server.rs @@ -60,49 +60,10 @@ impl Server { } impl Server { - /// Set the [`Identity`] of this server using `openssl`. - /// - /// ```no_run - /// # use tonic::transport::{Identity, Server}; - /// # fn dothing() -> Result<(), Box> { - /// # let mut builder = Server::builder(); - /// let cert = std::fs::read_to_string("server.pem")?; - /// let key = std::fs::read_to_string("server.key")?; - /// - /// let identity = Identity::from_pem(&cert, &key); - /// - /// builder.openssl_tls(identity); - /// # Ok(()) - /// # } - /// ``` - #[cfg(feature = "openssl")] - #[cfg_attr(docsrs, doc(cfg(feature = "openssl")))] - pub fn openssl_tls(&mut self, identity: Identity) -> &mut Self { - let acceptor = TlsAcceptor::new_with_openssl(identity).unwrap(); - self.tls = Some(acceptor); - self - } - - /// Set the [`Identity`] of this server using `rustls`. - /// - /// ```no_run - /// # use tonic::transport::{Identity, Server}; - /// # fn dothing() -> Result<(), Box> { - /// # let mut builder = Server::builder(); - /// let cert = std::fs::read_to_string("server.pem")?; - /// let key = std::fs::read_to_string("server.key")?; - /// - /// let identity = Identity::from_pem(&cert, &key); - /// - /// builder.rustls_tls(identity); - /// # Ok(()) - /// # } - /// ``` - #[cfg(feature = "rustls")] - #[cfg_attr(docsrs, doc(cfg(feature = "rustls")))] - pub fn rustls_tls(&mut self, identity: Identity) -> &mut Self { - let acceptor = TlsAcceptor::new_with_rustls(identity).unwrap(); - self.tls = Some(acceptor); + /// Configure TLS for this server. + #[cfg(feature = "tls")] + pub fn tls_config(&mut self, tls_config: ServerTlsConfig) -> &mut Self { + self.tls = tls_config.tls_acceptor(); self } @@ -247,6 +208,92 @@ impl fmt::Debug for Server { } } +/// Configures TLS settings for servers. +#[cfg(feature = "tls")] +pub struct ServerTlsConfig { + provider: TlsProvider, + identity: Option, + #[cfg(feature = "openssl")] + openssl_raw: Option, + #[cfg(feature = "rustls")] + rustls_raw: Option, +} + +#[cfg(feature = "tls")] +impl fmt::Debug for ServerTlsConfig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ServerTlsConfig") + .field("provider", &self.provider) + .finish() + } +} + +#[cfg(feature = "tls")] +impl ServerTlsConfig { + /// Creates a new `ServerTlsConfig` backed by the specified provider. Enable the `openssl` or + /// `rustls` features of the `tonic` crate to use OpenSSL or Rustls respectively. + pub fn new(provider: crate::transport::TlsProvider) -> Self { + ServerTlsConfig { + provider, + identity: None, + #[cfg(feature = "openssl")] + openssl_raw: None, + #[cfg(feature = "rustls")] + rustls_raw: None, + } + } + + /// Sets the [`Identity`] of the server. + pub fn identity(mut self, identity: Identity) -> Self { + self.identity = Some(identity); + self + } + + /// Use options specified by the given `SslAcceptor` to configure TLS. + /// + /// This overrides all other TLS options set via other means. + #[cfg(feature = "openssl")] + pub fn openssl_connector(mut self, acceptor: openssl1::ssl::SslAcceptor) -> Self { + self.openssl_raw = Some(acceptor); + self + } + + /// Use options specified by the given `ServerConfig` to configure TLS. + /// + /// This overrides all other TLS options set via other means. + #[cfg(feature = "rustls")] + pub fn rustls_client_config(mut self, config: tokio_rustls::rustls::ServerConfig) -> Self { + self.rustls_raw = Some(config); + self + } + + fn tls_acceptor(self) -> Option { + match self.provider { + TlsProvider::None => None, + #[cfg(feature = "openssl")] + TlsProvider::OpenSsl => { + if let Some(acceptor) = self.openssl_raw { + return Some(TlsAcceptor::new_with_openssl_raw(acceptor).unwrap()); + } else { + return Some( + TlsAcceptor::new_with_openssl_identity(self.identity.unwrap()).unwrap(), + ); + } + } + #[cfg(feature = "rustls")] + TlsProvider::Rustls => { + if let Some(config) = self.rustls_raw { + return Some(TlsAcceptor::new_with_rustls_raw(config).unwrap()); + } else { + return Some( + TlsAcceptor::new_with_rustls_identity(self.identity.unwrap()).unwrap(), + ); + } + } + } + } +} + #[derive(Debug)] struct TcpIncoming { inner: conn::AddrIncoming, diff --git a/tonic/src/transport/service/tls.rs b/tonic/src/transport/service/tls.rs index de45e3aa8..36bcf7e7a 100644 --- a/tonic/src/transport/service/tls.rs +++ b/tonic/src/transport/service/tls.rs @@ -56,41 +56,67 @@ enum Connector { impl TlsConnector { #[cfg(feature = "openssl")] - pub(crate) fn new_with_openssl( - cert: Certificate, + pub(crate) fn new_with_openssl_cert( + cert: Option, domain: String, ) -> Result { let mut config = SslConnector::builder(SslMethod::tls())?; - config.set_alpn_protos(ALPN_H2_WIRE)?; - let ca = X509::from_pem(&cert.pem[..])?; - - config.cert_store_mut().add_cert(ca)?; + if cert.is_some() { + let cert = cert.unwrap(); + let ca = X509::from_pem(&cert.pem[..])?; + config.cert_store_mut().add_cert(ca)?; + } - let config = config.build(); + // let config = config.build(); Ok(Self { - inner: Connector::Openssl(config), + inner: Connector::Openssl(config.build()), domain: Arc::new(domain), }) } - #[cfg(feature = "rustls")] - pub(crate) fn new_with_rustls(cert: Certificate, domain: String) -> Result { - let mut buf = std::io::Cursor::new(&cert.pem[..]); + #[cfg(feature = "openssl")] + pub(crate) fn new_with_openssl_raw( + ssl_connector: openssl1::ssl::SslConnector, + ) -> Result { + Ok(Self { + inner: Connector::Openssl(ssl_connector), + domain: Arc::new("".into()), + }) + } + #[cfg(feature = "rustls")] + pub(crate) fn new_with_rustls_cert( + cert: Option, + domain: String, + ) -> Result { let mut config = ClientConfig::new(); - - config.root_store.add_pem_file(&mut buf).unwrap(); config.set_protocols(&[Vec::from(&ALPN_H2[..])]); + if cert.is_some() { + let cert = cert.unwrap(); + let mut buf = std::io::Cursor::new(&cert.pem[..]); + config.root_store.add_pem_file(&mut buf).unwrap(); + } + Ok(Self { inner: Connector::Rustls(Arc::new(config)), domain: Arc::new(domain), }) } + #[cfg(feature = "rustls")] + pub(crate) fn new_with_rustls_raw( + config: tokio_rustls::rustls::ClientConfig, + ) -> Result { + Ok(Self { + inner: Connector::Rustls(Arc::new(config)), + domain: Arc::new("".into()), + }) + } + pub(crate) async fn connect(&self, io: TcpStream) -> Result { let tls_io = match &self.inner { #[cfg(feature = "openssl")] @@ -167,7 +193,7 @@ enum Acceptor { impl TlsAcceptor { #[cfg(feature = "openssl")] - pub(crate) fn new_with_openssl(identity: Identity) -> Result { + pub(crate) fn new_with_openssl_identity(identity: Identity) -> Result { let key = PKey::private_key_from_pem(&identity.key[..])?; let cert = X509::from_pem(&identity.cert.pem[..])?; @@ -185,6 +211,15 @@ impl TlsAcceptor { }) } + #[cfg(feature = "openssl")] + pub(crate) fn new_with_openssl_raw( + acceptor: openssl1::ssl::SslAcceptor, + ) -> Result { + Ok(Self { + inner: Acceptor::Openssl(acceptor), + }) + } + #[cfg(feature = "rustls")] fn load_rustls_private_key( mut cursor: std::io::Cursor<&[u8]>, @@ -209,7 +244,7 @@ impl TlsAcceptor { } #[cfg(feature = "rustls")] - pub(crate) fn new_with_rustls(identity: Identity) -> Result { + pub(crate) fn new_with_rustls_identity(identity: Identity) -> Result { let cert = { let mut cert = std::io::Cursor::new(&identity.cert.pem[..]); match pemfile::certs(&mut cert) { @@ -238,6 +273,15 @@ impl TlsAcceptor { }) } + #[cfg(feature = "rustls")] + pub(crate) fn new_with_rustls_raw( + config: tokio_rustls::rustls::ServerConfig, + ) -> Result { + Ok(Self { + inner: Acceptor::Rustls(Arc::new(config)), + }) + } + pub(crate) async fn connect(&self, io: TcpStream) -> Result { let io = match &self.inner { #[cfg(feature = "openssl")] diff --git a/tonic/src/transport/tls.rs b/tonic/src/transport/tls.rs index f8fb9916d..ca6c5d952 100644 --- a/tonic/src/transport/tls.rs +++ b/tonic/src/transport/tls.rs @@ -1,3 +1,18 @@ +/// Selects a library to provide TLS. +#[derive(Debug)] +pub enum TlsProvider { + /// Do not enable TLS. To use OpenSSL or Rustls, enable the `openssl` or `rustls` features of + /// the `tonic` crate. + #[allow(dead_code)] + None, + /// Use OpenSSL for TLS. + #[cfg(feature = "openssl")] + OpenSsl, + /// Use OpenSSL for TLS. + #[cfg(feature = "rustls")] + Rustls, +} + /// Represents a X509 certificate. #[derive(Debug, Clone)] pub struct Certificate {