diff --git a/uint256.go b/uint256.go index 5faa400f..5c10cbd8 100644 --- a/uint256.go +++ b/uint256.go @@ -190,8 +190,12 @@ func (z *Int) AddOverflow(x, y *Int) bool { return carry != 0 } -// AddMod sets z to the sum ( x+y ) mod m, and returns z +// AddMod sets z to the sum ( x+y ) mod m, and returns z. +// If m == 0, z is set to 0 (OBS: differs from the big.Int) func (z *Int) AddMod(x, y, m *Int) *Int { + if m.IsZero() { + return z.Clear() + } if z == m { // z is an alias for m // TODO: Understand why needed and add tests for all "division" methods. m = m.Clone() } @@ -567,8 +571,12 @@ func (z *Int) SMod(x, y *Int) *Int { } // MulMod calculates the modulo-m multiplication of x and y and -// returns z +// returns z. +// If m == 0, z is set to 0 (OBS: differs from the big.Int) func (z *Int) MulMod(x, y, m *Int) *Int { + if x.IsZero() || y.IsZero() || m.IsZero() { + return z.Clear() + } p := umul(x, y) var ( pl Int diff --git a/uint256_test.go b/uint256_test.go index 35a17d85..ff6622fa 100644 --- a/uint256_test.go +++ b/uint256_test.go @@ -74,6 +74,10 @@ var ( // A collection of interesting input values for ternary operators (addmod, mulmod). ternTestCases = [][3]string{ + {"0", "0", "0"}, + {"1", "0", "0"}, + {"1", "1", "0"}, + {"0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffd", "0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe", "0"}, {"0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffd", "0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe", "0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"}, {"0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffd", "3", "0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"}, {"0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"}, @@ -920,9 +924,22 @@ func TestTernOp(t *testing.T) { } } } - - t.Run("AddMod", func(t *testing.T) { proc(t, (*Int).AddMod, addMod) }) - t.Run("MulMod", func(t *testing.T) { proc(t, (*Int).MulMod, mulMod) }) + t.Run("AddMod", func(t *testing.T) { + proc(t, (*Int).AddMod, func(z, x, y, m *big.Int) *big.Int { + if m.Sign() == 0 { + return z.SetUint64(0) + } + return addMod(z, x, y, m) + }) + }) + t.Run("MulMod", func(t *testing.T) { + proc(t, (*Int).MulMod, func(z, x, y, m *big.Int) *big.Int { + if m.Sign() == 0 { + return z.SetUint64(0) + } + return mulMod(z, x, y, m) + }) + }) } func TestCmpOp(t *testing.T) {