From ce568f208c38fafa7d3f6303a5c1646ebb72a28b Mon Sep 17 00:00:00 2001 From: Oba Date: Thu, 1 Aug 2024 17:28:20 +0200 Subject: [PATCH] feat: add code hash to storage account (#1309) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Time spent on this PR: ## Pull request type Please check the type of change your PR introduces: - [ ] Bugfix - [x] Feature - [ ] Code style update (formatting, renaming) - [ ] Refactoring (no functional changes, no api changes) - [ ] Build related changes - [ ] Documentation content changes - [ ] Other (please describe): ## What is the current behavior? Resolves #1267 ## What is the new behavior? - code hash is added to model.Account - it is computed only once when creating a contract account - - - This change is [Reviewable](https://reviewable.io/reviews/kkrt-labs/kakarot/1309) --------- Co-authored-by: Mathieu <60658558+enitrat@users.noreply.github.com> --- docs/general/accounts.md | 1 + src/backend/starknet.cairo | 1 + src/kakarot/account.cairo | 84 +++++++++++++++++-- src/kakarot/accounts/account_contract.cairo | 17 ++++ src/kakarot/accounts/library.cairo | 17 ++++ .../environmental_information.cairo | 19 +---- .../instructions/system_operations.cairo | 3 +- src/kakarot/interfaces/interfaces.cairo | 6 ++ src/kakarot/library.cairo | 2 + src/kakarot/model.cairo | 1 + src/kakarot/state.cairo | 11 ++- .../accounts/test_account_contract.cairo | 14 ++++ .../kakarot/accounts/test_account_contract.py | 24 +++++- .../test_environmental_information.py | 26 ++++-- tests/src/kakarot/test_account.cairo | 50 +++++++++-- tests/src/kakarot/test_account.py | 28 ++++++- tests/src/kakarot/test_kakarot.py | 6 ++ tests/src/kakarot/test_state.cairo | 20 ++++- tests/src/kakarot/test_state.py | 7 ++ 19 files changed, 292 insertions(+), 45 deletions(-) diff --git a/docs/general/accounts.md b/docs/general/accounts.md index 1ef4225d7..71ae76aaa 100644 --- a/docs/general/accounts.md +++ b/docs/general/accounts.md @@ -40,6 +40,7 @@ Account contracts store the following information: - `is_initialized`: A boolean indicating whether the account has been initialized, used to prevent reinitializing an already initialized account. - `evm_address`: The Ethereum address associated with this Starknet account. +- `code_hash`: The hash of the EVM contract account bytecode. ## Account entrypoints diff --git a/src/backend/starknet.cairo b/src/backend/starknet.cairo index 5af952789..7278972a6 100644 --- a/src/backend/starknet.cairo +++ b/src/backend/starknet.cairo @@ -221,6 +221,7 @@ namespace Internals { Internals._save_valid_jumpdests( starknet_address, self.valid_jumpdests_start, self.valid_jumpdests ); + IAccount.set_code_hash(starknet_address, [self.code_hash]); return (); } diff --git a/src/kakarot/account.cairo b/src/kakarot/account.cairo index ae6f01e4e..e09382bff 100644 --- a/src/kakarot/account.cairo +++ b/src/kakarot/account.cairo @@ -21,18 +21,20 @@ from starkware.cairo.common.hash_state import ( ) from starkware.starknet.common.storage import normalize_address from starkware.starknet.common.syscalls import get_contract_address - +from starkware.cairo.lang.compiler.lib.registers import get_ap from kakarot.constants import Constants from kakarot.storages import ( Kakarot_uninitialized_account_class_hash, Kakarot_native_token_address, Kakarot_account_contract_class_hash, + Kakarot_cairo1_helpers_class_hash, ) -from kakarot.interfaces.interfaces import IAccount, IERC20 +from kakarot.interfaces.interfaces import IAccount, IERC20, ICairo1Helpers from kakarot.model import model from kakarot.storages import Kakarot_evm_to_starknet_address from utils.dict import default_dict_copy from utils.utils import Helpers +from utils.bytes import bytes_to_bytes8_little_endian namespace Account { // @notice Create a new account @@ -44,7 +46,12 @@ namespace Account { // @return The updated state // @return The account func init( - address: model.Address*, code_len: felt, code: felt*, nonce: felt, balance: Uint256* + address: model.Address*, + code_len: felt, + code: felt*, + code_hash: Uint256*, + nonce: felt, + balance: Uint256*, ) -> model.Account* { let (storage_start) = default_dict_new(0); let (transient_storage_start) = default_dict_new(0); @@ -53,6 +60,7 @@ namespace Account { address=address, code_len=code_len, code=code, + code_hash=code_hash, storage_start=storage_start, storage=storage_start, transient_storage_start=transient_storage_start, @@ -82,6 +90,7 @@ namespace Account { address=self.address, code_len=self.code_len, code=self.code, + code_hash=self.code_hash, storage_start=storage_start, storage=storage, transient_storage_start=transient_storage_start, @@ -114,8 +123,17 @@ namespace Account { tempvar address = new model.Address(starknet=starknet_address, evm=evm_address); let balance = fetch_balance(address); assert balance_ptr = new Uint256(balance.low, balance.high); + // empty code hash see https://eips.ethereum.org/EIPS/eip-1052 + tempvar code_hash_ptr = new Uint256( + 304396909071904405792975023732328604784, 262949717399590921288928019264691438528 + ); let account = Account.init( - address=address, code_len=0, code=bytecode, nonce=0, balance=balance_ptr + address=address, + code_len=0, + code=bytecode, + code_hash=code_hash_ptr, + nonce=0, + balance=balance_ptr, ); return account; } @@ -126,14 +144,21 @@ namespace Account { let (bytecode_len, bytecode) = IAccount.bytecode(contract_address=starknet_address); let (nonce) = IAccount.get_nonce(contract_address=starknet_address); + IAccount.get_code_hash(contract_address=starknet_address); + let (ap_val) = get_ap(); + let code_hash = cast(ap_val - 2, Uint256*); // CAs are instantiated with their actual nonce - EOAs are instantiated with the nonce=1 // that is set when they're deployed. - // If an account was created-selfdestructed in the same tx, its nonce is 0, thus // it is considered as a new account as per the `has_code_or_nonce` rule. let account = Account.init( - address=address, code_len=bytecode_len, code=bytecode, nonce=nonce, balance=balance_ptr + address=address, + code_len=bytecode_len, + code=bytecode, + code_hash=code_hash, + nonce=nonce, + balance=balance_ptr, ); return account; } @@ -161,6 +186,7 @@ namespace Account { address=self.address, code_len=self.code_len, code=self.code, + code_hash=self.code_hash, storage_start=self.storage_start, storage=storage, transient_storage_start=self.transient_storage_start, @@ -200,6 +226,7 @@ namespace Account { address=self.address, code_len=self.code_len, code=self.code, + code_hash=self.code_hash, storage_start=self.storage_start, storage=storage, transient_storage_start=self.transient_storage_start, @@ -229,6 +256,7 @@ namespace Account { address=self.address, code_len=self.code_len, code=self.code, + code_hash=self.code_hash, storage_start=self.storage_start, storage=storage, transient_storage_start=self.transient_storage_start, @@ -258,6 +286,7 @@ namespace Account { address=self.address, code_len=self.code_len, code=self.code, + code_hash=self.code_hash, storage_start=self.storage_start, storage=self.storage, transient_storage_start=self.transient_storage_start, @@ -296,6 +325,7 @@ namespace Account { address=self.address, code_len=self.code_len, code=self.code, + code_hash=self.code_hash, storage_start=self.storage_start, storage=self.storage, transient_storage_start=self.transient_storage_start, @@ -319,14 +349,19 @@ namespace Account { // @param code_len The len of the code // @param code The code array // @return The updated Account with the code and valid jumpdests set - func set_code{range_check_ptr}( + func set_code{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( self: model.Account*, code_len: felt, code: felt* ) -> model.Account* { + alloc_locals; + compute_code_hash(code_len, code); + let (ap_val) = get_ap(); + let code_hash = cast(ap_val - 2, Uint256*); let (valid_jumpdests_start, valid_jumpdests) = Helpers.initialize_jumpdests(code_len, code); return new model.Account( address=self.address, code_len=code_len, code=code, + code_hash=code_hash, storage_start=self.storage_start, storage=self.storage, transient_storage_start=self.transient_storage_start, @@ -348,6 +383,7 @@ namespace Account { address=self.address, code_len=self.code_len, code=self.code, + code_hash=self.code_hash, storage_start=self.storage_start, storage=self.storage, transient_storage_start=self.transient_storage_start, @@ -367,6 +403,7 @@ namespace Account { address=self.address, code_len=self.code_len, code=self.code, + code_hash=self.code_hash, storage_start=self.storage_start, storage=self.storage, transient_storage_start=self.transient_storage_start, @@ -421,6 +458,7 @@ namespace Account { address=self.address, code_len=self.code_len, code=self.code, + code_hash=self.code_hash, storage_start=self.storage_start, storage=self.storage, transient_storage_start=self.transient_storage_start, @@ -442,6 +480,7 @@ namespace Account { address=self.address, code_len=self.code_len, code=self.code, + code_hash=self.code_hash, storage_start=self.storage_start, storage=self.storage, transient_storage_start=self.transient_storage_start, @@ -578,6 +617,7 @@ namespace Account { address=self.address, code_len=self.code_len, code=self.code, + code_hash=self.code_hash, storage_start=self.storage_start, storage=self.storage, transient_storage_start=self.transient_storage_start, @@ -603,6 +643,7 @@ namespace Account { address=self.address, code_len=self.code_len, code=self.code, + code_hash=self.code_hash, storage_start=self.storage_start, storage=storage, transient_storage_start=self.transient_storage_start, @@ -637,6 +678,7 @@ namespace Account { address=self.address, code_len=self.code_len, code=self.code, + code_hash=self.code_hash, storage_start=self.storage_start, storage=storage_ptr, transient_storage_start=self.transient_storage_start, @@ -650,6 +692,34 @@ namespace Account { ); return self; } + + func compute_code_hash{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( + code_len: felt, code: felt* + ) -> Uint256 { + alloc_locals; + if (code_len == 0) { + // see https://eips.ethereum.org/EIPS/eip-1052 + let empty_code_hash = Uint256( + 304396909071904405792975023732328604784, 262949717399590921288928019264691438528 + ); + return empty_code_hash; + } + + let (local dst: felt*) = alloc(); + let (dst_len, last_word, last_word_num_bytes) = bytes_to_bytes8_little_endian( + dst, code_len, code + ); + + let (implementation) = Kakarot_cairo1_helpers_class_hash.read(); + let (code_hash) = ICairo1Helpers.library_call_keccak( + class_hash=implementation, + words_len=dst_len, + words=dst, + last_input_word=last_word, + last_input_num_bytes=last_word_num_bytes, + ); + return code_hash; + } } namespace Internals { diff --git a/src/kakarot/accounts/account_contract.cairo b/src/kakarot/accounts/account_contract.cairo index c42ac3640..e0fcfb6af 100644 --- a/src/kakarot/accounts/account_contract.cairo +++ b/src/kakarot/accounts/account_contract.cairo @@ -287,6 +287,23 @@ func is_valid_jumpdest{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_che return (is_valid=is_valid); } +@view +func get_code_hash{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> ( + code_hash: Uint256 +) { + let code_hash = AccountContract.get_code_hash(); + return (code_hash,); +} + +@external +func set_code_hash{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( + code_hash: Uint256 +) { + Ownable.assert_only_owner(); + AccountContract.set_code_hash(code_hash); + return (); +} + // @notice Authorizes a pre-eip155 transaction by message hash. // @param message_hash The hash of the message. @external diff --git a/src/kakarot/accounts/library.cairo b/src/kakarot/accounts/library.cairo index 99c6b944e..4b4aea064 100644 --- a/src/kakarot/accounts/library.cairo +++ b/src/kakarot/accounts/library.cairo @@ -66,6 +66,10 @@ func Account_jumpdests_initialized() -> (initialized: felt) { func Account_authorized_message_hashes(hash: Uint256) -> (res: felt) { } +@storage_var +func Account_code_hash() -> (code_hash: Uint256) { +} + @event func transaction_executed(response_len: felt, response: felt*, success: felt, gas_used: felt) { } @@ -436,6 +440,19 @@ namespace AccountContract { ); return is_valid_jumpdest(index=index); } + + func get_code_hash{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( + ) -> Uint256 { + let (code_hash) = Account_code_hash.read(); + return code_hash; + } + + func set_code_hash{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( + code_hash: Uint256 + ) { + Account_code_hash.write(code_hash); + return (); + } } namespace Internals { diff --git a/src/kakarot/instructions/environmental_information.cairo b/src/kakarot/instructions/environmental_information.cairo index ef06a17fb..3ffb65644 100644 --- a/src/kakarot/instructions/environmental_information.cairo +++ b/src/kakarot/instructions/environmental_information.cairo @@ -11,7 +11,7 @@ from starkware.cairo.common.math_cmp import is_not_zero, is_nn from starkware.cairo.common.uint256 import Uint256, uint256_le from kakarot.account import Account -from kakarot.interfaces.interfaces import ICairo1Helpers + from kakarot.evm import EVM from kakarot.errors import Errors from kakarot.gas import Gas @@ -19,7 +19,6 @@ from kakarot.memory import Memory from kakarot.model import model from kakarot.stack import Stack from kakarot.state import State -from kakarot.storages import Kakarot_cairo1_helpers_class_hash from utils.array import slice from utils.bytes import bytes_to_bytes8_little_endian from utils.uint256 import uint256_to_uint160, uint256_add, uint256_eq @@ -480,21 +479,7 @@ namespace EnvironmentalInformation { return evm; } - let (local dst: felt*) = alloc(); - let (dst_len, last_word, last_word_num_bytes) = bytes_to_bytes8_little_endian( - dst, account.code_len, account.code - ); - - let (implementation) = Kakarot_cairo1_helpers_class_hash.read(); - let (code_hash) = ICairo1Helpers.library_call_keccak( - class_hash=implementation, - words_len=dst_len, - words=dst, - last_input_word=last_word, - last_input_num_bytes=last_word_num_bytes, - ); - - Stack.push_uint256(code_hash); + Stack.push_uint256([account.code_hash]); return evm; } diff --git a/src/kakarot/instructions/system_operations.cairo b/src/kakarot/instructions/system_operations.cairo index 2570462bd..cb6b6f21b 100644 --- a/src/kakarot/instructions/system_operations.cairo +++ b/src/kakarot/instructions/system_operations.cairo @@ -7,13 +7,12 @@ from starkware.cairo.common.bool import TRUE, FALSE from starkware.cairo.common.cairo_builtins import HashBuiltin, BitwiseBuiltin from starkware.cairo.common.math import split_felt, unsigned_div_rem from starkware.cairo.common.math_cmp import is_nn, is_not_zero -from starkware.cairo.common.registers import get_fp_and_pc from starkware.cairo.common.uint256 import Uint256, uint256_lt, uint256_le from starkware.cairo.common.default_dict import default_dict_new from starkware.cairo.common.dict_access import DictAccess from kakarot.account import Account -from kakarot.interfaces.interfaces import IAccount, ICairo1Helpers +from kakarot.interfaces.interfaces import ICairo1Helpers from kakarot.constants import Constants from kakarot.errors import Errors from kakarot.evm import EVM diff --git a/src/kakarot/interfaces/interfaces.cairo b/src/kakarot/interfaces/interfaces.cairo index 0273feda5..ae9beef6f 100644 --- a/src/kakarot/interfaces/interfaces.cairo +++ b/src/kakarot/interfaces/interfaces.cairo @@ -79,6 +79,12 @@ namespace IAccount { to: felt, function_selector: felt, calldata_len: felt, calldata: felt* ) -> (retdata_len: felt, retdata: felt*, success: felt) { } + + func get_code_hash() -> (code_hash: Uint256) { + } + + func set_code_hash(code_hash: Uint256) { + } } @contract_interface diff --git a/src/kakarot/library.cairo b/src/kakarot/library.cairo index 20610aaaa..f658769c4 100644 --- a/src/kakarot/library.cairo +++ b/src/kakarot/library.cairo @@ -325,6 +325,8 @@ namespace Kakarot { alloc_locals; let starknet_address = Account.get_starknet_address(evm_address); IAccount.write_bytecode(starknet_address, bytecode_len, bytecode); + let code_hash = Account.compute_code_hash(bytecode_len, bytecode); + IAccount.set_code_hash(starknet_address, code_hash); return (); } diff --git a/src/kakarot/model.cairo b/src/kakarot/model.cairo index 3ddae87b5..308b5f730 100644 --- a/src/kakarot/model.cairo +++ b/src/kakarot/model.cairo @@ -69,6 +69,7 @@ namespace model { address: model.Address*, code_len: felt, code: felt*, + code_hash: Uint256*, storage_start: DictAccess*, storage: DictAccess*, transient_storage_start: DictAccess*, diff --git a/src/kakarot/state.cairo b/src/kakarot/state.cairo index ba5af9e99..6b8159d5e 100644 --- a/src/kakarot/state.cairo +++ b/src/kakarot/state.cairo @@ -464,8 +464,17 @@ namespace Internals { let balance = Account.fetch_balance(address); tempvar balance_ptr = new Uint256(balance.low, balance.high); let (bytecode) = alloc(); + // empty code hash see https://eips.ethereum.org/EIPS/eip-1052 + tempvar code_hash_ptr = new Uint256( + 304396909071904405792975023732328604784, 262949717399590921288928019264691438528 + ); let account = Account.init( - address=address, code_len=0, code=bytecode, nonce=0, balance=balance_ptr + address=address, + code_len=0, + code=bytecode, + code_hash=code_hash_ptr, + nonce=0, + balance=balance_ptr, ); dict_write{dict_ptr=accounts_ptr}(key=address.evm, new_value=cast(account, felt)); return (); diff --git a/tests/src/kakarot/accounts/test_account_contract.cairo b/tests/src/kakarot/accounts/test_account_contract.cairo index 84247212f..0f7b1420f 100644 --- a/tests/src/kakarot/accounts/test_account_contract.cairo +++ b/tests/src/kakarot/accounts/test_account_contract.cairo @@ -16,6 +16,7 @@ from kakarot.accounts.account_contract import ( set_nonce, set_authorized_pre_eip155_tx, execute_starknet_call, + set_code_hash, ) func test__initialize{ @@ -162,3 +163,16 @@ func test__is_valid_jumpdest{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, ran return is_valid; } + +func test__set_code_hash{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() { + alloc_locals; + local code_hash: Uint256; + %{ + ids.code_hash.low = program_input["code_hash"][0] + ids.code_hash.high = program_input["code_hash"][1] + %} + + set_code_hash(code_hash); + + return (); +} diff --git a/tests/src/kakarot/accounts/test_account_contract.py b/tests/src/kakarot/accounts/test_account_contract.py index b71bc93dc..982ec0a5f 100644 --- a/tests/src/kakarot/accounts/test_account_contract.py +++ b/tests/src/kakarot/accounts/test_account_contract.py @@ -7,7 +7,7 @@ from eth_account.account import Account from eth_utils import keccak from hypothesis import given, settings -from hypothesis.strategies import binary +from hypothesis.strategies import binary, integers from starkware.starknet.public.abi import ( get_selector_from_name, get_storage_var_address, @@ -258,6 +258,28 @@ def test__should_return_if_jumpdest_valid_when_not_stored( SyscallHandler.mock_storage.assert_has_calls(expected_read_calls) SyscallHandler.mock_storage.assert_has_calls(expected_write_calls) + class TestCodeHash: + @given(code_hash=integers(min_value=0, max_value=2**256 - 1)) + @SyscallHandler.patch("Ownable_owner", 0xDEAD) + def test_should_assert_only_owner(self, cairo_run, code_hash): + with cairo_error(message="Ownable: caller is not the owner"): + cairo_run("test__set_code_hash", code_hash=int_to_uint256(code_hash)) + + @given(code_hash=integers(min_value=0, max_value=2**256 - 1)) + @SyscallHandler.patch("Ownable_owner", SyscallHandler.caller_address) + def test__should_set_code_hash(self, cairo_run, code_hash): + with patch.object(SyscallHandler, "mock_storage") as mock_storage: + low, high = int_to_uint256(code_hash) + cairo_run("test__set_code_hash", code_hash=(low, high)) + code_hash_address = get_storage_var_address("Account_code_hash") + ownable_address = get_storage_var_address("Ownable_owner") + calls = [ + call(address=ownable_address), + call(address=code_hash_address, value=low), + call(address=code_hash_address + 1, value=high), + ] + mock_storage.assert_has_calls(calls) + class TestSetAuthorizedPreEIP155Transactions: def test_should_assert_only_owner(self, cairo_run): with cairo_error(message="Ownable: caller is not the owner"): diff --git a/tests/src/kakarot/instructions/test_environmental_information.py b/tests/src/kakarot/instructions/test_environmental_information.py index 3456d26e8..f990ac5a1 100644 --- a/tests/src/kakarot/instructions/test_environmental_information.py +++ b/tests/src/kakarot/instructions/test_environmental_information.py @@ -3,8 +3,8 @@ import pytest from Crypto.Hash import keccak -from tests.utils.constants import CAIRO1_HELPERS_CLASS_HASH from tests.utils.syscall_handler import SyscallHandler +from tests.utils.uint256 import int_to_uint256 EXISTING_ACCOUNT = 0xABDE1 EXISTING_ACCOUNT_SN_ADDR = 0x1234 @@ -43,6 +43,9 @@ class TestExtCodeSize: @SyscallHandler.patch( "Kakarot_evm_to_starknet_address", EXISTING_ACCOUNT, 0x1234 ) + @SyscallHandler.patch( + "IAccount.get_code_hash", lambda sn_addr, data: [0x1, 0x1] + ) def test_extcodesize_should_push_code_size(self, cairo_run, bytecode, address): with SyscallHandler.patch( "IAccount.bytecode", lambda addr, data: [len(bytecode), *bytecode] @@ -66,6 +69,9 @@ class TestExtCodeCopy: @SyscallHandler.patch( "Kakarot_evm_to_starknet_address", EXISTING_ACCOUNT, 0x1234 ) + @SyscallHandler.patch( + "IAccount.get_code_hash", lambda sn_addr, data: [0x1, 0x1] + ) def test_extcodecopy_should_copy_code( self, cairo_run, size, offset, dest_offset, bytecode, address ): @@ -106,6 +112,9 @@ def test_extcodecopy_should_copy_code( @SyscallHandler.patch( "Kakarot_evm_to_starknet_address", EXISTING_ACCOUNT, 0x1234 ) + @SyscallHandler.patch( + "IAccount.get_code_hash", lambda sn_addr, data: [0x1, 0x1] + ) def test_extcodecopy_offset_high_zellic_issue_1258( self, cairo_run, size, bytecode, address ): @@ -206,15 +215,18 @@ class TestExtCodeHash: EXISTING_ACCOUNT, EXISTING_ACCOUNT_SN_ADDR, ) - @SyscallHandler.patch( - "Kakarot_cairo1_helpers_class_hash", - CAIRO1_HELPERS_CLASS_HASH, - ) def test_extcodehash__should_push_hash( self, cairo_run, bytecode, bytecode_hash, address ): - with SyscallHandler.patch( - "IAccount.bytecode", lambda sn_addr, data: [len(bytecode), *bytecode] + with ( + SyscallHandler.patch( + "IAccount.bytecode", + lambda sn_addr, data: [len(bytecode), *bytecode], + ), + SyscallHandler.patch( + "IAccount.get_code_hash", + lambda sn_addr, data: [*int_to_uint256(bytecode_hash)], + ), ): output = cairo_run("test__exec_extcodehash", address=address) diff --git a/tests/src/kakarot/test_account.cairo b/tests/src/kakarot/test_account.cairo index 659244a4f..98d83d5df 100644 --- a/tests/src/kakarot/test_account.cairo +++ b/tests/src/kakarot/test_account.cairo @@ -20,12 +20,16 @@ func test__init__should_return_account_with_default_dict_as_storage{ local evm_address: felt; local code_len: felt; let (code) = alloc(); + let (code_hash_ptr) = alloc(); local nonce: felt; local balance_low: felt; %{ + from tests.utils.uint256 import int_to_uint256 + ids.evm_address = program_input["evm_address"] ids.code_len = len(program_input["code"]) segments.write_arg(ids.code, program_input["code"]) + segments.write_arg(ids.code_hash_ptr, int_to_uint256(program_input["code_hash"])) ids.nonce = program_input["nonce"] ids.balance_low = program_input["balance_low"] %} @@ -35,7 +39,9 @@ func test__init__should_return_account_with_default_dict_as_storage{ tempvar balance = new Uint256(balance_low, 0); // When - let account = Account.init(address, code_len, code, nonce, balance); + let account = Account.init( + address, code_len, code, cast(code_hash_ptr, Uint256*), nonce, balance + ); // Then assert account.address = address; @@ -58,12 +64,16 @@ func test__copy__should_return_new_account_with_same_attributes{ local evm_address: felt; local code_len: felt; let (code) = alloc(); + let (code_hash_ptr) = alloc(); local nonce: felt; local balance_low: felt; %{ + from tests.utils.uint256 import int_to_uint256 + ids.evm_address = program_input["evm_address"] ids.code_len = len(program_input["code"]) segments.write_arg(ids.code, program_input["code"]) + segments.write_arg(ids.code_hash_ptr, int_to_uint256(program_input["code_hash"])) ids.nonce = program_input["nonce"] ids.balance_low = program_input["balance_low"] %} @@ -71,7 +81,9 @@ func test__copy__should_return_new_account_with_same_attributes{ let starknet_address = Account.compute_starknet_address(evm_address); tempvar address = new model.Address(starknet=starknet_address, evm=evm_address); tempvar balance = new Uint256(balance_low, 0); - let account = Account.init(address, code_len, code, nonce, balance); + let account = Account.init( + address, code_len, code, cast(code_hash_ptr, Uint256*), nonce, balance + ); tempvar key = new Uint256(1, 2); tempvar value = new Uint256(3, 4); let account = Account.write_storage(account, key, value); @@ -124,8 +136,9 @@ func test__write_storage__should_store_value_at_key{ let starknet_address = Account.compute_starknet_address(0); tempvar address = new model.Address(starknet_address, 0); let (local code: felt*) = alloc(); + tempvar code_hash = new Uint256(0, 0); tempvar balance = new Uint256(0, 0); - let account = Account.init(address, 0, code, 0, balance); + let account = Account.init(address, 0, code, code_hash, 0, balance); // When let account = Account.write_storage(account, key, value); @@ -160,8 +173,11 @@ func test__fetch_original_storage__state_modified{ let starknet_address = Account.compute_starknet_address(evm_address); tempvar address = new model.Address(starknet_address, evm_address); let (local code: felt*) = alloc(); + tempvar code_hash = new Uint256( + 304396909071904405792975023732328604784, 262949717399590921288928019264691438528 + ); tempvar balance = new Uint256(0, 0); - let account = Account.init(address, 0, code, 0, balance); + let account = Account.init(address, 0, code, code_hash, 0, balance); // When let account = Account.write_storage(account, key, value); @@ -177,17 +193,23 @@ func test__has_code_or_nonce{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, ran // Given local code_len: felt; let (code) = alloc(); + let (code_hash_ptr) = alloc(); local nonce: felt; %{ + from tests.utils.uint256 import int_to_uint256 + ids.code_len = len(program_input["code"]) segments.write_arg(ids.code, program_input["code"]) + segments.write_arg(ids.code_hash_ptr, int_to_uint256(program_input["code_hash"])) ids.nonce = program_input["nonce"] %} let starknet_address = Account.compute_starknet_address(0); tempvar address = new model.Address(starknet_address, 0); tempvar balance = new Uint256(0, 0); - let account = Account.init(address, code_len, code, nonce, balance); + let account = Account.init( + address, code_len, code, cast(code_hash_ptr, Uint256*), nonce, balance + ); // When let result = Account.has_code_or_nonce(account); @@ -195,3 +217,21 @@ func test__has_code_or_nonce{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, ran // Then return result; } + +func test__compute_code_hash{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( + ) -> Uint256 { + alloc_locals; + // Given + local code_len: felt; + let (code) = alloc(); + %{ + ids.code_len = len(program_input["code"]) + segments.write_arg(ids.code, program_input["code"]) + %} + + // When + let result = Account.compute_code_hash(code_len, code); + + // Then + return result; +} diff --git a/tests/src/kakarot/test_account.py b/tests/src/kakarot/test_account.py index 65e01281a..b9ebdcf67 100644 --- a/tests/src/kakarot/test_account.py +++ b/tests/src/kakarot/test_account.py @@ -1,4 +1,7 @@ import pytest +from eth_utils import keccak +from hypothesis import given +from hypothesis.strategies import binary from tests.utils.syscall_handler import SyscallHandler from tests.utils.uint256 import int_to_uint256 @@ -13,12 +16,15 @@ class TestInit: def test_should_return_account_with_default_dict_as_storage( self, cairo_run, address, code, nonce, balance ): + code_hash_bytes = keccak(bytes(code)) + code_hash = int.from_bytes(code_hash_bytes, "big") cairo_run( "test__init__should_return_account_with_default_dict_as_storage", evm_address=address, code=code, nonce=nonce, balance_low=balance, + code_hash=code_hash, ) class TestCopy: @@ -29,12 +35,15 @@ class TestCopy: def test_should_return_new_account_with_same_attributes( self, cairo_run, address, code, nonce, balance ): + code_hash_bytes = keccak(bytes(code)) + code_hash = int.from_bytes(code_hash_bytes, "big") cairo_run( "test__copy__should_return_new_account_with_same_attributes", evm_address=address, code=code, nonce=nonce, balance_low=balance, + code_hash=code_hash, ) class TestWriteStorage: @@ -88,8 +97,25 @@ class TestHasCodeOrNonce: (1, [1], True), ), ) + @SyscallHandler.patch( + "IAccount.get_code_hash", lambda sn_addr, data: [0x1, 0x1] + ) def test_should_return_true_when_nonce( self, cairo_run, nonce, code, expected_result ): - output = cairo_run("test__has_code_or_nonce", nonce=nonce, code=code) + code_hash_bytes = keccak(bytes(code)) + code_hash = int.from_bytes(code_hash_bytes, "big") + output = cairo_run( + "test__has_code_or_nonce", nonce=nonce, code=code, code_hash=code_hash + ) assert output == expected_result + + class TestComputeCodeHash: + @given(bytecode=binary(min_size=0, max_size=400)) + def test_should_compute_code_hash(self, cairo_run, bytecode): + output = cairo_run( + "test__compute_code_hash", + code=bytecode, + ) + code_hash = int.from_bytes(keccak(bytecode), byteorder="big") + assert int(output, 16) == code_hash diff --git a/tests/src/kakarot/test_kakarot.py b/tests/src/kakarot/test_kakarot.py index 5487952e4..3b29a7326 100644 --- a/tests/src/kakarot/test_kakarot.py +++ b/tests/src/kakarot/test_kakarot.py @@ -297,6 +297,9 @@ class TestEthCall: "IAccount.is_valid_jumpdest", lambda addr, data: [1], ) + @SyscallHandler.patch( + "IAccount.get_code_hash", lambda sn_addr, data: [0x1, 0x1] + ) async def test_erc20_transfer(self, get_contract): erc20 = await get_contract("Solmate", "ERC20") amount = int(1e18) @@ -321,6 +324,9 @@ async def test_erc20_transfer(self, get_contract): "IAccount.is_valid_jumpdest", lambda addr, data: [1], ) + @SyscallHandler.patch( + "IAccount.get_code_hash", lambda sn_addr, data: [0x1, 0x1] + ) async def test_erc721_transfer(self, get_contract): erc721 = await get_contract("Solmate", "ERC721") token_id = 1337 diff --git a/tests/src/kakarot/test_state.cairo b/tests/src/kakarot/test_state.cairo index a44b09ebc..9b0ee80b9 100644 --- a/tests/src/kakarot/test_state.cairo +++ b/tests/src/kakarot/test_state.cairo @@ -111,18 +111,24 @@ func test__is_account_alive__account_alive_in_state{ local balance_low: felt; local code_len: felt; let (code) = alloc(); + let (code_hash_ptr) = alloc(); %{ + from tests.utils.uint256 import int_to_uint256 + ids.nonce = program_input["nonce"] ids.balance_low = program_input["balance_low"] ids.code_len = len(program_input["code"]) - segments.write_arg(ids.code, program_input["code"]); + segments.write_arg(ids.code, program_input["code"]) + segments.write_arg(ids.code_hash_ptr, int_to_uint256(program_input["code_hash"])) %} let evm_address = 'alive'; let starknet_address = Account.compute_starknet_address(evm_address); tempvar address = new model.Address(starknet_address, evm_address); tempvar balance = new Uint256(balance_low, 0); - let account = Account.init(address, code_len, code, nonce, balance); + let account = Account.init( + address, code_len, code, cast(code_hash_ptr, Uint256*), nonce, balance + ); let state = State.init(); with state { @@ -166,7 +172,10 @@ func test___copy_accounts__should_handle_null_pointers{range_check_ptr}() { tempvar address = new model.Address(1, 2); tempvar balance = new Uint256(1, 0); let (code) = alloc(); - let account = Account.init(address, 0, code, 1, balance); + tempvar code_hash = new Uint256( + 304396909071904405792975023732328604784, 262949717399590921288928019264691438528 + ); + let account = Account.init(address, 0, code, code_hash, 1, balance); dict_write{dict_ptr=accounts}(address.evm, cast(account, felt)); let empty_address = 'empty address'; dict_read{dict_ptr=accounts}(empty_address); @@ -193,7 +202,10 @@ func test__is_account_warm__account_in_state{ tempvar address = new model.Address(starknet_address, evm_address); tempvar balance = new Uint256(1, 0); let (code) = alloc(); - let account = Account.init(address, 0, code, 1, balance); + tempvar code_hash = new Uint256( + 304396909071904405792975023732328604784, 262949717399590921288928019264691438528 + ); + let account = Account.init(address, 0, code, code_hash, 1, balance); tempvar state = State.init(); with state { diff --git a/tests/src/kakarot/test_state.py b/tests/src/kakarot/test_state.py index 3836c317c..1da73f048 100644 --- a/tests/src/kakarot/test_state.py +++ b/tests/src/kakarot/test_state.py @@ -1,4 +1,5 @@ import pytest +from eth_utils import keccak from ethereum.shanghai.transactions import ( TX_ACCESS_LIST_ADDRESS_COST, TX_ACCESS_LIST_STORAGE_KEY_COST, @@ -33,10 +34,13 @@ class TestIsAccountAlive: def test_should_return_true_when_existing_account_cached( self, cairo_run, nonce, code, balance_low, expected_result ): + code_hash_bytes = keccak(bytes(code)) + code_hash = int.from_bytes(code_hash_bytes, "big") is_alive = cairo_run( "test__is_account_alive__account_alive_in_state", nonce=nonce, code=code, + code_hash=code_hash, balance_low=balance_low, ) assert is_alive == expected_result @@ -45,6 +49,9 @@ def test_should_return_true_when_existing_account_cached( @SyscallHandler.patch("IERC20.balanceOf", lambda addr, data: [0, 1]) @SyscallHandler.patch("IAccount.get_nonce", lambda addr, data: [1]) @SyscallHandler.patch("Kakarot_evm_to_starknet_address", 0xABDE1, 0x1234) + @SyscallHandler.patch( + "IAccount.get_code_hash", lambda sn_addr, data: [0x1, 0x1] + ) def test_should_return_true_when_existing_account_not_cached(self, cairo_run): cairo_run( "test__is_account_alive__account_alive_not_in_state",