Skip to content

Commit

Permalink
heapq: add support for intrusion
Browse files Browse the repository at this point in the history
Add an Update method to the queue that allows the caller to set a callback that
will be invoked whenever the offset of an item in the queue is updated. This
allows the caller to keep track of the offsets of items in the queue as they
are moved by the addition and removal of items.
  • Loading branch information
creachadair committed Feb 9, 2024
1 parent d8e0023 commit 21d3bb4
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 7 deletions.
48 changes: 41 additions & 7 deletions heapq/heapq.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,16 @@ package heapq
type Queue[T any] struct {
data []T
cmp func(a, b T) int
move func(T, int)
}

// nmove is a no-op move function used by default in a queue on which no update
// function has been set.
func nmove[T any](T, int) {}

// New constructs an empty Queue with the given comparison function, where
// cmp(a, b) must be <0 if a < b, =0 if a == b, and >0 if a > b.
func New[T any](cmp func(a, b T) int) *Queue[T] { return &Queue[T]{cmp: cmp} }
func New[T any](cmp func(a, b T) int) *Queue[T] { return &Queue[T]{cmp: cmp, move: nmove[T]} }

// NewWithData constructs an empty Queue with the given comparison function
// that uses the given slice as storage. This allows the caller to initialize
Expand All @@ -34,13 +39,30 @@ func New[T any](cmp func(a, b T) int) *Queue[T] { return &Queue[T]{cmp: cmp} }
// access the contents data after the call unless the queue will no longer be
// used.
func NewWithData[T any](cmp func(a, b T) int, data []T) *Queue[T] {
q := &Queue[T]{data: data, cmp: cmp}
q := &Queue[T]{data: data, cmp: cmp, move: nmove[T]}
for i := len(q.data) / 2; i >= 0; i-- {
q.pushDown(i)
}
return q
}

// Update sets u as the update function on q. This function is called whenever
// an element of the queue is moved to a new position, giving the value and its
// new position. If u == nil, an existing update function is removed. Update
// returns q to allow chaining.
//
// Setting an update function makes q intrusive, allowing values in the queue
// to keep track of their current offset in the queue as items are added and
// removed. By default location information is not reported.
func (q *Queue[T]) Update(u func(T, int)) *Queue[T] {
if u == nil {
q.move = nmove[T]
} else {
q.move = u
}
return q
}

// Len reports the number of elements in the queue. This is a constant-time operation.
func (q *Queue[T]) Len() int { return len(q.data) }

Expand Down Expand Up @@ -86,6 +108,7 @@ func (q *Queue[T]) Pop() (T, bool) {
func (q *Queue[T]) Add(v T) int {
n := len(q.data)
q.data = append(q.data, v)
q.move(q.data[n], n)
return q.pushUp(n)
}

Expand All @@ -105,8 +128,8 @@ func (q *Queue[T]) Remove(n int) (T, bool) {

// Set replaces the contents of q with the specified values. Any previous
// values in the queue are discarded. This operation takes time proportional to
// len(vs) to restore heap order.
func (q *Queue[T]) Set(vs []T) {
// len(vs) to restore heap order. Set returns q to allow chaining.
func (q *Queue[T]) Set(vs []T) *Queue[T] {
// Copy the values so we do not alias the original slice.
// If the existing buffer already has enough space, reslice it; otherwise
// allocate a fresh one.
Expand All @@ -116,9 +139,11 @@ func (q *Queue[T]) Set(vs []T) {
q.data = q.data[:len(vs)]
}
copy(q.data, vs)
for i := len(q.data) / 2; i >= 0; i-- {
for i := len(q.data) - 1; i >= 0; i-- {
q.move(q.data[i], i)
q.pushDown(i)
}
return q
}

// Reorder replaces the ordering function for q with a new function. This
Expand Down Expand Up @@ -155,6 +180,7 @@ func (q *Queue[T]) pop(i int) T {
q.data = q.data[:0]
} else {
q.data[i], q.data[n] = q.data[n], out
q.move(q.data[i], i) // N.B. we do not report a move of out.
q.data = q.data[:n]
q.pushDown(i)
}
Expand All @@ -169,7 +195,7 @@ func (q *Queue[T]) pushUp(i int) int {
if q.cmp(q.data[i], q.data[par]) >= 0 {
break
}
q.data[i], q.data[par] = q.data[par], q.data[i]
q.swap(i, par)
i = par
}
return i
Expand All @@ -190,12 +216,20 @@ func (q *Queue[T]) pushDown(i int) int {
if min == i {
break // no more work to do
}
q.data[i], q.data[min] = q.data[min], q.data[i]
q.swap(i, min)
i, lc = min, 2*min+1
}
return i
}

// swap exchanges the elements at positions i and j of the heap, invoking the
// update function as needed.
func (q *Queue[T]) swap(i, j int) {
q.data[i], q.data[j] = q.data[j], q.data[i]
q.move(q.data[i], i)
q.move(q.data[j], j)
}

// Sort reorders the contents of vs in-place using the heap-sort algorithm, in
// non-decreasing order by the comparison function provided.
func Sort[T any](cmp func(a, b T) int, vs []T) {
Expand Down
53 changes: 53 additions & 0 deletions heapq/heapq_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,59 @@ func TestSort(t *testing.T) {
}
}

func TestUpdate(t *testing.T) {
m := make(map[string]int) // tracks the offsets of strings in the queue
up := func(s string, p int) { m[s] = p } // update the offsets map
q := heapq.New(stdcmp.Compare[string]).Update(up)

// Verify that all the elements know their current offset correctly.
check := func() {
for i := 0; i < q.Len(); i++ {
s, _ := q.Peek(i)
if m[s] != i {
t.Errorf("At pos %d: %s is at %d instead", i, s, m[s])
}
}
}

check() // empty

// Check that Set assigns positions to the elements added.
q.Set([]string{"m", "z", "t", "a", "k", "b"})
check()

// Check that Add updates positions correctly.
q.Add("c")
check()

// Check that we can add an element and remove it by its assigned position.
q.Add("j")
check()

oldp := m["j"]
t.Logf("Added j at pos=%d", oldp)
q.Remove(oldp)
check()

// After removal, the element retains its last position.
if m["j"] != oldp {
t.Errorf("After Remove j: p=%d, want %d", m["j"], oldp)
}

var got []string
for !q.IsEmpty() {
s, _ := q.Pop()
got = append(got, s)
if m[s] != 0 {
t.Errorf("Pop: got %q at p=%d, want p=0", s, m[s])
}

}
if diff := cmp.Diff(got, []string{"a", "b", "c", "k", "m", "t", "z"}); diff != "" {
t.Errorf("Values (-got, +want):\n%s", diff)
}
}

func extract[T any](q *heapq.Queue[T]) []T {
all := make([]T, 0, q.Len())
for !q.IsEmpty() {
Expand Down

0 comments on commit 21d3bb4

Please sign in to comment.