Skip to content

Commit

Permalink
Merge branch 'main-dev' of https://github.com/ashvardanian/SimSIMD in…
Browse files Browse the repository at this point in the history
…to main-dev
  • Loading branch information
ashvardanian committed Nov 26, 2024
2 parents 05e29b3 + cd15779 commit c6ff9ea
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 16 deletions.
213 changes: 202 additions & 11 deletions include/simsimd/curved.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,6 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f32_neon(simsimd_f32_t const *a, simsimd_f3
*result = sum;
}

SIMSIMD_PUBLIC void simsimd_bilinear_f32c_neon(simsimd_f32c_t const *a, simsimd_f32c_t const *b,
simsimd_f32c_t const *c, simsimd_size_t n, simsimd_distance_t *result) {}

SIMSIMD_PUBLIC void simsimd_mahalanobis_f32_neon(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_f32_t const *c,
simsimd_size_t n, simsimd_distance_t *result) {
float32x4_t sum_vec = vdupq_n_f32(0);
Expand Down Expand Up @@ -285,6 +282,58 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f32_neon(simsimd_f32_t const *a, simsimd
*result = _simsimd_sqrt_f64_neon(sum);
}

SIMSIMD_PUBLIC void simsimd_bilinear_f32c_neon(simsimd_f32c_t const *a, simsimd_f32c_t const *b,
simsimd_f32c_t const *c, simsimd_size_t n, simsimd_distance_t *results) {
simsimd_f32_t sum_real = 0;
simsimd_f32_t sum_imag = 0;
for (simsimd_size_t i = 0; i != n; ++i) {
simsimd_f32c_t a_i = a[i];
simsimd_f32c_t cb_j;
float32x4_t cb_j_real_vec = vdupq_n_f32(0);
float32x4_t cb_j_imag_vec = vdupq_n_f32(0);
for (simsimd_size_t j = 0; j + 4 <= n; j += 4) {
// Unpack the input arrays into real and imaginary parts:
float32x4x2_t b_vec = vld2q_f32((simsimd_f32_t const *)(b + j));
float32x4x2_t c_vec = vld2q_f32((simsimd_f32_t const *)(c + i * n + j));
float32x4_t b_real_vec = b_vec.val[0];
float32x4_t b_imag_vec = b_vec.val[1];
float32x4_t c_real_vec = c_vec.val[0];
float32x4_t c_imag_vec = c_vec.val[1];

// Compute the dot product:
cb_j_real_vec = vfmaq_f32(cb_j_real_vec, c_real_vec, b_real_vec);
cb_j_real_vec = vfmsq_f32(cb_j_real_vec, c_imag_vec, b_imag_vec);
cb_j_imag_vec = vfmaq_f32(cb_j_imag_vec, c_real_vec, b_imag_vec);
cb_j_imag_vec = vfmaq_f32(cb_j_imag_vec, c_imag_vec, b_real_vec);
}
cb_j.real = vaddvq_f32(cb_j_real_vec);
cb_j.imag = vaddvq_f32(cb_j_imag_vec);
sum_real += a_i.real * cb_j.real - a_i.imag * cb_j.imag;
sum_imag += a_i.real * cb_j.imag + a_i.imag * cb_j.real;
}

// Handle the tail of every row
simsimd_size_t const tail_length = n % 4;
simsimd_size_t const tail_start = n - tail_length;
if (tail_length) {
for (simsimd_size_t i = 0; i != n; ++i) {
simsimd_f32c_t a_i = a[i];
simsimd_f32c_t cb_j = {0, 0};
for (simsimd_size_t j = tail_start; j != n; ++j) {
simsimd_f32c_t b_j = b[j];
simsimd_f32c_t c_ij = c[i * n + j];
cb_j.real += b_j.real * c_ij.real - b_j.imag * c_ij.imag;
cb_j.imag += b_j.real * c_ij.imag + b_j.imag * c_ij.real;
}
sum_real += a_i.real * cb_j.real - a_i.imag * cb_j.imag;
sum_imag += a_i.real * cb_j.imag + a_i.imag * cb_j.real;
}
}

results[0] = sum_real;
results[1] = sum_imag;
}

#pragma clang attribute pop
#pragma GCC pop_options
#endif // SIMSIMD_TARGET_NEON
Expand Down Expand Up @@ -326,10 +375,6 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f16_neon(simsimd_f16_t const *a, simsimd_f1
*result = sum;
}

SIMSIMD_PUBLIC void simsimd_bilinear_f16c_neon(simsimd_f16c_t const *a, simsimd_f16c_t const *b,
simsimd_f16c_t const *c, simsimd_size_t n, simsimd_distance_t *results) {
}

SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_neon(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_f16_t const *c,
simsimd_size_t n, simsimd_distance_t *result) {
float32x4_t sum_vec = vdupq_n_f32(0);
Expand Down Expand Up @@ -370,6 +415,86 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_neon(simsimd_f16_t const *a, simsimd
*result = _simsimd_sqrt_f32_neon(sum);
}

SIMSIMD_INTERNAL simsimd_f32_t _simsimd_reduce_f16x8_neon(float16x8_t vec) {
// Split the 8-element vector into two 4-element vectors
float16x4_t low = vget_low_f16(vec); // Lower 4 elements
float16x4_t high = vget_high_f16(vec); // Upper 4 elements

// Add the lower and upper parts
float16x4_t sum = vadd_f16(low, high);

// Perform pairwise addition to reduce 4 elements to 2, then to 1
sum = vpadd_f16(sum, sum); // First reduction: 4 -> 2
sum = vpadd_f16(sum, sum); // Second reduction: 2 -> 1

// Convert the remaining half-precision value to single-precision and return
return vgetq_lane_f32(vcvt_f32_f16(sum), 0);
}

SIMSIMD_INTERNAL float16x8x2_t _simsimd_partial_load_f16x8x2_neon(simsimd_f16c_t const *x, simsimd_size_t n) {
union {
float16x8x2_t vecs;
simsimd_f16_t scalars[2][8];
} result;
simsimd_size_t i = 0;
for (; i < n; ++i) result.scalars[0][i] = x[i].real, result.scalars[1][i] = x[i].imag;
for (; i < 8; ++i) result.scalars[0][i] = 0, result.scalars[1][i] = 0;
return result.vecs;
}

SIMSIMD_PUBLIC void simsimd_bilinear_f16c_neon(simsimd_f16c_t const *a, simsimd_f16c_t const *b,
simsimd_f16c_t const *c, simsimd_size_t n, simsimd_distance_t *results) {
simsimd_f32_t sum_real = 0;
simsimd_f32_t sum_imag = 0;
simsimd_size_t const tail_length = n % 8;
simsimd_size_t const tail_start = n - tail_length;
for (simsimd_size_t i = 0; i != n; ++i) {
simsimd_f32c_t a_i = {simsimd_f16_to_f32(&a[i].real), simsimd_f16_to_f32(&a[i].imag)};
float16x8_t cb_j_real_vec = vdupq_n_f16(0);
float16x8_t cb_j_imag_vec = vdupq_n_f16(0);
for (simsimd_size_t j = 0; j + 8 <= n; j += 8) {
// Unpack the input arrays into real and imaginary parts:
float16x8x2_t b_vec = vld2q_f16((float16_t const *)(b + j));
float16x8x2_t c_vec = vld2q_f16((float16_t const *)(c + i * n + j));
float16x8_t b_real_vec = b_vec.val[0];
float16x8_t b_imag_vec = b_vec.val[1];
float16x8_t c_real_vec = c_vec.val[0];
float16x8_t c_imag_vec = c_vec.val[1];

// Compute the dot product:
cb_j_real_vec = vfmaq_f16(cb_j_real_vec, c_real_vec, b_real_vec);
cb_j_real_vec = vfmsq_f16(cb_j_real_vec, c_imag_vec, b_imag_vec);
cb_j_imag_vec = vfmaq_f16(cb_j_imag_vec, c_real_vec, b_imag_vec);
cb_j_imag_vec = vfmaq_f16(cb_j_imag_vec, c_imag_vec, b_real_vec);
}
// Handle row tails
if (tail_length) {
// Unpack the input arrays into real and imaginary parts:
float16x8x2_t b_vec = _simsimd_partial_load_f16x8x2_neon(b + tail_start, tail_length);
float16x8x2_t c_vec = _simsimd_partial_load_f16x8x2_neon(c + i * n + tail_start, tail_length);
float16x8_t b_real_vec = b_vec.val[0];
float16x8_t b_imag_vec = b_vec.val[1];
float16x8_t c_real_vec = c_vec.val[0];
float16x8_t c_imag_vec = c_vec.val[1];

// Compute the dot product:
cb_j_real_vec = vfmaq_f16(cb_j_real_vec, c_real_vec, b_real_vec);
cb_j_real_vec = vfmsq_f16(cb_j_real_vec, c_imag_vec, b_imag_vec);
cb_j_imag_vec = vfmaq_f16(cb_j_imag_vec, c_real_vec, b_imag_vec);
cb_j_imag_vec = vfmaq_f16(cb_j_imag_vec, c_imag_vec, b_real_vec);
}

simsimd_f32c_t cb_j;
cb_j.real = _simsimd_reduce_f16x8_neon(cb_j_real_vec);
cb_j.imag = _simsimd_reduce_f16x8_neon(cb_j_imag_vec);
sum_real += a_i.real * cb_j.real - a_i.imag * cb_j.imag;
sum_imag += a_i.real * cb_j.imag + a_i.imag * cb_j.real;
}

results[0] = sum_real;
results[1] = sum_imag;
}

#pragma clang attribute pop
#pragma GCC pop_options
#endif // SIMSIMD_TARGET_NEON_F16
Expand Down Expand Up @@ -410,10 +535,6 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16_neon(simsimd_bf16_t const *a, simsimd_
*result = sum;
}

SIMSIMD_PUBLIC void simsimd_bilinear_bf16c_neon(simsimd_bf16c_t const *a, simsimd_bf16c_t const *b,
simsimd_bf16c_t const *c, simsimd_size_t n,
simsimd_distance_t *results);

SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_neon(simsimd_bf16_t const *a, simsimd_bf16_t const *b,
simsimd_bf16_t const *c, simsimd_size_t n,
simsimd_distance_t *result) {
Expand Down Expand Up @@ -474,6 +595,76 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_neon(simsimd_bf16_t const *a, simsi
*result = _simsimd_sqrt_f32_neon(sum);
}

SIMSIMD_INTERNAL int16x4x2_t _simsimd_partial_load_bf16x4x2_neon(simsimd_bf16c_t const *x, simsimd_size_t n) {
union {
int16x4x2_t vec;
simsimd_bf16_t scalars[2][4];
} result;
simsimd_size_t i = 0;
for (; i < n; ++i) result.scalars[0][i] = x[i].real, result.scalars[1][i] = x[i].imag;
for (; i < 4; ++i) result.scalars[1][i] = 0, result.scalars[1][i] = 0;
return result.vec;
}

SIMSIMD_PUBLIC void simsimd_bilinear_bf16c_neon(simsimd_bf16c_t const *a, simsimd_bf16c_t const *b,
simsimd_bf16c_t const *c, simsimd_size_t n,
simsimd_distance_t *results) {
simsimd_f32_t sum_real = 0;
simsimd_f32_t sum_imag = 0;
simsimd_size_t const tail_length = n % 4;
simsimd_size_t const tail_start = n - tail_length;
for (simsimd_size_t i = 0; i != n; ++i) {
simsimd_f32c_t a_i = {simsimd_bf16_to_f32(&a[i].real), simsimd_bf16_to_f32(&a[i].imag)};
// A nicer approach is to use `bf16` arithmetic for the dot product, but that requires
// FMLA extensions available on Arm v8.3 and later. That we can also process 16 entries
// at once. That's how the original implementation worked, but compiling it was a nightmare :)
float32x4_t cb_j_real_vec = vdupq_n_f32(0);
float32x4_t cb_j_imag_vec = vdupq_n_f32(0);
for (simsimd_size_t j = 0; j + 4 <= n; j += 4) {
// Unpack the input arrays into real and imaginary parts.
// MSVC sadly doesn't recognize the `vld2_bf16`, so we load the data as signed
// integers of the same size and reinterpret with `vreinterpret_bf16_s16` afterwards.
int16x4x2_t b_vec = vld2_s16((short const *)(b + j));
int16x4x2_t c_vec = vld2_s16((short const *)(c + i * n + j));
float32x4_t b_real_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(b_vec.val[0]));
float32x4_t b_imag_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(b_vec.val[1]));
float32x4_t c_real_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(c_vec.val[0]));
float32x4_t c_imag_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(c_vec.val[1]));

// Compute the dot product:
cb_j_real_vec = vfmaq_f32(cb_j_real_vec, c_real_vec, b_real_vec);
cb_j_real_vec = vfmsq_f32(cb_j_real_vec, c_imag_vec, b_imag_vec);
cb_j_imag_vec = vfmaq_f32(cb_j_imag_vec, c_real_vec, b_imag_vec);
cb_j_imag_vec = vfmaq_f32(cb_j_imag_vec, c_imag_vec, b_real_vec);
}
// Handle row tails
if (tail_length) {
// Unpack the input arrays into real and imaginary parts:
int16x4x2_t b_vec = _simsimd_partial_load_bf16x4x2_neon(b + tail_start, tail_length);
int16x4x2_t c_vec = _simsimd_partial_load_bf16x4x2_neon(c + i * n + tail_start, tail_length);
float32x4_t b_real_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(b_vec.val[0]));
float32x4_t b_imag_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(b_vec.val[1]));
float32x4_t c_real_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(c_vec.val[0]));
float32x4_t c_imag_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(c_vec.val[1]));

// Compute the dot product:
cb_j_real_vec = vfmaq_f32(cb_j_real_vec, c_real_vec, b_real_vec);
cb_j_real_vec = vfmsq_f32(cb_j_real_vec, c_imag_vec, b_imag_vec);
cb_j_imag_vec = vfmaq_f32(cb_j_imag_vec, c_real_vec, b_imag_vec);
cb_j_imag_vec = vfmaq_f32(cb_j_imag_vec, c_imag_vec, b_real_vec);
}

simsimd_f32c_t cb_j;
cb_j.real = vaddvq_f32(cb_j_real_vec);
cb_j.imag = vaddvq_f32(cb_j_imag_vec);
sum_real += a_i.real * cb_j.real - a_i.imag * cb_j.imag;
sum_imag += a_i.real * cb_j.imag + a_i.imag * cb_j.real;
}

results[0] = sum_real;
results[1] = sum_imag;
}

#pragma clang attribute pop
#pragma GCC pop_options
#endif // SIMSIMD_TARGET_NEON_BF16
Expand Down
8 changes: 4 additions & 4 deletions include/simsimd/dot.h
Original file line number Diff line number Diff line change
Expand Up @@ -564,8 +564,8 @@ SIMSIMD_PUBLIC void simsimd_dot_bf16c_neon(simsimd_bf16c_t const *a_pairs, simsi
// Unpack the input arrays into real and imaginary parts.
// MSVC sadly doesn't recognize the `vld2_bf16`, so we load the data as signed
// integers of the same size and reinterpret with `vreinterpret_bf16_s16` afterwards.
int16x4x2_t a_vec = vld2_s16((short *)a_pairs);
int16x4x2_t b_vec = vld2_s16((short *)b_pairs);
int16x4x2_t a_vec = vld2_s16((short const *)a_pairs);
int16x4x2_t b_vec = vld2_s16((short const *)b_pairs);
float32x4_t a_real_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(a_vec.val[0]));
float32x4_t a_imag_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(a_vec.val[1]));
float32x4_t b_real_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(b_vec.val[0]));
Expand Down Expand Up @@ -599,8 +599,8 @@ SIMSIMD_PUBLIC void simsimd_vdot_bf16c_neon(simsimd_bf16c_t const *a_pairs, sims
// Unpack the input arrays into real and imaginary parts.
// MSVC sadly doesn't recognize the `vld2_bf16`, so we load the data as signed
// integers of the same size and reinterpret with `vreinterpret_bf16_s16` afterwards.
int16x4x2_t a_vec = vld2_s16((short *)a_pairs);
int16x4x2_t b_vec = vld2_s16((short *)b_pairs);
int16x4x2_t a_vec = vld2_s16((short const *)a_pairs);
int16x4x2_t b_vec = vld2_s16((short const *)b_pairs);
float32x4_t a_real_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(a_vec.val[0]));
float32x4_t a_imag_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(a_vec.val[1]));
float32x4_t b_real_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(b_vec.val[0]));
Expand Down
2 changes: 1 addition & 1 deletion include/simsimd/simsimd.h
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,7 @@ SIMSIMD_INTERNAL void _simsimd_find_kernel_punned_f32c(simsimd_capability_t v, s
if (v & simsimd_cap_neon_k) switch (k) {
case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32c_neon, *c = simsimd_cap_neon_k; return;
case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f32c_neon, *c = simsimd_cap_neon_k; return;
case simsimd_metric_bilinear_k: *m = (m_t)&simsimd_bilinear_f32c_neon, *c = simsimd_cap_neon_f32_k; return;
case simsimd_metric_bilinear_k: *m = (m_t)&simsimd_bilinear_f32c_neon, *c = simsimd_cap_neon_k; return;
default: break;
}
#endif
Expand Down
3 changes: 3 additions & 0 deletions scripts/bench.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,7 @@ int main(int argc, char **argv) {

curved_<f32_k>("bilinear_f32_neon", simsimd_bilinear_f32_neon, simsimd_bilinear_f32_accurate);
curved_<f32_k>("mahalanobis_f32_neon", simsimd_mahalanobis_f32_neon, simsimd_mahalanobis_f32_accurate);
curved_<f32c_k>("bilinear_f32c_neon", simsimd_bilinear_f32c_neon, simsimd_bilinear_f32c_accurate);

sparse_<u16_k>("intersect_u16_neon", simsimd_intersect_u16_neon, simsimd_intersect_u16_accurate);
sparse_<u32_k>("intersect_u32_neon", simsimd_intersect_u32_neon, simsimd_intersect_u32_accurate);
Expand All @@ -834,6 +835,7 @@ int main(int argc, char **argv) {

curved_<f16_k>("bilinear_f16_neon", simsimd_bilinear_f16_neon, simsimd_bilinear_f16_accurate);
curved_<f16_k>("mahalanobis_f16_neon", simsimd_mahalanobis_f16_neon, simsimd_mahalanobis_f16_accurate);
curved_<f16c_k>("bilinear_f16c_neon", simsimd_bilinear_f16c_neon, simsimd_bilinear_f16c_accurate);

fma_<f16_k>("fma_f16_neon", simsimd_fma_f16_neon, simsimd_fma_f16_accurate, simsimd_l2_f16_accurate);
fma_<f16_k>("wsum_f16_neon", simsimd_wsum_f16_neon, simsimd_wsum_f16_accurate, simsimd_l2_f16_accurate);
Expand All @@ -856,6 +858,7 @@ int main(int argc, char **argv) {

curved_<bf16_k>("bilinear_bf16_neon", simsimd_bilinear_bf16_neon, simsimd_bilinear_bf16_accurate);
curved_<bf16_k>("mahalanobis_bf16_neon", simsimd_mahalanobis_bf16_neon, simsimd_mahalanobis_bf16_accurate);
curved_<bf16c_k>("bilinear_bf16c_neon", simsimd_bilinear_bf16c_neon, simsimd_bilinear_bf16c_accurate);

fma_<bf16_k>("fma_bf16_neon", simsimd_fma_bf16_neon, simsimd_fma_bf16_accurate, simsimd_l2_bf16_accurate);
fma_<bf16_k>("wsum_bf16_neon", simsimd_wsum_bf16_neon, simsimd_wsum_bf16_accurate, simsimd_l2_bf16_accurate);
Expand Down

0 comments on commit c6ff9ea

Please sign in to comment.