From 73c7094c89eb510677642a34708cfcf3fd030f9b Mon Sep 17 00:00:00 2001 From: Shachar Langbeheim Date: Sun, 11 Jun 2023 14:48:12 +0000 Subject: [PATCH 1/2] ReconnectingConnection: Add heartbeat check. --- babushka-core/src/client/mod.rs | 1 + .../src/client/reconnecting_connection.rs | 98 ++++++++++++++----- babushka-core/tests/test_client_cmd.rs | 38 +++++++ 3 files changed, 112 insertions(+), 25 deletions(-) diff --git a/babushka-core/src/client/mod.rs b/babushka-core/src/client/mod.rs index a5acb0c854..077b6d8183 100644 --- a/babushka-core/src/client/mod.rs +++ b/babushka-core/src/client/mod.rs @@ -10,6 +10,7 @@ use std::io; use std::time::Duration; mod client_cmd; mod reconnecting_connection; +pub use reconnecting_connection::HEARTBEAT_SLEEP_DURATION; pub trait BabushkaClient: ConnectionLike + Send + Clone {} diff --git a/babushka-core/src/client/reconnecting_connection.rs b/babushka-core/src/client/reconnecting_connection.rs index 8806706064..0b4630d114 100644 --- a/babushka-core/src/client/reconnecting_connection.rs +++ b/babushka-core/src/client/reconnecting_connection.rs @@ -36,6 +36,23 @@ struct InnerReconnectingConnection { backend: ConnectionBackend, } +pub const HEARTBEAT_SLEEP_DURATION: Duration = Duration::from_secs(1); + +impl InnerReconnectingConnection { + fn is_dropped(&self) -> bool { + self.backend.client_dropped_flagged.load(Ordering::Relaxed) + } + + async fn try_get_connection(&self) -> Option { + let guard = self.state.lock().await; + if let ConnectionState::Connected(connection) = &*guard { + Some(connection.clone()) + } else { + None + } + } +} + /// The separation between an inner and outer connection is because the outer connection is clonable, and the inner connection needs to be dropped when no outer connection exists. struct DropWrapper(Arc); @@ -124,6 +141,7 @@ impl ReconnectingConnection { client_dropped_flagged: AtomicBool::new(false), }; let connection = try_create_connection(client, connection_retry_strategy).await?; + Self::start_heartbeat(connection.inner.0.clone()); log_debug( "connection creation", format!("Connection to {address} created"), @@ -139,21 +157,52 @@ impl ReconnectingConnection { .connection_available_signal .wait() .await; - { - let guard = self.inner.0.state.lock().await; - if let ConnectionState::Connected(connection) = &*guard { - return Ok(connection.clone()); - } - }; + if let Some(connection) = self.inner.0.try_get_connection().await { + return Ok(connection); + } } } - async fn reconnect(&self) { + fn start_heartbeat(reconnecting_connection: Arc) { + task::spawn(async move { + loop { + tokio::time::sleep(HEARTBEAT_SLEEP_DURATION).await; + if reconnecting_connection.is_dropped() { + log_debug( + "ReconnectingConnection", + "heartbeat stopped after client was dropped", + ); + // Client was dropped, heartbeat can stop. + return; + } + + let Some(mut connection) = reconnecting_connection.try_get_connection().await else { + log_debug( + "ReconnectingConnection", + "heartbeat stopped while client is reconnecting", + ); + // Client is reconnecting, heartbeat can stop. It will be restarted by the reconnect attempt once it succeeds. + return; + }; + log_debug("ReconnectingConnection", "performing heartbeat"); + if connection + .req_packed_command(&redis::cmd("PING")) + .await + .is_err_and(|err| err.is_connection_dropped() || err.is_connection_refusal()) + { + log_debug("ReconnectingConnection", "heartbeat triggered reconnect"); + Self::reconnect(&reconnecting_connection).await; + } + } + }); + } + + async fn reconnect(connection: &Arc) { { - let mut guard = self.inner.0.state.lock().await; + let mut guard = connection.state.lock().await; match &*guard { ConnectionState::Connected(_) => { - self.inner.0.backend.connection_available_signal.reset(); + connection.backend.connection_available_signal.reset(); } _ => { log_trace("reconnect", "already started"); @@ -164,18 +213,14 @@ impl ReconnectingConnection { *guard = ConnectionState::Reconnecting; }; log_debug("reconnect", "starting"); - let inner_connection_clone = self.inner.0.clone(); + let inner_connection_clone = connection.clone(); // The reconnect task is spawned instead of awaited here, so that the reconnect attempt will continue in the // background, regardless of whether the calling task is dropped or not. task::spawn(async move { let client = &inner_connection_clone.backend.connection_info; for sleep_duration in internal_retry_iterator() { - if inner_connection_clone - .backend - .client_dropped_flagged - .load(Ordering::Relaxed) - { - log_trace( + if inner_connection_clone.is_dropped() { + log_debug( "ReconnectingConnection", "reconnect stopped after client was dropped", ); @@ -185,13 +230,16 @@ impl ReconnectingConnection { log_debug("connection creation", "Creating multiplexed connection"); match get_multiplexed_connection(client).await { Ok(connection) => { - let mut guard = inner_connection_clone.state.lock().await; - log_debug("reconnect", "completed succesfully"); - inner_connection_clone - .backend - .connection_available_signal - .set(); - *guard = ConnectionState::Connected(connection); + { + let mut guard = inner_connection_clone.state.lock().await; + log_debug("reconnect", "completed succesfully"); + inner_connection_clone + .backend + .connection_available_signal + .set(); + *guard = ConnectionState::Connected(connection); + } + Self::start_heartbeat(inner_connection_clone); return; } Err(_) => tokio::time::sleep(sleep_duration).await, @@ -209,7 +257,7 @@ impl ReconnectingConnection { let result = connection.send_packed_command(cmd).await; match result { Err(err) if err.is_connection_dropped() => { - self.reconnect().await; + Self::reconnect(&self.inner.0).await; Err(err) } _ => result, @@ -226,7 +274,7 @@ impl ReconnectingConnection { let result = connection.send_packed_commands(cmd, offset, count).await; match result { Err(err) if err.is_connection_dropped() => { - self.reconnect().await; + Self::reconnect(&self.inner.0).await; Err(err) } _ => result, diff --git a/babushka-core/tests/test_client_cmd.rs b/babushka-core/tests/test_client_cmd.rs index b2c387d7f3..d118417efd 100644 --- a/babushka-core/tests/test_client_cmd.rs +++ b/babushka-core/tests/test_client_cmd.rs @@ -23,6 +23,7 @@ mod client_cmd_tests { let address = server.get_client_addr(); drop(server); + // we use another thread, so that the creation of the server won't block the client work. let thread = std::thread::spawn(move || { block_on_all(async move { let mut get_command = redis::Cmd::new(); @@ -53,4 +54,41 @@ mod client_cmd_tests { thread.join().unwrap(); }); } + + #[rstest] + #[timeout(LONG_CMD_TEST_TIMEOUT)] + fn test_detect_disconnect_and_reconnect_using_heartbeat(#[values(false, true)] use_tls: bool) { + let (sender, receiver) = tokio::sync::oneshot::channel(); + block_on_all(async move { + let mut test_basics = setup_test_basics(use_tls).await; + let server = test_basics.server; + let address = server.get_client_addr(); + println!("dropping server"); + drop(server); + + // we use another thread, so that the creation of the server won't block the client work. + std::thread::spawn(move || { + block_on_all(async move { + let new_server = RedisServer::new_with_addr_and_modules(address.clone(), &[]); + wait_for_server_to_become_ready(&address).await; + let _ = sender.send(new_server); + }) + }); + + let _new_server = receiver.await; + tokio::time::sleep(babushka::client::HEARTBEAT_SLEEP_DURATION + Duration::from_secs(1)) + .await; + + let mut get_command = redis::Cmd::new(); + get_command + .arg("GET") + .arg("test_detect_disconnect_and_reconnect_using_heartbeat"); + let get_result = test_basics + .client + .send_packed_command(&get_command) + .await + .unwrap(); + assert_eq!(get_result, Value::Nil); + }); + } } From c716cbb5947ca7f7a8d58ab88c3bc599f58f3738 Mon Sep 17 00:00:00 2001 From: Shachar Langbeheim Date: Tue, 13 Jun 2023 15:52:18 +0000 Subject: [PATCH 2/2] Move heartbeat logic to ClientCMD. --- babushka-core/src/client/client_cmd.rs | 85 +++++++-- babushka-core/src/client/mod.rs | 3 +- .../src/client/reconnecting_connection.rs | 161 +++++------------- 3 files changed, 117 insertions(+), 132 deletions(-) diff --git a/babushka-core/src/client/client_cmd.rs b/babushka-core/src/client/client_cmd.rs index 4b46d1718d..688931dfbe 100644 --- a/babushka-core/src/client/client_cmd.rs +++ b/babushka-core/src/client/client_cmd.rs @@ -1,14 +1,26 @@ +use super::get_redis_connection_info; +use super::reconnecting_connection::ReconnectingConnection; use crate::connection_request::{ConnectionRequest, TlsMode}; use crate::retry_strategies::RetryStrategy; +use logger_core::{log_debug, log_trace}; use redis::RedisResult; +use std::sync::Arc; +use tokio::task; -use super::get_redis_connection_info; -use super::reconnecting_connection::ReconnectingConnection; +struct DropWrapper { + /// Connection to the primary node in the client. + primary: ReconnectingConnection, +} + +impl Drop for DropWrapper { + fn drop(&mut self) { + self.primary.mark_as_dropped(); + } +} #[derive(Clone)] pub struct ClientCMD { - /// Connection to the primary node in the client. - primary: ReconnectingConnection, + inner: Arc, } impl ClientCMD { @@ -27,27 +39,80 @@ impl ClientCMD { tls_mode, ) .await?; - - Ok(Self { primary }) + Self::start_heartbeat(primary.clone()); + Ok(Self { + inner: Arc::new(DropWrapper { primary }), + }) } pub async fn send_packed_command( &mut self, cmd: &redis::Cmd, ) -> redis::RedisResult { - self.primary.send_packed_command(cmd).await + log_trace("ClientCMD", "sending command"); + let mut connection = self.inner.primary.get_connection().await?; + let result = connection.send_packed_command(cmd).await; + match result { + Err(err) if err.is_connection_dropped() => { + self.inner.primary.reconnect().await; + Err(err) + } + _ => result, + } } - pub(super) async fn send_packed_commands( + pub async fn send_packed_commands( &mut self, cmd: &redis::Pipeline, offset: usize, count: usize, ) -> redis::RedisResult> { - self.primary.send_packed_commands(cmd, offset, count).await + let mut connection = self.inner.primary.get_connection().await?; + let result = connection.send_packed_commands(cmd, offset, count).await; + match result { + Err(err) if err.is_connection_dropped() => { + self.inner.primary.reconnect().await; + Err(err) + } + _ => result, + } } pub(super) fn get_db(&self) -> i64 { - self.primary.get_db() + self.inner.primary.get_db() + } + + fn start_heartbeat(reconnecting_connection: ReconnectingConnection) { + task::spawn(async move { + loop { + tokio::time::sleep(super::HEARTBEAT_SLEEP_DURATION).await; + if reconnecting_connection.is_dropped() { + log_debug( + "ClientCMD", + "heartbeat stopped after connection was dropped", + ); + // Client was dropped, heartbeat can stop. + return; + } + + let Some(mut connection) = reconnecting_connection.try_get_connection().await else { + log_debug( + "ClientCMD", + "heartbeat stopped while connection is reconnecting", + ); + // Client is reconnecting.. + continue; + }; + log_debug("ClientCMD", "performing heartbeat"); + if connection + .send_packed_command(&redis::cmd("PING")) + .await + .is_err_and(|err| err.is_connection_dropped() || err.is_connection_refusal()) + { + log_debug("ClientCMD", "heartbeat triggered reconnect"); + reconnecting_connection.reconnect().await; + } + } + }); } } diff --git a/babushka-core/src/client/mod.rs b/babushka-core/src/client/mod.rs index 077b6d8183..4d4d25c867 100644 --- a/babushka-core/src/client/mod.rs +++ b/babushka-core/src/client/mod.rs @@ -10,7 +10,8 @@ use std::io; use std::time::Duration; mod client_cmd; mod reconnecting_connection; -pub use reconnecting_connection::HEARTBEAT_SLEEP_DURATION; + +pub const HEARTBEAT_SLEEP_DURATION: Duration = Duration::from_secs(1); pub trait BabushkaClient: ConnectionLike + Send + Clone {} diff --git a/babushka-core/src/client/reconnecting_connection.rs b/babushka-core/src/client/reconnecting_connection.rs index 0b4630d114..6a0bcb711d 100644 --- a/babushka-core/src/client/reconnecting_connection.rs +++ b/babushka-core/src/client/reconnecting_connection.rs @@ -36,40 +36,9 @@ struct InnerReconnectingConnection { backend: ConnectionBackend, } -pub const HEARTBEAT_SLEEP_DURATION: Duration = Duration::from_secs(1); - -impl InnerReconnectingConnection { - fn is_dropped(&self) -> bool { - self.backend.client_dropped_flagged.load(Ordering::Relaxed) - } - - async fn try_get_connection(&self) -> Option { - let guard = self.state.lock().await; - if let ConnectionState::Connected(connection) = &*guard { - Some(connection.clone()) - } else { - None - } - } -} - -/// The separation between an inner and outer connection is because the outer connection is clonable, and the inner connection needs to be dropped when no outer connection exists. -struct DropWrapper(Arc); - -impl Drop for DropWrapper { - fn drop(&mut self) { - self.0 - .backend - .client_dropped_flagged - .store(true, Ordering::Relaxed); - } -} - #[derive(Clone)] pub(super) struct ReconnectingConnection { - /// All of the connection's clones point to the same internal wrapper, which will be dropped only once, - /// when all of the clones have been dropped. - inner: Arc, + inner: Arc, } async fn get_multiplexed_connection(client: &redis::Client) -> RedisResult { @@ -92,10 +61,10 @@ async fn try_create_connection( let connection = Retry::spawn(retry_strategy.get_iterator(), action).await?; Ok(ReconnectingConnection { - inner: Arc::new(DropWrapper(Arc::new(InnerReconnectingConnection { + inner: Arc::new(InnerReconnectingConnection { state: Mutex::new(ConnectionState::Connected(connection)), backend: connection_backend, - }))), + }), }) } @@ -129,7 +98,7 @@ impl ReconnectingConnection { connection_retry_strategy: RetryStrategy, redis_connection_info: RedisConnectionInfo, tls_mode: TlsMode, - ) -> RedisResult { + ) -> RedisResult { log_debug( "connection creation", format!("Attempting connection to {address}"), @@ -141,7 +110,6 @@ impl ReconnectingConnection { client_dropped_flagged: AtomicBool::new(false), }; let connection = try_create_connection(client, connection_retry_strategy).await?; - Self::start_heartbeat(connection.inner.0.clone()); log_debug( "connection creation", format!("Connection to {address} created"), @@ -149,60 +117,44 @@ impl ReconnectingConnection { Ok(connection) } - async fn get_connection(&self) -> Result { - loop { - self.inner - .0 - .backend - .connection_available_signal - .wait() - .await; - if let Some(connection) = self.inner.0.try_get_connection().await { - return Ok(connection); - } - } + pub(super) fn is_dropped(&self) -> bool { + self.inner + .backend + .client_dropped_flagged + .load(Ordering::Relaxed) } - fn start_heartbeat(reconnecting_connection: Arc) { - task::spawn(async move { - loop { - tokio::time::sleep(HEARTBEAT_SLEEP_DURATION).await; - if reconnecting_connection.is_dropped() { - log_debug( - "ReconnectingConnection", - "heartbeat stopped after client was dropped", - ); - // Client was dropped, heartbeat can stop. - return; - } + pub(super) fn mark_as_dropped(&self) { + self.inner + .backend + .client_dropped_flagged + .store(true, Ordering::Relaxed) + } - let Some(mut connection) = reconnecting_connection.try_get_connection().await else { - log_debug( - "ReconnectingConnection", - "heartbeat stopped while client is reconnecting", - ); - // Client is reconnecting, heartbeat can stop. It will be restarted by the reconnect attempt once it succeeds. - return; - }; - log_debug("ReconnectingConnection", "performing heartbeat"); - if connection - .req_packed_command(&redis::cmd("PING")) - .await - .is_err_and(|err| err.is_connection_dropped() || err.is_connection_refusal()) - { - log_debug("ReconnectingConnection", "heartbeat triggered reconnect"); - Self::reconnect(&reconnecting_connection).await; - } + pub(super) async fn try_get_connection(&self) -> Option { + let guard = self.inner.state.lock().await; + if let ConnectionState::Connected(connection) = &*guard { + Some(connection.clone()) + } else { + None + } + } + + pub(super) async fn get_connection(&self) -> Result { + loop { + self.inner.backend.connection_available_signal.wait().await; + if let Some(connection) = self.try_get_connection().await { + return Ok(connection); } - }); + } } - async fn reconnect(connection: &Arc) { + pub(super) async fn reconnect(&self) { { - let mut guard = connection.state.lock().await; + let mut guard = self.inner.state.lock().await; match &*guard { ConnectionState::Connected(_) => { - connection.backend.connection_available_signal.reset(); + self.inner.backend.connection_available_signal.reset(); } _ => { log_trace("reconnect", "already started"); @@ -213,13 +165,13 @@ impl ReconnectingConnection { *guard = ConnectionState::Reconnecting; }; log_debug("reconnect", "starting"); - let inner_connection_clone = connection.clone(); + let connection_clone = self.clone(); // The reconnect task is spawned instead of awaited here, so that the reconnect attempt will continue in the // background, regardless of whether the calling task is dropped or not. task::spawn(async move { - let client = &inner_connection_clone.backend.connection_info; + let client = &connection_clone.inner.backend.connection_info; for sleep_duration in internal_retry_iterator() { - if inner_connection_clone.is_dropped() { + if connection_clone.is_dropped() { log_debug( "ReconnectingConnection", "reconnect stopped after client was dropped", @@ -231,15 +183,15 @@ impl ReconnectingConnection { match get_multiplexed_connection(client).await { Ok(connection) => { { - let mut guard = inner_connection_clone.state.lock().await; + let mut guard = connection_clone.inner.state.lock().await; log_debug("reconnect", "completed succesfully"); - inner_connection_clone + connection_clone + .inner .backend .connection_available_signal .set(); *guard = ConnectionState::Connected(connection); } - Self::start_heartbeat(inner_connection_clone); return; } Err(_) => tokio::time::sleep(sleep_duration).await, @@ -248,41 +200,8 @@ impl ReconnectingConnection { }); } - pub(super) async fn send_packed_command( - &mut self, - cmd: &redis::Cmd, - ) -> redis::RedisResult { - log_trace("ReconnectingConnection", "sending command"); - let mut connection = self.get_connection().await?; - let result = connection.send_packed_command(cmd).await; - match result { - Err(err) if err.is_connection_dropped() => { - Self::reconnect(&self.inner.0).await; - Err(err) - } - _ => result, - } - } - - pub(super) async fn send_packed_commands( - &mut self, - cmd: &redis::Pipeline, - offset: usize, - count: usize, - ) -> redis::RedisResult> { - let mut connection = self.get_connection().await?; - let result = connection.send_packed_commands(cmd, offset, count).await; - match result { - Err(err) if err.is_connection_dropped() => { - Self::reconnect(&self.inner.0).await; - Err(err) - } - _ => result, - } - } - pub(super) fn get_db(&self) -> i64 { - let guard = self.inner.0.state.blocking_lock(); + let guard = self.inner.state.blocking_lock(); match &*guard { ConnectionState::Connected(connection) => connection.get_db(), _ => -1,