diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index 07544ab9b..7ed68da07 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -2,7 +2,7 @@ use super::compression::{compress, CompressionEncoding, SingleMessageCompression use super::{EncodeBuf, Encoder, DEFAULT_MAX_SEND_MESSAGE_SIZE, HEADER_SIZE}; use crate::{Code, Status}; use bytes::{BufMut, Bytes, BytesMut}; -use futures_util::{ready, StreamExt, TryStream, TryStreamExt}; +use futures_util::{ready, stream::Fuse, StreamExt, TryStreamExt}; use http::HeaderMap; use http_body::Body; use pin_project::pin_project; @@ -13,6 +13,7 @@ use std::{ use tokio_stream::Stream; pub(super) const BUFFER_SIZE: usize = 8 * 1024; +const YIELD_THRESHOLD: usize = 32 * 1024; pub(crate) fn encode_server( encoder: T, @@ -25,15 +26,13 @@ where T: Encoder, U: Stream>, { - let stream = encode( + let stream = EncodedBytes::new( encoder, source, compression_encoding, compression_override, max_message_size, - ) - .into_stream(); - + ); EncodeBody::new_server(stream) } @@ -47,55 +46,125 @@ where T: Encoder, U: Stream, { - let stream = encode( + let stream = EncodedBytes::new( encoder, source.map(Ok), compression_encoding, SingleMessageCompressionOverride::default(), max_message_size, - ) - .into_stream(); + ); EncodeBody::new_client(stream) } -fn encode( - mut encoder: T, - source: U, +/// Combinator for efficient encoding of messages into reasonably sized buffers. +/// EncodedBytes encodes ready messages from its delegate stream into a BytesMut, +/// splitting off and yielding a buffer when either: +/// * The delegate stream polls as not ready, or +/// * The encoded buffer surpasses YIELD_THRESHOLD. +#[pin_project(project = EncodedBytesProj)] +#[derive(Debug)] +pub(crate) struct EncodedBytes +where + T: Encoder, + U: Stream>, +{ + #[pin] + source: Fuse, + encoder: T, compression_encoding: Option, - compression_override: SingleMessageCompressionOverride, max_message_size: Option, -) -> impl TryStream + buf: BytesMut, + uncompression_buf: BytesMut, +} + +impl EncodedBytes where T: Encoder, U: Stream>, { - let mut buf = BytesMut::with_capacity(BUFFER_SIZE); + fn new( + encoder: T, + source: U, + compression_encoding: Option, + compression_override: SingleMessageCompressionOverride, + max_message_size: Option, + ) -> Self { + let buf = BytesMut::with_capacity(BUFFER_SIZE); - let compression_encoding = if compression_override == SingleMessageCompressionOverride::Disable - { - None - } else { - compression_encoding - }; + let compression_encoding = + if compression_override == SingleMessageCompressionOverride::Disable { + None + } else { + compression_encoding + }; - let mut uncompression_buf = if compression_encoding.is_some() { - BytesMut::with_capacity(BUFFER_SIZE) - } else { - BytesMut::new() - }; + let uncompression_buf = if compression_encoding.is_some() { + BytesMut::with_capacity(BUFFER_SIZE) + } else { + BytesMut::new() + }; + + return EncodedBytes { + source: source.fuse(), + encoder, + compression_encoding, + max_message_size, + buf, + uncompression_buf, + }; + } +} - source.map(move |result| { - let item = result?; +impl Stream for EncodedBytes +where + T: Encoder, + U: Stream>, +{ + type Item = Result; - encode_item( - &mut encoder, - &mut buf, - &mut uncompression_buf, + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let EncodedBytesProj { + mut source, + encoder, compression_encoding, max_message_size, - item, - ) - }) + buf, + uncompression_buf, + } = self.project(); + + loop { + match source.as_mut().poll_next(cx) { + Poll::Pending if buf.is_empty() => { + return Poll::Pending; + } + Poll::Ready(None) if buf.is_empty() => { + return Poll::Ready(None); + } + Poll::Pending | Poll::Ready(None) => { + return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze()))); + } + Poll::Ready(Some(Ok(item))) => { + if let Err(status) = encode_item( + encoder, + buf, + uncompression_buf, + *compression_encoding, + *max_message_size, + item, + ) { + return Poll::Ready(Some(Err(status))); + } + + if buf.len() >= YIELD_THRESHOLD { + return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze()))); + } + } + Poll::Ready(Some(Err(status))) => { + return Poll::Ready(Some(Err(status))); + } + } + } + } } fn encode_item( @@ -105,10 +174,12 @@ fn encode_item( compression_encoding: Option, max_message_size: Option, item: T::Item, -) -> Result +) -> Result<(), Status> where T: Encoder, { + let offset = buf.len(); + buf.reserve(HEADER_SIZE); unsafe { buf.advance_mut(HEADER_SIZE); @@ -132,14 +203,14 @@ where } // now that we know length, we can write the header - finish_encoding(compression_encoding, max_message_size, buf) + finish_encoding(compression_encoding, max_message_size, &mut buf[offset..]) } fn finish_encoding( compression_encoding: Option, max_message_size: Option, - buf: &mut BytesMut, -) -> Result { + buf: &mut [u8], +) -> Result<(), Status> { let len = buf.len() - HEADER_SIZE; let limit = max_message_size.unwrap_or(DEFAULT_MAX_SEND_MESSAGE_SIZE); if len > limit { @@ -163,7 +234,7 @@ fn finish_encoding( buf.put_u32(len as u32); } - Ok(buf.split_to(len + HEADER_SIZE).freeze()) + Ok(()) } #[derive(Debug)]