Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: ws server terminate subscriptions when connection is closed by the client. #483

Merged
merged 16 commits into from
Sep 24, 2021
Merged
2 changes: 1 addition & 1 deletion tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ publish = false
[dev-dependencies]
beef = { version = "0.5.1", features = ["impl_serde"] }
env_logger = "0.9"
futures-channel = { version = "0.3.14", default-features = false }
futures = { version = "0.3.14", default-features = false, features = ["std"] }
jsonrpsee = { path = "../jsonrpsee", features = ["full"] }
tokio = { version = "1", features = ["full"] }
serde_json = "1"
Expand Down
15 changes: 11 additions & 4 deletions tests/tests/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

use futures_channel::oneshot;
use futures::channel::oneshot;
use jsonrpsee::{
http_server::HttpServerBuilder,
types::Error,
ws_server::{WsServerBuilder, WsStopHandle},
RpcModule,
};
Expand All @@ -47,7 +48,9 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsStopHandle)
module
.register_subscription("subscribe_hello", "unsubscribe_hello", |_, mut sink, _| {
std::thread::spawn(move || loop {
let _ = sink.send(&"hello from subscription");
if let Err(Error::SubscriptionClosed(_)) = sink.send(&"hello from subscription") {
break;
}
std::thread::sleep(Duration::from_millis(50));
});
Ok(())
Expand All @@ -57,7 +60,9 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsStopHandle)
module
.register_subscription("subscribe_foo", "unsubscribe_foo", |_, mut sink, _| {
std::thread::spawn(move || loop {
let _ = sink.send(&1337);
if let Err(Error::SubscriptionClosed(_)) = sink.send(&1337) {
break;
}
std::thread::sleep(Duration::from_millis(100));
});
Ok(())
Expand All @@ -69,7 +74,9 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsStopHandle)
let mut count: usize = params.one()?;
std::thread::spawn(move || loop {
count = count.wrapping_add(1);
let _ = sink.send(&count);
if let Err(Error::SubscriptionClosed(_)) = sink.send(&count) {
break;
}
std::thread::sleep(Duration::from_millis(100));
});
Ok(())
Expand Down
49 changes: 45 additions & 4 deletions tests/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,17 +137,16 @@ async fn ws_subscription_several_clients_with_drop() {
let (client, hello_sub, foo_sub) = clients.remove(i);
drop(hello_sub);
drop(foo_sub);
// Send this request to make sure that the client's background thread hasn't
// been canceled.
assert!(client.is_connected());
drop(client);
}

// make sure nothing weird happened after dropping half the clients (should be `unsubscribed` in the server)
// make sure nothing weird happened after dropping half of the clients (should be `unsubscribed` in the server)
// would be good to know that subscriptions actually were removed but not possible to verify at
// this layer.
for _ in 0..10 {
for (_client, hello_sub, foo_sub) in &mut clients {
for (client, hello_sub, foo_sub) in &mut clients {
assert!(client.is_connected());
let hello = hello_sub.next().await.unwrap().unwrap();
let foo = foo_sub.next().await.unwrap().unwrap();
assert_eq!(&hello, "hello from subscription");
Expand Down Expand Up @@ -295,3 +294,45 @@ async fn ws_close_pending_subscription_when_server_terminated() {

panic!("subscription keeps sending messages after server shutdown");
}

#[tokio::test]
async fn ws_server_should_stop_subscription_after_client_drop() {
use futures::{channel::mpsc, SinkExt, StreamExt};
use jsonrpsee::{ws_server::WsServerBuilder, RpcModule};

let server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap();
let server_url = format!("ws://{}", server.local_addr().unwrap());

let (tx, mut rx) = mpsc::channel(1);
let mut module = RpcModule::new(tx);

module
.register_subscription("subscribe_hello", "unsubscribe_hello", |_, mut sink, mut tx| {
tokio::spawn(async move {
let close_err = loop {
if let Err(Error::SubscriptionClosed(err)) = sink.send(&1) {
break err;
}
tokio::time::sleep(Duration::from_millis(100)).await;
};
let send_back = Arc::make_mut(&mut tx);
send_back.feed(close_err).await.unwrap();
});
Ok(())
})
.unwrap();

tokio::spawn(async move { server.start(module).await });

let client = WsClientBuilder::default().build(&server_url).await.unwrap();

let mut sub: Subscription<usize> =
client.subscribe("subscribe_hello", ParamsSer::NoParams, "unsubscribe_hello").await.unwrap();

let res = sub.next().await.unwrap();

assert_eq!(res.as_ref(), Some(&1));
drop(client);
// assert that the server received `SubscriptionClosed` after the client was dropped.
assert!(rx.next().await.is_some());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit opaque unless one knows exactly what is going on.

Maybe we can do something like this?

	let sub_closed_err = rx.next().await.expect("Server received `SubscriptionClosed` after the client was dropped.");
	assert!(format!("{:?}", sub_closed_err).contains("is closed but not yet dropped"));

It's a bit annoying that SubscriptionClosedErr has a private member (and that there's no way for the client to know what the subscription ID is) or we could have used assert_matches here for better readability...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yepp, I think we could replace SubscriptionClosed(String) with a better type.

Perhaps with a kind enum, to distinguish why the subscription was closed the error message is quite useful atm

}
2 changes: 1 addition & 1 deletion tests/tests/proc_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

use std::net::SocketAddr;

use futures_channel::oneshot;
use futures::channel::oneshot;
use jsonrpsee::{ws_client::*, ws_server::WsServerBuilder};
use serde_json::value::RawValue;

Expand Down
4 changes: 4 additions & 0 deletions types/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ pub enum FrontToBack {
// Such operations will be blocked until a response is received or the background
// thread has been terminated.
SubscriptionClosed(SubscriptionId),
/// Terminate background thread.
///
/// This implies that all pending operations are ignored and the connection is terminated.
Terminate,
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
}

impl<Notif> Subscription<Notif>
Expand Down
23 changes: 18 additions & 5 deletions ws-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,12 @@ impl WsClient {
}
}

impl Drop for WsClient {
fn drop(&mut self) {
self.to_back.send(FrontToBack::Terminate).now_or_never();
}
}

#[async_trait]
impl Client for WsClient {
async fn notification<'a>(&self, method: &'a str, params: ParamsSer<'a>) -> Result<(), Error> {
Expand Down Expand Up @@ -522,7 +528,7 @@ async fn background_task(
// There is nothing to do just terminate.
Either::Left((None, _)) => {
log::trace!("[backend]: frontend dropped; terminate client");
return;
break;
}

Either::Left((Some(FrontToBack::Batch(batch)), _)) => {
Expand Down Expand Up @@ -606,6 +612,10 @@ async fn background_task(
log::trace!("[backend] unregistering notification handler: {:?}", method);
let _ = manager.remove_notification_handler(method);
}
// User dropped the client.
Either::Left((Some(FrontToBack::Terminate), _)) => {
break;
}
Either::Right((Some(Ok(raw)), _)) => {
// Single response to a request.
if let Ok(single) = serde_json::from_slice::<Response<_>>(&raw) {
Expand All @@ -617,7 +627,7 @@ async fn background_task(
Ok(None) => (),
Err(err) => {
let _ = front_error.send(err);
return;
break;
}
}
}
Expand Down Expand Up @@ -656,19 +666,22 @@ async fn background_task(
serde_json::from_slice::<serde_json::Value>(&raw)
);
let _ = front_error.send(Error::Custom("Unparsable response".into()));
return;
break;
}
}
Either::Right((Some(Err(e)), _)) => {
log::error!("Error: {:?} terminating client", e);
let _ = front_error.send(Error::Transport(e.into()));
return;
break;
}
Either::Right((None, _)) => {
log::error!("[backend]: WebSocket receiver dropped; terminate client");
let _ = front_error.send(Error::Custom("WebSocket receiver dropped".into()));
return;
break;
}
}
}

// Send WebSocket close reason to the server (this might fail if the server terminated the connection).
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
let _ = sender.close().await;
}
7 changes: 6 additions & 1 deletion ws-client/src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ impl RequestManager {
}
}

/// Get a mutable reference to underlying `Sink` in order to send incmoing notifications to the subscription.
/// Get a mutable reference to underlying `Sink` in order to send incoming notifications to the subscription.
///
/// Returns `Some` if the `method` was registered as a NotificationHandler otherwise `None`.
pub fn as_notification_handler_mut(&mut self, method: String) -> Option<&mut SubscriptionSink> {
Expand All @@ -302,6 +302,11 @@ impl RequestManager {
pub fn get_request_id_by_subscription_id(&self, sub_id: &SubscriptionId) -> Option<RequestId> {
self.subscriptions.get(sub_id).copied()
}

/// Get all active subscriptions.
pub fn subscriptions(&self) -> Vec<(SubscriptionId, RequestId)> {
self.subscriptions.iter().map(|(s, r)| (s.clone(), *r)).collect()
}
}

#[cfg(test)]
Expand Down
5 changes: 5 additions & 0 deletions ws-client/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ impl Sender {
self.inner.flush().await?;
Ok(())
}

/// Send a close message and close the connection.
pub async fn close(&mut self) -> Result<(), WsError> {
self.inner.close().await.map_err(Into::into)
}
}

impl Receiver {
Expand Down
26 changes: 16 additions & 10 deletions ws-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};

use crate::future::{FutureDriver, StopHandle, StopMonitor};
Expand Down Expand Up @@ -209,16 +211,19 @@ async fn background_task(
conn_id: ConnectionId,
methods: Methods,
max_request_body_size: u32,
stop_monitor: StopMonitor,
stop_server: StopMonitor,
) -> Result<(), Error> {
// And we can finally transition to a websocket background_task.
let (mut sender, mut receiver) = server.into_builder().finish();
let (tx, mut rx) = mpsc::unbounded::<String>();
let stop_conn = Arc::new(AtomicBool::new(false));

let stop_server2 = stop_server.clone();
let stop_conn2 = stop_conn.clone();
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved

let stop_monitor2 = stop_monitor.clone();
// Send results back to the client.
tokio::spawn(async move {
while !stop_monitor2.shutdown_requested() {
while !stop_server2.shutdown_requested() && !stop_conn2.load(Ordering::SeqCst) {
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
match rx.next().await {
Some(response) => {
log::debug!("send: {}", response);
Expand All @@ -228,20 +233,24 @@ async fn background_task(
None => break,
};
}

drop(stop_monitor2);
// terminate connection.
let _ = sender.close().await;
// NOTE(niklasad1): when the receiver is dropped no further requests or subscriptions
// will be possible.
});

// Buffer for incoming data.
let mut data = Vec::with_capacity(100);
let mut method_executors = FutureDriver::default();

while !stop_monitor.shutdown_requested() {
while !stop_server.shutdown_requested() {
data.clear();

method_executors.select_with(receiver.receive_data(&mut data)).await?;
if let Err(e) = method_executors.select_with(receiver.receive_data(&mut data)).await {
Copy link
Member Author

@niklasad1 niklasad1 Sep 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE; this will kill a connection on any soketto::connection::Error, not just if the client terminated the connection

In a follow-up PR I think we should try to do some graceful error handling

log::error!("Could not receive WS data: {:?}; closing connection", e);
stop_conn.store(true, Ordering::SeqCst);
return Err(e.into());
}

if data.len() > max_request_body_size as usize {
log::warn!("Request is too big ({} bytes, max is {})", data.len(), max_request_body_size);
Expand Down Expand Up @@ -295,9 +304,6 @@ async fn background_task(
// Drive all running methods to completion
method_executors.await;

// Drop the monitor for this task since we are shutting down
drop(stop_monitor);

Ok(())
}

Expand Down