From ad2d85e29190f3db501b7cb4e1816f51a0c8c31c Mon Sep 17 00:00:00 2001 From: Edward Mack Date: Fri, 24 Sep 2021 14:16:30 -0400 Subject: [PATCH] fix: confirm block import notifier is closed properly (#1736) * add TODOs to identify where block imported channel is handled * added comments for imported channels * create constructor for listeners * added close channel to defer in listen * move imported chan to block_notify * remove comments, lint * handle lint issues * replace imported channel map with sync.Map * fix mocks in listeners test * fix mock functions for new imported notification channel * fix deep source issues * add debugging printf * remove sync.Pool, and sync.Map * handle channel closing * add sleep before close * remove channel close * run go imported * defer importedLock unlock * wrap notifier channel in struct * store channel by interface{} key * update storage key for imported block listeners * refacter GetImportedBlockNotifierChannel arugments * GetImportedBlockNotifierChannel doesn't return error, fixed related test * remove un-needed comments * remove close for FinalisedChannel listener * added test for free imported channel * add mocks paths to .deepsource exclude_patterns --- .deepsource.toml | 4 +- dot/core/interface.go | 4 +- dot/core/mocks/block_state.go | 61 +++++--------- dot/digest/digest.go | 12 +-- dot/digest/interface.go | 4 +- dot/rpc/http.go | 3 +- dot/rpc/modules/api.go | 4 +- dot/rpc/modules/api_mocks.go | 8 +- .../mocks/{BlockAPI.go => block_api.go} | 79 +++++++++---------- dot/rpc/modules/system_test.go | 2 +- dot/rpc/subscription/listeners.go | 40 +++++++--- dot/rpc/subscription/listeners_test.go | 16 ++-- dot/rpc/subscription/websocket.go | 39 ++------- dot/rpc/subscription/websocket_test.go | 29 ++----- dot/state/block.go | 11 +-- dot/state/block_notify.go | 39 ++++----- dot/state/block_notify_test.go | 27 +++---- lib/grandpa/grandpa.go | 13 +-- lib/grandpa/message_tracker.go | 15 +--- lib/grandpa/message_tracker_test.go | 9 +-- lib/grandpa/state.go | 4 +- lib/grandpa/vote_message_test.go | 6 +- 22 files changed, 174 insertions(+), 255 deletions(-) rename dot/rpc/modules/mocks/{BlockAPI.go => block_api.go} (74%) diff --git a/.deepsource.toml b/.deepsource.toml index 1c41320658..9053d32f39 100644 --- a/.deepsource.toml +++ b/.deepsource.toml @@ -11,7 +11,9 @@ exclude_patterns = [ "dot/config/**/*", "dot/rpc/modules/test_data", "lib/runtime/test_data", - "**/*_test.go" + "**/*_test.go", + "**/mocks/*", + "**/mock_*" ] [[analyzers]] diff --git a/dot/core/interface.go b/dot/core/interface.go index 3cdeb0c326..dd017a96d5 100644 --- a/dot/core/interface.go +++ b/dot/core/interface.go @@ -42,8 +42,8 @@ type BlockState interface { GetSlotForBlock(common.Hash) (uint64, error) GetFinalisedHeader(uint64, uint64) (*types.Header, error) GetFinalisedHash(uint64, uint64) (common.Hash, error) - RegisterImportedChannel(ch chan<- *types.Block) (byte, error) - UnregisterImportedChannel(id byte) + GetImportedBlockNotifierChannel() chan *types.Block + FreeImportedBlockNotifierChannel(ch chan *types.Block) RegisterFinalizedChannel(ch chan<- *types.FinalisationInfo) (byte, error) UnregisterFinalisedChannel(id byte) HighestCommonAncestor(a, b common.Hash) (common.Hash, error) diff --git a/dot/core/mocks/block_state.go b/dot/core/mocks/block_state.go index 7686c23dab..65cc19ed21 100644 --- a/dot/core/mocks/block_state.go +++ b/dot/core/mocks/block_state.go @@ -143,6 +143,11 @@ func (_m *MockBlockState) BestBlockStateRoot() (common.Hash, error) { return r0, r1 } +// FreeImportedBlockNotifierChannel provides a mock function with given fields: ch +func (_m *MockBlockState) FreeImportedBlockNotifierChannel(ch chan *types.Block) { + _m.Called(ch) +} + // GenesisHash provides a mock function with given fields: func (_m *MockBlockState) GenesisHash() common.Hash { ret := _m.Called() @@ -267,6 +272,22 @@ func (_m *MockBlockState) GetFinalisedHeader(_a0 uint64, _a1 uint64) (*types.Hea return r0, r1 } +// GetImportedBlockNotifierChannel provides a mock function with given fields: +func (_m *MockBlockState) GetImportedBlockNotifierChannel() chan *types.Block { + ret := _m.Called() + + var r0 chan *types.Block + if rf, ok := ret.Get(0).(func() chan *types.Block); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(chan *types.Block) + } + } + + return r0 +} + // GetRuntime provides a mock function with given fields: _a0 func (_m *MockBlockState) GetRuntime(_a0 *common.Hash) (runtime.Instance, error) { ret := _m.Called(_a0) @@ -369,41 +390,6 @@ func (_m *MockBlockState) RegisterFinalizedChannel(ch chan<- *types.Finalisation return r0, r1 } -// RegisterImportedChannel provides a mock function with given fields: ch -func (_m *MockBlockState) RegisterImportedChannel(ch chan<- *types.Block) (byte, error) { - ret := _m.Called(ch) - - var r0 byte - if rf, ok := ret.Get(0).(func(chan<- *types.Block) byte); ok { - r0 = rf(ch) - } else { - r0 = ret.Get(0).(byte) - } - - var r1 error - if rf, ok := ret.Get(1).(func(chan<- *types.Block) error); ok { - r1 = rf(ch) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// SetFinalisedHash provides a mock function with given fields: _a0, _a1, _a2 -func (_m *MockBlockState) SetFinalisedHash(_a0 common.Hash, _a1 uint64, _a2 uint64) error { - ret := _m.Called(_a0, _a1, _a2) - - var r0 error - if rf, ok := ret.Get(0).(func(common.Hash, uint64, uint64) error); ok { - r0 = rf(_a0, _a1, _a2) - } else { - r0 = ret.Error(0) - } - - return r0 -} - // StoreRuntime provides a mock function with given fields: _a0, _a1 func (_m *MockBlockState) StoreRuntime(_a0 common.Hash, _a1 runtime.Instance) { _m.Called(_a0, _a1) @@ -436,8 +422,3 @@ func (_m *MockBlockState) SubChain(start common.Hash, end common.Hash) ([]common func (_m *MockBlockState) UnregisterFinalisedChannel(id byte) { _m.Called(id) } - -// UnregisterImportedChannel provides a mock function with given fields: id -func (_m *MockBlockState) UnregisterImportedChannel(id byte) { - _m.Called(id) -} diff --git a/dot/digest/digest.go b/dot/digest/digest.go index 34343d243c..30c6c492a2 100644 --- a/dot/digest/digest.go +++ b/dot/digest/digest.go @@ -48,7 +48,6 @@ type Handler struct { // block notification channels imported chan *types.Block - importedID byte finalised chan *types.FinalisationInfo finalisedID byte @@ -74,12 +73,9 @@ type resume struct { // NewHandler returns a new Handler func NewHandler(blockState BlockState, epochState EpochState, grandpaState GrandpaState) (*Handler, error) { - imported := make(chan *types.Block, 16) + imported := blockState.GetImportedBlockNotifierChannel() + finalised := make(chan *types.FinalisationInfo, 16) - iid, err := blockState.RegisterImportedChannel(imported) - if err != nil { - return nil, err - } fid, err := blockState.RegisterFinalizedChannel(finalised) if err != nil { @@ -95,7 +91,6 @@ func NewHandler(blockState BlockState, epochState EpochState, grandpaState Grand epochState: epochState, grandpaState: grandpaState, imported: imported, - importedID: iid, finalised: finalised, finalisedID: fid, }, nil @@ -111,9 +106,8 @@ func (h *Handler) Start() error { // Stop stops the Handler func (h *Handler) Stop() error { h.cancel() - h.blockState.UnregisterImportedChannel(h.importedID) + h.blockState.FreeImportedBlockNotifierChannel(h.imported) h.blockState.UnregisterFinalisedChannel(h.finalisedID) - close(h.imported) close(h.finalised) return nil } diff --git a/dot/digest/interface.go b/dot/digest/interface.go index fbc31cd350..f467b0a7d0 100644 --- a/dot/digest/interface.go +++ b/dot/digest/interface.go @@ -26,8 +26,8 @@ import ( // BlockState interface for block state methods type BlockState interface { BestBlockHeader() (*types.Header, error) - RegisterImportedChannel(ch chan<- *types.Block) (byte, error) - UnregisterImportedChannel(id byte) + GetImportedBlockNotifierChannel() chan *types.Block + FreeImportedBlockNotifierChannel(ch chan *types.Block) RegisterFinalizedChannel(ch chan<- *types.FinalisationInfo) (byte, error) UnregisterFinalisedChannel(id byte) } diff --git a/dot/rpc/http.go b/dot/rpc/http.go index a7675d17e8..a7d109f426 100644 --- a/dot/rpc/http.go +++ b/dot/rpc/http.go @@ -194,8 +194,7 @@ func (h *HTTPServer) Stop() error { case *subscription.StorageObserver: h.serverConfig.StorageAPI.UnregisterStorageObserver(v) case *subscription.BlockListener: - h.serverConfig.BlockAPI.UnregisterImportedChannel(v.ChanID) - close(v.Channel) + h.serverConfig.BlockAPI.FreeImportedBlockNotifierChannel(v.Channel) } } diff --git a/dot/rpc/modules/api.go b/dot/rpc/modules/api.go index c9d4da2e2c..febd9dac80 100644 --- a/dot/rpc/modules/api.go +++ b/dot/rpc/modules/api.go @@ -35,8 +35,8 @@ type BlockAPI interface { GetHighestFinalisedHash() (common.Hash, error) HasJustification(hash common.Hash) (bool, error) GetJustification(hash common.Hash) ([]byte, error) - RegisterImportedChannel(ch chan<- *types.Block) (byte, error) - UnregisterImportedChannel(id byte) + GetImportedBlockNotifierChannel() chan *types.Block + FreeImportedBlockNotifierChannel(ch chan *types.Block) RegisterFinalizedChannel(ch chan<- *types.FinalisationInfo) (byte, error) UnregisterFinalisedChannel(id byte) SubChain(start, end common.Hash) ([]common.Hash, error) diff --git a/dot/rpc/modules/api_mocks.go b/dot/rpc/modules/api_mocks.go index fa1d6c732a..b9f9802eaa 100644 --- a/dot/rpc/modules/api_mocks.go +++ b/dot/rpc/modules/api_mocks.go @@ -2,6 +2,7 @@ package modules import ( modulesmocks "github.com/ChainSafe/gossamer/dot/rpc/modules/mocks" + "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common" runtimemocks "github.com/ChainSafe/gossamer/lib/runtime/mocks" "github.com/stretchr/testify/mock" @@ -21,15 +22,16 @@ func NewMockStorageAPI() *modulesmocks.MockStorageAPI { } // NewMockBlockAPI creates and return an rpc BlockAPI interface mock -func NewMockBlockAPI() *modulesmocks.BlockAPI { - m := new(modulesmocks.BlockAPI) +func NewMockBlockAPI() *modulesmocks.MockBlockAPI { + m := new(modulesmocks.MockBlockAPI) m.On("GetHeader", mock.AnythingOfType("common.Hash")).Return(nil, nil) m.On("BestBlockHash").Return(common.Hash{}) m.On("GetBlockByHash", mock.AnythingOfType("common.Hash")).Return(nil, nil) m.On("GetBlockHash", mock.AnythingOfType("*big.Int")).Return(nil, nil) m.On("GetFinalisedHash", mock.AnythingOfType("uint64"), mock.AnythingOfType("uint64")).Return(common.Hash{}, nil) m.On("GetHighestFinalisedHash").Return(common.Hash{}, nil) - m.On("RegisterImportedChannel", mock.AnythingOfType("chan<- *types.Block")).Return(byte(0), nil) + m.On("GetImportedBlockNotifierChannel").Return(make(chan *types.Block, 5)) + m.On("FreeImportedBlockNotifierChannel", mock.AnythingOfType("chan *types.Block")) m.On("UnregisterImportedChannel", mock.AnythingOfType("uint8")) m.On("RegisterFinalizedChannel", mock.AnythingOfType("chan<- *types.FinalisationInfo")).Return(byte(0), nil) m.On("UnregisterFinalizedChannel", mock.AnythingOfType("uint8")) diff --git a/dot/rpc/modules/mocks/BlockAPI.go b/dot/rpc/modules/mocks/block_api.go similarity index 74% rename from dot/rpc/modules/mocks/BlockAPI.go rename to dot/rpc/modules/mocks/block_api.go index dd6e6db19b..d343fe1627 100644 --- a/dot/rpc/modules/mocks/BlockAPI.go +++ b/dot/rpc/modules/mocks/block_api.go @@ -1,4 +1,4 @@ -// Code generated by mockery v0.0.0-dev. DO NOT EDIT. +// Code generated by mockery v2.8.0. DO NOT EDIT. package mocks @@ -13,13 +13,13 @@ import ( types "github.com/ChainSafe/gossamer/dot/types" ) -// BlockAPI is an autogenerated mock type for the BlockAPI type -type BlockAPI struct { +// MockBlockAPI is an autogenerated mock type for the BlockAPI type +type MockBlockAPI struct { mock.Mock } // BestBlockHash provides a mock function with given fields: -func (_m *BlockAPI) BestBlockHash() common.Hash { +func (_m *MockBlockAPI) BestBlockHash() common.Hash { ret := _m.Called() var r0 common.Hash @@ -34,8 +34,13 @@ func (_m *BlockAPI) BestBlockHash() common.Hash { return r0 } +// FreeImportedBlockNotifierChannel provides a mock function with given fields: ch +func (_m *MockBlockAPI) FreeImportedBlockNotifierChannel(ch chan *types.Block) { + _m.Called(ch) +} + // GetBlockByHash provides a mock function with given fields: hash -func (_m *BlockAPI) GetBlockByHash(hash common.Hash) (*types.Block, error) { +func (_m *MockBlockAPI) GetBlockByHash(hash common.Hash) (*types.Block, error) { ret := _m.Called(hash) var r0 *types.Block @@ -58,7 +63,7 @@ func (_m *BlockAPI) GetBlockByHash(hash common.Hash) (*types.Block, error) { } // GetBlockHash provides a mock function with given fields: blockNumber -func (_m *BlockAPI) GetBlockHash(blockNumber *big.Int) (common.Hash, error) { +func (_m *MockBlockAPI) GetBlockHash(blockNumber *big.Int) (common.Hash, error) { ret := _m.Called(blockNumber) var r0 common.Hash @@ -81,7 +86,7 @@ func (_m *BlockAPI) GetBlockHash(blockNumber *big.Int) (common.Hash, error) { } // GetFinalisedHash provides a mock function with given fields: _a0, _a1 -func (_m *BlockAPI) GetFinalisedHash(_a0 uint64, _a1 uint64) (common.Hash, error) { +func (_m *MockBlockAPI) GetFinalisedHash(_a0 uint64, _a1 uint64) (common.Hash, error) { ret := _m.Called(_a0, _a1) var r0 common.Hash @@ -104,7 +109,7 @@ func (_m *BlockAPI) GetFinalisedHash(_a0 uint64, _a1 uint64) (common.Hash, error } // GetHeader provides a mock function with given fields: hash -func (_m *BlockAPI) GetHeader(hash common.Hash) (*types.Header, error) { +func (_m *MockBlockAPI) GetHeader(hash common.Hash) (*types.Header, error) { ret := _m.Called(hash) var r0 *types.Header @@ -127,7 +132,7 @@ func (_m *BlockAPI) GetHeader(hash common.Hash) (*types.Header, error) { } // GetHighestFinalisedHash provides a mock function with given fields: -func (_m *BlockAPI) GetHighestFinalisedHash() (common.Hash, error) { +func (_m *MockBlockAPI) GetHighestFinalisedHash() (common.Hash, error) { ret := _m.Called() var r0 common.Hash @@ -149,8 +154,24 @@ func (_m *BlockAPI) GetHighestFinalisedHash() (common.Hash, error) { return r0, r1 } +// GetImportedBlockNotifierChannel provides a mock function with given fields: +func (_m *MockBlockAPI) GetImportedBlockNotifierChannel() chan *types.Block { + ret := _m.Called() + + var r0 chan *types.Block + if rf, ok := ret.Get(0).(func() chan *types.Block); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(chan *types.Block) + } + } + + return r0 +} + // GetJustification provides a mock function with given fields: hash -func (_m *BlockAPI) GetJustification(hash common.Hash) ([]byte, error) { +func (_m *MockBlockAPI) GetJustification(hash common.Hash) ([]byte, error) { ret := _m.Called(hash) var r0 []byte @@ -173,7 +194,7 @@ func (_m *BlockAPI) GetJustification(hash common.Hash) ([]byte, error) { } // HasJustification provides a mock function with given fields: hash -func (_m *BlockAPI) HasJustification(hash common.Hash) (bool, error) { +func (_m *MockBlockAPI) HasJustification(hash common.Hash) (bool, error) { ret := _m.Called(hash) var r0 bool @@ -194,7 +215,7 @@ func (_m *BlockAPI) HasJustification(hash common.Hash) (bool, error) { } // RegisterFinalizedChannel provides a mock function with given fields: ch -func (_m *BlockAPI) RegisterFinalizedChannel(ch chan<- *types.FinalisationInfo) (byte, error) { +func (_m *MockBlockAPI) RegisterFinalizedChannel(ch chan<- *types.FinalisationInfo) (byte, error) { ret := _m.Called(ch) var r0 byte @@ -214,29 +235,8 @@ func (_m *BlockAPI) RegisterFinalizedChannel(ch chan<- *types.FinalisationInfo) return r0, r1 } -// RegisterImportedChannel provides a mock function with given fields: ch -func (_m *BlockAPI) RegisterImportedChannel(ch chan<- *types.Block) (byte, error) { - ret := _m.Called(ch) - - var r0 byte - if rf, ok := ret.Get(0).(func(chan<- *types.Block) byte); ok { - r0 = rf(ch) - } else { - r0 = ret.Get(0).(byte) - } - - var r1 error - if rf, ok := ret.Get(1).(func(chan<- *types.Block) error); ok { - r1 = rf(ch) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // RegisterRuntimeUpdatedChannel provides a mock function with given fields: ch -func (_m *BlockAPI) RegisterRuntimeUpdatedChannel(ch chan<- runtime.Version) (uint32, error) { +func (_m *MockBlockAPI) RegisterRuntimeUpdatedChannel(ch chan<- runtime.Version) (uint32, error) { ret := _m.Called(ch) var r0 uint32 @@ -257,7 +257,7 @@ func (_m *BlockAPI) RegisterRuntimeUpdatedChannel(ch chan<- runtime.Version) (ui } // SubChain provides a mock function with given fields: start, end -func (_m *BlockAPI) SubChain(start common.Hash, end common.Hash) ([]common.Hash, error) { +func (_m *MockBlockAPI) SubChain(start common.Hash, end common.Hash) ([]common.Hash, error) { ret := _m.Called(start, end) var r0 []common.Hash @@ -280,17 +280,12 @@ func (_m *BlockAPI) SubChain(start common.Hash, end common.Hash) ([]common.Hash, } // UnregisterFinalisedChannel provides a mock function with given fields: id -func (_m *BlockAPI) UnregisterFinalisedChannel(id byte) { - _m.Called(id) -} - -// UnregisterImportedChannel provides a mock function with given fields: id -func (_m *BlockAPI) UnregisterImportedChannel(id byte) { +func (_m *MockBlockAPI) UnregisterFinalisedChannel(id byte) { _m.Called(id) } // UnregisterRuntimeUpdatedChannel provides a mock function with given fields: id -func (_m *BlockAPI) UnregisterRuntimeUpdatedChannel(id uint32) bool { +func (_m *MockBlockAPI) UnregisterRuntimeUpdatedChannel(id uint32) bool { ret := _m.Called(id) var r0 bool diff --git a/dot/rpc/modules/system_test.go b/dot/rpc/modules/system_test.go index efcd0c7318..7239d317ca 100644 --- a/dot/rpc/modules/system_test.go +++ b/dot/rpc/modules/system_test.go @@ -363,7 +363,7 @@ func TestSyncState(t *testing.T) { Number: big.NewInt(int64(49)), } - blockapiMock := new(mocks.BlockAPI) + blockapiMock := new(mocks.MockBlockAPI) blockapiMock.On("BestBlockHash").Return(fakeCommonHash) blockapiMock.On("GetHeader", fakeCommonHash).Return(fakeHeader, nil).Once() diff --git a/dot/rpc/subscription/listeners.go b/dot/rpc/subscription/listeners.go index 7528f71630..6299d7599c 100644 --- a/dot/rpc/subscription/listeners.go +++ b/dot/rpc/subscription/listeners.go @@ -105,7 +105,7 @@ func (s *StorageObserver) GetFilter() map[string][]byte { } // Listen to satisfy Listener interface (but is no longer used by StorageObserver) -func (s *StorageObserver) Listen() {} +func (*StorageObserver) Listen() {} // Stop to satisfy Listener interface (but is no longer used by StorageObserver) func (s *StorageObserver) Stop() error { @@ -117,18 +117,28 @@ func (s *StorageObserver) Stop() error { type BlockListener struct { Channel chan *types.Block wsconn *WSConn - ChanID byte subID uint32 done chan struct{} cancel chan struct{} cancelTimeout time.Duration } +// NewBlockListener constructor for creating BlockListener +func NewBlockListener(conn *WSConn) *BlockListener { + bl := &BlockListener{ + wsconn: conn, + cancel: make(chan struct{}, 1), + cancelTimeout: defaultCancelTimeout, + done: make(chan struct{}, 1), + } + return bl +} + // Listen implementation of Listen interface to listen for importedChan changes func (l *BlockListener) Listen() { go func() { defer func() { - l.wsconn.BlockAPI.UnregisterImportedChannel(l.ChanID) + l.wsconn.BlockAPI.FreeImportedBlockNotifierChannel(l.Channel) close(l.done) }() @@ -221,7 +231,6 @@ type AllBlocksListener struct { wsconn *WSConn finalizedChanID byte - importedChanID byte subID uint32 done chan struct{} cancel chan struct{} @@ -235,7 +244,6 @@ func newAllBlockListener(conn *WSConn) *AllBlocksListener { cancelTimeout: defaultCancelTimeout, wsconn: conn, finalizedChan: make(chan *types.FinalisationInfo, DEFAULT_BUFFER_SIZE), - importedChan: make(chan *types.Block, DEFAULT_BUFFER_SIZE), } } @@ -243,10 +251,9 @@ func newAllBlockListener(conn *WSConn) *AllBlocksListener { func (l *AllBlocksListener) Listen() { go func() { defer func() { - l.wsconn.BlockAPI.UnregisterImportedChannel(l.importedChanID) + l.wsconn.BlockAPI.FreeImportedBlockNotifierChannel(l.importedChan) l.wsconn.BlockAPI.UnregisterFinalisedChannel(l.finalizedChanID) - close(l.importedChan) close(l.finalizedChan) close(l.done) }() @@ -304,7 +311,6 @@ type ExtrinsicSubmitListener struct { subID uint32 extrinsic types.Extrinsic importedChan chan *types.Block - importedChanID byte importedHash common.Hash finalisedChan chan *types.FinalisationInfo finalisedChanID byte @@ -313,14 +319,28 @@ type ExtrinsicSubmitListener struct { cancelTimeout time.Duration } +// NewExtrinsicSubmitListener constructor to build new ExtrinsicSubmitListener +func NewExtrinsicSubmitListener(conn *WSConn, extBytes []byte) *ExtrinsicSubmitListener { + esl := &ExtrinsicSubmitListener{ + wsconn: conn, + extrinsic: types.Extrinsic(extBytes), + finalisedChan: make(chan *types.FinalisationInfo), + cancel: make(chan struct{}, 1), + done: make(chan struct{}, 1), + cancelTimeout: defaultCancelTimeout, + } + return esl +} + // Listen implementation of Listen interface to listen for importedChan changes func (l *ExtrinsicSubmitListener) Listen() { // listen for imported blocks with extrinsic go func() { defer func() { - l.wsconn.BlockAPI.UnregisterImportedChannel(l.importedChanID) + l.wsconn.BlockAPI.FreeImportedBlockNotifierChannel(l.importedChan) l.wsconn.BlockAPI.UnregisterFinalisedChannel(l.finalisedChanID) close(l.done) + close(l.finalisedChan) }() for { @@ -430,7 +450,7 @@ func (l *RuntimeVersionListener) GetChannelID() uint32 { // Stop to runtimeVersionListener not implemented yet because the listener // does not need to be stoped -func (l *RuntimeVersionListener) Stop() error { return nil } +func (*RuntimeVersionListener) Stop() error { return nil } // GrandpaJustificationListener struct has the finalisedCh and the context to stop the goroutines type GrandpaJustificationListener struct { diff --git a/dot/rpc/subscription/listeners_test.go b/dot/rpc/subscription/listeners_test.go index 589b044751..0e01832ada 100644 --- a/dot/rpc/subscription/listeners_test.go +++ b/dot/rpc/subscription/listeners_test.go @@ -96,8 +96,8 @@ func TestBlockListener_Listen(t *testing.T) { wsconn, ws, cancel := setupWSConn(t) defer cancel() - BlockAPI := new(mocks.BlockAPI) - BlockAPI.On("UnregisterImportedChannel", mock.AnythingOfType("uint8")) + BlockAPI := new(mocks.MockBlockAPI) + BlockAPI.On("FreeImportedBlockNotifierChannel", mock.AnythingOfType("chan *types.Block")) wsconn.BlockAPI = BlockAPI @@ -118,7 +118,7 @@ func TestBlockListener_Listen(t *testing.T) { defer func() { require.NoError(t, bl.Stop()) time.Sleep(time.Millisecond * 10) - BlockAPI.AssertCalled(t, "UnregisterImportedChannel", mock.AnythingOfType("uint8")) + BlockAPI.AssertCalled(t, "FreeImportedBlockNotifierChannel", mock.AnythingOfType("chan *types.Block")) }() notifyChan <- &block @@ -144,7 +144,7 @@ func TestBlockFinalizedListener_Listen(t *testing.T) { wsconn, ws, cancel := setupWSConn(t) defer cancel() - BlockAPI := new(mocks.BlockAPI) + BlockAPI := new(mocks.MockBlockAPI) BlockAPI.On("UnregisterFinalisedChannel", mock.AnythingOfType("uint8")) wsconn.BlockAPI = BlockAPI @@ -195,8 +195,8 @@ func TestExtrinsicSubmitListener_Listen(t *testing.T) { notifyImportedChan := make(chan *types.Block, 100) notifyFinalizedChan := make(chan *types.FinalisationInfo, 100) - BlockAPI := new(mocks.BlockAPI) - BlockAPI.On("UnregisterImportedChannel", mock.AnythingOfType("uint8")) + BlockAPI := new(mocks.MockBlockAPI) + BlockAPI.On("FreeImportedBlockNotifierChannel", mock.AnythingOfType("chan *types.Block")) BlockAPI.On("UnregisterFinalisedChannel", mock.AnythingOfType("uint8")) wsconn.BlockAPI = BlockAPI @@ -226,7 +226,7 @@ func TestExtrinsicSubmitListener_Listen(t *testing.T) { require.NoError(t, esl.Stop()) time.Sleep(time.Millisecond * 10) - BlockAPI.AssertCalled(t, "UnregisterImportedChannel", mock.AnythingOfType("uint8")) + BlockAPI.AssertCalled(t, "FreeImportedBlockNotifierChannel", mock.AnythingOfType("chan *types.Block")) BlockAPI.AssertCalled(t, "UnregisterFinalisedChannel", mock.AnythingOfType("uint8")) }() @@ -270,7 +270,7 @@ func TestGrandpaJustification_Listen(t *testing.T) { mockedJustBytes, err := mockedJust.Encode() require.NoError(t, err) - blockStateMock := new(mocks.BlockAPI) + blockStateMock := new(mocks.MockBlockAPI) blockStateMock.On("GetJustification", mock.AnythingOfType("common.Hash")).Return(mockedJustBytes, nil) blockStateMock.On("UnregisterFinalisedChannel", mock.AnythingOfType("uint8")) wsconn.BlockAPI = blockStateMock diff --git a/dot/rpc/subscription/websocket.go b/dot/rpc/subscription/websocket.go index e9dac52573..ff8742407f 100644 --- a/dot/rpc/subscription/websocket.go +++ b/dot/rpc/subscription/websocket.go @@ -213,25 +213,14 @@ func (c *WSConn) initStorageChangeListener(reqID float64, params interface{}) (L } func (c *WSConn) initBlockListener(reqID float64, _ interface{}) (Listener, error) { - bl := &BlockListener{ - Channel: make(chan *types.Block, DEFAULT_BUFFER_SIZE), - wsconn: c, - cancel: make(chan struct{}, 1), - cancelTimeout: defaultCancelTimeout, - done: make(chan struct{}, 1), - } + bl := NewBlockListener(c) if c.BlockAPI == nil { c.safeSendError(reqID, nil, "error BlockAPI not set") return nil, fmt.Errorf("error BlockAPI not set") } - var err error - bl.ChanID, err = c.BlockAPI.RegisterImportedChannel(bl.Channel) - - if err != nil { - return nil, err - } + bl.Channel = c.BlockAPI.GetImportedBlockNotifierChannel() c.mu.Lock() @@ -286,13 +275,9 @@ func (c *WSConn) initAllBlocksListerner(reqID float64, _ interface{}) (Listener, return nil, fmt.Errorf("error BlockAPI not set") } - var err error - listener.importedChanID, err = c.BlockAPI.RegisterImportedChannel(listener.importedChan) - if err != nil { - c.safeSendError(reqID, nil, "could not register imported channel") - return nil, fmt.Errorf("could not register imported channel") - } + listener.importedChan = c.BlockAPI.GetImportedBlockNotifierChannel() + var err error listener.finalizedChanID, err = c.BlockAPI.RegisterFinalizedChannel(listener.finalizedChan) if err != nil { c.safeSendError(reqID, nil, "could not register finalised channel") @@ -316,23 +301,13 @@ func (c *WSConn) initExtrinsicWatch(reqID float64, params interface{}) (Listener } // listen for built blocks - esl := &ExtrinsicSubmitListener{ - importedChan: make(chan *types.Block, DEFAULT_BUFFER_SIZE), - wsconn: c, - extrinsic: types.Extrinsic(extBytes), - finalisedChan: make(chan *types.FinalisationInfo), - cancel: make(chan struct{}, 1), - done: make(chan struct{}, 1), - cancelTimeout: defaultCancelTimeout, - } + esl := NewExtrinsicSubmitListener(c, extBytes) if c.BlockAPI == nil { return nil, fmt.Errorf("error BlockAPI not set") } - esl.importedChanID, err = c.BlockAPI.RegisterImportedChannel(esl.importedChan) - if err != nil { - return nil, err - } + + esl.importedChan = c.BlockAPI.GetImportedBlockNotifierChannel() esl.finalisedChanID, err = c.BlockAPI.RegisterFinalizedChannel(esl.finalisedChan) if err != nil { diff --git a/dot/rpc/subscription/websocket_test.go b/dot/rpc/subscription/websocket_test.go index 0f87143adc..3ce27444e7 100644 --- a/dot/rpc/subscription/websocket_test.go +++ b/dot/rpc/subscription/websocket_test.go @@ -231,7 +231,7 @@ func TestWSConn_HandleComm(t *testing.T) { mockedJustBytes, err := mockedJust.Encode() require.NoError(t, err) - BlockAPI := new(modulesmocks.BlockAPI) + BlockAPI := new(modulesmocks.MockBlockAPI) BlockAPI.On("RegisterFinalizedChannel", mock.AnythingOfType("chan<- *types.FinalisationInfo")). Run(func(args mock.Arguments) { ch := args.Get(0).(chan<- *types.FinalisationInfo) @@ -287,20 +287,11 @@ func TestSubscribeAllHeads(t *testing.T) { require.NoError(t, err) require.Equal(t, []byte(`{"jsonrpc":"2.0","error":{"code":null,"message":"error BlockAPI not set"},"id":1}`+"\n"), msg) - mockBlockAPI := new(mocks.BlockAPI) - mockBlockAPI.On("RegisterImportedChannel", mock.AnythingOfType("chan<- *types.Block")). - Return(uint8(0), errors.New("some mocked error")).Once() + mockBlockAPI := new(mocks.MockBlockAPI) wsconn.BlockAPI = mockBlockAPI - _, err = wsconn.initAllBlocksListerner(1, nil) - require.Error(t, err, "could not register imported channel") - - _, msg, err = c.ReadMessage() - require.NoError(t, err) - require.Equal(t, []byte(`{"jsonrpc":"2.0","error":{"code":null,"message":"could not register imported channel"},"id":1}`+"\n"), msg) - mockBlockAPI.On("RegisterImportedChannel", mock.AnythingOfType("chan<- *types.Block")). - Return(uint8(10), nil).Once() + mockBlockAPI.On("GetImportedBlockNotifierChannel").Return(make(chan *types.Block)).Once() mockBlockAPI.On("RegisterFinalizedChannel", mock.AnythingOfType("chan<- *types.FinalisationInfo")). Return(uint8(0), errors.New("failed")).Once() @@ -308,17 +299,11 @@ func TestSubscribeAllHeads(t *testing.T) { require.Error(t, err, "could not register finalised channel") c.ReadMessage() - importedChanID := uint8(10) finalizedChanID := uint8(11) var fCh chan<- *types.FinalisationInfo - var iCh chan<- *types.Block - - mockBlockAPI.On("RegisterImportedChannel", mock.AnythingOfType("chan<- *types.Block")). - Run(func(args mock.Arguments) { - ch := args.Get(0).(chan<- *types.Block) - iCh = ch - }).Return(importedChanID, nil).Once() + iCh := make(chan *types.Block) + mockBlockAPI.On("GetImportedBlockNotifierChannel").Return(iCh).Once() mockBlockAPI.On("RegisterFinalizedChannel", mock.AnythingOfType("chan<- *types.FinalisationInfo")). Run(func(args mock.Arguments) { @@ -383,10 +368,10 @@ func TestSubscribeAllHeads(t *testing.T) { require.NoError(t, err) require.Equal(t, []byte(expected+"\n"), msg) - mockBlockAPI.On("UnregisterImportedChannel", importedChanID) + mockBlockAPI.On("FreeImportedBlockNotifierChannel", mock.AnythingOfType("chan *types.Block")) mockBlockAPI.On("UnregisterFinalisedChannel", finalizedChanID) require.NoError(t, l.Stop()) - mockBlockAPI.AssertCalled(t, "UnregisterImportedChannel", importedChanID) + mockBlockAPI.On("FreeImportedBlockNotifierChannel", mock.AnythingOfType("chan *types.Block")) mockBlockAPI.AssertCalled(t, "UnregisterFinalisedChannel", finalizedChanID) } diff --git a/dot/state/block.go b/dot/state/block.go index 1be58b1cb8..7ba4c83c7d 100644 --- a/dot/state/block.go +++ b/dot/state/block.go @@ -51,11 +51,10 @@ type BlockState struct { lastFinalised common.Hash // block notifiers - imported map[byte]chan<- *types.Block + imported map[chan *types.Block]struct{} finalised map[byte]chan<- *types.FinalisationInfo - importedLock sync.RWMutex finalisedLock sync.RWMutex - importedBytePool *common.BytePool + importedLock sync.RWMutex finalisedBytePool *common.BytePool runtimeUpdateSubscriptionsLock sync.RWMutex runtimeUpdateSubscriptions map[uint32]chan<- runtime.Version @@ -74,7 +73,7 @@ func NewBlockState(db chaindb.Database, bt *blocktree.BlockTree) (*BlockState, e dbPath: db.Path(), baseState: NewBaseState(db), db: chaindb.NewTable(db, blockPrefix), - imported: make(map[byte]chan<- *types.Block), + imported: make(map[chan *types.Block]struct{}), finalised: make(map[byte]chan<- *types.FinalisationInfo), pruneKeyCh: make(chan *types.Header, pruneKeyBufferSize), runtimeUpdateSubscriptions: make(map[uint32]chan<- runtime.Version), @@ -91,7 +90,6 @@ func NewBlockState(db chaindb.Database, bt *blocktree.BlockTree) (*BlockState, e return nil, fmt.Errorf("failed to get last finalised hash: %w", err) } - bs.importedBytePool = common.NewBytePool256() bs.finalisedBytePool = common.NewBytePool256() return bs, nil } @@ -102,7 +100,7 @@ func NewBlockStateFromGenesis(db chaindb.Database, header *types.Header) (*Block bt: blocktree.NewBlockTreeFromRoot(header, db), baseState: NewBaseState(db), db: chaindb.NewTable(db, blockPrefix), - imported: make(map[byte]chan<- *types.Block), + imported: make(map[chan *types.Block]struct{}), finalised: make(map[byte]chan<- *types.FinalisationInfo), pruneKeyCh: make(chan *types.Header, pruneKeyBufferSize), runtimeUpdateSubscriptions: make(map[uint32]chan<- runtime.Version), @@ -135,7 +133,6 @@ func NewBlockStateFromGenesis(db chaindb.Database, header *types.Header) (*Block return nil, err } - bs.importedBytePool = common.NewBytePool256() bs.finalisedBytePool = common.NewBytePool256() return bs, nil } diff --git a/dot/state/block_notify.go b/dot/state/block_notify.go index 34e674c200..ce58e0304e 100644 --- a/dot/state/block_notify.go +++ b/dot/state/block_notify.go @@ -26,22 +26,17 @@ import ( "github.com/google/uuid" ) -// RegisterImportedChannel registers a channel for block notification upon block import. -// It returns the channel ID (used for unregistering the channel) -func (bs *BlockState) RegisterImportedChannel(ch chan<- *types.Block) (byte, error) { - bs.importedLock.RLock() - - id, err := bs.importedBytePool.Get() - if err != nil { - return 0, err - } - - bs.importedLock.RUnlock() +// DEFAULT_BUFFER_SIZE buffer size for channels +const DEFAULT_BUFFER_SIZE = 100 +// GetImportedBlockNotifierChannel function to retrieve a imported block notifier channel +func (bs *BlockState) GetImportedBlockNotifierChannel() chan *types.Block { bs.importedLock.Lock() - bs.imported[id] = ch - bs.importedLock.Unlock() - return id, nil + defer bs.importedLock.Unlock() + + ch := make(chan *types.Block, DEFAULT_BUFFER_SIZE) + bs.imported[ch] = struct{}{} + return ch } // RegisterFinalizedChannel registers a channel for block notification upon block finalisation. @@ -62,17 +57,11 @@ func (bs *BlockState) RegisterFinalizedChannel(ch chan<- *types.FinalisationInfo return id, nil } -// UnregisterImportedChannel removes the block import notification channel with the given ID. -// A channel must be unregistered before closing it. -func (bs *BlockState) UnregisterImportedChannel(id byte) { +// FreeImportedBlockNotifierChannel to free imported block notifier channel +func (bs *BlockState) FreeImportedBlockNotifierChannel(ch chan *types.Block) { bs.importedLock.Lock() defer bs.importedLock.Unlock() - - delete(bs.imported, id) - err := bs.importedBytePool.Put(id) - if err != nil { - logger.Error("failed to unregister imported channel", "error", err) - } + delete(bs.imported, ch) } // UnregisterFinalisedChannel removes the block finalisation notification channel with the given ID. @@ -97,8 +86,8 @@ func (bs *BlockState) notifyImported(block *types.Block) { } logger.Trace("notifying imported block chans...", "chans", bs.imported) - for _, ch := range bs.imported { - go func(ch chan<- *types.Block) { + for ch := range bs.imported { + go func(ch chan *types.Block) { select { case ch <- block: default: diff --git a/dot/state/block_notify_test.go b/dot/state/block_notify_test.go index 6135fbaad4..cd9bcafe6c 100644 --- a/dot/state/block_notify_test.go +++ b/dot/state/block_notify_test.go @@ -32,12 +32,9 @@ var testMessageTimeout = time.Second * 3 func TestImportChannel(t *testing.T) { bs := newTestBlockState(t, testGenesisHeader) + ch := bs.GetImportedBlockNotifierChannel() - ch := make(chan *types.Block, 3) - id, err := bs.RegisterImportedChannel(ch) - require.NoError(t, err) - - defer bs.UnregisterImportedChannel(id) + defer bs.FreeImportedBlockNotifierChannel(ch) AddBlocksToState(t, bs, 3) @@ -50,6 +47,15 @@ func TestImportChannel(t *testing.T) { } } +func TestFreeImportedBlockNotifierChannel(t *testing.T) { + bs := newTestBlockState(t, testGenesisHeader) + ch := bs.GetImportedBlockNotifierChannel() + require.Equal(t, 1, len(bs.imported)) + + bs.FreeImportedBlockNotifierChannel(ch) + require.Equal(t, 0, len(bs.imported)) +} + func TestFinalizedChannel(t *testing.T) { bs := newTestBlockState(t, testGenesisHeader) @@ -79,13 +85,9 @@ func TestImportChannel_Multi(t *testing.T) { num := 5 chs := make([]chan *types.Block, num) - ids := make([]byte, num) - var err error for i := 0; i < num; i++ { - chs[i] = make(chan *types.Block) - ids[i], err = bs.RegisterImportedChannel(chs[i]) - require.NoError(t, err) + chs[i] = bs.GetImportedBlockNotifierChannel() } var wg sync.WaitGroup @@ -93,7 +95,7 @@ func TestImportChannel_Multi(t *testing.T) { for i, ch := range chs { - go func(i int, ch chan *types.Block) { + go func(i int, ch <-chan *types.Block) { select { case b := <-ch: require.Equal(t, big.NewInt(1), b.Header.Number) @@ -109,9 +111,6 @@ func TestImportChannel_Multi(t *testing.T) { AddBlocksToState(t, bs, 1) wg.Wait() - for _, id := range ids { - bs.UnregisterImportedChannel(id) - } } func TestFinalizedChannel_Multi(t *testing.T) { diff --git a/lib/grandpa/grandpa.go b/lib/grandpa/grandpa.go index f144919732..1098e91943 100644 --- a/lib/grandpa/grandpa.go +++ b/lib/grandpa/grandpa.go @@ -326,10 +326,7 @@ func (s *Service) initiateRound() error { s.precommits = new(sync.Map) s.pvEquivocations = make(map[ed25519.PublicKeyBytes][]*SignedVote) s.pcEquivocations = make(map[ed25519.PublicKeyBytes][]*SignedVote) - s.tracker, err = newTracker(s.blockState, s.messageHandler) - if err != nil { - return err - } + s.tracker = newTracker(s.blockState, s.messageHandler) s.tracker.start() logger.Trace("started message tracker") s.roundLock.Unlock() @@ -376,13 +373,9 @@ func (s *Service) initiate() error { } func (s *Service) waitForFirstBlock() error { - ch := make(chan *types.Block) - id, err := s.blockState.RegisterImportedChannel(ch) - if err != nil { - return err - } + ch := s.blockState.GetImportedBlockNotifierChannel() - defer s.blockState.UnregisterImportedChannel(id) + defer s.blockState.FreeImportedBlockNotifierChannel(ch) // loop until block 1 for { diff --git a/lib/grandpa/message_tracker.go b/lib/grandpa/message_tracker.go index da0ae28cd2..e66cac24db 100644 --- a/lib/grandpa/message_tracker.go +++ b/lib/grandpa/message_tracker.go @@ -33,16 +33,11 @@ type tracker struct { commitMessages map[common.Hash]*CommitMessage // map of commit block hash to commit message mapLock sync.Mutex in chan *types.Block // receive imported block from BlockState - chanID byte // BlockState channel ID stopped chan struct{} } -func newTracker(bs BlockState, handler *MessageHandler) (*tracker, error) { - in := make(chan *types.Block, 16) - id, err := bs.RegisterImportedChannel(in) - if err != nil { - return nil, err - } +func newTracker(bs BlockState, handler *MessageHandler) *tracker { + in := bs.GetImportedBlockNotifierChannel() return &tracker{ blockState: bs, @@ -51,9 +46,8 @@ func newTracker(bs BlockState, handler *MessageHandler) (*tracker, error) { commitMessages: make(map[common.Hash]*CommitMessage), mapLock: sync.Mutex{}, in: in, - chanID: id, stopped: make(chan struct{}), - }, nil + } } func (t *tracker) start() { @@ -62,8 +56,7 @@ func (t *tracker) start() { func (t *tracker) stop() { close(t.stopped) - t.blockState.UnregisterImportedChannel(t.chanID) - close(t.in) + t.blockState.FreeImportedBlockNotifierChannel(t.in) } func (t *tracker) addVote(v *networkVoteMessage) { diff --git a/lib/grandpa/message_tracker_test.go b/lib/grandpa/message_tracker_test.go index 9d1e3644c3..b4d3671a46 100644 --- a/lib/grandpa/message_tracker_test.go +++ b/lib/grandpa/message_tracker_test.go @@ -35,8 +35,7 @@ func TestMessageTracker_ValidateMessage(t *testing.T) { gs, _, _, _ := setupGrandpa(t, kr.Bob().(*ed25519.Keypair)) state.AddBlocksToState(t, gs.blockState.(*state.BlockState), 3) - gs.tracker, err = newTracker(gs.blockState, gs.messageHandler) - require.NoError(t, err) + gs.tracker = newTracker(gs.blockState, gs.messageHandler) fake := &types.Header{ Number: big.NewInt(77), @@ -62,8 +61,7 @@ func TestMessageTracker_SendMessage(t *testing.T) { gs, in, _, _ := setupGrandpa(t, kr.Bob().(*ed25519.Keypair)) state.AddBlocksToState(t, gs.blockState.(*state.BlockState), 3) - gs.tracker, err = newTracker(gs.blockState, gs.messageHandler) - require.NoError(t, err) + gs.tracker = newTracker(gs.blockState, gs.messageHandler) gs.tracker.start() defer gs.tracker.stop() @@ -156,8 +154,7 @@ func TestMessageTracker_MapInsideMap(t *testing.T) { gs, _, _, _ := setupGrandpa(t, kr.Bob().(*ed25519.Keypair)) state.AddBlocksToState(t, gs.blockState.(*state.BlockState), 3) - gs.tracker, err = newTracker(gs.blockState, gs.messageHandler) - require.NoError(t, err) + gs.tracker = newTracker(gs.blockState, gs.messageHandler) header := &types.Header{ Number: big.NewInt(77), diff --git a/lib/grandpa/state.go b/lib/grandpa/state.go index 65dc644273..6ff29d9940 100644 --- a/lib/grandpa/state.go +++ b/lib/grandpa/state.go @@ -41,8 +41,8 @@ type BlockState interface { BestBlockHash() common.Hash Leaves() []common.Hash BlocktreeAsString() string - RegisterImportedChannel(ch chan<- *types.Block) (byte, error) - UnregisterImportedChannel(id byte) + GetImportedBlockNotifierChannel() chan *types.Block + FreeImportedBlockNotifierChannel(ch chan *types.Block) RegisterFinalizedChannel(ch chan<- *types.FinalisationInfo) (byte, error) UnregisterFinalisedChannel(id byte) SetJustification(hash common.Hash, data []byte) error diff --git a/lib/grandpa/vote_message_test.go b/lib/grandpa/vote_message_test.go index 44287da798..d2f0b7e929 100644 --- a/lib/grandpa/vote_message_test.go +++ b/lib/grandpa/vote_message_test.go @@ -337,8 +337,7 @@ func TestValidateMessage_BlockDoesNotExist(t *testing.T) { gs, err := NewService(cfg) require.NoError(t, err) state.AddBlocksToState(t, st.Block, 3) - gs.tracker, err = newTracker(st.Block, gs.messageHandler) - require.NoError(t, err) + gs.tracker = newTracker(st.Block, gs.messageHandler) fake := &types.Header{ Number: big.NewInt(77), @@ -371,8 +370,7 @@ func TestValidateMessage_IsNotDescendant(t *testing.T) { gs, err := NewService(cfg) require.NoError(t, err) - gs.tracker, err = newTracker(gs.blockState, gs.messageHandler) - require.NoError(t, err) + gs.tracker = newTracker(gs.blockState, gs.messageHandler) var branches []*types.Header for {