diff --git a/tonic-examples/src/tls/client.rs b/tonic-examples/src/tls/client.rs index b8346763c..6370ba838 100644 --- a/tonic-examples/src/tls/client.rs +++ b/tonic-examples/src/tls/client.rs @@ -3,19 +3,23 @@ pub mod pb { } use pb::{client::EchoClient, EchoRequest}; -use tonic::transport::{Certificate, Channel}; +use tonic::transport::{Certificate, Channel, ClientTlsConfig}; #[tokio::main] async fn main() -> Result<(), Box> { let pem = tokio::fs::read("tonic-examples/data/tls/ca.pem").await?; let ca = Certificate::from_pem(pem); + let tls = ClientTlsConfig::with_rustls() + .ca_certificate(ca) + .domain_name("example.com") + .clone(); + let channel = Channel::from_static("http://[::1]:50051") - .rustls_tls(ca, Some("example.com".into())) + .tls_config(&tls) .channel(); let mut client = EchoClient::new(channel); - let request = tonic::Request::new(EchoRequest { message: "hello".into(), }); diff --git a/tonic-examples/src/tls/server.rs b/tonic-examples/src/tls/server.rs index 5713a4fd0..0bac205aa 100644 --- a/tonic-examples/src/tls/server.rs +++ b/tonic-examples/src/tls/server.rs @@ -5,7 +5,7 @@ pub mod pb { use pb::{EchoRequest, EchoResponse}; use std::collections::VecDeque; use tonic::{ - transport::{Identity, Server}, + transport::{Identity, Server, ServerTlsConfig}, Request, Response, Status, Streaming, }; @@ -59,7 +59,7 @@ async fn main() -> Result<(), Box> { let server = EchoServer::default(); Server::builder() - .rustls_tls(identity) + .tls_config(ServerTlsConfig::with_rustls().identity(identity)) .clone() .serve(addr, pb::server::EchoServer::new(server)) .await?; diff --git a/tonic-interop/src/bin/client.rs b/tonic-interop/src/bin/client.rs index 2012df61a..0b4283072 100644 --- a/tonic-interop/src/bin/client.rs +++ b/tonic-interop/src/bin/client.rs @@ -1,6 +1,6 @@ use std::time::Duration; use structopt::{clap::arg_enum, StructOpt}; -use tonic::transport::{Certificate, Endpoint}; +use tonic::transport::{Certificate, ClientTlsConfig, Endpoint}; use tonic_interop::client; #[derive(StructOpt)] @@ -33,7 +33,12 @@ async fn main() -> Result<(), Box> { if matches.use_tls { let pem = tokio::fs::read("tonic-interop/data/ca.pem").await?; let ca = Certificate::from_pem(pem); - endpoint.openssl_tls(ca, Some("foo.test.google.fr".into())); + + endpoint.tls_config( + ClientTlsConfig::with_openssl() + .ca_certificate(ca) + .domain_name("foo.test.google.fr"), + ); } let channel = endpoint.channel(); diff --git a/tonic-interop/src/bin/server.rs b/tonic-interop/src/bin/server.rs index 7a65473ab..dd0a7d052 100644 --- a/tonic-interop/src/bin/server.rs +++ b/tonic-interop/src/bin/server.rs @@ -2,7 +2,7 @@ use http::header::HeaderName; use structopt::StructOpt; use tonic::body::BoxBody; use tonic::client::GrpcService; -use tonic::transport::{Identity, Server}; +use tonic::transport::{Identity, Server, ServerTlsConfig}; use tonic_interop::{server, MergeTrailers}; #[derive(StructOpt)] @@ -26,7 +26,7 @@ async fn main() -> std::result::Result<(), Box> { let key = tokio::fs::read("tonic-interop/data/server1.key").await?; let identity = Identity::from_pem(cert, key); - builder.openssl_tls(identity); + builder.tls_config(ServerTlsConfig::with_openssl().identity(identity)); } builder.interceptor_fn(|svc, req| { diff --git a/tonic/src/transport/endpoint.rs b/tonic/src/transport/endpoint.rs index ab11b9922..2d3de4387 100644 --- a/tonic/src/transport/endpoint.rs +++ b/tonic/src/transport/endpoint.rs @@ -1,6 +1,9 @@ use super::channel::Channel; #[cfg(feature = "tls")] -use super::{service::TlsConnector, tls::Certificate}; +use super::{ + service::TlsConnector, + tls::{Certificate, TlsProvider}, +}; use bytes::Bytes; use http::uri::{InvalidUriBytes, Uri}; use std::{ @@ -122,64 +125,6 @@ impl Endpoint { self } - /// Enable TLS and apply the CA as the root certificate. - /// - /// Providing an optional domain to override. If `None` is passed to this - /// the TLS implementation will use the `Uri` that was used to create the - /// `Endpoint` builder. - /// - /// ```no_run - /// # use tonic::transport::{Certificate, Endpoint}; - /// # fn dothing() -> Result<(), Box> { - /// # let mut builder = Endpoint::from_static("https://example.com"); - /// let ca = std::fs::read_to_string("ca.pem")?; - /// - /// let ca = Certificate::from_pem(ca); - /// - /// builder.openssl_tls(ca, "example.com".to_string()); - /// # Ok(()) - /// # } - /// ``` - #[cfg(feature = "openssl")] - #[cfg_attr(docsrs, doc(cfg(feature = "openssl")))] - pub fn openssl_tls(&mut self, ca: Certificate, domain: impl Into>) -> &mut Self { - let domain = domain - .into() - .unwrap_or_else(|| self.uri.clone().to_string()); - let tls = TlsConnector::new_with_openssl(ca, domain).unwrap(); - self.tls = Some(tls); - self - } - - /// Enable TLS and apply the CA as the root certificate. - /// - /// Providing an optional domain to override. If `None` is passed to this - /// the TLS implementation will use the `Uri` that was used to create the - /// `Endpoint` builder. - /// - /// ```no_run - /// # use tonic::transport::{Certificate, Endpoint}; - /// # fn dothing() -> Result<(), Box> { - /// # let mut builder = Endpoint::from_static("https://example.com"); - /// let ca = std::fs::read_to_string("ca.pem")?; - /// - /// let ca = Certificate::from_pem(ca); - /// - /// builder.rustls_tls(ca, "example.com".to_string()); - /// # Ok(()) - /// # } - /// ``` - #[cfg(feature = "rustls")] - #[cfg_attr(docsrs, doc(cfg(feature = "rustls")))] - pub fn rustls_tls(&mut self, ca: Certificate, domain: impl Into>) -> &mut Self { - let domain = domain - .into() - .unwrap_or_else(|| self.uri.clone().to_string()); - let tls = TlsConnector::new_with_rustls(ca, domain).unwrap(); - self.tls = Some(tls); - self - } - /// Intercept outbound HTTP Request headers; pub fn intercept_headers(&mut self, f: F) -> &mut Self where @@ -189,6 +134,13 @@ impl Endpoint { self } + /// Configures TLS for the endpoint. + #[cfg(feature = "tls")] + pub fn tls_config(&mut self, tls_config: &ClientTlsConfig) -> &mut Self { + self.tls = Some(tls_config.tls_connector(self.uri.clone()).unwrap()); + self + } + /// Create a channel from this config. pub fn channel(&self) -> Channel { Channel::connect(self.clone()) @@ -252,3 +204,104 @@ impl fmt::Debug for Endpoint { f.debug_struct("Endpoint").finish() } } + +/// Configures TLS settings for endpoints. +#[cfg(feature = "tls")] +#[derive(Clone)] +pub struct ClientTlsConfig { + provider: TlsProvider, + domain: Option, + cert: Option, + #[cfg(feature = "openssl")] + openssl_raw: Option, + #[cfg(feature = "rustls")] + rustls_raw: Option, +} + +#[cfg(feature = "tls")] +impl fmt::Debug for ClientTlsConfig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ClientTlsConfig") + .field("provider", &self.provider) + .finish() + } +} + +#[cfg(feature = "tls")] +impl ClientTlsConfig { + /// Creates a new `ClientTlsConfig` using OpenSSL. + #[cfg(feature = "openssl")] + pub fn with_openssl() -> Self { + Self::new(TlsProvider::OpenSsl) + } + + /// Creates a new `ClientTlsConfig` using Rustls. + #[cfg(feature = "rustls")] + pub fn with_rustls() -> Self { + Self::new(TlsProvider::Rustls) + } + + fn new(provider: TlsProvider) -> Self { + ClientTlsConfig { + provider, + domain: None, + cert: None, + #[cfg(feature = "openssl")] + openssl_raw: None, + #[cfg(feature = "rustls")] + rustls_raw: None, + } + } + + /// Sets the domain name against which to verify the server's TLS certificate. + pub fn domain_name(&mut self, domain_name: impl Into) -> &mut Self { + self.domain = Some(domain_name.into()); + self + } + + /// Sets the CA Certificate against which to verify the server's TLS certificate. + pub fn ca_certificate(&mut self, ca_certificate: Certificate) -> &mut Self { + self.cert = Some(ca_certificate); + self + } + + /// Use options specified by the given `SslConnector` to configure TLS. + /// + /// This overrides all other TLS options set via other means. + #[cfg(feature = "openssl")] + pub fn openssl_connector(&mut self, connector: openssl1::ssl::SslConnector) -> &mut Self { + self.openssl_raw = Some(connector); + self + } + + /// Use options specified by the given `ClientConfig` to configure TLS. + /// + /// This overrides all other TLS options set via other means. + #[cfg(feature = "rustls")] + pub fn rustls_client_config( + &mut self, + config: tokio_rustls::rustls::ClientConfig, + ) -> &mut Self { + self.rustls_raw = Some(config); + self + } + + fn tls_connector(&self, uri: Uri) -> Result { + let domain = match &self.domain { + None => uri.to_string(), + Some(domain) => domain.clone(), + }; + match self.provider { + #[cfg(feature = "openssl")] + TlsProvider::OpenSsl => match &self.openssl_raw { + None => TlsConnector::new_with_openssl_cert(self.cert.clone(), domain), + Some(r) => TlsConnector::new_with_openssl_raw(r.clone(), domain), + }, + #[cfg(feature = "rustls")] + TlsProvider::Rustls => match &self.rustls_raw { + None => TlsConnector::new_with_rustls_cert(self.cert.clone(), domain), + Some(c) => TlsConnector::new_with_rustls_raw(c.clone(), domain), + }, + } + } +} diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index d963186e4..846719ef5 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -19,7 +19,7 @@ //! ## Client //! //! ```no_run -//! # use tonic::transport::{Channel, Certificate}; +//! # use tonic::transport::{Channel, Certificate, ClientTlsConfig}; //! # use std::time::Duration; //! # use tonic::body::BoxBody; //! # use tonic::client::GrpcService;; @@ -29,7 +29,9 @@ //! let cert = std::fs::read_to_string("ca.pem")?; //! //! let mut channel = Channel::from_static("https://example.com") -//! .rustls_tls(Certificate::from_pem(&cert), "example.com".to_string()) +//! .tls_config(ClientTlsConfig::with_rustls() +//! .ca_certificate(Certificate::from_pem(&cert)) +//! .domain_name("example.com".to_string())) //! .timeout(Duration::from_secs(5)) //! .rate_limit(5, Duration::from_secs(1)) //! .concurrency_limit(256) @@ -43,7 +45,7 @@ //! ## Server //! //! ```no_run -//! # use tonic::transport::{Server, Identity}; +//! # use tonic::transport::{Server, Identity, ServerTlsConfig}; //! # use tower::{Service, service_fn}; //! # use futures_util::future::{err, ok}; //! # #[cfg(feature = "rustls")] @@ -55,7 +57,8 @@ //! let addr = "[::1]:50051".parse()?; //! //! Server::builder() -//! .rustls_tls(Identity::from_pem(&cert, &key)) +//! .tls_config(ServerTlsConfig::with_rustls() +//! .identity(Identity::from_pem(&cert, &key))) //! .concurrency_limit_per_connection(256) //! .interceptor_fn(|svc, req| { //! println!("Request: {:?}", req); @@ -89,4 +92,9 @@ pub use self::server::Server; pub use self::tls::{Certificate, Identity}; pub use hyper::Body; +#[cfg(feature = "tls")] +pub use self::endpoint::ClientTlsConfig; +#[cfg(feature = "tls")] +pub use self::server::ServerTlsConfig; + pub(crate) use self::error::ErrorKind; diff --git a/tonic/src/transport/server.rs b/tonic/src/transport/server.rs index 9878a2b0f..1cfbc999c 100644 --- a/tonic/src/transport/server.rs +++ b/tonic/src/transport/server.rs @@ -2,7 +2,10 @@ use super::service::{layer_fn, BoxedIo, ServiceBuilderExt}; #[cfg(feature = "tls")] -use super::{service::TlsAcceptor, tls::Identity}; +use super::{ + service::TlsAcceptor, + tls::{Identity, TlsProvider}, +}; use crate::body::BoxBody; use futures_core::Stream; use futures_util::{ready, try_future::MapErr, TryFutureExt, TryStreamExt}; @@ -62,49 +65,10 @@ impl Server { } impl Server { - /// Set the [`Identity`] of this server using `openssl`. - /// - /// ```no_run - /// # use tonic::transport::{Identity, Server}; - /// # fn dothing() -> Result<(), Box> { - /// # let mut builder = Server::builder(); - /// let cert = std::fs::read_to_string("server.pem")?; - /// let key = std::fs::read_to_string("server.key")?; - /// - /// let identity = Identity::from_pem(&cert, &key); - /// - /// builder.openssl_tls(identity); - /// # Ok(()) - /// # } - /// ``` - #[cfg(feature = "openssl")] - #[cfg_attr(docsrs, doc(cfg(feature = "openssl")))] - pub fn openssl_tls(&mut self, identity: Identity) -> &mut Self { - let acceptor = TlsAcceptor::new_with_openssl(identity).unwrap(); - self.tls = Some(acceptor); - self - } - - /// Set the [`Identity`] of this server using `rustls`. - /// - /// ```no_run - /// # use tonic::transport::{Identity, Server}; - /// # fn dothing() -> Result<(), Box> { - /// # let mut builder = Server::builder(); - /// let cert = std::fs::read_to_string("server.pem")?; - /// let key = std::fs::read_to_string("server.key")?; - /// - /// let identity = Identity::from_pem(&cert, &key); - /// - /// builder.rustls_tls(identity); - /// # Ok(()) - /// # } - /// ``` - #[cfg(feature = "rustls")] - #[cfg_attr(docsrs, doc(cfg(feature = "rustls")))] - pub fn rustls_tls(&mut self, identity: Identity) -> &mut Self { - let acceptor = TlsAcceptor::new_with_rustls(identity).unwrap(); - self.tls = Some(acceptor); + /// Configure TLS for this server. + #[cfg(feature = "tls")] + pub fn tls_config(&mut self, tls_config: &ServerTlsConfig) -> &mut Self { + self.tls = Some(tls_config.tls_acceptor().unwrap()); self } @@ -255,6 +219,97 @@ impl fmt::Debug for Server { } } +/// Configures TLS settings for servers. +#[cfg(feature = "tls")] +#[derive(Clone)] +pub struct ServerTlsConfig { + provider: TlsProvider, + identity: Option, + #[cfg(feature = "openssl")] + openssl_raw: Option, + #[cfg(feature = "rustls")] + rustls_raw: Option, +} + +#[cfg(feature = "tls")] +impl fmt::Debug for ServerTlsConfig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ServerTlsConfig") + .field("provider", &self.provider) + .finish() + } +} + +#[cfg(feature = "tls")] +impl ServerTlsConfig { + /// Creates a new `ServerTlsConfig` using OpenSSL. + #[cfg(feature = "openssl")] + pub fn with_openssl() -> Self { + Self::new(TlsProvider::OpenSsl) + } + + /// Creates a new `ServerTlsConfig` using Rustls. + #[cfg(feature = "rustls")] + pub fn with_rustls() -> Self { + Self::new(TlsProvider::Rustls) + } + + /// Creates a new `ServerTlsConfig` backed by the specified provider. Enable the `openssl` or + /// `rustls` features of the `tonic` crate to use OpenSSL or Rustls respectively. + fn new(provider: TlsProvider) -> Self { + ServerTlsConfig { + provider, + identity: None, + #[cfg(feature = "openssl")] + openssl_raw: None, + #[cfg(feature = "rustls")] + rustls_raw: None, + } + } + + /// Sets the [`Identity`] of the server. + pub fn identity(&mut self, identity: Identity) -> &mut Self { + self.identity = Some(identity); + self + } + + /// Use options specified by the given `SslAcceptor` to configure TLS. + /// + /// This overrides all other TLS options set via other means. + #[cfg(feature = "openssl")] + pub fn openssl_connector(&mut self, acceptor: openssl1::ssl::SslAcceptor) -> &mut Self { + self.openssl_raw = Some(acceptor); + self + } + + /// Use options specified by the given `ServerConfig` to configure TLS. + /// + /// This overrides all other TLS options set via other means. + #[cfg(feature = "rustls")] + pub fn rustls_client_config( + &mut self, + config: tokio_rustls::rustls::ServerConfig, + ) -> &mut Self { + self.rustls_raw = Some(config); + self + } + + fn tls_acceptor(&self) -> Result { + match self.provider { + #[cfg(feature = "openssl")] + TlsProvider::OpenSsl => match &self.openssl_raw { + None => TlsAcceptor::new_with_openssl_identity(self.identity.clone().unwrap()), + Some(acceptor) => TlsAcceptor::new_with_openssl_raw(acceptor.clone()), + }, + #[cfg(feature = "rustls")] + TlsProvider::Rustls => match &self.rustls_raw { + None => TlsAcceptor::new_with_rustls_identity(self.identity.clone().unwrap()), + Some(config) => TlsAcceptor::new_with_rustls_raw(config.clone()), + }, + } + } +} + #[derive(Debug)] struct TcpIncoming { inner: conn::AddrIncoming, diff --git a/tonic/src/transport/service/tls.rs b/tonic/src/transport/service/tls.rs index de45e3aa8..ab3f5b9d6 100644 --- a/tonic/src/transport/service/tls.rs +++ b/tonic/src/transport/service/tls.rs @@ -56,35 +56,60 @@ enum Connector { impl TlsConnector { #[cfg(feature = "openssl")] - pub(crate) fn new_with_openssl( - cert: Certificate, + pub(crate) fn new_with_openssl_cert( + cert: Option, domain: String, ) -> Result { let mut config = SslConnector::builder(SslMethod::tls())?; - config.set_alpn_protos(ALPN_H2_WIRE)?; - let ca = X509::from_pem(&cert.pem[..])?; - - config.cert_store_mut().add_cert(ca)?; + if let Some(cert) = cert { + let ca = X509::from_pem(&cert.pem[..])?; + config.cert_store_mut().add_cert(ca)?; + } - let config = config.build(); + Ok(Self { + inner: Connector::Openssl(config.build()), + domain: Arc::new(domain), + }) + } + #[cfg(feature = "openssl")] + pub(crate) fn new_with_openssl_raw( + ssl_connector: openssl1::ssl::SslConnector, + domain: String, + ) -> Result { Ok(Self { - inner: Connector::Openssl(config), + inner: Connector::Openssl(ssl_connector), domain: Arc::new(domain), }) } #[cfg(feature = "rustls")] - pub(crate) fn new_with_rustls(cert: Certificate, domain: String) -> Result { - let mut buf = std::io::Cursor::new(&cert.pem[..]); - + pub(crate) fn new_with_rustls_cert( + cert: Option, + domain: String, + ) -> Result { let mut config = ClientConfig::new(); - - config.root_store.add_pem_file(&mut buf).unwrap(); config.set_protocols(&[Vec::from(&ALPN_H2[..])]); + if cert.is_some() { + let cert = cert.unwrap(); + let mut buf = std::io::Cursor::new(&cert.pem[..]); + config.root_store.add_pem_file(&mut buf).unwrap(); + } + + Ok(Self { + inner: Connector::Rustls(Arc::new(config)), + domain: Arc::new(domain), + }) + } + + #[cfg(feature = "rustls")] + pub(crate) fn new_with_rustls_raw( + config: tokio_rustls::rustls::ClientConfig, + domain: String, + ) -> Result { Ok(Self { inner: Connector::Rustls(Arc::new(config)), domain: Arc::new(domain), @@ -167,7 +192,7 @@ enum Acceptor { impl TlsAcceptor { #[cfg(feature = "openssl")] - pub(crate) fn new_with_openssl(identity: Identity) -> Result { + pub(crate) fn new_with_openssl_identity(identity: Identity) -> Result { let key = PKey::private_key_from_pem(&identity.key[..])?; let cert = X509::from_pem(&identity.cert.pem[..])?; @@ -185,6 +210,15 @@ impl TlsAcceptor { }) } + #[cfg(feature = "openssl")] + pub(crate) fn new_with_openssl_raw( + acceptor: openssl1::ssl::SslAcceptor, + ) -> Result { + Ok(Self { + inner: Acceptor::Openssl(acceptor), + }) + } + #[cfg(feature = "rustls")] fn load_rustls_private_key( mut cursor: std::io::Cursor<&[u8]>, @@ -209,7 +243,7 @@ impl TlsAcceptor { } #[cfg(feature = "rustls")] - pub(crate) fn new_with_rustls(identity: Identity) -> Result { + pub(crate) fn new_with_rustls_identity(identity: Identity) -> Result { let cert = { let mut cert = std::io::Cursor::new(&identity.cert.pem[..]); match pemfile::certs(&mut cert) { @@ -238,6 +272,15 @@ impl TlsAcceptor { }) } + #[cfg(feature = "rustls")] + pub(crate) fn new_with_rustls_raw( + config: tokio_rustls::rustls::ServerConfig, + ) -> Result { + Ok(Self { + inner: Acceptor::Rustls(Arc::new(config)), + }) + } + pub(crate) async fn connect(&self, io: TcpStream) -> Result { let io = match &self.inner { #[cfg(feature = "openssl")] diff --git a/tonic/src/transport/tls.rs b/tonic/src/transport/tls.rs index f8fb9916d..33c8fec6b 100644 --- a/tonic/src/transport/tls.rs +++ b/tonic/src/transport/tls.rs @@ -1,3 +1,14 @@ +/// Selects a library to provide TLS. +#[derive(Clone, Debug)] +pub(crate) enum TlsProvider { + /// Use OpenSSL for TLS. + #[cfg(feature = "openssl")] + OpenSsl, + /// Use OpenSSL for TLS. + #[cfg(feature = "rustls")] + Rustls, +} + /// Represents a X509 certificate. #[derive(Debug, Clone)] pub struct Certificate {