Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(rpc module): stream API for SubscriptionSink #639

Merged
merged 24 commits into from
Jan 21, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
72024ea
feat(rpc module): add_stream to subscription sink
niklasad1 Jan 5, 2022
fe176df
fix some nits
niklasad1 Jan 5, 2022
eebc73d
Merge remote-tracking branch 'origin/master' into na-hacky-sink-add-s…
niklasad1 Jan 5, 2022
a29f988
unify parameters to rpc methods
niklasad1 Jan 12, 2022
6aa22e2
Update core/src/server/rpc_module.rs
niklasad1 Jan 12, 2022
9bdea0d
Update tests/tests/integration_tests.rs
niklasad1 Jan 13, 2022
79a8e55
address grumbles
niklasad1 Jan 13, 2022
7e81acf
fix subscription tests
niklasad1 Jan 13, 2022
6982598
new type for `SubscriptionCallback` and glue code
niklasad1 Jan 17, 2022
d589d24
remove unsed code
niklasad1 Jan 17, 2022
92bb97e
remove todo
niklasad1 Jan 18, 2022
b9598ca
add missing feature tokio/macros
niklasad1 Jan 18, 2022
bce48da
make `add_stream` cancel-safe
niklasad1 Jan 18, 2022
6d16927
rename add_stream and return status
niklasad1 Jan 19, 2022
a47a965
fix nits
niklasad1 Jan 19, 2022
07f80e2
rename stream API -> streamify
niklasad1 Jan 19, 2022
bedf808
Update core/src/server/rpc_module.rs
niklasad1 Jan 19, 2022
03ef669
provide proper close reason
niklasad1 Jan 19, 2022
ef7b965
Merge remote-tracking branch 'origin/na-hacky-sink-add-stream' into n…
niklasad1 Jan 19, 2022
f7aa544
spelling
niklasad1 Jan 19, 2022
72f4726
consume_and_streamify + docs
niklasad1 Jan 20, 2022
7a20a52
fmt
niklasad1 Jan 20, 2022
de4ac6a
rename API pipe_from_stream
niklasad1 Jan 20, 2022
fa15cfb
improve logging; indicate which subscription method that failed
niklasad1 Jan 20, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ anyhow = "1"
arrayvec = "0.7.1"
async-trait = "0.1"
beef = { version = "0.5.1", features = ["impl_serde"] }
async-channel = { version = "1.6", optional = true }
thiserror = "1"
futures-channel = { version = "0.3.14", default-features = false }
futures-util = { version = "0.3.14", default-features = false, optional = true }
Expand All @@ -29,6 +30,7 @@ tokio = { version = "1.8", features = ["rt"], optional = true }
default = []
http-helpers = ["futures-util"]
server = [
"async-channel",
"futures-util",
"rustc-hash",
"tracing",
Expand Down
2 changes: 1 addition & 1 deletion core/src/server/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use std::io;

use crate::{to_json_raw_value, Error};
use futures_channel::mpsc;
use futures_util::stream::StreamExt;
use futures_util::StreamExt;
use jsonrpsee_types::error::{
CallError, ErrorCode, ErrorObject, ErrorResponse, CALL_EXECUTION_FAILED_CODE, OVERSIZED_RESPONSE_CODE,
OVERSIZED_RESPONSE_MSG, UNKNOWN_ERROR_CODE,
Expand Down
121 changes: 82 additions & 39 deletions core/src/server/rpc_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use crate::to_json_raw_value;
use crate::traits::{IdProvider, ToRpcParams};
use beef::Cow;
use futures_channel::{mpsc, oneshot};
use futures_util::{future::BoxFuture, FutureExt, StreamExt};
use futures_util::{future::BoxFuture, FutureExt, Stream, StreamExt};
use jsonrpsee_types::error::{invalid_subscription_err, ErrorCode, CALL_EXECUTION_FAILED_CODE};
use jsonrpsee_types::{
Id, Params, Request, Response, SubscriptionId as RpcSubscriptionId, SubscriptionPayload, SubscriptionResponse,
Expand All @@ -51,16 +51,32 @@ use serde::{de::DeserializeOwned, Serialize};
/// implemented as a function pointer to a `Fn` function taking four arguments:
/// the `id`, `params`, a channel the function uses to communicate the result (or error)
/// back to `jsonrpsee`, and the connection ID (useful for the websocket transport).
pub type SyncMethod = Arc<dyn Send + Sync + Fn(Id, Params, &MethodSink, ConnectionId, &dyn IdProvider) -> bool>;
pub type SyncMethod = Arc<dyn Send + Sync + Fn(Id, Params, &MethodSink, Option<ConnState>) -> bool>;
/// Similar to [`SyncMethod`], but represents an asynchronous handler and takes an additional argument containing a [`ResourceGuard`] if configured.
pub type AsyncMethod<'a> = Arc<
dyn Send + Sync + Fn(Id<'a>, Params<'a>, MethodSink, Option<ResourceGuard>, &dyn IdProvider) -> BoxFuture<'a, bool>,
dyn Send + Sync + Fn(Id<'a>, Params<'a>, MethodSink, ConnectionId, Option<ResourceGuard>) -> BoxFuture<'a, bool>,
>;
/// Connection ID, used for stateful protocol such as WebSockets.
/// For stateless protocols such as http it's unused, so feel free to set it some hardcoded value.
pub type ConnectionId = usize;
/// Raw RPC response.
pub type RawRpcResponse = (String, mpsc::UnboundedReceiver<String>, mpsc::UnboundedSender<String>);
pub type RawRpcResponse = (String, mpsc::UnboundedReceiver<String>, async_channel::Sender<()>);

/// Data for stateful connections.
pub struct ConnState<'a> {
/// Connection ID
pub conn_id: ConnectionId,
/// Channel to know whether the connection is closed or not.
pub close: async_channel::Receiver<()>,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: this panics if count > usize::MAX / 2

but if we reach that we likely have other problems such as OOM :)

/// ID provider.
pub id_provider: &'a dyn IdProvider,
}

impl<'a> std::fmt::Debug for ConnState<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnState").field("conn_id", &self.conn_id).field("close", &self.close).finish()
}
}

type Subscribers = Arc<Mutex<FxHashMap<SubscriptionKey, (MethodSink, oneshot::Receiver<()>)>>>;

Expand Down Expand Up @@ -157,24 +173,23 @@ impl MethodCallback {
pub fn execute(
&self,
sink: &MethodSink,
conn_state: Option<ConnState>,
req: Request<'_>,
conn_id: ConnectionId,
claimed: Option<ResourceGuard>,
id_gen: &dyn IdProvider,
) -> MethodResult<bool> {
let id = req.id.clone();
let params = Params::new(req.params.map(|params| params.get()));

let result = match &self.callback {
MethodKind::Sync(callback) => {
tracing::trace!(
"[MethodCallback::execute] Executing sync callback, params={:?}, req.id={:?}, conn_id={:?}",
"[MethodCallback::execute] Executing sync callback, params={:?}, req.id={:?}, conn_state={:?}",
params,
id,
conn_id
conn_state
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
);

let result = (callback)(id, params, sink, conn_id, id_gen);
let result = (callback)(id, params, sink, conn_state);

// Release claimed resources
drop(claimed);
Expand All @@ -185,14 +200,15 @@ impl MethodCallback {
let sink = sink.clone();
let params = params.into_owned();
let id = id.into_owned();
let conn_id = conn_state.map(|s| s.conn_id).unwrap_or(0);
tracing::trace!(
"[MethodCallback::execute] Executing async callback, params={:?}, req.id={:?}, conn_id={:?}",
"[MethodCallback::execute] Executing async callback, params={:?}, req.id={:?}, conn_state={:?}",
params,
id,
conn_id
conn_id,
);

MethodResult::Async((callback)(id, params, sink, claimed, id_gen))
MethodResult::Async((callback)(id, params, sink, conn_id, claimed))
}
};

Expand Down Expand Up @@ -307,16 +323,10 @@ impl Methods {
}

/// Attempt to execute a callback, sending the resulting JSON (success or error) to the specified sink.
pub fn execute(
&self,
sink: &MethodSink,
req: Request,
conn_id: ConnectionId,
id_gen: &dyn IdProvider,
) -> MethodResult<bool> {
pub fn execute(&self, sink: &MethodSink, conn_state: Option<ConnState>, req: Request) -> MethodResult<bool> {
tracing::trace!("[Methods::execute] Executing request: {:?}", req);
match self.callbacks.get(&*req.method) {
Some(callback) => callback.execute(sink, req, conn_id, None, id_gen),
Some(callback) => callback.execute(sink, conn_state, req, None),
None => {
sink.send_error(req.id, ErrorCode::MethodNotFound.into());
MethodResult::Sync(false)
Expand All @@ -329,15 +339,14 @@ impl Methods {
pub fn execute_with_resources<'r>(
&self,
sink: &MethodSink,
conn_state: Option<ConnState<'r>>,
req: Request<'r>,
conn_id: ConnectionId,
resources: &Resources,
id_gen: &dyn IdProvider,
) -> Result<(&'static str, MethodResult<bool>), Cow<'r, str>> {
tracing::trace!("[Methods::execute_with_resources] Executing request: {:?}", req);
match self.callbacks.get_key_value(&*req.method) {
Some((&name, callback)) => match callback.claim(&req.method, resources) {
Ok(guard) => Ok((name, callback.execute(sink, req, conn_id, Some(guard), id_gen))),
Ok(guard) => Ok((name, callback.execute(sink, conn_state, req, Some(guard)))),
Err(err) => {
tracing::error!("[Methods::execute_with_resources] failed to lock resources: {:?}", err);
sink.send_error(req.id, ErrorCode::ServerIsBusy.into());
Expand Down Expand Up @@ -425,14 +434,17 @@ impl Methods {
/// Wrapper over [`Methods::execute`] to execute a callback.
async fn inner_call(&self, req: Request<'_>) -> RawRpcResponse {
let (tx, mut rx) = mpsc::unbounded();
let sink = MethodSink::new(tx.clone());
let sink = MethodSink::new(tx);
let (close_tx, close_rx) = async_channel::unbounded();

if let MethodResult::Async(fut) = self.execute(&sink, req, 0, &RandomIntegerIdProvider) {
let conn_state = Some(ConnState { conn_id: 0, close: close_rx, id_provider: &RandomIntegerIdProvider });

if let MethodResult::Async(fut) = self.execute(&sink, conn_state, req) {
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
fut.await;
}

let resp = rx.next().await.expect("tx and rx still alive; qed");
(resp, rx, tx)
(resp, rx, close_tx)
}

/// Helper to create a subscription on the `RPC module` without having to spin up a server.
Expand Down Expand Up @@ -527,7 +539,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
let ctx = self.ctx.clone();
let callback = self.methods.verify_and_insert(
method_name,
MethodCallback::new_sync(Arc::new(move |id, params, sink, _, _| match callback(params, &*ctx) {
MethodCallback::new_sync(Arc::new(move |id, params, sink, _| match callback(params, &*ctx) {
Ok(res) => sink.send_response(id, res),
Err(err) => sink.send_call_error(id, err),
})),
Expand All @@ -550,7 +562,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
let ctx = self.ctx.clone();
let callback = self.methods.verify_and_insert(
method_name,
MethodCallback::new_async(Arc::new(move |id, params, sink, claimed, _| {
MethodCallback::new_async(Arc::new(move |id, params, sink, _, claimed| {
let ctx = ctx.clone();
let future = async move {
let result = match callback(params, ctx).await {
Expand Down Expand Up @@ -585,7 +597,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
let ctx = self.ctx.clone();
let callback = self.methods.verify_and_insert(
method_name,
MethodCallback::new_async(Arc::new(move |id, params, sink, claimed, _| {
MethodCallback::new_async(Arc::new(move |id, params, sink, _, claimed| {
let ctx = ctx.clone();

tokio::task::spawn_blocking(move || {
Expand Down Expand Up @@ -671,12 +683,13 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
let subscribers = subscribers.clone();
self.methods.mut_callbacks().insert(
subscribe_method_name,
MethodCallback::new_sync(Arc::new(move |id, params, method_sink, conn_id, id_provider| {
MethodCallback::new_sync(Arc::new(move |id, params, method_sink, conn| {
let (conn_tx, conn_rx) = oneshot::channel::<()>();
let c = conn.expect("conn must be Some; this is bug");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it bust be Some, why don't we restrict it on parameter level?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's way better but I think it requires additional callback types right? see my comment below.

thoughts?

Copy link
Member Author

@niklasad1 niklasad1 Jan 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I had a look at the code to solve this properly at parameter level one need to lookup at the actual callback to avoid to pass down a bunch of unused parameters.

  // server code
  let result = match RpcModule::as_callback(&req.method) {
     None => {
        send_error(sink, msg),
        continue
     }
     Some(MethodKind::Sync(cb)) => (cb)(id, params, sink, conn_id),
     Some(MethodKind::Async(cb)) => (cb)(id, params, sink, resources).await,
     // should be used for subscriptions...
     // servers that don't support subscriptions should throw an error here...
     Some(MethodKind::Subscription) => (cb)(id, params, sink, resources, conn_state),
  };

  // modify RpcModule::register_subscription 
  pub fn register_subscription<F>(
     &mut self,
     subscribe_method_name: &'static str,
     notif_method_name: &'static str,
     unsubscribe_method_name: &'static str,
     callback: F,
 ) -> Result<(), Error> {
   .... 
   ....
   self.methods.mut_callbacks().insert(
      subscribe_method_name,
      MethodCallback::new_subscription(Arc::new(move |id, params, method_sink, conn| {
        ... 
     }
  );
 }

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one benefit is not to have to clone this "introduced channel connection_state" for every method instead only for subscriptions where it's actually used... not sure I liked how the old code abstracted this away

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I already thought about splitting async calls from subscriptions when doing my refactoring earlier, and for regular method calls returning values instead of passing in the sink as a parameter, I reckon that would make things much more readable and straight forward, and potentially make the final binary smaller. So if you want to go that route and add another enum variant I think that's cool, and I can do a PR that switches method calls to have a return value later on :).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, I already added the enum variant in this PR but just a hacky draft to check that it works and to show you and David what I had in mind :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds great to get rid of the sink for the synchronous calls.


let sub_id = {
let sub_id: RpcSubscriptionId = id_provider.next_id().into_owned();
let uniq_sub = SubscriptionKey { conn_id, sub_id: sub_id.clone() };
let sub_id: RpcSubscriptionId = c.id_provider.next_id().into_owned();
let uniq_sub = SubscriptionKey { conn_id: c.conn_id, sub_id: sub_id.clone() };

subscribers.lock().insert(uniq_sub, (method_sink.clone(), conn_rx));

Expand All @@ -687,9 +700,10 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {

let sink = SubscriptionSink {
inner: method_sink.clone(),
close: c.close,
method: notif_method_name,
subscribers: subscribers.clone(),
uniq_sub: SubscriptionKey { conn_id, sub_id },
uniq_sub: SubscriptionKey { conn_id: c.conn_id, sub_id },
is_connected: Some(conn_tx),
};
if let Err(err) = callback(params, sink, ctx.clone()) {
Expand All @@ -710,7 +724,9 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
{
self.methods.mut_callbacks().insert(
unsubscribe_method_name,
MethodCallback::new_sync(Arc::new(move |id, params, sink, conn_id, _| {
MethodCallback::new_sync(Arc::new(move |id, params, sink, conn_state| {
let c = conn_state.expect("conn must be Some; this is bug");

let sub_id = match params.one::<RpcSubscriptionId>() {
Ok(sub_id) => sub_id,
Err(_) => {
Expand All @@ -727,7 +743,11 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
};
let sub_id = sub_id.into_owned();

if subscribers.lock().remove(&SubscriptionKey { conn_id, sub_id: sub_id.clone() }).is_some() {
if subscribers
.lock()
.remove(&SubscriptionKey { conn_id: c.conn_id, sub_id: sub_id.clone() })
.is_some()
{
sink.send_response(id, "Unsubscribed")
} else {
let err = to_json_raw_value(&format!(
Expand Down Expand Up @@ -764,6 +784,8 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
pub struct SubscriptionSink {
/// Sink.
inner: MethodSink,
/// Close
close: async_channel::Receiver<()>,
/// MethodCallback.
method: &'static str,
/// Unique subscription.
Expand All @@ -786,9 +808,30 @@ impl SubscriptionSink {
self.inner_send(msg).map_err(Into::into)
}

/// Consume the sink by passing a stream to be sent via the sink.
pub async fn add_stream<S, T>(mut self, mut stream: S)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about calling this into_stream? I think "add" implies there could be more than one and that it doesn't quite relay the information about the important changes that this call makes to the sink.

Copy link
Member Author

@niklasad1 niklasad1 Jan 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like add_stream either, but into_stream is not really great either it doesn't return the stream....

maybe run_stream, from_stream, spawn_stream or something else?!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair; I'd have liked as_stream but as_ is "taken" with different semantics so can't do that.

Of your suggestions I like from_stream the best.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this method consumes a stream, feeding the items into the subscription?

I guess I'd go with something like consume_stream or read_from_stream. into_, as_, and from_ all sortof feel like I should expect some result back from this call to me!

Copy link
Member Author

@niklasad1 niklasad1 Jan 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this method consumes a stream, feeding the items into the subscription?

yes

I guess we could return a type that impls Sink/SinkExt instead here to make it more readable and flexible i.e, to deal with errors and so on.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

streamify()?
I think consume_stream is so-so. Yes, we do consume it, but that's not really the point. Rather we're "hooking up" the stream to the sink and leave it there for the duration of the subscription.
with_stream?

Copy link
Collaborator

@jsdw jsdw Jan 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pipe, maybe? we're piping a stream into the subscription.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sink.pipe_from_stream? I quite like pipe!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like pipe_from_stream, let's settle for that?

where
S: Stream<Item = T> + Unpin,
T: Serialize,
{
loop {
tokio::select! {
Some(item) = stream.next() => {
if let Err(Error::SubscriptionClosed(_)) = self.send(&item) {
break;
}
},
// No messages should be sent over this channel (just ignore and continue)
Some(_) = self.close.next() => {},
// Stream or connection was dropped => close stream.
else => break,
}
}
}

/// Returns whether this channel is closed without needing a context.
pub fn is_closed(&self) -> bool {
self.inner.is_closed()
self.inner.is_closed() || self.close.is_closed()
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
}

fn build_message<T: Serialize>(&self, result: &T) -> Result<String, Error> {
Expand All @@ -806,7 +849,7 @@ impl SubscriptionSink {
self.inner.send_raw(msg).map_err(|_| Some(SubscriptionClosedReason::ConnectionReset))
}
Some(_) => Err(Some(SubscriptionClosedReason::Unsubscribed)),
// NOTE(niklasad1): this should be unreachble, after the first error is detected the subscription is closed.
// NOTE(niklasad1): this should be unreachable, after the first error is detected the subscription is closed.
None => Err(None),
};

Expand Down Expand Up @@ -850,15 +893,15 @@ impl Drop for SubscriptionSink {
/// Wrapper struct that maintains a subscription "mainly" for testing.
#[derive(Debug)]
pub struct Subscription {
tx: mpsc::UnboundedSender<String>,
tx: async_channel::Sender<()>,
rx: mpsc::UnboundedReceiver<String>,
sub_id: RpcSubscriptionId<'static>,
}

impl Subscription {
/// Close the subscription channel.
pub fn close(&mut self) {
self.tx.close_channel();
self.tx.close();
}

/// Get the subscription ID
Expand Down
5 changes: 2 additions & 3 deletions http-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ use hyper::service::{make_service_fn, service_fn};
use hyper::{Error as HyperError, Method};
use jsonrpsee_core::error::{Error, GenericTransportError};
use jsonrpsee_core::http_helpers::{self, read_body};
use jsonrpsee_core::id_providers::NoopIdProvider;
use jsonrpsee_core::middleware::Middleware;
use jsonrpsee_core::server::helpers::{collect_batch_response, prepare_error, MethodSink};
use jsonrpsee_core::server::resource_limiting::Resources;
Expand Down Expand Up @@ -457,7 +456,7 @@ async fn process_validated_request(
middleware.on_call(req.method.as_ref());

// NOTE: we don't need to track connection id on HTTP, so using hardcoded 0 here.
match methods.execute_with_resources(&sink, req, 0, &resources, &NoopIdProvider) {
match methods.execute_with_resources(&sink, None, req, &resources) {
Ok((name, MethodResult::Sync(success))) => {
middleware.on_result(name, success, request_start);
}
Expand All @@ -483,7 +482,7 @@ async fn process_validated_request(
let middleware = &middleware;

join_all(batch.into_iter().filter_map(move |req| {
match methods.execute_with_resources(&sink, req, 0, &resources, &NoopIdProvider) {
match methods.execute_with_resources(&sink, None, req, &resources) {
Ok((name, MethodResult::Sync(success))) => {
middleware.on_result(name, success, request_start);
None
Expand Down
38 changes: 38 additions & 0 deletions tests/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,44 @@ async fn ws_server_should_stop_subscription_after_client_drop() {
assert!(matches!(close_err.close_reason(), &SubscriptionClosedReason::ConnectionReset));
}

#[tokio::test]
async fn ws_server_cancels_stream_after_reset_conn() {
use futures::{channel::mpsc, SinkExt, StreamExt};
use jsonrpsee::{ws_server::WsServerBuilder, RpcModule};

let server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap();
let server_url = format!("ws://{}", server.local_addr().unwrap());

let (tx, mut rx) = mpsc::channel(1);
let mut module = RpcModule::new(tx);

module
.register_subscription("subscribe_never_produce", "n", "unsubscribe_never_produce", |_, sink, mut tx| {
// create stream that doesn't produce items.
let stream = futures::stream::empty::<usize>();
tokio::spawn(async move {
sink.add_stream(stream).await;
let send_back = Arc::make_mut(&mut tx);
send_back.feed(()).await.unwrap();
dvdplm marked this conversation as resolved.
Show resolved Hide resolved
});
Ok(())
})
.unwrap();

server.start(module).unwrap();

let client = WsClientBuilder::default().build(&server_url).await.unwrap();
let _sub1: Subscription<usize> =
client.subscribe("subscribe_never_produce", None, "unsubscribe_never_produce").await.unwrap();
let _sub2: Subscription<usize> =
client.subscribe("subscribe_never_produce", None, "unsubscribe_never_produce").await.unwrap();

// terminate connection.
drop(client);
assert_eq!(Some(()), rx.next().await, "subscription stream should be terminated after the client was dropped");
assert_eq!(Some(()), rx.next().await, "subscription stream should be terminated after the client was dropped");
Comment on lines +396 to +397
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I felt like I understood the test until I got here. I thought rx would produce None after the client was dropped and so it'd be assert!(rx.next().await.is_none()). What am I missing? :/

Copy link
Member Author

@niklasad1 niklasad1 Jan 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha, maybe it's more clear as you described it but the test actually sends a message on the channel when the subscription terminated.

the reason why is because the tx is kept in the RpcModule and can't be dropped in the subscribe callback.

}

#[tokio::test]
async fn ws_batch_works() {
let server_addr = websocket_server().await;
Expand Down
Loading