Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ws server): fix shutdown on connection closed #1103

Merged
merged 18 commits into from
Apr 27, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 52 additions & 18 deletions server/src/tests/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

use std::time::Duration;

use crate::server::BatchRequestConfig;
use crate::tests::helpers::{deser_call, init_logger, server_with_context};
use crate::types::SubscriptionId;
Expand Down Expand Up @@ -815,25 +817,11 @@ async fn notif_is_ignored() {
}

#[tokio::test]
async fn drop_client_with_pending_calls_works() {
async fn close_client_with_pending_calls_works() {
const MAX_TIMEOUT: Duration = Duration::from_secs(60);
init_logger();

let (handle, addr) = {
let server = ServerBuilder::default().build("127.0.0.1:0").with_default_timeout().await.unwrap().unwrap();

let mut module = RpcModule::new(());

module
.register_async_method("infinite_call", |_, _| async move {
futures_util::future::pending::<()>().await;
"ok"
})
.unwrap();
let addr = server.local_addr().unwrap();

(server.start(module).unwrap(), addr)
};

let (handle, addr) = server_with_infinite_call(MAX_TIMEOUT.checked_mul(10).unwrap()).await;
let mut client = WebSocketTestClient::new(addr).with_default_timeout().await.unwrap().unwrap();

for _ in 0..10 {
Expand All @@ -847,5 +835,51 @@ async fn drop_client_with_pending_calls_works() {
// Stop the server and ensure that the server doesn't wait for futures to complete
// when the connection has already been closed.
handle.stop().unwrap();
assert!(handle.stopped().with_default_timeout().await.is_ok());
assert!(handle.stopped().with_timeout(MAX_TIMEOUT).await.is_ok());
}

#[tokio::test]
async fn drop_client_with_pending_calls_works() {
const MAX_TIMEOUT: Duration = Duration::from_secs(60);

init_logger();
let (handle, addr) = server_with_infinite_call(MAX_TIMEOUT.checked_mul(10).unwrap()).await;

{
let mut client = WebSocketTestClient::new(addr).with_default_timeout().await.unwrap().unwrap();

for _ in 0..10 {
let req = r#"{"jsonrpc":"2.0","method":"infinite_call","id":1}"#;
client.send(req).with_default_timeout().await.unwrap().unwrap();
}
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dq: I remember from the substrate CI that these constants could be exceeded on overcommited days. Should we increase the timeout a bit here, or this should be sufficient?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know what I was thinking, I could embed mpsc::Sender in the method callback and await for 10 messages

Nice to get rid of sleeps, those shouldn't be used if possible :D

}

// Stop the server and ensure that the server doesn't wait for futures to complete
// when the connection has already been closed.
handle.stop().unwrap();
assert!(handle.stopped().with_timeout(MAX_TIMEOUT).await.is_ok());
}

async fn server_with_infinite_call(timeout: Duration) -> (crate::ServerHandle, std::net::SocketAddr) {
let server = ServerBuilder::default()
// Make sure that the ping_interval doesn't force the connection to be closed
.ping_interval(timeout)
.build("127.0.0.1:0")
.with_default_timeout()
.await
.unwrap()
.unwrap();

let mut module = RpcModule::new(());

module
.register_async_method("infinite_call", |_, _| async move {
futures_util::future::pending::<()>().await;
"ok"
})
.unwrap();
let addr = server.local_addr().unwrap();

(server.start(module).unwrap(), addr)
}
91 changes: 70 additions & 21 deletions server/src/transport/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ pub(crate) async fn background_task<L: Logger>(
sender: Sender,
mut receiver: Receiver,
svc: ServiceData<L>,
) -> Result<(), Error> {
) -> Result<Shutdown, Error> {
let ServiceData {
methods,
max_request_body_size,
Expand All @@ -250,17 +250,17 @@ pub(crate) async fn background_task<L: Logger>(
} = svc;

let (tx, rx) = mpsc::channel::<String>(message_buffer_capacity as usize);
let (mut conn_tx, conn_rx) = oneshot::channel();
let (conn_tx, conn_rx) = oneshot::channel();
let sink = MethodSink::new_with_limit(tx, max_response_body_size, max_log_length);
let bounded_subscriptions = BoundedSubscriptions::new(max_subscriptions_per_connection);
let pending_calls = FuturesUnordered::new();

// Spawn another task that sends out the responses on the Websocket.
tokio::spawn(send_task(rx, sender, ping_interval, conn_rx));
let send_task_handle = tokio::spawn(send_task(rx, sender, ping_interval, conn_rx));

// Buffer for incoming data.
let mut data = Vec::with_capacity(100);
let stopped = stop_handle.shutdown();
let stopped = stop_handle.clone().shutdown();

tokio::pin!(stopped);

Expand All @@ -272,11 +272,11 @@ pub(crate) async fn background_task<L: Logger>(
stopped = stop;
permit
}
None => break Ok(()),
None => break Ok(Shutdown::ConnectionClosed),
Copy link
Collaborator

@jsdw jsdw Apr 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I right in thinking that:

  • Stopped means the user has manually stopped the server, so we want to gracefully close the eonnction
  • ConnectionClosed means the connection was closed for some other reason (eg network issue or whatever)

Copy link
Member Author

@niklasad1 niklasad1 Apr 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this is only fails if the send_task has been completed and the receiver of the bounded channel has been dropped.

this can only occur if the send_ping fails I think i.e, connection closed

};

match try_recv(&mut receiver, &mut data, stopped).await {
Receive::Shutdown => break Ok(()),
Receive::Shutdown => break Ok(Shutdown::Stopped),
Receive::Ok(stop) => {
stopped = stop;
}
Expand All @@ -286,7 +286,7 @@ pub(crate) async fn background_task<L: Logger>(
match err {
SokettoError::Closed => {
tracing::debug!("WS transport: remote peer terminated the connection: {}", conn_id);
break Ok(());
break Ok(Shutdown::ConnectionClosed);
}
SokettoError::MessageTooLarge { current, maximum } => {
tracing::debug!(
Expand Down Expand Up @@ -326,21 +326,11 @@ pub(crate) async fn background_task<L: Logger>(
// Drive all running methods to completion.
// **NOTE** Do not return early in this function. This `await` needs to run to guarantee
// proper drop behaviour.
//
// This is not strictly not needed because `tokio::spawn` will drive these the completion
// but it's preferred that the `stop_handle.stopped()` should not return until all methods has been
// executed and the connection has been closed.
tokio::select! {
// All pending calls executed.
_ = pending_calls.for_each(|_| async {}) => {
_ = conn_tx.send(());
}
// The connection was closed, no point of waiting for the pending calls.
_ = conn_tx.closed() => {}
}
graceful_shutdown(&result, pending_calls, receiver, data, conn_tx, send_task_handle).await;

logger.on_disconnect(remote_addr, TransportProtocol::WebSocket);
drop(conn);
drop(stop_handle);
result
}

Expand All @@ -352,7 +342,11 @@ async fn send_task(
stop: oneshot::Receiver<()>,
) {
// Interval to send out continuously `pings`.
let ping_interval = IntervalStream::new(tokio::time::interval(ping_interval));
let mut ping_interval = tokio::time::interval(ping_interval);
// This returns immediately so make sure it doesn't resolve before the ping_interval has been elapsed.
ping_interval.tick().await;

let ping_interval = IntervalStream::new(ping_interval);
let rx = ReceiverStream::new(rx);

tokio::pin!(ping_interval, rx, stop);
Expand Down Expand Up @@ -384,15 +378,18 @@ async fn send_task(
}

// Handle timer intervals.
Either::Right((Either::Left((_, stop)), next_rx)) => {
Either::Right((Either::Left((Some(_instant), stop)), next_rx)) => {
if let Err(err) = send_ping(&mut ws_sender).await {
tracing::debug!("WS transport error: send ping failed: {}", err);
break;
}

rx_item = next_rx;
futs = future::select(ping_interval.next(), stop);
}

Either::Right((Either::Left((None, _)), _)) => unreachable!("IntervalStream never terminates"),

// Server is stopped.
Either::Right((Either::Right(_), _)) => {
break;
Expand Down Expand Up @@ -558,3 +555,55 @@ async fn execute_unchecked_call<L: Logger>(params: ExecuteCallParams<L>) {
}
};
}

#[derive(Debug, Copy, Clone)]
pub(crate) enum Shutdown {
Stopped,
ConnectionClosed,
}

/// Enforce a graceful shutdown.
///
/// This will return once the connection has been terminated or all pending calls have been executed.
async fn graceful_shutdown<F: Future>(
result: &Result<Shutdown, Error>,
pending_calls: FuturesUnordered<F>,
receiver: Receiver,
data: Vec<u8>,
mut conn_tx: oneshot::Sender<()>,
send_task_handle: tokio::task::JoinHandle<()>,
) {
match result {
Ok(Shutdown::ConnectionClosed) => (),
Ok(Shutdown::Stopped) | Err(_) => {
// Soketto doesn't have a way to signal when the connection is closed
// thus just throw away the data and terminate the stream once the connection has
// been terminated.
//
// The receiver is not cancel-safe such that it's used in a stream to enforce that.
let disconnect_stream = futures_util::stream::unfold((receiver, data), |(mut receiver, mut data)| async {
if let Err(SokettoError::Closed) = receiver.receive(&mut data).await {
None
} else {
Some(((), (receiver, data)))
}
});

let graceful_shutdown = pending_calls.for_each(|_| async {});
let disconnect = disconnect_stream.for_each(|_| async {});

// All pending calls has been finished or the connection closed.
// Fine to terminate
tokio::select! {
_ = graceful_shutdown => {}
_ = disconnect => {}
_ = conn_tx.closed() => {}
}
}
};

// Send a message to close down the "send task".
_ = conn_tx.send(());
// Ensure that send task has been closed.
_ = send_task_handle.await;
}