Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sniff the first byte to glean if the incoming request is a single or batch request #419

Merged
merged 6 commits into from
Aug 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 34 additions & 35 deletions http-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ use jsonrpsee_types::{
v2::{
error::JsonRpcErrorCode,
params::Id,
request::{JsonRpcInvalidRequest, JsonRpcNotification, JsonRpcRequest},
request::{JsonRpcNotification, JsonRpcRequest},
},
TEN_MB_SIZE_BYTES,
};
use jsonrpsee_utils::hyper_helpers::read_response_to_body;
use jsonrpsee_utils::server::{
helpers::{collect_batch_response, send_error},
helpers::{collect_batch_response, prepare_error, send_error},
rpc_module::Methods,
};

Expand Down Expand Up @@ -215,43 +215,42 @@ impl Server {
let (tx, mut rx) = mpsc::unbounded::<String>();
// Is this a single request or a batch (or error)?
let mut single = true;

// For reasons outlined [here](https://github.com/serde-rs/json/issues/497), `RawValue` can't be
// used with untagged enums at the moment. This means we can't use an `SingleOrBatch` untagged
// enum here and have to try each case individually: first the single request case, then the
// batch case and lastly the error. For the worst case – unparseable input – we make three calls
// to [`serde_json::from_slice`] which is pretty annoying.
// Our [issue](https://github.com/paritytech/jsonrpsee/issues/296).
if let Ok(req) = serde_json::from_slice::<JsonRpcRequest>(&body) {
// NOTE: we don't need to track connection id on HTTP, so using hardcoded 0 here.
methods.execute(&tx, req, 0).await;
} else if let Ok(_req) = serde_json::from_slice::<JsonRpcNotification<Option<&RawValue>>>(&body)
{
return Ok::<_, HyperError>(response::ok_response("".into()));
} else if let Ok(batch) = serde_json::from_slice::<Vec<JsonRpcRequest>>(&body) {
if !batch.is_empty() {
single = false;
for req in batch {
type Notif<'a> = JsonRpcNotification<'a, Option<&'a RawValue>>;
dvdplm marked this conversation as resolved.
Show resolved Hide resolved
match body.get(0) {
// Single request or notification
Some(b'{') => {
if let Ok(req) = serde_json::from_slice::<JsonRpcRequest>(&body) {
// NOTE: we don't need to track connection id on HTTP, so using hardcoded 0 here.
methods.execute(&tx, req, 0).await;
} else if let Ok(_req) = serde_json::from_slice::<Notif>(&body) {
return Ok::<_, HyperError>(response::ok_response("".into()));
} else {
let (id, code) = prepare_error(&body);
send_error(id, &tx, code.into());
}
} else {
send_error(Id::Null, &tx, JsonRpcErrorCode::InvalidRequest.into());
}
} else if let Ok(_batch) =
serde_json::from_slice::<Vec<JsonRpcNotification<Option<&RawValue>>>>(&body)
{
return Ok::<_, HyperError>(response::ok_response("".into()));
} else {
log::error!(
"[service_fn], Cannot parse request body={:?}",
String::from_utf8_lossy(&body[..cmp::min(body.len(), 1024)])
);
let (id, code) = match serde_json::from_slice::<JsonRpcInvalidRequest>(&body) {
Ok(req) => (req.id, JsonRpcErrorCode::InvalidRequest),
Err(_) => (Id::Null, JsonRpcErrorCode::ParseError),
};
send_error(id, &tx, code.into());
// Bacth of requests or notifications
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Bacth of requests or notifications
// Batch of requests or notifications

Some(b'[') => {
if let Ok(batch) = serde_json::from_slice::<Vec<JsonRpcRequest>>(&body) {
if !batch.is_empty() {
single = false;
for req in batch {
methods.execute(&tx, req, 0).await;
}
} else {
send_error(Id::Null, &tx, JsonRpcErrorCode::InvalidRequest.into());
}
} else if let Ok(_batch) = serde_json::from_slice::<Vec<Notif>>(&body) {
return Ok::<_, HyperError>(response::ok_response("".into()));
} else {
let (id, code) = prepare_error(&body);
send_error(id, &tx, code.into());
}
}
// Garbage request
_ => send_error(Id::Null, &tx, JsonRpcErrorCode::ParseError.into()),
}

// Closes the receiving half of a channel without dropping it. This prevents any further
// messages from being sent on the channel.
rx.close();
Expand Down
46 changes: 46 additions & 0 deletions http-server/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,52 @@ async fn invalid_batched_method_calls() {
assert_eq!(response.body, parse_error(Id::Null));
}

#[tokio::test]
async fn garbage_request_fails() {
let addr = server().await;
let uri = to_http_uri(addr);

let req = r#"dsdfs fsdsfds"#;
let response = http_request(req.into(), uri.clone()).await.unwrap();
assert_eq!(response.body, parse_error(Id::Null));

let req = r#"{ "#;
let response = http_request(req.into(), uri.clone()).await.unwrap();
assert_eq!(response.body, parse_error(Id::Null));

let req = r#" {"jsonrpc":"2.0","method":"add", "params":[1, 2],"id":1}"#;
let response = http_request(req.into(), uri.clone()).await.unwrap();
assert_eq!(response.body, parse_error(Id::Null));

let req = r#"{}"#;
let response = http_request(req.into(), uri.clone()).await.unwrap();
assert_eq!(response.body, parse_error(Id::Null));

let req = r#"{sds}"#;
let response = http_request(req.into(), uri.clone()).await.unwrap();
assert_eq!(response.body, parse_error(Id::Null));

let req = r#"["#;
let response = http_request(req.into(), uri.clone()).await.unwrap();
assert_eq!(response.body, parse_error(Id::Null));

let req = r#"[dsds]"#;
let response = http_request(req.into(), uri.clone()).await.unwrap();
assert_eq!(response.body, parse_error(Id::Null));

let req = r#" [{"jsonrpc":"2.0","method":"add", "params":[1, 2],"id":1}]"#;
let response = http_request(req.into(), uri.clone()).await.unwrap();
assert_eq!(response.body, parse_error(Id::Null));

let req = r#"[]"#;
let response = http_request(req.into(), uri.clone()).await.unwrap();
assert_eq!(response.body, invalid_request(Id::Null));

let req = r#"[{"jsonrpc":"2.0","method":"add", "params":[1, 2],"id":1}"#;
let response = http_request(req.into(), uri.clone()).await.unwrap();
assert_eq!(response.body, parse_error(Id::Null));
}

#[tokio::test]
async fn should_return_method_not_found() {
let addr = server().with_default_timeout().await.unwrap();
Expand Down
10 changes: 10 additions & 0 deletions utils/src/server/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use futures_channel::mpsc;
use futures_util::stream::StreamExt;
use jsonrpsee_types::v2::error::{JsonRpcError, JsonRpcErrorCode, JsonRpcErrorObject};
use jsonrpsee_types::v2::params::{Id, TwoPointZero};
use jsonrpsee_types::v2::request::JsonRpcInvalidRequest;
use jsonrpsee_types::v2::response::JsonRpcResponse;
use serde::Serialize;

Expand Down Expand Up @@ -38,6 +39,15 @@ pub fn send_error(id: Id, tx: &MethodSink, error: JsonRpcErrorObject) {
}
}

/// Figure out if this is a sufficiently complete request that we can extract an [`Id`] out of, or just plain
/// unparseable garbage.
pub fn prepare_error(data: &[u8]) -> (Id<'_>, JsonRpcErrorCode) {
match serde_json::from_slice::<JsonRpcInvalidRequest>(&data) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
match serde_json::from_slice::<JsonRpcInvalidRequest>(&data) {
match serde_json::from_slice::<JsonRpcInvalidRequest>(data) {

Ok(JsonRpcInvalidRequest { id }) => (id, JsonRpcErrorCode::InvalidRequest),
Err(_) => (Id::Null, JsonRpcErrorCode::ParseError),
}
}

/// Read all the results of all method calls in a batch request from the ['Stream']. Format the result into a single
/// `String` appropriately wrapped in `[`/`]`.
pub async fn collect_batch_response(rx: mpsc::UnboundedReceiver<String>) -> String {
Expand Down
74 changes: 36 additions & 38 deletions ws-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,7 @@ use std::task::{Context, Poll};
use std::{net::SocketAddr, sync::Arc};

use crate::types::{
error::Error,
v2::error::JsonRpcErrorCode,
v2::params::Id,
v2::request::{JsonRpcInvalidRequest, JsonRpcRequest},
TEN_MB_SIZE_BYTES,
error::Error, v2::error::JsonRpcErrorCode, v2::params::Id, v2::request::JsonRpcRequest, TEN_MB_SIZE_BYTES,
};
use futures_channel::mpsc;
use futures_util::future::{join_all, FutureExt};
Expand All @@ -50,7 +46,7 @@ use tokio::{
};
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};

use jsonrpsee_utils::server::helpers::{collect_batch_response, send_error};
use jsonrpsee_utils::server::helpers::{collect_batch_response, prepare_error, send_error};
use jsonrpsee_utils::server::rpc_module::{ConnectionId, Methods};

/// Default maximum connections allowed.
Expand Down Expand Up @@ -275,40 +271,42 @@ async fn background_task(
continue;
}

// For reasons outlined [here](https://github.com/serde-rs/json/issues/497), `RawValue` can't be used with
// untagged enums at the moment. This means we can't use an `SingleOrBatch` untagged enum here and have to try
// each case individually: first the single request case, then the batch case and lastly the error. For the
// worst case – unparseable input – we make three calls to [`serde_json::from_slice`] which is pretty annoying.
// Our [issue](https://github.com/paritytech/jsonrpsee/issues/296).
if let Ok(req) = serde_json::from_slice::<JsonRpcRequest>(&data) {
log::debug!("recv: {:?}", req);
methods.execute(&tx, req, conn_id).await;
} else if let Ok(batch) = serde_json::from_slice::<Vec<JsonRpcRequest>>(&data) {
if !batch.is_empty() {
// Batch responses must be sent back as a single message so we read the results from each request in the
// batch and read the results off of a new channel, `rx_batch`, and then send the complete batch response
// back to the client over `tx`.
let (tx_batch, mut rx_batch) = mpsc::unbounded::<String>();

join_all(batch.into_iter().map(|req| methods.execute(&tx_batch, req, conn_id))).await;

// Closes the receiving half of a channel without dropping it. This prevents any further messages from
// being sent on the channel.
rx_batch.close();
let results = collect_batch_response(rx_batch).await;
if let Err(err) = tx.unbounded_send(results) {
log::error!("Error sending batch response to the client: {:?}", err)
match data.get(0) {
Some(b'{') => {
if let Ok(req) = serde_json::from_slice::<JsonRpcRequest>(&data) {
log::debug!("recv: {:?}", req);
methods.execute(&tx, req, conn_id).await;
} else {
let (id, code) = prepare_error(&data);
send_error(id, &tx, code.into());
}
} else {
send_error(Id::Null, &tx, JsonRpcErrorCode::InvalidRequest.into());
}
} else {
let (id, code) = match serde_json::from_slice::<JsonRpcInvalidRequest>(&data) {
Ok(req) => (req.id, JsonRpcErrorCode::InvalidRequest),
Err(_) => (Id::Null, JsonRpcErrorCode::ParseError),
};

send_error(id, &tx, code.into());
Some(b'[') => {
if let Ok(batch) = serde_json::from_slice::<Vec<JsonRpcRequest>>(&data) {
if !batch.is_empty() {
// Batch responses must be sent back as a single message so we read the results from each request in the
// batch and read the results off of a new channel, `rx_batch`, and then send the complete batch response
// back to the client over `tx`.
let (tx_batch, mut rx_batch) = mpsc::unbounded::<String>();

join_all(batch.into_iter().map(|req| methods.execute(&tx_batch, req, conn_id))).await;

// Closes the receiving half of a channel without dropping it. This prevents any further messages from
// being sent on the channel.
rx_batch.close();
let results = collect_batch_response(rx_batch).await;
if let Err(err) = tx.unbounded_send(results) {
log::error!("Error sending batch response to the client: {:?}", err)
}
} else {
send_error(Id::Null, &tx, JsonRpcErrorCode::InvalidRequest.into());
}
} else {
let (id, code) = prepare_error(&data);
send_error(id, &tx, code.into());
}
}
_ => send_error(Id::Null, &tx, JsonRpcErrorCode::ParseError.into()),
}
}
Ok(())
Expand Down
46 changes: 46 additions & 0 deletions ws-server/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,52 @@ async fn batch_method_call_where_some_calls_fail() {
);
}

#[tokio::test]
async fn garbage_request_fails() {
let addr = server().await;
let mut client = WebSocketTestClient::new(addr).await.unwrap();

let req = r#"dsdfs fsdsfds"#;
let response = client.send_request_text(req).await.unwrap();
assert_eq!(response, parse_error(Id::Null));

let req = r#"{ "#;
let response = client.send_request_text(req).await.unwrap();
assert_eq!(response, parse_error(Id::Null));

let req = r#" {"jsonrpc":"2.0","method":"add", "params":[1, 2],"id":1}"#;
let response = client.send_request_text(req).await.unwrap();
assert_eq!(response, parse_error(Id::Null));

let req = r#"{}"#;
let response = client.send_request_text(req).await.unwrap();
assert_eq!(response, parse_error(Id::Null));

let req = r#"{sds}"#;
let response = client.send_request_text(req).await.unwrap();
assert_eq!(response, parse_error(Id::Null));

let req = r#"["#;
let response = client.send_request_text(req).await.unwrap();
assert_eq!(response, parse_error(Id::Null));

let req = r#"[dsds]"#;
let response = client.send_request_text(req).await.unwrap();
assert_eq!(response, parse_error(Id::Null));

let req = r#" [{"jsonrpc":"2.0","method":"add", "params":[1, 2],"id":1}]"#;
let response = client.send_request_text(req).await.unwrap();
assert_eq!(response, parse_error(Id::Null));

let req = r#"[]"#;
let response = client.send_request_text(req).await.unwrap();
assert_eq!(response, invalid_request(Id::Null));

let req = r#"[{"jsonrpc":"2.0","method":"add", "params":[1, 2],"id":1}"#;
let response = client.send_request_text(req).await.unwrap();
assert_eq!(response, parse_error(Id::Null));
}

#[tokio::test]
async fn single_method_call_with_params_works() {
let addr = server().await;
Expand Down