Skip to content

Commit

Permalink
Use Axum Request and Response in transport
Browse files Browse the repository at this point in the history
This commit is primarily converting Request and Response types within
the transport module to Axum 0.7 Request/Response. There is still more
to come to finish this conversion.

There are also small changes such as updating the hyper service builder
syntax.

Over the course of this commit,it was discovered that hyper-util is missing
`http2_max_pending_accept_reset_streams`.
  • Loading branch information
allan2 authored and alexrudy committed May 26, 2024
1 parent ea67be1 commit 081fad3
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 53 deletions.
22 changes: 9 additions & 13 deletions tonic/src/transport/channel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,9 @@ pub use endpoint::Endpoint;
pub use tls::ClientTlsConfig;

use super::service::{Connection, DynamicServiceStream, SharedExec};
use crate::body::BoxBody;
use crate::transport::Executor;
use bytes::Bytes;
use http::{
uri::{InvalidUri, Uri},
Request, Response,
};
use http::uri::{InvalidUri, Uri};
use hyper_util::client::legacy::connect::Connection as HyperConnection;
use std::{
fmt,
Expand All @@ -30,6 +26,7 @@ use tokio::{
sync::mpsc::{channel, Sender},
};

use axum::{extract::Request, response::Response, body::Body};
use tower::balance::p2c::Balance;
use tower::{
buffer::{self, Buffer},
Expand All @@ -38,7 +35,7 @@ use tower::{
Service,
};

type Svc = Either<Connection, BoxService<Request<BoxBody>, Response<hyper::Body>, crate::Error>>;
type Svc = Either<Connection, BoxService<Request, Response, crate::Error>>;

const DEFAULT_BUFFER_SIZE: usize = 1024;

Expand Down Expand Up @@ -67,14 +64,14 @@ const DEFAULT_BUFFER_SIZE: usize = 1024;
/// cloning the `Channel` type is cheap and encouraged.
#[derive(Clone)]
pub struct Channel {
svc: Buffer<Svc, Request<BoxBody>>,
svc: Buffer<Svc, Request>,
}

/// A future that resolves to an HTTP response.
///
/// This is returned by the `Service::call` on [`Channel`].
pub struct ResponseFuture {
inner: buffer::future::ResponseFuture<<Svc as Service<Request<BoxBody>>>::Future>,
inner: buffer::future::ResponseFuture<<Svc as Service<Request>>::Future>,
}

impl Channel {
Expand Down Expand Up @@ -200,24 +197,24 @@ impl Channel {
}
}

impl Service<http::Request<BoxBody>> for Channel {
type Response = http::Response<super::Body>;
impl Service<Request> for Channel {
type Response = Response;
type Error = super::Error;
type Future = ResponseFuture;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Service::poll_ready(&mut self.svc, cx).map_err(super::Error::from_source)
}

fn call(&mut self, request: http::Request<BoxBody>) -> Self::Future {
fn call(&mut self, request: Request<Body>) -> Self::Future {
let inner = Service::call(&mut self.svc, request);

ResponseFuture { inner }
}
}

impl Future for ResponseFuture {
type Output = Result<Response<hyper::Body>, super::Error>;
type Output = Result<Response, super::Error>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let val = ready!(Pin::new(&mut self.inner).poll(cx)).map_err(super::Error::from_source)?;
Expand All @@ -236,4 +233,3 @@ impl fmt::Debug for ResponseFuture {
f.debug_struct("ResponseFuture").finish()
}
}

23 changes: 13 additions & 10 deletions tonic/src/transport/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub use super::service::Routes;
pub use super::service::RoutesBuilder;

pub use conn::{Connected, TcpConnectInfo};
use hyper_util::rt::TokioExecutor;
#[cfg(feature = "tls")]
pub use tls::ServerTlsConfig;

Expand Down Expand Up @@ -534,16 +535,17 @@ impl<L> Server<L> {
_io: PhantomData,
};

let server = hyper::Server::builder(incoming)
.http2_only(http2_only)
.http2_initial_connection_window_size(init_connection_window_size)
.http2_initial_stream_window_size(init_stream_window_size)
.http2_max_concurrent_streams(max_concurrent_streams)
.http2_keep_alive_interval(http2_keepalive_interval)
.http2_keep_alive_timeout(http2_keepalive_timeout)
.http2_adaptive_window(http2_adaptive_window.unwrap_or_default())
.http2_max_pending_accept_reset_streams(http2_max_pending_accept_reset_streams)
.http2_max_frame_size(max_frame_size);
let server = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
.http2()
.initial_connection_window_size(init_connection_window_size)
.initial_stream_window_size(init_stream_window_size)
.max_concurrent_streams(max_concurrent_streams)
.keep_alive_interval(http2_keepalive_interval)
.keep_alive_timeout(http2_keepalive_timeout)
.adaptive_window(http2_adaptive_window.unwrap_or_default())
// FIXME: wait for this to be added to hyper-util
//.max_pending_accept_reset_streams(http2_max_pending_accept_reset_streams)
.max_frame_size(max_frame_size);

if let Some(signal) = signal {
server
Expand Down Expand Up @@ -885,3 +887,4 @@ where
future::ready(Ok(svc))
}
}

25 changes: 9 additions & 16 deletions tonic/src/transport/service/connection.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
use super::{grpc_timeout::GrpcTimeout, reconnect::Reconnect, AddOrigin, UserAgent};
use crate::{
body::BoxBody,
transport::{BoxFuture, Endpoint},
};
use crate::transport::{BoxFuture, Endpoint};
use http::Uri;
use hyper::client::conn::Builder;
use hyper::client::service::Connect as HyperConnect;
use hyper_util::client::legacy::connect::Connection as HyperConnection;
use hyper::client::conn::http2::Builder;
use hyper_util::client::legacy::connect::{Connect as HyperConnect, Connection as HyperConnection};
use std::{
fmt,
task::{Context, Poll},
Expand All @@ -21,9 +17,8 @@ use tower::{
};
use tower_service::Service;

pub(crate) type Request = http::Request<BoxBody>;
pub(crate) type Response = http::Response<hyper::Body>;

pub(crate) type Request = axum::extract::Request;
pub(crate) type Response = axum::response::Response;
pub(crate) struct Connection {
inner: BoxService<Request, Response, crate::Error>,
}
Expand All @@ -36,12 +31,10 @@ impl Connection {
C::Future: Unpin + Send,
C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static,
{
let mut settings = Builder::new()
.http2_initial_stream_window_size(endpoint.init_stream_window_size)
.http2_initial_connection_window_size(endpoint.init_connection_window_size)
.http2_only(true)
.http2_keep_alive_interval(endpoint.http2_keep_alive_interval)
.executor(endpoint.executor.clone())
let mut settings = Builder::new(endpoint.executor)
.initial_stream_window_size(endpoint.init_stream_window_size)
.initial_connection_window_size(endpoint.init_connection_window_size)
.keep_alive_interval(endpoint.http2_keep_alive_interval)
.clone();

if let Some(val) = endpoint.http2_keep_alive_timeout {
Expand Down
25 changes: 11 additions & 14 deletions tonic/src/transport/service/router.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
use crate::{
body::{boxed, BoxBody},
server::NamedService,
};
use http::{Request, Response};
use hyper::Body;
use crate::{body::boxed, server::NamedService};
use axum::{extract::Request, response::Response};
use pin_project::pin_project;
use std::{
convert::Infallible,
Expand Down Expand Up @@ -31,7 +27,7 @@ impl RoutesBuilder {
/// Add a new service.
pub fn add_service<S>(&mut self, svc: S) -> &mut Self
where
S: Service<Request<Body>, Response = Response<BoxBody>, Error = Infallible>
S: Service<Request, Response = Response, Error = Infallible>
+ NamedService
+ Clone
+ Send
Expand All @@ -53,7 +49,7 @@ impl Routes {
/// Create a new routes with `svc` already added to it.
pub fn new<S>(svc: S) -> Self
where
S: Service<Request<Body>, Response = Response<BoxBody>, Error = Infallible>
S: Service<Request, Response = Response, Error = Infallible>
+ NamedService
+ Clone
+ Send
Expand All @@ -68,7 +64,7 @@ impl Routes {
/// Add a new service.
pub fn add_service<S>(mut self, svc: S) -> Self
where
S: Service<Request<Body>, Response = Response<BoxBody>, Error = Infallible>
S: Service<Request, Response = Response, Error = Infallible>
+ NamedService
+ Clone
+ Send
Expand Down Expand Up @@ -103,8 +99,8 @@ async fn unimplemented() -> impl axum::response::IntoResponse {
(status, headers)
}

impl Service<Request<Body>> for Routes {
type Response = Response<BoxBody>;
impl Service<Request> for Routes {
type Response = Response;
type Error = crate::Error;
type Future = RoutesFuture;

Expand All @@ -113,13 +109,13 @@ impl Service<Request<Body>> for Routes {
Poll::Ready(Ok(()))
}

fn call(&mut self, req: Request<Body>) -> Self::Future {
fn call(&mut self, req: Request) -> Self::Future {
RoutesFuture(self.router.call(req))
}
}

#[pin_project]
pub struct RoutesFuture(#[pin] axum::routing::future::RouteFuture<Body, Infallible>);
pub struct RoutesFuture(#[pin] axum::routing::future::RouteFuture<Infallible>);

impl fmt::Debug for RoutesFuture {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Expand All @@ -128,7 +124,7 @@ impl fmt::Debug for RoutesFuture {
}

impl Future for RoutesFuture {
type Output = Result<Response<BoxBody>, crate::Error>;
type Output = Result<Response, crate::Error>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match ready!(self.project().0.poll(cx)) {
Expand All @@ -137,3 +133,4 @@ impl Future for RoutesFuture {
}
}
}

0 comments on commit 081fad3

Please sign in to comment.