diff --git a/crates/cheatcodes/assets/cheatcodes.json b/crates/cheatcodes/assets/cheatcodes.json index d8f8d21df67b..54556043ac9d 100644 --- a/crates/cheatcodes/assets/cheatcodes.json +++ b/crates/cheatcodes/assets/cheatcodes.json @@ -5051,6 +5051,46 @@ "status": "stable", "safety": "unsafe" }, + { + "func": { + "id": "expectRevert_10", + "description": "Expects a `count` number of reverts from the upcoming calls from the reverter address that match the revert data.", + "declaration": "function expectRevert(bytes4 revertData, address reverter, uint64 count) external;", + "visibility": "external", + "mutability": "", + "signature": "expectRevert(bytes4,address,uint64)", + "selector": "0xb0762d73", + "selectorBytes": [ + 176, + 118, + 45, + 115 + ] + }, + "group": "testing", + "status": "stable", + "safety": "unsafe" + }, + { + "func": { + "id": "expectRevert_11", + "description": "Expects a `count` number of reverts from the upcoming calls from the reverter address that exactly match the revert data.", + "declaration": "function expectRevert(bytes calldata revertData, address reverter, uint64 count) external;", + "visibility": "external", + "mutability": "", + "signature": "expectRevert(bytes,address,uint64)", + "selector": "0xd345fb1f", + "selectorBytes": [ + 211, + 69, + 251, + 31 + ] + }, + "group": "testing", + "status": "stable", + "safety": "unsafe" + }, { "func": { "id": "expectRevert_2", @@ -5131,6 +5171,86 @@ "status": "stable", "safety": "unsafe" }, + { + "func": { + "id": "expectRevert_6", + "description": "Expects a `count` number of reverts from the upcoming calls with any revert data or reverter.", + "declaration": "function expectRevert(uint64 count) external;", + "visibility": "external", + "mutability": "", + "signature": "expectRevert(uint64)", + "selector": "0x4ee38244", + "selectorBytes": [ + 78, + 227, + 130, + 68 + ] + }, + "group": "testing", + "status": "stable", + "safety": "unsafe" + }, + { + "func": { + "id": "expectRevert_7", + "description": "Expects a `count` number of reverts from the upcoming calls that match the revert data.", + "declaration": "function expectRevert(bytes4 revertData, uint64 count) external;", + "visibility": "external", + "mutability": "", + "signature": "expectRevert(bytes4,uint64)", + "selector": "0xe45ca72d", + "selectorBytes": [ + 228, + 92, + 167, + 45 + ] + }, + "group": "testing", + "status": "stable", + "safety": "unsafe" + }, + { + "func": { + "id": "expectRevert_8", + "description": "Expects a `count` number of reverts from the upcoming calls that exactly match the revert data.", + "declaration": "function expectRevert(bytes calldata revertData, uint64 count) external;", + "visibility": "external", + "mutability": "", + "signature": "expectRevert(bytes,uint64)", + "selector": "0x4994c273", + "selectorBytes": [ + 73, + 148, + 194, + 115 + ] + }, + "group": "testing", + "status": "stable", + "safety": "unsafe" + }, + { + "func": { + "id": "expectRevert_9", + "description": "Expects a `count` number of reverts from the upcoming calls from the reverter address.", + "declaration": "function expectRevert(address reverter, uint64 count) external;", + "visibility": "external", + "mutability": "", + "signature": "expectRevert(address,uint64)", + "selector": "0x1ff5f952", + "selectorBytes": [ + 31, + 245, + 249, + 82 + ] + }, + "group": "testing", + "status": "stable", + "safety": "unsafe" + }, { "func": { "id": "expectSafeMemory", diff --git a/crates/cheatcodes/spec/src/vm.rs b/crates/cheatcodes/spec/src/vm.rs index 6b66d31ddb50..6954fd1e526d 100644 --- a/crates/cheatcodes/spec/src/vm.rs +++ b/crates/cheatcodes/spec/src/vm.rs @@ -1019,6 +1019,30 @@ interface Vm { #[cheatcode(group = Testing, safety = Unsafe)] function expectRevert(bytes calldata revertData, address reverter) external; + /// Expects a `count` number of reverts from the upcoming calls with any revert data or reverter. + #[cheatcode(group = Testing, safety = Unsafe)] + function expectRevert(uint64 count) external; + + /// Expects a `count` number of reverts from the upcoming calls that match the revert data. + #[cheatcode(group = Testing, safety = Unsafe)] + function expectRevert(bytes4 revertData, uint64 count) external; + + /// Expects a `count` number of reverts from the upcoming calls that exactly match the revert data. + #[cheatcode(group = Testing, safety = Unsafe)] + function expectRevert(bytes calldata revertData, uint64 count) external; + + /// Expects a `count` number of reverts from the upcoming calls from the reverter address. + #[cheatcode(group = Testing, safety = Unsafe)] + function expectRevert(address reverter, uint64 count) external; + + /// Expects a `count` number of reverts from the upcoming calls from the reverter address that match the revert data. + #[cheatcode(group = Testing, safety = Unsafe)] + function expectRevert(bytes4 revertData, address reverter, uint64 count) external; + + /// Expects a `count` number of reverts from the upcoming calls from the reverter address that exactly match the revert data. + #[cheatcode(group = Testing, safety = Unsafe)] + function expectRevert(bytes calldata revertData, address reverter, uint64 count) external; + /// Expects an error on next call that starts with the revert data. #[cheatcode(group = Testing, safety = Unsafe)] function expectPartialRevert(bytes4 revertData) external; diff --git a/crates/cheatcodes/src/inspector.rs b/crates/cheatcodes/src/inspector.rs index 4e2b91b9dd2d..329b89d0ebc2 100644 --- a/crates/cheatcodes/src/inspector.rs +++ b/crates/cheatcodes/src/inspector.rs @@ -754,16 +754,23 @@ where { if ecx.journaled_state.depth() <= expected_revert.depth && matches!(expected_revert.kind, ExpectedRevertKind::Default) { - let expected_revert = std::mem::take(&mut self.expected_revert).unwrap(); - return match expect::handle_expect_revert( + let mut expected_revert = std::mem::take(&mut self.expected_revert).unwrap(); + let handler_result = expect::handle_expect_revert( false, true, - &expected_revert, + &mut expected_revert, outcome.result.result, outcome.result.output.clone(), &self.config.available_artifacts, - ) { + ); + + return match handler_result { Ok((address, retdata)) => { + expected_revert.actual_count += 1; + if expected_revert.actual_count < expected_revert.count { + self.expected_revert = Some(expected_revert.clone()); + } + outcome.result.result = InstructionResult::Return; outcome.result.output = retdata; outcome.address = address; @@ -1302,6 +1309,14 @@ impl Inspector<&mut dyn DatabaseExt> for Cheatcodes { expected_revert.reverted_by.is_none() { expected_revert.reverted_by = Some(call.target_address); + } else if outcome.result.is_revert() && + expected_revert.reverter.is_some() && + expected_revert.reverted_by.is_some() && + expected_revert.count > 1 + { + // If we're expecting more than one revert, we need to reset the reverted_by address + // to latest reverter. + expected_revert.reverted_by = Some(call.target_address); } if ecx.journaled_state.depth() <= expected_revert.depth { @@ -1315,15 +1330,20 @@ impl Inspector<&mut dyn DatabaseExt> for Cheatcodes { }; if needs_processing { - let expected_revert = std::mem::take(&mut self.expected_revert).unwrap(); - return match expect::handle_expect_revert( + // Only `remove` the expected revert from state if `expected_revert.count` == + // `expected_revert.actual_count` + let mut expected_revert = std::mem::take(&mut self.expected_revert).unwrap(); + + let handler_result = expect::handle_expect_revert( cheatcode_call, false, - &expected_revert, + &mut expected_revert, outcome.result.result, outcome.result.output.clone(), &self.config.available_artifacts, - ) { + ); + + return match handler_result { Err(error) => { trace!(expected=?expected_revert, ?error, status=?outcome.result.result, "Expected revert mismatch"); outcome.result.result = InstructionResult::Revert; @@ -1331,6 +1351,10 @@ impl Inspector<&mut dyn DatabaseExt> for Cheatcodes { outcome } Ok((_, retdata)) => { + expected_revert.actual_count += 1; + if expected_revert.actual_count < expected_revert.count { + self.expected_revert = Some(expected_revert.clone()); + } outcome.result.result = InstructionResult::Return; outcome.result.output = retdata; outcome diff --git a/crates/cheatcodes/src/test/expect.rs b/crates/cheatcodes/src/test/expect.rs index 3ee58407a7e6..a3ddef8c16c9 100644 --- a/crates/cheatcodes/src/test/expect.rs +++ b/crates/cheatcodes/src/test/expect.rs @@ -87,6 +87,10 @@ pub struct ExpectedRevert { pub reverter: Option
, /// Actual reverter of the call. pub reverted_by: Option
, + /// Number of times this revert is expected. + pub count: u64, + /// Actual number of times this revert has been seen. + pub actual_count: u64, } #[derive(Clone, Debug)] @@ -295,7 +299,7 @@ impl Cheatcode for expectEmitAnonymous_3Call { impl Cheatcode for expectRevert_0Call { fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { let Self {} = self; - expect_revert(ccx.state, None, ccx.ecx.journaled_state.depth(), false, false, None) + expect_revert(ccx.state, None, ccx.ecx.journaled_state.depth(), false, false, None, 1) } } @@ -309,6 +313,7 @@ impl Cheatcode for expectRevert_1Call { false, false, None, + 1, ) } } @@ -323,6 +328,7 @@ impl Cheatcode for expectRevert_2Call { false, false, None, + 1, ) } } @@ -337,6 +343,7 @@ impl Cheatcode for expectRevert_3Call { false, false, Some(*reverter), + 1, ) } } @@ -351,6 +358,7 @@ impl Cheatcode for expectRevert_4Call { false, false, Some(*reverter), + 1, ) } } @@ -365,6 +373,89 @@ impl Cheatcode for expectRevert_5Call { false, false, Some(*reverter), + 1, + ) + } +} + +impl Cheatcode for expectRevert_6Call { + fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { + let Self { count } = self; + expect_revert(ccx.state, None, ccx.ecx.journaled_state.depth(), false, false, None, *count) + } +} + +impl Cheatcode for expectRevert_7Call { + fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { + let Self { revertData, count } = self; + expect_revert( + ccx.state, + Some(revertData.as_ref()), + ccx.ecx.journaled_state.depth(), + false, + false, + None, + *count, + ) + } +} + +impl Cheatcode for expectRevert_8Call { + fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { + let Self { revertData, count } = self; + expect_revert( + ccx.state, + Some(revertData), + ccx.ecx.journaled_state.depth(), + false, + false, + None, + *count, + ) + } +} + +impl Cheatcode for expectRevert_9Call { + fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { + let Self { reverter, count } = self; + expect_revert( + ccx.state, + None, + ccx.ecx.journaled_state.depth(), + false, + false, + Some(*reverter), + *count, + ) + } +} + +impl Cheatcode for expectRevert_10Call { + fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { + let Self { revertData, reverter, count } = self; + expect_revert( + ccx.state, + Some(revertData.as_ref()), + ccx.ecx.journaled_state.depth(), + false, + false, + Some(*reverter), + *count, + ) + } +} + +impl Cheatcode for expectRevert_11Call { + fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { + let Self { revertData, reverter, count } = self; + expect_revert( + ccx.state, + Some(revertData), + ccx.ecx.journaled_state.depth(), + false, + false, + Some(*reverter), + *count, ) } } @@ -379,6 +470,7 @@ impl Cheatcode for expectPartialRevert_0Call { false, true, None, + 1, ) } } @@ -393,13 +485,14 @@ impl Cheatcode for expectPartialRevert_1Call { false, true, Some(*reverter), + 1, ) } } impl Cheatcode for _expectCheatcodeRevert_0Call { fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { - expect_revert(ccx.state, None, ccx.ecx.journaled_state.depth(), true, false, None) + expect_revert(ccx.state, None, ccx.ecx.journaled_state.depth(), true, false, None, 1) } } @@ -413,6 +506,7 @@ impl Cheatcode for _expectCheatcodeRevert_1Call { true, false, None, + 1, ) } } @@ -427,6 +521,7 @@ impl Cheatcode for _expectCheatcodeRevert_2Call { true, false, None, + 1, ) } } @@ -662,6 +757,7 @@ fn expect_revert( cheatcode: bool, partial_match: bool, reverter: Option
, + count: u64, ) -> Result { ensure!( state.expected_revert.is_none(), @@ -678,6 +774,8 @@ fn expect_revert( partial_match, reverter, reverted_by: None, + count, + actual_count: 0, }); Ok(Default::default()) } @@ -685,7 +783,7 @@ fn expect_revert( pub(crate) fn handle_expect_revert( is_cheatcode: bool, is_create: bool, - expected_revert: &ExpectedRevert, + expected_revert: &mut ExpectedRevert, status: InstructionResult, retdata: Bytes, known_contracts: &Option, @@ -698,72 +796,117 @@ pub(crate) fn handle_expect_revert( } }; - ensure!(!matches!(status, return_ok!()), "next call did not revert as expected"); - - // If expected reverter address is set then check it matches the actual reverter. - if let (Some(expected_reverter), Some(actual_reverter)) = - (expected_revert.reverter, expected_revert.reverted_by) - { - if expected_reverter != actual_reverter { - return Err(fmt_err!( - "Reverter != expected reverter: {} != {}", - actual_reverter, - expected_reverter - )); + let stringify = |data: &[u8]| { + if let Ok(s) = String::abi_decode(data, true) { + return s; } - } - - let expected_reason = expected_revert.reason.as_deref(); - // If None, accept any revert. - let Some(expected_reason) = expected_reason else { - return Ok(success_return()); + if data.is_ascii() { + return std::str::from_utf8(data).unwrap().to_owned(); + } + hex::encode_prefixed(data) }; - if !expected_reason.is_empty() && retdata.is_empty() { - bail!("call reverted as expected, but without data"); - } - - let mut actual_revert: Vec = retdata.into(); + if expected_revert.count == 0 { + if expected_revert.reverter.is_none() && expected_revert.reason.is_none() { + ensure!( + matches!(status, return_ok!()), + "call reverted when it was expected not to revert" + ); + return Ok(success_return()); + } - // Compare only the first 4 bytes if partial match. - if expected_revert.partial_match && actual_revert.get(..4) == expected_reason.get(..4) { - return Ok(success_return()) - } + // Flags to track if the reason and reverter match. + let mut reason_match = expected_revert.reason.as_ref().map(|_| false); + let mut reverter_match = expected_revert.reverter.as_ref().map(|_| false); - // Try decoding as known errors. - if matches!( - actual_revert.get(..4).map(|s| s.try_into().unwrap()), - Some(Vm::CheatcodeError::SELECTOR | alloy_sol_types::Revert::SELECTOR) - ) { - if let Ok(decoded) = Vec::::abi_decode(&actual_revert[4..], false) { - actual_revert = decoded; + // Reverter check + if let (Some(expected_reverter), Some(actual_reverter)) = + (expected_revert.reverter, expected_revert.reverted_by) + { + if expected_reverter == actual_reverter { + reverter_match = Some(true); + } } - } - if actual_revert == expected_reason || - (is_cheatcode && memchr::memmem::find(&actual_revert, expected_reason).is_some()) - { - Ok(success_return()) + // Reason check + let expected_reason = expected_revert.reason.as_deref(); + if let Some(expected_reason) = expected_reason { + let mut actual_revert: Vec = retdata.into(); + actual_revert = decode_revert(actual_revert); + + if actual_revert == expected_reason { + reason_match = Some(true); + } + }; + + match (reason_match, reverter_match) { + (Some(true), Some(true)) => Err(fmt_err!( + "expected 0 reverts with reason: {}, from address: {}, but got one", + &stringify(expected_reason.unwrap_or_default()), + expected_revert.reverter.unwrap() + )), + (Some(true), None) => Err(fmt_err!( + "expected 0 reverts with reason: {}, but got one", + &stringify(expected_reason.unwrap_or_default()) + )), + (None, Some(true)) => Err(fmt_err!( + "expected 0 reverts from address: {}, but got one", + expected_revert.reverter.unwrap() + )), + _ => Ok(success_return()), + } } else { - let (actual, expected) = if let Some(contracts) = known_contracts { - let decoder = RevertDecoder::new().with_abis(contracts.iter().map(|(_, c)| &c.abi)); - ( - &decoder.decode(actual_revert.as_slice(), Some(status)), - &decoder.decode(expected_reason, Some(status)), - ) + ensure!(!matches!(status, return_ok!()), "next call did not revert as expected"); + + // If expected reverter address is set then check it matches the actual reverter. + if let (Some(expected_reverter), Some(actual_reverter)) = + (expected_revert.reverter, expected_revert.reverted_by) + { + if expected_reverter != actual_reverter { + return Err(fmt_err!( + "Reverter != expected reverter: {} != {}", + actual_reverter, + expected_reverter + )); + } + } + + let expected_reason = expected_revert.reason.as_deref(); + // If None, accept any revert. + let Some(expected_reason) = expected_reason else { + return Ok(success_return()); + }; + + if !expected_reason.is_empty() && retdata.is_empty() { + bail!("call reverted as expected, but without data"); + } + + let mut actual_revert: Vec = retdata.into(); + + // Compare only the first 4 bytes if partial match. + if expected_revert.partial_match && actual_revert.get(..4) == expected_reason.get(..4) { + return Ok(success_return()) + } + + // Try decoding as known errors. + actual_revert = decode_revert(actual_revert); + + if actual_revert == expected_reason || + (is_cheatcode && memchr::memmem::find(&actual_revert, expected_reason).is_some()) + { + Ok(success_return()) } else { - let stringify = |data: &[u8]| { - if let Ok(s) = String::abi_decode(data, true) { - return s; - } - if data.is_ascii() { - return std::str::from_utf8(data).unwrap().to_owned(); - } - hex::encode_prefixed(data) + let (actual, expected) = if let Some(contracts) = known_contracts { + let decoder = RevertDecoder::new().with_abis(contracts.iter().map(|(_, c)| &c.abi)); + ( + &decoder.decode(actual_revert.as_slice(), Some(status)), + &decoder.decode(expected_reason, Some(status)), + ) + } else { + (&stringify(&actual_revert), &stringify(expected_reason)) }; - (&stringify(&actual_revert), &stringify(expected_reason)) - }; - Err(fmt_err!("Error != expected error: {} != {}", actual, expected,)) + Err(fmt_err!("Error != expected error: {} != {}", actual, expected,)) + } } } @@ -774,3 +917,15 @@ fn expect_safe_memory(state: &mut Cheatcodes, start: u64, end: u64, depth: u64) offsets.push(start..end); Ok(Default::default()) } + +fn decode_revert(revert: Vec) -> Vec { + if matches!( + revert.get(..4).map(|s| s.try_into().unwrap()), + Some(Vm::CheatcodeError::SELECTOR | alloy_sol_types::Revert::SELECTOR) + ) { + if let Ok(decoded) = Vec::::abi_decode(&revert[4..], false) { + return decoded; + } + } + revert +} diff --git a/testdata/cheats/Vm.sol b/testdata/cheats/Vm.sol index f6f66969f99f..b3746b6bc319 100644 --- a/testdata/cheats/Vm.sol +++ b/testdata/cheats/Vm.sol @@ -246,10 +246,16 @@ interface Vm { function expectPartialRevert(bytes4 revertData, address reverter) external; function expectRevert() external; function expectRevert(bytes4 revertData) external; + function expectRevert(bytes4 revertData, address reverter, uint64 count) external; + function expectRevert(bytes calldata revertData, address reverter, uint64 count) external; function expectRevert(bytes calldata revertData) external; function expectRevert(address reverter) external; function expectRevert(bytes4 revertData, address reverter) external; function expectRevert(bytes calldata revertData, address reverter) external; + function expectRevert(uint64 count) external; + function expectRevert(bytes4 revertData, uint64 count) external; + function expectRevert(bytes calldata revertData, uint64 count) external; + function expectRevert(address reverter, uint64 count) external; function expectSafeMemory(uint64 min, uint64 max) external; function expectSafeMemoryCall(uint64 min, uint64 max) external; function fee(uint256 newBasefee) external; diff --git a/testdata/default/cheats/ExpectRevert.t.sol b/testdata/default/cheats/ExpectRevert.t.sol index 18a90bac6e29..fef4ebaf5790 100644 --- a/testdata/default/cheats/ExpectRevert.t.sol +++ b/testdata/default/cheats/ExpectRevert.t.sol @@ -30,6 +30,10 @@ contract Reverter { revert(message); } + function callThenNoRevert(Dummy dummy) public pure { + dummy.callMe(); + } + function revertWithoutReason() public pure { revert(); } @@ -188,7 +192,7 @@ contract ExpectRevertTest is DSTest { } function testexpectCheatcodeRevert() public { - vm._expectCheatcodeRevert("JSON value at \".a\" is not an object"); + vm._expectCheatcodeRevert('JSON value at ".a" is not an object'); vm.parseJsonKeys('{"a": "b"}', ".a"); } @@ -351,3 +355,211 @@ contract ExpectRevertWithReverterTest is DSTest { aContract.callAndRevert(); } } + +contract ExpectRevertCount is DSTest { + Vm constant vm = Vm(HEVM_ADDRESS); + + function testRevertCountAny() public { + uint64 count = 3; + Reverter reverter = new Reverter(); + vm.expectRevert(count); + reverter.revertWithMessage("revert"); + reverter.revertWithMessage("revert2"); + reverter.revertWithMessage("revert3"); + + vm.expectRevert("revert"); + reverter.revertWithMessage("revert"); + } + + function testFailRevertCountAny() public { + uint64 count = 3; + Reverter reverter = new Reverter(); + vm.expectRevert(count); + reverter.revertWithMessage("revert"); + reverter.revertWithMessage("revert2"); + } + + function testNoRevert() public { + uint64 count = 0; + Reverter reverter = new Reverter(); + vm.expectRevert(count); + reverter.doNotRevert(); + } + + function testFailNoRevert() public { + uint64 count = 0; + Reverter reverter = new Reverter(); + vm.expectRevert(count); + reverter.revertWithMessage("revert"); + } + + function testRevertCountSpecific() public { + uint64 count = 2; + Reverter reverter = new Reverter(); + vm.expectRevert("revert", count); + reverter.revertWithMessage("revert"); + reverter.revertWithMessage("revert"); + } + + function testFailReverCountSpecifc() public { + uint64 count = 2; + Reverter reverter = new Reverter(); + vm.expectRevert("revert", count); + reverter.revertWithMessage("revert"); + reverter.revertWithMessage("second-revert"); + } + + function testNoRevertSpecific() public { + uint64 count = 0; + Reverter reverter = new Reverter(); + vm.expectRevert("revert", count); + reverter.doNotRevert(); + } + + function testFailNoRevertSpecific() public { + uint64 count = 0; + Reverter reverter = new Reverter(); + vm.expectRevert("revert", count); + reverter.revertWithMessage("revert"); + } + + function testNoRevertSpecificButDiffRevert() public { + uint64 count = 0; + Reverter reverter = new Reverter(); + vm.expectRevert("revert", count); + reverter.revertWithMessage("revert2"); + } + + function testRevertCountWithConstructor() public { + uint64 count = 1; + vm.expectRevert("constructor revert", count); + new ConstructorReverter("constructor revert"); + } + + function testNoRevertWithConstructor() public { + uint64 count = 0; + vm.expectRevert("constructor revert", count); + new CContract(); + } + + function testRevertCountNestedSpecific() public { + uint64 count = 2; + Reverter reverter = new Reverter(); + Reverter inner = new Reverter(); + + vm.expectRevert("nested revert", count); + reverter.revertWithMessage("nested revert"); + reverter.nestedRevert(inner, "nested revert"); + + vm.expectRevert("nested revert", count); + reverter.nestedRevert(inner, "nested revert"); + reverter.nestedRevert(inner, "nested revert"); + } + + function testRevertCountCallsThenReverts() public { + uint64 count = 2; + Reverter reverter = new Reverter(); + Dummy dummy = new Dummy(); + + vm.expectRevert("called a function and then reverted", count); + reverter.callThenRevert(dummy, "called a function and then reverted"); + reverter.callThenRevert(dummy, "called a function and then reverted"); + } + + function testFailRevertCountCallsThenReverts() public { + uint64 count = 2; + Reverter reverter = new Reverter(); + Dummy dummy = new Dummy(); + + vm.expectRevert("called a function and then reverted", count); + reverter.callThenRevert(dummy, "called a function and then reverted"); + reverter.callThenRevert(dummy, "wrong revert"); + } + + function testNoRevertCall() public { + uint64 count = 0; + Reverter reverter = new Reverter(); + Dummy dummy = new Dummy(); + + vm.expectRevert("called a function and then reverted", count); + reverter.callThenNoRevert(dummy); + } +} + +contract ExpectRevertCountWithReverter is DSTest { + Vm constant vm = Vm(HEVM_ADDRESS); + + function testRevertCountWithReverter() public { + uint64 count = 2; + Reverter reverter = new Reverter(); + vm.expectRevert(address(reverter), count); + reverter.revertWithMessage("revert"); + reverter.revertWithMessage("revert"); + } + + function testFailRevertCountWithReverter() public { + uint64 count = 2; + Reverter reverter = new Reverter(); + Reverter reverter2 = new Reverter(); + vm.expectRevert(address(reverter), count); + reverter.revertWithMessage("revert"); + reverter2.revertWithMessage("revert"); + } + + function testNoRevertWithReverter() public { + uint64 count = 0; + Reverter reverter = new Reverter(); + vm.expectRevert(address(reverter), count); + reverter.doNotRevert(); + } + + function testNoRevertWithWrongReverter() public { + uint64 count = 0; + Reverter reverter = new Reverter(); + Reverter reverter2 = new Reverter(); + vm.expectRevert(address(reverter), count); + reverter2.revertWithMessage("revert"); // revert from wrong reverter + } + + function testFailNoRevertWithReverter() public { + uint64 count = 0; + Reverter reverter = new Reverter(); + vm.expectRevert(address(reverter), count); + reverter.revertWithMessage("revert"); + } + + function testReverterCountWithData() public { + uint64 count = 2; + Reverter reverter = new Reverter(); + vm.expectRevert("revert", address(reverter), count); + reverter.revertWithMessage("revert"); + reverter.revertWithMessage("revert"); + } + + function testFailReverterCountWithWrongData() public { + uint64 count = 2; + Reverter reverter = new Reverter(); + vm.expectRevert("revert", address(reverter), count); + reverter.revertWithMessage("revert"); + reverter.revertWithMessage("wrong revert"); + } + + function testFailWrongReverterCountWithData() public { + uint64 count = 2; + Reverter reverter = new Reverter(); + Reverter reverter2 = new Reverter(); + vm.expectRevert("revert", address(reverter), count); + reverter.revertWithMessage("revert"); + reverter2.revertWithMessage("revert"); + } + + function testNoReverterCountWithData() public { + uint64 count = 0; + Reverter reverter = new Reverter(); + vm.expectRevert("revert", address(reverter), count); + reverter.doNotRevert(); + + vm.expectRevert("revert", address(reverter), count); + reverter.revertWithMessage("revert2"); + } +}