diff --git a/types/math/dec.go b/types/math/dec.go index 0212a66cde..1ab6bec2fd 100644 --- a/types/math/dec.go +++ b/types/math/dec.go @@ -5,6 +5,7 @@ import ( "math/big" "github.com/cockroachdb/apd/v2" + sdk "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/cosmos-sdk/types/errors" ) @@ -205,6 +206,27 @@ func (x Dec) BigInt() (*big.Int, error) { return z, nil } +// SdkIntTrim rounds decimal number to the integer towards zero and converts it to `sdk.Int`. +// Panics if x is bigger the SDK Int max value +func (x Dec) SdkIntTrim() sdk.Int { + y, _ := x.Reduce() + var r = y.dec.Coeff + if y.dec.Exponent != 0 { + decs := big.NewInt(10) + if y.dec.Exponent > 0 { + decs.Exp(decs, big.NewInt(int64(y.dec.Exponent)), nil) + r.Mul(&y.dec.Coeff, decs) + } else { + decs.Exp(decs, big.NewInt(int64(-y.dec.Exponent)), nil) + r.Quo(&y.dec.Coeff, decs) + } + } + if x.dec.Negative { + r.Neg(&r) + } + return sdk.NewIntFromBigInt(&r) +} + func (x Dec) String() string { return x.dec.Text('f') } diff --git a/types/math/dec_bench_test.go b/types/math/dec_bench_test.go new file mode 100644 index 0000000000..9537b7899b --- /dev/null +++ b/types/math/dec_bench_test.go @@ -0,0 +1,61 @@ +package math + +import ( + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" +) + +func BenchmarkSdkIntTrim(b *testing.B) { + s := "12345678901234567890.12345678901234567890" + d, err := NewDecFromString(s) + if err != nil { + b.Error("can't convert test number") + } + + b.Run("exp", func(b *testing.B) { + for n := 0; n < b.N; n++ { + d.SdkIntTrim() + } + }) + + b.Run("quo-integer", func(b *testing.B) { + for n := 0; n < b.N; n++ { + sdkIntTrimQuo(d) + } + }) + + b.Run("string", func(b *testing.B) { + for n := 0; n < b.N; n++ { + sdkIntTrimNaive(d) + } + }) + +} + +func sdkIntTrimQuo(d Dec) sdk.Int { + d, err := d.QuoInteger(NewDecFromInt64(1)) + if err != nil { + panic(err) + } + + i, err := d.BigInt() + if err != nil { + panic(err) + } + return sdk.NewIntFromBigInt(i) +} + +func sdkIntTrimNaive(d Dec) sdk.Int { + d, err := d.QuoInteger(NewDecFromInt64(1)) + if err != nil { + panic(err) + } + + s := d.String() + i, ok := sdk.NewIntFromString(s) + if !ok { + panic("can't convert from string") + } + return i +} diff --git a/types/math/dec_test.go b/types/math/dec_test.go index d0fc0f6cdf..c02ce47f23 100644 --- a/types/math/dec_test.go +++ b/types/math/dec_test.go @@ -644,27 +644,51 @@ func TestQuoExactBad(t *testing.T) { } func TestToBigInt(t *testing.T) { - intStr := "1000000000000000000000000000000000000123456789" - a, err := NewDecFromString(intStr) - require.NoError(t, err) - b, err := a.BigInt() - require.Equal(t, intStr, b.String()) - - intStrWithTrailingZeros := "1000000000000000000000000000000000000123456789.00000000" - a, err = NewDecFromString(intStrWithTrailingZeros) - require.NoError(t, err) - b, err = a.BigInt() - require.Equal(t, intStr, b.String()) - - intStr2 := "123.456e6" - a, err = NewDecFromString(intStr2) - require.NoError(t, err) - b, err = a.BigInt() - require.Equal(t, "123456000", b.String()) + i1 := "1000000000000000000000000000000000000123456789" + tcs := []struct { + intStr string + out string + isError error + }{ + {i1, i1, nil}, + {"1000000000000000000000000000000000000123456789.00000000", i1, nil}, + {"123.456e6", "123456000", nil}, + {"12345.6", "", ErrNonIntegeral}, + } + for idx, tc := range tcs { + a, err := NewDecFromString(tc.intStr) + require.NoError(t, err) + b, err := a.BigInt() + if tc.isError == nil { + require.NoError(t, err, "test_%d", idx) + require.Equal(t, tc.out, b.String(), "test_%d", idx) + } else { + require.ErrorIs(t, err, tc.isError, "test_%d", idx) + } + } +} - intStr3 := "12345.6" - a, err = NewDecFromString(intStr3) - require.NoError(t, err) - _, err = a.BigInt() - require.ErrorIs(t, err, ErrNonIntegeral) +func TestToSdkInt(t *testing.T) { + i1 := "1000000000000000000000000000000000000123456789" + tcs := []struct { + intStr string + out string + }{ + {i1, i1}, + {"1000000000000000000000000000000000000123456789.00000000", i1}, + {"123.456e6", "123456000"}, + {"123.456e1", "1234"}, + {"123.456", "123"}, + {"123.956", "123"}, + {"-123.456", "-123"}, + {"-123.956", "-123"}, + {"-0.956", "0"}, + {"-0.9", "0"}, + } + for idx, tc := range tcs { + a, err := NewDecFromString(tc.intStr) + require.NoError(t, err) + b := a.SdkIntTrim() + require.Equal(t, tc.out, b.String(), "test_%d", idx) + } }