diff --git a/dot/core/interface.go b/dot/core/interface.go index dd017a96d5..daa6d38a0e 100644 --- a/dot/core/interface.go +++ b/dot/core/interface.go @@ -38,6 +38,7 @@ type BlockState interface { AddBlock(*types.Block) error GetAllBlocksAtDepth(hash common.Hash) []common.Hash GetBlockByHash(common.Hash) (*types.Block, error) + GetBlockStateRoot(bhash common.Hash) (common.Hash, error) GenesisHash() common.Hash GetSlotForBlock(common.Hash) (uint64, error) GetFinalisedHeader(uint64, uint64) (*types.Header, error) @@ -62,6 +63,7 @@ type StorageState interface { StoreTrie(*rtstorage.TrieState, *types.Header) error GetStateRootFromBlock(bhash *common.Hash) (*common.Hash, error) GetStorage(root *common.Hash, key []byte) ([]byte, error) + GenerateTrieProof(stateRoot common.Hash, keys [][]byte) ([][]byte, error) sync.Locker } diff --git a/dot/core/mocks/block_state.go b/dot/core/mocks/block_state.go index 65cc19ed21..efb499f96f 100644 --- a/dot/core/mocks/block_state.go +++ b/dot/core/mocks/block_state.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.8.0. DO NOT EDIT. +// Code generated by mockery v0.0.0-dev. DO NOT EDIT. package mocks @@ -226,6 +226,29 @@ func (_m *MockBlockState) GetBlockByHash(_a0 common.Hash) (*types.Block, error) return r0, r1 } +// GetBlockStateRoot provides a mock function with given fields: bhash +func (_m *MockBlockState) GetBlockStateRoot(bhash common.Hash) (common.Hash, error) { + ret := _m.Called(bhash) + + var r0 common.Hash + if rf, ok := ret.Get(0).(func(common.Hash) common.Hash); ok { + r0 = rf(bhash) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.Hash) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(common.Hash) error); ok { + r1 = rf(bhash) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetFinalisedHash provides a mock function with given fields: _a0, _a1 func (_m *MockBlockState) GetFinalisedHash(_a0 uint64, _a1 uint64) (common.Hash, error) { ret := _m.Called(_a0, _a1) diff --git a/dot/core/mocks/storage_state.go b/dot/core/mocks/storage_state.go new file mode 100644 index 0000000000..654e1e8ec9 --- /dev/null +++ b/dot/core/mocks/storage_state.go @@ -0,0 +1,180 @@ +// Code generated by mockery v0.0.0-dev. DO NOT EDIT. + +package mocks + +import ( + common "github.com/ChainSafe/gossamer/lib/common" + + mock "github.com/stretchr/testify/mock" + + storage "github.com/ChainSafe/gossamer/lib/runtime/storage" + + types "github.com/ChainSafe/gossamer/dot/types" +) + +// MockStorageState is an autogenerated mock type for the StorageState type +type MockStorageState struct { + mock.Mock +} + +// GenerateTrieProof provides a mock function with given fields: stateRoot, keys +func (_m *MockStorageState) GenerateTrieProof(stateRoot common.Hash, keys [][]byte) ([][]byte, error) { + ret := _m.Called(stateRoot, keys) + + var r0 [][]byte + if rf, ok := ret.Get(0).(func(common.Hash, [][]byte) [][]byte); ok { + r0 = rf(stateRoot, keys) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([][]byte) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(common.Hash, [][]byte) error); ok { + r1 = rf(stateRoot, keys) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetStateRootFromBlock provides a mock function with given fields: bhash +func (_m *MockStorageState) GetStateRootFromBlock(bhash *common.Hash) (*common.Hash, error) { + ret := _m.Called(bhash) + + var r0 *common.Hash + if rf, ok := ret.Get(0).(func(*common.Hash) *common.Hash); ok { + r0 = rf(bhash) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*common.Hash) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*common.Hash) error); ok { + r1 = rf(bhash) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetStorage provides a mock function with given fields: root, key +func (_m *MockStorageState) GetStorage(root *common.Hash, key []byte) ([]byte, error) { + ret := _m.Called(root, key) + + var r0 []byte + if rf, ok := ret.Get(0).(func(*common.Hash, []byte) []byte); ok { + r0 = rf(root, key) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*common.Hash, []byte) error); ok { + r1 = rf(root, key) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// LoadCode provides a mock function with given fields: root +func (_m *MockStorageState) LoadCode(root *common.Hash) ([]byte, error) { + ret := _m.Called(root) + + var r0 []byte + if rf, ok := ret.Get(0).(func(*common.Hash) []byte); ok { + r0 = rf(root) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*common.Hash) error); ok { + r1 = rf(root) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// LoadCodeHash provides a mock function with given fields: root +func (_m *MockStorageState) LoadCodeHash(root *common.Hash) (common.Hash, error) { + ret := _m.Called(root) + + var r0 common.Hash + if rf, ok := ret.Get(0).(func(*common.Hash) common.Hash); ok { + r0 = rf(root) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.Hash) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*common.Hash) error); ok { + r1 = rf(root) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Lock provides a mock function with given fields: +func (_m *MockStorageState) Lock() { + _m.Called() +} + +// StoreTrie provides a mock function with given fields: _a0, _a1 +func (_m *MockStorageState) StoreTrie(_a0 *storage.TrieState, _a1 *types.Header) error { + ret := _m.Called(_a0, _a1) + + var r0 error + if rf, ok := ret.Get(0).(func(*storage.TrieState, *types.Header) error); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// TrieState provides a mock function with given fields: root +func (_m *MockStorageState) TrieState(root *common.Hash) (*storage.TrieState, error) { + ret := _m.Called(root) + + var r0 *storage.TrieState + if rf, ok := ret.Get(0).(func(*common.Hash) *storage.TrieState); ok { + r0 = rf(root) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*storage.TrieState) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*common.Hash) error); ok { + r1 = rf(root) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Unlock provides a mock function with given fields: +func (_m *MockStorageState) Unlock() { + _m.Called() +} diff --git a/dot/core/service.go b/dot/core/service.go index b0890a4e76..b466aa3712 100644 --- a/dot/core/service.go +++ b/dot/core/service.go @@ -627,3 +627,23 @@ func (s *Service) tryQueryStorage(block common.Hash, keys ...string) (QueryKeyVa return changes, nil } + +// GetReadProofAt will return an array with the proofs for the keys passed as params +// based on the block hash passed as param as well, if block hash is nil then the current state will take place +func (s *Service) GetReadProofAt(block common.Hash, keys [][]byte) (common.Hash, [][]byte, error) { + if common.EmptyHash.Equal(block) { + block = s.blockState.BestBlockHash() + } + + stateRoot, err := s.blockState.GetBlockStateRoot(block) + if err != nil { + return common.EmptyHash, nil, err + } + + proofForKeys, err := s.storageState.GenerateTrieProof(stateRoot, keys) + if err != nil { + return common.EmptyHash, nil, err + } + + return block, proofForKeys, nil +} diff --git a/dot/core/service_test.go b/dot/core/service_test.go index 5988da2393..9774857e25 100644 --- a/dot/core/service_test.go +++ b/dot/core/service_test.go @@ -26,7 +26,6 @@ import ( "time" "github.com/ChainSafe/gossamer/dot/core/mocks" - coremocks "github.com/ChainSafe/gossamer/dot/core/mocks" "github.com/ChainSafe/gossamer/dot/network" "github.com/ChainSafe/gossamer/dot/state" "github.com/ChainSafe/gossamer/dot/sync" @@ -107,7 +106,7 @@ func TestStartService(t *testing.T) { } func TestAnnounceBlock(t *testing.T) { - net := new(coremocks.MockNetwork) + net := new(mocks.MockNetwork) cfg := &Config{ Network: net, } @@ -829,3 +828,76 @@ func TestDecodeSessionKeys_WhenGetRuntimeReturnError(t *testing.T) { require.Error(t, err, "problems") require.Nil(t, b) } + +func TestGetReadProofAt(t *testing.T) { + keysToProof := [][]byte{[]byte("first_key"), []byte("another_key")} + mockedProofs := [][]byte{[]byte("proof01"), []byte("proof02")} + + t.Run("When Has Block Is Empty", func(t *testing.T) { + mockedStateRootHash := common.NewHash([]byte("state root hash")) + expectedBlockHash := common.NewHash([]byte("expected block hash")) + + mockBlockState := new(mocks.MockBlockState) + mockBlockState.On("BestBlockHash").Return(expectedBlockHash) + mockBlockState.On("GetBlockStateRoot", expectedBlockHash). + Return(mockedStateRootHash, nil) + + mockStorageStage := new(mocks.MockStorageState) + mockStorageStage.On("GenerateTrieProof", mockedStateRootHash, keysToProof). + Return(mockedProofs, nil) + + s := &Service{ + blockState: mockBlockState, + storageState: mockStorageStage, + } + + b, p, err := s.GetReadProofAt(common.EmptyHash, keysToProof) + require.NoError(t, err) + require.Equal(t, p, mockedProofs) + require.Equal(t, expectedBlockHash, b) + + mockBlockState.AssertCalled(t, "BestBlockHash") + mockBlockState.AssertCalled(t, "GetBlockStateRoot", expectedBlockHash) + mockStorageStage.AssertCalled(t, "GenerateTrieProof", mockedStateRootHash, keysToProof) + }) + + t.Run("When GetStateRoot fails", func(t *testing.T) { + mockedBlockHash := common.NewHash([]byte("fake block hash")) + + mockBlockState := new(mocks.MockBlockState) + mockBlockState.On("GetBlockStateRoot", mockedBlockHash). + Return(common.EmptyHash, errors.New("problems while getting state root")) + + s := &Service{ + blockState: mockBlockState, + } + + b, p, err := s.GetReadProofAt(mockedBlockHash, keysToProof) + require.Equal(t, common.EmptyHash, b) + require.Nil(t, p) + require.Error(t, err) + }) + + t.Run("When GenerateTrieProof fails", func(t *testing.T) { + mockedBlockHash := common.NewHash([]byte("fake block hash")) + mockedStateRootHash := common.NewHash([]byte("state root hash")) + + mockBlockState := new(mocks.MockBlockState) + mockBlockState.On("GetBlockStateRoot", mockedBlockHash). + Return(mockedStateRootHash, nil) + + mockStorageStage := new(mocks.MockStorageState) + mockStorageStage.On("GenerateTrieProof", mockedStateRootHash, keysToProof). + Return(nil, errors.New("problems to generate trie proof")) + + s := &Service{ + blockState: mockBlockState, + storageState: mockStorageStage, + } + + b, p, err := s.GetReadProofAt(mockedBlockHash, keysToProof) + require.Equal(t, common.EmptyHash, b) + require.Nil(t, p) + require.Error(t, err) + }) +} diff --git a/dot/rpc/modules/api.go b/dot/rpc/modules/api.go index 8812b09969..65c4646654 100644 --- a/dot/rpc/modules/api.go +++ b/dot/rpc/modules/api.go @@ -86,6 +86,7 @@ type CoreAPI interface { GetMetadata(bhash *common.Hash) ([]byte, error) QueryStorage(from, to common.Hash, keys ...string) (map[common.Hash]core.QueryKeyValueChanges, error) DecodeSessionKeys(enc []byte) ([]byte, error) + GetReadProofAt(block common.Hash, keys [][]byte) (common.Hash, [][]byte, error) } // RPCAPI is the interface for methods related to RPC service diff --git a/dot/rpc/modules/mocks/core_api.go b/dot/rpc/modules/mocks/core_api.go index da382dbda6..fe20861ab2 100644 --- a/dot/rpc/modules/mocks/core_api.go +++ b/dot/rpc/modules/mocks/core_api.go @@ -66,6 +66,38 @@ func (_m *MockCoreAPI) GetMetadata(bhash *common.Hash) ([]byte, error) { return r0, r1 } +// GetReadProofAt provides a mock function with given fields: block, keys +func (_m *MockCoreAPI) GetReadProofAt(block common.Hash, keys [][]byte) (common.Hash, [][]byte, error) { + ret := _m.Called(block, keys) + + var r0 common.Hash + if rf, ok := ret.Get(0).(func(common.Hash, [][]byte) common.Hash); ok { + r0 = rf(block, keys) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.Hash) + } + } + + var r1 [][]byte + if rf, ok := ret.Get(1).(func(common.Hash, [][]byte) [][]byte); ok { + r1 = rf(block, keys) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).([][]byte) + } + } + + var r2 error + if rf, ok := ret.Get(2).(func(common.Hash, [][]byte) error); ok { + r2 = rf(block, keys) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + // GetRuntimeVersion provides a mock function with given fields: bhash func (_m *MockCoreAPI) GetRuntimeVersion(bhash *common.Hash) (runtime.Version, error) { ret := _m.Called(bhash) diff --git a/dot/rpc/modules/state.go b/dot/rpc/modules/state.go index f301b4e8ac..cb9c8dcba8 100644 --- a/dot/rpc/modules/state.go +++ b/dot/rpc/modules/state.go @@ -28,6 +28,12 @@ import ( "github.com/ChainSafe/gossamer/pkg/scale" ) +//StateGetReadProofRequest json fields +type StateGetReadProofRequest struct { + Keys []string + Hash common.Hash +} + // StateCallRequest holds json fields type StateCallRequest struct { Method string `json:"method"` @@ -128,6 +134,12 @@ type StateStorageKeysResponse []string //TODO: Determine actual type type StateMetadataResponse string +//StateGetReadProofResponse holds the response format +type StateGetReadProofResponse struct { + At common.Hash `json:"at"` + Proof []string `json:"proof"` +} + // StorageChangeSetResponse is the struct that holds the block and changes type StorageChangeSetResponse struct { Block *common.Hash `json:"block"` @@ -168,7 +180,7 @@ func NewStateModule(net NetworkAPI, storage StorageAPI, core CoreAPI) *StateModu } // GetPairs returns the keys with prefix, leave empty to get all the keys. -func (sm *StateModule) GetPairs(r *http.Request, req *StatePairRequest, res *StatePairResponse) error { +func (sm *StateModule) GetPairs(_ *http.Request, req *StatePairRequest, res *StatePairResponse) error { // TODO implement change storage trie so that block hash parameter works (See issue #834) var ( stateRootHash *common.Hash @@ -209,38 +221,38 @@ func (sm *StateModule) GetPairs(r *http.Request, req *StatePairRequest, res *Sta } // Call isn't implemented properly yet. -func (sm *StateModule) Call(r *http.Request, req *StateCallRequest, res *StateCallResponse) error { +func (sm *StateModule) Call(_ *http.Request, _ *StateCallRequest, _ *StateCallResponse) error { _ = sm.networkAPI _ = sm.storageAPI return nil } // GetChildKeys isn't implemented properly yet. -func (sm *StateModule) GetChildKeys(r *http.Request, req *StateChildStorageRequest, res *StateKeysResponse) error { +func (*StateModule) GetChildKeys(_ *http.Request, _ *StateChildStorageRequest, _ *StateKeysResponse) error { // TODO implement change storage trie so that block hash parameter works (See issue #834) return nil } // GetChildStorage isn't implemented properly yet. -func (sm *StateModule) GetChildStorage(r *http.Request, req *StateChildStorageRequest, res *StateStorageDataResponse) error { +func (*StateModule) GetChildStorage(_ *http.Request, _ *StateChildStorageRequest, _ *StateStorageDataResponse) error { // TODO implement change storage trie so that block hash parameter works (See issue #834) return nil } // GetChildStorageHash isn't implemented properly yet. -func (sm *StateModule) GetChildStorageHash(r *http.Request, req *StateChildStorageRequest, res *StateChildStorageResponse) error { +func (*StateModule) GetChildStorageHash(_ *http.Request, _ *StateChildStorageRequest, _ *StateChildStorageResponse) error { // TODO implement change storage trie so that block hash parameter works (See issue #834) return nil } // GetChildStorageSize isn't implemented properly yet. -func (sm *StateModule) GetChildStorageSize(r *http.Request, req *StateChildStorageRequest, res *StateChildStorageSizeResponse) error { +func (*StateModule) GetChildStorageSize(_ *http.Request, _ *StateChildStorageRequest, _ *StateChildStorageSizeResponse) error { // TODO implement change storage trie so that block hash parameter works (See issue #834) return nil } // GetKeysPaged Returns the keys with prefix with pagination support. -func (sm *StateModule) GetKeysPaged(r *http.Request, req *StateStorageKeyRequest, res *StateStorageKeysResponse) error { +func (sm *StateModule) GetKeysPaged(_ *http.Request, req *StateStorageKeyRequest, res *StateStorageKeysResponse) error { if req.Prefix == "" { req.Prefix = "0x" } @@ -266,7 +278,7 @@ func (sm *StateModule) GetKeysPaged(r *http.Request, req *StateStorageKeyRequest } // GetMetadata calls runtime Metadata_metadata function -func (sm *StateModule) GetMetadata(r *http.Request, req *StateRuntimeMetadataQuery, res *StateMetadataResponse) error { +func (sm *StateModule) GetMetadata(_ *http.Request, req *StateRuntimeMetadataQuery, res *StateMetadataResponse) error { // TODO implement change storage trie so that block hash parameter works (See issue #834) metadata, err := sm.coreAPI.GetMetadata(req.Bhash) if err != nil { @@ -279,10 +291,40 @@ func (sm *StateModule) GetMetadata(r *http.Request, req *StateRuntimeMetadataQue return err } +// GetReadProof returns the proof to the received storage keys +func (sm *StateModule) GetReadProof(_ *http.Request, req *StateGetReadProofRequest, res *StateGetReadProofResponse) error { + keys := make([][]byte, len(req.Keys)) + for i, hexKey := range req.Keys { + bKey, err := common.HexToBytes(hexKey) + if err != nil { + return err + } + + keys[i] = bKey + } + + block, proofs, err := sm.coreAPI.GetReadProofAt(req.Hash, keys) + if err != nil { + return err + } + + var decProof []string + for _, p := range proofs { + decProof = append(decProof, common.BytesToHex(p)) + } + + *res = StateGetReadProofResponse{ + At: block, + Proof: decProof, + } + + return nil +} + // GetRuntimeVersion Get the runtime version at a given block. // If no block hash is provided, the latest version gets returned. // TODO currently only returns latest version, add functionality to lookup runtime by block hash (see issue #834) -func (sm *StateModule) GetRuntimeVersion(r *http.Request, req *StateRuntimeVersionRequest, res *StateRuntimeVersionResponse) error { +func (sm *StateModule) GetRuntimeVersion(_ *http.Request, req *StateRuntimeVersionRequest, res *StateRuntimeVersionResponse) error { rtVersion, err := sm.coreAPI.GetRuntimeVersion(req.Bhash) if err != nil { return err @@ -300,7 +342,7 @@ func (sm *StateModule) GetRuntimeVersion(r *http.Request, req *StateRuntimeVersi } // GetStorage Returns a storage entry at a specific block's state. If not block hash is provided, the latest value is returned. -func (sm *StateModule) GetStorage(r *http.Request, req *StateStorageRequest, res *StateStorageResponse) error { +func (sm *StateModule) GetStorage(_ *http.Request, req *StateStorageRequest, res *StateStorageResponse) error { var ( item []byte err error @@ -330,7 +372,7 @@ func (sm *StateModule) GetStorage(r *http.Request, req *StateStorageRequest, res // GetStorageHash returns the hash of a storage entry at a block's state. // If no block hash is provided, the latest value is returned. // TODO implement change storage trie so that block hash parameter works (See issue #834) -func (sm *StateModule) GetStorageHash(r *http.Request, req *StateStorageHashRequest, res *StateStorageHashResponse) error { +func (sm *StateModule) GetStorageHash(_ *http.Request, req *StateStorageHashRequest, res *StateStorageHashResponse) error { var ( item []byte err error @@ -360,7 +402,7 @@ func (sm *StateModule) GetStorageHash(r *http.Request, req *StateStorageHashRequ // GetStorageSize returns the size of a storage entry at a block's state. // If no block hash is provided, the latest value is used. // TODO implement change storage trie so that block hash parameter works (See issue #834) -func (sm *StateModule) GetStorageSize(r *http.Request, req *StateStorageSizeRequest, res *StateStorageSizeResponse) error { +func (sm *StateModule) GetStorageSize(_ *http.Request, req *StateStorageSizeRequest, res *StateStorageSizeResponse) error { var ( item []byte err error @@ -388,7 +430,7 @@ func (sm *StateModule) GetStorageSize(r *http.Request, req *StateStorageSizeRequ } // QueryStorage isn't implemented properly yet. -func (sm *StateModule) QueryStorage(r *http.Request, req *StateStorageQueryRangeRequest, res *[]StorageChangeSetResponse) error { +func (sm *StateModule) QueryStorage(_ *http.Request, req *StateStorageQueryRangeRequest, res *[]StorageChangeSetResponse) error { if req.StartBlock == common.EmptyHash { return errors.New("the start block hash cannot be an empty value") } @@ -419,7 +461,7 @@ func (sm *StateModule) QueryStorage(r *http.Request, req *StateStorageQueryRange // SubscribeRuntimeVersion isn't implemented properly yet. // TODO make this actually a subscription that pushes data -func (sm *StateModule) SubscribeRuntimeVersion(r *http.Request, req *StateStorageQueryRangeRequest, res *StateRuntimeVersionResponse) error { +func (sm *StateModule) SubscribeRuntimeVersion(r *http.Request, _ *StateStorageQueryRangeRequest, res *StateRuntimeVersionResponse) error { // TODO implement change storage trie so that block hash parameter works (See issue #834) return sm.GetRuntimeVersion(r, nil, res) } @@ -427,7 +469,7 @@ func (sm *StateModule) SubscribeRuntimeVersion(r *http.Request, req *StateStorag // SubscribeStorage Storage subscription. If storage keys are specified, it creates a message for each block which // changes the specified storage keys. If none are specified, then it creates a message for every block. // This endpoint communicates over the Websocket protocol, but this func should remain here so it's added to rpc_methods list -func (sm *StateModule) SubscribeStorage(r *http.Request, req *StateStorageQueryRangeRequest, res *StorageChangeSetResponse) error { +func (*StateModule) SubscribeStorage(_ *http.Request, _ *StateStorageQueryRangeRequest, _ *StorageChangeSetResponse) error { return nil } diff --git a/dot/rpc/modules/state_test.go b/dot/rpc/modules/state_test.go index c5cf4b169f..5bbbc40e1a 100644 --- a/dot/rpc/modules/state_test.go +++ b/dot/rpc/modules/state_test.go @@ -476,6 +476,53 @@ func TestStateModule_GetKeysPaged(t *testing.T) { } } +func TestGetReadProof_WhenCoreAPIReturnsError(t *testing.T) { + coreAPIMock := new(mocks.MockCoreAPI) + coreAPIMock. + On("GetReadProofAt", mock.AnythingOfType("common.Hash"), mock.AnythingOfType("[][]uint8")). + Return(common.EmptyHash, nil, errors.New("mocked error")) + + sm := new(StateModule) + sm.coreAPI = coreAPIMock + + req := &StateGetReadProofRequest{ + Keys: []string{}, + Hash: common.EmptyHash, + } + err := sm.GetReadProof(nil, req, nil) + require.Error(t, err, "mocked error") +} + +func TestGetReadProof_WhenReturnsProof(t *testing.T) { + expectedBlock := common.BytesToHash([]byte("random hash")) + mockedProof := [][]byte{[]byte("proof-1"), []byte("proof-2")} + + coreAPIMock := new(mocks.MockCoreAPI) + coreAPIMock. + On("GetReadProofAt", mock.AnythingOfType("common.Hash"), mock.AnythingOfType("[][]uint8")). + Return(expectedBlock, mockedProof, nil) + + sm := new(StateModule) + sm.coreAPI = coreAPIMock + + req := &StateGetReadProofRequest{ + Keys: []string{}, + Hash: common.EmptyHash, + } + + res := new(StateGetReadProofResponse) + err := sm.GetReadProof(nil, req, res) + require.NoError(t, err) + require.Equal(t, res.At, expectedBlock) + + expectedProof := []string{ + common.BytesToHex([]byte("proof-1")), + common.BytesToHex([]byte("proof-2")), + } + + require.Equal(t, res.Proof, expectedProof) +} + func setupStateModule(t *testing.T) (*StateModule, *common.Hash, *common.Hash) { // setup service net := newNetworkService(t) diff --git a/dot/rpc/subscription/websocket_test.go b/dot/rpc/subscription/websocket_test.go index 6e5340ade0..e87605c288 100644 --- a/dot/rpc/subscription/websocket_test.go +++ b/dot/rpc/subscription/websocket_test.go @@ -8,7 +8,6 @@ import ( "time" "github.com/ChainSafe/gossamer/dot/rpc/modules/mocks" - modulesmocks "github.com/ChainSafe/gossamer/dot/rpc/modules/mocks" "github.com/ChainSafe/gossamer/dot/rpc/modules" "github.com/ChainSafe/gossamer/dot/types" @@ -232,7 +231,7 @@ func TestWSConn_HandleComm(t *testing.T) { mockedJustBytes, err := scale.Marshal(mockedJust) require.NoError(t, err) - BlockAPI := new(modulesmocks.MockBlockAPI) + BlockAPI := new(mocks.MockBlockAPI) BlockAPI.On("RegisterFinalizedChannel", mock.AnythingOfType("chan<- *types.FinalisationInfo")). Run(func(args mock.Arguments) { ch := args.Get(0).(chan<- *types.FinalisationInfo) diff --git a/dot/state/block.go b/dot/state/block.go index 7ba4c83c7d..dc980d19d0 100644 --- a/dot/state/block.go +++ b/dot/state/block.go @@ -546,6 +546,16 @@ func (bs *BlockState) BestBlockStateRoot() (common.Hash, error) { return header.StateRoot, nil } +// GetBlockStateRoot returns the state root of the given block hash +func (bs *BlockState) GetBlockStateRoot(bhash common.Hash) (common.Hash, error) { + header, err := bs.GetHeader(bhash) + if err != nil { + return common.EmptyHash, err + } + + return header.StateRoot, nil +} + // BestBlockNumber returns the block number of the current head of the chain func (bs *BlockState) BestBlockNumber() (*big.Int, error) { header, err := bs.GetHeader(bs.BestBlockHash()) diff --git a/dot/state/storage.go b/dot/state/storage.go index fd26ac24c0..96e9643ca7 100644 --- a/dot/state/storage.go +++ b/dot/state/storage.go @@ -265,7 +265,7 @@ func (s *StorageState) StorageRoot() (common.Hash, error) { } // EnumeratedTrieRoot not implemented -func (s *StorageState) EnumeratedTrieRoot(values [][]byte) { +func (*StorageState) EnumeratedTrieRoot(_ [][]byte) { //TODO panic("not implemented") } @@ -325,6 +325,11 @@ func (s *StorageState) LoadCodeHash(hash *common.Hash) (common.Hash, error) { return common.Blake2bHash(code) } +// GenerateTrieProof returns the proofs related to the keys on the state root trie +func (s *StorageState) GenerateTrieProof(stateRoot common.Hash, keys [][]byte) ([][]byte, error) { + return trie.GenerateProof(stateRoot[:], keys, s.db) +} + // GetBalance gets the balance for an account with the given public key func (s *StorageState) GetBalance(hash *common.Hash, key [32]byte) (uint64, error) { skey, err := common.BalanceKey(key) diff --git a/go.mod b/go.mod index 61d56ed401..9d22db6e33 100644 --- a/go.mod +++ b/go.mod @@ -51,7 +51,8 @@ require ( golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 - golang.org/x/tools v0.1.5 // indirect + golang.org/x/text v0.3.7 // indirect + golang.org/x/tools v0.1.6-0.20210908145159-e5f719fbe6d5 // indirect google.golang.org/appengine v1.6.6 // indirect google.golang.org/protobuf v1.26.0-rc.1 honnef.co/go/tools v0.2.0 // indirect diff --git a/go.sum b/go.sum index 43d7cfe4cc..9d58898dbb 100644 --- a/go.sum +++ b/go.sum @@ -1089,7 +1089,7 @@ github.com/xlab/treeprint v0.0.0-20180616005107-d6fb6747feb6/go.mod h1:ce1O1j6Ut github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= @@ -1221,7 +1221,6 @@ golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210220033124-5f55cee0dc0d/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d h1:20cMwl2fHAzkJMEA+8J4JgqBQcQGzbisXo31MIeenXI= golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= @@ -1297,12 +1296,11 @@ golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210316164454-77fc1eacc6aa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210324051608-47abb6519492/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210415045647-66c3f260301c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210420205809-ac73e9fd8988/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210426080607-c94f62235c83/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210816183151-1e6c022a8912 h1:uCLL3g5wH2xjxVREVuAbP9JM5PPKjRbXKRa6IBjkzmU= golang.org/x/sys v0.0.0-20210816183151-1e6c022a8912/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= @@ -1315,8 +1313,9 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -1363,8 +1362,8 @@ golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roY golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= -golang.org/x/tools v0.1.5 h1:ouewzE6p+/VEB31YYnTbEJdi8pFqKp4P4n85vwo3DHA= -golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.1.6-0.20210908145159-e5f719fbe6d5 h1:VsF0qTw/aGuPGfWp6z2XrkKtnBt6pKLSsCo2CWW1yL0= +golang.org/x/tools v0.1.6-0.20210908145159-e5f719fbe6d5/go.mod h1:YD9qOF0M9xpSpdWTBbzEl5e/RnCefISl8E5Noe10jFM= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/lib/babe/build_test.go b/lib/babe/build_test.go index 7977ddc081..cb0dd91468 100644 --- a/lib/babe/build_test.go +++ b/lib/babe/build_test.go @@ -305,9 +305,6 @@ func TestApplyExtrinsic(t *testing.T) { } func TestBuildAndApplyExtrinsic(t *testing.T) { - // TODO (ed) currently skipping this because it's failing on github with error: - // failed to sign with subkey: fork/exec /Users/runner/.local/bin/subkey: exec format error - t.Skip() cfg := &ServiceConfig{ TransactionState: state.NewTransactionState(), LogLvl: log.LvlInfo, @@ -331,7 +328,7 @@ func TestBuildAndApplyExtrinsic(t *testing.T) { rawMeta, err := rt.Metadata() require.NoError(t, err) var decoded []byte - err = scale.Unmarshal(rawMeta, []byte{}) + err = scale.Unmarshal(rawMeta, &decoded) require.NoError(t, err) meta := &ctypes.Metadata{} diff --git a/lib/trie/lookup.go b/lib/trie/lookup.go new file mode 100644 index 0000000000..dd8600963e --- /dev/null +++ b/lib/trie/lookup.go @@ -0,0 +1,87 @@ +package trie + +import ( + "bytes" + "errors" + + "github.com/ChainSafe/chaindb" +) + +var ( + // ErrProofNodeNotFound when a needed proof node is not in the database + ErrProofNodeNotFound = errors.New("cannot find a trie node in the database") +) + +// lookup struct holds the state root and database reference +// used to retrieve trie information from database +type lookup struct { + // root to start the lookup + root []byte + db chaindb.Database +} + +// newLookup returns a Lookup to helps the proof generator +func newLookup(rootHash []byte, db chaindb.Database) *lookup { + lk := &lookup{db: db} + lk.root = make([]byte, len(rootHash)) + copy(lk.root, rootHash) + + return lk +} + +// find will return the desired value or nil if key cannot be found and will record visited nodes +func (l *lookup) find(key []byte, recorder *recorder) ([]byte, error) { + partial := key + hash := l.root + + for { + nodeData, err := l.db.Get(hash) + if err != nil { + return nil, ErrProofNodeNotFound + } + + nodeHash := make([]byte, len(hash)) + copy(nodeHash, hash) + + recorder.record(nodeHash, nodeData) + + decoded, err := decodeBytes(nodeData) + if err != nil { + return nil, err + } + + switch currNode := decoded.(type) { + case nil: + return nil, nil + + case *leaf: + if bytes.Equal(currNode.key, partial) { + return currNode.value, nil + } + return nil, nil + + case *branch: + switch len(partial) { + case 0: + return currNode.value, nil + default: + if !bytes.HasPrefix(partial, currNode.key) { + return nil, nil + } + + if bytes.Equal(partial, currNode.key) { + return currNode.value, nil + } + + length := lenCommonPrefix(currNode.key, partial) + switch child := currNode.children[partial[length]].(type) { + case nil: + return nil, nil + default: + partial = partial[length+1:] + copy(hash, child.getHash()) + } + } + } + } +} diff --git a/lib/trie/proof.go b/lib/trie/proof.go index 537b4e7229..7668b69df8 100644 --- a/lib/trie/proof.go +++ b/lib/trie/proof.go @@ -17,9 +17,7 @@ package trie import ( - "bytes" "errors" - "fmt" "github.com/ChainSafe/chaindb" "github.com/ChainSafe/gossamer/lib/common" @@ -28,117 +26,38 @@ import ( var ( // ErrEmptyTrieRoot occurs when trying to craft a prove with an empty trie root ErrEmptyTrieRoot = errors.New("provided trie must have a root") - - // ErrEmptyNibbles occurs when trying to prove or valid a proof to an empty key - ErrEmptyNibbles = errors.New("empty nibbles provided from key") ) -// GenerateProof constructs the merkle-proof for key. The result contains all encoded nodes -// on the path to the key. Returns the amount of nodes of the path and error if could not found the key -func (t *Trie) GenerateProof(key []byte, db chaindb.Writer) (int, error) { - key = keyToNibbles(key) - if len(key) == 0 { - return 0, ErrEmptyNibbles - } - - var nodes []node - currNode := t.root - -proveLoop: - for { - switch n := currNode.(type) { - case nil: - return 0, errors.New("no more paths to follow") - - case *leaf: - nodes = append(nodes, n) - - if bytes.Equal(n.key, key) { - break proveLoop - } - - return 0, errors.New("leaf node doest not match the key") +// GenerateProof receive the keys to proof, the trie root and a reference to database +// will +func GenerateProof(root []byte, keys [][]byte, db chaindb.Database) ([][]byte, error) { + trackedProofs := make(map[string][]byte) - case *branch: - nodes = append(nodes, n) - if bytes.Equal(n.key, key) || len(key) == 0 { - break proveLoop - } - - length := lenCommonPrefix(n.key, key) - currNode = n.children[key[length]] - key = key[length+1:] - } - } + for _, k := range keys { + nk := keyToNibbles(k) - for _, n := range nodes { - var ( - hashNode []byte - encHashNode []byte - err error - ) + lookup := newLookup(root, db) + recorder := new(recorder) - if encHashNode, hashNode, err = n.encodeAndHash(); err != nil { - return 0, fmt.Errorf("problems while encoding and hashing the node: %w", err) + _, err := lookup.find(nk, recorder) + if err != nil { + return nil, err } - if err = db.Put(hashNode, encHashNode); err != nil { - return len(nodes), err + for !recorder.isEmpty() { + recNode := recorder.next() + nodeHashHex := common.BytesToHex(recNode.hash) + if _, ok := trackedProofs[nodeHashHex]; !ok { + trackedProofs[nodeHashHex] = recNode.rawData + } } } - return len(nodes), nil -} + proofs := make([][]byte, 0) -// VerifyProof checks merkle proofs given an proof -func VerifyProof(rootHash common.Hash, key []byte, db chaindb.Reader) (bool, error) { - key = keyToNibbles(key) - if len(key) == 0 { - return false, ErrEmptyNibbles + for _, p := range trackedProofs { + proofs = append(proofs, p) } - var wantedHash []byte - wantedHash = rootHash.ToBytes() - - for { - enc, err := db.Get(wantedHash) - if errors.Is(err, chaindb.ErrKeyNotFound) { - return false, nil - } else if err != nil { - return false, nil - } - - currNode, err := decodeBytes(enc) - if err != nil { - return false, fmt.Errorf("could not decode node bytes: %w", err) - } - - switch n := currNode.(type) { - case nil: - return false, nil - case *leaf: - if bytes.Equal(n.key, key) { - return true, nil - } - - return false, nil - case *branch: - if bytes.Equal(n.key, key) { - return true, nil - } - - if len(key) == 0 { - return false, nil - } - - length := lenCommonPrefix(n.key, key) - next := n.children[key[length]] - if next == nil { - return false, nil - } - - key = key[length+1:] - wantedHash = next.getHash() - } - } + return proofs, nil } diff --git a/lib/trie/proof_test.go b/lib/trie/proof_test.go index 68127e3165..9129d503c2 100644 --- a/lib/trie/proof_test.go +++ b/lib/trie/proof_test.go @@ -17,237 +17,39 @@ package trie import ( - crand "crypto/rand" "io/ioutil" - "math/rand" - "os" - "sync" "testing" "github.com/ChainSafe/chaindb" - "github.com/ChainSafe/gossamer/lib/common" "github.com/stretchr/testify/require" ) -func inMemoryChainDB(t *testing.T) (*chaindb.BadgerDB, func()) { - t.Helper() - - tmpdir, err := ioutil.TempDir("", "trie-chaindb-*") +func TestProofGeneration(t *testing.T) { + tmp, err := ioutil.TempDir("", "*-test-trie") require.NoError(t, err) - db, err := chaindb.NewBadgerDB(&chaindb.Config{ + memdb, err := chaindb.NewBadgerDB(&chaindb.Config{ InMemory: true, - DataDir: tmpdir, + DataDir: tmp, }) require.NoError(t, err) - clear := func() { - err = db.Close() - require.NoError(t, err) - - err = os.RemoveAll(tmpdir) - require.NoError(t, err) - } - - return db, clear -} - -func TestVerifyProof(t *testing.T) { - trie, entries := randomTrie(t, 200) - root, err := trie.Hash() - require.NoError(t, err) - - amount := make(chan struct{}, 15) - wg := new(sync.WaitGroup) - - for _, entry := range entries { - wg.Add(1) - go func(kv *kv) { - defer func() { - wg.Done() - <-amount - }() - - amount <- struct{}{} - - proof, clear := inMemoryChainDB(t) - defer clear() - - _, err := trie.GenerateProof(kv.k, proof) - require.NoError(t, err) - v, err := VerifyProof(root, kv.k, proof) - - require.NoError(t, err) - require.True(t, v) - }(entry) - } - - wg.Wait() -} - -func TestVerifyProofOneElement(t *testing.T) { trie := NewEmptyTrie() - key := randBytes(32) - trie.Put(key, []byte("V")) - - rootHash, err := trie.Hash() - require.NoError(t, err) + trie.Put([]byte("cat"), rand32Bytes()) + trie.Put([]byte("catapulta"), rand32Bytes()) + trie.Put([]byte("catapora"), rand32Bytes()) + trie.Put([]byte("dog"), rand32Bytes()) + trie.Put([]byte("doguinho"), rand32Bytes()) - proof, clear := inMemoryChainDB(t) - defer clear() - - _, err = trie.GenerateProof(key, proof) + err = trie.Store(memdb) require.NoError(t, err) - val, err := VerifyProof(rootHash, key, proof) + hash, err := trie.Hash() require.NoError(t, err) - require.True(t, val) -} - -func TestVerifyProof_BadProof(t *testing.T) { - trie, entries := randomTrie(t, 200) - rootHash, err := trie.Hash() + proof, err := GenerateProof(hash.ToBytes(), [][]byte{[]byte("catapulta"), []byte("catapora")}, memdb) require.NoError(t, err) - amount := make(chan struct{}, 15) - wg := new(sync.WaitGroup) - - for _, entry := range entries { - wg.Add(1) - - go func(kv *kv) { - defer func() { - wg.Done() - <-amount - }() - - amount <- struct{}{} - proof, clear := inMemoryChainDB(t) - defer clear() - - nLen, err := trie.GenerateProof(kv.k, proof) - require.Greater(t, nLen, 0) - require.NoError(t, err) - - it := proof.NewIterator() - for i, d := 0, rand.Intn(nLen); i <= d; i++ { - it.Next() - } - key := it.Key() - val, _ := proof.Get(key) - proof.Del(key) - it.Release() - - newhash, err := common.Keccak256(val) - require.NoError(t, err) - proof.Put(newhash.ToBytes(), val) - - v, err := VerifyProof(rootHash, kv.k, proof) - require.NoError(t, err) - require.False(t, v) - }(entry) - } - - wg.Wait() -} - -func TestGenerateProofEmptyNibbles(t *testing.T) { - k := []byte{} - trie := NewEmptyTrie() - _, err := trie.GenerateProof(k, nil) - require.Error(t, err, ErrEmptyNibbles) -} - -func TestGenerateProofNilRoot(t *testing.T) { - k := []byte{0xff, 0xff} - trie := NewEmptyTrie() - _, err := trie.GenerateProof(k, nil) - - require.Error(t, err, "no more paths to follow") -} - -func TestGenerateProofMissingKey(t *testing.T) { - trie := NewEmptyTrie() - - parentKey, parentVal := randBytes(32), randBytes(20) - chieldKey, chieldValue := modifyLastBytes(parentKey), modifyLastBytes(parentVal) - gransonKey, gransonValue := modifyLastBytes(chieldKey), modifyLastBytes(chieldValue) - - trie.Put(parentKey, parentVal) - trie.Put(chieldKey, chieldValue) - trie.Put(gransonKey, gransonValue) - - proof, clear := inMemoryChainDB(t) - defer clear() - - searchfor := make([]byte, len(gransonKey)) - copy(searchfor[:], gransonKey[:]) - - // keep the path til the key but modify the last element - searchfor[len(searchfor)-1] = searchfor[len(searchfor)-1] + byte(0xff) - - _, err := trie.GenerateProof(searchfor, proof) - require.Error(t, err, "leaf node doest not match the key") -} - -func TestGenerateProofNoMorePathToFollow(t *testing.T) { - trie := NewEmptyTrie() - - parentKey, parentVal := randBytes(32), randBytes(20) - chieldKey, chieldValue := modifyLastBytes(parentKey), modifyLastBytes(parentVal) - gransonKey, gransonValue := modifyLastBytes(chieldKey), modifyLastBytes(chieldValue) - - trie.Put(parentKey, parentVal) - trie.Put(chieldKey, chieldValue) - trie.Put(gransonKey, gransonValue) - - proof, clear := inMemoryChainDB(t) - defer clear() - - searchfor := make([]byte, len(parentKey)) - copy(searchfor[:], parentKey[:]) - - // the keys are equals until the byte number 20 so we modify the byte number 20 to another - // value and the branch node will no be able to found the right slot - searchfor[20] = searchfor[20] + byte(0xff) - - _, err := trie.GenerateProof(searchfor, proof) - require.Error(t, err, "no more paths to follow") -} - -type kv struct { - k []byte - v []byte -} - -func randomTrie(t *testing.T, n int) (*Trie, map[string]*kv) { - t.Helper() - - trie := NewEmptyTrie() - vals := make(map[string]*kv) - - for i := 0; i < n; i++ { - v := &kv{randBytes(32), randBytes(20)} - trie.Put(v.k, v.v) - vals[string(v.k)] = v - } - - return trie, vals -} - -func randBytes(n int) []byte { - r := make([]byte, n) - crand.Read(r) - return r -} - -func modifyLastBytes(b []byte) []byte { - newB := make([]byte, len(b)) - copy(newB[:], b) - - rb := randBytes(12) - copy(newB[20:], rb) - - return newB + // TODO: use the verify_proof function to assert the tests + require.Equal(t, 5, len(proof)) } diff --git a/lib/trie/recorder.go b/lib/trie/recorder.go new file mode 100644 index 0000000000..7c2b9a40c9 --- /dev/null +++ b/lib/trie/recorder.go @@ -0,0 +1,31 @@ +package trie + +// nodeRecord represets a record of a visited node +type nodeRecord struct { + rawData []byte + hash []byte +} + +// Recorder keeps the list of nodes find by Lookup.Find +type recorder []nodeRecord + +// Record insert a node insede the recorded list +func (r *recorder) record(h, rd []byte) { + *r = append(*r, nodeRecord{rawData: rd, hash: h}) +} + +// Next returns the current item the cursor is on and increment the cursor by 1 +func (r *recorder) next() *nodeRecord { + if !r.isEmpty() { + n := (*r)[0] + *r = (*r)[1:] + return &n + } + + return nil +} + +// IsEmpty returns bool if there is data inside the slice +func (r *recorder) isEmpty() bool { + return len(*r) <= 0 +} diff --git a/lib/trie/test_utils.go b/lib/trie/test_utils.go index bb8a87c1be..907e17a0e8 100644 --- a/lib/trie/test_utils.go +++ b/lib/trie/test_utils.go @@ -71,3 +71,9 @@ func generateRandomTest(t testing.TB, kv map[string][]byte) Test { } } } + +func rand32Bytes() []byte { + r := make([]byte, 32) + rand.Read(r) //nolint + return r +}