diff --git a/tests/Cargo.toml b/tests/Cargo.toml index e05fbd1f6e..3f7f55f4c1 100644 --- a/tests/Cargo.toml +++ b/tests/Cargo.toml @@ -9,7 +9,7 @@ publish = false [dev-dependencies] beef = { version = "0.5.1", features = ["impl_serde"] } -futures-channel = { version = "0.3.14", default-features = false } +futures = { version = "0.3.14", default-features = false, features = ["std"] } jsonrpsee = { path = "../jsonrpsee", features = ["full"] } tokio = { version = "1", features = ["full"] } serde_json = "1" diff --git a/tests/tests/helpers.rs b/tests/tests/helpers.rs index 860ac3278d..5af9073bcd 100644 --- a/tests/tests/helpers.rs +++ b/tests/tests/helpers.rs @@ -24,9 +24,10 @@ // IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use futures_channel::oneshot; +use futures::channel::oneshot; use jsonrpsee::{ http_server::HttpServerBuilder, + types::Error, ws_server::{WsServerBuilder, WsStopHandle}, RpcModule, }; @@ -47,7 +48,9 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsStopHandle) module .register_subscription("subscribe_hello", "unsubscribe_hello", |_, mut sink, _| { std::thread::spawn(move || loop { - let _ = sink.send(&"hello from subscription"); + if let Err(Error::SubscriptionClosed(_)) = sink.send(&"hello from subscription") { + break; + } std::thread::sleep(Duration::from_millis(50)); }); Ok(()) @@ -57,7 +60,9 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsStopHandle) module .register_subscription("subscribe_foo", "unsubscribe_foo", |_, mut sink, _| { std::thread::spawn(move || loop { - let _ = sink.send(&1337); + if let Err(Error::SubscriptionClosed(_)) = sink.send(&1337) { + break; + } std::thread::sleep(Duration::from_millis(100)); }); Ok(()) @@ -69,7 +74,9 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsStopHandle) let mut count: usize = params.one()?; std::thread::spawn(move || loop { count = count.wrapping_add(1); - let _ = sink.send(&count); + if let Err(Error::SubscriptionClosed(_)) = sink.send(&count) { + break; + } std::thread::sleep(Duration::from_millis(100)); }); Ok(()) diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index 54696ec4d2..2c700dc5bf 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -33,6 +33,7 @@ use helpers::{http_server, websocket_server, websocket_server_with_subscription} use jsonrpsee::{ http_client::HttpClientBuilder, types::{ + error::SubscriptionClosedError, traits::{Client, SubscriptionClient}, v2::ParamsSer, Error, JsonValue, Subscription, @@ -137,17 +138,16 @@ async fn ws_subscription_several_clients_with_drop() { let (client, hello_sub, foo_sub) = clients.remove(i); drop(hello_sub); drop(foo_sub); - // Send this request to make sure that the client's background thread hasn't - // been canceled. assert!(client.is_connected()); drop(client); } - // make sure nothing weird happened after dropping half the clients (should be `unsubscribed` in the server) + // make sure nothing weird happened after dropping half of the clients (should be `unsubscribed` in the server) // would be good to know that subscriptions actually were removed but not possible to verify at // this layer. for _ in 0..10 { - for (_client, hello_sub, foo_sub) in &mut clients { + for (client, hello_sub, foo_sub) in &mut clients { + assert!(client.is_connected()); let hello = hello_sub.next().await.unwrap().unwrap(); let foo = foo_sub.next().await.unwrap().unwrap(); assert_eq!(&hello, "hello from subscription"); @@ -295,3 +295,45 @@ async fn ws_close_pending_subscription_when_server_terminated() { panic!("subscription keeps sending messages after server shutdown"); } + +#[tokio::test] +async fn ws_server_should_stop_subscription_after_client_drop() { + use futures::{channel::mpsc, SinkExt, StreamExt}; + use jsonrpsee::{ws_server::WsServerBuilder, RpcModule}; + + let server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap(); + let server_url = format!("ws://{}", server.local_addr().unwrap()); + + let (tx, mut rx) = mpsc::channel(1); + let mut module = RpcModule::new(tx); + + module + .register_subscription("subscribe_hello", "unsubscribe_hello", |_, mut sink, mut tx| { + tokio::spawn(async move { + let close_err = loop { + if let Err(Error::SubscriptionClosed(err)) = sink.send(&1) { + break err; + } + tokio::time::sleep(Duration::from_millis(100)).await; + }; + let send_back = Arc::make_mut(&mut tx); + send_back.feed(close_err).await.unwrap(); + }); + Ok(()) + }) + .unwrap(); + + tokio::spawn(async move { server.start(module).await }); + + let client = WsClientBuilder::default().build(&server_url).await.unwrap(); + + let mut sub: Subscription = + client.subscribe("subscribe_hello", ParamsSer::NoParams, "unsubscribe_hello").await.unwrap(); + + let res = sub.next().await.unwrap(); + + assert_eq!(res.as_ref(), Some(&1)); + drop(client); + // assert that the server received `SubscriptionClosed` after the client was dropped. + assert!(matches!(rx.next().await.unwrap(), SubscriptionClosedError { .. })); +} diff --git a/tests/tests/proc_macros.rs b/tests/tests/proc_macros.rs index 0ddb703726..e9daddfb33 100644 --- a/tests/tests/proc_macros.rs +++ b/tests/tests/proc_macros.rs @@ -28,7 +28,7 @@ use std::net::SocketAddr; -use futures_channel::oneshot; +use futures::channel::oneshot; use jsonrpsee::{ws_client::*, ws_server::WsServerBuilder}; use serde_json::value::RawValue; diff --git a/utils/src/server/rpc_module.rs b/utils/src/server/rpc_module.rs index 9ebf788a9a..7047a437e3 100644 --- a/utils/src/server/rpc_module.rs +++ b/utils/src/server/rpc_module.rs @@ -566,12 +566,12 @@ impl SubscriptionSink { impl Drop for SubscriptionSink { fn drop(&mut self) { - self.close(format!("Subscription: {} is closed and dropped", self.uniq_sub.sub_id)); + self.close(format!("Subscription: {} closed", self.uniq_sub.sub_id)); } } fn subscription_closed_err(sub_id: u64) -> Error { - Error::SubscriptionClosed(format!("Subscription {} is closed but not yet dropped", sub_id).into()) + Error::SubscriptionClosed(format!("Subscription {} closed", sub_id).into()) } #[cfg(test)] @@ -765,6 +765,6 @@ mod tests { // The subscription is now closed let my_sub = my_sub_stream.next().await.unwrap(); let my_sub = serde_json::from_str::>(&my_sub).unwrap(); - assert_eq!(my_sub.params.result, format!("Subscription: {} is closed and dropped", sub_id).to_string().into()); + assert_eq!(my_sub.params.result, format!("Subscription: {} closed", sub_id).into()); } } diff --git a/ws-client/src/client.rs b/ws-client/src/client.rs index 18ef9eab30..beab6e59bb 100644 --- a/ws-client/src/client.rs +++ b/ws-client/src/client.rs @@ -298,6 +298,12 @@ impl WsClient { } } +impl Drop for WsClient { + fn drop(&mut self) { + self.to_back.close_channel(); + } +} + #[async_trait] impl Client for WsClient { async fn notification<'a>(&self, method: &'a str, params: ParamsSer<'a>) -> Result<(), Error> { @@ -522,7 +528,7 @@ async fn background_task( // There is nothing to do just terminate. Either::Left((None, _)) => { log::trace!("[backend]: frontend dropped; terminate client"); - return; + break; } Either::Left((Some(FrontToBack::Batch(batch)), _)) => { @@ -617,7 +623,7 @@ async fn background_task( Ok(None) => (), Err(err) => { let _ = front_error.send(err); - return; + break; } } } @@ -656,19 +662,22 @@ async fn background_task( serde_json::from_slice::(&raw) ); let _ = front_error.send(Error::Custom("Unparsable response".into())); - return; + break; } } Either::Right((Some(Err(e)), _)) => { log::error!("Error: {:?} terminating client", e); let _ = front_error.send(Error::Transport(e.into())); - return; + break; } Either::Right((None, _)) => { log::error!("[backend]: WebSocket receiver dropped; terminate client"); let _ = front_error.send(Error::Custom("WebSocket receiver dropped".into())); - return; + break; } } } + + // Send close message to the server. + let _ = sender.close().await; } diff --git a/ws-client/src/manager.rs b/ws-client/src/manager.rs index 126a2d51f0..3e67e3ca86 100644 --- a/ws-client/src/manager.rs +++ b/ws-client/src/manager.rs @@ -289,7 +289,7 @@ impl RequestManager { } } - /// Get a mutable reference to underlying `Sink` in order to send incmoing notifications to the subscription. + /// Get a mutable reference to underlying `Sink` in order to send incoming notifications to the subscription. /// /// Returns `Some` if the `method` was registered as a NotificationHandler otherwise `None`. pub fn as_notification_handler_mut(&mut self, method: String) -> Option<&mut SubscriptionSink> { diff --git a/ws-client/src/transport.rs b/ws-client/src/transport.rs index 715d95e675..7b999a4833 100644 --- a/ws-client/src/transport.rs +++ b/ws-client/src/transport.rs @@ -153,6 +153,11 @@ impl Sender { self.inner.flush().await?; Ok(()) } + + /// Send a close message and close the connection. + pub async fn close(&mut self) -> Result<(), WsError> { + self.inner.close().await.map_err(Into::into) + } } impl Receiver { diff --git a/ws-server/src/server.rs b/ws-server/src/server.rs index 89eea96133..c91b28ed49 100644 --- a/ws-server/src/server.rs +++ b/ws-server/src/server.rs @@ -209,16 +209,16 @@ async fn background_task( conn_id: ConnectionId, methods: Methods, max_request_body_size: u32, - stop_monitor: StopMonitor, + stop_server: StopMonitor, ) -> Result<(), Error> { // And we can finally transition to a websocket background_task. let (mut sender, mut receiver) = server.into_builder().finish(); let (tx, mut rx) = mpsc::unbounded::(); + let stop_server2 = stop_server.clone(); - let stop_monitor2 = stop_monitor.clone(); // Send results back to the client. tokio::spawn(async move { - while !stop_monitor2.shutdown_requested() { + while !stop_server2.shutdown_requested() { match rx.next().await { Some(response) => { log::debug!("send: {}", response); @@ -228,20 +228,24 @@ async fn background_task( None => break, }; } - - drop(stop_monitor2); // terminate connection. let _ = sender.close().await; + // NOTE(niklasad1): when the receiver is dropped no further requests or subscriptions + // will be possible. }); // Buffer for incoming data. let mut data = Vec::with_capacity(100); let mut method_executors = FutureDriver::default(); - while !stop_monitor.shutdown_requested() { + while !stop_server.shutdown_requested() { data.clear(); - method_executors.select_with(receiver.receive_data(&mut data)).await?; + if let Err(e) = method_executors.select_with(receiver.receive_data(&mut data)).await { + log::error!("Could not receive WS data: {:?}; closing connection", e); + tx.close_channel(); + return Err(e.into()); + } if data.len() > max_request_body_size as usize { log::warn!("Request is too big ({} bytes, max is {})", data.len(), max_request_body_size); @@ -295,9 +299,6 @@ async fn background_task( // Drive all running methods to completion method_executors.await; - // Drop the monitor for this task since we are shutting down - drop(stop_monitor); - Ok(()) }