Skip to content

Commit

Permalink
feat(migration): fix bug when deleting a trie node not originating fr…
Browse files Browse the repository at this point in the history
…om DeleteStorage or DeleteAccount
  • Loading branch information
0xbenyun committed Dec 10, 2024
1 parent f97780b commit 8d10579
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 35 deletions.
5 changes: 2 additions & 3 deletions migration/migrator_newstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func (m *StateMigrator) applyAccountChanges(tr *trie.StateTrie, bn uint64, root
}
// if set is nil, it means there are no changes, so we skip verification in that case.
if set != nil {
if err := m.verifyStorage(storageTr, id, addr, set, bn); err != nil {
if err := m.validateStorage(storageTr, id, addr, set, bn); err != nil {
return err
}
}
Expand Down Expand Up @@ -158,10 +158,9 @@ func (m *StateMigrator) applyNewStateTransition(headNumber uint64) error {
if err != nil {
return err
}

// if set is nil, it means there are no changes, so we skip verification in that case.
if set != nil {
if err := m.verifyState(tr, set, prevRoot, i); err != nil {
if err := m.validateState(tr, set, prevRoot, i); err != nil {
return err
}
}
Expand Down
4 changes: 2 additions & 2 deletions migration/migrator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func TestApplyNewStateTransition(t *testing.T) {
accounts[addr2] = acc
storages := make(map[common.Address]map[common.Hash][]byte)
accStorage := make(map[common.Hash][]byte)
for i := 0; i < 30000; i++ {
for i := 0; i < 20000; i++ {
key := common.BigToHash(big.NewInt(int64(i)))
val := big.NewInt(int64(rand.Uint32()) + 1).Bytes()
accStorage[key] = val
Expand All @@ -74,7 +74,7 @@ func TestApplyNewStateTransition(t *testing.T) {
storages = make(map[common.Address]map[common.Hash][]byte)
accStorage = make(map[common.Hash][]byte)

for i := 0; i < 15000; i++ {
for i := 0; i < 10000; i++ {
key := common.BigToHash(big.NewInt(int64(i)))
val := big.NewInt(int64(0)).Bytes() // to result in DeleteStorage()
accStorage[key] = val
Expand Down
135 changes: 105 additions & 30 deletions migration/migrator_validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ import (
"github.com/ethereum/go-ethereum/trie/trienode"
)

type Slot struct {
key []byte
value []byte
}

type StateAccount struct {
addr common.Address
value *types.StateAccount
}

func (m *StateMigrator) ValidateMigratedState(mptRoot common.Hash, zkRoot common.Hash) error {
var accounts atomic.Uint64
var slots atomic.Uint64
Expand Down Expand Up @@ -141,8 +151,8 @@ func (m *StateMigrator) ValidateMigratedState(mptRoot common.Hash, zkRoot common
return nil
}

// if verification succeeds, it returns nil
func (m *StateMigrator) verifyState(tr *trie.StateTrie, set *trienode.NodeSet, prevRoot common.Hash, bn uint64) error {
// if validation succeeds, it returns nil
func (m *StateMigrator) validateState(tr *trie.StateTrie, set *trienode.NodeSet, prevRoot common.Hash, bn uint64) error {
if set == nil {
return nil
}
Expand Down Expand Up @@ -174,6 +184,9 @@ func (m *StateMigrator) verifyState(tr *trie.StateTrie, set *trienode.NodeSet, p
if err != nil {
return err
}

var deletedLeaves []*StateAccount
var updatedLeaves []*StateAccount
for path, node := range set.Nodes {
if node.IsDeleted() {
blob, _, err := originTrie.GetNode(trie.HexToCompact(path))
Expand All @@ -190,10 +203,7 @@ func (m *StateMigrator) verifyState(tr *trie.StateTrie, set *trienode.NodeSet, p
return fmt.Errorf("failed to get preimage for hashKey: %x", hk)
}
addr := common.BytesToAddress(preimage)
err = parentZkTrie.DeleteAccount(addr)
if err != nil {
return err
}
deletedLeaves = append(deletedLeaves, &StateAccount{addr, nil})
}
} else {
if trie.IsLeafNode(node.Blob) {
Expand All @@ -218,13 +228,24 @@ func (m *StateMigrator) verifyState(tr *trie.StateTrie, set *trienode.NodeSet, p
} else {
acc.Root = zkAcc.Root
}
err = parentZkTrie.UpdateAccount(addr, acc)
if err != nil {
return err
}
updatedLeaves = append(updatedLeaves, &StateAccount{addr, acc})
}
}
}

for _, leave := range deletedLeaves {
err = parentZkTrie.DeleteAccount(leave.addr)
if err != nil {
return err
}
}
for _, leave := range updatedLeaves {
err = parentZkTrie.UpdateAccount(leave.addr, leave.value)
if err != nil {
return err
}
}

zktRoot, _, err := parentZkTrie.Commit(false)
if err != nil {
return err
Expand All @@ -235,7 +256,7 @@ func (m *StateMigrator) verifyState(tr *trie.StateTrie, set *trienode.NodeSet, p
return nil
}

func (m *StateMigrator) verifyStorage(tr *trie.StateTrie, id *trie.ID, addr common.Address, set *trienode.NodeSet, bn uint64) error {
func (m *StateMigrator) validateStorage(tr *trie.StateTrie, id *trie.ID, addr common.Address, set *trienode.NodeSet, bn uint64) error {
if set == nil {
return nil
}
Expand Down Expand Up @@ -271,12 +292,16 @@ func (m *StateMigrator) verifyStorage(tr *trie.StateTrie, id *trie.ID, addr comm
if err != nil {
return err
}

var deletedLeaves []*Slot
var updatedLeaves []*Slot
for path, node := range set.Nodes {
if node.IsDeleted() {
blob, _, err := originTrie.GetNode(trie.HexToCompact(path))
if err != nil {
return err
}

if blob != nil && trie.IsLeafNode(blob) {
hk, err := trie.GetKeyFromPath(originRootNode, m.db, []byte(path))
if err != nil {
Expand All @@ -286,14 +311,14 @@ func (m *StateMigrator) verifyStorage(tr *trie.StateTrie, id *trie.ID, addr comm
if preimage == nil {
return fmt.Errorf("failed to get preimage for hashKey: %x", hk)
}
slot := common.BytesToHash(preimage)
err = parentZkt.DeleteStorage(common.Address{}, slot.Bytes())
if err != nil {
return err
}
if bytes.Compare(slot.Bytes(), hexutils.HexToBytes("1db7b1394727b4ec83580f945ddf3bdf76fcb71f6c6109779c749ee7ec003004")) == 0 {
log.Info(fmt.Sprintf("[DeleteStorage] slot : %x", slot))
slot := common.BytesToHash(preimage).Bytes()
deletedLeaves = append(deletedLeaves, &Slot{slot, nil})
if addr.Cmp(common.HexToAddress("0x51901916b0a8A67b18299bb6fA16da4D7428f9cA")) == 0 {
if bytes.Compare(slot, hexutils.HexToBytes("5be7a1449cff78980bfa293037675a1770f220327e9386d49ba469c7210536a4")) == 0 {

Check failure on line 317 in migration/migrator_validate.go

View workflow job for this annotation

GitHub Actions / Lint check (ubuntu-latest)

S1004: should use bytes.Equal(slot, hexutils.HexToBytes("5be7a1449cff78980bfa293037675a1770f220327e9386d49ba469c7210536a4")) instead (gosimple)
fmt.Printf("[DeleteStorage]\n")
}
}

}
} else {
if trie.IsLeafNode(node.Blob) {
Expand All @@ -310,16 +335,29 @@ func (m *StateMigrator) verifyStorage(tr *trie.StateTrie, id *trie.ID, addr comm
if err != nil {
return err
}
if bytes.Compare(slot, hexutils.HexToBytes("1db7b1394727b4ec83580f945ddf3bdf76fcb71f6c6109779c749ee7ec003004")) == 0 {
log.Info(fmt.Sprintf("[UpdateStorage] slot : %x , value : %x", slot, val))
}
err = parentZkt.UpdateStorage(common.Address{}, slot, val)
if err != nil {
return err
updatedLeaves = append(updatedLeaves, &Slot{slot, val})
if addr.Cmp(common.HexToAddress("0x51901916b0a8A67b18299bb6fA16da4D7428f9cA")) == 0 {
if bytes.Compare(slot, hexutils.HexToBytes("5be7a1449cff78980bfa293037675a1770f220327e9386d49ba469c7210536a4")) == 0 {

Check failure on line 340 in migration/migrator_validate.go

View workflow job for this annotation

GitHub Actions / Lint check (ubuntu-latest)

S1004: should use bytes.Equal(slot, hexutils.HexToBytes("5be7a1449cff78980bfa293037675a1770f220327e9386d49ba469c7210536a4")) instead (gosimple)
fmt.Printf("[UpdateStorage] path %x slot %x\n", []byte(path), slot)
}
}
}
}
}

for _, leave := range deletedLeaves {
err = parentZkt.DeleteStorage(common.Address{}, leave.key)
if err != nil {
return err
}
}
for _, leave := range updatedLeaves {
err = parentZkt.UpdateStorage(common.Address{}, leave.key, leave.value)
if err != nil {
return err
}
}

zktRoot, _, err := parentZkt.Commit(false)
if err != nil {
return err
Expand All @@ -338,29 +376,34 @@ func (m *StateMigrator) verifyStorage(tr *trie.StateTrie, id *trie.ID, addr comm
return fmt.Errorf("account doesn't exist: %s", addr.Hex())
} else {
if zktRoot.Cmp(zkAcc.Root) != 0 {
err := m.printStoragesForDebug(zkAcc.Root, parentZkt)
if err != nil {
panic(err)
}
//err := m.printStoragesForDebug(zkAcc.Root, parentZkt)
//if err != nil {
// panic(err)
//}
return fmt.Errorf("invalid migrated storage of account: %s", addr.Hex())
}
}
return nil
}

// TODO(Ben) : this func should be removed before this branch is merged
func (m *StateMigrator) printStoragesForDebug(expectedStorageRoot common.Hash, actualZkt *trie.ZkMerkleStateTrie) error {

Check failure on line 390 in migration/migrator_validate.go

View workflow job for this annotation

GitHub Actions / Lint check (ubuntu-latest)

func `(*StateMigrator).printStoragesForDebug` is unused (unused)
expectedZkt, err := trie.NewZkMerkleStateTrie(expectedStorageRoot, m.zktdb)
if err != nil {
return err
}
log.Info(fmt.Sprintf("expectedStorageRoot : %x\n", expectedStorageRoot))
log.Info(fmt.Sprintf("actualtorageRoot : %x\n", actualZkt.Hash()))

nodeIt, err := expectedZkt.NodeIterator(nil)
if err != nil {
return fmt.Errorf("failed to open node iterator (root: %s): %w", expectedZkt.Hash(), err)
}
iter := trie.NewIterator(nodeIt)
storageNum := 0
actualStorages := make(map[common.Hash]bool)
for iter.Next() {

storageNum++
hk := trie.IteratorKeyToHash(iter.Key, true)
preimage, err := m.readZkPreimage(*hk)
if err != nil {
Expand All @@ -374,19 +417,51 @@ func (m *StateMigrator) printStoragesForDebug(expectedStorageRoot common.Hash, a
log.Error("Failed to get storage value in MPT", "err", err)
return err
}
actualStorages[common.BytesToHash(slot)] = true
if !bytes.Equal(actualVal, zktVal) {

log.Warn(fmt.Sprintf("expected - slot : %x, val : %x\n", slot, zktVal))
log.Warn(fmt.Sprintf("actual - slot : %x, val : %x\n", slot, actualVal))
} else {
log.Info(fmt.Sprintf("passed validation - slot : %x, val : %x\n", slot, actualVal))
}

if err != nil {
return err
}
}
log.Info(fmt.Sprintf("expected storageNum : %d\n", storageNum))
if iter.Err != nil {
return fmt.Errorf("failed to traverse state trie (root: %s): %w", actualZkt.Hash(), iter.Err)
}

actualStorageNum := 0
{
nodeIt, err := actualZkt.NodeIterator(nil)
if err != nil {
return fmt.Errorf("failed to open actual zkt node iterator (root: %s): %w", actualZkt.Hash(), err)
}
iter := trie.NewIterator(nodeIt)
for iter.Next() {
actualStorageNum++
hk := trie.IteratorKeyToHash(iter.Key, true)
preimage, err := m.readZkPreimage(*hk)
if err != nil {
return err
}
slot := common.BytesToHash(preimage).Bytes()
zktVal := common.BytesToHash(iter.Value).Bytes()

if _, ok := actualStorages[common.BytesToHash(slot)]; !ok {
log.Warn("actualStorage has a slot doesn't exist in expectedStorage")
log.Warn(fmt.Sprintf("actual - slot : %x, val : %x\n", slot, zktVal))
}
}
log.Info(fmt.Sprintf("actual storageNum : %d\n", actualStorageNum))
}

if storageNum != actualStorageNum {
log.Warn(fmt.Sprintf("storage num is not equal : expected %d, actual %d\n", storageNum, actualStorageNum))
}

return nil
}

0 comments on commit 8d10579

Please sign in to comment.