Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Calculate, Hamerly impls for [T; N] #56

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "kmeans_colors"
version = "0.6.0"
authors = ["okaneco <47607823+okaneco@users.noreply.github.com>"]
edition = "2018"
exclude = ["test", "gfx", ".github"]
exclude = ["test", "gfx", ".github", "examples"]
homepage = "https://github.com/okaneco/kmeans-colors"
repository = "https://github.com/okaneco/kmeans-colors"
readme = "README.md"
Expand All @@ -25,7 +25,7 @@ app = [
]

# Enable `palette` color types
palette_color = ["palette", "num-traits"]
palette_color = ["palette"]

[dependencies.image]
version = "0.24.6"
Expand All @@ -43,7 +43,6 @@ optional = true
version = "0.2.16"
default-features = false
features = ["std"]
optional = true

[dependencies.rand]
version = "0.8.5"
Expand Down
3 changes: 1 addition & 2 deletions src/colors/kmeans.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#[cfg(feature = "palette_color")]
use num_traits::{Float, FromPrimitive, Zero};
#[cfg(feature = "palette_color")]
use palette::{rgb::Rgb, rgb::Rgba, Lab};

use num_traits::{Float, FromPrimitive, Zero};
use rand::Rng;

use crate::kmeans::{Calculate, Hamerly, HamerlyCentroids, HamerlyPoint};
Expand Down
2 changes: 2 additions & 0 deletions src/kmeans.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use rand::{Rng, SeedableRng};

mod array;

/// A trait for enabling k-means calculation of a data type.
pub trait Calculate: Sized {
/// Find a points's nearest centroid, index the point with that centroid.
Expand Down
216 changes: 216 additions & 0 deletions src/kmeans/array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
use core::ops::{AddAssign, DivAssign};

use num_traits::{Float, FromPrimitive};

use crate::{Calculate, Hamerly};

impl<T, const N: usize> Calculate for [T; N]
where
T: Float + FromPrimitive + AddAssign + DivAssign + Default,
[T; N]: Default,
{
fn get_closest_centroid(buffer: &[Self], centroids: &[Self], indices: &mut Vec<u8>) {
indices.extend(buffer.iter().map(|color| {
let mut index = 0;
let mut diff;
let mut min = core::f32::MAX;
for (idx, cent) in centroids.iter().enumerate() {
diff = Self::difference(color, cent);
if diff < min {
min = diff;
index = idx;
}
}
index as u8
}))
}

fn recalculate_centroids(
rng: &mut impl rand::Rng,
buf: &[Self],
centroids: &mut [Self],
indices: &[u8],
) {
for (idx, cent) in centroids.iter_mut().enumerate() {
let mut acc = Self::default();
let mut counter: u64 = 0;
for (&jdx, &color) in indices.iter().zip(buf) {
if jdx as usize == idx {
acc.iter_mut().zip(color.iter()).for_each(|(t, &c)| *t += c);
counter += 1;
}
}
if counter != 0 {
let counter_float = T::from_f64(counter as f64).unwrap();
acc.iter_mut().for_each(|t| *t /= counter_float);
*cent = acc;
} else {
*cent = Self::create_random(rng);
}
}
}

fn check_loop(centroids: &[Self], old_centroids: &[Self]) -> f32 {
let mut acc = Self::default();
for (new_cent, old_cent) in centroids.iter().zip(old_centroids) {
acc.iter_mut()
.zip(new_cent.iter())
.zip(old_cent.iter())
.for_each(|((t, &new), &old)| *t += new - old);
}

acc.iter()
.fold(T::default(), |accum, t| accum + t.powi(2))
.to_f32()
.unwrap_or(f32::MAX)
}

// 2023-08 TODO: create_random should take a min and max from a builder object
fn create_random(rng: &mut impl rand::Rng) -> Self {
let mut out = Self::default();
out.iter_mut()
.for_each(|o| *o = T::from_f64(rng.gen_range(0.0..=1.0)).unwrap());
out
}

fn difference(c1: &Self, c2: &Self) -> f32 {
c1.iter()
.zip(c2.iter())
.fold(T::default(), |acc, (&l, &r)| acc + (l - r).powi(2))
.to_f32()
.unwrap_or(f32::MAX)
}
}

impl<T, const N: usize> Hamerly for [T; N]
where
T: Float + FromPrimitive + AddAssign + DivAssign + Default,
[T; N]: Default,
{
fn compute_half_distances(centroids: &mut crate::HamerlyCentroids<Self>) {
// Find each center's closest center
for ((i, ci), half_dist) in centroids
.centroids
.iter()
.enumerate()
.zip(centroids.half_distances.iter_mut())
{
let mut diff;
let mut min = f32::MAX;
for (j, cj) in centroids.centroids.iter().enumerate() {
// Don't compare centroid to itself
if i == j {
continue;
}
diff = Self::difference(ci, cj);
if diff < min {
min = diff;
}
}
*half_dist = min.sqrt() * 0.5;
}
}

fn get_closest_centroid_hamerly(
buffer: &[Self],
centroids: &crate::HamerlyCentroids<Self>,
indices: &mut [crate::HamerlyPoint],
) {
for (val, point) in buffer.iter().zip(indices.iter_mut()) {
// Assign max of lower bound and half distance to z
let z = centroids
.half_distances
.get(point.index as usize)
.unwrap()
.max(point.lower_bound);

if point.upper_bound <= z {
continue;
}

// Tighten upper bound
point.upper_bound =
Self::difference(val, centroids.centroids.get(point.index as usize).unwrap())
.sqrt();

if point.upper_bound <= z {
continue;
}

// Find the two closest centers to current point and their distances
if centroids.centroids.len() < 2 {
continue;
}

let mut min1 = Self::difference(val, centroids.centroids.get(0).unwrap());
let mut min2 = f32::MAX;
let mut c1 = 0;
for j in 1..centroids.centroids.len() {
let diff = Self::difference(val, centroids.centroids.get(j).unwrap());
if diff < min1 {
min2 = min1;
min1 = diff;
c1 = j;
continue;
}
if diff < min2 {
min2 = diff;
}
}

if c1 as u8 != point.index {
point.index = c1 as u8;
point.upper_bound = min1.sqrt();
}
point.lower_bound = min2.sqrt();
}
}

fn recalculate_centroids_hamerly(
rng: &mut impl rand::Rng,
buf: &[Self],
centroids: &mut crate::HamerlyCentroids<Self>,
points: &[crate::HamerlyPoint],
) {
for ((idx, cent), delta) in centroids
.centroids
.iter_mut()
.enumerate()
.zip(centroids.deltas.iter_mut())
{
let mut acc = Self::default();
let mut counter: u64 = 0;
for (point, &color) in points.iter().zip(buf) {
if point.index as usize == idx {
acc.iter_mut().zip(color.iter()).for_each(|(t, &c)| *t += c);
counter += 1;
}
}
if counter != 0 {
let counter_float = T::from_f64(counter as f64).unwrap();
acc.iter_mut().for_each(|t| *t /= counter_float);
} else {
acc = Self::create_random(rng);
}
*delta = Self::difference(cent, &acc).sqrt();
*cent = acc;
}
}

fn update_bounds(
centroids: &crate::HamerlyCentroids<Self>,
points: &mut [crate::HamerlyPoint],
) {
let mut delta_p = 0.0;
for &c in centroids.deltas.iter() {
if c > delta_p {
delta_p = c;
}
}

points.iter_mut().for_each(|point| {
point.upper_bound += centroids.deltas.get(point.index as usize).unwrap();
point.lower_bound -= delta_p;
})
}
}