diff --git a/contracts/math/SafeMath.sol b/contracts/math/SafeMath.sol index 7975c634cca..68ef98b2615 100644 --- a/contracts/math/SafeMath.sol +++ b/contracts/math/SafeMath.sol @@ -33,6 +33,23 @@ library SafeMath { return c; } + /** + * @dev Returns the addition of two unsigned integers, reverting with custom message on + * overflow. + * + * Counterpart to Solidity's `+` operator. + * + * Requirements: + * + * - Addition cannot overflow. + */ + function add(uint256 a, uint256 b, string memory errorMessage) internal pure returns (uint256) { + uint256 c = a + b; + require(c >= a, errorMessage); + + return c; + } + /** * @dev Returns the subtraction of two unsigned integers, reverting on * overflow (when the result is negative). @@ -44,7 +61,10 @@ library SafeMath { * - Subtraction cannot overflow. */ function sub(uint256 a, uint256 b) internal pure returns (uint256) { - return sub(a, b, "SafeMath: subtraction overflow"); + require(b <= a, "SafeMath: subtraction overflow"); + uint256 c = a - b; + + return c; } /** @@ -89,7 +109,31 @@ library SafeMath { } /** - * @dev Returns the integer division of two unsigned integers. Reverts on + * @dev Returns the multiplication of two unsigned integers, reverting with custom message on + * overflow. + * + * Counterpart to Solidity's `*` operator. + * + * Requirements: + * + * - Multiplication cannot overflow. + */ + function mul(uint256 a, uint256 b, string memory errorMessage) internal pure returns (uint256) { + // Gas optimization: this is cheaper than requiring 'a' not being zero, but the + // benefit is lost if 'b' is also tested. + // See: https://github.com/OpenZeppelin/openzeppelin-contracts/pull/522 + if (a == 0) { + return 0; + } + + uint256 c = a * b; + require(c / a == b, errorMessage); + + return c; + } + + /** + * @dev Returns the integer division of two unsigned integers, reverting on * division by zero. The result is rounded towards zero. * * Counterpart to Solidity's `/` operator. Note: this function uses a @@ -101,11 +145,15 @@ library SafeMath { * - The divisor cannot be zero. */ function div(uint256 a, uint256 b) internal pure returns (uint256) { - return div(a, b, "SafeMath: division by zero"); + require(b > 0, "SafeMath: division by zero"); + uint256 c = a / b; + // assert(a == b * c + a % b); // There is no case in which this doesn't hold + + return c; } /** - * @dev Returns the integer division of two unsigned integers. Reverts with custom message on + * @dev Returns the integer division of two unsigned integers, reverting with custom message on * division by zero. The result is rounded towards zero. * * Counterpart to Solidity's `/` operator. Note: this function uses a @@ -126,7 +174,7 @@ library SafeMath { /** * @dev Returns the remainder of dividing two unsigned integers. (unsigned integer modulo), - * Reverts when dividing by zero. + * reverting when dividing by zero. * * Counterpart to Solidity's `%` operator. This function uses a `revert` * opcode (which leaves remaining gas untouched) while Solidity uses an @@ -137,12 +185,13 @@ library SafeMath { * - The divisor cannot be zero. */ function mod(uint256 a, uint256 b) internal pure returns (uint256) { - return mod(a, b, "SafeMath: modulo by zero"); + require(b != 0, "SafeMath: modulo by zero"); + return a % b; } /** * @dev Returns the remainder of dividing two unsigned integers. (unsigned integer modulo), - * Reverts with custom message when dividing by zero. + * reverting with custom message when dividing by zero. * * Counterpart to Solidity's `%` operator. This function uses a `revert` * opcode (which leaves remaining gas untouched) while Solidity uses an diff --git a/contracts/mocks/SafeMathMock.sol b/contracts/mocks/SafeMathMock.sol index 5d2b8d8b2da..627c68cfcf5 100644 --- a/contracts/mocks/SafeMathMock.sol +++ b/contracts/mocks/SafeMathMock.sol @@ -24,4 +24,40 @@ contract SafeMathMock { function mod(uint256 a, uint256 b) public pure returns (uint256) { return SafeMath.mod(a, b); } + + function mulMemoryCheck() public pure returns (uint256 mem) { + uint256 length = 32; + assembly { mem := mload(0x40) } + for (uint256 i = 0; i < length; ++i) { SafeMath.mul(1, 1); } + assembly { mem := sub(mload(0x40), mem) } + } + + function divMemoryCheck() public pure returns (uint256 mem) { + uint256 length = 32; + assembly { mem := mload(0x40) } + for (uint256 i = 0; i < length; ++i) { SafeMath.div(1, 1); } + assembly { mem := sub(mload(0x40), mem) } + } + + function subMemoryCheck() public pure returns (uint256 mem) { + uint256 length = 32; + assembly { mem := mload(0x40) } + for (uint256 i = 0; i < length; ++i) { SafeMath.sub(1, 1); } + assembly { mem := sub(mload(0x40), mem) } + } + + function addMemoryCheck() public pure returns (uint256 mem) { + uint256 length = 32; + assembly { mem := mload(0x40) } + for (uint256 i = 0; i < length; ++i) { SafeMath.add(1, 1); } + assembly { mem := sub(mload(0x40), mem) } + } + + function modMemoryCheck() public pure returns (uint256 mem) { + uint256 length = 32; + assembly { mem := mload(0x40) } + for (uint256 i = 0; i < length; ++i) { SafeMath.mod(1, 1); } + assembly { mem := sub(mload(0x40), mem) } + } + } diff --git a/test/math/SafeMath.test.js b/test/math/SafeMath.test.js index 44cc1ba8e7c..cda5c9ff2c0 100644 --- a/test/math/SafeMath.test.js +++ b/test/math/SafeMath.test.js @@ -143,4 +143,26 @@ contract('SafeMath', function (accounts) { await expectRevert(this.safeMath.mod(a, b), 'SafeMath: modulo by zero'); }); }); + + describe('memory leakage', function () { + it('add does not leak', async function () { + expect(await this.safeMath.addMemoryCheck()).to.be.bignumber.equal('0'); + }); + + it('sub does not leak', async function () { + expect(await this.safeMath.subMemoryCheck()).to.be.bignumber.equal('0'); + }); + + it('mul does not leak', async function () { + expect(await this.safeMath.mulMemoryCheck()).to.be.bignumber.equal('0'); + }); + + it('div does not leak', async function () { + expect(await this.safeMath.divMemoryCheck()).to.be.bignumber.equal('0'); + }); + + it('mod does not leak', async function () { + expect(await this.safeMath.modMemoryCheck()).to.be.bignumber.equal('0'); + }); + }); });