diff --git a/src/MetaMorpho.sol b/src/MetaMorpho.sol index 79d49d6b..3783efcc 100644 --- a/src/MetaMorpho.sol +++ b/src/MetaMorpho.sol @@ -481,8 +481,6 @@ contract MetaMorpho is ERC4626, ERC20Permit, Ownable2Step, Multicall, IMetaMorph shares = _convertToSharesWithFeeAccrued(assets, totalSupply(), newTotalAssets, Math.Rounding.Floor); _deposit(_msgSender(), receiver, assets, shares); - - _updateLastTotalAssets(newTotalAssets + assets); } /// @inheritdoc IERC4626 @@ -491,8 +489,6 @@ contract MetaMorpho is ERC4626, ERC20Permit, Ownable2Step, Multicall, IMetaMorph assets = _convertToAssetsWithFeeAccrued(shares, totalSupply(), newTotalAssets, Math.Rounding.Ceil); _deposit(_msgSender(), receiver, assets, shares); - - _updateLastTotalAssets(newTotalAssets + assets); } /// @inheritdoc IERC4626 @@ -507,8 +503,6 @@ contract MetaMorpho is ERC4626, ERC20Permit, Ownable2Step, Multicall, IMetaMorph shares = _convertToSharesWithFeeAccrued(assets, totalSupply(), newTotalAssets, Math.Rounding.Ceil); _withdraw(_msgSender(), receiver, owner, assets, shares); - - _updateLastTotalAssets(newTotalAssets - assets); } /// @inheritdoc IERC4626 @@ -523,8 +517,6 @@ contract MetaMorpho is ERC4626, ERC20Permit, Ownable2Step, Multicall, IMetaMorph assets = _convertToAssetsWithFeeAccrued(shares, totalSupply(), newTotalAssets, Math.Rounding.Floor); _withdraw(_msgSender(), receiver, owner, assets, shares); - - _updateLastTotalAssets(newTotalAssets - assets); } /// @inheritdoc IERC4626 @@ -569,7 +561,7 @@ contract MetaMorpho is ERC4626, ERC20Permit, Ownable2Step, Multicall, IMetaMorph } /// @inheritdoc ERC4626 - /// @dev The accrual of fees is taken into account in the conversion. + /// @dev The accrual of performance fees is taken into account in the conversion. function _convertToAssets(uint256 shares, Math.Rounding rounding) internal view override returns (uint256) { (uint256 feeShares, uint256 newTotalAssets) = _accruedFeeShares(); @@ -607,6 +599,9 @@ contract MetaMorpho is ERC4626, ERC20Permit, Ownable2Step, Multicall, IMetaMorph super._deposit(caller, owner, assets, shares); _supplyMorpho(assets); + + // `newTotalAssets + assets` cannot be used as input because of rounding errors so we must use `totalAssets`. + _updateLastTotalAssets(totalAssets()); } /// @inheritdoc ERC4626 @@ -624,6 +619,9 @@ contract MetaMorpho is ERC4626, ERC20Permit, Ownable2Step, Multicall, IMetaMorph if (_withdrawMorpho(assets) != 0) revert ErrorsLib.WithdrawMorphoFailed(); super._withdraw(caller, receiver, owner, assets, shares); + + // `newTotalAssets - assets` cannot be used as input because of rounding errors so we must use `totalAssets`. + _updateLastTotalAssets(totalAssets()); } /* INTERNAL */ diff --git a/test/forge/FeeTest.sol b/test/forge/FeeTest.sol index 85388b69..7754c2c8 100644 --- a/test/forge/FeeTest.sol +++ b/test/forge/FeeTest.sol @@ -23,17 +23,17 @@ contract FeeTest is BaseTest { // Create some debt on the market to accrue interest. - loanToken.setBalance(SUPPLIER, 1); + loanToken.setBalance(SUPPLIER, MAX_TEST_ASSETS); vm.prank(SUPPLIER); - morpho.supply(marketParams, 1, 0, ONBEHALF, hex""); + morpho.supply(marketParams, MAX_TEST_ASSETS, 0, ONBEHALF, hex""); - uint256 collateral = uint256(1).wDivUp(marketParams.lltv); + uint256 collateral = uint256(MAX_TEST_ASSETS).wDivUp(marketParams.lltv); collateralToken.setBalance(BORROWER, collateral); vm.startPrank(BORROWER); morpho.supplyCollateral(marketParams, collateral, BORROWER, hex""); - morpho.borrow(marketParams, 1, 0, BORROWER, BORROWER); + morpho.borrow(marketParams, MAX_TEST_ASSETS, 0, BORROWER, BORROWER); vm.stopPrank(); } @@ -43,10 +43,10 @@ contract FeeTest is BaseTest { function _feeShares(uint256 totalAssetsBefore) internal view returns (uint256) { uint256 totalAssetsAfter = vault.totalAssets(); uint256 interest = totalAssetsAfter - totalAssetsBefore; - uint256 feeAmount = interest.wMulDown(FEE); + uint256 feeAssets = interest.mulDiv(FEE, WAD); - return feeAmount.mulDiv( - vault.totalSupply() + 10 ** DECIMALS_OFFSET, totalAssetsAfter - feeAmount + 1, Math.Rounding.Floor + return feeAssets.mulDiv( + vault.totalSupply() + 10 ** DECIMALS_OFFSET, totalAssetsAfter - feeAssets + 1, Math.Rounding.Floor ); } @@ -62,8 +62,9 @@ contract FeeTest is BaseTest { } function testAccrueFeeWithinABlock(uint256 deposited, uint256 withdrawn) public { - deposited = bound(deposited, MIN_TEST_ASSETS, MAX_TEST_ASSETS); - withdrawn = bound(withdrawn, MIN_TEST_ASSETS, deposited); + deposited = bound(deposited, MIN_TEST_ASSETS + 1, MAX_TEST_ASSETS); + // The deposited amount is rounded down on Morpho and thus cannot be withdrawn in a block in most cases. + withdrawn = bound(withdrawn, MIN_TEST_ASSETS, deposited - 1); loanToken.setBalance(SUPPLIER, deposited); @@ -73,7 +74,7 @@ contract FeeTest is BaseTest { vm.prank(ONBEHALF); vault.withdraw(withdrawn, RECEIVER, ONBEHALF); - assertEq(vault.balanceOf(FEE_RECIPIENT), 0, "vault.balanceOf(FEE_RECIPIENT)"); + assertApproxEqAbs(vault.balanceOf(FEE_RECIPIENT), 0, 1, "vault.balanceOf(FEE_RECIPIENT)"); } function testDepositAccrueFee(uint256 deposited, uint256 newDeposit, uint256 blocks) public { @@ -178,11 +179,9 @@ contract FeeTest is BaseTest { function testSetFeeAccrueFee(uint256 deposited, uint256 fee, uint256 blocks) public { deposited = bound(deposited, MIN_TEST_ASSETS, MAX_TEST_ASSETS); - fee = bound(fee, 0, MAX_FEE); + fee = bound(fee, 0, FEE - 1); blocks = _boundBlocks(blocks); - vm.assume(fee != FEE); - loanToken.setBalance(SUPPLIER, deposited); vm.prank(SUPPLIER); @@ -253,4 +252,54 @@ contract FeeTest is BaseTest { vm.expectRevert(ErrorsLib.ZeroFeeRecipient.selector); vault.setFeeRecipient(address(0)); } + + function testConvertToAssetsWithFeeAndInterest(uint256 deposited, uint256 assets, uint256 blocks) public { + deposited = bound(deposited, MIN_TEST_ASSETS, MAX_TEST_ASSETS); + assets = bound(assets, 1, MAX_TEST_ASSETS); + blocks = _boundBlocks(blocks); + + loanToken.setBalance(SUPPLIER, deposited); + + vm.prank(SUPPLIER); + vault.deposit(deposited, ONBEHALF); + + uint256 totalAssetsBefore = vault.totalAssets(); + uint256 sharesBefore = vault.convertToShares(assets); + + _forward(blocks); + + uint256 feeShares = _feeShares(totalAssetsBefore); + uint256 expectedShares = assets.mulDiv( + vault.totalSupply() + feeShares + 10 ** DECIMALS_OFFSET, vault.totalAssets() + 1, Math.Rounding.Floor + ); + uint256 shares = vault.convertToShares(assets); + + assertEq(shares, expectedShares, "shares"); + assertLt(shares, sharesBefore, "shares decreased"); + } + + function testConvertToSharesWithFeeAndInterest(uint256 deposited, uint256 shares, uint256 blocks) public { + deposited = bound(deposited, MIN_TEST_ASSETS, MAX_TEST_ASSETS); + shares = bound(shares, 10 ** DECIMALS_OFFSET, MAX_TEST_ASSETS); + blocks = _boundBlocks(blocks); + + loanToken.setBalance(SUPPLIER, deposited); + + vm.prank(SUPPLIER); + vault.deposit(deposited, ONBEHALF); + + uint256 totalAssetsBefore = vault.totalAssets(); + uint256 assetsBefore = vault.convertToAssets(shares); + + _forward(blocks); + + uint256 feeShares = _feeShares(totalAssetsBefore); + uint256 expectedAssets = shares.mulDiv( + vault.totalAssets() + 1, vault.totalSupply() + feeShares + 10 ** DECIMALS_OFFSET, Math.Rounding.Floor + ); + uint256 assets = vault.convertToAssets(shares); + + assertEq(assets, expectedAssets, "assets"); + assertGe(assets, assetsBefore, "assets increased"); + } } diff --git a/test/forge/helpers/BaseTest.sol b/test/forge/helpers/BaseTest.sol index 8ac92a4c..a0584fbc 100644 --- a/test/forge/helpers/BaseTest.sol +++ b/test/forge/helpers/BaseTest.sol @@ -149,7 +149,7 @@ contract BaseTest is Test { /// @dev Bounds the fuzzing input to a realistic number of blocks. function _boundBlocks(uint256 blocks) internal view returns (uint256) { - return bound(blocks, 1, type(uint24).max); + return bound(blocks, 2, type(uint24).max); } /// @dev Bounds the fuzzing input to a non-zero address.