Skip to content

Commit

Permalink
perf(vault): only read totalAssets once
Browse files Browse the repository at this point in the history
  • Loading branch information
Rubilmax committed Sep 13, 2023
1 parent d14f4c3 commit 5134be6
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 30 deletions.
105 changes: 78 additions & 27 deletions contracts/SupplyVault.sol
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault {

function setFee() external timelockElapsed(pendingFee.timestamp) onlyOwner {
// Accrue interest using the previous fee set before changing it.
_accrueFee();
_updateLastTotalAssets(_accrueFee());

fee = uint96(pendingFee.value);

Expand All @@ -153,7 +153,7 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault {
require(newFeeRecipient != feeRecipient, ErrorsLib.ALREADY_SET);

// Accrue interest to the previous fee recipient set before changing it.
_accrueFee();
_updateLastTotalAssets(_accrueFee());

feeRecipient = newFeeRecipient;

Expand Down Expand Up @@ -242,16 +242,22 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault {
return _convertToShares(maxWithdraw(owner), Math.Rounding.Down);
}

function deposit(uint256 assets, address receiver) public virtual override returns (uint256) {
_accrueFee();
function deposit(uint256 assets, address receiver) public virtual override returns (uint256 shares) {
uint256 newTotalAssets = _accrueFee();

return super.deposit(assets, receiver);
shares = _convertToSharesWithFeeAccrued(assets, newTotalAssets, Math.Rounding.Down);
_deposit(_msgSender(), receiver, assets, shares);

_updateLastTotalAssets(newTotalAssets + assets);
}

function mint(uint256 shares, address receiver) public virtual override returns (uint256) {
_accrueFee();
function mint(uint256 shares, address receiver) public virtual override returns (uint256 assets) {
uint256 newTotalAssets = _accrueFee();

assets = _convertToAssetsWithFeeAccrued(shares, newTotalAssets, Math.Rounding.Up);
_deposit(_msgSender(), receiver, assets, shares);

return super.mint(shares, receiver);
_updateLastTotalAssets(newTotalAssets + assets);
}

function withdraw(uint256 assets, address receiver, address owner)
Expand All @@ -260,21 +266,25 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault {
override
returns (uint256 shares)
{
_accrueFee();
uint256 newTotalAssets = _accrueFee();

// Do not call expensive `maxWithdraw` and optimistically withdraw assets.

shares = previewWithdraw(assets);
shares = _convertToSharesWithFeeAccrued(assets, newTotalAssets, Math.Rounding.Up);
_withdraw(_msgSender(), receiver, owner, assets, shares);

_updateLastTotalAssets(newTotalAssets - assets);
}

function redeem(uint256 shares, address receiver, address owner) public virtual override returns (uint256 assets) {
_accrueFee();
uint256 newTotalAssets = _accrueFee();

// Do not call expensive `maxRedeem` and optimistically redeem shares.

assets = previewRedeem(shares);
assets = _convertToAssetsWithFeeAccrued(shares, newTotalAssets, Math.Rounding.Down);
_withdraw(_msgSender(), receiver, owner, assets, shares);

_updateLastTotalAssets(newTotalAssets - assets);
}

function totalAssets() public view override returns (uint256 assets) {
Expand Down Expand Up @@ -306,6 +316,46 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault {
super._withdraw(caller, receiver, owner, assets, shares);
}

function _convertToShares(uint256 assets, Math.Rounding rounding)
internal
view
virtual
override
returns (uint256)
{
(uint256 feeShares, uint256 newTotalAssets) = _accruedFeeShares();

return assets.mulDiv(totalSupply() + feeShares + 10 ** _decimalsOffset(), newTotalAssets + 1, rounding);
}

function _convertToAssets(uint256 shares, Math.Rounding rounding)
internal
view
virtual
override
returns (uint256)
{
(uint256 feeShares, uint256 newTotalAssets) = _accruedFeeShares();

return shares.mulDiv(newTotalAssets + 1, totalSupply() + feeShares + 10 ** _decimalsOffset(), rounding);
}

function _convertToSharesWithFeeAccrued(uint256 assets, uint256 newTotalAssets, Math.Rounding rounding)
internal
view
returns (uint256)
{
return assets.mulDiv(totalSupply() + 10 ** _decimalsOffset(), newTotalAssets + 1, rounding);
}

function _convertToAssetsWithFeeAccrued(uint256 shares, uint256 newTotalAssets, Math.Rounding rounding)
internal
view
returns (uint256)
{
return shares.mulDiv(newTotalAssets + 1, totalSupply() + 10 ** _decimalsOffset(), rounding);
}

/* INTERNAL */

function _market(Id id) internal view returns (VaultMarket storage) {
Expand Down Expand Up @@ -455,28 +505,29 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault {
}
}

function _accrueFee() internal {
if (fee == 0 || feeRecipient == address(0)) return;

(uint256 newTotalAssets, uint256 feeShares) = _accruedFeeShares();

function _updateLastTotalAssets(uint256 newTotalAssets) internal {
lastTotalAssets = newTotalAssets;

if (feeShares != 0) _mint(feeRecipient, feeShares);
emit EventsLib.UpdateLastTotalAssets(newTotalAssets);
}

function _accrueFee() internal returns (uint256 newTotalAssets) {
uint256 feeShares;
(feeShares, newTotalAssets) = _accruedFeeShares();

emit EventsLib.AccrueFee(newTotalAssets, feeShares);
if (feeShares != 0 && feeRecipient != address(0)) _mint(feeRecipient, feeShares);
}

function _accruedFeeShares() internal view returns (uint256 newTotalAssets, uint256 feeShares) {
function _accruedFeeShares() internal view returns (uint256 feeShares, uint256 newTotalAssets) {
newTotalAssets = totalAssets();
uint256 totalInterest = newTotalAssets.zeroFloorSub(lastTotalAssets);

if (totalInterest != 0) {
uint256 feeAmount = totalInterest.mulDiv(fee, WAD);
// The fee amount is subtracted from the total assets in this calculation to compensate for the fact
// that total assets is already increased by the total interest (including the fee amount).
feeShares = feeAmount.mulDiv(
totalSupply() + 10 ** _decimalsOffset(), newTotalAssets - feeAmount + 1, Math.Rounding.Down
uint256 totalInterest = newTotalAssets.zeroFloorSub(lastTotalAssets);
if (totalInterest != 0 && fee != 0) {
uint256 feeAssets = totalInterest.mulDiv(fee, WAD);
// The fee assets is subtracted from the total assets in this calculation to compensate for the fact
// that total assets is already increased by the total interest (including the fee assets).
feeShares = feeAssets.mulDiv(
totalSupply() + 10 ** _decimalsOffset(), newTotalAssets - feeAssets + 1, Math.Rounding.Down
);
}
}
Expand Down
5 changes: 2 additions & 3 deletions contracts/libraries/EventsLib.sol
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ library EventsLib {

event DisableMarket(Id id);

/// @notice Emitted when the vault's performance fee is accrued.
/// @notice Emitted when the vault's last total assets is updated.
/// @param totalAssets The total amount of assets this vault manages.
/// @param feeShares The shares minted corresponding to the fee accrued.
event AccrueFee(uint256 totalAssets, uint256 feeShares);
event UpdateLastTotalAssets(uint256 totalAssets);
}

0 comments on commit 5134be6

Please sign in to comment.