Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Normalize to speed up distance calculation #1

Merged
merged 1 commit into from
Oct 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion bruteforce.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package search

import (
"math"
"sort"

"github.com/kelindar/search/internal/cosine/simd"
Expand Down Expand Up @@ -36,6 +37,8 @@ func NewIndex[T any]() *Index[T] {

// Add adds a new vector to the search index.
func (b *Index[T]) Add(vx Vector, item T) {
normalize(vx)

b.arr = append(b.arr, entry[T]{
Vector: vx,
Value: item,
Expand All @@ -48,10 +51,13 @@ func (b *Index[T]) Search(query Vector, k int) []Result[T] {
return nil
}

// Normalize and quantize the query vector
normalize(query)

var relevance float64
dst := make(minheap[T], 0, k)
for _, v := range b.arr {
simd.Cosine(&relevance, v.Vector, query)
simd.DotProduct(&relevance, query, v.Vector)
result := Result[T]{
entry: v,
Relevance: relevance,
Expand All @@ -73,6 +79,21 @@ func (b *Index[T]) Search(query Vector, k int) []Result[T] {
return dst
}

// Normalize normalizes the vector, resulting in a unit vector. This allows us
// to do a simple dot product to calculate the cosine similarity instead of
// the full cosine distance.
func normalize(v []float32) {
norm := float32(0)
for _, x := range v {
norm += x * x
}

norm = float32(math.Sqrt(float64(norm)))
for i := range v {
v[i] /= norm
}
}

// --------------------------------- Heap ---------------------------------

// minheap is a min-heap of top values, ordered by relevance.
Expand Down
2 changes: 1 addition & 1 deletion bruteforce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (

/*
cpu: 13th Gen Intel(R) Core(TM) i7-13700K
BenchmarkIndex/search-24 4029 298055 ns/op 272 B/op 3 allocs/op
BenchmarkIndex/search-24 5366 217116 ns/op 272 B/op 3 allocs/op
*/
func BenchmarkIndex(b *testing.B) {
data, err := loadDataset()
Expand Down
13 changes: 12 additions & 1 deletion internal/cosine/cosine_apple.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,15 @@ void f32_cosine_distance(const float *x, const float *y, double *result, const u

double cosine_similarity = (double)sum_xy / (double)denominator;
*result = cosine_similarity;
}
}

void f32_dot_product(const float *x, const float *y, double *result, const uint64_t size) {
float sum = 0.0f;

#pragma clang loop vectorize(enable) interleave(enable)
for (uint64_t i = 0; i < size; i++) {
sum += x[i] * y[i];
}

*result = (double)sum;
}
17 changes: 14 additions & 3 deletions internal/cosine/cosine_avx.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ void f32_cosine_distance(const float *x, const float *y, double *result, const u
#pragma clang loop vectorize(enable) interleave_count(2)
for (uint64_t i = 0; i < size; i++) {
sum_xy += x[i] * y[i]; // Sum of x * y
sum_xx += x[i] * x[i]; // Sum of x * x
sum_yy += y[i] * y[i]; // Sum of y * y
sum_xx += x[i] * x[i]; // Sum of x * x
sum_yy += y[i] * y[i]; // Sum of y * y
}

// Calculate the final result
Expand All @@ -26,4 +26,15 @@ void f32_cosine_distance(const float *x, const float *y, double *result, const u

double cosine_similarity = (double)sum_xy / (double)denominator;
*result = cosine_similarity;
}
}

void f32_dot_product(const float *x, const float *y, double *result, const uint64_t size) {
float sum = 0.0f;

#pragma clang loop vectorize(enable) interleave(enable)
for (uint64_t i = 0; i < size; i++) {
sum += x[i] * y[i];
}

*result = (double)sum;
}
13 changes: 12 additions & 1 deletion internal/cosine/cosine_neon.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,15 @@ void f32_cosine_distance(const float *x, const float *y, double *result, const u

double cosine_similarity = (double)sum_xy / (double)denominator;
*result = cosine_similarity;
}
}

void f32_dot_product(const float *x, const float *y, double *result, const uint64_t size) {
float sum = 0.0f;

#pragma clang loop vectorize(enable) interleave(enable)
for (uint64_t i = 0; i < size; i++) {
sum += x[i] * y[i];
}

*result = (double)sum;
}
3 changes: 3 additions & 0 deletions internal/cosine/simd/cosine_apple.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ import "unsafe"

//go:noescape,nosplit
func f32_cosine_distance(x unsafe.Pointer, y unsafe.Pointer, result unsafe.Pointer, size uint64)

//go:noescape,nosplit
func f32_dot_product(x unsafe.Pointer, y unsafe.Pointer, result unsafe.Pointer, size uint64)
62 changes: 62 additions & 0 deletions internal/cosine/simd/cosine_apple.s
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,65 @@ BB0_11:
WORD $0xfd000040 // str d0, [x2]
WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 ; 16-byte Folded Reload
WORD $0xd65f03c0 // ret

TEXT ·f32_dot_product(SB), $0-32
MOVD x+0(FP), R0
MOVD y+8(FP), R1
MOVD result+16(FP), R2
MOVD size+24(FP), R3
WORD $0xa9bf7bfd // stp x29, x30, [sp, #-16]! ; 16-byte Folded Spill
WORD $0x910003fd // mov x29, sp
WORD $0xb40000c3 // cbz x3, LBB1_3
WORD $0xf100207f // cmp x3, #8
WORD $0x54000102 // b.hs LBB1_4
WORD $0xd2800008 // mov x8, #0
WORD $0x2f00e400 // movi d0, #0000000000000000
WORD $0x14000018 // b LBB1_7

BB1_3:
WORD $0x2f00e400 // movi d0, #0000000000000000
WORD $0xfd000040 // str d0, [x2]
WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 ; 16-byte Folded Reload
WORD $0xd65f03c0 // ret

BB1_4:
WORD $0x927df068 // and x8, x3, #0xfffffffffffffff8
WORD $0x91004009 // add x9, x0, #16
WORD $0x9100402a // add x10, x1, #16
WORD $0x6f00e400 // movi.2d v0, #0000000000000000
WORD $0xaa0803eb // mov x11, x8
WORD $0x6f00e401 // movi.2d v1, #0000000000000000

BB1_5:
WORD $0xad7f8d22 // ldp q2, q3, [x9, #-16]
WORD $0xad7f9544 // ldp q4, q5, [x10, #-16]
WORD $0x4e22cc80 // fmla.4s v0, v4, v2
WORD $0x4e23cca1 // fmla.4s v1, v5, v3
WORD $0x91008129 // add x9, x9, #32
WORD $0x9100814a // add x10, x10, #32
WORD $0xf100216b // subs x11, x11, #8
WORD $0x54ffff21 // b.ne LBB1_5
WORD $0x4e20d420 // fadd.4s v0, v1, v0
WORD $0x6e20d400 // faddp.4s v0, v0, v0
WORD $0x7e30d800 // faddp.2s s0, v0
WORD $0xeb03011f // cmp x8, x3
WORD $0x54000140 // b.eq LBB1_9

BB1_7:
WORD $0xcb080069 // sub x9, x3, x8
WORD $0xd37ef50a // lsl x10, x8, #2
WORD $0x8b0a0028 // add x8, x1, x10
WORD $0x8b0a000a // add x10, x0, x10

BB1_8:
WORD $0xbc404541 // ldr s1, [x10], #4
WORD $0xbc404502 // ldr s2, [x8], #4
WORD $0x1f010040 // fmadd s0, s2, s1, s0
WORD $0xf1000529 // subs x9, x9, #1
WORD $0x54ffff81 // b.ne LBB1_8

BB1_9:
WORD $0x1e22c000 // fcvt d0, s0
WORD $0xfd000040 // str d0, [x2]
WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 ; 16-byte Folded Reload
WORD $0xd65f03c0 // ret
3 changes: 3 additions & 0 deletions internal/cosine/simd/cosine_avx.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ import "unsafe"

//go:noescape,nosplit
func f32_cosine_distance(x unsafe.Pointer, y unsafe.Pointer, result unsafe.Pointer, size uint64)

//go:noescape,nosplit
func f32_dot_product(x unsafe.Pointer, y unsafe.Pointer, result unsafe.Pointer, size uint64)
71 changes: 71 additions & 0 deletions internal/cosine/simd/cosine_avx.s
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,74 @@ LBB0_9:
BYTE $0x5d // pop rbp
WORD $0xf8c5; BYTE $0x77 // vzeroupper
BYTE $0xc3 // ret

TEXT ·f32_dot_product(SB), $0-32
MOVQ x+0(FP), DI
MOVQ y+8(FP), SI
MOVQ result+16(FP), DX
MOVQ size+24(FP), CX
BYTE $0x55 // push rbp
WORD $0x8948; BYTE $0xe5 // mov rbp, rsp
LONG $0xf8e48348 // and rsp, -8
WORD $0x8548; BYTE $0xc9 // test rcx, rcx
JE LBB1_1
LONG $0x20f98348 // cmp rcx, 32
JAE LBB1_5
LONG $0xc057f8c5 // vxorps xmm0, xmm0, xmm0
WORD $0x3145; BYTE $0xc0 // xor r8d, r8d
JMP LBB1_4

LBB1_1:
LONG $0xc057f8c5 // vxorps xmm0, xmm0, xmm0
LONG $0x0211fbc5 // vmovsd qword ptr [rdx], xmm0
WORD $0x8948; BYTE $0xec // mov rsp, rbp
BYTE $0x5d // pop rbp
BYTE $0xc3 // ret

LBB1_5:
WORD $0x8949; BYTE $0xc8 // mov r8, rcx
LONG $0xe0e08349 // and r8, -32
LONG $0xc057f8c5 // vxorps xmm0, xmm0, xmm0
WORD $0xc031 // xor eax, eax
LONG $0xc957f0c5 // vxorps xmm1, xmm1, xmm1
LONG $0xd257e8c5 // vxorps xmm2, xmm2, xmm2
LONG $0xdb57e0c5 // vxorps xmm3, xmm3, xmm3

LBB1_6:
LONG $0x2410fcc5; BYTE $0x86 // vmovups ymm4, ymmword ptr [rsi + 4*rax]
LONG $0x6c10fcc5; WORD $0x2086 // vmovups ymm5, ymmword ptr [rsi + 4*rax + 32]
LONG $0x7410fcc5; WORD $0x4086 // vmovups ymm6, ymmword ptr [rsi + 4*rax + 64]
LONG $0x7c10fcc5; WORD $0x6086 // vmovups ymm7, ymmword ptr [rsi + 4*rax + 96]
LONG $0xb85de2c4; WORD $0x8704 // vfmadd231ps ymm0, ymm4, ymmword ptr [rdi + 4*rax]
LONG $0xb855e2c4; WORD $0x874c; BYTE $0x20 // vfmadd231ps ymm1, ymm5, ymmword ptr [rdi + 4*rax + 32]
LONG $0xb84de2c4; WORD $0x8754; BYTE $0x40 // vfmadd231ps ymm2, ymm6, ymmword ptr [rdi + 4*rax + 64]
LONG $0xb845e2c4; WORD $0x875c; BYTE $0x60 // vfmadd231ps ymm3, ymm7, ymmword ptr [rdi + 4*rax + 96]
LONG $0x20c08348 // add rax, 32
WORD $0x3949; BYTE $0xc0 // cmp r8, rax
JNE LBB1_6
LONG $0xc058f4c5 // vaddps ymm0, ymm1, ymm0
LONG $0xc058ecc5 // vaddps ymm0, ymm2, ymm0
LONG $0xc058e4c5 // vaddps ymm0, ymm3, ymm0
LONG $0x197de3c4; WORD $0x01c1 // vextractf128 xmm1, ymm0, 1
LONG $0xc158f8c5 // vaddps xmm0, xmm0, xmm1
LONG $0x0579e3c4; WORD $0x01c8 // vpermilpd xmm1, xmm0, 1
LONG $0xc158f8c5 // vaddps xmm0, xmm0, xmm1
LONG $0xc816fac5 // vmovshdup xmm1, xmm0
LONG $0xc158fac5 // vaddss xmm0, xmm0, xmm1
WORD $0x3949; BYTE $0xc8 // cmp r8, rcx
JE LBB1_8

LBB1_4:
LONG $0x107aa1c4; WORD $0x860c // vmovss xmm1, dword ptr [rsi + 4*r8]
LONG $0xb971a2c4; WORD $0x8704 // vfmadd231ss xmm0, xmm1, dword ptr [rdi + 4*r8]
WORD $0xff49; BYTE $0xc0 // inc r8
WORD $0x394c; BYTE $0xc1 // cmp rcx, r8
JNE LBB1_4

LBB1_8:
LONG $0xc05afac5 // vcvtss2sd xmm0, xmm0, xmm0
LONG $0x0211fbc5 // vmovsd qword ptr [rdx], xmm0
WORD $0x8948; BYTE $0xec // mov rsp, rbp
BYTE $0x5d // pop rbp
WORD $0xf8c5; BYTE $0x77 // vzeroupper
BYTE $0xc3 // ret
3 changes: 3 additions & 0 deletions internal/cosine/simd/cosine_neon.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ import "unsafe"

//go:noescape,nosplit
func f32_cosine_distance(x unsafe.Pointer, y unsafe.Pointer, result unsafe.Pointer, size uint64)

//go:noescape,nosplit
func f32_dot_product(x unsafe.Pointer, y unsafe.Pointer, result unsafe.Pointer, size uint64)
62 changes: 62 additions & 0 deletions internal/cosine/simd/cosine_neon.s
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,65 @@ LBB0_11:
WORD $0xfd000040 // str d0, [x2]
WORD $0xa8c17bfd // ldp x29, x30, [sp], #16
WORD $0xd65f03c0 // ret

TEXT ·f32_dot_product(SB), $0-32
MOVD x+0(FP), R0
MOVD y+8(FP), R1
MOVD result+16(FP), R2
MOVD size+24(FP), R3
WORD $0xa9bf7bfd // stp x29, x30, [sp, #-16]!
WORD $0x910003fd // mov x29, sp
WORD $0xb40000c3 // cbz x3, .LBB1_3
WORD $0xf100207f // cmp x3, #8
WORD $0x54000102 // b.hs .LBB1_4
WORD $0x2f00e400 // movi d0, #0000000000000000
WORD $0xaa1f03e8 // mov x8, xzr
WORD $0x14000018 // b .LBB1_7

LBB1_3:
WORD $0x2f00e400 // movi d0, #0000000000000000
WORD $0xfd000040 // str d0, [x2]
WORD $0xa8c17bfd // ldp x29, x30, [sp], #16
WORD $0xd65f03c0 // ret

LBB1_4:
WORD $0x927df068 // and x8, x3, #0xfffffffffffffff8
WORD $0x91004009 // add x9, x0, #16
WORD $0x6f00e400 // movi v0.2d, #0000000000000000
WORD $0x9100402a // add x10, x1, #16
WORD $0x6f00e401 // movi v1.2d, #0000000000000000
WORD $0xaa0803eb // mov x11, x8

LBB1_5:
WORD $0xad7f8d22 // ldp q2, q3, [x9, #-16]
WORD $0x91008129 // add x9, x9, #32
WORD $0xf100216b // subs x11, x11, #8
WORD $0xad7f9544 // ldp q4, q5, [x10, #-16]
WORD $0x9100814a // add x10, x10, #32
WORD $0x4e22cc80 // fmla v0.4s, v4.4s, v2.4s
WORD $0x4e23cca1 // fmla v1.4s, v5.4s, v3.4s
WORD $0x54ffff21 // b.ne .LBB1_5
WORD $0x4e20d420 // fadd v0.4s, v1.4s, v0.4s
WORD $0xeb03011f // cmp x8, x3
WORD $0x6e20d400 // faddp v0.4s, v0.4s, v0.4s
WORD $0x7e30d800 // faddp s0, v0.2s
WORD $0x54000140 // b.eq .LBB1_9

LBB1_7:
WORD $0xd37ef50a // lsl x10, x8, #2
WORD $0xcb080069 // sub x9, x3, x8
WORD $0x8b0a0028 // add x8, x1, x10
WORD $0x8b0a000a // add x10, x0, x10

LBB1_8:
WORD $0xbc404541 // ldr s1, [x10], #4
WORD $0xbc404502 // ldr s2, [x8], #4
WORD $0xf1000529 // subs x9, x9, #1
WORD $0x1f010040 // fmadd s0, s2, s1, s0
WORD $0x54ffff81 // b.ne .LBB1_8

LBB1_9:
WORD $0x1e22c000 // fcvt d0, s0
WORD $0xfd000040 // str d0, [x2]
WORD $0xa8c17bfd // ldp x29, x30, [sp], #16
WORD $0xd65f03c0 // ret
5 changes: 4 additions & 1 deletion internal/cosine/simd/cosine_stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ package simd

import "unsafe"

// stub
func f32_cosine_distance(x unsafe.Pointer, y unsafe.Pointer, result unsafe.Pointer, size uint64) {
panic("not implemented")
}

func f32_dot_product(x unsafe.Pointer, y unsafe.Pointer, result unsafe.Pointer, size uint64) {
panic("not implemented")
}
Loading
Loading