diff --git a/PENDING.md b/PENDING.md index 3dba15fc8115..bfe66c6e3385 100644 --- a/PENDING.md +++ b/PENDING.md @@ -45,6 +45,8 @@ ### SDK +* [\#3665] Overhaul sdk.Uint type in preparation for Coins's Int -> Uint migration. + ### Tendermint diff --git a/types/int.go b/types/int.go index e19809548f6f..907b492b7702 100644 --- a/types/int.go +++ b/types/int.go @@ -318,11 +318,6 @@ func (i Int) String() string { return i.i.String() } -// Testing purpose random Int generator -func randomInt(i Int) Int { - return NewIntFromBigInt(random(i.BigInt())) -} - // MarshalAmino defines custom encoding scheme func (i Int) MarshalAmino() (string, error) { if i.i == nil { // Necessary since default Uint initialization has i.i as nil @@ -355,256 +350,6 @@ func (i *Int) UnmarshalJSON(bz []byte) error { return unmarshalJSON(i.i, bz) } -// Int wraps integer with 256 bit range bound -// Checks overflow, underflow and division by zero -// Exists in range from 0 to 2^256-1 -type Uint struct { - i *big.Int -} - -// BigInt converts Uint to big.Unt -func (i Uint) BigInt() *big.Int { - return new(big.Int).Set(i.i) -} - -// NewUint constructs Uint from int64 -func NewUint(n uint64) Uint { - i := new(big.Int) - i.SetUint64(n) - return Uint{i} -} - -// NewUintFromBigUint constructs Uint from big.Uint -func NewUintFromBigInt(i *big.Int) Uint { - res := Uint{i} - if UintOverflow(res) { - panic("Uint overflow") - } - return res -} - -// NewUintFromString constructs Uint from string -func NewUintFromString(s string) (res Uint, ok bool) { - i, ok := newIntegerFromString(s) - if !ok { - return - } - // Check overflow - if i.Sign() == -1 || i.Sign() == 1 && i.BitLen() > 256 { - ok = false - return - } - return Uint{i}, true -} - -// NewUintWithDecimal constructs Uint with decimal -// Result value is n*10^dec -func NewUintWithDecimal(n uint64, dec int) Uint { - if dec < 0 { - panic("NewUintWithDecimal() decimal is negative") - } - exp := new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(dec)), nil) - i := new(big.Int) - i.Mul(new(big.Int).SetUint64(n), exp) - - res := Uint{i} - if UintOverflow(res) { - panic("NewUintWithDecimal() out of bound") - } - - return res -} - -// ZeroUint returns Uint value with zero -func ZeroUint() Uint { return Uint{big.NewInt(0)} } - -// OneUint returns Uint value with one -func OneUint() Uint { return Uint{big.NewInt(1)} } - -// Uint64 converts Uint to uint64 -// Panics if the value is out of range -func (i Uint) Uint64() uint64 { - if !i.i.IsUint64() { - panic("Uint64() out of bound") - } - return i.i.Uint64() -} - -// IsUint64 returns true if Uint64() not panics -func (i Uint) IsUint64() bool { - return i.i.IsUint64() -} - -// IsZero returns true if Uint is zero -func (i Uint) IsZero() bool { - return i.i.Sign() == 0 -} - -// Sign returns sign of Uint -func (i Uint) Sign() int { - return i.i.Sign() -} - -// Equal compares two Uints -func (i Uint) Equal(i2 Uint) bool { - return equal(i.i, i2.i) -} - -// GT returns true if first Uint is greater than second -func (i Uint) GT(i2 Uint) bool { - return gt(i.i, i2.i) -} - -// LT returns true if first Uint is lesser than second -func (i Uint) LT(i2 Uint) bool { - return lt(i.i, i2.i) -} - -// Add adds Uint from another -func (i Uint) Add(i2 Uint) (res Uint) { - res = Uint{add(i.i, i2.i)} - if UintOverflow(res) { - panic("Uint overflow") - } - return -} - -// AddRaw adds uint64 to Uint -func (i Uint) AddRaw(i2 uint64) Uint { - return i.Add(NewUint(i2)) -} - -// Sub subtracts Uint from another -func (i Uint) Sub(i2 Uint) (res Uint) { - res = Uint{sub(i.i, i2.i)} - if UintOverflow(res) { - panic("Uint overflow") - } - return -} - -// SafeSub attempts to subtract one Uint from another. A boolean is also returned -// indicating if the result contains integer overflow. -func (i Uint) SafeSub(i2 Uint) (Uint, bool) { - res := Uint{sub(i.i, i2.i)} - if UintOverflow(res) { - return res, true - } - - return res, false -} - -// SubRaw subtracts uint64 from Uint -func (i Uint) SubRaw(i2 uint64) Uint { - return i.Sub(NewUint(i2)) -} - -// Mul multiples two Uints -func (i Uint) Mul(i2 Uint) (res Uint) { - if i.i.BitLen()+i2.i.BitLen()-1 > 256 { - panic("Uint overflow") - } - - res = Uint{mul(i.i, i2.i)} - if UintOverflow(res) { - panic("Uint overflow") - } - - return -} - -// MulRaw multipies Uint and uint64 -func (i Uint) MulRaw(i2 uint64) Uint { - return i.Mul(NewUint(i2)) -} - -// Div divides Uint with Uint -func (i Uint) Div(i2 Uint) (res Uint) { - // Check division-by-zero - if i2.Sign() == 0 { - panic("division-by-zero") - } - return Uint{div(i.i, i2.i)} -} - -// Div divides Uint with uint64 -func (i Uint) DivRaw(i2 uint64) Uint { - return i.Div(NewUint(i2)) -} - -// Mod returns remainder after dividing with Uint -func (i Uint) Mod(i2 Uint) Uint { - if i2.Sign() == 0 { - panic("division-by-zero") - } - return Uint{mod(i.i, i2.i)} -} - -// ModRaw returns remainder after dividing with uint64 -func (i Uint) ModRaw(i2 uint64) Uint { - return i.Mod(NewUint(i2)) -} - -// Return the minimum of the Uints -func MinUint(i1, i2 Uint) Uint { - return Uint{min(i1.BigInt(), i2.BigInt())} -} - -// MaxUint returns the maximum between two unsigned integers. -func MaxUint(i, i2 Uint) Uint { - return Uint{max(i.BigInt(), i2.BigInt())} -} - -// Human readable string -func (i Uint) String() string { - return i.i.String() -} - -// Testing purpose random Uint generator -func randomUint(i Uint) Uint { - return NewUintFromBigInt(random(i.BigInt())) -} - -// MarshalAmino defines custom encoding scheme -func (i Uint) MarshalAmino() (string, error) { - if i.i == nil { // Necessary since default Uint initialization has i.i as nil - i.i = new(big.Int) - } - return marshalAmino(i.i) -} - -// UnmarshalAmino defines custom decoding scheme -func (i *Uint) UnmarshalAmino(text string) error { - if i.i == nil { // Necessary since default Uint initialization has i.i as nil - i.i = new(big.Int) - } - return unmarshalAmino(i.i, text) -} - -// MarshalJSON defines custom encoding scheme -func (i Uint) MarshalJSON() ([]byte, error) { - if i.i == nil { // Necessary since default Uint initialization has i.i as nil - i.i = new(big.Int) - } - return marshalJSON(i.i) -} - -// UnmarshalJSON defines custom decoding scheme -func (i *Uint) UnmarshalJSON(bz []byte) error { - if i.i == nil { // Necessary since default Uint initialization has i.i as nil - i.i = new(big.Int) - } - return unmarshalJSON(i.i, bz) -} - -//__________________________________________________________________________ - -// UintOverflow returns true if a given unsigned integer overflows and false -// otherwise. -func UintOverflow(x Uint) bool { - return x.i.Sign() == -1 || x.i.Sign() == 1 && x.i.BitLen() > 256 -} - // intended to be used with require/assert: require.True(IntEq(...)) func IntEq(t *testing.T, exp, got Int) (*testing.T, bool, string, string, string) { return t, exp.Equal(got), "expected:\t%v\ngot:\t\t%v", exp.String(), got.String() diff --git a/types/int_test.go b/types/int_test.go index 20a2f3f61bf4..2fef894a4412 100644 --- a/types/int_test.go +++ b/types/int_test.go @@ -1,7 +1,6 @@ package types import ( - "math" "math/big" "math/rand" "strconv" @@ -73,45 +72,6 @@ func TestIntPanic(t *testing.T) { require.Panics(t, func() { i1.Div(NewInt(0)) }) } -func TestUintPanic(t *testing.T) { - // Max Uint = 1.15e+77 - // Min Uint = 0 - require.NotPanics(t, func() { NewUintWithDecimal(5, 76) }) - i1 := NewUintWithDecimal(5, 76) - require.NotPanics(t, func() { NewUintWithDecimal(10, 76) }) - i2 := NewUintWithDecimal(10, 76) - require.NotPanics(t, func() { NewUintWithDecimal(11, 76) }) - i3 := NewUintWithDecimal(11, 76) - - require.Panics(t, func() { NewUintWithDecimal(12, 76) }) - require.Panics(t, func() { NewUintWithDecimal(1, 80) }) - - // Overflow check - require.NotPanics(t, func() { i1.Add(i1) }) - require.Panics(t, func() { i2.Add(i2) }) - require.Panics(t, func() { i3.Add(i3) }) - - require.Panics(t, func() { i1.Mul(i1) }) - require.Panics(t, func() { i2.Mul(i2) }) - require.Panics(t, func() { i3.Mul(i3) }) - - // Underflow check - require.NotPanics(t, func() { i2.Sub(i1) }) - require.NotPanics(t, func() { i2.Sub(i2) }) - require.Panics(t, func() { i2.Sub(i3) }) - - // Bound check - uintmax := NewUintFromBigInt(new(big.Int).Sub(new(big.Int).Exp(big.NewInt(2), big.NewInt(256), nil), big.NewInt(1))) - uintmin := NewUint(0) - require.NotPanics(t, func() { uintmax.Add(ZeroUint()) }) - require.NotPanics(t, func() { uintmin.Sub(ZeroUint()) }) - require.Panics(t, func() { uintmax.Add(OneUint()) }) - require.Panics(t, func() { uintmin.Sub(OneUint()) }) - - // Division-by-zero check - require.Panics(t, func() { i1.Div(uintmin) }) -} - // Tests below uses randomness // Since we are using *big.Int as underlying value // and (U/)Int is immutable value(see TestImmutability(U/)Int) @@ -205,28 +165,6 @@ func TestCompInt(t *testing.T) { } } -func TestIdentUint(t *testing.T) { - for d := 0; d < 1000; d++ { - n := rand.Uint64() - i := NewUint(n) - - ifromstr, ok := NewUintFromString(strconv.FormatUint(n, 10)) - require.True(t, ok) - - cases := []uint64{ - i.Uint64(), - i.BigInt().Uint64(), - ifromstr.Uint64(), - NewUintFromBigInt(new(big.Int).SetUint64(n)).Uint64(), - NewUintWithDecimal(n, 0).Uint64(), - } - - for tcnum, tc := range cases { - require.Equal(t, n, tc, "Uint is modified during conversion. tc #%d", tcnum) - } - } -} - func minuint(i1, i2 uint64) uint64 { if i1 < i2 { return i1 @@ -241,71 +179,6 @@ func maxuint(i1, i2 uint64) uint64 { return i2 } -func TestArithUint(t *testing.T) { - for d := 0; d < 1000; d++ { - n1 := uint64(rand.Uint32()) - i1 := NewUint(n1) - n2 := uint64(rand.Uint32()) - i2 := NewUint(n2) - - cases := []struct { - ires Uint - nres uint64 - }{ - {i1.Add(i2), n1 + n2}, - {i1.Mul(i2), n1 * n2}, - {i1.Div(i2), n1 / n2}, - {i1.AddRaw(n2), n1 + n2}, - {i1.MulRaw(n2), n1 * n2}, - {i1.DivRaw(n2), n1 / n2}, - {MinUint(i1, i2), minuint(n1, n2)}, - {MaxUint(i1, i2), maxuint(n1, n2)}, - } - - for tcnum, tc := range cases { - require.Equal(t, tc.nres, tc.ires.Uint64(), "Uint arithmetic operation does not match with uint64 operation. tc #%d", tcnum) - } - - if n2 > n1 { - continue - } - - subs := []struct { - ires Uint - nres uint64 - }{ - {i1.Sub(i2), n1 - n2}, - {i1.SubRaw(n2), n1 - n2}, - } - - for tcnum, tc := range subs { - require.Equal(t, tc.nres, tc.ires.Uint64(), "Uint subtraction does not match with uint64 operation. tc #%d", tcnum) - } - } -} - -func TestCompUint(t *testing.T) { - for d := 0; d < 1000; d++ { - n1 := rand.Uint64() - i1 := NewUint(n1) - n2 := rand.Uint64() - i2 := NewUint(n2) - - cases := []struct { - ires bool - nres bool - }{ - {i1.Equal(i2), n1 == n2}, - {i1.GT(i2), n1 > n2}, - {i1.LT(i2), n1 < n2}, - } - - for tcnum, tc := range cases { - require.Equal(t, tc.nres, tc.ires, "Uint comparison operation does not match with uint64 operation. tc #%d", tcnum) - } - } -} - func randint() Int { return NewInt(rand.Int63()) } @@ -394,108 +267,6 @@ func TestImmutabilityArithInt(t *testing.T) { } } } -func TestImmutabilityAllUint(t *testing.T) { - ops := []func(*Uint){ - func(i *Uint) { _ = i.Add(NewUint(rand.Uint64())) }, - func(i *Uint) { _ = i.Sub(NewUint(rand.Uint64() % i.Uint64())) }, - func(i *Uint) { _ = i.Mul(randuint()) }, - func(i *Uint) { _ = i.Div(randuint()) }, - func(i *Uint) { _ = i.AddRaw(rand.Uint64()) }, - func(i *Uint) { _ = i.SubRaw(rand.Uint64() % i.Uint64()) }, - func(i *Uint) { _ = i.MulRaw(rand.Uint64()) }, - func(i *Uint) { _ = i.DivRaw(rand.Uint64()) }, - func(i *Uint) { _ = i.IsZero() }, - func(i *Uint) { _ = i.Sign() }, - func(i *Uint) { _ = i.Equal(randuint()) }, - func(i *Uint) { _ = i.GT(randuint()) }, - func(i *Uint) { _ = i.LT(randuint()) }, - func(i *Uint) { _ = i.String() }, - } - - for i := 0; i < 1000; i++ { - n := rand.Uint64() - ni := NewUint(n) - - for opnum, op := range ops { - op(&ni) - - require.Equal(t, n, ni.Uint64(), "Uint is modified by operation. #%d", opnum) - require.Equal(t, NewUint(n), ni, "Uint is modified by operation. #%d", opnum) - } - } -} - -type uintop func(Uint, *big.Int) (Uint, *big.Int) - -func uintarith(uifn func(Uint, Uint) Uint, bifn func(*big.Int, *big.Int, *big.Int) *big.Int, sub bool) uintop { - return func(ui Uint, bi *big.Int) (Uint, *big.Int) { - r := rand.Uint64() - if sub && ui.IsUint64() { - if ui.IsZero() { - return ui, bi - } - r = r % ui.Uint64() - } - ur := NewUint(r) - br := new(big.Int).SetUint64(r) - return uifn(ui, ur), bifn(new(big.Int), bi, br) - } -} - -func uintarithraw(uifn func(Uint, uint64) Uint, bifn func(*big.Int, *big.Int, *big.Int) *big.Int, sub bool) uintop { - return func(ui Uint, bi *big.Int) (Uint, *big.Int) { - r := rand.Uint64() - if sub && ui.IsUint64() { - if ui.IsZero() { - return ui, bi - } - r = r % ui.Uint64() - } - br := new(big.Int).SetUint64(r) - mui := ui.ModRaw(math.MaxUint64) - mbi := new(big.Int).Mod(bi, new(big.Int).SetUint64(math.MaxUint64)) - return uifn(mui, r), bifn(new(big.Int), mbi, br) - } -} - -func TestImmutabilityArithUint(t *testing.T) { - size := 500 - - ops := []uintop{ - uintarith(Uint.Add, (*big.Int).Add, false), - uintarith(Uint.Sub, (*big.Int).Sub, true), - uintarith(Uint.Mul, (*big.Int).Mul, false), - uintarith(Uint.Div, (*big.Int).Div, false), - uintarithraw(Uint.AddRaw, (*big.Int).Add, false), - uintarithraw(Uint.SubRaw, (*big.Int).Sub, true), - uintarithraw(Uint.MulRaw, (*big.Int).Mul, false), - uintarithraw(Uint.DivRaw, (*big.Int).Div, false), - } - - for i := 0; i < 100; i++ { - uis := make([]Uint, size) - bis := make([]*big.Int, size) - - n := rand.Uint64() - ui := NewUint(n) - bi := new(big.Int).SetUint64(n) - - for j := 0; j < size; j++ { - op := ops[rand.Intn(len(ops))] - uis[j], bis[j] = op(ui, bi) - } - - for j := 0; j < size; j++ { - require.Equal(t, 0, bis[j].Cmp(uis[j].BigInt()), "Int is different from *big.Int. tc #%d, Int %s, *big.Int %s", j, uis[j].String(), bis[j].String()) - require.Equal(t, NewUintFromBigInt(bis[j]), uis[j], "Int is different from *big.Int. tc #%d, Int %s, *big.Int %s", j, uis[j].String(), bis[j].String()) - require.True(t, uis[j].i != bis[j], "Pointer addresses are equal. tc #%d, Int %s, *big.Int %s", j, uis[j].String(), bis[j].String()) - } - } -} - -func randuint() Uint { - return NewUint(rand.Uint64()) -} func TestEncodingRandom(t *testing.T) { for i := 0; i < 1000; i++ { @@ -607,31 +378,6 @@ func TestEncodingTableUint(t *testing.T) { } } -func TestSafeSub(t *testing.T) { - testCases := []struct { - x, y Uint - expected uint64 - overflow bool - }{ - {NewUint(0), NewUint(0), 0, false}, - {NewUint(10), NewUint(5), 5, false}, - {NewUint(5), NewUint(10), 5, true}, - {NewUint(math.MaxUint64), NewUint(0), math.MaxUint64, false}, - } - - for i, tc := range testCases { - res, overflow := tc.x.SafeSub(tc.y) - require.Equal( - t, tc.overflow, overflow, - "invalid overflow result; x: %s, y: %s, tc: #%d", tc.x, tc.y, i, - ) - require.Equal( - t, tc.expected, res.BigInt().Uint64(), - "invalid subtraction result; x: %s, y: %s, tc: #%d", tc.x, tc.y, i, - ) - } -} - func TestSerializationOverflow(t *testing.T) { bx, _ := new(big.Int).SetString("91888242871839275229946405745257275988696311157297823662689937894645226298583", 10) x := Int{bx} diff --git a/types/uint.go b/types/uint.go new file mode 100644 index 000000000000..cdbba52174fd --- /dev/null +++ b/types/uint.go @@ -0,0 +1,172 @@ +package types + +import ( + "errors" + "fmt" + "math/big" +) + +// Uint wraps integer with 256 bit range bound +// Checks overflow, underflow and division by zero +// Exists in range from 0 to 2^256-1 +type Uint struct { + i *big.Int +} + +// NewUintFromBigUint constructs Uint from big.Uint +func NewUintFromBigInt(i *big.Int) Uint { + u, err := checkNewUint(i) + if err != nil { + panic(fmt.Errorf("overflow: %s", err)) + } + return u +} + +// NewUint constructs Uint from int64 +func NewUint(n uint64) Uint { + i := new(big.Int) + i.SetUint64(n) + return NewUintFromBigInt(i) +} + +// NewUintFromString constructs Uint from string +func NewUintFromString(s string) Uint { + u, err := ParseUint(s) + if err != nil { + panic(err) + } + return u +} + +// ZeroUint returns unsigned zero. +func ZeroUint() Uint { return Uint{big.NewInt(0)} } + +// OneUint returns Uint value with one. +func OneUint() Uint { return Uint{big.NewInt(1)} } + +// Uint64 converts Uint to uint64 +// Panics if the value is out of range +func (u Uint) Uint64() uint64 { + if !u.i.IsUint64() { + panic("Uint64() out of bound") + } + return u.i.Uint64() +} + +// IsZero returns 1 if the uint equals to 0. +func (u Uint) IsZero() bool { return u.Equal(ZeroUint()) } + +// Equal compares two Uints +func (u Uint) Equal(u2 Uint) bool { return equal(u.i, u2.i) } + +// GT returns true if first Uint is greater than second +func (u Uint) GT(u2 Uint) bool { return gt(u.i, u2.i) } + +// GTE returns true if first Uint is greater than second +func (u Uint) GTE(u2 Uint) bool { return u.GT(u2) || u.Equal(u2) } + +// LT returns true if first Uint is lesser than second +func (u Uint) LT(u2 Uint) bool { return lt(u.i, u2.i) } + +// LTE returns true if first Uint is lesser than or equal to the second +func (u Uint) LTE(u2 Uint) bool { return !u.GTE(u2) } + +// Add adds Uint from another +func (u Uint) Add(u2 Uint) Uint { return NewUintFromBigInt(new(big.Int).Add(u.i, u2.i)) } + +// Add convert uint64 and add it to Uint +func (u Uint) AddUint64(u2 uint64) Uint { return u.Add(NewUint(u2)) } + +// Sub adds Uint from another +func (u Uint) Sub(u2 Uint) Uint { return NewUintFromBigInt(new(big.Int).Sub(u.i, u2.i)) } + +// SubUint64 adds Uint from another +func (u Uint) SubUint64(u2 uint64) Uint { return u.Sub(NewUint(u2)) } + +// Mul multiplies two Uints +func (u Uint) Mul(u2 Uint) (res Uint) { + return NewUintFromBigInt(new(big.Int).Mul(u.i, u2.i)) +} + +// Mul multiplies two Uints +func (u Uint) MulUint64(u2 uint64) (res Uint) { return u.Mul(NewUint(u2)) } + +// Div divides Uint with Uint +func (u Uint) Div(u2 Uint) (res Uint) { return NewUintFromBigInt(div(u.i, u2.i)) } + +// Div divides Uint with uint64 +func (u Uint) DivUint64(u2 uint64) Uint { return u.Div(NewUint(u2)) } + +// Return the minimum of the Uints +func MinUint(u1, u2 Uint) Uint { return NewUintFromBigInt(min(u1.i, u2.i)) } + +// Return the maximum of the Uints +func MaxUint(u1, u2 Uint) Uint { return NewUintFromBigInt(max(u1.i, u2.i)) } + +// Human readable string +func (u Uint) String() string { return u.i.String() } + +// Testing purpose random Uint generator +func randomUint(u Uint) Uint { return NewUintFromBigInt(random(u.i)) } + +// MarshalAmino defines custom encoding scheme +func (u Uint) MarshalAmino() (string, error) { + if u.i == nil { // Necessary since default Uint initialization has i.i as nil + u.i = new(big.Int) + } + return marshalAmino(u.i) +} + +// UnmarshalAmino defines custom decoding scheme +func (u *Uint) UnmarshalAmino(text string) error { + if u.i == nil { // Necessary since default Uint initialization has i.i as nil + u.i = new(big.Int) + } + return unmarshalAmino(u.i, text) +} + +// MarshalJSON defines custom encoding scheme +func (u Uint) MarshalJSON() ([]byte, error) { + if u.i == nil { // Necessary since default Uint initialization has i.i as nil + u.i = new(big.Int) + } + return marshalJSON(u.i) +} + +// UnmarshalJSON defines custom decoding scheme +func (u *Uint) UnmarshalJSON(bz []byte) error { + if u.i == nil { // Necessary since default Uint initialization has i.i as nil + u.i = new(big.Int) + } + return unmarshalJSON(u.i, bz) +} + +//__________________________________________________________________________ + +// UintOverflow returns true if a given unsigned integer overflows and false +// otherwise. +func UintOverflow(i *big.Int) error { + if i.Sign() < 0 { + return errors.New("non-positive integer") + } + if i.BitLen() > 256 { + return fmt.Errorf("bit length %d greater than 256", i.BitLen()) + } + return nil +} + +// ParseUint reads a string-encoded Uint value and return a Uint. +func ParseUint(s string) (Uint, error) { + i, ok := new(big.Int).SetString(s, 0) + if !ok { + return Uint{}, fmt.Errorf("cannot convert %q to big.Int", s) + } + return checkNewUint(i) +} + +func checkNewUint(i *big.Int) (Uint, error) { + if err := UintOverflow(i); err != nil { + return Uint{}, err + } + return Uint{i}, nil +} diff --git a/types/uint_test.go b/types/uint_test.go new file mode 100644 index 000000000000..8c1c7a8e8dff --- /dev/null +++ b/types/uint_test.go @@ -0,0 +1,253 @@ +package types + +import ( + "math" + "math/big" + "math/rand" + "strconv" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUintPanics(t *testing.T) { + // Max Uint = 1.15e+77 + // Min Uint = 0 + u1 := NewUint(0) + u2 := OneUint() + + require.Equal(t, uint64(0), u1.Uint64()) + require.Equal(t, uint64(1), u2.Uint64()) + + require.Panics(t, func() { NewUintFromBigInt(big.NewInt(-5)) }) + require.Panics(t, func() { NewUintFromString("-1") }) + require.NotPanics(t, func() { + require.True(t, NewUintFromString("0").Equal(ZeroUint())) + require.True(t, NewUintFromString("5").Equal(NewUint(5))) + }) + + // Overflow check + require.True(t, u1.Add(u1).Equal(ZeroUint())) + require.True(t, u1.Add(OneUint()).Equal(OneUint())) + require.Equal(t, uint64(0), u1.Uint64()) + require.Equal(t, uint64(1), OneUint().Uint64()) + require.Panics(t, func() { u1.SubUint64(2) }) + require.True(t, u1.SubUint64(0).Equal(ZeroUint())) + require.True(t, u2.Add(OneUint()).Sub(OneUint()).Equal(OneUint())) // i2 == 1 + require.True(t, u2.Add(OneUint()).Mul(NewUint(5)).Equal(NewUint(10))) // i2 == 10 + require.True(t, NewUint(7).Div(NewUint(2)).Equal(NewUint(3))) + require.True(t, NewUint(0).Div(NewUint(2)).Equal(ZeroUint())) + require.True(t, NewUint(5).MulUint64(4).Equal(NewUint(20))) + require.True(t, NewUint(5).MulUint64(0).Equal(ZeroUint())) + + uintmax := NewUintFromBigInt(new(big.Int).Sub(new(big.Int).Exp(big.NewInt(2), big.NewInt(256), nil), big.NewInt(1))) + uintmin := ZeroUint() + + // divs by zero + require.Panics(t, func() { OneUint().Mul(ZeroUint().SubUint64(uint64(1))) }) + require.Panics(t, func() { OneUint().DivUint64(0) }) + require.Panics(t, func() { OneUint().Div(ZeroUint()) }) + require.Panics(t, func() { ZeroUint().DivUint64(0) }) + require.Panics(t, func() { OneUint().Div(ZeroUint().Sub(OneUint())) }) + require.Panics(t, func() { uintmax.Add(OneUint()) }) + require.Panics(t, func() { uintmin.Sub(OneUint()) }) + + require.Equal(t, uint64(0), MinUint(ZeroUint(), OneUint()).Uint64()) + require.Equal(t, uint64(1), MaxUint(ZeroUint(), OneUint()).Uint64()) + + // comparison ops + require.True(t, + OneUint().GT(ZeroUint()), + ) + require.False(t, + OneUint().LT(ZeroUint()), + ) + require.True(t, + OneUint().GTE(ZeroUint()), + ) + require.False(t, + OneUint().LTE(ZeroUint()), + ) + + require.False(t, ZeroUint().GT(OneUint())) + require.True(t, ZeroUint().LT(OneUint())) + require.False(t, ZeroUint().GTE(OneUint())) + require.True(t, ZeroUint().LTE(OneUint())) +} + +func TestIdentUint(t *testing.T) { + for d := 0; d < 1000; d++ { + n := rand.Uint64() + i := NewUint(n) + + ifromstr := NewUintFromString(strconv.FormatUint(n, 10)) + + cases := []uint64{ + i.Uint64(), + i.i.Uint64(), + ifromstr.Uint64(), + NewUintFromBigInt(new(big.Int).SetUint64(n)).Uint64(), + } + + for tcnum, tc := range cases { + require.Equal(t, n, tc, "Uint is modified during conversion. tc #%d", tcnum) + } + } +} + +func TestArithUint(t *testing.T) { + for d := 0; d < 1000; d++ { + n1 := uint64(rand.Uint32()) + u1 := NewUint(n1) + n2 := uint64(rand.Uint32()) + u2 := NewUint(n2) + + cases := []struct { + ures Uint + nres uint64 + }{ + {u1.Add(u2), n1 + n2}, + {u1.Mul(u2), n1 * n2}, + {u1.Div(u2), n1 / n2}, + {u1.AddUint64(n2), n1 + n2}, + {u1.MulUint64(n2), n1 * n2}, + {u1.DivUint64(n2), n1 / n2}, + {MinUint(u1, u2), minuint(n1, n2)}, + {MaxUint(u1, u2), maxuint(n1, n2)}, + } + + for tcnum, tc := range cases { + require.Equal(t, tc.nres, tc.ures.Uint64(), "Uint arithmetic operation does not match with uint64 operation. tc #%d", tcnum) + } + + if n2 > n1 { + n1, n2 = n2, n1 + u1, u2 = NewUint(n1), NewUint(n2) + } + + subs := []struct { + ures Uint + nres uint64 + }{ + {u1.Sub(u2), n1 - n2}, + {u1.SubUint64(n2), n1 - n2}, + } + + for tcnum, tc := range subs { + require.Equal(t, tc.nres, tc.ures.Uint64(), "Uint subtraction does not match with uint64 operation. tc #%d", tcnum) + } + } +} + +func TestCompUint(t *testing.T) { + for d := 0; d < 1000; d++ { + n1 := rand.Uint64() + i1 := NewUint(n1) + n2 := rand.Uint64() + i2 := NewUint(n2) + + cases := []struct { + ires bool + nres bool + }{ + {i1.Equal(i2), n1 == n2}, + {i1.GT(i2), n1 > n2}, + {i1.LT(i2), n1 < n2}, + {i1.GTE(i2), !i1.LT(i2)}, + {!i1.GTE(i2), i1.LT(i2)}, + } + + for tcnum, tc := range cases { + require.Equal(t, tc.nres, tc.ires, "Uint comparison operation does not match with uint64 operation. tc #%d", tcnum) + } + } +} + +func TestImmutabilityAllUint(t *testing.T) { + ops := []func(*Uint){ + func(i *Uint) { _ = i.Add(NewUint(rand.Uint64())) }, + func(i *Uint) { _ = i.Sub(NewUint(rand.Uint64() % i.Uint64())) }, + func(i *Uint) { _ = i.Mul(randuint()) }, + func(i *Uint) { _ = i.Div(randuint()) }, + func(i *Uint) { _ = i.AddUint64(rand.Uint64()) }, + func(i *Uint) { _ = i.SubUint64(rand.Uint64() % i.Uint64()) }, + func(i *Uint) { _ = i.MulUint64(rand.Uint64()) }, + func(i *Uint) { _ = i.DivUint64(rand.Uint64()) }, + func(i *Uint) { _ = i.IsZero() }, + func(i *Uint) { _ = i.Equal(randuint()) }, + func(i *Uint) { _ = i.GT(randuint()) }, + func(i *Uint) { _ = i.GTE(randuint()) }, + func(i *Uint) { _ = i.LT(randuint()) }, + func(i *Uint) { _ = i.LTE(randuint()) }, + func(i *Uint) { _ = i.String() }, + } + + for i := 0; i < 1000; i++ { + n := rand.Uint64() + ni := NewUint(n) + + for opnum, op := range ops { + op(&ni) + + require.Equal(t, n, ni.Uint64(), "Uint is modified by operation. #%d", opnum) + require.Equal(t, NewUint(n), ni, "Uint is modified by operation. #%d", opnum) + } + } +} + +func TestSafeSub(t *testing.T) { + testCases := []struct { + x, y Uint + expected uint64 + panic bool + }{ + {NewUint(0), NewUint(0), 0, false}, + {NewUint(10), NewUint(5), 5, false}, + {NewUint(5), NewUint(10), 5, true}, + {NewUint(math.MaxUint64), NewUint(0), math.MaxUint64, false}, + } + + for i, tc := range testCases { + if tc.panic { + require.Panics(t, func() { tc.x.Sub(tc.y) }) + continue + } + require.Equal( + t, tc.expected, tc.x.Sub(tc.y).Uint64(), + "invalid subtraction result; x: %s, y: %s, tc: #%d", tc.x, tc.y, i, + ) + } +} + +func TestParseUint(t *testing.T) { + type args struct { + s string + } + tests := []struct { + name string + args args + want Uint + wantErr bool + }{ + {"malformed", args{"malformed"}, Uint{}, true}, + {"empty", args{""}, Uint{}, true}, + {"positive", args{"50"}, NewUint(uint64(50)), false}, + {"negative", args{"-1"}, Uint{}, true}, + {"zero", args{"0"}, ZeroUint(), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseUint(tt.args.s) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.True(t, got.Equal(tt.want)) + }) + } +} + +func randuint() Uint { + return NewUint(rand.Uint64()) +}