Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into chore-release-v0.18.1
Browse files Browse the repository at this point in the history
  • Loading branch information
niklasad1 committed Apr 27, 2023
2 parents a96edc6 + 457d2d2 commit 2d24810
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 68 deletions.
4 changes: 2 additions & 2 deletions client/http-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use std::fmt;
use std::sync::Arc;
use std::time::Duration;

use crate::transport::{self, Error as TransportError, HttpTransportClient};
use crate::transport::{self, Error as TransportError, HttpBackend, HttpTransportClient};
use crate::types::{NotificationSer, RequestSer, Response};
use async_trait::async_trait;
use hyper::body::HttpBody;
Expand Down Expand Up @@ -235,7 +235,7 @@ impl Default for HttpClientBuilder<Identity> {

/// JSON-RPC HTTP Client that provides functionality to perform method calls and notifications.
#[derive(Debug, Clone)]
pub struct HttpClient<S> {
pub struct HttpClient<S = HttpBackend> {
/// HTTP transport client.
transport: HttpTransportClient<S>,
/// Request timeout. Defaults to 60sec.
Expand Down
2 changes: 1 addition & 1 deletion server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ impl<L: Logger> hyper::service::Service<hyper::Request<hyper::Body>> for TowerSe
ws_builder.set_max_message_size(data.max_request_body_size as usize);
let (sender, receiver) = ws_builder.finish();

let _ = ws::background_task::<L>(sender, receiver, data).await;
ws::background_task::<L>(sender, receiver, data).await;
}
.in_current_span(),
);
Expand Down
92 changes: 69 additions & 23 deletions server/src/tests/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

use std::time::Duration;

use crate::server::BatchRequestConfig;
use crate::tests::helpers::{deser_call, init_logger, server_with_context};
use crate::types::SubscriptionId;
Expand Down Expand Up @@ -815,43 +817,87 @@ async fn notif_is_ignored() {
}

#[tokio::test]
async fn drop_client_with_pending_calls_works() {
async fn close_client_with_pending_calls_works() {
const MAX_TIMEOUT: Duration = Duration::from_secs(60);
const CONCURRENT_CALLS: usize = 10;
init_logger();

let (handle, addr) = {
let server = ServerBuilder::default()
.ping_interval(std::time::Duration::from_secs(60 * 60))
.build("127.0.0.1:0")
.with_default_timeout()
.await
.unwrap()
.unwrap();

let mut module = RpcModule::new(());

module
.register_async_method("infinite_call", |_, _| async move {
futures_util::future::pending::<()>().await;
"ok"
})
.unwrap();
let addr = server.local_addr().unwrap();

(server.start(module).unwrap(), addr)
};
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();

let (handle, addr) = server_with_infinite_call(MAX_TIMEOUT.checked_mul(10).unwrap(), tx).await;
let mut client = WebSocketTestClient::new(addr).with_default_timeout().await.unwrap().unwrap();

for _ in 0..10 {
let req = r#"{"jsonrpc":"2.0","method":"infinite_call","id":1}"#;
client.send(req).with_default_timeout().await.unwrap().unwrap();
}

// Assert that the server has received the calls.
for _ in 0..CONCURRENT_CALLS {
assert!(rx.recv().await.is_some());
}

client.close().await.unwrap();
assert!(client.receive().await.is_err());

// Stop the server and ensure that the server doesn't wait for futures to complete
// when the connection has already been closed.
handle.stop().unwrap();
assert!(handle.stopped().with_default_timeout().await.is_ok());
assert!(handle.stopped().with_timeout(MAX_TIMEOUT).await.is_ok());
}

#[tokio::test]
async fn drop_client_with_pending_calls_works() {
const MAX_TIMEOUT: Duration = Duration::from_secs(60);
const CONCURRENT_CALLS: usize = 10;
init_logger();

let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
let (handle, addr) = server_with_infinite_call(MAX_TIMEOUT.checked_mul(10).unwrap(), tx).await;

{
let mut client = WebSocketTestClient::new(addr).with_default_timeout().await.unwrap().unwrap();

for _ in 0..CONCURRENT_CALLS {
let req = r#"{"jsonrpc":"2.0","method":"infinite_call","id":1}"#;
client.send(req).with_default_timeout().await.unwrap().unwrap();
}
// Assert that the server has received the calls.
for _ in 0..CONCURRENT_CALLS {
assert!(rx.recv().await.is_some());
}
}

// Stop the server and ensure that the server doesn't wait for futures to complete
// when the connection has already been closed.
handle.stop().unwrap();
assert!(handle.stopped().with_timeout(MAX_TIMEOUT).await.is_ok());
}

async fn server_with_infinite_call(
timeout: Duration,
tx: tokio::sync::mpsc::UnboundedSender<()>,
) -> (crate::ServerHandle, std::net::SocketAddr) {
let server = ServerBuilder::default()
// Make sure that the ping_interval doesn't force the connection to be closed
.ping_interval(timeout)
.build("127.0.0.1:0")
.with_default_timeout()
.await
.unwrap()
.unwrap();

let mut module = RpcModule::new(tx);

module
.register_async_method("infinite_call", |_, mut ctx| async move {
let tx = std::sync::Arc::make_mut(&mut ctx);
tx.send(()).unwrap();
futures_util::future::pending::<()>().await;
"ok"
})
.unwrap();
let addr = server.local_addr().unwrap();

(server.start(module).unwrap(), addr)
}
104 changes: 62 additions & 42 deletions server/src/transport/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,11 +226,7 @@ pub(crate) async fn execute_call<'a, L: Logger>(req: Request<'a>, call: CallData
response
}

pub(crate) async fn background_task<L: Logger>(
sender: Sender,
mut receiver: Receiver,
svc: ServiceData<L>,
) -> Result<Shutdown, Error> {
pub(crate) async fn background_task<L: Logger>(sender: Sender, mut receiver: Receiver, svc: ServiceData<L>) {
let ServiceData {
methods,
max_request_body_size,
Expand All @@ -256,11 +252,11 @@ pub(crate) async fn background_task<L: Logger>(
let pending_calls = FuturesUnordered::new();

// Spawn another task that sends out the responses on the Websocket.
tokio::spawn(send_task(rx, sender, ping_interval, conn_rx));
let send_task_handle = tokio::spawn(send_task(rx, sender, ping_interval, conn_rx));

// Buffer for incoming data.
let mut data = Vec::with_capacity(100);
let stopped = stop_handle.shutdown();
let stopped = stop_handle.clone().shutdown();

tokio::pin!(stopped);

Expand Down Expand Up @@ -300,7 +296,7 @@ pub(crate) async fn background_task<L: Logger>(
}
err => {
tracing::debug!("WS transport error: {}; terminate connection: {}", err, conn_id);
break Err(err.into());
break Err(err);
}
};
}
Expand All @@ -326,41 +322,11 @@ pub(crate) async fn background_task<L: Logger>(
// Drive all running methods to completion.
// **NOTE** Do not return early in this function. This `await` needs to run to guarantee
// proper drop behaviour.
//
// This is not strictly not needed because `tokio::spawn` will drive these the completion
// but it's preferred that the `stop_handle.stopped()` should not return until all methods has been
// executed and the connection has been closed.
match result {
Ok(Shutdown::Stopped) | Err(_) => {
// Soketto doesn't have a way to signal when the connection is closed
// thus just throw the data and terminate the stream once the connection has
// been terminated.
//
// The receiver is not cancel-safe such that it used in stream to enforce that.
let disconnect_stream = futures_util::stream::unfold((receiver, data), |(mut receiver, mut data)| async {
if let Err(SokettoError::Closed) = receiver.receive(&mut data).await {
None
} else {
Some(((), (receiver, data)))
}
});

let pending = pending_calls.for_each(|_| async {});
let disconnect = disconnect_stream.for_each(|_| async {});

tokio::select! {
_ = pending => (),
_ = disconnect => (),
}
}
Ok(Shutdown::ConnectionClosed) => (),
};

_ = conn_tx.send(());
graceful_shutdown(result, pending_calls, receiver, data, conn_tx, send_task_handle).await;

logger.on_disconnect(remote_addr, TransportProtocol::WebSocket);
drop(conn);
result
drop(stop_handle);
}

/// A task that waits for new messages via the `rx channel` and sends them out on the `WebSocket`.
Expand All @@ -371,7 +337,11 @@ async fn send_task(
stop: oneshot::Receiver<()>,
) {
// Interval to send out continuously `pings`.
let ping_interval = IntervalStream::new(tokio::time::interval(ping_interval));
let mut ping_interval = tokio::time::interval(ping_interval);
// This returns immediately so make sure it doesn't resolve before the ping_interval has been elapsed.
ping_interval.tick().await;

let ping_interval = IntervalStream::new(ping_interval);
let rx = ReceiverStream::new(rx);

tokio::pin!(ping_interval, rx, stop);
Expand Down Expand Up @@ -403,15 +373,18 @@ async fn send_task(
}

// Handle timer intervals.
Either::Right((Either::Left((_, stop)), next_rx)) => {
Either::Right((Either::Left((Some(_instant), stop)), next_rx)) => {
if let Err(err) = send_ping(&mut ws_sender).await {
tracing::debug!("WS transport error: send ping failed: {}", err);
break;
}

rx_item = next_rx;
futs = future::select(ping_interval.next(), stop);
}

Either::Right((Either::Left((None, _)), _)) => unreachable!("IntervalStream never terminates"),

// Server is stopped.
Either::Right((Either::Right(_), _)) => {
break;
Expand Down Expand Up @@ -578,7 +551,54 @@ async fn execute_unchecked_call<L: Logger>(params: ExecuteCallParams<L>) {
};
}

#[derive(Debug, Copy, Clone)]
pub(crate) enum Shutdown {
Stopped,
ConnectionClosed,
}

/// Enforce a graceful shutdown.
///
/// This will return once the connection has been terminated or all pending calls have been executed.
async fn graceful_shutdown<F: Future>(
result: Result<Shutdown, SokettoError>,
pending_calls: FuturesUnordered<F>,
receiver: Receiver,
data: Vec<u8>,
mut conn_tx: oneshot::Sender<()>,
send_task_handle: tokio::task::JoinHandle<()>,
) {
match result {
Ok(Shutdown::ConnectionClosed) | Err(SokettoError::Closed) => (),
Ok(Shutdown::Stopped) | Err(_) => {
// Soketto doesn't have a way to signal when the connection is closed
// thus just throw away the data and terminate the stream once the connection has
// been terminated.
//
// The receiver is not cancel-safe such that it's used in a stream to enforce that.
let disconnect_stream = futures_util::stream::unfold((receiver, data), |(mut receiver, mut data)| async {
if let Err(SokettoError::Closed) = receiver.receive(&mut data).await {
None
} else {
Some(((), (receiver, data)))
}
});

let graceful_shutdown = pending_calls.for_each(|_| async {});
let disconnect = disconnect_stream.for_each(|_| async {});

// All pending calls has been finished or the connection closed.
// Fine to terminate
tokio::select! {
_ = graceful_shutdown => {}
_ = disconnect => {}
_ = conn_tx.closed() => {}
}
}
};

// Send a message to close down the "send task".
_ = conn_tx.send(());
// Ensure that send task has been closed.
_ = send_task_handle.await;
}

0 comments on commit 2d24810

Please sign in to comment.