From 592b619eb45a29e02ff0908e1d13be2b6bb22c94 Mon Sep 17 00:00:00 2001 From: lmittmann Date: Fri, 31 May 2024 16:03:43 +0200 Subject: [PATCH 1/6] w3vm: added `ApplyWithSnapshot` --- w3vm/vm.go | 46 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/w3vm/vm.go b/w3vm/vm.go index efa1c408..ad1d2847 100644 --- a/w3vm/vm.go +++ b/w3vm/vm.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "math/big" + "slices" "testing" "time" @@ -40,6 +41,7 @@ type VM struct { txIndex uint64 db *gethState.StateDB + snaps []int } // New creates a new VM, that is configured with the given options. @@ -77,7 +79,8 @@ func New(opts ...Option) (*VM, error) { // Apply the given message to the VM and return its receipt. Multiple tracing hooks // can be given to trace the execution of the message. func (vm *VM) Apply(msg *w3types.Message, hooks ...*tracing.Hooks) (*Receipt, error) { - return vm.apply(msg, false, joinHooks(hooks)) + receipt, _, err := vm.apply(msg, false, true, joinHooks(hooks)) + return receipt, err } // ApplyTx is like [VM.Apply], but takes a transaction instead of a message. @@ -89,15 +92,22 @@ func (vm *VM) ApplyTx(tx *types.Transaction, hooks ...*tracing.Hooks) (*Receipt, return vm.Apply(msg, hooks...) } -func (v *VM) apply(msg *w3types.Message, isCall bool, hooks *tracing.Hooks) (*Receipt, error) { +// ApplyWithSnapshot is like [VM.Apply], but also returns a state snapshot of the messages pre-state. +// The VM's state can be rolled back to the snapshot using [VM.Rollback]. A snapshot is invalidated +// after a call to [VM.Apply] or [VM.ApplyTx]. +func (vm *VM) ApplyWithSnapshot(msg *w3types.Message, hooks ...*tracing.Hooks) (*Receipt, int, error) { + return vm.apply(msg, false, false, joinHooks(hooks)) +} + +func (v *VM) apply(msg *w3types.Message, isCall, finalize bool, hooks *tracing.Hooks) (*Receipt, int, error) { if v.db.Error() != nil { - return nil, ErrFetch + return nil, -1, ErrFetch } v.db.SetLogger(hooks) coreMsg, txCtx, err := v.buildMessage(msg, isCall) if err != nil { - return nil, err + return nil, -1, err } var txHash common.Hash @@ -116,7 +126,7 @@ func (v *VM) apply(msg *w3types.Message, isCall bool, hooks *tracing.Hooks) (*Re // apply the message to the evm result, err := core.ApplyMessage(evm, coreMsg, gp) if err != nil { - return nil, err + return nil, -1, err } // build receipt @@ -144,16 +154,22 @@ func (v *VM) apply(msg *w3types.Message, isCall bool, hooks *tracing.Hooks) (*Re if isCall && !result.Failed() { v.db.RevertToSnapshot(snap) } - v.db.Finalise(false) + if finalize { + v.db.Finalise(false) + v.snaps = v.snaps[:0] // clear snapshots + } else if !isCall { + v.snaps = append(v.snaps, snap) + } - return receipt, receipt.Err + return receipt, snap, receipt.Err } // Call calls the given message on the VM and returns a receipt. Any state changes // of a call are reverted. Multiple tracing hooks can be passed to trace the execution // of the message. func (vm *VM) Call(msg *w3types.Message, hooks ...*tracing.Hooks) (*Receipt, error) { - return vm.apply(msg, true, joinHooks(hooks)) + receipt, _, err := vm.apply(msg, true, false, joinHooks(hooks)) + return receipt, err } // CallFunc is a utility function for [VM.Call] that calls the given function @@ -229,6 +245,20 @@ func (vm *VM) StorageAt(addr common.Address, slot common.Hash) (common.Hash, err return val, nil } +// Rollback the state of the VM to the given snapshot. +// Snapshots +func (vm *VM) Rollback(snapshot int) error { + // validate snapshot + i := slices.Index(vm.snaps, snapshot) + if i < 0 { + return fmt.Errorf("invalid snapshot %d", snapshot) + } + + vm.db.RevertToSnapshot(snapshot) + vm.snaps = vm.snaps[:i] // clear snapshot, and all snapshots after it + return nil +} + func (v *VM) buildMessage(msg *w3types.Message, skipAccChecks bool) (*core.Message, *vm.TxContext, error) { nonce := msg.Nonce if !skipAccChecks && nonce == 0 && msg.From != w3.Addr0 { From 13b2e4ad4258e5e7d48791c07acb5384778c0362 Mon Sep 17 00:00:00 2001 From: lmittmann Date: Mon, 3 Jun 2024 10:57:24 +0200 Subject: [PATCH 2/6] changed snapshot logic --- w3vm/db.go | 13 ++++++++---- w3vm/fetcher.go | 10 ++++----- w3vm/vm.go | 56 ++++++++++++++----------------------------------- 3 files changed, 30 insertions(+), 49 deletions(-) diff --git a/w3vm/db.go b/w3vm/db.go index b67d82d3..07287eba 100644 --- a/w3vm/db.go +++ b/w3vm/db.go @@ -4,7 +4,7 @@ import ( "errors" "github.com/ethereum/go-ethereum/common" - gethState "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/trie" @@ -28,13 +28,18 @@ func newDB(fetcher Fetcher) *db { // state.Database methods ////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// -func (db *db) OpenTrie(root common.Hash) (gethState.Trie, error) { return db, nil } +func (db *db) OpenTrie(root common.Hash) (state.Trie, error) { return db, nil } -func (db *db) OpenStorageTrie(stateRoot common.Hash, addr common.Address, root common.Hash, trie gethState.Trie) (gethState.Trie, error) { +func (db *db) OpenStorageTrie(stateRoot common.Hash, addr common.Address, root common.Hash, trie state.Trie) (state.Trie, error) { return db, nil } -func (*db) CopyTrie(gethState.Trie) gethState.Trie { panic("not implemented") } +func (*db) CopyTrie(trie state.Trie) state.Trie { + if db, ok := trie.(*db); ok { + return db + } + panic("not implemented") +} func (db *db) ContractCode(addr common.Address, codeHash common.Hash) ([]byte, error) { if db.fetcher == nil { diff --git a/w3vm/fetcher.go b/w3vm/fetcher.go index 107ed930..a31c4890 100644 --- a/w3vm/fetcher.go +++ b/w3vm/fetcher.go @@ -213,7 +213,7 @@ func NewTestingRPCFetcher(tb testing.TB, chainID uint64, client *w3.Client, bloc var ( globalStateStoreMux sync.RWMutex - globalStateStore = make(map[string]*state) + globalStateStore = make(map[string]*testdataState) ) func (f *rpcFetcher) loadTestdataState(tb testing.TB, chainID uint64) error { @@ -224,7 +224,7 @@ func (f *rpcFetcher) loadTestdataState(tb testing.TB, chainID uint64) error { fmt.Sprintf("%d_%v.json", chainID, f.blockNumber), ) - var s *state + var s *testdataState // check if the state has already been loaded globalStateStoreMux.RLock() @@ -307,7 +307,7 @@ func (f *rpcFetcher) storeTestdataState(tb testing.TB, chainID uint64) error { defer f.mux2.RUnlock() defer f.mux3.RUnlock() - s := &state{ + s := &testdataState{ Accounts: make(map[common.Address]*account, len(f.accounts)), HeaderHashes: make(map[hexutil.Uint64]common.Hash, len(f.headerHashes)), } @@ -382,7 +382,7 @@ func (f *rpcFetcher) storeTestdataState(tb testing.TB, chainID uint64) error { return nil } -type state struct { +type testdataState struct { Accounts map[common.Address]*account `json:"accounts"` HeaderHashes map[hexutil.Uint64]common.Hash `json:"headerHashes,omitempty"` } @@ -396,7 +396,7 @@ type account struct { // mergeStates merges the source state into the destination state and returns // whether the destination state has been modified. -func mergeStates(dst, src *state) (modified bool) { +func mergeStates(dst, src *testdataState) (modified bool) { // merge accounts for addr, acc := range src.Accounts { if dstAcc, ok := dst.Accounts[addr]; !ok { diff --git a/w3vm/vm.go b/w3vm/vm.go index ad1d2847..572d9bf3 100644 --- a/w3vm/vm.go +++ b/w3vm/vm.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "math/big" - "slices" "testing" "time" @@ -17,7 +16,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/core" - gethState "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/tracing" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" @@ -40,8 +39,7 @@ type VM struct { opts *options txIndex uint64 - db *gethState.StateDB - snaps []int + db *state.StateDB } // New creates a new VM, that is configured with the given options. @@ -60,7 +58,7 @@ func New(opts ...Option) (*VM, error) { // set DB db := newDB(vm.opts.fetcher) - vm.db, _ = gethState.New(w3.Hash0, db, nil) + vm.db, _ = state.New(w3.Hash0, db, nil) for addr, acc := range vm.opts.preState { vm.db.SetNonce(addr, acc.Nonce) if acc.Balance != nil { @@ -79,8 +77,7 @@ func New(opts ...Option) (*VM, error) { // Apply the given message to the VM and return its receipt. Multiple tracing hooks // can be given to trace the execution of the message. func (vm *VM) Apply(msg *w3types.Message, hooks ...*tracing.Hooks) (*Receipt, error) { - receipt, _, err := vm.apply(msg, false, true, joinHooks(hooks)) - return receipt, err + return vm.apply(msg, false, joinHooks(hooks)) } // ApplyTx is like [VM.Apply], but takes a transaction instead of a message. @@ -92,22 +89,15 @@ func (vm *VM) ApplyTx(tx *types.Transaction, hooks ...*tracing.Hooks) (*Receipt, return vm.Apply(msg, hooks...) } -// ApplyWithSnapshot is like [VM.Apply], but also returns a state snapshot of the messages pre-state. -// The VM's state can be rolled back to the snapshot using [VM.Rollback]. A snapshot is invalidated -// after a call to [VM.Apply] or [VM.ApplyTx]. -func (vm *VM) ApplyWithSnapshot(msg *w3types.Message, hooks ...*tracing.Hooks) (*Receipt, int, error) { - return vm.apply(msg, false, false, joinHooks(hooks)) -} - -func (v *VM) apply(msg *w3types.Message, isCall, finalize bool, hooks *tracing.Hooks) (*Receipt, int, error) { +func (v *VM) apply(msg *w3types.Message, isCall bool, hooks *tracing.Hooks) (*Receipt, error) { if v.db.Error() != nil { - return nil, -1, ErrFetch + return nil, ErrFetch } v.db.SetLogger(hooks) coreMsg, txCtx, err := v.buildMessage(msg, isCall) if err != nil { - return nil, -1, err + return nil, err } var txHash common.Hash @@ -126,7 +116,7 @@ func (v *VM) apply(msg *w3types.Message, isCall, finalize bool, hooks *tracing.H // apply the message to the evm result, err := core.ApplyMessage(evm, coreMsg, gp) if err != nil { - return nil, -1, err + return nil, err } // build receipt @@ -154,22 +144,16 @@ func (v *VM) apply(msg *w3types.Message, isCall, finalize bool, hooks *tracing.H if isCall && !result.Failed() { v.db.RevertToSnapshot(snap) } - if finalize { - v.db.Finalise(false) - v.snaps = v.snaps[:0] // clear snapshots - } else if !isCall { - v.snaps = append(v.snaps, snap) - } + v.db.Finalise(false) - return receipt, snap, receipt.Err + return receipt, receipt.Err } // Call calls the given message on the VM and returns a receipt. Any state changes // of a call are reverted. Multiple tracing hooks can be passed to trace the execution // of the message. func (vm *VM) Call(msg *w3types.Message, hooks ...*tracing.Hooks) (*Receipt, error) { - receipt, _, err := vm.apply(msg, true, false, joinHooks(hooks)) - return receipt, err + return vm.apply(msg, true, joinHooks(hooks)) } // CallFunc is a utility function for [VM.Call] that calls the given function @@ -245,23 +229,15 @@ func (vm *VM) StorageAt(addr common.Address, slot common.Hash) (common.Hash, err return val, nil } -// Rollback the state of the VM to the given snapshot. -// Snapshots -func (vm *VM) Rollback(snapshot int) error { - // validate snapshot - i := slices.Index(vm.snaps, snapshot) - if i < 0 { - return fmt.Errorf("invalid snapshot %d", snapshot) - } +// Snapshot the current state of the VM. +func (vm *VM) Snapshot() *state.StateDB { return vm.db.Copy() } - vm.db.RevertToSnapshot(snapshot) - vm.snaps = vm.snaps[:i] // clear snapshot, and all snapshots after it - return nil -} +// Rollback the state of the VM to the given snapshot. +func (vm *VM) Rollback(snapshot *state.StateDB) { vm.db = snapshot } func (v *VM) buildMessage(msg *w3types.Message, skipAccChecks bool) (*core.Message, *vm.TxContext, error) { nonce := msg.Nonce - if !skipAccChecks && nonce == 0 && msg.From != w3.Addr0 { + if !skipAccChecks && nonce == 0 { var err error nonce, err = v.Nonce(msg.From) if err != nil { From 3e4aeceeb14c4d4a9e66b92f8d3f35fc6ee0fd8a Mon Sep 17 00:00:00 2001 From: lmittmann Date: Mon, 3 Jun 2024 11:34:35 +0200 Subject: [PATCH 3/6] added test --- w3vm/vm_test.go | 56 +++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 54 insertions(+), 2 deletions(-) diff --git a/w3vm/vm_test.go b/w3vm/vm_test.go index 75930882..1e1eb6f9 100644 --- a/w3vm/vm_test.go +++ b/w3vm/vm_test.go @@ -14,7 +14,7 @@ import ( "github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/rawdb" - coreState "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/crypto" @@ -214,6 +214,58 @@ func TestVMApply(t *testing.T) { } } +func TestVMSnapshot(t *testing.T) { + vm, _ := w3vm.New( + w3vm.WithState(w3types.State{ + addrWETH: {Code: codeWETH}, + addr0: {Balance: w3.I("100 ether")}, + }), + ) + + depositMsg := &w3types.Message{ + From: addr0, + To: &addrWETH, + Value: w3.I("1 ether"), + } + + getBalanceOf := func(t *testing.T, token, acc common.Address) *big.Int { + t.Helper() + + var balance *big.Int + if err := vm.CallFunc(token, funcBalanceOf, acc).Returns(&balance); err != nil { + t.Fatalf("Failed to call balanceOf: %v", err) + } + return balance + } + + if got := getBalanceOf(t, addrWETH, addr0); got.Sign() != 0 { + t.Fatalf("Balance: want 0 WETH, got %s WETH", w3.FromWei(got, 18)) + } + + var snap *state.StateDB + for i := range 100 { + if i == 42 { + snap = vm.Snapshot() + } + + if _, err := vm.Apply(depositMsg); err != nil { + t.Fatalf("Failed to apply deposit msg: %v", err) + } + + want := w3.I(strconv.Itoa(i+1) + " ether") + if got := getBalanceOf(t, addrWETH, addr0); want.Cmp(got) != 0 { + t.Fatalf("Balance: want %s WETH, got %s WETH", w3.FromWei(want, 18), w3.FromWei(got, 18)) + } + } + + vm.Rollback(snap) + + want := w3.I("42 ether") + if got := getBalanceOf(t, addrWETH, addr0); got.Cmp(want) != 0 { + t.Fatalf("Balance: want %s WETH, got %s WETH", w3.FromWei(want, 18), w3.FromWei(got, 18)) + } +} + func TestVMCall(t *testing.T) { tests := []struct { PreState w3types.State @@ -506,7 +558,7 @@ func BenchmarkTransferWETH9(b *testing.B) { }) b.Run("geth", func(b *testing.B) { - stateDB, _ := coreState.New(common.Hash{}, coreState.NewDatabase(rawdb.NewMemoryDatabase()), nil) + stateDB, _ := state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase()), nil) stateDB.SetCode(addrWETH, codeWETH) stateDB.SetState(addrWETH, w3vm.WETHBalanceSlot(addr0), common.BigToHash(w3.I("1 ether"))) From 780f7ec1c785b702868be1e11abf59bf339a0542 Mon Sep 17 00:00:00 2001 From: lmittmann Date: Mon, 3 Jun 2024 11:41:40 +0200 Subject: [PATCH 4/6] updated testdata --- w3vm/testdata/w3vm/1_17034867.json | 5 ----- 1 file changed, 5 deletions(-) diff --git a/w3vm/testdata/w3vm/1_17034867.json b/w3vm/testdata/w3vm/1_17034867.json index edc67154..90aa1b3e 100644 --- a/w3vm/testdata/w3vm/1_17034867.json +++ b/w3vm/testdata/w3vm/1_17034867.json @@ -1,10 +1,5 @@ { "accounts": { - "0x0000000000000000000000000000000000000000": { - "nonce": "0x0", - "balance": "0x272392e2b6e127d35e3", - "code": "0x" - }, "0x0000000000000000000000000000000000000001": { "nonce": "0x0", "balance": "0x2e58c20c74febd3b7", From 0ed6db86b408131882ad6214dd8d74643a82df83 Mon Sep 17 00:00:00 2001 From: lmittmann Date: Mon, 3 Jun 2024 11:58:30 +0200 Subject: [PATCH 5/6] added benchmark --- w3vm/bench_test.go | 55 ++++++++++++++++++++++++++++++++++++++++++++++ w3vm/fetcher.go | 4 ++-- 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/w3vm/bench_test.go b/w3vm/bench_test.go index fe4d4f32..36952b79 100644 --- a/w3vm/bench_test.go +++ b/w3vm/bench_test.go @@ -1,6 +1,7 @@ package w3vm_test import ( + "fmt" "math/big" "testing" @@ -102,3 +103,57 @@ func BenchmarkVM(b *testing.B) { dur := b.Elapsed() b.ReportMetric(float64(gasSimulated)/dur.Seconds(), "gas/s") } + +func BenchmarkVMSnapshot(b *testing.B) { + depositMsg := &w3types.Message{ + From: addr0, + To: &addrWETH, + Value: w3.I("1 ether"), + } + + runs := 2 + b.Run(fmt.Sprintf("re-run %d", runs), func(b *testing.B) { + for range b.N { + vm, _ := w3vm.New( + w3vm.WithState(w3types.State{ + addrWETH: {Code: codeWETH}, + addr0: {Balance: w3.I("2 ether")}, + }), + ) + + for i := 0; i < runs; i++ { + _, err := vm.Apply(depositMsg) + if err != nil { + b.Fatalf("Failed to deposit: %v", err) + } + } + } + }) + + b.Run(fmt.Sprintf("snapshot %d", runs), func(b *testing.B) { + vm, _ := w3vm.New( + w3vm.WithState(w3types.State{ + addrWETH: {Code: codeWETH}, + addr0: {Balance: w3.I("2 ether")}, + }), + ) + + for i := 0; i < runs-1; i++ { + _, err := vm.Apply(depositMsg) + if err != nil { + b.Fatalf("Failed to deposit: %v", err) + } + } + + snap := vm.Snapshot() + + for range b.N { + _, err := vm.Apply(depositMsg) + if err != nil { + b.Fatalf("Failed to deposit: %v", err) + } + + vm.Rollback(snap.Copy()) + } + }) +} diff --git a/w3vm/fetcher.go b/w3vm/fetcher.go index a31c4890..f3797a65 100644 --- a/w3vm/fetcher.go +++ b/w3vm/fetcher.go @@ -363,12 +363,12 @@ func (f *rpcFetcher) storeTestdataState(tb testing.TB, chainID uint64) error { // create directory, if it does not exist dirPath := filepath.Dir(fn) if _, err := os.Stat(dirPath); errors.Is(err, os.ErrNotExist) { - if err := os.MkdirAll(dirPath, 0775); err != nil { + if err := os.MkdirAll(dirPath, 0o775); err != nil { return err } } - file, err := os.OpenFile(fn, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0664) + file, err := os.OpenFile(fn, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o664) if err != nil { return err } From 65743e62676243265de32c98dcf8ce2755b02dbc Mon Sep 17 00:00:00 2001 From: lmittmann Date: Mon, 3 Jun 2024 12:34:29 +0200 Subject: [PATCH 6/6] improved doc --- w3vm/vm.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/w3vm/vm.go b/w3vm/vm.go index 572d9bf3..af82647a 100644 --- a/w3vm/vm.go +++ b/w3vm/vm.go @@ -229,7 +229,8 @@ func (vm *VM) StorageAt(addr common.Address, slot common.Hash) (common.Hash, err return val, nil } -// Snapshot the current state of the VM. +// Snapshot the current state of the VM. The returned state can only be rolled +// back to once. Use [state.StateDB.Copy] if you need to rollback multiple times. func (vm *VM) Snapshot() *state.StateDB { return vm.db.Copy() } // Rollback the state of the VM to the given snapshot.