Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use generic to allow 32 and 64 uint keys #30

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bits.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ func metaMatchEmpty(m *metadata) bitset {
return hasZeroByte(castUint64(m) ^ hiBits)
}

func nextMatch(b *bitset) uint32 {
s := uint32(bits.TrailingZeros64(uint64(*b)))
func nextMatch[S Size](b *bitset) S {
s := S(bits.TrailingZeros64(uint64(*b)))
*b &= ^(1 << s) // clear bit |s|
return s >> 3 // div by 8
}
Expand Down
8 changes: 4 additions & 4 deletions bits_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestMatchMetadata(t *testing.T) {
for _, x := range meta {
mask := metaMatchH2(&meta, h2(x))
assert.NotZero(t, mask)
assert.Equal(t, uint32(x), nextMatch(&mask))
assert.Equal(t, uint32(x), nextMatch[uint32](&mask))
}
})
t.Run("metaMatchEmpty", func(t *testing.T) {
Expand All @@ -42,7 +42,7 @@ func TestMatchMetadata(t *testing.T) {
meta[i] = empty
mask = metaMatchEmpty(&meta)
assert.NotZero(t, mask)
assert.Equal(t, uint32(i), nextMatch(&mask))
assert.Equal(t, uint32(i), nextMatch[uint32](&mask))
meta[i] = int8(i)
}
})
Expand All @@ -51,14 +51,14 @@ func TestMatchMetadata(t *testing.T) {
meta = newEmptyMetadata()
mask := metaMatchEmpty(&meta)
for i := range meta {
assert.Equal(t, uint32(i), nextMatch(&mask))
assert.Equal(t, uint32(i), nextMatch[uint32](&mask))
}
for i := 0; i < len(meta); i += 2 {
meta[i] = int8(42)
}
mask = metaMatchH2(&meta, h2(42))
for i := 0; i < len(meta); i += 2 {
assert.Equal(t, uint32(i), nextMatch(&mask))
assert.Equal(t, uint32(i), nextMatch[uint32](&mask))
}
})
}
Expand Down
101 changes: 57 additions & 44 deletions map.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,19 @@ const (
maxLoadFactor = float32(maxAvgGroupLoad) / float32(groupSize)
)

type Size interface {
uint32 | uint64
}

// Map is an open-addressing hash map
// based on Abseil's flat_hash_map.
type Map[K comparable, V any] struct {
type Map[K comparable, V any, S Size] struct {
jtarchie marked this conversation as resolved.
Show resolved Hide resolved
ctrl []metadata
groups []group[K, V]
hash maphash.Hasher[K]
resident uint32
dead uint32
limit uint32
resident S
dead S
limit S
}

// metadata is the h2 metadata array for a group.
Expand Down Expand Up @@ -58,9 +62,18 @@ type h1 uint64
type h2 int8

// NewMap constructs a Map.
func NewMap[K comparable, V any](sz uint32) (m *Map[K, V]) {
func NewMap[K comparable, V any](sz uint32) (m *Map[K, V, uint32]) {
return newMap[K, V, uint32](sz)
}

// NewMap constructs a Map.
func NewMap64[K comparable, V any](sz uint64) (m *Map[K, V, uint64]) {
return newMap[K, V, uint64](sz)
}

func newMap[K comparable, V any, S Size](sz S) (m *Map[K, V, S]) {
groups := numGroups(sz)
m = &Map[K, V]{
m = &Map[K, V, S]{
ctrl: make([]metadata, groups),
groups: make([]group[K, V], groups),
hash: maphash.NewHasher[K](),
Expand All @@ -73,13 +86,13 @@ func NewMap[K comparable, V any](sz uint32) (m *Map[K, V]) {
}

// Has returns true if |key| is present in |m|.
func (m *Map[K, V]) Has(key K) (ok bool) {
func (m *Map[K, V, S]) Has(key K) (ok bool) {
hi, lo := splitHash(m.hash.Hash(key))
g := probeStart(hi, len(m.groups))
g := probeStart[S](hi, len(m.groups))
for { // inlined find loop
matches := metaMatchH2(&m.ctrl[g], lo)
for matches != 0 {
s := nextMatch(&matches)
s := nextMatch[S](&matches)
if key == m.groups[g].keys[s] {
ok = true
return
Expand All @@ -93,20 +106,20 @@ func (m *Map[K, V]) Has(key K) (ok bool) {
return
}
g += 1 // linear probing
if g >= uint32(len(m.groups)) {
if g >= S(len(m.groups)) {
g = 0
}
}
}

// Get returns the |value| mapped by |key| if one exists.
func (m *Map[K, V]) Get(key K) (value V, ok bool) {
func (m *Map[K, V, S]) Get(key K) (value V, ok bool) {
hi, lo := splitHash(m.hash.Hash(key))
g := probeStart(hi, len(m.groups))
g := probeStart[S](hi, len(m.groups))
for { // inlined find loop
matches := metaMatchH2(&m.ctrl[g], lo)
for matches != 0 {
s := nextMatch(&matches)
s := nextMatch[S](&matches)
if key == m.groups[g].keys[s] {
value, ok = m.groups[g].values[s], true
return
Expand All @@ -120,23 +133,23 @@ func (m *Map[K, V]) Get(key K) (value V, ok bool) {
return
}
g += 1 // linear probing
if g >= uint32(len(m.groups)) {
if g >= S(len(m.groups)) {
g = 0
}
}
}

// Put attempts to insert |key| and |value|
func (m *Map[K, V]) Put(key K, value V) {
func (m *Map[K, V, S]) Put(key K, value V) {
if m.resident >= m.limit {
m.rehash(m.nextSize())
}
hi, lo := splitHash(m.hash.Hash(key))
g := probeStart(hi, len(m.groups))
g := probeStart[S](hi, len(m.groups))
for { // inlined find loop
matches := metaMatchH2(&m.ctrl[g], lo)
for matches != 0 {
s := nextMatch(&matches)
s := nextMatch[S](&matches)
if key == m.groups[g].keys[s] { // update
m.groups[g].keys[s] = key
m.groups[g].values[s] = value
Expand All @@ -147,28 +160,28 @@ func (m *Map[K, V]) Put(key K, value V) {
// stop probing if we see an empty slot
matches = metaMatchEmpty(&m.ctrl[g])
if matches != 0 { // insert
s := nextMatch(&matches)
s := nextMatch[S](&matches)
m.groups[g].keys[s] = key
m.groups[g].values[s] = value
m.ctrl[g][s] = int8(lo)
m.resident++
return
}
g += 1 // linear probing
if g >= uint32(len(m.groups)) {
if g >= S(len(m.groups)) {
g = 0
}
}
}

// Delete attempts to remove |key|, returns true successful.
func (m *Map[K, V]) Delete(key K) (ok bool) {
func (m *Map[K, V, S]) Delete(key K) (ok bool) {
hi, lo := splitHash(m.hash.Hash(key))
g := probeStart(hi, len(m.groups))
g := probeStart[S](hi, len(m.groups))
for {
matches := metaMatchH2(&m.ctrl[g], lo)
for matches != 0 {
s := nextMatch(&matches)
s := nextMatch[S](&matches)
if key == m.groups[g].keys[s] {
ok = true
// optimization: if |m.ctrl[g]| contains any empty
Expand Down Expand Up @@ -200,7 +213,7 @@ func (m *Map[K, V]) Delete(key K) (ok bool) {
return
}
g += 1 // linear probing
if g >= uint32(len(m.groups)) {
if g >= S(len(m.groups)) {
g = 0
}
}
Expand All @@ -211,12 +224,12 @@ func (m *Map[K, V]) Delete(key K) (ok bool) {
// for un-mutated Maps, every key will be visited once. If the Map is
// Mutated during iteration, mutations will be reflected on return from
// Iter, but the set of keys visited by Iter is non-deterministic.
func (m *Map[K, V]) Iter(cb func(k K, v V) (stop bool)) {
func (m *Map[K, V, S]) Iter(cb func(k K, v V) (stop bool)) {
// take a consistent view of the table in case
// we rehash during iteration
ctrl, groups := m.ctrl, m.groups
// pick a random starting group
g := randIntN(len(groups))
g := S(randIntN(len(groups)))
for n := 0; n < len(groups); n++ {
for s, c := range ctrl[g] {
if c == empty || c == tombstone {
Expand All @@ -228,14 +241,14 @@ func (m *Map[K, V]) Iter(cb func(k K, v V) (stop bool)) {
}
}
g++
if g >= uint32(len(groups)) {
if g >= S(len(groups)) {
g = 0
}
}
}

// Clear removes all elements from the Map.
func (m *Map[K, V]) Clear() {
func (m *Map[K, V, S]) Clear() {
for i, c := range m.ctrl {
for j := range c {
m.ctrl[i][j] = empty
Expand All @@ -254,24 +267,24 @@ func (m *Map[K, V]) Clear() {
}

// Count returns the number of elements in the Map.
func (m *Map[K, V]) Count() int {
func (m *Map[K, V, S]) Count() int {
return int(m.resident - m.dead)
}

// Capacity returns the number of additional elements
// the can be added to the Map before resizing.
func (m *Map[K, V]) Capacity() int {
func (m *Map[K, V, S]) Capacity() int {
return int(m.limit - m.resident)
}

// find returns the location of |key| if present, or its insertion location if absent.
// for performance, find is manually inlined into public methods.
func (m *Map[K, V]) find(key K, hi h1, lo h2) (g, s uint32, ok bool) {
g = probeStart(hi, len(m.groups))
func (m *Map[K, V, S]) find(key K, hi h1, lo h2) (g, s S, ok bool) {
g = probeStart[S](hi, len(m.groups))
for {
matches := metaMatchH2(&m.ctrl[g], lo)
for matches != 0 {
s = nextMatch(&matches)
s = nextMatch[S](&matches)
if key == m.groups[g].keys[s] {
return g, s, true
}
Expand All @@ -280,25 +293,25 @@ func (m *Map[K, V]) find(key K, hi h1, lo h2) (g, s uint32, ok bool) {
// stop probing if we see an empty slot
matches = metaMatchEmpty(&m.ctrl[g])
if matches != 0 {
s = nextMatch(&matches)
s = nextMatch[S](&matches)
return g, s, false
}
g += 1 // linear probing
if g >= uint32(len(m.groups)) {
if g >= S(len(m.groups)) {
g = 0
}
}
}

func (m *Map[K, V]) nextSize() (n uint32) {
n = uint32(len(m.groups)) * 2
func (m *Map[K, V, S]) nextSize() (n S) {
n = S(len(m.groups)) * 2
if m.dead >= (m.resident / 2) {
n = uint32(len(m.groups))
n = S(len(m.groups))
}
return
}

func (m *Map[K, V]) rehash(n uint32) {
func (m *Map[K, V, S]) rehash(n S) {
groups, ctrl := m.groups, m.ctrl
m.groups = make([]group[K, V], n)
m.ctrl = make([]metadata, n)
Expand All @@ -319,13 +332,13 @@ func (m *Map[K, V]) rehash(n uint32) {
}
}

func (m *Map[K, V]) loadFactor() float32 {
func (m *Map[K, V, S]) loadFactor() float32 {
slots := float32(len(m.groups) * groupSize)
return float32(m.resident-m.dead) / slots
}

// numGroups returns the minimum number of groups needed to store |n| elems.
func numGroups(n uint32) (groups uint32) {
func numGroups[S Size](n S) (groups S) {
groups = (n + maxAvgGroupLoad - 1) / maxAvgGroupLoad
if groups == 0 {
groups = 1
Expand All @@ -344,11 +357,11 @@ func splitHash(h uint64) (h1, h2) {
return h1((h & h1Mask) >> 7), h2(h & h2Mask)
}

func probeStart(hi h1, groups int) uint32 {
return fastModN(uint32(hi), uint32(groups))
func probeStart[S Size](hi h1, groups int) S {
return fastModN(S(hi), S(groups))
}

// lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
func fastModN(x, n uint32) uint32 {
return uint32((uint64(x) * uint64(n)) >> 32)
func fastModN[S Size](x, n S) S {
return S((uint64(x) * uint64(n)) >> 32)
}
2 changes: 1 addition & 1 deletion map_fuzz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ type hasher struct {
seed uintptr
}

func setConstSeed[K comparable, V any](m *Map[K, V], seed uintptr) {
func setConstSeed[K comparable, V any, S Size](m *Map[K, V, S], seed uintptr) {
h := (*hasher)((unsafe.Pointer)(&m.hash))
h.seed = seed
}
Loading