Skip to content

Commit

Permalink
opti: charge_gas (#1323)
Browse files Browse the repository at this point in the history
<!--- Please provide a general summary of your changes in the title
above -->

<!-- Give an estimate of the time you spent on this PR in terms of work
days.
Did you spend 0.5 days on this PR or rather 2 days?  -->

Time spent on this PR:

## Pull request type

<!-- Please try to limit your pull request to one type,
submit multiple pull requests if needed. -->

Please check the type of change your PR introduces:

- [ ] Bugfix
- [ ] Feature
- [ ] Code style update (formatting, renaming)
- [ ] Refactoring (no functional changes, no api changes)
- [ ] Build related changes
- [ ] Documentation content changes
- [x] Other (please describe): opti

## What is the current behavior?

<!-- Please describe the current behavior that you are modifying,
or link to a relevant issue. -->

Resolves #1261

## What is the new behavior?

<!-- Please describe the behavior or changes that are being added by
this PR. -->

### Fix async tests
Tests using `get_contract` were skipped in local and CI due to not being
able to run this function as `get_solidity_contract` was not awaited and
`get_contract` not marked as async. With those changes, they now run in
local and CI. They are using a sync version of `get_contract`.

### chage_gas consume less step
A gain of 9.5% 
* Before: 156433
* After with inlining: 141777

Test run: `poetry run pytest tests/src/kakarot -k
test_loop_profiling\[10] --profile-cairo`

Before: 156433

![before](https://github.com/user-attachments/assets/b3288cd3-fea2-4fcc-a1f0-aff599366092)

After with inlining: 141777

![profile004](https://github.com/user-attachments/assets/61600142-7362-4784-930c-1b25fd20a65a)

<!-- Reviewable:start -->
- - -
This change is [<img src="https://reviewable.io/review_button.svg"
height="34" align="absmiddle"
alt="Reviewable"/>](https://reviewable.io/reviews/kkrt-labs/kakarot/1323)
<!-- Reviewable:end -->

---------

Co-authored-by: Mathieu <60658558+enitrat@users.noreply.github.com>
  • Loading branch information
obatirou and enitrat authored Aug 7, 2024
1 parent 67fbb3d commit 7112b36
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 40 deletions.
28 changes: 28 additions & 0 deletions kakarot_scripts/utils/kakarot.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,34 @@ async def get_contract(
return contract


def get_contract_sync(
contract_app: str,
contract_name: str,
address=None,
caller_eoa: Optional[Account] = None,
) -> Web3Contract:

artifacts = get_solidity_artifacts(contract_app, contract_name)

contract = cast(
Web3Contract,
WEB3.eth.contract(
address=to_checksum_address(address) if address is not None else address,
abi=artifacts["abi"],
bytecode=artifacts["bytecode"]["object"],
),
)
contract.bytecode_runtime = HexBytes(artifacts["bytecode_runtime"]["object"])

try:
for fun in contract.functions:
setattr(contract, fun, MethodType(_wrap_kakarot(fun, caller_eoa), contract))
except NoABIFunctionsFound:
pass
contract.events.parse_events = MethodType(_parse_events, contract.events)
return contract


@alru_cache()
async def get_or_deploy_library(library_app: str, library_name: str) -> str:
"""
Expand Down
56 changes: 40 additions & 16 deletions src/kakarot/evm.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.bool import TRUE, FALSE
from starkware.cairo.common.cairo_builtins import HashBuiltin, BitwiseBuiltin
from starkware.cairo.common.math_cmp import is_nn, is_le_felt
from starkware.cairo.common.math_cmp import is_nn, is_le_felt, RC_BOUND
from starkware.cairo.common.math import assert_le_felt
from starkware.cairo.common.memcpy import memcpy
from starkware.cairo.common.registers import get_label_location
from starkware.cairo.common.uint256 import Uint256
Expand Down Expand Up @@ -140,36 +141,59 @@ namespace EVM {

// @notice Subtracts `amount` from `evm.gas_left`.
// @dev The gas left is decremented by the given amount.
// Use code adapted from is_nn.
// Assumption: gas_left < 2 ** 128
// @param self The pointer to the current execution context.
// @param amount The amount of gas the current operation requires.
// @return EVM The pointer to the updated execution context.
func charge_gas{range_check_ptr}(self: model.EVM*, amount: felt) -> model.EVM* {
let out_of_gas = is_le_felt(self.gas_left + 1, amount);
// This is equivalent to is_nn(self.gas_left - amount)
tempvar a = self.gas_left - amount; // a is necessary for using the whitelisted hint
%{ memory[ap] = 0 if 0 <= (ids.a % PRIME) < range_check_builtin.bound else 1 %}
jmp out_of_range if [ap] != 0, ap++;
[range_check_ptr] = a;
ap += 20;
tempvar range_check_ptr = range_check_ptr + 1;
jmp enough_gas;

if (out_of_gas != 0) {
let (revert_reason_len, revert_reason) = Errors.outOfGas(self.gas_left, amount);
return new model.EVM(
message=self.message,
return_data_len=revert_reason_len,
return_data=revert_reason,
program_counter=self.program_counter,
stopped=TRUE,
gas_left=0,
gas_refund=self.gas_refund,
reverted=Errors.EXCEPTIONAL_HALT,
);
}
out_of_range:
%{ memory[ap] = 0 if 0 <= ((-ids.a - 1) % PRIME) < range_check_builtin.bound else 1 %}
jmp need_felt_comparison if [ap] != 0, ap++;
assert [range_check_ptr] = (-a) - 1;
ap += 17;
tempvar range_check_ptr = range_check_ptr + 1;
jmp not_enough_gas;

need_felt_comparison:
assert_le_felt(RC_BOUND, a);
jmp not_enough_gas;

enough_gas:
let range_check_ptr = [ap - 1];
return new model.EVM(
message=self.message,
return_data_len=self.return_data_len,
return_data=self.return_data,
program_counter=self.program_counter,
stopped=self.stopped,
gas_left=self.gas_left - amount,
gas_left=a,
gas_refund=self.gas_refund,
reverted=self.reverted,
);

not_enough_gas:
let range_check_ptr = [ap - 1];
let (revert_reason_len, revert_reason) = Errors.outOfGas(self.gas_left, amount);
return new model.EVM(
message=self.message,
return_data_len=revert_reason_len,
return_data=revert_reason,
program_counter=self.program_counter,
stopped=TRUE,
gas_left=0,
gas_refund=self.gas_refund,
reverted=Errors.EXCEPTIONAL_HALT,
);
}

func halt_validation_failed{range_check_ptr}(self: model.EVM*) -> model.EVM* {
Expand Down
13 changes: 13 additions & 0 deletions tests/src/kakarot/test_evm.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,16 @@ func test__is_valid_jumpdest{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, ran

return result;
}

func test__charge_gas{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> (
felt, felt
) {
alloc_locals;
local amount;

%{ ids.amount = program_input["amount"] %}
let evm = TestHelpers.init_evm();
let result = EVM.charge_gas(evm, amount);

return (result.gas_left, result.stopped);
}
16 changes: 16 additions & 0 deletions tests/src/kakarot/test_evm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import pytest
from hypothesis import given
from hypothesis.strategies import integers

from tests.utils.syscall_handler import SyscallHandler

Expand Down Expand Up @@ -80,3 +82,17 @@ def test_should_return_non_cached_valid_jumpdest(
)
== expected
)

# 1000000 is the default value for the init_evm test helper
@given(amount=integers(min_value=0, max_value=1000000))
def test_should_return_gas_left(self, cairo_run, amount):
gas_left, stopped = cairo_run("test__charge_gas", amount=amount)
assert gas_left == 1000000 - amount
assert stopped == 0

# 1000000 is the default value for the init_evm test helper
@given(amount=integers(min_value=1000001, max_value=2**248 - 1))
def test_should_return_not_enough_gas(self, cairo_run, amount):
gas_left, stopped = cairo_run("test__charge_gas", amount=amount)
assert gas_left == 0
assert stopped == 1
39 changes: 15 additions & 24 deletions tests/src/kakarot/test_kakarot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

@pytest.fixture(scope="module")
def get_contract(cairo_run):
from kakarot_scripts.utils.kakarot import get_contract as get_solidity_contract
from kakarot_scripts.utils.kakarot import get_contract_sync as get_solidity_contract

def _factory(contract_app, contract_name):
def _wrap_cairo_run(fun):
Expand Down Expand Up @@ -293,15 +293,10 @@ def test_register_account_should_fail_caller_not_resolved_address(
class TestEthCall:
@pytest.mark.slow
@pytest.mark.SolmateERC20
@SyscallHandler.patch(
"IAccount.is_valid_jumpdest",
lambda addr, data: [1],
)
@SyscallHandler.patch(
"IAccount.get_code_hash", lambda sn_addr, data: [0x1, 0x1]
)
async def test_erc20_transfer(self, get_contract):
erc20 = await get_contract("Solmate", "ERC20")
@SyscallHandler.patch("IAccount.is_valid_jumpdest", lambda addr, data: [1])
@SyscallHandler.patch("IAccount.get_code_hash", lambda addr, data: [0x1, 0x1])
def test_erc20_transfer(self, get_contract):
erc20 = get_contract("Solmate", "ERC20")
amount = int(1e18)
initial_state = {
CONTRACT_ADDRESS: {
Expand All @@ -320,15 +315,10 @@ async def test_erc20_transfer(self, get_contract):

@pytest.mark.slow
@pytest.mark.SolmateERC721
@SyscallHandler.patch(
"IAccount.is_valid_jumpdest",
lambda addr, data: [1],
)
@SyscallHandler.patch(
"IAccount.get_code_hash", lambda sn_addr, data: [0x1, 0x1]
)
async def test_erc721_transfer(self, get_contract):
erc721 = await get_contract("Solmate", "ERC721")
@SyscallHandler.patch("IAccount.is_valid_jumpdest", lambda addr, data: [1])
@SyscallHandler.patch("IAccount.get_code_hash", lambda addr, data: [0x1, 0x1])
def test_erc721_transfer(self, get_contract):
erc721 = get_contract("Solmate", "ERC721")
token_id = 1337
initial_state = {
CONTRACT_ADDRESS: {
Expand Down Expand Up @@ -356,7 +346,7 @@ async def test_erc721_transfer(self, get_contract):
"ef_blockchain_test",
EF_TESTS_PARSED_DIR.glob("*walletConstruction_d0g1v0_Cancun*.json"),
)
async def test_case(
def test_case(
self,
cairo_run,
ef_blockchain_test,
Expand Down Expand Up @@ -394,7 +384,7 @@ async def test_case(
assert gas_used == int(block["blockHeader"]["gasUsed"], 16)

@pytest.mark.skip
async def test_failing_contract(self, cairo_run):
def test_failing_contract(self, cairo_run):
initial_state = {
CONTRACT_ADDRESS: {
"code": bytes.fromhex("ADDC0DE1"),
Expand All @@ -418,10 +408,11 @@ async def test_failing_contract(self, cairo_run):
class TestLoopProfiling:
@pytest.mark.slow
@pytest.mark.NoCI
@pytest.mark.parametrize("steps", [10, 50, 100, 200])
@SyscallHandler.patch("IAccount.is_valid_jumpdest", lambda addr, data: [1])
async def test_loop_profiling(self, get_contract, steps):
plain_opcodes = await get_contract("PlainOpcodes", "PlainOpcodes")
@SyscallHandler.patch("IAccount.get_code_hash", lambda addr, data: [0x1, 0x1])
@pytest.mark.parametrize("steps", [10, 50, 100, 200])
def test_loop_profiling(self, get_contract, steps):
plain_opcodes = get_contract("PlainOpcodes", "PlainOpcodes")
initial_state = {
CONTRACT_ADDRESS: {
"code": list(plain_opcodes.bytecode_runtime),
Expand Down

0 comments on commit 7112b36

Please sign in to comment.