Skip to content

Commit

Permalink
Add mockFunction cheatcode and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
grandizzy committed Sep 5, 2024
1 parent 82e3f68 commit 0699b9d
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 0 deletions.
20 changes: 20 additions & 0 deletions crates/cheatcodes/assets/cheatcodes.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions crates/cheatcodes/spec/src/vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,15 @@ interface Vm {
function mockCallRevert(address callee, uint256 msgValue, bytes calldata data, bytes calldata revertData)
external;

/// Whenever a call is made to `callee` with calldata `data`, this cheatcode instead calls
/// `target` with the same calldata. This functionality is similar to a delegate call made to
/// `target` contract from `callee`.
/// Can be used to substitute a call to a function with another implementation that captures
/// the primary logic of the original function but is easier to reason about.
/// If calldata is not a strict match then partial match by selector is attempted.
#[cheatcode(group = Evm, safety = Unsafe)]
function mockFunction(address callee, address target, bytes calldata data) external;

// --- Impersonation (pranks) ---

/// Sets the *next* call's `msg.sender` to be the input address.
Expand Down
9 changes: 9 additions & 0 deletions crates/cheatcodes/src/evm/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,15 @@ impl Cheatcode for mockCallRevert_1Call {
}
}

impl Cheatcode for mockFunctionCall {
fn apply(&self, state: &mut Cheatcodes) -> Result {
let Self { callee, target, data } = self;
state.mocked_functions.entry(*callee).or_default().insert(data.clone(), *target);

Ok(Default::default())
}
}

#[allow(clippy::ptr_arg)] // Not public API, doesn't matter
fn mock_call(
state: &mut Cheatcodes,
Expand Down
5 changes: 5 additions & 0 deletions crates/cheatcodes/src/inspector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,10 @@ pub struct Cheatcodes {
// **Note**: inner must a BTreeMap because of special `Ord` impl for `MockCallDataContext`
pub mocked_calls: HashMap<Address, BTreeMap<MockCallDataContext, MockCallReturnData>>,

/// Mocked functions.
/// Maps contract address to be mocked to (calldata, mock address).
pub mocked_functions: HashMap<Address, BTreeMap<Bytes, Address>>,

/// Expected calls
pub expected_calls: ExpectedCallTracker,
/// Expected emits
Expand Down Expand Up @@ -398,6 +402,7 @@ impl Cheatcodes {
recorded_account_diffs_stack: Default::default(),
recorded_logs: Default::default(),
mocked_calls: Default::default(),
mocked_functions: Default::default(),
expected_calls: Default::default(),
expected_emits: Default::default(),
allowed_mem_writes: Default::default(),
Expand Down
1 change: 1 addition & 0 deletions crates/cheatcodes/tests/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod assume;
mod expect;
mod gas_metering;
mod mock;
mod random;
94 changes: 94 additions & 0 deletions crates/cheatcodes/tests/mock.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
//! Contains various tests for `mock*` cheatcodes.

use foundry_test_utils::{forgetest_init, str};

// Kontrol `mockFunction` cheatcode tests ported.
forgetest_init!(test_mock_function_kontrol, |prj, cmd| {
prj.wipe_contracts();
prj.insert_ds_test();
prj.insert_vm();
prj.clear();

prj.add_source(
"MockFunction.t.sol",
r#"pragma solidity 0.8.24;
import {Vm} from "./Vm.sol";
import {DSTest} from "./test.sol";
contract MockFunctionContract {
uint256 public a;
function mocked_function() public {
a = 321;
}
function mocked_args_function(uint256 x) public {
a = 321 + x;
}
}
contract ModelMockFunctionContract {
uint256 public a;
function mocked_function() public {
a = 123;
}
function mocked_args_function(uint256 x) public {
a = 123 + x;
}
}
contract MockFunctionTest is DSTest {
MockFunctionContract my_contract;
ModelMockFunctionContract model_contract;
Vm vm = Vm(HEVM_ADDRESS);
function setUp() public {
my_contract = new MockFunctionContract();
model_contract = new ModelMockFunctionContract();
}
function test_mock_function() public {
vm.mockFunction(
address(my_contract),
address(model_contract),
abi.encodeWithSelector(MockFunctionContract.mocked_function.selector)
);
my_contract.mocked_function();
assertEq(my_contract.a(), 123);
}
function test_mock_function_concrete_args() public {
vm.mockFunction(
address(my_contract),
address(model_contract),
abi.encodeWithSelector(MockFunctionContract.mocked_args_function.selector, 456)
);
my_contract.mocked_args_function(456);
assertEq(my_contract.a(), 123 + 456);
my_contract.mocked_args_function(567);
assertEq(my_contract.a(), 321 + 567);
}
function test_mock_function_all_args() public {
vm.mockFunction(
address(my_contract),
address(model_contract),
abi.encodeWithSelector(MockFunctionContract.mocked_args_function.selector)
);
my_contract.mocked_args_function(678);
assertEq(my_contract.a(), 123 + 678);
my_contract.mocked_args_function(789);
assertEq(my_contract.a(), 123 + 789);
}
}
"#,
)
.unwrap();

cmd.args(["test"]).assert_success().stdout_eq(str![[r#"
...
[PASS] test_mock_function() ([GAS])
[PASS] test_mock_function_all_args() ([GAS])
[PASS] test_mock_function_concrete_args() ([GAS])
...
"#]]);
});
12 changes: 12 additions & 0 deletions crates/evm/evm/src/inspectors/stack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,18 @@ impl<'a, DB: DatabaseExt> Inspector<DB> for InspectorStackRefMut<'a> {

ecx.journaled_state.depth += self.in_inner_context as usize;
if let Some(cheatcodes) = self.cheatcodes.as_deref_mut() {
// Handle mocked functions, replace bytecode address with mock if matched.
if let Some(mocks) = cheatcodes.mocked_functions.get(&call.target_address) {
if let Some(target) = mocks.get(&call.input) {
call.bytecode_address = *target;
} else {
// Check if we have a catch all mock set for selector.
if let Some(target) = mocks.get(&call.input.slice(..4)) {
call.bytecode_address = *target;
}
}
}

if let Some(output) = cheatcodes.call_with_executor(ecx, call, self.inner) {
if output.result.result != InstructionResult::Continue {
ecx.journaled_state.depth -= self.in_inner_context as usize;
Expand Down
1 change: 1 addition & 0 deletions testdata/cheats/Vm.sol

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 0699b9d

Please sign in to comment.