diff --git a/dot/network/sync.go b/dot/network/sync.go index d0136ab1f9..e8fb87add2 100644 --- a/dot/network/sync.go +++ b/dot/network/sync.go @@ -119,6 +119,7 @@ type syncQueue struct { cancel context.CancelFunc peerScore *sync.Map // map[peer.ID]int; peers we have successfully synced from before -> their score; score increases on successful response + requestDataByHash *sync.Map // map[common.Hash]requestData; caching requestData by hash requestData *sync.Map // map[uint64]requestData; map of start # of request -> requestData justificationRequestData *sync.Map // map[common.Hash]requestData; map of requests of justifications -> requestData requestCh chan *syncRequest @@ -144,6 +145,7 @@ func newSyncQueue(s *Service) *syncQueue { cancel: cancel, peerScore: new(sync.Map), requestData: new(sync.Map), + requestDataByHash: new(sync.Map), justificationRequestData: new(sync.Map), requestCh: make(chan *syncRequest, blockRequestBufferSize), responses: []*types.BlockData{}, @@ -457,6 +459,7 @@ func (q *syncQueue) pushResponse(resp *BlockResponseMessage, pid peer.ID) error } startHash := resp.BlockData[0].Hash + if _, has := q.justificationRequestData.Load(startHash); has && !resp.BlockData[0].Header.Exists() { numJustifications := 0 justificationResponses := []*types.BlockData{} @@ -500,11 +503,18 @@ func (q *syncQueue) pushResponse(resp *BlockResponseMessage, pid peer.ID) error // update peer's score q.updatePeerScore(pid, 1) - q.requestData.Store(uint64(start), requestData{ + + reqdata := requestData{ sent: true, received: true, from: pid, - }) + } + + if _, has := q.requestDataByHash.Load(startHash); has { + q.requestDataByHash.Store(startHash, reqdata) + } else { + q.requestData.Store(uint64(start), reqdata) + } q.responseLock.Lock() defer q.responseLock.Unlock() @@ -522,6 +532,28 @@ func (q *syncQueue) pushResponse(resp *BlockResponseMessage, pid peer.ID) error return nil } +func (q *syncQueue) isRequestDataCached(startingBlock *variadic.Uint64OrHash) (*requestData, bool) { + if startingBlock == nil { + return nil, false + } + + if startingBlock.IsHash() { + if d, has := q.requestDataByHash.Load(startingBlock.Hash()); has { + data := d.(requestData) + return &data, true + } + } + + if startingBlock.IsUint64() { + if d, has := q.requestData.Load(startingBlock.Uint64()); has { + data := d.(requestData) + return &data, true + } + } + + return nil, false +} + func (q *syncQueue) processBlockRequests() { for { select { @@ -530,16 +562,15 @@ func (q *syncQueue) processBlockRequests() { continue } - if !req.req.StartingBlock.IsUint64() { + reqData, ok := q.isRequestDataCached(req.req.StartingBlock) + + if !ok { q.trySync(req) continue } - if d, has := q.requestData.Load(req.req.StartingBlock.Uint64()); has { - data := d.(requestData) - if data.sent && data.received { - continue - } + if reqData.sent && reqData.received { + continue } q.trySync(req) @@ -599,10 +630,13 @@ func (q *syncQueue) trySync(req *syncRequest) { received: false, }) } else if req.req.StartingBlock.IsHash() && (req.req.RequestedData&RequestedDataHeader) == 0 { - q.justificationRequestData.Store(req.req.StartingBlock.Hash(), requestData{ + startingBlockHash := req.req.StartingBlock.Hash() + reqdata := requestData{ sent: true, received: false, - }) + } + + q.justificationRequestData.Store(startingBlockHash, reqdata) } req.to = "" diff --git a/dot/network/sync_test.go b/dot/network/sync_test.go index 812ec04ca0..4bb96712d7 100644 --- a/dot/network/sync_test.go +++ b/dot/network/sync_test.go @@ -26,6 +26,7 @@ import ( "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common/optional" + "github.com/ChainSafe/gossamer/lib/common/variadic" "github.com/ChainSafe/gossamer/lib/utils" "github.com/ChainSafe/chaindb" @@ -421,6 +422,53 @@ func TestSyncQueue_processBlockResponses(t *testing.T) { require.Equal(t, blockRequestBufferSize, len(q.requestCh)) } +func TestSyncQueue_isRequestDataCached(t *testing.T) { + q := newTestSyncQueue(t) + q.stop() + + reqdata := requestData{ + sent: true, + received: false, + } + + // generate hash or uint64 + hashtrack := variadic.NewUint64OrHashFromBytes([]byte{0, 0, 0, 1}) + uinttrack := variadic.NewUint64OrHashFromBytes([]byte{1, 0, 0, 1}) + othertrack := variadic.NewUint64OrHashFromBytes([]byte{1, 2, 3, 1}) + + tests := []struct { + variadic *variadic.Uint64OrHash + reqMessage BlockRequestMessage + expectedOk bool + expectedData *requestData + }{ + { + variadic: hashtrack, + expectedOk: true, + expectedData: &reqdata, + }, + { + variadic: uinttrack, + expectedOk: true, + expectedData: &reqdata, + }, + { + variadic: othertrack, + expectedOk: false, + expectedData: nil, + }, + } + + q.requestDataByHash.Store(hashtrack.Hash(), reqdata) + q.requestData.Store(uinttrack.Uint64(), reqdata) + + for _, test := range tests { + data, ok := q.isRequestDataCached(test.variadic) + require.Equal(t, test.expectedOk, ok) + require.Equal(t, test.expectedData, data) + } +} + func TestSyncQueue_SyncAtHead(t *testing.T) { q := newTestSyncQueue(t) q.stop()