Skip to content

Commit

Permalink
Replace async-channel (#708)
Browse files Browse the repository at this point in the history
* wip wip wip

Use tokio::sync::Notify to signal to the server when a subscriber has gone away without calling unsubscribe

* Cleanup

* Fmt

* More cleanup more TODOs

* fmt

* Address a few todos

* Update core/src/server/rpc_module.rs

Co-authored-by: Niklas Adolfsson <niklasadolfsson1@gmail.com>

* Update ws-server/src/server.rs

Co-authored-by: Niklas Adolfsson <niklasadolfsson1@gmail.com>

* Fix bad merge

* Test

* fmt

* fix test

* Finish test

* Cleanup
Add a second subscription to serverless test

* Update tests/tests/integration_tests.rs

Co-authored-by: Niklas Adolfsson <niklasadolfsson1@gmail.com>

* simplify test

* Review feedback: avoid allocations

* cleanup

* Remove async-channel

* remove async-channel deps

Co-authored-by: Niklas Adolfsson <niklasadolfsson1@gmail.com>
  • Loading branch information
dvdplm and niklasad1 authored Mar 9, 2022
1 parent 662676e commit 7c46458
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 74 deletions.
10 changes: 5 additions & 5 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ license = "MIT"
anyhow = "1"
arrayvec = "0.7.1"
async-trait = "0.1"
async-channel = { version = "1.6", optional = true }
beef = { version = "0.5.1", features = ["impl_serde"] }
thiserror = "1"
futures-channel = { version = "0.3.14", default-features = false }
Expand All @@ -24,26 +23,27 @@ serde = { version = "1.0", default-features = false, features = ["derive"] }
serde_json = { version = "1", features = ["raw_value"] }
soketto = "0.7.1"
parking_lot = { version = "0.12", optional = true }
tokio = { version = "1.8", features = ["rt"], optional = true }
tokio = { version = "1.8", optional = true }

[features]
default = []
http-helpers = ["futures-util"]
server = [
"async-channel",
"futures-util",
"rustc-hash",
"tracing",
"parking_lot",
"rand",
"tokio",
"tokio/rt",
"tokio/sync",
]
client = ["futures-util"]
async-client = [
"client",
"rustc-hash",
"tokio/sync",
"tokio/macros",
"tokio/rt",
"tokio/sync",
"tokio/time",
"tracing"
]
Expand Down
127 changes: 71 additions & 56 deletions core/src/server/rpc_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ use crate::server::resource_limiting::{ResourceGuard, ResourceTable, ResourceVec
use crate::traits::{IdProvider, ToRpcParams};
use futures_channel::{mpsc, oneshot};
use futures_util::future::Either;
use futures_util::pin_mut;
use futures_util::{future::BoxFuture, FutureExt, Stream, StreamExt};
use jsonrpsee_types::error::{ErrorCode, CALL_EXECUTION_FAILED_CODE};
use jsonrpsee_types::{
Expand All @@ -45,6 +46,7 @@ use jsonrpsee_types::{
use parking_lot::Mutex;
use rustc_hash::FxHashMap;
use serde::{de::DeserializeOwned, Serialize};
use tokio::sync::Notify;

/// A `MethodCallback` is an RPC endpoint, callable with a standard JSON-RPC request,
/// implemented as a function pointer to a `Fn` function taking four arguments:
Expand All @@ -61,22 +63,27 @@ pub type SubscriptionMethod = Arc<dyn Send + Sync + Fn(Id, Params, &MethodSink,
/// Connection ID, used for stateful protocol such as WebSockets.
/// For stateless protocols such as http it's unused, so feel free to set it some hardcoded value.
pub type ConnectionId = usize;
/// Raw RPC response.
pub type RawRpcResponse = (String, mpsc::UnboundedReceiver<String>, async_channel::Sender<()>);

/// Data for stateful connections.
/// Raw response from an RPC
/// A 3-tuple containing:
/// - Call result as a `String`,
/// - a [`mpsc::UnboundedReceiver<String>`] to receive future subscription results
/// - a [`tokio::sync::Notify`] to allow subscribers to notify their [`SubscriptionSink`] when they disconnect.
pub type RawRpcResponse = (String, mpsc::UnboundedReceiver<String>, Arc<Notify>);

/// Helper struct to manage subscriptions.
pub struct ConnState<'a> {
/// Connection ID
pub conn_id: ConnectionId,
/// Channel to know whether the connection is closed or not.
pub close: async_channel::Receiver<()>,
/// Get notified when the connection to subscribers is closed.
pub close_notify: Arc<Notify>,
/// ID provider.
pub id_provider: &'a dyn IdProvider,
}

impl<'a> std::fmt::Debug for ConnState<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnState").field("conn_id", &self.conn_id).field("close", &self.close).finish()
f.debug_struct("ConnState").field("conn_id", &self.conn_id).field("close", &self.close_notify).finish()
}
}

Expand Down Expand Up @@ -366,25 +373,26 @@ impl Methods {

/// Execute a callback.
async fn inner_call(&self, req: Request<'_>) -> RawRpcResponse {
let (tx, mut rx) = mpsc::unbounded();
let sink = MethodSink::new(tx);
let (close_tx, close_rx) = async_channel::unbounded();

let (tx_sink, mut rx_sink) = mpsc::unbounded();
let sink = MethodSink::new(tx_sink);
let id = req.id.clone();
let params = Params::new(req.params.map(|params| params.get()));
let notify = Arc::new(Notify::new());

let _result = match self.method(&req.method).map(|c| &c.callback) {
None => sink.send_error(req.id, ErrorCode::MethodNotFound.into()),
Some(MethodKind::Sync(cb)) => (cb)(id, params, &sink),
Some(MethodKind::Async(cb)) => (cb)(id.into_owned(), params.into_owned(), sink, 0, None).await,
Some(MethodKind::Subscription(cb)) => {
let conn_state = ConnState { conn_id: 0, close: close_rx, id_provider: &RandomIntegerIdProvider };
let close_notify = notify.clone();
let conn_state = ConnState { conn_id: 0, close_notify, id_provider: &RandomIntegerIdProvider };
(cb)(id, params, &sink, conn_state)
}
};

let resp = rx.next().await.expect("tx and rx still alive; qed");
(resp, rx, close_tx)
let resp = rx_sink.next().await.expect("tx and rx still alive; qed");

(resp, rx_sink, notify)
}

/// Helper to create a subscription on the `RPC module` without having to spin up a server.
Expand Down Expand Up @@ -416,10 +424,11 @@ impl Methods {
let params = params.to_rpc_params()?;
let req = Request::new(sub_method.into(), Some(&params), Id::Number(0));
tracing::trace!("[Methods::subscribe] Calling subscription method: {:?}, params: {:?}", sub_method, params);
let (response, rx, tx) = self.inner_call(req).await;
let (response, rx, close_notify) = self.inner_call(req).await;
let subscription_response = serde_json::from_str::<Response<RpcSubscriptionId>>(&response)?;
let sub_id = subscription_response.result.into_owned();
Ok(Subscription { sub_id, rx, tx })
let close_notify = Some(close_notify);
Ok(Subscription { sub_id, rx, close_notify })
}

/// Returns an `Iterator` with all the method names registered on this server.
Expand Down Expand Up @@ -633,6 +642,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
let ctx = self.ctx.clone();
let subscribers = Subscribers::default();

// Subscribe
{
let subscribers = subscribers.clone();
self.methods.mut_callbacks().insert(
Expand All @@ -653,7 +663,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {

let sink = SubscriptionSink {
inner: method_sink.clone(),
close: conn.close,
close_notify: Some(conn.close_notify),
method: notif_method_name,
subscribers: subscribers.clone(),
uniq_sub: SubscriptionKey { conn_id: conn.conn_id, sub_id },
Expand All @@ -674,6 +684,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
);
}

// Unsubscribe
{
self.methods.mut_callbacks().insert(
unsubscribe_method_name,
Expand Down Expand Up @@ -725,8 +736,8 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
pub struct SubscriptionSink {
/// Sink.
inner: MethodSink,
/// Close
close: async_channel::Receiver<()>,
/// Get notified when subscribers leave so we can exit
close_notify: Option<Arc<Notify>>,
/// MethodCallback.
method: &'static str,
/// Unique subscription.
Expand Down Expand Up @@ -773,47 +784,45 @@ impl SubscriptionSink {
S: Stream<Item = T> + Unpin,
T: Serialize,
{
let mut close_stream = self.close.clone();
let mut item = stream.next();
let mut close = close_stream.next();

loop {
match futures_util::future::select(item, close).await {
Either::Left((Some(result), c)) => {
match self.send(&result) {
Ok(_) => (),
Err(Error::SubscriptionClosed(close_reason)) => {
self.close(&close_reason);
break Ok(());
}
Err(err) => {
tracing::error!("subscription `{}` failed to send item got error: {:?}", self.method, err);
break Err(err);
}
};
close = c;
item = stream.next();
}
// No messages should be sent over this channel
// if that occurred just ignore and continue.
Either::Right((Some(_), i)) => {
item = i;
close = close_stream.next();
}
// Connection terminated.
Either::Right((None, _)) => {
self.close(&SubscriptionClosed::new(SubscriptionClosedReason::ConnectionReset));
break Ok(());
if let Some(close_notify) = self.close_notify.clone() {
let mut stream_item = stream.next();
let closed_fut = close_notify.notified();
pin_mut!(closed_fut);
loop {
match futures_util::future::select(stream_item, closed_fut).await {
// The app sent us a value to send back to the subscribers
Either::Left((Some(result), next_closed_fut)) => {
match self.send(&result) {
Ok(_) => (),
Err(Error::SubscriptionClosed(close_reason)) => {
self.close(&close_reason);
break Ok(());
}
Err(err) => {
break Err(err);
}
};
stream_item = stream.next();
closed_fut = next_closed_fut;
}
// Stream terminated.
Either::Left((None, _)) => break Ok(()),
// The subscriber went away without telling us.
Either::Right(((), _)) => {
self.close(&SubscriptionClosed::new(SubscriptionClosedReason::ConnectionReset));
break Ok(());
}
}
// Stream terminated.
Either::Left((None, _)) => break Ok(()),
}
} else {
// The sink is closed.
Ok(())
}
}

/// Returns whether this channel is closed without needing a context.
pub fn is_closed(&self) -> bool {
self.inner.is_closed() || self.close.is_closed()
self.inner.is_closed() || self.close_notify.is_none()
}

fn build_message<T: Serialize>(&self, result: &T) -> Result<String, Error> {
Expand Down Expand Up @@ -853,7 +862,7 @@ impl SubscriptionSink {
self.inner_close(Some(&close_reason));
}

/// Provide close from `SubscriptionClosed`.
/// Close the subscription sink with the provided [`SubscriptionClosed`].
pub fn close(&mut self, close_reason: &SubscriptionClosed) {
self.inner_close(Some(close_reason));
}
Expand All @@ -880,17 +889,19 @@ impl Drop for SubscriptionSink {
/// Wrapper struct that maintains a subscription "mainly" for testing.
#[derive(Debug)]
pub struct Subscription {
tx: async_channel::Sender<()>,
close_notify: Option<Arc<Notify>>,
rx: mpsc::UnboundedReceiver<String>,
sub_id: RpcSubscriptionId<'static>,
}

impl Subscription {
/// Close the subscription channel.
pub fn close(&mut self) {
self.tx.close();
tracing::trace!("[Subscription::close] Notifying");
if let Some(n) = self.close_notify.take() {
n.notify_one()
}
}

/// Get the subscription ID
pub fn subscription_id(&self) -> &RpcSubscriptionId {
&self.sub_id
Expand All @@ -903,6 +914,10 @@ impl Subscription {
///
/// If the decoding the value as `T` fails.
pub async fn next<T: DeserializeOwned>(&mut self) -> Option<Result<(T, RpcSubscriptionId<'static>), Error>> {
if self.close_notify.is_none() {
tracing::debug!("[Subscription::next] Closed.");
return Some(Err(Error::SubscriptionClosed(SubscriptionClosedReason::ConnectionReset.into())));
}
let raw = self.rx.next().await?;
let res = match serde_json::from_str::<SubscriptionResponse<T>>(&raw) {
Ok(r) => Ok((r.params.result, r.params.subscription.into_owned())),
Expand Down
2 changes: 2 additions & 0 deletions tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ tracing = "0.1"
serde = "1"
serde_json = "1"
hyper = { version = "0.14", features = ["http1", "client"] }
tracing-subscriber = { version = "0.3.3", features = ["env-filter"] }
tokio-stream = "0.1"
58 changes: 57 additions & 1 deletion tests/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,16 @@
use std::sync::Arc;
use std::time::Duration;

use futures::TryStreamExt;
use helpers::{http_server, http_server_with_access_control, websocket_server, websocket_server_with_subscription};
use jsonrpsee::core::client::{ClientT, IdKind, Subscription, SubscriptionClientT};
use jsonrpsee::core::error::SubscriptionClosedReason;
use jsonrpsee::core::error::{SubscriptionClosed, SubscriptionClosedReason};
use jsonrpsee::core::{Error, JsonValue};
use jsonrpsee::http_client::HttpClientBuilder;
use jsonrpsee::rpc_params;
use jsonrpsee::ws_client::WsClientBuilder;
use tokio::time::interval;
use tokio_stream::wrappers::IntervalStream;

mod helpers;

Expand Down Expand Up @@ -379,6 +382,11 @@ async fn ws_server_should_stop_subscription_after_client_drop() {

#[tokio::test]
async fn ws_server_cancels_stream_after_reset_conn() {
tracing_subscriber::FmtSubscriber::builder()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.try_init()
.expect("setting default subscriber failed");

use futures::{channel::mpsc, SinkExt, StreamExt};
use jsonrpsee::{ws_server::WsServerBuilder, RpcModule};

Expand Down Expand Up @@ -415,6 +423,54 @@ async fn ws_server_cancels_stream_after_reset_conn() {
assert_eq!(Some(()), rx.next().await, "subscription stream should be terminated after the client was dropped");
}

#[tokio::test]
async fn ws_server_subscribe_with_stream() {
use futures::StreamExt;
use jsonrpsee::{ws_server::WsServerBuilder, RpcModule};

let server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap();
let server_url = format!("ws://{}", server.local_addr().unwrap());

let mut module = RpcModule::new(());

module
.register_subscription("subscribe_5_ints", "n", "unsubscribe_5_ints", |_, sink, _| {
tokio::spawn(async move {
let interval = interval(Duration::from_millis(50));
let stream = IntervalStream::new(interval).zip(futures::stream::iter(1..=5)).map(|(_, c)| c);

sink.pipe_from_stream(stream).await.unwrap();
});
Ok(())
})
.unwrap();
server.start(module).unwrap();

let client = WsClientBuilder::default().build(&server_url).await.unwrap();
let mut sub1: Subscription<usize> = client.subscribe("subscribe_5_ints", None, "unsubscribe_5_ints").await.unwrap();
let mut sub2: Subscription<usize> = client.subscribe("subscribe_5_ints", None, "unsubscribe_5_ints").await.unwrap();

let (r1, r2) = futures::future::try_join(
sub1.by_ref().take(2).try_collect::<Vec<_>>(),
sub2.by_ref().take(3).try_collect::<Vec<_>>(),
)
.await
.unwrap();

assert_eq!(r1, vec![1, 2]);
assert_eq!(r2, vec![1, 2, 3]);

// Be rude, don't run the destructor
std::mem::forget(sub2);

// sub1 is still in business, read remaining items.
assert_eq!(sub1.by_ref().take(3).try_collect::<Vec<usize>>().await.unwrap(), vec![3, 4, 5]);

let exp = SubscriptionClosed::new(SubscriptionClosedReason::Server("No close reason provided".to_string()));
// The server closed down the subscription it will send a close reason.
assert!(matches!(sub1.next().await, Some(Err(Error::SubscriptionClosed(close_reason))) if close_reason == exp));
}

#[tokio::test]
async fn ws_batch_works() {
let server_addr = websocket_server().await;
Expand Down
Loading

0 comments on commit 7c46458

Please sign in to comment.