Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add derived impl of squaring for MontBackend #530

Merged
merged 11 commits into from
Dec 10, 2022
4 changes: 3 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ jobs:
command: test
args: "--workspace \
--package ark-test-curves \
--all-features"
--all-features
-- -Z macro-backtrace
"
if: matrix.rust == 'nightly'

test_assembly:
Expand Down
13 changes: 13 additions & 0 deletions ff-macros/src/montgomery/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ use double::*;
mod mul;
use mul::*;

mod square;
use square::*;

mod sum_of_products;
use sum_of_products::*;

Expand Down Expand Up @@ -77,6 +80,12 @@ pub fn mont_config_helper(
&modulus_limbs,
modulus_has_spare_bit,
);
let square_in_place = square_in_place_impl(
can_use_no_carry_mul_opt,
limbs,
&modulus_limbs,
modulus_has_spare_bit,
);
let sum_of_products = sum_of_products_impl(limbs, &modulus_limbs);

let mixed_radix = if let Some(large_subgroup_generator) = large_subgroup_generator {
Expand Down Expand Up @@ -141,6 +150,10 @@ pub fn mont_config_helper(
fn mul_assign(a: &mut F, b: &F) {
#mul_assign
}
#[inline(always)]
fn square_in_place(a: &mut F) {
#square_in_place
}

fn sum_of_products<const M: usize>(
a: &[F; M],
Expand Down
16 changes: 12 additions & 4 deletions ff-macros/src/montgomery/mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,21 @@ pub(super) fn mul_assign_impl(
#[allow(unsafe_code, unused_mut)]
ark_ff::x86_64_asm_mul!(#num_limbs, (a.0).0, (b.0).0);
} else {
#default
#[cfg(
not(all(
feature = "asm",
target_feature = "bmi2",
target_feature = "adx",
target_arch = "x86_64"
))
)]
{
#default
}
}
}))
} else {
body.extend(quote!({
#default
}))
body.extend(quote!({ #default }))
}
body.extend(quote!(__subtract_modulus(a);));
} else {
Expand Down
114 changes: 114 additions & 0 deletions ff-macros/src/montgomery/square.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
use quote::quote;

pub(super) fn square_in_place_impl(
can_use_no_carry_mul_opt: bool,
num_limbs: usize,
modulus_limbs: &[u64],
modulus_has_spare_bit: bool,
) -> proc_macro2::TokenStream {
let mut body = proc_macro2::TokenStream::new();
let mut default = proc_macro2::TokenStream::new();

let modulus_0 = modulus_limbs[0];
let double_num_limbs = 2 * num_limbs;
default.extend(quote! {
let mut r = [0u64; #double_num_limbs];
let mut carry = 0;
});
for i in 0..(num_limbs - 1) {
for j in (i + 1)..num_limbs {
let idx = i + j;
default.extend(quote! {
r[#idx] = fa::mac_with_carry(r[#idx], (a.0).0[#i], (a.0).0[#j], &mut carry);
})
}
default.extend(quote! {
r[#num_limbs + #i] = carry;
carry = 0;
});
}
default.extend(quote! { r[#double_num_limbs - 1] = r[#double_num_limbs - 2] >> 63; });
for i in 2..(double_num_limbs - 1) {
let idx = double_num_limbs - i;
default.extend(quote! { r[#idx] = (r[#idx] << 1) | (r[#idx - 1] >> 63); });
}
default.extend(quote! { r[1] <<= 1; });

for i in 0..num_limbs {
let idx = 2 * i;
default.extend(quote! {
r[#idx] = fa::mac_with_carry(r[#idx], (a.0).0[#i], (a.0).0[#i], &mut carry);
carry = fa::adc(&mut r[#idx + 1], 0, carry);
});
}
// Montgomery reduction
default.extend(quote! { let mut carry2 = 0; });
for i in 0..num_limbs {
default.extend(quote! {
let k = r[#i].wrapping_mul(Self::INV);
let mut carry = 0;
fa::mac_discard(r[#i], k, #modulus_0, &mut carry);
});
for j in 1..num_limbs {
let idx = j + i;
let modulus_j = modulus_limbs[j];
default.extend(quote! {
r[#idx] = fa::mac_with_carry(r[#idx], k, #modulus_j, &mut carry);
});
}
default.extend(quote! { carry2 = fa::adc(&mut r[#num_limbs + #i], carry, carry2); });
}
default.extend(quote! { (a.0).0 = r[#num_limbs..].try_into().unwrap(); });

if num_limbs == 1 {
// We default to multiplying with `a` using the `Mul` impl
// for the N == 1 case
quote!({
*a *= *a;
})
} else if (2..=6).contains(&num_limbs) && can_use_no_carry_mul_opt {
body.extend(quote!({
if cfg!(all(
feature = "asm",
target_feature = "bmi2",
target_feature = "adx",
target_arch = "x86_64"
)) {
#[cfg(
all(
feature = "asm",
target_feature = "bmi2",
target_feature = "adx",
target_arch = "x86_64"
)
)]
#[allow(unsafe_code, unused_mut)]
{
ark_ff::x86_64_asm_square!(#num_limbs, (a.0).0);
}
} else {
#[cfg(
not(all(
feature = "asm",
target_feature = "bmi2",
target_feature = "adx",
target_arch = "x86_64"
))
)]
{
#default
}
}
}));
body.extend(quote!(__subtract_modulus(a);));
body
} else {
body.extend(quote!( #default ));
if modulus_has_spare_bit {
body.extend(quote!(__subtract_modulus(a);));
} else {
body.extend(quote!(__subtract_modulus_with_carry(a, carry2 != 0);));
}
body
}
}
1 change: 1 addition & 0 deletions ff/src/fields/models/fp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ pub type Fp768<P> = Fp<P, 12>;
pub type Fp832<P> = Fp<P, 13>;

impl<P: FpConfig<N>, const N: usize> Fp<P, N> {
#[doc(hidden)]
#[inline]
pub fn is_geq_modulus(&self) -> bool {
self.0 >= P::MODULUS
Expand Down
52 changes: 26 additions & 26 deletions test-templates/src/fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,55 +209,55 @@ macro_rules! __test_field {
let mut rng = test_rng();
let zero = <$field>::zero();
let one = <$field>::one();
assert_eq!(one.inverse().unwrap(), one);
assert!(one.is_one());
assert_eq!(one.inverse().unwrap(), one, "One inverse failed");
assert!(one.is_one(), "One is not one");

assert!(<$field>::ONE.is_one());
assert_eq!(<$field>::ONE, one);
assert!(<$field>::ONE.is_one(), "One constant is not one");
assert_eq!(<$field>::ONE, one, "One constant is incorrect");

for _ in 0..ITERATIONS {
// Associativity
let a = <$field>::rand(&mut rng);
let b = <$field>::rand(&mut rng);
let c = <$field>::rand(&mut rng);
assert_eq!((a * b) * c, a * (b * c));
assert_eq!((a * b) * c, a * (b * c), "Associativity failed");

// Commutativity
assert_eq!(a * b, b * a);
assert_eq!(a * b, b * a, "Commutativity failed");

// Identity
assert_eq!(one * a, a);
assert_eq!(one * b, b);
assert_eq!(one * c, c);
assert_eq!(one * a, a, "Identity mul failed");
assert_eq!(one * b, b, "Identity mul failed");
assert_eq!(one * c, c, "Identity mul failed");

assert_eq!(zero * a, zero);
assert_eq!(zero * b, zero);
assert_eq!(zero * c, zero);
assert_eq!(zero * a, zero, "Mul by zero failed");
assert_eq!(zero * b, zero, "Mul by zero failed");
assert_eq!(zero * c, zero, "Mul by zero failed");

// Inverses
assert_eq!(a * a.inverse().unwrap(), one);
assert_eq!(b * b.inverse().unwrap(), one);
assert_eq!(c * c.inverse().unwrap(), one);
assert_eq!(a * a.inverse().unwrap(), one, "Mul by inverse failed");
assert_eq!(b * b.inverse().unwrap(), one, "Mul by inverse failed");
assert_eq!(c * c.inverse().unwrap(), one, "Mul by inverse failed");

// Associativity and commutativity simultaneously
let t0 = (a * b) * c;
let t1 = (a * c) * b;
let t2 = (b * c) * a;
assert_eq!(t0, t1);
assert_eq!(t1, t2);
assert_eq!(t0, t1, "Associativity + commutativity failed");
assert_eq!(t1, t2, "Associativity + commutativity failed");

// Squaring
assert_eq!(a * a, a.square());
assert_eq!(b * b, b.square());
assert_eq!(c * c, c.square());
assert_eq!(a * a, a.square(), "Squaring failed");
assert_eq!(b * b, b.square(), "Squaring failed");
assert_eq!(c * c, c.square(), "Squaring failed");

// Distributivity
assert_eq!(a * (b + c), a * b + a * c);
assert_eq!(b * (a + c), b * a + b * c);
assert_eq!(c * (a + b), c * a + c * b);
assert_eq!((a + b).square(), a.square() + b.square() + a * b.double());
assert_eq!((b + c).square(), c.square() + b.square() + c * b.double());
assert_eq!((c + a).square(), a.square() + c.square() + a * c.double());
assert_eq!(a * (b + c), a * b + a * c, "Distributivity failed");
assert_eq!(b * (a + c), b * a + b * c, "Distributivity failed");
assert_eq!(c * (a + b), c * a + c * b, "Distributivity failed");
assert_eq!((a + b).square(), a.square() + b.square() + a * b.double(), "Distributivity for square failed");
assert_eq!((b + c).square(), c.square() + b.square() + c * b.double(), "Distributivity for square failed");
assert_eq!((c + a).square(), a.square() + c.square() + a * c.double(), "Distributivity for square failed");
}
}

Expand Down