diff --git a/lib/grandpa/commits_tracker.go b/lib/grandpa/commits_tracker.go new file mode 100644 index 0000000000..14addd5a86 --- /dev/null +++ b/lib/grandpa/commits_tracker.go @@ -0,0 +1,108 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "container/list" + + "github.com/ChainSafe/gossamer/lib/common" +) + +// commitsTracker tracks vote messages that could +// not be processed, and removes the oldest ones once +// its maximum capacity is reached. +// It is NOT THREAD SAFE to use. +type commitsTracker struct { + // map of commit block hash to linked list commit message. + mapping map[common.Hash]*list.Element + // double linked list of commit messages + // to track the order commit messages were added in. + linkedList *list.List + capacity int +} + +// newCommitsTracker creates a new commit messages tracker +// with the capacity specified. +func newCommitsTracker(capacity int) commitsTracker { + return commitsTracker{ + mapping: make(map[common.Hash]*list.Element, capacity), + linkedList: list.New(), + capacity: capacity, + } +} + +// add adds a commit message to the commit message tracker. +// If the commit message tracker capacity is reached, +// the oldest commit message is removed. +func (ct *commitsTracker) add(commitMessage *CommitMessage) { + blockHash := commitMessage.Vote.Hash + + listElement, has := ct.mapping[blockHash] + if has { + // commit already exists so override the commit message in the linked list; + // do not move the list element in the linked list to avoid + // someone re-sending the same commit message and going at the + // front of the list, hence erasing other possible valid commit messages + // in the tracker. + listElement.Value = commitMessage + return + } + + // add new block hash in tracker + ct.cleanup() + listElement = ct.linkedList.PushFront(commitMessage) + ct.mapping[blockHash] = listElement +} + +// cleanup removes the oldest commit message from the tracker +// if the number of commit messages is at the tracker capacity. +// This method is designed to be called automatically from the +// add method and should not be called elsewhere. +func (ct *commitsTracker) cleanup() { + if ct.linkedList.Len() < ct.capacity { + return + } + + oldestElement := ct.linkedList.Back() + ct.linkedList.Remove(oldestElement) + + oldestCommitMessage := oldestElement.Value.(*CommitMessage) + oldestBlockHash := oldestCommitMessage.Vote.Hash + delete(ct.mapping, oldestBlockHash) +} + +// delete deletes all the vote messages for a particular +// block hash from the vote messages tracker. +func (ct *commitsTracker) delete(blockHash common.Hash) { + listElement, has := ct.mapping[blockHash] + if !has { + return + } + + ct.linkedList.Remove(listElement) + delete(ct.mapping, blockHash) +} + +// message returns a pointer to the +// commit message for a particular block hash from +// the tracker. It returns nil if the block hash +// does not exist in the tracker +func (ct *commitsTracker) message(blockHash common.Hash) ( + message *CommitMessage) { + listElement, ok := ct.mapping[blockHash] + if !ok { + return nil + } + + return listElement.Value.(*CommitMessage) +} + +// forEach runs the function `f` on each +// commit message stored in the tracker. +func (ct *commitsTracker) forEach(f func(message *CommitMessage)) { + for _, data := range ct.mapping { + message := data.Value.(*CommitMessage) + f(message) + } +} diff --git a/lib/grandpa/commits_tracker_test.go b/lib/grandpa/commits_tracker_test.go new file mode 100644 index 0000000000..d5d28bb588 --- /dev/null +++ b/lib/grandpa/commits_tracker_test.go @@ -0,0 +1,321 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "bytes" + "container/list" + "crypto/rand" + "sort" + "testing" + + "github.com/ChainSafe/gossamer/lib/common" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// buildCommitMessage creates a test commit message +// using the given block hash. +func buildCommitMessage(blockHash common.Hash) *CommitMessage { + return &CommitMessage{ + Vote: Vote{ + Hash: blockHash, + }, + } +} + +func assertCommitsMapping(t *testing.T, + mapping map[common.Hash]*list.Element, + expected map[common.Hash]*CommitMessage) { + t.Helper() + + require.Len(t, mapping, len(expected), "mapping does not have the expected length") + for expectedBlockHash, expectedCommitMessage := range expected { + listElement, ok := mapping[expectedBlockHash] + assert.Truef(t, ok, "block hash %s not found in mapping", expectedBlockHash) + assert.Equalf(t, expectedCommitMessage, listElement.Value.(*CommitMessage), + "commit message for block hash %s is not as expected", + expectedBlockHash) + } +} + +func Test_newCommitsTracker(t *testing.T) { + t.Parallel() + + const capacity = 1 + expected := commitsTracker{ + mapping: make(map[common.Hash]*list.Element, capacity), + linkedList: list.New(), + capacity: capacity, + } + vt := newCommitsTracker(capacity) + + assert.Equal(t, expected, vt) +} + +// We cannot really unit test each method independently +// due to the dependency on the double linked list from +// the standard package `list` which has private fields +// which cannot be set. +// For example we cannot assert the commits tracker mapping +// entirely due to the linked list elements unexported fields. + +func Test_commitsTracker_cleanup(t *testing.T) { + t.Parallel() + + const capacity = 2 + tracker := newCommitsTracker(capacity) + + blockHashA := common.Hash{0xa} + blockHashB := common.Hash{0xb} + blockHashC := common.Hash{0xc} + + messageBlockA := buildCommitMessage(blockHashA) + messageBlockB := buildCommitMessage(blockHashB) + messageBlockC := buildCommitMessage(blockHashC) + + tracker.add(messageBlockA) + tracker.add(messageBlockB) + // Add third message for block C. + // This triggers a cleanup removing the oldest message + // which is the message for block A. + tracker.add(messageBlockC) + assertCommitsMapping(t, tracker.mapping, map[common.Hash]*CommitMessage{ + blockHashB: messageBlockB, + blockHashC: messageBlockC, + }) +} + +// This test verifies overidding a value does not affect the +// input order for which each message was added. +func Test_commitsTracker_overriding(t *testing.T) { + t.Parallel() + + t.Run("override oldest", func(t *testing.T) { + t.Parallel() + + const capacity = 2 + tracker := newCommitsTracker(capacity) + + blockHashA := common.Hash{0xa} + blockHashB := common.Hash{0xb} + blockHashC := common.Hash{0xc} + + messageBlockA := buildCommitMessage(blockHashA) + messageBlockB := buildCommitMessage(blockHashB) + messageBlockC := buildCommitMessage(blockHashC) + + tracker.add(messageBlockA) + tracker.add(messageBlockB) + tracker.add(messageBlockA) // override oldest + tracker.add(messageBlockC) + + assertCommitsMapping(t, tracker.mapping, map[common.Hash]*CommitMessage{ + blockHashB: messageBlockB, + blockHashC: messageBlockC, + }) + }) + + t.Run("override newest", func(t *testing.T) { + t.Parallel() + + const capacity = 2 + tracker := newCommitsTracker(capacity) + + blockHashA := common.Hash{0xa} + blockHashB := common.Hash{0xb} + blockHashC := common.Hash{0xc} + + messageBlockA := buildCommitMessage(blockHashA) + messageBlockB := buildCommitMessage(blockHashB) + messageBlockC := buildCommitMessage(blockHashC) + + tracker.add(messageBlockA) + tracker.add(messageBlockB) + tracker.add(messageBlockB) // override newest + tracker.add(messageBlockC) + + assertCommitsMapping(t, tracker.mapping, map[common.Hash]*CommitMessage{ + blockHashB: messageBlockB, + blockHashC: messageBlockC, + }) + }) +} + +func Test_commitsTracker_delete(t *testing.T) { + t.Parallel() + + t.Run("non existing block hash", func(t *testing.T) { + t.Parallel() + + const capacity = 2 + tracker := newCommitsTracker(capacity) + + blockHashA := common.Hash{0xa} + blockHashB := common.Hash{0xb} + + messageBlockA := buildCommitMessage(blockHashA) + + tracker.add(messageBlockA) + tracker.delete(blockHashB) + + assertCommitsMapping(t, tracker.mapping, map[common.Hash]*CommitMessage{ + blockHashA: messageBlockA, + }) + }) + + t.Run("existing block hash", func(t *testing.T) { + t.Parallel() + + const capacity = 2 + tracker := newCommitsTracker(capacity) + + blockHashA := common.Hash{0xa} + blockHashB := common.Hash{0xb} + + messageBlockA := buildCommitMessage(blockHashA) + messageBlockB := buildCommitMessage(blockHashB) + + tracker.add(messageBlockA) + tracker.add(messageBlockB) + tracker.delete(blockHashB) + + assertCommitsMapping(t, tracker.mapping, map[common.Hash]*CommitMessage{ + blockHashA: messageBlockA, + }) + }) +} + +func Test_commitsTracker_message(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + commitsTracker *commitsTracker + blockHash common.Hash + message *CommitMessage + }{ + "non existing block hash": { + commitsTracker: &commitsTracker{ + mapping: map[common.Hash]*list.Element{ + {1}: {}, + }, + }, + blockHash: common.Hash{2}, + }, + "existing block hash": { + commitsTracker: &commitsTracker{ + mapping: map[common.Hash]*list.Element{ + {1}: { + Value: &CommitMessage{Round: 1}, + }, + }, + }, + blockHash: common.Hash{1}, + message: &CommitMessage{Round: 1}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + vt := testCase.commitsTracker + message := vt.message(testCase.blockHash) + + assert.Equal(t, testCase.message, message) + }) + } +} + +func Test_commitsTracker_forEach(t *testing.T) { + t.Parallel() + + const capacity = 10 + ct := newCommitsTracker(capacity) + + blockHashA := common.Hash{0xa} + blockHashB := common.Hash{0xb} + blockHashC := common.Hash{0xc} + + messageBlockA := buildCommitMessage(blockHashA) + messageBlockB := buildCommitMessage(blockHashB) + messageBlockC := buildCommitMessage(blockHashC) + + ct.add(messageBlockA) + ct.add(messageBlockB) + ct.add(messageBlockC) + + var results []*CommitMessage + ct.forEach(func(message *CommitMessage) { + results = append(results, message) + }) + + // Predictable messages order for assertion. + // Sort by block hash then authority id then peer ID. + sort.Slice(results, func(i, j int) bool { + return bytes.Compare(results[i].Vote.Hash[:], + results[j].Vote.Hash[:]) < 0 + }) + + expectedResults := []*CommitMessage{ + messageBlockA, + messageBlockB, + messageBlockC, + } + + assert.Equal(t, expectedResults, results) +} + +func Benchmark_ForEachVsSlice(b *testing.B) { + getMessages := func(ct *commitsTracker) (messages []*CommitMessage) { + messages = make([]*CommitMessage, 0, len(ct.mapping)) + for _, data := range ct.mapping { + messages = append(messages, data.Value.(*CommitMessage)) + } + return messages + } + + f := func(message *CommitMessage) { + message.Round++ + message.SetID++ + } + + const trackerSize = 10e4 + makeSeededTracker := func() (ct *commitsTracker) { + ct = &commitsTracker{ + mapping: make(map[common.Hash]*list.Element), + } + for i := 0; i < trackerSize; i++ { + hashBytes := make([]byte, 32) + _, _ = rand.Read(hashBytes) + var blockHash common.Hash + copy(blockHash[:], hashBytes) + ct.mapping[blockHash] = &list.Element{ + Value: &CommitMessage{ + Round: uint64(i), + SetID: uint64(i), + }, + } + } + return ct + } + + b.Run("forEach", func(b *testing.B) { + tracker := makeSeededTracker() + for i := 0; i < b.N; i++ { + tracker.forEach(f) + } + }) + + b.Run("get messages for iterate", func(b *testing.B) { + tracker := makeSeededTracker() + for i := 0; i < b.N; i++ { + messages := getMessages(tracker) + for _, message := range messages { + f(message) + } + } + }) +} diff --git a/lib/grandpa/message_tracker.go b/lib/grandpa/message_tracker.go index 00e7ef801a..1c380bb175 100644 --- a/lib/grandpa/message_tracker.go +++ b/lib/grandpa/message_tracker.go @@ -8,7 +8,6 @@ import ( "time" "github.com/ChainSafe/gossamer/dot/types" - "github.com/ChainSafe/gossamer/lib/common" "github.com/libp2p/go-libp2p-core/peer" ) @@ -19,12 +18,10 @@ type tracker struct { blockState BlockState handler *MessageHandler votes votesTracker - - // map of commit block hash to commit message - commitMessages map[common.Hash]*CommitMessage - mapLock sync.Mutex - in chan *types.Block // receive imported block from BlockState - stopped chan struct{} + commits commitsTracker + mapLock sync.Mutex + in chan *types.Block // receive imported block from BlockState + stopped chan struct{} catchUpResponseMessageMutex sync.Mutex // round(uint64) is used as key and *CatchUpResponse as value @@ -32,12 +29,15 @@ type tracker struct { } func newTracker(bs BlockState, handler *MessageHandler) *tracker { - const votesCapacity = 1000 + const ( + votesCapacity = 1000 + commitsCapacity = 1000 + ) return &tracker{ blockState: bs, handler: handler, votes: newVotesTracker(votesCapacity), - commitMessages: make(map[common.Hash]*CommitMessage), + commits: newCommitsTracker(commitsCapacity), mapLock: sync.Mutex{}, in: bs.GetImportedBlockNotifierChannel(), stopped: make(chan struct{}), @@ -68,7 +68,7 @@ func (t *tracker) addVote(peerID peer.ID, message *VoteMessage) { func (t *tracker) addCommit(cm *CommitMessage) { t.mapLock.Lock() defer t.mapLock.Unlock() - t.commitMessages[cm.Vote.Hash] = cm + t.commits.add(cm) } func (t *tracker) addCatchUpResponse(_ *CatchUpResponse) { @@ -116,13 +116,14 @@ func (t *tracker) handleBlock(b *types.Block) { // delete block hash that may or may not be in the tracker. t.votes.delete(h) - if cm, has := t.commitMessages[h]; has { + cm := t.commits.message(h) + if cm != nil { _, err := t.handler.handleMessage("", cm) if err != nil { logger.Warnf("failed to handle commit message %v: %s", cm, err) } - delete(t.commitMessages, h) + t.commits.delete(h) } } @@ -144,13 +145,16 @@ func (t *tracker) handleTick() { } } - for _, cm := range t.commitMessages { + t.commits.forEach(func(cm *CommitMessage) { _, err := t.handler.handleMessage("", cm) if err != nil { logger.Debugf("failed to handle commit message %v: %s", cm, err) - continue + return } - delete(t.commitMessages, cm.Vote.Hash) - } + // deleting while iterating is safe to do since + // each block hash has at most 1 commit message we + // just handled above. + t.commits.delete(cm.Vote.Hash) + }) }