Skip to content

Commit

Permalink
fix(rpc module): failed unsubscribe middleware (#792)
Browse files Browse the repository at this point in the history
  • Loading branch information
niklasad1 authored Jun 13, 2022
1 parent 6421530 commit 6888804
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 11 deletions.
5 changes: 3 additions & 2 deletions core/src/server/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ impl MethodSink {
}
};

if let Err(err) = self.tx.unbounded_send(json) {
if let Err(err) = self.send_raw(json) {
tracing::warn!("Error sending response {:?}", err);
false
} else {
Expand All @@ -147,7 +147,7 @@ impl MethodSink {
}
};

if let Err(err) = self.tx.unbounded_send(json) {
if let Err(err) = self.send_raw(json) {
tracing::warn!("Error sending response {:?}", err);
}

Expand All @@ -162,6 +162,7 @@ impl MethodSink {
/// Send a raw JSON-RPC message to the client, `MethodSink` does not check verify the validity
/// of the JSON being sent.
pub fn send_raw(&self, raw_json: String) -> Result<(), mpsc::TrySendError<String>> {
tracing::trace!("send: {:?}", raw_json);
self.tx.unbounded_send(raw_json)
}

Expand Down
31 changes: 24 additions & 7 deletions core/src/server/rpc_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ pub type AsyncMethod<'a> = Arc<
dyn Send + Sync + Fn(Id<'a>, Params<'a>, MethodSink, ConnectionId, Option<ResourceGuard>) -> BoxFuture<'a, bool>,
>;
/// Method callback for subscriptions.
pub type SubscriptionMethod = Arc<dyn Send + Sync + Fn(Id, Params, &MethodSink, ConnState) -> bool>;
pub type SubscriptionMethod = Arc<dyn Send + Sync + Fn(Id, Params, MethodSink, ConnState) -> bool>;
// Method callback to unsubscribe.
type UnsubscriptionMethod = Arc<dyn Send + Sync + Fn(Id, Params, &MethodSink, ConnectionId) -> bool>;

Expand Down Expand Up @@ -411,17 +411,19 @@ impl Methods {
let close_notify = bounded_subs.acquire().expect("u32::MAX permits is sufficient; qed");
let notify = bounded_subs.acquire().expect("u32::MAX permits is sufficient; qed");

let _result = match self.method(&req.method).map(|c| &c.callback) {
let result = match self.method(&req.method).map(|c| &c.callback) {
None => sink.send_error(req.id, ErrorCode::MethodNotFound.into()),
Some(MethodKind::Sync(cb)) => (cb)(id, params, &sink),
Some(MethodKind::Async(cb)) => (cb)(id.into_owned(), params.into_owned(), sink, 0, None).await,
Some(MethodKind::Subscription(cb)) => {
let conn_state = ConnState { conn_id: 0, close_notify, id_provider: &RandomIntegerIdProvider };
(cb)(id, params, &sink, conn_state)
(cb)(id, params, sink, conn_state)
}
Some(MethodKind::Unsubscription(cb)) => (cb)(id, params, &sink, 0),
};

tracing::trace!("[Methods::inner_call]: method: `{}` success: {}", req.method, result);

let resp = rx_sink.next().await.expect("tx and rx still alive; qed");

(resp, rx_sink, notify)
Expand Down Expand Up @@ -727,7 +729,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
let sub_id = match params.one::<RpcSubscriptionId>() {
Ok(sub_id) => sub_id,
Err(_) => {
tracing::error!(
tracing::warn!(
"unsubscribe call '{}' failed: couldn't parse subscription id={:?} request id={:?}",
unsubscribe_method_name,
params,
Expand All @@ -736,11 +738,20 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
return sink.send_response(id, false);
}
};
let sub_id = sub_id.into_owned();
let key = SubscriptionKey { conn_id, sub_id: sub_id.into_owned() };

let result = subscribers.lock().remove(&key).is_some();

let result = subscribers.lock().remove(&SubscriptionKey { conn_id, sub_id }).is_some();
if !result {
tracing::warn!(
"unsubscribe call `{}` subscription key={:?} not an active subscription",
unsubscribe_method_name,
key,
);
}

sink.send_response(id, result)
// if both the message was successful and the subscription was removed.
sink.send_response(id, result) && result
})),
);
}
Expand Down Expand Up @@ -1063,11 +1074,17 @@ impl Subscription {
n.handle().notify_one()
}
}

/// Get the subscription ID
pub fn subscription_id(&self) -> &RpcSubscriptionId {
&self.sub_id
}

/// Check whether the subscription is closed.
pub fn is_closed(&self) -> bool {
self.close_notify.is_none()
}

/// Returns `Some((val, sub_id))` for the next element of type T from the underlying stream,
/// otherwise `None` if the subscription was closed.
///
Expand Down
47 changes: 47 additions & 0 deletions tests/tests/rpc_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@
// DEALINGS IN THE SOFTWARE.

use std::collections::HashMap;
use std::time::Duration;

use futures::StreamExt;
use jsonrpsee::core::error::{Error, SubscriptionClosed};
use jsonrpsee::core::server::rpc_module::*;
use jsonrpsee::types::error::{CallError, ErrorCode, ErrorObject};
use jsonrpsee::types::{EmptyParams, Params};
use serde::{Deserialize, Serialize};
use tokio::time::interval;
use tokio_stream::wrappers::IntervalStream;

// Helper macro to assert that a binding is of a specific type.
macro_rules! assert_type {
Expand Down Expand Up @@ -311,3 +315,46 @@ async fn subscribing_without_server_bad_params() {
matches!(sub, Error::Call(CallError::Custom(e)) if e.message().contains("invalid length 0, expected an array of length 1 at line 1 column 2") && e.code() == ErrorCode::InvalidParams.code())
);
}

#[tokio::test]
async fn subscribe_unsubscribe_without_server() {
let mut module = RpcModule::new(());
module
.register_subscription("my_sub", "my_sub", "my_unsub", |_, pending, _| {
let mut sink = match pending.accept() {
Some(sink) => sink,
_ => return,
};

let interval = interval(Duration::from_millis(200));
let stream = IntervalStream::new(interval).map(move |_| 1);

tokio::spawn(async move {
sink.pipe_from_stream(stream).await;
});
})
.unwrap();

async fn subscribe_and_assert(module: &RpcModule<()>) {
let sub = module.subscribe("my_sub", EmptyParams::new()).await.unwrap();

let ser_id = serde_json::to_string(sub.subscription_id()).unwrap();

// Unsubscribe should be valid.
let unsub_req = format!("{{\"jsonrpc\":\"2.0\",\"method\":\"my_unsub\",\"params\":[{}],\"id\":1}}", ser_id);
let (response, _) = module.raw_json_request(&unsub_req).await.unwrap();

assert_eq!(response, r#"{"jsonrpc":"2.0","result":true,"id":1}"#);

// Unsubscribe already performed; should be error.
let unsub_req = format!("{{\"jsonrpc\":\"2.0\",\"method\":\"my_unsub\",\"params\":[{}],\"id\":1}}", ser_id);
let (response, _) = module.raw_json_request(&unsub_req).await.unwrap();

assert_eq!(response, r#"{"jsonrpc":"2.0","result":false,"id":1}"#);
}

let sub1 = subscribe_and_assert(&module);
let sub2 = subscribe_and_assert(&module);

futures::future::join(sub1, sub2).await;
}
4 changes: 2 additions & 2 deletions ws-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ async fn background_task(
let result = if let Some(cn) = bounded_subscriptions.acquire() {
let conn_state =
ConnState { conn_id, close_notify: cn, id_provider: &*id_provider };
callback(id, params, &sink, conn_state)
callback(id, params, sink.clone(), conn_state)
} else {
sink.send_error(req.id, ErrorCode::ServerIsBusy.into());
false
Expand Down Expand Up @@ -600,7 +600,7 @@ async fn background_task(
close_notify: cn,
id_provider: &*id_provider,
};
callback(id, params, &sink_batch, conn_state)
callback(id, params, sink_batch.clone(), conn_state)
} else {
sink_batch.send_error(req.id, ErrorCode::ServerIsBusy.into());
false
Expand Down

0 comments on commit 6888804

Please sign in to comment.