Skip to content

Commit

Permalink
Merge pull request #79 from oscarbenjamin/pr_nmod_pow
Browse files Browse the repository at this point in the history
fix(nmod): ZeroDivisionError instead of coredump
  • Loading branch information
oscarbenjamin authored Sep 9, 2023
2 parents 9a86d4e + 630ea01 commit 243d896
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 17 deletions.
13 changes: 10 additions & 3 deletions src/flint/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,6 +1279,7 @@ def test_nmod():
assert G(1,2) != G(0,2)
assert G(0,2) != G(0,3)
assert G(3,5) == G(8,5)
assert G(1,2) != (1,2)
assert isinstance(hash(G(3, 5)), int)
assert raises(lambda: G([], 3), TypeError)
#assert G(3,5) == 8 # do we want this?
Expand All @@ -1304,14 +1305,20 @@ def test_nmod():
assert G(0,3) / G(1,3) == G(0,3)
assert G(3,17) * flint.fmpq(11,5) == G(10,17)
assert G(3,17) / flint.fmpq(11,5) == G(6,17)
assert raises(lambda: G(flint.fmpq(2, 3), 3), ZeroDivisionError)
assert raises(lambda: G(2,5) / G(0,5), ZeroDivisionError)
assert raises(lambda: G(2,5) / 0, ZeroDivisionError)
assert G(1,6) / G(5,6) == G(5,6)
assert raises(lambda: G(1,6) / G(3,6), ZeroDivisionError)
assert G(1,3) ** 2 == G(1,3)
assert G(2,3) ** flint.fmpz(2) == G(1,3)
assert ~G(2,7) == G(2,7) ** -1 == G(4,7)
assert raises(lambda: G(3,6) ** -1, ZeroDivisionError)
assert raises(lambda: ~G(3,6), ZeroDivisionError)
assert raises(lambda: pow(G(1,3), 2, 7), TypeError)
assert G(flint.fmpq(2, 3), 5) == G(4,5)
assert raises(lambda: G(2,5) ** G(2,5), TypeError)
assert raises(lambda: flint.fmpz(2) ** G(2,5), TypeError)
assert raises(lambda: G(flint.fmpq(2, 3), 3), ZeroDivisionError)
assert raises(lambda: G(2,5) / G(0,5), ZeroDivisionError)
assert raises(lambda: G(2,5) / 0, ZeroDivisionError)
assert raises(lambda: G(2,5) + G(2,7), ValueError)
assert raises(lambda: G(2,5) - G(2,7), ValueError)
assert raises(lambda: G(2,5) * G(2,7), ValueError)
Expand Down
56 changes: 42 additions & 14 deletions src/flint/types/nmod.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ from flint.types.fmpz cimport any_as_fmpz
from flint.types.fmpz cimport fmpz
from flint.types.fmpq cimport fmpq

from flint.flintlib.flint cimport ulong
from flint.flintlib.fmpz cimport fmpz_t
from flint.flintlib.nmod cimport nmod_pow_fmpz, nmod_inv
from flint.flintlib.nmod_vec cimport *
from flint.flintlib.fmpz cimport fmpz_fdiv_ui, fmpz_init, fmpz_clear
from flint.flintlib.fmpz cimport fmpz_set_ui, fmpz_get_ui
from flint.flintlib.fmpq cimport fmpq_mod_fmpz
from flint.flintlib.ulong_extras cimport n_gcdinv

cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1:
cdef int success
Expand Down Expand Up @@ -64,9 +66,6 @@ cdef class nmod(flint_scalar):
def __int__(self):
return int(self.val)

def __long__(self):
return self.val

def modulus(self):
return self.mod.n

Expand Down Expand Up @@ -170,6 +169,8 @@ cdef class nmod(flint_scalar):
cdef nmod r
cdef mp_limb_t sval, tval, x
cdef nmod_t mod
cdef ulong tinvval

if typecheck(s, nmod):
mod = (<nmod>s).mod
sval = (<nmod>s).val
Expand All @@ -180,17 +181,19 @@ cdef class nmod(flint_scalar):
tval = (<nmod>t).val
if not any_as_nmod(&sval, s, mod):
return NotImplemented

if tval == 0:
raise ZeroDivisionError("%s is not invertible mod %s" % (tval, mod.n))
if not s:
return s
# XXX: check invertibility?
x = nmod_div(sval, tval, mod)
if x == 0:

g = n_gcdinv(&tinvval, <ulong>tval, <ulong>mod.n)
if g != 1:
raise ZeroDivisionError("%s is not invertible mod %s" % (tval, mod.n))

r = nmod.__new__(nmod)
r.mod = mod
r.val = x
r.val = nmod_mul(sval, <mp_limb_t>tinvval, mod)
return r

def __truediv__(s, t):
Expand All @@ -200,18 +203,43 @@ cdef class nmod(flint_scalar):
return nmod._div_(t, s)

def __invert__(self):
return (1 / self) # XXX: speed up
cdef nmod r
cdef ulong g, inv, sval
sval = <ulong>(<nmod>self).val
g = n_gcdinv(&inv, sval, self.mod.n)
if g != 1:
raise ZeroDivisionError("%s is not invertible mod %s" % (sval, self.mod.n))
r = nmod.__new__(nmod)
r.mod = self.mod
r.val = <mp_limb_t>inv
return r

def __pow__(self, exp):
def __pow__(self, exp, modulus=None):
cdef nmod r
cdef mp_limb_t rval, mod
cdef ulong g, rinv

if modulus is not None:
raise TypeError("three-argument pow() not supported by nmod")

e = any_as_fmpz(exp)
if e is NotImplemented:
return NotImplemented
r = nmod.__new__(nmod)
r.mod = self.mod
r.val = self.val

rval = (<nmod>self).val
mod = (<nmod>self).mod.n

# XXX: It is not clear that it is necessary to special case negative
# exponents here. The nmod_pow_fmpz function seems to handle this fine
# but the Flint docs say that the exponent must be nonnegative.
if e < 0:
r.val = nmod_inv(r.val, self.mod)
g = n_gcdinv(&rinv, <ulong>rval, <ulong>mod)
if g != 1:
raise ZeroDivisionError("%s is not invertible mod %s" % (rval, mod))
rval = <mp_limb_t>rinv
e = -e
r.val = nmod_pow_fmpz(r.val, (<fmpz>e).val, self.mod)

r = nmod.__new__(nmod)
r.mod = self.mod
r.val = nmod_pow_fmpz(rval, (<fmpz>e).val, self.mod)
return r

0 comments on commit 243d896

Please sign in to comment.