From 666795bac305d467e1cf37242d51c2cd1315106d Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Thu, 3 Mar 2022 10:26:09 -0500 Subject: [PATCH] chore(dot/state): replace `sync.Map` with map+mutex (#2286) --- dot/state/block.go | 71 ++----- dot/state/block_finalisation.go | 20 +- dot/state/hashtoblockmap.go | 91 +++++++++ dot/state/hashtoblockmap_test.go | 327 +++++++++++++++++++++++++++++++ 4 files changed, 445 insertions(+), 64 deletions(-) create mode 100644 dot/state/hashtoblockmap.go create mode 100644 dot/state/hashtoblockmap_test.go diff --git a/dot/state/block.go b/dot/state/block.go index 59f3db4521..f505e654fd 100644 --- a/dot/state/block.go +++ b/dot/state/block.go @@ -58,7 +58,7 @@ type BlockState struct { sync.RWMutex genesisHash common.Hash lastFinalised common.Hash - unfinalisedBlocks *sync.Map // map[common.Hash]*types.Block + unfinalisedBlocks *hashToBlockMap tries *Tries // block notifiers @@ -78,7 +78,7 @@ func NewBlockState(db chaindb.Database, trs *Tries, telemetry telemetry.Client) dbPath: db.Path(), baseState: NewBaseState(db), db: chaindb.NewTable(db, blockPrefix), - unfinalisedBlocks: new(sync.Map), + unfinalisedBlocks: newHashToBlockMap(), tries: trs, imported: make(map[chan *types.Block]struct{}), finalised: make(map[chan *types.FinalisationInfo]struct{}), @@ -106,12 +106,12 @@ func NewBlockState(db chaindb.Database, trs *Tries, telemetry telemetry.Client) // NewBlockStateFromGenesis initialises a BlockState from a genesis header, // saving it to the database located at basePath func NewBlockStateFromGenesis(db chaindb.Database, trs *Tries, header *types.Header, - telemetryMailer telemetry.Client) (*BlockState, error) { // TODO CHECKTEST + telemetryMailer telemetry.Client) (*BlockState, error) { bs := &BlockState{ bt: blocktree.NewBlockTreeFromRoot(header), baseState: NewBaseState(db), db: chaindb.NewTable(db, blockPrefix), - unfinalisedBlocks: new(sync.Map), + unfinalisedBlocks: newHashToBlockMap(), tries: trs, imported: make(map[chan *types.Block]struct{}), finalised: make(map[chan *types.FinalisationInfo]struct{}), @@ -184,46 +184,9 @@ func (bs *BlockState) GenesisHash() common.Hash { return bs.genesisHash } -func (bs *BlockState) storeUnfinalisedBlock(block *types.Block) { - bs.unfinalisedBlocks.Store(block.Header.Hash(), block) -} - -func (bs *BlockState) hasUnfinalisedBlock(hash common.Hash) bool { - _, has := bs.unfinalisedBlocks.Load(hash) - return has -} - -func (bs *BlockState) getUnfinalisedHeader(hash common.Hash) (*types.Header, bool) { - block, has := bs.getUnfinalisedBlock(hash) - if !has { - return nil, false - } - - return &block.Header, true -} - -func (bs *BlockState) getUnfinalisedBlock(hash common.Hash) (*types.Block, bool) { - block, has := bs.unfinalisedBlocks.Load(hash) - if !has { - return nil, false - } - - // TODO: dot/core tx re-org test seems to abort here due to block body being invalid? - return block.(*types.Block), true -} - -func (bs *BlockState) getAndDeleteUnfinalisedBlock(hash common.Hash) (*types.Block, bool) { - block, has := bs.unfinalisedBlocks.LoadAndDelete(hash) - if !has { - return nil, false - } - - return block.(*types.Block), true -} - // HasHeader returns if the db contains a header with the given hash func (bs *BlockState) HasHeader(hash common.Hash) (bool, error) { - if bs.hasUnfinalisedBlock(hash) { + if bs.unfinalisedBlocks.getBlock(hash) != nil { return true, nil } @@ -231,9 +194,9 @@ func (bs *BlockState) HasHeader(hash common.Hash) (bool, error) { } // GetHeader returns a BlockHeader for a given hash -func (bs *BlockState) GetHeader(hash common.Hash) (*types.Header, error) { - header, has := bs.getUnfinalisedHeader(hash) - if has { +func (bs *BlockState) GetHeader(hash common.Hash) (header *types.Header, err error) { + header = bs.unfinalisedBlocks.getBlockHeader(hash) + if header != nil { return header, nil } @@ -313,8 +276,8 @@ func (bs *BlockState) GetBlockByHash(hash common.Hash) (*types.Block, error) { bs.RLock() defer bs.RUnlock() - block, has := bs.getUnfinalisedBlock(hash) - if has { + block := bs.unfinalisedBlocks.getBlock(hash) + if block != nil { return block, nil } @@ -346,7 +309,7 @@ func (bs *BlockState) HasBlockBody(hash common.Hash) (bool, error) { bs.RLock() defer bs.RUnlock() - if bs.hasUnfinalisedBlock(hash) { + if bs.unfinalisedBlocks.getBlock(hash) != nil { return true, nil } @@ -354,10 +317,10 @@ func (bs *BlockState) HasBlockBody(hash common.Hash) (bool, error) { } // GetBlockBody will return Body for a given hash -func (bs *BlockState) GetBlockBody(hash common.Hash) (*types.Body, error) { - block, has := bs.getUnfinalisedBlock(hash) - if has { - return &block.Body, nil +func (bs *BlockState) GetBlockBody(hash common.Hash) (body *types.Body, err error) { + body = bs.unfinalisedBlocks.getBlockBody(hash) + if body != nil { + return body, nil } data, err := bs.db.Get(blockBodyKey(hash)) @@ -417,7 +380,7 @@ func (bs *BlockState) AddBlockWithArrivalTime(block *types.Block, arrivalTime ti return err } - bs.storeUnfinalisedBlock(block) + bs.unfinalisedBlocks.store(block) go bs.notifyImported(block) return nil } @@ -433,7 +396,7 @@ func (bs *BlockState) AddBlockToBlockTree(block *types.Block) error { arrivalTime = time.Now() } - bs.storeUnfinalisedBlock(block) + bs.unfinalisedBlocks.store(block) return bs.bt.AddBlock(&block.Header, arrivalTime) } diff --git a/dot/state/block_finalisation.go b/dot/state/block_finalisation.go index f7fe569191..bbbcd4fd3f 100644 --- a/dot/state/block_finalisation.go +++ b/dot/state/block_finalisation.go @@ -146,14 +146,14 @@ func (bs *BlockState) SetFinalisedHash(hash common.Hash, round, setID uint64) er pruned := bs.bt.Prune(hash) for _, hash := range pruned { - block, has := bs.getAndDeleteUnfinalisedBlock(hash) - if !has { + blockHeader := bs.unfinalisedBlocks.delete(hash) + if blockHeader == nil { continue } - bs.tries.delete(block.Header.StateRoot) + bs.tries.delete(blockHeader.StateRoot) - logger.Tracef("pruned block number %s with hash %s", block.Header.Number, hash) + logger.Tracef("pruned block number %s with hash %s", blockHeader.Number, hash) } // if nothing was previously finalised, set the first slot of the network to the @@ -207,8 +207,8 @@ func (bs *BlockState) handleFinalisedBlock(curr common.Hash) error { continue } - block, has := bs.getUnfinalisedBlock(hash) - if !has { + block := bs.unfinalisedBlocks.getBlock(hash) + if block == nil { return fmt.Errorf("failed to find block in unfinalised block map, block=%s", hash) } @@ -234,14 +234,14 @@ func (bs *BlockState) handleFinalisedBlock(curr common.Hash) error { } // delete from the unfinalisedBlockMap and delete reference to in-memory trie - block, has = bs.getAndDeleteUnfinalisedBlock(hash) - if !has { + blockHeader := bs.unfinalisedBlocks.delete(hash) + if blockHeader == nil { continue } - bs.tries.delete(block.Header.StateRoot) + bs.tries.delete(blockHeader.StateRoot) - logger.Tracef("cleaned out finalised block from memory; block number %s with hash %s", block.Header.Number, hash) + logger.Tracef("cleaned out finalised block from memory; block number %s with hash %s", blockHeader.Number, hash) } return batch.Flush() diff --git a/dot/state/hashtoblockmap.go b/dot/state/hashtoblockmap.go new file mode 100644 index 0000000000..4fbcb306b4 --- /dev/null +++ b/dot/state/hashtoblockmap.go @@ -0,0 +1,91 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package state + +import ( + "sync" + + "github.com/ChainSafe/gossamer/dot/types" + "github.com/ChainSafe/gossamer/lib/common" +) + +// hashToBlockMap implements a thread safe map of block header hashes +// to block pointers. It has helper methods to fit the needs of callers +// in this package. +type hashToBlockMap struct { + mutex sync.RWMutex + mapping map[common.Hash]*types.Block +} + +func newHashToBlockMap() *hashToBlockMap { + return &hashToBlockMap{ + mapping: make(map[common.Hash]*types.Block), + } +} + +// getBlock returns a pointer to the block stored at the hash given, +// or nil if not found. +// Note this returns a pointer to the block so modifying the returned value +// will modify the block stored in the map, potentially leading to data races +// or unwanted changes, so be careful. +func (h *hashToBlockMap) getBlock(hash common.Hash) (block *types.Block) { + h.mutex.RLock() + defer h.mutex.RUnlock() + return h.mapping[hash] +} + +// getBlockHeader returns a pointer to the header of the block stored at the +// hash given, or nil if not found. +// Note this returns a pointer to the header of the block so modifying the +// returned value will modify the header of the block stored in the map, +// potentially leading to data races or unwanted changes, so be careful. +func (h *hashToBlockMap) getBlockHeader(hash common.Hash) (header *types.Header) { + h.mutex.RLock() + defer h.mutex.RUnlock() + block := h.mapping[hash] + if block == nil { + return nil + } + return &block.Header +} + +// getBlockBody returns a pointer to the body of the block stored at the +// hash given, or nil if not found. +// Note this returns a pointer to the body of the block so modifying the +// returned value will modify the body of the block stored in the map, +// potentially leading to data races or unwanted changes, so be careful. +func (h *hashToBlockMap) getBlockBody(hash common.Hash) (body *types.Body) { + h.mutex.RLock() + defer h.mutex.RUnlock() + block := h.mapping[hash] + if block == nil { + return nil + } + return &block.Body +} + +// store stores a block and uses its header hash digest as key. +// Note the block is not deep copied so mutating the passed argument +// will lead to mutation for the block in the map and returned by this map. +// Also note this operation sets the hash field on the block header because of +// the call to block.Header.Hash(). +func (h *hashToBlockMap) store(block *types.Block) { + h.mutex.Lock() + defer h.mutex.Unlock() + h.mapping[block.Header.Hash()] = block +} + +// delete deletes the block stored at the hash given, and returns +// a pointer to the header of the block deleted from the map, +// or nil if the block is not found. +func (h *hashToBlockMap) delete(hash common.Hash) (deletedHeader *types.Header) { + h.mutex.Lock() + defer h.mutex.Unlock() + block := h.mapping[hash] + delete(h.mapping, hash) + if block == nil { + return nil + } + return &block.Header +} diff --git a/dot/state/hashtoblockmap_test.go b/dot/state/hashtoblockmap_test.go new file mode 100644 index 0000000000..307169b18d --- /dev/null +++ b/dot/state/hashtoblockmap_test.go @@ -0,0 +1,327 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package state + +import ( + "context" + "math/big" + "sync" + "testing" + "time" + + "github.com/ChainSafe/gossamer/dot/types" + "github.com/ChainSafe/gossamer/lib/common" + "github.com/stretchr/testify/assert" +) + +func Test_newHashToBlockMap(t *testing.T) { + t.Parallel() + + htb := newHashToBlockMap() + + expected := &hashToBlockMap{ + mapping: make(map[common.Hash]*types.Block), + } + assert.Equal(t, expected, htb) +} + +func Test_hashToBlockMap_getBlock(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + htb *hashToBlockMap + hash common.Hash + block *types.Block + }{ + "hash does not exist": { + htb: &hashToBlockMap{ + mapping: map[common.Hash]*types.Block{ + {4, 5, 6}: {}, + }, + }, + hash: common.Hash{1, 2, 3}, + }, + "hash exists": { + htb: &hashToBlockMap{ + mapping: map[common.Hash]*types.Block{ + {1, 2, 3}: {Header: types.Header{ParentHash: common.Hash{1}}}, + }, + }, + hash: common.Hash{1, 2, 3}, + block: &types.Block{Header: types.Header{ParentHash: common.Hash{1}}}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + block := testCase.htb.getBlock(testCase.hash) + + assert.Equal(t, testCase.block, block) + }) + } +} + +func Test_hashToBlockMap_getBlockHeader(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + htb *hashToBlockMap + hash common.Hash + header *types.Header + }{ + "hash does not exist": { + htb: &hashToBlockMap{ + mapping: map[common.Hash]*types.Block{ + {4, 5, 6}: {}, + }, + }, + hash: common.Hash{1, 2, 3}, + }, + "hash exists": { + htb: &hashToBlockMap{ + mapping: map[common.Hash]*types.Block{ + {1, 2, 3}: {Header: types.Header{ParentHash: common.Hash{1}}}, + }, + }, + hash: common.Hash{1, 2, 3}, + header: &types.Header{ParentHash: common.Hash{1}}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + header := testCase.htb.getBlockHeader(testCase.hash) + + assert.Equal(t, testCase.header, header) + }) + } +} + +func Test_hashToBlockMap_getBlockBody(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + htb *hashToBlockMap + hash common.Hash + body *types.Body + }{ + "hash does not exist": { + htb: &hashToBlockMap{ + mapping: map[common.Hash]*types.Block{ + {4, 5, 6}: {}, + }, + }, + hash: common.Hash{1, 2, 3}, + }, + "hash exists": { + htb: &hashToBlockMap{ + mapping: map[common.Hash]*types.Block{ + {1, 2, 3}: {Body: types.Body{}}, + }, + }, + hash: common.Hash{1, 2, 3}, + body: &types.Body{}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + body := testCase.htb.getBlockBody(testCase.hash) + + assert.Equal(t, testCase.body, body) + }) + } +} + +func Test_hashToBlockMap_store(t *testing.T) { + t.Parallel() + + headerWithHash := func(header types.Header) types.Header { + header.Hash() + return header + } + + testCases := map[string]struct { + initialMap *hashToBlockMap + block *types.Block + expectedMap *hashToBlockMap + }{ + "override block": { + initialMap: &hashToBlockMap{ + mapping: map[common.Hash]*types.Block{ + { + 0x64, 0x43, 0xa0, 0xb4, 0x6e, 0x4, 0x12, 0xe6, + 0x26, 0x36, 0x30, 0x28, 0x11, 0x5a, 0x9f, 0x2c, + 0xf9, 0x63, 0xee, 0xed, 0x52, 0x6b, 0x8b, 0x33, + 0xe5, 0x31, 0x6f, 0x8, 0xb5, 0xd, 0xd, 0xc3, + }: {Header: types.Header{Number: big.NewInt(99)}}, + }, + }, + block: &types.Block{Header: types.Header{Number: big.NewInt(1)}}, + expectedMap: &hashToBlockMap{ + mapping: map[common.Hash]*types.Block{ + { + 0x64, 0x43, 0xa0, 0xb4, 0x6e, 0x4, 0x12, 0xe6, + 0x26, 0x36, 0x30, 0x28, 0x11, 0x5a, 0x9f, 0x2c, + 0xf9, 0x63, 0xee, 0xed, 0x52, 0x6b, 0x8b, 0x33, + 0xe5, 0x31, 0x6f, 0x8, 0xb5, 0xd, 0xd, 0xc3, + }: {Header: headerWithHash(types.Header{Number: big.NewInt(1)})}, + }, + }, + }, + "store new block": { + initialMap: &hashToBlockMap{ + mapping: map[common.Hash]*types.Block{}, + }, + block: &types.Block{Header: types.Header{Number: big.NewInt(1)}}, + expectedMap: &hashToBlockMap{ + mapping: map[common.Hash]*types.Block{ + { + 0x64, 0x43, 0xa0, 0xb4, 0x6e, 0x4, 0x12, 0xe6, + 0x26, 0x36, 0x30, 0x28, 0x11, 0x5a, 0x9f, 0x2c, + 0xf9, 0x63, 0xee, 0xed, 0x52, 0x6b, 0x8b, 0x33, + 0xe5, 0x31, 0x6f, 0x8, 0xb5, 0xd, 0xd, 0xc3, + }: {Header: headerWithHash(types.Header{Number: big.NewInt(1)})}, + }, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + htb := testCase.initialMap + + htb.store(testCase.block) + + assert.Equal(t, testCase.expectedMap, htb) + }) + } +} + +func Test_hashToBlockMap_delete(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + initialMap *hashToBlockMap + hash common.Hash + deletedHeader *types.Header + expectedMap *hashToBlockMap + }{ + "hash does not exist": { + initialMap: &hashToBlockMap{ + mapping: map[common.Hash]*types.Block{}, + }, + hash: common.Hash{1, 2, 3}, + expectedMap: &hashToBlockMap{ + mapping: map[common.Hash]*types.Block{}, + }, + }, + "hash deleted": { + initialMap: &hashToBlockMap{ + mapping: map[common.Hash]*types.Block{ + {1, 2, 3}: {Header: types.Header{ParentHash: common.Hash{1, 2, 3}}}, + }, + }, + hash: common.Hash{1, 2, 3}, + deletedHeader: &types.Header{ParentHash: common.Hash{1, 2, 3}}, + expectedMap: &hashToBlockMap{ + mapping: map[common.Hash]*types.Block{}, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + htb := testCase.initialMap + + deletedHeader := htb.delete(testCase.hash) + + assert.Equal(t, testCase.deletedHeader, deletedHeader) + assert.Equal(t, testCase.expectedMap, htb) + }) + } +} + +func Test_hashToBlockMap_threadSafety(t *testing.T) { + // This test consists in checking for concurrent access + // using the -race detector. + t.Parallel() + + var startWg, endWg sync.WaitGroup + ctx, cancel := context.WithCancel(context.Background()) + + const parallelism = 4 + const operations = 5 + const goroutines = parallelism * operations + startWg.Add(goroutines) + endWg.Add(goroutines) + + const testDuration = 50 * time.Millisecond + go func() { + timer := time.NewTimer(time.Hour) + startWg.Wait() + _ = timer.Reset(testDuration) + <-timer.C + cancel() + }() + + runInLoop := func(f func()) { + defer endWg.Done() + startWg.Done() + startWg.Wait() + for ctx.Err() == nil { + f() + } + } + + htb := newHashToBlockMap() + hash := common.Hash{ + 0x64, 0x43, 0xa0, 0xb4, 0x6e, 0x4, 0x12, 0xe6, + 0x26, 0x36, 0x30, 0x28, 0x11, 0x5a, 0x9f, 0x2c, + 0xf9, 0x63, 0xee, 0xed, 0x52, 0x6b, 0x8b, 0x33, + 0xe5, 0x31, 0x6f, 0x8, 0xb5, 0xd, 0xd, 0xc3, + } + block := &types.Block{ + Header: types.Header{Number: big.NewInt(1)}, + } + + for i := 0; i < parallelism; i++ { + go runInLoop(func() { + htb.getBlock(hash) + }) + + go runInLoop(func() { + htb.getBlockHeader(hash) + }) + + go runInLoop(func() { + htb.getBlockBody(hash) + }) + + go runInLoop(func() { + htb.store(block) + }) + + go runInLoop(func() { + _ = htb.delete(hash) + }) + } + + endWg.Wait() +}