Skip to content

Commit

Permalink
wip: search algorithms rework
Browse files Browse the repository at this point in the history
  • Loading branch information
nishaq503 committed Dec 2, 2024
1 parent 08e4568 commit 16fca78
Show file tree
Hide file tree
Showing 8 changed files with 1,069 additions and 128 deletions.
282 changes: 154 additions & 128 deletions crates/abd-clam/src/cakes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
mod cluster;
mod dataset;
pub mod new_search;
mod search;

pub use cluster::{ParSearchable, PermutedBall, Searchable};
Expand All @@ -15,25 +16,10 @@ pub mod tests {

use distances::{number::Float, Number};
use rand::prelude::*;
use test_case::test_case;

use crate::{
cluster::{
adapter::{BallAdapter, ParBallAdapter},
ParPartition, Partition,
},
dataset::{AssociatesMetadataMut, ParDataset},
metric::{Euclidean, Levenshtein, ParMetric},
Ball, Cluster, Dataset, FlatVec,
};
use crate::{metric::ParMetric, Dataset, FlatVec};

use super::{search::Algorithm, ParSearchable, PermutedBall};

/// Type alias for the algorithms and their correctness checkers.
pub type Algs<I, T, Me> = Vec<(
Algorithm<T>,
fn(Vec<(usize, T)>, Vec<(usize, T)>, &str, &FlatVec<I, Me>) -> bool,
)>;
use super::{new_search::ParSearch, ParSearchable};

/// Generate 1d line data for testing.
#[allow(clippy::pedantic)]
Expand All @@ -51,6 +37,14 @@ pub mod tests {
FlatVec::new(data).unwrap_or_else(|e| unreachable!("{e}"))
}

/// Generate random data for testing.
#[allow(clippy::pedantic)]
pub fn gen_random_data<T: Float>(car: usize, dim: usize, max: T, seed: u64) -> FlatVec<Vec<T>, usize> {
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let data = symagen::random_data::random_tabular(car, dim, -max, max, &mut rng);
FlatVec::new(data).unwrap_or_else(|e| unreachable!("{e}"))
}

/// Check the search results using the indices of hits.
#[allow(clippy::pedantic)]
pub fn check_search_by_index<I: Debug, T: Number, Me>(
Expand Down Expand Up @@ -106,14 +100,46 @@ pub mod tests {
true
}

/// Generate random data for testing.
#[allow(clippy::pedantic)]
pub fn gen_random_data<T: Float>(car: usize, dim: usize, max: T, seed: u64) -> FlatVec<Vec<T>, usize> {
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let data = symagen::random_data::random_tabular(car, dim, -max, max, &mut rng);
FlatVec::new(data).unwrap_or_else(|e| unreachable!("{e}"))
/// Check a ranged search algorithm.
pub fn check_rnn<I, U, Me, M, C, A>(
root: &C,
data: &FlatVec<I, Me>,
metric: &M,
query: &I,
radius: U,
alg: &A,
) -> bool
where
I: Debug + Send + Sync,
U: Number,
Me: Debug + Send + Sync,
M: ParMetric<I, U>,
C: ParSearchable<I, U, FlatVec<I, Me>, M>,
A: ParSearch<I, U, FlatVec<I, Me>, M, C, U>,
{
let true_hits = data.rnn(query, radius, metric).collect::<Vec<_>>();

let pred_hits = alg.search(data, metric, root, query, &radius);
assert_eq!(pred_hits.len(), true_hits.len(), "Rnn search failed: {pred_hits:?}");
check_search_by_index(true_hits.clone(), pred_hits, "RnnClustered", data);

let pred_hits = alg.par_search(data, metric, root, query, &radius);
assert_eq!(
pred_hits.len(),
true_hits.len(),
"Parallel Rnn search failed: {pred_hits:?}"
);
check_search_by_index(true_hits, pred_hits, "Par RnnClustered", data);

true
}

/// Type alias for the algorithms and their correctness checkers.
pub type Algs<I, T, Me> = Vec<(
super::Algorithm<T>,
fn(Vec<(usize, T)>, Vec<(usize, T)>, &str, &FlatVec<I, Me>) -> bool,
)>;

/// Check the search results of the algorithms.
pub fn check_search<I, T, D, M, C, Me>(
algs: &Algs<I, T, Me>,
Expand All @@ -127,7 +153,7 @@ pub mod tests {
where
I: Send + Sync,
T: Number,
D: ParDataset<I>,
D: crate::dataset::ParDataset<I>,
M: ParMetric<I, T>,
C: ParSearchable<I, T, D, M>,
{
Expand All @@ -141,108 +167,108 @@ pub mod tests {
true
}

#[test_case(1_000, 10)]
#[test_case(10_000, 10)]
#[test_case(1_000, 100)]
#[test_case(10_000, 100)]
fn vectors(car: usize, dim: usize) {
let mut algs: Algs<Vec<f32>, f32, usize> = vec![];
for radius in [0.1, 1.0] {
algs.push((Algorithm::RnnClustered(radius), check_search_by_index));
}
for k in [1, 10, 100] {
algs.push((Algorithm::KnnRepeatedRnn(k, 2.0), check_search_by_distance));
algs.push((Algorithm::KnnBreadthFirst(k), check_search_by_distance));
algs.push((Algorithm::KnnDepthFirst(k), check_search_by_distance));
}

let seed = 42;
let data = gen_random_data(car, dim, 10.0, seed);
let metric = Euclidean;
let criteria = |c: &Ball<_>| c.cardinality() > 1;
let seed = Some(seed);
let query = &vec![0.0; dim];

let ball = Ball::new_tree(&data, &metric, &criteria, seed);
check_search(&algs, &data, &metric, &ball, query, "ball", &data);

let ball = Ball::par_new_tree(&data, &metric, &criteria, seed);
check_search(&algs, &data, &metric, &ball, query, "ball", &data);

let (off_ball, perm_data) = PermutedBall::from_ball_tree(ball.clone(), data.clone(), &metric);
check_search(&algs, &perm_data, &metric, &off_ball, query, "off_ball", &perm_data);

let (par_off_ball, per_perm_data) = PermutedBall::par_from_ball_tree(ball, data, &metric);
check_search(
&algs,
&per_perm_data,
&metric,
&par_off_ball,
query,
"par_off_ball",
&per_perm_data,
);
}

#[test_case::test_case(16, 16, 2)]
#[test_case::test_case(32, 16, 3)]
fn strings(num_clumps: usize, clump_size: usize, clump_radius: u16) -> Result<(), String> {
let mut algs: Algs<String, u16, String> = vec![];
for radius in [4, 12] {
algs.push((Algorithm::RnnClustered(radius), check_search_by_index));
}
for k in [1, 10] {
algs.push((Algorithm::KnnRepeatedRnn(k, 2), check_search_by_distance));
algs.push((Algorithm::KnnBreadthFirst(k), check_search_by_distance));
algs.push((Algorithm::KnnDepthFirst(k), check_search_by_distance));
}

let seed_length = 30;
let alphabet = "ACTGN".chars().collect::<Vec<_>>();
let seed_string = symagen::random_edits::generate_random_string(seed_length, &alphabet);
let penalties = distances::strings::Penalties::default();
let inter_clump_distance_range = (clump_radius * 5, clump_radius * 7);
let len_delta = seed_length / 10;
let (metadata, data) = symagen::random_edits::generate_clumped_data(
&seed_string,
penalties,
&alphabet,
num_clumps,
clump_size,
clump_radius,
inter_clump_distance_range,
len_delta,
)
.into_iter()
.unzip::<_, _, Vec<_>, Vec<_>>();
let query = &seed_string;

let metric = Levenshtein;
let data = FlatVec::new(data)?.with_metadata(&metadata)?;

let criteria = |c: &Ball<_>| c.cardinality() > 1;
let seed = Some(42);

let ball = Ball::new_tree(&data, &metric, &criteria, seed);
check_search(&algs, &data, &metric, &ball, query, "ball", &data);

let ball = Ball::par_new_tree(&data, &metric, &criteria, seed);
check_search(&algs, &data, &metric, &ball, query, "ball", &data);

let (perm_ball, perm_data) = PermutedBall::from_ball_tree(ball.clone(), data.clone(), &metric);
check_search(&algs, &perm_data, &metric, &perm_ball, query, "off_ball", &perm_data);

let (par_perm_ball, par_perm_data) = PermutedBall::par_from_ball_tree(ball, data, &metric);
check_search(
&algs,
&par_perm_data,
&metric,
&par_perm_ball,
query,
"par_off_ball",
&par_perm_data,
);

Ok(())
}
// #[test_case(1_000, 10)]
// #[test_case(10_000, 10)]
// #[test_case(1_000, 100)]
// #[test_case(10_000, 100)]
// fn vectors(car: usize, dim: usize) {
// let mut algs: Algs<Vec<f32>, f32, usize> = vec![];
// for radius in [0.1, 1.0] {
// algs.push((Algorithm::RnnClustered(radius), check_search_by_index));
// }
// for k in [1, 10, 100] {
// algs.push((Algorithm::KnnRepeatedRnn(k, 2.0), check_search_by_distance));
// algs.push((Algorithm::KnnBreadthFirst(k), check_search_by_distance));
// algs.push((Algorithm::KnnDepthFirst(k), check_search_by_distance));
// }

// let seed = 42;
// let data = gen_random_data(car, dim, 10.0, seed);
// let metric = Euclidean;
// let criteria = |c: &Ball<_>| c.cardinality() > 1;
// let seed = Some(seed);
// let query = &vec![0.0; dim];

// let ball = Ball::new_tree(&data, &metric, &criteria, seed);
// check_search(&algs, &data, &metric, &ball, query, "ball", &data);

// let ball = Ball::par_new_tree(&data, &metric, &criteria, seed);
// check_search(&algs, &data, &metric, &ball, query, "ball", &data);

// let (off_ball, perm_data) = PermutedBall::from_ball_tree(ball.clone(), data.clone(), &metric);
// check_search(&algs, &perm_data, &metric, &off_ball, query, "off_ball", &perm_data);

// let (par_off_ball, per_perm_data) = PermutedBall::par_from_ball_tree(ball, data, &metric);
// check_search(
// &algs,
// &per_perm_data,
// &metric,
// &par_off_ball,
// query,
// "par_off_ball",
// &per_perm_data,
// );
// }

// #[test_case::test_case(16, 16, 2)]
// #[test_case::test_case(32, 16, 3)]
// fn strings(num_clumps: usize, clump_size: usize, clump_radius: u16) -> Result<(), String> {
// let mut algs: Algs<String, u16, String> = vec![];
// for radius in [4, 12] {
// algs.push((Algorithm::RnnClustered(radius), check_search_by_index));
// }
// for k in [1, 10] {
// algs.push((Algorithm::KnnRepeatedRnn(k, 2), check_search_by_distance));
// algs.push((Algorithm::KnnBreadthFirst(k), check_search_by_distance));
// algs.push((Algorithm::KnnDepthFirst(k), check_search_by_distance));
// }

// let seed_length = 30;
// let alphabet = "ACTGN".chars().collect::<Vec<_>>();
// let seed_string = symagen::random_edits::generate_random_string(seed_length, &alphabet);
// let penalties = distances::strings::Penalties::default();
// let inter_clump_distance_range = (clump_radius * 5, clump_radius * 7);
// let len_delta = seed_length / 10;
// let (metadata, data) = symagen::random_edits::generate_clumped_data(
// &seed_string,
// penalties,
// &alphabet,
// num_clumps,
// clump_size,
// clump_radius,
// inter_clump_distance_range,
// len_delta,
// )
// .into_iter()
// .unzip::<_, _, Vec<_>, Vec<_>>();
// let query = &seed_string;

// let metric = Levenshtein;
// let data = FlatVec::new(data)?.with_metadata(&metadata)?;

// let criteria = |c: &Ball<_>| c.cardinality() > 1;
// let seed = Some(42);

// let ball = Ball::new_tree(&data, &metric, &criteria, seed);
// check_search(&algs, &data, &metric, &ball, query, "ball", &data);

// let ball = Ball::par_new_tree(&data, &metric, &criteria, seed);
// check_search(&algs, &data, &metric, &ball, query, "ball", &data);

// let (perm_ball, perm_data) = PermutedBall::from_ball_tree(ball.clone(), data.clone(), &metric);
// check_search(&algs, &perm_data, &metric, &perm_ball, query, "off_ball", &perm_data);

// let (par_perm_ball, par_perm_data) = PermutedBall::par_from_ball_tree(ball, data, &metric);
// check_search(
// &algs,
// &par_perm_data,
// &metric,
// &par_perm_ball,
// query,
// "par_off_ball",
// &par_perm_data,
// );

// Ok(())
// }
}
Loading

0 comments on commit 16fca78

Please sign in to comment.