From 16fca784abe07be8f52fbd5d61e8b5ba8930c4d4 Mon Sep 17 00:00:00 2001 From: Najib Ishaq Date: Mon, 2 Dec 2024 14:43:05 -0500 Subject: [PATCH] wip: search algorithms rework --- crates/abd-clam/src/cakes/mod.rs | 282 ++++++++++-------- .../src/cakes/new_search/knn_breadth_first.rs | 212 +++++++++++++ .../src/cakes/new_search/knn_depth_first.rs | 182 +++++++++++ .../src/cakes/new_search/knn_linear.rs | 32 ++ .../src/cakes/new_search/knn_repeated_rnn.rs | 129 ++++++++ crates/abd-clam/src/cakes/new_search/mod.rs | 83 ++++++ .../src/cakes/new_search/rnn_clustered.rs | 247 +++++++++++++++ .../src/cakes/new_search/rnn_linear.rs | 30 ++ 8 files changed, 1069 insertions(+), 128 deletions(-) create mode 100644 crates/abd-clam/src/cakes/new_search/knn_breadth_first.rs create mode 100644 crates/abd-clam/src/cakes/new_search/knn_depth_first.rs create mode 100644 crates/abd-clam/src/cakes/new_search/knn_linear.rs create mode 100644 crates/abd-clam/src/cakes/new_search/knn_repeated_rnn.rs create mode 100644 crates/abd-clam/src/cakes/new_search/mod.rs create mode 100644 crates/abd-clam/src/cakes/new_search/rnn_clustered.rs create mode 100644 crates/abd-clam/src/cakes/new_search/rnn_linear.rs diff --git a/crates/abd-clam/src/cakes/mod.rs b/crates/abd-clam/src/cakes/mod.rs index 38574be8..d42be84e 100644 --- a/crates/abd-clam/src/cakes/mod.rs +++ b/crates/abd-clam/src/cakes/mod.rs @@ -2,6 +2,7 @@ mod cluster; mod dataset; +pub mod new_search; mod search; pub use cluster::{ParSearchable, PermutedBall, Searchable}; @@ -15,25 +16,10 @@ pub mod tests { use distances::{number::Float, Number}; use rand::prelude::*; - use test_case::test_case; - use crate::{ - cluster::{ - adapter::{BallAdapter, ParBallAdapter}, - ParPartition, Partition, - }, - dataset::{AssociatesMetadataMut, ParDataset}, - metric::{Euclidean, Levenshtein, ParMetric}, - Ball, Cluster, Dataset, FlatVec, - }; + use crate::{metric::ParMetric, Dataset, FlatVec}; - use super::{search::Algorithm, ParSearchable, PermutedBall}; - - /// Type alias for the algorithms and their correctness checkers. - pub type Algs = Vec<( - Algorithm, - fn(Vec<(usize, T)>, Vec<(usize, T)>, &str, &FlatVec) -> bool, - )>; + use super::{new_search::ParSearch, ParSearchable}; /// Generate 1d line data for testing. #[allow(clippy::pedantic)] @@ -51,6 +37,14 @@ pub mod tests { FlatVec::new(data).unwrap_or_else(|e| unreachable!("{e}")) } + /// Generate random data for testing. + #[allow(clippy::pedantic)] + pub fn gen_random_data(car: usize, dim: usize, max: T, seed: u64) -> FlatVec, usize> { + let mut rng = rand::rngs::StdRng::seed_from_u64(seed); + let data = symagen::random_data::random_tabular(car, dim, -max, max, &mut rng); + FlatVec::new(data).unwrap_or_else(|e| unreachable!("{e}")) + } + /// Check the search results using the indices of hits. #[allow(clippy::pedantic)] pub fn check_search_by_index( @@ -106,14 +100,46 @@ pub mod tests { true } - /// Generate random data for testing. - #[allow(clippy::pedantic)] - pub fn gen_random_data(car: usize, dim: usize, max: T, seed: u64) -> FlatVec, usize> { - let mut rng = rand::rngs::StdRng::seed_from_u64(seed); - let data = symagen::random_data::random_tabular(car, dim, -max, max, &mut rng); - FlatVec::new(data).unwrap_or_else(|e| unreachable!("{e}")) + /// Check a ranged search algorithm. + pub fn check_rnn( + root: &C, + data: &FlatVec, + metric: &M, + query: &I, + radius: U, + alg: &A, + ) -> bool + where + I: Debug + Send + Sync, + U: Number, + Me: Debug + Send + Sync, + M: ParMetric, + C: ParSearchable, M>, + A: ParSearch, M, C, U>, + { + let true_hits = data.rnn(query, radius, metric).collect::>(); + + let pred_hits = alg.search(data, metric, root, query, &radius); + assert_eq!(pred_hits.len(), true_hits.len(), "Rnn search failed: {pred_hits:?}"); + check_search_by_index(true_hits.clone(), pred_hits, "RnnClustered", data); + + let pred_hits = alg.par_search(data, metric, root, query, &radius); + assert_eq!( + pred_hits.len(), + true_hits.len(), + "Parallel Rnn search failed: {pred_hits:?}" + ); + check_search_by_index(true_hits, pred_hits, "Par RnnClustered", data); + + true } + /// Type alias for the algorithms and their correctness checkers. + pub type Algs = Vec<( + super::Algorithm, + fn(Vec<(usize, T)>, Vec<(usize, T)>, &str, &FlatVec) -> bool, + )>; + /// Check the search results of the algorithms. pub fn check_search( algs: &Algs, @@ -127,7 +153,7 @@ pub mod tests { where I: Send + Sync, T: Number, - D: ParDataset, + D: crate::dataset::ParDataset, M: ParMetric, C: ParSearchable, { @@ -141,108 +167,108 @@ pub mod tests { true } - #[test_case(1_000, 10)] - #[test_case(10_000, 10)] - #[test_case(1_000, 100)] - #[test_case(10_000, 100)] - fn vectors(car: usize, dim: usize) { - let mut algs: Algs, f32, usize> = vec![]; - for radius in [0.1, 1.0] { - algs.push((Algorithm::RnnClustered(radius), check_search_by_index)); - } - for k in [1, 10, 100] { - algs.push((Algorithm::KnnRepeatedRnn(k, 2.0), check_search_by_distance)); - algs.push((Algorithm::KnnBreadthFirst(k), check_search_by_distance)); - algs.push((Algorithm::KnnDepthFirst(k), check_search_by_distance)); - } - - let seed = 42; - let data = gen_random_data(car, dim, 10.0, seed); - let metric = Euclidean; - let criteria = |c: &Ball<_>| c.cardinality() > 1; - let seed = Some(seed); - let query = &vec![0.0; dim]; - - let ball = Ball::new_tree(&data, &metric, &criteria, seed); - check_search(&algs, &data, &metric, &ball, query, "ball", &data); - - let ball = Ball::par_new_tree(&data, &metric, &criteria, seed); - check_search(&algs, &data, &metric, &ball, query, "ball", &data); - - let (off_ball, perm_data) = PermutedBall::from_ball_tree(ball.clone(), data.clone(), &metric); - check_search(&algs, &perm_data, &metric, &off_ball, query, "off_ball", &perm_data); - - let (par_off_ball, per_perm_data) = PermutedBall::par_from_ball_tree(ball, data, &metric); - check_search( - &algs, - &per_perm_data, - &metric, - &par_off_ball, - query, - "par_off_ball", - &per_perm_data, - ); - } - - #[test_case::test_case(16, 16, 2)] - #[test_case::test_case(32, 16, 3)] - fn strings(num_clumps: usize, clump_size: usize, clump_radius: u16) -> Result<(), String> { - let mut algs: Algs = vec![]; - for radius in [4, 12] { - algs.push((Algorithm::RnnClustered(radius), check_search_by_index)); - } - for k in [1, 10] { - algs.push((Algorithm::KnnRepeatedRnn(k, 2), check_search_by_distance)); - algs.push((Algorithm::KnnBreadthFirst(k), check_search_by_distance)); - algs.push((Algorithm::KnnDepthFirst(k), check_search_by_distance)); - } - - let seed_length = 30; - let alphabet = "ACTGN".chars().collect::>(); - let seed_string = symagen::random_edits::generate_random_string(seed_length, &alphabet); - let penalties = distances::strings::Penalties::default(); - let inter_clump_distance_range = (clump_radius * 5, clump_radius * 7); - let len_delta = seed_length / 10; - let (metadata, data) = symagen::random_edits::generate_clumped_data( - &seed_string, - penalties, - &alphabet, - num_clumps, - clump_size, - clump_radius, - inter_clump_distance_range, - len_delta, - ) - .into_iter() - .unzip::<_, _, Vec<_>, Vec<_>>(); - let query = &seed_string; - - let metric = Levenshtein; - let data = FlatVec::new(data)?.with_metadata(&metadata)?; - - let criteria = |c: &Ball<_>| c.cardinality() > 1; - let seed = Some(42); - - let ball = Ball::new_tree(&data, &metric, &criteria, seed); - check_search(&algs, &data, &metric, &ball, query, "ball", &data); - - let ball = Ball::par_new_tree(&data, &metric, &criteria, seed); - check_search(&algs, &data, &metric, &ball, query, "ball", &data); - - let (perm_ball, perm_data) = PermutedBall::from_ball_tree(ball.clone(), data.clone(), &metric); - check_search(&algs, &perm_data, &metric, &perm_ball, query, "off_ball", &perm_data); - - let (par_perm_ball, par_perm_data) = PermutedBall::par_from_ball_tree(ball, data, &metric); - check_search( - &algs, - &par_perm_data, - &metric, - &par_perm_ball, - query, - "par_off_ball", - &par_perm_data, - ); - - Ok(()) - } + // #[test_case(1_000, 10)] + // #[test_case(10_000, 10)] + // #[test_case(1_000, 100)] + // #[test_case(10_000, 100)] + // fn vectors(car: usize, dim: usize) { + // let mut algs: Algs, f32, usize> = vec![]; + // for radius in [0.1, 1.0] { + // algs.push((Algorithm::RnnClustered(radius), check_search_by_index)); + // } + // for k in [1, 10, 100] { + // algs.push((Algorithm::KnnRepeatedRnn(k, 2.0), check_search_by_distance)); + // algs.push((Algorithm::KnnBreadthFirst(k), check_search_by_distance)); + // algs.push((Algorithm::KnnDepthFirst(k), check_search_by_distance)); + // } + + // let seed = 42; + // let data = gen_random_data(car, dim, 10.0, seed); + // let metric = Euclidean; + // let criteria = |c: &Ball<_>| c.cardinality() > 1; + // let seed = Some(seed); + // let query = &vec![0.0; dim]; + + // let ball = Ball::new_tree(&data, &metric, &criteria, seed); + // check_search(&algs, &data, &metric, &ball, query, "ball", &data); + + // let ball = Ball::par_new_tree(&data, &metric, &criteria, seed); + // check_search(&algs, &data, &metric, &ball, query, "ball", &data); + + // let (off_ball, perm_data) = PermutedBall::from_ball_tree(ball.clone(), data.clone(), &metric); + // check_search(&algs, &perm_data, &metric, &off_ball, query, "off_ball", &perm_data); + + // let (par_off_ball, per_perm_data) = PermutedBall::par_from_ball_tree(ball, data, &metric); + // check_search( + // &algs, + // &per_perm_data, + // &metric, + // &par_off_ball, + // query, + // "par_off_ball", + // &per_perm_data, + // ); + // } + + // #[test_case::test_case(16, 16, 2)] + // #[test_case::test_case(32, 16, 3)] + // fn strings(num_clumps: usize, clump_size: usize, clump_radius: u16) -> Result<(), String> { + // let mut algs: Algs = vec![]; + // for radius in [4, 12] { + // algs.push((Algorithm::RnnClustered(radius), check_search_by_index)); + // } + // for k in [1, 10] { + // algs.push((Algorithm::KnnRepeatedRnn(k, 2), check_search_by_distance)); + // algs.push((Algorithm::KnnBreadthFirst(k), check_search_by_distance)); + // algs.push((Algorithm::KnnDepthFirst(k), check_search_by_distance)); + // } + + // let seed_length = 30; + // let alphabet = "ACTGN".chars().collect::>(); + // let seed_string = symagen::random_edits::generate_random_string(seed_length, &alphabet); + // let penalties = distances::strings::Penalties::default(); + // let inter_clump_distance_range = (clump_radius * 5, clump_radius * 7); + // let len_delta = seed_length / 10; + // let (metadata, data) = symagen::random_edits::generate_clumped_data( + // &seed_string, + // penalties, + // &alphabet, + // num_clumps, + // clump_size, + // clump_radius, + // inter_clump_distance_range, + // len_delta, + // ) + // .into_iter() + // .unzip::<_, _, Vec<_>, Vec<_>>(); + // let query = &seed_string; + + // let metric = Levenshtein; + // let data = FlatVec::new(data)?.with_metadata(&metadata)?; + + // let criteria = |c: &Ball<_>| c.cardinality() > 1; + // let seed = Some(42); + + // let ball = Ball::new_tree(&data, &metric, &criteria, seed); + // check_search(&algs, &data, &metric, &ball, query, "ball", &data); + + // let ball = Ball::par_new_tree(&data, &metric, &criteria, seed); + // check_search(&algs, &data, &metric, &ball, query, "ball", &data); + + // let (perm_ball, perm_data) = PermutedBall::from_ball_tree(ball.clone(), data.clone(), &metric); + // check_search(&algs, &perm_data, &metric, &perm_ball, query, "off_ball", &perm_data); + + // let (par_perm_ball, par_perm_data) = PermutedBall::par_from_ball_tree(ball, data, &metric); + // check_search( + // &algs, + // &par_perm_data, + // &metric, + // &par_perm_ball, + // query, + // "par_off_ball", + // &par_perm_data, + // ); + + // Ok(()) + // } } diff --git a/crates/abd-clam/src/cakes/new_search/knn_breadth_first.rs b/crates/abd-clam/src/cakes/new_search/knn_breadth_first.rs new file mode 100644 index 00000000..18ec2c0c --- /dev/null +++ b/crates/abd-clam/src/cakes/new_search/knn_breadth_first.rs @@ -0,0 +1,212 @@ +//! K-Nearest Neighbors search using a Breadth First sieve. + +use core::cmp::{min, Ordering}; + +use distances::Number; +use rayon::prelude::*; + +use crate::{ + cakes::{ParSearchable, Searchable}, + dataset::ParDataset, + metric::ParMetric, + Cluster, Dataset, Metric, SizedHeap, +}; + +use super::{ParSearch, Search}; + +/// K-Nearest Neighbors search using a Breadth First sieve. +pub struct KnnBreadthFirst; + +impl, M: Metric, C: Searchable> Search + for KnnBreadthFirst +{ + fn search(&self, data: &D, metric: &M, root: &C, query: &I, &k: &usize) -> Vec<(usize, T)> { + let mut candidates = Vec::new(); + let mut hits = SizedHeap::<(T, usize)>::new(Some(k)); + + let d = root.query_to_center(query, data, metric); + candidates.push((d_max(root, d), root)); + + while !candidates.is_empty() { + let [needed, maybe_needed, _] = split_candidates(&mut candidates, k); + + let (leaves, parents) = needed + .into_iter() + .chain(maybe_needed) + .partition::, _>(|(_, c)| c.is_leaf()); + + for (d, c) in leaves { + if c.is_singleton() { + c.indices().for_each(|i| hits.push((d, i))); + } else { + c.query_to_all(query, data, metric).for_each(|(i, d)| hits.push((d, i))); + } + } + + candidates = Vec::new(); + for (_, p) in parents { + p.child_clusters() + .map(|c| (c, c.query_to_center(query, data, metric))) + .for_each(|(c, d)| candidates.push((d_max(c, d), c))); + } + } + + hits.items().map(|(d, i)| (i, d)).collect() + } +} + +impl, M: ParMetric, C: ParSearchable> + ParSearch for KnnBreadthFirst +{ + fn par_search(&self, data: &D, metric: &M, root: &C, query: &I, &k: &usize) -> Vec<(usize, T)> { + let mut candidates = Vec::new(); + let mut hits = SizedHeap::<(T, usize)>::new(Some(k)); + + let d = root.query_to_center(query, data, metric); + candidates.push((d_max(root, d), root)); + + while !candidates.is_empty() { + let [needed, maybe_needed, _] = split_candidates(&mut candidates, k); + + let (leaves, parents) = needed + .into_iter() + .chain(maybe_needed) + .partition::, _>(|(_, c)| c.is_leaf()); + + for (d, c) in leaves { + if c.is_singleton() { + c.indices().for_each(|i| hits.push((d, i))); + } else { + c.par_query_to_all(query, data, metric) + .collect::>() + .into_iter() + .for_each(|(i, d)| hits.push((d, i))); + } + } + + candidates = Vec::new(); + let distances = parents + .into_par_iter() + .flat_map(|(_, p)| p.child_clusters().collect::>()) + .map(|c| (c, c.par_query_to_center(query, data, metric))) + .collect::>(); + distances + .into_iter() + .for_each(|(c, d)| candidates.push((d_max(c, d), c))); + } + + hits.items().map(|(d, i)| (i, d)).collect() + } +} + +/// Returns the theoretical maximum distance from the query to a point in the cluster. +fn d_max>(c: &C, d: T) -> T { + c.radius() + d +} + +/// Splits the candidates three ways: those needed to get to k hits, those that +/// might be needed to get to k hits, and those that are not needed to get to k +/// hits. +fn split_candidates<'a, T: Number, C: Cluster>(candidates: &mut [(T, &'a C)], k: usize) -> [Vec<(T, &'a C)>; 3] { + let threshold_index = quick_partition(candidates, k); + let threshold = candidates[threshold_index].0; + + let (needed, others) = candidates.iter().partition::, _>(|(d, _)| *d < threshold); + + let (not_needed, maybe_needed) = others + .into_iter() + .map(|(d, c)| { + let diam = c.radius().double(); + if d <= diam { + (d, T::ZERO, c) + } else { + (d, d - diam, c) + } + }) + .partition::, _>(|(_, d, _)| *d > threshold); + + let not_needed = not_needed.into_iter().map(|(d, _, c)| (d, c)).collect(); + let maybe_needed = maybe_needed.into_iter().map(|(d, _, c)| (d, c)).collect(); + + [needed, maybe_needed, not_needed] +} + +/// The Quick Partition algorithm, which is a variant of the Quick Select +/// algorithm. It finds the k-th smallest element in a list of elements, while +/// also reordering the list so that all elements to the left of the k-th +/// smallest element are less than or equal to it, and all elements to the right +/// of the k-th smallest element are greater than or equal to it. +fn quick_partition>(items: &mut [(T, &C)], k: usize) -> usize { + qps(items, k, 0, items.len() - 1) +} + +/// The recursive helper function for the Quick Partition algorithm. +fn qps>(items: &mut [(T, &C)], k: usize, l: usize, r: usize) -> usize { + if l >= r { + min(l, r) + } else { + // Choose the pivot point + let pivot = l + (r - l) / 2; + let p = find_pivot(items, l, r, pivot); + + // Calculate the cumulative guaranteed cardinalities for the first p + // `Cluster`s + let cumulative_guarantees = items + .iter() + .take(p) + .scan(0, |acc, (_, c)| { + *acc += c.cardinality(); + Some(*acc) + }) + .collect::>(); + + // Calculate the guaranteed cardinality of the p-th `Cluster` + let guaranteed_p = if p > 0 { cumulative_guarantees[p - 1] } else { 0 }; + + match guaranteed_p.cmp(&k) { + Ordering::Equal => p, // Found the k-th smallest element + Ordering::Less => qps(items, k, p + 1, r), // Need to look to the right + Ordering::Greater => { + // The `Cluster` just before the p-th might be the one we need + let guaranteed_p_minus_one = if p > 1 { cumulative_guarantees[p - 2] } else { 0 }; + if p == 0 || guaranteed_p_minus_one < k { + p // Found the k-th smallest element + } else { + // Need to look to the left + qps(items, k, l, p - 1) + } + } + } + } +} + +/// Moves pivot point and swaps elements around so that all elements to left +/// of pivot are less than or equal to pivot and all elements to right of pivot +/// are greater than pivot. +fn find_pivot(items: &mut [(T, &C)], l: usize, r: usize, pivot: usize) -> usize +where + T: Number, + C: Cluster, +{ + // Move pivot to the end + items.swap(pivot, r); + + // Partition around pivot + let (mut a, mut b) = (l, l); + // Invariant: a <= b <= r + while b < r { + // If the current element is less than the pivot, swap it with the + // element at a and increment a. + if items[b].0 < items[r].0 { + items.swap(a, b); + a += 1; + } + // Increment b + b += 1; + } + + // Move pivot to its final position + items.swap(a, r); + + a +} diff --git a/crates/abd-clam/src/cakes/new_search/knn_depth_first.rs b/crates/abd-clam/src/cakes/new_search/knn_depth_first.rs new file mode 100644 index 00000000..cd4af620 --- /dev/null +++ b/crates/abd-clam/src/cakes/new_search/knn_depth_first.rs @@ -0,0 +1,182 @@ +//! K-Nearest Neighbors search using a Depth First sieve. + +use core::cmp::Reverse; + +use distances::Number; +use rayon::prelude::*; + +use crate::{ + cakes::{ParSearchable, Searchable}, + dataset::{ParDataset, SizedHeap}, + metric::ParMetric, + Cluster, Dataset, Metric, +}; + +use super::{ParSearch, Search}; + +/// K-Nearest Neighbors search using a Depth First sieve. +pub struct KnnDepthFirst; + +impl, M: Metric, C: Searchable> Search + for KnnDepthFirst +{ + fn search(&self, data: &D, metric: &M, root: &C, query: &I, &k: &usize) -> Vec<(usize, T)> { + let mut candidates = SizedHeap::<(Reverse, &C)>::new(None); + let mut hits = SizedHeap::<(T, usize)>::new(Some(k)); + + let d = root.query_to_center(query, data, metric); + candidates.push((Reverse(d_min(root, d)), root)); + + while !hits.is_full() // We do not have enough hits. + || (!candidates.is_empty() // We have candidates. + && hits // and + .peek() + .map_or_else(|| unreachable!("`hits` is non-empty."), |(d, _)| *d) // the farthest hit + >= candidates // is farther than + .peek() // the closest candidate + .map_or_else(|| unreachable!("`candidates` is non-empty."), |(d, _)| d.0)) + { + let (d, leaf) = pop_till_leaf(data, metric, query, &mut candidates); + leaf_into_hits(data, metric, query, &mut hits, d, leaf); + } + hits.items().map(|(d, i)| (i, d)).collect() + } +} + +impl, M: ParMetric, C: ParSearchable> + ParSearch for KnnDepthFirst +{ + fn par_search(&self, data: &D, metric: &M, root: &C, query: &I, &k: &usize) -> Vec<(usize, T)> { + let mut candidates = SizedHeap::<(Reverse, &C)>::new(None); + let mut hits = SizedHeap::<(T, usize)>::new(Some(k)); + + let d = root.par_query_to_center(query, data, metric); + candidates.push((Reverse(d_min(root, d)), root)); + + while !hits.is_full() // We do not have enough hits. + || (!candidates.is_empty() // We have candidates. + && hits // and + .peek() + .map_or_else(|| unreachable!("`hits` is non-empty."), |(d, _)| *d) // the farthest hit + >= candidates // is farther than + .peek() // the closest candidate + .map_or_else(|| unreachable!("`candidates` is non-empty."), |(d, _)| d.0)) + { + par_pop_till_leaf(data, metric, query, &mut candidates); + par_leaf_into_hits(data, metric, query, &mut hits, &mut candidates); + } + hits.items().map(|(d, i)| (i, d)).collect() + } +} + +/// Calculates the theoretical best case distance for a point in a cluster, i.e., +/// the closest a point in a given cluster could possibly be to the query. +pub fn d_min>(c: &C, d: T) -> T { + if d < c.radius() { + T::ZERO + } else { + d - c.radius() + } +} + +/// Pops from the top of `candidates` until the top candidate is a leaf cluster. +/// Then, pops and returns the leaf cluster. +fn pop_till_leaf<'a, I, T, D, M, C>( + data: &D, + metric: &M, + query: &I, + candidates: &mut SizedHeap<(Reverse, &'a C)>, +) -> (T, &'a C) +where + T: Number + 'a, + D: Dataset, + M: Metric, + C: Searchable, +{ + while candidates + .peek() // The top candidate is a leaf + .map_or_else(|| unreachable!("`candidates` is non-empty"), |(_, c)| !c.is_leaf()) + { + let parent = candidates + .pop() + .map_or_else(|| unreachable!("`candidates` is non-empty"), |(_, c)| c); + for child in parent.child_clusters() { + candidates.push((Reverse(d_min(child, child.query_to_center(query, data, metric))), child)); + } + } + candidates + .pop() + .map_or_else(|| unreachable!("`candidates` is non-empty"), |(Reverse(d), c)| (d, c)) +} + +/// Parallel version of `pop_till_leaf`. +fn par_pop_till_leaf<'a, I, T, D, M, C>(data: &D, metric: &M, query: &I, candidates: &mut SizedHeap<(Reverse, &'a C)>) +where + I: Send + Sync, + T: Number + 'a, + D: ParDataset, + M: ParMetric, + C: ParSearchable, +{ + while candidates + .peek() // The top candidate + .map_or_else(|| unreachable!("`candidates` is non-empty"), |(_, c)| !c.is_leaf()) + // is not a leaf + { + let parent = candidates + .pop() + .map_or_else(|| unreachable!("`candidates` is non-empty"), |(_, c)| c); + let children = parent.child_clusters().collect::>(); + let indices = children.iter().map(|c| c.arg_center()).collect::>(); + indices + .into_par_iter() + .map(|j| (j, data.par_query_to_one(query, j, metric))) + .collect::>() + .into_iter() + .zip(children) + .for_each(|((_, d), c)| candidates.push((Reverse(d_min(c, d)), c))); + } +} + +/// Pops from the top of `candidates` and adds its points to `hits`. +fn leaf_into_hits(data: &D, metric: &M, query: &I, hits: &mut SizedHeap<(T, usize)>, d: T, leaf: &C) +where + T: Number, + D: Dataset, + M: Metric, + C: Searchable, +{ + if leaf.is_singleton() { + leaf.indices().for_each(|i| hits.push((d, i))); + } else { + leaf.query_to_all(query, data, metric) + .for_each(|(i, d)| hits.push((d, i))); + }; +} + +/// Parallel version of `leaf_into_hits`. +fn par_leaf_into_hits( + data: &D, + metric: &M, + query: &I, + hits: &mut SizedHeap<(T, usize)>, + candidates: &mut SizedHeap<(Reverse, &C)>, +) where + I: Send + Sync, + T: Number, + D: ParDataset, + M: ParMetric, + C: ParSearchable, +{ + let (d, leaf) = candidates + .pop() + .map_or_else(|| unreachable!("`candidates` is non-empty"), |(Reverse(d), c)| (d, c)); + if leaf.is_singleton() { + leaf.indices().for_each(|i| hits.push((d, i))); + } else { + leaf.par_query_to_all(query, data, metric) + .collect::>() + .into_iter() + .for_each(|(i, d)| hits.push((d, i))); + }; +} diff --git a/crates/abd-clam/src/cakes/new_search/knn_linear.rs b/crates/abd-clam/src/cakes/new_search/knn_linear.rs new file mode 100644 index 00000000..44977c48 --- /dev/null +++ b/crates/abd-clam/src/cakes/new_search/knn_linear.rs @@ -0,0 +1,32 @@ +//! k-NN search using a linear scan of the dataset. + +use distances::Number; +use rayon::prelude::*; + +use crate::{ + cakes::{ParSearchable, Searchable}, + dataset::ParDataset, + metric::ParMetric, + Dataset, Metric, +}; + +use super::{ParSearch, Search}; + +/// k-NN search using a linear scan of the dataset. +pub struct KnnLinear; + +impl, M: Metric, C: Searchable> Search + for KnnLinear +{ + fn search(&self, data: &D, metric: &M, _: &C, query: &I, &k: &usize) -> Vec<(usize, T)> { + data.knn(query, k, metric).collect() + } +} + +impl, M: ParMetric, C: ParSearchable> + ParSearch for KnnLinear +{ + fn par_search(&self, data: &D, metric: &M, _: &C, query: &I, &k: &usize) -> Vec<(usize, T)> { + data.par_knn(query, k, metric).collect() + } +} diff --git a/crates/abd-clam/src/cakes/new_search/knn_repeated_rnn.rs b/crates/abd-clam/src/cakes/new_search/knn_repeated_rnn.rs new file mode 100644 index 00000000..ff9f2e4e --- /dev/null +++ b/crates/abd-clam/src/cakes/new_search/knn_repeated_rnn.rs @@ -0,0 +1,129 @@ +//! k-NN search using a linear scan of the dataset. + +use distances::{number::Multiplication, Number}; + +use crate::{ + cakes::{ParSearchable, Searchable}, + dataset::ParDataset, + metric::ParMetric, + Cluster, Dataset, Metric, SizedHeap, LFD, +}; + +use super::{ + rnn_clustered::{leaf_search, par_leaf_search, par_tree_search, tree_search}, + ParSearch, ParSearchParams, Search, SearchParams, +}; + +/// Maximum multiplier for the radius of the search when repeating ranged +/// search for `KnnRepeatedRnn`. +pub struct RadiusMultiplier(pub usize, pub T); + +impl SearchParams for RadiusMultiplier {} + +impl ParSearchParams for RadiusMultiplier {} + +/// k-NN search using a linear scan of the dataset. +pub struct KnnRepeatedRnn; + +impl, M: Metric, C: Searchable> Search> + for KnnRepeatedRnn +{ + fn search( + &self, + data: &D, + metric: &M, + root: &C, + query: &I, + &RadiusMultiplier(k, max_multiplier): &RadiusMultiplier, + ) -> Vec<(usize, T)> { + let max_multiplier = max_multiplier.as_f32(); + let mut radius = root.radius().as_f32(); + + let mut multiplier = LFD::multiplier_for_k(root.lfd(), root.cardinality(), k).min(max_multiplier); + radius = radius.mul_add(multiplier, T::EPSILON.as_f32()); + + let [mut confirmed, mut straddlers] = tree_search(data, metric, root, query, T::from(radius)); + + let mut num_confirmed = count_hits(&confirmed); + while num_confirmed == 0 { + radius = radius.double(); + [confirmed, straddlers] = tree_search(data, metric, root, query, T::from(radius)); + num_confirmed = count_hits(&confirmed); + } + + while num_confirmed < k { + let (lfd, car) = mean_lfd(&confirmed, &straddlers); + multiplier = LFD::multiplier_for_k(lfd, car, k) + .min(max_multiplier) + .max(f32::ONE + f32::EPSILON); + radius = radius.mul_add(multiplier, T::EPSILON.as_f32()); + [confirmed, straddlers] = tree_search(data, metric, root, query, T::from(radius)); + num_confirmed = count_hits(&confirmed); + } + + let mut knn = SizedHeap::new(Some(k)); + leaf_search(data, metric, confirmed, straddlers, query, T::from(radius)) + .into_iter() + .for_each(|(i, d)| knn.push((d, i))); + knn.items().map(|(d, i)| (i, d)).collect() + } +} + +impl, M: ParMetric, C: ParSearchable> + ParSearch> for KnnRepeatedRnn +{ + fn par_search( + &self, + data: &D, + metric: &M, + root: &C, + query: &I, + &RadiusMultiplier(k, max_multiplier): &RadiusMultiplier, + ) -> Vec<(usize, T)> { + let max_multiplier = max_multiplier.as_f32(); + let mut radius = root.radius().as_f32(); + + let mut multiplier = LFD::multiplier_for_k(root.lfd(), root.cardinality(), k).min(max_multiplier); + radius = radius.mul_add(multiplier, T::EPSILON.as_f32()); + + let [mut confirmed, mut straddlers] = par_tree_search(data, metric, root, query, T::from(radius)); + + let mut num_confirmed = count_hits(&confirmed); + while num_confirmed == 0 { + radius = radius.double(); + [confirmed, straddlers] = par_tree_search(data, metric, root, query, T::from(radius)); + num_confirmed = count_hits(&confirmed); + } + + while num_confirmed < k { + let (lfd, car) = mean_lfd(&confirmed, &straddlers); + multiplier = LFD::multiplier_for_k(lfd, car, k) + .min(max_multiplier) + .max(f32::ONE + f32::EPSILON); + radius = radius.mul_add(multiplier, T::EPSILON.as_f32()); + [confirmed, straddlers] = par_tree_search(data, metric, root, query, T::from(radius)); + num_confirmed = count_hits(&confirmed); + } + + let mut knn = SizedHeap::new(Some(k)); + par_leaf_search(data, metric, confirmed, straddlers, query, T::from(radius)) + .into_iter() + .for_each(|(i, d)| knn.push((d, i))); + knn.items().map(|(d, i)| (i, d)).collect() + } +} + +/// Count the total cardinality of the clusters. +fn count_hits>(hits: &[(&C, T)]) -> usize { + hits.iter().map(|(c, _)| c.cardinality()).sum() +} + +/// Calculate the weighted mean of the LFDs of the clusters. +fn mean_lfd>(confirmed: &[(&C, T)], straddlers: &[(&C, T)]) -> (f32, usize) { + let (lfd, car) = confirmed + .iter() + .chain(straddlers.iter()) + .map(|&(c, _)| (c.lfd(), c.cardinality())) + .fold((0.0, 0), |(lfd, car), (l, c)| (l.mul_add(c.as_f32(), lfd), car + c)); + (lfd / car.as_f32(), car) +} diff --git a/crates/abd-clam/src/cakes/new_search/mod.rs b/crates/abd-clam/src/cakes/new_search/mod.rs new file mode 100644 index 00000000..ba271516 --- /dev/null +++ b/crates/abd-clam/src/cakes/new_search/mod.rs @@ -0,0 +1,83 @@ +//! Entropy scaling search algorithms and supporting traits. + +use distances::Number; +use rayon::prelude::*; + +use crate::{dataset::ParDataset, metric::ParMetric, Dataset, Metric}; + +use super::{ParSearchable, Searchable}; + +mod knn_breadth_first; +mod knn_depth_first; +mod knn_linear; +mod knn_repeated_rnn; +mod rnn_clustered; +mod rnn_linear; + +pub use knn_breadth_first::KnnBreadthFirst; +pub use knn_depth_first::KnnDepthFirst; +pub use knn_linear::KnnLinear; +pub use knn_repeated_rnn::KnnRepeatedRnn; +pub use rnn_clustered::RnnClustered; +pub use rnn_linear::RnnLinear; + +/// Search parameters for individual entropy scaling search algorithms. +pub trait SearchParams {} + +/// A blanket implementation of `SearchParams` for all `Number` types. +impl SearchParams for T {} + +/// Common trait for entropy scaling search algorithms. +pub trait Search, M: Metric, C: Searchable, P: SearchParams> { + /// Perform a search using the given parameters. + /// + /// # Arguments + /// + /// * `data` - The dataset to search. + /// * `metric` - The metric to use for distance calculations. + /// * `root` - The root of the tree to search. + /// * `query` - The query to search around. + /// * `params` - The parameters to use for the search. + /// + /// # Returns + /// + /// A vector of pairs, where each pair contains the index of an item in the + /// dataset and the distance from the query to that item. + fn search(&self, data: &D, metric: &M, root: &C, query: &I, params: &P) -> Vec<(usize, T)>; + + /// Batched version of `Search::search`. + fn batch_search(&self, data: &D, metric: &M, root: &C, queries: &[I], params: &P) -> Vec> { + queries + .iter() + .map(|query| self.search(data, metric, root, query, params)) + .collect() + } +} + +/// Parallel version of `SearchParams`. +pub trait ParSearchParams: SearchParams + Send + Sync {} + +/// A blanket implementation of `ParSearchParams` for all `Number` types. +impl ParSearchParams for T {} + +/// Parallel version of `Search`. +pub trait ParSearch< + I: Send + Sync, + T: Number, + D: ParDataset, + M: ParMetric, + C: ParSearchable, + P: ParSearchParams, +>: Search + Send + Sync +{ + /// Parallel version of `Search::search`. + fn par_search(&self, data: &D, metric: &M, root: &C, query: &I, params: &P) -> Vec<(usize, T)>; + + /// Parallel version of `Search::batch_search`. + fn par_batch_search(&self, data: &D, metric: &M, root: &C, queries: &[I], params: &P) -> Vec> { + queries + .par_iter() + .map(|query| self.par_search(data, metric, root, query, params)) + .collect() + } +} diff --git a/crates/abd-clam/src/cakes/new_search/rnn_clustered.rs b/crates/abd-clam/src/cakes/new_search/rnn_clustered.rs new file mode 100644 index 00000000..dc987782 --- /dev/null +++ b/crates/abd-clam/src/cakes/new_search/rnn_clustered.rs @@ -0,0 +1,247 @@ +//! Ranged Nearest Neighbors search using a tree, as described in the CHESS +//! paper. + +use distances::Number; +use rayon::prelude::*; + +use crate::{ + cakes::{ParSearchable, Searchable}, + dataset::ParDataset, + metric::ParMetric, + Dataset, Metric, +}; + +use super::{ParSearch, Search}; + +/// Ranged Nearest Neighbors search using a tree. +pub struct RnnClustered; + +impl, M: Metric, C: Searchable> Search for RnnClustered { + fn search(&self, data: &D, metric: &M, root: &C, query: &I, &radius: &T) -> Vec<(usize, T)> { + let [confirmed, straddlers] = tree_search(data, metric, root, query, radius); + leaf_search(data, metric, confirmed, straddlers, query, radius) + } +} + +impl, M: ParMetric, C: ParSearchable> + ParSearch for RnnClustered +{ + fn par_search(&self, data: &D, metric: &M, root: &C, query: &I, &radius: &T) -> Vec<(usize, T)> { + let [confirmed, straddlers] = par_tree_search(data, metric, root, query, radius); + par_leaf_search(data, metric, confirmed, straddlers, query, radius) + } +} + +/// Perform coarse-grained tree search. +/// +/// # Arguments +/// +/// - `data` - The dataset to search. +/// - `root` - The root of the tree to search. +/// - `query` - The query to search around. +/// - `radius` - The radius to search within. +/// +/// # Returns +/// +/// A 2-slice of vectors of 2-tuples, where the first element in the slice +/// is the confirmed clusters, i.e. those that are contained within the +/// query ball, and the second element is the straddlers, i.e. those that +/// overlap the query ball. The 2-tuples are the clusters and the distance +/// from the query to the cluster center. +pub fn tree_search<'a, I, T, D, M, C>(data: &D, metric: &M, root: &'a C, query: &I, radius: T) -> [Vec<(&'a C, T)>; 2] +where + T: Number + 'a, + D: Dataset, + M: Metric, + C: Searchable, +{ + let mut confirmed = Vec::new(); + let mut straddlers = Vec::new(); + let mut candidates = vec![root]; + + let (mut terminal, mut non_terminal): (Vec<_>, Vec<_>); + while !candidates.is_empty() { + (terminal, non_terminal) = candidates + .into_iter() + .map(|c| (c, c.query_to_center(query, data, metric))) + .filter(|&(c, d)| d <= (c.radius() + radius)) + .partition(|&(c, d)| (c.radius() + d) <= radius); + confirmed.append(&mut terminal); + + (terminal, non_terminal) = non_terminal.into_iter().partition(|&(c, _)| c.is_leaf()); + straddlers.append(&mut terminal); + + candidates = non_terminal.into_iter().flat_map(|(c, _)| c.child_clusters()).collect(); + } + + [confirmed, straddlers] +} + +/// Parallel version of `tree_search`. +pub fn par_tree_search<'a, I, T, D, M, C>( + data: &D, + metric: &M, + root: &'a C, + query: &I, + radius: T, +) -> [Vec<(&'a C, T)>; 2] +where + I: Send + Sync, + T: Number + 'a, + D: ParDataset, + M: ParMetric, + C: ParSearchable, +{ + let mut confirmed = Vec::new(); + let mut straddlers = Vec::new(); + let mut candidates = vec![root]; + + let (mut terminal, mut non_terminal): (Vec<_>, Vec<_>); + while !candidates.is_empty() { + (terminal, non_terminal) = candidates + .into_par_iter() + .map(|c| (c, c.par_query_to_center(query, data, metric))) + .filter(|&(c, d)| d <= (c.radius() + radius)) + .partition(|&(c, d)| (c.radius() + d) < radius); + confirmed.append(&mut terminal); + + (terminal, non_terminal) = non_terminal.into_iter().partition(|&(c, _)| c.is_leaf()); + straddlers.append(&mut terminal); + + candidates = non_terminal.into_iter().flat_map(|(c, _)| c.child_clusters()).collect(); + } + + [confirmed, straddlers] +} + +/// Perform fine-grained leaf search. +/// +/// # Arguments +/// +/// - `data` - The dataset to search. +/// - `confirmed` - The confirmed clusters from the tree search. All points +/// in these clusters are guaranteed to be within the query ball. +/// - `straddlers` - The straddlers from the tree search. These clusters +/// overlap the query ball, but not all points in the cluster are guaranteed +/// to be within the query ball. +/// - `query` - The query to search around. +/// - `radius` - The radius to search within. +/// +/// # Returns +/// +/// The `(index, distance)` pairs of the points within the query ball. +pub fn leaf_search( + data: &D, + metric: &M, + confirmed: Vec<(&C, T)>, + straddlers: Vec<(&C, T)>, + query: &I, + radius: T, +) -> Vec<(usize, T)> +where + T: Number, + D: Dataset, + M: Metric, + C: Searchable, +{ + let hits = confirmed.into_iter().flat_map(|(c, d)| { + if c.is_singleton() { + c.indices().map(|i| (i, d)).collect::>() + } else { + c.query_to_all(query, data, metric).collect() + } + }); + + let distances = straddlers + .into_iter() + .flat_map(|(c, _)| c.query_to_all(query, data, metric)) + .filter(|&(_, d)| d <= radius); + + hits.chain(distances).collect() +} + +/// Parallel version of `leaf_search`. +pub fn par_leaf_search( + data: &D, + metric: &M, + confirmed: Vec<(&C, T)>, + straddlers: Vec<(&C, T)>, + query: &I, + radius: T, +) -> Vec<(usize, T)> +where + I: Send + Sync, + T: Number, + D: ParDataset, + M: ParMetric, + C: ParSearchable, +{ + let hits = confirmed.into_par_iter().flat_map(|(c, d)| { + if c.is_singleton() { + c.indices().map(|i| (i, d)).collect::>() + } else { + c.par_query_to_all(query, data, metric).collect() + } + }); + + let distances = straddlers + .into_par_iter() + .flat_map(|(c, _)| c.par_query_to_all(query, data, metric)) + .filter(|&(_, d)| d <= radius); + + hits.chain(distances).collect() +} + +#[cfg(test)] +mod tests { + use crate::{ + cakes::PermutedBall, + cluster::{adapter::BallAdapter, ParPartition, Partition}, + metric::{AbsoluteDifference, Hypotenuse}, + Ball, Cluster, + }; + + use super::RnnClustered; + + use crate::cakes::tests::{check_rnn, gen_grid_data, gen_line_data}; + + #[test] + fn line() { + let data = gen_line_data(10); + let metric = AbsoluteDifference; + let query = &0; + + let criteria = |c: &Ball<_>| c.cardinality() > 1; + let seed = Some(42); + + let ball = Ball::new_tree(&data, &metric, &criteria, seed); + for radius in 0..=4 { + assert!(check_rnn(&ball, &data, &metric, query, radius, &RnnClustered)); + } + + let (off_ball, perm_data) = PermutedBall::from_ball_tree(ball, data, &metric); + for radius in 0..=4 { + assert!(check_rnn(&off_ball, &perm_data, &metric, query, radius, &RnnClustered)); + } + } + + #[test] + fn grid() { + let data = gen_grid_data(10); + let metric = Hypotenuse; + let query = &(0.0, 0.0); + + let criteria = |c: &Ball<_>| c.cardinality() > 1; + let seed = Some(42); + + let ball = Ball::par_new_tree(&data, &metric, &criteria, seed); + for radius in [1.0, 4.0, 8.0, 16.0, 32.0] { + assert!(check_rnn(&ball, &data, &metric, query, radius, &RnnClustered)); + } + + let (off_ball, perm_data) = PermutedBall::from_ball_tree(ball, data, &metric); + for radius in [1.0, 4.0, 8.0, 16.0, 32.0] { + assert!(check_rnn(&off_ball, &perm_data, &metric, query, radius, &RnnClustered)); + } + } +} diff --git a/crates/abd-clam/src/cakes/new_search/rnn_linear.rs b/crates/abd-clam/src/cakes/new_search/rnn_linear.rs new file mode 100644 index 00000000..5335dc3c --- /dev/null +++ b/crates/abd-clam/src/cakes/new_search/rnn_linear.rs @@ -0,0 +1,30 @@ +//! Ranged nearest neighbor search using a linear scan of the dataset. + +use distances::Number; +use rayon::prelude::*; + +use crate::{ + cakes::{ParSearchable, Searchable}, + dataset::ParDataset, + metric::ParMetric, + Dataset, Metric, +}; + +use super::{ParSearch, Search}; + +/// Ranged nearest neighbor search using a linear scan of the dataset. +pub struct RnnLinear; + +impl, M: Metric, C: Searchable> Search for RnnLinear { + fn search(&self, data: &D, metric: &M, _: &C, query: &I, &radius: &T) -> Vec<(usize, T)> { + data.rnn(query, radius, metric).collect() + } +} + +impl, M: ParMetric, C: ParSearchable> + ParSearch for RnnLinear +{ + fn par_search(&self, data: &D, metric: &M, _: &C, query: &I, &radius: &T) -> Vec<(usize, T)> { + data.par_rnn(query, radius, metric).collect() + } +}