diff --git a/cmd/gossamer/commands/import_state.go b/cmd/gossamer/commands/import_state.go index a12f137ba9..43d24a4588 100644 --- a/cmd/gossamer/commands/import_state.go +++ b/cmd/gossamer/commands/import_state.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/ChainSafe/gossamer/dot" + "github.com/ChainSafe/gossamer/lib/trie" "github.com/ChainSafe/gossamer/lib/utils" "github.com/spf13/cobra" ) @@ -14,6 +15,10 @@ import ( func init() { ImportStateCmd.Flags().String("chain", "", "Chain id used to load default configuration for specified chain") ImportStateCmd.Flags().String("state-file", "", "Path to JSON file consisting of key-value pairs") + ImportStateCmd.Flags().Uint32("state-version", + uint32(trie.DefaultStateVersion), + "State version to use when importing state", + ) ImportStateCmd.Flags().String("header-file", "", "Path to JSON file of block header corresponding to the given state") ImportStateCmd.Flags().Uint64("first-slot", 0, "The first BABE slot of the network") } @@ -26,7 +31,8 @@ var ImportStateCmd = &cobra.Command{ in the form of key-value pairs to be imported. Input can be generated by using the RPC function state_getPairs. Example: - gossamer import-state --state-file state.json --header-file header.json --first-slot `, + gossamer import-state --state-file state.json --state-version 1 --header-file header.json + --first-slot `, RunE: func(cmd *cobra.Command, args []string) error { return execImportState(cmd) }, @@ -54,6 +60,15 @@ func execImportState(cmd *cobra.Command) error { return fmt.Errorf("state-file must be specified") } + stateVersion, err := cmd.Flags().GetUint32("state-version") + if err != nil { + return fmt.Errorf("failed to get state-version: %s", err) + } + stateTrieVersion, err := trie.ParseVersion(stateVersion) + if err != nil { + return fmt.Errorf("invalid state version") + } + headerFile, err := cmd.Flags().GetString("header-file") if err != nil { return fmt.Errorf("failed to get header-file: %s", err) @@ -64,5 +79,5 @@ func execImportState(cmd *cobra.Command) error { basePath = utils.ExpandDir(basePath) - return dot.ImportState(basePath, stateFile, headerFile, firstSlot) + return dot.ImportState(basePath, stateFile, headerFile, stateTrieVersion, firstSlot) } diff --git a/cmd/gossamer/commands/import_state_test.go b/cmd/gossamer/commands/import_state_test.go new file mode 100644 index 0000000000..6406023a42 --- /dev/null +++ b/cmd/gossamer/commands/import_state_test.go @@ -0,0 +1,57 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package commands + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestImportStateMissingStateFile(t *testing.T) { + rootCmd, err := NewRootCommand() + require.NoError(t, err) + rootCmd.AddCommand(ImportStateCmd) + + rootCmd.SetArgs([]string{ImportStateCmd.Name()}) + err = rootCmd.Execute() + assert.ErrorContains(t, err, "state-file must be specified") +} + +func TestImportStateInvalidFirstSlot(t *testing.T) { + rootCmd, err := NewRootCommand() + require.NoError(t, err) + rootCmd.AddCommand(ImportStateCmd) + + rootCmd.SetArgs([]string{ImportStateCmd.Name(), "--first-slot", "wrong"}) + err = rootCmd.Execute() + assert.ErrorContains(t, err, "invalid argument \"wrong\"") +} + +func TestImportStateEmptyHeaderFile(t *testing.T) { + rootCmd, err := NewRootCommand() + require.NoError(t, err) + rootCmd.AddCommand(ImportStateCmd) + + rootCmd.SetArgs([]string{ImportStateCmd.Name(), + "--state-file", "test", + "--header-file", "", + }) + err = rootCmd.Execute() + assert.ErrorContains(t, err, "header-file must be specified") +} + +func TestImportStateErrorImportingState(t *testing.T) { + rootCmd, err := NewRootCommand() + require.NoError(t, err) + rootCmd.AddCommand(ImportStateCmd) + + rootCmd.SetArgs([]string{ImportStateCmd.Name(), + "--state-file", "test", + "--header-file", "test", + }) + err = rootCmd.Execute() + assert.ErrorContains(t, err, "no such file or directory") +} diff --git a/dot/core/helpers_test.go b/dot/core/helpers_test.go index 53a2af4d78..4c47371437 100644 --- a/dot/core/helpers_test.go +++ b/dot/core/helpers_test.go @@ -57,7 +57,7 @@ func createTestService(t *testing.T, genesisFilePath string, require.NoError(t, err) genesisHeader := &types.Header{ - StateRoot: genesisTrie.MustHash(), + StateRoot: trie.V0.MustHash(genesisTrie), Number: 0, } @@ -271,7 +271,7 @@ func newWestendLocalWithTrieAndHeader(t *testing.T) ( require.NoError(t, err) parentHash := common.NewHash([]byte{0}) - stateRoot := genesisTrie.MustHash() + stateRoot := trie.V0.MustHash(genesisTrie) extrinsicRoot := trie.EmptyHash const number = 0 digest := types.NewDigest() diff --git a/dot/core/service.go b/dot/core/service.go index 07204a59e8..a7dc445a3e 100644 --- a/dot/core/service.go +++ b/dot/core/service.go @@ -20,6 +20,7 @@ import ( rtstorage "github.com/ChainSafe/gossamer/lib/runtime/storage" wazero_runtime "github.com/ChainSafe/gossamer/lib/runtime/wazero" "github.com/ChainSafe/gossamer/lib/transaction" + "github.com/ChainSafe/gossamer/lib/trie" cscale "github.com/centrifuge/go-substrate-rpc-client/v4/scale" ctypes "github.com/centrifuge/go-substrate-rpc-client/v4/types" @@ -116,6 +117,21 @@ func (s *Service) Stop() error { return nil } +func (s *Service) getCurrentStateTrieVersion() (trie.TrieLayout, error) { + bestBlockHash := s.blockState.BestBlockHash() + rt, err := s.blockState.GetRuntime(bestBlockHash) + if err != nil { + return trie.NoVersion, err + } + + runtimeVersion, err := rt.Version() + if err != nil { + return trie.NoVersion, err + } + + return trie.ParseVersion(runtimeVersion.StateVersion) +} + // StorageRoot returns the hash of the storage root func (s *Service) StorageRoot() (common.Hash, error) { ts, err := s.storageState.TrieState(nil) @@ -123,7 +139,12 @@ func (s *Service) StorageRoot() (common.Hash, error) { return common.Hash{}, err } - return ts.Root() + stateTrieVersion, err := s.getCurrentStateTrieVersion() + if err != nil { + return common.Hash{}, err + } + + return stateTrieVersion.Hash(ts.Trie()) } // HandleBlockImport handles a block that was imported via the network @@ -226,7 +247,7 @@ func (s *Service) handleBlock(block *types.Block, state *rtstorage.TrieState) er } logger.Debugf("imported block %s and stored state trie with root %s", - block.Header.Hash(), state.MustRoot()) + block.Header.Hash(), state.MustRoot(trie.NoMaxInlineValueSize)) parentRuntimeInstance, err := s.blockState.GetRuntime(block.Header.ParentHash) if err != nil { diff --git a/dot/core/service_test.go b/dot/core/service_test.go index 629ee0a089..a21b516779 100644 --- a/dot/core/service_test.go +++ b/dot/core/service_test.go @@ -137,6 +137,7 @@ func Test_Service_StorageRoot(t *testing.T) { retErr error expErr error expErrMsg string + stateVersion uint32 }{ { name: "storage trie state error", @@ -147,12 +148,22 @@ func Test_Service_StorageRoot(t *testing.T) { trieStateCall: true, }, { - name: "storage trie state ok", + name: "storage trie state ok v0", service: &Service{}, exp: common.Hash{0x3, 0x17, 0xa, 0x2e, 0x75, 0x97, 0xb7, 0xb7, 0xe3, 0xd8, 0x4c, 0x5, 0x39, 0x1d, 0x13, 0x9a, 0x62, 0xb1, 0x57, 0xe7, 0x87, 0x86, 0xd8, 0xc0, 0x82, 0xf2, 0x9d, 0xcf, 0x4c, 0x11, 0x13, 0x14}, retTrieState: ts, trieStateCall: true, + stateVersion: 0, + }, + { + name: "storage trie state ok v1", + service: &Service{}, + exp: common.Hash{0x3, 0x17, 0xa, 0x2e, 0x75, 0x97, 0xb7, 0xb7, 0xe3, 0xd8, 0x4c, 0x5, 0x39, 0x1d, 0x13, 0x9a, + 0x62, 0xb1, 0x57, 0xe7, 0x87, 0x86, 0xd8, 0xc0, 0x82, 0xf2, 0x9d, 0xcf, 0x4c, 0x11, 0x13, 0x14}, + retTrieState: ts, + trieStateCall: true, + stateVersion: 1, }, } for _, tt := range tests { @@ -164,7 +175,23 @@ func Test_Service_StorageRoot(t *testing.T) { ctrl := gomock.NewController(t) mockStorageState := NewMockStorageState(ctrl) mockStorageState.EXPECT().TrieState(nil).Return(tt.retTrieState, tt.retErr) + service.storageState = mockStorageState + + if tt.retErr == nil { + mockRuntimeVersion := runtime.Version{ + StateVersion: tt.stateVersion, + } + + mockRuntime := NewMockInstance(ctrl) + mockRuntime.EXPECT().Version().Return(mockRuntimeVersion, nil) + + mockBlockState := NewMockBlockState(ctrl) + mockBlockState.EXPECT().BestBlockHash().Return(common.Hash{}) + mockBlockState.EXPECT().GetRuntime(gomock.Any()).Return(mockRuntime, nil) + + service.blockState = mockBlockState + } } res, err := service.StorageRoot() diff --git a/dot/digest/helpers_test.go b/dot/digest/helpers_test.go index 10298488a9..8637ccbaf2 100644 --- a/dot/digest/helpers_test.go +++ b/dot/digest/helpers_test.go @@ -28,7 +28,10 @@ func newWestendDevGenesisWithTrieAndHeader(t *testing.T) ( require.NoError(t, err) parentHash := common.NewHash([]byte{0}) - stateRoot := genesisTrie.MustHash() + + // We are using state trie V0 since we are using the genesis trie where v0 is used + stateRoot := trie.V0.MustHash(genesisTrie) + extrinsicRoot := trie.EmptyHash const number = 0 digest := types.NewDigest() diff --git a/dot/helpers_test.go b/dot/helpers_test.go index 88a22bbb37..a4539a9299 100644 --- a/dot/helpers_test.go +++ b/dot/helpers_test.go @@ -42,7 +42,10 @@ func newWestendDevGenesisWithTrieAndHeader(t *testing.T) ( require.NoError(t, err) parentHash := common.NewHash([]byte{0}) - stateRoot := genesisTrie.MustHash() + + // We are using state trie V0 since we are using the genesis trie where v0 is used + stateRoot := trie.V0.MustHash(genesisTrie) + extrinsicRoot := trie.EmptyHash const number = 0 digest := types.NewDigest() diff --git a/dot/import.go b/dot/import.go index c5d9c85f4e..208d648765 100644 --- a/dot/import.go +++ b/dot/import.go @@ -20,7 +20,7 @@ import ( ) // ImportState imports the state in the given files to the database with the given path. -func ImportState(basepath, stateFP, headerFP string, firstSlot uint64) error { +func ImportState(basepath, stateFP, headerFP string, stateTrieVersion trie.TrieLayout, firstSlot uint64) error { tr, err := newTrieFromPairs(stateFP) if err != nil { return err @@ -38,7 +38,7 @@ func ImportState(basepath, stateFP, headerFP string, firstSlot uint64) error { LogLevel: log.Info, } srv := state.NewService(config) - return srv.Import(header, tr, firstSlot) + return srv.Import(header, tr, stateTrieVersion, firstSlot) } func newTrieFromPairs(filename string) (*trie.Trie, error) { diff --git a/dot/import_integration_test.go b/dot/import_integration_test.go index 534f50b4cf..436e4c2204 100644 --- a/dot/import_integration_test.go +++ b/dot/import_integration_test.go @@ -13,6 +13,7 @@ import ( "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/internal/log" "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/trie" "github.com/ChainSafe/gossamer/pkg/scale" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -22,10 +23,11 @@ func Test_newTrieFromPairs(t *testing.T) { t.Parallel() tests := []struct { - name string - filename string - want common.Hash - err error + name string + filename string + want common.Hash + stateVersion trie.TrieLayout + err error }{ { name: "no_arguments", @@ -33,9 +35,16 @@ func Test_newTrieFromPairs(t *testing.T) { want: common.Hash{}, }, { - name: "working example", - filename: setupStateFile(t), - want: common.MustHexToHash("0x09f9ca28df0560c2291aa16b56e15e07d1e1927088f51356d522722aa90ca7cb"), + name: "working example", + filename: setupStateFile(t), + want: common.MustHexToHash("0x09f9ca28df0560c2291aa16b56e15e07d1e1927088f51356d522722aa90ca7cb"), + stateVersion: trie.V0, + }, + { + name: "working example", + filename: setupStateFile(t), + want: common.MustHexToHash("0xcc25fe024a58297658e576e2e4c33691fe3a9fe5a7cdd2e55534164a0fcc0782"), + stateVersion: trie.V1, }, } for _, tt := range tests { @@ -52,7 +61,7 @@ func Test_newTrieFromPairs(t *testing.T) { if tt.want.IsEmpty() { assert.Nil(t, got) } else { - assert.Equal(t, tt.want, got.MustHash()) + assert.Equal(t, tt.want, tt.stateVersion.MustHash(*got)) } }) } @@ -93,7 +102,7 @@ func TestImportState_Integration(t *testing.T) { headerFP := setupHeaderFile(t) const firstSlot = uint64(262493679) - err = ImportState(config.BasePath, stateFP, headerFP, firstSlot) + err = ImportState(config.BasePath, stateFP, headerFP, trie.V0, firstSlot) require.NoError(t, err) // confirm data is imported into db stateConfig := state.Config{ @@ -124,10 +133,11 @@ func TestImportState(t *testing.T) { headerFP := setupHeaderFile(t) type args struct { - basepath string - stateFP string - headerFP string - firstSlot uint64 + basepath string + stateFP string + headerFP string + stateVersion trie.TrieLayout + firstSlot uint64 } tests := []struct { name string @@ -141,10 +151,11 @@ func TestImportState(t *testing.T) { { name: "working_example", args: args{ - basepath: config.BasePath, - stateFP: stateFP, - headerFP: headerFP, - firstSlot: 262493679, + basepath: config.BasePath, + stateFP: stateFP, + headerFP: headerFP, + stateVersion: trie.V0, + firstSlot: 262493679, }, }, } @@ -153,7 +164,7 @@ func TestImportState(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - err := ImportState(tt.args.basepath, tt.args.stateFP, tt.args.headerFP, tt.args.firstSlot) + err := ImportState(tt.args.basepath, tt.args.stateFP, tt.args.headerFP, tt.args.stateVersion, tt.args.firstSlot) if tt.err != nil { assert.EqualError(t, err, tt.err.Error()) } else { diff --git a/dot/node_integration_test.go b/dot/node_integration_test.go index 19e048b92e..aad81b1331 100644 --- a/dot/node_integration_test.go +++ b/dot/node_integration_test.go @@ -388,7 +388,7 @@ func TestInitNode_LoadStorageRoot(t *testing.T) { expected, err := trie.LoadFromMap(gen.GenesisFields().Raw["top"]) require.NoError(t, err) - expectedRoot, err := expected.Hash() + expectedRoot, err := trie.V0.Hash(&expected) // Since we are using a runtime with state trie V0 require.NoError(t, err) coreServiceInterface := node.ServiceRegistry.Get(&core.Service{}) diff --git a/dot/rpc/helpers_test.go b/dot/rpc/helpers_test.go index 0e981800d2..af7835518a 100644 --- a/dot/rpc/helpers_test.go +++ b/dot/rpc/helpers_test.go @@ -28,7 +28,7 @@ func newWestendDevGenesisWithTrieAndHeader(t *testing.T) ( require.NoError(t, err) parentHash := common.NewHash([]byte{0}) - stateRoot := genesisTrie.MustHash() + stateRoot := trie.V0.MustHash(genesisTrie) extrinsicRoot := trie.EmptyHash const number = 0 digest := types.NewDigest() diff --git a/dot/rpc/modules/childstate_integration_test.go b/dot/rpc/modules/childstate_integration_test.go index c9e7481051..74c7ae53e0 100644 --- a/dot/rpc/modules/childstate_integration_test.go +++ b/dot/rpc/modules/childstate_integration_test.go @@ -255,7 +255,7 @@ func setupChildStateStorage(t *testing.T) (*ChildStateModule, common.Hash) { err = tr.SetChild([]byte(":child_storage_key"), childTr) require.NoError(t, err) - stateRoot, err := tr.Root() + stateRoot, err := tr.Root(trie.NoMaxInlineValueSize) require.NoError(t, err) bb, err := st.Block.BestBlock() diff --git a/dot/rpc/modules/childstate_test.go b/dot/rpc/modules/childstate_test.go index c84df630ce..9fe353fe49 100644 --- a/dot/rpc/modules/childstate_test.go +++ b/dot/rpc/modules/childstate_test.go @@ -35,7 +35,7 @@ func createTestTrieState(t *testing.T) (*trie.Trie, common.Hash) { err := tr.SetChild([]byte(":child_storage_key"), childTr) require.NoError(t, err) - stateRoot, err := tr.Root() + stateRoot, err := tr.Root(trie.NoMaxInlineValueSize) require.NoError(t, err) return tr.Trie(), stateRoot diff --git a/dot/rpc/modules/helpers_test.go b/dot/rpc/modules/helpers_test.go index fad1310ce0..bce60ba019 100644 --- a/dot/rpc/modules/helpers_test.go +++ b/dot/rpc/modules/helpers_test.go @@ -35,7 +35,7 @@ func newWestendLocalGenesisWithTrieAndHeader(t *testing.T) ( require.NoError(t, err) parentHash := common.NewHash([]byte{0}) - stateRoot := genesisTrie.MustHash() + stateRoot := genesisTrie.MustHash(trie.NoMaxInlineValueSize) extrinsicRoot := trie.EmptyHash const number = 0 digest := types.NewDigest() diff --git a/dot/rpc/modules/state_integration_test.go b/dot/rpc/modules/state_integration_test.go index 04cd189554..7347aab6c5 100644 --- a/dot/rpc/modules/state_integration_test.go +++ b/dot/rpc/modules/state_integration_test.go @@ -17,6 +17,7 @@ import ( "github.com/ChainSafe/gossamer/dot/rpc/modules/mocks" "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/trie" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -574,7 +575,7 @@ func setupStateModule(t *testing.T) (*StateModule, *common.Hash, *common.Hash) { err = ts.SetChildStorage([]byte(`:child1`), []byte(`:key1`), []byte(`:childValue1`)) require.NoError(t, err) - sr1, err := ts.Root() + sr1, err := ts.Root(trie.NoMaxInlineValueSize) require.NoError(t, err) err = chain.Storage.StoreTrie(ts, nil) require.NoError(t, err) diff --git a/dot/rpc/modules/system_integration_test.go b/dot/rpc/modules/system_integration_test.go index 453ac8bb72..ad9ae3ae8c 100644 --- a/dot/rpc/modules/system_integration_test.go +++ b/dot/rpc/modules/system_integration_test.go @@ -330,7 +330,7 @@ func setupSystemModule(t *testing.T) *SystemModule { Header: types.Header{ Number: 3, ParentHash: chain.Block.BestBlockHash(), - StateRoot: ts.MustRoot(), + StateRoot: ts.MustRoot(trie.NoMaxInlineValueSize), Digest: digest, }, Body: types.Body{}, diff --git a/dot/state/base_test.go b/dot/state/base_test.go index a6bedcde74..637ee44bb9 100644 --- a/dot/state/base_test.go +++ b/dot/state/base_test.go @@ -29,15 +29,15 @@ func TestTrie_StoreAndLoadFromDB(t *testing.T) { err := tt.WriteDirty(db) require.NoError(t, err) - encroot, err := tt.Hash() + encroot, err := tt.Hash(trie.NoMaxInlineValueSize) require.NoError(t, err) - expected := tt.MustHash() + expected := tt.MustHash(trie.NoMaxInlineValueSize) tt = trie.NewEmptyTrie() err = tt.Load(db, encroot) require.NoError(t, err) - require.Equal(t, expected, tt.MustHash()) + require.Equal(t, expected, tt.MustHash(trie.NoMaxInlineValueSize)) } func TestStoreAndLoadGenesisData(t *testing.T) { diff --git a/dot/state/db_getter_mocks_test.go b/dot/state/db_getter_mocks_test.go deleted file mode 100644 index 9493704a06..0000000000 --- a/dot/state/db_getter_mocks_test.go +++ /dev/null @@ -1,49 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/ChainSafe/gossamer/lib/trie (interfaces: DBGetter) - -// Package state is a generated GoMock package. -package state - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" -) - -// MockDBGetter is a mock of DBGetter interface. -type MockDBGetter struct { - ctrl *gomock.Controller - recorder *MockDBGetterMockRecorder -} - -// MockDBGetterMockRecorder is the mock recorder for MockDBGetter. -type MockDBGetterMockRecorder struct { - mock *MockDBGetter -} - -// NewMockDBGetter creates a new mock instance. -func NewMockDBGetter(ctrl *gomock.Controller) *MockDBGetter { - mock := &MockDBGetter{ctrl: ctrl} - mock.recorder = &MockDBGetterMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockDBGetter) EXPECT() *MockDBGetterMockRecorder { - return m.recorder -} - -// Get mocks base method. -func (m *MockDBGetter) Get(arg0 []byte) ([]byte, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get", arg0) - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Get indicates an expected call of Get. -func (mr *MockDBGetterMockRecorder) Get(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockDBGetter)(nil).Get), arg0) -} diff --git a/dot/state/helpers_test.go b/dot/state/helpers_test.go index 44abfc4de5..10029e9f01 100644 --- a/dot/state/helpers_test.go +++ b/dot/state/helpers_test.go @@ -102,7 +102,7 @@ func newWestendDevGenesisWithTrieAndHeader(t *testing.T) ( require.NoError(t, err) parentHash := common.NewHash([]byte{0}) - stateRoot := genesisTrie.MustHash() + stateRoot := genesisTrie.MustHash(trie.NoMaxInlineValueSize) extrinsicRoot := trie.EmptyHash const number = 0 digest := types.NewDigest() diff --git a/dot/state/interfaces.go b/dot/state/interfaces.go index f57caec2db..55b3f426fe 100644 --- a/dot/state/interfaces.go +++ b/dot/state/interfaces.go @@ -29,10 +29,11 @@ type GetPutter interface { Putter } -// GetNewBatcher has methods to get values and create a +// GetterPutterNewBatcher has methods to get values and create a // new batch. -type GetNewBatcher interface { +type GetterPutterNewBatcher interface { Getter + Putter NewBatcher } diff --git a/dot/state/mocks_database_test.go b/dot/state/mocks_database_test.go new file mode 100644 index 0000000000..079cb5578f --- /dev/null +++ b/dot/state/mocks_database_test.go @@ -0,0 +1,63 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ChainSafe/gossamer/lib/trie/db (interfaces: Database) + +// Package state is a generated GoMock package. +package state + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockDatabase is a mock of Database interface. +type MockDatabase struct { + ctrl *gomock.Controller + recorder *MockDatabaseMockRecorder +} + +// MockDatabaseMockRecorder is the mock recorder for MockDatabase. +type MockDatabaseMockRecorder struct { + mock *MockDatabase +} + +// NewMockDatabase creates a new mock instance. +func NewMockDatabase(ctrl *gomock.Controller) *MockDatabase { + mock := &MockDatabase{ctrl: ctrl} + mock.recorder = &MockDatabaseMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDatabase) EXPECT() *MockDatabaseMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockDatabase) Get(arg0 []byte) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockDatabaseMockRecorder) Get(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockDatabase)(nil).Get), arg0) +} + +// Put mocks base method. +func (m *MockDatabase) Put(arg0, arg1 []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Put", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Put indicates an expected call of Put. +func (mr *MockDatabaseMockRecorder) Put(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockDatabase)(nil).Put), arg0, arg1) +} diff --git a/dot/state/mocks_generate_test.go b/dot/state/mocks_generate_test.go index e52212d053..10d9c46059 100644 --- a/dot/state/mocks_generate_test.go +++ b/dot/state/mocks_generate_test.go @@ -7,4 +7,4 @@ package state //go:generate mockgen -destination=mocks_runtime_test.go -package $GOPACKAGE github.com/ChainSafe/gossamer/lib/runtime Instance //go:generate mockgen -destination=mock_gauge_test.go -package $GOPACKAGE github.com/prometheus/client_golang/prometheus Gauge //go:generate mockgen -destination=mock_counter_test.go -package $GOPACKAGE github.com/prometheus/client_golang/prometheus Counter -//go:generate mockgen -destination=db_getter_mocks_test.go -package=$GOPACKAGE github.com/ChainSafe/gossamer/lib/trie DBGetter +//go:generate mockgen -destination=mocks_database_test.go -package=$GOPACKAGE github.com/ChainSafe/gossamer/lib/trie/db Database diff --git a/dot/state/service.go b/dot/state/service.go index eb17ca25ea..7066664ee4 100644 --- a/dot/state/service.go +++ b/dot/state/service.go @@ -256,7 +256,7 @@ func (s *Service) Stop() error { // Import imports the given state corresponding to the given header and sets the head of the chain // to it. Additionally, it uses the first slot to correctly set the epoch number of the block. -func (s *Service) Import(header *types.Header, t *trie.Trie, firstSlot uint64) error { +func (s *Service) Import(header *types.Header, t *trie.Trie, stateTrieVersion trie.TrieLayout, firstSlot uint64) error { var err error // initialise database using data directory if !s.isMemDB { @@ -301,7 +301,7 @@ func (s *Service) Import(header *types.Header, t *trie.Trie, firstSlot uint64) e return err } - root := t.MustHash() + root := stateTrieVersion.MustHash(*t) if root != header.StateRoot { return fmt.Errorf("trie state root does not equal header state root") } diff --git a/dot/state/service_integration_test.go b/dot/state/service_integration_test.go index d5b51186e7..8fb9bc5de5 100644 --- a/dot/state/service_integration_test.go +++ b/dot/state/service_integration_test.go @@ -90,7 +90,7 @@ func TestService_Initialise(t *testing.T) { require.NoError(t, err) genesisHeaderPtr := types.NewHeader(common.NewHash([]byte{77}), - genTrie.MustHash(), trie.EmptyHash, 0, types.NewDigest()) + genTrie.MustHash(trie.NoMaxInlineValueSize), trie.EmptyHash, 0, types.NewDigest()) err = state.Initialise(&genData, genesisHeaderPtr, genTrieCopy) require.NoError(t, err) @@ -287,7 +287,7 @@ func TestService_PruneStorage(t *testing.T) { copiedTrie := trieState.Trie().DeepCopy() var rootHash common.Hash - rootHash, err = copiedTrie.Hash() + rootHash, err = copiedTrie.Hash(trie.NoMaxInlineValueSize) require.NoError(t, err) prunedArr = append(prunedArr, prunedBlock{hash: block.Header.StateRoot, dbKey: rootHash[:]}) @@ -400,13 +400,13 @@ func TestService_Import(t *testing.T) { require.NoError(t, err) header := &types.Header{ Number: 77, - StateRoot: tr.MustHash(), + StateRoot: tr.MustHash(trie.NoMaxInlineValueSize), Digest: digest, } firstSlot := uint64(100) - err = serv.Import(header, tr, firstSlot) + err = serv.Import(header, tr, trie.V0, firstSlot) require.NoError(t, err) err = serv.Start() @@ -440,7 +440,7 @@ func generateBlockWithRandomTrie(t *testing.T, serv *Service, err = trieState.Put(key, value) require.NoError(t, err) - trieStateRoot, err := trieState.Root() + trieStateRoot, err := trieState.Root(trie.NoMaxInlineValueSize) require.NoError(t, err) if parent == nil { diff --git a/dot/state/storage.go b/dot/state/storage.go index b7104b2c68..ed8f8377a0 100644 --- a/dot/state/storage.go +++ b/dot/state/storage.go @@ -33,7 +33,7 @@ type StorageState struct { blockState *BlockState tries *Tries - db GetNewBatcher + db GetterPutterNewBatcher sync.RWMutex // change notifiers @@ -59,7 +59,7 @@ func NewStorageState(db database.Database, blockState *BlockState, // StoreTrie stores the given trie in the StorageState and writes it to the database func (s *StorageState) StoreTrie(ts *rtstorage.TrieState, header *types.Header) error { - root := ts.MustRoot() + root := ts.MustRoot(trie.NoMaxInlineValueSize) s.tries.softSet(root, ts.Trie()) @@ -106,7 +106,7 @@ func (s *StorageState) TrieState(root *common.Hash) (*rtstorage.TrieState, error } s.tries.softSet(*root, t) - } else if t.MustHash() != *root { + } else if t.MustHash(trie.NoMaxInlineValueSize) != *root { panic("trie does not have expected root") } @@ -119,13 +119,13 @@ func (s *StorageState) TrieState(root *common.Hash) (*rtstorage.TrieState, error // LoadFromDB loads an encoded trie from the DB where the key is `root` func (s *StorageState) LoadFromDB(root common.Hash) (*trie.Trie, error) { - t := trie.NewEmptyTrie() + t := trie.NewTrie(nil, s.db) err := t.Load(s.db, root) if err != nil { return nil, err } - s.tries.softSet(t.MustHash(), t) + s.tries.softSet(t.MustHash(trie.NoMaxInlineValueSize), t) return t, nil } diff --git a/dot/state/storage_test.go b/dot/state/storage_test.go index d77cc03e5e..a654e3ea82 100644 --- a/dot/state/storage_test.go +++ b/dot/state/storage_test.go @@ -35,7 +35,7 @@ func TestStorage_StoreAndLoadTrie(t *testing.T) { ts, err := storage.TrieState(&trie.EmptyHash) require.NoError(t, err) - root, err := ts.Root() + root, err := ts.Root(trie.NoMaxInlineValueSize) require.NoError(t, err) err = storage.StoreTrie(ts, nil) require.NoError(t, err) @@ -46,7 +46,8 @@ func TestStorage_StoreAndLoadTrie(t *testing.T) { require.NoError(t, err) ts2 := runtime.NewTrieState(trie) newSnapshot := ts2.Snapshot() - require.Equal(t, ts.Trie(), newSnapshot) + + require.True(t, ts.Trie().Equal(newSnapshot)) } func TestStorage_GetStorageByBlockHash(t *testing.T) { @@ -58,7 +59,7 @@ func TestStorage_GetStorageByBlockHash(t *testing.T) { value := []byte("testvalue") ts.Put(key, value) - root, err := ts.Root() + root, err := ts.Root(trie.NoMaxInlineValueSize) require.NoError(t, err) err = storage.StoreTrie(ts, nil) require.NoError(t, err) @@ -90,7 +91,7 @@ func TestStorage_TrieState(t *testing.T) { require.NoError(t, err) ts.Put([]byte("noot"), []byte("washere")) - root, err := ts.Root() + root, err := ts.Root(trie.NoMaxInlineValueSize) require.NoError(t, err) err = storage.StoreTrie(ts, nil) require.NoError(t, err) @@ -101,7 +102,7 @@ func TestStorage_TrieState(t *testing.T) { storage.blockState.tries.delete(root) ts3, err := storage.TrieState(&root) require.NoError(t, err) - require.Equal(t, ts.Trie().MustHash(), ts3.Trie().MustHash()) + require.Equal(t, ts.Trie().MustHash(trie.NoMaxInlineValueSize), ts3.Trie().MustHash(trie.NoMaxInlineValueSize)) } func TestStorage_LoadFromDB(t *testing.T) { @@ -117,13 +118,14 @@ func TestStorage_LoadFromDB(t *testing.T) { {[]byte("key1"), []byte("value1")}, {[]byte("key2"), []byte("value2")}, {[]byte("xyzKey1"), []byte("xyzValue1")}, + {[]byte("long"), []byte("newvaluewithmorethan32byteslength")}, } for _, kv := range trieKV { ts.Put(kv.key, kv.value) } - root, err := ts.Root() + root, err := ts.Root(trie.NoMaxInlineValueSize) require.NoError(t, err) // Write trie to disk. @@ -147,7 +149,7 @@ func TestStorage_LoadFromDB(t *testing.T) { entries, err := storage.Entries(&root) require.NoError(t, err) - require.Equal(t, 4, len(entries)) + require.Equal(t, 5, len(entries)) } func TestStorage_StoreTrie_NotSyncing(t *testing.T) { @@ -178,15 +180,15 @@ func TestGetStorageChildAndGetStorageFromChild(t *testing.T) { genHeader.Hash(), "0", )) - dbGetter := NewMockDBGetter(ctrl) - dbGetter.EXPECT().Get(gomock.Any()).Times(0) + trieDB := NewMockDatabase(ctrl) + trieDB.EXPECT().Get(gomock.Any()).Times(0) trieRoot := &node.Node{ PartialKey: []byte{1, 2}, StorageValue: []byte{3, 4}, Dirty: true, } - testChildTrie := trie.NewTrie(trieRoot, dbGetter) + testChildTrie := trie.NewTrie(trieRoot, trieDB) testChildTrie.Put([]byte("keyInsidechild"), []byte("voila")) @@ -203,13 +205,13 @@ func TestGetStorageChildAndGetStorageFromChild(t *testing.T) { trieState := runtime.NewTrieState(&genTrie) - header := types.NewHeader(blockState.GenesisHash(), trieState.MustRoot(), + header := types.NewHeader(blockState.GenesisHash(), trieState.MustRoot(trie.NoMaxInlineValueSize), common.Hash{}, 1, types.NewDigest()) err = storage.StoreTrie(trieState, header) require.NoError(t, err) - rootHash, err := genTrie.Hash() + rootHash, err := genTrie.Hash(trie.NoMaxInlineValueSize) require.NoError(t, err) _, err = storage.GetStorageChild(&rootHash, []byte("keyToChild")) diff --git a/dot/state/tries.go b/dot/state/tries.go index 34a8484a7d..d7c87bc625 100644 --- a/dot/state/tries.go +++ b/dot/state/tries.go @@ -59,8 +59,8 @@ func (t *Tries) SetEmptyTrie() { } // SetTrie sets the trie at its root hash in the tries map. -func (t *Tries) SetTrie(trie *trie.Trie) { - t.softSet(trie.MustHash(), trie) +func (t *Tries) SetTrie(tr *trie.Trie) { + t.softSet(tr.MustHash(trie.NoMaxInlineValueSize), tr) } // softSet sets the given trie at the given root hash diff --git a/dot/state/tries_test.go b/dot/state/tries_test.go index f97c4934e9..40ef5626bd 100644 --- a/dot/state/tries_test.go +++ b/dot/state/tries_test.go @@ -49,17 +49,17 @@ func Test_Tries_SetEmptyTrie(t *testing.T) { func Test_Tries_SetTrie(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) - dbGetter := NewMockDBGetter(ctrl) - dbGetter.EXPECT().Get(gomock.Any()).Times(0) + db := NewMockDatabase(ctrl) + db.EXPECT().Get(gomock.Any()).Times(0) - tr := trie.NewTrie(&node.Node{PartialKey: []byte{1}}, dbGetter) + tr := trie.NewTrie(&node.Node{PartialKey: []byte{1}}, db) tries := NewTries() tries.SetTrie(tr) expectedTries := &Tries{ rootToTrie: map[common.Hash]*trie.Trie{ - tr.MustHash(): tr, + tr.MustHash(trie.NoMaxInlineValueSize): tr, }, triesGauge: triesGauge, setCounter: setCounter, @@ -192,8 +192,8 @@ func Test_Tries_delete(t *testing.T) { func Test_Tries_get(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) - dbGetter := NewMockDBGetter(ctrl) - dbGetter.EXPECT().Get(gomock.Any()).Times(0) + db := NewMockDatabase(ctrl) + db.EXPECT().Get(gomock.Any()).Times(0) testCases := map[string]struct { tries *Tries @@ -206,14 +206,14 @@ func Test_Tries_get(t *testing.T) { {1, 2, 3}: trie.NewTrie(&node.Node{ PartialKey: []byte{1, 2, 3}, StorageValue: []byte{1}, - }, dbGetter), + }, db), }, }, root: common.Hash{1, 2, 3}, trie: trie.NewTrie(&node.Node{ PartialKey: []byte{1, 2, 3}, StorageValue: []byte{1}, - }, dbGetter), + }, db), }, "not_found_in_map": { // similar to not found in database diff --git a/dot/sync/chain_sync.go b/dot/sync/chain_sync.go index 6457a72e0b..68da2fa83c 100644 --- a/dot/sync/chain_sync.go +++ b/dot/sync/chain_sync.go @@ -26,6 +26,7 @@ import ( "github.com/ChainSafe/gossamer/lib/blocktree" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/common/variadic" + "github.com/ChainSafe/gossamer/lib/trie" ) var _ ChainSync = (*chainSync)(nil) @@ -950,7 +951,7 @@ func (cs *chainSync) handleBlock(block *types.Block, announceImportedBlock bool) return err } - root := ts.MustRoot() + root := ts.MustRoot(trie.NoMaxInlineValueSize) if !bytes.Equal(parent.StateRoot[:], root[:]) { panic("parent state root does not match snapshot state root") } diff --git a/dot/sync/chain_sync_test.go b/dot/sync/chain_sync_test.go index 0344ab43f6..fb932a8c60 100644 --- a/dot/sync/chain_sync_test.go +++ b/dot/sync/chain_sync_test.go @@ -65,9 +65,10 @@ func Test_chainSync_onBlockAnnounce(t *testing.T) { errTest := errors.New("test error") emptyTrieState := storage.NewTrieState(nil) - block1AnnounceHeader := types.NewHeader(common.Hash{}, emptyTrieState.MustRoot(), + block1AnnounceHeader := types.NewHeader(common.Hash{}, emptyTrieState.MustRoot(trie.NoMaxInlineValueSize), common.Hash{}, 1, scale.VaryingDataTypeSlice{}) - block2AnnounceHeader := types.NewHeader(block1AnnounceHeader.Hash(), emptyTrieState.MustRoot(), + block2AnnounceHeader := types.NewHeader(block1AnnounceHeader.Hash(), + emptyTrieState.MustRoot(trie.NoMaxInlineValueSize), common.Hash{}, 2, scale.VaryingDataTypeSlice{}) testCases := map[string]struct { @@ -242,9 +243,10 @@ func Test_chainSync_onBlockAnnounceHandshake_tipModeNeedToCatchup(t *testing.T) const somePeer = peer.ID("abc") emptyTrieState := storage.NewTrieState(nil) - block1AnnounceHeader := types.NewHeader(common.Hash{}, emptyTrieState.MustRoot(), + block1AnnounceHeader := types.NewHeader(common.Hash{}, emptyTrieState.MustRoot(trie.NoMaxInlineValueSize), common.Hash{}, 1, scale.VaryingDataTypeSlice{}) - block2AnnounceHeader := types.NewHeader(block1AnnounceHeader.Hash(), emptyTrieState.MustRoot(), + block2AnnounceHeader := types.NewHeader(block1AnnounceHeader.Hash(), + emptyTrieState.MustRoot(trie.NoMaxInlineValueSize), common.Hash{}, 130, scale.VaryingDataTypeSlice{}) blockStateMock := NewMockBlockState(ctrl) @@ -1249,7 +1251,7 @@ func createSuccesfullBlockResponse(t *testing.T, parentHeader common.Hash, response.BlockData = make([]*types.BlockData, numBlocks) emptyTrieState := storage.NewTrieState(nil) - tsRoot := emptyTrieState.MustRoot() + tsRoot := emptyTrieState.MustRoot(trie.NoMaxInlineValueSize) firstHeader := types.NewHeader(parentHeader, tsRoot, common.Hash{}, uint(startingAt), scale.VaryingDataTypeSlice{}) diff --git a/dot/sync/syncer_integration_test.go b/dot/sync/syncer_integration_test.go index b12ed27363..9333486bce 100644 --- a/dot/sync/syncer_integration_test.go +++ b/dot/sync/syncer_integration_test.go @@ -99,7 +99,7 @@ func newTestSyncer(t *testing.T) *Service { stateSrvc.Block.StoreRuntime(block.Header.Hash(), instance) logger.Debugf("imported block %s and stored state trie with root %s", - block.Header.Hash(), ts.MustRoot()) + block.Header.Hash(), ts.MustRoot(trie.NoMaxInlineValueSize)) return nil }).AnyTimes() cfg.BlockImportHandler = blockImportHandler @@ -137,7 +137,7 @@ func newWestendDevGenesisWithTrieAndHeader(t *testing.T) ( require.NoError(t, err) parentHash := common.NewHash([]byte{0}) - stateRoot := genesisTrie.MustHash() + stateRoot := genesisTrie.MustHash(trie.NoMaxInlineValueSize) extrinsicRoot := trie.EmptyHash const number = 0 digest := types.NewDigest() diff --git a/dot/utils_integration_test.go b/dot/utils_integration_test.go index 7f4e0a67a3..27e0d7c738 100644 --- a/dot/utils_integration_test.go +++ b/dot/utils_integration_test.go @@ -36,13 +36,13 @@ func TestTrieSnapshot(t *testing.T) { newTrie := tri.Snapshot() // Get the Trie root hash for all the 3 tries. - tHash, err := tri.Hash() + tHash, err := tri.Hash(trie.NoMaxInlineValueSize) require.NoError(t, err) - dcTrieHash, err := deepCopyTrie.Hash() + dcTrieHash, err := deepCopyTrie.Hash(trie.NoMaxInlineValueSize) require.NoError(t, err) - newTrieHash, err := newTrie.Hash() + newTrieHash, err := newTrie.Hash(trie.NoMaxInlineValueSize) require.NoError(t, err) // Root hash for the 3 tries should be equal. @@ -54,13 +54,13 @@ func TestTrieSnapshot(t *testing.T) { newTrie.Put(key, value) // Get the updated root hash of all tries. - tHash, err = tri.Hash() + tHash, err = tri.Hash(trie.NoMaxInlineValueSize) require.NoError(t, err) - dcTrieHash, err = deepCopyTrie.Hash() + dcTrieHash, err = deepCopyTrie.Hash(trie.NoMaxInlineValueSize) require.NoError(t, err) - newTrieHash, err = newTrie.Hash() + newTrieHash, err = newTrie.Hash(trie.NoMaxInlineValueSize) require.NoError(t, err) // Only the current trie should have a different root hash since it is updated. diff --git a/internal/trie/node/branch_encode.go b/internal/trie/node/branch_encode.go index 75d00cc6d8..d9b8e7b5a6 100644 --- a/internal/trie/node/branch_encode.go +++ b/internal/trie/node/branch_encode.go @@ -18,10 +18,10 @@ type encodingAsyncResult struct { err error } -func runEncodeChild(child *Node, index int, +func runEncodeChild(child *Node, index, maxInlineValue int, results chan<- encodingAsyncResult, rateLimit <-chan struct{}) { buffer := bytes.NewBuffer(nil) - err := encodeChild(child, buffer) + err := encodeChild(child, maxInlineValue, buffer) results <- encodingAsyncResult{ index: index, @@ -44,7 +44,7 @@ var parallelEncodingRateLimit = make(chan struct{}, parallelLimit) // goroutines IF they are less than the parallelLimit number of goroutines already // running. This is designed to limit the total number of goroutines in order to // avoid using too much memory on the stack. -func encodeChildrenOpportunisticParallel(children []*Node, buffer io.Writer) (err error) { +func encodeChildrenOpportunisticParallel(children []*Node, maxInlineValue int, buffer io.Writer) (err error) { // Buffered channels since children might be encoded in this // goroutine or another one. resultsCh := make(chan encodingAsyncResult, ChildrenCapacity) @@ -56,7 +56,7 @@ func encodeChildrenOpportunisticParallel(children []*Node, buffer io.Writer) (er } if child.Kind() == Leaf { - runEncodeChild(child, i, resultsCh, nil) + runEncodeChild(child, i, maxInlineValue, resultsCh, nil) continue } @@ -65,11 +65,11 @@ func encodeChildrenOpportunisticParallel(children []*Node, buffer io.Writer) (er case parallelEncodingRateLimit <- struct{}{}: // We have a goroutine available to encode // the branch in parallel. - go runEncodeChild(child, i, resultsCh, parallelEncodingRateLimit) + go runEncodeChild(child, i, maxInlineValue, resultsCh, parallelEncodingRateLimit) default: // we reached the maximum parallel goroutines // so encode this branch in this goroutine - runEncodeChild(child, i, resultsCh, nil) + runEncodeChild(child, i, maxInlineValue, resultsCh, nil) } } @@ -116,24 +116,10 @@ func encodeChildrenOpportunisticParallel(children []*Node, buffer io.Writer) (er return err } -func encodeChildrenSequentially(children []*Node, buffer io.Writer) (err error) { - for i, child := range children { - if child == nil { - continue - } - - err = encodeChild(child, buffer) - if err != nil { - return fmt.Errorf("encoding child at index %d: %w", i, err) - } - } - return nil -} - // encodeChild computes the Merkle value of the node // and then SCALE encodes it to the given buffer. -func encodeChild(child *Node, buffer io.Writer) (err error) { - merkleValue, err := child.CalculateMerkleValue() +func encodeChild(child *Node, maxInlineValue int, buffer io.Writer) (err error) { + merkleValue, err := child.CalculateMerkleValue(maxInlineValue) if err != nil { return fmt.Errorf("computing %s Merkle value: %w", child.Kind(), err) } diff --git a/internal/trie/node/branch_encode_test.go b/internal/trie/node/branch_encode_test.go index 7ef2ec4c0f..e21b315423 100644 --- a/internal/trie/node/branch_encode_test.go +++ b/internal/trie/node/branch_encode_test.go @@ -23,7 +23,7 @@ func Benchmark_encodeChildrenOpportunisticParallel(b *testing.B) { b.Run("", func(b *testing.B) { for i := 0; i < b.N; i++ { - _ = encodeChildrenOpportunisticParallel(children, io.Discard) + _ = encodeChildrenOpportunisticParallel(children, NoMaxInlineValueSize, io.Discard) } }) } @@ -139,7 +139,7 @@ func Test_encodeChildrenOpportunisticParallel(t *testing.T) { previousCall = call } - err := encodeChildrenOpportunisticParallel(testCase.children, buffer) + err := encodeChildrenOpportunisticParallel(testCase.children, NoMaxInlineValueSize, buffer) if testCase.wrappedErr != nil { assert.ErrorIs(t, err, testCase.wrappedErr) @@ -164,7 +164,7 @@ func Test_encodeChildrenOpportunisticParallel(t *testing.T) { // Note this may run in parallel or not depending on other tests // running in parallel. - err := encodeChildrenOpportunisticParallel(children, buffer) + err := encodeChildrenOpportunisticParallel(children, NoMaxInlineValueSize, buffer) require.NoError(t, err) expectedBytes := []byte{ @@ -180,100 +180,6 @@ func Test_encodeChildrenOpportunisticParallel(t *testing.T) { }) } -func Test_encodeChildrenSequentially(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - children []*Node - writes []writeCall - wrappedErr error - errMessage string - }{ - "no_children": {}, - "first_child_not_nil": { - children: []*Node{ - {PartialKey: []byte{1}, StorageValue: []byte{2}}, - }, - writes: []writeCall{ - {written: []byte{16}}, - {written: []byte{65, 1, 4, 2}}, - }, - }, - "last_child_not_nil": { - children: []*Node{ - nil, nil, nil, nil, nil, - nil, nil, nil, nil, nil, - nil, nil, nil, nil, nil, - {PartialKey: []byte{1}, StorageValue: []byte{2}}, - }, - writes: []writeCall{ - {written: []byte{16}}, - {written: []byte{65, 1, 4, 2}}, - }, - }, - "first_two_children_not_nil": { - children: []*Node{ - {PartialKey: []byte{1}, StorageValue: []byte{2}}, - {PartialKey: []byte{3}, StorageValue: []byte{4}}, - }, - writes: []writeCall{ - {written: []byte{16}}, - {written: []byte{65, 1, 4, 2}}, - {written: []byte{16}}, - {written: []byte{65, 3, 4, 4}}, - }, - }, - "encoding_error": { - children: []*Node{ - nil, nil, nil, nil, - nil, nil, nil, nil, - nil, nil, nil, - {PartialKey: []byte{1}, StorageValue: []byte{2}}, - nil, nil, nil, nil, - }, - writes: []writeCall{ - { - written: []byte{16}, - err: errTest, - }, - }, - wrappedErr: errTest, - errMessage: "encoding child at index 11: " + - "scale encoding Merkle value: test error", - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - - buffer := NewMockWriter(ctrl) - var previousCall *gomock.Call - for _, write := range testCase.writes { - call := buffer.EXPECT(). - Write(write.written). - Return(write.n, write.err) - - if previousCall != nil { - call.After(previousCall) - } - previousCall = call - } - - err := encodeChildrenSequentially(testCase.children, buffer) - - if testCase.wrappedErr != nil { - assert.ErrorIs(t, err, testCase.wrappedErr) - assert.EqualError(t, err, testCase.errMessage) - } else { - require.NoError(t, err) - } - }) - } -} - func Test_encodeChild(t *testing.T) { t.Parallel() @@ -349,7 +255,7 @@ func Test_encodeChild(t *testing.T) { previousCall = call } - err := encodeChild(testCase.child, buffer) + err := encodeChild(testCase.child, NoMaxInlineValueSize, buffer) if testCase.wrappedErr != nil { assert.ErrorIs(t, err, testCase.wrappedErr) diff --git a/internal/trie/node/copy.go b/internal/trie/node/copy.go index c8f0d8f001..f622419456 100644 --- a/internal/trie/node/copy.go +++ b/internal/trie/node/copy.go @@ -90,7 +90,7 @@ func (n *Node) Copy(settings CopySettings) *Node { if settings.CopyStorageValue && n.StorageValue != nil { cpy.StorageValue = make([]byte, len(n.StorageValue)) copy(cpy.StorageValue, n.StorageValue) - cpy.HashedValue = n.HashedValue + cpy.IsHashedValue = n.IsHashedValue } if settings.CopyMerkleValue { diff --git a/internal/trie/node/copy_test.go b/internal/trie/node/copy_test.go index c9a2268cf2..cf080b4ca8 100644 --- a/internal/trie/node/copy_test.go +++ b/internal/trie/node/copy_test.go @@ -112,14 +112,14 @@ func Test_Node_Copy(t *testing.T) { }, "deep_copy_branch_with_hashed_values": { node: &Node{ - PartialKey: []byte{1, 2}, - StorageValue: []byte{3, 4}, - HashedValue: true, + PartialKey: []byte{1, 2}, + StorageValue: []byte{3, 4}, + IsHashedValue: true, Children: padRightChildren([]*Node{ nil, nil, { - PartialKey: []byte{9}, - StorageValue: []byte{1}, - HashedValue: true, + PartialKey: []byte{9}, + StorageValue: []byte{1}, + IsHashedValue: true, }, }), Dirty: true, @@ -127,14 +127,14 @@ func Test_Node_Copy(t *testing.T) { }, settings: DeepCopySettings, expectedNode: &Node{ - PartialKey: []byte{1, 2}, - StorageValue: []byte{3, 4}, - HashedValue: true, + PartialKey: []byte{1, 2}, + StorageValue: []byte{3, 4}, + IsHashedValue: true, Children: padRightChildren([]*Node{ nil, nil, { - PartialKey: []byte{9}, - StorageValue: []byte{1}, - HashedValue: true, + PartialKey: []byte{9}, + StorageValue: []byte{1}, + IsHashedValue: true, }, }), Dirty: true, @@ -172,19 +172,19 @@ func Test_Node_Copy(t *testing.T) { }, "deep_copy_leaf_with_hashed_value": { node: &Node{ - PartialKey: []byte{1, 2}, - StorageValue: []byte{3, 4}, - HashedValue: true, - Dirty: true, - MerkleValue: []byte{5}, + PartialKey: []byte{1, 2}, + StorageValue: []byte{3, 4}, + IsHashedValue: true, + Dirty: true, + MerkleValue: []byte{5}, }, settings: DeepCopySettings, expectedNode: &Node{ - PartialKey: []byte{1, 2}, - StorageValue: []byte{3, 4}, - HashedValue: true, - Dirty: true, - MerkleValue: []byte{5}, + PartialKey: []byte{1, 2}, + StorageValue: []byte{3, 4}, + IsHashedValue: true, + Dirty: true, + MerkleValue: []byte{5}, }, }, } diff --git a/internal/trie/node/decode.go b/internal/trie/node/decode.go index 1e2cf5c614..2c7c22f553 100644 --- a/internal/trie/node/decode.go +++ b/internal/trie/node/decode.go @@ -96,7 +96,7 @@ func decodeBranch(reader io.Reader, variant variant, partialKeyLength uint16) ( return nil, err } node.StorageValue = hashedValue - node.HashedValue = true + node.IsHashedValue = true default: // Ignored } @@ -150,7 +150,7 @@ func decodeLeaf(reader io.Reader, variant variant, partialKeyLength uint16) (nod return nil, err } node.StorageValue = hashedValue - node.HashedValue = true + node.IsHashedValue = true return node, nil } diff --git a/internal/trie/node/decode_test.go b/internal/trie/node/decode_test.go index 56983161b7..49066d641e 100644 --- a/internal/trie/node/decode_test.go +++ b/internal/trie/node/decode_test.go @@ -109,9 +109,9 @@ func Test_Decode(t *testing.T) { hashedValue.ToBytes(), })), n: &Node{ - PartialKey: []byte{9}, - StorageValue: hashedValue.ToBytes(), - HashedValue: true, + PartialKey: []byte{9}, + StorageValue: hashedValue.ToBytes(), + IsHashedValue: true, }, }, "leaf_with_hashed_value_fail_too_short": { @@ -131,10 +131,10 @@ func Test_Decode(t *testing.T) { hashedValue.ToBytes(), })), n: &Node{ - PartialKey: []byte{9}, - Children: make([]*Node, ChildrenCapacity), - StorageValue: hashedValue.ToBytes(), - HashedValue: true, + PartialKey: []byte{9}, + Children: make([]*Node, ChildrenCapacity), + StorageValue: hashedValue.ToBytes(), + IsHashedValue: true, }, }, "branch_with_hashed_value_fail_too_short": { diff --git a/internal/trie/node/encode.go b/internal/trie/node/encode.go index 90eb7b7d94..1e5810ea34 100644 --- a/internal/trie/node/encode.go +++ b/internal/trie/node/encode.go @@ -4,7 +4,6 @@ package node import ( - "errors" "fmt" "github.com/ChainSafe/gossamer/internal/trie/codec" @@ -12,14 +11,12 @@ import ( "github.com/ChainSafe/gossamer/pkg/scale" ) -var ErrEncodeHashedValueTooShort = errors.New("hashed storage value too short") - // Encode encodes the node to the buffer given. // The encoding format is documented in the README.md // of this package, and specified in the Polkadot spec at // https://spec.polkadot.network/#sect-state-storage -func (n *Node) Encode(buffer Buffer) (err error) { - err = encodeHeader(n, buffer) +func (n *Node) Encode(buffer Buffer, maxInlineValue int) (err error) { + err = encodeHeader(n, maxInlineValue, buffer) if err != nil { return fmt.Errorf("cannot encode header: %w", err) } @@ -49,11 +46,12 @@ func (n *Node) Encode(buffer Buffer) (err error) { // even if it is empty. Do not encode if the branch is without value. // Note leaves and branches with value cannot have a `nil` storage value. if n.StorageValue != nil { - if n.HashedValue { - if len(n.StorageValue) != common.HashLength { - return fmt.Errorf("%w: expected %d, got: %d", ErrEncodeHashedValueTooShort, common.HashLength, len(n.StorageValue)) + if len(n.StorageValue) > maxInlineValue { + hashedValue, err := common.Blake2bHash(n.StorageValue) + if err != nil { + return fmt.Errorf("hashing storage value: %w", err) } - _, err := buffer.Write(n.StorageValue) + _, err = buffer.Write(hashedValue.ToBytes()) if err != nil { return fmt.Errorf("encoding hashed storage value: %w", err) } @@ -67,7 +65,7 @@ func (n *Node) Encode(buffer Buffer) (err error) { } if nodeIsBranch { - err = encodeChildrenOpportunisticParallel(n.Children, buffer) + err = encodeChildrenOpportunisticParallel(n.Children, maxInlineValue, buffer) if err != nil { return fmt.Errorf("cannot encode children of branch: %w", err) } diff --git a/internal/trie/node/encode_decode_test.go b/internal/trie/node/encode_decode_test.go index 71974ff907..8eed629bae 100644 --- a/internal/trie/node/encode_decode_test.go +++ b/internal/trie/node/encode_decode_test.go @@ -118,7 +118,7 @@ func Test_Branch_Encode_Decode(t *testing.T) { buffer := bytes.NewBuffer(nil) - err := testCase.branchToEncode.Encode(buffer) + err := testCase.branchToEncode.Encode(buffer, NoMaxInlineValueSize) require.NoError(t, err) nodeVariant, partialKeyLength, err := decodeHeader(buffer) diff --git a/internal/trie/node/encode_test.go b/internal/trie/node/encode_test.go index 53707e5a3f..aea9d02196 100644 --- a/internal/trie/node/encode_test.go +++ b/internal/trie/node/encode_test.go @@ -5,6 +5,7 @@ package node import ( "errors" + "math" "testing" "github.com/ChainSafe/gossamer/lib/common" @@ -13,6 +14,8 @@ import ( "github.com/stretchr/testify/require" ) +const NoMaxInlineValueSize = math.MaxInt + type writeCall struct { written []byte n int // number of bytes @@ -24,17 +27,19 @@ var errTest = errors.New("test error") func Test_Node_Encode(t *testing.T) { t.Parallel() - hashedValue, err := common.Blake2bHash([]byte("test")) - assert.NoError(t, err) + largeValue := []byte("newvaluewithmorethan32byteslength") + hashedLargeValue := common.MustBlake2bHash(largeValue).ToBytes() testCases := map[string]struct { - node *Node - writes []writeCall - wrappedErr error - errMessage string + node *Node + maxInlineValueSize int + writes []writeCall + wrappedErr error + errMessage string }{ "nil_node": { - node: nil, + node: nil, + maxInlineValueSize: NoMaxInlineValueSize, writes: []writeCall{ { written: []byte{emptyVariant.bits}, @@ -45,6 +50,7 @@ func Test_Node_Encode(t *testing.T) { node: &Node{ PartialKey: make([]byte, 1), }, + maxInlineValueSize: NoMaxInlineValueSize, writes: []writeCall{ { written: []byte{leafVariant.bits | 1}, @@ -59,6 +65,7 @@ func Test_Node_Encode(t *testing.T) { PartialKey: []byte{1, 2, 3}, StorageValue: []byte{1}, }, + maxInlineValueSize: NoMaxInlineValueSize, writes: []writeCall{ { written: []byte{leafVariant.bits | 3}, // partial key length 3 @@ -76,6 +83,7 @@ func Test_Node_Encode(t *testing.T) { PartialKey: []byte{1, 2, 3}, StorageValue: []byte{4, 5, 6}, }, + maxInlineValueSize: NoMaxInlineValueSize, writes: []writeCall{ { written: []byte{leafVariant.bits | 3}, // partial key length 3 @@ -96,6 +104,7 @@ func Test_Node_Encode(t *testing.T) { PartialKey: []byte{1, 2, 3}, StorageValue: []byte{4, 5, 6}, }, + maxInlineValueSize: NoMaxInlineValueSize, writes: []writeCall{ { written: []byte{leafVariant.bits | 3}, // partial key length 3 @@ -110,6 +119,7 @@ func Test_Node_Encode(t *testing.T) { PartialKey: []byte{1, 2, 3}, StorageValue: []byte{}, }, + maxInlineValueSize: NoMaxInlineValueSize, writes: []writeCall{ {written: []byte{leafVariant.bits | 3}}, // partial key length 3 {written: []byte{0x01, 0x23}}, // partial key @@ -117,26 +127,27 @@ func Test_Node_Encode(t *testing.T) { {written: []byte{}}, // node storage value }, }, - "leaf_with_hashed_value_success": { + "leaf_with_value_gt_max_success": { node: &Node{ PartialKey: []byte{1, 2, 3}, - StorageValue: hashedValue.ToBytes(), - HashedValue: true, + StorageValue: largeValue, }, + maxInlineValueSize: 32, writes: []writeCall{ { written: []byte{leafWithHashedValueVariant.bits | 3}, }, {written: []byte{0x01, 0x23}}, - {written: hashedValue.ToBytes()}, + {written: hashedLargeValue}, }, }, - "leaf_with_hashed_value_fail": { + "leaf_with_value_gt_max_fail": { node: &Node{ - PartialKey: []byte{1, 2, 3}, - StorageValue: hashedValue.ToBytes(), - HashedValue: true, + PartialKey: []byte{1, 2, 3}, + StorageValue: largeValue, + IsHashedValue: true, }, + maxInlineValueSize: 32, writes: []writeCall{ { written: []byte{leafWithHashedValueVariant.bits | 3}, @@ -145,28 +156,13 @@ func Test_Node_Encode(t *testing.T) { written: []byte{0x01, 0x23}, }, { - written: hashedValue.ToBytes(), + written: hashedLargeValue, err: errTest, }, }, wrappedErr: errTest, errMessage: "encoding hashed storage value: test error", }, - "leaf_with_hashed_value_fail_too_short": { - node: &Node{ - PartialKey: []byte{1, 2, 3}, - StorageValue: []byte("tooshort"), - HashedValue: true, - }, - writes: []writeCall{ - { - written: []byte{leafWithHashedValueVariant.bits | 3}, - }, - {written: []byte{0x01, 0x23}}, - }, - wrappedErr: ErrEncodeHashedValueTooShort, - errMessage: "hashed storage value too short: expected 32, got: 8", - }, "branch_header_encoding_error": { node: &Node{ Children: make([]*Node, ChildrenCapacity), @@ -187,6 +183,7 @@ func Test_Node_Encode(t *testing.T) { PartialKey: []byte{1, 2, 3}, StorageValue: []byte{100}, }, + maxInlineValueSize: NoMaxInlineValueSize, writes: []writeCall{ { // header written: []byte{branchWithValueVariant.bits | 3}, // partial key length 3 @@ -208,6 +205,7 @@ func Test_Node_Encode(t *testing.T) { nil, nil, nil, {PartialKey: []byte{11}, StorageValue: []byte{1}}, }, }, + maxInlineValueSize: NoMaxInlineValueSize, writes: []writeCall{ { // header written: []byte{branchWithValueVariant.bits | 3}, // partial key length 3 @@ -232,6 +230,7 @@ func Test_Node_Encode(t *testing.T) { nil, nil, nil, {PartialKey: []byte{11}, StorageValue: []byte{1}}, }, }, + maxInlineValueSize: NoMaxInlineValueSize, writes: []writeCall{ { // header written: []byte{branchWithValueVariant.bits | 3}, // partial key length 3 @@ -259,6 +258,7 @@ func Test_Node_Encode(t *testing.T) { nil, nil, nil, {PartialKey: []byte{11}, StorageValue: []byte{1}}, }, }, + maxInlineValueSize: NoMaxInlineValueSize, writes: []writeCall{ { // header written: []byte{branchWithValueVariant.bits | 3}, // partial key length 3 @@ -291,6 +291,7 @@ func Test_Node_Encode(t *testing.T) { nil, nil, nil, {PartialKey: []byte{11}, StorageValue: []byte{1}}, }, }, + maxInlineValueSize: NoMaxInlineValueSize, writes: []writeCall{ { // header written: []byte{branchWithValueVariant.bits | 3}, // partial key length 3 @@ -320,6 +321,7 @@ func Test_Node_Encode(t *testing.T) { nil, nil, nil, {PartialKey: []byte{11}, StorageValue: []byte{1}}, }, }, + maxInlineValueSize: NoMaxInlineValueSize, writes: []writeCall{ { // header written: []byte{branchVariant.bits | 3}, // partial key length 3 @@ -338,16 +340,16 @@ func Test_Node_Encode(t *testing.T) { }, }, }, - "branch_with_hashed_value_success": { + "branch_with_value_gt_max_success": { node: &Node{ PartialKey: []byte{1, 2, 3}, - StorageValue: hashedValue.ToBytes(), - HashedValue: true, + StorageValue: largeValue, Children: []*Node{ nil, nil, nil, {PartialKey: []byte{9}, StorageValue: []byte{1}}, nil, nil, nil, {PartialKey: []byte{11}, StorageValue: []byte{1}}, }, }, + maxInlineValueSize: 32, writes: []writeCall{ { // header written: []byte{branchWithHashedValueVariant.bits | 3}, // partial key length 3 @@ -359,7 +361,7 @@ func Test_Node_Encode(t *testing.T) { written: []byte{136, 0}, }, { - written: hashedValue.ToBytes(), + written: hashedLargeValue, }, { // first children written: []byte{16, 65, 9, 4, 1}, @@ -369,30 +371,6 @@ func Test_Node_Encode(t *testing.T) { }, }, }, - "branch_with_hashed_value_fail_too_short": { - node: &Node{ - PartialKey: []byte{1, 2, 3}, - StorageValue: []byte("tooshort"), - HashedValue: true, - Children: []*Node{ - nil, nil, nil, {PartialKey: []byte{9}, StorageValue: []byte{1}}, - nil, nil, nil, {PartialKey: []byte{11}, StorageValue: []byte{1}}, - }, - }, - writes: []writeCall{ - { // header - written: []byte{branchWithHashedValueVariant.bits | 3}, // partial key length 3 - }, - { // key LE - written: []byte{0x01, 0x23}, - }, - { // children bitmap - written: []byte{136, 0}, - }, - }, - wrappedErr: ErrEncodeHashedValueTooShort, - errMessage: "hashed storage value too short: expected 32, got: 8", - }, } for name, testCase := range testCases { @@ -414,7 +392,7 @@ func Test_Node_Encode(t *testing.T) { previousCall = call } - err := testCase.node.Encode(buffer) + err := testCase.node.Encode(buffer, testCase.maxInlineValueSize) if testCase.wrappedErr != nil { assert.ErrorIs(t, err, testCase.wrappedErr) diff --git a/internal/trie/node/hash.go b/internal/trie/node/hash.go index 595e06fc0f..1dfbca098d 100644 --- a/internal/trie/node/hash.go +++ b/internal/trie/node/hash.go @@ -55,12 +55,12 @@ func hashEncoding(encoding []byte, writer io.Writer) (err error) { } // CalculateMerkleValue returns the Merkle value of the non-root node. -func (n *Node) CalculateMerkleValue() (merkleValue []byte, err error) { +func (n *Node) CalculateMerkleValue(maxInlineValue int) (merkleValue []byte, err error) { if !n.Dirty && n.MerkleValue != nil { return n.MerkleValue, nil } - _, merkleValue, err = n.EncodeAndHash() + _, merkleValue, err = n.EncodeAndHash(maxInlineValue) if err != nil { return nil, fmt.Errorf("encoding and hashing node: %w", err) } @@ -69,13 +69,13 @@ func (n *Node) CalculateMerkleValue() (merkleValue []byte, err error) { } // CalculateRootMerkleValue returns the Merkle value of the root node. -func (n *Node) CalculateRootMerkleValue() (merkleValue []byte, err error) { +func (n *Node) CalculateRootMerkleValue(maxInlineValue int) (merkleValue []byte, err error) { const rootMerkleValueLength = 32 if !n.Dirty && len(n.MerkleValue) == rootMerkleValueLength { return n.MerkleValue, nil } - _, merkleValue, err = n.EncodeAndHashRoot() + _, merkleValue, err = n.EncodeAndHashRoot(maxInlineValue) if err != nil { return nil, fmt.Errorf("encoding and hashing root node: %w", err) } @@ -89,9 +89,9 @@ func (n *Node) CalculateRootMerkleValue() (merkleValue []byte, err error) { // TODO change this function to write to an encoding writer // and a merkle value writer, such that buffer sync pools can be used // by the caller. -func (n *Node) EncodeAndHash() (encoding, merkleValue []byte, err error) { +func (n *Node) EncodeAndHash(maxInlineValue int) (encoding, merkleValue []byte, err error) { encodingBuffer := bytes.NewBuffer(nil) - err = n.Encode(encodingBuffer) + err = n.Encode(encodingBuffer, maxInlineValue) if err != nil { return nil, nil, fmt.Errorf("encoding node: %w", err) } @@ -115,9 +115,9 @@ func (n *Node) EncodeAndHash() (encoding, merkleValue []byte, err error) { // TODO change this function to write to an encoding writer // and a merkle value writer, such that buffer sync pools can be used // by the caller. -func (n *Node) EncodeAndHashRoot() (encoding, merkleValue []byte, err error) { +func (n *Node) EncodeAndHashRoot(maxInlineValue int) (encoding, merkleValue []byte, err error) { encodingBuffer := bytes.NewBuffer(nil) - err = n.Encode(encodingBuffer) + err = n.Encode(encodingBuffer, maxInlineValue) if err != nil { return nil, nil, fmt.Errorf("encoding node: %w", err) } diff --git a/internal/trie/node/hash_test.go b/internal/trie/node/hash_test.go index 7c19b9d5de..adbe06326e 100644 --- a/internal/trie/node/hash_test.go +++ b/internal/trie/node/hash_test.go @@ -197,7 +197,7 @@ func Test_Node_CalculateMerkleValue(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - merkleValue, err := testCase.node.CalculateMerkleValue() + merkleValue, err := testCase.node.CalculateMerkleValue(NoMaxInlineValueSize) assert.ErrorIs(t, err, testCase.errWrapped) if testCase.errWrapped != nil { @@ -259,7 +259,7 @@ func Test_Node_CalculateRootMerkleValue(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - merkleValue, err := testCase.node.CalculateRootMerkleValue() + merkleValue, err := testCase.node.CalculateRootMerkleValue(NoMaxInlineValueSize) assert.ErrorIs(t, err, testCase.errWrapped) if testCase.errWrapped != nil { @@ -346,7 +346,7 @@ func Test_Node_EncodeAndHash(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - encoding, hash, err := testCase.node.EncodeAndHash() + encoding, hash, err := testCase.node.EncodeAndHash(NoMaxInlineValueSize) assert.ErrorIs(t, err, testCase.errWrapped) if testCase.errWrapped != nil { @@ -400,7 +400,7 @@ func Test_Node_EncodeAndHashRoot(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - encoding, hash, err := testCase.node.EncodeAndHashRoot() + encoding, hash, err := testCase.node.EncodeAndHashRoot(NoMaxInlineValueSize) assert.ErrorIs(t, err, testCase.errWrapped) if testCase.errWrapped != nil { diff --git a/internal/trie/node/header.go b/internal/trie/node/header.go index 8863b410a8..30f910be5e 100644 --- a/internal/trie/node/header.go +++ b/internal/trie/node/header.go @@ -10,7 +10,7 @@ import ( ) // encodeHeader writes the encoded header for the node. -func encodeHeader(node *Node, writer io.Writer) (err error) { +func encodeHeader(node *Node, maxInlineValueSize int, writer io.Writer) (err error) { if node == nil { _, err = writer.Write([]byte{emptyVariant.bits}) return err @@ -21,17 +21,19 @@ func encodeHeader(node *Node, writer io.Writer) (err error) { panic(fmt.Sprintf("partial key length is too big: %d", partialKeyLength)) } + isHashedValue := len(node.StorageValue) > maxInlineValueSize + // Merge variant byte and partial key length together var nodeVariant variant if node.Kind() == Leaf { - if node.HashedValue { + if isHashedValue { nodeVariant = leafWithHashedValueVariant } else { nodeVariant = leafVariant } } else if node.StorageValue == nil { nodeVariant = branchVariant - } else if node.HashedValue { + } else if isHashedValue { nodeVariant = branchWithHashedValueVariant } else { nodeVariant = branchWithValueVariant diff --git a/internal/trie/node/header_test.go b/internal/trie/node/header_test.go index 483ffb2730..5b57174628 100644 --- a/internal/trie/node/header_test.go +++ b/internal/trie/node/header_test.go @@ -10,7 +10,6 @@ import ( "sort" "testing" - "github.com/ChainSafe/gossamer/lib/common" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -19,14 +18,14 @@ import ( func Test_encodeHeader(t *testing.T) { t.Parallel() - hashedValue, err := common.Blake2bHash([]byte("test")) - assert.NoError(t, err) + largeValue := []byte("newvaluewithmorethan32byteslength") testCases := map[string]struct { - node *Node - writes []writeCall - errWrapped error - errMessage string + node *Node + writes []writeCall + maxInlineValueSize int + errWrapped error + errMessage string }{ "branch_with_no_key": { node: &Node{ @@ -47,10 +46,10 @@ func Test_encodeHeader(t *testing.T) { }, "branch_with_hashed_value": { node: &Node{ - StorageValue: hashedValue.ToBytes(), - HashedValue: true, + StorageValue: largeValue, Children: make([]*Node, ChildrenCapacity), }, + maxInlineValueSize: 32, writes: []writeCall{ {written: []byte{branchWithHashedValueVariant.bits}}, }, @@ -126,9 +125,9 @@ func Test_encodeHeader(t *testing.T) { }, "leaf_with_hashed_value": { node: &Node{ - StorageValue: hashedValue.ToBytes(), - HashedValue: true, + StorageValue: largeValue, }, + maxInlineValueSize: 32, writes: []writeCall{ {written: []byte{leafWithHashedValueVariant.bits}}, }, @@ -138,6 +137,7 @@ func Test_encodeHeader(t *testing.T) { writes: []writeCall{ {written: []byte{leafVariant.bits}}, }, + maxInlineValueSize: 32, }, "leaf_with_key_of_length_30": { node: &Node{ @@ -240,7 +240,7 @@ func Test_encodeHeader(t *testing.T) { previousCall = call } - err := encodeHeader(testCase.node, writer) + err := encodeHeader(testCase.node, testCase.maxInlineValueSize, writer) assert.ErrorIs(t, err, testCase.errWrapped) if testCase.errWrapped != nil { @@ -258,7 +258,7 @@ func Test_encodeHeader(t *testing.T) { } assert.PanicsWithValue(t, "partial key length is too big: 65536", func() { - _ = encodeHeader(node, io.Discard) + _ = encodeHeader(node, 0, io.Discard) }) }) } @@ -293,7 +293,7 @@ func Test_encodeHeader_At_Maximum(t *testing.T) { PartialKey: make([]byte, keyLength), } - err := encodeHeader(node, buffer) + err := encodeHeader(node, NoMaxInlineValueSize, buffer) require.NoError(t, err) assert.Equal(t, expectedBytes, buffer.Bytes()) diff --git a/internal/trie/node/node.go b/internal/trie/node/node.go index c5a7c83ce0..4510fabd89 100644 --- a/internal/trie/node/node.go +++ b/internal/trie/node/node.go @@ -7,6 +7,7 @@ package node import ( "fmt" + "strconv" "github.com/qdm12/gotree" ) @@ -16,8 +17,8 @@ type Node struct { // PartialKey is the partial key bytes in nibbles (0 to f in hexadecimal) PartialKey []byte StorageValue []byte - // HashedValue is true when the StorageValue is a blake2b hash - HashedValue bool + // IsHashedValue is true when the StorageValue is a blake2b hash + IsHashedValue bool // Generation is incremented on every trie Snapshot() call. // Each node also contain a certain Generation number, // which is updated to match the trie Generation once they are @@ -57,6 +58,7 @@ func (n *Node) StringNode() (stringNode *gotree.Node) { stringNode.Appendf("Dirty: %t", n.Dirty) stringNode.Appendf("Key: " + bytesToString(n.PartialKey)) stringNode.Appendf("Storage value: " + bytesToString(n.StorageValue)) + stringNode.Appendf("IsHashed: " + strconv.FormatBool(n.IsHashedValue)) if n.Descendants > 0 { // must be a branch stringNode.Appendf("Descendants: %d", n.Descendants) } diff --git a/internal/trie/node/node_test.go b/internal/trie/node/node_test.go index af4f2269b8..d24aec011e 100644 --- a/internal/trie/node/node_test.go +++ b/internal/trie/node/node_test.go @@ -27,6 +27,7 @@ func Test_Node_String(t *testing.T) { ├── Dirty: true ├── Key: 0x0102 ├── Storage value: 0x0304 +├── IsHashed: false └── Merkle value: nil`, }, "leaf_with_storage_value_higher_than_1024": { @@ -40,6 +41,7 @@ func Test_Node_String(t *testing.T) { ├── Dirty: true ├── Key: 0x0102 ├── Storage value: 0x0000000000000000...0000000000000000 +├── IsHashed: false └── Merkle value: nil`, }, "branch_with_storage_value_smaller_than_1024": { @@ -66,6 +68,7 @@ func Test_Node_String(t *testing.T) { ├── Dirty: true ├── Key: 0x0102 ├── Storage value: 0x0304 +├── IsHashed: false ├── Descendants: 3 ├── Merkle value: nil ├── Child 3 @@ -74,6 +77,7 @@ func Test_Node_String(t *testing.T) { | ├── Dirty: false | ├── Key: nil | ├── Storage value: nil +| ├── IsHashed: false | └── Merkle value: nil ├── Child 7 | └── Branch @@ -81,6 +85,7 @@ func Test_Node_String(t *testing.T) { | ├── Dirty: false | ├── Key: nil | ├── Storage value: nil +| ├── IsHashed: false | ├── Descendants: 1 | ├── Merkle value: nil | └── Child 0 @@ -89,6 +94,7 @@ func Test_Node_String(t *testing.T) { | ├── Dirty: false | ├── Key: nil | ├── Storage value: nil +| ├── IsHashed: false | └── Merkle value: nil └── Child 11 └── Leaf @@ -96,6 +102,7 @@ func Test_Node_String(t *testing.T) { ├── Dirty: false ├── Key: nil ├── Storage value: nil + ├── IsHashed: false └── Merkle value: nil`, }, "branch_with_storage_value_higher_than_1024": { @@ -122,6 +129,7 @@ func Test_Node_String(t *testing.T) { ├── Dirty: true ├── Key: 0x0102 ├── Storage value: 0x0000000000000000...0000000000000000 +├── IsHashed: false ├── Descendants: 3 ├── Merkle value: nil ├── Child 3 @@ -130,6 +138,7 @@ func Test_Node_String(t *testing.T) { | ├── Dirty: false | ├── Key: nil | ├── Storage value: nil +| ├── IsHashed: false | └── Merkle value: nil ├── Child 7 | └── Branch @@ -137,6 +146,7 @@ func Test_Node_String(t *testing.T) { | ├── Dirty: false | ├── Key: nil | ├── Storage value: nil +| ├── IsHashed: false | ├── Descendants: 1 | ├── Merkle value: nil | └── Child 0 @@ -145,6 +155,7 @@ func Test_Node_String(t *testing.T) { | ├── Dirty: false | ├── Key: nil | ├── Storage value: nil +| ├── IsHashed: false | └── Merkle value: nil └── Child 11 └── Leaf @@ -152,6 +163,7 @@ func Test_Node_String(t *testing.T) { ├── Dirty: false ├── Key: nil ├── Storage value: nil + ├── IsHashed: false └── Merkle value: nil`, }, } diff --git a/lib/babe/helpers_test.go b/lib/babe/helpers_test.go index d7b5804b2c..0aa192bcb9 100644 --- a/lib/babe/helpers_test.go +++ b/lib/babe/helpers_test.go @@ -375,7 +375,7 @@ func newWestendLocalGenesisWithTrieAndHeader(t *testing.T) ( require.NoError(t, err) genesisHeader = *types.NewHeader(common.NewHash([]byte{0}), - genesisTrie.MustHash(), trie.EmptyHash, 0, types.NewDigest()) + genesisTrie.MustHash(trie.NoMaxInlineValueSize), trie.EmptyHash, 0, types.NewDigest()) return gen, genesisTrie, genesisHeader } @@ -394,7 +394,7 @@ func newWestendDevGenesisWithTrieAndHeader(t *testing.T) ( require.NoError(t, err) genesisHeader = *types.NewHeader(common.NewHash([]byte{0}), - genesisTrie.MustHash(), trie.EmptyHash, 0, types.NewDigest()) + genesisTrie.MustHash(trie.NoMaxInlineValueSize), trie.EmptyHash, 0, types.NewDigest()) return gen, genesisTrie, genesisHeader } diff --git a/lib/common/hash.go b/lib/common/hash.go index 3f2ebaf99d..28798afb46 100644 --- a/lib/common/hash.go +++ b/lib/common/hash.go @@ -18,6 +18,8 @@ const ( HashLength = 32 ) +var EmptyHash = Hash{} + // Hash used to store a blake2b hash type Hash [32]byte @@ -40,7 +42,7 @@ func HashValidator(field reflect.Value) interface{} { // Try to convert to hash type. if valuer, ok := field.Interface().(Hash); ok { // Check if the hash is empty. - if valuer == (Hash{}) { + if valuer == (EmptyHash) { return "" } return valuer.ToBytes() @@ -50,7 +52,7 @@ func HashValidator(field reflect.Value) interface{} { // IsEmpty returns true if the hash is empty, false otherwise. func (h Hash) IsEmpty() bool { //skipcq: GO-W1029 - return h == Hash{} + return h == EmptyHash } // String returns the hex string for the hash @@ -79,7 +81,7 @@ func ReadHash(r io.Reader) (Hash, error) { buf := make([]byte, 32) _, err := r.Read(buf) if err != nil { - return Hash{}, err + return EmptyHash, err } h := [32]byte{} copy(h[:], buf) diff --git a/lib/grandpa/helpers_integration_test.go b/lib/grandpa/helpers_integration_test.go index ed65b1b582..b0105ff546 100644 --- a/lib/grandpa/helpers_integration_test.go +++ b/lib/grandpa/helpers_integration_test.go @@ -203,7 +203,7 @@ func newWestendDevGenesisWithTrieAndHeader(t *testing.T) ( assert.NoError(t, err) parentHash := common.NewHash([]byte{0}) - stateRoot := genesisTrie.MustHash() + stateRoot := genesisTrie.MustHash(trie.NoMaxInlineValueSize) extrinsicRoot := trie.EmptyHash const number = 0 digest := types.NewDigest() diff --git a/lib/runtime/interfaces.go b/lib/runtime/interfaces.go index 2f813a21f6..131b23cad4 100644 --- a/lib/runtime/interfaces.go +++ b/lib/runtime/interfaces.go @@ -13,7 +13,7 @@ import ( type Storage interface { Put(key []byte, value []byte) (err error) Get(key []byte) []byte - Root() (common.Hash, error) + Root(maxInlineValueSize int) (common.Hash, error) SetChild(keyToChild []byte, child *trie.Trie) error SetChildStorage(keyToChild, key, value []byte) error GetChildStorage(keyToChild, key []byte) ([]byte, error) diff --git a/lib/runtime/storage/trie.go b/lib/runtime/storage/trie.go index b88e8fb56a..2a21377222 100644 --- a/lib/runtime/storage/trie.go +++ b/lib/runtime/storage/trie.go @@ -84,13 +84,13 @@ func (s *TrieState) Get(key []byte) []byte { } // MustRoot returns the trie's root hash. It panics if it fails to compute the root. -func (s *TrieState) MustRoot() common.Hash { - return s.t.MustHash() +func (s *TrieState) MustRoot(maxInlineValue int) common.Hash { + return s.t.MustHash(maxInlineValue) } // Root returns the trie's root hash -func (s *TrieState) Root() (common.Hash, error) { - return s.t.Hash() +func (s *TrieState) Root(maxInlineValue int) (common.Hash, error) { + return s.t.Hash(maxInlineValue) } // Has returns whether or not a key exists diff --git a/lib/runtime/storage/trie_test.go b/lib/runtime/storage/trie_test.go index 01e0e2f732..e913b89435 100644 --- a/lib/runtime/storage/trie_test.go +++ b/lib/runtime/storage/trie_test.go @@ -40,6 +40,49 @@ func TestTrieState_SetGet(t *testing.T) { testFunc(ts) } +func TestTrieState_SetGetChildStorage(t *testing.T) { + ts := &TrieState{t: trie.NewEmptyTrie()} + + for _, tc := range testCases { + childTrie := trie.NewEmptyTrie() + err := ts.SetChild([]byte(tc), childTrie) + require.NoError(t, err) + + err = ts.SetChildStorage([]byte(tc), []byte(tc), []byte(tc)) + require.NoError(t, err) + } + + for _, tc := range testCases { + res, err := ts.GetChildStorage([]byte(tc), []byte(tc)) + require.NoError(t, err) + require.Equal(t, []byte(tc), res) + } +} + +func TestTrieState_SetAndClearFromChild(t *testing.T) { + testFunc := func(ts *TrieState) { + for _, tc := range testCases { + childTrie := trie.NewEmptyTrie() + err := ts.SetChild([]byte(tc), childTrie) + require.NoError(t, err) + + err = ts.SetChildStorage([]byte(tc), []byte(tc), []byte(tc)) + require.NoError(t, err) + } + + for _, tc := range testCases { + err := ts.ClearChildStorage([]byte(tc), []byte(tc)) + require.NoError(t, err) + + _, err = ts.GetChildStorage([]byte(tc), []byte(tc)) + require.ErrorContains(t, err, "child trie does not exist at key") + } + } + + ts := &TrieState{t: trie.NewEmptyTrie()} + testFunc(ts) +} + func TestTrieState_Delete(t *testing.T) { testFunc := func(ts *TrieState) { for _, tc := range testCases { @@ -61,8 +104,8 @@ func TestTrieState_Root(t *testing.T) { ts.Put([]byte(tc), []byte(tc)) } - expected := ts.MustRoot() - require.Equal(t, expected, ts.MustRoot()) + expected := ts.MustRoot(trie.NoMaxInlineValueSize) + require.Equal(t, expected, ts.MustRoot(trie.NoMaxInlineValueSize)) } ts := &TrieState{t: trie.NewEmptyTrie()} diff --git a/lib/runtime/wazero/imports.go b/lib/runtime/wazero/imports.go index 73a1295ee3..fa4a18fa78 100644 --- a/lib/runtime/wazero/imports.go +++ b/lib/runtime/wazero/imports.go @@ -33,8 +33,9 @@ var ( log.AddContext("module", "wazero"), ) - noneEncoded []byte = []byte{0x00} - allZeroesBytes = [32]byte{} + emptyByteVectorEncoded []byte = scale.MustMarshal([]byte{}) + noneEncoded []byte = []byte{0x00} + allZeroesBytes = [32]byte{} ) const ( @@ -781,33 +782,34 @@ func ext_crypto_finish_batch_verify_version_1(ctx context.Context, m api.Module) } func ext_trie_blake2_256_root_version_1(ctx context.Context, m api.Module, dataSpan uint64) uint32 { + return ext_trie_blake2_256_root_version_2(ctx, m, dataSpan, 0) +} + +func ext_trie_blake2_256_root_version_2(ctx context.Context, m api.Module, dataSpan uint64, version uint32) uint32 { rtCtx := ctx.Value(runtimeContextKey).(*runtime.Context) if rtCtx == nil { panic("nil runtime context") } - data := read(m, dataSpan) - - t := trie.NewEmptyTrie() - - type kv struct { - Key, Value []byte + stateVersion, err := trie.ParseVersion(version) + if err != nil { + logger.Errorf("failed parsing state version: %s", err) + return 0 } + data := read(m, dataSpan) + // this function is expecting an array of (key, value) tuples - var kvs []kv - if err := scale.Unmarshal(data, &kvs); err != nil { + var entries trie.Entries + if err := scale.Unmarshal(data, &entries); err != nil { logger.Errorf("failed scale decoding data: %s", err) return 0 } - for _, kv := range kvs { - err := t.Put(kv.Key, kv.Value) - if err != nil { - logger.Errorf("failed putting key 0x%x and value 0x%x into trie: %s", - kv.Key, kv.Value, err) - return 0 - } + hash, err := stateVersion.Root(entries) + if err != nil { + logger.Errorf("failed computing trie Merkle root hash: %s", err) + return 0 } // allocate memory for value and copy value to memory @@ -817,18 +819,17 @@ func ext_trie_blake2_256_root_version_1(ctx context.Context, m api.Module, dataS return 0 } - hash, err := t.Hash() - if err != nil { - logger.Errorf("failed computing trie Merkle root hash: %s", err) - return 0 - } - logger.Debugf("root hash is %s", hash) m.Memory().Write(ptr, hash[:]) return ptr } func ext_trie_blake2_256_ordered_root_version_1(ctx context.Context, m api.Module, dataSpan uint64) uint32 { + return ext_trie_blake2_256_ordered_root_version_2(ctx, m, dataSpan, 0) +} + +func ext_trie_blake2_256_ordered_root_version_2( + ctx context.Context, m api.Module, dataSpan uint64, version uint32) uint32 { rtCtx := ctx.Value(runtimeContextKey).(*runtime.Context) if rtCtx == nil { panic("nil runtime context") @@ -836,30 +837,29 @@ func ext_trie_blake2_256_ordered_root_version_1(ctx context.Context, m api.Modul data := read(m, dataSpan) - t := trie.NewEmptyTrie() + stateVersion, err := trie.ParseVersion(version) + if err != nil { + logger.Errorf("failed parsing state version: %s", err) + return 0 + } + var values [][]byte - err := scale.Unmarshal(data, &values) + err = scale.Unmarshal(data, &values) if err != nil { logger.Errorf("failed scale decoding data: %s", err) return 0 } + var entries trie.Entries + for i, value := range values { key, err := scale.Marshal(big.NewInt(int64(i))) if err != nil { logger.Errorf("failed scale encoding value index %d: %s", i, err) return 0 } - logger.Tracef( - "put key=0x%x and value=0x%x", - key, value) - err = t.Put(key, value) - if err != nil { - logger.Errorf("failed putting key 0x%x and value 0x%x into trie: %s", - key, value, err) - return 0 - } + entries = append(entries, trie.Entry{Key: key, Value: value}) } // allocate memory for value and copy value to memory @@ -869,7 +869,7 @@ func ext_trie_blake2_256_ordered_root_version_1(ctx context.Context, m api.Modul return 0 } - hash, err := t.Hash() + hash, err := stateVersion.Root(entries) if err != nil { logger.Errorf("failed computing trie Merkle root hash: %s", err) return 0 @@ -880,12 +880,6 @@ func ext_trie_blake2_256_ordered_root_version_1(ctx context.Context, m api.Modul return ptr } -func ext_trie_blake2_256_ordered_root_version_2( - ctx context.Context, m api.Module, dataSpan uint64, version uint32) uint32 { - // TODO: update to use state trie version 1 (#2418) - return ext_trie_blake2_256_ordered_root_version_1(ctx, m, dataSpan) -} - func ext_trie_blake2_256_verify_proof_version_1( ctx context.Context, m api.Module, rootSpan uint32, proofSpan, keySpan, valueSpan uint64) uint32 { rtCtx := ctx.Value(runtimeContextKey).(*runtime.Context) @@ -898,7 +892,45 @@ func ext_trie_blake2_256_verify_proof_version_1( err := scale.Unmarshal(toDecProofs, &encodedProofNodes) if err != nil { logger.Errorf("failed scale decoding proof data: %s", err) - return uint32(0) + return 0 + } + + key := read(m, keySpan) + value := read(m, valueSpan) + + trieRoot, ok := m.Memory().Read(rootSpan, 32) + if !ok { + panic("read overflow") + } + + err = proof.Verify(encodedProofNodes, trieRoot, key, value) + if err != nil { + logger.Errorf("failed proof verification: %s", err) + return 0 + } + + return 1 +} + +func ext_trie_blake2_256_verify_proof_version_2( + ctx context.Context, m api.Module, rootSpan uint32, proofSpan, keySpan, valueSpan uint64, version uint32) uint32 { + rtCtx := ctx.Value(runtimeContextKey).(*runtime.Context) + if rtCtx == nil { + panic("nil runtime context") + } + + _, err := trie.ParseVersion(version) + if err != nil { + logger.Errorf("failed parsing state version: %s", err) + return 0 + } + + toDecProofs := read(m, proofSpan) + var encodedProofNodes [][]byte + err = scale.Unmarshal(toDecProofs, &encodedProofNodes) + if err != nil { + logger.Errorf("failed scale decoding proof data: %s", err) + return 0 } key := read(m, keySpan) @@ -1210,7 +1242,7 @@ func ext_default_child_storage_root_version_1( return 0 } - childRoot, err := child.Hash() + childRoot, err := trie.V0.Hash(child) if err != nil { logger.Errorf("failed to encode child root: %s", err) return 0 @@ -1226,9 +1258,37 @@ func ext_default_child_storage_root_version_1( //export ext_default_child_storage_root_version_2 func ext_default_child_storage_root_version_2(ctx context.Context, m api.Module, childStorageKey uint64, - stateVersion uint32) (ptrSize uint64) { - // TODO: Implement this after we have storage trie version 1 implemented #2418 - return ext_default_child_storage_root_version_1(ctx, m, childStorageKey) + version uint32) (ptrSize uint64) { //skipcq: RVV-B0012 + rtCtx := ctx.Value(runtimeContextKey).(*runtime.Context) + if rtCtx == nil { + panic("nil runtime context") + } + storage := rtCtx.Storage + key := read(m, childStorageKey) + child, err := storage.GetChild(key) + if err != nil { + logger.Errorf("failed to retrieve child: %s", err) + return mustWrite(m, rtCtx.Allocator, emptyByteVectorEncoded) + } + + stateVersion, err := trie.ParseVersion(version) + if err != nil { + logger.Errorf("failed parsing state version: %s", err) + return 0 + } + + childRoot, err := stateVersion.Hash(child) + if err != nil { + logger.Errorf("failed to encode child root: %s", err) + return mustWrite(m, rtCtx.Allocator, emptyByteVectorEncoded) + } + childRootSlice := childRoot[:] + + ret, err := write(m, rtCtx.Allocator, scale.MustMarshal(&childRootSlice)) + if err != nil { + panic(err) + } + return ret } func ext_default_child_storage_storage_kill_version_1(ctx context.Context, m api.Module, childStorageKeySpan uint64) { @@ -2183,7 +2243,7 @@ func ext_storage_root_version_1(ctx context.Context, m api.Module) uint64 { } storage := rtCtx.Storage - root, err := storage.Root() + root, err := storage.Root(trie.V0.MaxInlineValue()) if err != nil { logger.Errorf("failed to get storage root: %s", err) panic(err) @@ -2200,8 +2260,32 @@ func ext_storage_root_version_1(ctx context.Context, m api.Module) uint64 { } func ext_storage_root_version_2(ctx context.Context, m api.Module, version uint32) uint64 { //skipcq: RVV-B0012 - // TODO: update to use state trie version 1 (#2418) - return ext_storage_root_version_1(ctx, m) + rtCtx := ctx.Value(runtimeContextKey).(*runtime.Context) + if rtCtx == nil { + panic("nil runtime context") + } + storage := rtCtx.Storage + + stateVersion, err := trie.ParseVersion(version) + if err != nil { + logger.Errorf("failed parsing state version: %s", err) + return mustWrite(m, rtCtx.Allocator, emptyByteVectorEncoded) + } + + root, err := storage.Root(stateVersion.MaxInlineValue()) + if err != nil { + logger.Errorf("failed to get storage root: %s", err) + panic(err) + } + + logger.Debugf("root hash is: %s", root) + + rootSpan, err := write(m, rtCtx.Allocator, root[:]) + if err != nil { + logger.Errorf("failed to allocate: %s", err) + panic(err) + } + return rootSpan } func ext_storage_set_version_1(ctx context.Context, m api.Module, keySpan, valueSpan uint64) { diff --git a/lib/runtime/wazero/imports_test.go b/lib/runtime/wazero/imports_test.go index a1a2cc9430..5ecd82b41d 100644 --- a/lib/runtime/wazero/imports_test.go +++ b/lib/runtime/wazero/imports_test.go @@ -554,7 +554,9 @@ func Test_ext_trie_blake2_256_root_version_1(t *testing.T) { require.NoError(t, err) encInput[0] = encInput[0] >> 1 - res, err := inst.Exec("rtm_ext_trie_blake2_256_root_version_1", encInput) + data := encInput + + res, err := inst.Exec("rtm_ext_trie_blake2_256_root_version_1", data) require.NoError(t, err) var hash []byte @@ -565,7 +567,39 @@ func Test_ext_trie_blake2_256_root_version_1(t *testing.T) { tt.Put([]byte("noot"), []byte("was")) tt.Put([]byte("here"), []byte("??")) - expected := tt.MustHash() + expected := tt.MustHash(trie.NoMaxInlineValueSize) + require.Equal(t, expected[:], hash) +} + +func Test_ext_trie_blake2_256_root_version_2(t *testing.T) { + inst := NewTestInstance(t, runtime.HOST_API_TEST_RUNTIME) + + testinput := []string{"dimartiro", "was", "here", "??"} + encInput, err := scale.Marshal(testinput) + require.NoError(t, err) + encInput[0] = encInput[0] >> 1 + + stateVersion := trie.V1 + + stateVersionInt := uint32(stateVersion) + encVersion, err := scale.Marshal(stateVersionInt) + require.NoError(t, err) + + data := append([]byte{}, encInput...) + data = append(data, encVersion...) + + res, err := inst.Exec("rtm_ext_trie_blake2_256_root_version_2", data) + require.NoError(t, err) + + var hash []byte + err = scale.Unmarshal(res, &hash) + require.NoError(t, err) + + tt := trie.NewEmptyTrie() + tt.Put([]byte("dimartiro"), []byte("was")) + tt.Put([]byte("here"), []byte("??")) + + expected := tt.MustHash(stateVersion.MaxInlineValue()) require.Equal(t, expected[:], hash) } @@ -587,15 +621,43 @@ func Test_ext_trie_blake2_256_ordered_root_version_1(t *testing.T) { require.Equal(t, expected[:], hash) } +func Test_ext_trie_blake2_256_ordered_root_version_2(t *testing.T) { + inst := NewTestInstance(t, runtime.HOST_API_TEST_RUNTIME) + + testvalues := []string{"static", "even-keeled", "Future-proofed"} + encValues, err := scale.Marshal(testvalues) + require.NoError(t, err) + + stateVersion := uint32(trie.V1) + encVersion, err := scale.Marshal(stateVersion) + require.NoError(t, err) + + data := append([]byte{}, encValues...) + data = append(data, encVersion...) + + res, err := inst.Exec("rtm_ext_trie_blake2_256_ordered_root_version_2", data) + require.NoError(t, err) + + var hash []byte + err = scale.Unmarshal(res, &hash) + require.NoError(t, err) + + expected := common.MustHexToHash("0xd847b86d0219a384d11458e829e9f4f4cce7e3cc2e6dcd0e8a6ad6f12c64a737") + require.Equal(t, expected[:], hash) +} + func Test_ext_trie_blake2_256_verify_proof_version_1(t *testing.T) { tmp := t.TempDir() memdb, err := database.NewPebble(tmp, true) require.NoError(t, err) + // Since this is Test_ext_trie_blake2_256_verify_proof_version_1, we use trie.V0 + stateVersion := trie.V0 + otherTrie := trie.NewEmptyTrie() otherTrie.Put([]byte("simple"), []byte("cat")) - otherHash, err := otherTrie.Hash() + otherHash, err := stateVersion.Hash(otherTrie) require.NoError(t, err) tr := trie.NewEmptyTrie() @@ -608,7 +670,7 @@ func Test_ext_trie_blake2_256_verify_proof_version_1(t *testing.T) { err = tr.WriteDirty(memdb) require.NoError(t, err) - hash, err := tr.Hash() + hash, err := stateVersion.Hash(tr) require.NoError(t, err) keys := [][]byte{ @@ -679,6 +741,106 @@ func Test_ext_trie_blake2_256_verify_proof_version_1(t *testing.T) { } } +func Test_ext_trie_blake2_256_verify_proof_version_2(t *testing.T) { + tmp := t.TempDir() + memdb, err := database.NewPebble(tmp, true) + require.NoError(t, err) + + stateVersion := trie.V1 + + stateVersionInt := uint32(stateVersion) + encVersion, err := scale.Marshal(stateVersionInt) + require.NoError(t, err) + + otherTrie := trie.NewEmptyTrie() + otherTrie.Put([]byte("simple"), []byte("cat")) + + otherHash, err := stateVersion.Hash(otherTrie) + require.NoError(t, err) + + tr := trie.NewEmptyTrie() + tr.Put([]byte("do"), []byte("verb")) + tr.Put([]byte("domain"), []byte("website")) + tr.Put([]byte("other"), []byte("random")) + tr.Put([]byte("otherwise"), []byte("randomstuff")) + tr.Put([]byte("cat"), []byte("another animal")) + + err = tr.WriteDirty(memdb) + require.NoError(t, err) + + hash, err := stateVersion.Hash(tr) + require.NoError(t, err) + + keys := [][]byte{ + []byte("do"), + []byte("domain"), + []byte("other"), + []byte("otherwise"), + []byte("cat"), + } + + root := hash.ToBytes() + otherRoot := otherHash.ToBytes() + + allProofs, err := proof.Generate(root, keys, memdb) + require.NoError(t, err) + + testcases := map[string]struct { + root, key, value []byte + proof [][]byte + expect bool + }{ + "Proof_should_be_true": { + root: root, key: []byte("do"), proof: allProofs, value: []byte("verb"), expect: true}, + "Root_empty,_proof_should_be_false": { + root: []byte{}, key: []byte("do"), proof: allProofs, value: []byte("verb"), expect: false}, + "Other_root,_proof_should_be_false": { + root: otherRoot, key: []byte("do"), proof: allProofs, value: []byte("verb"), expect: false}, + "Value_empty,_proof_should_be_true": { + root: root, key: []byte("do"), proof: allProofs, value: nil, expect: true}, + "Unknow_key,_proof_should_be_false": { + root: root, key: []byte("unknow"), proof: allProofs, value: nil, expect: false}, + "Key_and_value_unknow,_proof_should_be_false": { + root: root, key: []byte("unknow"), proof: allProofs, value: []byte("unknow"), expect: false}, + "Empty_proof,_should_be_false": { + root: root, key: []byte("do"), proof: [][]byte{}, value: nil, expect: false}, + } + + inst := NewTestInstance(t, runtime.HOST_API_TEST_RUNTIME) + + for name, testcase := range testcases { + testcase := testcase + t.Run(name, func(t *testing.T) { + hashEnc, err := scale.Marshal(testcase.root) + require.NoError(t, err) + + args := hashEnc + + encProof, err := scale.Marshal(testcase.proof) + require.NoError(t, err) + args = append(args, encProof...) + + keyEnc, err := scale.Marshal(testcase.key) + require.NoError(t, err) + args = append(args, keyEnc...) + + valueEnc, err := scale.Marshal(testcase.value) + require.NoError(t, err) + args = append(args, valueEnc...) + + args = append(args, encVersion...) + + res, err := inst.Exec("rtm_ext_trie_blake2_256_verify_proof_version_2", args) + require.NoError(t, err) + + var got bool + err = scale.Unmarshal(res, &got) + require.NoError(t, err) + require.Equal(t, testcase.expect, got) + }) + } +} + func Test_ext_misc_runtime_version_version_1(t *testing.T) { inst := NewTestInstance(t, runtime.HOST_API_TEST_RUNTIME) @@ -1128,7 +1290,9 @@ func Test_ext_default_child_storage_root_version_1(t *testing.T) { child, err := inst.Context.Storage.GetChild(testChildKey) require.NoError(t, err) - rootHash, err := child.Hash() + stateVersion := trie.V0 + + rootHash, err := stateVersion.Hash(child) require.NoError(t, err) encChildKey, err := scale.Marshal(testChildKey) @@ -1148,6 +1312,45 @@ func Test_ext_default_child_storage_root_version_1(t *testing.T) { require.Equal(t, rootHash, actualValue) } +func Test_ext_default_child_storage_root_version_2(t *testing.T) { + inst := NewTestInstance(t, runtime.HOST_API_TEST_RUNTIME) + + stateVersion := trie.V1 + + err := inst.Context.Storage.SetChild(testChildKey, trie.NewEmptyTrie()) + require.NoError(t, err) + + err = inst.Context.Storage.SetChildStorage(testChildKey, testKey, testValue) + require.NoError(t, err) + + child, err := inst.Context.Storage.GetChild(testChildKey) + require.NoError(t, err) + + rootHash, err := stateVersion.Hash(child) + require.NoError(t, err) + + encChildKey, err := scale.Marshal(testChildKey) + require.NoError(t, err) + + stateVersionInt := uint32(stateVersion) + encVersion, err := scale.Marshal(stateVersionInt) + require.NoError(t, err) + + data := append([]byte{}, encChildKey...) + data = append(data, encVersion...) + + ret, err := inst.Exec("rtm_ext_default_child_storage_root_version_2", data) + require.NoError(t, err) + + var hash []byte + err = scale.Unmarshal(ret, &hash) + require.NoError(t, err) + + // Convert decoded interface to common Hash + actualValue := common.BytesToHash(hash) + require.Equal(t, rootHash, actualValue) +} + func Test_ext_default_child_storage_storage_kill_version_1(t *testing.T) { inst := NewTestInstance(t, runtime.HOST_API_TEST_RUNTIME) @@ -1989,6 +2192,24 @@ func Test_ext_storage_root_version_1(t *testing.T) { require.Equal(t, expected[:], hash) } +func Test_ext_storage_root_version_2(t *testing.T) { + inst := NewTestInstance(t, runtime.HOST_API_TEST_RUNTIME) + + stateVersion := uint32(trie.V1) + encVersion, err := scale.Marshal(stateVersion) + require.NoError(t, err) + + ret, err := inst.Exec("rtm_ext_storage_root_version_2", encVersion) + require.NoError(t, err) + + var hash []byte + err = scale.Unmarshal(ret, &hash) + require.NoError(t, err) + + expected := trie.EmptyHash + require.Equal(t, expected[:], hash) +} + func Test_ext_storage_set_version_1(t *testing.T) { inst := NewTestInstance(t, runtime.HOST_API_TEST_RUNTIME) diff --git a/lib/runtime/wazero/instance.go b/lib/runtime/wazero/instance.go index 6358c36d92..8a793691a9 100644 --- a/lib/runtime/wazero/instance.go +++ b/lib/runtime/wazero/instance.go @@ -208,9 +208,7 @@ func NewInstance(code []byte, cfg Config) (instance *Instance, err error) { WithFunc(ext_trie_blake2_256_root_version_1). Export("ext_trie_blake2_256_root_version_1"). NewFunctionBuilder(). - WithFunc(func(a int64, v int32) int32 { - panic("ext_trie_blake2_256_root_version_2 unimplemented") - }). + WithFunc(ext_trie_blake2_256_root_version_2). Export("ext_trie_blake2_256_root_version_2"). NewFunctionBuilder(). WithFunc(ext_trie_blake2_256_ordered_root_version_1). @@ -222,9 +220,7 @@ func NewInstance(code []byte, cfg Config) (instance *Instance, err error) { WithFunc(ext_trie_blake2_256_verify_proof_version_1). Export("ext_trie_blake2_256_verify_proof_version_1"). NewFunctionBuilder(). - WithFunc(func(a int32, b int64, c int64, d int64, v int32) int32 { - panic("ext_trie_blake2_256_verify_proof_version_2 unimplemented") - }). + WithFunc(ext_trie_blake2_256_verify_proof_version_2). Export("ext_trie_blake2_256_verify_proof_version_2"). NewFunctionBuilder(). WithFunc(ext_misc_print_hex_version_1). diff --git a/lib/runtime/wazero/instance_test.go b/lib/runtime/wazero/instance_test.go index 511dfc3ed8..8a45a2df3d 100644 --- a/lib/runtime/wazero/instance_test.go +++ b/lib/runtime/wazero/instance_test.go @@ -230,7 +230,7 @@ func TestWestendRuntime_ValidateTransaction(t *testing.T) { genesisHeader := &types.Header{ Number: 0, - StateRoot: genTrie.MustHash(), + StateRoot: trie.V0.MustHash(genTrie), // Get right state version from runtime } extHex := runtime.NewTestExtrinsic(t, rt, genesisHeader.Hash(), genesisHeader.Hash(), @@ -435,7 +435,7 @@ func TestInstance_BadSignature_WestendBlock8077850(t *testing.T) { genesisHeader := &types.Header{ Number: 0, - StateRoot: genTrie.MustHash(), + StateRoot: trie.V0.MustHash(genTrie), // Use right version from runtime } header := &types.Header{ @@ -461,7 +461,7 @@ func TestInstance_BadSignature_WestendBlock8077850(t *testing.T) { genesisHeader := &types.Header{ Number: 0, - StateRoot: genTrie.MustHash(), + StateRoot: trie.V0.MustHash(genTrie), // Use right version from runtime } header := &types.Header{ @@ -634,7 +634,7 @@ func TestInstance_ApplyExtrinsic_WestendRuntime(t *testing.T) { genesisHeader := &types.Header{ Number: 0, - StateRoot: genTrie.MustHash(), + StateRoot: trie.V0.MustHash(genTrie), // Use right version from runtime } header := &types.Header{ ParentHash: genesisHeader.Hash(), @@ -675,7 +675,7 @@ func TestInstance_ExecuteBlock_PolkadotRuntime_PolkadotBlock1(t *testing.T) { require.NoError(t, err) expectedGenesisRoot := common.MustHexToHash("0x29d0d972cd27cbc511e9589fcb7a4506d5eb6a9e8df205f00472e5ab354a4e17") - require.Equal(t, expectedGenesisRoot, genTrie.MustHash()) + require.Equal(t, expectedGenesisRoot, trie.V0.MustHash(genTrie)) // set state to genesis state genState := storage.NewTrieState(&genTrie) @@ -725,7 +725,7 @@ func TestInstance_ExecuteBlock_KusamaRuntime_KusamaBlock1(t *testing.T) { require.NoError(t, err) expectedGenesisRoot := common.MustHexToHash("0xb0006203c3a6e6bd2c6a17b1d4ae8ca49a31da0f4579da950b127774b44aef6b") - require.Equal(t, expectedGenesisRoot, genTrie.MustHash()) + require.Equal(t, expectedGenesisRoot, trie.V0.MustHash(genTrie)) // set state to genesis state genState := storage.NewTrieState(&genTrie) @@ -771,7 +771,7 @@ func TestInstance_ExecuteBlock_KusamaRuntime_KusamaBlock1(t *testing.T) { func TestInstance_ExecuteBlock_KusamaRuntime_KusamaBlock3784(t *testing.T) { gossTrie3783 := newTrieFromPairs(t, "../test_data/kusama/block3783.out") expectedRoot := common.MustHexToHash("0x948338bc0976aee78879d559a1f42385407e5a481b05a91d2a9386aa7507e7a0") - require.Equal(t, expectedRoot, gossTrie3783.MustHash()) + require.Equal(t, expectedRoot, trie.V0.MustHash(*gossTrie3783)) // set state to genesis state state3783 := storage.NewTrieState(gossTrie3783) @@ -817,7 +817,7 @@ func TestInstance_ExecuteBlock_KusamaRuntime_KusamaBlock3784(t *testing.T) { func TestInstance_ExecuteBlock_KusamaRuntime_KusamaBlock901442(t *testing.T) { ksmTrie901441 := newTrieFromPairs(t, "../test_data/kusama/block901441.out") expectedRoot := common.MustHexToHash("0x3a2ef7ee032f5810160bb8f3ffe3e3377bb6f2769ee9f79a5425973347acd504") - require.Equal(t, expectedRoot, ksmTrie901441.MustHash()) + require.Equal(t, expectedRoot, trie.V0.MustHash(*ksmTrie901441)) // set state to genesis state state901441 := storage.NewTrieState(ksmTrie901441) @@ -863,7 +863,7 @@ func TestInstance_ExecuteBlock_KusamaRuntime_KusamaBlock901442(t *testing.T) { func TestInstance_ExecuteBlock_KusamaRuntime_KusamaBlock1377831(t *testing.T) { ksmTrie := newTrieFromPairs(t, "../test_data/kusama/block1377830.out") expectedRoot := common.MustHexToHash("0xe4de6fecda9e9e35f937d159665cf984bc1a68048b6c78912de0aeb6bd7f7e99") - require.Equal(t, expectedRoot, ksmTrie.MustHash()) + require.Equal(t, expectedRoot, trie.V0.MustHash(*ksmTrie)) // set state to genesis state state := storage.NewTrieState(ksmTrie) @@ -909,7 +909,7 @@ func TestInstance_ExecuteBlock_KusamaRuntime_KusamaBlock1377831(t *testing.T) { func TestInstance_ExecuteBlock_KusamaRuntime_KusamaBlock1482003(t *testing.T) { ksmTrie := newTrieFromPairs(t, "../test_data/kusama/block1482002.out") expectedRoot := common.MustHexToHash("0x09f9ca28df0560c2291aa16b56e15e07d1e1927088f51356d522722aa90ca7cb") - require.Equal(t, expectedRoot, ksmTrie.MustHash()) + require.Equal(t, expectedRoot, trie.V0.MustHash(*ksmTrie)) // set state to genesis state state := storage.NewTrieState(ksmTrie) @@ -956,7 +956,7 @@ func TestInstance_ExecuteBlock_KusamaRuntime_KusamaBlock1482003(t *testing.T) { func TestInstance_ExecuteBlock_PolkadotBlock1089328(t *testing.T) { dotTrie := newTrieFromPairs(t, "../test_data/polkadot/block1089327.json") expectedRoot := common.MustHexToHash("0x87ed9ebe7fb645d3b5b0255cc16e78ed022d9fbb52486105436e15a74557535b") - require.Equal(t, expectedRoot, dotTrie.MustHash()) + require.Equal(t, expectedRoot, trie.V0.MustHash(*dotTrie)) // set state to genesis state state := storage.NewTrieState(dotTrie) diff --git a/lib/trie/child_storage.go b/lib/trie/child_storage.go index 00dd28d20b..2f3cd008da 100644 --- a/lib/trie/child_storage.go +++ b/lib/trie/child_storage.go @@ -19,7 +19,7 @@ var ErrChildTrieDoesNotExist = errors.New("child trie does not exist") // A child trie is added as a node (K, V) in the main trie. K is the child storage key // associated to the child trie, and V is the root hash of the child trie. func (t *Trie) SetChild(keyToChild []byte, child *Trie) error { - childHash, err := child.Hash() + childHash, err := child.Hash(NoMaxInlineValueSize) if err != nil { return err } @@ -62,7 +62,7 @@ func (t *Trie) PutIntoChild(keyToChild, key, value []byte) error { } } - origChildHash, err := child.Hash() + origChildHash, err := child.Hash(NoMaxInlineValueSize) if err != nil { return err } @@ -116,7 +116,7 @@ func (t *Trie) ClearFromChild(keyToChild, key []byte) error { return fmt.Errorf("%w at key 0x%x%x", ErrChildTrieDoesNotExist, ChildStorageKeyPrefix, keyToChild) } - origChildHash, err := child.Hash() + origChildHash, err := child.Hash(NoMaxInlineValueSize) if err != nil { return err } diff --git a/lib/trie/child_storage_test.go b/lib/trie/child_storage_test.go index eb922f102f..d75dda0eaf 100644 --- a/lib/trie/child_storage_test.go +++ b/lib/trie/child_storage_test.go @@ -4,11 +4,10 @@ package trie import ( - "bytes" "encoding/binary" - "reflect" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -18,18 +17,46 @@ func TestPutAndGetChild(t *testing.T) { parentTrie := NewEmptyTrie() err := parentTrie.SetChild(childKey, childTrie) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) childTrieRes, err := parentTrie.GetChild(childKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) - if !reflect.DeepEqual(childTrie, childTrieRes) { - t.Fatalf("Fail: got %v expected %v", childTrieRes, childTrie) - } + assert.Equal(t, childTrie, childTrieRes) +} + +func TestPutAndDeleteChild(t *testing.T) { + childKey := []byte("default") + childTrie := buildSmallTrie() + parentTrie := NewEmptyTrie() + + err := parentTrie.SetChild(childKey, childTrie) + assert.NoError(t, err) + + err = parentTrie.DeleteChild(childKey) + assert.NoError(t, err) + + _, err = parentTrie.GetChild(childKey) + assert.ErrorContains(t, err, "child trie does not exist at key") +} + +func TestPutAndClearFromChild(t *testing.T) { + childKey := []byte("default") + keyInChild := []byte{0x01, 0x35} + childTrie := buildSmallTrie() + parentTrie := NewEmptyTrie() + + err := parentTrie.SetChild(childKey, childTrie) + assert.NoError(t, err) + + err = parentTrie.ClearFromChild(childKey, keyInChild) + assert.NoError(t, err) + + childTrie, err = parentTrie.GetChild(childKey) + assert.NoError(t, err) + + value := childTrie.Get(keyInChild) + assert.Equal(t, []uint8(nil), value) } func TestPutAndGetFromChild(t *testing.T) { @@ -38,46 +65,32 @@ func TestPutAndGetFromChild(t *testing.T) { parentTrie := NewEmptyTrie() err := parentTrie.SetChild(childKey, childTrie) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) testKey := []byte("child_key") testValue := []byte("child_value") err = parentTrie.PutIntoChild(childKey, testKey, testValue) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) valueRes, err := parentTrie.GetFromChild(childKey, testKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) - if !bytes.Equal(valueRes, testValue) { - t.Fatalf("Fail: got %x expected %x", valueRes, testValue) - } + assert.Equal(t, valueRes, testValue) testKey = []byte("child_key_again") testValue = []byte("child_value_again") err = parentTrie.PutIntoChild(childKey, testKey, testValue) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) valueRes, err = parentTrie.GetFromChild(childKey, testKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) - if !bytes.Equal(valueRes, testValue) { - t.Fatalf("Fail: got %x expected %x", valueRes, testValue) - } + assert.Equal(t, valueRes, testValue) } func TestChildTrieHashAfterClear(t *testing.T) { trieThatHoldsAChildTrie := NewEmptyTrie() - originalEmptyHash := trieThatHoldsAChildTrie.MustHash() + originalEmptyHash := V0.MustHash(*trieThatHoldsAChildTrie) keyToChild := []byte("crowdloan") keyInChild := []byte("account-alice") @@ -90,7 +103,7 @@ func TestChildTrieHashAfterClear(t *testing.T) { // the parent trie hash SHOULT NOT BE EQUAL to the original // empty hash since it contains a value - require.NotEqual(t, originalEmptyHash, trieThatHoldsAChildTrie.MustHash()) + require.NotEqual(t, originalEmptyHash, V0.MustHash(*trieThatHoldsAChildTrie)) // ensure the value is inside the child trie valueStored, err := trieThatHoldsAChildTrie.GetFromChild(keyToChild, keyInChild) @@ -103,6 +116,6 @@ func TestChildTrieHashAfterClear(t *testing.T) { // the parent trie hash SHOULD BE EQUAL to the original // empty hash since now it does not have any other value in it - require.Equal(t, originalEmptyHash, trieThatHoldsAChildTrie.MustHash()) + require.Equal(t, originalEmptyHash, V0.MustHash(*trieThatHoldsAChildTrie)) } diff --git a/lib/trie/database.go b/lib/trie/database.go index 7e5c07bef4..667ccec5df 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -11,18 +11,9 @@ import ( "github.com/ChainSafe/gossamer/internal/trie/codec" "github.com/ChainSafe/gossamer/internal/trie/node" "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/trie/db" ) -// DBGetter gets a value corresponding to the given key. -type DBGetter interface { - Get(key []byte) (value []byte, err error) -} - -// DBPutter puts a value at the given key and returns an error. -type DBPutter interface { - Put(key []byte, value []byte) error -} - // NewBatcher creates a new database batch. type NewBatcher interface { NewBatch() database.Batch @@ -30,7 +21,7 @@ type NewBatcher interface { // Load reconstructs the trie from the database from the given root hash. // It is used when restarting the node to load the current state trie. -func (t *Trie) Load(db DBGetter, rootHash common.Hash) error { +func (t *Trie) Load(db db.DBGetter, rootHash common.Hash) error { if rootHash == EmptyHash { t.root = nil return nil @@ -54,7 +45,7 @@ func (t *Trie) Load(db DBGetter, rootHash common.Hash) error { return t.loadNode(db, t.root) } -func (t *Trie) loadNode(db DBGetter, n *Node) error { +func (t *Trie) loadNode(db db.DBGetter, n *Node) error { if n.Kind() != node.Branch { return nil } @@ -70,7 +61,7 @@ func (t *Trie) loadNode(db DBGetter, n *Node) error { if len(merkleValue) < 32 { // node has already been loaded inline // just set its encoding - _, err := child.CalculateMerkleValue() + _, err := child.CalculateMerkleValue(NoMaxInlineValueSize) if err != nil { return fmt.Errorf("merkle value: %w", err) } @@ -118,7 +109,7 @@ func (t *Trie) loadNode(db DBGetter, n *Node) error { return fmt.Errorf("failed to load child trie with root hash=%s: %w", rootHash, err) } - hash, err := childTrie.Hash() + hash, err := childTrie.Hash(NoMaxInlineValueSize) if err != nil { return fmt.Errorf("cannot hash chilld trie at key 0x%x: %w", key, err) } @@ -196,7 +187,7 @@ func recordAllDeleted(n *Node, recorder DeltaRecorder) { // It recursively descends into the trie using the database starting // from the root node until it reaches the node with the given key. // It then reads the value from the database. -func GetFromDB(db DBGetter, rootHash common.Hash, key []byte) ( +func GetFromDB(db db.DBGetter, rootHash common.Hash, key []byte) ( value []byte, err error) { if rootHash == EmptyHash { return nil, nil @@ -222,7 +213,7 @@ func GetFromDB(db DBGetter, rootHash common.Hash, key []byte) ( // for the value corresponding to a key. // Note it does not copy the value so modifying the value bytes // slice will modify the value of the node in the trie. -func getFromDBAtNode(db DBGetter, n *Node, key []byte) ( +func getFromDBAtNode(db db.DBGetter, n *Node, key []byte) ( value []byte, err error) { if n.Kind() == node.Leaf { if bytes.Equal(n.PartialKey, key) { @@ -289,16 +280,18 @@ func (t *Trie) WriteDirty(db NewBatcher) error { return batch.Flush() } -func (t *Trie) writeDirtyNode(db DBPutter, n *Node) (err error) { +func (t *Trie) writeDirtyNode(db db.DBPutter, n *Node) (err error) { if n == nil || !n.Dirty { return nil } var encoding, merkleValue []byte + // TODO: I'm sure we don't need to store the encoded now, we can try storing the (key,value) only but it needs + // some refactor and testing. In the meantime we can store the encoded node using the v0 encoding if n == t.root { - encoding, merkleValue, err = n.EncodeAndHashRoot() + encoding, merkleValue, err = n.EncodeAndHashRoot(V0.MaxInlineValue()) } else { - encoding, merkleValue, err = n.EncodeAndHash() + encoding, merkleValue, err = n.EncodeAndHash(V0.MaxInlineValue()) } if err != nil { return fmt.Errorf( @@ -373,9 +366,9 @@ func (t *Trie) getInsertedNodeHashesAtNode(n *Node, nodeHashes map[common.Hash]s var merkleValue []byte if n == t.root { - merkleValue, err = n.CalculateRootMerkleValue() + merkleValue, err = n.CalculateRootMerkleValue(NoMaxInlineValueSize) } else { - merkleValue, err = n.CalculateMerkleValue() + merkleValue, err = n.CalculateMerkleValue(NoMaxInlineValueSize) } if err != nil { return fmt.Errorf("calculating Merkle value: %w", err) diff --git a/lib/trie/database_test.go b/lib/trie/database_test.go index 037218d237..0e2b6bb4b5 100644 --- a/lib/trie/database_test.go +++ b/lib/trie/database_test.go @@ -24,7 +24,7 @@ func Test_Trie_Store_Load(t *testing.T) { const size = 1000 trie, _ := makeSeededTrie(t, size) - rootHash := trie.MustHash() + rootHash := V0.MustHash(*trie) db := newTestDB(t) err := trie.WriteDirty(db) @@ -36,6 +36,15 @@ func Test_Trie_Store_Load(t *testing.T) { assert.Equal(t, trie.String(), trieFromDB.String()) } +func Test_Trie_Load_EmptyHash(t *testing.T) { + t.Parallel() + + db := newTestDB(t) + trieFromDB := NewEmptyTrie() + err := trieFromDB.Load(db, EmptyHash) + require.NoError(t, err) +} + func Test_Trie_WriteDirty_Put(t *testing.T) { t.Parallel() @@ -56,7 +65,7 @@ func Test_Trie_WriteDirty_Put(t *testing.T) { err := trie.WriteDirty(db) require.NoError(t, err) - rootHash := trie.MustHash() + rootHash := V0.MustHash(*trie) valueFromDB, err := GetFromDB(db, rootHash, key) require.NoError(t, err) assert.Equalf(t, value, valueFromDB, "for key=%x", key) @@ -76,7 +85,7 @@ func Test_Trie_WriteDirty_Put(t *testing.T) { err = trie.WriteDirty(db) require.NoError(t, err) - rootHash := trie.MustHash() + rootHash := V0.MustHash(*trie) // Verify the trie in database is also modified. trieFromDB := NewEmptyTrie() @@ -110,7 +119,7 @@ func Test_Trie_WriteDirty_Delete(t *testing.T) { deletedKeys[string(keyToDelete)] = struct{}{} } - rootHash := trie.MustHash() + rootHash := V0.MustHash(*trie) trieFromDB := NewEmptyTrie() err = trieFromDB.Load(db, rootHash) @@ -148,7 +157,7 @@ func Test_Trie_WriteDirty_ClearPrefix(t *testing.T) { require.NoError(t, err) } - rootHash := trie.MustHash() + rootHash := V0.MustHash(*trie) trieFromDB := NewEmptyTrie() err = trieFromDB.Load(db, rootHash) @@ -268,7 +277,7 @@ func Test_GetFromDB(t *testing.T) { err := trie.WriteDirty(db) require.NoError(t, err) - root := trie.MustHash() + root := V0.MustHash(*trie) for keyString, expectedValue := range keyValues { key := []byte(keyString) @@ -278,6 +287,16 @@ func Test_GetFromDB(t *testing.T) { } } +func Test_GetFromDB_EmptyHash(t *testing.T) { + t.Parallel() + + db := newTestDB(t) + + value, err := GetFromDB(db, EmptyHash, []byte("test")) + assert.NoError(t, err) + assert.Nil(t, value) +} + func Test_Trie_PutChild_Store_Load(t *testing.T) { t.Parallel() @@ -306,7 +325,7 @@ func Test_Trie_PutChild_Store_Load(t *testing.T) { require.NoError(t, err) trieFromDB := NewEmptyTrie() - err = trieFromDB.Load(db, trie.MustHash()) + err = trieFromDB.Load(db, V0.MustHash(*trie)) require.NoError(t, err) assert.Equal(t, trie.childTries, trieFromDB.childTries) diff --git a/lib/trie/db/db.go b/lib/trie/db/db.go index 2c04b28e05..2be63db075 100644 --- a/lib/trie/db/db.go +++ b/lib/trie/db/db.go @@ -4,12 +4,35 @@ package db import ( "fmt" + "sync" "github.com/ChainSafe/gossamer/lib/common" ) +type Database interface { + DBGetter + DBPutter +} + +// DBGetter gets a value corresponding to the given key. +type DBGetter interface { + Get(key []byte) (value []byte, err error) +} + +// DBPutter puts a value at the given key and returns an error. +type DBPutter interface { + Put(key []byte, value []byte) error +} + type MemoryDB struct { - data map[common.Hash][]byte + data map[common.Hash][]byte + mutex sync.RWMutex +} + +func NewEmptyMemoryDB() *MemoryDB { + return &MemoryDB{ + data: make(map[common.Hash][]byte), + } } func NewMemoryDBFromProof(encodedNodes [][]byte) (*MemoryDB, error) { @@ -30,16 +53,45 @@ func NewMemoryDBFromProof(encodedNodes [][]byte) (*MemoryDB, error) { } -func (mdb *MemoryDB) Get(key []byte) (value []byte, err error) { - if len(key) < common.HashLength { - return nil, fmt.Errorf("expected %d bytes length key, given %d (%x)", common.HashLength, len(key), value) +func (mdb *MemoryDB) Copy() Database { + newDB := NewEmptyMemoryDB() + copyData := make(map[common.Hash][]byte, len(mdb.data)) + + for k, v := range mdb.data { + copyData[k] = v } - var hash common.Hash - copy(hash[:], key) - if value, found := mdb.data[hash]; found { + newDB.data = copyData + return newDB +} + +func (mdb *MemoryDB) Get(key []byte) ([]byte, error) { + if len(key) != common.HashLength { + return nil, fmt.Errorf("expected %d bytes length key, given %d (%x)", common.HashLength, len(key), key) + } + hashedKey := common.Hash(key) + + mdb.mutex.RLock() + defer mdb.mutex.RUnlock() + + if value, found := mdb.data[hashedKey]; found { return value, nil } return nil, nil } + +func (mdb *MemoryDB) Put(key, value []byte) error { + if len(key) != common.HashLength { + return fmt.Errorf("expected %d bytes length key, given %d (%x)", common.HashLength, len(key), key) + } + + var hash common.Hash + copy(hash[:], key) + + mdb.mutex.Lock() + defer mdb.mutex.Unlock() + + mdb.data[hash] = value + return nil +} diff --git a/lib/trie/db_getter_mocks_test.go b/lib/trie/db_getter_mocks_test.go index e5a31ea71d..ecf5a52dbb 100644 --- a/lib/trie/db_getter_mocks_test.go +++ b/lib/trie/db_getter_mocks_test.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/ChainSafe/gossamer/lib/trie (interfaces: DBGetter) +// Source: github.com/ChainSafe/gossamer/lib/trie/db (interfaces: DBGetter) // Package trie is a generated GoMock package. package trie diff --git a/lib/trie/genesis.go b/lib/trie/genesis.go index c4e6e21737..d2205208b1 100644 --- a/lib/trie/genesis.go +++ b/lib/trie/genesis.go @@ -12,7 +12,7 @@ import ( // GenesisBlock creates a genesis block from the trie. func (t *Trie) GenesisBlock() (genesisHeader types.Header, err error) { - rootHash, err := t.Hash() + rootHash, err := t.Hash(NoMaxInlineValueSize) if err != nil { return genesisHeader, fmt.Errorf("root hashing trie: %w", err) } diff --git a/lib/trie/layout.go b/lib/trie/layout.go new file mode 100644 index 0000000000..5758a7b23a --- /dev/null +++ b/lib/trie/layout.go @@ -0,0 +1,117 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package trie + +import ( + "errors" + "fmt" + "math" + "strings" + + "github.com/ChainSafe/gossamer/lib/common" +) + +const ( + // NoMaxInlineValueSize is the numeric representation used to indicate that there is no max value size. + NoMaxInlineValueSize = math.MaxInt + // V1MaxInlineValueSize is the maximum size of a value to be hashed in state trie version 1. + V1MaxInlineValueSize = 32 +) + +// TrieLayout is the state trie version which dictates how a +// Merkle root should be constructed. It is defined in +// https://spec.polkadot.network/#defn-state-version +type TrieLayout uint8 + +const ( + // V0 is the state trie version 0 where the values of the keys are + // inserted into the trie directly. + // TODO set to iota once CI passes + V0 TrieLayout = iota + V1 +) + +var NoVersion = TrieLayout(math.MaxUint8) + +// ErrParseVersion is returned when parsing a state trie version fails. +var ErrParseVersion = errors.New("parsing version failed") + +// DefaultStateVersion sets the state version we should use as default +// See https://github.com/paritytech/substrate/blob/5e76587825b9a9d52d8cb02ba38828adf606157b/primitives/storage/src/lib.rs#L435-L439 +const DefaultStateVersion = V1 + +// Entry is a key-value pair used to build a trie +type Entry struct{ Key, Value []byte } + +// Entries is a list of entry used to build a trie +type Entries []Entry + +// String returns a string representation of trie version +func (v TrieLayout) String() string { + switch v { + case V0: + return "v0" + case V1: + return "v1" + default: + panic(fmt.Sprintf("unknown version %d", v)) + } +} + +// MaxInlineValue returns the maximum size of a value to be inlined in the trie node +func (v TrieLayout) MaxInlineValue() int { + switch v { + case V0: + return NoMaxInlineValueSize + case V1: + return V1MaxInlineValueSize + default: + panic(fmt.Sprintf("unknown version %d", v)) + } +} + +// Root returns the root hash of the trie built using the given entries +func (v TrieLayout) Root(entries Entries) (common.Hash, error) { + t := NewEmptyTrie() + + for _, kv := range entries { + err := t.Put(kv.Key, kv.Value) + if err != nil { + return common.EmptyHash, err + } + } + + return t.Hash(v.MaxInlineValue()) +} + +// Hash returns the root hash of the trie built using the given entries +func (v TrieLayout) Hash(t *Trie) (common.Hash, error) { + return t.Hash(v.MaxInlineValue()) +} + +// MustHash returns the root hash of the trie built using the given entries or panics if it fails +func (v TrieLayout) MustHash(t Trie) common.Hash { + return t.MustHash(v.MaxInlineValue()) +} + +// ParseVersion parses a state trie version string. +func ParseVersion[T string | uint32](v T) (version TrieLayout, err error) { + var s string + switch value := any(v).(type) { + case string: + s = value + case uint32: + s = fmt.Sprintf("V%d", value) + } + + switch { + case strings.EqualFold(s, V0.String()): + return V0, nil + case strings.EqualFold(s, V1.String()): + return V1, nil + default: + return version, fmt.Errorf("%w: %q must be one of [%s, %s]", + ErrParseVersion, s, V0, V1) + } +} diff --git a/lib/trie/layout_test.go b/lib/trie/layout_test.go new file mode 100644 index 0000000000..b2562ad7ab --- /dev/null +++ b/lib/trie/layout_test.go @@ -0,0 +1,205 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package trie + +import ( + "testing" + + "github.com/ChainSafe/gossamer/lib/common" + "github.com/stretchr/testify/assert" +) + +func Test_Version_String(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + version TrieLayout + versionString string + panicMessage string + }{ + "v0": { + version: V0, + versionString: "v0", + }, + "invalid": { + version: TrieLayout(99), + panicMessage: "unknown version 99", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + if testCase.panicMessage != "" { + assert.PanicsWithValue(t, testCase.panicMessage, func() { + _ = testCase.version.String() + }) + return + } + + versionString := testCase.version.String() + assert.Equal(t, testCase.versionString, versionString) + }) + } +} + +func Test_ParseVersion(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + v any + version TrieLayout + errWrapped error + errMessage string + }{ + "v0": { + v: "v0", + version: V0, + }, + "V0": { + v: "V0", + version: V0, + }, + "0": { + v: uint32(0), + version: V0, + }, + "v1": { + v: "v1", + version: V1, + }, + "V1": { + v: "V1", + version: V1, + }, + "1": { + v: uint32(1), + version: V1, + }, + "invalid": { + v: "xyz", + errWrapped: ErrParseVersion, + errMessage: "parsing version failed: \"xyz\" must be one of [v0, v1]", + }, + "invalid_uint32": { + v: uint32(999), + errWrapped: ErrParseVersion, + errMessage: "parsing version failed: \"V999\" must be one of [v0, v1]", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + var version TrieLayout + + var err error + switch typed := testCase.v.(type) { + case string: + version, err = ParseVersion(typed) + case uint32: + version, err = ParseVersion(typed) + default: + t.Fail() + } + + assert.Equal(t, testCase.version, version) + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + }) + } +} + +func Test_Version_MaxInlineValue(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + version TrieLayout + max int + panicMessage string + }{ + "v0": { + version: V0, + max: NoMaxInlineValueSize, + }, + "v1": { + version: V1, + max: V1MaxInlineValueSize, + }, + "invalid": { + version: TrieLayout(99), + max: 0, + panicMessage: "unknown version 99", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + if testCase.panicMessage != "" { + assert.PanicsWithValue(t, testCase.panicMessage, func() { + _ = testCase.version.MaxInlineValue() + }) + return + } + + maxInline := testCase.version.MaxInlineValue() + assert.Equal(t, testCase.max, maxInline) + }) + } +} + +func Test_Version_Root(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + version TrieLayout + input Entries + expected common.Hash + }{ + "v0": { + version: V0, + input: Entries{ + Entry{Key: []byte("key1"), Value: []byte("value1")}, + Entry{Key: []byte("key2"), Value: []byte("value2")}, + Entry{Key: []byte("key3"), Value: []byte("verylargevaluewithmorethan32byteslength")}, + }, + expected: common.Hash{ + 0x71, 0x5, 0x2d, 0x48, 0x70, 0x46, 0x58, 0xa8, 0x43, 0x5f, 0xb9, 0xcb, 0xc7, 0xef, 0x69, 0xc7, 0x5d, + 0xad, 0x2f, 0x64, 0x0, 0x1c, 0xb3, 0xb, 0xfa, 0x1, 0xf, 0x7d, 0x60, 0x9e, 0x26, 0x57, + }, + }, + "v1": { + version: V1, + input: Entries{ + Entry{Key: []byte("key1"), Value: []byte("value1")}, + Entry{Key: []byte("key2"), Value: []byte("value2")}, + Entry{Key: []byte("key3"), Value: []byte("verylargevaluewithmorethan32byteslength")}, + }, + expected: common.Hash{ + 0x6a, 0x4a, 0x73, 0x27, 0x57, 0x26, 0x3b, 0xf2, 0xbc, 0x4e, 0x3, 0xa3, 0x41, 0xe3, 0xf8, 0xea, 0x63, + 0x5f, 0x78, 0x99, 0x6e, 0xc0, 0x6a, 0x6a, 0x96, 0x5d, 0x50, 0x97, 0xa2, 0x91, 0x1c, 0x29, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + maxInline, err := testCase.version.Root(testCase.input) + assert.NoError(t, err) + assert.Equal(t, testCase.expected, maxInline) + }) + } +} diff --git a/lib/trie/mocks_generate_test.go b/lib/trie/mocks_generate_test.go index 767f1a8aed..f101c2239c 100644 --- a/lib/trie/mocks_generate_test.go +++ b/lib/trie/mocks_generate_test.go @@ -3,4 +3,4 @@ package trie -//go:generate mockgen -destination=db_getter_mocks_test.go -package=$GOPACKAGE . DBGetter +//go:generate mockgen -destination=db_getter_mocks_test.go -package=$GOPACKAGE github.com/ChainSafe/gossamer/lib/trie/db DBGetter diff --git a/lib/trie/print_test.go b/lib/trie/print_test.go index ff71d5123f..7610b34697 100644 --- a/lib/trie/print_test.go +++ b/lib/trie/print_test.go @@ -32,6 +32,7 @@ func Test_Trie_String(t *testing.T) { ├── Dirty: false ├── Key: 0x010203 ├── Storage value: 0x030405 +├── IsHashed: false └── Merkle value: nil`, }, "branch_root": { @@ -60,6 +61,7 @@ func Test_Trie_String(t *testing.T) { ├── Dirty: false ├── Key: nil ├── Storage value: 0x0102 +├── IsHashed: false ├── Descendants: 2 ├── Merkle value: nil ├── Child 0 @@ -68,6 +70,7 @@ func Test_Trie_String(t *testing.T) { | ├── Dirty: false | ├── Key: 0x010203 | ├── Storage value: 0x030405 +| ├── IsHashed: false | └── Merkle value: nil └── Child 3 └── Leaf @@ -75,6 +78,7 @@ func Test_Trie_String(t *testing.T) { ├── Dirty: false ├── Key: 0x010203 ├── Storage value: 0x030405 + ├── IsHashed: false └── Merkle value: nil`, }, } diff --git a/lib/trie/proof/database_mocks_test.go b/lib/trie/proof/database_mocks_test.go index 69262dc315..7a4a6d31a8 100644 --- a/lib/trie/proof/database_mocks_test.go +++ b/lib/trie/proof/database_mocks_test.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/ChainSafe/gossamer/lib/trie/proof (interfaces: Database) +// Source: github.com/ChainSafe/gossamer/lib/trie/db (interfaces: DBGetter) // Package proof is a generated GoMock package. package proof @@ -10,31 +10,31 @@ import ( gomock "github.com/golang/mock/gomock" ) -// MockDatabase is a mock of Database interface. -type MockDatabase struct { +// MockDBGetter is a mock of DBGetter interface. +type MockDBGetter struct { ctrl *gomock.Controller - recorder *MockDatabaseMockRecorder + recorder *MockDBGetterMockRecorder } -// MockDatabaseMockRecorder is the mock recorder for MockDatabase. -type MockDatabaseMockRecorder struct { - mock *MockDatabase +// MockDBGetterMockRecorder is the mock recorder for MockDBGetter. +type MockDBGetterMockRecorder struct { + mock *MockDBGetter } -// NewMockDatabase creates a new mock instance. -func NewMockDatabase(ctrl *gomock.Controller) *MockDatabase { - mock := &MockDatabase{ctrl: ctrl} - mock.recorder = &MockDatabaseMockRecorder{mock} +// NewMockDBGetter creates a new mock instance. +func NewMockDBGetter(ctrl *gomock.Controller) *MockDBGetter { + mock := &MockDBGetter{ctrl: ctrl} + mock.recorder = &MockDBGetterMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockDatabase) EXPECT() *MockDatabaseMockRecorder { +func (m *MockDBGetter) EXPECT() *MockDBGetterMockRecorder { return m.recorder } // Get mocks base method. -func (m *MockDatabase) Get(arg0 []byte) ([]byte, error) { +func (m *MockDBGetter) Get(arg0 []byte) ([]byte, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Get", arg0) ret0, _ := ret[0].([]byte) @@ -43,7 +43,7 @@ func (m *MockDatabase) Get(arg0 []byte) ([]byte, error) { } // Get indicates an expected call of Get. -func (mr *MockDatabaseMockRecorder) Get(arg0 interface{}) *gomock.Call { +func (mr *MockDBGetterMockRecorder) Get(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockDatabase)(nil).Get), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockDBGetter)(nil).Get), arg0) } diff --git a/lib/trie/proof/generate.go b/lib/trie/proof/generate.go index 2fd30d9ddd..0daca83e52 100644 --- a/lib/trie/proof/generate.go +++ b/lib/trie/proof/generate.go @@ -13,23 +13,18 @@ import ( "github.com/ChainSafe/gossamer/internal/trie/pools" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/trie" + "github.com/ChainSafe/gossamer/lib/trie/db" ) var ( ErrKeyNotFound = errors.New("key not found") ) -// Database defines a key value Get method used -// for proof generation. -type Database interface { - Get(key []byte) (value []byte, err error) -} - // Generate generates and deduplicates the encoded proof nodes // for the trie corresponding to the root hash given, and for // the slice of (Little Endian) full keys given. The database given // is used to load the trie using the root hash given. -func Generate(rootHash []byte, fullKeys [][]byte, database Database) ( +func Generate(rootHash []byte, fullKeys [][]byte, database db.DBGetter) ( encodedProofNodes [][]byte, err error) { trie := trie.NewEmptyTrie() if err := trie.Load(database, common.BytesToHash(rootHash)); err != nil { @@ -86,7 +81,7 @@ func walkRoot(root *node.Node, fullKey []byte) ( // Note we do not use sync.Pool buffers since we would have // to copy it so it persists in encodedProofNodes. encodingBuffer := bytes.NewBuffer(nil) - err = root.Encode(encodingBuffer) + err = root.Encode(encodingBuffer, trie.NoMaxInlineValueSize) if err != nil { return nil, fmt.Errorf("encode node: %w", err) } @@ -131,7 +126,7 @@ func walk(parent *node.Node, fullKey []byte) ( // Note we do not use sync.Pool buffers since we would have // to copy it so it persists in encodedProofNodes. encodingBuffer := bytes.NewBuffer(nil) - err = parent.Encode(encodingBuffer) + err = parent.Encode(encodingBuffer, trie.NoMaxInlineValueSize) if err != nil { return nil, fmt.Errorf("encode node: %w", err) } diff --git a/lib/trie/proof/generate_test.go b/lib/trie/proof/generate_test.go index 49f519b3c0..1390b1d08a 100644 --- a/lib/trie/proof/generate_test.go +++ b/lib/trie/proof/generate_test.go @@ -10,6 +10,7 @@ import ( "github.com/ChainSafe/gossamer/internal/trie/codec" "github.com/ChainSafe/gossamer/internal/trie/node" "github.com/ChainSafe/gossamer/lib/trie" + "github.com/ChainSafe/gossamer/lib/trie/db" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -31,15 +32,15 @@ func Test_Generate(t *testing.T) { testCases := map[string]struct { rootHash []byte fullKeysNibbles [][]byte - databaseBuilder func(ctrl *gomock.Controller) Database + databaseBuilder func(ctrl *gomock.Controller) db.DBGetter encodedProofNodes [][]byte errWrapped error errMessage string }{ "failed_loading_trie": { rootHash: someHash, - databaseBuilder: func(ctrl *gomock.Controller) Database { - mockDatabase := NewMockDatabase(ctrl) + databaseBuilder: func(ctrl *gomock.Controller) db.DBGetter { + mockDatabase := NewMockDBGetter(ctrl) mockDatabase.EXPECT().Get(someHash). Return(nil, errTest) return mockDatabase @@ -53,8 +54,8 @@ func Test_Generate(t *testing.T) { "walk_error": { rootHash: someHash, fullKeysNibbles: [][]byte{{1}}, - databaseBuilder: func(ctrl *gomock.Controller) Database { - mockDatabase := NewMockDatabase(ctrl) + databaseBuilder: func(ctrl *gomock.Controller) db.DBGetter { + mockDatabase := NewMockDBGetter(ctrl) encodedRoot := encodeNode(t, node.Node{ PartialKey: []byte{1}, StorageValue: []byte{2}, @@ -69,8 +70,8 @@ func Test_Generate(t *testing.T) { "leaf_root": { rootHash: someHash, fullKeysNibbles: [][]byte{{}}, - databaseBuilder: func(ctrl *gomock.Controller) Database { - mockDatabase := NewMockDatabase(ctrl) + databaseBuilder: func(ctrl *gomock.Controller) db.DBGetter { + mockDatabase := NewMockDBGetter(ctrl) encodedRoot := encodeNode(t, node.Node{ PartialKey: []byte{1}, StorageValue: []byte{2}, @@ -89,8 +90,8 @@ func Test_Generate(t *testing.T) { "branch_root": { rootHash: someHash, fullKeysNibbles: [][]byte{{}}, - databaseBuilder: func(ctrl *gomock.Controller) Database { - mockDatabase := NewMockDatabase(ctrl) + databaseBuilder: func(ctrl *gomock.Controller) db.DBGetter { + mockDatabase := NewMockDBGetter(ctrl) encodedRoot := encodeNode(t, node.Node{ PartialKey: []byte{1}, StorageValue: []byte{2}, @@ -125,8 +126,8 @@ func Test_Generate(t *testing.T) { fullKeysNibbles: [][]byte{ {1, 2, 3, 4}, }, - databaseBuilder: func(ctrl *gomock.Controller) Database { - mockDatabase := NewMockDatabase(ctrl) + databaseBuilder: func(ctrl *gomock.Controller) db.DBGetter { + mockDatabase := NewMockDBGetter(ctrl) rootNode := node.Node{ PartialKey: []byte{1, 2}, @@ -174,8 +175,8 @@ func Test_Generate(t *testing.T) { {1, 2, 4, 4}, {1, 2, 5, 5}, }, - databaseBuilder: func(ctrl *gomock.Controller) Database { - mockDatabase := NewMockDatabase(ctrl) + databaseBuilder: func(ctrl *gomock.Controller) db.DBGetter { + mockDatabase := NewMockDBGetter(ctrl) rootNode := node.Node{ PartialKey: []byte{1, 2}, diff --git a/lib/trie/proof/helpers_test.go b/lib/trie/proof/helpers_test.go index de3d6fe25a..58ec6bb417 100644 --- a/lib/trie/proof/helpers_test.go +++ b/lib/trie/proof/helpers_test.go @@ -10,6 +10,7 @@ import ( "github.com/ChainSafe/gossamer/internal/trie/node" "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/trie" "github.com/ChainSafe/gossamer/pkg/scale" "github.com/stretchr/testify/require" ) @@ -23,7 +24,7 @@ func padRightChildren(slice []*node.Node) (paddedSlice []*node.Node) { func encodeNode(t *testing.T, node node.Node) (encoded []byte) { t.Helper() buffer := bytes.NewBuffer(nil) - err := node.Encode(buffer) + err := node.Encode(buffer, trie.NoMaxInlineValueSize) require.NoError(t, err) return buffer.Bytes() } diff --git a/lib/trie/proof/mocks_generate_test.go b/lib/trie/proof/mocks_generate_test.go index 314b3490aa..adde5e4e89 100644 --- a/lib/trie/proof/mocks_generate_test.go +++ b/lib/trie/proof/mocks_generate_test.go @@ -3,4 +3,4 @@ package proof -//go:generate mockgen -destination=database_mocks_test.go -package=$GOPACKAGE . Database +//go:generate mockgen -destination=database_mocks_test.go -package=$GOPACKAGE github.com/ChainSafe/gossamer/lib/trie/db DBGetter diff --git a/lib/trie/proof/proof_test.go b/lib/trie/proof/proof_test.go index 40b68ed155..a45e4ffde3 100644 --- a/lib/trie/proof/proof_test.go +++ b/lib/trie/proof/proof_test.go @@ -25,19 +25,19 @@ func Test_Generate_Verify(t *testing.T) { "doguinho", } - trie := trie.NewEmptyTrie() + tr := trie.NewEmptyTrie() for i, key := range keys { value := fmt.Sprintf("%x-%d", key, i) - trie.Put([]byte(key), []byte(value)) + tr.Put([]byte(key), []byte(value)) } - rootHash, err := trie.Hash() + rootHash, err := trie.V0.Hash(tr) require.NoError(t, err) db, err := database.NewPebble("", true) require.NoError(t, err) - err = trie.WriteDirty(db) + err = tr.WriteDirty(db) require.NoError(t, err) for i, key := range keys { @@ -96,3 +96,37 @@ func TestParachainHeaderStateProof(t *testing.T) { require.NoError(t, err) } + +func TestTrieProof(t *testing.T) { + key, err := hex.DecodeString("f0c365c3cf59d671eb72da0e7a4113c49f1f0515f462cdcf84e0f1d6045dfcbb") + if err != nil { + panic(err) + } + root, err := hex.DecodeString("dc4887669c2a6b3462e9557aa3105a66a02b6ec3b21784613de78c95dc3cbbe0") + if err != nil { + panic(err) + } + proof1, err := hex.DecodeString("80fffd8028b54b9a0a90d41b7941c43e6a0597d5914e3b62bdcb244851b9fc806c28ea2480d5ba6d50586692888b0c2f5b3c3fc345eb3a2405996f025ed37982ca396f5ed580bd281c12f20f06077bffd56b2f8b6431ee6c9fd11fed9c22db86cea849aeff2280afa1e1b5ce72ea1675e5e69be85e98fbfb660691a76fee9229f758a75315f2bc80aafc60caa3519d4b861e6b8da226266a15060e2071bba4184e194da61dfb208e809d3f6ae8f655009551de95ae1ef863f6771522fd5c0475a50ff53c5c8169b5888024a760a8f6c27928ae9e2fed9968bc5f6e17c3ae647398d8a615e5b2bb4b425f8085a0da830399f25fca4b653de654ffd3c92be39f3ae4f54e7c504961b5bd00cf80c2d44d371e5fc1f50227d7491ad65ad049630361cefb4ab1844831237609f08380c644938921d14ae611f3a90991af8b7f5bdb8fa361ee2c646c849bca90f491e6806e729ad43a591cd1321762582782bbe4ed193c6f583ec76013126f7f786e376280509bb016f2887d12137e73d26d7ddcd7f9c8ff458147cb9d309494655fe68de180009f8697d760fbe020564b07f407e6aad58ba9451b3d2d88b3ee03e12db7c47480952dcc0804e1120508a1753f1de4aa5b7481026a3320df8b48e918f0cecbaed3803360bf948fddc403d345064082e8393d7a1aad7a19081f6d02d94358f242b86c") //nolint:lll + if err != nil { + panic(err) + } + proof2, err := hex.DecodeString("9ec365c3cf59d671eb72da0e7a4113c41002505f0e7b9012096b41c4eb3aaf947f6ea429080000685f0f1f0515f462cdcf84e0f1d6045dfcbb20865c4a2b7f010000") //nolint:lll + if err != nil { + panic(err) + } + proof3, err := hex.DecodeString("8005088076c66e2871b4fe037d112ebffb3bfc8bd83a4ec26047f58ee2df7be4e9ebe3d680c1638f702aaa71e4b78cc8538ecae03e827bb494cc54279606b201ec071a5e24806d2a1e6d5236e1e13c5a5c84831f5f5383f97eba32df6f9faf80e32cf2f129bc") //nolint:lll + if err != nil { + panic(err) + } + + proof := [][]byte{proof1, proof2, proof3} + proofDB, err := db.NewMemoryDBFromProof(proof) + + require.NoError(t, err) + + trie, err := buildTrie(proof, root, proofDB) + require.NoError(t, err) + value := trie.Get(key) + + require.Equal(t, []byte{0x86, 0x5c, 0x4a, 0x2b, 0x7f, 0x1, 0x0, 0x0}, value) +} diff --git a/lib/trie/proof/verify.go b/lib/trie/proof/verify.go index 1e8457f04f..97e439fdec 100644 --- a/lib/trie/proof/verify.go +++ b/lib/trie/proof/verify.go @@ -62,7 +62,7 @@ var ( ) // buildTrie sets a partial trie based on the proof slice of encoded nodes. -func buildTrie(encodedProofNodes [][]byte, rootHash []byte, db Database) (t *trie.Trie, err error) { +func buildTrie(encodedProofNodes [][]byte, rootHash []byte, db db.Database) (t *trie.Trie, err error) { if len(encodedProofNodes) == 0 { return nil, fmt.Errorf("%w: for Merkle root hash 0x%x", ErrEmptyProof, rootHash) diff --git a/lib/trie/proof/verify_test.go b/lib/trie/proof/verify_test.go index db969e1619..78e0e5dd50 100644 --- a/lib/trie/proof/verify_test.go +++ b/lib/trie/proof/verify_test.go @@ -141,7 +141,7 @@ func Test_buildTrie(t *testing.T) { encodedProofNodes [][]byte rootHash []byte expectedTrie *trie.Trie - db Database + db db.Database errWrapped error errMessage string } diff --git a/lib/trie/trie.go b/lib/trie/trie.go index ea2d1dd4ae..99152be798 100644 --- a/lib/trie/trie.go +++ b/lib/trie/trie.go @@ -6,11 +6,13 @@ package trie import ( "bytes" "fmt" + "reflect" "github.com/ChainSafe/gossamer/internal/trie/codec" "github.com/ChainSafe/gossamer/internal/trie/node" "github.com/ChainSafe/gossamer/internal/trie/tracking" "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/trie/db" ) // EmptyHash is the empty trie hash. @@ -21,7 +23,7 @@ type Trie struct { generation uint64 root *Node childTries map[common.Hash]*Trie - db DBGetter + db db.Database // deltas stores trie deltas since the last trie snapshot. // For example node hashes that were deleted since // the last snapshot. These are used by the online @@ -32,11 +34,11 @@ type Trie struct { // NewEmptyTrie creates a trie with a nil root func NewEmptyTrie() *Trie { - return NewTrie(nil, nil) + return NewTrie(nil, db.NewEmptyMemoryDB()) } // NewTrie creates a trie with an existing root node -func NewTrie(root *Node, db DBGetter) *Trie { +func NewTrie(root *Node, db db.Database) *Trie { return &Trie{ root: root, childTries: make(map[common.Hash]*Trie), @@ -46,6 +48,20 @@ func NewTrie(root *Node, db DBGetter) *Trie { } } +// Equal is to compare one trie with other, this method will ignore the shared db instance +func (t *Trie) Equal(other *Trie) bool { + if t == nil && other == nil { + return true + } + + if t == nil || other == nil { + return false + } + + return t.generation == other.generation && reflect.DeepEqual(t.root, other.root) && + reflect.DeepEqual(t.childTries, other.childTries) && reflect.DeepEqual(t.deltas, other.deltas) +} + // Snapshot creates a copy of the trie. // Note it does not deep copy the trie, but will // copy on write as modifications are done on this new trie. @@ -66,6 +82,7 @@ func (t *Trie) Snapshot() (newTrie *Trie) { return &Trie{ generation: t.generation + 1, root: t.root, + db: t.db, childTries: childTries, deltas: tracking.New(), } @@ -139,6 +156,7 @@ func (t *Trie) DeepCopy() (trieCopy *Trie) { trieCopy = &Trie{ generation: t.generation, + db: t.db, } if t.deltas != nil { @@ -171,8 +189,8 @@ func (t *Trie) RootNode() *Node { // MustHash returns the hashed root of the trie. // It panics if it fails to hash the root node. -func (t *Trie) MustHash() common.Hash { - h, err := t.Hash() +func (t *Trie) MustHash(maxInlineValue int) common.Hash { + h, err := t.Hash(maxInlineValue) if err != nil { panic(err) } @@ -181,12 +199,12 @@ func (t *Trie) MustHash() common.Hash { } // Hash returns the hashed root of the trie. -func (t *Trie) Hash() (rootHash common.Hash, err error) { +func (t *Trie) Hash(maxInlineValue int) (rootHash common.Hash, err error) { if t.root == nil { return EmptyHash, nil } - merkleValue, err := t.root.CalculateRootMerkleValue() + merkleValue, err := t.root.CalculateRootMerkleValue(maxInlineValue) if err != nil { return rootHash, err } @@ -194,71 +212,39 @@ func (t *Trie) Hash() (rootHash common.Hash, err error) { return rootHash, nil } -// EntriesList returns all the key-value pairs in the trie as a slice of key value -// where the keys are encoded in Little Endian. The slice starts with root node. -func (t *Trie) EntriesList() [][2][]byte { - list := make([][2][]byte, 0) - entriesList(t.root, nil, &list) - return list -} - -func entriesList(parent *Node, prefix []byte, list *[][2][]byte) { - if parent == nil { - return - } - - if parent.Kind() == node.Leaf { - parentKey := parent.PartialKey - fullKeyNibbles := concatenateSlices(prefix, parentKey) - keyLE := codec.NibblesToKeyLE(fullKeyNibbles) - *list = append(*list, [2][]byte{keyLE, parent.StorageValue}) - return - } - - branch := parent - if branch.StorageValue != nil { - fullKeyNibbles := concatenateSlices(prefix, branch.PartialKey) - keyLE := codec.NibblesToKeyLE(fullKeyNibbles) - *list = append(*list, [2][]byte{keyLE, parent.StorageValue}) - } - - for i, child := range branch.Children { - childPrefix := concatenateSlices(prefix, branch.PartialKey, intToByteSlice(i)) - entriesList(child, childPrefix, list) - } -} - // Entries returns all the key-value pairs in the trie as a map of keys to values // where the keys are encoded in Little Endian. func (t *Trie) Entries() (keyValueMap map[string][]byte) { keyValueMap = make(map[string][]byte) - entries(t.root, nil, keyValueMap) + t.buildEntriesMap(t.root, nil, keyValueMap) return keyValueMap } -func entries(parent *Node, prefix []byte, kv map[string][]byte) { - if parent == nil { +func (t *Trie) buildEntriesMap(currentNode *Node, prefix []byte, kv map[string][]byte) { + if currentNode == nil { return } - if parent.Kind() == node.Leaf { - parentKey := parent.PartialKey - fullKeyNibbles := concatenateSlices(prefix, parentKey) - keyLE := string(codec.NibblesToKeyLE(fullKeyNibbles)) - kv[keyLE] = parent.StorageValue + // Leaf + if currentNode.Kind() == node.Leaf { + key := currentNode.PartialKey + fullKeyNibbles := concatenateSlices(prefix, key) + keyLE := codec.NibblesToKeyLE(fullKeyNibbles) + kv[string(keyLE)] = t.Get(keyLE) return } - branch := parent + // Branch + branch := currentNode if branch.StorageValue != nil { fullKeyNibbles := concatenateSlices(prefix, branch.PartialKey) - keyLE := string(codec.NibblesToKeyLE(fullKeyNibbles)) - kv[keyLE] = branch.StorageValue + keyLE := codec.NibblesToKeyLE(fullKeyNibbles) + kv[string(keyLE)] = t.Get(keyLE) } for i, child := range branch.Children { childPrefix := concatenateSlices(prefix, branch.PartialKey, intToByteSlice(i)) - entries(child, childPrefix, kv) + t.buildEntriesMap(child, childPrefix, kv) } } @@ -372,6 +358,7 @@ func (t *Trie) insertKeyLE(keyLE, value []byte, // is no value. value = []byte{} } + root, _, _, err := t.insert(t.root, nibblesKey, value, pendingDeltas) if err != nil { return err @@ -382,8 +369,7 @@ func (t *Trie) insertKeyLE(keyLE, value []byte, // insert inserts a value in the trie at the key specified. // It may create one or more new nodes or update an existing node. -func (t *Trie) insert(parent *Node, key, value []byte, - pendingDeltas DeltaRecorder) (newParent *Node, +func (t *Trie) insert(parent *Node, key, value []byte, pendingDeltas DeltaRecorder) (newParent *Node, mutated bool, nodesCreated uint32, err error) { if parent == nil { mutated = true @@ -477,6 +463,7 @@ func (t *Trie) insertInLeaf(parentLeaf *Node, key, value []byte, if len(parentLeaf.PartialKey) == commonPrefixLength { // the key of the parent leaf is at this new branch newBranchParent.StorageValue = parentLeaf.StorageValue + newBranchParent.IsHashedValue = parentLeaf.IsHashedValue } else { // make the leaf a child of the new branch copySettings := node.DefaultCopySettings @@ -644,29 +631,6 @@ func LoadFromMap(data map[string]string) (trie Trie, err error) { return trie, nil } -// LoadFromEntries loads the given slice of key values into a new empty trie. -// The keys are in hexadecimal little Endian encoding and the values -// are hexadecimal encoded. -func LoadFromEntries(entries [][2][]byte) (trie *Trie, err error) { - trie = NewEmptyTrie() - - pendingDeltas := tracking.New() - defer func() { - trie.handleTrackedDeltas(err == nil, pendingDeltas) - }() - - for _, keyValue := range entries { - keyLE := keyValue[0] - value := keyValue[1] - err := trie.insertKeyLE(keyLE, value, pendingDeltas) - if err != nil { - return nil, err - } - } - - return trie, nil -} - // GetKeysWithPrefix returns all keys in little Endian // format from nodes in the trie that have the given little // Endian formatted prefix in their key. @@ -777,7 +741,7 @@ func (t *Trie) Get(keyLE []byte) (value []byte) { return retrieve(t.db, t.root, keyNibbles) } -func retrieve(db DBGetter, parent *Node, key []byte) (value []byte) { +func retrieve(db db.DBGetter, parent *Node, key []byte) (value []byte) { if parent == nil { return nil } @@ -788,9 +752,9 @@ func retrieve(db DBGetter, parent *Node, key []byte) (value []byte) { return retrieveFromBranch(db, parent, key) } -func retrieveFromLeaf(db DBGetter, leaf *Node, key []byte) (value []byte) { +func retrieveFromLeaf(db db.DBGetter, leaf *Node, key []byte) (value []byte) { if bytes.Equal(leaf.PartialKey, key) { - if leaf.HashedValue { + if leaf.IsHashedValue { // We get the node value, err := db.Get(leaf.StorageValue) if err != nil { @@ -803,7 +767,7 @@ func retrieveFromLeaf(db DBGetter, leaf *Node, key []byte) (value []byte) { return nil } -func retrieveFromBranch(db DBGetter, branch *Node, key []byte) (value []byte) { +func retrieveFromBranch(db db.DBGetter, branch *Node, key []byte) (value []byte) { if len(key) == 0 || bytes.Equal(branch.PartialKey, key) { return branch.StorageValue } @@ -1399,10 +1363,11 @@ func (t *Trie) handleDeletion(branch *Node, key []byte, if child.Kind() == node.Leaf { newLeafKey := concatenateSlices(branch.PartialKey, intToByteSlice(childIndex), child.PartialKey) return &Node{ - PartialKey: newLeafKey, - StorageValue: child.StorageValue, - Dirty: true, - Generation: branch.Generation, + PartialKey: newLeafKey, + StorageValue: child.StorageValue, + IsHashedValue: child.IsHashedValue, + Dirty: true, + Generation: branch.Generation, }, branchChildMerged, nil } @@ -1441,12 +1406,12 @@ func (t *Trie) ensureMerkleValueIsCalculated(parent *Node) (err error) { } if parent == t.root { - _, err = parent.CalculateRootMerkleValue() + _, err = parent.CalculateRootMerkleValue(NoMaxInlineValueSize) if err != nil { return fmt.Errorf("calculating Merkle value of root node: %w", err) } } else { - _, err = parent.CalculateMerkleValue() + _, err = parent.CalculateMerkleValue(NoMaxInlineValueSize) if err != nil { return fmt.Errorf("calculating Merkle value of node: %w", err) } diff --git a/lib/trie/trie_endtoend_test.go b/lib/trie/trie_endtoend_test.go index 2fe0fa9b72..1cccd342be 100644 --- a/lib/trie/trie_endtoend_test.go +++ b/lib/trie/trie_endtoend_test.go @@ -103,7 +103,7 @@ func TestPutAndGetOddKeyLengths(t *testing.T) { func Fuzz_Trie_PutAndGet_Single(f *testing.F) { f.Fuzz(func(t *testing.T, key, value []byte) { - trie := NewEmptyTrie() + trie := NewTrie(nil, nil) trie.Put(key, value) retrievedValue := trie.Get(key) assert.Equal(t, value, retrievedValue) @@ -349,13 +349,13 @@ func TestDelete(t *testing.T) { ssTrie := trie.Snapshot() // Get the Trie root hash for all the 3 tries. - tHash, err := trie.Hash() + tHash, err := DefaultStateVersion.Hash(trie) require.NoError(t, err) - dcTrieHash, err := dcTrie.Hash() + dcTrieHash, err := DefaultStateVersion.Hash(dcTrie) require.NoError(t, err) - ssTrieHash, err := ssTrie.Hash() + ssTrieHash, err := DefaultStateVersion.Hash(ssTrie) require.NoError(t, err) // Root hash for all the 3 tries should be equal. @@ -376,13 +376,13 @@ func TestDelete(t *testing.T) { } // Get the updated root hash of all tries. - tHash, err = trie.Hash() + tHash, err = DefaultStateVersion.Hash(trie) require.NoError(t, err) - dcTrieHash, err = dcTrie.Hash() + dcTrieHash, err = DefaultStateVersion.Hash(dcTrie) require.NoError(t, err) - ssTrieHash, err = ssTrie.Hash() + ssTrieHash, err = DefaultStateVersion.Hash(ssTrie) require.NoError(t, err) // Only the current trie should have a different root hash since it is updated. @@ -432,13 +432,13 @@ func TestClearPrefix(t *testing.T) { ssTrie := trie.Snapshot() // Get the Trie root hash for all the 3 tries. - tHash, err := trie.Hash() + tHash, err := DefaultStateVersion.Hash(trie) require.NoError(t, err) - dcTrieHash, err := dcTrie.Hash() + dcTrieHash, err := DefaultStateVersion.Hash(dcTrie) require.NoError(t, err) - ssTrieHash, err := ssTrie.Hash() + ssTrieHash, err := DefaultStateVersion.Hash(ssTrie) require.NoError(t, err) // Root hash for all the 3 tries should be equal. @@ -464,13 +464,13 @@ func TestClearPrefix(t *testing.T) { } // Get the updated root hash of all tries. - tHash, err = trie.Hash() + tHash, err = DefaultStateVersion.Hash(trie) require.NoError(t, err) - dcTrieHash, err = dcTrie.Hash() + dcTrieHash, err = DefaultStateVersion.Hash(dcTrie) require.NoError(t, err) - ssTrieHash, err = ssTrie.Hash() + ssTrieHash, err = DefaultStateVersion.Hash(ssTrie) require.NoError(t, err) // Only the current trie should have a different root hash since it is updated. @@ -489,13 +489,13 @@ func TestClearPrefix_Small(t *testing.T) { ssTrie := trie.Snapshot() // Get the Trie root hash for all the 3 tries. - tHash, err := trie.Hash() + tHash, err := DefaultStateVersion.Hash(trie) require.NoError(t, err) - dcTrieHash, err := dcTrie.Hash() + dcTrieHash, err := DefaultStateVersion.Hash(dcTrie) require.NoError(t, err) - ssTrieHash, err := ssTrie.Hash() + ssTrieHash, err := DefaultStateVersion.Hash(ssTrie) require.NoError(t, err) // Root hash for all the 3 tries should be equal. @@ -522,13 +522,13 @@ func TestClearPrefix_Small(t *testing.T) { require.Equal(t, expectedRoot, ssTrie.root) // Get the updated root hash of all tries. - tHash, err = trie.Hash() + tHash, err = DefaultStateVersion.Hash(trie) require.NoError(t, err) - dcTrieHash, err = dcTrie.Hash() + dcTrieHash, err = DefaultStateVersion.Hash(dcTrie) require.NoError(t, err) - ssTrieHash, err = ssTrie.Hash() + ssTrieHash, err = DefaultStateVersion.Hash(ssTrie) require.NoError(t, err) require.Equal(t, tHash, dcTrieHash) @@ -600,7 +600,10 @@ func TestTrie_ClearPrefixVsDelete(t *testing.T) { trieClearPrefix.ClearPrefix(prefix) - require.Equal(t, trieClearPrefix.MustHash(), trieDelete.MustHash()) + trieClearPrefixHash := DefaultStateVersion.MustHash(*trieClearPrefix) + trieDeleteHash := DefaultStateVersion.MustHash(*trieDelete) + + require.Equal(t, trieClearPrefixHash, trieDeleteHash) } } } @@ -633,8 +636,12 @@ func TestSnapshot(t *testing.T) { newTrie := parentTrie.Snapshot() newTrie.Put(tests[0].key, tests[0].value) - require.Equal(t, expectedTrie.MustHash(), newTrie.MustHash()) - require.NotEqual(t, parentTrie.MustHash(), newTrie.MustHash()) + expectedTrieHash := DefaultStateVersion.MustHash(*expectedTrie) + newTrieHash := DefaultStateVersion.MustHash(*newTrie) + parentTrieHash := DefaultStateVersion.MustHash(*parentTrie) + + require.Equal(t, expectedTrieHash, newTrieHash) + require.NotEqual(t, parentTrieHash, newTrieHash) } func Test_Trie_NextKey_Random(t *testing.T) { @@ -686,7 +693,7 @@ func Benchmark_Trie_Hash(b *testing.B) { } b.StartTimer() - _, err := trie.Hash() + _, err := DefaultStateVersion.Hash(trie) b.StopTimer() require.NoError(b, err) @@ -767,9 +774,11 @@ func TestTrie_ConcurrentSnapshotWrites(t *testing.T) { finishWg.Wait() for i := 0; i < workers; i++ { - assert.Equal(t, - expectedTries[i].MustHash(), - snapshotedTries[i].MustHash()) + assert.Equal( + t, + DefaultStateVersion.MustHash(*expectedTries[i]), + DefaultStateVersion.MustHash(*snapshotedTries[i]), + ) } } @@ -950,13 +959,13 @@ func TestTrie_ClearPrefixLimitSnapshot(t *testing.T) { ssTrie := trieClearPrefix.Snapshot() // Get the Trie root hash for all the 3 tries. - tHash, err := trieClearPrefix.Hash() + tHash, err := DefaultStateVersion.Hash(trieClearPrefix) require.NoError(t, err) - dcTrieHash, err := dcTrie.Hash() + dcTrieHash, err := DefaultStateVersion.Hash(dcTrie) require.NoError(t, err) - ssTrieHash, err := ssTrie.Hash() + ssTrieHash, err := DefaultStateVersion.Hash(ssTrie) require.NoError(t, err) // Root hash for all the 3 tries should be equal. @@ -992,13 +1001,13 @@ func TestTrie_ClearPrefixLimitSnapshot(t *testing.T) { } // Get the updated root hash of all tries. - tHash, err = trieClearPrefix.Hash() + tHash, err = DefaultStateVersion.Hash(trieClearPrefix) require.NoError(t, err) - dcTrieHash, err = dcTrie.Hash() + dcTrieHash, err = DefaultStateVersion.Hash(dcTrie) require.NoError(t, err) - ssTrieHash, err = ssTrie.Hash() + ssTrieHash, err = DefaultStateVersion.Hash(ssTrie) require.NoError(t, err) // If node got deleted then root hash must be updated else it has same root hash. @@ -1033,7 +1042,7 @@ func Test_encodeRoot_fuzz(t *testing.T) { assert.Equal(t, value, retrievedValue) } buffer := bytes.NewBuffer(nil) - err := trie.root.Encode(buffer) + err := trie.root.Encode(buffer, DefaultStateVersion.MaxInlineValue()) require.NoError(t, err) require.NotEmpty(t, buffer.Bytes()) } diff --git a/lib/trie/trie_test.go b/lib/trie/trie_test.go index d484fe2be5..6dfbb19064 100644 --- a/lib/trie/trie_test.go +++ b/lib/trie/trie_test.go @@ -12,6 +12,7 @@ import ( "github.com/ChainSafe/gossamer/internal/trie/node" "github.com/ChainSafe/gossamer/internal/trie/tracking" "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/trie/db" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -33,6 +34,7 @@ func Test_NewEmptyTrie(t *testing.T) { expectedTrie := &Trie{ childTries: make(map[common.Hash]*Trie), deltas: tracking.New(), + db: db.NewEmptyMemoryDB(), } trie := NewEmptyTrie() assert.Equal(t, expectedTrie, trie) @@ -277,8 +279,8 @@ func Test_Trie_registerDeletedNodeHash(t *testing.T) { testCases := map[string]struct { trie Trie node *Node - pendingDeltas DeltaRecorder - expectedPendingDeltas DeltaRecorder + pendingDeltas *tracking.Deltas + expectedPendingDeltas *tracking.Deltas expectedTrie Trie }{ "dirty_node_not_registered": { @@ -461,7 +463,7 @@ func Test_Trie_MustHash(t *testing.T) { var trie Trie - hash := trie.MustHash() + hash := V0.MustHash(trie) expectedHash := common.Hash{ 0x3, 0x17, 0xa, 0x2e, 0x75, 0x97, 0xb7, 0xb7, @@ -558,7 +560,7 @@ func Test_Trie_Hash(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - hash, err := testCase.trie.Hash() + hash, err := V0.Hash(&testCase.trie) assert.ErrorIs(t, err, testCase.errWrapped) if testCase.errWrapped != nil { @@ -604,12 +606,12 @@ func Test_Trie_Entries(t *testing.T) { t.Parallel() root := &Node{ - PartialKey: []byte{0xa}, + PartialKey: []byte{0x0, 0xa}, StorageValue: []byte("root"), Descendants: 2, Children: padRightChildren([]*Node{ { // index 0 - PartialKey: []byte{2, 0xb}, + PartialKey: []byte{0xb}, StorageValue: []byte("leaf"), }, nil, @@ -626,7 +628,7 @@ func Test_Trie_Entries(t *testing.T) { expectedEntries := map[string][]byte{ string([]byte{0x0a}): []byte("root"), - string([]byte{0xa0, 0x2b}): []byte("leaf"), + string([]byte{0x0a, 0xb}): []byte("leaf"), string([]byte{0x0a, 0x2b}): []byte("leaf"), } @@ -690,12 +692,13 @@ func Test_Trie_Entries(t *testing.T) { entriesMatch(t, expectedEntries, entries) }) - t.Run("end_to_end", func(t *testing.T) { + t.Run("end_to_end_v0", func(t *testing.T) { t.Parallel() trie := Trie{ root: nil, childTries: make(map[common.Hash]*Trie), + db: db.NewEmptyMemoryDB(), } kv := map[string][]byte{ @@ -713,6 +716,31 @@ func Test_Trie_Entries(t *testing.T) { assert.Equal(t, kv, entries) }) + + t.Run("end_to_end_v1", func(t *testing.T) { + t.Parallel() + + trie := Trie{ + root: nil, + childTries: make(map[common.Hash]*Trie), + db: db.NewEmptyMemoryDB(), + } + + kv := map[string][]byte{ + "ab": []byte("pen"), + "abc": []byte("penguin"), + "hy": []byte("feather"), + "long": []byte("newvaluewithmorethan32byteslength"), + } + + for k, v := range kv { + trie.Put([]byte(k), v) + } + + entries := trie.Entries() + + assert.Equal(t, kv, entries) + }) } func Test_Trie_NextKey(t *testing.T) { @@ -1655,6 +1683,7 @@ func Test_LoadFromMap(t *testing.T) { expectedTrie: Trie{ childTries: map[common.Hash]*Trie{}, deltas: newDeltas(), + db: db.NewEmptyMemoryDB(), }, }, "empty_data": { @@ -1662,6 +1691,7 @@ func Test_LoadFromMap(t *testing.T) { expectedTrie: Trie{ childTries: map[common.Hash]*Trie{}, deltas: newDeltas(), + db: db.NewEmptyMemoryDB(), }, }, "bad_key": { @@ -1695,6 +1725,7 @@ func Test_LoadFromMap(t *testing.T) { }, childTries: map[common.Hash]*Trie{}, deltas: newDeltas(), + db: db.NewEmptyMemoryDB(), }, }, "load_key_values": { @@ -1725,6 +1756,7 @@ func Test_LoadFromMap(t *testing.T) { }, childTries: map[common.Hash]*Trie{}, deltas: newDeltas(), + db: db.NewEmptyMemoryDB(), }, }, } @@ -2088,7 +2120,7 @@ func Test_retrieve(t *testing.T) { parent *Node key []byte value []byte - db DBGetter + db db.DBGetter }{ "nil_parent": { key: []byte{1}, @@ -2187,9 +2219,9 @@ func Test_retrieve(t *testing.T) { Children: padRightChildren([]*Node{ nil, nil, nil, nil, { // full key 1, 2, 3, 4, 5 - PartialKey: []byte{5}, - StorageValue: hashedValue, - HashedValue: true, + PartialKey: []byte{5}, + StorageValue: hashedValue, + IsHashedValue: true, }, }), }, @@ -2197,7 +2229,7 @@ func Test_retrieve(t *testing.T) { }, key: []byte{1, 2, 3, 4, 5}, value: hashedValueResult, - db: func() DBGetter { + db: func() db.DBGetter { defaultDBGetterMock := NewMockDBGetter(ctrl) defaultDBGetterMock.EXPECT().Get(gomock.Any()).Return(hashedValueResult, nil).Times(1) diff --git a/lib/trie/version.go b/lib/trie/version.go deleted file mode 100644 index 23527890c8..0000000000 --- a/lib/trie/version.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2022 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package trie - -import ( - "errors" - "fmt" - "strings" -) - -// Version is the state trie version which dictates how a -// Merkle root should be constructed. It is defined in -// https://spec.polkadot.network/#defn-state-version -type Version uint8 - -const ( - // V0 is the state trie version 0 where the values of the keys are - // inserted into the trie directly. - // TODO set to iota once CI passes - V0 Version = 1 -) - -func (v Version) String() string { - switch v { - case V0: - return "v0" - default: - panic(fmt.Sprintf("unknown version %d", v)) - } -} - -var ErrParseVersion = errors.New("parsing version failed") - -// ParseVersion parses a state trie version string. -func ParseVersion(s string) (version Version, err error) { - switch { - case strings.EqualFold(s, V0.String()): - return V0, nil - default: - return version, fmt.Errorf("%w: %q must be %s", - ErrParseVersion, s, V0) - } -} diff --git a/lib/trie/version_test.go b/lib/trie/version_test.go deleted file mode 100644 index ab2ac03ebe..0000000000 --- a/lib/trie/version_test.go +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2022 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package trie - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func Test_Version_String(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - version Version - versionString string - panicMessage string - }{ - "v0": { - version: V0, - versionString: "v0", - }, - "invalid": { - version: Version(99), - panicMessage: "unknown version 99", - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - - if testCase.panicMessage != "" { - assert.PanicsWithValue(t, testCase.panicMessage, func() { - _ = testCase.version.String() - }) - return - } - - versionString := testCase.version.String() - assert.Equal(t, testCase.versionString, versionString) - }) - } -} - -func Test_ParseVersion(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - s string - version Version - errWrapped error - errMessage string - }{ - "v0": { - s: "v0", - version: V0, - }, - "V0": { - s: "V0", - version: V0, - }, - "invalid": { - s: "xyz", - errWrapped: ErrParseVersion, - errMessage: "parsing version failed: \"xyz\" must be v0", - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - - version, err := ParseVersion(testCase.s) - - assert.Equal(t, testCase.version, version) - assert.ErrorIs(t, err, testCase.errWrapped) - if testCase.errWrapped != nil { - assert.EqualError(t, err, testCase.errMessage) - } - }) - } -}