diff --git a/README.md b/README.md index 2ae15a71..f07d37ef 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,8 @@ let client = Client::new_plain(vec![connection]).await.unwrap(); // create a topic let topic = "my_topic"; -client.create_topic( +let controller_client = client.controller_client().await.unwrap(); +controller_client.create_topic( topic, 2, // partitions 1, // replication factor diff --git a/src/client/controller.rs b/src/client/controller.rs new file mode 100644 index 00000000..f37e411b --- /dev/null +++ b/src/client/controller.rs @@ -0,0 +1,169 @@ +use std::sync::Arc; + +use tokio::sync::Mutex; +use tracing::{debug, error, info}; + +use crate::{ + backoff::{Backoff, BackoffConfig}, + client::{Error, Result}, + connection::{BrokerConnection, BrokerConnector}, + messenger::RequestError, + protocol::{ + error::Error as ProtocolError, + messages::{CreateTopicRequest, CreateTopicsRequest}, + primitives::{Int16, Int32, NullableString, String_}, + }, +}; + +#[derive(Debug)] +pub struct ControllerClient { + brokers: Arc, + + backoff_config: BackoffConfig, + + /// Current broker connection if any + current_broker: Mutex>, +} + +impl ControllerClient { + pub(super) fn new(brokers: Arc) -> Self { + Self { + brokers, + backoff_config: Default::default(), + current_broker: Mutex::new(None), + } + } + + /// Create a topic + pub async fn create_topic( + &self, + name: impl Into + Send, + num_partitions: i32, + replication_factor: i16, + ) -> Result<()> { + let request = &CreateTopicsRequest { + topics: vec![CreateTopicRequest { + name: String_(name.into()), + num_partitions: Int32(num_partitions), + replication_factor: Int16(replication_factor), + assignments: vec![], + configs: vec![], + tagged_fields: None, + }], + // TODO: Expose as configuration parameter + timeout_ms: Int32(5_000), + validate_only: None, + tagged_fields: None, + }; + + self.maybe_retry("create_topic", || async move { + let broker = self.get_cached_controller_broker().await?; + let response = broker.request(request).await?; + + if response.topics.len() != 1 { + return Err(Error::InvalidResponse(format!( + "Expected a single topic in response, got {}", + response.topics.len() + ))); + } + + let topic = response.topics.into_iter().next().unwrap(); + + match topic.error { + None => Ok(()), + Some(protocol_error) => match topic.error_message { + Some(NullableString(Some(msg))) => Err(Error::ServerError(protocol_error, msg)), + _ => Err(Error::ServerError(protocol_error, Default::default())), + }, + } + }) + .await + } + + /// Takes a `request_name` and a function yielding a fallible future + /// and handles certain classes of error + async fn maybe_retry(&self, request_name: &str, f: R) -> Result + where + R: (Fn() -> F) + Send + Sync, + F: std::future::Future> + Send, + { + let mut backoff = Backoff::new(&self.backoff_config); + + loop { + let error = match f().await { + Ok(v) => return Ok(v), + Err(e) => e, + }; + + match error { + Error::Request(RequestError::Poisoned(_) | RequestError::IO(_)) + | Error::Connection(_) => self.invalidate_cached_controller_broker().await, + Error::ServerError(ProtocolError::LeaderNotAvailable, _) => {} + Error::ServerError(ProtocolError::OffsetNotAvailable, _) => {} + Error::ServerError(ProtocolError::NotController, _) => { + self.invalidate_cached_controller_broker().await; + } + _ => { + error!( + e=%error, + request_name, + "request encountered fatal error", + ); + return Err(error); + } + } + + let backoff = backoff.next(); + info!( + e=%error, + request_name, + backoff_secs=backoff.as_secs(), + "request encountered non-fatal error - backing off", + ); + tokio::time::sleep(backoff).await; + } + } + + /// Gets a cached [`BrokerConnection`] to any cluster controller. + async fn get_cached_controller_broker(&self) -> Result { + let mut current_broker = self.current_broker.lock().await; + if let Some(broker) = &*current_broker { + return Ok(Arc::clone(broker)); + } + + info!("Creating new controller broker connection",); + + let controller_id = self + .get_controller_id(self.brokers.get_cached_arbitrary_broker().await?) + .await?; + let broker = self.brokers.connect(controller_id).await?.ok_or_else(|| { + Error::InvalidResponse(format!( + "Controller {} not found in metadata response", + controller_id + )) + })?; + + *current_broker = Some(Arc::clone(&broker)); + Ok(broker) + } + + /// Invalidates the cached controller broker. + /// + /// The next call to `[ContollerClient::get_cached_controller_broker]` will get a new connection + pub async fn invalidate_cached_controller_broker(&self) { + debug!("Invalidating cached controller broker"); + self.current_broker.lock().await.take(); + } + + /// Retrieve the broker ID of the controller + async fn get_controller_id(&self, broker: BrokerConnection) -> Result { + let metadata = self.brokers.request_metadata(broker, Some(vec![])).await?; + + let controller_id = metadata + .controller_id + .ok_or_else(|| Error::InvalidResponse("Leader is NULL".to_owned()))? + .0; + + Ok(controller_id) + } +} diff --git a/src/client/mod.rs b/src/client/mod.rs index b94cd094..f8fdac16 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -3,22 +3,20 @@ use std::sync::Arc; use thiserror::Error; use crate::{ - client::partition::PartitionClient, - connection::BrokerConnector, - protocol::{ - messages::{CreateTopicRequest, CreateTopicsRequest}, - primitives::*, - }, + client::partition::PartitionClient, connection::BrokerConnector, protocol::primitives::Boolean, topic::Topic, }; pub mod consumer; +pub mod controller; pub mod error; pub mod partition; pub mod producer; use error::{Error, Result}; +use self::controller::ControllerClient; + #[derive(Debug, Error)] pub enum ProduceError { #[error(transparent)] @@ -67,6 +65,11 @@ impl Client { Ok(Self { brokers }) } + /// Returns a client for performing certain cluster-wide operations. + pub async fn controller_client(&self) -> Result { + Ok(ControllerClient::new(Arc::clone(&self.brokers))) + } + /// Returns a client for performing operations on a specific partition pub async fn partition_client( &self, @@ -84,7 +87,7 @@ impl Client { pub async fn list_topics(&self) -> Result> { let response = self .brokers - .request_metadata(self.brokers.get_arbitrary_cached_broker().await?, None) + .request_metadata(self.brokers.get_cached_arbitrary_broker().await?, None) .await?; Ok(response @@ -101,47 +104,4 @@ impl Client { }) .collect()) } - - /// Create a topic - pub async fn create_topic( - &self, - name: impl Into + Send, - num_partitions: i32, - replication_factor: i16, - ) -> Result<()> { - let broker = self.brokers.get_arbitrary_cached_broker().await?; - let response = broker - .request(CreateTopicsRequest { - topics: vec![CreateTopicRequest { - name: String_(name.into()), - num_partitions: Int32(num_partitions), - replication_factor: Int16(replication_factor), - assignments: vec![], - configs: vec![], - tagged_fields: None, - }], - // TODO: Expose as configuration parameter - timeout_ms: Int32(5_000), - validate_only: None, - tagged_fields: None, - }) - .await?; - - if response.topics.len() != 1 { - return Err(Error::InvalidResponse(format!( - "Expected a single topic in response, got {}", - response.topics.len() - ))); - } - - let topic = response.topics.into_iter().next().unwrap(); - - match topic.error { - None => Ok(()), - Some(protocol_error) => match topic.error_message { - Some(NullableString(Some(msg))) => Err(Error::ServerError(protocol_error, msg)), - _ => Err(Error::ServerError(protocol_error, Default::default())), - }, - } - } } diff --git a/src/client/partition.rs b/src/client/partition.rs index 16e421c0..cc59f49a 100644 --- a/src/client/partition.rs +++ b/src/client/partition.rs @@ -441,7 +441,7 @@ impl PartitionClient { ); let leader = self - .get_leader(self.brokers.get_arbitrary_cached_broker().await?) + .get_leader(self.brokers.get_cached_arbitrary_broker().await?) .await?; let broker = self.brokers.connect(leader).await?.ok_or_else(|| { Error::InvalidResponse(format!( diff --git a/src/connection.rs b/src/connection.rs index 8d1ac370..a345f51e 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,6 +1,6 @@ use rand::prelude::*; use std::sync::Arc; -use tracing::{info, warn}; +use tracing::{debug, info, warn}; use thiserror::Error; use tokio::io::BufStream; @@ -50,8 +50,10 @@ pub struct BrokerConnector { /// Discovered brokers in the cluster, including bootstrap brokers topology: BrokerTopology, - /// The current cached broker - current_broker: Mutex>, + /// The cached arbitrary broker. + /// + /// This one is used for metadata queries. + cached_arbitrary_broker: Mutex>, /// The backoff configuration on error backoff_config: BackoffConfig, @@ -72,7 +74,7 @@ impl BrokerConnector { Self { bootstrap_brokers, topology: Default::default(), - current_broker: Mutex::new(None), + cached_arbitrary_broker: Mutex::new(None), backoff_config: Default::default(), tls_config, max_message_size, @@ -82,7 +84,7 @@ impl BrokerConnector { /// Fetch and cache broker metadata pub async fn refresh_metadata(&self) -> Result<()> { self.request_metadata( - self.get_arbitrary_cached_broker().await?, + self.get_cached_arbitrary_broker().await?, // Not interested in topic metadata Some(vec![]), ) @@ -133,16 +135,18 @@ impl BrokerConnector { tokio::time::sleep(backoff).await; }; + // Since the metadata request contains information about the cluster state, use it to update our view. self.topology.update(&response.brokers); + Ok(response) } - /// Invalidates the current cached broker + /// Invalidates the cached arbitrary broker. /// - /// The next call to `[BrokerPool::get_cached_broker]` will get a new connection - #[allow(dead_code)] + /// The next call to `[BrokerConnector::get_cached_arbitrary_broker]` will get a new connection pub async fn invalidate_cached_arbitrary_broker(&self) { - self.current_broker.lock().await.take(); + debug!("Invalidating cached arbitrary broker"); + self.cached_arbitrary_broker.lock().await.take(); } /// Returns a new connection to the broker with the provided id @@ -171,8 +175,8 @@ impl BrokerConnector { } /// Gets a cached [`BrokerConnection`] to any broker - pub async fn get_arbitrary_cached_broker(&self) -> Result { - let mut current_broker = self.current_broker.lock().await; + pub async fn get_cached_arbitrary_broker(&self) -> Result { + let mut current_broker = self.cached_arbitrary_broker.lock().await; if let Some(broker) = &*current_broker { return Ok(Arc::clone(broker)); } @@ -222,7 +226,7 @@ impl std::fmt::Debug for BrokerConnector { f.debug_struct("BrokerConnector") .field("bootstrap_brokers", &self.bootstrap_brokers) .field("topology", &self.topology) - .field("current_broker", &self.current_broker) + .field("cached_arbitrary_broker", &self.cached_arbitrary_broker) .field("backoff_config", &self.backoff_config) .field("tls_config", &tls_config) .field("max_message_size", &self.max_message_size) diff --git a/tests/client.rs b/tests/client.rs index 1362d85b..0f6ce131 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -29,9 +29,13 @@ async fn test_partition_leader() { let connection = maybe_skip_kafka_integration!(); let client = Client::new_plain(vec![connection]).await.unwrap(); + let controller_client = client.controller_client().await.unwrap(); let topic_name = random_topic_name(); - client.create_topic(&topic_name, 2, 1).await.unwrap(); + controller_client + .create_topic(&topic_name, 2, 1) + .await + .unwrap(); let client = client.partition_client(&topic_name, 0).await.unwrap(); tokio::time::timeout(Duration::from_secs(10), async move { loop { @@ -54,6 +58,7 @@ async fn test_topic_crud() { let connection = maybe_skip_kafka_integration!(); let client = Client::new_plain(vec![connection]).await.unwrap(); + let controller_client = client.controller_client().await.unwrap(); let topics = client.list_topics().await.unwrap(); let prefix = "test_topic_crud_"; @@ -66,23 +71,31 @@ async fn test_topic_crud() { } } } - let new_topic = format!("{}{}", prefix, max_id + 1); - client.create_topic(&new_topic, 2, 1).await.unwrap(); + controller_client + .create_topic(&new_topic, 2, 1) + .await + .unwrap(); - let topics = client.list_topics().await.unwrap(); + // might take a while to converge + tokio::time::timeout(Duration::from_millis(1_000), async { + loop { + let topics = client.list_topics().await.unwrap(); + let topic = topics.iter().find(|t| t.name == new_topic); + if topic.is_some() { + return; + } - let topic = topics.iter().find(|t| t.name == new_topic); - assert!( - topic.is_some(), - "topic {} not found in {:?}", - new_topic, - topics - ); - let topic = topic.unwrap(); - assert_eq!(topic.partitions.len(), 2); + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .unwrap(); - let err = client.create_topic(&new_topic, 1, 1).await.unwrap_err(); + let err = controller_client + .create_topic(&new_topic, 1, 1) + .await + .unwrap_err(); match err { ClientError::ServerError(ProtocolError::TopicAlreadyExists, _) => {} _ => panic!("Unexpected error: {}", err), @@ -141,7 +154,8 @@ async fn test_produce_empty() { let n_partitions = 2; let client = Client::new_plain(vec![connection]).await.unwrap(); - client + let controller_client = client.controller_client().await.unwrap(); + controller_client .create_topic(&topic_name, n_partitions, 1) .await .unwrap(); @@ -179,7 +193,8 @@ async fn test_get_high_watermark() { let n_partitions = 1; let client = Client::new_plain(vec![connection.clone()]).await.unwrap(); - client + let controller_client = client.controller_client().await.unwrap(); + controller_client .create_topic(&topic_name, n_partitions, 1) .await .unwrap(); @@ -230,7 +245,8 @@ where let n_partitions = 2; let client = Client::new_plain(vec![connection.clone()]).await.unwrap(); - client + let controller_client = client.controller_client().await.unwrap(); + controller_client .create_topic(&topic_name, n_partitions, 1) .await .unwrap(); diff --git a/tests/consumer.rs b/tests/consumer.rs index e7db417e..4be4b942 100644 --- a/tests/consumer.rs +++ b/tests/consumer.rs @@ -18,9 +18,10 @@ async fn test_stream_consumer() { let connection = maybe_skip_kafka_integration!(); let client = Client::new_plain(vec![connection]).await.unwrap(); + let controller_client = client.controller_client().await.unwrap(); let topic = random_topic_name(); - client.create_topic(&topic, 1, 1).await.unwrap(); + controller_client.create_topic(&topic, 1, 1).await.unwrap(); let record = record(); diff --git a/tests/producer.rs b/tests/producer.rs index 6e5e195e..0ded065a 100644 --- a/tests/producer.rs +++ b/tests/producer.rs @@ -15,9 +15,10 @@ async fn test_batch_producer() { let connection = maybe_skip_kafka_integration!(); let client = Client::new_plain(vec![connection]).await.unwrap(); + let controller_client = client.controller_client().await.unwrap(); let topic = random_topic_name(); - client.create_topic(&topic, 1, 1).await.unwrap(); + controller_client.create_topic(&topic, 1, 1).await.unwrap(); let record = record();