Skip to content

Commit

Permalink
feat(core): Default encoding/decoding limits
Browse files Browse the repository at this point in the history
This PR adds new defaults for both client and server max
encoding/decoding message size limits. By default, the max message
decoding size is `4MB` and the max message encoding size is
`usize::MAX`.

This is follow up work from #1274

BREAKING: Default max message encoding/decoding limits
  • Loading branch information
LucioFranco committed Mar 29, 2023
1 parent b3358dc commit e570c90
Show file tree
Hide file tree
Showing 10 changed files with 321 additions and 6 deletions.
3 changes: 2 additions & 1 deletion tests/integration_tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ futures-util = "0.3"
prost = "0.11"
tokio = {version = "1.0", features = ["macros", "rt-multi-thread", "net"]}
tonic = {path = "../../tonic"}
tracing-subscriber = {version = "0.3", features = ["env-filter"]}

[dev-dependencies]
async-stream = "0.3"
Expand All @@ -25,7 +26,7 @@ tokio-stream = {version = "0.1.5", features = ["net"]}
tower = {version = "0.4", features = []}
tower-http = { version = "0.4", features = ["set-header", "trace"] }
tower-service = "0.3"
tracing-subscriber = {version = "0.3", features = ["env-filter"]}
tracing = "0.1"

[build-dependencies]
tonic-build = {path = "../../tonic-build"}
Expand Down
11 changes: 11 additions & 0 deletions tests/integration_tests/proto/test.proto
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,14 @@ service Test {

message Input {}
message Output {}

service Test1 {
rpc UnaryCall(Input1) returns (Output1);
}

message Input1 {
bytes buf = 1;
}
message Output1 {
bytes buf = 1;
}
6 changes: 6 additions & 0 deletions tests/integration_tests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,9 @@ pub mod mock {
}
}
}

pub fn trace_init() {
let _ = tracing_subscriber::FmtSubscriber::builder()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.try_init();
}
278 changes: 278 additions & 0 deletions tests/integration_tests/tests/max_message_size.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
use integration_tests::{
pb::{test1_client, test1_server, Input1, Output1},
trace_init,
};
use tonic::{
transport::{Endpoint, Server},
Code, Request, Response, Status,
};

#[test]
fn max_message_recv_size() {
trace_init();

// Server recv
assert_server_recv_max_success(128);
// 5 is the size of the gRPC header
assert_server_recv_max_success((4 * 1024 * 1024) - 5);
// 4mb is the max recv size
assert_server_recv_max_failure(4 * 1024 * 1024);
assert_server_recv_max_failure(4 * 1024 * 1024 + 1);
assert_server_recv_max_failure(8 * 1024 * 1024);

// Client recv
assert_client_recv_max_success(128);
// 5 is the size of the gRPC header
assert_client_recv_max_success((4 * 1024 * 1024) - 5);
// 4mb is the max recv size
assert_client_recv_max_failure(4 * 1024 * 1024);
assert_client_recv_max_failure(4 * 1024 * 1024 + 1);
assert_client_recv_max_failure(8 * 1024 * 1024);

// Custom limit settings
assert_test_case(TestCase {
// 5 is the size of the gRPC header
server_blob_size: 1024 - 5,
client_recv_max: Some(1024),
..Default::default()
});
assert_test_case(TestCase {
server_blob_size: 1024,
client_recv_max: Some(1024),
expected_code: Some(Code::OutOfRange),
..Default::default()
});

assert_test_case(TestCase {
// 5 is the size of the gRPC header
client_blob_size: 1024 - 5,
server_recv_max: Some(1024),
..Default::default()
});
assert_test_case(TestCase {
client_blob_size: 1024,
server_recv_max: Some(1024),
expected_code: Some(Code::OutOfRange),
..Default::default()
});
}

#[test]
fn max_message_send_size() {
trace_init();

// Check client send limit works
assert_test_case(TestCase {
client_blob_size: 4 * 1024 * 1024,
server_recv_max: Some(usize::MAX),
..Default::default()
});
assert_test_case(TestCase {
// 5 is the size of the gRPC header
client_blob_size: 1024 - 5,
server_recv_max: Some(usize::MAX),
client_send_max: Some(1024),
..Default::default()
});
assert_test_case(TestCase {
// 5 is the size of the gRPC header
client_blob_size: 4 * 1024 * 1024,
server_recv_max: Some(usize::MAX),
// Set client send limit to 1024
client_send_max: Some(1024),
// TODO: This should return OutOfRange
// https://github.com/hyperium/tonic/issues/1334
expected_code: Some(Code::Internal),
..Default::default()
});

// Check server send limit works
assert_test_case(TestCase {
server_blob_size: 4 * 1024 * 1024,
client_recv_max: Some(usize::MAX),
..Default::default()
});
assert_test_case(TestCase {
// 5 is the gRPC header size
server_blob_size: 1024 - 5,
client_recv_max: Some(usize::MAX),
// Set server send limit to 1024
server_send_max: Some(1024),
..Default::default()
});
assert_test_case(TestCase {
server_blob_size: 4 * 1024 * 1024,
client_recv_max: Some(usize::MAX),
// Set server send limit to 1024
server_send_max: Some(1024),
expected_code: Some(Code::OutOfRange),
..Default::default()
});
}

// Track caller doesn't work on async fn so we extract the async part
// into a sync version and assert the response there using track track_caller
// so that when this does panic it tells us which line in the test failed not
// where we placed the panic call.

#[track_caller]
fn assert_server_recv_max_success(size: usize) {
let case = TestCase {
client_blob_size: size,
server_blob_size: 0,
..Default::default()
};

assert_test_case(case);
}

#[track_caller]
fn assert_server_recv_max_failure(size: usize) {
let case = TestCase {
client_blob_size: size,
server_blob_size: 0,
expected_code: Some(Code::OutOfRange),
..Default::default()
};

assert_test_case(case);
}

#[track_caller]
fn assert_client_recv_max_success(size: usize) {
let case = TestCase {
client_blob_size: 0,
server_blob_size: size,
..Default::default()
};

assert_test_case(case);
}

#[track_caller]
fn assert_client_recv_max_failure(size: usize) {
let case = TestCase {
client_blob_size: 0,
server_blob_size: size,
expected_code: Some(Code::OutOfRange),
..Default::default()
};

assert_test_case(case);
}

#[track_caller]
fn assert_test_case(case: TestCase) {
let res = max_message_run(&case);

match (case.expected_code, res) {
(Some(_), Ok(())) => panic!("Expected failure, but got success"),
(Some(code), Err(status)) => {
if status.code() != code {
panic!(
"Expected failure, got failure but wrong code, got: {:?}",
status
)
}
}

(None, Err(status)) => panic!("Expected success, but got failure, got: {:?}", status),

_ => (),
}
}

#[derive(Default)]
struct TestCase {
client_blob_size: usize,
server_blob_size: usize,
client_recv_max: Option<usize>,
server_recv_max: Option<usize>,
client_send_max: Option<usize>,
server_send_max: Option<usize>,

expected_code: Option<Code>,
}

#[tokio::main]
async fn max_message_run(case: &TestCase) -> Result<(), Status> {
let client_blob = vec![0; case.client_blob_size];
let server_blob = vec![0; case.server_blob_size];

let (client, server) = tokio::io::duplex(1024);

struct Svc(Vec<u8>);

#[tonic::async_trait]
impl test1_server::Test1 for Svc {
async fn unary_call(&self, _req: Request<Input1>) -> Result<Response<Output1>, Status> {
Ok(Response::new(Output1 {
buf: self.0.clone(),
}))
}
}

let svc = test1_server::Test1Server::new(Svc(server_blob));

let svc = if let Some(size) = case.server_recv_max {
svc.max_decoding_message_size(size)
} else {
svc
};

let svc = if let Some(size) = case.server_send_max {
svc.max_encoding_message_size(size)
} else {
svc
};

tokio::spawn(async move {
Server::builder()
.add_service(svc)
.serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>(server)]))
.await
.unwrap();
});

// Move client to an option so we can _move_ the inner value
// on the first attempt to connect. All other attempts will fail.
let mut client = Some(client);
let channel = Endpoint::try_from("http://[::]:50051")
.unwrap()
.connect_with_connector(tower::service_fn(move |_| {
let client = client.take();

async move {
if let Some(client) = client {
Ok(client)
} else {
Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Client already taken",
))
}
}
}))
.await
.unwrap();

let client = test1_client::Test1Client::new(channel);

let client = if let Some(size) = case.client_recv_max {
client.max_decoding_message_size(size)
} else {
client
};

let mut client = if let Some(size) = case.client_send_max {
client.max_encoding_message_size(size)
} else {
client
};

let req = Request::new(Input1 {
buf: client_blob.clone(),
});

client.unary_call(req).await.map(|_| ())
}
4 changes: 4 additions & 0 deletions tonic-build/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,17 @@ pub(crate) fn generate_internal<T: Service>(
}

/// Limits the maximum size of a decoded message.
///
/// Default: `4MB`
#[must_use]
pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
self.inner = self.inner.max_decoding_message_size(limit);
self
}

/// Limits the maximum size of an encoded message.
///
/// Default: `usize::MAX`
#[must_use]
pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
self.inner = self.inner.max_encoding_message_size(limit);
Expand Down
4 changes: 4 additions & 0 deletions tonic-build/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,17 @@ pub(crate) fn generate_internal<T: Service>(

let configure_max_message_size_methods = quote! {
/// Limits the maximum size of a decoded message.
///
/// Default: `4MB`
#[must_use]
pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
self.max_decoding_message_size = Some(limit);
self
}

/// Limits the maximum size of an encoded message.
///
/// Default: `usize::MAX`
#[must_use]
pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
self.max_encoding_message_size = Some(limit);
Expand Down
6 changes: 4 additions & 2 deletions tonic/src/codec/decode.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::compression::{decompress, CompressionEncoding};
use super::{DecodeBuf, Decoder, DEFAULT_MAX_MESSAGE_SIZE, HEADER_SIZE};
use super::{DecodeBuf, Decoder, DEFAULT_MAX_RECV_MESSAGE_SIZE, HEADER_SIZE};
use crate::{body::BoxBody, metadata::MetadataMap, Code, Status};
use bytes::{Buf, BufMut, BytesMut};
use futures_core::Stream;
Expand Down Expand Up @@ -174,7 +174,9 @@ impl StreamingInner {
};

let len = self.buf.get_u32() as usize;
let limit = self.max_message_size.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE);
let limit = self
.max_message_size
.unwrap_or(DEFAULT_MAX_RECV_MESSAGE_SIZE);
if len > limit {
return Err(Status::new(
Code::OutOfRange,
Expand Down
4 changes: 2 additions & 2 deletions tonic/src/codec/encode.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::compression::{compress, CompressionEncoding, SingleMessageCompressionOverride};
use super::{EncodeBuf, Encoder, DEFAULT_MAX_MESSAGE_SIZE, HEADER_SIZE};
use super::{EncodeBuf, Encoder, DEFAULT_MAX_SEND_MESSAGE_SIZE, HEADER_SIZE};
use crate::{Code, Status};
use bytes::{BufMut, Bytes, BytesMut};
use futures_core::{Stream, TryStream};
Expand Down Expand Up @@ -141,7 +141,7 @@ fn finish_encoding(
buf: &mut BytesMut,
) -> Result<Bytes, Status> {
let len = buf.len() - HEADER_SIZE;
let limit = max_message_size.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE);
let limit = max_message_size.unwrap_or(DEFAULT_MAX_SEND_MESSAGE_SIZE);
if len > limit {
return Err(Status::new(
Code::OutOfRange,
Expand Down
Loading

0 comments on commit e570c90

Please sign in to comment.