Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add BrokerPool #4

Merged
merged 3 commits into from
Jan 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ edition = "2021"
[dependencies]
async-trait = "0.1"
pin-project-lite = "0.2"
rand = "0.8"
thiserror = "1.0"
tokio = {version = "1.14", default-features = false, features = ["io-util", "net", "rt", "sync"]}
tokio = {version = "1.14", default-features = false, features = ["io-util", "net", "rt", "sync", "time"]}
varint-rs = "2.2"

[dev-dependencies]
Expand Down
113 changes: 113 additions & 0 deletions src/backoff.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
use rand::prelude::*;
use std::time::Duration;

/// Exponential backoff with jitter
///
/// See <https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/>
#[derive(Debug, Clone)]
pub struct BackoffConfig {
pub init_backoff: Duration,
pub max_backoff: Duration,
pub base: f64,
}

impl Default for BackoffConfig {
fn default() -> Self {
Self {
init_backoff: Duration::from_secs(1),
max_backoff: Duration::from_secs(500),
base: 3.,
}
}
}

/// [`Backoff`] can be created from a [`BackoffConfig`]
///
/// Consecutive calls to [`Backoff::next`] will return the next backoff interval
///
#[derive(Debug)]
pub struct Backoff<R> {
init_backoff: f64,
next_backoff_secs: f64,
max_backoff_secs: f64,
base: f64,
rng: R,
}

impl Backoff<ThreadRng> {
/// Create a new [`Backoff`] from the provided [`BackoffConfig`]
pub fn new(config: &BackoffConfig) -> Self {
Self::new_with_rng(config, thread_rng())
}
}

impl<R: Rng> Backoff<R> {
pub fn new_with_rng(config: &BackoffConfig, rng: R) -> Self {
let init_backoff = config.init_backoff.as_secs_f64();
Self {
init_backoff,
next_backoff_secs: init_backoff,
max_backoff_secs: config.max_backoff.as_secs_f64(),
base: config.base,
rng,
}
}

/// Returns the next backoff duration to wait for
pub fn next(&mut self) -> Duration {
let next_backoff = self.max_backoff_secs.min(
self.rng
.gen_range(self.init_backoff..(self.next_backoff_secs * self.base)),
);
Duration::from_secs_f64(std::mem::replace(&mut self.next_backoff_secs, next_backoff))
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
}
}
// TODO: tests (either statistical or by using a fixed RNG)


#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::mock::StepRng;

#[test]
fn test_backoff() {
let init_backoff_secs = 1.;
let max_backoff_secs = 500.;
let base = 3.;

let config = BackoffConfig {
init_backoff: Duration::from_secs_f64(init_backoff_secs),
max_backoff: Duration::from_secs_f64(max_backoff_secs),
base,
};

let assert_fuzzy_eq = |a: f64, b: f64| assert!((b - a).abs() < 0.0001, "{} != {}", a, b);

// Create a static rng that takes the minimum of the range
let rng = StepRng::new(0, 0);
let mut backoff = Backoff::new_with_rng(&config, rng);

for _ in 0..20 {
assert_eq!(backoff.next().as_secs_f64(), init_backoff_secs);
}

// Create a static rng that takes the maximum of the range
let rng = StepRng::new(u64::MAX, 0);
let mut backoff = Backoff::new_with_rng(&config, rng);

for i in 0..20 {
let value = (base.powi(i) * init_backoff_secs).min(max_backoff_secs);
assert_fuzzy_eq(backoff.next().as_secs_f64(), value);
}

// Create a static rng that takes the mid point of the range
let rng = StepRng::new(u64::MAX / 2, 0);
let mut backoff = Backoff::new_with_rng(&config, rng);

let mut value = init_backoff_secs;
for _ in 0..20 {
assert_fuzzy_eq(backoff.next().as_secs_f64(), value);
value =
(init_backoff_secs + (value * base - init_backoff_secs) / 2.).min(max_backoff_secs);
}
}
}
71 changes: 7 additions & 64 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,73 +1,16 @@
use std::collections::HashMap;

use tokio::{
io::BufStream,
net::{TcpStream, ToSocketAddrs},
};

use crate::{
messenger::Messenger,
protocol::{
api_key::ApiKey,
api_version::ApiVersion,
messages::{ApiVersionsRequest, RequestBody},
primitives::{CompactString, Int16, TaggedFields},
},
};
use crate::connection::BrokerPool;

pub struct Client {
#[allow(dead_code)]
messenger: Messenger<BufStream<TcpStream>>,
brokers: BrokerPool,
}

impl Client {
pub async fn new<A>(addr: A) -> Self
where
A: ToSocketAddrs,
{
let stream = TcpStream::connect(addr).await.unwrap();
let stream = BufStream::new(stream);
let messenger = Messenger::new(stream);
sync_versions(&messenger).await;
/// Create a new [`Client`] with the list of bootstrap brokers
pub async fn new(boostrap_brokers: Vec<String>) -> Self {
let mut brokers = BrokerPool::new(boostrap_brokers);
brokers.refresh_metadata().await.unwrap();

Self { messenger }
Self { brokers }
}
}

async fn sync_versions(messenger: &Messenger<BufStream<TcpStream>>) {
for upper_bound in (ApiVersionsRequest::API_VERSION_RANGE.0 .0 .0
..=ApiVersionsRequest::API_VERSION_RANGE.1 .0 .0)
.rev()
{
messenger.set_version_ranges(HashMap::from([(
ApiKey::ApiVersions,
(
ApiVersionsRequest::API_VERSION_RANGE.0,
ApiVersion(Int16(upper_bound)),
),
)]));

let body = ApiVersionsRequest {
client_software_name: CompactString(String::from("")),
client_software_version: CompactString(String::from("")),
tagged_fields: TaggedFields::default(),
};

if let Ok(response) = messenger.request(body).await {
if response.error_code.is_some() {
continue;
}

// TODO: check min and max are sane
let ranges = response
.api_keys
.into_iter()
.map(|x| (x.api_key, (x.min_version, x.max_version)))
.collect();
messenger.set_version_ranges(ranges);
return;
}
}

panic!("cannot sync")
}
99 changes: 99 additions & 0 deletions src/connection.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
use std::sync::Arc;

use thiserror::Error;
use tokio::io::BufStream;
use tokio::net::TcpStream;
use tokio::sync::Mutex;

use crate::backoff::{Backoff, BackoffConfig};
use crate::messenger::Messenger;

/// A connection to a broker
pub type BrokerConnection = Arc<Messenger<BufStream<TcpStream>>>;

#[derive(Debug, Error)]
pub enum Error {}

pub type Result<T, E = Error> = std::result::Result<T, E>;

/// Maintains a list of brokers within the cluster and caches a connection to a broker
pub struct BrokerPool {
///
bootstrap_brokers: Vec<String>,
/// Discovered brokers in the cluster, including bootstrap brokers
discovered_brokers: Vec<String>,
/// The current cached broker
current_broker: Mutex<Option<BrokerConnection>>,
/// The backoff configuration on error
backoff_config: BackoffConfig,
}

impl BrokerPool {
pub fn new(bootstrap_brokers: Vec<String>) -> Self {
Self {
bootstrap_brokers,
discovered_brokers: vec![],
current_broker: Mutex::new(None),
backoff_config: Default::default(),
}
}

/// Fetch and cache broker metadata
pub async fn refresh_metadata(&mut self) -> Result<()> {
self.get_cached_broker().await?;

//TODO: Get broker list
Ok(())
}

/// Invalidates the current cached broker
///
/// The next call to `[BrokerPool::get_cached_broker]` will get a new connection
#[allow(dead_code)]
pub async fn invalidate_cached_broker(&self) {
self.current_broker.lock().await.take();
}

/// Gets a cached [`BrokerConnection`] to any broker
pub async fn get_cached_broker(&self) -> Result<BrokerConnection> {
let mut current_broker = self.current_broker.lock().await;
if let Some(broker) = &*current_broker {
return Ok(Arc::clone(broker));
}

let brokers = if self.discovered_brokers.is_empty() {
&self.bootstrap_brokers
} else {
&self.discovered_brokers
};
Comment on lines +64 to +68
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about chaining this these two lists (discovered first, then falling back to bootstrap) in case the cluster got into some weird state (e.g. you have two brokers, the bootstrap broker doesn't report itself and after a while comes back and the discovered broker dies).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The discovered brokers should include the bootstrap brokers - will add docs to clarify


let mut backoff = Backoff::new(&self.backoff_config);

loop {
for broker in brokers {
let stream = match TcpStream::connect(&broker).await {
Ok(stream) => stream,
Err(e) => {
println!("Error connecting to broker {}: {}", broker, e);
continue;
}
};

let stream = BufStream::new(stream);
let messenger = Arc::new(Messenger::new(stream));
messenger.sync_versions().await;

*current_broker = Some(Arc::clone(&messenger));

return Ok(messenger);
}

let backoff = backoff.next();
println!(
"Failed to connect to any broker, backing off for {} seconds",
backoff.as_secs()
);
tokio::time::sleep(backoff).await;
}
}
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
mod backoff;
pub mod client;
mod connection;
mod messenger;
mod protocol;
50 changes: 47 additions & 3 deletions src/messenger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ use tokio::{
task::JoinHandle,
};

use crate::protocol::messages::ApiVersionsRequest;
use crate::protocol::primitives::CompactString;
use crate::protocol::{
api_key::ApiKey,
api_version::ApiVersion,
Expand All @@ -34,6 +36,10 @@ struct Response {
data: Cursor<Vec<u8>>,
}

/// A connection to a single broker
///
/// Note: Requests to the same [`Messenger`] will be pipelined by Kafka
///
pub struct Messenger<RW> {
stream_write: Mutex<WriteHalf<RW>>,
correlation_id: AtomicI32,
Expand Down Expand Up @@ -101,8 +107,8 @@ where
}
}

pub fn set_version_ranges(&self, ranges: HashMap<ApiKey, (ApiVersion, ApiVersion)>) {
*self.version_ranges.write().expect("lock poissened") = ranges;
fn set_version_ranges(&self, ranges: HashMap<ApiKey, (ApiVersion, ApiVersion)>) {
*self.version_ranges.write().expect("lock poisoned") = ranges;
}

pub async fn request<R>(&self, msg: R) -> Result<R::ResponseBody, RequestError>
Expand All @@ -113,7 +119,7 @@ where
let body_api_version = self
.version_ranges
.read()
.expect("lock poissened")
.expect("lock poisoned")
.get(&R::API_KEY)
.map(|range_server| match_versions(*range_server, R::API_VERSION_RANGE))
.flatten()
Expand Down Expand Up @@ -150,6 +156,44 @@ where

Ok(body)
}

pub async fn sync_versions(&self) {
for upper_bound in (ApiVersionsRequest::API_VERSION_RANGE.0 .0 .0
..=ApiVersionsRequest::API_VERSION_RANGE.1 .0 .0)
.rev()
{
self.set_version_ranges(HashMap::from([(
ApiKey::ApiVersions,
(
ApiVersionsRequest::API_VERSION_RANGE.0,
ApiVersion(Int16(upper_bound)),
),
)]));

let body = ApiVersionsRequest {
client_software_name: CompactString(String::from("")),
client_software_version: CompactString(String::from("")),
tagged_fields: TaggedFields::default(),
};

if let Ok(response) = self.request(body).await {
if response.error_code.is_some() {
continue;
}

// TODO: check min and max are sane
let ranges = response
.api_keys
.into_iter()
.map(|x| (x.api_key, (x.min_version, x.max_version)))
.collect();
self.set_version_ranges(ranges);
return;
}
}

panic!("cannot sync")
}
}

impl<RW> Drop for Messenger<RW> {
Expand Down
Loading