diff --git a/contracts/SupplyVault.sol b/contracts/SupplyVault.sol index 26815235..1c7ae015 100644 --- a/contracts/SupplyVault.sol +++ b/contracts/SupplyVault.sol @@ -140,7 +140,7 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault { function setFee() external timelockElapsed(pendingFee.timestamp) onlyOwner { // Accrue interest using the previous fee set before changing it. - _accrueFee(); + _updateLastTotalAssets(_accrueFee()); fee = uint96(pendingFee.value); @@ -153,7 +153,7 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault { require(newFeeRecipient != feeRecipient, ErrorsLib.ALREADY_SET); // Accrue interest to the previous fee recipient set before changing it. - _accrueFee(); + _updateLastTotalAssets(_accrueFee()); feeRecipient = newFeeRecipient; @@ -242,16 +242,22 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault { return _convertToShares(maxWithdraw(owner), Math.Rounding.Down); } - function deposit(uint256 assets, address receiver) public virtual override returns (uint256) { - _accrueFee(); + function deposit(uint256 assets, address receiver) public virtual override returns (uint256 shares) { + uint256 newTotalAssets = _accrueFee(); - return super.deposit(assets, receiver); + shares = _convertToSharesWithFeeAccrued(assets, newTotalAssets, Math.Rounding.Down); + _deposit(_msgSender(), receiver, assets, shares); + + _updateLastTotalAssets(newTotalAssets + assets); } - function mint(uint256 shares, address receiver) public virtual override returns (uint256) { - _accrueFee(); + function mint(uint256 shares, address receiver) public virtual override returns (uint256 assets) { + uint256 newTotalAssets = _accrueFee(); + + assets = _convertToAssetsWithFeeAccrued(shares, newTotalAssets, Math.Rounding.Up); + _deposit(_msgSender(), receiver, assets, shares); - return super.mint(shares, receiver); + _updateLastTotalAssets(newTotalAssets + assets); } function withdraw(uint256 assets, address receiver, address owner) @@ -260,21 +266,25 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault { override returns (uint256 shares) { - _accrueFee(); + uint256 newTotalAssets = _accrueFee(); // Do not call expensive `maxWithdraw` and optimistically withdraw assets. - shares = previewWithdraw(assets); + shares = _convertToSharesWithFeeAccrued(assets, newTotalAssets, Math.Rounding.Up); _withdraw(_msgSender(), receiver, owner, assets, shares); + + _updateLastTotalAssets(newTotalAssets - assets); } function redeem(uint256 shares, address receiver, address owner) public virtual override returns (uint256 assets) { - _accrueFee(); + uint256 newTotalAssets = _accrueFee(); // Do not call expensive `maxRedeem` and optimistically redeem shares. - assets = previewRedeem(shares); + assets = _convertToAssetsWithFeeAccrued(shares, newTotalAssets, Math.Rounding.Down); _withdraw(_msgSender(), receiver, owner, assets, shares); + + _updateLastTotalAssets(newTotalAssets - assets); } function totalAssets() public view override returns (uint256 assets) { @@ -306,6 +316,46 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault { super._withdraw(caller, receiver, owner, assets, shares); } + function _convertToShares(uint256 assets, Math.Rounding rounding) + internal + view + virtual + override + returns (uint256) + { + (uint256 feeShares, uint256 newTotalAssets) = _accruedFeeShares(); + + return assets.mulDiv(totalSupply() + feeShares + 10 ** _decimalsOffset(), newTotalAssets + 1, rounding); + } + + function _convertToAssets(uint256 shares, Math.Rounding rounding) + internal + view + virtual + override + returns (uint256) + { + (uint256 feeShares, uint256 newTotalAssets) = _accruedFeeShares(); + + return shares.mulDiv(newTotalAssets + 1, totalSupply() + feeShares + 10 ** _decimalsOffset(), rounding); + } + + function _convertToSharesWithFeeAccrued(uint256 assets, uint256 newTotalAssets, Math.Rounding rounding) + internal + view + returns (uint256) + { + return assets.mulDiv(totalSupply() + 10 ** _decimalsOffset(), newTotalAssets + 1, rounding); + } + + function _convertToAssetsWithFeeAccrued(uint256 shares, uint256 newTotalAssets, Math.Rounding rounding) + internal + view + returns (uint256) + { + return shares.mulDiv(newTotalAssets + 1, totalSupply() + 10 ** _decimalsOffset(), rounding); + } + /* INTERNAL */ function _market(Id id) internal view returns (VaultMarket storage) { @@ -455,28 +505,29 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault { } } - function _accrueFee() internal { - if (fee == 0 || feeRecipient == address(0)) return; - - (uint256 newTotalAssets, uint256 feeShares) = _accruedFeeShares(); - + function _updateLastTotalAssets(uint256 newTotalAssets) internal { lastTotalAssets = newTotalAssets; - if (feeShares != 0) _mint(feeRecipient, feeShares); + emit EventsLib.UpdateLastTotalAssets(newTotalAssets); + } + + function _accrueFee() internal returns (uint256 newTotalAssets) { + uint256 feeShares; + (feeShares, newTotalAssets) = _accruedFeeShares(); - emit EventsLib.AccrueFee(newTotalAssets, feeShares); + if (feeShares != 0 && feeRecipient != address(0)) _mint(feeRecipient, feeShares); } - function _accruedFeeShares() internal view returns (uint256 newTotalAssets, uint256 feeShares) { + function _accruedFeeShares() internal view returns (uint256 feeShares, uint256 newTotalAssets) { newTotalAssets = totalAssets(); - uint256 totalInterest = newTotalAssets.zeroFloorSub(lastTotalAssets); - if (totalInterest != 0) { - uint256 feeAmount = totalInterest.mulDiv(fee, WAD); - // The fee amount is subtracted from the total assets in this calculation to compensate for the fact - // that total assets is already increased by the total interest (including the fee amount). - feeShares = feeAmount.mulDiv( - totalSupply() + 10 ** _decimalsOffset(), newTotalAssets - feeAmount + 1, Math.Rounding.Down + uint256 totalInterest = newTotalAssets.zeroFloorSub(lastTotalAssets); + if (totalInterest != 0 && fee != 0) { + uint256 feeAssets = totalInterest.mulDiv(fee, WAD); + // The fee assets is subtracted from the total assets in this calculation to compensate for the fact + // that total assets is already increased by the total interest (including the fee assets). + feeShares = feeAssets.mulDiv( + totalSupply() + 10 ** _decimalsOffset(), newTotalAssets - feeAssets + 1, Math.Rounding.Down ); } } diff --git a/contracts/libraries/EventsLib.sol b/contracts/libraries/EventsLib.sol index b9bff7a8..d99e9112 100644 --- a/contracts/libraries/EventsLib.sol +++ b/contracts/libraries/EventsLib.sol @@ -30,8 +30,7 @@ library EventsLib { event DisableMarket(Id id); - /// @notice Emitted when the vault's performance fee is accrued. + /// @notice Emitted when the vault's last total assets is updated. /// @param totalAssets The total amount of assets this vault manages. - /// @param feeShares The shares minted corresponding to the fee accrued. - event AccrueFee(uint256 totalAssets, uint256 feeShares); + event UpdateLastTotalAssets(uint256 totalAssets); } diff --git a/lib/morpho-blue b/lib/morpho-blue index 7d60ca06..cdc0f008 160000 --- a/lib/morpho-blue +++ b/lib/morpho-blue @@ -1 +1 @@ -Subproject commit 7d60ca06e827780115d90508339de5560be9f643 +Subproject commit cdc0f0080e49949e50b87a6cd206fd73f118e7a0