Skip to content

Commit

Permalink
Merge branch 'main' into mattverse/disable-stableswap
Browse files Browse the repository at this point in the history
  • Loading branch information
alexanderbez committed May 17, 2022
2 parents dcf6216 + f304ed8 commit 597fab5
Show file tree
Hide file tree
Showing 6 changed files with 513 additions and 86 deletions.
96 changes: 96 additions & 0 deletions osmoutils/binary_search.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package osmoutils

import (
"errors"

sdk "github.com/cosmos/cosmos-sdk/types"
)

// ErrTolerance is used to define a compare function, which checks if two
// ints are within a certain error tolerance of one another.
// ErrTolerance.Compare(a, b) returns true iff:
// |a - b| <= AdditiveTolerance
// |a - b| / min(a, b) <= MultiplicativeTolerance
// Each check is respectively ignored if the entry is nil (sdk.Dec{}, sdk.Int{})
// Note that if AdditiveTolerance == 0, then this is equivalent to a standard compare.
type ErrTolerance struct {
AdditiveTolerance sdk.Int
MultiplicativeTolerance sdk.Dec
}

// Compare returns if actual is within errTolerance of expected.
// returns 0 if it is
// returns 1 if not, and expected > actual.
// returns -1 if not, and expected < actual
func (e ErrTolerance) Compare(expected sdk.Int, actual sdk.Int) int {
diff := expected.Sub(actual).Abs()

comparisonSign := 0
if expected.GT(actual) {
comparisonSign = 1
} else {
comparisonSign = -1
}

// if no error accepted, do a direct compare.
if e.AdditiveTolerance.IsZero() {
if expected.Equal(actual) {
return 0
} else {
return comparisonSign
}
}

// Check additive tolerance equations
if !e.AdditiveTolerance.IsNil() && !e.AdditiveTolerance.IsZero() {
if diff.GT(e.AdditiveTolerance) {
return comparisonSign
}
}
// Check multiplicative tolerance equations
if !e.MultiplicativeTolerance.IsNil() && !e.MultiplicativeTolerance.IsZero() {
errTerm := diff.ToDec().Quo(sdk.MinInt(expected, actual).ToDec())
if errTerm.GT(e.MultiplicativeTolerance) {
return comparisonSign
}
}

return 0
}

// Binary search inputs between [lowerbound, upperbound] to a monotonic increasing function f.
// We stop once f(found_input) meets the ErrTolerance constraints.
// If we perform more than maxIterations (or equivalently lowerbound = upperbound), we return an error.
func BinarySearch(f func(input sdk.Int) (sdk.Int, error),
lowerbound sdk.Int,
upperbound sdk.Int,
targetOutput sdk.Int,
errTolerance ErrTolerance,
maxIterations int) (sdk.Int, error) {
// Setup base case of loop
curEstimate := lowerbound.Add(upperbound).QuoRaw(2)
curOutput, err := f(curEstimate)
if err != nil {
return sdk.Int{}, err
}
curIteration := 0
for ; curIteration < maxIterations; curIteration += 1 {
compRes := errTolerance.Compare(curOutput, targetOutput)
if compRes > 0 {
upperbound = curEstimate
} else if compRes < 0 {
lowerbound = curEstimate
} else {
break
}
curEstimate = lowerbound.Add(upperbound).QuoRaw(2)
curOutput, err = f(curEstimate)
if err != nil {
return sdk.Int{}, err
}
}
if curIteration == maxIterations {
return sdk.Int{}, errors.New("hit maximum iterations, did not converge fast enough")
}
return curEstimate, nil
}
154 changes: 154 additions & 0 deletions osmoutils/binary_search_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package osmoutils

import (
"testing"

sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/stretchr/testify/require"
)

func TestBinarySearch(t *testing.T) {
// straight line function that returns input. Simplest to binary search on,
// binary search directly reveals one bit of the answer in each iteration with this function.
lineF := func(a sdk.Int) (sdk.Int, error) {
return a, nil
}
noErrTolerance := ErrTolerance{AdditiveTolerance: sdk.ZeroInt()}
tests := []struct {
f func(sdk.Int) (sdk.Int, error)
lowerbound sdk.Int
upperbound sdk.Int
targetOutput sdk.Int
errTolerance ErrTolerance
maxIterations int

expectedSolvedInput sdk.Int
expectErr bool
}{
{lineF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), noErrTolerance, 51, sdk.NewInt(1 + (1 << 25)), false},
{lineF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), noErrTolerance, 10, sdk.Int{}, true},
}

for _, tc := range tests {
actualSolvedInput, err := BinarySearch(tc.f, tc.lowerbound, tc.upperbound, tc.targetOutput, tc.errTolerance, tc.maxIterations)
if tc.expectErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.True(sdk.IntEq(t, tc.expectedSolvedInput, actualSolvedInput))
}
}
}

func TestBinarySearchNonlinear(t *testing.T) {
// straight line function that returns input. Simplest to binary search on,
// binary search directly reveals one bit of the answer in each iteration with this function.
lineF := func(a sdk.Int) (sdk.Int, error) {
return a, nil
}
noErrTolerance := ErrTolerance{AdditiveTolerance: sdk.ZeroInt()}
tests := []struct {
f func(sdk.Int) (sdk.Int, error)
lowerbound sdk.Int
upperbound sdk.Int
targetOutput sdk.Int
errTolerance ErrTolerance
maxIterations int

expectedSolvedInput sdk.Int
expectErr bool
}{
{lineF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), noErrTolerance, 51, sdk.NewInt(1 + (1 << 25)), false},
{lineF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), noErrTolerance, 10, sdk.Int{}, true},
}

for _, tc := range tests {
actualSolvedInput, err := BinarySearch(tc.f, tc.lowerbound, tc.upperbound, tc.targetOutput, tc.errTolerance, tc.maxIterations)
if tc.expectErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.True(sdk.IntEq(t, tc.expectedSolvedInput, actualSolvedInput))
}
}
}

func TestBinarySearchNonlinearNonzero(t *testing.T) {
// non-linear function that returns input. Simplest to binary search on,
// binary search directly reveals one bit of the answer in each iteration with this function.
lineF := func(a sdk.Int) (sdk.Int, error) {
return a, nil
}
noErrTolerance := ErrTolerance{AdditiveTolerance: sdk.ZeroInt()}
tests := []struct {
f func(sdk.Int) (sdk.Int, error)
lowerbound sdk.Int
upperbound sdk.Int
targetOutput sdk.Int
errTolerance ErrTolerance
maxIterations int

expectedSolvedInput sdk.Int
expectErr bool
}{
{lineF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), noErrTolerance, 51, sdk.NewInt(1 + (1 << 25)), false},
{lineF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), noErrTolerance, 10, sdk.Int{}, true},
}

for _, tc := range tests {
actualSolvedInput, err := BinarySearch(tc.f, tc.lowerbound, tc.upperbound, tc.targetOutput, tc.errTolerance, tc.maxIterations)
if tc.expectErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.True(sdk.IntEq(t, tc.expectedSolvedInput, actualSolvedInput))
}
}
}

func TestErrTolerance_Compare(t *testing.T) {
ZeroErrTolerance := ErrTolerance{AdditiveTolerance: sdk.ZeroInt(), MultiplicativeTolerance: sdk.Dec{}}
tests := []struct {
name string
tol ErrTolerance
input sdk.Int
reference sdk.Int

expectedCompareResult int
}{
{"0 tolerance: <", ZeroErrTolerance, sdk.NewInt(1000), sdk.NewInt(1001), -1},
{"0 tolerance: =", ZeroErrTolerance, sdk.NewInt(1001), sdk.NewInt(1001), 0},
{"0 tolerance: >", ZeroErrTolerance, sdk.NewInt(1002), sdk.NewInt(1001), 1},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.tol.Compare(tt.input, tt.reference); got != tt.expectedCompareResult {
t.Errorf("ErrTolerance.Compare() = %v, want %v", got, tt.expectedCompareResult)
}
})
}
}

func TestErrToleranceNonzero_Compare(t *testing.T) {
// Nonzero error tolerance test
NonZeroErrTolerance := ErrTolerance{AdditiveTolerance: sdk.NewInt(10), MultiplicativeTolerance: sdk.Dec{}}
tests := []struct {
name string
tol ErrTolerance
input sdk.Int
reference sdk.Int

expectedCompareResult int
}{
{"Nonzero tolerance: <", NonZeroErrTolerance, sdk.NewInt(420), sdk.NewInt(1001), -1},
{"Nonzero tolerance: =", NonZeroErrTolerance, sdk.NewInt(1002), sdk.NewInt(1001), 0},
{"Nonzero tolerance: >", NonZeroErrTolerance, sdk.NewInt(1230), sdk.NewInt(1001), 1},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.tol.Compare(tt.input, tt.reference); got != tt.expectedCompareResult {
t.Errorf("ErrTolerance.Compare() = %v, want %v", got, tt.expectedCompareResult)
}
})
}
}
73 changes: 4 additions & 69 deletions x/gamm/pool-models/balancer/amm.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"

"github.com/osmosis-labs/osmosis/v7/osmomath"
"github.com/osmosis-labs/osmosis/v7/x/gamm/pool-models/internal/cfmm_common"
"github.com/osmosis-labs/osmosis/v7/x/gamm/types"
)

Expand Down Expand Up @@ -246,48 +247,6 @@ func (p *Pool) calcSingleAssetJoin(tokenIn sdk.Coin, swapFee sdk.Dec, tokenInPoo
).TruncateInt(), nil
}

func (p *Pool) maximalExactRatioJoin(tokensIn sdk.Coins) (numShares sdk.Int, remCoins sdk.Coins, err error) {
coinShareRatios := make([]sdk.Dec, len(tokensIn), len(tokensIn))
minShareRatio := sdk.MaxSortableDec
maxShareRatio := sdk.ZeroDec()

poolLiquidity := p.GetTotalPoolLiquidity(sdk.Context{})

for i, coin := range tokensIn {
shareRatio := coin.Amount.ToDec().QuoInt(poolLiquidity.AmountOfNoDenomValidation(coin.Denom))
if shareRatio.LT(minShareRatio) {
minShareRatio = shareRatio
}
if shareRatio.GT(maxShareRatio) {
maxShareRatio = shareRatio
}
coinShareRatios[i] = shareRatio
}

remCoins = sdk.Coins{}
if minShareRatio.Equal(sdk.MaxSortableDec) {
return numShares, remCoins, errors.New("unexpected error in balancer maximalExactRatioJoin")
}
numShares = minShareRatio.MulInt(p.TotalShares.Amount).TruncateInt()

// if we have multiple shares, calculate remCoins
if !minShareRatio.Equal(maxShareRatio) {
// we have to calculate remCoins
for i, coin := range tokensIn {
if !coinShareRatios[i].Equal(minShareRatio) {
usedAmount := minShareRatio.MulInt(coin.Amount).Ceil().TruncateInt()
newAmt := coin.Amount.Sub(usedAmount)
// add to RemCoins
if !newAmt.IsZero() {
remCoins = remCoins.Add(sdk.Coin{Denom: coin.Denom, Amount: newAmt})
}
}
}
}

return numShares, remCoins, nil
}

func (p *Pool) JoinPool(_ctx sdk.Context, tokensIn sdk.Coins, swapFee sdk.Dec) (numShares sdk.Int, err error) {
numShares, newLiquidity, err := p.CalcJoinPoolShares(_ctx, tokensIn, swapFee)
if err != nil {
Expand All @@ -313,8 +272,8 @@ func (p *Pool) CalcJoinPoolShares(_ctx sdk.Context, tokensIn sdk.Coins, swapFee
return sdk.ZeroInt(), sdk.NewCoins(), errors.New(
"balancer pool only supports LP'ing with one asset, or all assets in pool")
}
// Add all exact coins we can (no swap)
numShares, remCoins, err := p.maximalExactRatioJoin(tokensIn)
// Add all exact coins we can (no swap). ctx arg doesn't matter for Balancer
numShares, remCoins, err := cfmm_common.MaximalExactRatioJoin(p, sdk.Context{}, tokensIn)
if err != nil {
return sdk.ZeroInt(), sdk.NewCoins(), err
}
Expand Down Expand Up @@ -369,31 +328,7 @@ func (p *Pool) exitPool(ctx sdk.Context, exitingCoins sdk.Coins, exitingShares s
}

func (p *Pool) CalcExitPoolShares(ctx sdk.Context, exitingShares sdk.Int, exitFee sdk.Dec) (exitedCoins sdk.Coins, err error) {
totalShares := p.GetTotalShares()
if exitingShares.GTE(totalShares) {
return sdk.Coins{}, sdkerrors.Wrapf(types.ErrLimitMaxAmount, errMsgFormatSharesLargerThanMax, exitingShares.Int64(), totalShares.Uint64())
}

refundedShares := exitingShares
if !exitFee.IsZero() {
// exitingShares * (1 - exit fee)
// Todo: make a -1 constant
oneSubExitFee := sdk.OneDec().Sub(exitFee)
refundedShares = oneSubExitFee.MulInt(exitingShares).TruncateInt()
}

shareOutRatio := refundedShares.ToDec().QuoInt(totalShares)
// Make it shareOutRatio * pool LP balances
exitedCoins = sdk.Coins{}
balances := p.GetTotalPoolLiquidity(ctx)
for _, asset := range balances {
exitAmt := shareOutRatio.MulInt(asset.Amount).TruncateInt()
if exitAmt.LTE(sdk.ZeroInt()) {
continue
}
exitedCoins = exitedCoins.Add(sdk.NewCoin(asset.Denom, exitAmt))
}
return exitedCoins, nil
return cfmm_common.CalcExitPool(ctx, p, exitingShares, exitFee)
}

// feeRatio returns the fee ratio that is defined as follows:
Expand Down
Loading

0 comments on commit 597fab5

Please sign in to comment.