diff --git a/CHANGELOG.md b/CHANGELOG.md index 6beb5bd..6de7aa3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,13 @@ file is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Added + +- `HDbscan` now accepts `f32`, in addion to `f64`, as the element type of the + input matrix. + ## [0.9.0] - 2024-08-09 ### Changed @@ -111,6 +118,7 @@ Versioning](https://semver.org/spec/v2.0.0.html). - The [OPTICS](https://en.wikipedia.org/wiki/OPTICS_algorithm) clustering algorithm. +[Unreleased]: https://github.com/petabi/petal-clustering/compare/0.9.0...main [0.9.0]: https://github.com/petabi/petal-clustering/compare/0.8.0...0.9.0 [0.8.0]: https://github.com/petabi/petal-clustering/compare/0.7.0...0.8.0 [0.7.0]: https://github.com/petabi/petal-clustering/compare/0.6.0...0.7.0 diff --git a/src/hdbscan.rs b/src/hdbscan.rs index 524d215..2100922 100644 --- a/src/hdbscan.rs +++ b/src/hdbscan.rs @@ -47,8 +47,7 @@ where impl Fit, (HashMap>, Vec)> for HDbscan where - A: AddAssign + DivAssign + FloatCore + FromPrimitive + Sync + Send + TryFrom, - >::Error: Debug, + A: AddAssign + DivAssign + FloatCore + FromPrimitive + Sync + Send, S: Data, M: Metric + Clone + Sync + Send, { @@ -276,12 +275,9 @@ fn condense_mst( result } -fn get_stability>( +fn get_stability( condensed_tree: &ArrayView1<(usize, usize, A, usize)>, -) -> HashMap -where - >::Error: Debug, -{ +) -> HashMap { let mut births: HashMap<_, _> = condensed_tree.iter().fold(HashMap::new(), |mut births, v| { let entry = births.entry(v.1).or_insert(v.2); if *entry > v.2 { @@ -304,19 +300,15 @@ where |mut stability, (parent, _child, lambda, size)| { let entry = stability.entry(*parent).or_insert_with(A::zero); let birth = births.get(parent).expect("invalid child node."); - *entry += (*lambda - *birth) - * A::try_from(u32::try_from(*size).expect("out of bound")).expect("out of bound"); + *entry += (*lambda - *birth) * A::from_usize(*size).expect("infallible"); stability }, ) } -fn find_clusters>( +fn find_clusters( condensed_tree: &ArrayView1<(usize, usize, A, usize)>, -) -> (HashMap>, Vec) -where - >::Error: Debug, -{ +) -> (HashMap>, Vec) { let mut stability = get_stability(condensed_tree); let mut nodes: Vec<_> = stability.keys().copied().collect(); nodes.sort_unstable(); @@ -918,15 +910,45 @@ impl Components { } mod test { + #[test] + fn hdbscan32() { + use ndarray::{array, Array2}; + use petal_neighbors::distance::Euclidean; + + use crate::Fit; + + let data: Array2 = array![ + [1.0, 2.0], + [1.1, 2.2], + [0.9, 1.9], + [1.0, 2.1], + [-2.0, 3.0], + [-2.2, 3.1], + ]; + let mut hdbscan = super::HDbscan { + eps: 0.5, + alpha: 1., + min_samples: 2, + min_cluster_size: 2, + metric: Euclidean::default(), + boruvka: false, + }; + let (clusters, outliers) = hdbscan.fit(&data); + assert_eq!(clusters.len(), 2); + assert_eq!( + outliers.len(), + data.nrows() - clusters.values().fold(0, |acc, v| acc + v.len()) + ); + } #[test] - fn hdbscan() { - use ndarray::array; + fn hdbscan64() { + use ndarray::{array, Array2}; use petal_neighbors::distance::Euclidean; use crate::Fit; - let data = array![ + let data: Array2 = array![ [1.0, 2.0], [1.1, 2.2], [0.9, 1.9],