diff --git a/cache/lru_cache.go b/cache/lru_cache.go index f35804ac448..bab5e9549a5 100644 --- a/cache/lru_cache.go +++ b/cache/lru_cache.go @@ -92,7 +92,9 @@ func (c *LRU[K, _]) evict(key K) { } func (c *LRU[K, V]) flush() { - c.elements = linked.NewHashmap[K, V]() + if c.elements != nil { + c.elements.Clear() + } } func (c *LRU[_, _]) len() int { diff --git a/cache/lru_sized_cache.go b/cache/lru_sized_cache.go index 592674cb222..e8c8b0c76e7 100644 --- a/cache/lru_sized_cache.go +++ b/cache/lru_sized_cache.go @@ -113,7 +113,7 @@ func (c *sizedLRU[K, _]) evict(key K) { } func (c *sizedLRU[K, V]) flush() { - c.elements = linked.NewHashmap[K, V]() + c.elements.Clear() c.currentSize = 0 } diff --git a/utils/linked/hashmap.go b/utils/linked/hashmap.go index b17b7b60972..968775adf2c 100644 --- a/utils/linked/hashmap.go +++ b/utils/linked/hashmap.go @@ -63,14 +63,25 @@ func (lh *Hashmap[K, V]) Get(key K) (V, bool) { func (lh *Hashmap[K, V]) Delete(key K) bool { e, ok := lh.entryMap[key] if ok { - lh.entryList.Remove(e) - delete(lh.entryMap, key) - e.Value = keyValue[K, V]{} // Free the key value pair - lh.freeList = append(lh.freeList, e) + lh.remove(e) } return ok } +func (lh *Hashmap[K, V]) Clear() { + for _, e := range lh.entryMap { + lh.remove(e) + } +} + +// remove assumes that [e] is currently in the Hashmap. +func (lh *Hashmap[K, V]) remove(e *ListElement[keyValue[K, V]]) { + delete(lh.entryMap, e.Value.key) + lh.entryList.Remove(e) + e.Value = keyValue[K, V]{} // Free the key value pair + lh.freeList = append(lh.freeList, e) +} + func (lh *Hashmap[K, V]) Len() int { return len(lh.entryMap) } diff --git a/utils/linked/hashmap_test.go b/utils/linked/hashmap_test.go index 1920180b180..25131888dcb 100644 --- a/utils/linked/hashmap_test.go +++ b/utils/linked/hashmap_test.go @@ -95,6 +95,23 @@ func TestHashmap(t *testing.T) { require.Equal(1, val1, "wrong value") } +func TestHashmapClear(t *testing.T) { + require := require.New(t) + + lh := NewHashmap[int, int]() + lh.Put(1, 1) + lh.Put(2, 2) + + lh.Clear() + + require.Empty(lh.entryMap) + require.Zero(lh.entryList.Len()) + require.Len(lh.freeList, 2) + for _, e := range lh.freeList { + require.Zero(e.Value) // Make sure the value is cleared + } +} + func TestIterator(t *testing.T) { require := require.New(t) id1, id2, id3 := ids.GenerateTestID(), ids.GenerateTestID(), ids.GenerateTestID()