Skip to content

Commit

Permalink
Merge pull request #264 from shachlanAmazon/heartbeat
Browse files Browse the repository at this point in the history
ReconnectingConnection: Add heartbeat check.
  • Loading branch information
shachlanAmazon authored Jun 25, 2023
2 parents 6ed7aa6 + c716cbb commit 5f6cc1f
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 93 deletions.
85 changes: 75 additions & 10 deletions babushka-core/src/client/client_cmd.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
use super::get_redis_connection_info;
use super::reconnecting_connection::ReconnectingConnection;
use crate::connection_request::{ConnectionRequest, TlsMode};
use crate::retry_strategies::RetryStrategy;
use logger_core::{log_debug, log_trace};
use redis::RedisResult;
use std::sync::Arc;
use tokio::task;

use super::get_redis_connection_info;
use super::reconnecting_connection::ReconnectingConnection;
struct DropWrapper {
/// Connection to the primary node in the client.
primary: ReconnectingConnection,
}

impl Drop for DropWrapper {
fn drop(&mut self) {
self.primary.mark_as_dropped();
}
}

#[derive(Clone)]
pub struct ClientCMD {
/// Connection to the primary node in the client.
primary: ReconnectingConnection,
inner: Arc<DropWrapper>,
}

impl ClientCMD {
Expand All @@ -27,27 +39,80 @@ impl ClientCMD {
tls_mode,
)
.await?;

Ok(Self { primary })
Self::start_heartbeat(primary.clone());
Ok(Self {
inner: Arc::new(DropWrapper { primary }),
})
}

pub async fn send_packed_command(
&mut self,
cmd: &redis::Cmd,
) -> redis::RedisResult<redis::Value> {
self.primary.send_packed_command(cmd).await
log_trace("ClientCMD", "sending command");
let mut connection = self.inner.primary.get_connection().await?;
let result = connection.send_packed_command(cmd).await;
match result {
Err(err) if err.is_connection_dropped() => {
self.inner.primary.reconnect().await;
Err(err)
}
_ => result,
}
}

pub(super) async fn send_packed_commands(
pub async fn send_packed_commands(
&mut self,
cmd: &redis::Pipeline,
offset: usize,
count: usize,
) -> redis::RedisResult<Vec<redis::Value>> {
self.primary.send_packed_commands(cmd, offset, count).await
let mut connection = self.inner.primary.get_connection().await?;
let result = connection.send_packed_commands(cmd, offset, count).await;
match result {
Err(err) if err.is_connection_dropped() => {
self.inner.primary.reconnect().await;
Err(err)
}
_ => result,
}
}

pub(super) fn get_db(&self) -> i64 {
self.primary.get_db()
self.inner.primary.get_db()
}

fn start_heartbeat(reconnecting_connection: ReconnectingConnection) {
task::spawn(async move {
loop {
tokio::time::sleep(super::HEARTBEAT_SLEEP_DURATION).await;
if reconnecting_connection.is_dropped() {
log_debug(
"ClientCMD",
"heartbeat stopped after connection was dropped",
);
// Client was dropped, heartbeat can stop.
return;
}

let Some(mut connection) = reconnecting_connection.try_get_connection().await else {
log_debug(
"ClientCMD",
"heartbeat stopped while connection is reconnecting",
);
// Client is reconnecting..
continue;
};
log_debug("ClientCMD", "performing heartbeat");
if connection
.send_packed_command(&redis::cmd("PING"))
.await
.is_err_and(|err| err.is_connection_dropped() || err.is_connection_refusal())
{
log_debug("ClientCMD", "heartbeat triggered reconnect");
reconnecting_connection.reconnect().await;
}
}
});
}
}
2 changes: 2 additions & 0 deletions babushka-core/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ use std::time::Duration;
mod client_cmd;
mod reconnecting_connection;

pub const HEARTBEAT_SLEEP_DURATION: Duration = Duration::from_secs(1);

pub trait BabushkaClient: ConnectionLike + Send + Clone {}

impl BabushkaClient for MultiplexedConnection {}
Expand Down
133 changes: 50 additions & 83 deletions babushka-core/src/client/reconnecting_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,9 @@ struct InnerReconnectingConnection {
backend: ConnectionBackend,
}

/// The separation between an inner and outer connection is because the outer connection is clonable, and the inner connection needs to be dropped when no outer connection exists.
struct DropWrapper(Arc<InnerReconnectingConnection>);

impl Drop for DropWrapper {
fn drop(&mut self) {
self.0
.backend
.client_dropped_flagged
.store(true, Ordering::Relaxed);
}
}

#[derive(Clone)]
pub(super) struct ReconnectingConnection {
/// All of the connection's clones point to the same internal wrapper, which will be dropped only once,
/// when all of the clones have been dropped.
inner: Arc<DropWrapper>,
inner: Arc<InnerReconnectingConnection>,
}

async fn get_multiplexed_connection(client: &redis::Client) -> RedisResult<MultiplexedConnection> {
Expand All @@ -75,10 +61,10 @@ async fn try_create_connection(

let connection = Retry::spawn(retry_strategy.get_iterator(), action).await?;
Ok(ReconnectingConnection {
inner: Arc::new(DropWrapper(Arc::new(InnerReconnectingConnection {
inner: Arc::new(InnerReconnectingConnection {
state: Mutex::new(ConnectionState::Connected(connection)),
backend: connection_backend,
}))),
}),
})
}

Expand Down Expand Up @@ -112,7 +98,7 @@ impl ReconnectingConnection {
connection_retry_strategy: RetryStrategy,
redis_connection_info: RedisConnectionInfo,
tls_mode: TlsMode,
) -> RedisResult<Self> {
) -> RedisResult<ReconnectingConnection> {
log_debug(
"connection creation",
format!("Attempting connection to {address}"),
Expand All @@ -131,29 +117,44 @@ impl ReconnectingConnection {
Ok(connection)
}

async fn get_connection(&self) -> Result<MultiplexedConnection, RedisError> {
pub(super) fn is_dropped(&self) -> bool {
self.inner
.backend
.client_dropped_flagged
.load(Ordering::Relaxed)
}

pub(super) fn mark_as_dropped(&self) {
self.inner
.backend
.client_dropped_flagged
.store(true, Ordering::Relaxed)
}

pub(super) async fn try_get_connection(&self) -> Option<MultiplexedConnection> {
let guard = self.inner.state.lock().await;
if let ConnectionState::Connected(connection) = &*guard {
Some(connection.clone())
} else {
None
}
}

pub(super) async fn get_connection(&self) -> Result<MultiplexedConnection, RedisError> {
loop {
self.inner
.0
.backend
.connection_available_signal
.wait()
.await;
{
let guard = self.inner.0.state.lock().await;
if let ConnectionState::Connected(connection) = &*guard {
return Ok(connection.clone());
}
};
self.inner.backend.connection_available_signal.wait().await;
if let Some(connection) = self.try_get_connection().await {
return Ok(connection);
}
}
}

async fn reconnect(&self) {
pub(super) async fn reconnect(&self) {
{
let mut guard = self.inner.0.state.lock().await;
let mut guard = self.inner.state.lock().await;
match &*guard {
ConnectionState::Connected(_) => {
self.inner.0.backend.connection_available_signal.reset();
self.inner.backend.connection_available_signal.reset();
}
_ => {
log_trace("reconnect", "already started");
Expand All @@ -164,18 +165,14 @@ impl ReconnectingConnection {
*guard = ConnectionState::Reconnecting;
};
log_debug("reconnect", "starting");
let inner_connection_clone = self.inner.0.clone();
let connection_clone = self.clone();
// The reconnect task is spawned instead of awaited here, so that the reconnect attempt will continue in the
// background, regardless of whether the calling task is dropped or not.
task::spawn(async move {
let client = &inner_connection_clone.backend.connection_info;
let client = &connection_clone.inner.backend.connection_info;
for sleep_duration in internal_retry_iterator() {
if inner_connection_clone
.backend
.client_dropped_flagged
.load(Ordering::Relaxed)
{
log_trace(
if connection_clone.is_dropped() {
log_debug(
"ReconnectingConnection",
"reconnect stopped after client was dropped",
);
Expand All @@ -185,13 +182,16 @@ impl ReconnectingConnection {
log_debug("connection creation", "Creating multiplexed connection");
match get_multiplexed_connection(client).await {
Ok(connection) => {
let mut guard = inner_connection_clone.state.lock().await;
log_debug("reconnect", "completed succesfully");
inner_connection_clone
.backend
.connection_available_signal
.set();
*guard = ConnectionState::Connected(connection);
{
let mut guard = connection_clone.inner.state.lock().await;
log_debug("reconnect", "completed succesfully");
connection_clone
.inner
.backend
.connection_available_signal
.set();
*guard = ConnectionState::Connected(connection);
}
return;
}
Err(_) => tokio::time::sleep(sleep_duration).await,
Expand All @@ -200,41 +200,8 @@ impl ReconnectingConnection {
});
}

pub(super) async fn send_packed_command(
&mut self,
cmd: &redis::Cmd,
) -> redis::RedisResult<redis::Value> {
log_trace("ReconnectingConnection", "sending command");
let mut connection = self.get_connection().await?;
let result = connection.send_packed_command(cmd).await;
match result {
Err(err) if err.is_connection_dropped() => {
self.reconnect().await;
Err(err)
}
_ => result,
}
}

pub(super) async fn send_packed_commands(
&mut self,
cmd: &redis::Pipeline,
offset: usize,
count: usize,
) -> redis::RedisResult<Vec<redis::Value>> {
let mut connection = self.get_connection().await?;
let result = connection.send_packed_commands(cmd, offset, count).await;
match result {
Err(err) if err.is_connection_dropped() => {
self.reconnect().await;
Err(err)
}
_ => result,
}
}

pub(super) fn get_db(&self) -> i64 {
let guard = self.inner.0.state.blocking_lock();
let guard = self.inner.state.blocking_lock();
match &*guard {
ConnectionState::Connected(connection) => connection.get_db(),
_ => -1,
Expand Down
38 changes: 38 additions & 0 deletions babushka-core/tests/test_client_cmd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ mod client_cmd_tests {
let address = server.get_client_addr();
drop(server);

// we use another thread, so that the creation of the server won't block the client work.
let thread = std::thread::spawn(move || {
block_on_all(async move {
let mut get_command = redis::Cmd::new();
Expand Down Expand Up @@ -53,4 +54,41 @@ mod client_cmd_tests {
thread.join().unwrap();
});
}

#[rstest]
#[timeout(LONG_CMD_TEST_TIMEOUT)]
fn test_detect_disconnect_and_reconnect_using_heartbeat(#[values(false, true)] use_tls: bool) {
let (sender, receiver) = tokio::sync::oneshot::channel();
block_on_all(async move {
let mut test_basics = setup_test_basics(use_tls).await;
let server = test_basics.server;
let address = server.get_client_addr();
println!("dropping server");
drop(server);

// we use another thread, so that the creation of the server won't block the client work.
std::thread::spawn(move || {
block_on_all(async move {
let new_server = RedisServer::new_with_addr_and_modules(address.clone(), &[]);
wait_for_server_to_become_ready(&address).await;
let _ = sender.send(new_server);
})
});

let _new_server = receiver.await;
tokio::time::sleep(babushka::client::HEARTBEAT_SLEEP_DURATION + Duration::from_secs(1))
.await;

let mut get_command = redis::Cmd::new();
get_command
.arg("GET")
.arg("test_detect_disconnect_and_reconnect_using_heartbeat");
let get_result = test_basics
.client
.send_packed_command(&get_command)
.await
.unwrap();
assert_eq!(get_result, Value::Nil);
});
}
}

0 comments on commit 5f6cc1f

Please sign in to comment.