diff --git a/osmoutils/binary_search.go b/osmoutils/binary_search.go new file mode 100644 index 00000000000..18c2a0866e0 --- /dev/null +++ b/osmoutils/binary_search.go @@ -0,0 +1,96 @@ +package osmoutils + +import ( + "errors" + + sdk "github.com/cosmos/cosmos-sdk/types" +) + +// ErrTolerance is used to define a compare function, which checks if two +// ints are within a certain error tolerance of one another. +// ErrTolerance.Compare(a, b) returns true iff: +// |a - b| <= AdditiveTolerance +// |a - b| / min(a, b) <= MultiplicativeTolerance +// Each check is respectively ignored if the entry is nil (sdk.Dec{}, sdk.Int{}) +// Note that if AdditiveTolerance == 0, then this is equivalent to a standard compare. +type ErrTolerance struct { + AdditiveTolerance sdk.Int + MultiplicativeTolerance sdk.Dec +} + +// Compare returns if actual is within errTolerance of expected. +// returns 0 if it is +// returns 1 if not, and expected > actual. +// returns -1 if not, and expected < actual +func (e ErrTolerance) Compare(expected sdk.Int, actual sdk.Int) int { + diff := expected.Sub(actual).Abs() + + comparisonSign := 0 + if expected.GT(actual) { + comparisonSign = 1 + } else { + comparisonSign = -1 + } + + // if no error accepted, do a direct compare. + if e.AdditiveTolerance.IsZero() { + if expected.Equal(actual) { + return 0 + } else { + return comparisonSign + } + } + + // Check additive tolerance equations + if !e.AdditiveTolerance.IsNil() && !e.AdditiveTolerance.IsZero() { + if diff.GT(e.AdditiveTolerance) { + return comparisonSign + } + } + // Check multiplicative tolerance equations + if !e.MultiplicativeTolerance.IsNil() && !e.MultiplicativeTolerance.IsZero() { + errTerm := diff.ToDec().Quo(sdk.MinInt(expected, actual).ToDec()) + if errTerm.GT(e.MultiplicativeTolerance) { + return comparisonSign + } + } + + return 0 +} + +// Binary search inputs between [lowerbound, upperbound] to a monotonic increasing function f. +// We stop once f(found_input) meets the ErrTolerance constraints. +// If we perform more than maxIterations (or equivalently lowerbound = upperbound), we return an error. +func BinarySearch(f func(input sdk.Int) (sdk.Int, error), + lowerbound sdk.Int, + upperbound sdk.Int, + targetOutput sdk.Int, + errTolerance ErrTolerance, + maxIterations int) (sdk.Int, error) { + // Setup base case of loop + curEstimate := lowerbound.Add(upperbound).QuoRaw(2) + curOutput, err := f(curEstimate) + if err != nil { + return sdk.Int{}, err + } + curIteration := 0 + for ; curIteration < maxIterations; curIteration += 1 { + compRes := errTolerance.Compare(curOutput, targetOutput) + if compRes > 0 { + upperbound = curEstimate + } else if compRes < 0 { + lowerbound = curEstimate + } else { + break + } + curEstimate = lowerbound.Add(upperbound).QuoRaw(2) + curOutput, err = f(curEstimate) + if err != nil { + return sdk.Int{}, err + } + } + if curIteration == maxIterations { + return sdk.Int{}, errors.New("hit maximum iterations, did not converge fast enough") + } + return curEstimate, nil +} diff --git a/osmoutils/binary_search_test.go b/osmoutils/binary_search_test.go new file mode 100644 index 00000000000..c125a17ba3e --- /dev/null +++ b/osmoutils/binary_search_test.go @@ -0,0 +1,154 @@ +package osmoutils + +import ( + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/require" +) + +func TestBinarySearch(t *testing.T) { + // straight line function that returns input. Simplest to binary search on, + // binary search directly reveals one bit of the answer in each iteration with this function. + lineF := func(a sdk.Int) (sdk.Int, error) { + return a, nil + } + noErrTolerance := ErrTolerance{AdditiveTolerance: sdk.ZeroInt()} + tests := []struct { + f func(sdk.Int) (sdk.Int, error) + lowerbound sdk.Int + upperbound sdk.Int + targetOutput sdk.Int + errTolerance ErrTolerance + maxIterations int + + expectedSolvedInput sdk.Int + expectErr bool + }{ + {lineF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), noErrTolerance, 51, sdk.NewInt(1 + (1 << 25)), false}, + {lineF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), noErrTolerance, 10, sdk.Int{}, true}, + } + + for _, tc := range tests { + actualSolvedInput, err := BinarySearch(tc.f, tc.lowerbound, tc.upperbound, tc.targetOutput, tc.errTolerance, tc.maxIterations) + if tc.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.True(sdk.IntEq(t, tc.expectedSolvedInput, actualSolvedInput)) + } + } +} + +func TestBinarySearchNonlinear(t *testing.T) { + // straight line function that returns input. Simplest to binary search on, + // binary search directly reveals one bit of the answer in each iteration with this function. + lineF := func(a sdk.Int) (sdk.Int, error) { + return a, nil + } + noErrTolerance := ErrTolerance{AdditiveTolerance: sdk.ZeroInt()} + tests := []struct { + f func(sdk.Int) (sdk.Int, error) + lowerbound sdk.Int + upperbound sdk.Int + targetOutput sdk.Int + errTolerance ErrTolerance + maxIterations int + + expectedSolvedInput sdk.Int + expectErr bool + }{ + {lineF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), noErrTolerance, 51, sdk.NewInt(1 + (1 << 25)), false}, + {lineF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), noErrTolerance, 10, sdk.Int{}, true}, + } + + for _, tc := range tests { + actualSolvedInput, err := BinarySearch(tc.f, tc.lowerbound, tc.upperbound, tc.targetOutput, tc.errTolerance, tc.maxIterations) + if tc.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.True(sdk.IntEq(t, tc.expectedSolvedInput, actualSolvedInput)) + } + } +} + +func TestBinarySearchNonlinearNonzero(t *testing.T) { + // non-linear function that returns input. Simplest to binary search on, + // binary search directly reveals one bit of the answer in each iteration with this function. + lineF := func(a sdk.Int) (sdk.Int, error) { + return a, nil + } + noErrTolerance := ErrTolerance{AdditiveTolerance: sdk.ZeroInt()} + tests := []struct { + f func(sdk.Int) (sdk.Int, error) + lowerbound sdk.Int + upperbound sdk.Int + targetOutput sdk.Int + errTolerance ErrTolerance + maxIterations int + + expectedSolvedInput sdk.Int + expectErr bool + }{ + {lineF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), noErrTolerance, 51, sdk.NewInt(1 + (1 << 25)), false}, + {lineF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), noErrTolerance, 10, sdk.Int{}, true}, + } + + for _, tc := range tests { + actualSolvedInput, err := BinarySearch(tc.f, tc.lowerbound, tc.upperbound, tc.targetOutput, tc.errTolerance, tc.maxIterations) + if tc.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.True(sdk.IntEq(t, tc.expectedSolvedInput, actualSolvedInput)) + } + } +} + +func TestErrTolerance_Compare(t *testing.T) { + ZeroErrTolerance := ErrTolerance{AdditiveTolerance: sdk.ZeroInt(), MultiplicativeTolerance: sdk.Dec{}} + tests := []struct { + name string + tol ErrTolerance + input sdk.Int + reference sdk.Int + + expectedCompareResult int + }{ + {"0 tolerance: <", ZeroErrTolerance, sdk.NewInt(1000), sdk.NewInt(1001), -1}, + {"0 tolerance: =", ZeroErrTolerance, sdk.NewInt(1001), sdk.NewInt(1001), 0}, + {"0 tolerance: >", ZeroErrTolerance, sdk.NewInt(1002), sdk.NewInt(1001), 1}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.tol.Compare(tt.input, tt.reference); got != tt.expectedCompareResult { + t.Errorf("ErrTolerance.Compare() = %v, want %v", got, tt.expectedCompareResult) + } + }) + } +} + +func TestErrToleranceNonzero_Compare(t *testing.T) { + // Nonzero error tolerance test + NonZeroErrTolerance := ErrTolerance{AdditiveTolerance: sdk.NewInt(10), MultiplicativeTolerance: sdk.Dec{}} + tests := []struct { + name string + tol ErrTolerance + input sdk.Int + reference sdk.Int + + expectedCompareResult int + }{ + {"Nonzero tolerance: <", NonZeroErrTolerance, sdk.NewInt(420), sdk.NewInt(1001), -1}, + {"Nonzero tolerance: =", NonZeroErrTolerance, sdk.NewInt(1002), sdk.NewInt(1001), 0}, + {"Nonzero tolerance: >", NonZeroErrTolerance, sdk.NewInt(1230), sdk.NewInt(1001), 1}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.tol.Compare(tt.input, tt.reference); got != tt.expectedCompareResult { + t.Errorf("ErrTolerance.Compare() = %v, want %v", got, tt.expectedCompareResult) + } + }) + } +} diff --git a/x/gamm/pool-models/balancer/amm.go b/x/gamm/pool-models/balancer/amm.go index 2e034cb52ff..a8e6c9a9b40 100644 --- a/x/gamm/pool-models/balancer/amm.go +++ b/x/gamm/pool-models/balancer/amm.go @@ -7,6 +7,7 @@ import ( sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" "github.com/osmosis-labs/osmosis/v7/osmomath" + "github.com/osmosis-labs/osmosis/v7/x/gamm/pool-models/internal/cfmm_common" "github.com/osmosis-labs/osmosis/v7/x/gamm/types" ) @@ -246,48 +247,6 @@ func (p *Pool) calcSingleAssetJoin(tokenIn sdk.Coin, swapFee sdk.Dec, tokenInPoo ).TruncateInt(), nil } -func (p *Pool) maximalExactRatioJoin(tokensIn sdk.Coins) (numShares sdk.Int, remCoins sdk.Coins, err error) { - coinShareRatios := make([]sdk.Dec, len(tokensIn), len(tokensIn)) - minShareRatio := sdk.MaxSortableDec - maxShareRatio := sdk.ZeroDec() - - poolLiquidity := p.GetTotalPoolLiquidity(sdk.Context{}) - - for i, coin := range tokensIn { - shareRatio := coin.Amount.ToDec().QuoInt(poolLiquidity.AmountOfNoDenomValidation(coin.Denom)) - if shareRatio.LT(minShareRatio) { - minShareRatio = shareRatio - } - if shareRatio.GT(maxShareRatio) { - maxShareRatio = shareRatio - } - coinShareRatios[i] = shareRatio - } - - remCoins = sdk.Coins{} - if minShareRatio.Equal(sdk.MaxSortableDec) { - return numShares, remCoins, errors.New("unexpected error in balancer maximalExactRatioJoin") - } - numShares = minShareRatio.MulInt(p.TotalShares.Amount).TruncateInt() - - // if we have multiple shares, calculate remCoins - if !minShareRatio.Equal(maxShareRatio) { - // we have to calculate remCoins - for i, coin := range tokensIn { - if !coinShareRatios[i].Equal(minShareRatio) { - usedAmount := minShareRatio.MulInt(coin.Amount).Ceil().TruncateInt() - newAmt := coin.Amount.Sub(usedAmount) - // add to RemCoins - if !newAmt.IsZero() { - remCoins = remCoins.Add(sdk.Coin{Denom: coin.Denom, Amount: newAmt}) - } - } - } - } - - return numShares, remCoins, nil -} - func (p *Pool) JoinPool(_ctx sdk.Context, tokensIn sdk.Coins, swapFee sdk.Dec) (numShares sdk.Int, err error) { numShares, newLiquidity, err := p.CalcJoinPoolShares(_ctx, tokensIn, swapFee) if err != nil { @@ -313,8 +272,8 @@ func (p *Pool) CalcJoinPoolShares(_ctx sdk.Context, tokensIn sdk.Coins, swapFee return sdk.ZeroInt(), sdk.NewCoins(), errors.New( "balancer pool only supports LP'ing with one asset, or all assets in pool") } - // Add all exact coins we can (no swap) - numShares, remCoins, err := p.maximalExactRatioJoin(tokensIn) + // Add all exact coins we can (no swap). ctx arg doesn't matter for Balancer + numShares, remCoins, err := cfmm_common.MaximalExactRatioJoin(p, sdk.Context{}, tokensIn) if err != nil { return sdk.ZeroInt(), sdk.NewCoins(), err } @@ -369,31 +328,7 @@ func (p *Pool) exitPool(ctx sdk.Context, exitingCoins sdk.Coins, exitingShares s } func (p *Pool) CalcExitPoolShares(ctx sdk.Context, exitingShares sdk.Int, exitFee sdk.Dec) (exitedCoins sdk.Coins, err error) { - totalShares := p.GetTotalShares() - if exitingShares.GTE(totalShares) { - return sdk.Coins{}, sdkerrors.Wrapf(types.ErrLimitMaxAmount, errMsgFormatSharesLargerThanMax, exitingShares.Int64(), totalShares.Uint64()) - } - - refundedShares := exitingShares - if !exitFee.IsZero() { - // exitingShares * (1 - exit fee) - // Todo: make a -1 constant - oneSubExitFee := sdk.OneDec().Sub(exitFee) - refundedShares = oneSubExitFee.MulInt(exitingShares).TruncateInt() - } - - shareOutRatio := refundedShares.ToDec().QuoInt(totalShares) - // Make it shareOutRatio * pool LP balances - exitedCoins = sdk.Coins{} - balances := p.GetTotalPoolLiquidity(ctx) - for _, asset := range balances { - exitAmt := shareOutRatio.MulInt(asset.Amount).TruncateInt() - if exitAmt.LTE(sdk.ZeroInt()) { - continue - } - exitedCoins = exitedCoins.Add(sdk.NewCoin(asset.Denom, exitAmt)) - } - return exitedCoins, nil + return cfmm_common.CalcExitPool(ctx, p, exitingShares, exitFee) } // feeRatio returns the fee ratio that is defined as follows: diff --git a/x/gamm/pool-models/internal/cfmm_common/lp.go b/x/gamm/pool-models/internal/cfmm_common/lp.go new file mode 100644 index 00000000000..18ea63e782a --- /dev/null +++ b/x/gamm/pool-models/internal/cfmm_common/lp.go @@ -0,0 +1,165 @@ +package cfmm_common + +import ( + "errors" + + sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" + + "github.com/osmosis-labs/osmosis/v7/osmoutils" + "github.com/osmosis-labs/osmosis/v7/x/gamm/types" +) + +const errMsgFormatSharesLargerThanMax = "%d resulted shares is larger than the max amount of %d" + +// CalcExitPool returns how many tokens should come out, when exiting k LP shares against a "standard" CFMM +func CalcExitPool(ctx sdk.Context, pool types.PoolI, exitingShares sdk.Int, exitFee sdk.Dec) (sdk.Coins, error) { + totalShares := pool.GetTotalShares() + if exitingShares.GTE(totalShares) { + return sdk.Coins{}, sdkerrors.Wrapf(types.ErrLimitMaxAmount, errMsgFormatSharesLargerThanMax, exitingShares.Int64(), totalShares.Uint64()) + } + + // refundedShares = exitingShares * (1 - exit fee) + // with 0 exit fee optimization + var refundedShares sdk.Dec + if !exitFee.IsZero() { + // exitingShares * (1 - exit fee) + oneSubExitFee := sdk.OneDec().SubMut(exitFee) + refundedShares = oneSubExitFee.MulIntMut(exitingShares) + } else { + refundedShares = exitingShares.ToDec() + } + + shareOutRatio := refundedShares.QuoInt(totalShares) + // exitedCoins = shareOutRatio * pool liquidity + exitedCoins := sdk.Coins{} + poolLiquidity := pool.GetTotalPoolLiquidity(ctx) + + for _, asset := range poolLiquidity { + // round down here, due to not wanting to over-exit + exitAmt := shareOutRatio.MulInt(asset.Amount).TruncateInt() + if exitAmt.LTE(sdk.ZeroInt()) { + continue + } + if exitAmt.GTE(asset.Amount) { + return sdk.Coins{}, errors.New("too many shares out") + } + exitedCoins = exitedCoins.Add(sdk.NewCoin(asset.Denom, exitAmt)) + } + + return exitedCoins, nil +} + +// MaximalExactRatioJoin LP's the maximal amount of tokens in possible, and returns the number of shares that'd be +// and how many coins would be left over. +func MaximalExactRatioJoin(p types.PoolI, ctx sdk.Context, tokensIn sdk.Coins) (numShares sdk.Int, remCoins sdk.Coins, err error) { + coinShareRatios := make([]sdk.Dec, len(tokensIn)) + minShareRatio := sdk.MaxSortableDec + maxShareRatio := sdk.ZeroDec() + + poolLiquidity := p.GetTotalPoolLiquidity(ctx) + totalShares := p.GetTotalShares() + + for i, coin := range tokensIn { + shareRatio := coin.Amount.ToDec().QuoInt(poolLiquidity.AmountOfNoDenomValidation(coin.Denom)) + if shareRatio.LT(minShareRatio) { + minShareRatio = shareRatio + } + if shareRatio.GT(maxShareRatio) { + maxShareRatio = shareRatio + } + coinShareRatios[i] = shareRatio + } + + if minShareRatio.Equal(sdk.MaxSortableDec) { + return numShares, remCoins, errors.New("unexpected error in MaximalExactRatioJoin") + } + + remCoins = sdk.Coins{} + numShares = minShareRatio.MulInt(totalShares).TruncateInt() + + // if we have multiple share values, calculate remainingCoins + if !minShareRatio.Equal(maxShareRatio) { + // we have to calculate remCoins + for i, coin := range tokensIn { + // if coinShareRatios[i] == minShareRatio, no remainder + if coinShareRatios[i].Equal(minShareRatio) { + continue + } + + usedAmount := minShareRatio.MulInt(coin.Amount).Ceil().TruncateInt() + newAmt := coin.Amount.Sub(usedAmount) + // if newAmt is non-zero, add to RemCoins. (It could be zero due to rounding) + if !newAmt.IsZero() { + remCoins = remCoins.Add(sdk.Coin{Denom: coin.Denom, Amount: newAmt}) + } + } + } + + return numShares, remCoins, nil +} + +// We binary search a number of LP shares, s.t. if we exited the pool with the updated liquidity, +// and swapped all the tokens back to the input denom, we'd get the same amount. (under 0 swap fee) +// Thanks to CFMM path-independence, we can estimate slippage with these swaps to be sure to get the right numbers here. +// (by path-independence, swap all of B -> A, and then swap all of C -> A will yield same amount of A, regardless +// of order and interleaving) +// +// This implementation requires each of pool.GetTotalPoolLiquidity, pool.ExitPool, and pool.SwapExactAmountIn +// to not update or read from state, and instead only do updates based upon the pool struct. +func BinarySearchSingleAssetJoin( + pool types.PoolI, + tokenIn sdk.Coin, + poolWithAddedLiquidityAndShares func(newLiquidity sdk.Coin, newShares sdk.Int) types.PoolI, +) (numLPShares sdk.Int, err error) { + // use dummy context + ctx := sdk.Context{} + // Need to get something that makes the result correct within 1 LP share + // If we fail to reach it within maxIterations, we return an error + correctnessThreshold := sdk.NewInt(2) + maxIterations := 300 + // upperbound of number of LP shares = existingShares * tokenIn.Amount / pool.totalLiquidity.AmountOf(tokenIn.Denom) + existingTokenLiquidity := pool.GetTotalPoolLiquidity(ctx).AmountOf(tokenIn.Denom) + existingLPShares := pool.GetTotalShares() + LPShareUpperBound := existingLPShares.Mul(tokenIn.Amount).ToDec().QuoInt(existingTokenLiquidity).Ceil().TruncateInt() + LPShareLowerBound := sdk.ZeroInt() + + // Creates a pool with tokenIn liquidity added, where it created `sharesIn` number of shares. + // Returns how many tokens you'd get, if you then exited all of `sharesIn` for tokenIn.Denom + estimateCoinOutGivenShares := func(sharesIn sdk.Int) (tokenOut sdk.Int, err error) { + // new pool with added liquidity & LP shares, which we can mutate. + poolWithUpdatedLiquidity := poolWithAddedLiquidityAndShares(tokenIn, sharesIn) + swapToDenom := tokenIn.Denom + // so now due to correctness of exitPool, we exitPool and swap all remaining assets to base asset + exitFee := sdk.ZeroDec() + exitedCoins, err := poolWithUpdatedLiquidity.ExitPool(ctx, sharesIn, exitFee) + if err != nil { + return sdk.Int{}, err + } + + return swapAllCoinsToSingleAsset(poolWithUpdatedLiquidity, ctx, exitedCoins, swapToDenom) + } + // TODO: Come back and revisit err tolerance + errTolerance := osmoutils.ErrTolerance{AdditiveTolerance: correctnessThreshold, MultiplicativeTolerance: sdk.Dec{}} + numLPShares, err = osmoutils.BinarySearch( + estimateCoinOutGivenShares, + LPShareLowerBound, LPShareUpperBound, tokenIn.Amount, errTolerance, maxIterations) + + return numLPShares, err +} + +func swapAllCoinsToSingleAsset(pool types.PoolI, ctx sdk.Context, inTokens sdk.Coins, swapToDenom string) (sdk.Int, error) { + swapFee := sdk.ZeroDec() + tokenOutAmt := inTokens.AmountOfNoDenomValidation(swapToDenom) + for _, coin := range inTokens { + if coin.Denom == swapToDenom { + continue + } + tokenOut, err := pool.SwapOutAmtGivenIn(ctx, sdk.NewCoins(coin), swapToDenom, swapFee) + if err != nil { + return sdk.Int{}, err + } + tokenOutAmt = tokenOutAmt.Add(tokenOut.Amount) + } + return tokenOutAmt, nil +} diff --git a/x/gamm/pool-models/stableswap/amm.go b/x/gamm/pool-models/stableswap/amm.go index 431516bca3a..d1e4b786710 100644 --- a/x/gamm/pool-models/stableswap/amm.go +++ b/x/gamm/pool-models/stableswap/amm.go @@ -1,7 +1,12 @@ package stableswap import ( + "errors" + sdk "github.com/cosmos/cosmos-sdk/types" + + "github.com/osmosis-labs/osmosis/v7/x/gamm/pool-models/internal/cfmm_common" + types "github.com/osmosis-labs/osmosis/v7/x/gamm/types" ) var ( @@ -261,7 +266,7 @@ func approxDecEqual(a, b, tol sdk.Dec) bool { var ( twodec = sdk.MustNewDecFromStr("2.0") - threshold = sdk.MustNewDecFromStr("0.00001") // 0.001% + threshold = sdk.NewDecWithPrec(1, 10) // Correct within a factor of 1 * 10^{-10} ) // solveCFMMBinarySearch searches the correct dx using binary search over constant K. @@ -338,8 +343,7 @@ func (pa *Pool) calcOutAmtGivenIn(tokenIn sdk.Coin, tokenOutDenom string, swapFe if err != nil { return sdk.Dec{}, err } - tokenInSupply := reserves[0].ToDec() - tokenOutSupply := reserves[1].ToDec() + tokenInSupply, tokenOutSupply := reserves[0], reserves[1] // We are solving for the amount of token out, hence x = tokenOutSupply, y = tokenInSupply cfmmOut := solveCfmm(tokenOutSupply, tokenInSupply, tokenIn.Amount.ToDec()) outAmt := pa.getDescaledPoolAmt(tokenOutDenom, cfmmOut) @@ -352,11 +356,51 @@ func (pa *Pool) calcInAmtGivenOut(tokenOut sdk.Coin, tokenInDenom string, swapFe if err != nil { return sdk.Dec{}, err } - tokenInSupply := reserves[0].ToDec() - tokenOutSupply := reserves[1].ToDec() + tokenInSupply, tokenOutSupply := reserves[0], reserves[1] // We are solving for the amount of token in, cfmm(x,y) = cfmm(x + x_in, y - y_out) // x = tokenInSupply, y = tokenOutSupply, yIn = -tokenOutAmount cfmmIn := solveCfmm(tokenInSupply, tokenOutSupply, tokenOut.Amount.ToDec().Neg()) inAmt := pa.getDescaledPoolAmt(tokenInDenom, cfmmIn.NegMut()) return inAmt, nil } + +func (pa *Pool) calcSingleAssetJoinShares(tokenIn sdk.Coin, swapFee sdk.Dec) (sdk.Int, error) { + poolWithAddedLiquidityAndShares := func(newLiquidity sdk.Coin, newShares sdk.Int) types.PoolI { + paCopy := pa.Copy() + paCopy.updatePoolForJoin(sdk.NewCoins(tokenIn), newShares) + return &paCopy + } + // TODO: Correctly handle swap fee + return cfmm_common.BinarySearchSingleAssetJoin(pa, tokenIn, poolWithAddedLiquidityAndShares) +} + +// We can mutate pa here +// TODO: some day switch this to a COW wrapped pa, for better perf +func (pa *Pool) joinPoolSharesInternal(ctx sdk.Context, tokensIn sdk.Coins, swapFee sdk.Dec) (numShares sdk.Int, newLiquidity sdk.Coins, err error) { + if len(tokensIn) == 1 { + numShares, err = pa.calcSingleAssetJoinShares(tokensIn[0], swapFee) + newLiquidity = tokensIn + return numShares, newLiquidity, err + } else if len(tokensIn) != pa.NumAssets() { + return sdk.ZeroInt(), sdk.NewCoins(), errors.New( + "stableswap pool only supports LP'ing with one asset, or all assets in pool") + } + + // Add all exact coins we can (no swap). ctx arg doesn't matter for Stableswap + numShares, remCoins, err := cfmm_common.MaximalExactRatioJoin(pa, sdk.Context{}, tokensIn) + if err != nil { + return sdk.ZeroInt(), sdk.NewCoins(), err + } + pa.updatePoolForJoin(tokensIn.Sub(remCoins), numShares) + + for _, coin := range remCoins { + // TODO: Perhaps add a method to skip if this is too small. + newShare, err := pa.calcSingleAssetJoinShares(coin, swapFee) + if err != nil { + return sdk.ZeroInt(), sdk.NewCoins(), err + } + pa.updatePoolForJoin(sdk.NewCoins(coin), newShare) + } + + return numShares, tokensIn, nil +} diff --git a/x/gamm/pool-models/stableswap/pool.go b/x/gamm/pool-models/stableswap/pool.go index b41bd492fa9..e94b1ebb3f3 100644 --- a/x/gamm/pool-models/stableswap/pool.go +++ b/x/gamm/pool-models/stableswap/pool.go @@ -9,6 +9,7 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" + "github.com/osmosis-labs/osmosis/v7/x/gamm/pool-models/internal/cfmm_common" "github.com/osmosis-labs/osmosis/v7/x/gamm/types" ) @@ -64,6 +65,10 @@ func (pa Pool) GetScalingFactorByLiquidityIndex(liquidityIndex int) uint64 { return pa.ScalingFactor[liquidityIndex] } +func (pa Pool) NumAssets() int { + return len(pa.PoolLiquidity) +} + // returns pool liquidity of the provided denoms, in the same order the denoms were provided in func (pa Pool) getPoolAmts(denoms ...string) ([]sdk.Int, error) { result := make([]sdk.Int, len(denoms)) @@ -79,8 +84,8 @@ func (pa Pool) getPoolAmts(denoms ...string) ([]sdk.Int, error) { } // getScaledPoolAmts returns scaled amount of pool liquidity based on each asset's precisions -func (pa Pool) getScaledPoolAmts(denoms ...string) ([]sdk.Int, error) { - result := make([]sdk.Int, len(denoms)) +func (pa Pool) getScaledPoolAmts(denoms ...string) ([]sdk.Dec, error) { + result := make([]sdk.Dec, len(denoms)) poolLiquidity := pa.PoolLiquidity liquidityIndexes := pa.getLiquidityIndexMap() @@ -89,10 +94,10 @@ func (pa Pool) getScaledPoolAmts(denoms ...string) ([]sdk.Int, error) { amt := poolLiquidity.AmountOf(denom) if amt.IsZero() { - return []sdk.Int{}, fmt.Errorf("denom %s does not exist in pool", denom) + return []sdk.Dec{}, fmt.Errorf("denom %s does not exist in pool", denom) } scalingFactor := pa.GetScalingFactorByLiquidityIndex(liquidityIndex) - result[i] = amt.QuoRaw(int64(scalingFactor)) + result[i] = amt.ToDec().QuoInt64Mut(int64(scalingFactor)) } return result, nil } @@ -130,6 +135,18 @@ func (p *Pool) updatePoolLiquidityForSwap(tokensIn sdk.Coins, tokensOut sdk.Coin } } +// updatePoolLiquidityForExit updates the pool liquidity after an exit. +// The function sanity checks that not all tokens of a given denom are removed, +// and panics if thats the case. +func (p *Pool) updatePoolLiquidityForExit(tokensOut sdk.Coins) { + p.updatePoolLiquidityForSwap(sdk.Coins{}, tokensOut) +} + +func (p *Pool) updatePoolForJoin(tokensIn sdk.Coins, newShares sdk.Int) { + p.PoolLiquidity = p.PoolLiquidity.Add(tokensIn...) + p.TotalShares.Amount = p.TotalShares.Amount.Add(newShares) +} + // TODO: These should all get moved to amm.go func (pa Pool) CalcOutAmtGivenIn(ctx sdk.Context, tokenIn sdk.Coins, tokenOutDenom string, swapFee sdk.Dec) (tokenOut sdk.Coin, err error) { if tokenIn.Len() != 1 { @@ -195,26 +212,42 @@ func (pa Pool) SpotPrice(ctx sdk.Context, baseAssetDenom string, quoteAssetDenom if err != nil { return sdk.Dec{}, err } - scaledSpotPrice := spotPrice(reserves[0].ToDec(), reserves[1].ToDec()) + scaledSpotPrice := spotPrice(reserves[0], reserves[1]) spotPrice := pa.getDescaledPoolAmt(baseAssetDenom, scaledSpotPrice) return spotPrice, nil } -func (pa Pool) CalcJoinPoolShares(ctx sdk.Context, tokensIn sdk.Coins, swapFee sdk.Dec) (numShares sdk.Int, newLiquidity sdk.Coins, err error) { - return sdk.Int{}, sdk.Coins{}, types.ErrNotImplemented +func (pa Pool) Copy() Pool { + pa2 := pa + pa2.PoolLiquidity = sdk.NewCoins(pa.PoolLiquidity...) + return pa2 +} + +func (pa *Pool) CalcJoinPoolShares(ctx sdk.Context, tokensIn sdk.Coins, swapFee sdk.Dec) (numShares sdk.Int, newLiquidity sdk.Coins, err error) { + paCopy := pa.Copy() + return paCopy.joinPoolSharesInternal(ctx, tokensIn, swapFee) } func (pa *Pool) JoinPool(ctx sdk.Context, tokensIn sdk.Coins, swapFee sdk.Dec) (numShares sdk.Int, err error) { - return sdk.Int{}, types.ErrNotImplemented + numShares, _, err = pa.joinPoolSharesInternal(ctx, tokensIn, swapFee) + return numShares, err } -func (pa *Pool) ExitPool(ctx sdk.Context, numShares sdk.Int, exitFee sdk.Dec) (exitedCoins sdk.Coins, err error) { - return sdk.Coins{}, types.ErrNotImplemented +func (pa *Pool) ExitPool(ctx sdk.Context, exitingShares sdk.Int, exitFee sdk.Dec) (exitingCoins sdk.Coins, err error) { + exitingCoins, err = pa.CalcExitPoolShares(ctx, exitingShares, exitFee) + if err != nil { + return sdk.Coins{}, err + } + + pa.TotalShares.Amount = pa.TotalShares.Amount.Sub(exitingShares) + pa.updatePoolLiquidityForExit(exitingCoins) + + return exitingCoins, nil } -func (pa Pool) CalcExitPoolShares(ctx sdk.Context, numShares sdk.Int, exitFee sdk.Dec) (exitedCoins sdk.Coins, err error) { - return sdk.Coins{}, types.ErrNotImplemented +func (pa Pool) CalcExitPoolShares(ctx sdk.Context, exitingShares sdk.Int, exitFee sdk.Dec) (exitingCoins sdk.Coins, err error) { + return cfmm_common.CalcExitPool(ctx, &pa, exitingShares, exitFee) } // no-op for stableswap