diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 89dbc75e0..214a86ebc 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -276,13 +276,13 @@ tracing = ["dep:tracing", "dep:tracing-subscriber"] uds = ["tokio-stream/net", "dep:tower", "dep:hyper", "dep:hyper-util"] streaming = ["tokio-stream", "dep:h2"] mock = ["tokio-stream", "dep:tower", "dep:hyper-util"] -tower = ["dep:hyper", "dep:hyper-util", "dep:tower", "dep:http"] +tower = ["dep:hyper", "dep:hyper-util", "dep:tower", "tower?/timeout", "dep:http"] json-codec = ["dep:serde", "dep:serde_json", "dep:bytes"] compression = ["tonic/gzip"] tls = ["tonic/tls"] tls-rustls = ["dep:hyper", "dep:hyper-util", "dep:hyper-rustls", "dep:tower", "tower-http/util", "tower-http/add-extension", "dep:rustls-pemfile", "dep:tokio-rustls", "dep:pin-project", "dep:http-body-util"] dynamic-load-balance = ["dep:tower"] -timeout = ["tokio/time", "dep:tower"] +timeout = ["tokio/time", "dep:tower", "tower?/timeout"] tls-client-auth = ["tonic/tls"] types = ["dep:tonic-types"] h2c = ["dep:hyper", "dep:tower", "dep:http", "dep:hyper-util"] diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 74df3a583..cb77bcaed 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -26,24 +26,29 @@ version = "0.11.0" codegen = ["dep:async-trait"] gzip = ["dep:flate2"] zstd = ["dep:zstd"] -default = ["transport", "codegen", "prost"] +default = ["channel", "codegen", "prost"] prost = ["dep:prost"] tls = ["dep:rustls-pemfile", "transport", "dep:tokio-rustls", "dep:tokio", "tokio?/rt", "tokio?/macros"] tls-roots = ["tls-roots-common", "dep:rustls-native-certs"] -tls-roots-common = ["tls"] +tls-roots-common = ["tls", "channel"] tls-webpki-roots = ["tls-roots-common", "dep:webpki-roots"] router = ["dep:axum"] transport = [ "router", "dep:async-stream", - "channel", "dep:h2", - "dep:hyper", "dep:hyper-util", "dep:hyper-timeout", + "dep:hyper", "dep:hyper-util", "dep:socket2", "dep:tokio", "tokio?/macros", "tokio?/net", "tokio?/time", - "dep:tower", + "dep:tower", "tower?/util", "tower?/limit", +] +channel = [ + "transport", + "dep:hyper", "hyper?/client", + "dep:hyper-util", "hyper-util?/client-legacy", + "dep:tower", "tower?/balance", "tower?/buffer", "tower?/discover", "tower?/load", "tower?/make", + "dep:hyper-timeout", ] -channel = [] # [[bench]] # name = "bench_main" @@ -71,13 +76,12 @@ async-trait = {version = "0.1.13", optional = true} # transport async-stream = {version = "0.3", optional = true} h2 = {version = "0.4", optional = true} -hyper = {version = "1", features = ["full"], optional = true} -hyper-util = { version = ">=0.1.4, <0.2", features = ["full"], optional = true } -hyper-timeout = {version = "0.5", optional = true} +hyper = {version = "1", features = ["http1", "http2", "server"], optional = true} +hyper-util = { version = ">=0.1.4, <0.2", features = ["service", "server-auto", "tokio"], optional = true } socket2 = { version = ">=0.4.7, <0.6.0", optional = true, features = ["all"] } tokio = {version = "1", default-features = false, optional = true} tokio-stream = { version = "0.1", features = ["net"] } -tower = {version = "0.4.7", default-features = false, features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true} +tower = {version = "0.4.7", default-features = false, optional = true} axum = {version = "0.7", default-features = false, optional = true} # rustls @@ -90,6 +94,9 @@ webpki-roots = { version = "0.26", optional = true } flate2 = {version = "1.0", optional = true} zstd = { version = "0.13.0", optional = true } +# channel +hyper-timeout = {version = "0.5", optional = true} + [dev-dependencies] bencher = "0.1.5" quickcheck = "1.0" diff --git a/tonic/src/lib.rs b/tonic/src/lib.rs index 29137536f..1ca537e52 100644 --- a/tonic/src/lib.rs +++ b/tonic/src/lib.rs @@ -16,10 +16,9 @@ //! //! # Feature Flags //! -//! - `transport`: Enables the fully featured, batteries included client and server -//! implementation based on [`hyper`], [`tower`] and [`tokio`]. Enabled by default. -//! - `channel`: Enables just the full featured channel/client portion of the `transport` -//! feature. +//! - `transport`: Enables just the full featured server portion of the `channel` feature. +//! - `channel`: Enables the fully featured, batteries included client and server +//! implementation based on [`hyper`], [`tower`] and [`tokio`]. Enabled by default. //! - `codegen`: Enables all the required exports and optional dependencies required //! for [`tonic-build`]. Enabled by default. //! - `tls`: Enables the `rustls` based TLS options for the `transport` feature. Not diff --git a/tonic/src/status.rs b/tonic/src/status.rs index 1b97d22d4..a1c2eb016 100644 --- a/tonic/src/status.rs +++ b/tonic/src/status.rs @@ -618,8 +618,10 @@ fn find_status_in_source_chain(err: &(dyn Error + 'static)) -> Option { // matches the spec of: // > The service is currently unavailable. This is most likely a transient condition that // > can be corrected if retried with a backoff. - #[cfg(feature = "transport")] - if let Some(connect) = err.downcast_ref::() { + #[cfg(feature = "channel")] + if let Some(connect) = + err.downcast_ref::() + { return Some(Status::unavailable(connect.to_string())); } diff --git a/tonic/src/transport/channel/endpoint.rs b/tonic/src/transport/channel/endpoint.rs index 6014960a8..4961e03b6 100644 --- a/tonic/src/transport/channel/endpoint.rs +++ b/tonic/src/transport/channel/endpoint.rs @@ -1,10 +1,10 @@ -use super::super::service; +#[cfg(feature = "tls")] +use super::service::TlsConnector; +use super::service::{self, Executor, SharedExec}; use super::Channel; #[cfg(feature = "tls")] use super::ClientTlsConfig; -#[cfg(feature = "tls")] -use crate::transport::service::TlsConnector; -use crate::transport::{service::SharedExec, Error, Executor}; +use crate::transport::Error; use bytes::Bytes; use http::{uri::Uri, HeaderValue}; use hyper::rt; diff --git a/tonic/src/transport/channel/mod.rs b/tonic/src/transport/channel/mod.rs index 0983725f8..16a6ec160 100644 --- a/tonic/src/transport/channel/mod.rs +++ b/tonic/src/transport/channel/mod.rs @@ -1,6 +1,7 @@ //! Client implementation and builder. mod endpoint; +pub(crate) mod service; #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] mod tls; @@ -9,9 +10,8 @@ pub use endpoint::Endpoint; #[cfg(feature = "tls")] pub use tls::ClientTlsConfig; -use super::service::{Connection, DynamicServiceStream, SharedExec}; +use self::service::{Connection, DynamicServiceStream, Executor, SharedExec}; use crate::body::BoxBody; -use crate::transport::Executor; use bytes::Bytes; use http::{ uri::{InvalidUri, Uri}, diff --git a/tonic/src/transport/service/add_origin.rs b/tonic/src/transport/channel/service/add_origin.rs similarity index 100% rename from tonic/src/transport/service/add_origin.rs rename to tonic/src/transport/channel/service/add_origin.rs diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/channel/service/connection.rs similarity index 94% rename from tonic/src/transport/service/connection.rs rename to tonic/src/transport/channel/service/connection.rs index 47a505f48..2c34f3ed7 100644 --- a/tonic/src/transport/service/connection.rs +++ b/tonic/src/transport/channel/service/connection.rs @@ -1,8 +1,7 @@ -use super::SharedExec; -use super::{grpc_timeout::GrpcTimeout, reconnect::Reconnect, AddOrigin, UserAgent}; +use super::{AddOrigin, Reconnect, SharedExec, UserAgent}; use crate::{ body::{boxed, BoxBody}, - transport::{BoxFuture, Endpoint}, + transport::{service::GrpcTimeout, BoxFuture, Endpoint}, }; use http::Uri; use hyper::rt; @@ -36,7 +35,7 @@ impl Connection { C::Future: Unpin + Send, C::Response: rt::Read + rt::Write + Unpin + Send + 'static, { - let mut settings: Builder = Builder::new(endpoint.executor.clone()) + let mut settings: Builder = Builder::new(endpoint.executor.clone()) .initial_stream_window_size(endpoint.init_stream_window_size) .initial_connection_window_size(endpoint.init_connection_window_size) .keep_alive_interval(endpoint.http2_keep_alive_interval) @@ -158,12 +157,12 @@ impl tower::Service> for SendRequest { struct MakeSendRequestService { connector: C, - executor: super::SharedExec, - settings: Builder, + executor: SharedExec, + settings: Builder, } impl MakeSendRequestService { - fn new(connector: C, executor: SharedExec, settings: Builder) -> Self { + fn new(connector: C, executor: SharedExec, settings: Builder) -> Self { Self { connector, executor, diff --git a/tonic/src/transport/service/connector.rs b/tonic/src/transport/channel/service/connector.rs similarity index 98% rename from tonic/src/transport/service/connector.rs rename to tonic/src/transport/channel/service/connector.rs index 4c73d13f2..0bfe0a518 100644 --- a/tonic/src/transport/service/connector.rs +++ b/tonic/src/transport/channel/service/connector.rs @@ -1,7 +1,7 @@ -use super::super::BoxFuture; -use super::io::BoxedIo; +use super::BoxedIo; #[cfg(feature = "tls")] -use super::tls::TlsConnector; +use super::TlsConnector; +use crate::transport::BoxFuture; use http::Uri; use std::fmt; use std::task::{Context, Poll}; diff --git a/tonic/src/transport/service/discover.rs b/tonic/src/transport/channel/service/discover.rs similarity index 96% rename from tonic/src/transport/service/discover.rs rename to tonic/src/transport/channel/service/discover.rs index b9356110e..b1d3c3331 100644 --- a/tonic/src/transport/service/discover.rs +++ b/tonic/src/transport/channel/service/discover.rs @@ -1,5 +1,4 @@ -use super::connection::Connection; -use crate::transport::Endpoint; +use super::super::{Connection, Endpoint}; use hyper_util::client::legacy::connect::HttpConnector; use std::{ diff --git a/tonic/src/transport/service/executor.rs b/tonic/src/transport/channel/service/executor.rs similarity index 100% rename from tonic/src/transport/service/executor.rs rename to tonic/src/transport/channel/service/executor.rs diff --git a/tonic/src/transport/channel/service/io.rs b/tonic/src/transport/channel/service/io.rs new file mode 100644 index 000000000..084195ad6 --- /dev/null +++ b/tonic/src/transport/channel/service/io.rs @@ -0,0 +1,67 @@ +use std::io::{self, IoSlice}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use hyper::rt; +use hyper_util::client::legacy::connect::{Connected as HyperConnected, Connection}; + +pub(in crate::transport) trait Io: + rt::Read + rt::Write + Send + 'static +{ +} + +impl Io for T where T: rt::Read + rt::Write + Send + 'static {} + +pub(crate) struct BoxedIo(Pin>); + +impl BoxedIo { + pub(in crate::transport) fn new(io: I) -> Self { + BoxedIo(Box::pin(io)) + } +} + +impl Connection for BoxedIo { + fn connected(&self) -> HyperConnected { + HyperConnected::new() + } +} + +impl rt::Read for BoxedIo { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: rt::ReadBufCursor<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_read(cx, buf) + } +} + +impl rt::Write for BoxedIo { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_shutdown(cx) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut self.0).poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.0.is_write_vectored() + } +} diff --git a/tonic/src/transport/channel/service/mod.rs b/tonic/src/transport/channel/service/mod.rs new file mode 100644 index 000000000..6ac7603a9 --- /dev/null +++ b/tonic/src/transport/channel/service/mod.rs @@ -0,0 +1,28 @@ +mod add_origin; +use self::add_origin::AddOrigin; + +mod user_agent; +use self::user_agent::UserAgent; + +mod reconnect; +use self::reconnect::Reconnect; + +mod connection; +pub(super) use self::connection::Connection; + +mod discover; +pub(super) use self::discover::DynamicServiceStream; + +mod io; +use self::io::BoxedIo; + +mod connector; +pub(crate) use self::connector::{ConnectError, Connector}; + +mod executor; +pub(super) use self::executor::{Executor, SharedExec}; + +#[cfg(feature = "tls")] +mod tls; +#[cfg(feature = "tls")] +pub(super) use self::tls::TlsConnector; diff --git a/tonic/src/transport/service/reconnect.rs b/tonic/src/transport/channel/service/reconnect.rs similarity index 100% rename from tonic/src/transport/service/reconnect.rs rename to tonic/src/transport/channel/service/reconnect.rs diff --git a/tonic/src/transport/channel/service/tls.rs b/tonic/src/transport/channel/service/tls.rs new file mode 100644 index 000000000..c63396a70 --- /dev/null +++ b/tonic/src/transport/channel/service/tls.rs @@ -0,0 +1,83 @@ +use std::fmt; +use std::io::Cursor; +use std::sync::Arc; + +use hyper_util::rt::TokioIo; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_rustls::{ + rustls::{pki_types::ServerName, ClientConfig, RootCertStore}, + TlsConnector as RustlsConnector, +}; + +use super::io::BoxedIo; +use crate::transport::service::tls::{add_certs_from_pem, load_identity, TlsError, ALPN_H2}; +use crate::transport::tls::{Certificate, Identity}; + +#[derive(Clone)] +pub(crate) struct TlsConnector { + config: Arc, + domain: Arc>, + assume_http2: bool, +} + +impl TlsConnector { + pub(crate) fn new( + ca_certs: Vec, + identity: Option, + domain: &str, + assume_http2: bool, + ) -> Result { + let builder = ClientConfig::builder(); + let mut roots = RootCertStore::empty(); + + #[cfg(feature = "tls-roots")] + roots.add_parsable_certificates(rustls_native_certs::load_native_certs()?); + + #[cfg(feature = "tls-webpki-roots")] + roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + + for cert in ca_certs { + add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?; + } + + let builder = builder.with_root_certificates(roots); + let mut config = match identity { + Some(identity) => { + let (client_cert, client_key) = load_identity(identity)?; + builder.with_client_auth_cert(client_cert, client_key)? + } + None => builder.with_no_client_auth(), + }; + + config.alpn_protocols.push(ALPN_H2.into()); + Ok(Self { + config: Arc::new(config), + domain: Arc::new(ServerName::try_from(domain)?.to_owned()), + assume_http2, + }) + } + + pub(crate) async fn connect(&self, io: I) -> Result + where + I: AsyncRead + AsyncWrite + Send + Unpin + 'static, + { + let io = RustlsConnector::from(self.config.clone()) + .connect(self.domain.as_ref().to_owned(), io) + .await?; + + // Generally we require ALPN to be negotiated, but if the user has + // explicitly set `assume_http2` to true, we'll allow it to be missing. + let (_, session) = io.get_ref(); + let alpn_protocol = session.alpn_protocol(); + if !(alpn_protocol == Some(ALPN_H2) || self.assume_http2) { + return Err(TlsError::H2NotNegotiated.into()); + } + Ok(BoxedIo::new(TokioIo::new(io))) + } +} + +impl fmt::Debug for TlsConnector { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TlsConnector").finish() + } +} diff --git a/tonic/src/transport/service/user_agent.rs b/tonic/src/transport/channel/service/user_agent.rs similarity index 100% rename from tonic/src/transport/service/user_agent.rs rename to tonic/src/transport/channel/service/user_agent.rs diff --git a/tonic/src/transport/channel/tls.rs b/tonic/src/transport/channel/tls.rs index a3c64a65c..31f5a7a46 100644 --- a/tonic/src/transport/channel/tls.rs +++ b/tonic/src/transport/channel/tls.rs @@ -1,5 +1,5 @@ +use super::service::TlsConnector; use crate::transport::{ - service::TlsConnector, tls::{Certificate, Identity}, Error, }; diff --git a/tonic/src/transport/error.rs b/tonic/src/transport/error.rs index d2a1c7bb2..92a910498 100644 --- a/tonic/src/transport/error.rs +++ b/tonic/src/transport/error.rs @@ -15,7 +15,9 @@ struct ErrorImpl { #[derive(Debug)] pub(crate) enum Kind { Transport, + #[cfg(feature = "channel")] InvalidUri, + #[cfg(feature = "channel")] InvalidUserAgent, } @@ -35,10 +37,12 @@ impl Error { Error::new(Kind::Transport).with(source) } + #[cfg(feature = "channel")] pub(crate) fn new_invalid_uri() -> Self { Error::new(Kind::InvalidUri) } + #[cfg(feature = "channel")] pub(crate) fn new_invalid_user_agent() -> Self { Error::new(Kind::InvalidUserAgent) } @@ -46,7 +50,9 @@ impl Error { fn description(&self) -> &str { match &self.inner.kind { Kind::Transport => "transport error", + #[cfg(feature = "channel")] Kind::InvalidUri => "invalid URI", + #[cfg(feature = "channel")] Kind::InvalidUserAgent => "user agent is not a valid header value", } } diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index f29fd1c99..9ee24a557 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -89,6 +89,7 @@ //! //! [rustls]: https://docs.rs/rustls/0.16.0/rustls/ +#[cfg(feature = "channel")] pub mod channel; pub mod server; @@ -106,7 +107,6 @@ pub use self::error::Error; pub use self::server::Server; #[doc(inline)] pub use self::service::grpc_timeout::TimeoutExpired; -pub(crate) use self::service::ConnectError; #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] @@ -116,10 +116,8 @@ pub use hyper::{body::Body, Uri}; #[cfg(feature = "tls")] pub use tokio_rustls::rustls::pki_types::CertificateDer; -pub(crate) use self::service::executor::Executor; - -#[cfg(feature = "tls")] -#[cfg_attr(docsrs, doc(cfg(feature = "tls")))] +#[cfg(all(feature = "channel", feature = "tls"))] +#[cfg_attr(docsrs, doc(cfg(all(feature = "channel", feature = "tls"))))] pub use self::channel::ClientTlsConfig; #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] @@ -128,4 +126,5 @@ pub use self::server::ServerTlsConfig; #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] pub use self::tls::Identity; +#[cfg(feature = "channel")] use crate::service::router::BoxFuture; diff --git a/tonic/src/transport/service/io.rs b/tonic/src/transport/service/io.rs index cb2296cac..7821f691c 100644 --- a/tonic/src/transport/service/io.rs +++ b/tonic/src/transport/service/io.rs @@ -1,6 +1,4 @@ use crate::transport::server::Connected; -use hyper::rt; -use hyper_util::client::legacy::connect::{Connected as HyperConnected, Connection}; use std::io; use std::io::IoSlice; use std::pin::Pin; @@ -9,78 +7,6 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; #[cfg(feature = "tls")] use tokio_rustls::server::TlsStream; -pub(in crate::transport) trait Io: - rt::Read + rt::Write + Send + 'static -{ -} - -impl Io for T where T: rt::Read + rt::Write + Send + 'static {} - -pub(crate) struct BoxedIo(Pin>); - -impl BoxedIo { - pub(in crate::transport) fn new(io: I) -> Self { - BoxedIo(Box::pin(io)) - } -} - -impl Connection for BoxedIo { - fn connected(&self) -> HyperConnected { - HyperConnected::new() - } -} - -impl Connected for BoxedIo { - type ConnectInfo = NoneConnectInfo; - - fn connect_info(&self) -> Self::ConnectInfo { - NoneConnectInfo - } -} - -#[derive(Copy, Clone)] -pub(crate) struct NoneConnectInfo; - -impl rt::Read for BoxedIo { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: rt::ReadBufCursor<'_>, - ) -> Poll> { - Pin::new(&mut self.0).poll_read(cx, buf) - } -} - -impl rt::Write for BoxedIo { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.0).poll_write(cx, buf) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_flush(cx) - } - - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_shutdown(cx) - } - - fn poll_write_vectored( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - Pin::new(&mut self.0).poll_write_vectored(cx, bufs) - } - - fn is_write_vectored(&self) -> bool { - self.0.is_write_vectored() - } -} - pub(crate) enum ServerIo { Io(IO), #[cfg(feature = "tls")] diff --git a/tonic/src/transport/service/mod.rs b/tonic/src/transport/service/mod.rs index eeae37f5d..b5904aa07 100644 --- a/tonic/src/transport/service/mod.rs +++ b/tonic/src/transport/service/mod.rs @@ -1,23 +1,9 @@ -mod add_origin; -mod connection; -mod connector; -mod discover; -pub(crate) mod executor; pub(crate) mod grpc_timeout; mod io; -mod reconnect; #[cfg(feature = "tls")] -mod tls; -mod user_agent; +pub(crate) mod tls; -pub(crate) use self::add_origin::AddOrigin; -pub(crate) use self::connection::Connection; -pub(crate) use self::connector::ConnectError; -pub(crate) use self::connector::Connector; -pub(crate) use self::discover::DynamicServiceStream; -pub(crate) use self::executor::SharedExec; pub(crate) use self::grpc_timeout::GrpcTimeout; pub(crate) use self::io::ServerIo; #[cfg(feature = "tls")] -pub(crate) use self::tls::{TlsAcceptor, TlsConnector}; -pub(crate) use self::user_agent::UserAgent; +pub(crate) use self::tls::TlsAcceptor; diff --git a/tonic/src/transport/service/tls.rs b/tonic/src/transport/service/tls.rs index 0e38d87ee..7bbf210b3 100644 --- a/tonic/src/transport/service/tls.rs +++ b/tonic/src/transport/service/tls.rs @@ -6,99 +6,29 @@ use std::{ use tokio::io::{AsyncRead, AsyncWrite}; use tokio_rustls::{ rustls::{ - pki_types::{CertificateDer, PrivateKeyDer, ServerName}, + pki_types::{CertificateDer, PrivateKeyDer}, server::WebPkiClientVerifier, - ClientConfig, RootCertStore, ServerConfig, + RootCertStore, ServerConfig, }, - TlsAcceptor as RustlsAcceptor, TlsConnector as RustlsConnector, + TlsAcceptor as RustlsAcceptor, }; -use super::io::BoxedIo; use crate::transport::{ server::{Connected, TlsStream}, Certificate, Identity, }; -use hyper_util::rt::TokioIo; /// h2 alpn in plain format for rustls. -const ALPN_H2: &[u8] = b"h2"; +pub(crate) const ALPN_H2: &[u8] = b"h2"; #[derive(Debug)] -enum TlsError { +pub(crate) enum TlsError { + #[cfg(feature = "channel")] H2NotNegotiated, CertificateParseError, PrivateKeyParseError, } -#[derive(Clone)] -pub(crate) struct TlsConnector { - config: Arc, - domain: Arc>, - assume_http2: bool, -} - -impl TlsConnector { - pub(crate) fn new( - ca_certs: Vec, - identity: Option, - domain: &str, - assume_http2: bool, - ) -> Result { - let builder = ClientConfig::builder(); - let mut roots = RootCertStore::empty(); - - #[cfg(feature = "tls-roots")] - roots.add_parsable_certificates(rustls_native_certs::load_native_certs()?); - - #[cfg(feature = "tls-webpki-roots")] - roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); - - for cert in ca_certs { - add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?; - } - - let builder = builder.with_root_certificates(roots); - let mut config = match identity { - Some(identity) => { - let (client_cert, client_key) = load_identity(identity)?; - builder.with_client_auth_cert(client_cert, client_key)? - } - None => builder.with_no_client_auth(), - }; - - config.alpn_protocols.push(ALPN_H2.into()); - Ok(Self { - config: Arc::new(config), - domain: Arc::new(ServerName::try_from(domain)?.to_owned()), - assume_http2, - }) - } - - pub(crate) async fn connect(&self, io: I) -> Result - where - I: AsyncRead + AsyncWrite + Send + Unpin + 'static, - { - let io = RustlsConnector::from(self.config.clone()) - .connect(self.domain.as_ref().to_owned(), io) - .await?; - - // Generally we require ALPN to be negotiated, but if the user has - // explicitly set `assume_http2` to true, we'll allow it to be missing. - let (_, session) = io.get_ref(); - let alpn_protocol = session.alpn_protocol(); - if !(alpn_protocol == Some(ALPN_H2) || self.assume_http2) { - return Err(TlsError::H2NotNegotiated.into()); - } - Ok(BoxedIo::new(TokioIo::new(io))) - } -} - -impl fmt::Debug for TlsConnector { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("TlsConnector").finish() - } -} - #[derive(Clone)] pub(crate) struct TlsAcceptor { inner: Arc, @@ -154,6 +84,7 @@ impl fmt::Debug for TlsAcceptor { impl fmt::Display for TlsError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { + #[cfg(feature = "channel")] TlsError::H2NotNegotiated => write!(f, "HTTP/2 was not negotiated."), TlsError::CertificateParseError => write!(f, "Error parsing TLS certificate."), TlsError::PrivateKeyParseError => write!( @@ -166,7 +97,7 @@ impl fmt::Display for TlsError { impl std::error::Error for TlsError {} -fn load_identity( +pub(crate) fn load_identity( identity: Identity, ) -> Result<(Vec>, PrivateKeyDer<'static>), TlsError> { let cert = rustls_pemfile::certs(&mut Cursor::new(identity.cert)) @@ -180,7 +111,7 @@ fn load_identity( Ok((cert, key)) } -fn add_certs_from_pem( +pub(crate) fn add_certs_from_pem( mut certs: &mut dyn std::io::BufRead, roots: &mut RootCertStore, ) -> Result<(), crate::Error> {