diff --git a/tonic-examples/src/tls/client.rs b/tonic-examples/src/tls/client.rs index b8346763c..305fe61b8 100644 --- a/tonic-examples/src/tls/client.rs +++ b/tonic-examples/src/tls/client.rs @@ -3,19 +3,22 @@ pub mod pb { } use pb::{client::EchoClient, EchoRequest}; -use tonic::transport::{Certificate, Channel}; +use tonic::transport::{Certificate, Channel, ClientTlsConfig, TlsProvider}; #[tokio::main] async fn main() -> Result<(), Box> { let pem = tokio::fs::read("tonic-examples/data/tls/ca.pem").await?; let ca = Certificate::from_pem(pem); + let tls = ClientTlsConfig::new(TlsProvider::Rustls) + .ca_certificate(ca) + .domain_name(Some("example.com".into())); + let channel = Channel::from_static("http://[::1]:50051") - .rustls_tls(ca, Some("example.com".into())) + .tls(tls) .channel(); let mut client = EchoClient::new(channel); - let request = tonic::Request::new(EchoRequest { message: "hello".into(), }); diff --git a/tonic-examples/src/tls/server.rs b/tonic-examples/src/tls/server.rs index 5713a4fd0..16496ba03 100644 --- a/tonic-examples/src/tls/server.rs +++ b/tonic-examples/src/tls/server.rs @@ -5,7 +5,7 @@ pub mod pb { use pb::{EchoRequest, EchoResponse}; use std::collections::VecDeque; use tonic::{ - transport::{Identity, Server}, + transport::{Identity, Server, ServerTlsConfig, TlsProvider}, Request, Response, Status, Streaming, }; @@ -58,8 +58,10 @@ async fn main() -> Result<(), Box> { let addr = "[::1]:50051".parse().unwrap(); let server = EchoServer::default(); + let tls = ServerTlsConfig::new(TlsProvider::Rustls).identity(identity); + Server::builder() - .rustls_tls(identity) + .tls(tls) .clone() .serve(addr, pb::server::EchoServer::new(server)) .await?; diff --git a/tonic-interop/src/bin/client.rs b/tonic-interop/src/bin/client.rs index 2012df61a..e0f0392a2 100644 --- a/tonic-interop/src/bin/client.rs +++ b/tonic-interop/src/bin/client.rs @@ -1,6 +1,6 @@ use std::time::Duration; use structopt::{clap::arg_enum, StructOpt}; -use tonic::transport::{Certificate, Endpoint}; +use tonic::transport::{Certificate, ClientTlsConfig, Endpoint, TlsProvider}; use tonic_interop::client; #[derive(StructOpt)] @@ -33,7 +33,11 @@ async fn main() -> Result<(), Box> { if matches.use_tls { let pem = tokio::fs::read("tonic-interop/data/ca.pem").await?; let ca = Certificate::from_pem(pem); - endpoint.openssl_tls(ca, Some("foo.test.google.fr".into())); + + let tls = ClientTlsConfig::new(TlsProvider::OpenSsl) + .ca_certificate(ca) + .domain_name(Some("foo.test.google.fr".into())); + endpoint.tls(tls); } let channel = endpoint.channel(); diff --git a/tonic-interop/src/bin/server.rs b/tonic-interop/src/bin/server.rs index 7a65473ab..cf8ba8ea0 100644 --- a/tonic-interop/src/bin/server.rs +++ b/tonic-interop/src/bin/server.rs @@ -2,7 +2,7 @@ use http::header::HeaderName; use structopt::StructOpt; use tonic::body::BoxBody; use tonic::client::GrpcService; -use tonic::transport::{Identity, Server}; +use tonic::transport::{Identity, Server, ServerTlsConfig, TlsProvider}; use tonic_interop::{server, MergeTrailers}; #[derive(StructOpt)] @@ -26,7 +26,7 @@ async fn main() -> std::result::Result<(), Box> { let key = tokio::fs::read("tonic-interop/data/server1.key").await?; let identity = Identity::from_pem(cert, key); - builder.openssl_tls(identity); + builder.tls(ServerTlsConfig::new(TlsProvider::OpenSsl).identity(identity)); } builder.interceptor_fn(|svc, req| { diff --git a/tonic/src/transport/endpoint.rs b/tonic/src/transport/endpoint.rs index ab11b9922..7d9230238 100644 --- a/tonic/src/transport/endpoint.rs +++ b/tonic/src/transport/endpoint.rs @@ -1,6 +1,7 @@ use super::channel::Channel; #[cfg(feature = "tls")] use super::{service::TlsConnector, tls::Certificate}; +use crate::transport::tls::TlsProvider; use bytes::Bytes; use http::uri::{InvalidUriBytes, Uri}; use std::{ @@ -122,64 +123,6 @@ impl Endpoint { self } - /// Enable TLS and apply the CA as the root certificate. - /// - /// Providing an optional domain to override. If `None` is passed to this - /// the TLS implementation will use the `Uri` that was used to create the - /// `Endpoint` builder. - /// - /// ```no_run - /// # use tonic::transport::{Certificate, Endpoint}; - /// # fn dothing() -> Result<(), Box> { - /// # let mut builder = Endpoint::from_static("https://example.com"); - /// let ca = std::fs::read_to_string("ca.pem")?; - /// - /// let ca = Certificate::from_pem(ca); - /// - /// builder.openssl_tls(ca, "example.com".to_string()); - /// # Ok(()) - /// # } - /// ``` - #[cfg(feature = "openssl")] - #[cfg_attr(docsrs, doc(cfg(feature = "openssl")))] - pub fn openssl_tls(&mut self, ca: Certificate, domain: impl Into>) -> &mut Self { - let domain = domain - .into() - .unwrap_or_else(|| self.uri.clone().to_string()); - let tls = TlsConnector::new_with_openssl(ca, domain).unwrap(); - self.tls = Some(tls); - self - } - - /// Enable TLS and apply the CA as the root certificate. - /// - /// Providing an optional domain to override. If `None` is passed to this - /// the TLS implementation will use the `Uri` that was used to create the - /// `Endpoint` builder. - /// - /// ```no_run - /// # use tonic::transport::{Certificate, Endpoint}; - /// # fn dothing() -> Result<(), Box> { - /// # let mut builder = Endpoint::from_static("https://example.com"); - /// let ca = std::fs::read_to_string("ca.pem")?; - /// - /// let ca = Certificate::from_pem(ca); - /// - /// builder.rustls_tls(ca, "example.com".to_string()); - /// # Ok(()) - /// # } - /// ``` - #[cfg(feature = "rustls")] - #[cfg_attr(docsrs, doc(cfg(feature = "rustls")))] - pub fn rustls_tls(&mut self, ca: Certificate, domain: impl Into>) -> &mut Self { - let domain = domain - .into() - .unwrap_or_else(|| self.uri.clone().to_string()); - let tls = TlsConnector::new_with_rustls(ca, domain).unwrap(); - self.tls = Some(tls); - self - } - /// Intercept outbound HTTP Request headers; pub fn intercept_headers(&mut self, f: F) -> &mut Self where @@ -189,6 +132,12 @@ impl Endpoint { self } + /// Configures TLS for the endpoint. + pub fn tls(&mut self, tls_config: ClientTlsConfig) -> &mut Self { + self.tls = tls_config.tls_connector(self.uri.clone()); + self + } + /// Create a channel from this config. pub fn channel(&self) -> Channel { Channel::connect(self.clone()) @@ -252,3 +201,95 @@ impl fmt::Debug for Endpoint { f.debug_struct("Endpoint").finish() } } + +/// Configures TLS settings for endpoints. +#[cfg(feature = "tls")] +pub struct ClientTlsConfig { + provider: TlsProvider, + domain: Option, + cert: Option, + #[cfg(feature = "openssl")] + openssl_raw: Option, + #[cfg(feature = "rustls")] + rustls_raw: Option, +} + +impl fmt::Debug for ClientTlsConfig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ClientTlsConfig") + .field("provider", &self.provider) + .finish() + } +} + +#[cfg(feature = "tls")] +impl ClientTlsConfig { + /// Creates a new `ClientTlsConfig` backed by the specified provider. Enable the `openssl` or + /// `rustls` features of the `tonic` crate to use OpenSSL or Rustls respectively. + pub fn new(provider: TlsProvider) -> Self { + ClientTlsConfig { + provider, + domain: None, + cert: None, + #[cfg(feature = "openssl")] + openssl_raw: None, + #[cfg(feature = "rustls")] + rustls_raw: None, + } + } + + /// Sets the domain name against which to verify the server's TLS certificate. If set to `None` + /// (the default), the address specified + pub fn domain_name(mut self, domain_name: impl Into>) -> Self { + self.domain = domain_name.into(); + self + } + + /// Sets the CA Certificate against which to verify the server's TLS certificate. + pub fn ca_certificate(mut self, ca_certificate: Certificate) -> Self { + self.cert = Some(ca_certificate); + self + } + + /// Use options specified by the given `SslConnector` to configure TLS. + /// + /// This overrides all other TLS options set via other means. + #[cfg(feature = "openssl")] + pub fn openssl_connector(mut self, connector: openssl1::ssl::SslConnector) -> Self { + self.openssl_raw = Some(connector); + self + } + + /// Use options specified by the given `ClientConfig` to configure TLS. + /// + /// This overrides all other TLS options set via other means. + #[cfg(feature = "rustls")] + pub fn rustls_client_config(mut self, config: tokio_rustls::rustls::ClientConfig) -> Self { + self.rustls_raw = Some(config); + self + } + + fn tls_connector(self, uri: Uri) -> Option { + match self.provider { + TlsProvider::None => None, + #[cfg(feature = "openssl")] + TlsProvider::OpenSsl => { + if let Some(connector) = self.openssl_raw { + return Some(TlsConnector::new_with_openssl_raw(connector).unwrap()); + } else { + let domain = self.domain.unwrap_or_else(|| uri.to_string()); + return Some(TlsConnector::new_with_openssl_cert(self.cert, domain).unwrap()); + } + } + #[cfg(feature = "rustls")] + TlsProvider::Rustls => { + if let Some(config) = self.rustls_raw { + return Some(TlsConnector::new_with_rustls_raw(config).unwrap()); + } else { + let domain = self.domain.unwrap_or_else(|| uri.to_string()); + return Some(TlsConnector::new_with_rustls_cert(self.cert, domain).unwrap()); + } + } + } + } +} diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index d963186e4..68e41e388 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -19,7 +19,7 @@ //! ## Client //! //! ```no_run -//! # use tonic::transport::{Channel, Certificate}; +//! # use tonic::transport::{Channel, Certificate, ClientTlsConfig, TlsProvider}; //! # use std::time::Duration; //! # use tonic::body::BoxBody; //! # use tonic::client::GrpcService;; @@ -29,7 +29,9 @@ //! let cert = std::fs::read_to_string("ca.pem")?; //! //! let mut channel = Channel::from_static("https://example.com") -//! .rustls_tls(Certificate::from_pem(&cert), "example.com".to_string()) +//! .tls(ClientTlsConfig::new(TlsProvider::Rustls) +//! .ca_certificate(Certificate::from_pem(&cert)) +//! .domain_name("example.com".to_string())) //! .timeout(Duration::from_secs(5)) //! .rate_limit(5, Duration::from_secs(1)) //! .concurrency_limit(256) @@ -43,7 +45,7 @@ //! ## Server //! //! ```no_run -//! # use tonic::transport::{Server, Identity}; +//! # use tonic::transport::{Server, Identity, ServerTlsConfig, TlsProvider}; //! # use tower::{Service, service_fn}; //! # use futures_util::future::{err, ok}; //! # #[cfg(feature = "rustls")] @@ -55,7 +57,8 @@ //! let addr = "[::1]:50051".parse()?; //! //! Server::builder() -//! .rustls_tls(Identity::from_pem(&cert, &key)) +//! .tls(ServerTlsConfig::new(TlsProvider::Rustls) +//! .identity(Identity::from_pem(&cert, &key))) //! .concurrency_limit_per_connection(256) //! .interceptor_fn(|svc, req| { //! println!("Request: {:?}", req); @@ -89,4 +92,8 @@ pub use self::server::Server; pub use self::tls::{Certificate, Identity}; pub use hyper::Body; +pub use self::endpoint::ClientTlsConfig; +pub use self::server::ServerTlsConfig; +pub use self::tls::TlsProvider; + pub(crate) use self::error::ErrorKind; diff --git a/tonic/src/transport/server.rs b/tonic/src/transport/server.rs index 34cd32711..987a19b65 100644 --- a/tonic/src/transport/server.rs +++ b/tonic/src/transport/server.rs @@ -4,6 +4,7 @@ use super::service::{layer_fn, BoxedIo, ServiceBuilderExt}; #[cfg(feature = "tls")] use super::{service::TlsAcceptor, tls::Identity}; use crate::body::BoxBody; +use crate::transport::TlsProvider; use futures_core::Stream; use futures_util::{ready, try_future::MapErr, TryFutureExt, TryStreamExt}; use http::{Request, Response}; @@ -60,49 +61,9 @@ impl Server { } impl Server { - /// Set the [`Identity`] of this server using `openssl`. - /// - /// ```no_run - /// # use tonic::transport::{Identity, Server}; - /// # fn dothing() -> Result<(), Box> { - /// # let mut builder = Server::builder(); - /// let cert = std::fs::read_to_string("server.pem")?; - /// let key = std::fs::read_to_string("server.key")?; - /// - /// let identity = Identity::from_pem(&cert, &key); - /// - /// builder.openssl_tls(identity); - /// # Ok(()) - /// # } - /// ``` - #[cfg(feature = "openssl")] - #[cfg_attr(docsrs, doc(cfg(feature = "openssl")))] - pub fn openssl_tls(&mut self, identity: Identity) -> &mut Self { - let acceptor = TlsAcceptor::new_with_openssl(identity).unwrap(); - self.tls = Some(acceptor); - self - } - - /// Set the [`Identity`] of this server using `rustls`. - /// - /// ```no_run - /// # use tonic::transport::{Identity, Server}; - /// # fn dothing() -> Result<(), Box> { - /// # let mut builder = Server::builder(); - /// let cert = std::fs::read_to_string("server.pem")?; - /// let key = std::fs::read_to_string("server.key")?; - /// - /// let identity = Identity::from_pem(&cert, &key); - /// - /// builder.rustls_tls(identity); - /// # Ok(()) - /// # } - /// ``` - #[cfg(feature = "rustls")] - #[cfg_attr(docsrs, doc(cfg(feature = "rustls")))] - pub fn rustls_tls(&mut self, identity: Identity) -> &mut Self { - let acceptor = TlsAcceptor::new_with_rustls(identity).unwrap(); - self.tls = Some(acceptor); + /// Configure TLS for this server. + pub fn tls(&mut self, tls_config: ServerTlsConfig) -> &mut Self { + self.tls = tls_config.tls_acceptor(); self } @@ -247,6 +208,89 @@ impl fmt::Debug for Server { } } +/// Configures TLS settings for servers. +#[cfg(feature = "tls")] +pub struct ServerTlsConfig { + provider: TlsProvider, + identity: Option, + #[cfg(feature = "openssl")] + openssl_raw: Option, + #[cfg(feature = "rustls")] + rustls_raw: Option, +} + +impl fmt::Debug for ServerTlsConfig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ServerTlsConfig") + .field("provider", &self.provider) + .finish() + } +} + +#[cfg(feature = "tls")] +impl ServerTlsConfig { + /// Creates a new `ServerTlsConfig` backed by the specified provider. Enable the `openssl` or + /// `rustls` features of the `tonic` crate to use OpenSSL or Rustls respectively. + pub fn new(provider: TlsProvider) -> Self { + ServerTlsConfig { + provider, + identity: None, + openssl_raw: None, + rustls_raw: None, + } + } + + /// Sets the [`Identity`] of the server. + pub fn identity(mut self, identity: Identity) -> Self { + self.identity = Some(identity); + self + } + + /// Use options specified by the given `SslAcceptor` to configure TLS. + /// + /// This overrides all other TLS options set via other means. + #[cfg(feature = "openssl")] + pub fn openssl_connector(mut self, acceptor: openssl1::ssl::SslAcceptor) -> Self { + self.openssl_raw = Some(acceptor); + self + } + + /// Use options specified by the given `ServerConfig` to configure TLS. + /// + /// This overrides all other TLS options set via other means. + #[cfg(feature = "rustls")] + pub fn rustls_client_config(mut self, config: tokio_rustls::rustls::ServerConfig) -> Self { + self.rustls_raw = Some(config); + self + } + + fn tls_acceptor(self) -> Option { + match self.provider { + TlsProvider::None => None, + #[cfg(feature = "openssl")] + TlsProvider::OpenSsl => { + if let Some(acceptor) = self.openssl_raw { + return Some(TlsAcceptor::new_with_openssl_raw(acceptor).unwrap()); + } else { + return Some( + TlsAcceptor::new_with_openssl_identity(self.identity.unwrap()).unwrap(), + ); + } + } + #[cfg(feature = "rustls")] + TlsProvider::Rustls => { + if let Some(config) = self.rustls_raw { + return Some(TlsAcceptor::new_with_rustls_raw(config).unwrap()); + } else { + return Some( + TlsAcceptor::new_with_rustls_identity(self.identity.unwrap()).unwrap(), + ); + } + } + } + } +} + #[derive(Debug)] struct TcpIncoming { inner: conn::AddrIncoming, diff --git a/tonic/src/transport/service/tls.rs b/tonic/src/transport/service/tls.rs index de45e3aa8..36bcf7e7a 100644 --- a/tonic/src/transport/service/tls.rs +++ b/tonic/src/transport/service/tls.rs @@ -56,41 +56,67 @@ enum Connector { impl TlsConnector { #[cfg(feature = "openssl")] - pub(crate) fn new_with_openssl( - cert: Certificate, + pub(crate) fn new_with_openssl_cert( + cert: Option, domain: String, ) -> Result { let mut config = SslConnector::builder(SslMethod::tls())?; - config.set_alpn_protos(ALPN_H2_WIRE)?; - let ca = X509::from_pem(&cert.pem[..])?; - - config.cert_store_mut().add_cert(ca)?; + if cert.is_some() { + let cert = cert.unwrap(); + let ca = X509::from_pem(&cert.pem[..])?; + config.cert_store_mut().add_cert(ca)?; + } - let config = config.build(); + // let config = config.build(); Ok(Self { - inner: Connector::Openssl(config), + inner: Connector::Openssl(config.build()), domain: Arc::new(domain), }) } - #[cfg(feature = "rustls")] - pub(crate) fn new_with_rustls(cert: Certificate, domain: String) -> Result { - let mut buf = std::io::Cursor::new(&cert.pem[..]); + #[cfg(feature = "openssl")] + pub(crate) fn new_with_openssl_raw( + ssl_connector: openssl1::ssl::SslConnector, + ) -> Result { + Ok(Self { + inner: Connector::Openssl(ssl_connector), + domain: Arc::new("".into()), + }) + } + #[cfg(feature = "rustls")] + pub(crate) fn new_with_rustls_cert( + cert: Option, + domain: String, + ) -> Result { let mut config = ClientConfig::new(); - - config.root_store.add_pem_file(&mut buf).unwrap(); config.set_protocols(&[Vec::from(&ALPN_H2[..])]); + if cert.is_some() { + let cert = cert.unwrap(); + let mut buf = std::io::Cursor::new(&cert.pem[..]); + config.root_store.add_pem_file(&mut buf).unwrap(); + } + Ok(Self { inner: Connector::Rustls(Arc::new(config)), domain: Arc::new(domain), }) } + #[cfg(feature = "rustls")] + pub(crate) fn new_with_rustls_raw( + config: tokio_rustls::rustls::ClientConfig, + ) -> Result { + Ok(Self { + inner: Connector::Rustls(Arc::new(config)), + domain: Arc::new("".into()), + }) + } + pub(crate) async fn connect(&self, io: TcpStream) -> Result { let tls_io = match &self.inner { #[cfg(feature = "openssl")] @@ -167,7 +193,7 @@ enum Acceptor { impl TlsAcceptor { #[cfg(feature = "openssl")] - pub(crate) fn new_with_openssl(identity: Identity) -> Result { + pub(crate) fn new_with_openssl_identity(identity: Identity) -> Result { let key = PKey::private_key_from_pem(&identity.key[..])?; let cert = X509::from_pem(&identity.cert.pem[..])?; @@ -185,6 +211,15 @@ impl TlsAcceptor { }) } + #[cfg(feature = "openssl")] + pub(crate) fn new_with_openssl_raw( + acceptor: openssl1::ssl::SslAcceptor, + ) -> Result { + Ok(Self { + inner: Acceptor::Openssl(acceptor), + }) + } + #[cfg(feature = "rustls")] fn load_rustls_private_key( mut cursor: std::io::Cursor<&[u8]>, @@ -209,7 +244,7 @@ impl TlsAcceptor { } #[cfg(feature = "rustls")] - pub(crate) fn new_with_rustls(identity: Identity) -> Result { + pub(crate) fn new_with_rustls_identity(identity: Identity) -> Result { let cert = { let mut cert = std::io::Cursor::new(&identity.cert.pem[..]); match pemfile::certs(&mut cert) { @@ -238,6 +273,15 @@ impl TlsAcceptor { }) } + #[cfg(feature = "rustls")] + pub(crate) fn new_with_rustls_raw( + config: tokio_rustls::rustls::ServerConfig, + ) -> Result { + Ok(Self { + inner: Acceptor::Rustls(Arc::new(config)), + }) + } + pub(crate) async fn connect(&self, io: TcpStream) -> Result { let io = match &self.inner { #[cfg(feature = "openssl")] diff --git a/tonic/src/transport/tls.rs b/tonic/src/transport/tls.rs index f8fb9916d..8fbc012e7 100644 --- a/tonic/src/transport/tls.rs +++ b/tonic/src/transport/tls.rs @@ -1,3 +1,17 @@ +/// Selects a library to provide TLS. +#[derive(Debug)] +pub enum TlsProvider { + /// Do not enable TLS. To use OpenSSL or Rustls, enable the `openssl` or `rustls` features of + /// the `tonic` crate. + None, + /// Use OpenSSL for TLS. + #[cfg(feature = "openssl")] + OpenSsl, + /// Use OpenSSL for TLS. + #[cfg(feature = "rustls")] + Rustls, +} + /// Represents a X509 certificate. #[derive(Debug, Clone)] pub struct Certificate {