Skip to content

Commit

Permalink
Normalize to speed up distance calc
Browse files Browse the repository at this point in the history
  • Loading branch information
kelindar committed Oct 27, 2024
1 parent 3586dc9 commit 0aef6d5
Show file tree
Hide file tree
Showing 14 changed files with 331 additions and 19 deletions.
23 changes: 22 additions & 1 deletion bruteforce.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package search

import (
"math"
"sort"

"github.com/kelindar/search/internal/cosine/simd"
Expand Down Expand Up @@ -36,6 +37,8 @@ func NewIndex[T any]() *Index[T] {

// Add adds a new vector to the search index.
func (b *Index[T]) Add(vx Vector, item T) {
normalize(vx)

b.arr = append(b.arr, entry[T]{
Vector: vx,
Value: item,
Expand All @@ -48,10 +51,13 @@ func (b *Index[T]) Search(query Vector, k int) []Result[T] {
return nil
}

// Normalize and quantize the query vector
normalize(query)

var relevance float64
dst := make(minheap[T], 0, k)
for _, v := range b.arr {
simd.Cosine(&relevance, v.Vector, query)
simd.DotProduct(&relevance, query, v.Vector)
result := Result[T]{
entry: v,
Relevance: relevance,
Expand All @@ -73,6 +79,21 @@ func (b *Index[T]) Search(query Vector, k int) []Result[T] {
return dst
}

// Normalize normalizes the vector, resulting in a unit vector. This allows us
// to do a simple dot product to calculate the cosine similarity instead of
// the full cosine distance.
func normalize(v []float32) {
norm := float32(0)
for _, x := range v {
norm += x * x
}

norm = float32(math.Sqrt(float64(norm)))
for i := range v {
v[i] /= norm
}
}

// --------------------------------- Heap ---------------------------------

// minheap is a min-heap of top values, ordered by relevance.
Expand Down
2 changes: 1 addition & 1 deletion bruteforce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (

/*
cpu: 13th Gen Intel(R) Core(TM) i7-13700K
BenchmarkIndex/search-24 4029 298055 ns/op 272 B/op 3 allocs/op
BenchmarkIndex/search-24 5366 217116 ns/op 272 B/op 3 allocs/op
*/
func BenchmarkIndex(b *testing.B) {
data, err := loadDataset()
Expand Down
13 changes: 12 additions & 1 deletion internal/cosine/cosine_apple.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,15 @@ void f32_cosine_distance(const float *x, const float *y, double *result, const u

double cosine_similarity = (double)sum_xy / (double)denominator;
*result = cosine_similarity;
}
}

void f32_dot_product(const float *x, const float *y, double *result, const uint64_t size) {
float sum = 0.0f;

#pragma clang loop vectorize(enable) interleave(enable)
for (uint64_t i = 0; i < size; i++) {
sum += x[i] * y[i];
}

*result = (double)sum;
}
17 changes: 14 additions & 3 deletions internal/cosine/cosine_avx.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ void f32_cosine_distance(const float *x, const float *y, double *result, const u
#pragma clang loop vectorize(enable) interleave_count(2)
for (uint64_t i = 0; i < size; i++) {
sum_xy += x[i] * y[i]; // Sum of x * y
sum_xx += x[i] * x[i]; // Sum of x * x
sum_yy += y[i] * y[i]; // Sum of y * y
sum_xx += x[i] * x[i]; // Sum of x * x
sum_yy += y[i] * y[i]; // Sum of y * y
}

// Calculate the final result
Expand All @@ -26,4 +26,15 @@ void f32_cosine_distance(const float *x, const float *y, double *result, const u

double cosine_similarity = (double)sum_xy / (double)denominator;
*result = cosine_similarity;
}
}

void f32_dot_product(const float *x, const float *y, double *result, const uint64_t size) {
float sum = 0.0f;

#pragma clang loop vectorize(enable) interleave(enable)
for (uint64_t i = 0; i < size; i++) {
sum += x[i] * y[i];
}

*result = (double)sum;
}
13 changes: 12 additions & 1 deletion internal/cosine/cosine_neon.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,15 @@ void f32_cosine_distance(const float *x, const float *y, double *result, const u

double cosine_similarity = (double)sum_xy / (double)denominator;
*result = cosine_similarity;
}
}

void f32_dot_product(const float *x, const float *y, double *result, const uint64_t size) {
float sum = 0.0f;

#pragma clang loop vectorize(enable) interleave(enable)
for (uint64_t i = 0; i < size; i++) {
sum += x[i] * y[i];
}

*result = (double)sum;
}
3 changes: 3 additions & 0 deletions internal/cosine/simd/cosine_apple.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ import "unsafe"

//go:noescape,nosplit
func f32_cosine_distance(x unsafe.Pointer, y unsafe.Pointer, result unsafe.Pointer, size uint64)

//go:noescape,nosplit
func f32_dot_product(x unsafe.Pointer, y unsafe.Pointer, result unsafe.Pointer, size uint64)
62 changes: 62 additions & 0 deletions internal/cosine/simd/cosine_apple.s
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,65 @@ BB0_11:
WORD $0xfd000040 // str d0, [x2]
WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 ; 16-byte Folded Reload
WORD $0xd65f03c0 // ret

TEXT ·f32_dot_product(SB), $0-32
MOVD x+0(FP), R0
MOVD y+8(FP), R1
MOVD result+16(FP), R2
MOVD size+24(FP), R3
WORD $0xa9bf7bfd // stp x29, x30, [sp, #-16]! ; 16-byte Folded Spill
WORD $0x910003fd // mov x29, sp
WORD $0xb40000c3 // cbz x3, LBB1_3
WORD $0xf100207f // cmp x3, #8
WORD $0x54000102 // b.hs LBB1_4
WORD $0xd2800008 // mov x8, #0
WORD $0x2f00e400 // movi d0, #0000000000000000
WORD $0x14000018 // b LBB1_7

BB1_3:
WORD $0x2f00e400 // movi d0, #0000000000000000
WORD $0xfd000040 // str d0, [x2]
WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 ; 16-byte Folded Reload
WORD $0xd65f03c0 // ret

BB1_4:
WORD $0x927df068 // and x8, x3, #0xfffffffffffffff8
WORD $0x91004009 // add x9, x0, #16
WORD $0x9100402a // add x10, x1, #16
WORD $0x6f00e400 // movi.2d v0, #0000000000000000
WORD $0xaa0803eb // mov x11, x8
WORD $0x6f00e401 // movi.2d v1, #0000000000000000

BB1_5:
WORD $0xad7f8d22 // ldp q2, q3, [x9, #-16]
WORD $0xad7f9544 // ldp q4, q5, [x10, #-16]
WORD $0x4e22cc80 // fmla.4s v0, v4, v2
WORD $0x4e23cca1 // fmla.4s v1, v5, v3
WORD $0x91008129 // add x9, x9, #32
WORD $0x9100814a // add x10, x10, #32
WORD $0xf100216b // subs x11, x11, #8
WORD $0x54ffff21 // b.ne LBB1_5
WORD $0x4e20d420 // fadd.4s v0, v1, v0
WORD $0x6e20d400 // faddp.4s v0, v0, v0
WORD $0x7e30d800 // faddp.2s s0, v0
WORD $0xeb03011f // cmp x8, x3
WORD $0x54000140 // b.eq LBB1_9

BB1_7:
WORD $0xcb080069 // sub x9, x3, x8
WORD $0xd37ef50a // lsl x10, x8, #2
WORD $0x8b0a0028 // add x8, x1, x10
WORD $0x8b0a000a // add x10, x0, x10

BB1_8:
WORD $0xbc404541 // ldr s1, [x10], #4
WORD $0xbc404502 // ldr s2, [x8], #4
WORD $0x1f010040 // fmadd s0, s2, s1, s0
WORD $0xf1000529 // subs x9, x9, #1
WORD $0x54ffff81 // b.ne LBB1_8

BB1_9:
WORD $0x1e22c000 // fcvt d0, s0
WORD $0xfd000040 // str d0, [x2]
WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 ; 16-byte Folded Reload
WORD $0xd65f03c0 // ret
3 changes: 3 additions & 0 deletions internal/cosine/simd/cosine_avx.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ import "unsafe"

//go:noescape,nosplit
func f32_cosine_distance(x unsafe.Pointer, y unsafe.Pointer, result unsafe.Pointer, size uint64)

//go:noescape,nosplit
func f32_dot_product(x unsafe.Pointer, y unsafe.Pointer, result unsafe.Pointer, size uint64)
71 changes: 71 additions & 0 deletions internal/cosine/simd/cosine_avx.s
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,74 @@ LBB0_9:
BYTE $0x5d // pop rbp
WORD $0xf8c5; BYTE $0x77 // vzeroupper
BYTE $0xc3 // ret

TEXT ·f32_dot_product(SB), $0-32
MOVQ x+0(FP), DI
MOVQ y+8(FP), SI
MOVQ result+16(FP), DX
MOVQ size+24(FP), CX
BYTE $0x55 // push rbp
WORD $0x8948; BYTE $0xe5 // mov rbp, rsp
LONG $0xf8e48348 // and rsp, -8
WORD $0x8548; BYTE $0xc9 // test rcx, rcx
JE LBB1_1
LONG $0x20f98348 // cmp rcx, 32
JAE LBB1_5
LONG $0xc057f8c5 // vxorps xmm0, xmm0, xmm0
WORD $0x3145; BYTE $0xc0 // xor r8d, r8d
JMP LBB1_4

LBB1_1:
LONG $0xc057f8c5 // vxorps xmm0, xmm0, xmm0
LONG $0x0211fbc5 // vmovsd qword ptr [rdx], xmm0
WORD $0x8948; BYTE $0xec // mov rsp, rbp
BYTE $0x5d // pop rbp
BYTE $0xc3 // ret

LBB1_5:
WORD $0x8949; BYTE $0xc8 // mov r8, rcx
LONG $0xe0e08349 // and r8, -32
LONG $0xc057f8c5 // vxorps xmm0, xmm0, xmm0
WORD $0xc031 // xor eax, eax
LONG $0xc957f0c5 // vxorps xmm1, xmm1, xmm1
LONG $0xd257e8c5 // vxorps xmm2, xmm2, xmm2
LONG $0xdb57e0c5 // vxorps xmm3, xmm3, xmm3

LBB1_6:
LONG $0x2410fcc5; BYTE $0x86 // vmovups ymm4, ymmword ptr [rsi + 4*rax]
LONG $0x6c10fcc5; WORD $0x2086 // vmovups ymm5, ymmword ptr [rsi + 4*rax + 32]
LONG $0x7410fcc5; WORD $0x4086 // vmovups ymm6, ymmword ptr [rsi + 4*rax + 64]
LONG $0x7c10fcc5; WORD $0x6086 // vmovups ymm7, ymmword ptr [rsi + 4*rax + 96]
LONG $0xb85de2c4; WORD $0x8704 // vfmadd231ps ymm0, ymm4, ymmword ptr [rdi + 4*rax]
LONG $0xb855e2c4; WORD $0x874c; BYTE $0x20 // vfmadd231ps ymm1, ymm5, ymmword ptr [rdi + 4*rax + 32]
LONG $0xb84de2c4; WORD $0x8754; BYTE $0x40 // vfmadd231ps ymm2, ymm6, ymmword ptr [rdi + 4*rax + 64]
LONG $0xb845e2c4; WORD $0x875c; BYTE $0x60 // vfmadd231ps ymm3, ymm7, ymmword ptr [rdi + 4*rax + 96]
LONG $0x20c08348 // add rax, 32
WORD $0x3949; BYTE $0xc0 // cmp r8, rax
JNE LBB1_6
LONG $0xc058f4c5 // vaddps ymm0, ymm1, ymm0
LONG $0xc058ecc5 // vaddps ymm0, ymm2, ymm0
LONG $0xc058e4c5 // vaddps ymm0, ymm3, ymm0
LONG $0x197de3c4; WORD $0x01c1 // vextractf128 xmm1, ymm0, 1
LONG $0xc158f8c5 // vaddps xmm0, xmm0, xmm1
LONG $0x0579e3c4; WORD $0x01c8 // vpermilpd xmm1, xmm0, 1
LONG $0xc158f8c5 // vaddps xmm0, xmm0, xmm1
LONG $0xc816fac5 // vmovshdup xmm1, xmm0
LONG $0xc158fac5 // vaddss xmm0, xmm0, xmm1
WORD $0x3949; BYTE $0xc8 // cmp r8, rcx
JE LBB1_8

LBB1_4:
LONG $0x107aa1c4; WORD $0x860c // vmovss xmm1, dword ptr [rsi + 4*r8]
LONG $0xb971a2c4; WORD $0x8704 // vfmadd231ss xmm0, xmm1, dword ptr [rdi + 4*r8]
WORD $0xff49; BYTE $0xc0 // inc r8
WORD $0x394c; BYTE $0xc1 // cmp rcx, r8
JNE LBB1_4

LBB1_8:
LONG $0xc05afac5 // vcvtss2sd xmm0, xmm0, xmm0
LONG $0x0211fbc5 // vmovsd qword ptr [rdx], xmm0
WORD $0x8948; BYTE $0xec // mov rsp, rbp
BYTE $0x5d // pop rbp
WORD $0xf8c5; BYTE $0x77 // vzeroupper
BYTE $0xc3 // ret
3 changes: 3 additions & 0 deletions internal/cosine/simd/cosine_neon.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ import "unsafe"

//go:noescape,nosplit
func f32_cosine_distance(x unsafe.Pointer, y unsafe.Pointer, result unsafe.Pointer, size uint64)

//go:noescape,nosplit
func f32_dot_product(x unsafe.Pointer, y unsafe.Pointer, result unsafe.Pointer, size uint64)
62 changes: 62 additions & 0 deletions internal/cosine/simd/cosine_neon.s
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,65 @@ LBB0_11:
WORD $0xfd000040 // str d0, [x2]
WORD $0xa8c17bfd // ldp x29, x30, [sp], #16
WORD $0xd65f03c0 // ret

TEXT ·f32_dot_product(SB), $0-32
MOVD x+0(FP), R0
MOVD y+8(FP), R1
MOVD result+16(FP), R2
MOVD size+24(FP), R3
WORD $0xa9bf7bfd // stp x29, x30, [sp, #-16]!
WORD $0x910003fd // mov x29, sp
WORD $0xb40000c3 // cbz x3, .LBB1_3
WORD $0xf100207f // cmp x3, #8
WORD $0x54000102 // b.hs .LBB1_4
WORD $0x2f00e400 // movi d0, #0000000000000000
WORD $0xaa1f03e8 // mov x8, xzr
WORD $0x14000018 // b .LBB1_7

LBB1_3:
WORD $0x2f00e400 // movi d0, #0000000000000000
WORD $0xfd000040 // str d0, [x2]
WORD $0xa8c17bfd // ldp x29, x30, [sp], #16
WORD $0xd65f03c0 // ret

LBB1_4:
WORD $0x927df068 // and x8, x3, #0xfffffffffffffff8
WORD $0x91004009 // add x9, x0, #16
WORD $0x6f00e400 // movi v0.2d, #0000000000000000
WORD $0x9100402a // add x10, x1, #16
WORD $0x6f00e401 // movi v1.2d, #0000000000000000
WORD $0xaa0803eb // mov x11, x8

LBB1_5:
WORD $0xad7f8d22 // ldp q2, q3, [x9, #-16]
WORD $0x91008129 // add x9, x9, #32
WORD $0xf100216b // subs x11, x11, #8
WORD $0xad7f9544 // ldp q4, q5, [x10, #-16]
WORD $0x9100814a // add x10, x10, #32
WORD $0x4e22cc80 // fmla v0.4s, v4.4s, v2.4s
WORD $0x4e23cca1 // fmla v1.4s, v5.4s, v3.4s
WORD $0x54ffff21 // b.ne .LBB1_5
WORD $0x4e20d420 // fadd v0.4s, v1.4s, v0.4s
WORD $0xeb03011f // cmp x8, x3
WORD $0x6e20d400 // faddp v0.4s, v0.4s, v0.4s
WORD $0x7e30d800 // faddp s0, v0.2s
WORD $0x54000140 // b.eq .LBB1_9

LBB1_7:
WORD $0xd37ef50a // lsl x10, x8, #2
WORD $0xcb080069 // sub x9, x3, x8
WORD $0x8b0a0028 // add x8, x1, x10
WORD $0x8b0a000a // add x10, x0, x10

LBB1_8:
WORD $0xbc404541 // ldr s1, [x10], #4
WORD $0xbc404502 // ldr s2, [x8], #4
WORD $0xf1000529 // subs x9, x9, #1
WORD $0x1f010040 // fmadd s0, s2, s1, s0
WORD $0x54ffff81 // b.ne .LBB1_8

LBB1_9:
WORD $0x1e22c000 // fcvt d0, s0
WORD $0xfd000040 // str d0, [x2]
WORD $0xa8c17bfd // ldp x29, x30, [sp], #16
WORD $0xd65f03c0 // ret
5 changes: 4 additions & 1 deletion internal/cosine/simd/cosine_stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ package simd

import "unsafe"

// stub
func f32_cosine_distance(x unsafe.Pointer, y unsafe.Pointer, result unsafe.Pointer, size uint64) {
panic("not implemented")
}

func f32_dot_product(x unsafe.Pointer, y unsafe.Pointer, result unsafe.Pointer, size uint64) {
panic("not implemented")
}
Loading

0 comments on commit 0aef6d5

Please sign in to comment.