From ab2905ef055ab20205cd8134f9810e4d26380d21 Mon Sep 17 00:00:00 2001 From: Paul Mach Date: Thu, 5 Jan 2023 10:46:32 -0800 Subject: [PATCH] quadtree: fix bad sort due to pointer allocation issue --- quadtree/maxheap.go | 33 ++++++++++++--------------------- quadtree/maxheap_test.go | 6 ++++-- quadtree/quadtree.go | 4 +++- quadtree/quadtree_test.go | 21 +++++++++++++++++++++ 4 files changed, 40 insertions(+), 24 deletions(-) diff --git a/quadtree/maxheap.go b/quadtree/maxheap.go index 5473a3e..56e0d08 100644 --- a/quadtree/maxheap.go +++ b/quadtree/maxheap.go @@ -6,7 +6,7 @@ import "github.com/paulmach/orb" // the furthest point from the query point in the list, hence maxHeap. // When we find a point closer than the furthest away, we remove // furthest and add the new point to the heap. -type maxHeap []*heapItem +type maxHeap []heapItem type heapItem struct { point orb.Pointer @@ -14,19 +14,10 @@ type heapItem struct { } func (h *maxHeap) Push(point orb.Pointer, distance float64) { - // Common usage is Push followed by a Pop if we have > k points. - // We're reusing the k+1 heapItem object to reduce memory allocations. - // First we manaully lengthen the slice, - // then we see if the last item has been allocated already. - prevLen := len(*h) *h = (*h)[:prevLen+1] - if (*h)[prevLen] == nil { - (*h)[prevLen] = &heapItem{point: point, distance: distance} - } else { - (*h)[prevLen].point = point - (*h)[prevLen].distance = distance - } + (*h)[prevLen].point = point + (*h)[prevLen].distance = distance i := len(*h) - 1 for i > 0 { @@ -53,21 +44,20 @@ func (h *maxHeap) Push(point orb.Pointer, distance float64) { // Pop returns the "greatest" item in the list. // The returned item should not be saved across push/pop operations. -func (h *maxHeap) Pop() *heapItem { - removed := (*h)[0] +func (h *maxHeap) Pop() { lastItem := (*h)[len(*h)-1] (*h) = (*h)[:len(*h)-1] mh := (*h) if len(mh) == 0 { - return removed + return } // move the last item to the top and reset the heap - mh[0] = lastItem + mh[0].point = lastItem.point + mh[0].distance = lastItem.distance i := 0 - current := mh[i] for { right := (i + 1) << 1 left := right - 1 @@ -92,11 +82,12 @@ func (h *maxHeap) Pop() *heapItem { } // swap the nodes - mh[i] = child - mh[childIndex] = current + mh[i].point = child.point + mh[i].distance = child.distance + + mh[childIndex].point = lastItem.point + mh[childIndex].distance = lastItem.distance i = childIndex } - - return removed } diff --git a/quadtree/maxheap_test.go b/quadtree/maxheap_test.go index 885438a..d48545c 100644 --- a/quadtree/maxheap_test.go +++ b/quadtree/maxheap_test.go @@ -14,9 +14,11 @@ func TestMaxHeap(t *testing.T) { h.Push(nil, r.Float64()) } - current := h.Pop().distance + current := h[0].distance + h.Pop() for len(h) > 0 { - next := h.Pop().distance + next := h[0].distance + h.Pop() if next > current { t.Errorf("incorrect") } diff --git a/quadtree/quadtree.go b/quadtree/quadtree.go index d50a66e..847a0af 100644 --- a/quadtree/quadtree.go +++ b/quadtree/quadtree.go @@ -113,6 +113,7 @@ func (q *Quadtree) add(n *node, p orb.Pointer, point orb.Point, left, right, bot // Remove will remove the pointer from the quadtree. By default it'll match // using the points, but a FilterFunc can be provided for a more specific test // if there are elements with the same point value in the tree. For example: +// // func(pointer orb.Pointer) { // return pointer.(*MyType).ID == lookingFor.ID // } @@ -273,7 +274,8 @@ func (q *Quadtree) KNearestMatching(buf []orb.Pointer, p orb.Point, k int, f Fil } for i := len(v.maxHeap) - 1; i >= 0; i-- { - buf[i] = v.maxHeap.Pop().point + buf[i] = v.maxHeap[0].point + v.maxHeap.Pop() } return buf diff --git a/quadtree/quadtree_test.go b/quadtree/quadtree_test.go index a4dca3d..e443615 100644 --- a/quadtree/quadtree_test.go +++ b/quadtree/quadtree_test.go @@ -483,6 +483,27 @@ func TestQuadtreeKNearest_sorted(t *testing.T) { } } +func TestQuadtreeKNearest_sorted2(t *testing.T) { + q := New(orb.Bound{Max: orb.Point{8, 8}}) + q.Add(orb.Point{0, 0}) + q.Add(orb.Point{1, 1}) + q.Add(orb.Point{2, 2}) + q.Add(orb.Point{3, 3}) + q.Add(orb.Point{4, 4}) + q.Add(orb.Point{5, 5}) + q.Add(orb.Point{6, 6}) + q.Add(orb.Point{7, 7}) + + nearest := q.KNearest(nil, orb.Point{5.25, 5.25}, 3) + + expected := []orb.Point{{5, 5}, {6, 6}, {4, 4}} + for i, p := range expected { + if n := nearest[i].Point(); !n.Equal(p) { + t.Errorf("incorrect point %d: %v", i, n) + } + } +} + func TestQuadtreeKNearest_DistanceLimit(t *testing.T) { type dataPointer struct { orb.Pointer