Skip to content

Commit

Permalink
feat(tonic): make it easier to add tower middleware to servers (#651)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidpdrsn authored May 19, 2021
1 parent 4dda4cb commit 4d2667d
Show file tree
Hide file tree
Showing 22 changed files with 734 additions and 301 deletions.
18 changes: 15 additions & 3 deletions examples/src/tower/client.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use hello_world::greeter_client::GreeterClient;
use hello_world::HelloRequest;
use service::AuthSvc;
use tower::ServiceBuilder;

use tonic::transport::Channel;
use tonic::{transport::Channel, Request, Status};

pub mod hello_world {
tonic::include_proto!("helloworld");
Expand All @@ -11,9 +12,14 @@ pub mod hello_world {
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let channel = Channel::from_static("http://[::1]:50051").connect().await?;
let auth = AuthSvc::new(channel);

let mut client = GreeterClient::new(auth);
let channel = ServiceBuilder::new()
// Interceptors can be also be applied as middleware
.layer(tonic::service::interceptor_fn(intercept))
.layer_fn(AuthSvc::new)
.service(channel);

let mut client = GreeterClient::new(channel);

let request = tonic::Request::new(HelloRequest {
name: "Tonic".into(),
Expand All @@ -26,6 +32,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}

// An interceptor function.
fn intercept(req: Request<()>) -> Result<Request<()>, Status> {
println!("received {:?}", req);
Ok(req)
}

mod service {
use http::{Request, Response};
use std::future::Future;
Expand Down
68 changes: 44 additions & 24 deletions examples/src/tower/server.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use hyper::{Body, Request as HyperRequest, Response as HyperResponse};
use std::task::{Context, Poll};
use tonic::{
body::BoxBody,
transport::{NamedService, Server},
Request, Response, Status,
use hyper::Body;
use std::{
task::{Context, Poll},
time::Duration,
};
use tower::Service;
use tonic::{body::BoxBody, transport::Server, Request, Response, Status};
use tower::{Layer, Service};

use hello_world::greeter_server::{Greeter, GreeterServer};
use hello_world::{HelloReply, HelloRequest};
Expand Down Expand Up @@ -39,27 +38,52 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

println!("GreeterServer listening on {}", addr);

let svc = InterceptedService {
inner: GreeterServer::new(greeter),
};

Server::builder().add_service(svc).serve(addr).await?;
let svc = GreeterServer::new(greeter);

// The stack of middleware that our service will be wrapped in
let layer = tower::ServiceBuilder::new()
// Apply middleware from tower
.timeout(Duration::from_secs(30))
// Apply our own middleware
.layer(MyMiddlewareLayer::default())
// Interceptors can be also be applied as middleware
.layer(tonic::service::interceptor_fn(intercept))
.into_inner();

Server::builder()
// Wrap all services in the middleware stack
.layer(layer)
.add_service(svc)
.serve(addr)
.await?;

Ok(())
}

// An interceptor function.
fn intercept(req: Request<()>) -> Result<Request<()>, Status> {
Ok(req)
}

#[derive(Debug, Clone, Default)]
struct MyMiddlewareLayer;

impl<S> Layer<S> for MyMiddlewareLayer {
type Service = MyMiddleware<S>;

fn layer(&self, service: S) -> Self::Service {
MyMiddleware { inner: service }
}
}

#[derive(Debug, Clone)]
struct InterceptedService<S> {
struct MyMiddleware<S> {
inner: S,
}

impl<S> Service<HyperRequest<Body>> for InterceptedService<S>
impl<S> Service<hyper::Request<Body>> for MyMiddleware<S>
where
S: Service<HyperRequest<Body>, Response = HyperResponse<BoxBody>>
+ NamedService
+ Clone
+ Send
+ 'static,
S: Service<hyper::Request<Body>, Response = hyper::Response<BoxBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
Expand All @@ -70,7 +94,7 @@ where
self.inner.poll_ready(cx)
}

fn call(&mut self, req: HyperRequest<Body>) -> Self::Future {
fn call(&mut self, req: hyper::Request<Body>) -> Self::Future {
// This is necessary because tonic internally uses `tower::buffer::Buffer`.
// See https://github.com/tower-rs/tower/issues/547#issuecomment-767629149
// for details on why this is necessary
Expand All @@ -85,7 +109,3 @@ where
})
}
}

impl<S: NamedService> NamedService for InterceptedService<S> {
const NAME: &'static str = S::NAME;
}
3 changes: 3 additions & 0 deletions tests/integration_tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ tokio-stream = { version = "0.1.5", features = ["net"] }
tower-service = "0.3"
hyper = "0.14"
futures = "0.3"
tower = { version = "0.4", features = [] }
http-body = "0.4"
http = "0.2"

[build-dependencies]
tonic-build = { path = "../../tonic-build" }
113 changes: 113 additions & 0 deletions tests/integration_tests/tests/complex_tower_middleware.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#![allow(unused_variables, dead_code)]

use http_body::Body;
use integration_tests::pb::{test_server, Input, Output};
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tonic::{transport::Server, Request, Response, Status};
use tower::{layer::Layer, BoxError, Service};

// all we care about is that this compiles
async fn complex_tower_layers_work() {
struct Svc;

#[tonic::async_trait]
impl test_server::Test for Svc {
async fn unary_call(&self, req: Request<Input>) -> Result<Response<Output>, Status> {
unimplemented!()
}
}

let svc = test_server::TestServer::new(Svc);

Server::builder()
.layer(MyServiceLayer::new())
.add_service(svc)
.serve("127.0.0.1:1322".parse().unwrap())
.await
.unwrap();
}

#[derive(Debug, Clone)]
struct MyServiceLayer {}

impl MyServiceLayer {
fn new() -> Self {
unimplemented!()
}
}

impl<S> Layer<S> for MyServiceLayer {
type Service = MyService<S>;

fn layer(&self, inner: S) -> Self::Service {
unimplemented!()
}
}

#[derive(Debug, Clone)]
struct MyService<S> {
inner: S,
}

impl<S, R, ResBody> Service<R> for MyService<S>
where
S: Service<R, Response = http::Response<ResBody>>,
{
type Response = http::Response<MyBody<ResBody>>;
type Error = BoxError;
type Future = MyFuture<S::Future, ResBody>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
unimplemented!()
}

fn call(&mut self, req: R) -> Self::Future {
unimplemented!()
}
}

struct MyFuture<F, B> {
inner: F,
body: B,
}

impl<F, E, B> Future for MyFuture<F, B>
where
F: Future<Output = Result<http::Response<B>, E>>,
{
type Output = Result<http::Response<MyBody<B>>, BoxError>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
unimplemented!()
}
}

struct MyBody<B> {
inner: B,
}

impl<B> Body for MyBody<B>
where
B: Body,
{
type Data = B::Data;
type Error = BoxError;

fn poll_data(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
unimplemented!()
}

fn poll_trailers(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
unimplemented!()
}
}
20 changes: 13 additions & 7 deletions tonic-build/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,24 @@ pub fn generate<T: Service>(
#connect

impl<T> #service_ident<T>
where T: tonic::client::GrpcService<tonic::body::BoxBody>,
T::ResponseBody: Body + Send + Sync + 'static,
T::Error: Into<StdError>,
<T::ResponseBody as Body>::Error: Into<StdError> + Send, {
where
T: tonic::client::GrpcService<tonic::body::BoxBody>,
T::ResponseBody: Body + Send + Sync + 'static,
T::Error: Into<StdError>,
<T::ResponseBody as Body>::Error: Into<StdError> + Send,
{
pub fn new(inner: T) -> Self {
let inner = tonic::client::Grpc::new(inner);
Self { inner }
}

pub fn with_interceptor(inner: T, interceptor: impl Into<tonic::Interceptor>) -> Self {
let inner = tonic::client::Grpc::with_interceptor(inner, interceptor);
Self { inner }
pub fn with_interceptor<F>(inner: T, interceptor: F) -> #service_ident<InterceptedService<T, F>>
where
F: FnMut(tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status>,
T: Service<http::Request<tonic::body::BoxBody>, Response = http::Response<T::ResponseBody>>,
<T as Service<http::Request<tonic::body::BoxBody>>>::Error: Into<StdError> + Send + Sync,
{
#service_ident::new(InterceptedService::new(inner, interceptor))
}

#methods
Expand Down
Loading

0 comments on commit 4d2667d

Please sign in to comment.