From e299a35a9b54ce4c05ba728ff237c67ca4e5a2d9 Mon Sep 17 00:00:00 2001 From: Jeff Foster Date: Mon, 24 Jun 2024 22:05:21 -0700 Subject: [PATCH] fix(tonic): flush accumulated ready messages when status received #1423 introduced logic to buffer multiple ready messages in order to amortize the cost of sends to the underlying transport. This also introduced a change in behavior for tonic in the following scenario: A stream of ready messages less than the yield threshold is trailed by a status. Previously the ready messages would all have been yielded from the stream and sent, followed by the status. After the change was introduced the status is yielded from the stream and sent but the accumulated ready messages in the buffer are never sent out. This change adjusts the logic to restore the previous behavior while still retaining the amoritization benefits. Namely it flushes the accumulated ready messages prior to yielding the status ensuring they are sent out from the stream in the order they are read. --- tests/integration_tests/tests/status.rs | 49 +++++++++++++++++++++++++ tonic/src/codec/encode.rs | 13 ++++++- 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/tests/integration_tests/tests/status.rs b/tests/integration_tests/tests/status.rs index df6bc4b3b..99e20c695 100644 --- a/tests/integration_tests/tests/status.rs +++ b/tests/integration_tests/tests/status.rs @@ -194,3 +194,52 @@ async fn status_from_server_stream_with_source() { let source = error.source().unwrap(); source.downcast_ref::().unwrap(); } + +#[tokio::test] +async fn message_and_then_status_from_server_stream() { + integration_tests::trace_init(); + + struct Svc; + + #[tonic::async_trait] + impl test_stream_server::TestStream for Svc { + type StreamCallStream = Stream; + + async fn stream_call( + &self, + _: Request, + ) -> Result, Status> { + let s = tokio_stream::iter(vec![ + Ok(OutputStream {}), + Err::(Status::unavailable("foo")), + ]); + Ok(Response::new(Box::pin(s) as Self::StreamCallStream)) + } + } + + let svc = test_stream_server::TestStreamServer::new(Svc); + + tokio::spawn(async move { + Server::builder() + .add_service(svc) + .serve("127.0.0.1:1340".parse().unwrap()) + .await + .unwrap(); + }); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let mut client = test_stream_client::TestStreamClient::connect("http://127.0.0.1:1340") + .await + .unwrap(); + + let mut stream = client + .stream_call(InputStream {}) + .await + .unwrap() + .into_inner(); + + assert_eq!(stream.message().await.unwrap(), Some(OutputStream {})); + assert_eq!(stream.message().await.unwrap_err().message(), "foo"); + assert_eq!(stream.message().await.unwrap(), None); +} diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index 82b4eb61d..0b5de1bda 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -74,6 +74,7 @@ where max_message_size: Option, buf: BytesMut, uncompression_buf: BytesMut, + error: Option, } impl EncodedBytes @@ -112,6 +113,7 @@ where max_message_size, buf, uncompression_buf, + error: None, } } } @@ -131,9 +133,14 @@ where max_message_size, buf, uncompression_buf, + error, } = self.project(); let buffer_settings = encoder.buffer_settings(); + if let Some(status) = error.take() { + return Poll::Ready(Some(Err(status))); + } + loop { match source.as_mut().poll_next(cx) { Poll::Pending if buf.is_empty() => { @@ -163,7 +170,11 @@ where } } Poll::Ready(Some(Err(status))) => { - return Poll::Ready(Some(Err(status))); + if buf.is_empty() { + return Poll::Ready(Some(Err(status))); + } + *error = Some(status); + return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze()))); } } }