diff --git a/w3vm/bench_test.go b/w3vm/bench_test.go index fe4d4f32..36952b79 100644 --- a/w3vm/bench_test.go +++ b/w3vm/bench_test.go @@ -1,6 +1,7 @@ package w3vm_test import ( + "fmt" "math/big" "testing" @@ -102,3 +103,57 @@ func BenchmarkVM(b *testing.B) { dur := b.Elapsed() b.ReportMetric(float64(gasSimulated)/dur.Seconds(), "gas/s") } + +func BenchmarkVMSnapshot(b *testing.B) { + depositMsg := &w3types.Message{ + From: addr0, + To: &addrWETH, + Value: w3.I("1 ether"), + } + + runs := 2 + b.Run(fmt.Sprintf("re-run %d", runs), func(b *testing.B) { + for range b.N { + vm, _ := w3vm.New( + w3vm.WithState(w3types.State{ + addrWETH: {Code: codeWETH}, + addr0: {Balance: w3.I("2 ether")}, + }), + ) + + for i := 0; i < runs; i++ { + _, err := vm.Apply(depositMsg) + if err != nil { + b.Fatalf("Failed to deposit: %v", err) + } + } + } + }) + + b.Run(fmt.Sprintf("snapshot %d", runs), func(b *testing.B) { + vm, _ := w3vm.New( + w3vm.WithState(w3types.State{ + addrWETH: {Code: codeWETH}, + addr0: {Balance: w3.I("2 ether")}, + }), + ) + + for i := 0; i < runs-1; i++ { + _, err := vm.Apply(depositMsg) + if err != nil { + b.Fatalf("Failed to deposit: %v", err) + } + } + + snap := vm.Snapshot() + + for range b.N { + _, err := vm.Apply(depositMsg) + if err != nil { + b.Fatalf("Failed to deposit: %v", err) + } + + vm.Rollback(snap.Copy()) + } + }) +} diff --git a/w3vm/db.go b/w3vm/db.go index b67d82d3..07287eba 100644 --- a/w3vm/db.go +++ b/w3vm/db.go @@ -4,7 +4,7 @@ import ( "errors" "github.com/ethereum/go-ethereum/common" - gethState "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/trie" @@ -28,13 +28,18 @@ func newDB(fetcher Fetcher) *db { // state.Database methods ////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// -func (db *db) OpenTrie(root common.Hash) (gethState.Trie, error) { return db, nil } +func (db *db) OpenTrie(root common.Hash) (state.Trie, error) { return db, nil } -func (db *db) OpenStorageTrie(stateRoot common.Hash, addr common.Address, root common.Hash, trie gethState.Trie) (gethState.Trie, error) { +func (db *db) OpenStorageTrie(stateRoot common.Hash, addr common.Address, root common.Hash, trie state.Trie) (state.Trie, error) { return db, nil } -func (*db) CopyTrie(gethState.Trie) gethState.Trie { panic("not implemented") } +func (*db) CopyTrie(trie state.Trie) state.Trie { + if db, ok := trie.(*db); ok { + return db + } + panic("not implemented") +} func (db *db) ContractCode(addr common.Address, codeHash common.Hash) ([]byte, error) { if db.fetcher == nil { diff --git a/w3vm/fetcher.go b/w3vm/fetcher.go index 107ed930..f3797a65 100644 --- a/w3vm/fetcher.go +++ b/w3vm/fetcher.go @@ -213,7 +213,7 @@ func NewTestingRPCFetcher(tb testing.TB, chainID uint64, client *w3.Client, bloc var ( globalStateStoreMux sync.RWMutex - globalStateStore = make(map[string]*state) + globalStateStore = make(map[string]*testdataState) ) func (f *rpcFetcher) loadTestdataState(tb testing.TB, chainID uint64) error { @@ -224,7 +224,7 @@ func (f *rpcFetcher) loadTestdataState(tb testing.TB, chainID uint64) error { fmt.Sprintf("%d_%v.json", chainID, f.blockNumber), ) - var s *state + var s *testdataState // check if the state has already been loaded globalStateStoreMux.RLock() @@ -307,7 +307,7 @@ func (f *rpcFetcher) storeTestdataState(tb testing.TB, chainID uint64) error { defer f.mux2.RUnlock() defer f.mux3.RUnlock() - s := &state{ + s := &testdataState{ Accounts: make(map[common.Address]*account, len(f.accounts)), HeaderHashes: make(map[hexutil.Uint64]common.Hash, len(f.headerHashes)), } @@ -363,12 +363,12 @@ func (f *rpcFetcher) storeTestdataState(tb testing.TB, chainID uint64) error { // create directory, if it does not exist dirPath := filepath.Dir(fn) if _, err := os.Stat(dirPath); errors.Is(err, os.ErrNotExist) { - if err := os.MkdirAll(dirPath, 0775); err != nil { + if err := os.MkdirAll(dirPath, 0o775); err != nil { return err } } - file, err := os.OpenFile(fn, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0664) + file, err := os.OpenFile(fn, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o664) if err != nil { return err } @@ -382,7 +382,7 @@ func (f *rpcFetcher) storeTestdataState(tb testing.TB, chainID uint64) error { return nil } -type state struct { +type testdataState struct { Accounts map[common.Address]*account `json:"accounts"` HeaderHashes map[hexutil.Uint64]common.Hash `json:"headerHashes,omitempty"` } @@ -396,7 +396,7 @@ type account struct { // mergeStates merges the source state into the destination state and returns // whether the destination state has been modified. -func mergeStates(dst, src *state) (modified bool) { +func mergeStates(dst, src *testdataState) (modified bool) { // merge accounts for addr, acc := range src.Accounts { if dstAcc, ok := dst.Accounts[addr]; !ok { diff --git a/w3vm/testdata/w3vm/1_17034867.json b/w3vm/testdata/w3vm/1_17034867.json index edc67154..90aa1b3e 100644 --- a/w3vm/testdata/w3vm/1_17034867.json +++ b/w3vm/testdata/w3vm/1_17034867.json @@ -1,10 +1,5 @@ { "accounts": { - "0x0000000000000000000000000000000000000000": { - "nonce": "0x0", - "balance": "0x272392e2b6e127d35e3", - "code": "0x" - }, "0x0000000000000000000000000000000000000001": { "nonce": "0x0", "balance": "0x2e58c20c74febd3b7", diff --git a/w3vm/vm.go b/w3vm/vm.go index efa1c408..af82647a 100644 --- a/w3vm/vm.go +++ b/w3vm/vm.go @@ -16,7 +16,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/core" - gethState "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/tracing" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" @@ -39,7 +39,7 @@ type VM struct { opts *options txIndex uint64 - db *gethState.StateDB + db *state.StateDB } // New creates a new VM, that is configured with the given options. @@ -58,7 +58,7 @@ func New(opts ...Option) (*VM, error) { // set DB db := newDB(vm.opts.fetcher) - vm.db, _ = gethState.New(w3.Hash0, db, nil) + vm.db, _ = state.New(w3.Hash0, db, nil) for addr, acc := range vm.opts.preState { vm.db.SetNonce(addr, acc.Nonce) if acc.Balance != nil { @@ -229,9 +229,16 @@ func (vm *VM) StorageAt(addr common.Address, slot common.Hash) (common.Hash, err return val, nil } +// Snapshot the current state of the VM. The returned state can only be rolled +// back to once. Use [state.StateDB.Copy] if you need to rollback multiple times. +func (vm *VM) Snapshot() *state.StateDB { return vm.db.Copy() } + +// Rollback the state of the VM to the given snapshot. +func (vm *VM) Rollback(snapshot *state.StateDB) { vm.db = snapshot } + func (v *VM) buildMessage(msg *w3types.Message, skipAccChecks bool) (*core.Message, *vm.TxContext, error) { nonce := msg.Nonce - if !skipAccChecks && nonce == 0 && msg.From != w3.Addr0 { + if !skipAccChecks && nonce == 0 { var err error nonce, err = v.Nonce(msg.From) if err != nil { diff --git a/w3vm/vm_test.go b/w3vm/vm_test.go index 75930882..1e1eb6f9 100644 --- a/w3vm/vm_test.go +++ b/w3vm/vm_test.go @@ -14,7 +14,7 @@ import ( "github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/rawdb" - coreState "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/crypto" @@ -214,6 +214,58 @@ func TestVMApply(t *testing.T) { } } +func TestVMSnapshot(t *testing.T) { + vm, _ := w3vm.New( + w3vm.WithState(w3types.State{ + addrWETH: {Code: codeWETH}, + addr0: {Balance: w3.I("100 ether")}, + }), + ) + + depositMsg := &w3types.Message{ + From: addr0, + To: &addrWETH, + Value: w3.I("1 ether"), + } + + getBalanceOf := func(t *testing.T, token, acc common.Address) *big.Int { + t.Helper() + + var balance *big.Int + if err := vm.CallFunc(token, funcBalanceOf, acc).Returns(&balance); err != nil { + t.Fatalf("Failed to call balanceOf: %v", err) + } + return balance + } + + if got := getBalanceOf(t, addrWETH, addr0); got.Sign() != 0 { + t.Fatalf("Balance: want 0 WETH, got %s WETH", w3.FromWei(got, 18)) + } + + var snap *state.StateDB + for i := range 100 { + if i == 42 { + snap = vm.Snapshot() + } + + if _, err := vm.Apply(depositMsg); err != nil { + t.Fatalf("Failed to apply deposit msg: %v", err) + } + + want := w3.I(strconv.Itoa(i+1) + " ether") + if got := getBalanceOf(t, addrWETH, addr0); want.Cmp(got) != 0 { + t.Fatalf("Balance: want %s WETH, got %s WETH", w3.FromWei(want, 18), w3.FromWei(got, 18)) + } + } + + vm.Rollback(snap) + + want := w3.I("42 ether") + if got := getBalanceOf(t, addrWETH, addr0); got.Cmp(want) != 0 { + t.Fatalf("Balance: want %s WETH, got %s WETH", w3.FromWei(want, 18), w3.FromWei(got, 18)) + } +} + func TestVMCall(t *testing.T) { tests := []struct { PreState w3types.State @@ -506,7 +558,7 @@ func BenchmarkTransferWETH9(b *testing.B) { }) b.Run("geth", func(b *testing.B) { - stateDB, _ := coreState.New(common.Hash{}, coreState.NewDatabase(rawdb.NewMemoryDatabase()), nil) + stateDB, _ := state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase()), nil) stateDB.SetCode(addrWETH, codeWETH) stateDB.SetState(addrWETH, w3vm.WETHBalanceSlot(addr0), common.BigToHash(w3.I("1 ether")))