diff --git a/dot/core/digest.go b/dot/core/digest.go index 0758625204..f46fe41768 100644 --- a/dot/core/digest.go +++ b/dot/core/digest.go @@ -42,7 +42,7 @@ type DigestHandler struct { // block notification channels imported chan *types.Block importedID byte - finalised chan *types.Header + finalised chan *types.FinalisationInfo finalisedID byte // GRANDPA changes @@ -68,7 +68,7 @@ type resume struct { // NewDigestHandler returns a new DigestHandler func NewDigestHandler(blockState BlockState, epochState EpochState, grandpaState GrandpaState, babe BlockProducer, verifier Verifier) (*DigestHandler, error) { imported := make(chan *types.Block, 16) - finalised := make(chan *types.Header, 16) + finalised := make(chan *types.FinalisationInfo, 16) iid, err := blockState.RegisterImportedChannel(imported) if err != nil { return nil, err @@ -195,12 +195,12 @@ func (h *DigestHandler) handleBlockImport(ctx context.Context) { func (h *DigestHandler) handleBlockFinalisation(ctx context.Context) { for { select { - case header := <-h.finalised: - if header == nil { + case info := <-h.finalised: + if info == nil || info.Header == nil { continue } - err := h.handleGrandpaChangesOnFinalization(header.Number) + err := h.handleGrandpaChangesOnFinalization(info.Header.Number) if err != nil { logger.Error("failed to handle grandpa changes on block finalisation", "error", err) } diff --git a/dot/core/interface.go b/dot/core/interface.go index 04df6bdc77..bab25e3099 100644 --- a/dot/core/interface.go +++ b/dot/core/interface.go @@ -44,7 +44,7 @@ type BlockState interface { SetFinalizedHash(common.Hash, uint64, uint64) error RegisterImportedChannel(ch chan<- *types.Block) (byte, error) UnregisterImportedChannel(id byte) - RegisterFinalizedChannel(ch chan<- *types.Header) (byte, error) + RegisterFinalizedChannel(ch chan<- *types.FinalisationInfo) (byte, error) UnregisterFinalizedChannel(id byte) HighestCommonAncestor(a, b common.Hash) (common.Hash, error) SubChain(start, end common.Hash) ([]common.Hash, error) diff --git a/dot/rpc/modules/api.go b/dot/rpc/modules/api.go index 4951294139..ce7356d1d6 100644 --- a/dot/rpc/modules/api.go +++ b/dot/rpc/modules/api.go @@ -33,7 +33,7 @@ type BlockAPI interface { GetJustification(hash common.Hash) ([]byte, error) RegisterImportedChannel(ch chan<- *types.Block) (byte, error) UnregisterImportedChannel(id byte) - RegisterFinalizedChannel(ch chan<- *types.Header) (byte, error) + RegisterFinalizedChannel(ch chan<- *types.FinalisationInfo) (byte, error) UnregisterFinalizedChannel(id byte) SubChain(start, end common.Hash) ([]common.Hash, error) } diff --git a/dot/rpc/subscription/listeners.go b/dot/rpc/subscription/listeners.go index 7f7916545f..579cfbba88 100644 --- a/dot/rpc/subscription/listeners.go +++ b/dot/rpc/subscription/listeners.go @@ -114,7 +114,7 @@ func (l *BlockListener) Listen() { // BlockFinalizedListener to handle listening for finalised blocks type BlockFinalizedListener struct { - channel chan *types.Header + channel chan *types.FinalisationInfo wsconn WSConnAPI chanID byte subID uint @@ -122,11 +122,11 @@ type BlockFinalizedListener struct { // Listen implementation of Listen interface to listen for importedChan changes func (l *BlockFinalizedListener) Listen() { - for header := range l.channel { - if header == nil { + for info := range l.channel { + if info == nil || info.Header == nil { continue } - head, err := modules.HeaderToJSON(*header) + head, err := modules.HeaderToJSON(*info.Header) if err != nil { logger.Error("failed to convert header to JSON", "error", err) } @@ -147,7 +147,7 @@ type ExtrinsicSubmitListener struct { importedChan chan *types.Block importedChanID byte importedHash common.Hash - finalisedChan chan *types.Header + finalisedChan chan *types.FinalisationInfo finalisedChanID byte } @@ -180,10 +180,10 @@ func (l *ExtrinsicSubmitListener) Listen() { // listen for finalised headers go func() { - for header := range l.finalisedChan { - if reflect.DeepEqual(l.importedHash, header.Hash()) { + for info := range l.finalisedChan { + if reflect.DeepEqual(l.importedHash, info.Header.Hash()) { resM := make(map[string]interface{}) - resM["finalised"] = header.Hash().String() + resM["finalised"] = info.Header.Hash().String() l.wsconn.safeSend(newSubscriptionResponse(AuthorExtrinsicUpdates, l.subID, resM)) } } diff --git a/dot/rpc/subscription/listeners_test.go b/dot/rpc/subscription/listeners_test.go index ebbe8d8119..13a20914bc 100644 --- a/dot/rpc/subscription/listeners_test.go +++ b/dot/rpc/subscription/listeners_test.go @@ -95,7 +95,7 @@ func TestBlockListener_Listen(t *testing.T) { } func TestBlockFinalizedListener_Listen(t *testing.T) { - notifyChan := make(chan *types.Header) + notifyChan := make(chan *types.FinalisationInfo) mockConnection := &MockWSConnAPI{} bfl := BlockFinalizedListener{ channel: notifyChan, @@ -113,14 +113,16 @@ func TestBlockFinalizedListener_Listen(t *testing.T) { go bfl.Listen() - notifyChan <- header + notifyChan <- &types.FinalisationInfo{ + Header: header, + } time.Sleep(time.Millisecond * 10) require.Equal(t, expectedResponse, mockConnection.lastMessage) } func TestExtrinsicSubmitListener_Listen(t *testing.T) { notifyImportedChan := make(chan *types.Block) - notifyFinalizedChan := make(chan *types.Header) + notifyFinalizedChan := make(chan *types.FinalisationInfo) mockConnection := &MockWSConnAPI{} esl := ExtrinsicSubmitListener{ @@ -149,7 +151,9 @@ func TestExtrinsicSubmitListener_Listen(t *testing.T) { time.Sleep(time.Millisecond * 10) require.Equal(t, expectedImportedRespones, mockConnection.lastMessage) - notifyFinalizedChan <- header + notifyFinalizedChan <- &types.FinalisationInfo{ + Header: header, + } time.Sleep(time.Millisecond * 10) resFinalised := map[string]interface{}{"finalised": block.Header.Hash().String()} expectedFinalizedRespones := newSubscriptionResponse(AuthorExtrinsicUpdates, esl.subID, resFinalised) diff --git a/dot/rpc/subscription/websocket.go b/dot/rpc/subscription/websocket.go index 17d6b737a7..40fd48bedc 100644 --- a/dot/rpc/subscription/websocket.go +++ b/dot/rpc/subscription/websocket.go @@ -236,7 +236,7 @@ func (c *WSConn) initBlockListener(reqID float64) (uint, error) { func (c *WSConn) initBlockFinalizedListener(reqID float64) (uint, error) { bfl := &BlockFinalizedListener{ - channel: make(chan *types.Header), + channel: make(chan *types.FinalisationInfo), wsconn: c, } @@ -271,7 +271,7 @@ func (c *WSConn) initExtrinsicWatch(reqID float64, params interface{}) (uint, er importedChan: make(chan *types.Block), wsconn: c, extrinsic: types.Extrinsic(extBytes), - finalisedChan: make(chan *types.Header), + finalisedChan: make(chan *types.FinalisationInfo), } if c.BlockAPI == nil { diff --git a/dot/rpc/subscription/websocket_test.go b/dot/rpc/subscription/websocket_test.go index 43e189c39c..735df61fac 100644 --- a/dot/rpc/subscription/websocket_test.go +++ b/dot/rpc/subscription/websocket_test.go @@ -223,7 +223,7 @@ func (m *MockBlockAPI) RegisterImportedChannel(ch chan<- *types.Block) (byte, er } func (m *MockBlockAPI) UnregisterImportedChannel(id byte) { } -func (m *MockBlockAPI) RegisterFinalizedChannel(ch chan<- *types.Header) (byte, error) { +func (m *MockBlockAPI) RegisterFinalizedChannel(ch chan<- *types.FinalisationInfo) (byte, error) { return 0, nil } func (m *MockBlockAPI) UnregisterFinalizedChannel(id byte) {} diff --git a/dot/rpc/websocket_test.go b/dot/rpc/websocket_test.go index 4cdd9e271d..e03a8d8e7d 100644 --- a/dot/rpc/websocket_test.go +++ b/dot/rpc/websocket_test.go @@ -120,7 +120,7 @@ func (m *MockBlockAPI) RegisterImportedChannel(ch chan<- *types.Block) (byte, er } func (m *MockBlockAPI) UnregisterImportedChannel(id byte) { } -func (m *MockBlockAPI) RegisterFinalizedChannel(ch chan<- *types.Header) (byte, error) { +func (m *MockBlockAPI) RegisterFinalizedChannel(ch chan<- *types.FinalisationInfo) (byte, error) { return 0, nil } func (m *MockBlockAPI) UnregisterFinalizedChannel(id byte) {} diff --git a/dot/state/block.go b/dot/state/block.go index 16406d066e..ea5d8c5267 100644 --- a/dot/state/block.go +++ b/dot/state/block.go @@ -47,7 +47,7 @@ type BlockState struct { // block notifiers imported map[byte]chan<- *types.Block - finalised map[byte]chan<- *types.Header + finalised map[byte]chan<- *types.FinalisationInfo importedLock sync.RWMutex finalisedLock sync.RWMutex @@ -65,7 +65,7 @@ func NewBlockState(db chaindb.Database, bt *blocktree.BlockTree) (*BlockState, e baseState: NewBaseState(db), db: chaindb.NewTable(db, blockPrefix), imported: make(map[byte]chan<- *types.Block), - finalised: make(map[byte]chan<- *types.Header), + finalised: make(map[byte]chan<- *types.FinalisationInfo), pruneKeyCh: make(chan *types.Header, pruneKeyBufferSize), } @@ -85,7 +85,7 @@ func NewBlockStateFromGenesis(db chaindb.Database, header *types.Header) (*Block baseState: NewBaseState(db), db: chaindb.NewTable(db, blockPrefix), imported: make(map[byte]chan<- *types.Block), - finalised: make(map[byte]chan<- *types.Header), + finalised: make(map[byte]chan<- *types.FinalisationInfo), pruneKeyCh: make(chan *types.Header, pruneKeyBufferSize), } @@ -424,7 +424,7 @@ func (bs *BlockState) SetFinalizedHash(hash common.Hash, round, setID uint64) er bs.Lock() defer bs.Unlock() - go bs.notifyFinalized(hash) + go bs.notifyFinalized(hash, round, setID) if round > 0 { err := bs.SetRound(round) if err != nil { diff --git a/dot/state/block_notify.go b/dot/state/block_notify.go index b4aa44bfc3..1fbcb214f6 100644 --- a/dot/state/block_notify.go +++ b/dot/state/block_notify.go @@ -51,7 +51,7 @@ func (bs *BlockState) RegisterImportedChannel(ch chan<- *types.Block) (byte, err // RegisterFinalizedChannel registers a channel for block notification upon block finalisation. // It returns the channel ID (used for unregistering the channel) -func (bs *BlockState) RegisterFinalizedChannel(ch chan<- *types.Header) (byte, error) { +func (bs *BlockState) RegisterFinalizedChannel(ch chan<- *types.FinalisationInfo) (byte, error) { bs.finalisedLock.RLock() if len(bs.finalised) == 256 { @@ -111,7 +111,7 @@ func (bs *BlockState) notifyImported(block *types.Block) { } } -func (bs *BlockState) notifyFinalized(hash common.Hash) { +func (bs *BlockState) notifyFinalized(hash common.Hash, round, setID uint64) { bs.finalisedLock.RLock() defer bs.finalisedLock.RUnlock() @@ -126,11 +126,16 @@ func (bs *BlockState) notifyFinalized(hash common.Hash) { } logger.Debug("notifying finalised block chans...", "chans", bs.finalised) + info := &types.FinalisationInfo{ + Header: header, + Round: round, + SetID: setID, + } for _, ch := range bs.finalised { - go func(ch chan<- *types.Header) { + go func(ch chan<- *types.FinalisationInfo) { select { - case ch <- header: + case ch <- info: default: } }(ch) diff --git a/dot/state/block_notify_test.go b/dot/state/block_notify_test.go index 3e8054a4fe..ace6e1067f 100644 --- a/dot/state/block_notify_test.go +++ b/dot/state/block_notify_test.go @@ -52,7 +52,7 @@ func TestImportChannel(t *testing.T) { func TestFinalizedChannel(t *testing.T) { bs := newTestBlockState(t, testGenesisHeader) - ch := make(chan *types.Header, 3) + ch := make(chan *types.FinalisationInfo, 3) id, err := bs.RegisterFinalizedChannel(ch) require.NoError(t, err) @@ -117,12 +117,12 @@ func TestFinalizedChannel_Multi(t *testing.T) { bs := newTestBlockState(t, testGenesisHeader) num := 5 - chs := make([]chan *types.Header, num) + chs := make([]chan *types.FinalisationInfo, num) ids := make([]byte, num) var err error for i := 0; i < num; i++ { - chs[i] = make(chan *types.Header) + chs[i] = make(chan *types.FinalisationInfo) ids[i], err = bs.RegisterFinalizedChannel(chs[i]) require.NoError(t, err) } @@ -134,7 +134,7 @@ func TestFinalizedChannel_Multi(t *testing.T) { for i, ch := range chs { - go func(i int, ch chan *types.Header) { + go func(i int, ch chan *types.FinalisationInfo) { select { case <-ch: case <-time.After(testMessageTimeout): diff --git a/dot/types/grandpa.go b/dot/types/grandpa.go index 65a9af80e2..dd3c6ee97f 100644 --- a/dot/types/grandpa.go +++ b/dot/types/grandpa.go @@ -165,3 +165,10 @@ func DecodeGrandpaVoters(r io.Reader) (GrandpaVoters, error) { return voters, nil } + +// FinalisationInfo represents information about what block was finalised in what round and setID +type FinalisationInfo struct { + Header *Header + Round uint64 + SetID uint64 +} diff --git a/lib/grandpa/grandpa.go b/lib/grandpa/grandpa.go index 4c0a1140d8..436cb32ad7 100644 --- a/lib/grandpa/grandpa.go +++ b/lib/grandpa/grandpa.go @@ -74,7 +74,10 @@ type Service struct { justification map[uint64][]*SignedPrecommit // map of round number -> precommit round justification // channels for communication with other services - in chan GrandpaMessage // only used to receive *VoteMessage + in chan GrandpaMessage // only used to receive *VoteMessage + finalisedCh chan *types.FinalisationInfo + finalisedChID byte + neighbourMessage *NeighbourMessage // cached neighbour message } // Config represents a GRANDPA service configuration @@ -133,6 +136,12 @@ func NewService(cfg *Config) (*Service, error) { return nil, err } + finalisedCh := make(chan *types.FinalisationInfo, 16) + fid, err := cfg.BlockState.RegisterFinalizedChannel(finalisedCh) + if err != nil { + return nil, err + } + ctx, cancel := context.WithCancel(context.Background()) s := &Service{ ctx: ctx, @@ -156,6 +165,8 @@ func NewService(cfg *Config) (*Service, error) { in: make(chan GrandpaMessage, 128), resumed: make(chan struct{}), network: cfg.Network, + finalisedCh: finalisedCh, + finalisedChID: fid, } s.messageHandler = NewMessageHandler(s, s.blockState) @@ -185,6 +196,7 @@ func (s *Service) Start() error { } }() + go s.sendNeighbourMessage() return nil } @@ -195,6 +207,9 @@ func (s *Service) Stop() error { s.cancel() + s.blockState.UnregisterFinalizedChannel(s.finalisedChID) + close(s.finalisedCh) + if !s.authority { return nil } diff --git a/lib/grandpa/network.go b/lib/grandpa/network.go index d52677db77..ba712d347d 100644 --- a/lib/grandpa/network.go +++ b/lib/grandpa/network.go @@ -18,6 +18,7 @@ package grandpa import ( "fmt" + "time" "github.com/ChainSafe/gossamer/dot/network" "github.com/ChainSafe/gossamer/lib/common" @@ -28,8 +29,9 @@ import ( ) var ( - grandpaID protocol.ID = "/paritytech/grandpa/1" - messageID = network.ConsensusMsgType + grandpaID protocol.ID = "/paritytech/grandpa/1" + messageID = network.ConsensusMsgType + neighbourMessageInterval = time.Minute * 5 ) // Handshake is an alias for network.Handshake @@ -160,3 +162,29 @@ func (s *Service) handleNetworkMessage(from peer.ID, msg NotificationsMessage) ( return true, nil } + +func (s *Service) sendNeighbourMessage() { + for { + select { + case <-time.After(neighbourMessageInterval): + if s.neighbourMessage == nil { + continue + } + case info := <-s.finalisedCh: + s.neighbourMessage = &NeighbourMessage{ + Version: 1, + Round: info.Round, + SetID: info.SetID, + Number: uint32(info.Header.Number.Int64()), + } + } + + cm, err := s.neighbourMessage.ToConsensusMessage() + if err != nil { + logger.Warn("failed to convert NeighbourMessage to network message", "error", err) + continue + } + + s.network.SendMessage(cm) + } +} diff --git a/lib/grandpa/network_test.go b/lib/grandpa/network_test.go index 466e92f953..e41d4b7d0b 100644 --- a/lib/grandpa/network_test.go +++ b/lib/grandpa/network_test.go @@ -17,9 +17,12 @@ package grandpa import ( + "math/big" "testing" "time" + "github.com/ChainSafe/gossamer/dot/types" + "github.com/libp2p/go-libp2p-core/peer" "github.com/stretchr/testify/require" ) @@ -80,3 +83,56 @@ func TestHandleNetworkMessage(t *testing.T) { require.NoError(t, err) require.False(t, propagate) } + +func TestSendNeighbourMessage(t *testing.T) { + gs, st := newTestService(t) + neighbourMessageInterval = time.Second + defer func() { + neighbourMessageInterval = time.Minute * 5 + }() + go gs.sendNeighbourMessage() + + block := &types.Block{ + Header: &types.Header{ + ParentHash: testGenesisHeader.Hash(), + Number: big.NewInt(1), + }, + Body: &types.Body{}, + } + + err := st.Block.AddBlock(block) + require.NoError(t, err) + + hash := block.Header.Hash() + round := uint64(7) + setID := uint64(33) + err = st.Block.SetFinalizedHash(hash, round, setID) + require.NoError(t, err) + + expected := &NeighbourMessage{ + Version: 1, + SetID: setID, + Round: round, + Number: 1, + } + + select { + case <-time.After(time.Second): + t.Fatal("did not send message") + case msg := <-gs.network.(*testNetwork).out: + nm, ok := msg.(*NeighbourMessage) + require.True(t, ok) + require.Equal(t, expected, nm) + } + + require.Equal(t, expected, gs.neighbourMessage) + + select { + case <-time.After(time.Second * 2): + t.Fatal("did not send message") + case msg := <-gs.network.(*testNetwork).out: + nm, ok := msg.(*NeighbourMessage) + require.True(t, ok) + require.Equal(t, expected, nm) + } +} diff --git a/lib/grandpa/state.go b/lib/grandpa/state.go index c9051d460f..b145d816e2 100644 --- a/lib/grandpa/state.go +++ b/lib/grandpa/state.go @@ -44,7 +44,7 @@ type BlockState interface { BlocktreeAsString() string RegisterImportedChannel(ch chan<- *types.Block) (byte, error) UnregisterImportedChannel(id byte) - RegisterFinalizedChannel(ch chan<- *types.Header) (byte, error) + RegisterFinalizedChannel(ch chan<- *types.FinalisationInfo) (byte, error) UnregisterFinalizedChannel(id byte) SetJustification(hash common.Hash, data []byte) error HasJustification(hash common.Hash) (bool, error)