diff --git a/contracts/SupplyVault.sol b/contracts/SupplyVault.sol index c9f822f9..aae006d7 100644 --- a/contracts/SupplyVault.sol +++ b/contracts/SupplyVault.sol @@ -55,9 +55,8 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault { Pending public pendingTimelock; uint256 public timelock; - /// @dev Stores the total assets owned by this vault when the fee was last accrued. + /// @dev Stores the total assets this vault manage when it last handled a deposit/withdraw. uint256 public lastTotalAssets; - uint256 public lastUpdateTimestamp; ConfigSet private _config; @@ -98,6 +97,12 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault { _; } + modifier updateLastTotalAssets() { + _; + + lastTotalAssets = totalAssets(); + } + /* ONLY OWNER FUNCTIONS */ function submitPendingTimelock(uint256 newTimelock) external onlyOwner { @@ -139,7 +144,7 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault { emit EventsLib.SubmitPendingFee(newFee); } - function setFee() external timelockElapsed(pendingFee.timestamp) onlyOwner { + function setFee() external timelockElapsed(pendingFee.timestamp) onlyOwner updateLastTotalAssets { // Accrue interest using the previous fee set before changing it. _accrueFee(); @@ -150,7 +155,7 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault { delete pendingFee; } - function setFeeRecipient(address newFeeRecipient) external onlyOwner { + function setFeeRecipient(address newFeeRecipient) external onlyOwner updateLastTotalAssets { require(newFeeRecipient != feeRecipient, ErrorsLib.ALREADY_SET); // Accrue interest to the previous fee recipient set before changing it. @@ -231,11 +236,9 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault { return _market(id).cap; } - /* ERC4626 */ + /* ERC4626 (PUBLIC) */ function maxWithdraw(address owner) public view virtual override returns (uint256) { - _accruedFeeShares(); - return _staticWithdrawOrder(super.maxWithdraw(owner)); } @@ -243,16 +246,14 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault { return _convertToShares(maxWithdraw(owner), Math.Rounding.Down); } - function deposit(uint256 assets, address receiver) public virtual override returns (uint256) { - _accrueFee(); - - return super.deposit(assets, receiver); + function deposit(uint256 assets, address receiver) public virtual override returns (uint256 shares) { + shares = _convertToSharesWithFeeAccrued(assets, Math.Rounding.Down); + _deposit(_msgSender(), receiver, assets, shares); } - function mint(uint256 shares, address receiver) public virtual override returns (uint256) { - _accrueFee(); - - return super.mint(shares, receiver); + function mint(uint256 shares, address receiver) public virtual override returns (uint256 assets) { + assets = _convertToAssetsWithFeeAccrued(shares, Math.Rounding.Up); + _deposit(_msgSender(), receiver, assets, shares); } function withdraw(uint256 assets, address receiver, address owner) @@ -261,20 +262,16 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault { override returns (uint256 shares) { - _accrueFee(); - // Do not call expensive `maxWithdraw` and optimistically withdraw assets. - shares = previewWithdraw(assets); + shares = _convertToSharesWithFeeAccrued(assets, Math.Rounding.Up); _withdraw(_msgSender(), receiver, owner, assets, shares); } function redeem(uint256 shares, address receiver, address owner) public virtual override returns (uint256 assets) { - _accrueFee(); - // Do not call expensive `maxRedeem` and optimistically redeem shares. - assets = previewRedeem(shares); + assets = _convertToAssetsWithFeeAccrued(shares, Math.Rounding.Down); _withdraw(_msgSender(), receiver, owner, assets, shares); } @@ -290,8 +287,14 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault { assets += ERC20(asset()).balanceOf(address(this)); } + /* ERC4626 (INTERNAL) */ + /// @dev Used in mint or deposit to deposit the underlying asset to Blue markets. - function _deposit(address caller, address owner, uint256 assets, uint256 shares) internal override { + function _deposit(address caller, address owner, uint256 assets, uint256 shares) + internal + override + updateLastTotalAssets + { super._deposit(caller, owner, assets, shares); require(_depositOrder(assets) == 0, ErrorsLib.DEPOSIT_ORDER_FAILED); @@ -301,12 +304,45 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault { function _withdraw(address caller, address receiver, address owner, uint256 assets, uint256 shares) internal override + updateLastTotalAssets { require(_withdrawOrder(assets) == 0, ErrorsLib.WITHDRAW_ORDER_FAILED); super._withdraw(caller, receiver, owner, assets, shares); } + function _convertToShares(uint256 assets, Math.Rounding rounding) + internal + view + virtual + override + returns (uint256) + { + return assets.mulDiv(totalSupply() + _accruedFeeShares() + 10 ** _decimalsOffset(), totalAssets() + 1, rounding); + } + + function _convertToAssets(uint256 shares, Math.Rounding rounding) + internal + view + virtual + override + returns (uint256) + { + return shares.mulDiv(totalAssets() + 1, totalSupply() + _accruedFeeShares() + 10 ** _decimalsOffset(), rounding); + } + + function _convertToSharesWithFeeAccrued(uint256 assets, Math.Rounding rounding) internal returns (uint256) { + _accrueFee(); + + return assets.mulDiv(totalSupply() + 10 ** _decimalsOffset(), totalAssets() + 1, rounding); + } + + function _convertToAssetsWithFeeAccrued(uint256 shares, Math.Rounding rounding) internal returns (uint256) { + _accrueFee(); + + return shares.mulDiv(totalAssets() + 1, totalSupply() + 10 ** _decimalsOffset(), rounding); + } + /* INTERNAL */ function _market(Id id) internal view returns (VaultMarket storage) { @@ -457,30 +493,23 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault { } function _accrueFee() internal { - uint256 lastUpdate = lastUpdateTimestamp; - lastUpdateTimestamp = block.timestamp; - - if (lastUpdate == block.timestamp || fee == 0 || feeRecipient == address(0)) return; + if (fee == 0 || feeRecipient == address(0)) return; - (uint256 newTotalAssets, uint256 feeShares) = _accruedFeeShares(); - - lastTotalAssets = newTotalAssets; + uint256 feeShares = _accruedFeeShares(); if (feeShares != 0) _mint(feeRecipient, feeShares); - - emit EventsLib.AccrueFee(newTotalAssets, feeShares); } - function _accruedFeeShares() internal view returns (uint256 newTotalAssets, uint256 feeShares) { - newTotalAssets = totalAssets(); + function _accruedFeeShares() internal view returns (uint256 feeShares) { + uint256 newTotalAssets = totalAssets(); uint256 totalInterest = newTotalAssets.zeroFloorSub(lastTotalAssets); if (totalInterest != 0) { - uint256 feeAmount = totalInterest.mulDiv(fee, WAD); - // The fee amount is subtracted from the total assets in this calculation to compensate for the fact - // that total assets is already increased by the total interest (including the fee amount). - feeShares = feeAmount.mulDiv( - totalSupply() + 10 ** _decimalsOffset(), newTotalAssets - feeAmount + 1, Math.Rounding.Down + uint256 feeAssets = totalInterest.mulDiv(fee, WAD); + // The fee assets is subtracted from the total assets in this calculation to compensate for the fact + // that total assets is already increased by the total interest (including the fee assets). + feeShares = feeAssets.mulDiv( + totalSupply() + 10 ** _decimalsOffset(), newTotalAssets - feeAssets + 1, Math.Rounding.Down ); } } diff --git a/contracts/libraries/EventsLib.sol b/contracts/libraries/EventsLib.sol index b9bff7a8..62cd7611 100644 --- a/contracts/libraries/EventsLib.sol +++ b/contracts/libraries/EventsLib.sol @@ -29,9 +29,4 @@ library EventsLib { event SetCap(uint128 cap); event DisableMarket(Id id); - - /// @notice Emitted when the vault's performance fee is accrued. - /// @param totalAssets The total amount of assets this vault manages. - /// @param feeShares The shares minted corresponding to the fee accrued. - event AccrueFee(uint256 totalAssets, uint256 feeShares); } diff --git a/contracts/mocks/IrmMock.sol b/contracts/mocks/IrmMock.sol index da1125c6..bc1bd2f1 100644 --- a/contracts/mocks/IrmMock.sol +++ b/contracts/mocks/IrmMock.sol @@ -18,8 +18,7 @@ contract IrmMock is IIrm { function borrowRateView(MarketParams memory, Market memory market) public view returns (uint256) { uint256 utilization = market.totalBorrowAssets.wDivDown(market.totalSupplyAssets); - // Divide by the number of seconds in a year. - // This is a very simple model where x% utilization corresponds to x% APR. + // When rate is zero, x% utilization corresponds to x% APR. return rate == 0 ? utilization / 365 days : rate / 365 days; }