Skip to content

Commit

Permalink
more boilerplate code
Browse files Browse the repository at this point in the history
  • Loading branch information
GiacomoPope committed Aug 6, 2024
1 parent b728899 commit c8fe699
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 23 deletions.
73 changes: 72 additions & 1 deletion src/flint/flint_base/flint_base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,78 @@ cdef class flint_elem:


cdef class flint_scalar(flint_elem):
pass
# =================================================
# These are the functions a new class should define
# assumes that addition and multiplication are
# commutative
# =================================================
def _neg_(self):
return NotImplemented

def _add_(self, other):
return NotImplemented

@staticmethod
def _sub_(left, right):
return NotImplemented

def _mul_(self, other):
return NotImplemented

@staticmethod
def _div_(left, right):
return NotImplemented

@staticmethod
def _floordiv_(left, right):
return NotImplemented

def _invert_(self):
return NotImplemented

# =================================================
# Generic arithmetic using the above functions
# =================================================

def __pos__(self):
return self

def __neg__(self):
return self._neg_()

def __add__(self, other):
return self._add_(other)

def __radd__(self, other):
return self._add_(other)

def __sub__(self, other):
return self._sub_(self, other)

def __rsub__(self, other):
return self._sub_(other, self)

def __mul__(self, other):
return self._mul_(other)

def __rmul__(self, other):
return self._mul_(other)

def __truediv__(self, other):
return self._div_(self, other)

def __rtruediv__(self, other):
return self._div_(other, self)

def __floordiv__(self, other):
return self._floordiv_(self, other)

def __rfloordiv__(self, other):
return self._floordiv_(other, self)

def __invert__(self):
return self._invert_()



cdef class flint_poly(flint_elem):
Expand Down
6 changes: 4 additions & 2 deletions src/flint/flintlib/fq_default.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,10 @@ cdef extern from "flint/fq_default.h":
int fq_default_get_fmpz(fmpz_t rop, const fq_default_t op, const fq_default_ctx_t ctx)
void fq_default_get_nmod_poly(nmod_poly_t poly, const fq_default_t op, const fq_default_ctx_t ctx)
void fq_default_set_nmod_poly(fq_default_t op, const nmod_poly_t poly, const fq_default_ctx_t ctx)
void fq_default_get_fmpz_mod_poly(fmpz_mod_poly_t poly, const fq_default_t op, const fmpz_mod_ctx_t mod_ctx, const fq_default_ctx_t ctx)
void fq_default_set_fmpz_mod_poly(fq_default_t op, const fmpz_mod_poly_t poly, const fmpz_mod_ctx_t mod_ctx, const fq_default_ctx_t ctx)
# void fq_default_get_fmpz_mod_poly(fmpz_mod_poly_t poly, const fq_default_t op, const fmpz_mod_ctx_t mod_ctx, const fq_default_ctx_t ctx)
# void fq_default_set_fmpz_mod_poly(fq_default_t op, const fmpz_mod_poly_t poly, const fmpz_mod_ctx_t mod_ctx, const fq_default_ctx_t ctx)
void fq_default_get_fmpz_mod_poly(fmpz_mod_poly_t poly, const fq_default_t op, const fq_default_ctx_t ctx)
void fq_default_set_fmpz_mod_poly(fq_default_t op, const fmpz_mod_poly_t poly, const fq_default_ctx_t ctx)
void fq_default_get_fmpz_poly(fmpz_poly_t a, const fq_default_t b, const fq_default_ctx_t ctx)
void fq_default_set_fmpz_poly(fq_default_struct a, const fmpz_poly_t b, const fq_default_ctx_t ctx)
int fq_default_is_zero(const fq_default_t op, const fq_default_ctx_t ctx)
Expand Down
5 changes: 3 additions & 2 deletions src/flint/types/fq_default.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ cpdef enum fq_default_type:
cdef class fq_default_ctx:
cdef fq_default_ctx_t val
cdef readonly char *var
cdef bint initialized
cdef bint _initialized

cdef new_ctype_fq_default(self)
cdef set_any_as_fq_default(self, fq_default_t val, obj)
cdef any_as_fmpz_mod(self, obj)

@staticmethod
cdef fq_default_ctx c_from_order(fmpz p, int d, char *var, fq_default_type type=*)

@staticmethod
cdef fq_default_ctx c_from_modulus(modulus, char *var, fq_default_type type=*)

Expand Down
188 changes: 170 additions & 18 deletions src/flint/types/fq_default.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from flint.pyflint cimport global_random_state
from flint.types.fmpz cimport fmpz, any_as_fmpz
from flint.types.fmpz_poly cimport fmpz_poly
from flint.types.fmpz_mod_poly cimport fmpz_mod_poly, fmpz_mod_poly_ctx
from flint.types.nmod_poly cimport nmod_poly
from flint.utils.typecheck cimport typecheck
Expand All @@ -22,11 +24,14 @@ cdef class fq_default_ctx:
For more details, see the documentation of :method:`~.from_order`
and :method:`~.from_modulus`.
"""
def __cinit__(self):
pass

def __dealloc__(self):
if self.initialized is not None:
if self._initialized:
fq_default_ctx_clear(self.val)

def __init__(self):
def __init__(self, *args, **kwargs):
raise TypeError("This class cannot be instantiated directly. Use .from_order() or .from_modulus().")

@staticmethod
Expand All @@ -35,7 +40,7 @@ cdef class fq_default_ctx:
cdef fq_default_ctx ctx = fq_default_ctx.__new__(fq_default_ctx)
ctx.var = var
fq_default_ctx_init_type(ctx.val, p.val, d, ctx.var, type)
ctx.initialized = True
ctx._initialized = True
return ctx

@staticmethod
Expand Down Expand Up @@ -64,8 +69,6 @@ cdef class fq_default_ctx:
fq_default_ctx.from_modulus(x^2 + 4*x + 2, b'y', 2)
>>> gf.type
<fq_default_type.FQ_NMOD: 2>
"""
# c_from_order expects the characteristic to be fmpz type
order = any_as_fmpz(p)
Expand Down Expand Up @@ -163,6 +166,8 @@ cdef class fq_default_ctx:
p = fmpz.__new__(fmpz)
fq_default_ctx_prime(p.val, self.val)
return p

characteristic = prime

def order(self):
"""
Expand Down Expand Up @@ -195,11 +200,15 @@ cdef class fq_default_ctx:
"""
Return the zero element
>>> F = fmpz_mod_ctx(163)
>>> F.zero()
>>> Fq = fq_default_ctx.from_order(5, 1, "x")
>>> Fq.zero()
fmpz_mod(0, 163)
"""
pass
cdef fq_default res
res = self.new_ctype_fq_default()
res.ctx = self
fq_default_zero(res.val, self.val)
return res

def one(self):
"""
Expand All @@ -209,19 +218,33 @@ cdef class fq_default_ctx:
>>> F.one()
fmpz_mod(1, 163)
"""
pass
cdef fq_default res
res = self.new_ctype_fq_default()
res.ctx = self
fq_default_one(res.val, self.val)
return res

def random_element(self):
def random_element(self, not_zero=False):
r"""
Return a random element in :math:`\mathbb{Z}/N\mathbb{Z}`
Return a random element of the finite field
"""
pass
cdef fq_default res
res = self.new_ctype_fq_default()
res.ctx = self
if not_zero:
fq_default_rand_not_zero(res.val, global_random_state, self.val)
else:
fq_default_rand(res.val, global_random_state, self.val)
return res

cdef set_any_as_fq_default(self, fq_default_t val, obj):
pass
cdef new_ctype_fq_default(self):
return fq_default.__new__(fq_default, None, self)

cdef any_as_fmpz_mod(self, obj):
pass
cdef set_any_as_fq_default(self, fq_default_t val, obj):
if typecheck(obj, fmpz):
fq_default_set_fmpz(val, (<fmpz>obj).val, self.val)
return 0
return NotImplemented

def __eq__(self, other):
"""
Expand Down Expand Up @@ -255,11 +278,140 @@ cdef class fq_default_ctx:
return f"Context for fq_default in GF({self.prime()}^{self.degree()})[{self.var.decode()}]/({self.modulus().str(var=self.var.decode())})"

def __repr__(self):
return f"fq_default_ctx.from_modulus({self.modulus()!r}, {self.var.encode()}, {self.type})"
return f"fq_default_ctx.from_modulus({self.modulus()!r}, {self.var.decode()}, {self.type})"

def __call__(self, val):
return fq_default(val, self)


cdef class fq_default(flint_scalar):
pass
def __cinit__(self, val, ctx):
if not typecheck(ctx, fq_default_ctx):
raise TypeError
self.ctx = ctx
fq_default_init(self.val, self.ctx.val)

def __dealloc__(self):
if self.ctx is not None:
fq_default_clear(self.val, self.ctx.val)

def __init__(self, val, ctx):
if not typecheck(ctx, fq_default_ctx):
raise TypeError
self.ctx = ctx

check = self.ctx.set_any_as_fq_default(self.val, val)
if check is NotImplemented:
raise TypeError

def __int__(self):
"""
Attempts to lift self to an integer of type fmpz in [0, p-1]
"""
cdef fmpz x = fmpz.__new__(fmpz)
res = fq_default_get_fmpz(x.val, self.val, self.ctx.val)
if res == 1:
return int(x)
raise ValueError("fq element has no lift to the integers")

def polynomial(self):
"""
Returns a representative of ``self`` as a polynomial in `(Z/pZ)[x] / h(x)`
where `h(x)` is the defining polynomial of the finite field.
"""
cdef fmpz_mod_poly_ctx ctx
cdef fmpz_mod_poly pol

ring_ctx = fmpz_mod_poly_ctx(self.ctx.prime())
pol = ring_ctx.new_ctype_poly()
fq_default_get_fmpz_mod_poly((<fmpz_mod_poly>pol).val, self.val, self.ctx.val)

return pol

def __repr__(self):
return f"fq_default({self.polynomial()}, {self.ctx.__repr__()})"

def str(self):
return self.polynomial().__str__()

# =================================================
# Comparisons
# =================================================
def is_zero(self):
return 1 == fq_default_is_zero(self.val, self.ctx.val)

def is_one(self):
return 1 == fq_default_is_zero(self.val, self.ctx.val)

# =================================================
# Generic arithmetic required by flint_scalar
# =================================================

def _neg_(self):
cdef fq_default res
res = self.ctx.new_ctype_fq_default()
fq_default_neg(res.val, self.val, self.ctx.val)
return res

def _add_(self, other):
return NotImplemented

@staticmethod
def _sub_(left, right):
return NotImplemented

def _mul_(self, other):
return NotImplemented

@staticmethod
def _div_(left, right):
return NotImplemented

def _invert_(self):
cdef fq_default res
res = self.ctx.new_ctype_fq_default()
fq_default_inv(res.val, self.val, self.ctx.val)
return res

# =================================================
# Additional arithmetic
# =================================================

def square(self):
cdef fq_default res
res = self.ctx.new_ctype_fq_default()
fq_default_sqr(res.val, self.val, self.ctx.val)
return res

def __pow__(self, e):
return NotImplemented

def sqrt(self):
cdef fq_default res
res = self.ctx.new_ctype_fq_default()
check = fq_default_sqrt(res.val, self.val, self.ctx.val)
if check:
return res
raise ValueError("element is not a square")

def is_square(self):
return 1 == fq_default_is_square(self.val, self.ctx.val)

def pth_root(self):
cdef fq_default res
res = self.ctx.new_ctype_fq_default()
fq_default_pth_root(res.val, self.val, self.ctx.val)
return res

# =================================================
# Special functions
# =================================================

def trace(self):
return NotImplemented

def norm(self):
return NotImplemented

def frobenius(self):
return NotImplemented

0 comments on commit c8fe699

Please sign in to comment.