From 08e456825d30258be38262caff0d88f70a394d07 Mon Sep 17 00:00:00 2001 From: Najib Ishaq Date: Mon, 2 Dec 2024 13:15:48 -0500 Subject: [PATCH] wip: adding dataset extension with search hints --- crates/abd-clam/src/cakes/dataset/data.rs | 217 +++++++++++++++++++ crates/abd-clam/src/cakes/dataset/mod.rs | 6 + crates/abd-clam/src/cakes/mod.rs | 2 + crates/abd-clam/src/cakes/search/mod.rs | 13 ++ crates/abd-clam/src/core/cluster/mod.rs | 16 ++ crates/abd-clam/src/core/dataset/flat_vec.rs | 1 - 6 files changed, 254 insertions(+), 1 deletion(-) create mode 100644 crates/abd-clam/src/cakes/dataset/data.rs create mode 100644 crates/abd-clam/src/cakes/dataset/mod.rs diff --git a/crates/abd-clam/src/cakes/dataset/data.rs b/crates/abd-clam/src/cakes/dataset/data.rs new file mode 100644 index 00000000..6bd17ec0 --- /dev/null +++ b/crates/abd-clam/src/cakes/dataset/data.rs @@ -0,0 +1,217 @@ +//! A wrapper around any `Dataset` type to provide a dataset for storing +//! search hints. + +use std::collections::HashMap; + +use distances::Number; +use serde::{Deserialize, Serialize}; + +use crate::{ + cakes::{Algorithm, ParSearchable, Searchable}, + dataset::{AssociatesMetadata, AssociatesMetadataMut, ParDataset, Permutable}, + metric::ParMetric, + Cluster, Dataset, Metric, +}; + +/// A dataset which stores search hints for each item. +#[derive(Clone, Serialize, Deserialize)] +#[cfg_attr(feature = "disk-io", derive(bitcode::Encode, bitcode::Decode))] +pub struct HintedDataset> { + /// The underlying dataset. + data: D, + /// The search hints. + hints: Vec>, + /// The name of the dataset. + name: String, + /// Ghosts in the machine. + i_phantom: std::marker::PhantomData, +} + +impl> From for HintedDataset { + fn from(data: D) -> Self { + let hints = (0..data.cardinality()).map(|_| HashMap::new()).collect(); + let name = format!("Hinted({})", data.name()); + Self { + data, + hints, + name, + i_phantom: std::marker::PhantomData, + } + } +} + +impl> HintedDataset { + /// Get the search hints for the dataset. + pub fn hints(&self) -> &[HashMap] { + &self.hints + } + + /// Get the search hints for the dataset as mutable. + pub fn hints_mut(&mut self) -> &mut [HashMap] { + &mut self.hints + } + + /// Get the search hints for a specific item by index. + pub fn hints_for(&self, i: usize) -> &HashMap { + &self.hints[i] + } + + /// Get the search hints for a specific item by index as mutable. + pub fn hints_for_mut(&mut self, i: usize) -> &mut HashMap { + &mut self.hints[i] + } + + /// Add a hint for a specific item. + /// + /// # Arguments + /// + /// * `i` - The index of the item. + /// * `k` - The number of known neighbors. + /// * `d` - The distance to the farthest known neighbor. + pub fn add_hint(&mut self, i: usize, k: usize, d: T) { + self.hints_for_mut(i).insert(k, d); + } + + /// Add hints from a tree. + /// + /// For each cluster in the tree, this will add hints for cluster centers + /// using the cluster radius and cardinality. + pub fn add_from_tree>(&mut self, root: &C) { + self.add_hint(root.arg_center(), root.cardinality(), root.radius()); + if !root.is_leaf() { + root.child_clusters().for_each(|c| self.add_from_tree(c)); + } + } + + /// Add hints using a search algorithm. + /// + /// # Arguments + /// + /// * `metric` - The metric to use for the search. + /// * `root` - The root of the search tree. + /// * `alg` - The search algorithm to use. + /// * `q` - The index of the query item. + pub fn add_by_search, C: Searchable>( + &mut self, + metric: &M, + root: &C, + alg: &Algorithm, + i: usize, + ) { + let (k, r) = alg.params(); + let (car, d) = { + let mut hits = alg.search(self, metric, root, self.get(i)); + hits.sort_unstable_by(|(_, a), (_, b)| a.total_cmp(b)); + hits.last().map_or((0, T::ZERO), |&(j, d)| (j, d)) + }; + if let Some(k) = k { + self.add_hint(i, k, d); + } else if let Some(r) = r { + self.add_hint(i, car, r); + } + } +} + +impl> HintedDataset { + /// Parallel version of `HintedDataset::add_by_search`. + pub fn par_add_by_search, C: ParSearchable>( + &mut self, + metric: &M, + root: &C, + alg: &Algorithm, + i: usize, + ) where + Self: Sized, + { + let (k, r) = alg.params(); + let (car, d) = { + let mut hits = alg.par_search(self, metric, root, self.get(i)); + hits.sort_unstable_by(|(_, a), (_, b)| a.total_cmp(b)); + hits.last().map_or((0, T::ZERO), |&(j, d)| (j, d)) + }; + if let Some(k) = k { + self.add_hint(i, k, d); + } else if let Some(r) = r { + self.add_hint(i, car, r); + } + } +} + +impl> Dataset for HintedDataset { + fn name(&self) -> &str { + &self.name + } + + fn with_name(mut self, name: &str) -> Self { + self.name = name.to_string(); + self + } + + fn cardinality(&self) -> usize { + self.data.cardinality() + } + + fn dimensionality_hint(&self) -> (usize, Option) { + self.data.dimensionality_hint() + } + + fn get(&self, index: usize) -> &I { + self.data.get(index) + } +} + +impl> ParDataset for HintedDataset {} + +impl> AssociatesMetadata for HintedDataset { + fn metadata(&self) -> &[Me] { + self.data.metadata() + } + + fn metadata_at(&self, index: usize) -> &Me { + self.data.metadata_at(index) + } +} + +impl, D: AssociatesMetadataMut> + AssociatesMetadataMut for HintedDataset +{ + fn metadata_mut(&mut self) -> &mut [Me] { + self.data.metadata_mut() + } + + fn metadata_at_mut(&mut self, index: usize) -> &mut Me { + self.data.metadata_at_mut(index) + } + + fn with_metadata(self, metadata: &[Met]) -> Result { + self.data.with_metadata(metadata) + } + + fn transform_metadata Met>(self, f: F) -> Det { + self.data.transform_metadata(f) + } +} + +impl + Permutable> Permutable for HintedDataset { + fn permutation(&self) -> Vec { + self.data.permutation() + } + + fn set_permutation(&mut self, permutation: &[usize]) { + self.data.set_permutation(permutation); + } + + fn swap_two(&mut self, i: usize, j: usize) { + self.data.swap_two(i, j); + self.hints.swap(i, j); + } +} + +#[cfg(feature = "disk-io")] +impl> crate::dataset::DatasetIO for HintedDataset {} + +#[cfg(feature = "disk-io")] +impl> crate::dataset::ParDatasetIO + for HintedDataset +{ +} diff --git a/crates/abd-clam/src/cakes/dataset/mod.rs b/crates/abd-clam/src/cakes/dataset/mod.rs new file mode 100644 index 00000000..752d7028 --- /dev/null +++ b/crates/abd-clam/src/cakes/dataset/mod.rs @@ -0,0 +1,6 @@ +//! `Dataset`s which store extra information to improve search performance. + +mod data; + +#[allow(clippy::module_name_repetitions)] +pub use data::HintedDataset; diff --git a/crates/abd-clam/src/cakes/mod.rs b/crates/abd-clam/src/cakes/mod.rs index 94e2bd6a..38574be8 100644 --- a/crates/abd-clam/src/cakes/mod.rs +++ b/crates/abd-clam/src/cakes/mod.rs @@ -1,9 +1,11 @@ //! Entropy Scaling Search mod cluster; +mod dataset; mod search; pub use cluster::{ParSearchable, PermutedBall, Searchable}; +pub use dataset::HintedDataset; pub use search::Algorithm; /// Tests for the `cakes` module. diff --git a/crates/abd-clam/src/cakes/search/mod.rs b/crates/abd-clam/src/cakes/search/mod.rs index 8ab89de1..f4254958 100644 --- a/crates/abd-clam/src/cakes/search/mod.rs +++ b/crates/abd-clam/src/cakes/search/mod.rs @@ -181,6 +181,19 @@ impl Algorithm { } } + /// Retrieve the parameters of the algorithm. + /// + /// Exactly one of the two values will be `None`. + #[must_use] + pub const fn params(&self) -> (Option, Option) { + match self { + Self::RnnLinear(r) | Self::RnnClustered(r) => (None, Some(*r)), + Self::KnnLinear(k) | Self::KnnRepeatedRnn(k, _) | Self::KnnBreadthFirst(k) | Self::KnnDepthFirst(k) => { + (Some(*k), None) + } + } + } + /// Same variant of the algorithm with different parameters. #[must_use] pub const fn with_params(&self, radius: T, k: usize) -> Self { diff --git a/crates/abd-clam/src/core/cluster/mod.rs b/crates/abd-clam/src/core/cluster/mod.rs index 1193f373..e3dff48f 100644 --- a/crates/abd-clam/src/core/cluster/mod.rs +++ b/crates/abd-clam/src/core/cluster/mod.rs @@ -240,6 +240,22 @@ pub trait ParCluster: Cluster + Send + Sync { /// Parallel version of `Cluster::indices`. fn par_indices(&self) -> impl ParallelIterator; + /// Parallel version of `Cluster::child_clusters`. + fn par_child_clusters<'a>(&'a self) -> impl ParallelIterator + where + T: 'a, + { + self.children().par_iter().map(|(_, _, child)| child.as_ref()) + } + + /// Parallel version of `Cluster::child_clusters_mut`. + fn par_child_clusters_mut<'a>(&'a mut self) -> impl ParallelIterator + where + T: 'a, + { + self.children_mut().par_iter_mut().map(|(_, _, child)| child.as_mut()) + } + /// Parallel version of `Cluster::distance_to_other`. fn par_distance_to_other, M: ParMetric>( &self, diff --git a/crates/abd-clam/src/core/dataset/flat_vec.rs b/crates/abd-clam/src/core/dataset/flat_vec.rs index 91c38da8..be8ac459 100644 --- a/crates/abd-clam/src/core/dataset/flat_vec.rs +++ b/crates/abd-clam/src/core/dataset/flat_vec.rs @@ -12,7 +12,6 @@ use super::{AssociatesMetadata, AssociatesMetadataMut, Dataset, ParDataset, Perm /// - `Me`: The type of the metadata associated with the items. #[derive(Clone, Serialize, Deserialize)] #[cfg_attr(feature = "disk-io", derive(bitcode::Encode, bitcode::Decode))] -#[cfg_attr(feature = "disk-io", bitcode(recursive))] pub struct FlatVec { /// The items in the dataset. items: Vec,