Skip to content

Commit

Permalink
Optimize KNN allocs
Browse files Browse the repository at this point in the history
  • Loading branch information
minkezhang committed Nov 19, 2021
1 parent e68174a commit 95a2f8d
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions internal/knn/knn.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func path(n *node.N, v vector.V) []*node.N {
// slice.
func KNN(n *node.N, p vector.V, k int) []*node.N {
q := pq.New(k)
knn(n, p, q)
knn(n, p, q, make([]float64, p.Dimension()))

// q.Pop() returns furthest distance first, so we need to reverse the
// queue for the KNN output.
Expand All @@ -50,15 +50,34 @@ func KNN(n *node.N, p vector.V, k int) []*node.N {
return ns
}

// sub uses the scratch space to calculate the difference between two vectors.
//
// Care should be taken that the scratch space is not used concurrently, as
// subsequent calls will modify the slice.
func sub(v vector.V, u vector.V, scratch []float64) vector.V {
for i := vector.D(0); i < v.Dimension(); i++ {
scratch[i] = v.X(i) - u.X(i)
}
return vector.V(scratch)
}

// knn recursively searches for the k-nearest neighbors of a node. The priority
// queue input q in effect tracks the searched space.
func knn(n *node.N, p vector.V, q *pq.Q) {
//
// The scratch input is used to reduce the amount of vector.Sub operations, as
// the returned vector is copy-constructed.
//
// This does not seem have a significant impact relative to the overall
// execution time, but we do see the overall allocs cut in half.
//
// TODO(minkezhang): Determine if this optimization is actually useful.
func knn(n *node.N, p vector.V, q *pq.Q, scratch []float64) {
if n == nil {
return
}

for _, n := range path(n, p) {
if d := vector.Magnitude(vector.Sub(p, n.P())); !q.Full() || d < q.Priority() {
if d := vector.Magnitude(sub(p, n.P(), scratch)); !q.Full() || d < q.Priority() {
q.Push(n, d)
}

Expand All @@ -84,7 +103,7 @@ func knn(n *node.N, p vector.V, q *pq.Q) {
c = n.L()
}

knn(c, p, q)
knn(c, p, q, scratch)
}
}
}

0 comments on commit 95a2f8d

Please sign in to comment.