Skip to content

Commit

Permalink
fix: deprecated PowInt, replaced by PowInt32 (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
quagmt authored Nov 14, 2024
1 parent a5482f2 commit 1307249
Show file tree
Hide file tree
Showing 4 changed files with 280 additions and 5 deletions.
93 changes: 89 additions & 4 deletions decimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ var (
// ErrInvalidBinaryData is returned when unmarshalling invalid binary data
// The binary data should follow the format as described in MarshalBinary
ErrInvalidBinaryData = fmt.Errorf("invalid binary data")

// ErrZeroPowNegative is returned when raising zero to a negative power
ErrZeroPowNegative = fmt.Errorf("can't raise zero to a negative power")
)

var (
Expand Down Expand Up @@ -1141,11 +1144,23 @@ func trailingZerosU128(n u128) uint8 {
return zeros
}

// PowInt returns d^e where e is an integer.
// Deprecated: Use PowInt32 instead for correct handling of 0^0 and negative exponents.
// This function treats 0 raised to any power as 0, which may not align with mathematical conventions
// but is practical in certain cases. See: https://github.com/quagmt/udecimal/issues/25.
//
// PowInt raises the decimal d to the integer power e (d^e).
//
// Special cases:
// - 0^e = 0 for any integer e
// - d^0 = 1 for any decimal d ≠ 0
//
// Examples:
//
// PowInt(2.5, 2) = 6.25
// PowInt(0, 0) = 0
// PowInt(0, 1) = 0
// PowInt(0, -1) = 0
// PowInt(2, 0) = 1
// PowInt(2.5, 2) = 6.25
// PowInt(2.5, -2) = 0.16
func (d Decimal) PowInt(e int) Decimal {
// check 0 first to avoid 0^0 = 1
Expand Down Expand Up @@ -1196,6 +1211,76 @@ func (d Decimal) PowInt(e int) Decimal {
return newDecimal(neg, bintFromBigInt(qBig), uint8(powPrecision))
}

// PowInt32 returns d raised to the power of e, where e is an int32.
//
// Returns:
//
// The result of d raised to the power of e.
// An error if d is zero and e is a negative integer.
//
// Special cases:
//
// 0^0 = 1
// 0^(any negative integer) results in an error
//
// Examples:
//
// PowInt32(0, 0) = 1
// PowInt32(2, 0) = 1
// PowInt32(0, 1) = 0
// PowInt32(0, -1) results in an error
// PowInt32(2.5, 2) = 6.25
// PowInt32(2.5, -2) = 0.16
func (d Decimal) PowInt32(e int32) (Decimal, error) {
// special case: 0 raised to a negative power
if d.coef.IsZero() && e < 0 {
return Decimal{}, ErrZeroPowNegative
}

if e == 0 {
return One, nil
}

if e == 1 {
return d, nil
}

// Rescale first to remove trailing zeros
dTrim := d.trimTrailingZeros()

if e < 0 {
return dTrim.powIntInverse(int(-e)), nil
}

// e > 1 && d != 0
q, err := dTrim.tryPowIntU128(int(e))
if err == nil {
return q, nil
}

// overflow, fallback to big.Int
dBig := dTrim.coef.GetBig()

var factor int32
powPrecision := int32(dTrim.prec) * e
if powPrecision >= int32(defaultPrec) {
factor = powPrecision - int32(defaultPrec)
powPrecision = int32(defaultPrec)
}

m := new(big.Int).Exp(bigTen, big.NewInt(int64(factor)), nil)
dBig = new(big.Int).Exp(dBig, big.NewInt(int64(e)), nil)
qBig := dBig.Quo(dBig, m)

neg := d.neg
if e%2 == 0 {
neg = false
}

//nolint:gosec
return newDecimal(neg, bintFromBigInt(qBig), uint8(powPrecision)), nil
}

// powIntInverse returns d^(-e), with e > 0
func (d Decimal) powIntInverse(e int) Decimal {
q, err := d.tryInversePowIntU128(e)
Expand Down Expand Up @@ -1279,8 +1364,8 @@ func (d Decimal) tryInversePowIntU128(e int) (Decimal, error) {
return Decimal{}, errOverflow
}

if d.coef.u128.hi != 0 && e >= 3 {
// e > 3 and u128.hi != 0 means the result will >= 2^192,
if d.coef.u128.hi != 0 && e >= 4 {
// e >= 4 and u128.hi != 0 means the result will >= 2^256,
// which we can't use fast division. So we need to use big.Int instead
return Decimal{}, errOverflow
}
Expand Down
127 changes: 127 additions & 0 deletions decimal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2243,6 +2243,133 @@ func TestRandomPow(t *testing.T) {
}
}

func TestPowInt32(t *testing.T) {
testcases := []struct {
a string
b int32
want string
wantErr error
}{
{"123456789012345678901234567890123456789.9999999999999999999", 2, "15241578753238836750495351562566681945252248135650053345652796829976527968319.753086421975308642", nil},
{"0.5", -14, "16384", nil},
{"5", -18, "0.000000000000262144", nil},
{"-96", 384, "155651563400161893689540829251750532876602528021691915200061141022544075854496838643052295888420136905906567539126502582243693732125449523059780613380755061052491943449381255863820131332142779769865996188291542971996702478765598563482106934995948481892528830806840727897892513634949541154348143236794203399068607458789100280733156671481421737413484548654754828937861442964361485155011834501441449057827522043722520499866143913624005535732240536689495728164138830318329923569260213567200238743687906030695515032990022513102670332644203639546984105586335760789206424524917450457774575904047665710191104154700220406574406611422191187238002842748820651406984670104474060413271629299557918370269495849383625416400964818595369246834495413046931303826618633216386400256", nil},
{"-70", -8, "0.0000000000000017346", nil},
{"0.12", 100, "0", nil},
{"0", 0, "1", nil},
{"0", -1, "0", ErrZeroPowNegative},
{"0", 1, "0", nil},
{"0", 10, "0", nil},
{"1.12345", 4, "1.5929971334827095062", nil},
{"123456789012345678901234567890123456789.9999999999999999999", 0, "1", nil},
{"123456789012345678901234567890123456789.9999999999999999999", 1, "123456789012345678901234567890123456789.9999999999999999999", nil},
{"1.5", 3, "3.375", nil},
{"1.12345", 1, "1.12345", nil},
{"1.12345", 2, "1.2621399025", nil},
{"1.12345", 3, "1.417951073463625", nil},
{"1.12345", 4, "1.5929971334827095062", nil},
{"1.12345", 5, "1.7896526296111499947", nil},
{"1.12345", 6, "2.0105852467366464616", nil},
{"1.12345", 7, "2.2587919954462854673", nil},
{"-1.12345", 4, "1.5929971334827095062", nil},
}

for _, tc := range testcases {
t.Run(fmt.Sprintf("%s.pow(%d)", tc.a, tc.b), func(t *testing.T) {
a, err := Parse(tc.a)
require.NoError(t, err)

aStr := a.String()

b, err := a.PowInt32(tc.b)
if tc.wantErr != nil {
require.Equal(t, tc.wantErr, err)
return
}

require.Equal(t, tc.want, b.String())

// make sure a is immutable
require.Equal(t, aStr, a.String())

// cross check with shopspring/decimal

aa := decimal.RequireFromString(tc.a)
aa, err = aa.PowWithPrecision(decimal.New(int64(tc.b), 0), int32(b.prec)+4)

// special case for 0^0
// udecimal: 0^0 = 1
// shopspring/decimal: 0^0 is undefined and will return an error
if tc.a == "0" && tc.b == 0 {
require.EqualError(t, err, "cannot represent undefined value of 0**0")
return
}

require.NoError(t, err)

aa = aa.Truncate(int32(b.prec))

require.Equal(t, aa.String(), b.String())
})
}
}

func TestRandomPowInt32(t *testing.T) {
inputs := []string{
"0.1234",
"-0.1234",
"1.123456789012345679",
"-1.123456789012345679",
"1.12345",
"-1.12345",
"123456789012345678901234567890123456789.9999999999999999999",
"123456789012345678901234567890123456789.9999999999999999999",
"1.5",
"123456.789",
"123.4",
"1234567890123456789.1234567890123456789",
"-1234567890123456789.1234567890123456789",
}

for _, input := range inputs {
t.Run(fmt.Sprintf("pow(%s)", input), func(t *testing.T) {
a := MustParse(input)

for i := 0; i <= 1000; i++ {
b, err := a.PowInt32(int32(i))
require.NoError(t, err)

aa := decimal.RequireFromString(input)
aa, err = aa.PowWithPrecision(decimal.New(int64(i), 0), int32(b.prec)+4)
require.NoError(t, err)

aa = aa.Truncate(int32(b.prec))

require.Equal(t, aa.String(), b.String(), "%s.pow(%d)", input, i)
}
})
}

for _, input := range inputs {
t.Run(fmt.Sprintf("powInverse(%s)", input), func(t *testing.T) {
a := MustParse(input)

for i := 0; i >= -100; i-- {
b, err := a.PowInt32(int32(i))
require.NoError(t, err)

aa := decimal.RequireFromString(input)
aa, err = aa.PowWithPrecision(decimal.New(int64(i), 0), int32(b.prec)+4)
require.NoError(t, err)

aa = aa.Truncate(int32(b.prec))

require.Equal(t, aa.String(), b.String(), "%s.pow(%d)", input, i)
}
})
}
}

func TestSqrt(t *testing.T) {
testcases := []struct {
a string
Expand Down
14 changes: 14 additions & 0 deletions doc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,20 @@ func ExampleDecimal_PowInt() {
// 0.6609822195782933439
}

func ExampleDecimal_PowInt32() {
fmt.Println(MustParse("1.23").PowInt32(2))
fmt.Println(MustParse("1.23").PowInt32(0))
fmt.Println(MustParse("1.23").PowInt32(-2))
fmt.Println(MustParse("0").PowInt32(0))
fmt.Println(MustParse("0").PowInt32(-2))
// Output:
// 1.5129 <nil>
// 1 <nil>
// 0.6609822195782933439 <nil>
// 1 <nil>
// 0 can't raise zero to a negative power
}

func ExampleDecimal_Prec() {
fmt.Println(MustParse("1.23").Prec())
// Output:
Expand Down
51 changes: 50 additions & 1 deletion fuzz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ func FuzzTrunc(f *testing.F) {
})
}

func FuzzPowInt(f *testing.F) {
func FuzzDepcrecatedPowInt(f *testing.F) {
for _, c := range corpus {
f.Add(c.neg, c.hi, c.lo, c.prec, rand.Int())
}
Expand Down Expand Up @@ -605,6 +605,55 @@ func FuzzPowInt(f *testing.F) {
})
}

func FuzzPowInt32(f *testing.F) {
for _, c := range corpus {
f.Add(c.neg, c.hi, c.lo, c.prec, rand.Int())
}

f.Fuzz(func(t *testing.T, aneg bool, ahi uint64, alo uint64, aprec uint8, pow int) {
a, err := NewFromHiLo(aneg, ahi, alo, aprec)
if err == ErrPrecOutOfRange {
t.Skip()
} else {
require.NoError(t, err)
}

// use pow less than 10000
p := pow % 10000

c, err := a.PowInt32(int32(p))
if a.IsZero() && p < 0 {
require.Equal(t, err, ErrDivideByZero)
return
}

if c.coef.overflow() {
require.NotNil(t, c.coef.bigInt)
require.Equal(t, u128{}, c.coef.u128)
} else {
require.Nil(t, c.coef.bigInt)
}

// compare with shopspring/decimal
aa := ssDecimal(aneg, ahi, alo, aprec)
aa, err = aa.PowWithPrecision(ss.New(int64(p), 0), int32(c.prec)+4)

// special case for 0^0
// udecimal: 0^0 = 1
// shopspring/decimal: 0^0 is undefined and will return an error
if a.IsZero() && p == 0 {
require.EqualError(t, err, "cannot represent undefined value of 0**0")
require.Equal(t, "1", c.String())
return
}

require.NoError(t, err)
aa = aa.Truncate(int32(c.prec))

require.Equal(t, aa.String(), c.String(), "powInt %s %d", a, p)
})
}

func FuzzPowNegative(f *testing.F) {
for _, c := range corpus {
f.Add(c.neg, c.hi, c.lo, c.prec, rand.Int64())
Expand Down

0 comments on commit 1307249

Please sign in to comment.