diff --git a/go.mod b/go.mod index 076a3394..0299eeca 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/lucasjones/reggen v0.0.0-20180717132126-cdb49ff09d77 github.com/mitchellh/mapstructure v1.3.3 github.com/pkg/errors v0.9.1 // indirect + github.com/segmentio/fasthash v1.0.3 github.com/stretchr/objx v0.1.1 // indirect github.com/stretchr/testify v1.6.1 github.com/tidwall/gjson v1.6.3 diff --git a/go.sum b/go.sum index 152d070d..a5b7b9f5 100644 --- a/go.sum +++ b/go.sum @@ -226,6 +226,8 @@ github.com/rs/cors v0.0.0-20160617231935-a62a804a8a00/go.mod h1:gFx+x8UowdsKA9Ac github.com/rs/xhandler v0.0.0-20160618193221-ed27b6fd6521/go.mod h1:RvLn4FgxWubrpZHtQLnOf6EwhN2hEMusxZOhcW9H3UQ= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/segmentio/fasthash v1.0.3 h1:EI9+KE1EwvMLBWwjpRDc+fEM+prwxDYbslddQGtrmhM= +github.com/segmentio/fasthash v1.0.3/go.mod h1:waKX8l2N8yckOgmSsXJi7x1ZfdKZ4x7KRMzBtS3oedY= github.com/shirou/gopsutil v2.20.5+incompatible h1:tYH07UPoQt0OCQdgWWMgYHy3/a9bcxNpBIysykNIP7I= github.com/shirou/gopsutil v2.20.5+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= diff --git a/storage/badger_storage_test.go b/storage/badger_storage_test.go index 804552b7..740f69de 100644 --- a/storage/badger_storage_test.go +++ b/storage/badger_storage_test.go @@ -309,7 +309,7 @@ func TestBadgerTrain_Limit(t *testing.T) { namespace, newDir, dictionaryPath, - 10, + 50, []*CompressorEntry{}, ) assert.NoError(t, err) diff --git a/utils/priority_mutex_map.go b/utils/mutex_map.go similarity index 82% rename from utils/priority_mutex_map.go rename to utils/mutex_map.go index c0c0a704..4def4e34 100644 --- a/utils/priority_mutex_map.go +++ b/utils/mutex_map.go @@ -18,6 +18,10 @@ import ( "sync" ) +const ( + unlockPriority = true +) + // MutexMap is a struct that allows for // acquiring a *PriorityMutex via a string identifier // or for acquiring a global mutex that blocks @@ -26,8 +30,7 @@ import ( // This is useful for coordinating concurrent, non-overlapping // writes in the storage package. type MutexMap struct { - entries map[string]*mutexMapEntry - mutex sync.Mutex + entries *ShardedMap globalMutex sync.RWMutex } @@ -39,9 +42,9 @@ type mutexMapEntry struct { } // NewMutexMap returns a new *MutexMap. -func NewMutexMap() *MutexMap { +func NewMutexMap(shards int) *MutexMap { return &MutexMap{ - entries: map[string]*mutexMapEntry{}, + entries: NewShardedMap(shards), } } @@ -70,20 +73,23 @@ func (m *MutexMap) Lock(identifier string, priority bool) { // We acquire m when adding items to m.table // so that we don't accidentally overwrite // lock created by another goroutine. - m.mutex.Lock() - l, ok := m.entries[identifier] + data := m.entries.Lock(identifier, priority) + raw, ok := data[identifier] + var entry *mutexMapEntry if !ok { - l = &mutexMapEntry{ + entry = &mutexMapEntry{ lock: new(PriorityMutex), } - m.entries[identifier] = l + data[identifier] = entry + } else { + entry = raw.(*mutexMapEntry) } - l.count++ - m.mutex.Unlock() + entry.count++ + m.entries.Unlock(identifier) // Once we have a m.globalMutex.RLock, it is // safe to acquire an identifier lock. - l.lock.Lock(priority) + entry.lock.Lock(priority) } // Unlock releases a lock held for a particular identifier. @@ -92,15 +98,15 @@ func (m *MutexMap) Unlock(identifier string) { // exist by the time we unlock, otherwise // it would not have been possible to get // the lock to begin with. - m.mutex.Lock() - entry := m.entries[identifier] + data := m.entries.Lock(identifier, unlockPriority) + entry := data[identifier].(*mutexMapEntry) if entry.count <= 1 { // this should never be < 0 - delete(m.entries, identifier) + delete(data, identifier) } else { entry.count-- entry.lock.Unlock() } - m.mutex.Unlock() + m.entries.Unlock(identifier) // We release the globalMutex after unlocking // the identifier lock, otherwise it would be possible diff --git a/utils/priority_mutex_map_test.go b/utils/mutex_map_test.go similarity index 75% rename from utils/priority_mutex_map_test.go rename to utils/mutex_map_test.go index aec5da6b..5c2a8d38 100644 --- a/utils/priority_mutex_map_test.go +++ b/utils/mutex_map_test.go @@ -25,7 +25,7 @@ import ( func TestMutexMap(t *testing.T) { arr := []string{} - m := NewMutexMap() + m := NewMutexMap(DefaultShards) g, _ := errgroup.WithContext(context.Background()) // Lock while adding all locks @@ -47,7 +47,8 @@ func TestMutexMap(t *testing.T) { g.Go(func() error { m.Lock("a", false) - assert.Equal(t, m.entries["a"].count, 1) + entry := m.entries.shards[m.entries.shardIndex("a")].entries["a"].(*mutexMapEntry) + assert.Equal(t, entry.count, 1) <-a arr = append(arr, "a") close(b) @@ -57,7 +58,8 @@ func TestMutexMap(t *testing.T) { g.Go(func() error { m.Lock("b", false) - assert.Equal(t, m.entries["b"].count, 1) + entry := m.entries.shards[m.entries.shardIndex("b")].entries["b"].(*mutexMapEntry) + assert.Equal(t, entry.count, 1) close(a) <-b arr = append(arr, "b") @@ -68,7 +70,9 @@ func TestMutexMap(t *testing.T) { time.Sleep(1 * time.Second) // Ensure number of expected locks is correct - assert.Len(t, m.entries, 0) + totalKeys := len(m.entries.shards[m.entries.shardIndex("a")].entries) + + len(m.entries.shards[m.entries.shardIndex("b")].entries) + assert.Equal(t, totalKeys, 0) arr = append(arr, "global-a") m.GUnlock() assert.NoError(t, g.Wait()) @@ -83,5 +87,7 @@ func TestMutexMap(t *testing.T) { }, arr) // Ensure lock is no longer occupied - assert.Len(t, m.entries, 0) + totalKeys = len(m.entries.shards[m.entries.shardIndex("a")].entries) + + len(m.entries.shards[m.entries.shardIndex("b")].entries) + assert.Equal(t, totalKeys, 0) } diff --git a/utils/sharded_map.go b/utils/sharded_map.go new file mode 100644 index 00000000..a2d91c02 --- /dev/null +++ b/utils/sharded_map.go @@ -0,0 +1,85 @@ +// Copyright 2020 Coinbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "github.com/segmentio/fasthash/fnv1a" +) + +const ( + // DefaultShards is the default number of shards + // to use in ShardedMap. + DefaultShards = 256 +) + +// shardMapEntry governs access to the shard of +// the map contained at a particular index. +type shardMapEntry struct { + mutex *PriorityMutex + entries map[string]interface{} +} + +// ShardedMap allows concurrent writes +// to a map by sharding the map into some +// number of independently locked subsections. +type ShardedMap struct { + shards []*shardMapEntry +} + +// NewShardedMap creates a new *ShardedMap +// with some number of shards. The larger the +// number provided for shards, the less lock +// contention there will be. +// +// As a rule of thumb, shards should usually +// be set to the concurrency of the caller. +func NewShardedMap(shards int) *ShardedMap { + m := &ShardedMap{ + shards: make([]*shardMapEntry, shards), + } + + for i := 0; i < shards; i++ { + m.shards[i] = &shardMapEntry{ + entries: map[string]interface{}{}, + mutex: new(PriorityMutex), + } + } + + return m +} + +// shardIndex returns the index of the shard +// that could contain the key. +func (m *ShardedMap) shardIndex(key string) int { + return int(fnv1a.HashString32(key) % uint32(len(m.shards))) +} + +// Lock acquires the lock for a shard that could contain +// the key. This syntax allows the caller to perform multiple +// operations while holding the lock for a single shard. +func (m *ShardedMap) Lock(key string, priority bool) map[string]interface{} { + shardIndex := m.shardIndex(key) + shard := m.shards[shardIndex] + shard.mutex.Lock(priority) + return shard.entries +} + +// Unlock releases the lock for a shard that could contain +// the key. +func (m *ShardedMap) Unlock(key string) { + shardIndex := m.shardIndex(key) + shard := m.shards[shardIndex] + shard.mutex.Unlock() +} diff --git a/utils/sharded_map_test.go b/utils/sharded_map_test.go new file mode 100644 index 00000000..a9815bd1 --- /dev/null +++ b/utils/sharded_map_test.go @@ -0,0 +1,69 @@ +// Copyright 2020 Coinbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "golang.org/x/sync/errgroup" +) + +func TestShardedMap(t *testing.T) { + m := NewShardedMap(2) + g, _ := errgroup.WithContext(context.Background()) + + // To test locking, we use channels + // that will cause deadlock if not executed + // concurrently. + a := make(chan struct{}) + b := make(chan struct{}) + + g.Go(func() error { + s := m.Lock("a", false) + assert.Len(t, s, 0) + s["test"] = "a" + <-a + close(b) + m.Unlock("a") + return nil + }) + + g.Go(func() error { + s := m.Lock("b", false) + assert.Len(t, s, 0) + s["test"] = "b" + close(a) + <-b + m.Unlock("b") + return nil + }) + + time.Sleep(1 * time.Second) + assert.NoError(t, g.Wait()) + + // Ensure keys set correctly + s := m.Lock("a", false) + assert.Len(t, s, 1) + assert.Equal(t, s["test"], "a") + m.Unlock("a") + + s = m.Lock("b", false) + assert.Len(t, s, 1) + assert.Equal(t, s["test"], "b") + m.Unlock("b") +}