diff --git a/CHANGELOG.md b/CHANGELOG.md index 9465d42fc6f..07ddd738fa0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ ### New features * `SafeCast`: added functions to downcast signed integers (e.g. `toInt32`), improving usability of `SignedSafeMath`. ([#2243](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2243)) + * `functionCall`: new helpers that replicate Solidity's function call semantics, reducing the need to rely on `call`. ([#2264](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2264)) * `ERC1155`: added support for a base implementation, non-standard extensions and a preset contract. ([#2014](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2014), [#2230](https://github.com/OpenZeppelin/openzeppelin-contracts/issues/2230)) ### Improvements diff --git a/contracts/mocks/AddressImpl.sol b/contracts/mocks/AddressImpl.sol index 0bcfcbb85fc..19dcb15b440 100644 --- a/contracts/mocks/AddressImpl.sol +++ b/contracts/mocks/AddressImpl.sol @@ -5,6 +5,8 @@ pragma solidity ^0.6.0; import "../utils/Address.sol"; contract AddressImpl { + event CallReturnValue(string data); + function isContract(address account) external view returns (bool) { return Address.isContract(account); } @@ -13,6 +15,18 @@ contract AddressImpl { Address.sendValue(receiver, amount); } + function functionCall(address target, bytes calldata data) external { + bytes memory returnData = Address.functionCall(target, data); + + emit CallReturnValue(abi.decode(returnData, (string))); + } + + function functionCallWithValue(address target, bytes calldata data, uint256 value) external payable { + bytes memory returnData = Address.functionCallWithValue(target, data, value); + + emit CallReturnValue(abi.decode(returnData, (string))); + } + // sendValue's tests require the contract to hold Ether receive () external payable { } } diff --git a/contracts/mocks/CallReceiverMock.sol b/contracts/mocks/CallReceiverMock.sol new file mode 100644 index 00000000000..2e3297617b9 --- /dev/null +++ b/contracts/mocks/CallReceiverMock.sol @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.6.0; + +contract CallReceiverMock { + + event MockFunctionCalled(); + + uint256[] private _array; + + function mockFunction() public payable returns (string memory) { + emit MockFunctionCalled(); + + return "0x1234"; + } + + function mockFunctionNonPayable() public returns (string memory) { + emit MockFunctionCalled(); + + return "0x1234"; + } + + function mockFunctionRevertsNoReason() public payable { + revert(); + } + + function mockFunctionRevertsReason() public payable { + revert("CallReceiverMock: reverting"); + } + + function mockFunctionThrows() public payable { + assert(false); + } + + function mockFunctionOutOfGas() public payable { + for (uint256 i = 0; ; ++i) { + _array.push(i); + } + } +} diff --git a/contracts/token/ERC20/SafeERC20.sol b/contracts/token/ERC20/SafeERC20.sol index ac359400b0e..e5f0d23aa56 100644 --- a/contracts/token/ERC20/SafeERC20.sol +++ b/contracts/token/ERC20/SafeERC20.sol @@ -63,19 +63,10 @@ library SafeERC20 { */ function _callOptionalReturn(IERC20 token, bytes memory data) private { // We need to perform a low level call here, to bypass Solidity's return data size checking mechanism, since - // we're implementing it ourselves. - - // A Solidity high level call has three parts: - // 1. The target address is checked to verify it contains contract code - // 2. The call itself is made, and success asserted - // 3. The return value is decoded, which in turn checks the size of the returned data. - // solhint-disable-next-line max-line-length - require(address(token).isContract(), "SafeERC20: call to non-contract"); - - // solhint-disable-next-line avoid-low-level-calls - (bool success, bytes memory returndata) = address(token).call(data); - require(success, "SafeERC20: low-level call failed"); + // we're implementing it ourselves. We use {Address.functionCall} to perform this call, which verifies that + // the target address contains contract code and also asserts for success in the low-level call. + bytes memory returndata = address(token).functionCall(data, "SafeERC20: low-level call failed"); if (returndata.length > 0) { // Return data is optional // solhint-disable-next-line max-line-length require(abi.decode(returndata, (bool)), "SafeERC20: ERC20 operation did not succeed"); diff --git a/contracts/token/ERC721/ERC721.sol b/contracts/token/ERC721/ERC721.sol index 4ad45f36ff0..e6d120e478c 100644 --- a/contracts/token/ERC721/ERC721.sol +++ b/contracts/token/ERC721/ERC721.sol @@ -437,28 +437,15 @@ contract ERC721 is Context, ERC165, IERC721, IERC721Metadata, IERC721Enumerable if (!to.isContract()) { return true; } - // solhint-disable-next-line avoid-low-level-calls - (bool success, bytes memory returndata) = to.call(abi.encodeWithSelector( + bytes memory returndata = to.functionCall(abi.encodeWithSelector( IERC721Receiver(to).onERC721Received.selector, _msgSender(), from, tokenId, _data - )); - if (!success) { - if (returndata.length > 0) { - // solhint-disable-next-line no-inline-assembly - assembly { - let returndata_size := mload(returndata) - revert(add(32, returndata), returndata_size) - } - } else { - revert("ERC721: transfer to non ERC721Receiver implementer"); - } - } else { - bytes4 retval = abi.decode(returndata, (bytes4)); - return (retval == _ERC721_RECEIVED); - } + ), "ERC721: transfer to non ERC721Receiver implementer"); + bytes4 retval = abi.decode(returndata, (bytes4)); + return (retval == _ERC721_RECEIVED); } function _approve(address to, uint256 tokenId) private { diff --git a/contracts/utils/Address.sol b/contracts/utils/Address.sol index 2c54fe44968..1a96adae229 100644 --- a/contracts/utils/Address.sol +++ b/contracts/utils/Address.sol @@ -57,4 +57,79 @@ library Address { (bool success, ) = recipient.call{ value: amount }(""); require(success, "Address: unable to send value, recipient may have reverted"); } + + /** + * @dev Performs a Solidity function call using a low level `call`. A + * plain`call` is an unsafe replacement for a function call: use this + * function instead. + * + * If `target` reverts with a revert reason, it is bubbled up by this + * function (like regular Solidity function calls). + * + * Requirements: + * + * - `target` must be a contract. + * - calling `target` with `data` must not revert. + */ + function functionCall(address target, bytes memory data) internal returns (bytes memory) { + return functionCall(target, data, "Address: low-level call failed"); + } + + /** + * @dev Same as {Address-functionCall-address-bytes-}, but with + * `errorMessage` as a fallback revert reason when `target` reverts. + */ + function functionCall(address target, bytes memory data, string memory errorMessage) internal returns (bytes memory) { + return _functionCallWithValue(target, data, 0, errorMessage); + } + + /** + * @dev Performs a Solidity function call using a low level `call`, + * transferring `value` wei. A plain`call` is an unsafe replacement for a + * function call: use this function instead. + * + * If `target` reverts with a revert reason, it is bubbled up by this + * function (like regular Solidity function calls). + * + * Requirements: + * + * - `target` must be a contract. + * - the calling contract must have an ETH balance of at least `value`. + * - calling `target` with `data` must not revert. + */ + function functionCallWithValue(address target, bytes memory data, uint256 value) internal returns (bytes memory) { + return functionCallWithValue(target, data, value, "Address: low-level call with value failed"); + } + + /** + * @dev Same as {Address-functionCallWithValue-address-bytes-uint256-}, but + * with `errorMessage` as a fallback revert reason when `target` reverts. + */ + function functionCallWithValue(address target, bytes memory data, uint256 value, string memory errorMessage) internal returns (bytes memory) { + require(address(this).balance >= value, "Address: insufficient balance for call"); + return _functionCallWithValue(target, data, value, errorMessage); + } + + function _functionCallWithValue(address target, bytes memory data, uint256 weiValue, string memory errorMessage) private returns (bytes memory) { + require(isContract(target), "Address: call to non-contract"); + + // solhint-disable-next-line avoid-low-level-calls + (bool success, bytes memory returndata) = target.call{ value: weiValue }(data); + if (success) { + return returndata; + } else { + // Look for revert reason and bubble it up if present + if (returndata.length > 0) { + // The easiest way to bubble the revert reason is using memory via assembly + + // solhint-disable-next-line no-inline-assembly + assembly { + let returndata_size := mload(returndata) + revert(add(32, returndata), returndata_size) + } + } else { + revert(errorMessage); + } + } + } } diff --git a/test/token/ERC20/SafeERC20.test.js b/test/token/ERC20/SafeERC20.test.js index c50496ab7a0..11eac3100ef 100644 --- a/test/token/ERC20/SafeERC20.test.js +++ b/test/token/ERC20/SafeERC20.test.js @@ -15,7 +15,7 @@ describe('SafeERC20', function () { this.wrapper = await SafeERC20Wrapper.new(hasNoCode); }); - shouldRevertOnAllCalls('SafeERC20: call to non-contract'); + shouldRevertOnAllCalls('Address: call to non-contract'); }); describe('with token that returns false on all calls', function () { diff --git a/test/utils/Address.test.js b/test/utils/Address.test.js index 1229302caa2..8c08442874f 100644 --- a/test/utils/Address.test.js +++ b/test/utils/Address.test.js @@ -1,10 +1,11 @@ -const { accounts, contract } = require('@openzeppelin/test-environment'); +const { accounts, contract, web3 } = require('@openzeppelin/test-environment'); -const { balance, ether, expectRevert, send } = require('@openzeppelin/test-helpers'); +const { balance, ether, expectRevert, send, expectEvent } = require('@openzeppelin/test-helpers'); const { expect } = require('chai'); const AddressImpl = contract.fromArtifact('AddressImpl'); const EtherReceiver = contract.fromArtifact('EtherReceiverMock'); +const CallReceiverMock = contract.fromArtifact('CallReceiverMock'); describe('Address', function () { const [ recipient, other ] = accounts; @@ -90,4 +91,192 @@ describe('Address', function () { }); }); }); + + describe('functionCall', function () { + beforeEach(async function () { + this.contractRecipient = await CallReceiverMock.new(); + }); + + context('with valid contract receiver', function () { + it('calls the requested function', async function () { + const abiEncodedCall = web3.eth.abi.encodeFunctionCall({ + name: 'mockFunction', + type: 'function', + inputs: [], + }, []); + + const receipt = await this.mock.functionCall(this.contractRecipient.address, abiEncodedCall); + + expectEvent(receipt, 'CallReturnValue', { data: '0x1234' }); + await expectEvent.inTransaction(receipt.tx, CallReceiverMock, 'MockFunctionCalled'); + }); + + it('reverts when the called function reverts with no reason', async function () { + const abiEncodedCall = web3.eth.abi.encodeFunctionCall({ + name: 'mockFunctionRevertsNoReason', + type: 'function', + inputs: [], + }, []); + + await expectRevert( + this.mock.functionCall(this.contractRecipient.address, abiEncodedCall), + 'Address: low-level call failed' + ); + }); + + it('reverts when the called function reverts, bubbling up the revert reason', async function () { + const abiEncodedCall = web3.eth.abi.encodeFunctionCall({ + name: 'mockFunctionRevertsReason', + type: 'function', + inputs: [], + }, []); + + await expectRevert( + this.mock.functionCall(this.contractRecipient.address, abiEncodedCall), + 'CallReceiverMock: reverting' + ); + }); + + it('reverts when the called function runs out of gas', async function () { + const abiEncodedCall = web3.eth.abi.encodeFunctionCall({ + name: 'mockFunctionOutOfGas', + type: 'function', + inputs: [], + }, []); + + await expectRevert( + this.mock.functionCall(this.contractRecipient.address, abiEncodedCall), + 'Address: low-level call failed' + ); + }).timeout(5000); + + it('reverts when the called function throws', async function () { + const abiEncodedCall = web3.eth.abi.encodeFunctionCall({ + name: 'mockFunctionThrows', + type: 'function', + inputs: [], + }, []); + + await expectRevert( + this.mock.functionCall(this.contractRecipient.address, abiEncodedCall), + 'Address: low-level call failed' + ); + }); + + it('reverts when function does not exist', async function () { + const abiEncodedCall = web3.eth.abi.encodeFunctionCall({ + name: 'mockFunctionDoesNotExist', + type: 'function', + inputs: [], + }, []); + + await expectRevert( + this.mock.functionCall(this.contractRecipient.address, abiEncodedCall), + 'Address: low-level call failed' + ); + }); + }); + + context('with non-contract receiver', function () { + it('reverts when address is not a contract', async function () { + const [ recipient ] = accounts; + const abiEncodedCall = web3.eth.abi.encodeFunctionCall({ + name: 'mockFunction', + type: 'function', + inputs: [], + }, []); + await expectRevert(this.mock.functionCall(recipient, abiEncodedCall), 'Address: call to non-contract'); + }); + }); + }); + + describe('functionCallWithValue', function () { + beforeEach(async function () { + this.contractRecipient = await CallReceiverMock.new(); + }); + + context('with zero value', function () { + it('calls the requested function', async function () { + const abiEncodedCall = web3.eth.abi.encodeFunctionCall({ + name: 'mockFunction', + type: 'function', + inputs: [], + }, []); + + const receipt = await this.mock.functionCallWithValue(this.contractRecipient.address, abiEncodedCall, 0); + + expectEvent(receipt, 'CallReturnValue', { data: '0x1234' }); + await expectEvent.inTransaction(receipt.tx, CallReceiverMock, 'MockFunctionCalled'); + }); + }); + + context('with non-zero value', function () { + const amount = ether('1.2'); + + it('reverts if insufficient sender balance', async function () { + const abiEncodedCall = web3.eth.abi.encodeFunctionCall({ + name: 'mockFunction', + type: 'function', + inputs: [], + }, []); + + await expectRevert( + this.mock.functionCallWithValue(this.contractRecipient.address, abiEncodedCall, amount), + 'Address: insufficient balance for call' + ); + }); + + it('calls the requested function with existing value', async function () { + const abiEncodedCall = web3.eth.abi.encodeFunctionCall({ + name: 'mockFunction', + type: 'function', + inputs: [], + }, []); + + const tracker = await balance.tracker(this.contractRecipient.address); + + await send.ether(other, this.mock.address, amount); + const receipt = await this.mock.functionCallWithValue(this.contractRecipient.address, abiEncodedCall, amount); + + expect(await tracker.delta()).to.be.bignumber.equal(amount); + + expectEvent(receipt, 'CallReturnValue', { data: '0x1234' }); + await expectEvent.inTransaction(receipt.tx, CallReceiverMock, 'MockFunctionCalled'); + }); + + it('calls the requested function with transaction funds', async function () { + const abiEncodedCall = web3.eth.abi.encodeFunctionCall({ + name: 'mockFunction', + type: 'function', + inputs: [], + }, []); + + const tracker = await balance.tracker(this.contractRecipient.address); + + expect(await balance.current(this.mock.address)).to.be.bignumber.equal('0'); + const receipt = await this.mock.functionCallWithValue( + this.contractRecipient.address, abiEncodedCall, amount, { from: other, value: amount } + ); + + expect(await tracker.delta()).to.be.bignumber.equal(amount); + + expectEvent(receipt, 'CallReturnValue', { data: '0x1234' }); + await expectEvent.inTransaction(receipt.tx, CallReceiverMock, 'MockFunctionCalled'); + }); + + it('reverts when calling non-payable functions', async function () { + const abiEncodedCall = web3.eth.abi.encodeFunctionCall({ + name: 'mockFunctionNonPayable', + type: 'function', + inputs: [], + }, []); + + await send.ether(other, this.mock.address, amount); + await expectRevert( + this.mock.functionCallWithValue(this.contractRecipient.address, abiEncodedCall, amount), + 'Address: low-level call with value failed' + ); + }); + }); + }); });