Skip to content

Commit

Permalink
wip: adding dataset extension with search hints
Browse files Browse the repository at this point in the history
  • Loading branch information
nishaq503 committed Dec 2, 2024
1 parent 6b11ed8 commit 08e4568
Show file tree
Hide file tree
Showing 6 changed files with 254 additions and 1 deletion.
217 changes: 217 additions & 0 deletions crates/abd-clam/src/cakes/dataset/data.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
//! A wrapper around any `Dataset` type to provide a dataset for storing
//! search hints.
use std::collections::HashMap;

use distances::Number;
use serde::{Deserialize, Serialize};

use crate::{
cakes::{Algorithm, ParSearchable, Searchable},
dataset::{AssociatesMetadata, AssociatesMetadataMut, ParDataset, Permutable},
metric::ParMetric,
Cluster, Dataset, Metric,
};

/// A dataset which stores search hints for each item.
#[derive(Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "disk-io", derive(bitcode::Encode, bitcode::Decode))]
pub struct HintedDataset<I, T: Number, D: Dataset<I>> {
/// The underlying dataset.
data: D,
/// The search hints.
hints: Vec<HashMap<usize, T>>,
/// The name of the dataset.
name: String,
/// Ghosts in the machine.
i_phantom: std::marker::PhantomData<I>,
}

impl<I, T: Number, D: Dataset<I>> From<D> for HintedDataset<I, T, D> {
fn from(data: D) -> Self {
let hints = (0..data.cardinality()).map(|_| HashMap::new()).collect();
let name = format!("Hinted({})", data.name());
Self {
data,
hints,
name,
i_phantom: std::marker::PhantomData,
}
}
}

impl<I, T: Number, D: Dataset<I>> HintedDataset<I, T, D> {
/// Get the search hints for the dataset.
pub fn hints(&self) -> &[HashMap<usize, T>] {
&self.hints
}

/// Get the search hints for the dataset as mutable.
pub fn hints_mut(&mut self) -> &mut [HashMap<usize, T>] {
&mut self.hints
}

/// Get the search hints for a specific item by index.
pub fn hints_for(&self, i: usize) -> &HashMap<usize, T> {
&self.hints[i]
}

/// Get the search hints for a specific item by index as mutable.
pub fn hints_for_mut(&mut self, i: usize) -> &mut HashMap<usize, T> {
&mut self.hints[i]
}

/// Add a hint for a specific item.
///
/// # Arguments
///
/// * `i` - The index of the item.
/// * `k` - The number of known neighbors.
/// * `d` - The distance to the farthest known neighbor.
pub fn add_hint(&mut self, i: usize, k: usize, d: T) {
self.hints_for_mut(i).insert(k, d);
}

/// Add hints from a tree.
///
/// For each cluster in the tree, this will add hints for cluster centers
/// using the cluster radius and cardinality.
pub fn add_from_tree<C: Cluster<T>>(&mut self, root: &C) {
self.add_hint(root.arg_center(), root.cardinality(), root.radius());
if !root.is_leaf() {
root.child_clusters().for_each(|c| self.add_from_tree(c));
}
}

/// Add hints using a search algorithm.
///
/// # Arguments
///
/// * `metric` - The metric to use for the search.
/// * `root` - The root of the search tree.
/// * `alg` - The search algorithm to use.
/// * `q` - The index of the query item.
pub fn add_by_search<M: Metric<I, T>, C: Searchable<I, T, Self, M>>(
&mut self,
metric: &M,
root: &C,
alg: &Algorithm<T>,
i: usize,
) {
let (k, r) = alg.params();
let (car, d) = {
let mut hits = alg.search(self, metric, root, self.get(i));
hits.sort_unstable_by(|(_, a), (_, b)| a.total_cmp(b));
hits.last().map_or((0, T::ZERO), |&(j, d)| (j, d))
};
if let Some(k) = k {
self.add_hint(i, k, d);
} else if let Some(r) = r {
self.add_hint(i, car, r);
}
}
}

impl<I: Send + Sync, T: Number, D: ParDataset<I>> HintedDataset<I, T, D> {
/// Parallel version of `HintedDataset::add_by_search`.
pub fn par_add_by_search<M: ParMetric<I, T>, C: ParSearchable<I, T, Self, M>>(
&mut self,
metric: &M,
root: &C,
alg: &Algorithm<T>,
i: usize,
) where
Self: Sized,
{
let (k, r) = alg.params();
let (car, d) = {
let mut hits = alg.par_search(self, metric, root, self.get(i));
hits.sort_unstable_by(|(_, a), (_, b)| a.total_cmp(b));
hits.last().map_or((0, T::ZERO), |&(j, d)| (j, d))
};
if let Some(k) = k {
self.add_hint(i, k, d);
} else if let Some(r) = r {
self.add_hint(i, car, r);
}
}
}

impl<I, T: Number, D: Dataset<I>> Dataset<I> for HintedDataset<I, T, D> {
fn name(&self) -> &str {
&self.name
}

fn with_name(mut self, name: &str) -> Self {
self.name = name.to_string();
self
}

fn cardinality(&self) -> usize {
self.data.cardinality()
}

fn dimensionality_hint(&self) -> (usize, Option<usize>) {
self.data.dimensionality_hint()
}

fn get(&self, index: usize) -> &I {
self.data.get(index)
}
}

impl<I: Send + Sync, T: Number, D: ParDataset<I>> ParDataset<I> for HintedDataset<I, T, D> {}

impl<I, T: Number, Me, D: AssociatesMetadata<I, Me>> AssociatesMetadata<I, Me> for HintedDataset<I, T, D> {
fn metadata(&self) -> &[Me] {
self.data.metadata()
}

fn metadata_at(&self, index: usize) -> &Me {
self.data.metadata_at(index)
}
}

impl<I, T: Number, Me, Met: Clone, Det: AssociatesMetadata<I, Met>, D: AssociatesMetadataMut<I, Me, Met, Det>>
AssociatesMetadataMut<I, Me, Met, Det> for HintedDataset<I, T, D>
{
fn metadata_mut(&mut self) -> &mut [Me] {
self.data.metadata_mut()
}

fn metadata_at_mut(&mut self, index: usize) -> &mut Me {
self.data.metadata_at_mut(index)
}

fn with_metadata(self, metadata: &[Met]) -> Result<Det, String> {
self.data.with_metadata(metadata)
}

fn transform_metadata<F: Fn(&Me) -> Met>(self, f: F) -> Det {
self.data.transform_metadata(f)
}
}

impl<I, T: Number, D: Dataset<I> + Permutable> Permutable for HintedDataset<I, T, D> {
fn permutation(&self) -> Vec<usize> {
self.data.permutation()
}

fn set_permutation(&mut self, permutation: &[usize]) {
self.data.set_permutation(permutation);
}

fn swap_two(&mut self, i: usize, j: usize) {
self.data.swap_two(i, j);
self.hints.swap(i, j);
}
}

#[cfg(feature = "disk-io")]
impl<I, T: Number, D: crate::dataset::DatasetIO<I>> crate::dataset::DatasetIO<I> for HintedDataset<I, T, D> {}

#[cfg(feature = "disk-io")]
impl<I: Send + Sync, T: Number, D: crate::dataset::ParDatasetIO<I>> crate::dataset::ParDatasetIO<I>
for HintedDataset<I, T, D>
{
}
6 changes: 6 additions & 0 deletions crates/abd-clam/src/cakes/dataset/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
//! `Dataset`s which store extra information to improve search performance.
mod data;

#[allow(clippy::module_name_repetitions)]
pub use data::HintedDataset;
2 changes: 2 additions & 0 deletions crates/abd-clam/src/cakes/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
//! Entropy Scaling Search
mod cluster;
mod dataset;
mod search;

pub use cluster::{ParSearchable, PermutedBall, Searchable};
pub use dataset::HintedDataset;
pub use search::Algorithm;

/// Tests for the `cakes` module.
Expand Down
13 changes: 13 additions & 0 deletions crates/abd-clam/src/cakes/search/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,19 @@ impl<T: Number> Algorithm<T> {
}
}

/// Retrieve the parameters of the algorithm.
///
/// Exactly one of the two values will be `None`.
#[must_use]
pub const fn params(&self) -> (Option<usize>, Option<T>) {
match self {
Self::RnnLinear(r) | Self::RnnClustered(r) => (None, Some(*r)),
Self::KnnLinear(k) | Self::KnnRepeatedRnn(k, _) | Self::KnnBreadthFirst(k) | Self::KnnDepthFirst(k) => {
(Some(*k), None)
}
}
}

/// Same variant of the algorithm with different parameters.
#[must_use]
pub const fn with_params(&self, radius: T, k: usize) -> Self {
Expand Down
16 changes: 16 additions & 0 deletions crates/abd-clam/src/core/cluster/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,22 @@ pub trait ParCluster<T: Number>: Cluster<T> + Send + Sync {
/// Parallel version of `Cluster::indices`.
fn par_indices(&self) -> impl ParallelIterator<Item = usize>;

/// Parallel version of `Cluster::child_clusters`.
fn par_child_clusters<'a>(&'a self) -> impl ParallelIterator<Item = &'a Self>
where
T: 'a,
{
self.children().par_iter().map(|(_, _, child)| child.as_ref())
}

/// Parallel version of `Cluster::child_clusters_mut`.
fn par_child_clusters_mut<'a>(&'a mut self) -> impl ParallelIterator<Item = &'a mut Self>
where
T: 'a,
{
self.children_mut().par_iter_mut().map(|(_, _, child)| child.as_mut())
}

/// Parallel version of `Cluster::distance_to_other`.
fn par_distance_to_other<I: Send + Sync, D: ParDataset<I>, M: ParMetric<I, T>>(
&self,
Expand Down
1 change: 0 additions & 1 deletion crates/abd-clam/src/core/dataset/flat_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use super::{AssociatesMetadata, AssociatesMetadataMut, Dataset, ParDataset, Perm
/// - `Me`: The type of the metadata associated with the items.
#[derive(Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "disk-io", derive(bitcode::Encode, bitcode::Decode))]
#[cfg_attr(feature = "disk-io", bitcode(recursive))]
pub struct FlatVec<I, Me> {
/// The items in the dataset.
items: Vec<I>,
Expand Down

0 comments on commit 08e4568

Please sign in to comment.