diff --git a/server/v2/stf/branch/changeset.go b/server/v2/stf/branch/changeset.go index 13c016725130..c409b1b7becf 100644 --- a/server/v2/stf/branch/changeset.go +++ b/server/v2/stf/branch/changeset.go @@ -5,8 +5,6 @@ import ( "errors" "github.com/tidwall/btree" - - "cosmossdk.io/core/store" ) const ( @@ -55,7 +53,7 @@ func (bt changeSet) delete(key []byte) { // iterator returns a new iterator over the key-value pairs in the changeSet // that have keys greater than or equal to the start key and less than the end key. -func (bt changeSet) iterator(start, end []byte) (store.Iterator, error) { +func (bt changeSet) iterator(start, end []byte) (*memIterator, error) { if (start != nil && len(start) == 0) || (end != nil && len(end) == 0) { return nil, errKeyEmpty } @@ -65,7 +63,7 @@ func (bt changeSet) iterator(start, end []byte) (store.Iterator, error) { // reverseIterator returns a new iterator that iterates over the key-value pairs in reverse order // within the specified range [start, end) in the changeSet's tree. // If start or end is an empty byte slice, it returns an error indicating that the key is empty. -func (bt changeSet) reverseIterator(start, end []byte) (store.Iterator, error) { +func (bt changeSet) reverseIterator(start, end []byte) (*memIterator, error) { if (start != nil && len(start) == 0) || (end != nil && len(end) == 0) { return nil, errKeyEmpty } @@ -158,6 +156,7 @@ func (mi *memIterator) Domain() (start, end []byte) { // Close releases any resources held by the iterator. func (mi *memIterator) Close() error { mi.iter.Release() + mi.valid = false return nil } diff --git a/server/v2/stf/branch/changeset_test.go b/server/v2/stf/branch/changeset_test.go new file mode 100644 index 000000000000..a77713242ec2 --- /dev/null +++ b/server/v2/stf/branch/changeset_test.go @@ -0,0 +1,28 @@ +package branch + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_memIterator(t *testing.T) { + t.Run("iter is invalid after close", func(t *testing.T) { + cs := newChangeSet() + for i := byte(0); i < 32; i++ { + cs.set([]byte{0, i}, []byte{i}) + } + + it, err := cs.iterator(nil, nil) + if err != nil { + t.Fatal(err) + } + + err = it.Close() + if err != nil { + t.Fatal(err) + } + + require.False(t, it.Valid()) + }) +}