Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[servers] return error if context or params fails #295

Merged
merged 9 commits into from
May 4, 2021
Merged
10 changes: 5 additions & 5 deletions http-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ impl HttpClientBuilder {

/// Build the HTTP client with target to connect to.
pub fn build(self, target: impl AsRef<str>) -> Result<HttpClient, Error> {
let transport = HttpTransportClient::new(target, self.max_request_body_size)
.map_err(|e| Error::TransportError(Box::new(e)))?;
let transport =
HttpTransportClient::new(target, self.max_request_body_size).map_err(|e| Error::Transport(Box::new(e)))?;
Ok(HttpClient { transport, request_id: AtomicU64::new(0) })
}
}
Expand All @@ -55,7 +55,7 @@ impl Client for HttpClient {
self.transport
.send(serde_json::to_string(&notif).map_err(Error::ParseError)?)
.await
.map_err(|e| Error::TransportError(Box::new(e)))
.map_err(|e| Error::Transport(Box::new(e)))
}

/// Perform a request towards the server.
Expand All @@ -71,7 +71,7 @@ impl Client for HttpClient {
.transport
.send_and_read_body(serde_json::to_string(&request).map_err(Error::ParseError)?)
.await
.map_err(|e| Error::TransportError(Box::new(e)))?;
.map_err(|e| Error::Transport(Box::new(e)))?;

let response: JsonRpcResponse<_> = match serde_json::from_slice(&body) {
Ok(response) => response,
Expand Down Expand Up @@ -110,7 +110,7 @@ impl Client for HttpClient {
.transport
.send_and_read_body(serde_json::to_string(&batch_request).map_err(Error::ParseError)?)
.await
.map_err(|e| Error::TransportError(Box::new(e)))?;
.map_err(|e| Error::Transport(Box::new(e)))?;

let rps: Vec<JsonRpcResponse<_>> = match serde_json::from_slice(&body) {
Ok(response) => response,
Expand Down
36 changes: 25 additions & 11 deletions http-server/src/module.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
use jsonrpsee_types::{traits::RpcMethod, v2::params::RpcParams, Error};
use jsonrpsee_utils::server::{send_response, Methods};
use jsonrpsee_types::v2::error::{JsonRpcErrorCode, JsonRpcErrorObject, CALL_EXECUTION_FAILED_CODE};
use jsonrpsee_types::{
error::{CallError, Error, InvalidParams},
traits::RpcMethod,
v2::params::RpcParams,
};
use jsonrpsee_utils::server::{send_error, send_response, Methods};
use serde::Serialize;
use std::sync::Arc;

Expand Down Expand Up @@ -31,16 +36,17 @@ impl RpcModule {
pub fn register_method<F, R>(&mut self, method_name: &'static str, callback: F) -> Result<(), Error>
where
R: Serialize,
F: RpcMethod<R>,
F: RpcMethod<R, InvalidParams>,
{
self.verify_method_name(method_name)?;

self.methods.insert(
method_name,
Box::new(move |id, params, tx, _| {
let result = callback(params)?;

send_response(id, tx, result);
match callback(params) {
Ok(res) => send_response(id, tx, res),
Err(InvalidParams) => send_error(id, tx, JsonRpcErrorCode::InvalidParams.into()),
};

Ok(())
}),
Expand Down Expand Up @@ -82,7 +88,7 @@ impl<Context> RpcContextModule<Context> {
where
Context: Send + Sync + 'static,
R: Serialize,
F: Fn(RpcParams, &Context) -> Result<R, Error> + Send + Sync + 'static,
F: Fn(RpcParams, &Context) -> Result<R, CallError> + Send + Sync + 'static,
{
self.module.verify_method_name(method_name)?;

Expand All @@ -91,10 +97,18 @@ impl<Context> RpcContextModule<Context> {
self.module.methods.insert(
method_name,
Box::new(move |id, params, tx, _| {
let result = callback(params, &*ctx)?;

send_response(id, tx, result);

match callback(params, &*ctx) {
Ok(res) => send_response(id, tx, res),
Err(CallError::InvalidParams(_)) => send_error(id, tx, JsonRpcErrorCode::InvalidParams.into()),
Err(CallError::Failed(err)) => {
let err = JsonRpcErrorObject {
code: JsonRpcErrorCode::ServerError(CALL_EXECUTION_FAILED_CODE),
message: &err.to_string(),
data: None,
};
send_error(id, tx, err)
}
};
Ok(())
}),
);
Expand Down
4 changes: 2 additions & 2 deletions http-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use hyper::{
service::{make_service_fn, service_fn},
Error as HyperError,
};
use jsonrpsee_types::error::{Error, GenericTransportError};
use jsonrpsee_types::error::{Error, GenericTransportError, InvalidParams};
use jsonrpsee_types::v2::request::{JsonRpcInvalidRequest, JsonRpcRequest};
use jsonrpsee_types::v2::{error::JsonRpcErrorCode, params::RpcParams};
use jsonrpsee_utils::{
Expand Down Expand Up @@ -129,7 +129,7 @@ impl Server {
pub fn register_method<F, R>(&mut self, method_name: &'static str, callback: F) -> Result<(), Error>
where
R: Serialize,
F: Fn(RpcParams) -> Result<R, Error> + Send + Sync + 'static,
F: Fn(RpcParams) -> Result<R, InvalidParams> + Send + Sync + 'static,
{
self.root.register_method(method_name, callback)
}
Expand Down
69 changes: 65 additions & 4 deletions http-server/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

use std::net::SocketAddr;

use crate::HttpServerBuilder;
use crate::{HttpServerBuilder, RpcContextModule};
use jsonrpsee_test_utils::helpers::*;
use jsonrpsee_test_utils::types::{Id, StatusCode};
use jsonrpsee_test_utils::types::{Id, StatusCode, TestContext};
use jsonrpsee_types::error::CallError;
use serde_json::Value as JsonValue;

async fn server() -> SocketAddr {
Expand All @@ -23,6 +24,35 @@ async fn server() -> SocketAddr {
addr
}

/// Run server with user provided context.
pub async fn server_with_context() -> SocketAddr {
let mut server = HttpServerBuilder::default().build("127.0.0.1:0".parse().unwrap()).unwrap();

let ctx = TestContext;
let mut rpc_ctx = RpcContextModule::new(ctx);

rpc_ctx
.register_method("should_err", |_p, ctx| {
let _ = ctx.err().map_err(|e| CallError::Failed(e.into()))?;
Ok("err")
})
.unwrap();

rpc_ctx
.register_method("should_ok", |_p, ctx| {
let _ = ctx.ok().map_err(|e| CallError::Failed(e.into()))?;
Ok("ok")
})
.unwrap();

let rpc_module = rpc_ctx.into_module();
server.register_module(rpc_module).unwrap();
let addr = server.local_addr().unwrap();

tokio::spawn(async { server.start().await });
addr
}

#[tokio::test]
async fn single_method_call_works() {
let _ = env_logger::try_init();
Expand Down Expand Up @@ -54,14 +84,45 @@ async fn single_method_call_with_params() {
let addr = server().await;
let uri = to_http_uri(addr);

std::thread::sleep(std::time::Duration::from_secs(2));

let req = r#"{"jsonrpc":"2.0","method":"add", "params":[1, 2],"id":1}"#;
let response = http_request(req.into(), uri).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
assert_eq!(response.body, ok_response(JsonValue::Number(3.into()), Id::Num(1)));
}

#[tokio::test]
async fn single_method_call_with_faulty_params_returns_err() {
let addr = server().await;
let uri = to_http_uri(addr);

let req = r#"{"jsonrpc":"2.0","method":"add", "params":["Invalid"],"id":1}"#;
let response = http_request(req.into(), uri).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
assert_eq!(response.body, invalid_params(Id::Num(1)));
}

niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
#[tokio::test]
async fn single_method_call_with_faulty_context() {
let addr = server_with_context().await;
let uri = to_http_uri(addr);

let req = r#"{"jsonrpc":"2.0","method":"should_err", "params":[],"id":1}"#;
let response = http_request(req.into(), uri).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
assert_eq!(response.body, invalid_context("RPC context failed", Id::Num(1)));
}

#[tokio::test]
async fn single_method_call_with_ok_context() {
let addr = server_with_context().await;
let uri = to_http_uri(addr);

let req = r#"{"jsonrpc":"2.0","method":"should_ok", "params":[],"id":1}"#;
let response = http_request(req.into(), uri).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
assert_eq!(response.body, ok_response("ok".into(), Id::Num(1)));
}

#[tokio::test]
async fn valid_batched_method_calls() {
let _ = env_logger::try_init();
Expand Down
1 change: 1 addition & 0 deletions test-utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ edition = "2018"

[dependencies]
async-std = "1.9"
anyhow = "1"
futures-channel = "0.3"
futures-util = "0.3"
hyper = { version = "0.14", features = ["full"] }
Expand Down
8 changes: 8 additions & 0 deletions test-utils/src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ pub fn invalid_params(id: Id) -> String {
)
}

pub fn invalid_context(msg: &str, id: Id) -> String {
format!(
r#"{{"jsonrpc":"2.0","error":{{"code":-32000,"message":"{}"}},"id":{}}}"#,
msg,
serde_json::to_string(&id).unwrap()
)
}

pub fn internal_error(id: Id) -> String {
format!(
r#"{{"jsonrpc":"2.0","error":{{"code":-32603,"message":"Internal error"}},"id":{}}}"#,
Expand Down
11 changes: 11 additions & 0 deletions test-utils/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@ pub use hyper::{Body, HeaderMap, StatusCode, Uri};

type Error = Box<dyn std::error::Error>;

pub struct TestContext;

impl TestContext {
pub fn ok(&self) -> Result<(), anyhow::Error> {
Ok(())
}
pub fn err(&self) -> Result<(), anyhow::Error> {
Err(anyhow::anyhow!("RPC context failed"))
}
}

/// Request Id
#[derive(Debug, PartialEq, Clone, Hash, Eq, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
Expand Down
4 changes: 2 additions & 2 deletions tests/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,14 @@ async fn wss_works() {
#[tokio::test]
async fn ws_with_non_ascii_url_doesnt_hang_or_panic() {
let err = WsClientBuilder::default().build("wss://♥♥♥♥♥♥∀∂").await;
assert!(matches!(err, Err(Error::TransportError(_))));
assert!(matches!(err, Err(Error::Transport(_))));
}

#[tokio::test]
async fn http_with_non_ascii_url_doesnt_hang_or_panic() {
let client = HttpClientBuilder::default().build("http://♥♥♥♥♥♥∀∂").unwrap();
let err: Result<(), Error> = client.request("system_chain", JsonRpcParams::NoParams).await;
assert!(matches!(err, Err(Error::TransportError(_))));
assert!(matches!(err, Err(Error::Transport(_))));
}

#[tokio::test]
Expand Down
31 changes: 26 additions & 5 deletions types/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,36 @@ impl<T: fmt::Display> fmt::Display for Mismatch<T> {
}
}

/// Invalid params.
#[derive(Debug)]
pub struct InvalidParams;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why we need this zst, i.e. why the ServerCallError::InvalidParams variant needs an unnamed field. I'll read more and perhaps I'll figure it out.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was just that register_method_without_context can only fail when the params couldn't be parsed so it was introduced to avoid matching on unreachable enum variants


/// Error that occurs when a call failed.
#[derive(Debug, thiserror::Error)]
pub enum CallError {
#[error("Invalid params in the RPC call")]
/// Invalid params in the call.
InvalidParams(InvalidParams),
dvdplm marked this conversation as resolved.
Show resolved Hide resolved
#[error("RPC Call failed: {0}")]
/// The call failed.
Failed(#[source] Box<dyn std::error::Error + Send + Sync>),
}

impl From<InvalidParams> for CallError {
fn from(params: InvalidParams) -> Self {
Self::InvalidParams(params)
}
}

/// Error type.
#[derive(Debug, thiserror::Error)]
pub enum Error {
/// Error that occurs when a call failed.
#[error("Server call failed: {0}")]
Call(CallError),
/// Networking error or error on the low-level protocol layer.
#[error("Networking or low-level protocol error: {0}")]
TransportError(#[source] Box<dyn std::error::Error + Send + Sync>),
Transport(#[source] Box<dyn std::error::Error + Send + Sync>),
/// JSON-RPC request error.
#[error("JSON-RPC request error: {0:?}")]
Request(#[source] JsonRpcErrorAlloc),
Expand All @@ -34,7 +58,7 @@ pub enum Error {
/// The background task has been terminated.
#[error("The background task been terminated because: {0}; restart required")]
RestartNeeded(String),
dvdplm marked this conversation as resolved.
Show resolved Hide resolved
/// Failed to parse the data that the server sent back to us.
/// Failed to parse the data.
#[error("Parse error: {0}")]
ParseError(#[source] serde_json::Error),
/// Invalid subscription ID.
Expand All @@ -43,9 +67,6 @@ pub enum Error {
/// Invalid request ID.
#[error("Invalid request ID")]
InvalidRequestId,
/// Invalid params in the RPC call.
#[error("Invalid params in the RPC call")]
InvalidParams,
/// A request with the same request ID has already been registered.
#[error("A request with the same request ID has already been registered")]
DuplicateRequestId,
Expand Down
4 changes: 2 additions & 2 deletions types/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ pub trait SubscriptionClient: Client {
}

/// JSON-RPC server interface for managing method calls.
pub trait RpcMethod<R>: Fn(RpcParams) -> Result<R, Error> + Send + Sync + 'static {}
pub trait RpcMethod<R, E>: Fn(RpcParams) -> Result<R, E> + Send + Sync + 'static {}

impl<R, T> RpcMethod<R> for T where T: Fn(RpcParams) -> Result<R, Error> + Send + Sync + 'static {}
impl<R, T, E> RpcMethod<R, E> for T where T: Fn(RpcParams) -> Result<R, E> + Send + Sync + 'static {}
2 changes: 2 additions & 0 deletions types/src/v2/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ pub const INVALID_PARAMS_CODE: i32 = -32602;
pub const INVALID_REQUEST_CODE: i32 = -32600;
/// Method not found error code.
pub const METHOD_NOT_FOUND_CODE: i32 = -32601;
/// Custom server error when a call failed.
pub const CALL_EXECUTION_FAILED_CODE: i32 = -32000;

/// Parse error message
pub const PARSE_ERROR_MSG: &str = "Parse error";
Expand Down
6 changes: 3 additions & 3 deletions types/src/v2/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::Error;
use crate::error::Error;
use serde::de::DeserializeOwned;
use serde_json::value::RawValue;

Expand All @@ -12,11 +12,11 @@ pub mod request;
pub mod response;

/// Parse request ID from RawValue.
pub fn parse_request_id<T: DeserializeOwned>(raw: Option<&RawValue>) -> Result<T, crate::Error> {
pub fn parse_request_id<T: DeserializeOwned>(raw: Option<&RawValue>) -> Result<T, Error> {
match raw {
None => Err(Error::InvalidRequestId),
Some(v) => {
let val = serde_json::from_str(v.get()).map_err(Error::ParseError)?;
let val = serde_json::from_str(v.get()).map_err(|_| Error::InvalidRequestId)?;
Ok(val)
}
}
Expand Down
Loading