Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

w3vm: Added VM Snapshotting #147

Merged
merged 6 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions w3vm/bench_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package w3vm_test

import (
"fmt"
"math/big"
"testing"

Expand Down Expand Up @@ -102,3 +103,57 @@ func BenchmarkVM(b *testing.B) {
dur := b.Elapsed()
b.ReportMetric(float64(gasSimulated)/dur.Seconds(), "gas/s")
}

func BenchmarkVMSnapshot(b *testing.B) {
depositMsg := &w3types.Message{
From: addr0,
To: &addrWETH,
Value: w3.I("1 ether"),
}

runs := 2
b.Run(fmt.Sprintf("re-run %d", runs), func(b *testing.B) {
for range b.N {
vm, _ := w3vm.New(
w3vm.WithState(w3types.State{
addrWETH: {Code: codeWETH},
addr0: {Balance: w3.I("2 ether")},
}),
)

for i := 0; i < runs; i++ {
_, err := vm.Apply(depositMsg)
if err != nil {
b.Fatalf("Failed to deposit: %v", err)
}
}
}
})

b.Run(fmt.Sprintf("snapshot %d", runs), func(b *testing.B) {
vm, _ := w3vm.New(
w3vm.WithState(w3types.State{
addrWETH: {Code: codeWETH},
addr0: {Balance: w3.I("2 ether")},
}),
)

for i := 0; i < runs-1; i++ {
_, err := vm.Apply(depositMsg)
if err != nil {
b.Fatalf("Failed to deposit: %v", err)
}
}

snap := vm.Snapshot()

for range b.N {
_, err := vm.Apply(depositMsg)
if err != nil {
b.Fatalf("Failed to deposit: %v", err)
}

vm.Rollback(snap.Copy())
}
})
}
13 changes: 9 additions & 4 deletions w3vm/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"errors"

"github.com/ethereum/go-ethereum/common"
gethState "github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/trie"
Expand All @@ -28,13 +28,18 @@ func newDB(fetcher Fetcher) *db {
// state.Database methods //////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////

func (db *db) OpenTrie(root common.Hash) (gethState.Trie, error) { return db, nil }
func (db *db) OpenTrie(root common.Hash) (state.Trie, error) { return db, nil }

func (db *db) OpenStorageTrie(stateRoot common.Hash, addr common.Address, root common.Hash, trie gethState.Trie) (gethState.Trie, error) {
func (db *db) OpenStorageTrie(stateRoot common.Hash, addr common.Address, root common.Hash, trie state.Trie) (state.Trie, error) {
return db, nil
}

func (*db) CopyTrie(gethState.Trie) gethState.Trie { panic("not implemented") }
func (*db) CopyTrie(trie state.Trie) state.Trie {
if db, ok := trie.(*db); ok {
return db
}
panic("not implemented")
}

func (db *db) ContractCode(addr common.Address, codeHash common.Hash) ([]byte, error) {
if db.fetcher == nil {
Expand Down
14 changes: 7 additions & 7 deletions w3vm/fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ func NewTestingRPCFetcher(tb testing.TB, chainID uint64, client *w3.Client, bloc

var (
globalStateStoreMux sync.RWMutex
globalStateStore = make(map[string]*state)
globalStateStore = make(map[string]*testdataState)
)

func (f *rpcFetcher) loadTestdataState(tb testing.TB, chainID uint64) error {
Expand All @@ -224,7 +224,7 @@ func (f *rpcFetcher) loadTestdataState(tb testing.TB, chainID uint64) error {
fmt.Sprintf("%d_%v.json", chainID, f.blockNumber),
)

var s *state
var s *testdataState

// check if the state has already been loaded
globalStateStoreMux.RLock()
Expand Down Expand Up @@ -307,7 +307,7 @@ func (f *rpcFetcher) storeTestdataState(tb testing.TB, chainID uint64) error {
defer f.mux2.RUnlock()
defer f.mux3.RUnlock()

s := &state{
s := &testdataState{
Accounts: make(map[common.Address]*account, len(f.accounts)),
HeaderHashes: make(map[hexutil.Uint64]common.Hash, len(f.headerHashes)),
}
Expand Down Expand Up @@ -363,12 +363,12 @@ func (f *rpcFetcher) storeTestdataState(tb testing.TB, chainID uint64) error {
// create directory, if it does not exist
dirPath := filepath.Dir(fn)
if _, err := os.Stat(dirPath); errors.Is(err, os.ErrNotExist) {
if err := os.MkdirAll(dirPath, 0775); err != nil {
if err := os.MkdirAll(dirPath, 0o775); err != nil {
return err
}
}

file, err := os.OpenFile(fn, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0664)
file, err := os.OpenFile(fn, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o664)
if err != nil {
return err
}
Expand All @@ -382,7 +382,7 @@ func (f *rpcFetcher) storeTestdataState(tb testing.TB, chainID uint64) error {
return nil
}

type state struct {
type testdataState struct {
Accounts map[common.Address]*account `json:"accounts"`
HeaderHashes map[hexutil.Uint64]common.Hash `json:"headerHashes,omitempty"`
}
Expand All @@ -396,7 +396,7 @@ type account struct {

// mergeStates merges the source state into the destination state and returns
// whether the destination state has been modified.
func mergeStates(dst, src *state) (modified bool) {
func mergeStates(dst, src *testdataState) (modified bool) {
// merge accounts
for addr, acc := range src.Accounts {
if dstAcc, ok := dst.Accounts[addr]; !ok {
Expand Down
5 changes: 0 additions & 5 deletions w3vm/testdata/w3vm/1_17034867.json
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
{
"accounts": {
"0x0000000000000000000000000000000000000000": {
"nonce": "0x0",
"balance": "0x272392e2b6e127d35e3",
"code": "0x"
},
"0x0000000000000000000000000000000000000001": {
"nonce": "0x0",
"balance": "0x2e58c20c74febd3b7",
Expand Down
15 changes: 11 additions & 4 deletions w3vm/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/math"
"github.com/ethereum/go-ethereum/core"
gethState "github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/tracing"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
Expand All @@ -39,7 +39,7 @@ type VM struct {
opts *options

txIndex uint64
db *gethState.StateDB
db *state.StateDB
}

// New creates a new VM, that is configured with the given options.
Expand All @@ -58,7 +58,7 @@ func New(opts ...Option) (*VM, error) {

// set DB
db := newDB(vm.opts.fetcher)
vm.db, _ = gethState.New(w3.Hash0, db, nil)
vm.db, _ = state.New(w3.Hash0, db, nil)
for addr, acc := range vm.opts.preState {
vm.db.SetNonce(addr, acc.Nonce)
if acc.Balance != nil {
Expand Down Expand Up @@ -229,9 +229,16 @@ func (vm *VM) StorageAt(addr common.Address, slot common.Hash) (common.Hash, err
return val, nil
}

// Snapshot the current state of the VM. The returned state can only be rolled
// back to once. Use [state.StateDB.Copy] if you need to rollback multiple times.
func (vm *VM) Snapshot() *state.StateDB { return vm.db.Copy() }

// Rollback the state of the VM to the given snapshot.
func (vm *VM) Rollback(snapshot *state.StateDB) { vm.db = snapshot }

func (v *VM) buildMessage(msg *w3types.Message, skipAccChecks bool) (*core.Message, *vm.TxContext, error) {
nonce := msg.Nonce
if !skipAccChecks && nonce == 0 && msg.From != w3.Addr0 {
if !skipAccChecks && nonce == 0 {
var err error
nonce, err = v.Nonce(msg.From)
if err != nil {
Expand Down
56 changes: 54 additions & 2 deletions w3vm/vm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"github.com/ethereum/go-ethereum/common/math"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/rawdb"
coreState "github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/crypto"
Expand Down Expand Up @@ -214,6 +214,58 @@ func TestVMApply(t *testing.T) {
}
}

func TestVMSnapshot(t *testing.T) {
vm, _ := w3vm.New(
w3vm.WithState(w3types.State{
addrWETH: {Code: codeWETH},
addr0: {Balance: w3.I("100 ether")},
}),
)

depositMsg := &w3types.Message{
From: addr0,
To: &addrWETH,
Value: w3.I("1 ether"),
}

getBalanceOf := func(t *testing.T, token, acc common.Address) *big.Int {
t.Helper()

var balance *big.Int
if err := vm.CallFunc(token, funcBalanceOf, acc).Returns(&balance); err != nil {
t.Fatalf("Failed to call balanceOf: %v", err)
}
return balance
}

if got := getBalanceOf(t, addrWETH, addr0); got.Sign() != 0 {
t.Fatalf("Balance: want 0 WETH, got %s WETH", w3.FromWei(got, 18))
}

var snap *state.StateDB
for i := range 100 {
if i == 42 {
snap = vm.Snapshot()
}

if _, err := vm.Apply(depositMsg); err != nil {
t.Fatalf("Failed to apply deposit msg: %v", err)
}

want := w3.I(strconv.Itoa(i+1) + " ether")
if got := getBalanceOf(t, addrWETH, addr0); want.Cmp(got) != 0 {
t.Fatalf("Balance: want %s WETH, got %s WETH", w3.FromWei(want, 18), w3.FromWei(got, 18))
}
}

vm.Rollback(snap)

want := w3.I("42 ether")
if got := getBalanceOf(t, addrWETH, addr0); got.Cmp(want) != 0 {
t.Fatalf("Balance: want %s WETH, got %s WETH", w3.FromWei(want, 18), w3.FromWei(got, 18))
}
}

func TestVMCall(t *testing.T) {
tests := []struct {
PreState w3types.State
Expand Down Expand Up @@ -506,7 +558,7 @@ func BenchmarkTransferWETH9(b *testing.B) {
})

b.Run("geth", func(b *testing.B) {
stateDB, _ := coreState.New(common.Hash{}, coreState.NewDatabase(rawdb.NewMemoryDatabase()), nil)
stateDB, _ := state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase()), nil)
stateDB.SetCode(addrWETH, codeWETH)
stateDB.SetState(addrWETH, w3vm.WETHBalanceSlot(addr0), common.BigToHash(w3.I("1 ether")))

Expand Down