Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace async-channel #708

Merged
merged 23 commits into from
Mar 9, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ serde = { version = "1.0", default-features = false, features = ["derive"] }
serde_json = { version = "1", features = ["raw_value"] }
soketto = "0.7.1"
parking_lot = { version = "0.12", optional = true }
tokio = { version = "1.8", features = ["rt"], optional = true }
tokio = { version = "1.8", features = ["rt", "sync"], optional = true }

[features]
default = []
Expand Down
128 changes: 72 additions & 56 deletions core/src/server/rpc_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ use jsonrpsee_types::{
use parking_lot::Mutex;
use rustc_hash::FxHashMap;
use serde::{de::DeserializeOwned, Serialize};
use tokio::sync::Notify;

/// A `MethodCallback` is an RPC endpoint, callable with a standard JSON-RPC request,
/// implemented as a function pointer to a `Fn` function taking four arguments:
Expand All @@ -62,22 +63,27 @@ pub type SubscriptionMethod = Arc<dyn Send + Sync + Fn(Id, Params, &MethodSink,
/// Connection ID, used for stateful protocol such as WebSockets.
/// For stateless protocols such as http it's unused, so feel free to set it some hardcoded value.
pub type ConnectionId = usize;
/// Raw RPC response.
pub type RawRpcResponse = (String, mpsc::UnboundedReceiver<String>, async_channel::Sender<()>);

/// Data for stateful connections.
/// Raw response from an RPC
/// A 3-tuple containing:
/// - Call result as a `String`,
/// - a [`mpsc::UnboundedReceiver<String>`] to receive future subscription results
/// - a [`tokio::sync::Notify`] to allow subscribers to notify their [`SubscriptionSink`] when they disconnect.
pub type RawRpcResponse = (String, mpsc::UnboundedReceiver<String>, Arc<Notify>);

/// Helper struct to manage subscriptions.
pub struct ConnState<'a> {
/// Connection ID
pub conn_id: ConnectionId,
/// Channel to know whether the connection is closed or not.
pub close: async_channel::Receiver<()>,
/// Get notified when the connection to subscribers is closed.
pub close_notify: Arc<Notify>,
/// ID provider.
pub id_provider: &'a dyn IdProvider,
}

impl<'a> std::fmt::Debug for ConnState<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnState").field("conn_id", &self.conn_id).field("close", &self.close).finish()
f.debug_struct("ConnState").field("conn_id", &self.conn_id).field("close", &self.close_notify).finish()
}
}

Expand Down Expand Up @@ -367,25 +373,26 @@ impl Methods {

/// Execute a callback.
async fn inner_call(&self, req: Request<'_>) -> RawRpcResponse {
let (tx, mut rx) = mpsc::unbounded();
let sink = MethodSink::new(tx);
let (close_tx, close_rx) = async_channel::unbounded();

let (tx_sink, mut rx_sink) = mpsc::unbounded();
let sink = MethodSink::new(tx_sink);
let id = req.id.clone();
let params = Params::new(req.params.map(|params| params.get()));
let notify = Arc::new(Notify::new());
dvdplm marked this conversation as resolved.
Show resolved Hide resolved

let _result = match self.method(&req.method).map(|c| &c.callback) {
None => sink.send_error(req.id, ErrorCode::MethodNotFound.into()),
Some(MethodKind::Sync(cb)) => (cb)(id, params, &sink),
Some(MethodKind::Async(cb)) => (cb)(id.into_owned(), params.into_owned(), sink, 0, None).await,
Some(MethodKind::Subscription(cb)) => {
let conn_state = ConnState { conn_id: 0, close: close_rx, id_provider: &RandomIntegerIdProvider };
let close_notify = notify.clone();
let conn_state = ConnState { conn_id: 0, close_notify, id_provider: &RandomIntegerIdProvider };
(cb)(id, params, &sink, conn_state)
}
};

let resp = rx.next().await.expect("tx and rx still alive; qed");
(resp, rx, close_tx)
let resp = rx_sink.next().await.expect("tx and rx still alive; qed");

(resp, rx_sink, notify)
}

/// Helper to create a subscription on the `RPC module` without having to spin up a server.
Expand Down Expand Up @@ -417,10 +424,11 @@ impl Methods {
let params = params.to_rpc_params()?;
let req = Request::new(sub_method.into(), Some(&params), Id::Number(0));
tracing::trace!("[Methods::subscribe] Calling subscription method: {:?}, params: {:?}", sub_method, params);
let (response, rx, tx) = self.inner_call(req).await;
let (response, rx, close_notify) = self.inner_call(req).await;
let subscription_response = serde_json::from_str::<Response<RpcSubscriptionId>>(&response)?;
let sub_id = subscription_response.result.into_owned();
Ok(Subscription { sub_id, rx, tx })
let close_notify = Some(close_notify);
Ok(Subscription { sub_id, rx, close_notify })
}

/// Returns an `Iterator` with all the method names registered on this server.
Expand Down Expand Up @@ -627,6 +635,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
let ctx = self.ctx.clone();
let subscribers = Subscribers::default();

// Subscribe
{
let subscribers = subscribers.clone();
self.methods.mut_callbacks().insert(
Expand All @@ -647,7 +656,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {

let sink = SubscriptionSink {
inner: method_sink.clone(),
close: conn.close,
close_notify: Some(conn.close_notify),
method: notif_method_name,
subscribers: subscribers.clone(),
uniq_sub: SubscriptionKey { conn_id: conn.conn_id, sub_id },
Expand All @@ -668,6 +677,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
);
}

// Unsubscribe
{
self.methods.mut_callbacks().insert(
unsubscribe_method_name,
Expand Down Expand Up @@ -729,8 +739,8 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
pub struct SubscriptionSink {
/// Sink.
inner: MethodSink,
/// Close
close: async_channel::Receiver<()>,
/// Get notified when subscribers leave so we can exit
close_notify: Option<Arc<Notify>>,
dvdplm marked this conversation as resolved.
Show resolved Hide resolved
/// MethodCallback.
method: &'static str,
/// Unique subscription.
Expand Down Expand Up @@ -777,47 +787,47 @@ impl SubscriptionSink {
S: Stream<Item = T> + Unpin,
T: Serialize,
{
let mut close_stream = self.close.clone();
let mut item = stream.next();
let mut close = close_stream.next();

loop {
match futures_util::future::select(item, close).await {
Either::Left((Some(result), c)) => {
match self.send(&result) {
Ok(_) => (),
Err(Error::SubscriptionClosed(close_reason)) => {
self.close(&close_reason);
break Ok(());
}
Err(err) => {
tracing::error!("subscription `{}` failed to send item got error: {:?}", self.method, err);
break Err(err);
}
};
close = c;
item = stream.next();
}
// No messages should be sent over this channel
// if that occurred just ignore and continue.
Either::Right((Some(_), i)) => {
item = i;
close = close_stream.next();
}
// Connection terminated.
Either::Right((None, _)) => {
self.close(&SubscriptionClosed::new(SubscriptionClosedReason::ConnectionReset));
break Ok(());
if let Some(close_notify) = self.close_notify.clone() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need the clone here @dvdplm?!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't move out the Arc<Notify> and I can't use a reference either (because send takes a mutable reference below). :/

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be possible with take here but it didn't work when I tried so I just wonder why :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

take() compiles but then the tests fail. :/

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Played a bit with this, and not sure there is a super clean solution. You either clone it here, or somehow decouple the closed_fut from self lifetime wise, which is really hard with pinned futures.

let mut item = stream.next();
tracing::trace!("[SubscriptionSink::pipe_from_stream] Entering loop");
loop {
match futures_util::future::select(item, Box::pin(close_notify.notified())).await {
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
// The app sent us a value to send back to the subscribers
Either::Left((Some(result), _)) => {
tracing::trace!("[SubscriptionSink::pipe_from_stream] Left - sending a result back");
match self.send(&result) {
Ok(_) => (),
Err(Error::SubscriptionClosed(close_reason)) => {
self.close(&close_reason);
break Ok(());
}
Err(err) => {
tracing::error!("subscription `{}` failed to send item. Error: {:?}", self.method, err);
break Err(err);
}
};
item = stream.next();
}
// Stream terminated.
Either::Left((None, _)) => break Ok(()),
// The subscriber went away without telling us.
Either::Right(((), _)) => {
tracing::trace!("[SubscriptionSink::pipe_from_stream] Right - closing");
self.close(&SubscriptionClosed::new(SubscriptionClosedReason::ConnectionReset));
break Ok(());
}
}
// Stream terminated.
Either::Left((None, _)) => break Ok(()),
}
} else {
tracing::warn!("[SubscriptionSink::pipe_from_stream] We're closed.");
// TODO: (dp) Is this right? Should return `Err`?
return Ok(());
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
}
}

/// Returns whether this channel is closed without needing a context.
pub fn is_closed(&self) -> bool {
self.inner.is_closed() || self.close.is_closed()
self.inner.is_closed() || self.close_notify.is_none()
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
}

fn build_message<T: Serialize>(&self, result: &T) -> Result<String, Error> {
Expand Down Expand Up @@ -857,7 +867,7 @@ impl SubscriptionSink {
self.inner_close(Some(&close_reason));
}

/// Provide close from `SubscriptionClosed`.
/// Close the subscription sink with the provided [`SubscriptionClosed`].
pub fn close(&mut self, close_reason: &SubscriptionClosed) {
self.inner_close(Some(close_reason));
}
Expand All @@ -884,17 +894,19 @@ impl Drop for SubscriptionSink {
/// Wrapper struct that maintains a subscription "mainly" for testing.
#[derive(Debug)]
pub struct Subscription {
tx: async_channel::Sender<()>,
close_notify: Option<Arc<Notify>>,
rx: mpsc::UnboundedReceiver<String>,
sub_id: RpcSubscriptionId<'static>,
}

impl Subscription {
/// Close the subscription channel.
pub fn close(&mut self) {
self.tx.close();
tracing::trace!("[Subscription::close] Notifying");
if let Some(n) = self.close_notify.take() {
n.notify_one()
}
}

/// Get the subscription ID
pub fn subscription_id(&self) -> &RpcSubscriptionId {
&self.sub_id
Expand All @@ -907,6 +919,10 @@ impl Subscription {
///
/// If the decoding the value as `T` fails.
pub async fn next<T: DeserializeOwned>(&mut self) -> Option<Result<(T, RpcSubscriptionId<'static>), Error>> {
if self.close_notify.is_none() {
tracing::debug!("[Subscription::next] Closed.");
return Some(Err(Error::SubscriptionClosed(SubscriptionClosedReason::ConnectionReset.into())));
}
let raw = self.rx.next().await?;
let res = match serde_json::from_str::<SubscriptionResponse<T>>(&raw) {
Ok(r) => Ok((r.params.result, r.params.subscription.into_owned())),
Expand Down
1 change: 1 addition & 0 deletions tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ tracing = "0.1"
serde = "1"
serde_json = "1"
hyper = { version = "0.14", features = ["http1", "client"] }
tracing-subscriber = { version = "0.3.3", features = ["env-filter"] }
68 changes: 68 additions & 0 deletions tests/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,11 @@ async fn ws_server_should_stop_subscription_after_client_drop() {

#[tokio::test]
async fn ws_server_cancels_stream_after_reset_conn() {
tracing_subscriber::FmtSubscriber::builder()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.try_init()
.expect("setting default subscriber failed");

use futures::{channel::mpsc, SinkExt, StreamExt};
use jsonrpsee::{ws_server::WsServerBuilder, RpcModule};

Expand Down Expand Up @@ -414,6 +419,69 @@ async fn ws_server_cancels_stream_after_reset_conn() {
assert_eq!(Some(()), rx.next().await, "subscription stream should be terminated after the client was dropped");
assert_eq!(Some(()), rx.next().await, "subscription stream should be terminated after the client was dropped");
}
// TODO: (dp) Finish this test: check that dropping one subscriber doesn't halt the stream etc
#[tokio::test]
async fn ws_server_subscribe_with_stream() {
tracing_subscriber::FmtSubscriber::builder()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.try_init()
.expect("setting default subscriber failed");

use futures::{channel::mpsc, SinkExt, StreamExt};
use jsonrpsee::{ws_server::WsServerBuilder, RpcModule};

let server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap();
let server_url = format!("ws://{}", server.local_addr().unwrap());

let (tx, mut rx) = mpsc::channel(1);
let mut module = RpcModule::new(tx);

module
.register_subscription("subscribe_10_ints", "n", "unsubscribe_10_ints", |_, sink, mut tx| {
use futures::task::Poll;
let mut int_counter = 1usize;
let stream = futures::stream::poll_fn(move |_| -> Poll<Option<usize>> {
dvdplm marked this conversation as resolved.
Show resolved Hide resolved
if int_counter == 10 { return Poll::Ready(None); }
std::thread::sleep(Duration::from_millis(100));
let out = Poll::Ready(Some(int_counter));
int_counter += 1;
out
});

tokio::spawn(async move {
sink.pipe_from_stream(stream).await.unwrap();
let send_back = Arc::make_mut(&mut tx);
send_back.feed(()).await.unwrap();
});
Ok(())
})
.unwrap();
tracing::info!("[test] Starting server");
server.start(module).unwrap();

let client = WsClientBuilder::default().build(&server_url).await.unwrap();
let mut sub1: Subscription<usize> =
client.subscribe("subscribe_10_ints", None, "unsubscribe_10_ints").await.unwrap();
let mut sub2: Subscription<usize> =
client.subscribe("subscribe_10_ints", None, "unsubscribe_10_ints").await.unwrap();
tracing::info!("[test] Subscribed");

assert_eq!(sub1.next().await.unwrap().unwrap(), 1);
assert_eq!(sub2.next().await.unwrap().unwrap(), 1);
assert_eq!(sub1.next().await.unwrap().unwrap(), 2);
assert_eq!(sub2.next().await.unwrap().unwrap(), 2);
assert_eq!(sub2.next().await.unwrap().unwrap(), 3);

// Be rude, don't run the destructor
std::mem::forget(sub2);
// Sub1 is still in business
assert_eq!(sub1.next().await.unwrap().unwrap(), 3);
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved

// terminate connection.
// TODO: (dp) removing the drop changes nothing. Why is that?
dvdplm marked this conversation as resolved.
Show resolved Hide resolved
drop(client);
assert_eq!(Some(()), rx.next().await, "subscription stream should be terminated after the client was dropped");
}

#[tokio::test]
async fn ws_batch_works() {
Expand Down
8 changes: 8 additions & 0 deletions tests/tests/rpc_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,13 +229,19 @@ async fn subscribing_without_server() {

#[tokio::test]
async fn close_test_subscribing_without_server() {
tracing_subscriber::FmtSubscriber::builder()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.try_init()
.expect("setting default subscriber failed");

let mut module = RpcModule::new(());
module
.register_subscription("my_sub", "my_sub", "my_unsub", |_, mut sink, _| {
std::thread::spawn(move || {
// make sure to only send one item
sink.send(&"lo").unwrap();
while !sink.is_closed() {
tracing::debug!("[test] Sink is open, sleeping");
std::thread::sleep(std::time::Duration::from_millis(500));
}
// Get the close reason.
Expand All @@ -247,12 +253,14 @@ async fn close_test_subscribing_without_server() {
})
.unwrap();

tracing::info!("[test] About to subscribe");
let mut my_sub = module.subscribe("my_sub", EmptyParams::new()).await.unwrap();
let (val, id) = my_sub.next::<String>().await.unwrap().unwrap();
assert_eq!(&val, "lo");
assert_eq!(&id, my_sub.subscription_id());

// close the subscription to ensure it doesn't return any items.
tracing::info!("[test] Closing the subscription");
my_sub.close();

// In this case, the unsubscribe method was not called and
Expand Down
Loading