Skip to content

Commit

Permalink
[hotfix] fix code hash conflicts (#4431)
Browse files Browse the repository at this point in the history
* fix code hash issue

* goimports
  • Loading branch information
GheisMohammadi authored May 12, 2023
1 parent 5195577 commit 6577b0b
Show file tree
Hide file tree
Showing 15 changed files with 104 additions and 95 deletions.
2 changes: 1 addition & 1 deletion core/state/dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func (s *DB) DumpToCollector(c DumpCollector, conf *DumpConfig) (nextKey []byte)
addr := common.BytesToAddress(addrBytes)
obj := newObject(s, addr, data)
if !conf.SkipCode {
account.Code = obj.Code(s.db, false)
account.Code = obj.Code(s.db)
}
if !conf.SkipStorage {
account.Storage = make(map[common.Hash]string)
Expand Down
2 changes: 1 addition & 1 deletion core/state/prefeth.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (s *DB) prefetchWorker(job *prefetchJob, jobs chan *prefetchJob) {
addr := common.BytesToAddress(addrBytes)
obj := newObject(s, addr, data)
if data.CodeHash != nil {
obj.Code(s.db, false)
obj.Code(s.db)
}

// build account trie tree
Expand Down
125 changes: 68 additions & 57 deletions core/state/state_object.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@ import (
"github.com/harmony-one/harmony/staking"
)

var emptyCodeHash = crypto.Keccak256(nil)
var (
// EmptyRootHash is the known root hash of an empty trie.
EmptyRootHash = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421")

// EmptyCodeHash is the known hash of the empty EVM bytecode.
EmptyCodeHash = crypto.Keccak256Hash(nil) // c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470
)

// Code ...
type Code []byte
Expand Down Expand Up @@ -101,7 +107,7 @@ type Object struct {

// empty returns whether the account is considered empty.
func (s *Object) empty() bool {
return s.data.Nonce == 0 && s.data.Balance.Sign() == 0 && bytes.Equal(s.data.CodeHash, emptyCodeHash)
return s.data.Nonce == 0 && s.data.Balance.Sign() == 0 && bytes.Equal(s.data.CodeHash, EmptyCodeHash.Bytes())
}

// Account is the Ethereum consensus representation of accounts.
Expand All @@ -119,10 +125,10 @@ func newObject(db *DB, address common.Address, data types.StateAccount) *Object
data.Balance = new(big.Int)
}
if data.CodeHash == nil {
data.CodeHash = types.EmptyCodeHash.Bytes()
data.CodeHash = EmptyCodeHash.Bytes()
}
if data.Root == (common.Hash{}) {
data.Root = types.EmptyRootHash
data.Root = EmptyRootHash
}
return &Object{
db: db,
Expand Down Expand Up @@ -169,7 +175,7 @@ func (s *Object) getTrie(db Database) (Trie, error) {
if s.trie == nil {
// Try fetching from prefetcher first
// We don't prefetch empty tries
if s.data.Root != types.EmptyRootHash && s.db.prefetcher != nil {
if s.data.Root != EmptyRootHash && s.db.prefetcher != nil {
// When the miner is creating the pending state, there is no
// prefetcher
s.trie = s.db.prefetcher.trie(s.addrHash, s.data.Root)
Expand Down Expand Up @@ -316,7 +322,7 @@ func (s *Object) finalise(prefetch bool) {
slotsToPrefetch = append(slotsToPrefetch, common.CopyBytes(key[:])) // Copy needed for closure
}
}
if s.db.prefetcher != nil && prefetch && len(slotsToPrefetch) > 0 && s.data.Root != types.EmptyRootHash {
if s.db.prefetcher != nil && prefetch && len(slotsToPrefetch) > 0 && s.data.Root != EmptyRootHash {
s.db.prefetcher.prefetch(s.addrHash, s.data.Root, slotsToPrefetch)
}
if len(s.dirtyStorage) > 0 {
Expand Down Expand Up @@ -475,18 +481,18 @@ func (s *Object) setBalance(amount *big.Int) {
func (s *Object) ReturnGas(gas *big.Int) {}

func (s *Object) deepCopy(db *DB) *Object {
Object := newObject(db, s.address, s.data)
stateObject := newObject(db, s.address, s.data)
if s.trie != nil {
Object.trie = db.db.CopyTrie(s.trie)
stateObject.trie = db.db.CopyTrie(s.trie)
}
Object.code = s.code
Object.dirtyStorage = s.dirtyStorage.Copy()
Object.originStorage = s.originStorage.Copy()
Object.pendingStorage = s.pendingStorage.Copy()
Object.suicided = s.suicided
Object.dirtyCode = s.dirtyCode
Object.deleted = s.deleted
return Object
stateObject.code = s.code
stateObject.dirtyStorage = s.dirtyStorage.Copy()
stateObject.originStorage = s.originStorage.Copy()
stateObject.pendingStorage = s.pendingStorage.Copy()
stateObject.suicided = s.suicided
stateObject.dirtyCode = s.dirtyCode
stateObject.deleted = s.deleted
return stateObject
}

//
Expand All @@ -499,72 +505,77 @@ func (s *Object) Address() common.Address {
}

// Code returns the contract/validator code associated with this object, if any.
func (s *Object) Code(db Database, isValidatorCode bool) []byte {
func (s *Object) Code(db Database) []byte {
if s.code != nil {
return s.code
}
if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) {
if bytes.Equal(s.CodeHash(), EmptyCodeHash.Bytes()) {
return nil
}
if s.validatorWrapper || isValidatorCode {
code, err := db.ValidatorCode(s.addrHash, common.BytesToHash(s.CodeHash()))
if err != nil {
s.setError(
fmt.Errorf(
"can't load validator code for address %s hash %x: %v",
s.address.Hex(), s.CodeHash(), err,
),
)
var err error
code := []byte{}
// if it's not set for validator wrapper, then it may be either contract code or validator wrapper (old version of db
// don't have any prefix to differentiate between them)
// so, if it's not set for validator wrapper, we need to check contract code as well
if !s.validatorWrapper {
code, err = db.ContractCode(s.addrHash, common.BytesToHash(s.CodeHash()))
}
// if it couldn't load contract code or it is set to validator wrapper, then it tries to fetch validator wrapper code
if s.validatorWrapper || err != nil {
vCode, errVCode := db.ValidatorCode(s.addrHash, common.BytesToHash(s.CodeHash()))
if errVCode == nil && vCode != nil {
s.code = vCode
return vCode
}
if code != nil {
s.code = code
return code
if s.validatorWrapper {
s.setError(fmt.Errorf("can't load validator code hash %x for account address hash %x : %v", s.CodeHash(), s.addrHash, err))
} else {
s.setError(fmt.Errorf("can't load contract/validator code hash %x for account address hash %x : contract code error: %v, validator code error: %v",
s.CodeHash(), s.addrHash, err, errVCode))
}
}
code, err := db.ContractCode(s.addrHash, common.BytesToHash(s.CodeHash()))
if err != nil {
s.setError(
fmt.Errorf(
"can't load code for address %s hash %x: %v",
s.address.Hex(), s.CodeHash(), err,
),
)
}
if code != nil {
s.code = code
return code
}
return nil
s.code = code
return code
}

// CodeSize returns the size of the contract/validator code associated with this object,
// or zero if none. This method is an almost mirror of Code, but uses a cache
// inside the database to avoid loading codes seen recently.
func (s *Object) CodeSize(db Database, isValidatorCode bool) int {
func (s *Object) CodeSize(db Database) int {
if s.code != nil {
return len(s.code)
}
if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) {
if bytes.Equal(s.CodeHash(), EmptyCodeHash.Bytes()) {
return 0
}
if s.validatorWrapper || isValidatorCode {
size, err := db.ValidatorCodeSize(s.addrHash, common.BytesToHash(s.CodeHash()))
if err != nil {
s.setError(fmt.Errorf("can't load validator code size %x: %v", s.CodeHash(), err))
}
if size > 0 {
var err error
size := int(0)

// if it's not set for validator wrapper, then it may be either contract code or validator wrapper (old version of db
// don't have any prefix to differentiate between them)
// so, if it's not set for validator wrapper, we need to check contract code as well
if !s.validatorWrapper {
size, err = db.ContractCodeSize(s.addrHash, common.BytesToHash(s.CodeHash()))
}
// if it couldn't get contract code or it is set to validator wrapper, then it tries to retrieve validator wrapper code
if s.validatorWrapper || err != nil {
vcSize, errVCSize := db.ValidatorCodeSize(s.addrHash, common.BytesToHash(s.CodeHash()))
if errVCSize == nil && vcSize > 0 {
return size
}
}
size, err := db.ContractCodeSize(s.addrHash, common.BytesToHash(s.CodeHash()))
if err != nil {
s.setError(fmt.Errorf("can't load code size %x: %v", s.CodeHash(), err))
if s.validatorWrapper {
s.setError(fmt.Errorf("can't load validator code size %x for account address hash %x : %v", s.CodeHash(), s.addrHash, err))
} else {
s.setError(fmt.Errorf("can't load contract/validator code size %x for account address hash %x : contract code size error: %v, validator code size error: %v",
s.CodeHash(), s.addrHash, err, errVCSize))
}
s.setError(fmt.Errorf("can't load code size %x (validator wrapper: %t): %v", s.CodeHash(), s.validatorWrapper, err))
}
return size
}

func (s *Object) SetCode(codeHash common.Hash, code []byte, isValidatorCode bool) {
prevcode := s.Code(s.db.db, isValidatorCode)
prevcode := s.Code(s.db.db)
s.db.journal.append(codeChange{
account: &s.address,
prevhash: s.CodeHash(),
Expand Down
2 changes: 1 addition & 1 deletion core/state/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ func TestSnapshot2(t *testing.T) {
so0Restored := state.getStateObject(stateobjaddr0)
// Update lazily-loaded values before comparing.
so0Restored.GetState(state.db, storageaddr)
so0Restored.Code(state.db, false)
so0Restored.Code(state.db)
// non-deleted is equal (restored)
compareStateObjects(so0Restored, so0, t)

Expand Down
16 changes: 7 additions & 9 deletions core/state/statedb.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,18 +342,18 @@ func (db *DB) BlockHash() common.Hash {
return db.bhash
}

func (db *DB) GetCode(addr common.Address, isValidatorCode bool) []byte {
func (db *DB) GetCode(addr common.Address) []byte {
Object := db.getStateObject(addr)
if Object != nil {
return Object.Code(db.db, isValidatorCode)
return Object.Code(db.db)
}
return nil
}

func (db *DB) GetCodeSize(addr common.Address, isValidatorCode bool) int {
func (db *DB) GetCodeSize(addr common.Address) int {
Object := db.getStateObject(addr)
if Object != nil {
return Object.CodeSize(db.db, isValidatorCode)
return Object.CodeSize(db.db)
}
return 0
}
Expand Down Expand Up @@ -1241,13 +1241,11 @@ func (db *DB) ValidatorWrapper(
return copyValidatorWrapperIfNeeded(cached, sendOriginal, copyDelegations), nil
}

by := db.GetCode(addr, true)
by := db.GetCode(addr)
if len(by) == 0 {
by = db.GetCode(addr, false)
if len(by) == 0 {
return nil, ErrAddressNotPresent
}
return nil, ErrAddressNotPresent
}

val := stk.ValidatorWrapper{}
if err := rlp.DecodeBytes(by, &val); err != nil {
return nil, errors.Wrapf(
Expand Down
22 changes: 11 additions & 11 deletions core/state/statedb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,9 +452,9 @@ func (test *snapshotTest) checkEqual(state, checkstate *DB) error {
checkeq("HasSuicided", state.HasSuicided(addr), checkstate.HasSuicided(addr))
checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr))
checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr))
checkeq("GetCode", state.GetCode(addr, false), checkstate.GetCode(addr, false))
checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr))
checkeq("GetCodeSize", state.GetCodeSize(addr, false), checkstate.GetCodeSize(addr, false))
checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
// Check storage.
if obj := state.getStateObject(addr); obj != nil {
state.ForEachStorage(addr, func(key, value common.Hash) bool {
Expand Down Expand Up @@ -532,7 +532,7 @@ func TestCopyCommitCopy(t *testing.T) {
if balance := state.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 {
t.Fatalf("initial balance mismatch: have %v, want %v", balance, 42)
}
if code := state.GetCode(addr, false); !bytes.Equal(code, []byte("hello")) {
if code := state.GetCode(addr); !bytes.Equal(code, []byte("hello")) {
t.Fatalf("initial code mismatch: have %x, want %x", code, []byte("hello"))
}
if val := state.GetState(addr, skey); val != sval {
Expand All @@ -546,7 +546,7 @@ func TestCopyCommitCopy(t *testing.T) {
if balance := copyOne.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 {
t.Fatalf("first copy pre-commit balance mismatch: have %v, want %v", balance, 42)
}
if code := copyOne.GetCode(addr, false); !bytes.Equal(code, []byte("hello")) {
if code := copyOne.GetCode(addr); !bytes.Equal(code, []byte("hello")) {
t.Fatalf("first copy pre-commit code mismatch: have %x, want %x", code, []byte("hello"))
}
if val := copyOne.GetState(addr, skey); val != sval {
Expand All @@ -560,7 +560,7 @@ func TestCopyCommitCopy(t *testing.T) {
if balance := copyOne.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 {
t.Fatalf("first copy post-commit balance mismatch: have %v, want %v", balance, 42)
}
if code := copyOne.GetCode(addr, false); !bytes.Equal(code, []byte("hello")) {
if code := copyOne.GetCode(addr); !bytes.Equal(code, []byte("hello")) {
t.Fatalf("first copy post-commit code mismatch: have %x, want %x", code, []byte("hello"))
}
if val := copyOne.GetState(addr, skey); val != sval {
Expand All @@ -574,7 +574,7 @@ func TestCopyCommitCopy(t *testing.T) {
if balance := copyTwo.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 {
t.Fatalf("second copy balance mismatch: have %v, want %v", balance, 42)
}
if code := copyTwo.GetCode(addr, false); !bytes.Equal(code, []byte("hello")) {
if code := copyTwo.GetCode(addr); !bytes.Equal(code, []byte("hello")) {
t.Fatalf("second copy code mismatch: have %x, want %x", code, []byte("hello"))
}
if val := copyTwo.GetState(addr, skey); val != sval {
Expand Down Expand Up @@ -604,7 +604,7 @@ func TestCopyCopyCommitCopy(t *testing.T) {
if balance := state.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 {
t.Fatalf("initial balance mismatch: have %v, want %v", balance, 42)
}
if code := state.GetCode(addr, false); !bytes.Equal(code, []byte("hello")) {
if code := state.GetCode(addr); !bytes.Equal(code, []byte("hello")) {
t.Fatalf("initial code mismatch: have %x, want %x", code, []byte("hello"))
}
if val := state.GetState(addr, skey); val != sval {
Expand All @@ -618,7 +618,7 @@ func TestCopyCopyCommitCopy(t *testing.T) {
if balance := copyOne.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 {
t.Fatalf("first copy balance mismatch: have %v, want %v", balance, 42)
}
if code := copyOne.GetCode(addr, false); !bytes.Equal(code, []byte("hello")) {
if code := copyOne.GetCode(addr); !bytes.Equal(code, []byte("hello")) {
t.Fatalf("first copy code mismatch: have %x, want %x", code, []byte("hello"))
}
if val := copyOne.GetState(addr, skey); val != sval {
Expand All @@ -632,7 +632,7 @@ func TestCopyCopyCommitCopy(t *testing.T) {
if balance := copyTwo.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 {
t.Fatalf("second copy pre-commit balance mismatch: have %v, want %v", balance, 42)
}
if code := copyTwo.GetCode(addr, false); !bytes.Equal(code, []byte("hello")) {
if code := copyTwo.GetCode(addr); !bytes.Equal(code, []byte("hello")) {
t.Fatalf("second copy pre-commit code mismatch: have %x, want %x", code, []byte("hello"))
}
if val := copyTwo.GetState(addr, skey); val != sval {
Expand All @@ -645,7 +645,7 @@ func TestCopyCopyCommitCopy(t *testing.T) {
if balance := copyTwo.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 {
t.Fatalf("second copy post-commit balance mismatch: have %v, want %v", balance, 42)
}
if code := copyTwo.GetCode(addr, false); !bytes.Equal(code, []byte("hello")) {
if code := copyTwo.GetCode(addr); !bytes.Equal(code, []byte("hello")) {
t.Fatalf("second copy post-commit code mismatch: have %x, want %x", code, []byte("hello"))
}
if val := copyTwo.GetState(addr, skey); val != sval {
Expand All @@ -659,7 +659,7 @@ func TestCopyCopyCommitCopy(t *testing.T) {
if balance := copyThree.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 {
t.Fatalf("third copy balance mismatch: have %v, want %v", balance, 42)
}
if code := copyThree.GetCode(addr, false); !bytes.Equal(code, []byte("hello")) {
if code := copyThree.GetCode(addr); !bytes.Equal(code, []byte("hello")) {
t.Fatalf("third copy code mismatch: have %x, want %x", code, []byte("hello"))
}
if val := copyThree.GetState(addr, skey); val != sval {
Expand Down
2 changes: 1 addition & 1 deletion core/vm/contracts_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ func (c *crossShardXferPrecompile) RunWriteCapable(
return nil, err
}
// validate not a contract (toAddress can still be a contract)
if len(evm.StateDB.GetCode(fromAddress, false)) > 0 && !evm.IsValidator(evm.StateDB, fromAddress) {
if len(evm.StateDB.GetCode(fromAddress)) > 0 && !evm.IsValidator(evm.StateDB, fromAddress) {
return nil, errors.New("cross shard xfer not yet implemented for contracts")
}
// can't have too many shards
Expand Down
Loading

0 comments on commit 6577b0b

Please sign in to comment.