Skip to content

Commit

Permalink
trie: iterate values pre-order and fix seek behavior (ethereum#27838)
Browse files Browse the repository at this point in the history
This pull request fixes the pre-order trie traversal by defining 
a more accurate iterator order and path comparison rule.

Co-authored-by: Gary Rong <garyrong0905@gmail.com>
  • Loading branch information
2 people authored and jorgemmsilva committed Jun 17, 2024
1 parent b6b1d15 commit b93edc9
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 23 deletions.
81 changes: 64 additions & 17 deletions trie/iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ type nodeIteratorState struct {
node node // Trie node being iterated
parent common.Hash // Hash of the first full ancestor node (nil if current is the root)
index int // Child to be processed next
pathlen int // Length of the path to this node
pathlen int // Length of the path to the parent node
}

type nodeIterator struct {
Expand All @@ -145,7 +145,7 @@ type nodeIterator struct {
err error // Failure set in case of an internal error in the iterator

resolver NodeResolver // optional node resolver for avoiding disk hits
pool []*nodeIteratorState // local pool for iteratorstates
pool []*nodeIteratorState // local pool for iterator states
}

// errIteratorEnd is stored in nodeIterator.err when iteration is done.
Expand Down Expand Up @@ -304,14 +304,15 @@ func (it *nodeIterator) seek(prefix []byte) error {
// The path we're looking for is the hex encoded key without terminator.
key := keybytesToHex(prefix)
key = key[:len(key)-1]

// Move forward until we're just before the closest match to key.
for {
state, parentIndex, path, err := it.peekSeek(key)
if err == errIteratorEnd {
return errIteratorEnd
} else if err != nil {
return seekError{prefix, err}
} else if bytes.Compare(path, key) >= 0 {
} else if reachedPath(path, key) {
return nil
}
it.push(state, parentIndex, path)
Expand Down Expand Up @@ -339,7 +340,6 @@ func (it *nodeIterator) peek(descend bool) (*nodeIteratorState, *int, []byte, er
// If we're skipping children, pop the current node first
it.pop()
}

// Continue iteration to the next child
for len(it.stack) > 0 {
parent := it.stack[len(it.stack)-1]
Expand Down Expand Up @@ -372,7 +372,6 @@ func (it *nodeIterator) peekSeek(seekKey []byte) (*nodeIteratorState, *int, []by
// If we're skipping children, pop the current node first
it.pop()
}

// Continue iteration to the next child
for len(it.stack) > 0 {
parent := it.stack[len(it.stack)-1]
Expand Down Expand Up @@ -449,16 +448,18 @@ func (it *nodeIterator) findChild(n *fullNode, index int, ancestor common.Hash)
state *nodeIteratorState
childPath []byte
)
for ; index < len(n.Children); index++ {
for ; index < len(n.Children); index = nextChildIndex(index) {
if n.Children[index] != nil {
child = n.Children[index]
hash, _ := child.cache()

state = it.getFromPool()
state.hash = common.BytesToHash(hash)
state.node = child
state.parent = ancestor
state.index = -1
state.pathlen = len(path)

childPath = append(childPath, path...)
childPath = append(childPath, byte(index))
return child, state, childPath, index
Expand All @@ -471,8 +472,8 @@ func (it *nodeIterator) nextChild(parent *nodeIteratorState, ancestor common.Has
switch node := parent.node.(type) {
case *fullNode:
// Full node, move to the first non-nil child.
if child, state, path, index := it.findChild(node, parent.index+1, ancestor); child != nil {
parent.index = index - 1
if child, state, path, index := it.findChild(node, nextChildIndex(parent.index), ancestor); child != nil {
parent.index = prevChildIndex(index)
return state, path, true
}
case *shortNode:
Expand All @@ -498,23 +499,23 @@ func (it *nodeIterator) nextChildAt(parent *nodeIteratorState, ancestor common.H
switch n := parent.node.(type) {
case *fullNode:
// Full node, move to the first non-nil child before the desired key position
child, state, path, index := it.findChild(n, parent.index+1, ancestor)
child, state, path, index := it.findChild(n, nextChildIndex(parent.index), ancestor)
if child == nil {
// No more children in this fullnode
return parent, it.path, false
}
// If the child we found is already past the seek position, just return it.
if bytes.Compare(path, key) >= 0 {
parent.index = index - 1
if reachedPath(path, key) {
parent.index = prevChildIndex(index)
return state, path, true
}
// The child is before the seek position. Try advancing
for {
nextChild, nextState, nextPath, nextIndex := it.findChild(n, index+1, ancestor)
nextChild, nextState, nextPath, nextIndex := it.findChild(n, nextChildIndex(index), ancestor)
// If we run out of children, or skipped past the target, return the
// previous one
if nextChild == nil || bytes.Compare(nextPath, key) >= 0 {
parent.index = index - 1
if nextChild == nil || reachedPath(nextPath, key) {
parent.index = prevChildIndex(index)
return state, path, true
}
// We found a better child closer to the target
Expand All @@ -541,7 +542,7 @@ func (it *nodeIterator) push(state *nodeIteratorState, parentIndex *int, path []
it.path = path
it.stack = append(it.stack, state)
if parentIndex != nil {
*parentIndex++
*parentIndex = nextChildIndex(*parentIndex)
}
}

Expand All @@ -550,8 +551,54 @@ func (it *nodeIterator) pop() {
it.path = it.path[:last.pathlen]
it.stack[len(it.stack)-1] = nil
it.stack = it.stack[:len(it.stack)-1]
// last is now unused
it.putInPool(last)

it.putInPool(last) // last is now unused
}

// reachedPath normalizes a path by truncating a terminator if present, and
// returns true if it is greater than or equal to the target. Using this,
// the path of a value node embedded a full node will compare less than the
// full node's children.
func reachedPath(path, target []byte) bool {
if hasTerm(path) {
path = path[:len(path)-1]
}
return bytes.Compare(path, target) >= 0
}

// A value embedded in a full node occupies the last slot (16) of the array of
// children. In order to produce a pre-order traversal when iterating children,
// we jump to this last slot first, then go back iterate the child nodes (and
// skip the last slot at the end):

// prevChildIndex returns the index of a child in a full node which precedes
// the given index when performing a pre-order traversal.
func prevChildIndex(index int) int {
switch index {
case 0: // We jumped back to iterate the children, from the value slot
return 16
case 16: // We jumped to the embedded value slot at the end, from the placeholder index
return -1
case 17: // We skipped the value slot after iterating all the children
return 15
default: // We are iterating the children in sequence
return index - 1
}
}

// nextChildIndex returns the index of a child in a full node which follows
// the given index when performing a pre-order traversal.
func nextChildIndex(index int) int {
switch index {
case -1: // Jump from the placeholder index to the embedded value slot
return 16
case 15: // Skip the value slot after iterating the children
return 17
case 16: // From the embedded value slot, jump back to iterate the children
return 0
default: // Iterate children in sequence
return index + 1
}
}

func compareNodes(a, b NodeIterator) int {
Expand Down
18 changes: 12 additions & 6 deletions trie/iterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,14 @@ func testNodeIteratorCoverage(t *testing.T, scheme string) {
type kvs struct{ k, v string }

var testdata1 = []kvs{
{"bar", "b"},
{"barb", "ba"},
{"bard", "bc"},
{"bars", "bb"},
{"bar", "b"},
{"fab", "z"},
{"foo", "a"},
{"food", "ab"},
{"foos", "aa"},
{"foo", "a"},
}

var testdata2 = []kvs{
Expand Down Expand Up @@ -218,7 +218,7 @@ func TestIteratorSeek(t *testing.T) {

// Seek to a non-existent key.
it = NewIterator(trie.MustNodeIterator([]byte("barc")))
if err := checkIteratorOrder(testdata1[1:], it); err != nil {
if err := checkIteratorOrder(testdata1[2:], it); err != nil {
t.Fatal(err)
}

Expand All @@ -227,6 +227,12 @@ func TestIteratorSeek(t *testing.T) {
if err := checkIteratorOrder(nil, it); err != nil {
t.Fatal(err)
}

// Seek to a key for which a prefixing key exists.
it = NewIterator(trie.MustNodeIterator([]byte("food")))
if err := checkIteratorOrder(testdata1[6:], it); err != nil {
t.Fatal(err)
}
}

func checkIteratorOrder(want []kvs, it *Iterator) error {
Expand Down Expand Up @@ -311,16 +317,16 @@ func TestUnionIterator(t *testing.T) {

all := []struct{ k, v string }{
{"aardvark", "c"},
{"bar", "b"},
{"barb", "ba"},
{"barb", "bd"},
{"bard", "bc"},
{"bars", "bb"},
{"bars", "be"},
{"bar", "b"},
{"fab", "z"},
{"foo", "a"},
{"food", "ab"},
{"foos", "aa"},
{"foo", "a"},
{"jars", "d"},
}

Expand Down Expand Up @@ -512,7 +518,7 @@ func testIteratorContinueAfterSeekError(t *testing.T, memonly bool, scheme strin
rawdb.WriteTrieNode(diskdb, common.Hash{}, barNodePath, barNodeHash, barNodeBlob, triedb.Scheme())
}
// Check that iteration produces the right set of values.
if err := checkIteratorOrder(testdata1[2:], NewIterator(it)); err != nil {
if err := checkIteratorOrder(testdata1[3:], NewIterator(it)); err != nil {
t.Fatal(err)
}
}
Expand Down

0 comments on commit b93edc9

Please sign in to comment.