Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dev: prepare parallelization of end to end tests #1382

Merged
merged 1 commit into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions kakarot_scripts/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,11 +323,11 @@ def __init__(self, relayers: List[Dict[str, int]]):
)
for relayer in relayers
]
self._index = 0
self.index = 0

def __next__(self) -> Account:
relayer = self.relayer_accounts[self._index]
self._index = (self._index + 1) % len(self.relayer_accounts)
relayer = self.relayer_accounts[self.index]
self.index = (self.index + 1) % len(self.relayer_accounts)
return relayer


Expand Down
5 changes: 4 additions & 1 deletion kakarot_scripts/utils/kakarot.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,10 @@ async def deploy_and_fund_evm_address(evm_address: str, amount: float):
await fund_address(evm_address, amount - account_balance)
if not await _contract_exists(starknet_address):
await _invoke_starknet(
"kakarot", "deploy_externally_owned_account", int(evm_address, 16)
"kakarot",
"deploy_externally_owned_account",
int(evm_address, 16),
account=next(NETWORK["relayers"]),
)
return starknet_address

Expand Down
4 changes: 2 additions & 2 deletions kakarot_scripts/utils/starknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ async def fund_address(
else:
logger.info(f"{amount / 1e18} ETH minted to {hex(address)}")
else:
account = funding_account or await get_starknet_account()
eth_contract = token_contract or await get_eth_contract()
account = funding_account or next(NETWORK["relayers"])
eth_contract = token_contract or await get_eth_contract(account)
balance = await get_balance(account.address, eth_contract)
if balance < amount:
raise ValueError(
Expand Down
15 changes: 14 additions & 1 deletion tests/end_to_end/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from starknet_py.contract import Contract
from starknet_py.net.account.account import Account

from kakarot_scripts.constants import RPC_CLIENT, NetworkType
from kakarot_scripts.constants import NETWORK, RPC_CLIENT, NetworkType
from kakarot_scripts.utils.kakarot import eth_balance_of
from kakarot_scripts.utils.kakarot import get_contract as get_solidity_contract
from kakarot_scripts.utils.kakarot import get_eoa
Expand Down Expand Up @@ -182,3 +182,16 @@ async def _factory(block_number: Optional[Union[int, str]] = "latest"):
).block_hash

return _factory


@pytest.fixture(autouse=True, scope="session")
def relayers(worker_id):
"""
Override NETWORK["relayers"] to use the worker_id as the index and avoid nonce issues.
"""
try:
logger.info(f"Setting relayer index to {int(worker_id[2:])}")
NETWORK["relayers"].index = int(worker_id[2:])
except ValueError:
logger.info(f"Error while setting relayer index to {worker_id}")
return
22 changes: 2 additions & 20 deletions tests/fixtures/starknet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import logging
import math
import shutil
from hashlib import md5
from pathlib import Path
from time import perf_counter, time_ns
Expand All @@ -24,9 +23,9 @@
from starkware.starknet.compiler.starknet_pass_manager import starknet_pass_manager

from tests.utils.constants import Opcodes
from tests.utils.coverage import VmWithCoverage, report_runs
from tests.utils.coverage import VmWithCoverage
from tests.utils.hints import debug_info
from tests.utils.reporting import dump_coverage, profile_from_tracer_data
from tests.utils.reporting import profile_from_tracer_data
from tests.utils.serde import Serde
from tests.utils.syscall_handler import SyscallHandler

Expand All @@ -38,23 +37,6 @@
logger = logging.getLogger()


@pytest.fixture(scope="session", autouse=True)
async def coverage(worker_id, request):

output_dir = Path("coverage")
shutil.rmtree(output_dir, ignore_errors=True)

yield

output_dir.mkdir(exist_ok=True, parents=True)
files = report_runs(excluded_file={"site-packages", "tests"})

if worker_id == "master":
dump_coverage(output_dir, files)
else:
dump_coverage(output_dir / worker_id, files)


def cairo_compile(path):
module_reader = get_module_reader(cairo_path=["src"])

Expand Down
24 changes: 24 additions & 0 deletions tests/src/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import shutil
from pathlib import Path

import pytest

from tests.utils.coverage import report_runs
from tests.utils.reporting import dump_coverage


@pytest.fixture(scope="session", autouse=True)
async def coverage(worker_id):

output_dir = Path("coverage")
shutil.rmtree(output_dir, ignore_errors=True)

yield

output_dir.mkdir(exist_ok=True, parents=True)
files = report_runs(excluded_file={"site-packages", "tests"})

if worker_id == "master":
dump_coverage(output_dir, files)
else:
dump_coverage(output_dir / worker_id, files)
Loading