Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ws server): fix shutdown on connection closed #1103

Merged
merged 18 commits into from
Apr 27, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions server/src/tests/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

use std::time::Duration;

use crate::server::BatchRequestConfig;
use crate::tests::helpers::{deser_call, init_logger, server_with_context};
use crate::types::SubscriptionId;
Expand Down Expand Up @@ -816,10 +818,19 @@ async fn notif_is_ignored() {

#[tokio::test]
async fn drop_client_with_pending_calls_works() {
const MAX_TIMEOUT: Duration = Duration::from_secs(60);

init_logger();

let (handle, addr) = {
let server = ServerBuilder::default().build("127.0.0.1:0").with_default_timeout().await.unwrap().unwrap();
let server = ServerBuilder::default()
// Make sure that the ping_interval doesn't force the connection to be closed
.ping_interval(MAX_TIMEOUT.checked_mul(10).unwrap())
.build("127.0.0.1:0")
.with_default_timeout()
.await
.unwrap()
.unwrap();

let mut module = RpcModule::new(());

Expand Down Expand Up @@ -847,5 +858,5 @@ async fn drop_client_with_pending_calls_works() {
// Stop the server and ensure that the server doesn't wait for futures to complete
// when the connection has already been closed.
handle.stop().unwrap();
assert!(handle.stopped().with_default_timeout().await.is_ok());
assert!(handle.stopped().with_timeout(MAX_TIMEOUT).await.is_ok());
}
59 changes: 45 additions & 14 deletions server/src/transport/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ pub(crate) async fn background_task<L: Logger>(
sender: Sender,
mut receiver: Receiver,
svc: ServiceData<L>,
) -> Result<(), Error> {
) -> Result<Shutdown, Error> {
let ServiceData {
methods,
max_request_body_size,
Expand All @@ -250,7 +250,7 @@ pub(crate) async fn background_task<L: Logger>(
} = svc;

let (tx, rx) = mpsc::channel::<String>(message_buffer_capacity as usize);
let (mut conn_tx, conn_rx) = oneshot::channel();
let (conn_tx, conn_rx) = oneshot::channel();
let sink = MethodSink::new_with_limit(tx, max_response_body_size, max_log_length);
let bounded_subscriptions = BoundedSubscriptions::new(max_subscriptions_per_connection);
let pending_calls = FuturesUnordered::new();
Expand All @@ -272,11 +272,11 @@ pub(crate) async fn background_task<L: Logger>(
stopped = stop;
permit
}
None => break Ok(()),
None => break Ok(Shutdown::ConnectionClosed),
Copy link
Collaborator

@jsdw jsdw Apr 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I right in thinking that:

  • Stopped means the user has manually stopped the server, so we want to gracefully close the eonnction
  • ConnectionClosed means the connection was closed for some other reason (eg network issue or whatever)

Copy link
Member Author

@niklasad1 niklasad1 Apr 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this is only fails if the send_task has been completed and the receiver of the bounded channel has been dropped.

this can only occur if the send_ping fails I think i.e, connection closed

};

match try_recv(&mut receiver, &mut data, stopped).await {
Receive::Shutdown => break Ok(()),
Receive::Shutdown => break Ok(Shutdown::Stopped),
Receive::Ok(stop) => {
stopped = stop;
}
Expand All @@ -286,7 +286,7 @@ pub(crate) async fn background_task<L: Logger>(
match err {
SokettoError::Closed => {
tracing::debug!("WS transport: remote peer terminated the connection: {}", conn_id);
break Ok(());
break Ok(Shutdown::ConnectionClosed);
}
SokettoError::MessageTooLarge { current, maximum } => {
tracing::debug!(
Expand Down Expand Up @@ -330,14 +330,33 @@ pub(crate) async fn background_task<L: Logger>(
// This is not strictly not needed because `tokio::spawn` will drive these the completion
// but it's preferred that the `stop_handle.stopped()` should not return until all methods has been
// executed and the connection has been closed.
tokio::select! {
// All pending calls executed.
_ = pending_calls.for_each(|_| async {}) => {
_ = conn_tx.send(());
match result {
Ok(Shutdown::Stopped) | Err(_) => {
// Soketto doesn't have a way to signal when the connection is closed
// thus just throw away the data and terminate the stream once the connection has
// been terminated.
//
// The receiver is not cancel-safe such that it used in a stream to enforce that.
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
let disconnect_stream = futures_util::stream::unfold((receiver, data), |(mut receiver, mut data)| async {
if let Err(SokettoError::Closed) = receiver.receive(&mut data).await {
None
} else {
Some(((), (receiver, data)))
}
});

let pending = pending_calls.for_each(|_| async {});
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
let disconnect = disconnect_stream.for_each(|_| async {});
Copy link
Member Author

@niklasad1 niklasad1 Apr 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's possible that the connection is terminated during the graceful shutdown and this ensures that it is aborted once the close message is sent.

however, this is slightly error prone if the no proper close is sent but yeah would neat if soketto had something to indicate when the connection is closed


tokio::select! {
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
_ = pending => (),
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
_ = disconnect => (),
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So basically, whatever finishes first out of disconnect (which looks like any final things to be received?) and pending (whihc looks like anything to be sent back?) will lead to the thing ending?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it a bit hard to follow the logic here!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, it's way to make it possible to detect whether connection has terminated while we try to wait for the pending futures to complete

}
// The connection was closed, no point of waiting for the pending calls.
_ = conn_tx.closed() => {}
}
Ok(Shutdown::ConnectionClosed) => (),
};

_ = conn_tx.send(());

logger.on_disconnect(remote_addr, TransportProtocol::WebSocket);
drop(conn);
Expand All @@ -352,7 +371,11 @@ async fn send_task(
stop: oneshot::Receiver<()>,
) {
// Interval to send out continuously `pings`.
let ping_interval = IntervalStream::new(tokio::time::interval(ping_interval));
let mut ping_interval = tokio::time::interval(ping_interval);
// This returns immediately so make sure it doesn't resolve before the ping_interval has been elapsed.
ping_interval.tick().await;

let ping_interval = IntervalStream::new(ping_interval);
let rx = ReceiverStream::new(rx);

tokio::pin!(ping_interval, rx, stop);
Expand Down Expand Up @@ -384,15 +407,18 @@ async fn send_task(
}

// Handle timer intervals.
Either::Right((Either::Left((_, stop)), next_rx)) => {
Either::Right((Either::Left((Some(_instant), stop)), next_rx)) => {
if let Err(err) = send_ping(&mut ws_sender).await {
tracing::debug!("WS transport error: send ping failed: {}", err);
break;
}

rx_item = next_rx;
futs = future::select(ping_interval.next(), stop);
}

Either::Right((Either::Left((None, _)), _)) => unreachable!("IntervalStream never terminates"),

// Server is stopped.
Either::Right((Either::Right(_), _)) => {
break;
Expand Down Expand Up @@ -558,3 +584,8 @@ async fn execute_unchecked_call<L: Logger>(params: ExecuteCallParams<L>) {
}
};
}

pub(crate) enum Shutdown {
Stopped,
ConnectionClosed,
}