From c309063254dff42fd05afc5e56b0b0371b905758 Mon Sep 17 00:00:00 2001 From: Juan Alvarez Date: Thu, 13 May 2021 12:20:33 -0500 Subject: [PATCH] feat(tonic-web): implement grpc <-> grpc-web protocol translation (#455) tonic-web enables tonic servers to handle requests from grpc-web clients directly, without the need of an external proxy. Co-authored-by: John Hernandez Co-authored-by: zancas --- tonic-web/.gitignore | 6 + tonic-web/Cargo.toml | 6 + tonic-web/interop/.dockerignore | 2 + tonic-web/interop/Cargo.toml | 13 + tonic-web/interop/Dockerfile | 28 + tonic-web/interop/README.md | 27 + tonic-web/interop/client/interop_client.js | 204 +++++++ tonic-web/interop/client/package.json | 19 + tonic-web/interop/client/test.proto | 68 +++ tonic-web/interop/client/test.sh | 8 + tonic-web/interop/src/main.rs | 17 + tonic-web/tests/integration/Cargo.toml | 19 + tonic-web/tests/integration/build.rs | 11 + tonic-web/tests/integration/proto/test.proto | 19 + tonic-web/tests/integration/src/lib.rs | 69 +++ tonic-web/tests/integration/tests/grpc.rs | 167 ++++++ tonic-web/tests/integration/tests/grpc_web.rs | 154 +++++ tonic-web/tonic-web/Cargo.toml | 21 + tonic-web/tonic-web/src/call.rs | 305 ++++++++++ tonic-web/tonic-web/src/config.rs | 168 ++++++ tonic-web/tonic-web/src/cors.rs | 402 +++++++++++++ tonic-web/tonic-web/src/lib.rs | 133 +++++ tonic-web/tonic-web/src/service.rs | 559 ++++++++++++++++++ tonic/src/transport/server/mod.rs | 22 +- 24 files changed, 2445 insertions(+), 2 deletions(-) create mode 100644 tonic-web/.gitignore create mode 100644 tonic-web/Cargo.toml create mode 100644 tonic-web/interop/.dockerignore create mode 100644 tonic-web/interop/Cargo.toml create mode 100644 tonic-web/interop/Dockerfile create mode 100644 tonic-web/interop/README.md create mode 100644 tonic-web/interop/client/interop_client.js create mode 100644 tonic-web/interop/client/package.json create mode 100644 tonic-web/interop/client/test.proto create mode 100755 tonic-web/interop/client/test.sh create mode 100644 tonic-web/interop/src/main.rs create mode 100644 tonic-web/tests/integration/Cargo.toml create mode 100644 tonic-web/tests/integration/build.rs create mode 100644 tonic-web/tests/integration/proto/test.proto create mode 100644 tonic-web/tests/integration/src/lib.rs create mode 100644 tonic-web/tests/integration/tests/grpc.rs create mode 100644 tonic-web/tests/integration/tests/grpc_web.rs create mode 100644 tonic-web/tonic-web/Cargo.toml create mode 100644 tonic-web/tonic-web/src/call.rs create mode 100644 tonic-web/tonic-web/src/config.rs create mode 100644 tonic-web/tonic-web/src/cors.rs create mode 100644 tonic-web/tonic-web/src/lib.rs create mode 100644 tonic-web/tonic-web/src/service.rs diff --git a/tonic-web/.gitignore b/tonic-web/.gitignore new file mode 100644 index 000000000..79d35a04c --- /dev/null +++ b/tonic-web/.gitignore @@ -0,0 +1,6 @@ +node_modules +/target +binary +text +run.sh +package-lock.json \ No newline at end of file diff --git a/tonic-web/Cargo.toml b/tonic-web/Cargo.toml new file mode 100644 index 000000000..33fe4632d --- /dev/null +++ b/tonic-web/Cargo.toml @@ -0,0 +1,6 @@ +[workspace] +members = [ + "tonic-web", + "interop", + "tests/integration" +] \ No newline at end of file diff --git a/tonic-web/interop/.dockerignore b/tonic-web/interop/.dockerignore new file mode 100644 index 000000000..e62ad6ce5 --- /dev/null +++ b/tonic-web/interop/.dockerignore @@ -0,0 +1,2 @@ +* +!client/* \ No newline at end of file diff --git a/tonic-web/interop/Cargo.toml b/tonic-web/interop/Cargo.toml new file mode 100644 index 000000000..746b97c05 --- /dev/null +++ b/tonic-web/interop/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "web-interop" +version = "0.1.0" +authors = ["Juan Alvarez "] +publish = false +edition = "2018" + +[dependencies] +interop = { path = "../../interop" } +tonic = { path = "../../tonic" } +tonic-web = { path = "../tonic-web" } +tokio = { version = "1.0.1", features = ["rt-multi-thread", "macros"] } + diff --git a/tonic-web/interop/Dockerfile b/tonic-web/interop/Dockerfile new file mode 100644 index 000000000..640bf440b --- /dev/null +++ b/tonic-web/interop/Dockerfile @@ -0,0 +1,28 @@ +FROM node:12-stretch + +RUN apt-get install -y unzip + +WORKDIR /tmp + +RUN curl -sSL https://github.com/protocolbuffers/protobuf/releases/download/v3.14.0/\ +protoc-3.14.0-linux-x86_64.zip -o protoc.zip && \ + unzip -qq protoc.zip && \ + cp ./bin/protoc /usr/local/bin/protoc + +RUN curl -sSL https://github.com/grpc/grpc-web/releases/download/1.2.1/\ +protoc-gen-grpc-web-1.2.1-linux-x86_64 -o /usr/local/bin/protoc-gen-grpc-web && \ + chmod +x /usr/local/bin/protoc-gen-grpc-web + +WORKDIR / + +COPY ./client ./ + +RUN echo "\nloglevel=error\n" >> $HOME/.npmrc && npm install && mkdir -p binary text + +RUN protoc -I=. ./test.proto\ + --js_out=import_style=commonjs:./text\ + --grpc-web_out=import_style=commonjs,mode=grpcwebtext:./text + +RUN protoc -I=. ./test.proto\ + --js_out=import_style=commonjs:./binary\ + --grpc-web_out=import_style=commonjs,mode=grpcweb:./binary diff --git a/tonic-web/interop/README.md b/tonic-web/interop/README.md new file mode 100644 index 000000000..a225af417 --- /dev/null +++ b/tonic-web/interop/README.md @@ -0,0 +1,27 @@ +## Running interop tests + +Start the server: + +```bash +cd tonic-web/interop +cargo run +``` + +Build the client docker image: + +```bash + cd tonic-web/interop + docker build -t grpcweb-client . +``` + +Run tests on linux: + +```bash +docker run --network=host --rm grpcweb-client /test.sh +``` + +Run tests on docker desktop: + +```bash +docker run --rm grpcweb-client /test.sh host.docker.internal +``` diff --git a/tonic-web/interop/client/interop_client.js b/tonic-web/interop/client/interop_client.js new file mode 100644 index 000000000..c41bcc1b0 --- /dev/null +++ b/tonic-web/interop/client/interop_client.js @@ -0,0 +1,204 @@ +/** + * + * Copyright 2018 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Adapted from https://github.com/grpc/grpc-web/tree/master/test/interop + +global.XMLHttpRequest = require("xhr2"); + +const parseArgs = require('minimist'); +const argv = parseArgs(process.argv, { + string: ['mode', 'host'] +}); + +const SERVER_HOST = `http://${argv.host || "localhost"}:9999`; + +if (argv.mode === 'binary') { + console.log('Testing tonic-web mode (binary)...'); +} else { + console.log('Testing tonic-web mode (text)...'); +} +console.log('Tonic server:', SERVER_HOST); + +const PROTO_PATH = argv.mode === 'binary' ? './binary' : './text'; + +const { + Empty, + SimpleRequest, + StreamingOutputCallRequest, + EchoStatus, + Payload, + ResponseParameters +} = require(`${PROTO_PATH}/test_pb.js`); + +const {TestServiceClient} = require(`${PROTO_PATH}/test_grpc_web_pb.js`); + +const assert = require('assert'); +const grpc = {}; +grpc.web = require('grpc-web'); + +function multiDone(done, count) { + return function () { + count -= 1; + if (count <= 0) { + done(); + } + }; +} + +function doEmptyUnary(done) { + const testService = new TestServiceClient(SERVER_HOST, null, null); + testService.emptyCall(new Empty(), null, (err, response) => { + assert.ifError(err); + assert(response instanceof Empty); + done(); + }); +} + +function doLargeUnary(done) { + const testService = new TestServiceClient(SERVER_HOST, null, null); + const req = new SimpleRequest(); + const size = 314159; + + const payload = new Payload(); + payload.setBody('0'.repeat(271828)); + + req.setPayload(payload); + req.setResponseSize(size); + + testService.unaryCall(req, null, (err, response) => { + assert.ifError(err); + assert.equal(response.getPayload().getBody().length, size); + done(); + }); +} + +function doServerStreaming(done) { + const testService = new TestServiceClient(SERVER_HOST, null, null); + const sizes = [31415, 9, 2653, 58979]; + + const responseParams = sizes.map((size, idx) => { + const param = new ResponseParameters(); + param.setSize(size); + param.setIntervalUs(idx * 10); + return param; + }); + + const req = new StreamingOutputCallRequest(); + req.setResponseParametersList(responseParams); + + const stream = testService.streamingOutputCall(req); + + done = multiDone(done, sizes.length); + let numCallbacks = 0; + stream.on('data', (response) => { + assert.equal(response.getPayload().getBody().length, sizes[numCallbacks]); + numCallbacks++; + done(); + }); +} + +function doCustomMetadata(done) { + const testService = new TestServiceClient(SERVER_HOST, null, null); + done = multiDone(done, 3); + + const req = new SimpleRequest(); + const size = 314159; + const ECHO_INITIAL_KEY = 'x-grpc-test-echo-initial'; + const ECHO_INITIAL_VALUE = 'test_initial_metadata_value'; + const ECHO_TRAILING_KEY = 'x-grpc-test-echo-trailing-bin'; + const ECHO_TRAILING_VALUE = 0xababab; + + const payload = new Payload(); + payload.setBody('0'.repeat(271828)); + + req.setPayload(payload); + req.setResponseSize(size); + + const call = testService.unaryCall(req, { + [ECHO_INITIAL_KEY]: ECHO_INITIAL_VALUE, + [ECHO_TRAILING_KEY]: ECHO_TRAILING_VALUE + }, (err, response) => { + assert.ifError(err); + assert.equal(response.getPayload().getBody().length, size); + done(); + }); + + call.on('metadata', (metadata) => { + assert(ECHO_INITIAL_KEY in metadata); + assert.equal(metadata[ECHO_INITIAL_KEY], ECHO_INITIAL_VALUE); + done(); + }); + + call.on('status', (status) => { + assert('metadata' in status); + assert(ECHO_TRAILING_KEY in status.metadata); + assert.equal(status.metadata[ECHO_TRAILING_KEY], ECHO_TRAILING_VALUE); + done(); + }); +} + +function doStatusCodeAndMessage(done) { + const testService = new TestServiceClient(SERVER_HOST, null, null); + const req = new SimpleRequest(); + + const TEST_STATUS_MESSAGE = 'test status message'; + const echoStatus = new EchoStatus(); + echoStatus.setCode(2); + echoStatus.setMessage(TEST_STATUS_MESSAGE); + + req.setResponseStatus(echoStatus); + + testService.unaryCall(req, {}, (err, response) => { + assert(err); + assert('code' in err); + assert('message' in err); + assert.equal(err.code, 2); + assert.equal(err.message, TEST_STATUS_MESSAGE); + done(); + }); +} + +function doUnimplementedMethod(done) { + const testService = new TestServiceClient(SERVER_HOST, null, null); + testService.unimplementedCall(new Empty(), {}, (err, response) => { + assert(err); + assert('code' in err); + assert.equal(err.code, 12); + done(); + }); +} + +const testCases = { + 'empty_unary': {testFunc: doEmptyUnary}, + 'large_unary': {testFunc: doLargeUnary}, + 'server_streaming': { + testFunc: doServerStreaming, + skipBinaryMode: true + }, + 'custom_metadata': {testFunc: doCustomMetadata}, + 'status_code_and_message': {testFunc: doStatusCodeAndMessage}, + 'unimplemented_method': {testFunc: doUnimplementedMethod} +}; + + +describe('tonic-web interop tests', function () { + Object.keys(testCases).forEach((testCase) => { + if (argv.mode === 'binary' && testCases[testCase].skipBinaryMode) return; + it('should pass ' + testCase, testCases[testCase].testFunc); + }); +}); diff --git a/tonic-web/interop/client/package.json b/tonic-web/interop/client/package.json new file mode 100644 index 000000000..515a15d1a --- /dev/null +++ b/tonic-web/interop/client/package.json @@ -0,0 +1,19 @@ +{ + "name": "grpc-web-interop-test", + "version": "0.1.0", + "description": "gRPC-Web Interop Test Client", + "license": "Apache-2.0", + "private": true, + "scripts": { + "test": "mocha -b --timeout 500 ./interop_client.js" + }, + "dependencies": { + "google-protobuf": "~3.14.0", + "grpc-web": "~1.2.1" + }, + "devDependencies": { + "minimist": "~1.2.5", + "mocha": "~7.1.1", + "xhr2": "~0.2.0" + } +} diff --git a/tonic-web/interop/client/test.proto b/tonic-web/interop/client/test.proto new file mode 100644 index 000000000..7df6a8567 --- /dev/null +++ b/tonic-web/interop/client/test.proto @@ -0,0 +1,68 @@ + +// Copyright 2015-2016 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Adapted from https://github.com/grpc/grpc-web/tree/master/src/proto/grpc/testing + +syntax = "proto3"; + +package grpc.testing; + +service TestService { + rpc EmptyCall(grpc.testing.Empty) returns (grpc.testing.Empty); + rpc UnaryCall(SimpleRequest) returns (SimpleResponse); + rpc StreamingOutputCall(StreamingOutputCallRequest) + returns (stream StreamingOutputCallResponse); + rpc UnimplementedCall(grpc.testing.Empty) returns (grpc.testing.Empty); +} + +message Empty {} + +message BoolValue { + bool value = 1; +} + +message Payload { + bytes body = 2; +} + +message EchoStatus { + int32 code = 1; + string message = 2; +} + +message SimpleRequest { + int32 response_size = 2; + Payload payload = 3; + EchoStatus response_status = 7; +} + +message SimpleResponse { + Payload payload = 1; +} + +message ResponseParameters { + int32 size = 1; + int32 interval_us = 2; +} + +message StreamingOutputCallRequest { + repeated ResponseParameters response_parameters = 2; + Payload payload = 3; + EchoStatus response_status = 7; +} + +message StreamingOutputCallResponse { + Payload payload = 1; +} diff --git a/tonic-web/interop/client/test.sh b/tonic-web/interop/client/test.sh new file mode 100755 index 000000000..72e7caef9 --- /dev/null +++ b/tonic-web/interop/client/test.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +set -e + +HOST=${1:-"localhost"} + +npm test -- --host="$HOST" +npm test -- --mode=binary --host="$HOST" \ No newline at end of file diff --git a/tonic-web/interop/src/main.rs b/tonic-web/interop/src/main.rs new file mode 100644 index 000000000..de5279ccf --- /dev/null +++ b/tonic-web/interop/src/main.rs @@ -0,0 +1,17 @@ +use interop::server::{EchoHeadersSvc, TestService, TestServiceServer}; +use tonic::transport::Server; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let addr = ([127, 0, 0, 1], 9999).into(); + let test_svc = TestServiceServer::new(TestService::default()); + let with_echo = EchoHeadersSvc::new(test_svc); + + Server::builder() + .accept_http1(true) + .add_service(tonic_web::enable(with_echo)) + .serve(addr) + .await?; + + Ok(()) +} diff --git a/tonic-web/tests/integration/Cargo.toml b/tonic-web/tests/integration/Cargo.toml new file mode 100644 index 000000000..dd6348fcf --- /dev/null +++ b/tonic-web/tests/integration/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "integration" +version = "0.1.0" +authors = ["Juan Alvarez "] +publish = false +edition = "2018" + +[dependencies] +tonic = { path = "../../../tonic" } +tonic-web = { path = "../../tonic-web" } +prost = "0.7" +tokio = { version = "1.0.1", features = ["macros", "rt", "net"] } +base64 = "0.13" +bytes = "1.0" +tokio-stream = { version = "0.1", features = ["net"] } +hyper = "0.14" + +[build-dependencies] +tonic-build = { path = "../../../tonic-build" } diff --git a/tonic-web/tests/integration/build.rs b/tonic-web/tests/integration/build.rs new file mode 100644 index 000000000..88418c2df --- /dev/null +++ b/tonic-web/tests/integration/build.rs @@ -0,0 +1,11 @@ +fn main() { + let protos = &["proto/test.proto"]; + + tonic_build::configure() + .compile(protos, &["proto"]) + .unwrap(); + + protos + .iter() + .for_each(|file| println!("cargo:rerun-if-changed={}", file)); +} diff --git a/tonic-web/tests/integration/proto/test.proto b/tonic-web/tests/integration/proto/test.proto new file mode 100644 index 000000000..76e01fd10 --- /dev/null +++ b/tonic-web/tests/integration/proto/test.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; + +package test; + +service Test { + rpc UnaryCall(Input) returns (Output); + rpc ServerStream(Input) returns (stream Output); + rpc ClientStream(stream Input) returns (Output); +} + +message Input { + int32 id = 1; + string desc = 2; +} + +message Output { + int32 id = 1; + string desc = 2; +} diff --git a/tonic-web/tests/integration/src/lib.rs b/tonic-web/tests/integration/src/lib.rs new file mode 100644 index 000000000..a712dc1f0 --- /dev/null +++ b/tonic-web/tests/integration/src/lib.rs @@ -0,0 +1,69 @@ +use std::pin::Pin; + +use tokio_stream::{self as stream, StreamExt, Stream}; +use tonic::{Request, Response, Status, Streaming}; + +use pb::{test_server::Test, Input, Output}; + +pub mod pb { + tonic::include_proto!("test"); +} + +type BoxStream = Pin> + Send + Sync + 'static>>; + +pub struct Svc; + +#[tonic::async_trait] +impl Test for Svc { + async fn unary_call(&self, req: Request) -> Result, Status> { + let req = req.into_inner(); + + if &req.desc == "boom" { + Err(Status::invalid_argument("invalid boom")) + } else { + Ok(Response::new(Output { + id: req.id, + desc: req.desc, + })) + } + } + + type ServerStreamStream = BoxStream; + + async fn server_stream( + &self, + req: Request, + ) -> Result, Status> { + let req = req.into_inner(); + + Ok(Response::new(Box::pin(stream::iter(vec![1, 2]).map( + move |n| { + Ok(Output { + id: req.id, + desc: format!("{}-{}", n, req.desc), + }) + }, + )))) + } + + async fn client_stream( + &self, + req: Request>, + ) -> Result, Status> { + let out = Output { + id: 0, + desc: "".into(), + }; + + Ok(Response::new( + req.into_inner() + .fold(out, |mut acc, input| { + let input = input.unwrap(); + acc.id += input.id; + acc.desc += &input.desc; + acc + }) + .await, + )) + } +} diff --git a/tonic-web/tests/integration/tests/grpc.rs b/tonic-web/tests/integration/tests/grpc.rs new file mode 100644 index 000000000..0ab9ad419 --- /dev/null +++ b/tonic-web/tests/integration/tests/grpc.rs @@ -0,0 +1,167 @@ +use std::future::Future; +use std::net::SocketAddr; + +use tokio::net::TcpListener; +use tokio::time::Duration; +use tokio::{join, try_join}; +use tokio_stream::wrappers::TcpListenerStream; +use tokio_stream::{self as stream, StreamExt}; +use tonic::transport::{Channel, Error, Server}; +use tonic::{Response, Streaming}; + +use integration::pb::{test_client::TestClient, test_server::TestServer, Input}; +use integration::Svc; + +#[tokio::test] +async fn smoke_unary() { + let (mut c1, mut c2, mut c3, mut c4) = spawn().await.expect("clients"); + + let (r1, r2, r3, r4) = try_join!( + c1.unary_call(input()), + c2.unary_call(input()), + c3.unary_call(input()), + c4.unary_call(input()), + ) + .expect("responses"); + + assert!(meta(&r1) == meta(&r2) && meta(&r2) == meta(&r3) && meta(&r3) == meta(&r4)); + assert!(data(&r1) == data(&r2) && data(&r2) == data(&r3) && data(&r3) == data(&r4)); +} + +#[tokio::test] +async fn smoke_client_stream() { + let (mut c1, mut c2, mut c3, mut c4) = spawn().await.expect("clients"); + + let input_stream = || stream::iter(vec![input(), input()]); + + let (r1, r2, r3, r4) = try_join!( + c1.client_stream(input_stream()), + c2.client_stream(input_stream()), + c3.client_stream(input_stream()), + c4.client_stream(input_stream()), + ) + .expect("responses"); + + assert!(meta(&r1) == meta(&r2) && meta(&r2) == meta(&r3) && meta(&r3) == meta(&r4)); + assert!(data(&r1) == data(&r2) && data(&r2) == data(&r3) && data(&r3) == data(&r4)); +} + +#[tokio::test] +async fn smoke_server_stream() { + let (mut c1, mut c2, mut c3, mut c4) = spawn().await.expect("clients"); + + let (r1, r2, r3, r4) = try_join!( + c1.server_stream(input()), + c2.server_stream(input()), + c3.server_stream(input()), + c4.server_stream(input()), + ) + .expect("responses"); + + assert!(meta(&r1) == meta(&r2) && meta(&r2) == meta(&r3) && meta(&r3) == meta(&r4)); + + let r1 = stream(r1).await; + let r2 = stream(r2).await; + let r3 = stream(r3).await; + let r4 = stream(r4).await; + + assert!(&r1 == &r2 && &r2 == &r3 && &r3 == &r4); +} +#[tokio::test] +async fn smoke_error() { + let (mut c1, mut c2, mut c3, mut c4) = spawn().await.expect("clients"); + + let boom = Input { + id: 1, + desc: "boom".to_owned(), + }; + + let (r1, r2, r3, r4) = join!( + c1.unary_call(boom.clone()), + c2.unary_call(boom.clone()), + c3.unary_call(boom.clone()), + c4.unary_call(boom.clone()), + ); + + let s1 = r1.unwrap_err(); + let s2 = r2.unwrap_err(); + let s3 = r3.unwrap_err(); + let s4 = r4.unwrap_err(); + + assert!(status(&s1) == status(&s2) && status(&s2) == status(&s3) && status(&s3) == status(&s4)) +} + +async fn bind() -> (TcpListener, String) { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let lis = TcpListener::bind(addr).await.expect("listener"); + let url = format!("http://{}", lis.local_addr().unwrap()); + + (lis, url) +} + +async fn grpc(accept_h1: bool) -> (impl Future>, String) { + let (listener, url) = bind().await; + + let fut = Server::builder() + .accept_http1(accept_h1) + .add_service(TestServer::new(Svc)) + .serve_with_incoming(TcpListenerStream::new(listener)); + + (fut, url) +} + +async fn grpc_web(accept_h1: bool) -> (impl Future>, String) { + let (listener, url) = bind().await; + + let svc = tonic_web::config() + .allow_origins(vec!["http://foo.com"]) + .enable(TestServer::new(Svc)); + + let fut = Server::builder() + .accept_http1(accept_h1) + .add_service(svc) + .serve_with_incoming(TcpListenerStream::new(listener)); + + (fut, url) +} + +type Client = TestClient; + +async fn spawn() -> Result<(Client, Client, Client, Client), Error> { + let ((s1, u1), (s2, u2), (s3, u3), (s4, u4)) = + join!(grpc(true), grpc(false), grpc_web(true), grpc_web(false)); + + let _ = tokio::spawn(async move { join!(s1, s2, s3, s4) }); + + tokio::time::sleep(Duration::from_millis(30)).await; + + try_join!( + TestClient::connect(u1), + TestClient::connect(u2), + TestClient::connect(u3), + TestClient::connect(u4) + ) +} + +fn input() -> Input { + Input { + id: 1, + desc: "one".to_owned(), + } +} + +fn meta(r: &Response) -> String { + format!("{:?}", r.metadata()) +} + +fn data(r: &Response) -> &T { + r.get_ref() +} + +async fn stream(r: Response>) -> Vec { + r.into_inner().collect::, _>>().await.unwrap() +} + +fn status(s: &tonic::Status) -> (String, tonic::Code) { + (format!("{:?}", s.metadata()), s.code()) +} diff --git a/tonic-web/tests/integration/tests/grpc_web.rs b/tonic-web/tests/integration/tests/grpc_web.rs new file mode 100644 index 000000000..a833ef426 --- /dev/null +++ b/tonic-web/tests/integration/tests/grpc_web.rs @@ -0,0 +1,154 @@ +use std::net::SocketAddr; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use hyper::http::{header, StatusCode}; +use hyper::{Body, Client, Method, Request, Uri}; +use prost::Message; +use tokio::net::TcpListener; +use tokio_stream::wrappers::TcpListenerStream; +use tonic::transport::Server; + +use integration::pb::{test_server::TestServer, Input, Output}; +use integration::Svc; + +#[tokio::test] +async fn binary_request() { + let server_url = spawn("http://example.com").await; + let client = Client::new(); + + let req = build_request(server_url, "grpc-web", "grpc-web"); + let res = client.request(req).await.unwrap(); + let content_type = res.headers().get(header::CONTENT_TYPE).unwrap().clone(); + let content_type = content_type.to_str().unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(content_type, "application/grpc-web+proto"); + + let (message, trailers) = decode_body(res.into_body(), content_type).await; + let expected = Output { + id: 1, + desc: "one".to_owned(), + }; + + assert_eq!(message, expected); + assert_eq!(&trailers[..], b"grpc-status:0\r\n"); +} + +#[tokio::test] +async fn text_request() { + let server_url = spawn("http://example.com").await; + let client = Client::new(); + + let req = build_request(server_url, "grpc-web-text", "grpc-web-text"); + let res = client.request(req).await.unwrap(); + let content_type = res.headers().get(header::CONTENT_TYPE).unwrap().clone(); + let content_type = content_type.to_str().unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(content_type, "application/grpc-web-text+proto"); + + let (message, trailers) = decode_body(res.into_body(), content_type).await; + let expected = Output { + id: 1, + desc: "one".to_owned(), + }; + + assert_eq!(message, expected); + assert_eq!(&trailers[..], b"grpc-status:0\r\n"); +} + +#[tokio::test] +async fn origin_not_allowed() { + let server_url = spawn("http://foo.com").await; + let client = Client::new(); + + let req = build_request(server_url, "grpc-web-text", "grpc-web-text"); + let res = client.request(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::FORBIDDEN); +} + + +async fn spawn(allowed_origin: &str) -> String { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await.expect("listener"); + let url = format!("http://{}", listener.local_addr().unwrap()); + let listener_stream = TcpListenerStream::new(listener); + + let svc = tonic_web::config() + .allow_origins(vec![allowed_origin]) + .enable(TestServer::new(Svc)); + + let _ = tokio::spawn(async move { + Server::builder() + .accept_http1(true) + .add_service(svc) + .serve_with_incoming(listener_stream) + .await + .unwrap() + }); + + url +} + +fn encode_body() -> Bytes { + let input = Input { + id: 1, + desc: "one".to_owned(), + }; + + let mut buf = BytesMut::with_capacity(1024); + buf.reserve(5); + unsafe { + buf.advance_mut(5); + } + + input.encode(&mut buf).unwrap(); + + let len = buf.len() - 5; + { + let mut buf = &mut buf[..5]; + buf.put_u8(0); + buf.put_u32(len as u32); + } + + buf.split_to(len + 5).freeze() +} + +fn build_request(base_uri: String, content_type: &str, accept: &str) -> Request { + use header::{ACCEPT, CONTENT_TYPE, ORIGIN}; + + let request_uri = format!("{}/{}/{}", base_uri, "test.Test", "UnaryCall") + .parse::() + .unwrap(); + + let bytes = match content_type { + "grpc-web" => encode_body(), + "grpc-web-text" => base64::encode(encode_body()).into(), + _ => panic!("invalid content type {}", content_type), + }; + + Request::builder() + .method(Method::POST) + .header(CONTENT_TYPE, format!("application/{}", content_type)) + .header(ORIGIN, "http://example.com") + .header(ACCEPT, format!("application/{}", accept)) + .uri(request_uri) + .body(Body::from(bytes)) + .unwrap() +} + +async fn decode_body(body: Body, content_type: &str) -> (Output, Bytes) { + let mut body = hyper::body::to_bytes(body).await.unwrap(); + + if content_type == "application/grpc-web-text+proto" { + body = base64::decode(body).unwrap().into() + } + + body.advance(1); + let len = body.get_u32(); + let msg = Output::decode(&mut body.split_to(len as usize)).expect("decode"); + body.advance(5); + + (msg, body) +} diff --git a/tonic-web/tonic-web/Cargo.toml b/tonic-web/tonic-web/Cargo.toml new file mode 100644 index 000000000..72b7eff89 --- /dev/null +++ b/tonic-web/tonic-web/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "tonic-web" +version = "0.1.0" +authors = ["Juan Alvarez "] +edition = "2018" + +[dependencies] +tonic = { path = "../../tonic", default-features = false, features = ["transport"] } +http = "0.2" +base64 = "0.13" +futures-core = "0.3" +bytes = "1.0" +hyper = "0.14" +http-body = "0.4" +tower-service = "0.3" +tracing = "0.1" +pin-project = "1" + +[dev-dependencies] +tokio = { version = "1.0.1", features = ["macros", "rt"] } +tonic = { path = "../../tonic", default-features = false, features = ["transport", "tls"] } diff --git a/tonic-web/tonic-web/src/call.rs b/tonic-web/tonic-web/src/call.rs new file mode 100644 index 000000000..26a52d291 --- /dev/null +++ b/tonic-web/tonic-web/src/call.rs @@ -0,0 +1,305 @@ +use std::error::Error; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use futures_core::{ready, Stream}; +use http::{header, HeaderMap, HeaderValue}; +use http_body::{Body, SizeHint}; +use pin_project::pin_project; +use tonic::Status; + +use self::content_types::*; + +pub(crate) mod content_types { + use http::{header::CONTENT_TYPE, HeaderMap}; + + pub(crate) const GRPC_WEB: &str = "application/grpc-web"; + pub(crate) const GRPC_WEB_PROTO: &str = "application/grpc-web+proto"; + pub(crate) const GRPC_WEB_TEXT: &str = "application/grpc-web-text"; + pub(crate) const GRPC_WEB_TEXT_PROTO: &str = "application/grpc-web-text+proto"; + + pub(crate) fn is_grpc_web(headers: &HeaderMap) -> bool { + matches!( + content_type(headers), + Some(GRPC_WEB) | Some(GRPC_WEB_PROTO) | Some(GRPC_WEB_TEXT) | Some(GRPC_WEB_TEXT_PROTO) + ) + } + + fn content_type(headers: &HeaderMap) -> Option<&str> { + headers.get(CONTENT_TYPE).and_then(|val| val.to_str().ok()) + } +} + +const BUFFER_SIZE: usize = 8 * 1024; + +const FRAME_HEADER_SIZE: usize = 5; + +// 8th (MSB) bit of the 1st gRPC frame byte +// denotes an uncompressed trailer (as part of the body) +const GRPC_WEB_TRAILERS_BIT: u8 = 0b10000000; + +#[derive(Copy, Clone, PartialEq, Debug)] +enum Direction { + Request, + Response, +} + +#[derive(Copy, Clone, PartialEq, Debug)] +pub(crate) enum Encoding { + Base64, + None, +} + +#[pin_project] +pub(crate) struct GrpcWebCall { + #[pin] + inner: B, + buf: BytesMut, + direction: Direction, + encoding: Encoding, + poll_trailers: bool, +} + +impl GrpcWebCall { + pub(crate) fn request(inner: B, encoding: Encoding) -> Self { + Self::new(inner, Direction::Request, encoding) + } + + pub(crate) fn response(inner: B, encoding: Encoding) -> Self { + Self::new(inner, Direction::Response, encoding) + } + + fn new(inner: B, direction: Direction, encoding: Encoding) -> Self { + GrpcWebCall { + inner, + buf: BytesMut::with_capacity(match (direction, encoding) { + (Direction::Response, Encoding::Base64) => BUFFER_SIZE, + _ => 0, + }), + direction, + encoding, + poll_trailers: true, + } + } + + // This is to avoid passing a slice of bytes with a length that the base64 + // decoder would consider invalid. + #[inline] + fn max_decodable(&self) -> usize { + (self.buf.len() / 4) * 4 + } + + fn decode_chunk(mut self: Pin<&mut Self>) -> Result, Status> { + // not enough bytes to decode + if self.buf.is_empty() || self.buf.len() < 4 { + return Ok(None); + } + + // Split `buf` at the largest index that is multiple of 4. Decode the + // returned `Bytes`, keeping the rest for the next attempt to decode. + let index = self.max_decodable(); + + base64::decode(self.as_mut().project().buf.split_to(index)) + .map(|decoded| Some(Bytes::from(decoded))) + .map_err(internal_error) + } +} + +impl GrpcWebCall +where + B: Body, + B::Error: Error, +{ + fn poll_decode( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + match self.encoding { + Encoding::Base64 => loop { + if let Some(bytes) = self.as_mut().decode_chunk()? { + return Poll::Ready(Some(Ok(bytes))); + } + + let mut this = self.as_mut().project(); + + match ready!(this.inner.as_mut().poll_data(cx)) { + Some(Ok(data)) => this.buf.put(data), + Some(Err(e)) => return Poll::Ready(Some(Err(internal_error(e)))), + None => { + return if this.buf.has_remaining() { + Poll::Ready(Some(Err(internal_error("malformed base64 request")))) + } else { + Poll::Ready(None) + } + } + } + }, + + Encoding::None => match ready!(self.project().inner.poll_data(cx)) { + Some(res) => Poll::Ready(Some(res.map_err(internal_error))), + None => Poll::Ready(None), + }, + } + } + + fn poll_encode( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let mut this = self.as_mut().project(); + + if let Some(mut res) = ready!(this.inner.as_mut().poll_data(cx)) { + if *this.encoding == Encoding::Base64 { + res = res.map(|b| base64::encode(b).into()) + } + + return Poll::Ready(Some(res.map_err(internal_error))); + } + + // this flag is needed because the inner stream never + // returns Poll::Ready(None) when polled for trailers + if *this.poll_trailers { + return match ready!(this.inner.poll_trailers(cx)) { + Ok(Some(map)) => { + let mut frame = make_trailers_frame(map); + + if *this.encoding == Encoding::Base64 { + frame = base64::encode(frame).into_bytes(); + } + + *this.poll_trailers = false; + Poll::Ready(Some(Ok(frame.into()))) + } + Ok(None) => Poll::Ready(None), + Err(e) => Poll::Ready(Some(Err(internal_error(e)))), + }; + } + + Poll::Ready(None) + } +} + +impl Body for GrpcWebCall +where + B: Body, + B::Error: Error, +{ + type Data = Bytes; + type Error = Status; + + fn poll_data( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + match self.direction { + Direction::Request => self.poll_decode(cx), + Direction::Response => self.poll_encode(cx), + } + } + + fn poll_trailers( + self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll>, Self::Error>> { + Poll::Ready(Ok(None)) + } + + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } + + fn size_hint(&self) -> SizeHint { + self.inner.size_hint() + } +} + +impl Stream for GrpcWebCall +where + B: Body, + B::Error: Error, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Body::poll_data(self, cx) + } +} + +impl Encoding { + pub(crate) fn from_content_type(headers: &HeaderMap) -> Encoding { + Self::from_header(headers.get(header::CONTENT_TYPE)) + } + + pub(crate) fn from_accept(headers: &HeaderMap) -> Encoding { + Self::from_header(headers.get(header::ACCEPT)) + } + + pub(crate) fn to_content_type(&self) -> &'static str { + match self { + Encoding::Base64 => GRPC_WEB_TEXT_PROTO, + Encoding::None => GRPC_WEB_PROTO, + } + } + + fn from_header(value: Option<&HeaderValue>) -> Encoding { + match value.and_then(|val| val.to_str().ok()) { + Some(GRPC_WEB_TEXT_PROTO) | Some(GRPC_WEB_TEXT) => Encoding::Base64, + _ => Encoding::None, + } + } +} + +fn internal_error(e: impl std::fmt::Display) -> Status { + Status::internal(format!("tonic-web: {}", e)) +} + +// Key-value pairs encoded as a HTTP/1 headers block (without the terminating newline) +fn encode_trailers(trailers: HeaderMap) -> Vec { + trailers.iter().fold(Vec::new(), |mut acc, (key, value)| { + acc.put_slice(key.as_ref()); + acc.push(b':'); + acc.put_slice(value.as_bytes()); + acc.put_slice(b"\r\n"); + acc + }) +} + +fn make_trailers_frame(trailers: HeaderMap) -> Vec { + let trailers = encode_trailers(trailers); + let len = trailers.len(); + assert!(len <= u32::MAX as usize); + + let mut frame = Vec::with_capacity(len + FRAME_HEADER_SIZE); + frame.push(GRPC_WEB_TRAILERS_BIT); + frame.put_u32(len as u32); + frame.extend(trailers); + + frame +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn encoding_constructors() { + let cases = &[ + (GRPC_WEB, Encoding::None), + (GRPC_WEB_PROTO, Encoding::None), + (GRPC_WEB_TEXT, Encoding::Base64), + (GRPC_WEB_TEXT_PROTO, Encoding::Base64), + ("foo", Encoding::None), + ]; + + let mut headers = HeaderMap::new(); + + for case in cases { + headers.insert(header::CONTENT_TYPE, case.0.parse().unwrap()); + headers.insert(header::ACCEPT, case.0.parse().unwrap()); + + assert_eq!(Encoding::from_content_type(&headers), case.1, "{}", case.0); + assert_eq!(Encoding::from_accept(&headers), case.1, "{}", case.0); + } + } +} diff --git a/tonic-web/tonic-web/src/config.rs b/tonic-web/tonic-web/src/config.rs new file mode 100644 index 000000000..e27698f2e --- /dev/null +++ b/tonic-web/tonic-web/src/config.rs @@ -0,0 +1,168 @@ +use std::collections::{BTreeSet, HashSet}; +use std::convert::TryFrom; +use std::time::Duration; + +use http::{header::HeaderName, HeaderValue}; +use tonic::body::BoxBody; +use tonic::transport::NamedService; +use tower_service::Service; + +use crate::service::GrpcWeb; +use crate::BoxError; + +const DEFAULT_MAX_AGE: Duration = Duration::from_secs(24 * 60 * 60); + +const DEFAULT_EXPOSED_HEADERS: [&str; 2] = ["grpc-status", "grpc-message"]; + +/// A Configuration builder for grpc_web services. +/// +/// `Config` can be used to tweak the behavior of tonic_web services. Currently, +/// `Config` instances only expose cors settings. However, since tonic_web is designed to work +/// with grpc-web compliant clients only, some cors options have specific default values and not +/// all settings are configurable. +/// +/// ## Default values and configuration options +/// +/// * `allow-origin`: All origins allowed by default. Configurable, but null and wildcard origins +/// are not supported. +/// * `allow-methods`: `[POST,OPTIONS]`. Not configurable. +/// * `allow-headers`: Set to whatever the `OPTIONS` request carries. Not configurable. +/// * `allow-credentials`: `true`. Configurable. +/// * `max-age`: `86400`. Configurable. +/// * `expose-headers`: `grpc-status,grpc-message`. Configurable but values can only be added. +/// `grpc-status` and `grpc-message` will always be exposed. +#[derive(Debug, Clone)] +pub struct Config { + pub(crate) allowed_origins: AllowedOrigins, + pub(crate) exposed_headers: HashSet, + pub(crate) max_age: Option, + pub(crate) allow_credentials: bool, +} + +#[derive(Debug, Clone)] +pub(crate) enum AllowedOrigins { + Any, + #[allow(clippy::mutable_key_type)] + Only(BTreeSet), +} + +impl AllowedOrigins { + pub(crate) fn is_allowed(&self, origin: &HeaderValue) -> bool { + match self { + AllowedOrigins::Any => true, + AllowedOrigins::Only(origins) => origins.contains(origin), + } + } +} + +impl Config { + pub(crate) fn new() -> Config { + Config { + allowed_origins: AllowedOrigins::Any, + exposed_headers: DEFAULT_EXPOSED_HEADERS + .iter() + .cloned() + .map(HeaderName::from_static) + .collect(), + max_age: Some(DEFAULT_MAX_AGE), + allow_credentials: true, + } + } + + /// Allow any origin to access this resource. + /// + /// This is the default value. + pub fn allow_all_origins(self) -> Config { + Self { + allowed_origins: AllowedOrigins::Any, + ..self + } + } + + /// Only allow a specific set of origins to access this resource. + /// + /// ## Example + /// + /// ``` + /// tonic_web::config().allow_origins(vec!["http://a.com", "http://b.com"]); + /// ``` + pub fn allow_origins(self, origins: I) -> Config + where + I: IntoIterator, + HeaderValue: TryFrom, + { + // false positive when using HeaderValue, which uses Bytes internally + // https://rust-lang.github.io/rust-clippy/master/index.html#mutable_key_type + #[allow(clippy::mutable_key_type)] + let origins = origins + .into_iter() + .map(|v| match TryFrom::try_from(v) { + Ok(uri) => uri, + Err(_) => panic!("invalid origin"), + }) + .collect(); + + Self { + allowed_origins: AllowedOrigins::Only(origins), + ..self + } + } + + /// Adds multiple headers to the list of exposed headers. + /// + /// Default: `grpc-status,grpc-message`. These will always be included. + pub fn expose_headers(mut self, headers: I) -> Config + where + I: IntoIterator, + HeaderName: TryFrom, + { + let iter = headers + .into_iter() + .map(|header| match TryFrom::try_from(header) { + Ok(header) => header, + Err(_) => panic!("invalid header"), + }); + + self.exposed_headers.extend(iter); + self + } + + /// Defines the maximum cache lifetime for operations allowed on this + /// resource. + /// + /// Default: "86400" (24 hours) + pub fn max_age>>(self, max_age: T) -> Config { + Self { + max_age: max_age.into(), + ..self + } + } + + /// If true, the `access-control-allow-credentials` will be sent. + /// + /// Default: true + pub fn allow_credentials(self, allow_credentials: bool) -> Config { + Self { + allow_credentials, + ..self + } + } + + /// enable a tonic service to handle grpc-web requests with this configuration values. + pub fn enable(&self, service: S) -> GrpcWeb + where + S: Service, Response = http::Response>, + S: NamedService + Clone + Send + 'static, + S::Future: Send + 'static, + S::Error: Into + Send, + { + tracing::trace!("enabled for {}", S::NAME); + GrpcWeb::new(service, self.clone()) + } +} + +impl Default for Config { + fn default() -> Self { + Config::new() + } +} diff --git a/tonic-web/tonic-web/src/cors.rs b/tonic-web/tonic-web/src/cors.rs new file mode 100644 index 000000000..20b9650b2 --- /dev/null +++ b/tonic-web/tonic-web/src/cors.rs @@ -0,0 +1,402 @@ +use std::sync::Arc; + +pub(crate) use http::header::ACCESS_CONTROL_ALLOW_CREDENTIALS as ALLOW_CREDENTIALS; +pub(crate) use http::header::ACCESS_CONTROL_ALLOW_HEADERS as ALLOW_HEADERS; +pub(crate) use http::header::ACCESS_CONTROL_ALLOW_METHODS as ALLOW_METHODS; +pub(crate) use http::header::ACCESS_CONTROL_ALLOW_ORIGIN as ALLOW_ORIGIN; +pub(crate) use http::header::ACCESS_CONTROL_EXPOSE_HEADERS as EXPOSE_HEADERS; +pub(crate) use http::header::ACCESS_CONTROL_MAX_AGE as MAX_AGE; +pub(crate) use http::header::ACCESS_CONTROL_REQUEST_HEADERS as REQUEST_HEADERS; +pub(crate) use http::header::ACCESS_CONTROL_REQUEST_METHOD as REQUEST_METHOD; +pub(crate) use http::header::ORIGIN; +use http::{header, HeaderMap, HeaderValue, Method}; +use tracing::debug; + +use crate::config::Config; + +const DEFAULT_ALLOWED_METHODS: &[Method; 2] = &[Method::POST, Method::OPTIONS]; + +#[derive(Debug, Clone)] +pub(crate) struct Cors { + cache: Arc, +} + +#[derive(Debug, PartialEq)] +pub(crate) enum Error { + OriginNotAllowed, + MethodNotAllowed, +} + +#[derive(Clone, Debug)] +struct Cache { + config: Config, + expose_headers: HeaderValue, + allow_methods: HeaderValue, + allow_credentials: HeaderValue, +} + +impl Cors { + pub(crate) fn new(config: Config) -> Cors { + let expose_headers = join_header_value(&config.exposed_headers).unwrap(); + let allow_methods = HeaderValue::from_static("POST,OPTIONS"); + let allow_credentials = HeaderValue::from_static("true"); + + let cache = Arc::new(Cache { + config, + expose_headers, + allow_methods, + allow_credentials, + }); + + Cors { cache } + } + + fn is_method_allowed(&self, header: Option<&HeaderValue>) -> bool { + match header { + Some(value) => match Method::from_bytes(value.as_bytes()) { + Ok(method) => DEFAULT_ALLOWED_METHODS.contains(&method), + Err(_) => { + debug!("access-control-request-method {:?} is not valid", value); + false + } + }, + None => { + debug!("access-control-request-method is missing"); + false + } + } + } + + pub(crate) fn preflight( + &self, + req_headers: &HeaderMap, + origin: &HeaderValue, + request_headers_header: &HeaderValue, + ) -> Result { + if !self.is_origin_allowed(origin) { + return Err(Error::OriginNotAllowed); + } + + if !self.is_method_allowed(req_headers.get(REQUEST_METHOD)) { + return Err(Error::MethodNotAllowed); + } + + let mut headers = self.common_headers(origin.clone()); + headers.insert(ALLOW_METHODS, self.cache.allow_methods.clone()); + headers.insert(ALLOW_HEADERS, request_headers_header.clone()); + + if let Some(max_age) = self.cache.config.max_age { + headers.insert(MAX_AGE, HeaderValue::from(max_age.as_secs())); + } + + Ok(headers) + } + + pub(crate) fn simple(&self, headers: &HeaderMap) -> Result { + match headers.get(header::ORIGIN) { + Some(origin) if self.is_origin_allowed(origin) => { + Ok(self.common_headers(origin.clone())) + } + Some(_) => Err(Error::OriginNotAllowed), + None => Ok(HeaderMap::new()), + } + } + + fn common_headers(&self, origin: HeaderValue) -> HeaderMap { + let mut headers = HeaderMap::new(); + headers.insert(ALLOW_ORIGIN, origin); + headers.insert(EXPOSE_HEADERS, self.cache.expose_headers.clone()); + + if self.cache.config.allow_credentials { + headers.insert(ALLOW_CREDENTIALS, self.cache.allow_credentials.clone()); + } + + headers + } + + fn is_origin_allowed(&self, origin: &HeaderValue) -> bool { + self.cache.config.allowed_origins.is_allowed(origin) + } + + #[cfg(test)] + pub(crate) fn __check_preflight(&self, headers: &HeaderMap) -> Result { + self.preflight( + headers, + headers.get(ORIGIN).unwrap(), + headers.get(REQUEST_HEADERS).unwrap(), + ) + } +} + +#[cfg(test)] +impl Default for Cors { + fn default() -> Self { + Cors::new(Config::default()) + } +} + +fn join_header_value(values: I) -> Result +where + I: IntoIterator, + I::Item: AsRef, +{ + let mut values = values.into_iter(); + let mut value = Vec::new(); + + if let Some(v) = values.next() { + value.extend(v.as_ref().as_bytes()); + } + for v in values { + value.push(b','); + value.extend(v.as_ref().as_bytes()); + } + HeaderValue::from_bytes(&value) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! assert_value_eq { + ($header:expr, $expected:expr) => { + fn sorted(value: &str) -> Vec<&str> { + let mut vec = value.split(",").collect::>(); + vec.sort(); + vec + } + + assert_eq!(sorted($header.to_str().unwrap()), sorted($expected)) + }; + } + + fn value(s: &str) -> HeaderValue { + s.parse().unwrap() + } + + impl From for Cors { + fn from(c: Config) -> Self { + Cors::new(c) + } + } + + #[test] + #[should_panic] + #[ignore] + fn origin_is_valid_url() { + Config::new().allow_origins(vec!["foo"]); + } + + mod preflight { + use super::*; + + fn preflight_headers() -> HeaderMap { + let mut headers = HeaderMap::new(); + headers.insert(ORIGIN, value("http://example.com")); + headers.insert(REQUEST_METHOD, value("POST")); + headers.insert(REQUEST_HEADERS, value("x-grpc-web")); + headers + } + + #[test] + fn default_config() { + let cors = Cors::default(); + let headers = cors.__check_preflight(&preflight_headers()).unwrap(); + + assert_eq!(headers[ALLOW_ORIGIN], "http://example.com"); + assert_eq!(headers[ALLOW_METHODS], "POST,OPTIONS"); + assert_eq!(headers[ALLOW_HEADERS], "x-grpc-web"); + assert_eq!(headers[ALLOW_CREDENTIALS], "true"); + assert_eq!(headers[MAX_AGE], "86400"); + assert_value_eq!(&headers[EXPOSE_HEADERS], "grpc-status,grpc-message"); + } + + #[test] + fn any_origin() { + let cors: Cors = Config::new().allow_all_origins().into(); + + assert!(cors.__check_preflight(&preflight_headers()).is_ok()); + } + + #[test] + fn origin_list() { + let cors: Cors = Config::new() + .allow_origins(vec![ + HeaderValue::from_static("http://a.com"), + HeaderValue::from_static("http://b.com"), + ]) + .into(); + + let mut req_headers = preflight_headers(); + req_headers.insert(ORIGIN, value("http://b.com")); + + assert!(cors.__check_preflight(&req_headers).is_ok()); + } + + #[test] + fn origin_not_allowed() { + let cors: Cors = Config::new().allow_origins(vec!["http://a.com"]).into(); + + let err = cors.__check_preflight(&preflight_headers()).unwrap_err(); + + assert_eq!(err, Error::OriginNotAllowed) + } + + #[test] + fn disallow_credentials() { + let cors = Cors::new(Config::new().allow_credentials(false)); + let headers = cors.__check_preflight(&preflight_headers()).unwrap(); + + assert!(!headers.contains_key(ALLOW_CREDENTIALS)); + } + + #[test] + fn expose_headers_are_merged() { + let cors = Cors::new(Config::new().expose_headers(vec!["x-request-id"])); + let headers = cors.__check_preflight(&preflight_headers()).unwrap(); + + assert_value_eq!( + &headers[EXPOSE_HEADERS], + "x-request-id,grpc-message,grpc-status" + ); + } + + #[test] + fn allow_headers_echo_request_headers() { + let cors = Cors::default(); + let mut request_headers = preflight_headers(); + request_headers.insert(REQUEST_HEADERS, value("x-grpc-web,foo,x-request-id")); + + let headers = cors.__check_preflight(&request_headers).unwrap(); + + assert_value_eq!(&headers[ALLOW_HEADERS], "x-grpc-web,foo,x-request-id"); + } + + #[test] + fn missing_request_method() { + let cors = Cors::default(); + let mut request_headers = preflight_headers(); + request_headers.remove(REQUEST_METHOD); + + let err = cors.__check_preflight(&request_headers).unwrap_err(); + + assert_eq!(err, Error::MethodNotAllowed); + } + + #[test] + fn only_options_and_post_allowed() { + let cors = Cors::default(); + + for method in &[ + Method::GET, + Method::DELETE, + Method::TRACE, + Method::PATCH, + Method::PUT, + Method::HEAD, + ] { + let mut request_headers = preflight_headers(); + request_headers.insert(REQUEST_METHOD, value(method.as_str())); + + assert_eq!( + cors.__check_preflight(&request_headers).unwrap_err(), + Error::MethodNotAllowed, + ) + } + } + + #[test] + fn custom_max_age() { + use std::time::Duration; + + let cors = Cors::new(Config::new().max_age(Duration::from_secs(99))); + let headers = cors.__check_preflight(&preflight_headers()).unwrap(); + + assert_eq!(headers[MAX_AGE], "99"); + } + + #[test] + fn no_max_age() { + let cors = Cors::new(Config::new().max_age(None)); + let headers = cors.__check_preflight(&preflight_headers()).unwrap(); + + assert!(!headers.contains_key(MAX_AGE)); + } + } + + mod simple { + use super::*; + + fn request_headers() -> HeaderMap { + let mut headers = HeaderMap::new(); + headers.insert(ORIGIN, value("http://example.com")); + headers + } + + #[test] + fn default_config() { + let cors = Cors::default(); + let headers = cors.simple(&request_headers()).unwrap(); + + assert_eq!(headers[ALLOW_ORIGIN], "http://example.com"); + assert_eq!(headers[ALLOW_CREDENTIALS], "true"); + assert_value_eq!(&headers[EXPOSE_HEADERS], "grpc-message,grpc-status"); + + assert!(!headers.contains_key(ALLOW_HEADERS)); + assert!(!headers.contains_key(ALLOW_METHODS)); + assert!(!headers.contains_key(MAX_AGE)); + } + + #[test] + fn any_origin() { + let cors: Cors = Config::new().allow_all_origins().into(); + + assert!(cors.simple(&request_headers()).is_ok()); + } + + #[test] + fn origin_list() { + let cors: Cors = Config::new() + .allow_origins(vec![ + HeaderValue::from_static("http://a.com"), + HeaderValue::from_static("http://b.com"), + ]) + .into(); + + let mut req_headers = request_headers(); + req_headers.insert(ORIGIN, value("http://b.com")); + + assert!(cors.simple(&req_headers).is_ok()); + } + + #[test] + fn origin_not_allowed() { + let cors: Cors = Config::new().allow_origins(vec!["http://a.com"]).into(); + + let err = cors.simple(&request_headers()).unwrap_err(); + + assert_eq!(err, Error::OriginNotAllowed) + } + + #[test] + fn disallow_credentials() { + let cors = Cors::new(Config::new().allow_credentials(false)); + let headers = cors.simple(&request_headers()).unwrap(); + + assert!(!headers.contains_key(ALLOW_CREDENTIALS)); + } + + #[test] + fn expose_headers_are_merged() { + let cors: Cors = Config::new() + .expose_headers(vec!["x-hello", "custom-1"]) + .into(); + + let headers = cors.simple(&request_headers()).unwrap(); + + assert_value_eq!( + &headers[EXPOSE_HEADERS], + "grpc-message,grpc-status,x-hello,custom-1" + ); + } + } +} diff --git a/tonic-web/tonic-web/src/lib.rs b/tonic-web/tonic-web/src/lib.rs new file mode 100644 index 000000000..7e98fa18c --- /dev/null +++ b/tonic-web/tonic-web/src/lib.rs @@ -0,0 +1,133 @@ +//! grpc-web protocol translation for [`tonic`] services. +//! +//! [`tonic_web`] enables tonic servers to handle requests from [grpc-web] clients directly, +//! without the need of an external proxy. It achieves this by wrapping individual tonic services +//! with a [tower] service that performs the translation between protocols and handles `cors` +//! requests. +//! +//! ## Getting Started +//! +//! ```toml +//! [dependencies] +//! tonic_web = "0.1" +//! ``` +//! +//! ## Enabling tonic services +//! +//! The easiest way to get started, is to call the [`enable`] function with your tonic service +//! and allow the tonic server to accept HTTP/1.1 requests: +//! +//! ```ignore +//! #[tokio::main] +//! async fn main() -> Result<(), Box> { +//! let addr = "[::1]:50051".parse().unwrap(); +//! let greeter = GreeterServer::new(MyGreeter::default()); +//! +//! Server::builder() +//! .accept_http1(true) +//! .add_service(tonic_web::enable(greeter)) +//! .serve(addr) +//! .await?; +//! +//! Ok(()) +//! } +//! +//! ``` +//! This will apply a default configuration that works well with grpc-web clients out of the box. +//! See the [`Config`] documentation for details. +//! +//! Alternatively, if you have a tls enabled server, you could skip setting `accept_http1` to `true`. +//! This works because the browser will handle `ALPN`. +//! +//! ```ignore +//! #[tokio::main] +//! async fn main() -> Result<(), Box> { +//! let cert = tokio::fs::read("server.pem").await?; +//! let key = tokio::fs::read("server.key").await?; +//! let identity = Identity::from_pem(cert, key); +//! +//! let addr = "[::1]:50051".parse().unwrap(); +//! let greeter = GreeterServer::new(MyGreeter::default()); +//! +//! // No need to enable HTTP/1 +//! Server::builder() +//! .tls_config(ServerTlsConfig::new().identity(identity))? +//! .add_service(tonic_web::enable(greeter)) +//! .serve(addr) +//! .await?; +//! +//! Ok(()) +//! } +//! ``` +//! +//! ## Limitations +//! +//! * `tonic_web` is designed to work with grpc-web-compliant clients only. It is not expected to +//! handle arbitrary HTTP/x.x requests or bespoke protocols. +//! * Similarly, the cors support implemented by this crate will *only* handle grpc-web and +//! grpc-web preflight requests. +//! * Currently, grpc-web clients can only perform `unary` and `server-streaming` calls. These +//! are the only requests this crate is designed to handle. Support for client and bi-directional +//! streaming will be officially supported when clients do. +//! * There is no support for web socket transports. +//! +//! +//! [`tonic`]: https://github.com/hyperium/tonic +//! [`tonic_web`]: https://github.com/hyperium/tonic +//! [grpc-web]: https://github.com/grpc/grpc-web +//! [tower]: https://github.com/tower-rs/tower +//! [`enable`]: crate::enable() +//! [`Config`]: crate::Config +#![warn( + missing_debug_implementations, + missing_docs, + rust_2018_idioms, + unreachable_pub +)] + +pub use config::Config; + +mod call; +mod config; +mod cors; +mod service; + +use crate::service::GrpcWeb; +use std::future::Future; +use std::pin::Pin; +use tonic::body::BoxBody; +use tonic::transport::NamedService; +use tower_service::Service; + +/// enable a tonic service to handle grpc-web requests with the default configuration. +/// +/// Shortcut for `tonic_web::config().enable(service)` +pub fn enable(service: S) -> GrpcWeb +where + S: Service, Response = http::Response>, + S: NamedService + Clone + Send + 'static, + S::Future: Send + 'static, + S::Error: Into + Send, +{ + config().enable(service) +} + +/// returns a default [`Config`] instance for configuring services. +/// +/// ## Example +/// +/// ``` +/// let config = tonic_web::config() +/// .allow_origins(vec!["http://foo.com"]) +/// .allow_credentials(false) +/// .expose_headers(vec!["x-request-id"]); +/// +/// // let greeter = config.enable(Greeter); +/// // let route_guide = config.enable(RouteGuide); +/// ``` +pub fn config() -> Config { + Config::default() +} + +type BoxError = Box; +type BoxFuture = Pin> + Send>>; diff --git a/tonic-web/tonic-web/src/service.rs b/tonic-web/tonic-web/src/service.rs new file mode 100644 index 000000000..df7a14988 --- /dev/null +++ b/tonic-web/tonic-web/src/service.rs @@ -0,0 +1,559 @@ +use std::task::{Context, Poll}; + +use http::{header, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Version}; +use hyper::Body; +use tonic::body::BoxBody; +use tonic::transport::NamedService; +use tower_service::Service; +use tracing::{debug, trace}; + +use crate::call::content_types::is_grpc_web; +use crate::call::{Encoding, GrpcWebCall}; +use crate::cors::Cors; +use crate::cors::{ORIGIN, REQUEST_HEADERS}; +use crate::{BoxError, BoxFuture, Config}; + +const GRPC: &str = "application/grpc"; + +#[derive(Debug, Clone)] +pub struct GrpcWeb { + inner: S, + cors: Cors, +} + +#[derive(Debug, PartialEq)] +enum RequestKind<'a> { + // The request is considered a grpc-web request if its `content-type` + // header is exactly one of: + // + // - "application/grpc-web" + // - "application/grpc-web+proto" + // - "application/grpc-web-text" + // - "application/grpc-web-text+proto" + GrpcWeb { + method: &'a Method, + encoding: Encoding, + accept: Encoding, + }, + // The request is considered a grpc-web preflight request if all these + // conditions are met: + // + // - the request method is `OPTIONS` + // - request headers include `origin` + // - `access-control-request-headers` header is present and includes `x-grpc-web` + GrpcWebPreflight { + origin: &'a HeaderValue, + request_headers: &'a HeaderValue, + }, + // All other requests, including `application/grpc` + Other(http::Version), +} + +impl GrpcWeb { + pub(crate) fn new(inner: S, config: Config) -> Self { + GrpcWeb { + inner, + cors: Cors::new(config), + } + } +} + +impl GrpcWeb +where + S: Service, Response = Response> + Send + 'static, +{ + fn no_content(&self, headers: HeaderMap) -> BoxFuture { + let mut res = Response::builder() + .status(StatusCode::NO_CONTENT) + .body(BoxBody::empty()) + .unwrap(); + + res.headers_mut().extend(headers); + + Box::pin(async { Ok(res) }) + } + + fn response(&self, status: StatusCode) -> BoxFuture { + Box::pin(async move { + Ok(Response::builder() + .status(status) + .body(BoxBody::empty()) + .unwrap()) + }) + } +} + +impl Service> for GrpcWeb +where + S: Service, Response = Response> + Send + 'static, + S::Future: Send + 'static, + S::Error: Into + Send, +{ + type Response = S::Response; + type Error = S::Error; + type Future = BoxFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + match RequestKind::new(req.headers(), req.method(), req.version()) { + // A valid grpc-web request, regardless of HTTP version. + // + // If the request includes an `origin` header, we verify it is allowed + // to access the resource, an HTTP 403 response is returned otherwise. + // + // If the origin is allowed to access the resource or there is no + // `origin` header present, translate the request into a grpc request, + // call the inner service, and translate the response back to + // grpc-web. + RequestKind::GrpcWeb { + method: &Method::POST, + encoding, + accept, + } => match self.cors.simple(req.headers()) { + Ok(headers) => { + trace!(kind = "simple", path = ?req.uri().path(), ?encoding, ?accept); + + let fut = self.inner.call(coerce_request(req, encoding)); + + Box::pin(async move { + let mut res = coerce_response(fut.await?, accept); + res.headers_mut().extend(headers); + Ok(res) + }) + } + Err(e) => { + debug!(kind = "simple", error=?e, ?req); + self.response(StatusCode::FORBIDDEN) + } + }, + + // The request's content-type matches one of the 4 supported grpc-web + // content-types, but the request method is not `POST`. + // This is not a valid grpc-web request, return HTTP 405. + RequestKind::GrpcWeb { .. } => { + debug!(kind = "simple", error="method not allowed", method = ?req.method()); + self.response(StatusCode::METHOD_NOT_ALLOWED) + } + + // A valid grpc-web preflight request, regardless of HTTP version. + // This is handled by the cors module. + RequestKind::GrpcWebPreflight { + origin, + request_headers, + } => match self.cors.preflight(req.headers(), origin, request_headers) { + Ok(headers) => { + trace!(kind = "preflight", path = ?req.uri().path(), ?origin); + self.no_content(headers) + } + Err(e) => { + debug!(kind = "preflight", error = ?e, ?req); + self.response(StatusCode::FORBIDDEN) + } + }, + + // All http/2 requests that are not grpc-web or grpc-web preflight + // are passed through to the inner service, whatever they are. + RequestKind::Other(Version::HTTP_2) => { + debug!(kind = "other h2", content_type = ?req.headers().get(header::CONTENT_TYPE)); + Box::pin(self.inner.call(req)) + } + + // Return HTTP 400 for all other requests. + RequestKind::Other(_) => { + debug!(kind = "other h1", content_type = ?req.headers().get(header::CONTENT_TYPE)); + self.response(StatusCode::BAD_REQUEST) + } + } + } +} + +impl NamedService for GrpcWeb { + const NAME: &'static str = S::NAME; +} + +impl<'a> RequestKind<'a> { + fn new(headers: &'a HeaderMap, method: &'a Method, version: Version) -> Self { + if is_grpc_web(headers) { + return RequestKind::GrpcWeb { + method, + encoding: Encoding::from_content_type(headers), + accept: Encoding::from_accept(headers), + }; + } + + if let (&Method::OPTIONS, Some(origin), Some(value)) = + (method, headers.get(ORIGIN), headers.get(REQUEST_HEADERS)) + { + match value.to_str() { + Ok(h) if h.contains("x-grpc-web") => { + return RequestKind::GrpcWebPreflight { + origin, + request_headers: value, + }; + } + _ => {} + } + } + + RequestKind::Other(version) + } +} + +// Mutating request headers to conform to a gRPC request is not really +// necessary for us at this point. We could remove most of these except +// maybe for inserting `header::TE`, which tonic should check? +fn coerce_request(mut req: Request, encoding: Encoding) -> Request { + req.headers_mut().remove(header::CONTENT_LENGTH); + + req.headers_mut() + .insert(header::CONTENT_TYPE, HeaderValue::from_static(GRPC)); + + req.headers_mut() + .insert(header::TE, HeaderValue::from_static("trailers")); + + req.headers_mut().insert( + header::ACCEPT_ENCODING, + HeaderValue::from_static("identity,deflate,gzip"), + ); + + req.map(|b| GrpcWebCall::request(b, encoding)) + .map(Body::wrap_stream) +} + +fn coerce_response(res: Response, encoding: Encoding) -> Response { + let mut res = res + .map(|b| GrpcWebCall::response(b, encoding)) + .map(BoxBody::new); + + res.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_static(encoding.to_content_type()), + ); + + res +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::call::content_types::*; + use http::header::{CONTENT_TYPE, ORIGIN}; + + #[derive(Clone)] + struct Svc; + + impl tower_service::Service> for Svc { + type Response = Response; + type Error = String; + type Future = BoxFuture; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _: Request) -> Self::Future { + Box::pin(async { Ok(Response::new(BoxBody::empty())) }) + } + } + + impl NamedService for Svc { + const NAME: &'static str = "test"; + } + + mod grpc_web { + use super::*; + use http::HeaderValue; + + fn request() -> Request { + Request::builder() + .method(Method::POST) + .header(CONTENT_TYPE, GRPC_WEB) + .header(ORIGIN, "http://example.com") + .body(Body::empty()) + .unwrap() + } + + #[tokio::test] + async fn default_cors_config() { + let mut svc = crate::enable(Svc); + let res = svc.call(request()).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn without_origin() { + let mut svc = crate::enable(Svc); + + let mut req = request(); + req.headers_mut().remove(ORIGIN); + + let res = svc.call(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn origin_not_allowed() { + let mut svc = crate::config() + .allow_origins(vec!["http://localhost"]) + .enable(Svc); + + let res = svc.call(request()).await.unwrap(); + + assert_eq!(res.status(), StatusCode::FORBIDDEN) + } + + #[tokio::test] + async fn only_post_allowed() { + let mut svc = crate::enable(Svc); + + for method in &[ + Method::GET, + Method::PUT, + Method::DELETE, + Method::HEAD, + Method::OPTIONS, + Method::PATCH, + ] { + let mut req = request(); + *req.method_mut() = method.clone(); + + let res = svc.call(req).await.unwrap(); + + assert_eq!( + res.status(), + StatusCode::METHOD_NOT_ALLOWED, + "{} should not be allowed", + method + ); + } + } + + #[tokio::test] + async fn grpc_web_content_types() { + let mut svc = crate::enable(Svc); + + for ct in &[GRPC_WEB_TEXT, GRPC_WEB_PROTO, GRPC_WEB_PROTO, GRPC_WEB] { + let mut req = request(); + req.headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static(ct)); + + let res = svc.call(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + } + } + + mod options { + use super::*; + use crate::cors::{REQUEST_HEADERS, REQUEST_METHOD}; + use http::HeaderValue; + + const SUCCESS: StatusCode = StatusCode::NO_CONTENT; + + fn request() -> Request { + Request::builder() + .method(Method::OPTIONS) + .header(ORIGIN, "http://example.com") + .header(REQUEST_HEADERS, "x-grpc-web") + .header(REQUEST_METHOD, "POST") + .body(Body::empty()) + .unwrap() + } + + #[tokio::test] + async fn origin_not_allowed() { + let mut svc = crate::config() + .allow_origins(vec!["http://foo.com"]) + .enable(Svc); + + let res = svc.call(request()).await.unwrap(); + + assert_eq!(res.status(), StatusCode::FORBIDDEN); + } + + #[tokio::test] + async fn missing_request_method() { + let mut svc = crate::enable(Svc); + + let mut req = request(); + req.headers_mut().remove(REQUEST_METHOD); + + let res = svc.call(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::FORBIDDEN); + } + + #[tokio::test] + async fn only_post_and_options_allowed() { + let mut svc = crate::enable(Svc); + + for method in &[ + Method::GET, + Method::PUT, + Method::DELETE, + Method::HEAD, + Method::PATCH, + ] { + let mut req = request(); + req.headers_mut().insert( + REQUEST_METHOD, + HeaderValue::from_maybe_shared(method.to_string()).unwrap(), + ); + + let res = svc.call(req).await.unwrap(); + + assert_eq!( + res.status(), + StatusCode::FORBIDDEN, + "{} should not be allowed", + method + ); + } + } + + #[tokio::test] + async fn h1_missing_origin_is_err() { + let mut svc = crate::enable(Svc); + let mut req = request(); + req.headers_mut().remove(ORIGIN); + + let res = svc.call(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + } + + #[tokio::test] + async fn h2_missing_origin_is_ok() { + let mut svc = crate::enable(Svc); + + let mut req = request(); + *req.version_mut() = Version::HTTP_2; + req.headers_mut().remove(ORIGIN); + + let res = svc.call(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn h1_missing_x_grpc_web_header_is_err() { + let mut svc = crate::enable(Svc); + + let mut req = request(); + req.headers_mut().remove(REQUEST_HEADERS); + + let res = svc.call(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + } + + #[tokio::test] + async fn h2_missing_x_grpc_web_header_is_ok() { + let mut svc = crate::enable(Svc); + + let mut req = request(); + *req.version_mut() = Version::HTTP_2; + req.headers_mut().remove(REQUEST_HEADERS); + + let res = svc.call(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn valid_grpc_web_preflight() { + let mut svc = crate::enable(Svc); + let res = svc.call(request()).await.unwrap(); + + assert_eq!(res.status(), SUCCESS); + } + } + + mod grpc { + use super::*; + use http::HeaderValue; + + fn request() -> Request { + Request::builder() + .version(Version::HTTP_2) + .header(CONTENT_TYPE, GRPC) + .body(Body::empty()) + .unwrap() + } + + #[tokio::test] + async fn h2_is_ok() { + let mut svc = crate::enable(Svc); + + let req = request(); + let res = svc.call(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK) + } + + #[tokio::test] + async fn h1_is_err() { + let mut svc = crate::enable(Svc); + + let req = Request::builder() + .header(CONTENT_TYPE, GRPC) + .body(Body::empty()) + .unwrap(); + + let res = svc.call(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::BAD_REQUEST) + } + + #[tokio::test] + async fn content_type_variants() { + let mut svc = crate::enable(Svc); + + for variant in &["grpc", "grpc+proto", "grpc+thrift", "grpc+foo"] { + let mut req = request(); + req.headers_mut().insert( + CONTENT_TYPE, + HeaderValue::from_maybe_shared(format!("application/{}", variant)).unwrap(), + ); + + let res = svc.call(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK) + } + } + } + + mod other { + use super::*; + + fn request() -> Request { + Request::builder() + .header(CONTENT_TYPE, "application/text") + .body(Body::empty()) + .unwrap() + } + + #[tokio::test] + async fn h1_is_err() { + let mut svc = crate::enable(Svc); + let res = svc.call(request()).await.unwrap(); + + assert_eq!(res.status(), StatusCode::BAD_REQUEST) + } + + #[tokio::test] + async fn h2_is_ok() { + let mut svc = crate::enable(Svc); + let mut req = request(); + *req.version_mut() = Version::HTTP_2; + + let res = svc.call(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK) + } + } +} diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 35cddf615..7417532ac 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -72,6 +72,7 @@ pub struct Server { http2_keepalive_interval: Option, http2_keepalive_timeout: Option, max_frame_size: Option, + accept_http1: bool, } /// A stack based `Service` router. @@ -137,6 +138,7 @@ impl Server { pub fn builder() -> Self { Server { tcp_nodelay: true, + accept_http1: false, ..Default::default() } } @@ -287,7 +289,22 @@ impl Server { } } - /// Intercept inbound requests and add a [`tracing::Span`] to each response future. + /// Allow this server to accept http1 requests. + /// + /// Accepting http1 requests is only useful when developing `grpc-web` + /// enabled services. If this setting is set to `true` but services are + /// not correctly configured to handle grpc-web requests, your server may + /// return confusing (but correct) protocol errors. + /// + /// Default is `false`. + pub fn accept_http1(self, accept_http1: bool) -> Self { + Server { + accept_http1, + ..self + } + } + + /// Intercept inbound headers and add a [`tracing::Span`] to each response future. pub fn trace_fn(self, f: F) -> Self where F: Fn(&http::Request<()>) -> tracing::Span + Send + Sync + 'static, @@ -365,6 +382,7 @@ impl Server { let max_concurrent_streams = self.max_concurrent_streams; let timeout = self.timeout; let max_frame_size = self.max_frame_size; + let http2_only = !self.accept_http1; let http2_keepalive_interval = self.http2_keepalive_interval; let http2_keepalive_timeout = self @@ -382,7 +400,7 @@ impl Server { }; let server = hyper::Server::builder(incoming) - .http2_only(true) + .http2_only(http2_only) .http2_initial_connection_window_size(init_connection_window_size) .http2_initial_stream_window_size(init_stream_window_size) .http2_max_concurrent_streams(max_concurrent_streams)