Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

trie: iterate values pre-order and fix seek behavior #27838

Merged
merged 7 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 64 additions & 17 deletions trie/iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ type nodeIteratorState struct {
node node // Trie node being iterated
parent common.Hash // Hash of the first full ancestor node (nil if current is the root)
index int // Child to be processed next
pathlen int // Length of the path to this node
pathlen int // Length of the path to the parent node
}

type nodeIterator struct {
Expand All @@ -145,7 +145,7 @@ type nodeIterator struct {
err error // Failure set in case of an internal error in the iterator

resolver NodeResolver // optional node resolver for avoiding disk hits
pool []*nodeIteratorState // local pool for iteratorstates
pool []*nodeIteratorState // local pool for iterator states
}

// errIteratorEnd is stored in nodeIterator.err when iteration is done.
Expand Down Expand Up @@ -304,14 +304,15 @@ func (it *nodeIterator) seek(prefix []byte) error {
// The path we're looking for is the hex encoded key without terminator.
key := keybytesToHex(prefix)
key = key[:len(key)-1]

// Move forward until we're just before the closest match to key.
for {
state, parentIndex, path, err := it.peekSeek(key)
if err == errIteratorEnd {
return errIteratorEnd
} else if err != nil {
return seekError{prefix, err}
} else if bytes.Compare(path, key) >= 0 {
} else if reachedPath(path, key) {
return nil
}
it.push(state, parentIndex, path)
Expand Down Expand Up @@ -339,7 +340,6 @@ func (it *nodeIterator) peek(descend bool) (*nodeIteratorState, *int, []byte, er
// If we're skipping children, pop the current node first
it.pop()
}

// Continue iteration to the next child
for len(it.stack) > 0 {
parent := it.stack[len(it.stack)-1]
Expand Down Expand Up @@ -372,7 +372,6 @@ func (it *nodeIterator) peekSeek(seekKey []byte) (*nodeIteratorState, *int, []by
// If we're skipping children, pop the current node first
it.pop()
}

// Continue iteration to the next child
for len(it.stack) > 0 {
parent := it.stack[len(it.stack)-1]
Expand Down Expand Up @@ -449,16 +448,18 @@ func (it *nodeIterator) findChild(n *fullNode, index int, ancestor common.Hash)
state *nodeIteratorState
childPath []byte
)
for ; index < len(n.Children); index++ {
for ; index < len(n.Children); index = nextChildIndex(index) {
if n.Children[index] != nil {
child = n.Children[index]
hash, _ := child.cache()

state = it.getFromPool()
state.hash = common.BytesToHash(hash)
state.node = child
state.parent = ancestor
state.index = -1
state.pathlen = len(path)

childPath = append(childPath, path...)
childPath = append(childPath, byte(index))
return child, state, childPath, index
Expand All @@ -471,8 +472,8 @@ func (it *nodeIterator) nextChild(parent *nodeIteratorState, ancestor common.Has
switch node := parent.node.(type) {
case *fullNode:
// Full node, move to the first non-nil child.
if child, state, path, index := it.findChild(node, parent.index+1, ancestor); child != nil {
parent.index = index - 1
if child, state, path, index := it.findChild(node, nextChildIndex(parent.index), ancestor); child != nil {
parent.index = prevChildIndex(index)
return state, path, true
}
case *shortNode:
Expand All @@ -498,23 +499,23 @@ func (it *nodeIterator) nextChildAt(parent *nodeIteratorState, ancestor common.H
switch n := parent.node.(type) {
case *fullNode:
// Full node, move to the first non-nil child before the desired key position
child, state, path, index := it.findChild(n, parent.index+1, ancestor)
child, state, path, index := it.findChild(n, nextChildIndex(parent.index), ancestor)
if child == nil {
// No more children in this fullnode
return parent, it.path, false
}
// If the child we found is already past the seek position, just return it.
if bytes.Compare(path, key) >= 0 {
parent.index = index - 1
if reachedPath(path, key) {
parent.index = prevChildIndex(index)
return state, path, true
}
// The child is before the seek position. Try advancing
for {
nextChild, nextState, nextPath, nextIndex := it.findChild(n, index+1, ancestor)
nextChild, nextState, nextPath, nextIndex := it.findChild(n, nextChildIndex(index), ancestor)
// If we run out of children, or skipped past the target, return the
// previous one
if nextChild == nil || bytes.Compare(nextPath, key) >= 0 {
parent.index = index - 1
if nextChild == nil || reachedPath(nextPath, key) {
parent.index = prevChildIndex(index)
return state, path, true
}
// We found a better child closer to the target
Expand All @@ -541,7 +542,7 @@ func (it *nodeIterator) push(state *nodeIteratorState, parentIndex *int, path []
it.path = path
it.stack = append(it.stack, state)
if parentIndex != nil {
*parentIndex++
*parentIndex = nextChildIndex(*parentIndex)
}
}

Expand All @@ -550,8 +551,54 @@ func (it *nodeIterator) pop() {
it.path = it.path[:last.pathlen]
it.stack[len(it.stack)-1] = nil
it.stack = it.stack[:len(it.stack)-1]
// last is now unused
it.putInPool(last)

it.putInPool(last) // last is now unused
}

// reachedPath normalizes a path by truncating a terminator if present, and
// returns true if it is greater than or equal to the target. Using this,
// the path of a value node embedded a full node will compare less than the
// full node's children.
func reachedPath(path, target []byte) bool {
if hasTerm(path) {
path = path[:len(path)-1]
}
return bytes.Compare(path, target) >= 0
}

// A value embedded in a full node occupies the last slot (16) of the array of
// children. In order to produce a pre-order traversal when iterating children,
// we jump to this last slot first, then go back iterate the child nodes (and
// skip the last slot at the end):

// prevChildIndex returns the index of a child in a full node which precedes
// the given index when performing a pre-order traversal.
func prevChildIndex(index int) int {
switch index {
case 0: // We jumped back to iterate the children, from the value slot
return 16
case 16: // We jumped to the embedded value slot at the end, from the placeholder index
return -1
case 17: // We skipped the value slot after iterating all the children
return 15
default: // We are iterating the children in sequence
return index - 1
}
}

// nextChildIndex returns the index of a child in a full node which follows
// the given index when performing a pre-order traversal.
func nextChildIndex(index int) int {
switch index {
case -1: // Jump from the placeholder index to the embedded value slot
return 16
case 15: // Skip the value slot after iterating the children
return 17
case 16: // From the embedded value slot, jump back to iterate the children
return 0
default: // Iterate children in sequence
return index + 1
}
}

func compareNodes(a, b NodeIterator) int {
Expand Down
18 changes: 12 additions & 6 deletions trie/iterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,14 @@ func testNodeIteratorCoverage(t *testing.T, scheme string) {
type kvs struct{ k, v string }

var testdata1 = []kvs{
{"bar", "b"},
{"barb", "ba"},
{"bard", "bc"},
{"bars", "bb"},
{"bar", "b"},
{"fab", "z"},
{"foo", "a"},
{"food", "ab"},
{"foos", "aa"},
{"foo", "a"},
}

var testdata2 = []kvs{
Expand Down Expand Up @@ -218,7 +218,7 @@ func TestIteratorSeek(t *testing.T) {

// Seek to a non-existent key.
it = NewIterator(trie.MustNodeIterator([]byte("barc")))
if err := checkIteratorOrder(testdata1[1:], it); err != nil {
if err := checkIteratorOrder(testdata1[2:], it); err != nil {
t.Fatal(err)
}

Expand All @@ -227,6 +227,12 @@ func TestIteratorSeek(t *testing.T) {
if err := checkIteratorOrder(nil, it); err != nil {
t.Fatal(err)
}

// Seek to a key for which a prefixing key exists.
it = NewIterator(trie.MustNodeIterator([]byte("food")))
if err := checkIteratorOrder(testdata1[6:], it); err != nil {
t.Fatal(err)
}
}

func checkIteratorOrder(want []kvs, it *Iterator) error {
Expand Down Expand Up @@ -311,16 +317,16 @@ func TestUnionIterator(t *testing.T) {

all := []struct{ k, v string }{
{"aardvark", "c"},
{"bar", "b"},
{"barb", "ba"},
{"barb", "bd"},
{"bard", "bc"},
{"bars", "bb"},
{"bars", "be"},
{"bar", "b"},
{"fab", "z"},
{"foo", "a"},
{"food", "ab"},
{"foos", "aa"},
{"foo", "a"},
{"jars", "d"},
}

Expand Down Expand Up @@ -511,7 +517,7 @@ func testIteratorContinueAfterSeekError(t *testing.T, memonly bool, scheme strin
rawdb.WriteTrieNode(diskdb, common.Hash{}, barNodePath, barNodeHash, barNodeBlob, triedb.Scheme())
}
// Check that iteration produces the right set of values.
if err := checkIteratorOrder(testdata1[2:], NewIterator(it)); err != nil {
if err := checkIteratorOrder(testdata1[3:], NewIterator(it)); err != nil {
t.Fatal(err)
}
}
Expand Down
Loading