Skip to content

Commit

Permalink
feat: add n_route() to RouteProvider
Browse files Browse the repository at this point in the history
  • Loading branch information
nikolay-komarevskiy committed Aug 19, 2024
1 parent 386d353 commit 4401403
Show file tree
Hide file tree
Showing 6 changed files with 308 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,20 @@ where
{
fn route(&self) -> Result<Url, AgentError> {
let snapshot = self.routing_snapshot.load();
let node = snapshot.next().ok_or_else(|| {
let node = snapshot.next_node().ok_or_else(|| {
AgentError::RouteProviderError("No healthy API nodes found.".to_string())
})?;
Ok(node.to_routing_url())
}

fn n_routes(&self, n: usize) -> Result<Vec<Url>, AgentError> {
let snapshot = self.routing_snapshot.load();
let nodes = snapshot.next_n_nodes(n).ok_or_else(|| {
AgentError::RouteProviderError("No healthy API nodes found.".to_string())
})?;
let urls = nodes.iter().map(|n| n.to_routing_url()).collect();
Ok(urls)
}
}

impl<S> DynamicRouteProvider<S>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ where
// - failure should never happen, but we trace it if it does
loop {
let snapshot = self.routing_snapshot.load();
if let Some(node) = snapshot.next() {
if let Some(node) = snapshot.next_node() {
match self.fetcher.fetch((&node).into()).await {
Ok(nodes) => {
let msg = Some(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ impl LatencyRoutingSnapshot {
/// Helper function to sample nodes based on their weights.
/// Here weight index is selected based on the input number in range [0, 1]
#[inline(always)]
fn weighted_sample(weights: &[f64], number: f64) -> Option<usize> {
fn weighted_sample(weighted_nodes: &[(f64, &Node)], number: f64) -> Option<usize> {
if !(0.0..=1.0).contains(&number) {
return None;
}
let sum: f64 = weights.iter().sum();
let sum: f64 = weighted_nodes.iter().map(|n| n.0).sum();
let mut weighted_number = number * sum;
for (idx, weight) in weights.iter().enumerate() {
for (idx, &(weight, _)) in weighted_nodes.iter().enumerate() {
weighted_number -= weight;
if weighted_number <= 0.0 {
return Some(idx);
Expand All @@ -69,21 +69,53 @@ impl RoutingSnapshot for LatencyRoutingSnapshot {
!self.weighted_nodes.is_empty()
}

fn next(&self) -> Option<Node> {
fn next_node(&self) -> Option<Node> {
// We select a node based on it's weight, using a stochastic weighted random sampling approach.
let weights = self
let weighted_nodes: Vec<_> = self
.weighted_nodes
.iter()
.map(|n| n.weight)
.collect::<Vec<_>>();
.map(|n| (n.weight, &n.node))
.collect();
// Generate a random float in the range [0, 1)
let mut rng = rand::thread_rng();
let rand_num = rng.gen::<f64>();
// Using this random float and an array of weights we get an index of the node.
let idx = weighted_sample(weights.as_slice(), rand_num);
let idx = weighted_sample(weighted_nodes.as_slice(), rand_num);
idx.map(|idx| self.weighted_nodes[idx].node.clone())
}

// Uses weighted random sampling algorithm with item replacement n times.
fn next_n_nodes(&self, n: usize) -> Option<Vec<Node>> {
if n == 0 {
return Some(Vec::new());
}

let n = std::cmp::min(n, self.weighted_nodes.len());

let mut nodes = Vec::with_capacity(n);

let mut weighted_nodes: Vec<_> = self
.weighted_nodes
.iter()
.map(|n| (n.weight, &n.node))
.collect();

let mut rng = rand::thread_rng();

for _ in 0..n {
// Generate a random float in the range [0, 1)
let rand_num = rng.gen::<f64>();
if let Some(idx) = weighted_sample(weighted_nodes.as_slice(), rand_num) {
let node = weighted_nodes[idx].1;
nodes.push(node.clone());
// Remove the item, so that it can't be selected anymore.
weighted_nodes.swap_remove(idx);
}
}

Some(nodes)
}

fn sync_nodes(&mut self, nodes: &[Node]) -> bool {
let new_nodes = HashSet::from_iter(nodes.iter().cloned());
// Find nodes removed from topology.
Expand Down Expand Up @@ -143,7 +175,10 @@ impl RoutingSnapshot for LatencyRoutingSnapshot {

#[cfg(test)]
mod tests {
use std::{collections::HashSet, time::Duration};
use std::{
collections::{HashMap, HashSet},
time::Duration,
};

use simple_moving_average::SMA;

Expand All @@ -166,7 +201,7 @@ mod tests {
assert!(snapshot.weighted_nodes.is_empty());
assert!(snapshot.existing_nodes.is_empty());
assert!(!snapshot.has_nodes());
assert!(snapshot.next().is_none());
assert!(snapshot.next_node().is_none());
}

#[test]
Expand All @@ -181,7 +216,7 @@ mod tests {
assert!(!is_updated);
assert!(snapshot.weighted_nodes.is_empty());
assert!(!snapshot.has_nodes());
assert!(snapshot.next().is_none());
assert!(snapshot.next_node().is_none());
}

#[test]
Expand All @@ -201,7 +236,7 @@ mod tests {
Duration::from_secs(1)
);
assert_eq!(weighted_node.weight, 1.0);
assert_eq!(snapshot.next().unwrap(), node);
assert_eq!(snapshot.next_node().unwrap(), node);
// Check second update
let health = HealthCheckStatus::new(Some(Duration::from_secs(2)));
let is_updated = snapshot.update_node(&node, health);
Expand Down Expand Up @@ -232,7 +267,7 @@ mod tests {
assert_eq!(weighted_node.weight, 1.0 / avg_latency.as_secs_f64());
assert_eq!(snapshot.weighted_nodes.len(), 1);
assert_eq!(snapshot.existing_nodes.len(), 1);
assert_eq!(snapshot.next().unwrap(), node);
assert_eq!(snapshot.next_node().unwrap(), node);
}

#[test]
Expand Down Expand Up @@ -307,12 +342,13 @@ mod tests {

#[test]
fn test_weighted_sample() {
let node = &Node::new("api1.com").unwrap();
// Case 1: empty array
let arr: &[f64] = &[];
let arr = &[];
let idx = weighted_sample(arr, 0.5);
assert_eq!(idx, None);
// Case 2: single element in array
let arr: &[f64] = &[1.0];
let arr = &[(1.0, node)];
let idx = weighted_sample(arr, 0.0);
assert_eq!(idx, Some(0));
let idx = weighted_sample(arr, 1.0);
Expand All @@ -323,7 +359,7 @@ mod tests {
let idx = weighted_sample(arr, 1.1);
assert_eq!(idx, None);
// Case 3: two elements in array (second element has twice the weight of the first)
let arr: &[f64] = &[1.0, 2.0]; // prefixed_sum = [1.0, 3.0]
let arr = &[(1.0, node), (2.0, node)]; // // prefixed_sum = [1.0, 3.0]
let idx = weighted_sample(arr, 0.0); // 0.0 * 3.0 < 1.0
assert_eq!(idx, Some(0));
let idx = weighted_sample(arr, 0.33); // 0.33 * 3.0 < 1.0
Expand All @@ -338,7 +374,7 @@ mod tests {
let idx = weighted_sample(arr, 1.1);
assert_eq!(idx, None);
// Case 4: four elements in array
let arr: &[f64] = &[1.0, 2.0, 1.5, 2.5]; // prefixed_sum = [1.0, 3.0, 4.5, 7.0]
let arr = &[(1.0, node), (2.0, node), (1.5, node), (2.5, node)]; // prefixed_sum = [1.0, 3.0, 4.5, 7.0]
let idx = weighted_sample(arr, 0.14); // 0.14 * 7 < 1.0
assert_eq!(idx, Some(0)); // probability ~0.14
let idx = weighted_sample(arr, 0.15); // 0.15 * 7 > 1.0
Expand All @@ -359,4 +395,69 @@ mod tests {
let idx = weighted_sample(arr, 1.1);
assert_eq!(idx, None);
}

#[test]
#[ignore]
// This test is for manual runs to see the statistics for nodes selection probability.
fn test_stats_for_next_n_nodes() {
// Arrange
let mut snapshot = LatencyRoutingSnapshot::new();
let node_1 = Node::new("api1.com").unwrap();
let node_2 = Node::new("api2.com").unwrap();
let node_3 = Node::new("api3.com").unwrap();
let node_4 = Node::new("api4.com").unwrap();
let node_5 = Node::new("api5.com").unwrap();
let node_6 = Node::new("api6.com").unwrap();
let latency_mov_avg = LatencyMovAvg::from_zero(Duration::ZERO);
snapshot.weighted_nodes = vec![
WeightedNode {
node: node_2.clone(),
latency_mov_avg: latency_mov_avg.clone(),
weight: 8.0,
},
WeightedNode {
node: node_3.clone(),
latency_mov_avg: latency_mov_avg.clone(),
weight: 4.0,
},
WeightedNode {
node: node_1.clone(),
latency_mov_avg: latency_mov_avg.clone(),
weight: 16.0,
},
WeightedNode {
node: node_6.clone(),
latency_mov_avg: latency_mov_avg.clone(),
weight: 2.0,
},
WeightedNode {
node: node_5.clone(),
latency_mov_avg: latency_mov_avg.clone(),
weight: 1.0,
},
WeightedNode {
node: node_4.clone(),
latency_mov_avg: latency_mov_avg.clone(),
weight: 4.1,
},
];

let mut stats = HashMap::new();
let experiments = 30;
let select_nodes_count = 2;
for i in 0..experiments {
let nodes = snapshot.next_n_nodes(select_nodes_count).unwrap();
println!("Experiment {i}: selected nodes {nodes:?}");
for item in nodes.into_iter() {
*stats.entry(item).or_insert(1) += 1;
}
}
for (node, count) in stats {
println!(
"Node {:?} is selected with probability {}",
node.domain(),
count as f64 / experiments as f64
);
}
}
}
Loading

0 comments on commit 4401403

Please sign in to comment.