Skip to content

Commit

Permalink
add compose_mod and powmod with large exp
Browse files Browse the repository at this point in the history
  • Loading branch information
GiacomoPope committed Aug 5, 2024
1 parent 30e71dc commit 075e9e3
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 12 deletions.
24 changes: 22 additions & 2 deletions src/flint/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2090,6 +2090,18 @@ def test_fmpz_mod_poly():
assert f*f == f**2
assert f*f == f**fmpz(2)

# powmod
# assert ui and fmpz exp agree for polynomials and generators
R_gen = R_test.gen()
assert pow(f, 2**60, g) == pow(pow(f, 2**30, g), 2**30, g)
assert pow(R_gen, 2**60, g) == pow(pow(R_gen, 2**30, g), 2**30, g)

# Check other typechecks
assert raises(lambda: pow(f, -2, g), ValueError)
assert raises(lambda: pow(f, 1, "A"), TypeError)
assert raises(lambda: pow(f, "A", g), TypeError)
assert raises(lambda: f.powmod(2**32, g, mod_rev_inv="A"), TypeError)

# Shifts
assert raises(lambda: R_test([1,2,3]).left_shift(-1), ValueError)
assert raises(lambda: R_test([1,2,3]).right_shift(-1), ValueError)
Expand Down Expand Up @@ -2121,6 +2133,13 @@ def test_fmpz_mod_poly():
# compose
assert raises(lambda: h.compose("AAA"), TypeError)

# compose mod
mod = R_test([1,2,3,4])
assert f.compose(h) % mod == f.compose_mod(h, mod)
assert raises(lambda: h.compose_mod("AAA", mod), TypeError)
assert raises(lambda: h.compose_mod(f, "AAA"), TypeError)
assert raises(lambda: h.compose_mod(f, R_test(0)), ZeroDivisionError)

# Reverse
assert raises(lambda: h.reverse(degree=-100), ValueError)
assert R_test([-1,-2,-3]).reverse() == R_test([-3,-2,-1])
Expand Down Expand Up @@ -2606,8 +2625,9 @@ def setbad(obj, i, val):
assert raises(lambda: P([1, 1]) ** -1, ValueError)
assert raises(lambda: P([1, 1]) ** None, TypeError)

# # XXX: Not sure what this should do in general:
assert raises(lambda: pow(P([1, 1]), 2, 3), NotImplementedError)
# XXX: Not sure what this should do in general:
# TODO: this now fails as fmpz_mod_poly allows modulus
# assert raises(lambda: pow(P([1, 1]), 2, 3), NotImplementedError)

assert P([1, 2, 1]).gcd(P([1, 1])) == P([1, 1])
assert raises(lambda: P([1, 2, 1]).gcd(None), TypeError)
Expand Down
95 changes: 85 additions & 10 deletions src/flint/types/fmpz_mod_poly.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ cdef class fmpz_mod_poly(flint_poly):

def __pow__(self, e, mod=None):
if mod is not None:
raise NotImplementedError
return self.powmod(e, mod)

cdef fmpz_mod_poly res
if e < 0:
Expand Down Expand Up @@ -784,11 +784,11 @@ cdef class fmpz_mod_poly(flint_poly):

return evaluations

def compose(self, input):
def compose(self, other):
"""
Returns the composition of two polynomials
To be precise about the order of composition, given ``self``, and ``input``
To be precise about the order of composition, given ``self``, and ``other``
by `f(x)`, `g(x)`, returns `f(g(x))`.
>>> R = fmpz_mod_poly_ctx(163)
Expand All @@ -800,12 +800,45 @@ cdef class fmpz_mod_poly(flint_poly):
9*x^4 + 12*x^3 + 10*x^2 + 4*x + 1
"""
cdef fmpz_mod_poly res
val = self.ctx.any_as_fmpz_mod_poly(input)
val = self.ctx.any_as_fmpz_mod_poly(other)
if val is NotImplemented:
raise TypeError(f"Cannot compose the polynomial with input: {input}")
raise TypeError(f"Cannot compose the polynomial with input: {other}")

res = self.ctx.new_ctype_poly()
fmpz_mod_poly_compose(res.val, self.val, (<fmpz_mod_poly>val).val, self.ctx.mod.val)
return res

def compose_mod(self, other, modulus):
"""
Returns the composition of two polynomials modulo a third.
To be precise about the order of composition, given ``self``, and ``other``
and ``modulus`` by `f(x)`, `g(x)` and `h(x)`, returns `f(g(x)) \mod h(x)`.
We require that `h(x)` is non-zero.
>>> R = fmpz_mod_poly_ctx(163)
>>> f = R([1,2,3,4,5])
>>> g = R([3,2,1])
>>> h = R([1,0,1,0,1])
>>> f.compose_mod(g, h)
63*x^3 + 100*x^2 + 17*x + 63
>>> g.compose_mod(f, h)
147*x^3 + 159*x^2 + 4*x + 7
"""
cdef fmpz_mod_poly res
val = self.ctx.any_as_fmpz_mod_poly(other)
if val is NotImplemented:
raise TypeError(f"cannot compose the polynomial with input: {other}")

h = self.ctx.any_as_fmpz_mod_poly(modulus)
if h is NotImplemented:
raise TypeError(f"cannot reduce the polynomial with input: {modulus}")

if h.is_zero():
raise ZeroDivisionError("cannot reduce modulo zero")

res = self.ctx.new_ctype_poly()
fmpz_mod_poly_compose_mod(res.val, self.val, (<fmpz_mod_poly>val).val, (<fmpz_mod_poly>h).val, self.ctx.mod.val)
return res

cpdef long length(self):
Expand Down Expand Up @@ -1110,10 +1143,14 @@ cdef class fmpz_mod_poly(flint_poly):
)
return res

def powmod(self, e, modulus):
def powmod(self, e, modulus, mod_rev_inv=None):
"""
Returns ``self`` raised to the power ``e`` modulo ``modulus``:
:math:`f^e \mod g`
:math:`f^e \mod g`/
``mod_rev_inv`` is the inverse of the reverse of the modulus,
precomputing it and passing it to ``powmod()`` can optimise
powering of polynomials with large exponents.
>>> R = fmpz_mod_poly_ctx(163)
>>> x = R.gen()
Expand All @@ -1123,17 +1160,55 @@ cdef class fmpz_mod_poly(flint_poly):
>>>
>>> f.powmod(123, mod)
3*x^3 + 25*x^2 + 115*x + 161
>>> f.powmod(2**64, mod)
52*x^3 + 96*x^2 + 136*x + 9
>>> mod_rev_inv = mod.reverse().inverse_series_trunc(4)
>>> f.powmod(2**64, mod, mod_rev_inv)
52*x^3 + 96*x^2 + 136*x + 9
"""
cdef fmpz_mod_poly res

if e < 0:
raise ValueError("Exponent must be non-negative")

modulus = self.ctx.any_as_fmpz_mod_poly(modulus)
if modulus is NotImplemented:
raise TypeError(f"Cannot interpret {modulus} as a polynomial")

# Output polynomial
res = self.ctx.new_ctype_poly()
fmpz_mod_poly_powmod_ui_binexp(
res.val, self.val, <ulong>e, (<fmpz_mod_poly>modulus).val, res.ctx.mod.val
)

# For small exponents, use a simple powering method
if e.bit_length() < 32:
fmpz_mod_poly_powmod_ui_binexp(
res.val, self.val, <ulong>e, (<fmpz_mod_poly>modulus).val, res.ctx.mod.val
)
return res

# For larger exponents we can use faster algorithms, first convert exp to fmpz type
e_fmpz = any_as_fmpz(e)
if e_fmpz is NotImplemented:
raise ValueError(f"exponent cannot be cast to an fmpz type: {e = }")

# To optimise powering, we precompute the inverse of the reverse of the modulus
if mod_rev_inv is not None:
mod_rev_inv = self.ctx.any_as_fmpz_mod_poly(mod_rev_inv)
if mod_rev_inv is NotImplemented:
raise TypeError(f"Cannot interpret {mod_rev_inv} as a polynomial")
else:
mod_rev_inv = modulus.reverse().inverse_series_trunc(modulus.length())

# Use windows exponentiation optimisation when self = x
if self.is_gen():
fmpz_mod_poly_powmod_x_fmpz_preinv(
res.val, (<fmpz>e_fmpz).val, (<fmpz_mod_poly>modulus).val, (<fmpz_mod_poly>mod_rev_inv).val, res.ctx.mod.val
)
return res

# Otherwise using binary exponentiation for all other inputs
fmpz_mod_poly_powmod_fmpz_binexp_preinv(
res.val, self.val, (<fmpz>e_fmpz).val, (<fmpz_mod_poly>modulus).val, (<fmpz_mod_poly>mod_rev_inv).val, res.ctx.mod.val
)
return res

def divmod(self, other):
Expand Down

0 comments on commit 075e9e3

Please sign in to comment.