Skip to content

Commit

Permalink
feat: expose ConnectionGuard as request extension (#1443)
Browse files Browse the repository at this point in the history
* feat: expose ConnectionGuard as request extension

* test: assert connection guard information via json-rpc interface

* chore: format code
  • Loading branch information
dinhani authored Aug 19, 2024
1 parent 6c2de38 commit 9b66e9c
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 4 deletions.
4 changes: 3 additions & 1 deletion server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1047,7 +1047,9 @@ where
let curr_conns = max_conns - conn_guard.available_connections();
tracing::debug!(target: LOG_TARGET, "Accepting new connection {}/{}", curr_conns, max_conns);

request.extensions_mut().insert::<ConnectionId>(conn.conn_id.into());
let req_ext = request.extensions_mut();
req_ext.insert::<ConnectionGuard>(conn_guard.clone());
req_ext.insert::<ConnectionId>(conn.conn_id.into());

let is_upgrade_request = is_upgrade_request(&request);

Expand Down
12 changes: 10 additions & 2 deletions tests/tests/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ use jsonrpsee::server::middleware::http::ProxyGetRequestLayer;

use jsonrpsee::server::middleware::rpc::RpcServiceT;
use jsonrpsee::server::{
serve_with_graceful_shutdown, stop_channel, PendingSubscriptionSink, RpcModule, RpcServiceBuilder, Server,
ServerBuilder, ServerHandle, SubscriptionMessage, TrySendError,
serve_with_graceful_shutdown, stop_channel, ConnectionGuard, PendingSubscriptionSink, RpcModule, RpcServiceBuilder,
Server, ServerBuilder, ServerHandle, SubscriptionMessage, TrySendError,
};
use jsonrpsee::types::{ErrorObject, ErrorObjectOwned};
use jsonrpsee::{Methods, SubscriptionCloseResponse};
Expand Down Expand Up @@ -165,6 +165,14 @@ pub async fn server() -> SocketAddr {
let mut module = RpcModule::new(());
module.register_method("say_hello", |_, _, _| "hello").unwrap();
module.register_method("get_connection_id", |_, _, ext| *ext.get::<u32>().unwrap()).unwrap();
module
.register_method("get_available_connections", |_, _, ext| {
ext.get::<ConnectionGuard>().unwrap().available_connections()
})
.unwrap();
module
.register_method("get_max_connections", |_, _, ext| ext.get::<ConnectionGuard>().unwrap().max_connections())
.unwrap();

module
.register_async_method("slow_hello", |_, _, _| async {
Expand Down
27 changes: 26 additions & 1 deletion tests/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ async fn ws_method_call_works_over_proxy_stream() {
}

#[tokio::test]
async fn extensions_with_different_ws_clients() {
async fn connection_id_extension_with_different_ws_clients() {
init_logger();

let server_addr = server().await;
Expand All @@ -225,6 +225,31 @@ async fn extensions_with_different_ws_clients() {
assert_ne!(connection_id, second_connection_id);
}

#[tokio::test]
async fn connection_guard_extension_with_different_ws_clients() {
init_logger();

let server_addr = server().await;
let server_url = format!("ws://{}", server_addr);

// First connected retrieves initial information from ConnectionGuard.
let first_client = WsClientBuilder::default().build(&server_url).await.unwrap();
let first_max_connections: usize = first_client.request("get_max_connections", rpc_params![]).await.unwrap();
let first_available_connections: usize =
first_client.request("get_available_connections", rpc_params![]).await.unwrap();

assert_eq!(first_available_connections, first_max_connections - 1);

// Second client ensure max connections stays the same, but available connections is decreased.
let second_client = WsClientBuilder::default().build(&server_url).await.unwrap();
let second_max_connections: usize = second_client.request("get_max_connections", rpc_params![]).await.unwrap();
let second_available_connections: usize =
second_client.request("get_available_connections", rpc_params![]).await.unwrap();

assert_eq!(second_max_connections, first_max_connections);
assert_eq!(second_available_connections, second_max_connections - 2);
}

#[tokio::test]
async fn ws_method_call_str_id_works() {
init_logger();
Expand Down

0 comments on commit 9b66e9c

Please sign in to comment.