diff --git a/http-client/src/client.rs b/http-client/src/client.rs index 7d981e6dac..c1bc3164d1 100644 --- a/http-client/src/client.rs +++ b/http-client/src/client.rs @@ -27,8 +27,8 @@ impl HttpClientBuilder { /// Build the HTTP client with target to connect to. pub fn build(self, target: impl AsRef) -> Result { - let transport = HttpTransportClient::new(target, self.max_request_body_size) - .map_err(|e| Error::TransportError(Box::new(e)))?; + let transport = + HttpTransportClient::new(target, self.max_request_body_size).map_err(|e| Error::Transport(Box::new(e)))?; Ok(HttpClient { transport, request_id: AtomicU64::new(0) }) } } @@ -55,7 +55,7 @@ impl Client for HttpClient { self.transport .send(serde_json::to_string(¬if).map_err(Error::ParseError)?) .await - .map_err(|e| Error::TransportError(Box::new(e))) + .map_err(|e| Error::Transport(Box::new(e))) } /// Perform a request towards the server. @@ -71,7 +71,7 @@ impl Client for HttpClient { .transport .send_and_read_body(serde_json::to_string(&request).map_err(Error::ParseError)?) .await - .map_err(|e| Error::TransportError(Box::new(e)))?; + .map_err(|e| Error::Transport(Box::new(e)))?; let response: JsonRpcResponse<_> = match serde_json::from_slice(&body) { Ok(response) => response, @@ -110,7 +110,7 @@ impl Client for HttpClient { .transport .send_and_read_body(serde_json::to_string(&batch_request).map_err(Error::ParseError)?) .await - .map_err(|e| Error::TransportError(Box::new(e)))?; + .map_err(|e| Error::Transport(Box::new(e)))?; let rps: Vec> = match serde_json::from_slice(&body) { Ok(response) => response, diff --git a/http-server/src/module.rs b/http-server/src/module.rs index f4d42f7efb..65e4017411 100644 --- a/http-server/src/module.rs +++ b/http-server/src/module.rs @@ -1,5 +1,10 @@ -use jsonrpsee_types::{traits::RpcMethod, v2::params::RpcParams, Error}; -use jsonrpsee_utils::server::{send_response, Methods}; +use jsonrpsee_types::v2::error::{JsonRpcErrorCode, JsonRpcErrorObject, CALL_EXECUTION_FAILED_CODE}; +use jsonrpsee_types::{ + error::{CallError, Error, InvalidParams}, + traits::RpcMethod, + v2::params::RpcParams, +}; +use jsonrpsee_utils::server::{send_error, send_response, Methods}; use serde::Serialize; use std::sync::Arc; @@ -31,16 +36,17 @@ impl RpcModule { pub fn register_method(&mut self, method_name: &'static str, callback: F) -> Result<(), Error> where R: Serialize, - F: RpcMethod, + F: RpcMethod, { self.verify_method_name(method_name)?; self.methods.insert( method_name, Box::new(move |id, params, tx, _| { - let result = callback(params)?; - - send_response(id, tx, result); + match callback(params) { + Ok(res) => send_response(id, tx, res), + Err(InvalidParams) => send_error(id, tx, JsonRpcErrorCode::InvalidParams.into()), + }; Ok(()) }), @@ -82,7 +88,7 @@ impl RpcContextModule { where Context: Send + Sync + 'static, R: Serialize, - F: Fn(RpcParams, &Context) -> Result + Send + Sync + 'static, + F: Fn(RpcParams, &Context) -> Result + Send + Sync + 'static, { self.module.verify_method_name(method_name)?; @@ -91,10 +97,18 @@ impl RpcContextModule { self.module.methods.insert( method_name, Box::new(move |id, params, tx, _| { - let result = callback(params, &*ctx)?; - - send_response(id, tx, result); - + match callback(params, &*ctx) { + Ok(res) => send_response(id, tx, res), + Err(CallError::InvalidParams(_)) => send_error(id, tx, JsonRpcErrorCode::InvalidParams.into()), + Err(CallError::Failed(err)) => { + let err = JsonRpcErrorObject { + code: JsonRpcErrorCode::ServerError(CALL_EXECUTION_FAILED_CODE), + message: &err.to_string(), + data: None, + }; + send_error(id, tx, err) + } + }; Ok(()) }), ); diff --git a/http-server/src/server.rs b/http-server/src/server.rs index 3fc9123574..df4aa7fa6f 100644 --- a/http-server/src/server.rs +++ b/http-server/src/server.rs @@ -36,7 +36,7 @@ use hyper::{ service::{make_service_fn, service_fn}, Error as HyperError, }; -use jsonrpsee_types::error::{Error, GenericTransportError}; +use jsonrpsee_types::error::{Error, GenericTransportError, InvalidParams}; use jsonrpsee_types::v2::request::{JsonRpcInvalidRequest, JsonRpcRequest}; use jsonrpsee_types::v2::{error::JsonRpcErrorCode, params::RpcParams}; use jsonrpsee_utils::{ @@ -129,7 +129,7 @@ impl Server { pub fn register_method(&mut self, method_name: &'static str, callback: F) -> Result<(), Error> where R: Serialize, - F: Fn(RpcParams) -> Result + Send + Sync + 'static, + F: Fn(RpcParams) -> Result + Send + Sync + 'static, { self.root.register_method(method_name, callback) } diff --git a/http-server/src/tests.rs b/http-server/src/tests.rs index 318d8192b4..16e9328f65 100644 --- a/http-server/src/tests.rs +++ b/http-server/src/tests.rs @@ -2,9 +2,10 @@ use std::net::SocketAddr; -use crate::HttpServerBuilder; +use crate::{HttpServerBuilder, RpcContextModule}; use jsonrpsee_test_utils::helpers::*; -use jsonrpsee_test_utils::types::{Id, StatusCode}; +use jsonrpsee_test_utils::types::{Id, StatusCode, TestContext}; +use jsonrpsee_types::error::CallError; use serde_json::Value as JsonValue; async fn server() -> SocketAddr { @@ -23,6 +24,35 @@ async fn server() -> SocketAddr { addr } +/// Run server with user provided context. +pub async fn server_with_context() -> SocketAddr { + let mut server = HttpServerBuilder::default().build("127.0.0.1:0".parse().unwrap()).unwrap(); + + let ctx = TestContext; + let mut rpc_ctx = RpcContextModule::new(ctx); + + rpc_ctx + .register_method("should_err", |_p, ctx| { + let _ = ctx.err().map_err(|e| CallError::Failed(e.into()))?; + Ok("err") + }) + .unwrap(); + + rpc_ctx + .register_method("should_ok", |_p, ctx| { + let _ = ctx.ok().map_err(|e| CallError::Failed(e.into()))?; + Ok("ok") + }) + .unwrap(); + + let rpc_module = rpc_ctx.into_module(); + server.register_module(rpc_module).unwrap(); + let addr = server.local_addr().unwrap(); + + tokio::spawn(async { server.start().await }); + addr +} + #[tokio::test] async fn single_method_call_works() { let _ = env_logger::try_init(); @@ -54,14 +84,45 @@ async fn single_method_call_with_params() { let addr = server().await; let uri = to_http_uri(addr); - std::thread::sleep(std::time::Duration::from_secs(2)); - let req = r#"{"jsonrpc":"2.0","method":"add", "params":[1, 2],"id":1}"#; let response = http_request(req.into(), uri).await.unwrap(); assert_eq!(response.status, StatusCode::OK); assert_eq!(response.body, ok_response(JsonValue::Number(3.into()), Id::Num(1))); } +#[tokio::test] +async fn single_method_call_with_faulty_params_returns_err() { + let addr = server().await; + let uri = to_http_uri(addr); + + let req = r#"{"jsonrpc":"2.0","method":"add", "params":["Invalid"],"id":1}"#; + let response = http_request(req.into(), uri).await.unwrap(); + assert_eq!(response.status, StatusCode::OK); + assert_eq!(response.body, invalid_params(Id::Num(1))); +} + +#[tokio::test] +async fn single_method_call_with_faulty_context() { + let addr = server_with_context().await; + let uri = to_http_uri(addr); + + let req = r#"{"jsonrpc":"2.0","method":"should_err", "params":[],"id":1}"#; + let response = http_request(req.into(), uri).await.unwrap(); + assert_eq!(response.status, StatusCode::OK); + assert_eq!(response.body, invalid_context("RPC context failed", Id::Num(1))); +} + +#[tokio::test] +async fn single_method_call_with_ok_context() { + let addr = server_with_context().await; + let uri = to_http_uri(addr); + + let req = r#"{"jsonrpc":"2.0","method":"should_ok", "params":[],"id":1}"#; + let response = http_request(req.into(), uri).await.unwrap(); + assert_eq!(response.status, StatusCode::OK); + assert_eq!(response.body, ok_response("ok".into(), Id::Num(1))); +} + #[tokio::test] async fn valid_batched_method_calls() { let _ = env_logger::try_init(); diff --git a/test-utils/Cargo.toml b/test-utils/Cargo.toml index 09cce1d9dd..4bf6acaed1 100644 --- a/test-utils/Cargo.toml +++ b/test-utils/Cargo.toml @@ -9,6 +9,7 @@ edition = "2018" [dependencies] async-std = "1.9" +anyhow = "1" futures-channel = "0.3" futures-util = "0.3" hyper = { version = "0.14", features = ["full"] } diff --git a/test-utils/src/helpers.rs b/test-utils/src/helpers.rs index a7ee7dc109..c0318c6750 100644 --- a/test-utils/src/helpers.rs +++ b/test-utils/src/helpers.rs @@ -57,6 +57,14 @@ pub fn invalid_params(id: Id) -> String { ) } +pub fn invalid_context(msg: &str, id: Id) -> String { + format!( + r#"{{"jsonrpc":"2.0","error":{{"code":-32000,"message":"{}"}},"id":{}}}"#, + msg, + serde_json::to_string(&id).unwrap() + ) +} + pub fn internal_error(id: Id) -> String { format!( r#"{{"jsonrpc":"2.0","error":{{"code":-32603,"message":"Internal error"}},"id":{}}}"#, diff --git a/test-utils/src/types.rs b/test-utils/src/types.rs index 7bce7eeb68..ec853d58f3 100644 --- a/test-utils/src/types.rs +++ b/test-utils/src/types.rs @@ -18,6 +18,17 @@ pub use hyper::{Body, HeaderMap, StatusCode, Uri}; type Error = Box; +pub struct TestContext; + +impl TestContext { + pub fn ok(&self) -> Result<(), anyhow::Error> { + Ok(()) + } + pub fn err(&self) -> Result<(), anyhow::Error> { + Err(anyhow::anyhow!("RPC context failed")) + } +} + /// Request Id #[derive(Debug, PartialEq, Clone, Hash, Eq, Deserialize, Serialize)] #[serde(deny_unknown_fields)] diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index c7743550c9..e809ef8318 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -203,14 +203,14 @@ async fn wss_works() { #[tokio::test] async fn ws_with_non_ascii_url_doesnt_hang_or_panic() { let err = WsClientBuilder::default().build("wss://♥♥♥♥♥♥∀∂").await; - assert!(matches!(err, Err(Error::TransportError(_)))); + assert!(matches!(err, Err(Error::Transport(_)))); } #[tokio::test] async fn http_with_non_ascii_url_doesnt_hang_or_panic() { let client = HttpClientBuilder::default().build("http://♥♥♥♥♥♥∀∂").unwrap(); let err: Result<(), Error> = client.request("system_chain", JsonRpcParams::NoParams).await; - assert!(matches!(err, Err(Error::TransportError(_)))); + assert!(matches!(err, Err(Error::Transport(_)))); } #[tokio::test] diff --git a/types/src/error.rs b/types/src/error.rs index 1483c7e7b1..a76f2f169e 100644 --- a/types/src/error.rs +++ b/types/src/error.rs @@ -16,12 +16,36 @@ impl fmt::Display for Mismatch { } } +/// Invalid params. +#[derive(Debug)] +pub struct InvalidParams; + +/// Error that occurs when a call failed. +#[derive(Debug, thiserror::Error)] +pub enum CallError { + #[error("Invalid params in the RPC call")] + /// Invalid params in the call. + InvalidParams(InvalidParams), + #[error("RPC Call failed: {0}")] + /// The call failed. + Failed(#[source] Box), +} + +impl From for CallError { + fn from(params: InvalidParams) -> Self { + Self::InvalidParams(params) + } +} + /// Error type. #[derive(Debug, thiserror::Error)] pub enum Error { + /// Error that occurs when a call failed. + #[error("Server call failed: {0}")] + Call(CallError), /// Networking error or error on the low-level protocol layer. #[error("Networking or low-level protocol error: {0}")] - TransportError(#[source] Box), + Transport(#[source] Box), /// JSON-RPC request error. #[error("JSON-RPC request error: {0:?}")] Request(#[source] JsonRpcErrorAlloc), @@ -34,7 +58,7 @@ pub enum Error { /// The background task has been terminated. #[error("The background task been terminated because: {0}; restart required")] RestartNeeded(String), - /// Failed to parse the data that the server sent back to us. + /// Failed to parse the data. #[error("Parse error: {0}")] ParseError(#[source] serde_json::Error), /// Invalid subscription ID. @@ -43,9 +67,6 @@ pub enum Error { /// Invalid request ID. #[error("Invalid request ID")] InvalidRequestId, - /// Invalid params in the RPC call. - #[error("Invalid params in the RPC call")] - InvalidParams, /// A request with the same request ID has already been registered. #[error("A request with the same request ID has already been registered")] DuplicateRequestId, diff --git a/types/src/traits.rs b/types/src/traits.rs index b6a657a923..96bafbdc3a 100644 --- a/types/src/traits.rs +++ b/types/src/traits.rs @@ -47,6 +47,6 @@ pub trait SubscriptionClient: Client { } /// JSON-RPC server interface for managing method calls. -pub trait RpcMethod: Fn(RpcParams) -> Result + Send + Sync + 'static {} +pub trait RpcMethod: Fn(RpcParams) -> Result + Send + Sync + 'static {} -impl RpcMethod for T where T: Fn(RpcParams) -> Result + Send + Sync + 'static {} +impl RpcMethod for T where T: Fn(RpcParams) -> Result + Send + Sync + 'static {} diff --git a/types/src/v2/error.rs b/types/src/v2/error.rs index 700093047f..e48875be65 100644 --- a/types/src/v2/error.rs +++ b/types/src/v2/error.rs @@ -80,6 +80,8 @@ pub const INVALID_PARAMS_CODE: i32 = -32602; pub const INVALID_REQUEST_CODE: i32 = -32600; /// Method not found error code. pub const METHOD_NOT_FOUND_CODE: i32 = -32601; +/// Custom server error when a call failed. +pub const CALL_EXECUTION_FAILED_CODE: i32 = -32000; /// Parse error message pub const PARSE_ERROR_MSG: &str = "Parse error"; diff --git a/types/src/v2/mod.rs b/types/src/v2/mod.rs index a3f977ab7e..aba31248db 100644 --- a/types/src/v2/mod.rs +++ b/types/src/v2/mod.rs @@ -1,4 +1,4 @@ -use crate::Error; +use crate::error::Error; use serde::de::DeserializeOwned; use serde_json::value::RawValue; @@ -12,11 +12,11 @@ pub mod request; pub mod response; /// Parse request ID from RawValue. -pub fn parse_request_id(raw: Option<&RawValue>) -> Result { +pub fn parse_request_id(raw: Option<&RawValue>) -> Result { match raw { None => Err(Error::InvalidRequestId), Some(v) => { - let val = serde_json::from_str(v.get()).map_err(Error::ParseError)?; + let val = serde_json::from_str(v.get()).map_err(|_| Error::InvalidRequestId)?; Ok(val) } } diff --git a/types/src/v2/params.rs b/types/src/v2/params.rs index 31fe824231..a97544078f 100644 --- a/types/src/v2/params.rs +++ b/types/src/v2/params.rs @@ -1,4 +1,4 @@ -use crate::error::Error; +use crate::error::InvalidParams; use alloc::collections::BTreeMap; use serde::de::{self, Deserializer, Unexpected, Visitor}; use serde::ser::Serializer; @@ -78,18 +78,18 @@ impl<'a> RpcParams<'a> { } /// Attempt to parse all parameters as array or map into type T - pub fn parse(self) -> Result + pub fn parse(self) -> Result where T: Deserialize<'a>, { match self.0 { - None => Err(Error::InvalidParams), - Some(params) => serde_json::from_str(params).map_err(|_| Error::InvalidParams), + None => Err(InvalidParams), + Some(params) => serde_json::from_str(params).map_err(|_| InvalidParams), } } /// Attempt to parse only the first parameter from an array into type T - pub fn one(self) -> Result + pub fn one(self) -> Result where T: Deserialize<'a>, { diff --git a/ws-client/src/client.rs b/ws-client/src/client.rs index c783fb3ed6..08ef718da6 100644 --- a/ws-client/src/client.rs +++ b/ws-client/src/client.rs @@ -253,7 +253,7 @@ impl<'a> WsClientBuilder<'a> { let (to_back, from_front) = mpsc::channel(self.max_concurrent_requests); let (err_tx, err_rx) = oneshot::channel(); - let (sockaddrs, host, mode) = parse_url(url).map_err(|e| Error::TransportError(Box::new(e)))?; + let (sockaddrs, host, mode) = parse_url(url).map_err(|e| Error::Transport(Box::new(e)))?; let builder = WsTransportClientBuilder { sockaddrs, @@ -265,7 +265,7 @@ impl<'a> WsClientBuilder<'a> { max_request_body_size: self.max_request_body_size, }; - let (sender, receiver) = builder.build().await.map_err(|e| Error::TransportError(Box::new(e)))?; + let (sender, receiver) = builder.build().await.map_err(|e| Error::Transport(Box::new(e)))?; async_std::task::spawn(async move { background_task(sender, receiver, from_front, err_tx, max_capacity_per_subscription).await; @@ -518,7 +518,7 @@ async fn background_task( .expect("ID unused checked above; qed"), Err(e) => { log::warn!("[backend]: client request failed: {:?}", e); - let _ = request.send_back.map(|s| s.send(Err(Error::TransportError(Box::new(e))))); + let _ = request.send_back.map(|s| s.send(Err(Error::Transport(Box::new(e))))); } } } @@ -535,7 +535,7 @@ async fn background_task( .expect("Request ID unused checked above; qed"), Err(e) => { log::warn!("[backend]: client subscription failed: {:?}", e); - let _ = sub.send_back.send(Err(Error::TransportError(Box::new(e)))); + let _ = sub.send_back.send(Err(Error::Transport(Box::new(e)))); } }, // User dropped a subscription. @@ -600,7 +600,7 @@ async fn background_task( } Either::Right((Some(Err(e)), _)) => { log::error!("Error: {:?} terminating client", e); - let _ = front_error.send(Error::TransportError(Box::new(e))); + let _ = front_error.send(Error::Transport(Box::new(e))); return; } Either::Right((None, _)) => { diff --git a/ws-server/Cargo.toml b/ws-server/Cargo.toml index fee1b28c13..12d2ca058a 100644 --- a/ws-server/Cargo.toml +++ b/ws-server/Cargo.toml @@ -27,5 +27,6 @@ tokio-stream = { version = "0.1.1", features = ["net"] } tokio-util = { version = "0.6", features = ["compat"] } [dev-dependencies] +env_logger = "0.8" jsonrpsee-test-utils = { path = "../test-utils" } jsonrpsee-ws-client = { path = "../ws-client" } diff --git a/ws-server/src/server.rs b/ws-server/src/server.rs index 5be73959a2..03021c8841 100644 --- a/ws-server/src/server.rs +++ b/ws-server/src/server.rs @@ -37,7 +37,7 @@ use tokio::net::{TcpListener, ToSocketAddrs}; use tokio_stream::{wrappers::TcpListenerStream, StreamExt}; use tokio_util::compat::TokioAsyncReadCompatExt; -use jsonrpsee_types::error::Error; +use jsonrpsee_types::error::{Error, InvalidParams}; use jsonrpsee_types::v2::error::JsonRpcErrorCode; use jsonrpsee_types::v2::params::{JsonRpcNotificationParams, RpcParams, TwoPointZero}; use jsonrpsee_types::v2::request::{JsonRpcInvalidRequest, JsonRpcNotification, JsonRpcRequest}; @@ -105,7 +105,7 @@ impl Server { pub fn register_method(&mut self, method_name: &'static str, callback: F) -> Result<(), Error> where R: Serialize, - F: Fn(RpcParams) -> Result + Send + Sync + 'static, + F: Fn(RpcParams) -> Result + Send + Sync + 'static, { self.root.register_method(method_name, callback) } diff --git a/ws-server/src/server/module.rs b/ws-server/src/server/module.rs index 142898165a..3dd6840566 100644 --- a/ws-server/src/server/module.rs +++ b/ws-server/src/server/module.rs @@ -1,7 +1,10 @@ use crate::server::{RpcParams, SubscriptionId, SubscriptionSink}; -use jsonrpsee_types::error::Error; -use jsonrpsee_types::traits::RpcMethod; -use jsonrpsee_utils::server::{send_response, Methods}; +use jsonrpsee_types::{error::InvalidParams, traits::RpcMethod, v2::error::CALL_EXECUTION_FAILED_CODE}; +use jsonrpsee_types::{ + error::{CallError, Error}, + v2::error::{JsonRpcErrorCode, JsonRpcErrorObject}, +}; +use jsonrpsee_utils::server::{send_error, send_response, Methods}; use parking_lot::Mutex; use rustc_hash::FxHashMap; use serde::Serialize; @@ -35,16 +38,17 @@ impl RpcModule { pub fn register_method(&mut self, method_name: &'static str, callback: F) -> Result<(), Error> where R: Serialize, - F: RpcMethod, + F: RpcMethod, { self.verify_method_name(method_name)?; self.methods.insert( method_name, Box::new(move |id, params, tx, _| { - let result = callback(params)?; - - send_response(id, tx, result); + match callback(params) { + Ok(res) => send_response(id, tx, res), + Err(InvalidParams) => send_error(id, tx, JsonRpcErrorCode::InvalidParams.into()), + }; Ok(()) }), @@ -95,7 +99,7 @@ impl RpcModule { self.methods.insert( unsubscribe_method_name, Box::new(move |id, params, tx, conn| { - let sub_id = params.one()?; + let sub_id = params.one().map_err(|e| anyhow::anyhow!("{:?}", e))?; subscribers.lock().remove(&(conn, sub_id)); @@ -142,7 +146,7 @@ impl RpcContextModule { where Context: Send + Sync + 'static, R: Serialize, - F: Fn(RpcParams, &Context) -> Result + Send + Sync + 'static, + F: Fn(RpcParams, &Context) -> Result + Send + Sync + 'static, { self.module.verify_method_name(method_name)?; @@ -151,14 +155,22 @@ impl RpcContextModule { self.module.methods.insert( method_name, Box::new(move |id, params, tx, _| { - let result = callback(params, &*ctx)?; - - send_response(id, tx, result); + match callback(params, &*ctx) { + Ok(res) => send_response(id, tx, res), + Err(CallError::InvalidParams(_)) => send_error(id, tx, JsonRpcErrorCode::InvalidParams.into()), + Err(CallError::Failed(err)) => { + let err = JsonRpcErrorObject { + code: JsonRpcErrorCode::ServerError(CALL_EXECUTION_FAILED_CODE), + message: &err.to_string(), + data: None, + }; + send_error(id, tx, err) + } + }; Ok(()) }), ); - Ok(()) } diff --git a/ws-server/src/tests.rs b/ws-server/src/tests.rs index 9c7ffe8107..9fb3b2c67a 100644 --- a/ws-server/src/tests.rs +++ b/ws-server/src/tests.rs @@ -1,16 +1,15 @@ #![cfg(test)] -use crate::WsServer; -use futures_channel::oneshot::{self, Sender}; +use crate::{RpcContextModule, WsServer}; use jsonrpsee_test_utils::helpers::*; -use jsonrpsee_test_utils::types::{Id, WebSocketTestClient}; -use jsonrpsee_types::error::Error; +use jsonrpsee_test_utils::types::{Id, TestContext, WebSocketTestClient}; +use jsonrpsee_types::error::{CallError, Error}; use serde_json::Value as JsonValue; use std::net::SocketAddr; /// Spawns a dummy `JSONRPC v2 WebSocket` /// It has two hardcoded methods: "say_hello" and "add" -pub async fn server(server_started: Sender) { +pub async fn server() -> SocketAddr { let mut server = WsServer::new("127.0.0.1:0").await.unwrap(); server @@ -26,17 +25,45 @@ pub async fn server(server_started: Sender) { Ok(sum) }) .unwrap(); - server_started.send(server.local_addr().unwrap()).unwrap(); + let addr = server.local_addr().unwrap(); - server.start().await; + tokio::spawn(async { server.start().await }); + addr +} + +/// Run server with user provided context. +pub async fn server_with_context() -> SocketAddr { + let mut server = WsServer::new("127.0.0.1:0").await.unwrap(); + + let ctx = TestContext; + let mut rpc_ctx = RpcContextModule::new(ctx); + + rpc_ctx + .register_method("should_err", |_p, ctx| { + let _ = ctx.err().map_err(|e| CallError::Failed(e.into()))?; + Ok("err") + }) + .unwrap(); + + rpc_ctx + .register_method("should_ok", |_p, ctx| { + let _ = ctx.ok().map_err(|e| CallError::Failed(e.into()))?; + Ok("ok") + }) + .unwrap(); + + let rpc_module = rpc_ctx.into_module(); + server.register_module(rpc_module).unwrap(); + let addr = server.local_addr().unwrap(); + + tokio::spawn(async { server.start().await }); + addr } #[tokio::test] async fn single_method_call_works() { - let (server_started_tx, server_started_rx) = oneshot::channel::(); - tokio::spawn(server(server_started_tx)); - let server_addr = server_started_rx.await.unwrap(); - let mut client = WebSocketTestClient::new(server_addr).await.unwrap(); + let addr = server().await; + let mut client = WebSocketTestClient::new(addr).await.unwrap(); for i in 0..10 { let req = format!(r#"{{"jsonrpc":"2.0","method":"say_hello","id":{}}}"#, i); @@ -48,22 +75,49 @@ async fn single_method_call_works() { #[tokio::test] async fn single_method_call_with_params_works() { - let (server_started_tx, server_started_rx) = oneshot::channel::(); - tokio::spawn(server(server_started_tx)); - let server_addr = server_started_rx.await.unwrap(); - let mut client = WebSocketTestClient::new(server_addr).await.unwrap(); + let addr = server().await; + let mut client = WebSocketTestClient::new(addr).await.unwrap(); let req = r#"{"jsonrpc":"2.0","method":"add", "params":[1, 2],"id":1}"#; let response = client.send_request_text(req).await.unwrap(); assert_eq!(response, ok_response(JsonValue::Number(3.into()), Id::Num(1))); } +#[tokio::test] +async fn single_method_call_with_faulty_params_returns_err() { + let _ = env_logger::try_init(); + let addr = server().await; + let mut client = WebSocketTestClient::new(addr).await.unwrap(); + + let req = r#"{"jsonrpc":"2.0","method":"add", "params":["Invalid"],"id":1}"#; + let response = client.send_request_text(req).await.unwrap(); + assert_eq!(response, invalid_params(Id::Num(1))); +} + +#[tokio::test] +async fn single_method_call_with_faulty_context() { + let addr = server_with_context().await; + let mut client = WebSocketTestClient::new(addr).await.unwrap(); + + let req = r#"{"jsonrpc":"2.0","method":"should_err", "params":[],"id":1}"#; + let response = client.send_request_text(req).await.unwrap(); + assert_eq!(response, invalid_context("RPC context failed", Id::Num(1))); +} + +#[tokio::test] +async fn single_method_call_with_ok_context() { + let addr = server_with_context().await; + let mut client = WebSocketTestClient::new(addr).await.unwrap(); + + let req = r#"{"jsonrpc":"2.0","method":"should_ok", "params":[],"id":1}"#; + let response = client.send_request_text(req).await.unwrap(); + assert_eq!(response, ok_response("ok".into(), Id::Num(1))); +} + #[tokio::test] async fn single_method_send_binary() { - let (server_started_tx, server_started_rx) = oneshot::channel::(); - tokio::spawn(server(server_started_tx)); - let server_addr = server_started_rx.await.unwrap(); - let mut client = WebSocketTestClient::new(server_addr).await.unwrap(); + let addr = server().await; + let mut client = WebSocketTestClient::new(addr).await.unwrap(); let req = r#"{"jsonrpc":"2.0","method":"add", "params":[1, 2],"id":1}"#; let response = client.send_request_binary(req.as_bytes()).await.unwrap(); @@ -72,10 +126,8 @@ async fn single_method_send_binary() { #[tokio::test] async fn should_return_method_not_found() { - let (server_started_tx, server_started_rx) = oneshot::channel::(); - tokio::spawn(server(server_started_tx)); - let server_addr = server_started_rx.await.unwrap(); - let mut client = WebSocketTestClient::new(server_addr).await.unwrap(); + let addr = server().await; + let mut client = WebSocketTestClient::new(addr).await.unwrap(); let req = r#"{"jsonrpc":"2.0","method":"bar","id":"foo"}"#; let response = client.send_request_text(req).await.unwrap(); @@ -84,11 +136,9 @@ async fn should_return_method_not_found() { #[tokio::test] async fn invalid_json_id_missing_value() { - let (server_started_tx, server_started_rx) = oneshot::channel::(); - tokio::spawn(server(server_started_tx)); - let server_addr = server_started_rx.await.unwrap(); + let addr = server().await; + let mut client = WebSocketTestClient::new(addr).await.unwrap(); - let mut client = WebSocketTestClient::new(server_addr).await.unwrap(); let req = r#"{"jsonrpc":"2.0","method":"say_hello","id"}"#; let response = client.send_request_text(req).await.unwrap(); // If there was an error in detecting the id in the Request object (e.g. Parse error/Invalid Request), it MUST be Null. @@ -97,11 +147,9 @@ async fn invalid_json_id_missing_value() { #[tokio::test] async fn invalid_request_object() { - let (server_started_tx, server_started_rx) = oneshot::channel::(); - tokio::spawn(server(server_started_tx)); - let server_addr = server_started_rx.await.unwrap(); + let addr = server().await; + let mut client = WebSocketTestClient::new(addr).await.unwrap(); - let mut client = WebSocketTestClient::new(server_addr).await.unwrap(); let req = r#"{"jsonrpc":"2.0","method":"bar","id":1,"is_not_request_object":1}"#; let response = client.send_request_text(req).await.unwrap(); assert_eq!(response, invalid_request(Id::Num(1))); @@ -131,11 +179,9 @@ async fn register_same_subscribe_unsubscribe_is_err() { #[tokio::test] async fn parse_error_request_should_not_close_connection() { - let (server_started_tx, server_started_rx) = oneshot::channel::(); - tokio::spawn(server(server_started_tx)); - let server_addr = server_started_rx.await.unwrap(); + let addr = server().await; + let mut client = WebSocketTestClient::new(addr).await.unwrap(); - let mut client = WebSocketTestClient::new(server_addr).await.unwrap(); let invalid_request = r#"{"jsonrpc":"2.0","method":"bar","params":[1,"id":99}"#; let response1 = client.send_request_text(invalid_request).await.unwrap(); assert_eq!(response1, parse_error(Id::Null)); @@ -146,11 +192,9 @@ async fn parse_error_request_should_not_close_connection() { #[tokio::test] async fn invalid_request_should_not_close_connection() { - let (server_started_tx, server_started_rx) = oneshot::channel::(); - tokio::spawn(server(server_started_tx)); - let server_addr = server_started_rx.await.unwrap(); + let addr = server().await; + let mut client = WebSocketTestClient::new(addr).await.unwrap(); - let mut client = WebSocketTestClient::new(server_addr).await.unwrap(); let req = r#"{"jsonrpc":"2.0","method":"bar","id":1,"is_not_request_object":1}"#; let response = client.send_request_text(req).await.unwrap(); assert_eq!(response, invalid_request(Id::Num(1)));