Skip to content

Commit

Permalink
Generalize the finite field arithmetic to GF(p) for p < 2^126
Browse files Browse the repository at this point in the history
This change swaps the existing implementation of GF(2^32 - 2^20 + 1)
with a constant-time implementation suitable for any GF(p) for which p <
2^126. It is based on Montgomery multiplication.
  • Loading branch information
cjpatton committed Mar 22, 2021
1 parent e0247d0 commit 941ec20
Show file tree
Hide file tree
Showing 6 changed files with 428 additions and 89 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ thiserror = "1.0"

[dev-dependencies]
assert_matches = "1.5.0"
modinverse = "0.1.0"
num-bigint = "0.4.0"

[[example]]
name = "sum"
149 changes: 64 additions & 85 deletions src/finite_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

//! Finite field arithmetic over a prime field using a 32bit prime.

use crate::fp::FieldParameters;

/// Possible errors from finite field operations.
#[derive(Debug, thiserror::Error)]
pub enum FiniteFieldError {
Expand All @@ -11,14 +13,22 @@ pub enum FiniteFieldError {
InputSizeMismatch,
}

/// Newtype wrapper over u32
/// Newtype wrapper over u128
///
/// Implements the arithmetic over the finite prime field
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub struct Field(u32);
#[derive(Clone, Copy, Debug, PartialOrd, Ord, Hash, Default)]
pub struct Field(u128);

/// Parameters for GF(2^32 - 2^20 + 1).
pub(crate) const SMALL_FP: FieldParameters = FieldParameters {
p: 4293918721,
p2: 8587837442,
mu: 17302828673139736575,
r2: 1676699750,
};

/// Modulus for the field, a FFT friendly prime: 2^32 - 2^20 + 1
pub const MODULUS: u32 = 4293918721;
pub const MODULUS: u32 = SMALL_FP.p as u32;
/// Generator for the multiplicative subgroup
pub(crate) const GENERATOR: u32 = 3925978153;
/// Number of primitive roots
Expand All @@ -28,7 +38,7 @@ impl std::ops::Add for Field {
type Output = Field;

fn add(self, rhs: Self) -> Self {
self - Field(MODULUS - rhs.0)
Self(SMALL_FP.add(self.0, rhs.0))
}
}

Expand All @@ -42,14 +52,7 @@ impl std::ops::Sub for Field {
type Output = Field;

fn sub(self, rhs: Self) -> Self {
let l = self.0;
let r = rhs.0;

if l >= r {
Field(l - r)
} else {
Field(MODULUS - r + l)
}
Self(SMALL_FP.sub(self.0, rhs.0))
}
}

Expand All @@ -62,12 +65,8 @@ impl std::ops::SubAssign for Field {
impl std::ops::Mul for Field {
type Output = Field;

#[allow(clippy::suspicious_arithmetic_impl)]
fn mul(self, rhs: Self) -> Self {
let l = self.0 as u64;
let r = rhs.0 as u64;
let mul = l * r;
Field((mul % (MODULUS as u64)) as u32)
Self(SMALL_FP.mul(self.0, rhs.0))
}
}

Expand All @@ -80,7 +79,6 @@ impl std::ops::MulAssign for Field {
impl std::ops::Div for Field {
type Output = Field;

#[allow(clippy::suspicious_arithmetic_impl)]
fn div(self, rhs: Self) -> Self {
self * rhs.inv()
}
Expand All @@ -92,95 +90,76 @@ impl std::ops::DivAssign for Field {
}
}

impl PartialEq for Field {
fn eq(&self, rhs: &Self) -> bool {
SMALL_FP.from_elem(self.0) == SMALL_FP.from_elem(rhs.0)
}
}

impl Eq for Field {}

impl Field {
/// Modular exponentation
pub fn pow(self, exp: Self) -> Self {
// repeated squaring
let mut base = self;
let mut exp = exp.0;
let mut result: Field = Field(1);
while exp > 0 {
while (exp & 1) == 0 {
exp /= 2;
base *= base;
}
exp -= 1;
result *= base;
}
result
Self(SMALL_FP.pow(self.0, SMALL_FP.from_elem(exp.0)))
}

/// Modular inverse
///
/// Note: inverse of 0 is defined as 0.
pub fn inv(self) -> Self {
// extended Euclidean
let mut x1: i32 = 1;
let mut a1: u32 = self.0;
let mut x0: i32 = 0;
let mut a2: u32 = MODULUS;
let mut q: u32 = 0;

while a2 != 0 {
let x2 = x0 - (q as i32) * x1;
x0 = x1;
let a0 = a1;
x1 = x2;
a1 = a2;
q = a0 / a1;
a2 = a0 - q * a1;
}
if x1 < 0 {
let (r, _) = MODULUS.overflowing_add(x1 as u32);
Field(r)
} else {
Field(x1 as u32)
}
Self(SMALL_FP.inv(self.0))
}
}

impl From<u32> for Field {
fn from(x: u32) -> Self {
Field(x % MODULUS)
Field(SMALL_FP.elem(x as u128))
}
}

impl From<Field> for u32 {
fn from(x: Field) -> Self {
x.0
SMALL_FP.from_elem(x.0) as u32
}
}

impl PartialEq<u32> for Field {
fn eq(&self, rhs: &u32) -> bool {
self.0 == *rhs
SMALL_FP.from_elem(self.0) == *rhs as u128
}
}

impl std::fmt::Display for Field {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", self.0)
write!(f, "{}", SMALL_FP.from_elem(self.0))
}
}

#[test]
fn test_small_fp() {
assert_eq!(SMALL_FP.check(), Ok(()));
}

#[test]
fn test_arithmetic() {
use rand::prelude::*;

// add
assert_eq!(Field(MODULUS - 1) + Field(1), 0);
assert_eq!(Field(MODULUS - 2) + Field(2), 0);
assert_eq!(Field(MODULUS - 2) + Field(3), 1);
assert_eq!(Field(1) + Field(1), 2);
assert_eq!(Field(2) + Field(MODULUS), 2);
assert_eq!(Field(3) + Field(MODULUS - 1), 2);
assert_eq!(Field::from(MODULUS - 1) + Field::from(1), 0);
assert_eq!(Field::from(MODULUS - 2) + Field::from(2), 0);
assert_eq!(Field::from(MODULUS - 2) + Field::from(3), 1);
assert_eq!(Field::from(1) + Field::from(1), 2);
assert_eq!(Field::from(2) + Field::from(MODULUS), 2);
assert_eq!(Field::from(3) + Field::from(MODULUS - 1), 2);

// sub
assert_eq!(Field(0) - Field(1), MODULUS - 1);
assert_eq!(Field(1) - Field(2), MODULUS - 1);
assert_eq!(Field(15) - Field(3), 12);
assert_eq!(Field(1) - Field(1), 0);
assert_eq!(Field(2) - Field(MODULUS), 2);
assert_eq!(Field(3) - Field(MODULUS - 1), 4);
assert_eq!(Field::from(0) - Field::from(1), MODULUS - 1);
assert_eq!(Field::from(1) - Field::from(2), MODULUS - 1);
assert_eq!(Field::from(15) - Field::from(3), 12);
assert_eq!(Field::from(1) - Field::from(1), 0);
assert_eq!(Field::from(2) - Field::from(MODULUS), 2);
assert_eq!(Field::from(3) - Field::from(MODULUS - 1), 4);

// add + sub
for _ in 0..100 {
Expand All @@ -192,35 +171,35 @@ fn test_arithmetic() {
}

// mul
assert_eq!(Field(35) * Field(123), 4305);
assert_eq!(Field(1) * Field(MODULUS), 0);
assert_eq!(Field(0) * Field(123), 0);
assert_eq!(Field(123) * Field(0), 0);
assert_eq!(Field(123123123) * Field(123123123), 1237630077);
assert_eq!(Field::from(35) * Field::from(123), 4305);
assert_eq!(Field::from(1) * Field::from(MODULUS), 0);
assert_eq!(Field::from(0) * Field::from(123), 0);
assert_eq!(Field::from(123) * Field::from(0), 0);
assert_eq!(Field::from(123123123) * Field::from(123123123), 1237630077);

// div
assert_eq!(Field(35) / Field(5), 7);
assert_eq!(Field(35) / Field(0), 0);
assert_eq!(Field(0) / Field(5), 0);
assert_eq!(Field(1237630077) / Field(123123123), 123123123);
assert_eq!(Field::from(35) / Field::from(5), 7);
assert_eq!(Field::from(35) / Field::from(0), 0);
assert_eq!(Field::from(0) / Field::from(5), 0);
assert_eq!(Field::from(1237630077) / Field::from(123123123), 123123123);

assert_eq!(Field(0).inv(), 0);
assert_eq!(Field::from(0).inv(), 0);

// mul and div
let uniform = rand::distributions::Uniform::from(1..MODULUS);
let mut rng = thread_rng();
for _ in 0..100 {
// non-zero element
let f = Field(uniform.sample(&mut rng));
let f = Field::from(uniform.sample(&mut rng));
assert_eq!(f * f.inv(), 1);
assert_eq!(f.inv() * f, 1);
}

// pow
assert_eq!(Field(2).pow(3.into()), 8);
assert_eq!(Field(3).pow(9.into()), 19683);
assert_eq!(Field(51).pow(27.into()), 3760729523);
assert_eq!(Field(432).pow(0.into()), 1);
assert_eq!(Field::from(2).pow(3.into()), 8);
assert_eq!(Field::from(3).pow(9.into()), 19683);
assert_eq!(Field::from(51).pow(27.into()), 3760729523);
assert_eq!(Field::from(432).pow(0.into()), 1);
assert_eq!(Field(0).pow(123.into()), 0);
}

Expand Down
Loading

0 comments on commit 941ec20

Please sign in to comment.