Skip to content
This repository has been archived by the owner on Aug 21, 2024. It is now read-only.

Commit

Permalink
Dori/add rstest (#660)
Browse files Browse the repository at this point in the history
* Add dep

Signed-off-by: Dori Medini <dori@starkware.co>

* Parametrize test

Signed-off-by: Dori Medini <dori@starkware.co>

* Fixturize

Signed-off-by: Dori Medini <dori@starkware.co>
  • Loading branch information
dorimedini-starkware authored Jun 26, 2023
1 parent d11ebd2 commit 68fb4d2
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 60 deletions.
62 changes: 62 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ num-bigint = "0.4"
num-integer = "0.1.45"
num-traits = "0.2"
ouroboros = "0.15.6"
rstest = "0.17.0"

# IMPORTANT: next upgrade should delete replaced classes table handling.
# https://github.com/starkware-libs/blockifier/blob/54002da4b11c3c839a1221122cc18330854f563c/crates/native_blockifier/src/storage.rs#L145-L164
Expand Down
1 change: 1 addition & 0 deletions crates/blockifier/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ ctor.workspace = true
[dev-dependencies]
assert_matches.workspace = true
pretty_assertions.workspace = true
rstest.workspace = true
test-case.workspace = true
134 changes: 74 additions & 60 deletions crates/blockifier/src/transaction/account_transactions_test.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::HashMap;

use rstest::{fixture, rstest};
use starknet_api::core::{calculate_contract_address, ClassHash, ContractAddress};
use starknet_api::hash::StarkFelt;
use starknet_api::transaction::{
Expand All @@ -26,11 +27,21 @@ struct TestInitData {
pub account_address: ContractAddress,
pub contract_address: ContractAddress,
pub nonce_manager: NonceManager,
pub block_context: BlockContext,
}

fn create_state() -> CachedState<DictStateReader> {
let block_context = BlockContext::create_for_account_testing();
#[fixture]
fn max_fee() -> Fee {
Fee(MAX_FEE)
}

#[fixture]
fn block_context() -> BlockContext {
BlockContext::create_for_account_testing()
}

#[fixture]
fn create_state(block_context: BlockContext) -> CachedState<DictStateReader> {
// Declare all the needed contracts.
let test_account_class_hash = ClassHash(stark_felt!(TEST_ACCOUNT_CONTRACT_CLASS_HASH));
let test_erc20_class_hash = ClassHash(stark_felt!(TEST_ERC20_CONTRACT_CLASS_HASH));
Expand All @@ -49,8 +60,12 @@ fn create_state() -> CachedState<DictStateReader> {
})
}

fn create_test_state(max_fee: Fee, block_context: &BlockContext) -> TestInitData {
let mut state = create_state();
#[fixture]
fn create_test_init_data(
max_fee: Fee,
block_context: BlockContext,
#[from(create_state)] mut state: CachedState<DictStateReader>,
) -> TestInitData {
let mut nonce_manager = NonceManager::default();

// Deploy an account contract.
Expand All @@ -74,7 +89,7 @@ fn create_test_state(max_fee: Fee, block_context: &BlockContext) -> TestInitData
);

let account_tx = AccountTransaction::DeployAccount(deploy_account_tx);
account_tx.execute(&mut state, block_context).unwrap();
account_tx.execute(&mut state, &block_context).unwrap();

// Declare a contract.
let contract_class = ContractClassV0::from_file(TEST_CONTRACT_PATH).into();
Expand All @@ -89,7 +104,7 @@ fn create_test_state(max_fee: Fee, block_context: &BlockContext) -> TestInitData
)
.unwrap(),
);
account_tx.execute(&mut state, block_context).unwrap();
account_tx.execute(&mut state, &block_context).unwrap();

// Deploy a contract using syscall deploy.
let entry_point_selector = selector_from_name("deploy_contract");
Expand All @@ -110,7 +125,7 @@ fn create_test_state(max_fee: Fee, block_context: &BlockContext) -> TestInitData
nonce: nonce_manager.next(account_address),
..tx
}));
account_tx.execute(&mut state, block_context).unwrap();
account_tx.execute(&mut state, &block_context).unwrap();

// Calculate the newly deployed contract address
let contract_address = calculate_contract_address(
Expand All @@ -121,15 +136,18 @@ fn create_test_state(max_fee: Fee, block_context: &BlockContext) -> TestInitData
)
.unwrap();

TestInitData { state, account_address, contract_address, nonce_manager }
TestInitData { state, account_address, contract_address, nonce_manager, block_context }
}

#[test]
fn test_account_flow_test() {
let max_fee = Fee(MAX_FEE);
let block_context = &BlockContext::create_for_account_testing();
let TestInitData { mut state, account_address, contract_address, mut nonce_manager } =
create_test_state(max_fee, block_context);
#[rstest]
fn test_account_flow_test(max_fee: Fee, #[from(create_test_init_data)] init_data: TestInitData) {
let TestInitData {
mut state,
account_address,
contract_address,
mut nonce_manager,
block_context,
} = init_data;

// Invoke a function from the newly deployed contract.
let entry_point_selector = selector_from_name("return_result");
Expand All @@ -144,75 +162,71 @@ fn test_account_flow_test() {
nonce: nonce_manager.next(account_address),
..tx
}));
account_tx.execute(&mut state, block_context).unwrap();
account_tx.execute(&mut state, &block_context).unwrap();
}

#[test]
fn test_infinite_recursion() {
let max_fee = Fee(MAX_FEE);
let mut block_context = BlockContext::create_for_account_testing();

#[rstest]
#[case(true, true)]
#[case(true, false)]
#[case(false, true)]
#[case(false, false)]
fn test_infinite_recursion(
#[case] success: bool,
#[case] normal_recurse: bool,
#[from(create_state)] state: CachedState<DictStateReader>,
max_fee: Fee,
mut block_context: BlockContext,
) {
// Limit the number of execution steps (so we quickly hit the limit).
block_context.invoke_tx_max_n_steps = 1000;

let TestInitData { mut state, account_address, contract_address, mut nonce_manager } =
create_test_state(max_fee, &block_context);
let TestInitData {
mut state,
account_address,
contract_address,
mut nonce_manager,
block_context,
} = create_test_init_data(max_fee, block_context, state);

// Two types of recursion: one "normal" recursion, and one that uses the `call_contract`
// syscall.
let raw_contract_address = *contract_address.0.key();
let raw_normal_entry_point_selector = selector_from_name("recurse").0;
let raw_syscall_entry_point_selector = selector_from_name("recursive_syscall").0;
let raw_entry_point_selector =
selector_from_name(if normal_recurse { "recurse" } else { "recursive_syscall" }).0;

let recursion_depth = if success { 3_u32 } else { 1000_u32 };

let normal_calldata = |recursion_depth: u32| -> Calldata {
let execute_calldata = if normal_recurse {
calldata![
raw_contract_address,
raw_normal_entry_point_selector,
raw_entry_point_selector,
stark_felt!(1_u8),
stark_felt!(recursion_depth)
]
};
let syscall_calldata = |recursion_depth: u32| -> Calldata {
} else {
calldata![
raw_contract_address,
raw_syscall_entry_point_selector,
raw_entry_point_selector,
stark_felt!(3_u8), // Calldata length.
raw_contract_address,
raw_syscall_entry_point_selector,
raw_entry_point_selector,
stark_felt!(recursion_depth)
]
};

// Try two runs for each recursion type: one short run (success), and one that reverts due to
// step limit.
let first_valid_nonce = nonce_manager.next(account_address);
let second_valid_nonce = nonce_manager.next(account_address);
let third_valid_nonce = nonce_manager.next(account_address);
[
(1_u32, true, true, first_valid_nonce),
(1000_u32, false, true, second_valid_nonce),
(3_u32, true, false, second_valid_nonce), // Use same nonce, since previous tx should fail.
(1000_u32, false, false, third_valid_nonce),
]
.into_iter()
.map(|(recursion_depth, should_be_ok, use_normal_calldata, nonce)| {
let execute_calldata = if use_normal_calldata {
normal_calldata(recursion_depth)
} else {
syscall_calldata(recursion_depth)
};
let tx = invoke_tx(execute_calldata, account_address, max_fee, None);
let account_tx =
AccountTransaction::Invoke(InvokeTransaction::V1(InvokeTransactionV1 { nonce, ..tx }));
let result = account_tx.execute(&mut state, &block_context);
if should_be_ok {
result.unwrap();
} else {
assert!(
format!("{:?}", result.unwrap_err())
.contains("RunResources has no remaining steps.")
);
}
})
.for_each(drop);
let tx = invoke_tx(execute_calldata, account_address, max_fee, None);
let account_tx = AccountTransaction::Invoke(InvokeTransaction::V1(InvokeTransactionV1 {
nonce: nonce_manager.next(account_address),
..tx
}));
let result = account_tx.execute(&mut state, &block_context);
if success {
result.unwrap();
} else {
assert!(
format!("{:?}", result.unwrap_err()).contains("RunResources has no remaining steps.")
);
}
}

0 comments on commit 68fb4d2

Please sign in to comment.