Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return new node on modification #3193

Merged
merged 12 commits into from
Sep 21, 2022
150 changes: 93 additions & 57 deletions db/trie/mptrie/branchnode.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,53 +22,67 @@ type branchNode struct {
}

func newBranchNode(
mpt *merklePatriciaTrie,
cli client,
children map[byte]node,
indices *SortedList,
) (node, error) {
if len(children) == 0 {
return nil, errors.New("branch node children cannot be empty")
}
if indices == nil {
indices = NewSortedList(children)
}
bnode := &branchNode{
cacheNode: cacheNode{
mpt: mpt,
dirty: true,
},
children: children,
indices: NewSortedList(children),
indices: indices,
}
bnode.cacheNode.serializable = bnode
if len(bnode.children) != 0 {
if !mpt.async {
return bnode.store()
if !cli.asyncMode() {
if err := bnode.store(cli); err != nil {
return nil, err
}
}
}
return bnode, nil
}

func newEmptyRootBranchNode(mpt *merklePatriciaTrie) *branchNode {
func newRootBranchNode(cli client, children map[byte]node, indices *SortedList, dirty bool) (branch, error) {
if indices == nil {
indices = NewSortedList(children)
}
bnode := &branchNode{
cacheNode: cacheNode{
mpt: mpt,
dirty: dirty,
},
children: make(map[byte]node),
indices: NewSortedList(nil),
children: children,
indices: indices,
isRoot: true,
}
bnode.cacheNode.serializable = bnode
return bnode
if len(bnode.children) != 0 {
if !cli.asyncMode() {
if err := bnode.store(cli); err != nil {
return nil, err
}
}
}
return bnode, nil
}

func newBranchNodeFromProtoPb(pb *triepb.BranchPb, mpt *merklePatriciaTrie, hashVal []byte) *branchNode {
func newBranchNodeFromProtoPb(pb *triepb.BranchPb, hashVal []byte) *branchNode {
bnode := &branchNode{
cacheNode: cacheNode{
mpt: mpt,
hashVal: hashVal,
dirty: false,
},
children: make(map[byte]node, len(pb.Branches)),
}
for _, n := range pb.Branches {
bnode.children[byte(n.Index)] = newHashNode(mpt, n.Path)
bnode.children[byte(n.Index)] = newHashNode(n.Path)
}
bnode.indices = NewSortedList(bnode.children)
bnode.cacheNode.serializable = bnode
Expand All @@ -87,24 +101,24 @@ func (b *branchNode) Children() []node {
return ret
}

func (b *branchNode) Delete(key keyType, offset uint8) (node, error) {
func (b *branchNode) Delete(cli client, key keyType, offset uint8) (node, error) {
offsetKey := key[offset]
child, err := b.child(offsetKey)
if err != nil {
return nil, err
}
newChild, err := child.Delete(key, offset+1)
newChild, err := child.Delete(cli, key, offset+1)
if err != nil {
return nil, err
}
if newChild != nil || b.isRoot {
return b.updateChild(offsetKey, newChild, false)
return b.updateChild(cli, offsetKey, newChild)
}
switch len(b.children) {
case 1:
panic("branch shouldn't have 0 child after deleting")
case 2:
if err := b.delete(); err != nil {
if err := b.delete(cli); err != nil {
return nil, err
}
var orphan node
Expand All @@ -120,65 +134,63 @@ func (b *branchNode) Delete(key keyType, offset uint8) (node, error) {
panic("unexpected branch status")
}
if hn, ok := orphan.(*hashNode); ok {
if orphan, err = hn.LoadNode(); err != nil {
if orphan, err = hn.LoadNode(cli); err != nil {
return nil, err
}
}
switch node := orphan.(type) {
case *extensionNode:
return node.updatePath(
cli,
append([]byte{orphanKey}, node.path...),
false,
)
case *leafNode:
return node, nil
default:
return newExtensionNode(b.mpt, []byte{orphanKey}, node)
return newExtensionNode(cli, []byte{orphanKey}, node)
}
default:
return b.updateChild(offsetKey, newChild, false)
return b.updateChild(cli, offsetKey, newChild)
}
}

func (b *branchNode) Upsert(key keyType, offset uint8, value []byte) (node, error) {
func (b *branchNode) Upsert(cli client, key keyType, offset uint8, value []byte) (node, error) {
var newChild node
offsetKey := key[offset]
child, err := b.child(offsetKey)
switch errors.Cause(err) {
case nil:
newChild, err = child.Upsert(key, offset+1, value) // look for next key offset
newChild, err = child.Upsert(cli, key, offset+1, value) // look for next key offset
case trie.ErrNotExist:
newChild, err = newLeafNode(b.mpt, key, value)
newChild, err = newLeafNode(cli, key, value)
}
if err != nil {
return nil, err
}

return b.updateChild(offsetKey, newChild, true)
return b.updateChild(cli, offsetKey, newChild)
}

func (b *branchNode) Search(key keyType, offset uint8) (node, error) {
func (b *branchNode) Search(cli client, key keyType, offset uint8) (node, error) {
child, err := b.child(key[offset])
if err != nil {
return nil, err
}
return child.Search(key, offset+1)
return child.Search(cli, key, offset+1)
}

func (b *branchNode) proto(flush bool) (proto.Message, error) {
func (b *branchNode) proto(cli client, flush bool) (proto.Message, error) {
nodes := []*triepb.BranchNodePb{}
for _, idx := range b.indices.List() {
c := b.children[idx]
if flush {
if sn, ok := c.(serializable); ok {
var err error
c, err = sn.store()
if err != nil {
if err := sn.store(cli); err != nil {
return nil, err
}
}
}
h, err := c.Hash()
h, err := c.Hash(cli)
if err != nil {
return nil, err
}
Expand All @@ -199,48 +211,72 @@ func (b *branchNode) child(key byte) (node, error) {
return c, nil
}

func (b *branchNode) Flush() error {
func (b *branchNode) Flush(cli client) error {
if !b.dirty {
return nil
}
for _, idx := range b.indices.List() {
if err := b.children[idx].Flush(); err != nil {
if err := b.children[idx].Flush(cli); err != nil {
return err
}
}
_, err := b.store()
return err

return b.store(cli)
}

func (b *branchNode) updateChild(key byte, child node, hashnode bool) (node, error) {
if err := b.delete(); err != nil {
func (b *branchNode) updateChild(cli client, key byte, child node) (node, error) {
if err := b.delete(cli); err != nil {
return nil, err
}
var indices *SortedList
// update branchnode with new child
children := make(map[byte]node, len(b.children))
for k, v := range b.children {
children[k] = v
}
if child == nil {
delete(b.children, key)
b.indices.Delete(key)
delete(children, key)
if b.indices.sorted {
indices = b.indices.Clone()
indices.Delete(key)
}
} else {
if _, exist := b.children[key]; !exist {
b.indices.Insert(key)
children[key] = child
if b.indices.sorted {
indices = b.indices.Clone()
indices.Insert(key)
}
b.children[key] = child
}
b.dirty = true
if len(b.children) != 0 {
if !b.mpt.async {
hn, err := b.store()
if err != nil {
return nil, err
}
if !b.isRoot && hashnode {
return hn, nil // return hashnode
}
}
} else {
if _, err := b.hash(false); err != nil {

if b.isRoot {
bn, err := newRootBranchNode(cli, children, indices, true)
if err != nil {
return nil, err
}
return bn, nil
}
return newBranchNode(cli, children, indices)
}

func (b *branchNode) Clone() (branch, error) {
children := make(map[byte]node, len(b.children))
for key, child := range b.children {
children[key] = child
}
hashVal := make([]byte, len(b.hashVal))
copy(hashVal, b.hashVal)
ser := make([]byte, len(b.ser))
copy(ser, b.ser)
clone := &branchNode{
cacheNode: cacheNode{
dirty: b.dirty,
hashVal: hashVal,
ser: ser,
},
children: children,
indices: b.indices.Clone(),
isRoot: b.isRoot,
}
return b, nil // return branchnode
clone.cacheNode.serializable = clone
return clone, nil
}
92 changes: 92 additions & 0 deletions db/trie/mptrie/branchnode_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright (c) 2020 IoTeX Foundation
// This is an alpha (internal) release and is not suitable for production. This source code is provided 'as is' and no
// warranties are given as to title or non-infringement, merchantability or fitness for purpose and, to the extent
// permitted by law, all liability for your use of the code is disclaimed. This source code is governed by Apache
// License 2.0 that can be found in the LICENSE file.

package mptrie

import (
"bytes"
"testing"

"github.com/stretchr/testify/require"
)

func equals(bn *branchNode, clone *branchNode) bool {
if bn.isRoot != clone.isRoot {
return false
}
if bn.dirty != clone.dirty {
return false
}
if !bytes.Equal(bn.hashVal, clone.hashVal) || !bytes.Equal(bn.ser, clone.ser) {
return false
}
if len(bn.children) != len(clone.children) {
return false
}
for key, child := range clone.children {
if bn.children[key] != child {
return false
}
}
indices := bn.indices.List()
cloneIndices := clone.indices.List()
if len(indices) != len(cloneIndices) {
return false
}
for i, value := range cloneIndices {
if indices[i] != value {
return false
}
}
return true
}

func TestBranchNodeClone(t *testing.T) {
require := require.New(t)
t.Run("dirty empty root", func(t *testing.T) {
children := map[byte]node{}
indices := NewSortedList(children)
node, err := newRootBranchNode(nil, children, indices, true)
require.NoError(err)
bn, ok := node.(*branchNode)
require.True(ok)
clone, err := node.Clone()
require.NoError(err)
cbn, ok := clone.(*branchNode)
require.True(ok)
equals(bn, cbn)
})
t.Run("clean empty root", func(t *testing.T) {
children := map[byte]node{}
indices := NewSortedList(children)
node, err := newRootBranchNode(nil, children, indices, true)
require.NoError(err)
bn, ok := node.(*branchNode)
require.True(ok)
clone, err := node.Clone()
require.NoError(err)
cbn, ok := clone.(*branchNode)
require.True(ok)
equals(bn, cbn)
})
t.Run("normal branch node", func(t *testing.T) {
children := map[byte]node{}
children['a'] = &hashNode{hashVal: []byte("a")}
children['b'] = &hashNode{hashVal: []byte("b")}
children['c'] = &hashNode{hashVal: []byte("c")}
children['d'] = &hashNode{hashVal: []byte("d")}
indices := NewSortedList(children)
node, err := newBranchNode(&merklePatriciaTrie{async: true}, children, indices)
require.NoError(err)
bn, ok := node.(*branchNode)
require.True(ok)
clone, err := bn.Clone()
require.NoError(err)
cbn, ok := clone.(*branchNode)
require.True(ok)
equals(bn, cbn)
})
}
Loading