Skip to content

Commit

Permalink
perf: remove repeated reallocations in swap step iterations (#5211)
Browse files Browse the repository at this point in the history
* perf: remove repeated reallocations in swap step iterations

* smallest dec

* unused

* comment

* Update x/concentrated-liquidity/swaps.go

Co-authored-by: Dev Ojha <ValarDragon@users.noreply.github.com>

* Update x/concentrated-liquidity/swaps.go

Co-authored-by: Dev Ojha <ValarDragon@users.noreply.github.com>

---------

Co-authored-by: Dev Ojha <ValarDragon@users.noreply.github.com>
  • Loading branch information
p0mvn and ValarDragon authored May 18, 2023
1 parent cc54de1 commit 17fa8e9
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 88 deletions.
26 changes: 5 additions & 21 deletions x/concentrated-liquidity/math/math.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ import (
"github.com/osmosis-labs/osmosis/osmomath"
)

var smallestDec = sdk.SmallestDec()

// liquidity0 takes an amount of asset0 in the pool as well as the sqrtpCur and the nextPrice
// sqrtPriceA is the smaller of sqrtpCur and the nextPrice
// sqrtPriceB is the larger of sqrtpCur and the nextPrice
Expand Down Expand Up @@ -81,14 +79,14 @@ func CalcAmount0Delta(liq, sqrtPriceA, sqrtPriceB sdk.Dec, roundUp bool) sdk.Dec
// - adding liquidity (request user to provide more tokens in in favor of the pool)
// The denominator is truncated to get a higher final amount.
denom := sqrtPriceA.MulTruncate(sqrtPriceB)
return liq.Mul(diff).Quo(denom).Ceil()
return liq.Mul(diff).QuoMut(denom).Ceil()
}
// These are truncated at precision end to round in favor of the pool when:
// - calculating amount out during swap
// - withdrawing liquidity
// The denominator is rounded up to get a smaller final amount.
denom := sqrtPriceA.MulRoundUp(sqrtPriceB)
return liq.MulTruncate(diff).QuoTruncate(denom)
return liq.MulTruncate(diff).QuoTruncateMut(denom)
}

// CalcAmount1 takes the asset with the smaller liquidity in the pool as well as the sqrtpCur and the nextPrice and calculates the amount of asset 1
Expand Down Expand Up @@ -133,8 +131,8 @@ func GetNextSqrtPriceFromAmount0InRoundingUp(sqrtPriceCurrent, liquidity, amount
}

product := amountZeroRemainingIn.Mul(sqrtPriceCurrent)
denominator := liquidity.Add(product)
return liquidity.Mul(sqrtPriceCurrent).QuoRoundUp(denominator)
denominator := product.AddMut(liquidity)
return liquidity.Mul(sqrtPriceCurrent).QuoRoundupMut(denominator)
}

// GetNextSqrtPriceFromAmount0OutRoundingUp utilizes sqrtPriceCurrent, liquidity, and amount of denom0 that still needs
Expand All @@ -149,7 +147,7 @@ func GetNextSqrtPriceFromAmount0OutRoundingUp(sqrtPriceCurrent, liquidity, amoun

product := amountZeroRemainingOut.Mul(sqrtPriceCurrent)
denominator := liquidity.Sub(product)
return liquidity.Mul(sqrtPriceCurrent).QuoRoundUp(denominator)
return liquidity.Mul(sqrtPriceCurrent).QuoRoundupMut(denominator)
}

// GetNextSqrtPriceFromAmount1InRoundingDown utilizes the current sqrtPriceCurrent, liquidity, and amount of denom1 that still needs
Expand Down Expand Up @@ -195,20 +193,6 @@ func GetLiquidityFromAmounts(sqrtPrice, sqrtPriceA, sqrtPriceB sdk.Dec, amount0,
return liquidity
}

// AddLiquidity adds or subtracts liquidityB from liquidityA, depending on whether liquidityB is positive or negative.
func AddLiquidity(liquidityA, liquidityB sdk.Dec) (finalLiquidity sdk.Dec) {
if liquidityB.LT(sdk.ZeroDec()) {
return liquidityA.Sub(liquidityB.Abs())
}
return liquidityA.Add(liquidityB)
}

// MulRoundUp multiplies a by b and rounds up to the nearest integer
// at precision end.
func MulRoundUp(a, b sdk.Dec) sdk.Dec {
return a.MulTruncate(b).Add(smallestDec)
}

// SquareRoundUp squares and rounds up at precision end.
func SquareRoundUp(sqrtPrice sdk.Dec) sdk.Dec {
return sqrtPrice.MulRoundUp(sqrtPrice)
Expand Down
43 changes: 0 additions & 43 deletions x/concentrated-liquidity/math/math_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,49 +80,6 @@ func (suite *ConcentratedMathTestSuite) TestLiquidity0() {
}
}

func (suite *ConcentratedMathTestSuite) TestAddLiquidity() {
testCases := map[string]struct {
inputLiqA sdk.Dec
inputLiqB sdk.Dec

expectedOutout sdk.Dec
}{
"happy path": {
inputLiqA: sdk.MustNewDecFromStr("1000000000"),
inputLiqB: sdk.MustNewDecFromStr("300000999"),

expectedOutout: sdk.MustNewDecFromStr("1300000999"),
},
"second value negative": {
inputLiqA: sdk.MustNewDecFromStr("1000000000"),
inputLiqB: sdk.MustNewDecFromStr("-300000999"),

expectedOutout: sdk.MustNewDecFromStr("699999001"),
},
"first value negative": {
inputLiqA: sdk.MustNewDecFromStr("-1000000000"),
inputLiqB: sdk.MustNewDecFromStr("300000999"),

expectedOutout: sdk.MustNewDecFromStr("-699999001"),
},
"both values negative": {
inputLiqA: sdk.MustNewDecFromStr("-1000000000"),
inputLiqB: sdk.MustNewDecFromStr("-300000999"),

expectedOutout: sdk.MustNewDecFromStr("-1300000999"),
},
}

for name, tc := range testCases {
tc := tc

suite.Run(name, func() {
actualOutput := math.AddLiquidity(tc.inputLiqA, tc.inputLiqB)
suite.Require().Equal(tc.expectedOutout, actualOutput)
})
}
}

// TestGetNextSqrtPriceFromAmount0RoundingUp tests that getNextSqrtPriceFromAmount0RoundingUp utilizes
// the current squareRootPrice, liquidity of denom0, and amount of denom0 that still needs
// to be swapped in order to determine the next squareRootPrice
Expand Down
59 changes: 36 additions & 23 deletions x/concentrated-liquidity/swaps.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ type SwapState struct {
feeGrowthGlobal sdk.Dec
}

var (
smallestDec = sdk.SmallestDec()
)

// updateFeeGrowthGlobal updates the swap state's fee growth global per unit of liquidity
// when liquidity is positive.
//
Expand Down Expand Up @@ -263,9 +267,12 @@ func (k Keeper) computeOutAmtGivenIn(
return sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, types.NoSpotPriceWhenNoLiquidityError{PoolId: poolId}
}

asset0 := p.GetToken0()
asset1 := p.GetToken1()
tokenAmountInSpecified := tokenInMin.Amount.ToDec()
var (
tickSpacing = p.GetTickSpacing()
asset0 = p.GetToken0()
asset1 = p.GetToken1()
tokenAmountInSpecified = tokenInMin.Amount.ToDec()
)

// If swapping asset0 for asset1, zeroForOne is true
zeroForOne := tokenInMin.Denom == asset0
Expand Down Expand Up @@ -309,8 +316,11 @@ func (k Keeper) computeOutAmtGivenIn(
// initialize swap state with the following parameters:
// as we iterate through the following for loop, this swap state will get updated after each required iteration
swapState := SwapState{
amountSpecifiedRemaining: tokenAmountInSpecified, // tokenIn
amountCalculated: sdk.ZeroDec(), // tokenOut
// N.B. We clone tokenAmountInSpecified because swapState.amountSpecifiedRemaining will get
// mutated during the compute swap step loop. However, we still need the original
// value of tokenAmountInSpecified to calculate the amount of tokenOut to return.
amountSpecifiedRemaining: tokenAmountInSpecified.Clone(), // tokenIn
amountCalculated: sdk.ZeroDec(), // tokenOut
sqrtPrice: curSqrtPrice,
// Pad (or don't pad) current tick based on swap direction to avoid off-by-one errors
tick: swapStrategy.InitializeTickValue(p.GetCurrentTick()),
Expand All @@ -328,7 +338,7 @@ func (k Keeper) computeOutAmtGivenIn(
// TODO: for now, we check if amountSpecifiedRemaining is GT 0.0000001. This is because there are times when the remaining
// amount may be extremely small, and that small amount cannot generate and amountIn/amountOut and we are therefore left
// in an infinite loop.
for swapState.amountSpecifiedRemaining.GT(sdk.SmallestDec()) && !swapState.sqrtPrice.Equal(sqrtPriceLimit) {
for swapState.amountSpecifiedRemaining.GT(smallestDec) && !swapState.sqrtPrice.Equal(sqrtPriceLimit) {
// Log the sqrtPrice we start the iteration with
sqrtPriceStart := swapState.sqrtPrice

Expand Down Expand Up @@ -378,15 +388,15 @@ func (k Keeper) computeOutAmtGivenIn(
// Update the swapState with the new sqrtPrice from the above swap
swapState.sqrtPrice = sqrtPrice
// We deduct the amount of tokens we input in the computeSwapStep above from the user's defined tokenIn amount
swapState.amountSpecifiedRemaining = swapState.amountSpecifiedRemaining.Sub(amountIn.Add(feeCharge))
swapState.amountSpecifiedRemaining.SubMut(amountIn.Add(feeCharge))
// We add the amount of tokens we received (amountOut) from the computeSwapStep above to the amountCalculated accumulator
swapState.amountCalculated = swapState.amountCalculated.Add(amountOut)
swapState.amountCalculated.AddMut(amountOut)

// If the computeSwapStep calculated a sqrtPrice that is equal to the nextSqrtPrice, this means all liquidity in the current
// tick has been consumed and we must move on to the next tick to complete the swap
if nextTickSqrtPrice.Equal(sqrtPrice) {
// Retrieve the liquidity held in the next closest initialized tick
liquidityNet, err := k.crossTick(ctx, p.GetId(), nextTick, sdk.NewDecCoinFromDec(tokenInMin.Denom, swapState.feeGrowthGlobal))
liquidityNet, err := k.crossTick(ctx, poolId, nextTick, sdk.NewDecCoinFromDec(tokenInMin.Denom, swapState.feeGrowthGlobal))
if err != nil {
return sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, err
}
Expand All @@ -399,7 +409,7 @@ func (k Keeper) computeOutAmtGivenIn(

liquidityNet = swapStrategy.SetLiquidityDeltaSign(liquidityNet)
// Update the swapState's liquidity with the new tick's liquidity
newLiquidity := math.AddLiquidity(swapState.liquidity, liquidityNet)
newLiquidity := swapState.liquidity.AddMut(liquidityNet)
swapState.liquidity = newLiquidity

// Update the swapState's tick with the tick we retrieved liquidity from
Expand All @@ -408,7 +418,7 @@ func (k Keeper) computeOutAmtGivenIn(
// Otherwise if the sqrtPrice calculated from computeSwapStep does not equal the sqrtPrice we started with at the
// beginning of this iteration, we set the swapState tick to the corresponding tick of the sqrtPrice calculated from computeSwapStep
price := sqrtPrice.Mul(sqrtPrice)
swapState.tick, err = math.PriceToTickRoundDown(price, p.GetTickSpacing())
swapState.tick, err = math.PriceToTickRoundDown(price, tickSpacing)
if err != nil {
return sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, err
}
Expand Down Expand Up @@ -459,8 +469,12 @@ func (k Keeper) computeInAmtGivenOut(
return sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, types.NoSpotPriceWhenNoLiquidityError{PoolId: poolId}
}

asset0 := p.GetToken0()
asset1 := p.GetToken1()
var (
tickSpacing = p.GetTickSpacing()
asset0 = p.GetToken0()
asset1 = p.GetToken1()
tokenAmountOutSpecified = desiredTokenOut.Amount.ToDec()
)

// if swapping asset0 (in) for asset1 (out), zeroForOne is true
zeroForOne := desiredTokenOut.Denom == asset1
Expand All @@ -479,7 +493,7 @@ func (k Keeper) computeInAmtGivenOut(
}

// set the swap strategy
swapStrategy := swapstrategy.New(zeroForOne, sqrtPriceLimit, k.storeKey, swapFee, p.GetTickSpacing())
swapStrategy := swapstrategy.New(zeroForOne, sqrtPriceLimit, k.storeKey, swapFee, tickSpacing)

// get current sqrt price from pool
curSqrtPrice := p.GetCurrentSqrtPrice()
Expand All @@ -504,8 +518,8 @@ func (k Keeper) computeInAmtGivenOut(
// initialize swap state with the following parameters:
// as we iterate through the following for loop, this swap state will get updated after each required iteration
swapState := SwapState{
amountSpecifiedRemaining: desiredTokenOut.Amount.ToDec(), // tokenOut
amountCalculated: sdk.ZeroDec(), // tokenIn
amountSpecifiedRemaining: tokenAmountOutSpecified, // tokenOut
amountCalculated: sdk.ZeroDec(), // tokenIn
sqrtPrice: curSqrtPrice,
tick: swapStrategy.InitializeTickValue(p.GetCurrentTick()),
liquidity: p.GetLiquidity(),
Expand All @@ -514,7 +528,7 @@ func (k Keeper) computeInAmtGivenOut(

// TODO: This should be GT 0 but some instances have very small remainder
// need to look into fixing this
for swapState.amountSpecifiedRemaining.GT(sdk.SmallestDec()) && !swapState.sqrtPrice.Equal(sqrtPriceLimit) {
for swapState.amountSpecifiedRemaining.GT(smallestDec) && !swapState.sqrtPrice.Equal(sqrtPriceLimit) {
// log the sqrtPrice we start the iteration with
sqrtPriceStart := swapState.sqrtPrice

Expand Down Expand Up @@ -556,29 +570,28 @@ func (k Keeper) computeInAmtGivenOut(

// update the swapState with the new sqrtPrice from the above swap
swapState.sqrtPrice = sqrtPrice
swapState.amountSpecifiedRemaining = swapState.amountSpecifiedRemaining.Sub(amountOut)
swapState.amountCalculated = swapState.amountCalculated.Add(amountIn.Add(feeChargeTotal))
swapState.amountSpecifiedRemaining = swapState.amountSpecifiedRemaining.SubMut(amountOut)
swapState.amountCalculated = swapState.amountCalculated.AddMut(amountIn.Add(feeChargeTotal))

// if the computeSwapStep calculated a sqrtPrice that is equal to the nextSqrtPrice, this means all liquidity in the current
// tick has been consumed and we must move on to the next tick to complete the swap
if sqrtPriceNextTick.Equal(sqrtPrice) {
// retrieve the liquidity held in the next closest initialized tick
liquidityNet, err := k.crossTick(ctx, p.GetId(), nextTick, sdk.NewDecCoinFromDec(desiredTokenOut.Denom, swapState.feeGrowthGlobal))
liquidityNet, err := k.crossTick(ctx, poolId, nextTick, sdk.NewDecCoinFromDec(desiredTokenOut.Denom, swapState.feeGrowthGlobal))
if err != nil {
return sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, err
}
liquidityNet = swapStrategy.SetLiquidityDeltaSign(liquidityNet)
// update the swapState's liquidity with the new tick's liquidity
newLiquidity := math.AddLiquidity(swapState.liquidity, liquidityNet)
swapState.liquidity = newLiquidity
swapState.liquidity = swapState.liquidity.AddMut(liquidityNet)

// update the swapState's tick with the tick we retrieved liquidity from
swapState.tick = nextTick
} else if !sqrtPriceStart.Equal(sqrtPrice) {
// otherwise if the sqrtPrice calculated from computeSwapStep does not equal the sqrtPrice we started with at the
// beginning of this iteration, we set the swapState tick to the corresponding tick of the sqrtPrice calculated from computeSwapStep
price := sqrtPrice.Mul(sqrtPrice)
swapState.tick, err = math.PriceToTickRoundDown(price, p.GetTickSpacing())
swapState.tick, err = math.PriceToTickRoundDown(price, tickSpacing)
if err != nil {
return sdk.Coin{}, sdk.Coin{}, 0, sdk.Dec{}, sdk.Dec{}, err
}
Expand Down
2 changes: 1 addition & 1 deletion x/concentrated-liquidity/tick.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (k Keeper) initOrUpdateTick(ctx sdk.Context, poolId uint64, currentTick int

// note that liquidityIn can be either positive or negative.
// If negative, this would work as a subtraction from liquidityBefore
liquidityAfter := math.AddLiquidity(liquidityBefore, liquidityDelta)
liquidityAfter := liquidityBefore.Add(liquidityDelta)

tickInfo.LiquidityGross = liquidityAfter

Expand Down

0 comments on commit 17fa8e9

Please sign in to comment.