diff --git a/nodedb.go b/nodedb.go index 54933c2e6..2065a8907 100644 --- a/nodedb.go +++ b/nodedb.go @@ -511,6 +511,15 @@ func (ndb *nodeDB) HasVersion(version int64) (bool, error) { return ndb.db.Has(nodeKeyFormat.Key(version, []byte{1})) } +func isReferenceToRoot(bz []byte) bool { + if bz[0] == nodeKeyFormat.Prefix()[0] { + if len(bz) == 13 { + return true + } + } + return false +} + // GetRoot gets the nodeKey of the root for the specific version. func (ndb *nodeDB) GetRoot(version int64) (*NodeKey, error) { val, err := ndb.db.Get(nodeKeyFormat.Key(version, []byte{1})) @@ -523,7 +532,7 @@ func (ndb *nodeDB) GetRoot(version int64) (*NodeKey, error) { if len(val) == 0 { // empty root return nil, nil } - if val[0] == nodeKeyFormat.Prefix()[0] { // point to the prev root + if isReferenceToRoot(val) { // point to the prev root var ( version int64 nonce int32 @@ -702,6 +711,56 @@ func (ndb *nodeDB) traverseOrphans(version int64, fn func(*Node) error) error { return nil } +// traverseStateChanges iterate the range of versions, compare each version to it's predecessor to extract the state changes of it. +// endVersion is exclusive, set to `math.MaxInt64` to cover the latest version. +func (ndb *nodeDB) traverseStateChanges(startVersion, endVersion int64, fn func(version int64, changeSet *ChangeSet) error) error { + firstVersion, err := ndb.getFirstVersion() + if err != nil { + return err + } + if startVersion < firstVersion { + startVersion = firstVersion + } + latestVersion, err := ndb.getLatestVersion() + if err != nil { + return err + } + if endVersion > latestVersion { + endVersion = latestVersion + } + + prevVersion := startVersion - 1 + prevRoot, err := ndb.GetRoot(prevVersion) + if err != nil && err != ErrVersionDoesNotExist { + return err + } + + for version := startVersion; version <= endVersion; version++ { + root, err := ndb.GetRoot(version) + if err != nil { + return err + } + + var changeSet ChangeSet + receiveKVPair := func(pair *KVPair) error { + changeSet.Pairs = append(changeSet.Pairs, *pair) + return nil + } + + if err := ndb.extractStateChanges(prevVersion, prevRoot, root, receiveKVPair); err != nil { + return err + } + + if err := fn(version, &changeSet); err != nil { + return err + } + prevVersion = version + prevRoot = root + } + + return nil +} + // Utility and test functions func (ndb *nodeDB) leafNodes() ([]*Node, error) { @@ -768,32 +827,28 @@ func (ndb *nodeDB) size() int { } func (ndb *nodeDB) traverseNodes(fn func(node *Node) error) error { - ndb.resetLatestVersion(0) - latest, err := ndb.getLatestVersion() - if err != nil { - return err - } - nodes := []*Node{} - for version := int64(1); version <= latest; version++ { - if err := ndb.traverseRange(nodeKeyFormat.Key(version), nodeKeyFormat.Key(version+1), func(key, value []byte) error { - var ( - version int64 - nonce int32 - ) - nodeKeyFormat.Scan(key, &version, &nonce) - node, err := MakeNode(&NodeKey{ - version: version, - nonce: nonce, - }, value) - if err != nil { - return err - } - nodes = append(nodes, node) + + if err := ndb.traversePrefix(nodeKeyFormat.Key(), func(key, value []byte) error { + if isReferenceToRoot(value) { return nil - }); err != nil { + } + var ( + version int64 + nonce int32 + ) + nodeKeyFormat.Scan(key, &version, &nonce) + node, err := MakeNode(&NodeKey{ + version: version, + nonce: nonce, + }, value) + if err != nil { return err } + nodes = append(nodes, node) + return nil + }); err != nil { + return err } sort.Slice(nodes, func(i, j int) bool { @@ -808,56 +863,6 @@ func (ndb *nodeDB) traverseNodes(fn func(node *Node) error) error { return nil } -// traverseStateChanges iterate the range of versions, compare each version to it's predecessor to extract the state changes of it. -// endVersion is exclusive, set to `math.MaxInt64` to cover the latest version. -func (ndb *nodeDB) traverseStateChanges(startVersion, endVersion int64, fn func(version int64, changeSet *ChangeSet) error) error { - firstVersion, err := ndb.getFirstVersion() - if err != nil { - return err - } - if startVersion < firstVersion { - startVersion = firstVersion - } - latestVersion, err := ndb.getLatestVersion() - if err != nil { - return err - } - if endVersion > latestVersion { - endVersion = latestVersion - } - - prevVersion := startVersion - 1 - prevRoot, err := ndb.GetRoot(prevVersion) - if err != nil && err != ErrVersionDoesNotExist { - return err - } - - for version := startVersion; version <= endVersion; version++ { - root, err := ndb.GetRoot(version) - if err != nil { - return err - } - - var changeSet ChangeSet - receiveKVPair := func(pair *KVPair) error { - changeSet.Pairs = append(changeSet.Pairs, *pair) - return nil - } - - if err := ndb.extractStateChanges(prevVersion, prevRoot, root, receiveKVPair); err != nil { - return err - } - - if err := fn(version, &changeSet); err != nil { - return err - } - prevVersion = version - prevRoot = root - } - - return nil -} - func (ndb *nodeDB) String() (string, error) { buf := bufPool.Get().(*bytes.Buffer) defer bufPool.Put(buf) diff --git a/nodedb_test.go b/nodedb_test.go index ba33c2a79..381985de8 100644 --- a/nodedb_test.go +++ b/nodedb_test.go @@ -253,6 +253,38 @@ func TestIsFastStorageEnabled_False(t *testing.T) { require.NoError(t, err) } +func TestTraverseNodes(t *testing.T) { + tree, _ := getTestTree(0) + // version 1 + for i := 0; i < 20; i++ { + _, err := tree.Set([]byte{byte(i)}, []byte{byte(i)}) + require.NoError(t, err) + } + _, _, err := tree.SaveVersion() + require.NoError(t, err) + // version 2, no commit + _, _, err = tree.SaveVersion() + require.NoError(t, err) + // version 3 + for i := 20; i < 30; i++ { + _, err := tree.Set([]byte{byte(i)}, []byte{byte(i)}) + require.NoError(t, err) + } + _, _, err = tree.SaveVersion() + require.NoError(t, err) + + count := 0 + err = tree.ndb.traverseNodes(func(node *Node) error { + t.Log(node) + if node.isLeaf() { + count++ + } + return nil + }) + require.NoError(t, err) + require.Equal(t, 30, count) +} + func assertOrphansAndBranches(t *testing.T, ndb *nodeDB, version int64, branches int, orphanKeys [][]byte) { var branchCount, orphanIndex int err := ndb.traverseOrphans(version, func(node *Node) error {