diff --git a/core/Cargo.toml b/core/Cargo.toml index 95beee09b0..d8ea6326a3 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -10,7 +10,6 @@ license = "MIT" anyhow = "1" arrayvec = "0.7.1" async-trait = "0.1" -async-channel = { version = "1.6", optional = true } beef = { version = "0.5.1", features = ["impl_serde"] } thiserror = "1" futures-channel = { version = "0.3.14", default-features = false } @@ -24,26 +23,27 @@ serde = { version = "1.0", default-features = false, features = ["derive"] } serde_json = { version = "1", features = ["raw_value"] } soketto = "0.7.1" parking_lot = { version = "0.12", optional = true } -tokio = { version = "1.8", features = ["rt"], optional = true } +tokio = { version = "1.8", optional = true } [features] default = [] http-helpers = ["futures-util"] server = [ - "async-channel", "futures-util", "rustc-hash", "tracing", "parking_lot", "rand", - "tokio", + "tokio/rt", + "tokio/sync", ] client = ["futures-util"] async-client = [ "client", "rustc-hash", - "tokio/sync", "tokio/macros", + "tokio/rt", + "tokio/sync", "tokio/time", "tracing" ] diff --git a/core/src/server/rpc_module.rs b/core/src/server/rpc_module.rs index bbe8fc8df5..7f5f62875e 100644 --- a/core/src/server/rpc_module.rs +++ b/core/src/server/rpc_module.rs @@ -37,6 +37,7 @@ use crate::server::resource_limiting::{ResourceGuard, ResourceTable, ResourceVec use crate::traits::{IdProvider, ToRpcParams}; use futures_channel::{mpsc, oneshot}; use futures_util::future::Either; +use futures_util::pin_mut; use futures_util::{future::BoxFuture, FutureExt, Stream, StreamExt}; use jsonrpsee_types::error::{ErrorCode, CALL_EXECUTION_FAILED_CODE}; use jsonrpsee_types::{ @@ -45,6 +46,7 @@ use jsonrpsee_types::{ use parking_lot::Mutex; use rustc_hash::FxHashMap; use serde::{de::DeserializeOwned, Serialize}; +use tokio::sync::Notify; /// A `MethodCallback` is an RPC endpoint, callable with a standard JSON-RPC request, /// implemented as a function pointer to a `Fn` function taking four arguments: @@ -61,22 +63,27 @@ pub type SubscriptionMethod = Arc, async_channel::Sender<()>); -/// Data for stateful connections. +/// Raw response from an RPC +/// A 3-tuple containing: +/// - Call result as a `String`, +/// - a [`mpsc::UnboundedReceiver`] to receive future subscription results +/// - a [`tokio::sync::Notify`] to allow subscribers to notify their [`SubscriptionSink`] when they disconnect. +pub type RawRpcResponse = (String, mpsc::UnboundedReceiver, Arc); + +/// Helper struct to manage subscriptions. pub struct ConnState<'a> { /// Connection ID pub conn_id: ConnectionId, - /// Channel to know whether the connection is closed or not. - pub close: async_channel::Receiver<()>, + /// Get notified when the connection to subscribers is closed. + pub close_notify: Arc, /// ID provider. pub id_provider: &'a dyn IdProvider, } impl<'a> std::fmt::Debug for ConnState<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ConnState").field("conn_id", &self.conn_id).field("close", &self.close).finish() + f.debug_struct("ConnState").field("conn_id", &self.conn_id).field("close", &self.close_notify).finish() } } @@ -366,25 +373,26 @@ impl Methods { /// Execute a callback. async fn inner_call(&self, req: Request<'_>) -> RawRpcResponse { - let (tx, mut rx) = mpsc::unbounded(); - let sink = MethodSink::new(tx); - let (close_tx, close_rx) = async_channel::unbounded(); - + let (tx_sink, mut rx_sink) = mpsc::unbounded(); + let sink = MethodSink::new(tx_sink); let id = req.id.clone(); let params = Params::new(req.params.map(|params| params.get())); + let notify = Arc::new(Notify::new()); let _result = match self.method(&req.method).map(|c| &c.callback) { None => sink.send_error(req.id, ErrorCode::MethodNotFound.into()), Some(MethodKind::Sync(cb)) => (cb)(id, params, &sink), Some(MethodKind::Async(cb)) => (cb)(id.into_owned(), params.into_owned(), sink, 0, None).await, Some(MethodKind::Subscription(cb)) => { - let conn_state = ConnState { conn_id: 0, close: close_rx, id_provider: &RandomIntegerIdProvider }; + let close_notify = notify.clone(); + let conn_state = ConnState { conn_id: 0, close_notify, id_provider: &RandomIntegerIdProvider }; (cb)(id, params, &sink, conn_state) } }; - let resp = rx.next().await.expect("tx and rx still alive; qed"); - (resp, rx, close_tx) + let resp = rx_sink.next().await.expect("tx and rx still alive; qed"); + + (resp, rx_sink, notify) } /// Helper to create a subscription on the `RPC module` without having to spin up a server. @@ -416,10 +424,11 @@ impl Methods { let params = params.to_rpc_params()?; let req = Request::new(sub_method.into(), Some(¶ms), Id::Number(0)); tracing::trace!("[Methods::subscribe] Calling subscription method: {:?}, params: {:?}", sub_method, params); - let (response, rx, tx) = self.inner_call(req).await; + let (response, rx, close_notify) = self.inner_call(req).await; let subscription_response = serde_json::from_str::>(&response)?; let sub_id = subscription_response.result.into_owned(); - Ok(Subscription { sub_id, rx, tx }) + let close_notify = Some(close_notify); + Ok(Subscription { sub_id, rx, close_notify }) } /// Returns an `Iterator` with all the method names registered on this server. @@ -633,6 +642,7 @@ impl RpcModule { let ctx = self.ctx.clone(); let subscribers = Subscribers::default(); + // Subscribe { let subscribers = subscribers.clone(); self.methods.mut_callbacks().insert( @@ -653,7 +663,7 @@ impl RpcModule { let sink = SubscriptionSink { inner: method_sink.clone(), - close: conn.close, + close_notify: Some(conn.close_notify), method: notif_method_name, subscribers: subscribers.clone(), uniq_sub: SubscriptionKey { conn_id: conn.conn_id, sub_id }, @@ -674,6 +684,7 @@ impl RpcModule { ); } + // Unsubscribe { self.methods.mut_callbacks().insert( unsubscribe_method_name, @@ -725,8 +736,8 @@ impl RpcModule { pub struct SubscriptionSink { /// Sink. inner: MethodSink, - /// Close - close: async_channel::Receiver<()>, + /// Get notified when subscribers leave so we can exit + close_notify: Option>, /// MethodCallback. method: &'static str, /// Unique subscription. @@ -773,47 +784,45 @@ impl SubscriptionSink { S: Stream + Unpin, T: Serialize, { - let mut close_stream = self.close.clone(); - let mut item = stream.next(); - let mut close = close_stream.next(); - - loop { - match futures_util::future::select(item, close).await { - Either::Left((Some(result), c)) => { - match self.send(&result) { - Ok(_) => (), - Err(Error::SubscriptionClosed(close_reason)) => { - self.close(&close_reason); - break Ok(()); - } - Err(err) => { - tracing::error!("subscription `{}` failed to send item got error: {:?}", self.method, err); - break Err(err); - } - }; - close = c; - item = stream.next(); - } - // No messages should be sent over this channel - // if that occurred just ignore and continue. - Either::Right((Some(_), i)) => { - item = i; - close = close_stream.next(); - } - // Connection terminated. - Either::Right((None, _)) => { - self.close(&SubscriptionClosed::new(SubscriptionClosedReason::ConnectionReset)); - break Ok(()); + if let Some(close_notify) = self.close_notify.clone() { + let mut stream_item = stream.next(); + let closed_fut = close_notify.notified(); + pin_mut!(closed_fut); + loop { + match futures_util::future::select(stream_item, closed_fut).await { + // The app sent us a value to send back to the subscribers + Either::Left((Some(result), next_closed_fut)) => { + match self.send(&result) { + Ok(_) => (), + Err(Error::SubscriptionClosed(close_reason)) => { + self.close(&close_reason); + break Ok(()); + } + Err(err) => { + break Err(err); + } + }; + stream_item = stream.next(); + closed_fut = next_closed_fut; + } + // Stream terminated. + Either::Left((None, _)) => break Ok(()), + // The subscriber went away without telling us. + Either::Right(((), _)) => { + self.close(&SubscriptionClosed::new(SubscriptionClosedReason::ConnectionReset)); + break Ok(()); + } } - // Stream terminated. - Either::Left((None, _)) => break Ok(()), } + } else { + // The sink is closed. + Ok(()) } } /// Returns whether this channel is closed without needing a context. pub fn is_closed(&self) -> bool { - self.inner.is_closed() || self.close.is_closed() + self.inner.is_closed() || self.close_notify.is_none() } fn build_message(&self, result: &T) -> Result { @@ -853,7 +862,7 @@ impl SubscriptionSink { self.inner_close(Some(&close_reason)); } - /// Provide close from `SubscriptionClosed`. + /// Close the subscription sink with the provided [`SubscriptionClosed`]. pub fn close(&mut self, close_reason: &SubscriptionClosed) { self.inner_close(Some(close_reason)); } @@ -880,7 +889,7 @@ impl Drop for SubscriptionSink { /// Wrapper struct that maintains a subscription "mainly" for testing. #[derive(Debug)] pub struct Subscription { - tx: async_channel::Sender<()>, + close_notify: Option>, rx: mpsc::UnboundedReceiver, sub_id: RpcSubscriptionId<'static>, } @@ -888,9 +897,11 @@ pub struct Subscription { impl Subscription { /// Close the subscription channel. pub fn close(&mut self) { - self.tx.close(); + tracing::trace!("[Subscription::close] Notifying"); + if let Some(n) = self.close_notify.take() { + n.notify_one() + } } - /// Get the subscription ID pub fn subscription_id(&self) -> &RpcSubscriptionId { &self.sub_id @@ -903,6 +914,10 @@ impl Subscription { /// /// If the decoding the value as `T` fails. pub async fn next(&mut self) -> Option), Error>> { + if self.close_notify.is_none() { + tracing::debug!("[Subscription::next] Closed."); + return Some(Err(Error::SubscriptionClosed(SubscriptionClosedReason::ConnectionReset.into()))); + } let raw = self.rx.next().await?; let res = match serde_json::from_str::>(&raw) { Ok(r) => Ok((r.params.result, r.params.subscription.into_owned())), diff --git a/tests/Cargo.toml b/tests/Cargo.toml index cce59c40d5..6af376230d 100644 --- a/tests/Cargo.toml +++ b/tests/Cargo.toml @@ -17,3 +17,5 @@ tracing = "0.1" serde = "1" serde_json = "1" hyper = { version = "0.14", features = ["http1", "client"] } +tracing-subscriber = { version = "0.3.3", features = ["env-filter"] } +tokio-stream = "0.1" diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index 3cd8c5955b..3eec2ba10b 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -30,13 +30,16 @@ use std::sync::Arc; use std::time::Duration; +use futures::TryStreamExt; use helpers::{http_server, http_server_with_access_control, websocket_server, websocket_server_with_subscription}; use jsonrpsee::core::client::{ClientT, IdKind, Subscription, SubscriptionClientT}; -use jsonrpsee::core::error::SubscriptionClosedReason; +use jsonrpsee::core::error::{SubscriptionClosed, SubscriptionClosedReason}; use jsonrpsee::core::{Error, JsonValue}; use jsonrpsee::http_client::HttpClientBuilder; use jsonrpsee::rpc_params; use jsonrpsee::ws_client::WsClientBuilder; +use tokio::time::interval; +use tokio_stream::wrappers::IntervalStream; mod helpers; @@ -379,6 +382,11 @@ async fn ws_server_should_stop_subscription_after_client_drop() { #[tokio::test] async fn ws_server_cancels_stream_after_reset_conn() { + tracing_subscriber::FmtSubscriber::builder() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init() + .expect("setting default subscriber failed"); + use futures::{channel::mpsc, SinkExt, StreamExt}; use jsonrpsee::{ws_server::WsServerBuilder, RpcModule}; @@ -415,6 +423,54 @@ async fn ws_server_cancels_stream_after_reset_conn() { assert_eq!(Some(()), rx.next().await, "subscription stream should be terminated after the client was dropped"); } +#[tokio::test] +async fn ws_server_subscribe_with_stream() { + use futures::StreamExt; + use jsonrpsee::{ws_server::WsServerBuilder, RpcModule}; + + let server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap(); + let server_url = format!("ws://{}", server.local_addr().unwrap()); + + let mut module = RpcModule::new(()); + + module + .register_subscription("subscribe_5_ints", "n", "unsubscribe_5_ints", |_, sink, _| { + tokio::spawn(async move { + let interval = interval(Duration::from_millis(50)); + let stream = IntervalStream::new(interval).zip(futures::stream::iter(1..=5)).map(|(_, c)| c); + + sink.pipe_from_stream(stream).await.unwrap(); + }); + Ok(()) + }) + .unwrap(); + server.start(module).unwrap(); + + let client = WsClientBuilder::default().build(&server_url).await.unwrap(); + let mut sub1: Subscription = client.subscribe("subscribe_5_ints", None, "unsubscribe_5_ints").await.unwrap(); + let mut sub2: Subscription = client.subscribe("subscribe_5_ints", None, "unsubscribe_5_ints").await.unwrap(); + + let (r1, r2) = futures::future::try_join( + sub1.by_ref().take(2).try_collect::>(), + sub2.by_ref().take(3).try_collect::>(), + ) + .await + .unwrap(); + + assert_eq!(r1, vec![1, 2]); + assert_eq!(r2, vec![1, 2, 3]); + + // Be rude, don't run the destructor + std::mem::forget(sub2); + + // sub1 is still in business, read remaining items. + assert_eq!(sub1.by_ref().take(3).try_collect::>().await.unwrap(), vec![3, 4, 5]); + + let exp = SubscriptionClosed::new(SubscriptionClosedReason::Server("No close reason provided".to_string())); + // The server closed down the subscription it will send a close reason. + assert!(matches!(sub1.next().await, Some(Err(Error::SubscriptionClosed(close_reason))) if close_reason == exp)); +} + #[tokio::test] async fn ws_batch_works() { let server_addr = websocket_server().await; diff --git a/tests/tests/rpc_module.rs b/tests/tests/rpc_module.rs index fe3bf585eb..0bdc8fa25f 100644 --- a/tests/tests/rpc_module.rs +++ b/tests/tests/rpc_module.rs @@ -229,6 +229,11 @@ async fn subscribing_without_server() { #[tokio::test] async fn close_test_subscribing_without_server() { + tracing_subscriber::FmtSubscriber::builder() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init() + .expect("setting default subscriber failed"); + let mut module = RpcModule::new(()); module .register_subscription("my_sub", "my_sub", "my_unsub", |_, mut sink, _| { @@ -236,6 +241,7 @@ async fn close_test_subscribing_without_server() { // make sure to only send one item sink.send(&"lo").unwrap(); while !sink.is_closed() { + tracing::debug!("[test] Sink is open, sleeping"); std::thread::sleep(std::time::Duration::from_millis(500)); } // Get the close reason. @@ -251,14 +257,28 @@ async fn close_test_subscribing_without_server() { let (val, id) = my_sub.next::().await.unwrap().unwrap(); assert_eq!(&val, "lo"); assert_eq!(&id, my_sub.subscription_id()); + let mut my_sub2 = std::mem::ManuallyDrop::new(module.subscribe("my_sub", EmptyParams::new()).await.unwrap()); - // close the subscription to ensure it doesn't return any items. + // Close the subscription to ensure it doesn't return any items. my_sub.close(); + tracing::info!("[test] closed first sub"); - // In this case, the unsubscribe method was not called and + // The first subscription was not closed using the unsubscribe method and // it will be treated as the connection was closed. let exp = SubscriptionClosed::new(SubscriptionClosedReason::ConnectionReset); assert!( matches!(my_sub.next::().await, Some(Err(Error::SubscriptionClosed(close_reason))) if close_reason == exp) ); + + // The second subscription still works + let (val, _) = my_sub2.next::().await.unwrap().unwrap(); + assert_eq!(val, "lo".to_string()); + // Simulate a rude client that disconnects suddenly. + unsafe { + std::mem::ManuallyDrop::drop(&mut my_sub2); + } + + assert!( + matches!(my_sub2.next::().await, Some(Err(Error::SubscriptionClosed(close_reason))) if close_reason == exp) + ); } diff --git a/ws-server/Cargo.toml b/ws-server/Cargo.toml index 4fbf1ce9a1..7828e7ccad 100644 --- a/ws-server/Cargo.toml +++ b/ws-server/Cargo.toml @@ -10,7 +10,6 @@ homepage = "https://github.com/paritytech/jsonrpsee" documentation = "https://docs.rs/jsonrpsee-ws-server" [dependencies] -async-channel = "1.6.1" futures-channel = "0.3.14" futures-util = { version = "0.3.14", default-features = false, features = ["io", "async-await-macro"] } jsonrpsee-types = { path = "../types", version = "0.9.0" } diff --git a/ws-server/src/server.rs b/ws-server/src/server.rs index 1ed10bf746..29da616eb3 100644 --- a/ws-server/src/server.rs +++ b/ws-server/src/server.rs @@ -49,6 +49,7 @@ use soketto::connection::Error as SokettoError; use soketto::handshake::{server::Response, Server as SokettoServer}; use soketto::Sender; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; +use tokio::sync::Notify; use tokio_util::compat::{Compat, TokioAsyncReadCompatExt}; /// Default maximum connections allowed. @@ -298,7 +299,8 @@ async fn background_task( builder.set_max_message_size(max_request_body_size as usize); let (mut sender, mut receiver) = builder.finish(); let (tx, mut rx) = mpsc::unbounded::(); - let (conn_tx, conn_rx) = async_channel::unbounded(); + let close_notify = Arc::new(Notify::new()); + let close_notify_server_stop = close_notify.clone(); let stop_server2 = stop_server.clone(); let sink = MethodSink::new_with_limit(tx, max_request_body_size); @@ -324,7 +326,7 @@ async fn background_task( // Force `conn_tx` to this async block and close it down // when the connection closes to be on safe side. - conn_tx.close(); + close_notify_server_stop.notify_one(); }); // Buffer for incoming data. @@ -433,8 +435,9 @@ async fn background_task( }, MethodKind::Subscription(callback) => match method.claim(&req.method, &resources) { Ok(guard) => { + let cn = close_notify.clone(); let conn_state = - ConnState { conn_id, close: conn_rx.clone(), id_provider: &*id_provider }; + ConnState { conn_id, close_notify: cn, id_provider: &*id_provider }; let result = callback(id, params, &sink, conn_state); middleware.on_result(name, result, request_start); @@ -466,8 +469,8 @@ async fn background_task( let methods = &methods; let sink = sink.clone(); let id_provider = id_provider.clone(); + let close_notify2 = close_notify.clone(); - let conn_rx2 = conn_rx.clone(); let fut = async move { // Batch responses must be sent back as a single message so we read the results from each // request in the batch and read the results off of a new channel, `rx_batch`, and then send the @@ -533,11 +536,9 @@ async fn background_task( MethodKind::Subscription(callback) => { match method_callback.claim(&req.method, resources) { Ok(guard) => { - let conn_state = ConnState { - conn_id, - close: conn_rx2.clone(), - id_provider: &*id_provider, - }; + let close_notify = close_notify2.clone(); + let conn_state = + ConnState { conn_id, close_notify, id_provider: &*id_provider }; let result = callback(id, params, &sink_batch, conn_state); middleware.on_result(&req.method, result, request_start);