From fcfbabe657b848fb4d376cb1d611a474c0786f5b Mon Sep 17 00:00:00 2001 From: noot <36753753+noot@users.noreply.github.com> Date: Thu, 28 Oct 2021 21:00:14 -0400 Subject: [PATCH] fix(dot/sync): fix block request and response logic (#1907) --- dot/network/notifications.go | 2 +- dot/state/block.go | 10 + dot/sync/chain_sync.go | 30 ++- dot/sync/chain_sync_test.go | 24 +++ dot/sync/errors.go | 8 +- dot/sync/interface.go | 3 + dot/sync/message.go | 365 +++++++++++++++++++++++++++++----- dot/sync/message_test.go | 348 ++++++++++++++++++++++++++++++-- dot/sync/mocks/block_state.go | 67 +++++++ dot/sync/syncer.go | 6 + dot/sync/tip_syncer.go | 4 + 11 files changed, 791 insertions(+), 76 deletions(-) diff --git a/dot/network/notifications.go b/dot/network/notifications.go index b236157d4e..7ba62745e7 100644 --- a/dot/network/notifications.go +++ b/dot/network/notifications.go @@ -349,7 +349,7 @@ func (s *Service) sendData(peer peer.ID, hs Handshake, info *notificationsProtoc err := s.host.writeToStream(hsData.stream, msg) if err != nil { - logger.Trace("failed to send message to peer", "peer", peer, "error", err) + logger.Debug("failed to send message to peer", "peer", peer, "error", err) } } diff --git a/dot/state/block.go b/dot/state/block.go index 355bee3f72..46fc641b46 100644 --- a/dot/state/block.go +++ b/dot/state/block.go @@ -444,6 +444,16 @@ func (bs *BlockState) AddBlockToBlockTree(header *types.Header) error { return bs.bt.AddBlock(header, arrivalTime) } +// GetAllBlocksAtNumber returns all unfinalised blocks with the given number +func (bs *BlockState) GetAllBlocksAtNumber(num *big.Int) ([]common.Hash, error) { + header, err := bs.GetHeaderByNumber(num) + if err != nil { + return nil, err + } + + return bs.GetAllBlocksAtDepth(header.ParentHash), nil +} + // GetAllBlocksAtDepth returns all hashes with the depth of the given hash plus one func (bs *BlockState) GetAllBlocksAtDepth(hash common.Hash) []common.Hash { return bs.bt.GetAllBlocksAtNumber(hash) diff --git a/dot/sync/chain_sync.go b/dot/sync/chain_sync.go index 38a87ef4be..e7d4d5c9b4 100644 --- a/dot/sync/chain_sync.go +++ b/dot/sync/chain_sync.go @@ -646,9 +646,7 @@ func (cs *chainSync) doSync(req *network.BlockRequestMessage) *workerError { if req.Direction == network.Descending { // reverse blocks before pre-validating and placing in ready queue - for i, j := 0, len(resp.BlockData)-1; i < j; i, j = i+1, j-1 { - resp.BlockData[i], resp.BlockData[j] = resp.BlockData[j], resp.BlockData[i] - } + reverseBlockData(resp.BlockData) } // perform some pre-validation of response, error if failure @@ -897,10 +895,18 @@ func workerToRequests(w *worker) ([]*network.BlockRequestMessage, error) { } else { // in tip-syncing mode, we know the hash of the block on the fork we wish to sync start, _ = variadic.NewUint64OrHash(w.startHash) + + // if we're doing descending requests and not at the last (highest starting) request, + // then use number as start block + if w.direction == network.Descending && i != numRequests-1 { + start = variadic.MustNewUint64OrHash(startNumber) + } } var end *common.Hash - if !w.targetHash.IsEmpty() { + if !w.targetHash.IsEmpty() && i == numRequests-1 { + // if we're on our last request (which should contain the target hash), + // then add it end = &w.targetHash } @@ -911,7 +917,21 @@ func workerToRequests(w *worker) ([]*network.BlockRequestMessage, error) { Direction: w.direction, Max: &max, } - startNumber += maxResponseSize + + switch w.direction { + case network.Ascending: + startNumber += maxResponseSize + case network.Descending: + startNumber -= maxResponseSize + } + } + + // if our direction is descending, we want to send out the request with the lowest + // startNumber first + if w.direction == network.Descending { + for i, j := 0, len(reqs)-1; i < j; i, j = i+1, j-1 { + reqs[i], reqs[j] = reqs[j], reqs[i] + } } return reqs, nil diff --git a/dot/sync/chain_sync_test.go b/dot/sync/chain_sync_test.go index 9d3334089a..aab2ffdfd1 100644 --- a/dot/sync/chain_sync_test.go +++ b/dot/sync/chain_sync_test.go @@ -405,6 +405,30 @@ func TestWorkerToRequests(t *testing.T) { }, }, }, + { + w: &worker{ + startNumber: big.NewInt(1 + maxResponseSize + (maxResponseSize / 2)), + targetNumber: big.NewInt(1), + direction: network.Descending, + requestData: bootstrapRequestData, + }, + expected: []*network.BlockRequestMessage{ + { + RequestedData: network.RequestedDataHeader + network.RequestedDataBody + network.RequestedDataJustification, + StartingBlock: *variadic.MustNewUint64OrHash(1 + (maxResponseSize / 2)), + EndBlockHash: nil, + Direction: network.Descending, + Max: &max64, + }, + { + RequestedData: bootstrapRequestData, + StartingBlock: *variadic.MustNewUint64OrHash(1 + maxResponseSize + (maxResponseSize / 2)), + EndBlockHash: nil, + Direction: network.Descending, + Max: &max128, + }, + }, + }, } for i, tc := range testCases { diff --git a/dot/sync/errors.go b/dot/sync/errors.go index 53d2f2458a..6bfc7f6e93 100644 --- a/dot/sync/errors.go +++ b/dot/sync/errors.go @@ -40,7 +40,10 @@ var ( ErrInvalidBlock = errors.New("could not verify block") // ErrInvalidBlockRequest is returned when an invalid block request is received - ErrInvalidBlockRequest = errors.New("invalid block request") + ErrInvalidBlockRequest = errors.New("invalid block request") + errInvalidRequestDirection = errors.New("invalid request direction") + errRequestStartTooHigh = errors.New("request start number is higher than our best block") + errFailedToGetEndHashAncestor = errors.New("failed to get ancestor of end block") // chainSync errors errEmptyBlockData = errors.New("empty block data") @@ -57,6 +60,9 @@ var ( errUnknownParent = errors.New("parent of first block in block response is unknown") errUnknownBlockForJustification = errors.New("received justification for unknown block") errFailedToGetParent = errors.New("failed to get parent header") + errNilDescendantNumber = errors.New("descendant number is nil") + errStartAndEndMismatch = errors.New("request start and end hash are not on the same chain") + errFailedToGetDescendant = errors.New("failed to find descendant block") ) // ErrNilChannel is returned if a channel is nil diff --git a/dot/sync/interface.go b/dot/sync/interface.go index 2ba8c5b027..a9746523c3 100644 --- a/dot/sync/interface.go +++ b/dot/sync/interface.go @@ -57,6 +57,9 @@ type BlockState interface { StoreRuntime(common.Hash, runtime.Instance) GetHighestFinalisedHeader() (*types.Header, error) GetFinalisedNotifierChannel() chan *types.FinalisationInfo + GetHeaderByNumber(num *big.Int) (*types.Header, error) + GetAllBlocksAtNumber(num *big.Int) ([]common.Hash, error) + IsDescendantOf(parent, child common.Hash) (bool, error) } // StorageState is the interface for the storage state diff --git a/dot/sync/message.go b/dot/sync/message.go index 107ce787e4..5b3d324895 100644 --- a/dot/sync/message.go +++ b/dot/sync/message.go @@ -17,7 +17,6 @@ package sync import ( - "errors" "fmt" "math/big" @@ -32,108 +31,370 @@ const ( ) // CreateBlockResponse creates a block response message from a block request message -func (s *Service) CreateBlockResponse(blockRequest *network.BlockRequestMessage) (*network.BlockResponseMessage, error) { +func (s *Service) CreateBlockResponse(req *network.BlockRequestMessage) (*network.BlockResponseMessage, error) { + switch req.Direction { + case network.Ascending: + return s.handleAscendingRequest(req) + case network.Descending: + return s.handleDescendingRequest(req) + default: + return nil, errInvalidRequestDirection + } +} + +func (s *Service) handleAscendingRequest(req *network.BlockRequestMessage) (*network.BlockResponseMessage, error) { var ( - startHash, endHash common.Hash - startHeader, endHeader *types.Header - err error - respSize uint32 + startHash *common.Hash + endHash = req.EndBlockHash + startNumber, endNumber uint64 + max uint32 = maxResponseSize ) - if blockRequest.Max != nil { - respSize = *blockRequest.Max - if respSize > maxResponseSize { - respSize = maxResponseSize - } - } else { - respSize = maxResponseSize + // determine maximum response size + if req.Max != nil && *req.Max < maxResponseSize { + max = *req.Max } - switch startBlock := blockRequest.StartingBlock.Value().(type) { + switch startBlock := req.StartingBlock.Value().(type) { case uint64: if startBlock == 0 { startBlock = 1 } - block, err := s.blockState.GetBlockByNumber(big.NewInt(0).SetUint64(startBlock)) //nolint + bestBlockNumber, err := s.blockState.BestBlockNumber() if err != nil { - return nil, fmt.Errorf("failed to get start block %d for request: %w", startBlock, err) + return nil, fmt.Errorf("failed to get best block %d for request: %w", bestBlockNumber, err) + } + + // if request start is higher than our best block, return error + if bestBlockNumber.Uint64() < startBlock { + return nil, errRequestStartTooHigh } - startHeader = &block.Header - startHash = block.Header.Hash() + startNumber = startBlock + + if endHash != nil { + // TODO: end hash is provided but start hash isn't, so we need to determine a start block + // that is an ancestor of the end block + sh, err := s.blockState.GetHashByNumber(big.NewInt(int64(startNumber))) + if err != nil { + return nil, fmt.Errorf("failed to get start block %d for request: %w", startNumber, err) + } + + is, err := s.blockState.IsDescendantOf(sh, *endHash) + if err != nil { + return nil, err + } + + if !is { + return nil, fmt.Errorf("%w: hash=%s", errFailedToGetEndHashAncestor, *endHash) + } + + startHash = &sh + } case common.Hash: - startHash = startBlock - startHeader, err = s.blockState.GetHeader(startHash) + startHash = &startBlock + + // make sure we actually have the starting block + header, err := s.blockState.GetHeader(*startHash) if err != nil { return nil, fmt.Errorf("failed to get start block %s for request: %w", startHash, err) } + + startNumber = header.Number.Uint64() default: return nil, ErrInvalidBlockRequest } - if blockRequest.EndBlockHash != nil { - endHash = *blockRequest.EndBlockHash - endHeader, err = s.blockState.GetHeader(endHash) + if endHash == nil { + endNumber = startNumber + uint64(max) - 1 + bestBlockNumber, err := s.blockState.BestBlockNumber() if err != nil { - return nil, fmt.Errorf("failed to get end block %s for request: %w", endHash, err) + return nil, fmt.Errorf("failed to get best block %d for request: %w", bestBlockNumber, err) + } + + if endNumber > bestBlockNumber.Uint64() { + endNumber = bestBlockNumber.Uint64() } } else { - endNumber := big.NewInt(0).Add(startHeader.Number, big.NewInt(int64(respSize-1))) + header, err := s.blockState.GetHeader(*endHash) + if err != nil { + return nil, fmt.Errorf("failed to get end block %s: %w", *endHash, err) + } + + endNumber = header.Number.Uint64() + } + + // start hash provided, need to determine end hash that is descendant of start hash + if startHash != nil { + eh, err := s.checkOrGetDescendantHash(*startHash, endHash, big.NewInt(int64(endNumber))) + if err != nil { + return nil, err + } + + endHash = &eh + } + + if startHash == nil || endHash == nil { + logger.Debug("handling BlockRequestMessage", + "start", startNumber, + "end", endNumber, + "direction", req.Direction, + ) + return s.handleAscendingByNumber(startNumber, endNumber, req.RequestedData) + } + + logger.Debug("handling BlockRequestMessage", + "start", *startHash, + "end", *endHash, + "direction", req.Direction, + ) + return s.handleChainByHash(*startHash, *endHash, max, req.RequestedData, req.Direction) +} + +func (s *Service) handleDescendingRequest(req *network.BlockRequestMessage) (*network.BlockResponseMessage, error) { + var ( + startHash *common.Hash + endHash = req.EndBlockHash + startNumber, endNumber uint64 + max uint32 = maxResponseSize + ) + + // determine maximum response size + if req.Max != nil && *req.Max < maxResponseSize { + max = *req.Max + } + + switch startBlock := req.StartingBlock.Value().(type) { + case uint64: bestBlockNumber, err := s.blockState.BestBlockNumber() if err != nil { return nil, fmt.Errorf("failed to get best block %d for request: %w", bestBlockNumber, err) } - if endNumber.Cmp(bestBlockNumber) == 1 { - endNumber = bestBlockNumber + // if request start is higher than our best block, only return blocks from our best block and below + if bestBlockNumber.Uint64() < startBlock { + startNumber = bestBlockNumber.Uint64() + } else { + startNumber = startBlock } + case common.Hash: + startHash = &startBlock - endBlock, err := s.blockState.GetBlockByNumber(endNumber) + // make sure we actually have the starting block + header, err := s.blockState.GetHeader(*startHash) if err != nil { - return nil, fmt.Errorf("failed to get end block %d for request: %w", endNumber, err) + return nil, fmt.Errorf("failed to get start block %s for request: %w", startHash, err) } - endHeader = &endBlock.Header - endHash = endHeader.Hash() + + startNumber = header.Number.Uint64() + default: + return nil, ErrInvalidBlockRequest } - logger.Debug("handling BlockRequestMessage", "start", startHeader.Number, "end", endHeader.Number, "startHash", startHash, "endHash", endHash) + // end hash provided, need to determine start hash that is descendant of end hash + if endHash != nil { + sh, err := s.checkOrGetDescendantHash(*endHash, startHash, big.NewInt(int64(startNumber))) + startHash = &sh + if err != nil { + return nil, err + } + } - responseData := []*types.BlockData{} + // end hash is not provided, calculate end by number + if endHash == nil { + if startNumber <= uint64(max+1) { + endNumber = 1 + } else { + endNumber = startNumber - uint64(max) + 1 + } - switch blockRequest.Direction { - case network.Ascending: - for i := startHeader.Number.Int64(); i <= endHeader.Number.Int64(); i++ { - blockData, err := s.getBlockData(big.NewInt(i), blockRequest.RequestedData) + if startHash != nil { + // need to get blocks by subchain if start hash is provided, get end hash + endHeader, err := s.blockState.GetHeaderByNumber(big.NewInt(int64(endNumber))) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get end block %d for request: %w", endNumber, err) } - responseData = append(responseData, blockData) + + hash := endHeader.Hash() + endHash = &hash } - case network.Descending: - for i := endHeader.Number.Int64(); i >= startHeader.Number.Int64(); i-- { - blockData, err := s.getBlockData(big.NewInt(i), blockRequest.RequestedData) - if err != nil { - return nil, err + } + + if startHash == nil || endHash == nil { + logger.Debug("handling BlockRequestMessage", + "start", startNumber, + "end", endNumber, + "direction", req.Direction, + ) + return s.handleDescendingByNumber(startNumber, endNumber, req.RequestedData) + } + + logger.Debug("handling BlockRequestMessage", + "start", *startHash, + "end", *endHash, + "direction", req.Direction, + ) + return s.handleChainByHash(*endHash, *startHash, max, req.RequestedData, req.Direction) +} + +// checkOrGetDescendantHash checks if the provided `descendant` is on the same chain as the `ancestor`, if it's provided, +// otherwise, it sets `descendant` to a block with number=`descendantNumber` that is a descendant of the ancestor +// if used with an Ascending request, ancestor is the start block and descendant is the end block +// if used with an Descending request, ancestor is the end block and descendant is the start block +func (s *Service) checkOrGetDescendantHash(ancestor common.Hash, descendant *common.Hash, descendantNumber *big.Int) (common.Hash, error) { + if descendantNumber == nil { + return common.Hash{}, errNilDescendantNumber + } + + // if `descendant` was provided, check that it's a descendant of `ancestor` + if descendant != nil { + header, err := s.blockState.GetHeader(ancestor) + if err != nil { + return common.Hash{}, fmt.Errorf("failed to get descendant %s: %w", *descendant, err) + } + + // if descendant number is lower than ancestor number, this is an error + if header.Number.Cmp(descendantNumber) > 0 { + return common.Hash{}, fmt.Errorf("invalid request, descendant number %d is higher than ancestor %d", header.Number, descendantNumber) + } + + // check if provided start hash is descendant of provided descendant hash + is, err := s.blockState.IsDescendantOf(ancestor, *descendant) + if err != nil { + return common.Hash{}, err + } + + if !is { + return common.Hash{}, errStartAndEndMismatch + } + + return *descendant, nil + } + + // otherwise, get block on canonical chain by descendantNumber + hash, err := s.blockState.GetHashByNumber(descendantNumber) + if err != nil { + return common.Hash{}, err + } + + // check if it's a descendant of the provided ancestor hash + is, err := s.blockState.IsDescendantOf(ancestor, hash) + if err != nil { + return common.Hash{}, err + } + + if !is { + // if it's not a descendant, search for a block that has number=descendantNumber that is + hashes, err := s.blockState.GetAllBlocksAtNumber(descendantNumber) + if err != nil { + return common.Hash{}, fmt.Errorf("failed to get blocks at number %d: %w", descendantNumber, err) + } + + for _, hash := range hashes { + is, err := s.blockState.IsDescendantOf(ancestor, hash) + if err != nil || !is { + continue } - responseData = append(responseData, blockData) + + // this sets the descendant hash to whatever the first block we find with descendantNumber + // is, however there might be multiple blocks that fit this criteria + h := common.Hash{} + copy(h[:], hash[:]) + descendant = &h + break } - default: - return nil, errors.New("invalid BlockRequest direction") + + if descendant == nil { + return common.Hash{}, fmt.Errorf("%w with number %d", errFailedToGetDescendant, descendantNumber) + } + } else { + // if it is, set descendant hash to our block w/ descendantNumber + descendant = &hash + } + + logger.Trace("determined descendant", + "ancestor", ancestor, + "descendant", *descendant, + "number", descendantNumber, + ) + return *descendant, nil +} + +func (s *Service) handleAscendingByNumber(start, end uint64, requestedData byte) (*network.BlockResponseMessage, error) { + var err error + data := make([]*types.BlockData, (end-start)+1) + + for i := uint64(0); start+i <= end; i++ { + blockNumber := start + i + data[i], err = s.getBlockDataByNumber(big.NewInt(int64(blockNumber)), requestedData) + if err != nil { + return nil, err + } + } + + return &network.BlockResponseMessage{ + BlockData: data, + }, nil +} + +func (s *Service) handleDescendingByNumber(start, end uint64, requestedData byte) (*network.BlockResponseMessage, error) { + var err error + data := make([]*types.BlockData, (start-end)+1) + + for i := uint64(0); start-i >= end; i++ { + blockNumber := start - i + data[i], err = s.getBlockDataByNumber(big.NewInt(int64(blockNumber)), requestedData) + if err != nil { + return nil, err + } + } + + return &network.BlockResponseMessage{ + BlockData: data, + }, nil +} + +func (s *Service) handleChainByHash(ancestor, descendant common.Hash, max uint32, requestedData byte, direction network.SyncDirection) (*network.BlockResponseMessage, error) { + subchain, err := s.blockState.SubChain(ancestor, descendant) + if err != nil { + return nil, err + } + + if uint32(len(subchain)) > max { + subchain = subchain[:max] + } + + data := make([]*types.BlockData, len(subchain)) + + for i, hash := range subchain { + data[i], err = s.getBlockData(hash, requestedData) + if err != nil { + return nil, err + } + } + + // reverse BlockData, if descending request + if direction == network.Descending { + reverseBlockData(data) } - logger.Debug("sending BlockResponseMessage", "start", startHeader.Number, "end", endHeader.Number) return &network.BlockResponseMessage{ - BlockData: responseData, + BlockData: data, }, nil } -func (s *Service) getBlockData(num *big.Int, requestedData byte) (*types.BlockData, error) { +func (s *Service) getBlockDataByNumber(num *big.Int, requestedData byte) (*types.BlockData, error) { hash, err := s.blockState.GetHashByNumber(num) if err != nil { return nil, err } + return s.getBlockData(hash, requestedData) +} + +func (s *Service) getBlockData(hash common.Hash, requestedData byte) (*types.BlockData, error) { + var err error blockData := &types.BlockData{ Hash: hash, } @@ -145,14 +406,14 @@ func (s *Service) getBlockData(num *big.Int, requestedData byte) (*types.BlockDa if (requestedData & network.RequestedDataHeader) == 1 { blockData.Header, err = s.blockState.GetHeader(hash) if err != nil { - logger.Debug("failed to get header for block", "number", num, "hash", hash, "error", err) + logger.Debug("failed to get header for block", "hash", hash, "error", err) } } if (requestedData&network.RequestedDataBody)>>1 == 1 { blockData.Body, err = s.blockState.GetBlockBody(hash) if err != nil { - logger.Debug("failed to get body for block", "number", num, "hash", hash, "error", err) + logger.Debug("failed to get body for block", "hash", hash, "error", err) } } diff --git a/dot/sync/message_test.go b/dot/sync/message_test.go index e04091d232..7d94110ee4 100644 --- a/dot/sync/message_test.go +++ b/dot/sync/message_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/ChainSafe/gossamer/dot/network" + "github.com/ChainSafe/gossamer/dot/state" "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common/variadic" "github.com/ChainSafe/gossamer/lib/trie" @@ -37,8 +38,9 @@ func addTestBlocksToState(t *testing.T, depth int, blockState BlockState) { func TestService_CreateBlockResponse_MaxSize(t *testing.T) { s := newTestSyncer(t) - addTestBlocksToState(t, int(maxResponseSize), s.blockState) + addTestBlocksToState(t, int(maxResponseSize*2), s.blockState) + // test ascending start, err := variadic.NewUint64OrHash(uint64(1)) require.NoError(t, err) @@ -46,7 +48,7 @@ func TestService_CreateBlockResponse_MaxSize(t *testing.T) { RequestedData: 3, StartingBlock: *start, EndBlockHash: nil, - Direction: 0, + Direction: network.Ascending, Max: nil, } @@ -61,7 +63,7 @@ func TestService_CreateBlockResponse_MaxSize(t *testing.T) { RequestedData: 3, StartingBlock: *start, EndBlockHash: nil, - Direction: 0, + Direction: network.Ascending, Max: &max, } @@ -70,12 +72,79 @@ func TestService_CreateBlockResponse_MaxSize(t *testing.T) { require.Equal(t, int(maxResponseSize), len(resp.BlockData)) require.Equal(t, big.NewInt(1), resp.BlockData[0].Number()) require.Equal(t, big.NewInt(128), resp.BlockData[127].Number()) + + max = uint32(16) + req = &network.BlockRequestMessage{ + RequestedData: 3, + StartingBlock: *start, + EndBlockHash: nil, + Direction: network.Ascending, + Max: &max, + } + + resp, err = s.CreateBlockResponse(req) + require.NoError(t, err) + require.Equal(t, int(max), len(resp.BlockData)) + require.Equal(t, big.NewInt(1), resp.BlockData[0].Number()) + require.Equal(t, big.NewInt(16), resp.BlockData[15].Number()) + + // test descending + start, err = variadic.NewUint64OrHash(uint64(128)) + require.NoError(t, err) + + req = &network.BlockRequestMessage{ + RequestedData: 3, + StartingBlock: *start, + EndBlockHash: nil, + Direction: network.Descending, + Max: nil, + } + + resp, err = s.CreateBlockResponse(req) + require.NoError(t, err) + require.Equal(t, int(maxResponseSize), len(resp.BlockData)) + require.Equal(t, big.NewInt(128), resp.BlockData[0].Number()) + require.Equal(t, big.NewInt(1), resp.BlockData[127].Number()) + + max = uint32(maxResponseSize + 100) + start, err = variadic.NewUint64OrHash(uint64(256)) + require.NoError(t, err) + + req = &network.BlockRequestMessage{ + RequestedData: 3, + StartingBlock: *start, + EndBlockHash: nil, + Direction: network.Descending, + Max: &max, + } + + resp, err = s.CreateBlockResponse(req) + require.NoError(t, err) + require.Equal(t, int(maxResponseSize), len(resp.BlockData)) + require.Equal(t, big.NewInt(256), resp.BlockData[0].Number()) + require.Equal(t, big.NewInt(129), resp.BlockData[127].Number()) + + max = uint32(16) + req = &network.BlockRequestMessage{ + RequestedData: 3, + StartingBlock: *start, + EndBlockHash: nil, + Direction: network.Descending, + Max: &max, + } + + resp, err = s.CreateBlockResponse(req) + require.NoError(t, err) + require.Equal(t, int(max), len(resp.BlockData)) + require.Equal(t, big.NewInt(256), resp.BlockData[0].Number()) + require.Equal(t, big.NewInt(241), resp.BlockData[15].Number()) } func TestService_CreateBlockResponse_StartHash(t *testing.T) { s := newTestSyncer(t) - addTestBlocksToState(t, int(maxResponseSize), s.blockState) + addTestBlocksToState(t, int(maxResponseSize*2), s.blockState) + // test ascending with nil endBlockHash startHash, err := s.blockState.GetHashByNumber(big.NewInt(1)) require.NoError(t, err) @@ -86,7 +155,150 @@ func TestService_CreateBlockResponse_StartHash(t *testing.T) { RequestedData: 3, StartingBlock: *start, EndBlockHash: nil, - Direction: 0, + Direction: network.Ascending, + Max: nil, + } + + resp, err := s.CreateBlockResponse(req) + require.NoError(t, err) + require.Equal(t, int(maxResponseSize), len(resp.BlockData)) + require.Equal(t, big.NewInt(1), resp.BlockData[0].Number()) + require.Equal(t, big.NewInt(128), resp.BlockData[127].Number()) + + endHash, err := s.blockState.GetHashByNumber(big.NewInt(16)) + require.NoError(t, err) + + // test ascending with non-nil endBlockHash + req = &network.BlockRequestMessage{ + RequestedData: 3, + StartingBlock: *start, + EndBlockHash: &endHash, + Direction: network.Ascending, + Max: nil, + } + + resp, err = s.CreateBlockResponse(req) + require.NoError(t, err) + require.Equal(t, int(16), len(resp.BlockData)) + require.Equal(t, big.NewInt(1), resp.BlockData[0].Number()) + require.Equal(t, big.NewInt(16), resp.BlockData[15].Number()) + + // test descending with nil endBlockHash + startHash, err = s.blockState.GetHashByNumber(big.NewInt(16)) + require.NoError(t, err) + + start, err = variadic.NewUint64OrHash(startHash) + require.NoError(t, err) + + req = &network.BlockRequestMessage{ + RequestedData: 3, + StartingBlock: *start, + EndBlockHash: nil, + Direction: network.Descending, + Max: nil, + } + + resp, err = s.CreateBlockResponse(req) + require.NoError(t, err) + require.Equal(t, int(16), len(resp.BlockData)) + require.Equal(t, big.NewInt(16), resp.BlockData[0].Number()) + require.Equal(t, big.NewInt(1), resp.BlockData[15].Number()) + + // test descending with non-nil endBlockHash + endHash, err = s.blockState.GetHashByNumber(big.NewInt(1)) + require.NoError(t, err) + + req = &network.BlockRequestMessage{ + RequestedData: 3, + StartingBlock: *start, + EndBlockHash: &endHash, + Direction: network.Descending, + Max: nil, + } + + resp, err = s.CreateBlockResponse(req) + require.NoError(t, err) + require.Equal(t, int(16), len(resp.BlockData)) + require.Equal(t, big.NewInt(16), resp.BlockData[0].Number()) + require.Equal(t, big.NewInt(1), resp.BlockData[15].Number()) + + // test descending with nil endBlockHash and start > maxResponseSize + startHash, err = s.blockState.GetHashByNumber(big.NewInt(256)) + require.NoError(t, err) + + start, err = variadic.NewUint64OrHash(startHash) + require.NoError(t, err) + + req = &network.BlockRequestMessage{ + RequestedData: 3, + StartingBlock: *start, + EndBlockHash: nil, + Direction: network.Descending, + Max: nil, + } + + resp, err = s.CreateBlockResponse(req) + require.NoError(t, err) + require.Equal(t, int(maxResponseSize), len(resp.BlockData)) + require.Equal(t, big.NewInt(256), resp.BlockData[0].Number()) + require.Equal(t, big.NewInt(129), resp.BlockData[127].Number()) + + startHash, err = s.blockState.GetHashByNumber(big.NewInt(128)) + require.NoError(t, err) + + start, err = variadic.NewUint64OrHash(startHash) + require.NoError(t, err) + + req = &network.BlockRequestMessage{ + RequestedData: 3, + StartingBlock: *start, + EndBlockHash: nil, + Direction: network.Descending, + Max: nil, + } + + resp, err = s.CreateBlockResponse(req) + require.NoError(t, err) + require.Equal(t, int(maxResponseSize), len(resp.BlockData)) + require.Equal(t, big.NewInt(128), resp.BlockData[0].Number()) + require.Equal(t, big.NewInt(1), resp.BlockData[127].Number()) +} + +func TestService_CreateBlockResponse_Ascending_EndHash(t *testing.T) { + t.Parallel() + s := newTestSyncer(t) + addTestBlocksToState(t, int(maxResponseSize+1), s.blockState) + + // should error if end < start + start, err := variadic.NewUint64OrHash(uint64(128)) + require.NoError(t, err) + + end, err := s.blockState.GetHashByNumber(big.NewInt(1)) + require.NoError(t, err) + + req := &network.BlockRequestMessage{ + RequestedData: 3, + StartingBlock: *start, + EndBlockHash: &end, + Direction: network.Ascending, + Max: nil, + } + + _, err = s.CreateBlockResponse(req) + require.Error(t, err) + + // base case + start, err = variadic.NewUint64OrHash(uint64(1)) + require.NoError(t, err) + + end, err = s.blockState.GetHashByNumber(big.NewInt(128)) + require.NoError(t, err) + + req = &network.BlockRequestMessage{ + RequestedData: 3, + StartingBlock: *start, + EndBlockHash: &end, + Direction: network.Ascending, Max: nil, } @@ -97,21 +309,40 @@ func TestService_CreateBlockResponse_StartHash(t *testing.T) { require.Equal(t, big.NewInt(128), resp.BlockData[127].Number()) } -func TestService_CreateBlockResponse_Descending(t *testing.T) { +func TestService_CreateBlockResponse_Descending_EndHash(t *testing.T) { s := newTestSyncer(t) - addTestBlocksToState(t, int(maxResponseSize), s.blockState) + addTestBlocksToState(t, int(maxResponseSize+1), s.blockState) - startHash, err := s.blockState.GetHashByNumber(big.NewInt(1)) + // should error if start < end + start, err := variadic.NewUint64OrHash(uint64(1)) require.NoError(t, err) - start, err := variadic.NewUint64OrHash(startHash) + end, err := s.blockState.GetHashByNumber(big.NewInt(128)) require.NoError(t, err) req := &network.BlockRequestMessage{ RequestedData: 3, StartingBlock: *start, - EndBlockHash: nil, - Direction: 1, + EndBlockHash: &end, + Direction: network.Descending, + Max: nil, + } + + _, err = s.CreateBlockResponse(req) + require.Error(t, err) + + // base case + start, err = variadic.NewUint64OrHash(uint64(128)) + require.NoError(t, err) + + end, err = s.blockState.GetHashByNumber(big.NewInt(1)) + require.NoError(t, err) + + req = &network.BlockRequestMessage{ + RequestedData: 3, + StartingBlock: *start, + EndBlockHash: &end, + Direction: network.Descending, Max: nil, } @@ -122,8 +353,91 @@ func TestService_CreateBlockResponse_Descending(t *testing.T) { require.Equal(t, big.NewInt(1), resp.BlockData[127].Number()) } -// tests the ProcessBlockRequestMessage method -func TestService_CreateBlockResponse(t *testing.T) { +func TestService_checkOrGetDescendantHash(t *testing.T) { + t.Parallel() + s := newTestSyncer(t) + branches := map[int]int{ + 8: 1, + } + state.AddBlocksToStateWithFixedBranches(t, s.blockState.(*state.BlockState), 16, branches, 1) + + // base case + ancestor, err := s.blockState.GetHashByNumber(big.NewInt(1)) + require.NoError(t, err) + descendant, err := s.blockState.GetHashByNumber(big.NewInt(16)) + require.NoError(t, err) + descendantNumber := big.NewInt(16) + + res, err := s.checkOrGetDescendantHash(ancestor, &descendant, descendantNumber) + require.NoError(t, err) + require.Equal(t, descendant, res) + + // supply descendant that's not on canonical chain + leaves := s.blockState.(*state.BlockState).Leaves() + require.Equal(t, 2, len(leaves)) + + ancestor, err = s.blockState.GetHashByNumber(big.NewInt(1)) + require.NoError(t, err) + descendant, err = s.blockState.GetHashByNumber(big.NewInt(16)) + require.NoError(t, err) + + for _, leaf := range leaves { + if !leaf.Equal(descendant) { + descendant = leaf + break + } + } + + res, err = s.checkOrGetDescendantHash(ancestor, &descendant, descendantNumber) + require.NoError(t, err) + require.Equal(t, descendant, res) + + // supply descedant that's not on same chain as ancestor + ancestor, err = s.blockState.GetHashByNumber(big.NewInt(9)) + require.NoError(t, err) + res, err = s.checkOrGetDescendantHash(ancestor, &descendant, descendantNumber) + require.Error(t, err) + + // don't supply descendant, should return block on canonical chain + // as ancestor is on canonical chain + expected, err := s.blockState.GetHashByNumber(big.NewInt(16)) + require.NoError(t, err) + + res, err = s.checkOrGetDescendantHash(ancestor, nil, descendantNumber) + require.NoError(t, err) + require.Equal(t, expected, res) + + // don't supply descendant and provide ancestor not on canonical chain + // should return descendant block also not on canonical chain + block9s, err := s.blockState.GetAllBlocksAtNumber(big.NewInt(9)) + require.NoError(t, err) + canonical, err := s.blockState.GetHashByNumber(big.NewInt(9)) + require.NoError(t, err) + + // set ancestor to non-canonical block 9 + for _, block := range block9s { + if !canonical.Equal(block) { + ancestor = block + break + } + } + + // expected is non-canonical block 16 + for _, leaf := range leaves { + is, err := s.blockState.IsDescendantOf(ancestor, leaf) //nolint + require.NoError(t, err) + if is { + expected = leaf + break + } + } + + res, err = s.checkOrGetDescendantHash(ancestor, nil, descendantNumber) + require.NoError(t, err) + require.Equal(t, expected, res) +} + +func TestService_CreateBlockResponse_Fields(t *testing.T) { s := newTestSyncer(t) addTestBlocksToState(t, 2, s.blockState) @@ -171,7 +485,7 @@ func TestService_CreateBlockResponse(t *testing.T) { RequestedData: 3, StartingBlock: *start, EndBlockHash: &endHash, - Direction: 0, + Direction: network.Ascending, Max: nil, }, expectedMsgValue: &network.BlockResponseMessage{ @@ -190,7 +504,7 @@ func TestService_CreateBlockResponse(t *testing.T) { RequestedData: 1, StartingBlock: *start, EndBlockHash: &endHash, - Direction: 0, + Direction: network.Ascending, Max: nil, }, expectedMsgValue: &network.BlockResponseMessage{ @@ -209,7 +523,7 @@ func TestService_CreateBlockResponse(t *testing.T) { RequestedData: 4, StartingBlock: *start, EndBlockHash: &endHash, - Direction: 0, + Direction: network.Ascending, Max: nil, }, expectedMsgValue: &network.BlockResponseMessage{ @@ -229,7 +543,7 @@ func TestService_CreateBlockResponse(t *testing.T) { RequestedData: 8, StartingBlock: *start, EndBlockHash: &endHash, - Direction: 0, + Direction: network.Ascending, Max: nil, }, expectedMsgValue: &network.BlockResponseMessage{ diff --git a/dot/sync/mocks/block_state.go b/dot/sync/mocks/block_state.go index 31c7dce2aa..2a3dcd60bc 100644 --- a/dot/sync/mocks/block_state.go +++ b/dot/sync/mocks/block_state.go @@ -122,6 +122,29 @@ func (_m *BlockState) CompareAndSetBlockData(bd *types.BlockData) error { return r0 } +// GetAllBlocksAtNumber provides a mock function with given fields: num +func (_m *BlockState) GetAllBlocksAtNumber(num *big.Int) ([]common.Hash, error) { + ret := _m.Called(num) + + var r0 []common.Hash + if rf, ok := ret.Get(0).(func(*big.Int) []common.Hash); ok { + r0 = rf(num) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]common.Hash) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*big.Int) error); ok { + r1 = rf(num) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetBlockBody provides a mock function with given fields: _a0 func (_m *BlockState) GetBlockBody(_a0 common.Hash) (*types.Body, error) { ret := _m.Called(_a0) @@ -253,6 +276,29 @@ func (_m *BlockState) GetHeader(_a0 common.Hash) (*types.Header, error) { return r0, r1 } +// GetHeaderByNumber provides a mock function with given fields: num +func (_m *BlockState) GetHeaderByNumber(num *big.Int) (*types.Header, error) { + ret := _m.Called(num) + + var r0 *types.Header + if rf, ok := ret.Get(0).(func(*big.Int) *types.Header); ok { + r0 = rf(num) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.Header) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*big.Int) error); ok { + r1 = rf(num) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetHighestFinalisedHeader provides a mock function with given fields: func (_m *BlockState) GetHighestFinalisedHeader() (*types.Header, error) { ret := _m.Called() @@ -410,6 +456,27 @@ func (_m *BlockState) HasHeader(hash common.Hash) (bool, error) { return r0, r1 } +// IsDescendantOf provides a mock function with given fields: parent, child +func (_m *BlockState) IsDescendantOf(parent common.Hash, child common.Hash) (bool, error) { + ret := _m.Called(parent, child) + + var r0 bool + if rf, ok := ret.Get(0).(func(common.Hash, common.Hash) bool); ok { + r0 = rf(parent, child) + } else { + r0 = ret.Get(0).(bool) + } + + var r1 error + if rf, ok := ret.Get(1).(func(common.Hash, common.Hash) error); ok { + r1 = rf(parent, child) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // SetFinalisedHash provides a mock function with given fields: hash, round, setID func (_m *BlockState) SetFinalisedHash(hash common.Hash, round uint64, setID uint64) error { ret := _m.Called(hash, round, setID) diff --git a/dot/sync/syncer.go b/dot/sync/syncer.go index 9c4330f45b..9df49c15b5 100644 --- a/dot/sync/syncer.go +++ b/dot/sync/syncer.go @@ -133,3 +133,9 @@ func (s *Service) HandleBlockAnnounce(from peer.ID, msg *network.BlockAnnounceMe func (s *Service) IsSynced() bool { return s.chainSync.syncState() == tip } + +func reverseBlockData(data []*types.BlockData) { + for i, j := 0, len(data)-1; i < j; i, j = i+1, j-1 { + data[i], data[j] = data[j], data[i] + } +} diff --git a/dot/sync/tip_syncer.go b/dot/sync/tip_syncer.go index 45987f56b1..a20f68551a 100644 --- a/dot/sync/tip_syncer.go +++ b/dot/sync/tip_syncer.go @@ -157,6 +157,8 @@ func (*tipSyncer) hasCurrentWorker(w *worker, workers map[uint64]*worker) bool { // handleTick traverses the pending blocks set to find which forks still need to be requested func (s *tipSyncer) handleTick() ([]*worker, error) { + logger.Debug("handling tick...", "pending blocks count", s.pendingBlocks.size()) + if s.pendingBlocks.size() == 0 { return nil, nil } @@ -181,6 +183,8 @@ func (s *tipSyncer) handleTick() ([]*worker, error) { continue } + logger.Trace("handling pending block", "hash", block.hash, "number", block.number) + if block.header == nil { // case 1 workers = append(workers, &worker{