Skip to content

Commit

Permalink
feat(channel): Make channel feature additive
Browse files Browse the repository at this point in the history
  • Loading branch information
tottoto committed Feb 21, 2024
1 parent 18a2b30 commit 4aa7354
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 27 deletions.
4 changes: 2 additions & 2 deletions examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,13 @@ hyper-warp-multiplex = ["hyper-warp"]
uds = ["tokio-stream/net", "dep:tower", "dep:hyper"]
streaming = ["tokio-stream", "dep:h2"]
mock = ["tokio-stream", "dep:tower"]
tower = ["dep:hyper", "dep:tower", "dep:http"]
tower = ["dep:hyper", "tower/timeout", "dep:http"]
json-codec = ["dep:serde", "dep:serde_json", "dep:bytes"]
compression = ["tonic/gzip"]
tls = ["tonic/tls"]
tls-rustls = ["dep:hyper", "dep:hyper-rustls", "dep:tower", "tower-http/util", "tower-http/add-extension", "dep:rustls-pemfile", "dep:tokio-rustls"]
dynamic-load-balance = ["dep:tower"]
timeout = ["tokio/time", "dep:tower"]
timeout = ["tokio/time", "tower/timeout"]
tls-client-auth = ["tonic/tls"]
types = ["dep:tonic-types"]
h2c = ["dep:hyper", "dep:tower", "dep:http"]
Expand Down
23 changes: 14 additions & 9 deletions tonic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,26 @@ version = "0.11.0"
codegen = ["dep:async-trait"]
gzip = ["dep:flate2"]
zstd = ["dep:zstd"]
default = ["transport", "codegen", "prost"]
default = ["channel", "codegen", "prost"]
prost = ["dep:prost"]
tls = ["dep:rustls-pki-types", "dep:rustls-pemfile", "transport", "dep:tokio-rustls", "dep:tokio", "tokio?/rt", "tokio?/macros"]
tls-roots = ["tls-roots-common", "dep:rustls-native-certs"]
tls-roots-common = ["tls"]
tls-roots-common = ["tls", "channel"]
tls-webpki-roots = ["tls-roots-common", "dep:webpki-roots"]
transport = [
"dep:async-stream",
"dep:axum",
"channel",
"dep:h2",
"dep:hyper",
"dep:hyper", "hyper?/server",
"dep:tokio", "tokio?/net", "tokio?/time",
"dep:tower",
"dep:tower", "tower?/util", "tower?/limit",
]
channel = [
"transport",
"dep:hyper", "hyper?/client",
"dep:tower", "tower?/balance", "tower?/buffer", "tower?/discover", "tower?/load", "tower?/make",
"dep:hyper-timeout",
]
channel = []

# [[bench]]
# name = "bench_main"
Expand All @@ -68,13 +71,15 @@ async-trait = {version = "0.1.13", optional = true}

# transport
h2 = {version = "0.3.24", optional = true}
hyper = {version = "0.14.26", features = ["full"], optional = true}
hyper-timeout = {version = "0.4", optional = true}
hyper = {version = "0.14.26", features = ["http1", "http2", "runtime", "stream"], optional = true}
tokio = {version = "1.0.1", optional = true}
tokio-stream = "0.1"
tower = {version = "0.4.7", default-features = false, features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true}
tower = {version = "0.4.7", default-features = false, optional = true}
axum = {version = "0.6.9", default_features = false, optional = true}

# channel
hyper-timeout = {version = "0.4", optional = true}

# rustls
async-stream = { version = "0.3", optional = true }
rustls-pki-types = { version = "1.0", optional = true }
Expand Down
6 changes: 6 additions & 0 deletions tonic/src/transport/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ struct ErrorImpl {
#[derive(Debug)]
pub(crate) enum Kind {
Transport,
#[cfg(feature = "channel")]
InvalidUri,
#[cfg(feature = "channel")]
InvalidUserAgent,
}

Expand All @@ -35,18 +37,22 @@ impl Error {
Error::new(Kind::Transport).with(source)
}

#[cfg(feature = "channel")]
pub(crate) fn new_invalid_uri() -> Self {
Error::new(Kind::InvalidUri)
}

#[cfg(feature = "channel")]
pub(crate) fn new_invalid_user_agent() -> Self {
Error::new(Kind::InvalidUserAgent)
}

fn description(&self) -> &str {
match &self.inner.kind {
Kind::Transport => "transport error",
#[cfg(feature = "channel")]
Kind::InvalidUri => "invalid URI",
#[cfg(feature = "channel")]
Kind::InvalidUserAgent => "user agent is not a valid header value",
}
}
Expand Down
7 changes: 5 additions & 2 deletions tonic/src/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
//!
//! [rustls]: https://docs.rs/rustls/0.16.0/rustls/

#[cfg(feature = "channel")]
pub mod channel;
pub mod server;

Expand All @@ -110,10 +111,11 @@ pub use self::tls::Certificate;
pub use axum::{body::BoxBody as AxumBoxBody, Router as AxumRouter};
pub use hyper::{Body, Uri};

#[cfg(feature = "channel")]
pub(crate) use self::service::executor::Executor;

#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
#[cfg(all(feature = "channel", feature = "tls"))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "channel", feature = "tls"))))]
pub use self::channel::ClientTlsConfig;
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
Expand All @@ -122,4 +124,5 @@ pub use self::server::ServerTlsConfig;
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
pub use self::tls::Identity;

#[cfg(feature = "channel")]
type BoxFuture<'a, T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
8 changes: 8 additions & 0 deletions tonic/src/transport/service/io.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::transport::server::Connected;
#[cfg(feature = "channel")]
use hyper::client::connect::{Connected as HyperConnected, Connection};
use std::io;
use std::io::IoSlice;
Expand All @@ -15,20 +16,24 @@ pub(in crate::transport) trait Io:

impl<T> Io for T where T: AsyncRead + AsyncWrite + Send + 'static {}

#[cfg(feature = "channel")]
pub(crate) struct BoxedIo(Pin<Box<dyn Io>>);

#[cfg(feature = "channel")]
impl BoxedIo {
pub(in crate::transport) fn new<I: Io>(io: I) -> Self {
BoxedIo(Box::pin(io))
}
}

#[cfg(feature = "channel")]
impl Connection for BoxedIo {
fn connected(&self) -> HyperConnected {
HyperConnected::new()
}
}

#[cfg(feature = "channel")]
impl Connected for BoxedIo {
type ConnectInfo = NoneConnectInfo;

Expand All @@ -37,9 +42,11 @@ impl Connected for BoxedIo {
}
}

#[cfg(feature = "channel")]
#[derive(Copy, Clone)]
pub(crate) struct NoneConnectInfo;

#[cfg(feature = "channel")]
impl AsyncRead for BoxedIo {
fn poll_read(
mut self: Pin<&mut Self>,
Expand All @@ -50,6 +57,7 @@ impl AsyncRead for BoxedIo {
}
}

#[cfg(feature = "channel")]
impl AsyncWrite for BoxedIo {
fn poll_write(
mut self: Pin<&mut Self>,
Expand Down
22 changes: 15 additions & 7 deletions tonic/src/transport/service/mod.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,34 @@
#[cfg(feature = "channel")]
mod add_origin;
#[cfg(feature = "channel")]
mod connection;
#[cfg(feature = "channel")]
mod connector;
#[cfg(feature = "channel")]
mod discover;
#[cfg(feature = "channel")]
pub(crate) mod executor;
pub(crate) mod grpc_timeout;
mod io;
#[cfg(feature = "channel")]
mod reconnect;
mod router;
#[cfg(feature = "tls")]
mod tls;
#[cfg(feature = "channel")]
mod user_agent;

pub(crate) use self::add_origin::AddOrigin;
pub(crate) use self::connection::Connection;
pub(crate) use self::connector::Connector;
pub(crate) use self::discover::DynamicServiceStream;
pub(crate) use self::executor::SharedExec;
pub(crate) use self::grpc_timeout::GrpcTimeout;
pub(crate) use self::io::ServerIo;
#[cfg(feature = "tls")]
pub(crate) use self::tls::{TlsAcceptor, TlsConnector};
pub(crate) use self::user_agent::UserAgent;
pub(crate) use self::tls::TlsAcceptor;
#[cfg(all(feature = "channel", feature = "tls"))]
pub(crate) use self::tls::TlsConnector;
#[cfg(feature = "channel")]
pub(crate) use self::{
add_origin::AddOrigin, connection::Connection, connector::Connector,
discover::DynamicServiceStream, executor::SharedExec, user_agent::UserAgent,
};

pub use self::router::Routes;
pub use self::router::RoutesBuilder;
23 changes: 16 additions & 7 deletions tonic/src/transport/service/tls.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
use std::{
io::Cursor,
{fmt, sync::Arc},
};
use std::io::Cursor;
use std::{fmt, sync::Arc};

use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
#[cfg(feature = "channel")]
use rustls_pki_types::ServerName;
use rustls_pki_types::{CertificateDer, PrivateKeyDer};
use tokio::io::{AsyncRead, AsyncWrite};
#[cfg(feature = "channel")]
use tokio_rustls::{rustls::ClientConfig, TlsConnector as RustlsConnector};
use tokio_rustls::{
rustls::{server::WebPkiClientVerifier, ClientConfig, RootCertStore, ServerConfig},
TlsAcceptor as RustlsAcceptor, TlsConnector as RustlsConnector,
rustls::{server::WebPkiClientVerifier, RootCertStore, ServerConfig},
TlsAcceptor as RustlsAcceptor,
};

#[cfg(feature = "channel")]
use super::io::BoxedIo;
use crate::transport::{
server::{Connected, TlsStream},
Expand All @@ -21,17 +24,20 @@ const ALPN_H2: &[u8] = b"h2";

#[derive(Debug)]
enum TlsError {
#[cfg(feature = "channel")]
H2NotNegotiated,
CertificateParseError,
PrivateKeyParseError,
}

#[cfg(feature = "channel")]
#[derive(Clone)]
pub(crate) struct TlsConnector {
config: Arc<ClientConfig>,
domain: Arc<ServerName<'static>>,
}

#[cfg(feature = "channel")]
impl TlsConnector {
pub(crate) fn new(
ca_cert: Option<Certificate>,
Expand Down Expand Up @@ -67,6 +73,7 @@ impl TlsConnector {
})
}

#[cfg(feature = "channel")]
pub(crate) async fn connect<I>(&self, io: I) -> Result<BoxedIo, crate::Error>
where
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
Expand All @@ -84,6 +91,7 @@ impl TlsConnector {
}
}

#[cfg(feature = "channel")]
impl fmt::Debug for TlsConnector {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TlsConnector").finish()
Expand Down Expand Up @@ -145,6 +153,7 @@ impl fmt::Debug for TlsAcceptor {
impl fmt::Display for TlsError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
#[cfg(feature = "channel")]
TlsError::H2NotNegotiated => write!(f, "HTTP/2 was not negotiated."),
TlsError::CertificateParseError => write!(f, "Error parsing TLS certificate."),
TlsError::PrivateKeyParseError => write!(
Expand Down

0 comments on commit 4aa7354

Please sign in to comment.