Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use generic to allow 32 and 64 uint keys #30

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions bits.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package swiss
import (
"math/bits"
"unsafe"

"golang.org/x/exp/constraints"
)

const (
Expand All @@ -40,8 +42,8 @@ func metaMatchEmpty(m *metadata) bitset {
return hasZeroByte(castUint64(m) ^ hiBits)
}

func nextMatch(b *bitset) uint32 {
s := uint32(bits.TrailingZeros64(uint64(*b)))
func nextMatch[S constraints.Unsigned](b *bitset) S {
s := S(bits.TrailingZeros64(uint64(*b)))
*b &= ^(1 << s) // clear bit |s|
return s >> 3 // div by 8
}
Expand Down
8 changes: 4 additions & 4 deletions bits_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestMatchMetadata(t *testing.T) {
for _, x := range meta {
mask := metaMatchH2(&meta, h2(x))
assert.NotZero(t, mask)
assert.Equal(t, uint32(x), nextMatch(&mask))
assert.Equal(t, uint32(x), nextMatch[uint32](&mask))
}
})
t.Run("metaMatchEmpty", func(t *testing.T) {
Expand All @@ -42,7 +42,7 @@ func TestMatchMetadata(t *testing.T) {
meta[i] = empty
mask = metaMatchEmpty(&meta)
assert.NotZero(t, mask)
assert.Equal(t, uint32(i), nextMatch(&mask))
assert.Equal(t, uint32(i), nextMatch[uint32](&mask))
meta[i] = int8(i)
}
})
Expand All @@ -51,14 +51,14 @@ func TestMatchMetadata(t *testing.T) {
meta = newEmptyMetadata()
mask := metaMatchEmpty(&meta)
for i := range meta {
assert.Equal(t, uint32(i), nextMatch(&mask))
assert.Equal(t, uint32(i), nextMatch[uint32](&mask))
}
for i := 0; i < len(meta); i += 2 {
meta[i] = int8(42)
}
mask = metaMatchH2(&meta, h2(42))
for i := 0; i < len(meta); i += 2 {
assert.Equal(t, uint32(i), nextMatch(&mask))
assert.Equal(t, uint32(i), nextMatch[uint32](&mask))
}
})
}
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ require (
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/exp v0.0.0-20240529005216-23cca8864a10 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
golang.org/x/exp v0.0.0-20240529005216-23cca8864a10 h1:vpzMC/iZhYFAjJzHU0Cfuq+w1vLLsF2vLkDrPjzKYck=
golang.org/x/exp v0.0.0-20240529005216-23cca8864a10/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
Expand Down
98 changes: 54 additions & 44 deletions map.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package swiss

import (
"github.com/dolthub/maphash"
"golang.org/x/exp/constraints"
)

const (
Expand All @@ -24,13 +25,13 @@ const (

// Map is an open-addressing hash map
// based on Abseil's flat_hash_map.
type Map[K comparable, V any] struct {
type Map[K comparable, V any, S constraints.Unsigned] struct {
ctrl []metadata
groups []group[K, V]
hash maphash.Hasher[K]
resident uint32
dead uint32
limit uint32
resident S
dead S
limit S
}

// metadata is the h2 metadata array for a group.
Expand Down Expand Up @@ -58,9 +59,18 @@ type h1 uint64
type h2 int8

// NewMap constructs a Map.
func NewMap[K comparable, V any](sz uint32) (m *Map[K, V]) {
func NewMap[K comparable, V any](sz uint32) (m *Map[K, V, uint32]) {
return newMap[K, V, uint32](sz)
}

// NewMap constructs a Map.
func NewMap64[K comparable, V any](sz uint64) (m *Map[K, V, uint64]) {
return newMap[K, V, uint64](sz)
}

func newMap[K comparable, V any, S constraints.Unsigned](sz S) (m *Map[K, V, S]) {
groups := numGroups(sz)
m = &Map[K, V]{
m = &Map[K, V, S]{
ctrl: make([]metadata, groups),
groups: make([]group[K, V], groups),
hash: maphash.NewHasher[K](),
Expand All @@ -73,13 +83,13 @@ func NewMap[K comparable, V any](sz uint32) (m *Map[K, V]) {
}

// Has returns true if |key| is present in |m|.
func (m *Map[K, V]) Has(key K) (ok bool) {
func (m *Map[K, V, S]) Has(key K) (ok bool) {
hi, lo := splitHash(m.hash.Hash(key))
g := probeStart(hi, len(m.groups))
g := probeStart[S](hi, len(m.groups))
for { // inlined find loop
matches := metaMatchH2(&m.ctrl[g], lo)
for matches != 0 {
s := nextMatch(&matches)
s := nextMatch[S](&matches)
if key == m.groups[g].keys[s] {
ok = true
return
Expand All @@ -93,20 +103,20 @@ func (m *Map[K, V]) Has(key K) (ok bool) {
return
}
g += 1 // linear probing
if g >= uint32(len(m.groups)) {
if g >= S(len(m.groups)) {
g = 0
}
}
}

// Get returns the |value| mapped by |key| if one exists.
func (m *Map[K, V]) Get(key K) (value V, ok bool) {
func (m *Map[K, V, S]) Get(key K) (value V, ok bool) {
hi, lo := splitHash(m.hash.Hash(key))
g := probeStart(hi, len(m.groups))
g := probeStart[S](hi, len(m.groups))
for { // inlined find loop
matches := metaMatchH2(&m.ctrl[g], lo)
for matches != 0 {
s := nextMatch(&matches)
s := nextMatch[S](&matches)
if key == m.groups[g].keys[s] {
value, ok = m.groups[g].values[s], true
return
Expand All @@ -120,23 +130,23 @@ func (m *Map[K, V]) Get(key K) (value V, ok bool) {
return
}
g += 1 // linear probing
if g >= uint32(len(m.groups)) {
if g >= S(len(m.groups)) {
g = 0
}
}
}

// Put attempts to insert |key| and |value|
func (m *Map[K, V]) Put(key K, value V) {
func (m *Map[K, V, S]) Put(key K, value V) {
if m.resident >= m.limit {
m.rehash(m.nextSize())
}
hi, lo := splitHash(m.hash.Hash(key))
g := probeStart(hi, len(m.groups))
g := probeStart[S](hi, len(m.groups))
for { // inlined find loop
matches := metaMatchH2(&m.ctrl[g], lo)
for matches != 0 {
s := nextMatch(&matches)
s := nextMatch[S](&matches)
if key == m.groups[g].keys[s] { // update
m.groups[g].keys[s] = key
m.groups[g].values[s] = value
Expand All @@ -147,28 +157,28 @@ func (m *Map[K, V]) Put(key K, value V) {
// stop probing if we see an empty slot
matches = metaMatchEmpty(&m.ctrl[g])
if matches != 0 { // insert
s := nextMatch(&matches)
s := nextMatch[S](&matches)
m.groups[g].keys[s] = key
m.groups[g].values[s] = value
m.ctrl[g][s] = int8(lo)
m.resident++
return
}
g += 1 // linear probing
if g >= uint32(len(m.groups)) {
if g >= S(len(m.groups)) {
g = 0
}
}
}

// Delete attempts to remove |key|, returns true successful.
func (m *Map[K, V]) Delete(key K) (ok bool) {
func (m *Map[K, V, S]) Delete(key K) (ok bool) {
hi, lo := splitHash(m.hash.Hash(key))
g := probeStart(hi, len(m.groups))
g := probeStart[S](hi, len(m.groups))
for {
matches := metaMatchH2(&m.ctrl[g], lo)
for matches != 0 {
s := nextMatch(&matches)
s := nextMatch[S](&matches)
if key == m.groups[g].keys[s] {
ok = true
// optimization: if |m.ctrl[g]| contains any empty
Expand Down Expand Up @@ -200,7 +210,7 @@ func (m *Map[K, V]) Delete(key K) (ok bool) {
return
}
g += 1 // linear probing
if g >= uint32(len(m.groups)) {
if g >= S(len(m.groups)) {
g = 0
}
}
Expand All @@ -211,12 +221,12 @@ func (m *Map[K, V]) Delete(key K) (ok bool) {
// for un-mutated Maps, every key will be visited once. If the Map is
// Mutated during iteration, mutations will be reflected on return from
// Iter, but the set of keys visited by Iter is non-deterministic.
func (m *Map[K, V]) Iter(cb func(k K, v V) (stop bool)) {
func (m *Map[K, V, S]) Iter(cb func(k K, v V) (stop bool)) {
// take a consistent view of the table in case
// we rehash during iteration
ctrl, groups := m.ctrl, m.groups
// pick a random starting group
g := randIntN(len(groups))
g := S(randIntN(len(groups)))
for n := 0; n < len(groups); n++ {
for s, c := range ctrl[g] {
if c == empty || c == tombstone {
Expand All @@ -228,14 +238,14 @@ func (m *Map[K, V]) Iter(cb func(k K, v V) (stop bool)) {
}
}
g++
if g >= uint32(len(groups)) {
if g >= S(len(groups)) {
g = 0
}
}
}

// Clear removes all elements from the Map.
func (m *Map[K, V]) Clear() {
func (m *Map[K, V, S]) Clear() {
for i, c := range m.ctrl {
for j := range c {
m.ctrl[i][j] = empty
Expand All @@ -254,24 +264,24 @@ func (m *Map[K, V]) Clear() {
}

// Count returns the number of elements in the Map.
func (m *Map[K, V]) Count() int {
func (m *Map[K, V, S]) Count() int {
return int(m.resident - m.dead)
}

// Capacity returns the number of additional elements
// the can be added to the Map before resizing.
func (m *Map[K, V]) Capacity() int {
func (m *Map[K, V, S]) Capacity() int {
return int(m.limit - m.resident)
}

// find returns the location of |key| if present, or its insertion location if absent.
// for performance, find is manually inlined into public methods.
func (m *Map[K, V]) find(key K, hi h1, lo h2) (g, s uint32, ok bool) {
g = probeStart(hi, len(m.groups))
func (m *Map[K, V, S]) find(key K, hi h1, lo h2) (g, s S, ok bool) {
g = probeStart[S](hi, len(m.groups))
for {
matches := metaMatchH2(&m.ctrl[g], lo)
for matches != 0 {
s = nextMatch(&matches)
s = nextMatch[S](&matches)
if key == m.groups[g].keys[s] {
return g, s, true
}
Expand All @@ -280,25 +290,25 @@ func (m *Map[K, V]) find(key K, hi h1, lo h2) (g, s uint32, ok bool) {
// stop probing if we see an empty slot
matches = metaMatchEmpty(&m.ctrl[g])
if matches != 0 {
s = nextMatch(&matches)
s = nextMatch[S](&matches)
return g, s, false
}
g += 1 // linear probing
if g >= uint32(len(m.groups)) {
if g >= S(len(m.groups)) {
g = 0
}
}
}

func (m *Map[K, V]) nextSize() (n uint32) {
n = uint32(len(m.groups)) * 2
func (m *Map[K, V, S]) nextSize() (n S) {
n = S(len(m.groups)) * 2
if m.dead >= (m.resident / 2) {
n = uint32(len(m.groups))
n = S(len(m.groups))
}
return
}

func (m *Map[K, V]) rehash(n uint32) {
func (m *Map[K, V, S]) rehash(n S) {
groups, ctrl := m.groups, m.ctrl
m.groups = make([]group[K, V], n)
m.ctrl = make([]metadata, n)
Expand All @@ -319,13 +329,13 @@ func (m *Map[K, V]) rehash(n uint32) {
}
}

func (m *Map[K, V]) loadFactor() float32 {
func (m *Map[K, V, S]) loadFactor() float32 {
slots := float32(len(m.groups) * groupSize)
return float32(m.resident-m.dead) / slots
}

// numGroups returns the minimum number of groups needed to store |n| elems.
func numGroups(n uint32) (groups uint32) {
func numGroups[S constraints.Unsigned](n S) (groups S) {
groups = (n + maxAvgGroupLoad - 1) / maxAvgGroupLoad
if groups == 0 {
groups = 1
Expand All @@ -344,11 +354,11 @@ func splitHash(h uint64) (h1, h2) {
return h1((h & h1Mask) >> 7), h2(h & h2Mask)
}

func probeStart(hi h1, groups int) uint32 {
return fastModN(uint32(hi), uint32(groups))
func probeStart[S constraints.Unsigned](hi h1, groups int) S {
return fastModN(S(hi), S(groups))
}

// lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
func fastModN(x, n uint32) uint32 {
return uint32((uint64(x) * uint64(n)) >> 32)
func fastModN[S constraints.Unsigned](x, n S) S {
return S((uint64(x) * uint64(n)) >> 32)
Comment on lines 361 to +363
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the part that will not work with uint64: it will always return values in 32-bit range, so those buckets will be unused.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Would returning unsafe.Sizeof(S) help here? Or is this a bit-op trick that can not be rescued?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it can be rescued without having to do the usual modulo division. See the link above that explains how the trick works.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This writeup shows a possible version to support uint64 ranges.

Perhaps there's a method here for assigning the correct fastmodN for the map instantiation. It, however, wouldn't be inlined by the compiler.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the link! It's always great to learn something new.

Maybe it would make sense to just have a fork of this repo (swiss64?) with the 64-bit version? It should be quite easy to keep it up to date (I guess there shouldn't be many or frequent incoming changes), the API would be clean (swiss.Map vs swiss64.Map) and the code would be easier and faster to follow. WDYT?

}
3 changes: 2 additions & 1 deletion map_fuzz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"unsafe"

"github.com/stretchr/testify/assert"
"golang.org/x/exp/constraints"
)

func FuzzStringMap(f *testing.F) {
Expand Down Expand Up @@ -93,7 +94,7 @@ type hasher struct {
seed uintptr
}

func setConstSeed[K comparable, V any](m *Map[K, V], seed uintptr) {
func setConstSeed[K comparable, V any, S constraints.Unsigned](m *Map[K, V, S], seed uintptr) {
h := (*hasher)((unsafe.Pointer)(&m.hash))
h.seed = seed
}
Loading