Skip to content

Commit

Permalink
feat: monotonic sqrt big dec (#6053)
Browse files Browse the repository at this point in the history
* feat: monotonic sqrt big dec

* changelog

(cherry picked from commit 2b703e6)

# Conflicts:
#	CHANGELOG.md
  • Loading branch information
p0mvn authored and mergify[bot] committed Aug 28, 2023
1 parent 3378f8f commit 499055b
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 6 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### API breaks

* [#6071](https://github.com/osmosis-labs/osmosis/pull/6071) reduce number of returns for UpdatePosition and TicksToSqrtPrice functions
<<<<<<< HEAD
=======
* [#5906](https://github.com/osmosis-labs/osmosis/pull/5906) Add `AccountLockedCoins` query in lockup module to stargate whitelist.
* [#6053](https://github.com/osmosis-labs/osmosis/pull/6053) monotonic sqrt with 36 decimals
>>>>>>> 2b703e6e (feat: monotonic sqrt big dec (#6053))
## v17.0.0

Expand Down
42 changes: 42 additions & 0 deletions osmomath/sqrt.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import (
)

var smallestDec = sdk.SmallestDec()
var smallestBigDec = SmallestDec()
var tenTo18 = big.NewInt(1e18)
var tenTo36 = big.NewInt(0).Mul(tenTo18, tenTo18)
var oneBigInt = big.NewInt(1)

// Returns square root of d
Expand Down Expand Up @@ -49,6 +51,37 @@ func MonotonicSqrt(d sdk.Dec) (sdk.Dec, error) {
return root, nil
}

func MonotonicSqrtBigDec(d BigDec) (BigDec, error) {
if d.IsNegative() {
return d, errors.New("cannot take square root of negative number")
}

// A decimal value of d, is represented as an integer of value v = 10^18 * d.
// We have an integer square root function, and we'd like to get the square root of d.
// recall integer square root is floor(sqrt(x)), hence its accurate up to 1 integer.
// we want sqrt d accurate to 18 decimal places.
// So first we multiply our current value by 10^18, then we take the integer square root.
// since sqrt(10^18 * v) = 10^9 * sqrt(v) = 10^18 * sqrt(d), we get the answer we want.
//
// We can than interpret sqrt(10^18 * v) as our resulting decimal and return it.
// monotonicity is guaranteed by correctness of integer square root.
dBi := d.BigInt()
r := big.NewInt(0).Mul(dBi, tenTo36)
r.Sqrt(r)
// However this square root r is s.t. r^2 <= d. We want to flip this to be r^2 >= d.
// To do so, we check that if r^2 < d, do r += 1. Then by correctness we will be in the case we want.
// To compare r^2 and d, we can just compare r^2 and 10^18 * v. (recall r = 10^18 * sqrt(d), v = 10^18 * d)
check := big.NewInt(0).Mul(r, r)
// dBi is a copy of d, so we can modify it.
shiftedD := dBi.Mul(dBi, tenTo36)
if check.Cmp(shiftedD) == -1 {
r.Add(r, oneBigInt)
}
root := NewDecFromBigIntWithPrec(r, 36)

return root, nil
}

// MustMonotonicSqrt returns the output of MonotonicSqrt, panicking on error.
func MustMonotonicSqrt(d sdk.Dec) sdk.Dec {
sqrt, err := MonotonicSqrt(d)
Expand All @@ -57,3 +90,12 @@ func MustMonotonicSqrt(d sdk.Dec) sdk.Dec {
}
return sqrt
}

// MustMonotonicSqrt returns the output of MonotonicSqrt, panicking on error.
func MustMonotonicSqrtBigDec(d BigDec) BigDec {
sqrt, err := MonotonicSqrtBigDec(d)
if err != nil {
panic(err)
}
return sqrt
}
149 changes: 149 additions & 0 deletions osmomath/sqrt_big_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package osmomath

import (
"math/big"
"math/rand"
"testing"

sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func generateRandomDecForEachBitlenBigDec(r *rand.Rand, numPerBitlen int) []BigDec {
return generateRandomDecForEachBitlen[BigDec](r, numPerBitlen, NewDecFromBigIntWithPrec, Precision)
}

func TestSdkApproxSqrtVectors_BigDec(t *testing.T) {
testCases := []struct {
input BigDec
expected BigDec
}{
{OneDec(), OneDec()}, // 1.0 => 1.0
{NewDecWithPrec(25, 2), NewDecWithPrec(5, 1)}, // 0.25 => 0.5
{NewDecWithPrec(4, 2), NewDecWithPrec(2, 1)}, // 0.09 => 0.3
{NewDecFromInt(NewInt(9)), NewDecFromInt(NewInt(3))}, // 9 => 3
{NewDecFromInt(NewInt(2)), MustNewDecFromStr("1.414213562373095048801688724209698079")}, // 2 => 1.414213562373095048801688724209698079
{smallestBigDec, NewDecWithPrec(1, 18)}, // 10^-36 => 10^-18
{smallestBigDec.MulInt64(3), NewDecWithPrec(1732050807568877294, 36)}, // 3*10^-36 => sqrt(3)*10^-18
}

for i, tc := range testCases {
res, err := MonotonicSqrtBigDec(tc.input)
require.NoError(t, err)
require.Equal(t, tc.expected, res, "unexpected result for test case %d, input: %v", i, tc.input)
}
}

func testMonotonicityAroundBigDec(t *testing.T, x BigDec) {
// test that sqrt(x) is monotonic around x
// i.e. sqrt(x-1) <= sqrt(x) <= sqrt(x+1)
sqrtX, err := MonotonicSqrtBigDec(x)
require.NoError(t, err)
sqrtXMinusOne, err := MonotonicSqrtBigDec(x.Sub(smallestBigDec))
require.NoError(t, err)
sqrtXPlusOne, err := MonotonicSqrtBigDec(x.Add(smallestBigDec))
require.NoError(t, err)
assert.True(t, sqrtXMinusOne.LTE(sqrtX), "sqrtXMinusOne: %s, sqrtX: %s", sqrtXMinusOne, sqrtX)
assert.True(t, sqrtX.LTE(sqrtXPlusOne), "sqrtX: %s, sqrtXPlusOne: %s", sqrtX, sqrtXPlusOne)
}

func TestSqrtMonotinicity_BigDec(t *testing.T) {
type testcase struct {
smaller BigDec
bigger BigDec
}
testCases := []testcase{
{MustNewDecFromStr("120.120060020005000000"), MustNewDecFromStr("120.120060020005000001")},
{smallestBigDec, smallestBigDec.MulInt64(2)},
}
// create random test vectors for every bit-length
r := rand.New(rand.NewSource(rand.Int63()))
for i := 0; i < 255+sdk.DecimalPrecisionBits; i++ {
upperbound := big.NewInt(1)
upperbound.Lsh(upperbound, uint(i))
for j := 0; j < 100; j++ {
v := big.NewInt(0).Rand(r, upperbound)
d := NewDecFromBigIntWithPrec(v, 36)
testCases = append(testCases, testcase{d, d.Add(smallestBigDec)})
}
}
for i := 0; i < 1024; i++ {
d := NewDecWithPrec(int64(i), 18)
testCases = append(testCases, testcase{d, d.Add(smallestBigDec)})
}

for _, i := range testCases {
sqrtSmaller, err := MonotonicSqrtBigDec(i.smaller)
require.NoError(t, err, "smaller: %s", i.smaller)
sqrtBigger, err := MonotonicSqrtBigDec(i.bigger)
require.NoError(t, err, "bigger: %s", i.bigger)
assert.True(t, sqrtSmaller.LTE(sqrtBigger), "sqrtSmaller: %s, sqrtBigger: %s", sqrtSmaller, sqrtBigger)

// separately sanity check that sqrt * sqrt >= input
sqrtSmallerSquared := sqrtSmaller.Mul(sqrtSmaller)
assert.True(t, sqrtSmallerSquared.GTE(i.smaller), "sqrt %s, sqrtSmallerSquared: %s, smaller: %s", sqrtSmaller, sqrtSmallerSquared, i.smaller)
}
}

// Test that square(sqrt(x)) = x when x is a perfect square.
// We do this by sampling sqrt(v) from the set of numbers `a.b`, where a in [0, 2^128], b in [0, 10^9].
// and then setting x = sqrt(v)
// this is because this is the set of values whose squares are perfectly representable.
func TestPerfectSquares_BigDec(t *testing.T) {
cases := []BigDec{
NewBigDec(100),
}
r := rand.New(rand.NewSource(rand.Int63()))
tenToMin9 := big.NewInt(1_000_000_000)
for i := 0; i < 128; i++ {
upperbound := big.NewInt(1)
upperbound.Lsh(upperbound, uint(i))
for j := 0; j < 100; j++ {
v := big.NewInt(0).Rand(r, upperbound)
dec := big.NewInt(0).Rand(r, tenToMin9)
d := NewDecFromBigInt(v).Add(NewDecFromBigIntWithPrec(dec, 9))
cases = append(cases, d.MulMut(d))
}
}

for _, i := range cases {
sqrt, err := MonotonicSqrtBigDec(i)
require.NoError(t, err, "smaller: %s", i)
assert.Equal(t, i, sqrt.MulMut(sqrt))
if !i.IsZero() {
testMonotonicityAroundBigDec(t, i)
}
}
}

func TestSqrtRounding_BigDec(t *testing.T) {
testCases := []BigDec{
MustNewDecFromStr("11662930532952632574132537947829685675668532938920838254939577167671385459971.396347723368091000"),
}
r := rand.New(rand.NewSource(rand.Int63()))
testCases = append(testCases, generateRandomDecForEachBitlenBigDec(r, 10)...)
for _, i := range testCases {
sqrt, err := MonotonicSqrtBigDec(i)
require.NoError(t, err, "smaller: %s", i)
// Sanity check that sqrt * sqrt >= input
sqrtSquared := sqrt.Mul(sqrt)
assert.True(t, sqrtSquared.GTE(i), "sqrt %s, sqrtSquared: %s, original: %s", sqrt, sqrtSquared, i)
// (aside) check that (sqrt - 1ulp)^2 <= input
sqrtMin1 := sqrt.Sub(smallestBigDec)
sqrtSquared = sqrtMin1.Mul(sqrtMin1)
assert.True(t, sqrtSquared.LTE(i), "sqrtMin1ULP %s, sqrtSquared: %s, original: %s", sqrt, sqrtSquared, i)
}
}

// benchmarks the new square root across bit-lengths, for comparison with the SDK square root.
func BenchmarkMonotonicSqrt_BigDec(b *testing.B) {
r := rand.New(rand.NewSource(1))
vectors := generateRandomDecForEachBitlenBigDec(r, 1)
for i := 0; i < b.N; i++ {
for j := 0; j < len(vectors); j++ {
a, _ := MonotonicSqrtBigDec(vectors[j])
_ = a
}
}
}
16 changes: 10 additions & 6 deletions osmomath/sqrt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@ import (
"github.com/stretchr/testify/require"
)

func generateRandomDecForEachBitlen(r *rand.Rand, numPerBitlen int) []sdk.Dec {
res := make([]sdk.Dec, (255+sdk.DecimalPrecisionBits)*numPerBitlen)
func generateRandomDecForEachBitlenDec(r *rand.Rand, numPerBitlen int) []sdk.Dec {
return generateRandomDecForEachBitlen[sdk.Dec](r, numPerBitlen, sdk.NewDecFromBigIntWithPrec, sdk.Precision)
}

func generateRandomDecForEachBitlen[T any](r *rand.Rand, numPerBitlen int, constructor func(*big.Int, int64) T, precision int64) []T {
res := make([]T, (255+sdk.DecimalPrecisionBits)*numPerBitlen)
for i := 0; i < 255+sdk.DecimalPrecisionBits; i++ {
upperbound := big.NewInt(1)
upperbound.Lsh(upperbound, uint(i))
for j := 0; j < numPerBitlen; j++ {
v := big.NewInt(0).Rand(r, upperbound)
res[i*numPerBitlen+j] = sdk.NewDecFromBigIntWithPrec(v, 18)
res[i*numPerBitlen+j] = constructor(v, precision)
}
}
return res
Expand Down Expand Up @@ -133,7 +137,7 @@ func TestSqrtRounding(t *testing.T) {
// sdk.MustNewDecFromStr("11662930532952632574132537947829685675668532938920838254939577167671385459971.396347723368091000"),
}
r := rand.New(rand.NewSource(rand.Int63()))
testCases = append(testCases, generateRandomDecForEachBitlen(r, 10)...)
testCases = append(testCases, generateRandomDecForEachBitlenDec(r, 10)...)
for _, i := range testCases {
sqrt, err := MonotonicSqrt(i)
require.NoError(t, err, "smaller: %s", i)
Expand All @@ -150,7 +154,7 @@ func TestSqrtRounding(t *testing.T) {
// benchmarks the SDK square root across bit-lengths, for comparison with the new square root.
func BenchmarkSqrt(b *testing.B) {
r := rand.New(rand.NewSource(1))
vectors := generateRandomDecForEachBitlen(r, 1)
vectors := generateRandomDecForEachBitlenDec(r, 1)
for i := 0; i < b.N; i++ {
for j := 0; j < len(vectors); j++ {
a, _ := vectors[j].ApproxSqrt()
Expand All @@ -162,7 +166,7 @@ func BenchmarkSqrt(b *testing.B) {
// benchmarks the new square root across bit-lengths, for comparison with the SDK square root.
func BenchmarkMonotonicSqrt(b *testing.B) {
r := rand.New(rand.NewSource(1))
vectors := generateRandomDecForEachBitlen(r, 1)
vectors := generateRandomDecForEachBitlenDec(r, 1)
for i := 0; i < b.N; i++ {
for j := 0; j < len(vectors); j++ {
a, _ := MonotonicSqrt(vectors[j])
Expand Down

0 comments on commit 499055b

Please sign in to comment.