diff --git a/tests/integration_tests/Cargo.toml b/tests/integration_tests/Cargo.toml index fc33b0d37..a6b07be2f 100644 --- a/tests/integration_tests/Cargo.toml +++ b/tests/integration_tests/Cargo.toml @@ -23,6 +23,7 @@ futures = "0.3" tower = { version = "0.4", features = [] } http-body = "0.4" http = "0.2" +tracing-subscriber = "0.2" [build-dependencies] tonic-build = { path = "../../tonic-build" } diff --git a/tests/integration_tests/build.rs b/tests/integration_tests/build.rs index a091e9483..143257fc7 100644 --- a/tests/integration_tests/build.rs +++ b/tests/integration_tests/build.rs @@ -1,3 +1,4 @@ fn main() { tonic_build::compile_protos("proto/test.proto").unwrap(); + tonic_build::compile_protos("proto/stream.proto").unwrap(); } diff --git a/tests/integration_tests/proto/stream.proto b/tests/integration_tests/proto/stream.proto new file mode 100644 index 000000000..ed4672aae --- /dev/null +++ b/tests/integration_tests/proto/stream.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; + +package stream; + +service TestStream { + rpc StreamCall(InputStream) returns (stream OutputStream); +} + +message InputStream {} +message OutputStream {} diff --git a/tests/integration_tests/src/lib.rs b/tests/integration_tests/src/lib.rs index 9b06ad25c..3c7987a0e 100644 --- a/tests/integration_tests/src/lib.rs +++ b/tests/integration_tests/src/lib.rs @@ -1,3 +1,4 @@ pub mod pb { tonic::include_proto!("test"); + tonic::include_proto!("stream"); } diff --git a/tests/integration_tests/tests/status.rs b/tests/integration_tests/tests/status.rs index 6725784d1..ef57d5086 100644 --- a/tests/integration_tests/tests/status.rs +++ b/tests/integration_tests/tests/status.rs @@ -1,6 +1,9 @@ use bytes::Bytes; use futures_util::FutureExt; -use integration_tests::pb::{test_client, test_server, Input, Output}; +use integration_tests::pb::{ + test_client, test_server, test_stream_client, test_stream_server, Input, InputStream, Output, + OutputStream, +}; use std::time::Duration; use tokio::sync::oneshot; use tonic::metadata::{MetadataMap, MetadataValue}; @@ -117,3 +120,61 @@ async fn status_with_metadata() { jh.await.unwrap(); } + +type Stream = std::pin::Pin< + Box> + Send + Sync + 'static>, +>; + +#[tokio::test] +async fn status_from_server_stream() { + trace_init(); + + struct Svc; + + #[tonic::async_trait] + impl test_stream_server::TestStream for Svc { + type StreamCallStream = Stream; + + async fn stream_call( + &self, + _: Request, + ) -> Result, Status> { + let s = futures::stream::iter(vec![ + Err::(Status::unavailable("foo")), + Err::(Status::unavailable("bar")), + ]); + Ok(Response::new(Box::pin(s) as Self::StreamCallStream)) + } + } + + let svc = test_stream_server::TestStreamServer::new(Svc); + + tokio::spawn(async move { + Server::builder() + .add_service(svc) + .serve("127.0.0.1:1339".parse().unwrap()) + .await + .unwrap(); + }); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let mut client = test_stream_client::TestStreamClient::connect("http://127.0.0.1:1339") + .await + .unwrap(); + + let mut stream = client + .stream_call(InputStream {}) + .await + .unwrap() + .into_inner(); + + assert_eq!(stream.message().await.unwrap_err().message(), "foo"); + assert_eq!(stream.message().await.unwrap(), None); +} + +fn trace_init() { + let _ = tracing_subscriber::FmtSubscriber::builder() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); +} diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index dc68d7c52..e4a23cfce 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -258,7 +258,11 @@ impl Stream for Streaming { match ready!(Pin::new(&mut self.body).poll_trailers(cx)) { Ok(trailer) => { if let Err(e) = crate::status::infer_grpc_status(trailer.as_ref(), status) { - return Some(Err(e)).into(); + if let Some(e) = e { + return Some(Err(e)).into(); + } else { + return Poll::Ready(None); + } } else { self.trailers = trailer.map(MetadataMap::from_headers); } diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index 58f4f99da..a42c19512 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -88,6 +88,7 @@ pub(crate) struct EncodeBody { inner: S, error: Option, role: Role, + is_end_stream: bool, } impl EncodeBody @@ -99,6 +100,7 @@ where inner, error: None, role: Role::Client, + is_end_stream: false, } } @@ -107,6 +109,7 @@ where inner, error: None, role: Role::Server, + is_end_stream: false, } } } @@ -119,7 +122,7 @@ where type Error = Status; fn is_end_stream(&self) -> bool { - false + self.is_end_stream } fn poll_data( @@ -148,7 +151,13 @@ where Role::Client => Poll::Ready(Ok(None)), Role::Server => { let self_proj = self.project(); + + if *self_proj.is_end_stream { + return Poll::Ready(Ok(None)); + } + let status = if let Some(status) = self_proj.error.take() { + *self_proj.is_end_stream = true; status } else { Status::new(Code::Ok, "") diff --git a/tonic/src/status.rs b/tonic/src/status.rs index 674be2c56..e6270232b 100644 --- a/tonic/src/status.rs +++ b/tonic/src/status.rs @@ -617,13 +617,13 @@ impl Error for Status {} pub(crate) fn infer_grpc_status( trailers: Option<&HeaderMap>, status_code: http::StatusCode, -) -> Result<(), Status> { +) -> Result<(), Option> { if let Some(trailers) = trailers { if let Some(status) = Status::from_header_map(&trailers) { if status.code() == Code::Ok { return Ok(()); } else { - return Err(status); + return Err(status.into()); } } } @@ -638,6 +638,13 @@ pub(crate) fn infer_grpc_status( | http::StatusCode::BAD_GATEWAY | http::StatusCode::SERVICE_UNAVAILABLE | http::StatusCode::GATEWAY_TIMEOUT => Code::Unavailable, + // We got a 200 but no trailers, we can infer that this request is finished. + // + // This can happen when a streaming response sends two Status but + // gRPC requires that we end the stream after the first status. + // + // https://github.com/hyperium/tonic/issues/681 + http::StatusCode::OK => return Err(None), _ => Code::Unknown, }; @@ -646,7 +653,7 @@ pub(crate) fn infer_grpc_status( status_code.as_u16(), ); let status = Status::new(code, msg); - Err(status) + Err(status.into()) } // ===== impl Code =====