Skip to content

Commit

Permalink
Fix AVX and NEON L2 distance computation. (#476)
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyxu authored Jan 28, 2023
1 parent 26f77cb commit 972ced7
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 7 deletions.
1 change: 1 addition & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ clap = { version = "4.1.1", features = ["derive"] }
criterion = { version = "0.4", features = ["async", "async_tokio"] }
pprof = { version = "0.11", features = ["flamegraph", "criterion"] }
tempfile = "3.3.0"
approx = "0.5.1"

[features]
cli = ["clap"]
Expand Down
70 changes: 63 additions & 7 deletions rust/src/utils/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,16 @@ unsafe fn euclidean_distance_fma(from: &[f32], to: &[f32]) -> f32 {
sums = _mm256_fmadd_ps(sub, sub, sums);
}
// Shift and add vector, until only 1 value left.
for _ in 0..3 {
let shifted = _mm256_permute2f128_ps(sums, sums, 81);
sums = _mm256_add_ps(sums, shifted);
}
// sums = [x0-x7], shift = [x4-x7]
let mut shift = _mm256_permute2f128_ps(sums, sums, 1);
// [x0+x4, x1+x5, ..]
sums = _mm256_add_ps(sums, shift);
shift = _mm256_permute_ps(sums, 14);
sums = _mm256_add_ps(sums, shift);
sums = _mm256_hadd_ps(sums, sums);
let mut results: [f32; 8] = [0f32; 8];
_mm256_storeu_ps(results.as_mut_ptr(), sums);
results[7]
results[0]
}

/// Calculate L2 distance directly using Arrow compute kernels.
Expand Down Expand Up @@ -84,7 +87,7 @@ unsafe fn l2_distance_neon(from: &[f32], to: &[f32]) -> f32 {
let left = vld1q_f32(from.as_ptr().add(i));
let right = vld1q_f32(to.as_ptr().add(i));
let sub = vsubq_f32(left, right);
sum = vfmaq_laneq_f32(sum, sub, sub, 1);
sum = vfmaq_f32(sum, sub, sub);
}
vaddvq_f32(sum)
}
Expand Down Expand Up @@ -198,9 +201,10 @@ pub fn l2_distance(from: &Float32Array, to: &FixedSizeListArray) -> Result<Arc<F

#[cfg(test)]
mod tests {

use super::*;
use crate::arrow::FixedSizeListArrayExt;

use approx::assert_relative_eq;
use arrow_array::types::Float32Type;

#[test]
Expand Down Expand Up @@ -234,4 +238,56 @@ mod tests {

assert_eq!(scores.as_ref(), &Float32Array::from(vec![20.0]));
}

#[test]
fn test_l2_distance_cases() {
let values: Float32Array = vec![
0.25335717, 0.24663818, 0.26330215, 0.14988247, 0.06042378, 0.21077952, 0.26687378,
0.22145681, 0.18319066, 0.18688454, 0.05216244, 0.11470364, 0.10554603, 0.19964123,
0.06387895, 0.18992095, 0.00123718, 0.13500804, 0.09516747, 0.19508345, 0.2582458,
0.1211653, 0.21121833, 0.24809816, 0.04078768, 0.19586588, 0.16496408, 0.14766085,
0.04898421, 0.14728612, 0.21263947, 0.16763233,
]
.into();
let vectors = FixedSizeListArray::try_new(values, 32).unwrap();

let q: Float32Array = vec![
0.18549609,
0.29954708,
0.28318876,
0.05424477,
0.093134984,
0.21580857,
0.2951282,
0.19866848,
0.13868214,
0.19819534,
0.23271298,
0.047727287,
0.14394054,
0.023316395,
0.18589257,
0.037315924,
0.07037327,
0.32609823,
0.07344752,
0.020155912,
0.18485495,
0.32763934,
0.14296658,
0.04498596,
0.06254237,
0.24348071,
0.16009757,
0.053892266,
0.05918874,
0.040363103,
0.19913352,
0.14545348,
]
.into();

let d = l2_distance(&q, &vectors).unwrap();
assert_relative_eq!(0.31935785197341404, d.value(0));
}
}

0 comments on commit 972ced7

Please sign in to comment.