From 56bd2ff6b987d1e515166c2a866cdb4ddbaf1c28 Mon Sep 17 00:00:00 2001 From: Maciej Hirsz Date: Sun, 11 Jul 2021 11:23:39 +0100 Subject: [PATCH 1/7] Do not register methods on servers --- benches/helpers.rs | 10 ++-- examples/http.rs | 5 +- examples/proc_macro.rs | 5 +- examples/weather.rs | 6 +-- examples/ws.rs | 5 +- examples/ws_sub_with_params.rs | 5 +- examples/ws_subscription.rs | 5 +- http-server/src/server.rs | 29 ++++++----- http-server/src/tests.rs | 15 +++--- proc-macros/src/lib.rs | 6 +-- proc-macros/tests/ui/correct/basic.rs | 5 +- proc-macros/tests/ui/correct/only_server.rs | 3 +- tests/tests/helpers.rs | 15 +++--- tests/tests/new_proc_macros.rs | 5 +- utils/src/server/rpc_module.rs | 56 ++++++++++++--------- ws-server/src/server.rs | 20 ++------ ws-server/src/tests.rs | 31 +++++------- 17 files changed, 97 insertions(+), 129 deletions(-) diff --git a/benches/helpers.rs b/benches/helpers.rs index 64e7da54d9..e8dc250b3a 100644 --- a/benches/helpers.rs +++ b/benches/helpers.rs @@ -14,14 +14,13 @@ pub(crate) const UNSUB_METHOD_NAME: &str = "unsub"; pub async fn http_server() -> String { let (server_started_tx, server_started_rx) = oneshot::channel(); tokio::spawn(async move { - let mut server = + let server = HttpServerBuilder::default().max_request_body_size(u32::MAX).build("127.0.0.1:0".parse().unwrap()).unwrap(); let mut module = RpcModule::new(()); module.register_method(SYNC_METHOD_NAME, |_, _| Ok("lo")).unwrap(); module.register_async_method(ASYNC_METHOD_NAME, |_, _| (async { Ok("lo") }).boxed()).unwrap(); - server.register_module(module).unwrap(); server_started_tx.send(server.local_addr().unwrap()).unwrap(); - server.start().await + server.start(module).await }); format!("http://{}", server_started_rx.await.unwrap()) } @@ -30,7 +29,7 @@ pub async fn http_server() -> String { pub async fn ws_server() -> String { let (server_started_tx, server_started_rx) = oneshot::channel(); tokio::spawn(async move { - let mut server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap(); + let server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap(); let mut module = RpcModule::new(()); module.register_method(SYNC_METHOD_NAME, |_, _| Ok("lo")).unwrap(); module.register_async_method(ASYNC_METHOD_NAME, |_, _| (async { Ok("lo") }).boxed()).unwrap(); @@ -42,9 +41,8 @@ pub async fn ws_server() -> String { }) .unwrap(); - server.register_module(module).unwrap(); server_started_tx.send(server.local_addr().unwrap()).unwrap(); - server.start().await + server.start(module).await }); format!("ws://{}", server_started_rx.await.unwrap()) } diff --git a/examples/http.rs b/examples/http.rs index 761c279f39..5b7f822dd9 100644 --- a/examples/http.rs +++ b/examples/http.rs @@ -47,12 +47,11 @@ async fn main() -> anyhow::Result<()> { } async fn run_server() -> anyhow::Result { - let mut server = HttpServerBuilder::default().build("127.0.0.1:0".parse()?)?; + let server = HttpServerBuilder::default().build("127.0.0.1:0".parse()?)?; let mut module = RpcModule::new(()); module.register_method("say_hello", |_, _| Ok("lo"))?; - server.register_module(module).unwrap(); let addr = server.local_addr()?; - tokio::spawn(async move { server.start().await }); + tokio::spawn(server.start(module)); Ok(addr) } diff --git a/examples/proc_macro.rs b/examples/proc_macro.rs index a384b6b2c2..7ffebfaac3 100644 --- a/examples/proc_macro.rs +++ b/examples/proc_macro.rs @@ -52,12 +52,11 @@ async fn main() -> anyhow::Result<()> { } async fn run_server() -> anyhow::Result { - let mut server = HttpServerBuilder::default().build("127.0.0.1:0".parse()?)?; + let server = HttpServerBuilder::default().build("127.0.0.1:0".parse()?)?; let mut module = RpcModule::new(()); module.register_method("state_getPairs", |_, _| Ok(vec![1, 2, 3]))?; - server.register_module(module).unwrap(); let addr = server.local_addr()?; - tokio::spawn(async move { server.start().await }); + tokio::spawn(server.start(module)); Ok(addr) } diff --git a/examples/weather.rs b/examples/weather.rs index c7bab569e3..779da59dae 100644 --- a/examples/weather.rs +++ b/examples/weather.rs @@ -99,7 +99,7 @@ struct WeatherApiCx { } async fn run_server() -> anyhow::Result { - let mut server = WsServerBuilder::default().build("127.0.0.1:0").await?; + let server = WsServerBuilder::default().build("127.0.0.1:0").await?; let api_client = restson::RestClient::new("http://api.openweathermap.org").unwrap(); let last_weather = Weather::default(); @@ -125,9 +125,7 @@ async fn run_server() -> anyhow::Result { }) .unwrap(); - server.register_module(module).unwrap(); - let addr = server.local_addr()?; - tokio::spawn(async move { server.start().await }); + tokio::spawn(server.start(module)); Ok(addr) } diff --git a/examples/ws.rs b/examples/ws.rs index 33be48c572..baf5ef32db 100644 --- a/examples/ws.rs +++ b/examples/ws.rs @@ -44,11 +44,10 @@ async fn main() -> anyhow::Result<()> { } async fn run_server() -> anyhow::Result { - let mut server = WsServerBuilder::default().build("127.0.0.1:0").await?; + let server = WsServerBuilder::default().build("127.0.0.1:0").await?; let mut module = RpcModule::new(()); module.register_method("say_hello", |_, _| Ok("lo"))?; - server.register_module(module).unwrap(); let addr = server.local_addr()?; - tokio::spawn(async move { server.start().await }); + tokio::spawn(server.start(module)); Ok(addr) } diff --git a/examples/ws_sub_with_params.rs b/examples/ws_sub_with_params.rs index 3a56f9b363..0c41a03b2b 100644 --- a/examples/ws_sub_with_params.rs +++ b/examples/ws_sub_with_params.rs @@ -53,7 +53,7 @@ async fn main() -> anyhow::Result<()> { async fn run_server() -> anyhow::Result { const LETTERS: &str = "abcdefghijklmnopqrstuvxyz"; - let mut server = WsServerBuilder::default().build("127.0.0.1:0").await?; + let server = WsServerBuilder::default().build("127.0.0.1:0").await?; let mut module = RpcModule::new(()); module .register_subscription("sub_one_param", "unsub_one_param", |params, mut sink, _| { @@ -76,8 +76,7 @@ async fn run_server() -> anyhow::Result { }) .unwrap(); - server.register_module(module).unwrap(); let addr = server.local_addr()?; - tokio::spawn(async move { server.start().await }); + tokio::spawn(server.start(module)); Ok(addr) } diff --git a/examples/ws_subscription.rs b/examples/ws_subscription.rs index 2942d06858..906375ba07 100644 --- a/examples/ws_subscription.rs +++ b/examples/ws_subscription.rs @@ -54,7 +54,7 @@ async fn main() -> anyhow::Result<()> { } async fn run_server() -> anyhow::Result { - let mut server = WsServerBuilder::default().build("127.0.0.1:0").await?; + let server = WsServerBuilder::default().build("127.0.0.1:0").await?; let mut module = RpcModule::new(()); module.register_subscription("subscribe_hello", "unsubscribe_hello", |_, mut sink, _| { std::thread::spawn(move || loop { @@ -65,8 +65,7 @@ async fn run_server() -> anyhow::Result { }); Ok(()) })?; - server.register_module(module).unwrap(); let addr = server.local_addr()?; - tokio::spawn(async move { server.start().await }); + tokio::spawn(server.start(module)); Ok(addr) } diff --git a/http-server/src/server.rs b/http-server/src/server.rs index a7ad679bcb..1f94dcccb1 100644 --- a/http-server/src/server.rs +++ b/http-server/src/server.rs @@ -37,7 +37,6 @@ use jsonrpsee_types::v2::error::JsonRpcErrorCode; use jsonrpsee_types::v2::params::Id; use jsonrpsee_types::v2::request::{JsonRpcInvalidRequest, JsonRpcNotification, JsonRpcRequest}; use jsonrpsee_utils::hyper_helpers::read_response_to_body; -use jsonrpsee_utils::server::rpc_module::RpcModule; use jsonrpsee_utils::server::{ helpers::{collect_batch_response, send_error}, rpc_module::Methods, @@ -155,18 +154,18 @@ pub struct Server { } impl Server { - /// Register all methods from a [`Methods`] of provided [`RpcModule`] on this server. - /// In case a method already is registered with the same name, no method is added and a [`Error::MethodAlreadyRegistered`] - /// is returned. Note that the [`RpcModule`] is consumed after this call. - pub fn register_module(&mut self, module: RpcModule) -> Result<(), Error> { - self.methods.merge(module.into_methods())?; - Ok(()) - } - - /// Returns a `Vec` with all the method names registered on this server. - pub fn method_names(&self) -> Vec<&'static str> { - self.methods.method_names() - } + // /// Register all methods from a [`Methods`] of provided [`RpcModule`] on this server. + // /// In case a method already is registered with the same name, no method is added and a [`Error::MethodAlreadyRegistered`] + // /// is returned. Note that the [`RpcModule`] is consumed after this call. + // pub fn register_module(&mut self, module: RpcModule) -> Result<(), Error> { + // self.methods.merge(module.into_methods())?; + // Ok(()) + // } + + // /// Returns a `Vec` with all the method names registered on this server. + // pub fn method_names(&self) -> Vec<&'static str> { + // self.methods.method_names() + // } /// Returns socket address to which the server is bound. pub fn local_addr(&self) -> Result { @@ -179,15 +178,15 @@ impl Server { } /// Start the server. - pub async fn start(self) -> Result<(), Error> { + pub async fn start(self, methods: impl Into) -> Result<(), Error> { // Lock the stop mutex so existing stop handles can wait for server to stop. // It will be unlocked once this function returns. let _stop_handle = self.stop_handle.lock().await; - let methods = Arc::new(self.methods); let max_request_body_size = self.max_request_body_size; let access_control = self.access_control; let mut stop_receiver = self.stop_pair.1; + let methods = methods.into(); let make_service = make_service_fn(move |_| { let methods = methods.clone(); diff --git a/http-server/src/tests.rs b/http-server/src/tests.rs index 14532e04b1..0afd04d1d2 100644 --- a/http-server/src/tests.rs +++ b/http-server/src/tests.rs @@ -16,7 +16,7 @@ async fn server() -> SocketAddr { } async fn server_with_handles() -> (SocketAddr, JoinHandle>, StopHandle) { - let mut server = HttpServerBuilder::default().build("127.0.0.1:0".parse().unwrap()).unwrap(); + let server = HttpServerBuilder::default().build("127.0.0.1:0".parse().unwrap()).unwrap(); let ctx = TestContext; let mut module = RpcModule::new(ctx); let addr = server.local_addr().unwrap(); @@ -60,9 +60,8 @@ async fn server_with_handles() -> (SocketAddr, JoinHandle>, St }) .unwrap(); - server.register_module(module).unwrap(); let stop_handle = server.stop_handle(); - let join_handle = tokio::spawn(async move { server.start().with_default_timeout().await.unwrap() }); + let join_handle = tokio::spawn(async move { server.start(module).with_default_timeout().await.unwrap() }); (addr, join_handle, stop_handle) } @@ -296,8 +295,7 @@ async fn can_register_modules() { let cx2 = Vec::::new(); let mut mod2 = RpcModule::new(cx2); - let mut server = HttpServerBuilder::default().build("127.0.0.1:0".parse().unwrap()).unwrap(); - assert_eq!(server.method_names().len(), 0); + assert_eq!(mod1.method_names().count(), 0); mod1.register_method("bla", |_, cx| Ok(format!("Gave me {}", cx))).unwrap(); mod1.register_method("bla2", |_, cx| Ok(format!("Gave me {}", cx))).unwrap(); mod2.register_method("yada", |_, cx| Ok(format!("Gave me {:?}", cx))).unwrap(); @@ -305,14 +303,13 @@ async fn can_register_modules() { // Won't register, name clashes mod2.register_method("bla", |_, cx| Ok(format!("Gave me {:?}", cx))).unwrap(); - server.register_module(mod1).unwrap(); - assert_eq!(server.method_names().len(), 2); + assert_eq!(mod1.method_names().count(), 2); - let err = server.register_module(mod2).unwrap_err(); + let err = mod1.merge(mod2).unwrap_err(); let expected_err = Error::MethodAlreadyRegistered(String::from("bla")); assert_eq!(err.to_string(), expected_err.to_string()); - assert_eq!(server.method_names().len(), 2); + assert_eq!(mod1.method_names().count(), 2); } #[tokio::test] diff --git a/proc-macros/src/lib.rs b/proc-macros/src/lib.rs index 80af4b1c04..7453202d6c 100644 --- a/proc-macros/src/lib.rs +++ b/proc-macros/src/lib.rs @@ -325,14 +325,14 @@ pub fn rpc_client_api(input_token_stream: TokenStream) -> TokenStream { /// /// std::thread::spawn(move || { /// let rt = tokio::runtime::Runtime::new().unwrap(); -/// let mut server = rt.block_on(WsServerBuilder::default().build("127.0.0.1:0")).unwrap(); +/// let server = rt.block_on(WsServerBuilder::default().build("127.0.0.1:0")).unwrap(); /// // `into_rpc()` method was generated inside of the `RpcServer` trait under the hood. -/// server.register_module(RpcServerImpl.into_rpc().unwrap()).unwrap(); +/// let module = RpcServerImpl.into_rpc().unwrap(); /// /// rt.block_on(async move { /// server_started_tx.send(server.local_addr().unwrap()).unwrap(); /// -/// server.start().await +/// server.start(module).await /// }); /// }); /// diff --git a/proc-macros/tests/ui/correct/basic.rs b/proc-macros/tests/ui/correct/basic.rs index e7023cfad2..d8e97a52c0 100644 --- a/proc-macros/tests/ui/correct/basic.rs +++ b/proc-macros/tests/ui/correct/basic.rs @@ -51,13 +51,12 @@ pub async fn websocket_server() -> SocketAddr { std::thread::spawn(move || { let rt = tokio::runtime::Runtime::new().unwrap(); - let mut server = rt.block_on(WsServerBuilder::default().build("127.0.0.1:0")).unwrap(); - server.register_module(RpcServerImpl.into_rpc().unwrap()).unwrap(); + let server = rt.block_on(WsServerBuilder::default().build("127.0.0.1:0")).unwrap(); rt.block_on(async move { server_started_tx.send(server.local_addr().unwrap()).unwrap(); - server.start().await + server.start(RpcServerImpl.into_rpc().unwrap()).await }); }); diff --git a/proc-macros/tests/ui/correct/only_server.rs b/proc-macros/tests/ui/correct/only_server.rs index 0751c9d7f9..cc79c1c471 100644 --- a/proc-macros/tests/ui/correct/only_server.rs +++ b/proc-macros/tests/ui/correct/only_server.rs @@ -41,12 +41,11 @@ pub async fn websocket_server() -> SocketAddr { std::thread::spawn(move || { let rt = tokio::runtime::Runtime::new().unwrap(); let mut server = rt.block_on(WsServerBuilder::default().build("127.0.0.1:0")).unwrap(); - server.register_module(RpcServerImpl.into_rpc().unwrap()).unwrap(); rt.block_on(async move { server_started_tx.send(server.local_addr().unwrap()).unwrap(); - server.start().await + server.start(RpcServerImpl.into_rpc().unwrap()).await }); }); diff --git a/tests/tests/helpers.rs b/tests/tests/helpers.rs index 2b6991355e..c92243d4d3 100644 --- a/tests/tests/helpers.rs +++ b/tests/tests/helpers.rs @@ -39,7 +39,7 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsStopHandle) std::thread::spawn(move || { let rt = tokio::runtime::Runtime::new().unwrap(); - let mut server = rt.block_on(WsServerBuilder::default().build("127.0.0.1:0")).unwrap(); + let server = rt.block_on(WsServerBuilder::default().build("127.0.0.1:0")).unwrap(); let mut module = RpcModule::new(()); module.register_method("say_hello", |_, _| Ok("hello")).unwrap(); @@ -86,10 +86,9 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsStopHandle) }) .unwrap(); - server.register_module(module).unwrap(); rt.block_on(async move { server_started_tx.send((server.local_addr().unwrap(), server.stop_handle())).unwrap(); - server.start().await + server.start(module).await }); }); @@ -101,15 +100,14 @@ pub async fn websocket_server() -> SocketAddr { std::thread::spawn(move || { let rt = tokio::runtime::Runtime::new().unwrap(); - let mut server = rt.block_on(WsServerBuilder::default().build("127.0.0.1:0")).unwrap(); + let server = rt.block_on(WsServerBuilder::default().build("127.0.0.1:0")).unwrap(); let mut module = RpcModule::new(()); module.register_method("say_hello", |_, _| Ok("hello")).unwrap(); - server.register_module(module).unwrap(); rt.block_on(async move { server_started_tx.send(server.local_addr().unwrap()).unwrap(); - server.start().await + server.start(module).await }); }); @@ -117,13 +115,12 @@ pub async fn websocket_server() -> SocketAddr { } pub async fn http_server() -> SocketAddr { - let mut server = HttpServerBuilder::default().build("127.0.0.1:0".parse().unwrap()).unwrap(); + let server = HttpServerBuilder::default().build("127.0.0.1:0".parse().unwrap()).unwrap(); let mut module = RpcModule::new(()); let addr = server.local_addr().unwrap(); module.register_method("say_hello", |_, _| Ok("hello")).unwrap(); module.register_method("notif", |_, _| Ok("")).unwrap(); - server.register_module(module).unwrap(); - tokio::spawn(async move { server.start().await.unwrap() }); + tokio::spawn(server.start(module)); addr } diff --git a/tests/tests/new_proc_macros.rs b/tests/tests/new_proc_macros.rs index fdf729a72d..03c7971a39 100644 --- a/tests/tests/new_proc_macros.rs +++ b/tests/tests/new_proc_macros.rs @@ -55,13 +55,12 @@ pub async fn websocket_server() -> SocketAddr { std::thread::spawn(move || { let rt = tokio::runtime::Runtime::new().unwrap(); - let mut server = rt.block_on(WsServerBuilder::default().build("127.0.0.1:0")).unwrap(); - server.register_module(RpcServerImpl.into_rpc().unwrap()).unwrap(); + let server = rt.block_on(WsServerBuilder::default().build("127.0.0.1:0")).unwrap(); rt.block_on(async move { server_started_tx.send(server.local_addr().unwrap()).unwrap(); - server.start().await + server.start(RpcServerImpl.into_rpc().unwrap()).await }); }); diff --git a/utils/src/server/rpc_module.rs b/utils/src/server/rpc_module.rs index ee09d1b4ff..979b024ba3 100644 --- a/utils/src/server/rpc_module.rs +++ b/utils/src/server/rpc_module.rs @@ -14,6 +14,7 @@ use rustc_hash::FxHashMap; use serde::Serialize; use std::fmt::Debug; use std::sync::Arc; +use std::ops::{Deref, DerefMut}; /// A `Method` is an RPC endpoint, callable with a standard JSON-RPC request, /// implemented as a function pointer to a `Fn` function taking four arguments: @@ -110,7 +111,9 @@ impl Methods { /// Merge two [`Methods`]'s by adding all [`MethodCallback`]s from `other` into `self`. /// Fails if any of the methods in `other` is present already. - pub fn merge(&mut self, mut other: Methods) -> Result<(), Error> { + pub fn merge(&mut self, other: impl Into) -> Result<(), Error> { + let mut other = other.into(); + for name in other.callbacks.keys() { self.verify_method_name(name)?; } @@ -137,9 +140,23 @@ impl Methods { } } - /// Returns a `Vec` with all the method names registered on this server. - pub fn method_names(&self) -> Vec<&'static str> { - self.callbacks.keys().copied().collect() + /// Returns an `Iterator` with all the method names registered on this server. + pub fn method_names(&self) -> impl Iterator + '_ { + self.callbacks.keys().copied() + } +} + +impl Deref for RpcModule { + type Target = Methods; + + fn deref(&self) -> &Methods { + &self.methods + } +} + +impl DerefMut for RpcModule { + fn deref_mut(&mut self) -> &mut Methods { + &mut self.methods } } @@ -157,18 +174,11 @@ impl RpcModule { pub fn new(ctx: Context) -> Self { Self { ctx: Arc::new(ctx), methods: Default::default() } } +} - /// Convert a module into methods. Consumes self. - pub fn into_methods(self) -> Methods { - self.methods - } - - /// Merge two [`RpcModule`]'s by adding all [`Methods`] `other` into `self`. - /// Fails if any of the methods in `other` is present already. - pub fn merge(&mut self, other: RpcModule) -> Result<(), Error> { - self.methods.merge(other.methods)?; - - Ok(()) +impl From> for Methods { + fn from(module: RpcModule) -> Methods { + module.methods } } @@ -445,9 +455,8 @@ mod tests { mod1.merge(mod2).unwrap(); - let methods = mod1.into_methods(); - assert!(methods.method("bla with Vec context").is_some()); - assert!(methods.method("bla with String context").is_some()); + assert!(mod1.method("bla with Vec context").is_some()); + assert!(mod1.method("bla with String context").is_some()); } #[test] @@ -456,9 +465,8 @@ mod tests { let mut cxmodule = RpcModule::new(cx); let _subscription = cxmodule.register_subscription("hi", "goodbye", |_, _, _| Ok(())); - let methods = cxmodule.into_methods(); - assert!(methods.method("hi").is_some()); - assert!(methods.method("goodbye").is_some()); + assert!(cxmodule.method("hi").is_some()); + assert!(cxmodule.method("goodbye").is_some()); } #[test] @@ -468,9 +476,7 @@ mod tests { module.register_method("hello_world", |_: RpcParams, _| Ok(())).unwrap(); module.register_alias("hello_foobar", "hello_world").unwrap(); - let methods = module.into_methods(); - - assert!(methods.method("hello_world").is_some()); - assert!(methods.method("hello_foobar").is_some()); + assert!(module.method("hello_world").is_some()); + assert!(module.method("hello_foobar").is_some()); } } diff --git a/ws-server/src/server.rs b/ws-server/src/server.rs index 020b8826c6..471dc96d2c 100644 --- a/ws-server/src/server.rs +++ b/ws-server/src/server.rs @@ -49,7 +49,7 @@ use jsonrpsee_types::v2::error::JsonRpcErrorCode; use jsonrpsee_types::v2::params::Id; use jsonrpsee_types::v2::request::{JsonRpcInvalidRequest, JsonRpcRequest}; use jsonrpsee_utils::server::helpers::{collect_batch_response, send_error}; -use jsonrpsee_utils::server::rpc_module::{ConnectionId, Methods, RpcModule}; +use jsonrpsee_utils::server::rpc_module::{ConnectionId, Methods}; /// Default maximum connections allowed. const MAX_CONNECTIONS: u64 = 100; @@ -67,19 +67,6 @@ pub struct Server { } impl Server { - /// Register all methods from a [`Methods`] of provided [`RpcModule`] on this server. - /// In case a method already is registered with the same name, no method is added and a [`Error::MethodAlreadyRegistered`] - /// is returned. Note that the [`RpcModule`] is consumed after this call. - pub fn register_module(&mut self, module: RpcModule) -> Result<(), Error> { - self.methods.merge(module.into_methods())?; - Ok(()) - } - - /// Returns a `Vec` with all the method names registered on this server. - pub fn method_names(&self) -> Vec<&'static str> { - self.methods.method_names() - } - /// Returns socket address to which the server is bound. pub fn local_addr(&self) -> Result { self.listener.local_addr().map_err(Into::into) @@ -91,15 +78,14 @@ impl Server { } /// Start responding to connections requests. This will block current thread until the server is stopped. - pub async fn start(self) { + pub async fn start(self, methods: impl Into) { // Acquire read access to the lock such that additional reader(s) may share this lock. // Write access to this lock will only be possible after the server and all background tasks have stopped. let _stop_handle = self.stop_handle.read().await; let shutdown = self.stop_pair.0; + let methods = methods.into(); - let methods = self.methods; let mut id = 0; - let mut driver = ConnDriver::new(self.listener, self.stop_pair.1); loop { diff --git a/ws-server/src/tests.rs b/ws-server/src/tests.rs index bf3b286bb4..de0c4de515 100644 --- a/ws-server/src/tests.rs +++ b/ws-server/src/tests.rs @@ -34,7 +34,7 @@ async fn server() -> SocketAddr { /// It has two hardcoded methods: "say_hello" and "add" /// Returns the address together with handles for server future and server stop. async fn server_with_handles() -> (SocketAddr, JoinHandle<()>, StopHandle) { - let mut server = WsServerBuilder::default().build("127.0.0.1:0").with_default_timeout().await.unwrap().unwrap(); + let server = WsServerBuilder::default().build("127.0.0.1:0").with_default_timeout().await.unwrap().unwrap(); let mut module = RpcModule::new(()); module .register_method("say_hello", |_, _| { @@ -71,16 +71,15 @@ async fn server_with_handles() -> (SocketAddr, JoinHandle<()>, StopHandle) { .unwrap(); let addr = server.local_addr().unwrap(); - server.register_module(module).unwrap(); let stop_handle = server.stop_handle(); - let join_handle = tokio::spawn(server.start()); + let join_handle = tokio::spawn(server.start(module)); (addr, join_handle, stop_handle) } /// Run server with user provided context. async fn server_with_context() -> SocketAddr { - let mut server = WsServerBuilder::default().build("127.0.0.1:0").with_default_timeout().await.unwrap().unwrap(); + let server = WsServerBuilder::default().build("127.0.0.1:0").with_default_timeout().await.unwrap().unwrap(); let ctx = TestContext; let mut rpc_module = RpcModule::new(ctx); @@ -121,10 +120,9 @@ async fn server_with_context() -> SocketAddr { }) .unwrap(); - server.register_module(rpc_module).unwrap(); let addr = server.local_addr().unwrap(); - tokio::spawn(server.start()); + tokio::spawn(server.start(rpc_module)); addr } @@ -132,12 +130,11 @@ async fn server_with_context() -> SocketAddr { async fn can_set_the_max_request_body_size() { let addr = "127.0.0.1:0"; // Rejects all requests larger than 10 bytes - let mut server = WsServerBuilder::default().max_request_body_size(10).build(addr).await.unwrap(); + let server = WsServerBuilder::default().max_request_body_size(10).build(addr).await.unwrap(); let mut module = RpcModule::new(()); module.register_method("anything", |_p, _cx| Ok(())).unwrap(); - server.register_module(module).unwrap(); let addr = server.local_addr().unwrap(); - tokio::spawn(async { server.start().await }); + tokio::spawn(server.start(module)); let mut client = WebSocketTestClient::new(addr).await.unwrap(); @@ -156,13 +153,12 @@ async fn can_set_the_max_request_body_size() { async fn can_set_max_connections() { let addr = "127.0.0.1:0"; // Server that accepts max 2 connections - let mut server = WsServerBuilder::default().max_connections(2).build(addr).await.unwrap(); + let server = WsServerBuilder::default().max_connections(2).build(addr).await.unwrap(); let mut module = RpcModule::new(()); module.register_method("anything", |_p, _cx| Ok(())).unwrap(); - server.register_module(module).unwrap(); let addr = server.local_addr().unwrap(); - tokio::spawn(async { server.start().await }); + tokio::spawn(server.start(module)); let conn1 = WebSocketTestClient::new(addr).await; let conn2 = WebSocketTestClient::new(addr).await; @@ -438,8 +434,8 @@ async fn can_register_modules() { let cx2 = Vec::::new(); let mut mod2 = RpcModule::new(cx2); - let mut server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap(); - assert_eq!(server.method_names().len(), 0); + assert_eq!(mod1.method_names().count(), 0); + assert_eq!(mod2.method_names().count(), 0); mod1.register_method("bla", |_, cx| Ok(format!("Gave me {}", cx))).unwrap(); mod1.register_method("bla2", |_, cx| Ok(format!("Gave me {}", cx))).unwrap(); mod2.register_method("yada", |_, cx| Ok(format!("Gave me {:?}", cx))).unwrap(); @@ -447,12 +443,11 @@ async fn can_register_modules() { // Won't register, name clashes mod2.register_method("bla", |_, cx| Ok(format!("Gave me {:?}", cx))).unwrap(); - server.register_module(mod1).unwrap(); - assert_eq!(server.method_names().len(), 2); - let err = server.register_module(mod2).unwrap_err(); + assert_eq!(mod1.method_names().count(), 2); + let err = mod1.merge(mod2).unwrap_err(); let _expected_err = Error::MethodAlreadyRegistered(String::from("bla")); assert!(matches!(err, _expected_err)); - assert_eq!(server.method_names().len(), 2); + assert_eq!(mod1.method_names().count(), 2); } #[tokio::test] From 04c312ede7c5d705515c275da1ca8602dd73b707 Mon Sep 17 00:00:00 2001 From: Maciej Hirsz Date: Sun, 11 Jul 2021 11:24:02 +0100 Subject: [PATCH 2/7] fmt --- utils/src/server/rpc_module.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/src/server/rpc_module.rs b/utils/src/server/rpc_module.rs index 979b024ba3..e4a2642131 100644 --- a/utils/src/server/rpc_module.rs +++ b/utils/src/server/rpc_module.rs @@ -13,8 +13,8 @@ use parking_lot::Mutex; use rustc_hash::FxHashMap; use serde::Serialize; use std::fmt::Debug; -use std::sync::Arc; use std::ops::{Deref, DerefMut}; +use std::sync::Arc; /// A `Method` is an RPC endpoint, callable with a standard JSON-RPC request, /// implemented as a function pointer to a `Fn` function taking four arguments: From 861a894f16e4e290e34271bec3cef9e02f5642d1 Mon Sep 17 00:00:00 2001 From: Maciej Hirsz Date: Sun, 11 Jul 2021 15:35:44 +0200 Subject: [PATCH 3/7] Infallible `to_rpc` proc macro --- proc-macros/src/lib.rs | 3 +-- proc-macros/src/new/render_server.rs | 14 +++++++++----- proc-macros/tests/ui/correct/basic.rs | 2 +- proc-macros/tests/ui/correct/only_server.rs | 4 ++-- tests/tests/new_proc_macros.rs | 2 +- 5 files changed, 14 insertions(+), 11 deletions(-) diff --git a/proc-macros/src/lib.rs b/proc-macros/src/lib.rs index 7453202d6c..b3afac7617 100644 --- a/proc-macros/src/lib.rs +++ b/proc-macros/src/lib.rs @@ -327,12 +327,11 @@ pub fn rpc_client_api(input_token_stream: TokenStream) -> TokenStream { /// let rt = tokio::runtime::Runtime::new().unwrap(); /// let server = rt.block_on(WsServerBuilder::default().build("127.0.0.1:0")).unwrap(); /// // `into_rpc()` method was generated inside of the `RpcServer` trait under the hood. -/// let module = RpcServerImpl.into_rpc().unwrap(); /// /// rt.block_on(async move { /// server_started_tx.send(server.local_addr().unwrap()).unwrap(); /// -/// server.start(module).await +/// server.start(RpcServerImpl.into_rpc()).await /// }); /// }); /// diff --git a/proc-macros/src/new/render_server.rs b/proc-macros/src/new/render_server.rs index 0851035000..f5bc620052 100644 --- a/proc-macros/src/new/render_server.rs +++ b/proc-macros/src/new/render_server.rs @@ -106,13 +106,17 @@ impl RpcDescription { Ok(quote! { #[doc = #doc_comment] - fn into_rpc(self) -> Result<#rpc_module, #jrps_error> { - let mut rpc = #rpc_module::new(self); + fn into_rpc(self) -> #rpc_module { + let inner = move || -> Result<#rpc_module, #jrps_error> { + let mut rpc = #rpc_module::new(self); - #(#methods)* - #(#subscriptions)* + #(#methods)* + #(#subscriptions)* - Ok(rpc) + Ok(rpc) + }; + + inner().expect("Proc macro method names should never conflict") } }) } diff --git a/proc-macros/tests/ui/correct/basic.rs b/proc-macros/tests/ui/correct/basic.rs index d8e97a52c0..f310ece891 100644 --- a/proc-macros/tests/ui/correct/basic.rs +++ b/proc-macros/tests/ui/correct/basic.rs @@ -56,7 +56,7 @@ pub async fn websocket_server() -> SocketAddr { rt.block_on(async move { server_started_tx.send(server.local_addr().unwrap()).unwrap(); - server.start(RpcServerImpl.into_rpc().unwrap()).await + server.start(RpcServerImpl.into_rpc()).await }); }); diff --git a/proc-macros/tests/ui/correct/only_server.rs b/proc-macros/tests/ui/correct/only_server.rs index cc79c1c471..19922bfbe9 100644 --- a/proc-macros/tests/ui/correct/only_server.rs +++ b/proc-macros/tests/ui/correct/only_server.rs @@ -40,12 +40,12 @@ pub async fn websocket_server() -> SocketAddr { std::thread::spawn(move || { let rt = tokio::runtime::Runtime::new().unwrap(); - let mut server = rt.block_on(WsServerBuilder::default().build("127.0.0.1:0")).unwrap(); + let server = rt.block_on(WsServerBuilder::default().build("127.0.0.1:0")).unwrap(); rt.block_on(async move { server_started_tx.send(server.local_addr().unwrap()).unwrap(); - server.start(RpcServerImpl.into_rpc().unwrap()).await + server.start(RpcServerImpl.into_rpc()).await }); }); diff --git a/tests/tests/new_proc_macros.rs b/tests/tests/new_proc_macros.rs index 03c7971a39..9b386ea472 100644 --- a/tests/tests/new_proc_macros.rs +++ b/tests/tests/new_proc_macros.rs @@ -60,7 +60,7 @@ pub async fn websocket_server() -> SocketAddr { rt.block_on(async move { server_started_tx.send(server.local_addr().unwrap()).unwrap(); - server.start(RpcServerImpl.into_rpc().unwrap()).await + server.start(RpcServerImpl.into_rpc()).await }); }); From 88e013a73c38bcd358d057d1551ff1417dfe5b8c Mon Sep 17 00:00:00 2001 From: Maciej Hirsz Date: Mon, 12 Jul 2021 13:16:51 +0200 Subject: [PATCH 4/7] Remove dead code --- http-server/src/server.rs | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/http-server/src/server.rs b/http-server/src/server.rs index 1f94dcccb1..a4420ef031 100644 --- a/http-server/src/server.rs +++ b/http-server/src/server.rs @@ -154,19 +154,6 @@ pub struct Server { } impl Server { - // /// Register all methods from a [`Methods`] of provided [`RpcModule`] on this server. - // /// In case a method already is registered with the same name, no method is added and a [`Error::MethodAlreadyRegistered`] - // /// is returned. Note that the [`RpcModule`] is consumed after this call. - // pub fn register_module(&mut self, module: RpcModule) -> Result<(), Error> { - // self.methods.merge(module.into_methods())?; - // Ok(()) - // } - - // /// Returns a `Vec` with all the method names registered on this server. - // pub fn method_names(&self) -> Vec<&'static str> { - // self.methods.method_names() - // } - /// Returns socket address to which the server is bound. pub fn local_addr(&self) -> Result { self.local_addr.ok_or_else(|| Error::Custom("Local address not found".into())) From 32c574c169ab64b02d61c861fa2d7110ddd29354 Mon Sep 17 00:00:00 2001 From: Maciej Hirsz Date: Mon, 12 Jul 2021 20:23:35 +0200 Subject: [PATCH 5/7] Check for duplicate names at compile time --- proc-macros/src/new/render_server.rs | 120 ++++++++++++++++----------- 1 file changed, 73 insertions(+), 47 deletions(-) diff --git a/proc-macros/src/new/render_server.rs b/proc-macros/src/new/render_server.rs index 75e6a741dc..f41f9613d3 100644 --- a/proc-macros/src/new/render_server.rs +++ b/proc-macros/src/new/render_server.rs @@ -1,6 +1,7 @@ use super::RpcDescription; -use proc_macro2::TokenStream as TokenStream2; -use quote::quote; +use proc_macro2::{Span, TokenStream as TokenStream2}; +use quote::{quote, quote_spanned}; +use std::collections::HashSet; impl RpcDescription { pub(super) fn render_server(&self) -> Result { @@ -47,57 +48,81 @@ impl RpcDescription { let jrps_error = self.jrps_server_item(quote! { types::Error }); let rpc_module = self.jrps_server_item(quote! { RpcModule }); - let methods = self.methods.iter().map(|method| { - // Rust method to invoke (e.g. `self.(...)`). - let rust_method_name = &method.signature.sig.ident; - // Name of the RPC method (e.g. `foo_makeSpam`). - let rpc_method_name = self.rpc_identifier(&method.name); - // `parsing` is the code associated with parsing structure from the - // provided `RpcParams` object. - // `params_seq` is the comma-delimited sequence of parametsrs. - let is_method = true; - let (parsing, params_seq) = self.render_params_decoding(&method.params, is_method); - - if method.signature.sig.asyncness.is_some() { - quote! { - rpc.register_async_method(#rpc_method_name, |params, context| { - let fut = async move { + let mut registered = HashSet::new(); + let mut errors = Vec::new(); + let mut check_name = |name: String, span: Span| { + if registered.contains(&name) { + let message = format!("{:?} is already defined", name); + errors.push(quote_spanned!(span => compile_error!(#message);)); + } else { + registered.insert(name); + } + }; + + let methods = self + .methods + .iter() + .map(|method| { + // Rust method to invoke (e.g. `self.(...)`). + let rust_method_name = &method.signature.sig.ident; + // Name of the RPC method (e.g. `foo_makeSpam`). + let rpc_method_name = self.rpc_identifier(&method.name); + // `parsing` is the code associated with parsing structure from the + // provided `RpcParams` object. + // `params_seq` is the comma-delimited sequence of parametsrs. + let is_method = true; + let (parsing, params_seq) = self.render_params_decoding(&method.params, is_method); + + check_name(rpc_method_name.clone(), rust_method_name.span()); + + if method.signature.sig.asyncness.is_some() { + quote! { + rpc.register_async_method(#rpc_method_name, |params, context| { + let fut = async move { + #parsing + Ok(context.as_ref().#rust_method_name(#params_seq).await) + }; + Box::pin(fut) + })?; + } + } else { + quote! { + rpc.register_method(#rpc_method_name, |params, context| { #parsing - Ok(context.as_ref().#rust_method_name(#params_seq).await) - }; - Box::pin(fut) - })?; + Ok(context.#rust_method_name(#params_seq)) + })?; + } } - } else { + }) + .collect::>(); + + let subscriptions = self + .subscriptions + .iter() + .map(|sub| { + // Rust method to invoke (e.g. `self.(...)`). + let rust_method_name = &sub.signature.sig.ident; + // Name of the RPC method to subscribe (e.g. `foo_sub`). + let rpc_sub_name = self.rpc_identifier(&sub.name); + // Name of the RPC method to unsubscribe (e.g. `foo_sub`). + let rpc_unsub_name = self.rpc_identifier(&sub.unsub_method); + // `parsing` is the code associated with parsing structure from the + // provided `RpcParams` object. + // `params_seq` is the comma-delimited sequence of parametsrs. + let is_method = false; + let (parsing, params_seq) = self.render_params_decoding(&sub.params, is_method); + + check_name(rpc_sub_name.clone(), rust_method_name.span()); + check_name(rpc_unsub_name.clone(), rust_method_name.span()); + quote! { - rpc.register_method(#rpc_method_name, |params, context| { + rpc.register_subscription(#rpc_sub_name, #rpc_unsub_name, |params, sink, context| { #parsing - Ok(context.#rust_method_name(#params_seq)) + Ok(context.as_ref().#rust_method_name(sink, #params_seq)) })?; } - } - }); - - let subscriptions = self.subscriptions.iter().map(|sub| { - // Rust method to invoke (e.g. `self.(...)`). - let rust_method_name = &sub.signature.sig.ident; - // Name of the RPC method to subscribe (e.g. `foo_sub`). - let rpc_sub_name = self.rpc_identifier(&sub.name); - // Name of the RPC method to unsubscribe (e.g. `foo_sub`). - let rpc_unsub_name = self.rpc_identifier(&sub.unsub_method); - // `parsing` is the code associated with parsing structure from the - // provided `RpcParams` object. - // `params_seq` is the comma-delimited sequence of parametsrs. - let is_method = false; - let (parsing, params_seq) = self.render_params_decoding(&sub.params, is_method); - - quote! { - rpc.register_subscription(#rpc_sub_name, #rpc_unsub_name, |params, sink, context| { - #parsing - Ok(context.as_ref().#rust_method_name(sink, #params_seq)) - })?; - } - }); + }) + .collect::>(); let doc_comment = "Collects all the methods and subscriptions defined in the trait \ and adds them into a single `RpcModule`."; @@ -108,6 +133,7 @@ impl RpcDescription { let inner = move || -> Result<#rpc_module, #jrps_error> { let mut rpc = #rpc_module::new(self); + #(#errors)* #(#methods)* #(#subscriptions)* From 910f09fba50247e750178fa1fccb2d4b6009848d Mon Sep 17 00:00:00 2001 From: Maciej Hirsz Date: Mon, 12 Jul 2021 20:36:58 +0200 Subject: [PATCH 6/7] Add a UI test for name conflicts --- .../tests/ui/incorrect/rpc/rpc_name_conflict.rs | 13 +++++++++++++ .../tests/ui/incorrect/rpc/rpc_name_conflict.stderr | 5 +++++ 2 files changed, 18 insertions(+) create mode 100644 proc-macros/tests/ui/incorrect/rpc/rpc_name_conflict.rs create mode 100644 proc-macros/tests/ui/incorrect/rpc/rpc_name_conflict.stderr diff --git a/proc-macros/tests/ui/incorrect/rpc/rpc_name_conflict.rs b/proc-macros/tests/ui/incorrect/rpc/rpc_name_conflict.rs new file mode 100644 index 0000000000..ac8f457453 --- /dev/null +++ b/proc-macros/tests/ui/incorrect/rpc/rpc_name_conflict.rs @@ -0,0 +1,13 @@ +use jsonrpsee::proc_macros::rpc; + +// Associated items are forbidden. +#[rpc(client, server)] +pub trait MethodNameConflict { + #[method(name = "foo")] + async fn foo(&self) -> u8; + + #[method(name = "foo")] + async fn bar(&self) -> u8; +} + +fn main() {} diff --git a/proc-macros/tests/ui/incorrect/rpc/rpc_name_conflict.stderr b/proc-macros/tests/ui/incorrect/rpc/rpc_name_conflict.stderr new file mode 100644 index 0000000000..213f6fbacc --- /dev/null +++ b/proc-macros/tests/ui/incorrect/rpc/rpc_name_conflict.stderr @@ -0,0 +1,5 @@ +error: "foo" is already defined + --> $DIR/rpc_name_conflict.rs:10:11 + | +10 | async fn bar(&self) -> u8; + | ^^^ From 0521efd10debad17f3758e2d65e42951fafa3c2f Mon Sep 17 00:00:00 2001 From: Maciej Hirsz <1096222+maciejhirsz@users.noreply.github.com> Date: Mon, 12 Jul 2021 20:48:05 +0200 Subject: [PATCH 7/7] Apply suggestions from code review Co-authored-by: David --- proc-macros/src/new/render_server.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/proc-macros/src/new/render_server.rs b/proc-macros/src/new/render_server.rs index f41f9613d3..38c3d40135 100644 --- a/proc-macros/src/new/render_server.rs +++ b/proc-macros/src/new/render_server.rs @@ -102,7 +102,7 @@ impl RpcDescription { .map(|sub| { // Rust method to invoke (e.g. `self.(...)`). let rust_method_name = &sub.signature.sig.ident; - // Name of the RPC method to subscribe (e.g. `foo_sub`). + // Name of the RPC method to subscribe to (e.g. `foo_sub`). let rpc_sub_name = self.rpc_identifier(&sub.name); // Name of the RPC method to unsubscribe (e.g. `foo_sub`). let rpc_unsub_name = self.rpc_identifier(&sub.unsub_method); @@ -140,7 +140,7 @@ impl RpcDescription { Ok(rpc) }; - inner().expect("Proc macro method names should never conflict") + inner().expect("RPC macro method names should never conflict") } }) }