Skip to content

Commit

Permalink
Move heartbeat logic to ClientCMD.
Browse files Browse the repository at this point in the history
  • Loading branch information
shachlanAmazon committed Jun 25, 2023
1 parent 73c7094 commit c716cbb
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 132 deletions.
85 changes: 75 additions & 10 deletions babushka-core/src/client/client_cmd.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
use super::get_redis_connection_info;
use super::reconnecting_connection::ReconnectingConnection;
use crate::connection_request::{ConnectionRequest, TlsMode};
use crate::retry_strategies::RetryStrategy;
use logger_core::{log_debug, log_trace};
use redis::RedisResult;
use std::sync::Arc;
use tokio::task;

use super::get_redis_connection_info;
use super::reconnecting_connection::ReconnectingConnection;
struct DropWrapper {
/// Connection to the primary node in the client.
primary: ReconnectingConnection,
}

impl Drop for DropWrapper {
fn drop(&mut self) {
self.primary.mark_as_dropped();
}
}

#[derive(Clone)]
pub struct ClientCMD {
/// Connection to the primary node in the client.
primary: ReconnectingConnection,
inner: Arc<DropWrapper>,
}

impl ClientCMD {
Expand All @@ -27,27 +39,80 @@ impl ClientCMD {
tls_mode,
)
.await?;

Ok(Self { primary })
Self::start_heartbeat(primary.clone());
Ok(Self {
inner: Arc::new(DropWrapper { primary }),
})
}

pub async fn send_packed_command(
&mut self,
cmd: &redis::Cmd,
) -> redis::RedisResult<redis::Value> {
self.primary.send_packed_command(cmd).await
log_trace("ClientCMD", "sending command");
let mut connection = self.inner.primary.get_connection().await?;
let result = connection.send_packed_command(cmd).await;
match result {
Err(err) if err.is_connection_dropped() => {
self.inner.primary.reconnect().await;
Err(err)
}
_ => result,
}
}

pub(super) async fn send_packed_commands(
pub async fn send_packed_commands(
&mut self,
cmd: &redis::Pipeline,
offset: usize,
count: usize,
) -> redis::RedisResult<Vec<redis::Value>> {
self.primary.send_packed_commands(cmd, offset, count).await
let mut connection = self.inner.primary.get_connection().await?;
let result = connection.send_packed_commands(cmd, offset, count).await;
match result {
Err(err) if err.is_connection_dropped() => {
self.inner.primary.reconnect().await;
Err(err)
}
_ => result,
}
}

pub(super) fn get_db(&self) -> i64 {
self.primary.get_db()
self.inner.primary.get_db()
}

fn start_heartbeat(reconnecting_connection: ReconnectingConnection) {
task::spawn(async move {
loop {
tokio::time::sleep(super::HEARTBEAT_SLEEP_DURATION).await;
if reconnecting_connection.is_dropped() {
log_debug(
"ClientCMD",
"heartbeat stopped after connection was dropped",
);
// Client was dropped, heartbeat can stop.
return;
}

let Some(mut connection) = reconnecting_connection.try_get_connection().await else {
log_debug(
"ClientCMD",
"heartbeat stopped while connection is reconnecting",
);
// Client is reconnecting..
continue;
};
log_debug("ClientCMD", "performing heartbeat");
if connection
.send_packed_command(&redis::cmd("PING"))
.await
.is_err_and(|err| err.is_connection_dropped() || err.is_connection_refusal())
{
log_debug("ClientCMD", "heartbeat triggered reconnect");
reconnecting_connection.reconnect().await;
}
}
});
}
}
3 changes: 2 additions & 1 deletion babushka-core/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ use std::io;
use std::time::Duration;
mod client_cmd;
mod reconnecting_connection;
pub use reconnecting_connection::HEARTBEAT_SLEEP_DURATION;

pub const HEARTBEAT_SLEEP_DURATION: Duration = Duration::from_secs(1);

pub trait BabushkaClient: ConnectionLike + Send + Clone {}

Expand Down
161 changes: 40 additions & 121 deletions babushka-core/src/client/reconnecting_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,40 +36,9 @@ struct InnerReconnectingConnection {
backend: ConnectionBackend,
}

pub const HEARTBEAT_SLEEP_DURATION: Duration = Duration::from_secs(1);

impl InnerReconnectingConnection {
fn is_dropped(&self) -> bool {
self.backend.client_dropped_flagged.load(Ordering::Relaxed)
}

async fn try_get_connection(&self) -> Option<MultiplexedConnection> {
let guard = self.state.lock().await;
if let ConnectionState::Connected(connection) = &*guard {
Some(connection.clone())
} else {
None
}
}
}

/// The separation between an inner and outer connection is because the outer connection is clonable, and the inner connection needs to be dropped when no outer connection exists.
struct DropWrapper(Arc<InnerReconnectingConnection>);

impl Drop for DropWrapper {
fn drop(&mut self) {
self.0
.backend
.client_dropped_flagged
.store(true, Ordering::Relaxed);
}
}

#[derive(Clone)]
pub(super) struct ReconnectingConnection {
/// All of the connection's clones point to the same internal wrapper, which will be dropped only once,
/// when all of the clones have been dropped.
inner: Arc<DropWrapper>,
inner: Arc<InnerReconnectingConnection>,
}

async fn get_multiplexed_connection(client: &redis::Client) -> RedisResult<MultiplexedConnection> {
Expand All @@ -92,10 +61,10 @@ async fn try_create_connection(

let connection = Retry::spawn(retry_strategy.get_iterator(), action).await?;
Ok(ReconnectingConnection {
inner: Arc::new(DropWrapper(Arc::new(InnerReconnectingConnection {
inner: Arc::new(InnerReconnectingConnection {
state: Mutex::new(ConnectionState::Connected(connection)),
backend: connection_backend,
}))),
}),
})
}

Expand Down Expand Up @@ -129,7 +98,7 @@ impl ReconnectingConnection {
connection_retry_strategy: RetryStrategy,
redis_connection_info: RedisConnectionInfo,
tls_mode: TlsMode,
) -> RedisResult<Self> {
) -> RedisResult<ReconnectingConnection> {
log_debug(
"connection creation",
format!("Attempting connection to {address}"),
Expand All @@ -141,68 +110,51 @@ impl ReconnectingConnection {
client_dropped_flagged: AtomicBool::new(false),
};
let connection = try_create_connection(client, connection_retry_strategy).await?;
Self::start_heartbeat(connection.inner.0.clone());
log_debug(
"connection creation",
format!("Connection to {address} created"),
);
Ok(connection)
}

async fn get_connection(&self) -> Result<MultiplexedConnection, RedisError> {
loop {
self.inner
.0
.backend
.connection_available_signal
.wait()
.await;
if let Some(connection) = self.inner.0.try_get_connection().await {
return Ok(connection);
}
}
pub(super) fn is_dropped(&self) -> bool {
self.inner
.backend
.client_dropped_flagged
.load(Ordering::Relaxed)
}

fn start_heartbeat(reconnecting_connection: Arc<InnerReconnectingConnection>) {
task::spawn(async move {
loop {
tokio::time::sleep(HEARTBEAT_SLEEP_DURATION).await;
if reconnecting_connection.is_dropped() {
log_debug(
"ReconnectingConnection",
"heartbeat stopped after client was dropped",
);
// Client was dropped, heartbeat can stop.
return;
}
pub(super) fn mark_as_dropped(&self) {
self.inner
.backend
.client_dropped_flagged
.store(true, Ordering::Relaxed)
}

let Some(mut connection) = reconnecting_connection.try_get_connection().await else {
log_debug(
"ReconnectingConnection",
"heartbeat stopped while client is reconnecting",
);
// Client is reconnecting, heartbeat can stop. It will be restarted by the reconnect attempt once it succeeds.
return;
};
log_debug("ReconnectingConnection", "performing heartbeat");
if connection
.req_packed_command(&redis::cmd("PING"))
.await
.is_err_and(|err| err.is_connection_dropped() || err.is_connection_refusal())
{
log_debug("ReconnectingConnection", "heartbeat triggered reconnect");
Self::reconnect(&reconnecting_connection).await;
}
pub(super) async fn try_get_connection(&self) -> Option<MultiplexedConnection> {
let guard = self.inner.state.lock().await;
if let ConnectionState::Connected(connection) = &*guard {
Some(connection.clone())
} else {
None
}
}

pub(super) async fn get_connection(&self) -> Result<MultiplexedConnection, RedisError> {
loop {
self.inner.backend.connection_available_signal.wait().await;
if let Some(connection) = self.try_get_connection().await {
return Ok(connection);
}
});
}
}

async fn reconnect(connection: &Arc<InnerReconnectingConnection>) {
pub(super) async fn reconnect(&self) {
{
let mut guard = connection.state.lock().await;
let mut guard = self.inner.state.lock().await;
match &*guard {
ConnectionState::Connected(_) => {
connection.backend.connection_available_signal.reset();
self.inner.backend.connection_available_signal.reset();
}
_ => {
log_trace("reconnect", "already started");
Expand All @@ -213,13 +165,13 @@ impl ReconnectingConnection {
*guard = ConnectionState::Reconnecting;
};
log_debug("reconnect", "starting");
let inner_connection_clone = connection.clone();
let connection_clone = self.clone();
// The reconnect task is spawned instead of awaited here, so that the reconnect attempt will continue in the
// background, regardless of whether the calling task is dropped or not.
task::spawn(async move {
let client = &inner_connection_clone.backend.connection_info;
let client = &connection_clone.inner.backend.connection_info;
for sleep_duration in internal_retry_iterator() {
if inner_connection_clone.is_dropped() {
if connection_clone.is_dropped() {
log_debug(
"ReconnectingConnection",
"reconnect stopped after client was dropped",
Expand All @@ -231,15 +183,15 @@ impl ReconnectingConnection {
match get_multiplexed_connection(client).await {
Ok(connection) => {
{
let mut guard = inner_connection_clone.state.lock().await;
let mut guard = connection_clone.inner.state.lock().await;
log_debug("reconnect", "completed succesfully");
inner_connection_clone
connection_clone
.inner
.backend
.connection_available_signal
.set();
*guard = ConnectionState::Connected(connection);
}
Self::start_heartbeat(inner_connection_clone);
return;
}
Err(_) => tokio::time::sleep(sleep_duration).await,
Expand All @@ -248,41 +200,8 @@ impl ReconnectingConnection {
});
}

pub(super) async fn send_packed_command(
&mut self,
cmd: &redis::Cmd,
) -> redis::RedisResult<redis::Value> {
log_trace("ReconnectingConnection", "sending command");
let mut connection = self.get_connection().await?;
let result = connection.send_packed_command(cmd).await;
match result {
Err(err) if err.is_connection_dropped() => {
Self::reconnect(&self.inner.0).await;
Err(err)
}
_ => result,
}
}

pub(super) async fn send_packed_commands(
&mut self,
cmd: &redis::Pipeline,
offset: usize,
count: usize,
) -> redis::RedisResult<Vec<redis::Value>> {
let mut connection = self.get_connection().await?;
let result = connection.send_packed_commands(cmd, offset, count).await;
match result {
Err(err) if err.is_connection_dropped() => {
Self::reconnect(&self.inner.0).await;
Err(err)
}
_ => result,
}
}

pub(super) fn get_db(&self) -> i64 {
let guard = self.inner.0.state.blocking_lock();
let guard = self.inner.state.blocking_lock();
match &*guard {
ConnectionState::Connected(connection) => connection.get_db(),
_ => -1,
Expand Down

0 comments on commit c716cbb

Please sign in to comment.