diff --git a/lib/trie/trie.go b/lib/trie/trie.go index 9ed24a0d1c..c114486cc8 100644 --- a/lib/trie/trie.go +++ b/lib/trie/trie.go @@ -91,8 +91,7 @@ func (t *Trie) prepForMutation(currentNode *Node, // update the node generation. newNode = currentNode } else { - isRoot := currentNode == t.root - err = registerDeletedMerkleValue(currentNode, isRoot, + err = t.registerDeletedMerkleValue(currentNode, pendingDeletedMerkleValues) if err != nil { return nil, fmt.Errorf("registering deleted node: %w", err) @@ -104,8 +103,9 @@ func (t *Trie) prepForMutation(currentNode *Node, return newNode, nil } -func registerDeletedMerkleValue(node *Node, isRoot bool, +func (t *Trie) registerDeletedMerkleValue(node *Node, pendingDeletedMerkleValues map[string]struct{}) (err error) { + isRoot := node == t.root err = ensureMerkleValueIsCalculated(node, isRoot) if err != nil { return fmt.Errorf("ensuring Merkle value is calculated: %w", err) @@ -810,8 +810,7 @@ func (t *Trie) clearPrefixLimitAtNode(parent *Node, prefix []byte, // TODO check this is the same behaviour as in substrate const allDeleted = true if bytes.HasPrefix(parent.Key, prefix) { - isRoot := parent == t.root - err = registerDeletedMerkleValue(parent, isRoot, deletedMerkleValues) + err = t.registerDeletedMerkleValue(parent, deletedMerkleValues) if err != nil { return nil, 0, 0, false, fmt.Errorf("registering deleted Merkle value: %w", err) @@ -877,7 +876,7 @@ func (t *Trie) clearPrefixLimitBranch(branch *Node, prefix []byte, limit uint32, branch.Children[childIndex] = child branch.Descendants -= nodesRemoved - newParent, branchChildMerged, err := handleDeletion(branch, prefix, deletedMerkleValues) + newParent, branchChildMerged, err := t.handleDeletion(branch, prefix, deletedMerkleValues) if err != nil { return nil, 0, 0, false, fmt.Errorf("handling deletion: %w", err) } @@ -925,7 +924,7 @@ func (t *Trie) clearPrefixLimitChild(branch *Node, prefix []byte, limit uint32, branch.Children[childIndex] = child branch.Descendants -= nodesRemoved - newParent, branchChildMerged, err := handleDeletion(branch, prefix, deletedMerkleValues) + newParent, branchChildMerged, err := t.handleDeletion(branch, prefix, deletedMerkleValues) if err != nil { return nil, 0, 0, false, fmt.Errorf("handling deletion: %w", err) } @@ -952,8 +951,7 @@ func (t *Trie) deleteNodesLimit(parent *Node, limit uint32, } if parent.Kind() == node.Leaf { - isRoot := parent == t.root - err = registerDeletedMerkleValue(parent, isRoot, deletedMerkleValues) + err = t.registerDeletedMerkleValue(parent, deletedMerkleValues) if err != nil { return nil, 0, 0, fmt.Errorf("registering deleted merkle value: %w", err) } @@ -998,7 +996,7 @@ func (t *Trie) deleteNodesLimit(parent *Node, limit uint32, nodesRemoved += newNodesRemoved branch.Descendants -= newNodesRemoved - newParent, branchChildMerged, err = handleDeletion(branch, branch.Key, deletedMerkleValues) + newParent, branchChildMerged, err = t.handleDeletion(branch, branch.Key, deletedMerkleValues) if err != nil { return nil, 0, 0, fmt.Errorf("handling deletion: %w", err) } @@ -1103,8 +1101,7 @@ func (t *Trie) clearPrefixAtNode(parent *Node, prefix []byte, return nil, 0, fmt.Errorf("preparing branch for mutation: %w", err) } - const isRoot = false // child so it cannot be the root - err = registerDeletedMerkleValue(child, isRoot, deletedMerkleValues) + err = t.registerDeletedMerkleValue(child, deletedMerkleValues) if err != nil { return nil, 0, fmt.Errorf("registering deleted merkle value for child: %w", err) } @@ -1112,7 +1109,7 @@ func (t *Trie) clearPrefixAtNode(parent *Node, prefix []byte, branch.Children[childIndex] = nil branch.Descendants -= nodesRemoved var branchChildMerged bool - newParent, branchChildMerged, err = handleDeletion(branch, prefix, deletedMerkleValues) + newParent, branchChildMerged, err = t.handleDeletion(branch, prefix, deletedMerkleValues) if err != nil { return nil, 0, fmt.Errorf("handling deletion: %w", err) } @@ -1151,7 +1148,7 @@ func (t *Trie) clearPrefixAtNode(parent *Node, prefix []byte, branch.Descendants -= nodesRemoved branch.Children[childIndex] = child - newParent, branchChildMerged, err := handleDeletion(branch, prefix, deletedMerkleValues) + newParent, branchChildMerged, err := t.handleDeletion(branch, prefix, deletedMerkleValues) if err != nil { return nil, 0, fmt.Errorf("handling deletion: %w", err) } @@ -1221,8 +1218,7 @@ func (t *Trie) deleteLeaf(parent *Node, key []byte, newParent = nil - isRoot := parent == t.root - err = registerDeletedMerkleValue(parent, isRoot, deletedMerkleValues) + err = t.registerDeletedMerkleValue(parent, deletedMerkleValues) if err != nil { return nil, fmt.Errorf("registering deleted merkle value: %w", err) } @@ -1246,7 +1242,7 @@ func (t *Trie) deleteBranch(branch *Node, key []byte, branch.SubValue = nil deleted = true var branchChildMerged bool - newParent, branchChildMerged, err = handleDeletion(branch, key, deletedMerkleValues) + newParent, branchChildMerged, err = t.handleDeletion(branch, key, deletedMerkleValues) if err != nil { return nil, false, 0, fmt.Errorf("handling deletion: %w", err) } @@ -1287,7 +1283,7 @@ func (t *Trie) deleteBranch(branch *Node, key []byte, branch.Descendants -= nodesRemoved branch.Children[childIndex] = newChild - newParent, branchChildMerged, err := handleDeletion(branch, key, deletedMerkleValues) + newParent, branchChildMerged, err := t.handleDeletion(branch, key, deletedMerkleValues) if err != nil { return nil, false, 0, fmt.Errorf("handling deletion: %w", err) } @@ -1305,7 +1301,7 @@ func (t *Trie) deleteBranch(branch *Node, key []byte, // In this first case, branchChildMerged is returned as true to keep track of the removal // of one node in callers. // If the branch has a value and no child, it will be changed into a leaf. -func handleDeletion(branch *Node, key []byte, +func (t *Trie) handleDeletion(branch *Node, key []byte, deletedMerkleValues map[string]struct{}) ( newNode *Node, branchChildMerged bool, err error) { childrenCount := 0 @@ -1343,8 +1339,7 @@ func handleDeletion(branch *Node, key []byte, const branchChildMerged = true childIndex := firstChildIndex child := branch.Children[firstChildIndex] - const isRoot = false // child so it cannot be the root node - err = registerDeletedMerkleValue(child, isRoot, deletedMerkleValues) + err = t.registerDeletedMerkleValue(child, deletedMerkleValues) if err != nil { return nil, false, fmt.Errorf("registering deleted merkle value: %w", err) } diff --git a/lib/trie/trie_test.go b/lib/trie/trie_test.go index 60c7b7cf09..e25a52c4ee 100644 --- a/lib/trie/trie_test.go +++ b/lib/trie/trie_test.go @@ -279,7 +279,7 @@ func Test_Trie_prepForMutation(t *testing.T) { } } -func Test_registerDeletedMerkleValue(t *testing.T) { +func Test_Trie_registerDeletedMerkleValue(t *testing.T) { t.Parallel() someSmallNode := &Node{ @@ -288,21 +288,34 @@ func Test_registerDeletedMerkleValue(t *testing.T) { } testCases := map[string]struct { + trie Trie node *Node - isRoot bool pendingDeletedMerkleValues map[string]struct{} expectedPendingDeletedMerkleValues map[string]struct{} + expectedTrie Trie }{ "dirty node not registered": { node: &Node{Dirty: true}, }, "clean root node registered": { node: someSmallNode, - isRoot: true, + trie: Trie{root: someSmallNode}, pendingDeletedMerkleValues: map[string]struct{}{}, expectedPendingDeletedMerkleValues: map[string]struct{}{ "`Qm\v\xb6\xe1\xbb\xfb\x12\x93\xf1\xb2v\xea\x95\x05\xe9\xf4\xa4\xe7ُb\r\x05\x11^\v\x85'J\xe1": {}, }, + expectedTrie: Trie{ + root: &Node{ + Key: []byte{1}, + SubValue: []byte{2}, + MerkleValue: []byte{ + 0x60, 0x51, 0x6d, 0x0b, 0xb6, 0xe1, 0xbb, 0xfb, + 0x12, 0x93, 0xf1, 0xb2, 0x76, 0xea, 0x95, 0x05, + 0xe9, 0xf4, 0xa4, 0xe7, 0xd9, 0x8f, 0x62, 0x0d, + 0x05, 0x11, 0x5e, 0x0b, 0x85, 0x27, 0x4a, 0xe1}, + Encoding: []byte{0x41, 0x01, 0x04, 0x02}, + }, + }, }, "clean node with inlined Merkle value not registered": { node: &Node{ @@ -331,10 +344,14 @@ func Test_registerDeletedMerkleValue(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - err := registerDeletedMerkleValue(testCase.node, testCase.isRoot, testCase.pendingDeletedMerkleValues) + trie := testCase.trie + + err := trie.registerDeletedMerkleValue(testCase.node, + testCase.pendingDeletedMerkleValues) require.NoError(t, err) assert.Equal(t, testCase.expectedPendingDeletedMerkleValues, testCase.pendingDeletedMerkleValues) + assert.Equal(t, testCase.expectedTrie, trie) }) } } @@ -4005,10 +4022,11 @@ func Test_Trie_deleteAtNode(t *testing.T) { } } -func Test_handleDeletion(t *testing.T) { +func Test_Trie_handleDeletion(t *testing.T) { t.Parallel() testCases := map[string]struct { + trie Trie branch *Node deletedKey []byte deletedMerkleValues map[string]struct{} @@ -4130,7 +4148,10 @@ func Test_handleDeletion(t *testing.T) { copy(expectedKey, testCase.deletedKey) } - newNode, branchChildMerged, err := handleDeletion( + trie := testCase.trie + expectedTrie := *trie.DeepCopy() + + newNode, branchChildMerged, err := trie.handleDeletion( testCase.branch, testCase.deletedKey, testCase.deletedMerkleValues) assert.ErrorIs(t, err, testCase.errSentinel) @@ -4142,6 +4163,7 @@ func Test_handleDeletion(t *testing.T) { assert.Equal(t, testCase.branchChildMerged, branchChildMerged) assert.Equal(t, expectedKey, testCase.deletedKey) assert.Equal(t, testCase.expectedDeletedMerkleValues, testCase.deletedMerkleValues) + assert.Equal(t, expectedTrie, trie) }) } }