Skip to content

Commit

Permalink
[servers] return error if context or params fails (#295)
Browse files Browse the repository at this point in the history
* ret err if context/params fails

* address grumbles: specific error_code context fail

* address grumbles: make env_logger dev-dependency

* address grumbles: add tests

* chore(deps): remove unused deps

* address grumbles: rename types and docs

* address grumbles: more renaming.

* fix build
  • Loading branch information
niklasad1 authored May 4, 2021
1 parent 2cae10b commit b51abec
Show file tree
Hide file tree
Showing 18 changed files with 273 additions and 98 deletions.
10 changes: 5 additions & 5 deletions http-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ impl HttpClientBuilder {

/// Build the HTTP client with target to connect to.
pub fn build(self, target: impl AsRef<str>) -> Result<HttpClient, Error> {
let transport = HttpTransportClient::new(target, self.max_request_body_size)
.map_err(|e| Error::TransportError(Box::new(e)))?;
let transport =
HttpTransportClient::new(target, self.max_request_body_size).map_err(|e| Error::Transport(Box::new(e)))?;
Ok(HttpClient { transport, request_id: AtomicU64::new(0) })
}
}
Expand All @@ -55,7 +55,7 @@ impl Client for HttpClient {
self.transport
.send(serde_json::to_string(&notif).map_err(Error::ParseError)?)
.await
.map_err(|e| Error::TransportError(Box::new(e)))
.map_err(|e| Error::Transport(Box::new(e)))
}

/// Perform a request towards the server.
Expand All @@ -71,7 +71,7 @@ impl Client for HttpClient {
.transport
.send_and_read_body(serde_json::to_string(&request).map_err(Error::ParseError)?)
.await
.map_err(|e| Error::TransportError(Box::new(e)))?;
.map_err(|e| Error::Transport(Box::new(e)))?;

let response: JsonRpcResponse<_> = match serde_json::from_slice(&body) {
Ok(response) => response,
Expand Down Expand Up @@ -110,7 +110,7 @@ impl Client for HttpClient {
.transport
.send_and_read_body(serde_json::to_string(&batch_request).map_err(Error::ParseError)?)
.await
.map_err(|e| Error::TransportError(Box::new(e)))?;
.map_err(|e| Error::Transport(Box::new(e)))?;

let rps: Vec<JsonRpcResponse<_>> = match serde_json::from_slice(&body) {
Ok(response) => response,
Expand Down
36 changes: 25 additions & 11 deletions http-server/src/module.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
use jsonrpsee_types::{traits::RpcMethod, v2::params::RpcParams, Error};
use jsonrpsee_utils::server::{send_response, Methods};
use jsonrpsee_types::v2::error::{JsonRpcErrorCode, JsonRpcErrorObject, CALL_EXECUTION_FAILED_CODE};
use jsonrpsee_types::{
error::{CallError, Error, InvalidParams},
traits::RpcMethod,
v2::params::RpcParams,
};
use jsonrpsee_utils::server::{send_error, send_response, Methods};
use serde::Serialize;
use std::sync::Arc;

Expand Down Expand Up @@ -31,16 +36,17 @@ impl RpcModule {
pub fn register_method<F, R>(&mut self, method_name: &'static str, callback: F) -> Result<(), Error>
where
R: Serialize,
F: RpcMethod<R>,
F: RpcMethod<R, InvalidParams>,
{
self.verify_method_name(method_name)?;

self.methods.insert(
method_name,
Box::new(move |id, params, tx, _| {
let result = callback(params)?;

send_response(id, tx, result);
match callback(params) {
Ok(res) => send_response(id, tx, res),
Err(InvalidParams) => send_error(id, tx, JsonRpcErrorCode::InvalidParams.into()),
};

Ok(())
}),
Expand Down Expand Up @@ -82,7 +88,7 @@ impl<Context> RpcContextModule<Context> {
where
Context: Send + Sync + 'static,
R: Serialize,
F: Fn(RpcParams, &Context) -> Result<R, Error> + Send + Sync + 'static,
F: Fn(RpcParams, &Context) -> Result<R, CallError> + Send + Sync + 'static,
{
self.module.verify_method_name(method_name)?;

Expand All @@ -91,10 +97,18 @@ impl<Context> RpcContextModule<Context> {
self.module.methods.insert(
method_name,
Box::new(move |id, params, tx, _| {
let result = callback(params, &*ctx)?;

send_response(id, tx, result);

match callback(params, &*ctx) {
Ok(res) => send_response(id, tx, res),
Err(CallError::InvalidParams(_)) => send_error(id, tx, JsonRpcErrorCode::InvalidParams.into()),
Err(CallError::Failed(err)) => {
let err = JsonRpcErrorObject {
code: JsonRpcErrorCode::ServerError(CALL_EXECUTION_FAILED_CODE),
message: &err.to_string(),
data: None,
};
send_error(id, tx, err)
}
};
Ok(())
}),
);
Expand Down
4 changes: 2 additions & 2 deletions http-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use hyper::{
service::{make_service_fn, service_fn},
Error as HyperError,
};
use jsonrpsee_types::error::{Error, GenericTransportError};
use jsonrpsee_types::error::{Error, GenericTransportError, InvalidParams};
use jsonrpsee_types::v2::request::{JsonRpcInvalidRequest, JsonRpcRequest};
use jsonrpsee_types::v2::{error::JsonRpcErrorCode, params::RpcParams};
use jsonrpsee_utils::{
Expand Down Expand Up @@ -129,7 +129,7 @@ impl Server {
pub fn register_method<F, R>(&mut self, method_name: &'static str, callback: F) -> Result<(), Error>
where
R: Serialize,
F: Fn(RpcParams) -> Result<R, Error> + Send + Sync + 'static,
F: Fn(RpcParams) -> Result<R, InvalidParams> + Send + Sync + 'static,
{
self.root.register_method(method_name, callback)
}
Expand Down
69 changes: 65 additions & 4 deletions http-server/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

use std::net::SocketAddr;

use crate::HttpServerBuilder;
use crate::{HttpServerBuilder, RpcContextModule};
use jsonrpsee_test_utils::helpers::*;
use jsonrpsee_test_utils::types::{Id, StatusCode};
use jsonrpsee_test_utils::types::{Id, StatusCode, TestContext};
use jsonrpsee_types::error::CallError;
use serde_json::Value as JsonValue;

async fn server() -> SocketAddr {
Expand All @@ -23,6 +24,35 @@ async fn server() -> SocketAddr {
addr
}

/// Run server with user provided context.
pub async fn server_with_context() -> SocketAddr {
let mut server = HttpServerBuilder::default().build("127.0.0.1:0".parse().unwrap()).unwrap();

let ctx = TestContext;
let mut rpc_ctx = RpcContextModule::new(ctx);

rpc_ctx
.register_method("should_err", |_p, ctx| {
let _ = ctx.err().map_err(|e| CallError::Failed(e.into()))?;
Ok("err")
})
.unwrap();

rpc_ctx
.register_method("should_ok", |_p, ctx| {
let _ = ctx.ok().map_err(|e| CallError::Failed(e.into()))?;
Ok("ok")
})
.unwrap();

let rpc_module = rpc_ctx.into_module();
server.register_module(rpc_module).unwrap();
let addr = server.local_addr().unwrap();

tokio::spawn(async { server.start().await });
addr
}

#[tokio::test]
async fn single_method_call_works() {
let _ = env_logger::try_init();
Expand Down Expand Up @@ -54,14 +84,45 @@ async fn single_method_call_with_params() {
let addr = server().await;
let uri = to_http_uri(addr);

std::thread::sleep(std::time::Duration::from_secs(2));

let req = r#"{"jsonrpc":"2.0","method":"add", "params":[1, 2],"id":1}"#;
let response = http_request(req.into(), uri).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
assert_eq!(response.body, ok_response(JsonValue::Number(3.into()), Id::Num(1)));
}

#[tokio::test]
async fn single_method_call_with_faulty_params_returns_err() {
let addr = server().await;
let uri = to_http_uri(addr);

let req = r#"{"jsonrpc":"2.0","method":"add", "params":["Invalid"],"id":1}"#;
let response = http_request(req.into(), uri).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
assert_eq!(response.body, invalid_params(Id::Num(1)));
}

#[tokio::test]
async fn single_method_call_with_faulty_context() {
let addr = server_with_context().await;
let uri = to_http_uri(addr);

let req = r#"{"jsonrpc":"2.0","method":"should_err", "params":[],"id":1}"#;
let response = http_request(req.into(), uri).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
assert_eq!(response.body, invalid_context("RPC context failed", Id::Num(1)));
}

#[tokio::test]
async fn single_method_call_with_ok_context() {
let addr = server_with_context().await;
let uri = to_http_uri(addr);

let req = r#"{"jsonrpc":"2.0","method":"should_ok", "params":[],"id":1}"#;
let response = http_request(req.into(), uri).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
assert_eq!(response.body, ok_response("ok".into(), Id::Num(1)));
}

#[tokio::test]
async fn valid_batched_method_calls() {
let _ = env_logger::try_init();
Expand Down
1 change: 1 addition & 0 deletions test-utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ edition = "2018"

[dependencies]
async-std = "1.9"
anyhow = "1"
futures-channel = "0.3"
futures-util = "0.3"
hyper = { version = "0.14", features = ["full"] }
Expand Down
8 changes: 8 additions & 0 deletions test-utils/src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ pub fn invalid_params(id: Id) -> String {
)
}

pub fn invalid_context(msg: &str, id: Id) -> String {
format!(
r#"{{"jsonrpc":"2.0","error":{{"code":-32000,"message":"{}"}},"id":{}}}"#,
msg,
serde_json::to_string(&id).unwrap()
)
}

pub fn internal_error(id: Id) -> String {
format!(
r#"{{"jsonrpc":"2.0","error":{{"code":-32603,"message":"Internal error"}},"id":{}}}"#,
Expand Down
11 changes: 11 additions & 0 deletions test-utils/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@ pub use hyper::{Body, HeaderMap, StatusCode, Uri};

type Error = Box<dyn std::error::Error>;

pub struct TestContext;

impl TestContext {
pub fn ok(&self) -> Result<(), anyhow::Error> {
Ok(())
}
pub fn err(&self) -> Result<(), anyhow::Error> {
Err(anyhow::anyhow!("RPC context failed"))
}
}

/// Request Id
#[derive(Debug, PartialEq, Clone, Hash, Eq, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
Expand Down
4 changes: 2 additions & 2 deletions tests/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,14 @@ async fn wss_works() {
#[tokio::test]
async fn ws_with_non_ascii_url_doesnt_hang_or_panic() {
let err = WsClientBuilder::default().build("wss://♥♥♥♥♥♥∀∂").await;
assert!(matches!(err, Err(Error::TransportError(_))));
assert!(matches!(err, Err(Error::Transport(_))));
}

#[tokio::test]
async fn http_with_non_ascii_url_doesnt_hang_or_panic() {
let client = HttpClientBuilder::default().build("http://♥♥♥♥♥♥∀∂").unwrap();
let err: Result<(), Error> = client.request("system_chain", JsonRpcParams::NoParams).await;
assert!(matches!(err, Err(Error::TransportError(_))));
assert!(matches!(err, Err(Error::Transport(_))));
}

#[tokio::test]
Expand Down
31 changes: 26 additions & 5 deletions types/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,36 @@ impl<T: fmt::Display> fmt::Display for Mismatch<T> {
}
}

/// Invalid params.
#[derive(Debug)]
pub struct InvalidParams;

/// Error that occurs when a call failed.
#[derive(Debug, thiserror::Error)]
pub enum CallError {
#[error("Invalid params in the RPC call")]
/// Invalid params in the call.
InvalidParams(InvalidParams),
#[error("RPC Call failed: {0}")]
/// The call failed.
Failed(#[source] Box<dyn std::error::Error + Send + Sync>),
}

impl From<InvalidParams> for CallError {
fn from(params: InvalidParams) -> Self {
Self::InvalidParams(params)
}
}

/// Error type.
#[derive(Debug, thiserror::Error)]
pub enum Error {
/// Error that occurs when a call failed.
#[error("Server call failed: {0}")]
Call(CallError),
/// Networking error or error on the low-level protocol layer.
#[error("Networking or low-level protocol error: {0}")]
TransportError(#[source] Box<dyn std::error::Error + Send + Sync>),
Transport(#[source] Box<dyn std::error::Error + Send + Sync>),
/// JSON-RPC request error.
#[error("JSON-RPC request error: {0:?}")]
Request(#[source] JsonRpcErrorAlloc),
Expand All @@ -34,7 +58,7 @@ pub enum Error {
/// The background task has been terminated.
#[error("The background task been terminated because: {0}; restart required")]
RestartNeeded(String),
/// Failed to parse the data that the server sent back to us.
/// Failed to parse the data.
#[error("Parse error: {0}")]
ParseError(#[source] serde_json::Error),
/// Invalid subscription ID.
Expand All @@ -43,9 +67,6 @@ pub enum Error {
/// Invalid request ID.
#[error("Invalid request ID")]
InvalidRequestId,
/// Invalid params in the RPC call.
#[error("Invalid params in the RPC call")]
InvalidParams,
/// A request with the same request ID has already been registered.
#[error("A request with the same request ID has already been registered")]
DuplicateRequestId,
Expand Down
4 changes: 2 additions & 2 deletions types/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ pub trait SubscriptionClient: Client {
}

/// JSON-RPC server interface for managing method calls.
pub trait RpcMethod<R>: Fn(RpcParams) -> Result<R, Error> + Send + Sync + 'static {}
pub trait RpcMethod<R, E>: Fn(RpcParams) -> Result<R, E> + Send + Sync + 'static {}

impl<R, T> RpcMethod<R> for T where T: Fn(RpcParams) -> Result<R, Error> + Send + Sync + 'static {}
impl<R, T, E> RpcMethod<R, E> for T where T: Fn(RpcParams) -> Result<R, E> + Send + Sync + 'static {}
2 changes: 2 additions & 0 deletions types/src/v2/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ pub const INVALID_PARAMS_CODE: i32 = -32602;
pub const INVALID_REQUEST_CODE: i32 = -32600;
/// Method not found error code.
pub const METHOD_NOT_FOUND_CODE: i32 = -32601;
/// Custom server error when a call failed.
pub const CALL_EXECUTION_FAILED_CODE: i32 = -32000;

/// Parse error message
pub const PARSE_ERROR_MSG: &str = "Parse error";
Expand Down
6 changes: 3 additions & 3 deletions types/src/v2/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::Error;
use crate::error::Error;
use serde::de::DeserializeOwned;
use serde_json::value::RawValue;

Expand All @@ -12,11 +12,11 @@ pub mod request;
pub mod response;

/// Parse request ID from RawValue.
pub fn parse_request_id<T: DeserializeOwned>(raw: Option<&RawValue>) -> Result<T, crate::Error> {
pub fn parse_request_id<T: DeserializeOwned>(raw: Option<&RawValue>) -> Result<T, Error> {
match raw {
None => Err(Error::InvalidRequestId),
Some(v) => {
let val = serde_json::from_str(v.get()).map_err(Error::ParseError)?;
let val = serde_json::from_str(v.get()).map_err(|_| Error::InvalidRequestId)?;
Ok(val)
}
}
Expand Down
Loading

0 comments on commit b51abec

Please sign in to comment.