diff --git a/evm/src/executor/inspector/cheatcodes/expect.rs b/evm/src/executor/inspector/cheatcodes/expect.rs index a657dfd86960..9371f6b6f831 100644 --- a/evm/src/executor/inspector/cheatcodes/expect.rs +++ b/evm/src/executor/inspector/cheatcodes/expect.rs @@ -168,16 +168,26 @@ pub fn handle_expect_emit(state: &mut Cheatcodes, log: RawLog, address: &Address #[derive(Clone, Debug, Default)] pub struct ExpectedCallData { - /// The expected calldata - pub calldata: Bytes, /// The expected value sent in the call pub value: Option, /// The expected gas supplied to the call pub gas: Option, /// The expected *minimum* gas supplied to the call pub min_gas: Option, - /// The number of times the call is expected to be made - pub count: Option, + /// The number of times the call is expected to be made. + /// If the type of call is `NonCount`, this is the lower bound for the number of calls + /// that must be seen. + /// If the type of call is `Count`, this is the exact number of calls that must be seen. + pub count: u64, + /// The type of call + pub call_type: ExpectedCallType, +} + +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub enum ExpectedCallType { + #[default] + Count, + NonCount, } #[derive(Clone, Debug, Default, PartialEq, Eq)] @@ -220,6 +230,71 @@ fn expect_safe_memory(state: &mut Cheatcodes, start: u64, end: u64, depth: u64) Ok(Bytes::new()) } +/// Handles expected calls specified by the `vm.expectCall` cheatcode. +/// +/// It can handle calls in two ways: +/// - If the cheatcode was used with a `count` argument, it will expect the call to be made exactly +/// `count` times. +/// e.g. `vm.expectCall(address(0xc4f3), abi.encodeWithSelector(0xd34db33f), 4)` will expect the +/// call to address(0xc4f3) with selector `0xd34db33f` to be made exactly 4 times. If the amount of +/// calls is less or more than 4, the test will fail. Note that the `count` argument cannot be +/// overwritten with another `vm.expectCall`. If this is attempted, `expectCall` will revert. +/// - If the cheatcode was used without a `count` argument, it will expect the call to be made at +/// least the amount of times the cheatcode +/// was called. This means that `vm.expectCall` without a count argument can be called many times, +/// but cannot be called with a `count` argument after it was called without one. If the latter +/// happens, `expectCall` will revert. e.g `vm.expectCall(address(0xc4f3), +/// abi.encodeWithSelector(0xd34db33f))` will expect the call to address(0xc4f3) and selector +/// `0xd34db33f` to be made at least once. If the amount of calls is 0, the test will fail. If the +/// call is made more than once, the test will pass. +#[allow(clippy::too_many_arguments)] +fn expect_call( + state: &mut Cheatcodes, + target: H160, + calldata: Vec, + value: Option, + gas: Option, + min_gas: Option, + count: u64, + call_type: ExpectedCallType, +) -> Result { + match call_type { + ExpectedCallType::Count => { + // Get the expected calls for this target. + let expecteds = state.expected_calls.entry(target).or_default(); + // In this case, as we're using counted expectCalls, we should not be able to set them + // more than once. + ensure!( + !expecteds.contains_key(&calldata), + "Counted expected calls can only bet set once." + ); + expecteds + .insert(calldata, (ExpectedCallData { value, gas, min_gas, count, call_type }, 0)); + Ok(Bytes::new()) + } + ExpectedCallType::NonCount => { + let expecteds = state.expected_calls.entry(target).or_default(); + // Check if the expected calldata exists. + // If it does, increment the count by one as we expect to see it one more time. + if let Some(expected) = expecteds.get_mut(&calldata) { + // Ensure we're not overwriting a counted expectCall. + ensure!( + expected.0.call_type == ExpectedCallType::NonCount, + "Cannot overwrite a counted expectCall with a non-counted expectCall." + ); + expected.0.count += 1; + } else { + // If it does not exist, then create it. + expecteds.insert( + calldata, + (ExpectedCallData { value, gas, min_gas, count, call_type }, 0), + ); + } + Ok(Bytes::new()) + } + } +} + pub fn apply( state: &mut Cheatcodes, data: &mut EVMData<'_, DB>, @@ -267,125 +342,113 @@ pub fn apply( }); Ok(Bytes::new()) } - HEVMCalls::ExpectCall0(inner) => { - state.expected_calls.entry(inner.0).or_default().push(( - ExpectedCallData { - calldata: inner.1.to_vec().into(), - value: None, - gas: None, - min_gas: None, - count: None, - }, - 0, - )); - Ok(Bytes::new()) - } - HEVMCalls::ExpectCall1(inner) => { - state.expected_calls.entry(inner.0).or_default().push(( - ExpectedCallData { - calldata: inner.1.to_vec().into(), - value: None, - gas: None, - min_gas: None, - count: Some(inner.2), - }, - 0, - )); - Ok(Bytes::new()) - } - HEVMCalls::ExpectCall2(inner) => { - state.expected_calls.entry(inner.0).or_default().push(( - ExpectedCallData { - calldata: inner.2.to_vec().into(), - value: Some(inner.1), - gas: None, - min_gas: None, - count: None, - }, - 0, - )); - Ok(Bytes::new()) - } - HEVMCalls::ExpectCall3(inner) => { - state.expected_calls.entry(inner.0).or_default().push(( - ExpectedCallData { - calldata: inner.2.to_vec().into(), - value: Some(inner.1), - gas: None, - min_gas: None, - count: Some(inner.3), - }, - 0, - )); - Ok(Bytes::new()) - } + HEVMCalls::ExpectCall0(inner) => expect_call( + state, + inner.0, + inner.1.to_vec(), + None, + None, + None, + 1, + ExpectedCallType::NonCount, + ), + HEVMCalls::ExpectCall1(inner) => expect_call( + state, + inner.0, + inner.1.to_vec(), + None, + None, + None, + inner.2, + ExpectedCallType::Count, + ), + HEVMCalls::ExpectCall2(inner) => expect_call( + state, + inner.0, + inner.2.to_vec(), + Some(inner.1), + None, + None, + 1, + ExpectedCallType::NonCount, + ), + HEVMCalls::ExpectCall3(inner) => expect_call( + state, + inner.0, + inner.2.to_vec(), + Some(inner.1), + None, + None, + inner.3, + ExpectedCallType::Count, + ), HEVMCalls::ExpectCall4(inner) => { let value = inner.1; - // If the value of the transaction is non-zero, the EVM adds a call stipend of 2300 gas // to ensure that the basic fallback function can be called. let positive_value_cost_stipend = if value > U256::zero() { 2300 } else { 0 }; - state.expected_calls.entry(inner.0).or_default().push(( - ExpectedCallData { - calldata: inner.3.to_vec().into(), - value: Some(value), - gas: Some(inner.2 + positive_value_cost_stipend), - min_gas: None, - count: None, - }, - 0, - )); - Ok(Bytes::new()) + expect_call( + state, + inner.0, + inner.3.to_vec(), + Some(value), + Some(inner.2 + positive_value_cost_stipend), + None, + 1, + ExpectedCallType::NonCount, + ) } HEVMCalls::ExpectCall5(inner) => { let value = inner.1; + // If the value of the transaction is non-zero, the EVM adds a call stipend of 2300 gas + // to ensure that the basic fallback function can be called. let positive_value_cost_stipend = if value > U256::zero() { 2300 } else { 0 }; - state.expected_calls.entry(inner.0).or_default().push(( - ExpectedCallData { - calldata: inner.3.to_vec().into(), - value: Some(value), - gas: Some(inner.2 + positive_value_cost_stipend), - min_gas: None, - count: Some(inner.4), - }, - 0, - )); - Ok(Bytes::new()) + + expect_call( + state, + inner.0, + inner.3.to_vec(), + Some(value), + Some(inner.2 + positive_value_cost_stipend), + None, + inner.4, + ExpectedCallType::Count, + ) } HEVMCalls::ExpectCallMinGas0(inner) => { let value = inner.1; - // If the value of the transaction is non-zero, the EVM adds a call stipend of 2300 gas // to ensure that the basic fallback function can be called. let positive_value_cost_stipend = if value > U256::zero() { 2300 } else { 0 }; - state.expected_calls.entry(inner.0).or_default().push(( - ExpectedCallData { - calldata: inner.3.to_vec().into(), - value: Some(value), - gas: None, - min_gas: Some(inner.2 + positive_value_cost_stipend), - count: None, - }, - 0, - )); - Ok(Bytes::new()) + expect_call( + state, + inner.0, + inner.3.to_vec(), + Some(value), + None, + Some(inner.2 + positive_value_cost_stipend), + 1, + ExpectedCallType::NonCount, + ) } HEVMCalls::ExpectCallMinGas1(inner) => { let value = inner.1; + // If the value of the transaction is non-zero, the EVM adds a call stipend of 2300 gas + // to ensure that the basic fallback function can be called. let positive_value_cost_stipend = if value > U256::zero() { 2300 } else { 0 }; - state.expected_calls.entry(inner.0).or_default().push(( - ExpectedCallData { - calldata: inner.3.to_vec().into(), - value: Some(value), - gas: None, - min_gas: Some(inner.2 + positive_value_cost_stipend), - count: Some(inner.4), - }, - 0, - )); - Ok(Bytes::new()) + + expect_call( + state, + inner.0, + inner.3.to_vec(), + Some(value), + None, + Some(inner.2 + positive_value_cost_stipend), + inner.4, + ExpectedCallType::Count, + ) } HEVMCalls::MockCall0(inner) => { // TODO: Does this increase gas usage? diff --git a/evm/src/executor/inspector/cheatcodes/mod.rs b/evm/src/executor/inspector/cheatcodes/mod.rs index ce824d935a09..3a1869d42948 100644 --- a/evm/src/executor/inspector/cheatcodes/mod.rs +++ b/evm/src/executor/inspector/cheatcodes/mod.rs @@ -1,6 +1,6 @@ use self::{ env::Broadcast, - expect::{handle_expect_emit, handle_expect_revert}, + expect::{handle_expect_emit, handle_expect_revert, ExpectedCallType}, util::{check_if_fixed_gas_limit, process_create, BroadcastableTransactions}, }; use crate::{ @@ -69,6 +69,15 @@ mod error; pub(crate) use error::{bail, ensure, err}; pub use error::{Error, Result}; +/// Tracks the expected calls per address. +/// For each address, we track the expected calls per call data. We track it in such manner +/// so that we don't mix together calldatas that only contain selectors and calldatas that contain +/// selector and arguments (partial and full matches). +/// This then allows us to customize the matching behavior for each call data on the +/// `ExpectedCallData` struct and track how many times we've actually seen the call on the second +/// element of the tuple. +pub type ExpectedCallTracker = BTreeMap, (ExpectedCallData, u64)>>; + /// An inspector that handles calls to various cheatcodes, each with their own behavior. /// /// Cheatcodes can be called by contracts during execution to modify the VM environment, such as @@ -126,7 +135,7 @@ pub struct Cheatcodes { pub mocked_calls: BTreeMap>, /// Expected calls - pub expected_calls: BTreeMap>, + pub expected_calls: ExpectedCallTracker, /// Expected emits pub expected_emits: Vec, @@ -565,15 +574,29 @@ where } } else if call.contract != h160_to_b160(HARDHAT_CONSOLE_ADDRESS) { // Handle expected calls - if let Some(expecteds) = self.expected_calls.get_mut(&(b160_to_h160(call.contract))) { - if let Some((_, count)) = expecteds.iter_mut().find(|(expected, _)| { - expected.calldata.len() <= call.input.len() && - expected.calldata == call.input[..expected.calldata.len()] && - expected.value.map_or(true, |value| value == call.transfer.value.into()) && + + // Grab the different calldatas expected. + if let Some(expected_calls_for_target) = + self.expected_calls.get_mut(&(b160_to_h160(call.contract))) + { + // Match every partial/full calldata + for (calldata, (expected, actual_count)) in expected_calls_for_target.iter_mut() { + // Increment actual times seen if... + // The calldata is at most, as big as this call's input, and + if calldata.len() <= call.input.len() && + // Both calldata match, taking the length of the assumed smaller one (which will have at least the selector), and + *calldata == call.input[..calldata.len()] && + // The value matches, if provided + expected + .value + .map_or(true, |value| value == call.transfer.value.into()) && + // The gas matches, if provided expected.gas.map_or(true, |gas| gas == call.gas_limit) && + // The minimum gas matches, if provided expected.min_gas.map_or(true, |min_gas| min_gas <= call.gas_limit) - }) { - *count += 1; + { + *actual_count += 1; + } } } @@ -772,43 +795,67 @@ where // If the depth is 0, then this is the root call terminating if data.journaled_state.depth() == 0 { - for (address, expecteds) in &self.expected_calls { - for (expected, actual_count) in expecteds { - let ExpectedCallData { calldata, gas, min_gas, value, count } = expected; - let calldata = calldata.clone(); - let expected_values = [ - Some(format!("data {calldata}")), - value.map(|v| format!("value {v}")), - gas.map(|g| format!("gas {g}")), - min_gas.map(|g| format!("minimum gas {g}")), - ] - .into_iter() - .flatten() - .join(" and "); - if count.is_none() { - if *actual_count == 0 { - return ( - InstructionResult::Revert, - remaining_gas, - format!("Expected at least one call to {address:?} with {expected_values}, but got none") + // Match expected calls + for (address, calldatas) in &self.expected_calls { + // Loop over each address, and for each address, loop over each calldata it expects. + for (calldata, (expected, actual_count)) in calldatas { + // Grab the values we expect to see + let ExpectedCallData { gas, min_gas, value, count, call_type } = expected; + let calldata = Bytes::from(calldata.clone()); + + // We must match differently depending on the type of call we expect. + match call_type { + // If the cheatcode was called with a `count` argument, + // we must check that the EVM performed a CALL with this calldata exactly + // `count` times. + ExpectedCallType::Count => { + if *count != *actual_count { + let expected_values = [ + Some(format!("data {calldata}")), + value.map(|v| format!("value {v}")), + gas.map(|g| format!("gas {g}")), + min_gas.map(|g| format!("minimum gas {g}")), + ] + .into_iter() + .flatten() + .join(" and "); + return ( + InstructionResult::Revert, + remaining_gas, + format!( + "Expected call to {address:?} with {expected_values} to be called {count} time(s), but was called {actual_count} time(s)" + ) .encode() .into(), - ) + ) + } + } + // If the cheatcode was called without a `count` argument, + // we must check that the EVM performed a CALL with this calldata at least + // `count` times. The amount of times to check was + // the amount of time the cheatcode was called. + ExpectedCallType::NonCount => { + if *count > *actual_count { + let expected_values = [ + Some(format!("data {calldata}")), + value.map(|v| format!("value {v}")), + gas.map(|g| format!("gas {g}")), + min_gas.map(|g| format!("minimum gas {g}")), + ] + .into_iter() + .flatten() + .join(" and "); + return ( + InstructionResult::Revert, + remaining_gas, + format!( + "Expected call to {address:?} with {expected_values} to be called at least {count} time(s), but was called {actual_count} time(s)" + ) + .encode() + .into(), + ) + } } - } else if *count != Some(*actual_count) { - return ( - InstructionResult::Revert, - remaining_gas, - format!( - "Expected call to {:?} with {} to be made {} time(s), but was called {} time(s)", - address, - expected_values, - count.unwrap(), - actual_count, - ) - .encode() - .into(), - ) } } } diff --git a/testdata/cheats/ExpectCall.t.sol b/testdata/cheats/ExpectCall.t.sol index 6d088155a7e1..783ce296e350 100644 --- a/testdata/cheats/ExpectCall.t.sol +++ b/testdata/cheats/ExpectCall.t.sol @@ -62,6 +62,33 @@ contract ExpectCallTest is DSTest { target.add(1, 2); } + function testExpectMultipleCallsWithDataAdditive() public { + Contract target = new Contract(); + cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2)); + cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2)); + target.add(1, 2); + target.add(1, 2); + } + + function testExpectMultipleCallsWithDataAdditiveLowerBound() public { + Contract target = new Contract(); + cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2)); + cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2)); + target.add(1, 2); + target.add(1, 2); + target.add(1, 2); + } + + function testFailExpectMultipleCallsWithDataAdditive() public { + Contract target = new Contract(); + cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2)); + cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2)); + cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2)); + // Not enough calls to satisfy the additive expectCall, which expects 3 calls. + target.add(1, 2); + target.add(1, 2); + } + function testFailExpectCallWithData() public { Contract target = new Contract(); cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2)); @@ -312,3 +339,54 @@ contract ExpectCallCountTest is DSTest { target.addHardGasLimit(); } } + +contract ExpectCallMixedTest is DSTest { + Cheats constant cheats = Cheats(HEVM_ADDRESS); + + function testFailOverrideNoCountWithCount() public { + Contract target = new Contract(); + cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2)); + // You should not be able to overwrite a expectCall that had no count with some count. + cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2), 2); + target.add(1, 2); + target.add(1, 2); + } + + function testFailOverrideCountWithCount() public { + Contract target = new Contract(); + cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2), 2); + // You should not be able to overwrite a expectCall that had a count with some count. + cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2), 1); + target.add(1, 2); + target.add(1, 2); + } + + function testFailOverrideCountWithNoCount() public { + Contract target = new Contract(); + cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2), 2); + // You should not be able to overwrite a expectCall that had a count with no count. + cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2)); + target.add(1, 2); + target.add(1, 2); + } + + function testExpectMatchPartialAndFull() public { + Contract target = new Contract(); + cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector), 2); + // Even if a partial match is speciifed, you should still be able to look for full matches + // as one does not override the other. + cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2)); + target.add(1, 2); + target.add(1, 2); + } + + function testExpectMatchPartialAndFullFlipped() public { + Contract target = new Contract(); + cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector)); + // Even if a partial match is speciifed, you should still be able to look for full matches + // as one does not override the other. + cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2), 2); + target.add(1, 2); + target.add(1, 2); + } +}