Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(vault): only read totalAssets once #42

Merged
merged 1 commit into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 78 additions & 27 deletions contracts/SupplyVault.sol
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault {

function setFee() external timelockElapsed(pendingFee.timestamp) onlyOwner {
// Accrue interest using the previous fee set before changing it.
_accrueFee();
_updateLastTotalAssets(_accrueFee());

fee = uint96(pendingFee.value);

Expand All @@ -153,7 +153,7 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault {
require(newFeeRecipient != feeRecipient, ErrorsLib.ALREADY_SET);

// Accrue interest to the previous fee recipient set before changing it.
_accrueFee();
_updateLastTotalAssets(_accrueFee());

feeRecipient = newFeeRecipient;

Expand Down Expand Up @@ -242,16 +242,22 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault {
return _convertToShares(maxWithdraw(owner), Math.Rounding.Down);
}

function deposit(uint256 assets, address receiver) public virtual override returns (uint256) {
_accrueFee();
function deposit(uint256 assets, address receiver) public virtual override returns (uint256 shares) {
uint256 newTotalAssets = _accrueFee();

return super.deposit(assets, receiver);
shares = _convertToSharesWithFeeAccrued(assets, newTotalAssets, Math.Rounding.Down);
_deposit(_msgSender(), receiver, assets, shares);

_updateLastTotalAssets(newTotalAssets + assets);
Rubilmax marked this conversation as resolved.
Show resolved Hide resolved
}

function mint(uint256 shares, address receiver) public virtual override returns (uint256) {
_accrueFee();
function mint(uint256 shares, address receiver) public virtual override returns (uint256 assets) {
uint256 newTotalAssets = _accrueFee();

assets = _convertToAssetsWithFeeAccrued(shares, newTotalAssets, Math.Rounding.Up);
_deposit(_msgSender(), receiver, assets, shares);

return super.mint(shares, receiver);
_updateLastTotalAssets(newTotalAssets + assets);
}

function withdraw(uint256 assets, address receiver, address owner)
Expand All @@ -260,21 +266,25 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault {
override
returns (uint256 shares)
{
_accrueFee();
uint256 newTotalAssets = _accrueFee();

// Do not call expensive `maxWithdraw` and optimistically withdraw assets.

shares = previewWithdraw(assets);
shares = _convertToSharesWithFeeAccrued(assets, newTotalAssets, Math.Rounding.Up);
_withdraw(_msgSender(), receiver, owner, assets, shares);

_updateLastTotalAssets(newTotalAssets - assets);
Rubilmax marked this conversation as resolved.
Show resolved Hide resolved
}

function redeem(uint256 shares, address receiver, address owner) public virtual override returns (uint256 assets) {
_accrueFee();
uint256 newTotalAssets = _accrueFee();

// Do not call expensive `maxRedeem` and optimistically redeem shares.

assets = previewRedeem(shares);
assets = _convertToAssetsWithFeeAccrued(shares, newTotalAssets, Math.Rounding.Down);
_withdraw(_msgSender(), receiver, owner, assets, shares);

_updateLastTotalAssets(newTotalAssets - assets);
Rubilmax marked this conversation as resolved.
Show resolved Hide resolved
}

function totalAssets() public view override returns (uint256 assets) {
Expand Down Expand Up @@ -306,6 +316,46 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault {
super._withdraw(caller, receiver, owner, assets, shares);
}

function _convertToShares(uint256 assets, Math.Rounding rounding)
internal
view
virtual
override
returns (uint256)
{
(uint256 feeShares, uint256 newTotalAssets) = _accruedFeeShares();

return assets.mulDiv(totalSupply() + feeShares + 10 ** _decimalsOffset(), newTotalAssets + 1, rounding);
}

function _convertToAssets(uint256 shares, Math.Rounding rounding)
internal
view
virtual
override
returns (uint256)
{
(uint256 feeShares, uint256 newTotalAssets) = _accruedFeeShares();

return shares.mulDiv(newTotalAssets + 1, totalSupply() + feeShares + 10 ** _decimalsOffset(), rounding);
}

function _convertToSharesWithFeeAccrued(uint256 assets, uint256 newTotalAssets, Math.Rounding rounding)
internal
view
returns (uint256)
{
return assets.mulDiv(totalSupply() + 10 ** _decimalsOffset(), newTotalAssets + 1, rounding);
}

function _convertToAssetsWithFeeAccrued(uint256 shares, uint256 newTotalAssets, Math.Rounding rounding)
internal
view
returns (uint256)
{
return shares.mulDiv(newTotalAssets + 1, totalSupply() + 10 ** _decimalsOffset(), rounding);
}

/* INTERNAL */

function _market(Id id) internal view returns (VaultMarket storage) {
Expand Down Expand Up @@ -455,28 +505,29 @@ contract SupplyVault is ERC4626, Ownable2Step, ISupplyVault {
}
}

function _accrueFee() internal {
if (fee == 0 || feeRecipient == address(0)) return;

(uint256 newTotalAssets, uint256 feeShares) = _accruedFeeShares();

function _updateLastTotalAssets(uint256 newTotalAssets) internal {
lastTotalAssets = newTotalAssets;

if (feeShares != 0) _mint(feeRecipient, feeShares);
emit EventsLib.UpdateLastTotalAssets(newTotalAssets);
}

function _accrueFee() internal returns (uint256 newTotalAssets) {
uint256 feeShares;
(feeShares, newTotalAssets) = _accruedFeeShares();

emit EventsLib.AccrueFee(newTotalAssets, feeShares);
if (feeShares != 0 && feeRecipient != address(0)) _mint(feeRecipient, feeShares);
}

function _accruedFeeShares() internal view returns (uint256 newTotalAssets, uint256 feeShares) {
function _accruedFeeShares() internal view returns (uint256 feeShares, uint256 newTotalAssets) {
newTotalAssets = totalAssets();
uint256 totalInterest = newTotalAssets.zeroFloorSub(lastTotalAssets);

if (totalInterest != 0) {
uint256 feeAmount = totalInterest.mulDiv(fee, WAD);
// The fee amount is subtracted from the total assets in this calculation to compensate for the fact
// that total assets is already increased by the total interest (including the fee amount).
feeShares = feeAmount.mulDiv(
totalSupply() + 10 ** _decimalsOffset(), newTotalAssets - feeAmount + 1, Math.Rounding.Down
uint256 totalInterest = newTotalAssets.zeroFloorSub(lastTotalAssets);
if (totalInterest != 0 && fee != 0) {
MerlinEgalite marked this conversation as resolved.
Show resolved Hide resolved
uint256 feeAssets = totalInterest.mulDiv(fee, WAD);
// The fee assets is subtracted from the total assets in this calculation to compensate for the fact
// that total assets is already increased by the total interest (including the fee assets).
feeShares = feeAssets.mulDiv(
totalSupply() + 10 ** _decimalsOffset(), newTotalAssets - feeAssets + 1, Math.Rounding.Down
);
}
}
Expand Down
5 changes: 2 additions & 3 deletions contracts/libraries/EventsLib.sol
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ library EventsLib {

event DisableMarket(Id id);

/// @notice Emitted when the vault's performance fee is accrued.
/// @notice Emitted when the vault's last total assets is updated.
/// @param totalAssets The total amount of assets this vault manages.
/// @param feeShares The shares minted corresponding to the fee accrued.
event AccrueFee(uint256 totalAssets, uint256 feeShares);
event UpdateLastTotalAssets(uint256 totalAssets);
}
Loading