Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core): Default encoding/decoding limits #1335

Merged
merged 2 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/integration_tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ futures-util = "0.3"
prost = "0.11"
tokio = {version = "1.0", features = ["macros", "rt-multi-thread", "net"]}
tonic = {path = "../../tonic"}
tracing-subscriber = {version = "0.3", features = ["env-filter"]}

[dev-dependencies]
async-stream = "0.3"
Expand All @@ -25,7 +26,7 @@ tokio-stream = {version = "0.1.5", features = ["net"]}
tower = {version = "0.4", features = []}
tower-http = { version = "0.4", features = ["set-header", "trace"] }
tower-service = "0.3"
tracing-subscriber = {version = "0.3", features = ["env-filter"]}
tracing = "0.1"

[build-dependencies]
tonic-build = {path = "../../tonic-build"}
Expand Down
11 changes: 11 additions & 0 deletions tests/integration_tests/proto/test.proto
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,14 @@ service Test {

message Input {}
message Output {}

service Test1 {
rpc UnaryCall(Input1) returns (Output1);
}

message Input1 {
bytes buf = 1;
}
message Output1 {
bytes buf = 1;
}
6 changes: 6 additions & 0 deletions tests/integration_tests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,9 @@ pub mod mock {
}
}
}

pub fn trace_init() {
let _ = tracing_subscriber::FmtSubscriber::builder()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.try_init();
}
278 changes: 278 additions & 0 deletions tests/integration_tests/tests/max_message_size.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
use integration_tests::{
pb::{test1_client, test1_server, Input1, Output1},
trace_init,
};
use tonic::{
transport::{Endpoint, Server},
Code, Request, Response, Status,
};

#[test]
fn max_message_recv_size() {
trace_init();

// Server recv
assert_server_recv_max_success(128);
// 5 is the size of the gRPC header
assert_server_recv_max_success((4 * 1024 * 1024) - 5);
// 4mb is the max recv size
assert_server_recv_max_failure(4 * 1024 * 1024);
assert_server_recv_max_failure(4 * 1024 * 1024 + 1);
assert_server_recv_max_failure(8 * 1024 * 1024);

// Client recv
assert_client_recv_max_success(128);
// 5 is the size of the gRPC header
assert_client_recv_max_success((4 * 1024 * 1024) - 5);
// 4mb is the max recv size
assert_client_recv_max_failure(4 * 1024 * 1024);
assert_client_recv_max_failure(4 * 1024 * 1024 + 1);
assert_client_recv_max_failure(8 * 1024 * 1024);

// Custom limit settings
assert_test_case(TestCase {
// 5 is the size of the gRPC header
server_blob_size: 1024 - 5,
client_recv_max: Some(1024),
..Default::default()
});
assert_test_case(TestCase {
server_blob_size: 1024,
client_recv_max: Some(1024),
expected_code: Some(Code::OutOfRange),
..Default::default()
});

assert_test_case(TestCase {
// 5 is the size of the gRPC header
client_blob_size: 1024 - 5,
server_recv_max: Some(1024),
..Default::default()
});
assert_test_case(TestCase {
client_blob_size: 1024,
server_recv_max: Some(1024),
expected_code: Some(Code::OutOfRange),
..Default::default()
});
}

#[test]
fn max_message_send_size() {
trace_init();

// Check client send limit works
assert_test_case(TestCase {
client_blob_size: 4 * 1024 * 1024,
server_recv_max: Some(usize::MAX),
..Default::default()
});
assert_test_case(TestCase {
// 5 is the size of the gRPC header
client_blob_size: 1024 - 5,
server_recv_max: Some(usize::MAX),
client_send_max: Some(1024),
..Default::default()
});
assert_test_case(TestCase {
// 5 is the size of the gRPC header
client_blob_size: 4 * 1024 * 1024,
server_recv_max: Some(usize::MAX),
// Set client send limit to 1024
client_send_max: Some(1024),
// TODO: This should return OutOfRange
// https://github.com/hyperium/tonic/issues/1334
expected_code: Some(Code::Internal),
..Default::default()
});

// Check server send limit works
assert_test_case(TestCase {
server_blob_size: 4 * 1024 * 1024,
client_recv_max: Some(usize::MAX),
..Default::default()
});
assert_test_case(TestCase {
// 5 is the gRPC header size
server_blob_size: 1024 - 5,
client_recv_max: Some(usize::MAX),
// Set server send limit to 1024
server_send_max: Some(1024),
..Default::default()
});
assert_test_case(TestCase {
server_blob_size: 4 * 1024 * 1024,
client_recv_max: Some(usize::MAX),
// Set server send limit to 1024
server_send_max: Some(1024),
expected_code: Some(Code::OutOfRange),
..Default::default()
});
}

// Track caller doesn't work on async fn so we extract the async part
// into a sync version and assert the response there using track track_caller
// so that when this does panic it tells us which line in the test failed not
// where we placed the panic call.

#[track_caller]
fn assert_server_recv_max_success(size: usize) {
let case = TestCase {
client_blob_size: size,
server_blob_size: 0,
..Default::default()
};

assert_test_case(case);
}

#[track_caller]
fn assert_server_recv_max_failure(size: usize) {
let case = TestCase {
client_blob_size: size,
server_blob_size: 0,
expected_code: Some(Code::OutOfRange),
..Default::default()
};

assert_test_case(case);
}

#[track_caller]
fn assert_client_recv_max_success(size: usize) {
let case = TestCase {
client_blob_size: 0,
server_blob_size: size,
..Default::default()
};

assert_test_case(case);
}

#[track_caller]
fn assert_client_recv_max_failure(size: usize) {
let case = TestCase {
client_blob_size: 0,
server_blob_size: size,
expected_code: Some(Code::OutOfRange),
..Default::default()
};

assert_test_case(case);
}

#[track_caller]
fn assert_test_case(case: TestCase) {
let res = max_message_run(&case);

match (case.expected_code, res) {
(Some(_), Ok(())) => panic!("Expected failure, but got success"),
(Some(code), Err(status)) => {
if status.code() != code {
panic!(
"Expected failure, got failure but wrong code, got: {:?}",
status
)
}
}

(None, Err(status)) => panic!("Expected success, but got failure, got: {:?}", status),

_ => (),
}
}

#[derive(Default)]
struct TestCase {
client_blob_size: usize,
server_blob_size: usize,
client_recv_max: Option<usize>,
server_recv_max: Option<usize>,
client_send_max: Option<usize>,
server_send_max: Option<usize>,

expected_code: Option<Code>,
}

#[tokio::main]
async fn max_message_run(case: &TestCase) -> Result<(), Status> {
let client_blob = vec![0; case.client_blob_size];
let server_blob = vec![0; case.server_blob_size];

let (client, server) = tokio::io::duplex(1024);

struct Svc(Vec<u8>);

#[tonic::async_trait]
impl test1_server::Test1 for Svc {
async fn unary_call(&self, _req: Request<Input1>) -> Result<Response<Output1>, Status> {
Ok(Response::new(Output1 {
buf: self.0.clone(),
}))
}
}

let svc = test1_server::Test1Server::new(Svc(server_blob));

let svc = if let Some(size) = case.server_recv_max {
svc.max_decoding_message_size(size)
} else {
svc
};

let svc = if let Some(size) = case.server_send_max {
svc.max_encoding_message_size(size)
} else {
svc
};

tokio::spawn(async move {
Server::builder()
.add_service(svc)
.serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>(server)]))
.await
.unwrap();
});

// Move client to an option so we can _move_ the inner value
// on the first attempt to connect. All other attempts will fail.
let mut client = Some(client);
let channel = Endpoint::try_from("http://[::]:50051")
.unwrap()
.connect_with_connector(tower::service_fn(move |_| {
let client = client.take();

async move {
if let Some(client) = client {
Ok(client)
} else {
Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Client already taken",
))
}
}
}))
.await
.unwrap();

let client = test1_client::Test1Client::new(channel);

let client = if let Some(size) = case.client_recv_max {
client.max_decoding_message_size(size)
} else {
client
};

let mut client = if let Some(size) = case.client_send_max {
client.max_encoding_message_size(size)
} else {
client
};

let req = Request::new(Input1 {
buf: client_blob.clone(),
});

client.unary_call(req).await.map(|_| ())
}
4 changes: 4 additions & 0 deletions tonic-build/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,17 @@ pub(crate) fn generate_internal<T: Service>(
}

/// Limits the maximum size of a decoded message.
///
/// Default: `4MB`
#[must_use]
pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
self.inner = self.inner.max_decoding_message_size(limit);
self
}

/// Limits the maximum size of an encoded message.
///
/// Default: `usize::MAX`
#[must_use]
pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
self.inner = self.inner.max_encoding_message_size(limit);
Expand Down
4 changes: 4 additions & 0 deletions tonic-build/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,17 @@ pub(crate) fn generate_internal<T: Service>(

let configure_max_message_size_methods = quote! {
/// Limits the maximum size of a decoded message.
///
/// Default: `4MB`
#[must_use]
pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
self.max_decoding_message_size = Some(limit);
self
}

/// Limits the maximum size of an encoded message.
///
/// Default: `usize::MAX`
#[must_use]
pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
self.max_encoding_message_size = Some(limit);
Expand Down
8 changes: 8 additions & 0 deletions tonic-health/src/generated/grpc.health.v1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,16 @@ pub mod health_client {
self
}
/// Limits the maximum size of a decoded message.
///
/// Default: `4MB`
#[must_use]
pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
self.inner = self.inner.max_decoding_message_size(limit);
self
}
/// Limits the maximum size of an encoded message.
///
/// Default: `usize::MAX`
#[must_use]
pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
self.inner = self.inner.max_encoding_message_size(limit);
Expand Down Expand Up @@ -282,12 +286,16 @@ pub mod health_server {
self
}
/// Limits the maximum size of a decoded message.
///
/// Default: `4MB`
#[must_use]
pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
self.max_decoding_message_size = Some(limit);
self
}
/// Limits the maximum size of an encoded message.
///
/// Default: `usize::MAX`
#[must_use]
pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
self.max_encoding_message_size = Some(limit);
Expand Down
Loading