diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5d4656af6..39c56113f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -116,7 +116,9 @@ jobs: command: test args: "--workspace \ --package ark-test-curves \ - --all-features" + --all-features + -- -Z macro-backtrace + " if: matrix.rust == 'nightly' test_assembly: diff --git a/ff-macros/src/montgomery/mod.rs b/ff-macros/src/montgomery/mod.rs index 0e01e3233..a086399c1 100644 --- a/ff-macros/src/montgomery/mod.rs +++ b/ff-macros/src/montgomery/mod.rs @@ -14,6 +14,9 @@ use double::*; mod mul; use mul::*; +mod square; +use square::*; + mod sum_of_products; use sum_of_products::*; @@ -77,6 +80,12 @@ pub fn mont_config_helper( &modulus_limbs, modulus_has_spare_bit, ); + let square_in_place = square_in_place_impl( + can_use_no_carry_mul_opt, + limbs, + &modulus_limbs, + modulus_has_spare_bit, + ); let sum_of_products = sum_of_products_impl(limbs, &modulus_limbs); let mixed_radix = if let Some(large_subgroup_generator) = large_subgroup_generator { @@ -141,6 +150,10 @@ pub fn mont_config_helper( fn mul_assign(a: &mut F, b: &F) { #mul_assign } + #[inline(always)] + fn square_in_place(a: &mut F) { + #square_in_place + } fn sum_of_products( a: &[F; M], diff --git a/ff-macros/src/montgomery/mul.rs b/ff-macros/src/montgomery/mul.rs index 63e58cca6..8f6859aea 100644 --- a/ff-macros/src/montgomery/mul.rs +++ b/ff-macros/src/montgomery/mul.rs @@ -57,13 +57,21 @@ pub(super) fn mul_assign_impl( #[allow(unsafe_code, unused_mut)] ark_ff::x86_64_asm_mul!(#num_limbs, (a.0).0, (b.0).0); } else { - #default + #[cfg( + not(all( + feature = "asm", + target_feature = "bmi2", + target_feature = "adx", + target_arch = "x86_64" + )) + )] + { + #default + } } })) } else { - body.extend(quote!({ - #default - })) + body.extend(quote!({ #default })) } body.extend(quote!(__subtract_modulus(a);)); } else { diff --git a/ff-macros/src/montgomery/square.rs b/ff-macros/src/montgomery/square.rs new file mode 100644 index 000000000..df4c866e5 --- /dev/null +++ b/ff-macros/src/montgomery/square.rs @@ -0,0 +1,114 @@ +use quote::quote; + +pub(super) fn square_in_place_impl( + can_use_no_carry_mul_opt: bool, + num_limbs: usize, + modulus_limbs: &[u64], + modulus_has_spare_bit: bool, +) -> proc_macro2::TokenStream { + let mut body = proc_macro2::TokenStream::new(); + let mut default = proc_macro2::TokenStream::new(); + + let modulus_0 = modulus_limbs[0]; + let double_num_limbs = 2 * num_limbs; + default.extend(quote! { + let mut r = [0u64; #double_num_limbs]; + let mut carry = 0; + }); + for i in 0..(num_limbs - 1) { + for j in (i + 1)..num_limbs { + let idx = i + j; + default.extend(quote! { + r[#idx] = fa::mac_with_carry(r[#idx], (a.0).0[#i], (a.0).0[#j], &mut carry); + }) + } + default.extend(quote! { + r[#num_limbs + #i] = carry; + carry = 0; + }); + } + default.extend(quote! { r[#double_num_limbs - 1] = r[#double_num_limbs - 2] >> 63; }); + for i in 2..(double_num_limbs - 1) { + let idx = double_num_limbs - i; + default.extend(quote! { r[#idx] = (r[#idx] << 1) | (r[#idx - 1] >> 63); }); + } + default.extend(quote! { r[1] <<= 1; }); + + for i in 0..num_limbs { + let idx = 2 * i; + default.extend(quote! { + r[#idx] = fa::mac_with_carry(r[#idx], (a.0).0[#i], (a.0).0[#i], &mut carry); + carry = fa::adc(&mut r[#idx + 1], 0, carry); + }); + } + // Montgomery reduction + default.extend(quote! { let mut carry2 = 0; }); + for i in 0..num_limbs { + default.extend(quote! { + let k = r[#i].wrapping_mul(Self::INV); + let mut carry = 0; + fa::mac_discard(r[#i], k, #modulus_0, &mut carry); + }); + for j in 1..num_limbs { + let idx = j + i; + let modulus_j = modulus_limbs[j]; + default.extend(quote! { + r[#idx] = fa::mac_with_carry(r[#idx], k, #modulus_j, &mut carry); + }); + } + default.extend(quote! { carry2 = fa::adc(&mut r[#num_limbs + #i], carry, carry2); }); + } + default.extend(quote! { (a.0).0 = r[#num_limbs..].try_into().unwrap(); }); + + if num_limbs == 1 { + // We default to multiplying with `a` using the `Mul` impl + // for the N == 1 case + quote!({ + *a *= *a; + }) + } else if (2..=6).contains(&num_limbs) && can_use_no_carry_mul_opt { + body.extend(quote!({ + if cfg!(all( + feature = "asm", + target_feature = "bmi2", + target_feature = "adx", + target_arch = "x86_64" + )) { + #[cfg( + all( + feature = "asm", + target_feature = "bmi2", + target_feature = "adx", + target_arch = "x86_64" + ) + )] + #[allow(unsafe_code, unused_mut)] + { + ark_ff::x86_64_asm_square!(#num_limbs, (a.0).0); + } + } else { + #[cfg( + not(all( + feature = "asm", + target_feature = "bmi2", + target_feature = "adx", + target_arch = "x86_64" + )) + )] + { + #default + } + } + })); + body.extend(quote!(__subtract_modulus(a);)); + body + } else { + body.extend(quote!( #default )); + if modulus_has_spare_bit { + body.extend(quote!(__subtract_modulus(a);)); + } else { + body.extend(quote!(__subtract_modulus_with_carry(a, carry2 != 0);)); + } + body + } +} diff --git a/ff/src/fields/models/fp/mod.rs b/ff/src/fields/models/fp/mod.rs index ceb108aee..7ddf59b98 100644 --- a/ff/src/fields/models/fp/mod.rs +++ b/ff/src/fields/models/fp/mod.rs @@ -131,6 +131,7 @@ pub type Fp768

= Fp; pub type Fp832

= Fp; impl, const N: usize> Fp { + #[doc(hidden)] #[inline] pub fn is_geq_modulus(&self) -> bool { self.0 >= P::MODULUS diff --git a/test-templates/src/fields.rs b/test-templates/src/fields.rs index f249a8395..dbe7938fd 100644 --- a/test-templates/src/fields.rs +++ b/test-templates/src/fields.rs @@ -209,55 +209,55 @@ macro_rules! __test_field { let mut rng = test_rng(); let zero = <$field>::zero(); let one = <$field>::one(); - assert_eq!(one.inverse().unwrap(), one); - assert!(one.is_one()); + assert_eq!(one.inverse().unwrap(), one, "One inverse failed"); + assert!(one.is_one(), "One is not one"); - assert!(<$field>::ONE.is_one()); - assert_eq!(<$field>::ONE, one); + assert!(<$field>::ONE.is_one(), "One constant is not one"); + assert_eq!(<$field>::ONE, one, "One constant is incorrect"); for _ in 0..ITERATIONS { // Associativity let a = <$field>::rand(&mut rng); let b = <$field>::rand(&mut rng); let c = <$field>::rand(&mut rng); - assert_eq!((a * b) * c, a * (b * c)); + assert_eq!((a * b) * c, a * (b * c), "Associativity failed"); // Commutativity - assert_eq!(a * b, b * a); + assert_eq!(a * b, b * a, "Commutativity failed"); // Identity - assert_eq!(one * a, a); - assert_eq!(one * b, b); - assert_eq!(one * c, c); + assert_eq!(one * a, a, "Identity mul failed"); + assert_eq!(one * b, b, "Identity mul failed"); + assert_eq!(one * c, c, "Identity mul failed"); - assert_eq!(zero * a, zero); - assert_eq!(zero * b, zero); - assert_eq!(zero * c, zero); + assert_eq!(zero * a, zero, "Mul by zero failed"); + assert_eq!(zero * b, zero, "Mul by zero failed"); + assert_eq!(zero * c, zero, "Mul by zero failed"); // Inverses - assert_eq!(a * a.inverse().unwrap(), one); - assert_eq!(b * b.inverse().unwrap(), one); - assert_eq!(c * c.inverse().unwrap(), one); + assert_eq!(a * a.inverse().unwrap(), one, "Mul by inverse failed"); + assert_eq!(b * b.inverse().unwrap(), one, "Mul by inverse failed"); + assert_eq!(c * c.inverse().unwrap(), one, "Mul by inverse failed"); // Associativity and commutativity simultaneously let t0 = (a * b) * c; let t1 = (a * c) * b; let t2 = (b * c) * a; - assert_eq!(t0, t1); - assert_eq!(t1, t2); + assert_eq!(t0, t1, "Associativity + commutativity failed"); + assert_eq!(t1, t2, "Associativity + commutativity failed"); // Squaring - assert_eq!(a * a, a.square()); - assert_eq!(b * b, b.square()); - assert_eq!(c * c, c.square()); + assert_eq!(a * a, a.square(), "Squaring failed"); + assert_eq!(b * b, b.square(), "Squaring failed"); + assert_eq!(c * c, c.square(), "Squaring failed"); // Distributivity - assert_eq!(a * (b + c), a * b + a * c); - assert_eq!(b * (a + c), b * a + b * c); - assert_eq!(c * (a + b), c * a + c * b); - assert_eq!((a + b).square(), a.square() + b.square() + a * b.double()); - assert_eq!((b + c).square(), c.square() + b.square() + c * b.double()); - assert_eq!((c + a).square(), a.square() + c.square() + a * c.double()); + assert_eq!(a * (b + c), a * b + a * c, "Distributivity failed"); + assert_eq!(b * (a + c), b * a + b * c, "Distributivity failed"); + assert_eq!(c * (a + b), c * a + c * b, "Distributivity failed"); + assert_eq!((a + b).square(), a.square() + b.square() + a * b.double(), "Distributivity for square failed"); + assert_eq!((b + c).square(), c.square() + b.square() + c * b.double(), "Distributivity for square failed"); + assert_eq!((c + a).square(), a.square() + c.square() + a * c.double(), "Distributivity for square failed"); } }