Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement BigDec binary search in osmoutils #2802

Merged
merged 10 commits into from
Sep 22, 2022
82 changes: 78 additions & 4 deletions osmoutils/binary_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"errors"

sdk "github.com/cosmos/cosmos-sdk/types"

"github.com/osmosis-labs/osmosis/v12/osmomath"
)

// ErrTolerance is used to define a compare function, which checks if two
Expand Down Expand Up @@ -56,6 +58,44 @@ func (e ErrTolerance) Compare(expected sdk.Int, actual sdk.Int) int {
return 0
}

// CompareBigDec validates if actual is within errTolerance of expected.
// returns 0 if it is
// returns 1 if not, and expected > actual.
// returns -1 if not, and expected < actual
func (e ErrTolerance) CompareBigDec(expected osmomath.BigDec, actual osmomath.BigDec) int {
diff := expected.Sub(actual).Abs()

comparisonSign := 0
if expected.GT(actual) {
comparisonSign = 1
} else {
comparisonSign = -1
}

// Check additive tolerance equations
if !e.AdditiveTolerance.IsNil() {
// if no error accepted, do a direct compare.
if e.AdditiveTolerance.IsZero() {
if expected.Equal(actual) {
return 0
}
}

if diff.GT(osmomath.NewBigDec(e.AdditiveTolerance.Int64())) {
AlpinYukseloglu marked this conversation as resolved.
Show resolved Hide resolved
return comparisonSign
}
}
// Check multiplicative tolerance equations
if !e.MultiplicativeTolerance.IsNil() && !e.MultiplicativeTolerance.IsZero() {
errTerm := diff.Quo(osmomath.MinDec(expected, actual))
if errTerm.GT(osmomath.BigDecFromSDKDec(e.MultiplicativeTolerance)) {
return comparisonSign
}
}

return 0
}

// Binary search inputs between [lowerbound, upperbound] to a monotonic increasing function f.
// We stop once f(found_input) meets the ErrTolerance constraints.
// If we perform more than maxIterations (or equivalently lowerbound = upperbound), we return an error.
Expand All @@ -80,16 +120,50 @@ func BinarySearch(f func(input sdk.Int) (sdk.Int, error),
} else if compRes < 0 {
lowerbound = curEstimate
} else {
break
return curEstimate, nil
}
curEstimate = lowerbound.Add(upperbound).QuoRaw(2)
curOutput, err = f(curEstimate)
if err != nil {
return sdk.Int{}, err
}
}
if curIteration == maxIterations {
return sdk.Int{}, errors.New("hit maximum iterations, did not converge fast enough")

return sdk.Int{}, errors.New("hit maximum iterations, did not converge fast enough")
}

// Binary search BigDec inputs between [lowerbound, upperbound] to a monotonic increasing function f
// We stop once f(found_input) meets the ErrTolerance constraints.
// If we perform more than maxIterations (or equivalently lowerbound = upperbound), we return an error.
AlpinYukseloglu marked this conversation as resolved.
Show resolved Hide resolved
func BinarySearchBigDec(f func(input osmomath.BigDec) (osmomath.BigDec, error),
lowerbound osmomath.BigDec,
upperbound osmomath.BigDec,
targetOutput osmomath.BigDec,
errTolerance ErrTolerance,
maxIterations int,
) (osmomath.BigDec, error) {
// Setup base case of loop
curEstimate := lowerbound.Add(upperbound).Quo(osmomath.NewBigDec(2))
curOutput, err := f(curEstimate)
if err != nil {
return osmomath.BigDec{}, err
}
curIteration := 0
for ; curIteration < maxIterations; curIteration += 1 {
compRes := errTolerance.CompareBigDec(curOutput, targetOutput)
if compRes > 0 {
upperbound = curEstimate
} else if compRes < 0 {
lowerbound = curEstimate
} else {
return curEstimate, nil
}
curEstimate = lowerbound.Add(upperbound).Quo(osmomath.NewBigDec(2))
curOutput, err = f(curEstimate)
if err != nil {
return osmomath.BigDec{}, err
}
}
return curEstimate, nil

return osmomath.BigDec{}, errors.New("hit maximum iterations, did not converge fast enough")
}
139 changes: 105 additions & 34 deletions osmoutils/binary_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (

sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/stretchr/testify/require"

"github.com/osmosis-labs/osmosis/v12/osmomath"
)

func TestBinarySearch(t *testing.T) {
Expand All @@ -23,7 +25,7 @@ func TestBinarySearch(t *testing.T) {
testErrToleranceAdditive := ErrTolerance{AdditiveTolerance: sdk.NewInt(1 << 20)}
testErrToleranceMultiplicative := ErrTolerance{AdditiveTolerance: sdk.ZeroInt(), MultiplicativeTolerance: sdk.NewDec(10)}
testErrToleranceBoth := ErrTolerance{AdditiveTolerance: sdk.NewInt(1 << 20), MultiplicativeTolerance: sdk.NewDec(1 << 3)}
tests := []struct {
tests := map[string]struct {
f func(sdk.Int) (sdk.Int, error)
lowerbound sdk.Int
upperbound sdk.Int
Expand All @@ -41,26 +43,89 @@ func TestBinarySearch(t *testing.T) {
// If it is, we return current output
// Additive error bounds are solid addition / subtraction bounds to error, while multiplicative bounds take effect after dividing by the minimum between the two compared numbers.
}{
{lineF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), noErrTolerance, 51, sdk.NewInt(1 + (1 << 25)), false},
{lineF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), noErrTolerance, 10, sdk.Int{}, true},
{expF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), noErrTolerance, 51, sdk.NewInt(322539792367616), false},
{expF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), noErrTolerance, 10, sdk.Int{}, true},
{expF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt((1 << 15)), testErrToleranceAdditive, 51, sdk.NewInt(1 << 46), false},
{expF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt((1 << 30)), testErrToleranceAdditive, 10, sdk.Int{}, true},
{expF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), testErrToleranceMultiplicative, 51, sdk.NewInt(322539792367616), false},
{expF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), testErrToleranceMultiplicative, 10, sdk.Int{}, true},
{expF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt((1 << 15)), testErrToleranceBoth, 51, sdk.NewInt(1 << 45), false},
{expF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt((1 << 30)), testErrToleranceBoth, 10, sdk.Int{}, true},
"linear f, no err tolerance, converges": {lineF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), noErrTolerance, 51, sdk.NewInt(1 + (1 << 25)), false},
"linear f, no err tolerance, does not converge": {lineF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), noErrTolerance, 10, sdk.Int{}, true},
"exponential f, no err tolerance, converges": {expF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), noErrTolerance, 51, sdk.NewInt(322539792367616), false},
"exponential f, no err tolerance, does not converge": {expF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), noErrTolerance, 10, sdk.Int{}, true},
"exponential f, large additive err tolerance, converges": {expF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt((1 << 15)), testErrToleranceAdditive, 51, sdk.NewInt(1 << 46), false},
"exponential f, large additive err tolerance, does not converge": {expF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt((1 << 30)), testErrToleranceAdditive, 10, sdk.Int{}, true},
"exponential f, large multiplicative err tolerance, converges": {expF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), testErrToleranceMultiplicative, 51, sdk.NewInt(322539792367616), false},
"exponential f, large multiplicative err tolerance, does not converge": {expF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), testErrToleranceMultiplicative, 10, sdk.Int{}, true},
"exponential f, both err tolerances, converges": {expF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt((1 << 15)), testErrToleranceBoth, 51, sdk.NewInt(1 << 45), false},
"exponential f, both err tolerances, does not converge": {expF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt((1 << 30)), testErrToleranceBoth, 10, sdk.Int{}, true},
}

for name, tc := range tests {
t.Run(name, func(t *testing.T) {
actualSolvedInput, err := BinarySearch(tc.f, tc.lowerbound, tc.upperbound, tc.targetOutput, tc.errTolerance, tc.maxIterations)
if tc.expectErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.True(sdk.IntEq(t, tc.expectedSolvedInput, actualSolvedInput))
}
})
}
}

func TestBinarySearchBigDec(t *testing.T) {
// straight line function that returns input. Simplest to binary search on,
// binary search directly reveals one bit of the answer in each iteration with this function.
lineF := func(a osmomath.BigDec) (osmomath.BigDec, error) {
return a, nil
}
expF := func(a osmomath.BigDec) (osmomath.BigDec, error) {
AlpinYukseloglu marked this conversation as resolved.
Show resolved Hide resolved
// these precision shifts are done implicitly in the int binary search tests
// we keep them here to maintain parity between test cases across implementations
calculation := a.Quo(osmomath.NewBigDec(10).Power(18))
result := calculation.Power(3)
output := result.Mul(osmomath.NewBigDec(10).Power(18))
return output, nil
}
lowErrTolerance := ErrTolerance{AdditiveTolerance: sdk.OneInt()}
testErrToleranceAdditive := ErrTolerance{AdditiveTolerance: sdk.NewInt(1 << 20)}
testErrToleranceMultiplicative := ErrTolerance{AdditiveTolerance: sdk.OneInt(), MultiplicativeTolerance: sdk.NewDec(10)}
testErrToleranceBoth := ErrTolerance{AdditiveTolerance: sdk.NewInt(1 << 20), MultiplicativeTolerance: sdk.NewDec(1 << 3)}
tests := map[string]struct {
f func(osmomath.BigDec) (osmomath.BigDec, error)
lowerbound osmomath.BigDec
upperbound osmomath.BigDec
targetOutput osmomath.BigDec
errTolerance ErrTolerance
maxIterations int

expectedSolvedInput osmomath.BigDec
expectErr bool
// This binary searches inputs to a monotonic increasing function F
// We stop when the answer is within error bounds stated by errTolerance
// First, (lowerbound + upperbound) / 2 becomes the current estimate.
// A current output is also defined as f(current estimate). In this case f is lineF
// We then compare the current output with the target output to see if it's within error tolerance bounds. If not, continue binary searching by iterating.
// If it is, we return current output
// Additive error bounds are solid addition / subtraction bounds to error, while multiplicative bounds take effect after dividing by the minimum between the two compared numbers.
}{
"linear f, no err tolerance, converges": {lineF, osmomath.ZeroDec(), osmomath.NewBigDec(1 << 50), osmomath.NewBigDec(1 + (1 << 25)), lowErrTolerance, 51, osmomath.NewBigDec(1 + (1 << 25)), false},
"linear f, no err tolerance, does not converge": {lineF, osmomath.ZeroDec(), osmomath.NewBigDec(1 << 50), osmomath.NewBigDec(1 + (1 << 25)), lowErrTolerance, 10, osmomath.BigDec{}, true},
"exponential f, no err tolerance, converges": {expF, osmomath.ZeroDec(), osmomath.NewBigDec(1 << 50), osmomath.NewBigDec(1 + (1 << 25)), lowErrTolerance, 51, osmomath.NewBigDec(322539792367616), false},
"exponential f, no err tolerance, does not converge": {expF, osmomath.ZeroDec(), osmomath.NewBigDec(1 << 50), osmomath.NewBigDec(1 + (1 << 25)), lowErrTolerance, 10, osmomath.BigDec{}, true},
"exponential f, large additive err tolerance, converges": {expF, osmomath.ZeroDec(), osmomath.NewBigDec(1 << 50), osmomath.NewBigDec((1 << 15)), testErrToleranceAdditive, 51, osmomath.NewBigDec(1 << 46), false},
"exponential f, large additive err tolerance, does not converge": {expF, osmomath.ZeroDec(), osmomath.NewBigDec(1 << 50), osmomath.NewBigDec((1 << 30)), testErrToleranceAdditive, 10, osmomath.BigDec{}, true},
"exponential f, large multiplicative err tolerance, converges": {expF, osmomath.ZeroDec(), osmomath.NewBigDec(1 << 50), osmomath.NewBigDec(1 + (1 << 25)), testErrToleranceMultiplicative, 51, osmomath.NewBigDec(322539792367616), false},
"exponential f, large multiplicative err tolerance, does not converge": {expF, osmomath.ZeroDec(), osmomath.NewBigDec(1 << 50), osmomath.NewBigDec(1 + (1 << 25)), testErrToleranceMultiplicative, 10, osmomath.BigDec{}, true},
"exponential f, both err tolerances, converges": {expF, osmomath.ZeroDec(), osmomath.NewBigDec(1 << 50), osmomath.NewBigDec((1 << 15)), testErrToleranceBoth, 51, osmomath.NewBigDec(1 << 45), false},
"exponential f, both err tolerances, does not converge": {expF, osmomath.ZeroDec(), osmomath.NewBigDec(1 << 50), osmomath.NewBigDec((1 << 30)), testErrToleranceBoth, 10, osmomath.BigDec{}, true},
AlpinYukseloglu marked this conversation as resolved.
Show resolved Hide resolved
}

for _, tc := range tests {
actualSolvedInput, err := BinarySearch(tc.f, tc.lowerbound, tc.upperbound, tc.targetOutput, tc.errTolerance, tc.maxIterations)
if tc.expectErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.True(sdk.IntEq(t, tc.expectedSolvedInput, actualSolvedInput))
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
actualSolvedInput, err := BinarySearchBigDec(tc.f, tc.lowerbound, tc.upperbound, tc.targetOutput, tc.errTolerance, tc.maxIterations)
if tc.expectErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.True(osmomath.DecApproxEq(t, tc.expectedSolvedInput, actualSolvedInput, osmomath.OneDec()))
}
})
}
}

Expand All @@ -72,29 +137,35 @@ func TestErrTolerance_Compare(t *testing.T) {
tests := []struct {
name string
tol ErrTolerance
input sdk.Int
reference sdk.Int
intInput sdk.Int
intReference sdk.Int

bigDecInput osmomath.BigDec
bigDecReference osmomath.BigDec

expectedCompareResult int
}{
{"0 tolerance: <", ZeroErrTolerance, sdk.NewInt(1000), sdk.NewInt(1001), -1},
{"0 tolerance: =", ZeroErrTolerance, sdk.NewInt(1001), sdk.NewInt(1001), 0},
{"0 tolerance: >", ZeroErrTolerance, sdk.NewInt(1002), sdk.NewInt(1001), 1},
{"Nonzero additive tolerance: <", NonZeroErrAdditive, sdk.NewInt(420), sdk.NewInt(1001), -1},
{"Nonzero additive tolerance: =", NonZeroErrAdditive, sdk.NewInt(1011), sdk.NewInt(1001), 0},
{"Nonzero additive tolerance: >", NonZeroErrAdditive, sdk.NewInt(1230), sdk.NewInt(1001), 1},
{"Nonzero multiplicative tolerance: <", NonZeroErrMultiplicative, sdk.NewInt(1000), sdk.NewInt(1001), -1},
{"Nonzero multiplicative tolerance: =", NonZeroErrMultiplicative, sdk.NewInt(1001), sdk.NewInt(1001), 0},
{"Nonzero multiplicative tolerance: >", NonZeroErrMultiplicative, sdk.NewInt(1002), sdk.NewInt(1001), 1},
{"Nonzero both tolerance: <", NonZeroErrBoth, sdk.NewInt(990), sdk.NewInt(1001), -1},
{"Nonzero both tolerance: =", NonZeroErrBoth, sdk.NewInt(1002), sdk.NewInt(1001), 0},
{"Nonzero both tolerance: >", NonZeroErrBoth, sdk.NewInt(1011), sdk.NewInt(1001), 1},
{"0 tolerance: <", ZeroErrTolerance, sdk.NewInt(1000), sdk.NewInt(1001), osmomath.NewBigDec(1000), osmomath.NewBigDec(1001), -1},
{"0 tolerance: =", ZeroErrTolerance, sdk.NewInt(1001), sdk.NewInt(1001), osmomath.NewBigDec(1001), osmomath.NewBigDec(1001), 0},
{"0 tolerance: >", ZeroErrTolerance, sdk.NewInt(1002), sdk.NewInt(1001), osmomath.NewBigDec(1002), osmomath.NewBigDec(1001), 1},
{"Nonzero additive tolerance: <", NonZeroErrAdditive, sdk.NewInt(420), sdk.NewInt(1001), osmomath.NewBigDec(420), osmomath.NewBigDec(1001), -1},
{"Nonzero additive tolerance: =", NonZeroErrAdditive, sdk.NewInt(1011), sdk.NewInt(1001), osmomath.NewBigDec(1011), osmomath.NewBigDec(1001), 0},
{"Nonzero additive tolerance: >", NonZeroErrAdditive, sdk.NewInt(1230), sdk.NewInt(1001), osmomath.NewBigDec(1230), osmomath.NewBigDec(1001), 1},
{"Nonzero multiplicative tolerance: <", NonZeroErrMultiplicative, sdk.NewInt(1000), sdk.NewInt(1001), osmomath.NewBigDec(1000), osmomath.NewBigDec(1001), -1},
{"Nonzero multiplicative tolerance: =", NonZeroErrMultiplicative, sdk.NewInt(1001), sdk.NewInt(1001), osmomath.NewBigDec(1001), osmomath.NewBigDec(1001), 0},
{"Nonzero multiplicative tolerance: >", NonZeroErrMultiplicative, sdk.NewInt(1002), sdk.NewInt(1001), osmomath.NewBigDec(1002), osmomath.NewBigDec(1001), 1},
{"Nonzero both tolerance: <", NonZeroErrBoth, sdk.NewInt(990), sdk.NewInt(1001), osmomath.NewBigDec(990), osmomath.NewBigDec(1001), -1},
{"Nonzero both tolerance: =", NonZeroErrBoth, sdk.NewInt(1002), sdk.NewInt(1001), osmomath.NewBigDec(1002), osmomath.NewBigDec(1001), 0},
{"Nonzero both tolerance: >", NonZeroErrBoth, sdk.NewInt(1011), sdk.NewInt(1001), osmomath.NewBigDec(1011), osmomath.NewBigDec(1001), 1},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.tol.Compare(tt.input, tt.reference); got != tt.expectedCompareResult {
if got := tt.tol.Compare(tt.intInput, tt.intReference); got != tt.expectedCompareResult {
t.Errorf("ErrTolerance.Compare() = %v, want %v", got, tt.expectedCompareResult)
}
if got := tt.tol.CompareBigDec(tt.bigDecInput, tt.bigDecReference); got != tt.expectedCompareResult {
t.Errorf("ErrTolerance.CompareBigDec() = %v, want %v", got, tt.expectedCompareResult)
}
})
}
}