diff --git a/db/trie/mptrie/branchnode.go b/db/trie/mptrie/branchnode.go index 4a1a1497f7..4460a25601 100644 --- a/db/trie/mptrie/branchnode.go +++ b/db/trie/mptrie/branchnode.go @@ -22,53 +22,67 @@ type branchNode struct { } func newBranchNode( - mpt *merklePatriciaTrie, + cli client, children map[byte]node, + indices *SortedList, ) (node, error) { if len(children) == 0 { return nil, errors.New("branch node children cannot be empty") } + if indices == nil { + indices = NewSortedList(children) + } bnode := &branchNode{ cacheNode: cacheNode{ - mpt: mpt, dirty: true, }, children: children, - indices: NewSortedList(children), + indices: indices, } bnode.cacheNode.serializable = bnode if len(bnode.children) != 0 { - if !mpt.async { - return bnode.store() + if !cli.asyncMode() { + if err := bnode.store(cli); err != nil { + return nil, err + } } } return bnode, nil } -func newEmptyRootBranchNode(mpt *merklePatriciaTrie) *branchNode { +func newRootBranchNode(cli client, children map[byte]node, indices *SortedList, dirty bool) (branch, error) { + if indices == nil { + indices = NewSortedList(children) + } bnode := &branchNode{ cacheNode: cacheNode{ - mpt: mpt, + dirty: dirty, }, - children: make(map[byte]node), - indices: NewSortedList(nil), + children: children, + indices: indices, isRoot: true, } bnode.cacheNode.serializable = bnode - return bnode + if len(bnode.children) != 0 { + if !cli.asyncMode() { + if err := bnode.store(cli); err != nil { + return nil, err + } + } + } + return bnode, nil } -func newBranchNodeFromProtoPb(pb *triepb.BranchPb, mpt *merklePatriciaTrie, hashVal []byte) *branchNode { +func newBranchNodeFromProtoPb(pb *triepb.BranchPb, hashVal []byte) *branchNode { bnode := &branchNode{ cacheNode: cacheNode{ - mpt: mpt, hashVal: hashVal, dirty: false, }, children: make(map[byte]node, len(pb.Branches)), } for _, n := range pb.Branches { - bnode.children[byte(n.Index)] = newHashNode(mpt, n.Path) + bnode.children[byte(n.Index)] = newHashNode(n.Path) } bnode.indices = NewSortedList(bnode.children) bnode.cacheNode.serializable = bnode @@ -87,24 +101,24 @@ func (b *branchNode) Children() []node { return ret } -func (b *branchNode) Delete(key keyType, offset uint8) (node, error) { +func (b *branchNode) Delete(cli client, key keyType, offset uint8) (node, error) { offsetKey := key[offset] child, err := b.child(offsetKey) if err != nil { return nil, err } - newChild, err := child.Delete(key, offset+1) + newChild, err := child.Delete(cli, key, offset+1) if err != nil { return nil, err } if newChild != nil || b.isRoot { - return b.updateChild(offsetKey, newChild, false) + return b.updateChild(cli, offsetKey, newChild) } switch len(b.children) { case 1: panic("branch shouldn't have 0 child after deleting") case 2: - if err := b.delete(); err != nil { + if err := b.delete(cli); err != nil { return nil, err } var orphan node @@ -120,65 +134,63 @@ func (b *branchNode) Delete(key keyType, offset uint8) (node, error) { panic("unexpected branch status") } if hn, ok := orphan.(*hashNode); ok { - if orphan, err = hn.LoadNode(); err != nil { + if orphan, err = hn.LoadNode(cli); err != nil { return nil, err } } switch node := orphan.(type) { case *extensionNode: return node.updatePath( + cli, append([]byte{orphanKey}, node.path...), - false, ) case *leafNode: return node, nil default: - return newExtensionNode(b.mpt, []byte{orphanKey}, node) + return newExtensionNode(cli, []byte{orphanKey}, node) } default: - return b.updateChild(offsetKey, newChild, false) + return b.updateChild(cli, offsetKey, newChild) } } -func (b *branchNode) Upsert(key keyType, offset uint8, value []byte) (node, error) { +func (b *branchNode) Upsert(cli client, key keyType, offset uint8, value []byte) (node, error) { var newChild node offsetKey := key[offset] child, err := b.child(offsetKey) switch errors.Cause(err) { case nil: - newChild, err = child.Upsert(key, offset+1, value) // look for next key offset + newChild, err = child.Upsert(cli, key, offset+1, value) // look for next key offset case trie.ErrNotExist: - newChild, err = newLeafNode(b.mpt, key, value) + newChild, err = newLeafNode(cli, key, value) } if err != nil { return nil, err } - return b.updateChild(offsetKey, newChild, true) + return b.updateChild(cli, offsetKey, newChild) } -func (b *branchNode) Search(key keyType, offset uint8) (node, error) { +func (b *branchNode) Search(cli client, key keyType, offset uint8) (node, error) { child, err := b.child(key[offset]) if err != nil { return nil, err } - return child.Search(key, offset+1) + return child.Search(cli, key, offset+1) } -func (b *branchNode) proto(flush bool) (proto.Message, error) { +func (b *branchNode) proto(cli client, flush bool) (proto.Message, error) { nodes := []*triepb.BranchNodePb{} for _, idx := range b.indices.List() { c := b.children[idx] if flush { if sn, ok := c.(serializable); ok { - var err error - c, err = sn.store() - if err != nil { + if err := sn.store(cli); err != nil { return nil, err } } } - h, err := c.Hash() + h, err := c.Hash(cli) if err != nil { return nil, err } @@ -199,48 +211,72 @@ func (b *branchNode) child(key byte) (node, error) { return c, nil } -func (b *branchNode) Flush() error { +func (b *branchNode) Flush(cli client) error { if !b.dirty { return nil } for _, idx := range b.indices.List() { - if err := b.children[idx].Flush(); err != nil { + if err := b.children[idx].Flush(cli); err != nil { return err } } - _, err := b.store() - return err + + return b.store(cli) } -func (b *branchNode) updateChild(key byte, child node, hashnode bool) (node, error) { - if err := b.delete(); err != nil { +func (b *branchNode) updateChild(cli client, key byte, child node) (node, error) { + if err := b.delete(cli); err != nil { return nil, err } + var indices *SortedList // update branchnode with new child + children := make(map[byte]node, len(b.children)) + for k, v := range b.children { + children[k] = v + } if child == nil { - delete(b.children, key) - b.indices.Delete(key) + delete(children, key) + if b.indices.sorted { + indices = b.indices.Clone() + indices.Delete(key) + } } else { - if _, exist := b.children[key]; !exist { - b.indices.Insert(key) + children[key] = child + if b.indices.sorted { + indices = b.indices.Clone() + indices.Insert(key) } - b.children[key] = child } - b.dirty = true - if len(b.children) != 0 { - if !b.mpt.async { - hn, err := b.store() - if err != nil { - return nil, err - } - if !b.isRoot && hashnode { - return hn, nil // return hashnode - } - } - } else { - if _, err := b.hash(false); err != nil { + + if b.isRoot { + bn, err := newRootBranchNode(cli, children, indices, true) + if err != nil { return nil, err } + return bn, nil + } + return newBranchNode(cli, children, indices) +} + +func (b *branchNode) Clone() (branch, error) { + children := make(map[byte]node, len(b.children)) + for key, child := range b.children { + children[key] = child + } + hashVal := make([]byte, len(b.hashVal)) + copy(hashVal, b.hashVal) + ser := make([]byte, len(b.ser)) + copy(ser, b.ser) + clone := &branchNode{ + cacheNode: cacheNode{ + dirty: b.dirty, + hashVal: hashVal, + ser: ser, + }, + children: children, + indices: b.indices.Clone(), + isRoot: b.isRoot, } - return b, nil // return branchnode + clone.cacheNode.serializable = clone + return clone, nil } diff --git a/db/trie/mptrie/branchnode_test.go b/db/trie/mptrie/branchnode_test.go new file mode 100644 index 0000000000..ff693b4fa7 --- /dev/null +++ b/db/trie/mptrie/branchnode_test.go @@ -0,0 +1,92 @@ +// Copyright (c) 2020 IoTeX Foundation +// This is an alpha (internal) release and is not suitable for production. This source code is provided 'as is' and no +// warranties are given as to title or non-infringement, merchantability or fitness for purpose and, to the extent +// permitted by law, all liability for your use of the code is disclaimed. This source code is governed by Apache +// License 2.0 that can be found in the LICENSE file. + +package mptrie + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +func equals(bn *branchNode, clone *branchNode) bool { + if bn.isRoot != clone.isRoot { + return false + } + if bn.dirty != clone.dirty { + return false + } + if !bytes.Equal(bn.hashVal, clone.hashVal) || !bytes.Equal(bn.ser, clone.ser) { + return false + } + if len(bn.children) != len(clone.children) { + return false + } + for key, child := range clone.children { + if bn.children[key] != child { + return false + } + } + indices := bn.indices.List() + cloneIndices := clone.indices.List() + if len(indices) != len(cloneIndices) { + return false + } + for i, value := range cloneIndices { + if indices[i] != value { + return false + } + } + return true +} + +func TestBranchNodeClone(t *testing.T) { + require := require.New(t) + t.Run("dirty empty root", func(t *testing.T) { + children := map[byte]node{} + indices := NewSortedList(children) + node, err := newRootBranchNode(nil, children, indices, true) + require.NoError(err) + bn, ok := node.(*branchNode) + require.True(ok) + clone, err := node.Clone() + require.NoError(err) + cbn, ok := clone.(*branchNode) + require.True(ok) + equals(bn, cbn) + }) + t.Run("clean empty root", func(t *testing.T) { + children := map[byte]node{} + indices := NewSortedList(children) + node, err := newRootBranchNode(nil, children, indices, true) + require.NoError(err) + bn, ok := node.(*branchNode) + require.True(ok) + clone, err := node.Clone() + require.NoError(err) + cbn, ok := clone.(*branchNode) + require.True(ok) + equals(bn, cbn) + }) + t.Run("normal branch node", func(t *testing.T) { + children := map[byte]node{} + children['a'] = &hashNode{hashVal: []byte("a")} + children['b'] = &hashNode{hashVal: []byte("b")} + children['c'] = &hashNode{hashVal: []byte("c")} + children['d'] = &hashNode{hashVal: []byte("d")} + indices := NewSortedList(children) + node, err := newBranchNode(&merklePatriciaTrie{async: true}, children, indices) + require.NoError(err) + bn, ok := node.(*branchNode) + require.True(ok) + clone, err := bn.Clone() + require.NoError(err) + cbn, ok := clone.(*branchNode) + require.True(ok) + equals(bn, cbn) + }) +} diff --git a/db/trie/mptrie/cachenode.go b/db/trie/mptrie/cachenode.go index bf0b09e57c..edf23fe4eb 100644 --- a/db/trie/mptrie/cachenode.go +++ b/db/trie/mptrie/cachenode.go @@ -13,20 +13,19 @@ import ( type cacheNode struct { dirty bool serializable - mpt *merklePatriciaTrie hashVal []byte ser []byte } -func (cn *cacheNode) Hash() ([]byte, error) { - return cn.hash(false) +func (cn *cacheNode) Hash(cli client) ([]byte, error) { + return cn.hash(cli, false) } -func (cn *cacheNode) hash(flush bool) ([]byte, error) { - if cn.hashVal != nil { +func (cn *cacheNode) hash(cli client, flush bool) ([]byte, error) { + if len(cn.hashVal) != 0 { return cn.hashVal, nil } - pb, err := cn.proto(flush) + pb, err := cn.proto(cli, flush) if err != nil { return nil, err } @@ -36,18 +35,18 @@ func (cn *cacheNode) hash(flush bool) ([]byte, error) { } cn.ser = ser - cn.hashVal = cn.mpt.hashFunc(ser) + cn.hashVal = cli.hash(ser) return cn.hashVal, nil } -func (cn *cacheNode) delete() error { +func (cn *cacheNode) delete(cli client) error { if !cn.dirty { - h, err := cn.hash(false) + h, err := cn.hash(cli, false) if err != nil { return err } - if err := cn.mpt.deleteNode(h); err != nil { + if err := cli.deleteNode(h); err != nil { return err } } @@ -57,16 +56,18 @@ func (cn *cacheNode) delete() error { return nil } -func (cn *cacheNode) store() (node, error) { - h, err := cn.hash(true) +func (cn *cacheNode) store(cli client) error { + if !cn.dirty { + return nil + } + h, err := cn.hash(cli, true) if err != nil { - return nil, err + return err } - if cn.dirty { - if err := cn.mpt.putNode(h, cn.ser); err != nil { - return nil, err - } - cn.dirty = false + if err := cli.putNode(h, cn.ser); err != nil { + return err } - return newHashNode(cn.mpt, h), nil + cn.dirty = false + + return nil } diff --git a/db/trie/mptrie/extensionnode.go b/db/trie/mptrie/extensionnode.go index dfce572c36..9bf28d1eae 100644 --- a/db/trie/mptrie/extensionnode.go +++ b/db/trie/mptrie/extensionnode.go @@ -21,13 +21,12 @@ type extensionNode struct { } func newExtensionNode( - mpt *merklePatriciaTrie, + cli client, path []byte, child node, ) (node, error) { e := &extensionNode{ cacheNode: cacheNode{ - mpt: mpt, dirty: true, }, path: path, @@ -35,80 +34,82 @@ func newExtensionNode( } e.cacheNode.serializable = e - if !mpt.async { - return e.store() + if !cli.asyncMode() { + if err := e.store(cli); err != nil { + return nil, err + } } return e, nil } -func newExtensionNodeFromProtoPb(pb *triepb.ExtendPb, mpt *merklePatriciaTrie, hashVal []byte) *extensionNode { +func newExtensionNodeFromProtoPb(pb *triepb.ExtendPb, hashVal []byte) *extensionNode { e := &extensionNode{ cacheNode: cacheNode{ - mpt: mpt, hashVal: hashVal, dirty: false, }, path: pb.Path, - child: newHashNode(mpt, pb.Value), + child: newHashNode(pb.Value), } e.cacheNode.serializable = e return e } -func (e *extensionNode) Delete(key keyType, offset uint8) (node, error) { +func (e *extensionNode) Delete(cli client, key keyType, offset uint8) (node, error) { matched := e.commonPrefixLength(key[offset:]) if matched != uint8(len(e.path)) { return nil, trie.ErrNotExist } - newChild, err := e.child.Delete(key, offset+matched) + newChild, err := e.child.Delete(cli, key, offset+matched) if err != nil { return nil, err } if newChild == nil { - return nil, e.delete() + return nil, e.delete(cli) } if hn, ok := newChild.(*hashNode); ok { - if newChild, err = hn.LoadNode(); err != nil { + if newChild, err = hn.LoadNode(cli); err != nil { return nil, err } } switch node := newChild.(type) { case *extensionNode: - return node.updatePath(append(e.path, node.path...), false) + return node.updatePath(cli, append(e.path, node.path...)) case *branchNode: - return e.updateChild(node, false) + return e.updateChild(cli, node) default: - if err := e.delete(); err != nil { + if err := e.delete(cli); err != nil { return nil, err } return node, nil } } -func (e *extensionNode) Upsert(key keyType, offset uint8, value []byte) (node, error) { +func (e *extensionNode) Upsert(cli client, key keyType, offset uint8, value []byte) (node, error) { matched := e.commonPrefixLength(key[offset:]) if matched == uint8(len(e.path)) { - newChild, err := e.child.Upsert(key, offset+matched, value) + newChild, err := e.child.Upsert(cli, key, offset+matched, value) if err != nil { return nil, err } - return e.updateChild(newChild, true) + return e.updateChild(cli, newChild) } eb := e.path[matched] - enode, err := e.updatePath(e.path[matched+1:], true) + enode, err := e.updatePath(cli, e.path[matched+1:]) if err != nil { return nil, err } - lnode, err := newLeafNode(e.mpt, key, value) + lnode, err := newLeafNode(cli, key, value) if err != nil { return nil, err } bnode, err := newBranchNode( - e.mpt, + cli, map[byte]node{ eb: enode, key[offset+matched]: lnode, }, + nil, ) if err != nil { return nil, err @@ -116,28 +117,27 @@ func (e *extensionNode) Upsert(key keyType, offset uint8, value []byte) (node, e if matched == 0 { return bnode, nil } - return newExtensionNode(e.mpt, key[offset:offset+matched], bnode) + return newExtensionNode(cli, key[offset:offset+matched], bnode) } -func (e *extensionNode) Search(key keyType, offset uint8) (node, error) { +func (e *extensionNode) Search(cli client, key keyType, offset uint8) (node, error) { matched := e.commonPrefixLength(key[offset:]) if matched != uint8(len(e.path)) { return nil, trie.ErrNotExist } - return e.child.Search(key, offset+matched) + return e.child.Search(cli, key, offset+matched) } -func (e *extensionNode) proto(flush bool) (proto.Message, error) { +func (e *extensionNode) proto(cli client, flush bool) (proto.Message, error) { if flush { if sn, ok := e.child.(serializable); ok { - _, err := sn.store() - if err != nil { + if err := sn.store(cli); err != nil { return nil, err } } } - h, err := e.child.Hash() + h, err := e.child.Hash(cli) if err != nil { return nil, err } @@ -159,52 +159,30 @@ func (e *extensionNode) commonPrefixLength(key []byte) uint8 { return commonPrefixLength(e.path, key) } -func (e *extensionNode) Flush() error { +func (e *extensionNode) Flush(cli client) error { if !e.dirty { return nil } - if err := e.child.Flush(); err != nil { + if err := e.child.Flush(cli); err != nil { return err } - _, err := e.store() - return err + + return e.store(cli) } -func (e *extensionNode) updatePath(path []byte, hashnode bool) (node, error) { - if err := e.delete(); err != nil { +func (e *extensionNode) updatePath(cli client, path []byte) (node, error) { + if err := e.delete(cli); err != nil { return nil, err } - e.path = path - e.dirty = true - - if !e.mpt.async { - hn, err := e.store() - if err != nil { - return nil, err - } - if hashnode { - return hn, nil - } - } - return e, nil + return newExtensionNode(cli, path, e.child) } -func (e *extensionNode) updateChild(newChild node, hashnode bool) (node, error) { - err := e.delete() +func (e *extensionNode) updateChild(cli client, newChild node) (node, error) { + err := e.delete(cli) if err != nil { return nil, err } - e.child = newChild - e.dirty = true - - if !e.mpt.async { - hn, err := e.store() - if err != nil { - return nil, err - } - if hashnode { - return hn, nil - } - } - return e, nil + path := make([]byte, len(e.path)) + copy(path, e.path) + return newExtensionNode(cli, path, newChild) } diff --git a/db/trie/mptrie/hashnode.go b/db/trie/mptrie/hashnode.go index 86b79c966c..32c9fbbf5b 100644 --- a/db/trie/mptrie/hashnode.go +++ b/db/trie/mptrie/hashnode.go @@ -8,53 +8,52 @@ package mptrie type hashNode struct { node - mpt *merklePatriciaTrie hashVal []byte } -func newHashNode(mpt *merklePatriciaTrie, ha []byte) *hashNode { - return &hashNode{mpt: mpt, hashVal: ha} +func newHashNode(ha []byte) *hashNode { + return &hashNode{hashVal: ha} } -func (h *hashNode) Flush() error { +func (h *hashNode) Flush(_ client) error { return nil } -func (h *hashNode) Delete(key keyType, offset uint8) (node, error) { - n, err := h.loadNode() +func (h *hashNode) Delete(cli client, key keyType, offset uint8) (node, error) { + n, err := h.loadNode(cli) if err != nil { return nil, err } - return n.Delete(key, offset) + return n.Delete(cli, key, offset) } -func (h *hashNode) Upsert(key keyType, offset uint8, value []byte) (node, error) { - n, err := h.loadNode() +func (h *hashNode) Upsert(cli client, key keyType, offset uint8, value []byte) (node, error) { + n, err := h.loadNode(cli) if err != nil { return nil, err } - return n.Upsert(key, offset, value) + return n.Upsert(cli, key, offset, value) } -func (h *hashNode) Search(key keyType, offset uint8) (node, error) { - node, err := h.loadNode() +func (h *hashNode) Search(cli client, key keyType, offset uint8) (node, error) { + node, err := h.loadNode(cli) if err != nil { return nil, err } - return node.Search(key, offset) + return node.Search(cli, key, offset) } -func (h *hashNode) LoadNode() (node, error) { - return h.loadNode() +func (h *hashNode) LoadNode(cli client) (node, error) { + return h.loadNode(cli) } -func (h *hashNode) loadNode() (node, error) { - return h.mpt.loadNode(h.hashVal) +func (h *hashNode) loadNode(cli client) (node, error) { + return cli.loadNode(h.hashVal) } -func (h *hashNode) Hash() ([]byte, error) { +func (h *hashNode) Hash(_ client) ([]byte, error) { return h.hashVal, nil } diff --git a/db/trie/mptrie/leafiterator.go b/db/trie/mptrie/leafiterator.go index 920bd74cb7..4bedb86ccb 100644 --- a/db/trie/mptrie/leafiterator.go +++ b/db/trie/mptrie/leafiterator.go @@ -14,7 +14,7 @@ import ( // LeafIterator defines an iterator to go through all the leaves under given node type LeafIterator struct { - mpt *merklePatriciaTrie + cli client stack []node } @@ -26,7 +26,7 @@ func NewLeafIterator(tr trie.Trie) (trie.Iterator, error) { } stack := []node{mpt.root} - return &LeafIterator{mpt: mpt, stack: stack}, nil + return &LeafIterator{cli: mpt, stack: stack}, nil } // Next moves iterator to next node @@ -36,7 +36,7 @@ func (li *LeafIterator) Next() ([]byte, []byte, error) { node := li.stack[size-1] li.stack = li.stack[:size-1] if hn, ok := node.(*hashNode); ok { - node, err := hn.LoadNode() + node, err := hn.LoadNode(li.cli) if err != nil { return nil, nil, err } diff --git a/db/trie/mptrie/leafiterator_test.go b/db/trie/mptrie/leafiterator_test.go index b5bbd07054..3c77e58f45 100644 --- a/db/trie/mptrie/leafiterator_test.go +++ b/db/trie/mptrie/leafiterator_test.go @@ -30,12 +30,10 @@ func TestIterator(t *testing.T) { mpt, err := New(KVStoreOption(memStore), KeyLengthOption(5), AsyncOption()) require.NoError(err) - err = mpt.Start(context.Background()) - require.NoError(err) + require.NoError(mpt.Start(context.Background())) for _, item := range items { - err = mpt.Upsert([]byte(item.k), []byte(item.v)) - require.NoError(err) + require.NoError(mpt.Upsert([]byte(item.k), []byte(item.v))) } iter, err := NewLeafIterator(mpt) diff --git a/db/trie/mptrie/leafnode.go b/db/trie/mptrie/leafnode.go index b56fca4f58..49647ff11f 100644 --- a/db/trie/mptrie/leafnode.go +++ b/db/trie/mptrie/leafnode.go @@ -22,29 +22,29 @@ type leafNode struct { } func newLeafNode( - mpt *merklePatriciaTrie, + cli client, key keyType, value []byte, ) (node, error) { l := &leafNode{ cacheNode: cacheNode{ - mpt: mpt, dirty: true, }, key: key, value: value, } l.cacheNode.serializable = l - if !mpt.async { - return l.store() + if !cli.asyncMode() { + if err := l.store(cli); err != nil { + return nil, err + } } return l, nil } -func newLeafNodeFromProtoPb(pb *triepb.LeafPb, mpt *merklePatriciaTrie, hashVal []byte) *leafNode { +func newLeafNodeFromProtoPb(pb *triepb.LeafPb, hashVal []byte) *leafNode { l := &leafNode{ cacheNode: cacheNode{ - mpt: mpt, hashVal: hashVal, dirty: false, }, @@ -63,38 +63,39 @@ func (l *leafNode) Value() []byte { return l.value } -func (l *leafNode) Delete(key keyType, offset uint8) (node, error) { +func (l *leafNode) Delete(cli client, key keyType, offset uint8) (node, error) { if !bytes.Equal(l.key[offset:], key[offset:]) { return nil, trie.ErrNotExist } - return nil, l.delete() + return nil, l.delete(cli) } -func (l *leafNode) Upsert(key keyType, offset uint8, value []byte) (node, error) { +func (l *leafNode) Upsert(cli client, key keyType, offset uint8, value []byte) (node, error) { matched := commonPrefixLength(l.key[offset:], key[offset:]) if offset+matched == uint8(len(key)) { - return l.updateValue(value) + if err := l.delete(cli); err != nil { + return nil, err + } + return newLeafNode(cli, key, value) } // split into another leaf node and create branch/extension node - newl, err := newLeafNode(l.mpt, key, value) + newl, err := newLeafNode(cli, key, value) if err != nil { return nil, err } - var oldLeaf node - if !l.mpt.async { - oldLeaf, err = l.store() - if err != nil { + oldLeaf := l + if !cli.asyncMode() { + if err := l.store(cli); err != nil { return nil, err } - } else { - oldLeaf = l } bnode, err := newBranchNode( - l.mpt, + cli, map[byte]node{ key[offset+matched]: newl, l.key[offset+matched]: oldLeaf, }, + nil, ) if err != nil { return nil, err @@ -103,10 +104,10 @@ func (l *leafNode) Upsert(key keyType, offset uint8, value []byte) (node, error) return bnode, nil } - return newExtensionNode(l.mpt, l.key[offset:offset+matched], bnode) + return newExtensionNode(cli, l.key[offset:offset+matched], bnode) } -func (l *leafNode) Search(key keyType, offset uint8) (node, error) { +func (l *leafNode) Search(_ client, key keyType, offset uint8) (node, error) { if !bytes.Equal(l.key[offset:], key[offset:]) { return nil, trie.ErrNotExist } @@ -114,7 +115,7 @@ func (l *leafNode) Search(key keyType, offset uint8) (node, error) { return l, nil } -func (l *leafNode) proto(_ bool) (proto.Message, error) { +func (l *leafNode) proto(_ client, _ bool) (proto.Message, error) { return &triepb.NodePb{ Node: &triepb.NodePb_Leaf{ Leaf: &triepb.LeafPb{ @@ -125,22 +126,6 @@ func (l *leafNode) proto(_ bool) (proto.Message, error) { }, nil } -func (l *leafNode) Flush() error { - if !l.dirty { - return nil - } - _, err := l.store() - return err -} - -func (l *leafNode) updateValue(value []byte) (node, error) { - if err := l.delete(); err != nil { - return nil, err - } - l.value = value - l.dirty = true - if !l.mpt.async { - return l.store() - } - return l, nil +func (l *leafNode) Flush(cli client) error { + return l.store(cli) } diff --git a/db/trie/mptrie/merklepatriciatrie.go b/db/trie/mptrie/merklepatriciatrie.go index 6cad155c14..fde787dc3b 100644 --- a/db/trie/mptrie/merklepatriciatrie.go +++ b/db/trie/mptrie/merklepatriciatrie.go @@ -109,7 +109,11 @@ func (mpt *merklePatriciaTrie) Start(ctx context.Context) error { mpt.mutex.Lock() defer mpt.mutex.Unlock() - emptyRootHash, err := newEmptyRootBranchNode(mpt).Hash() + emptyRoot, err := newRootBranchNode(mpt, nil, nil, false) + if err != nil { + return err + } + emptyRootHash, err := emptyRoot.Hash(mpt) if err != nil { return err } @@ -127,10 +131,10 @@ func (mpt *merklePatriciaTrie) Stop(_ context.Context) error { func (mpt *merklePatriciaTrie) RootHash() ([]byte, error) { if mpt.async { - if err := mpt.root.Flush(); err != nil { + if err := mpt.root.Flush(mpt); err != nil { return nil, err } - h, err := mpt.root.Hash() + h, err := mpt.root.Hash(mpt) if err != nil { return nil, err } @@ -165,7 +169,7 @@ func (mpt *merklePatriciaTrie) Get(key []byte) ([]byte, error) { if err != nil { return nil, err } - t, err := mpt.root.Search(kt, 0) + t, err := mpt.root.Search(mpt, kt, 0) if err != nil { return nil, err } @@ -184,12 +188,25 @@ func (mpt *merklePatriciaTrie) Delete(key []byte) error { if err != nil { return err } - newRoot, err := mpt.root.Delete(kt, 0) + newRoot, err := mpt.root.Delete(mpt, kt, 0) if err != nil { return errors.Wrapf(trie.ErrNotExist, "key %x does not exist", kt) } - bn, ok := newRoot.(branch) - if !ok { + var bn branch + switch n := newRoot.(type) { + case branch: + bn = n + case *hashNode: + newRoot, err = n.LoadNode(mpt) + if err != nil { + return err + } + var ok bool + bn, ok = newRoot.(branch) + if !ok { + panic("unexpected new root") + } + default: panic("unexpected new root") } @@ -204,7 +221,7 @@ func (mpt *merklePatriciaTrie) Upsert(key []byte, value []byte) error { if err != nil { return err } - newRoot, err := mpt.root.Upsert(kt, 0, value) + newRoot, err := mpt.root.Upsert(mpt, kt, 0, value) if err != nil { return err } @@ -222,7 +239,10 @@ func (mpt *merklePatriciaTrie) isEmptyRootHash(h []byte) bool { func (mpt *merklePatriciaTrie) setRootHash(rootHash []byte) error { if len(rootHash) == 0 || mpt.isEmptyRootHash(rootHash) { - emptyRoot := newEmptyRootBranchNode(mpt) + emptyRoot, err := newRootBranchNode(mpt, nil, nil, false) + if err != nil { + return err + } return mpt.resetRoot(emptyRoot, mpt.emptyRootHash) } node, err := mpt.loadNode(rootHash) @@ -245,7 +265,7 @@ func (mpt *merklePatriciaTrie) resetRoot(newRoot branch, rootHash []byte) error } if rootHash == nil { var err error - rootHash, err = newRoot.Hash() + rootHash, err = newRoot.Hash(mpt) if err != nil { return err } @@ -256,6 +276,10 @@ func (mpt *merklePatriciaTrie) resetRoot(newRoot branch, rootHash []byte) error return nil } +func (mpt *merklePatriciaTrie) asyncMode() bool { + return mpt.async +} + func (mpt *merklePatriciaTrie) checkKeyType(key []byte) (keyType, error) { if len(key) != mpt.keyLength { return nil, errors.Errorf("invalid key length %d", len(key)) @@ -266,6 +290,10 @@ func (mpt *merklePatriciaTrie) checkKeyType(key []byte) (keyType, error) { return kt, nil } +func (mpt *merklePatriciaTrie) hash(key []byte) []byte { + return mpt.hashFunc(key) +} + func (mpt *merklePatriciaTrie) deleteNode(key []byte) error { return mpt.kvStore.Delete(key) } @@ -284,14 +312,38 @@ func (mpt *merklePatriciaTrie) loadNode(key []byte) (node, error) { return nil, err } if pbBranch := pb.GetBranch(); pbBranch != nil { - return newBranchNodeFromProtoPb(pbBranch, mpt, key), nil + return newBranchNodeFromProtoPb(pbBranch, key), nil } if pbLeaf := pb.GetLeaf(); pbLeaf != nil { - return newLeafNodeFromProtoPb(pbLeaf, mpt, key), nil + return newLeafNodeFromProtoPb(pbLeaf, key), nil } if pbExtend := pb.GetExtend(); pbExtend != nil { - return newExtensionNodeFromProtoPb(pbExtend, mpt, key), nil + return newExtensionNodeFromProtoPb(pbExtend, key), nil } return nil, errors.New("invalid node type") } + +func (mpt *merklePatriciaTrie) Clone(kvStore trie.KVStore) (trie.Trie, error) { + mpt.mutex.RLock() + defer mpt.mutex.RUnlock() + root, err := mpt.root.Clone() + if err != nil { + return nil, err + } + rh := make([]byte, len(mpt.rootHash)) + copy(rh, mpt.rootHash) + erh := make([]byte, len(mpt.emptyRootHash)) + copy(erh, mpt.emptyRootHash) + + return &merklePatriciaTrie{ + keyLength: mpt.keyLength, + root: root, + rootHash: rh, + rootKey: mpt.rootKey, + kvStore: kvStore, + hashFunc: mpt.hashFunc, + async: mpt.async, + emptyRootHash: erh, + }, nil +} diff --git a/db/trie/mptrie/merklepatriciatrie_test.go b/db/trie/mptrie/merklepatriciatrie_test.go index d1950cbf84..c78735940d 100644 --- a/db/trie/mptrie/merklepatriciatrie_test.go +++ b/db/trie/mptrie/merklepatriciatrie_test.go @@ -539,7 +539,13 @@ func Test4kEntries(t *testing.T) { func test4kEntries(t *testing.T, enableAsync bool) { require := require.New(t) - tr, err := New(KeyLengthOption(4), AsyncOption()) + var tr trie.Trie + var err error + if enableAsync { + tr, err = New(KeyLengthOption(4), AsyncOption()) + } else { + tr, err = New(KeyLengthOption(4)) + } require.NoError(err) require.NoError(tr.Start(context.Background())) root, err := tr.RootHash() diff --git a/db/trie/mptrie/node.go b/db/trie/mptrie/node.go index 15b552279f..b759982e25 100644 --- a/db/trie/mptrie/node.go +++ b/db/trie/mptrie/node.go @@ -17,20 +17,28 @@ var ErrNoData = errors.New("no data in hash node") type ( keyType []byte + client interface { + asyncMode() bool + hash([]byte) []byte + loadNode([]byte) (node, error) + deleteNode([]byte) error + putNode([]byte, []byte) error + } + node interface { - Search(keyType, uint8) (node, error) - Delete(keyType, uint8) (node, error) - Upsert(keyType, uint8, []byte) (node, error) - Hash() ([]byte, error) - Flush() error + Search(client, keyType, uint8) (node, error) + Delete(client, keyType, uint8) (node, error) + Upsert(client, keyType, uint8, []byte) (node, error) + Hash(client) ([]byte, error) + Flush(client) error } serializable interface { node - hash(flush bool) ([]byte, error) - proto(flush bool) (proto.Message, error) - delete() error - store() (node, error) + hash(client, bool) ([]byte, error) + proto(client, bool) (proto.Message, error) + delete(client) error + store(client) error } leaf interface { @@ -50,6 +58,7 @@ type ( node Children() []node MarkAsRoot() + Clone() (branch, error) } ) diff --git a/db/trie/mptrie/sortedlist.go b/db/trie/mptrie/sortedlist.go index 4e952f1381..c857f1b18b 100644 --- a/db/trie/mptrie/sortedlist.go +++ b/db/trie/mptrie/sortedlist.go @@ -6,16 +6,16 @@ import ( // SortedList is a data structure where elements are in ascending order type SortedList struct { - li []uint8 - isSorted bool // lazy-initilization + li []uint8 + sorted bool } // NewSortedList create SortedList from keys in the children map func NewSortedList(children map[byte]node) *SortedList { if len(children) == 0 { return &SortedList{ - li: make([]uint8, 0), - isSorted: true, + li: make([]uint8, 0), + sorted: true, } } li := make([]uint8, 0, len(children)) @@ -23,14 +23,23 @@ func NewSortedList(children map[byte]node) *SortedList { li = append(li, k) } return &SortedList{ - li: li, - isSorted: false, + li: li, } } +func (sl *SortedList) sort() { + if sl.sorted { + return + } + sort.Slice(sl.li, func(i, j int) bool { + return sl.li[i] < sl.li[j] + }) + sl.sorted = true +} + // Insert insert key into sortedlist func (sl *SortedList) Insert(key uint8) { - sl.sortIfNeed() + sl.sort() i := sort.Search(len(sl.li), func(i int) bool { return sl.li[i] >= uint8(key) }) @@ -47,13 +56,13 @@ func (sl *SortedList) Insert(key uint8) { // List returns sorted indices func (sl *SortedList) List() []uint8 { - sl.sortIfNeed() + sl.sort() return sl.li } // Delete deletes key in the sortedlist func (sl *SortedList) Delete(key uint8) { - sl.sortIfNeed() + sl.sort() i := sort.Search(len(sl.li), func(i int) bool { return sl.li[i] >= key }) @@ -65,11 +74,9 @@ func (sl *SortedList) Delete(key uint8) { } } -func (sl *SortedList) sortIfNeed() { - if !sl.isSorted { - sort.Slice(sl.li, func(i, j int) bool { - return sl.li[i] < sl.li[j] - }) - sl.isSorted = true - } +// Clone clones a sorted list +func (sl *SortedList) Clone() *SortedList { + li := make([]uint8, len(sl.li)) + copy(li, sl.li) + return &SortedList{li: li, sorted: sl.sorted} } diff --git a/db/trie/trie.go b/db/trie/trie.go index fb2824c9e6..c484224c38 100644 --- a/db/trie/trie.go +++ b/db/trie/trie.go @@ -47,6 +47,8 @@ type ( SetRootHash([]byte) error // IsEmpty returns true is this is an empty trie IsEmpty() bool + // Clone clones a trie with a new kvstore + Clone(KVStore) (Trie, error) } // TwoLayerTrie is a trie data structure with two layers TwoLayerTrie interface { diff --git a/test/mock/mock_trie/mock_trie.go b/test/mock/mock_trie/mock_trie.go index c71f83a5af..0d051c84a5 100644 --- a/test/mock/mock_trie/mock_trie.go +++ b/test/mock/mock_trie/mock_trie.go @@ -9,6 +9,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + trie "github.com/iotexproject/iotex-core/db/trie" ) // MockIterator is a mock of Iterator interface. @@ -73,6 +74,21 @@ func (m *MockTrie) EXPECT() *MockTrieMockRecorder { return m.recorder } +// Clone mocks base method. +func (m *MockTrie) Clone(arg0 trie.KVStore) (trie.Trie, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Clone", arg0) + ret0, _ := ret[0].(trie.Trie) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Clone indicates an expected call of Clone. +func (mr *MockTrieMockRecorder) Clone(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Clone", reflect.TypeOf((*MockTrie)(nil).Clone), arg0) +} + // Delete mocks base method. func (m *MockTrie) Delete(arg0 []byte) error { m.ctrl.T.Helper() @@ -210,6 +226,21 @@ func (m *MockTwoLayerTrie) EXPECT() *MockTwoLayerTrieMockRecorder { return m.recorder } +// Clone mocks base method. +func (m *MockTwoLayerTrie) Clone(arg0 trie.KVStore) (trie.TwoLayerTrie, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Clone", arg0) + ret0, _ := ret[0].(trie.TwoLayerTrie) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Clone indicates an expected call of Clone. +func (mr *MockTwoLayerTrieMockRecorder) Clone(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Clone", reflect.TypeOf((*MockTwoLayerTrie)(nil).Clone), arg0) +} + // Delete mocks base method. func (m *MockTwoLayerTrie) Delete(arg0, arg1 []byte) error { m.ctrl.T.Helper()