diff --git a/CHANGELOG.md b/CHANGELOG.md index aa846a2a1..804735a34 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,13 @@ - [\#593](https://github.com/arkworks-rs/algebra/pull/593) (`ark-ec`) Change `AffineRepr::xy()` to return owned values. - [\#633](https://github.com/arkworks-rs/algebra/pull/633) (`ark-ec`) Generic pairing implementation for the curves from the BW6 family. - [\#659](https://github.com/arkworks-rs/algebra/pull/659) (`ark-ec`) Move auxiliary `parity` function from `ark_ec::hashing::curve_maps::swu` to `ark_ec::hashing::curve_maps`. +- [\#746](https://github.com/arkworks-rs/algebra/pull/746) (`ark-ec`) Refactor fixed-based batch multiplication: + - Move functionality to `ScalarMul::batch_mul` and `ScalarMul::batch_mul_with_preprocessing`. + - Create new struct `BatchMulPreprocessing` for to hold preprocessed powers of `base`. + - Provide high-level constructor `new` that calculates window size and scalar size. + - Provide low-level constructor `with_window_and_scalar_size` that allows setting these parameters. + - Make `windowed_mul` a private method of `BatchMulPreprocessing`. + - Rename `get_mul_window_size` to `compute_window_size` and make it private. - [\#748](https://github.com/arkworks-rs/algebra/pull/748) (`ark-ff`) Add `FromStr` for `BigInteger`. ### Features diff --git a/ec/src/scalar_mul/fixed_base.rs b/ec/src/scalar_mul/fixed_base.rs deleted file mode 100644 index ce8001ccd..000000000 --- a/ec/src/scalar_mul/fixed_base.rs +++ /dev/null @@ -1,98 +0,0 @@ -use ark_ff::{BigInteger, PrimeField}; -use ark_std::{cfg_iter, cfg_iter_mut, vec::Vec}; - -#[cfg(feature = "parallel")] -use rayon::prelude::*; - -use super::ScalarMul; - -pub struct FixedBase; - -impl FixedBase { - pub fn get_mul_window_size(num_scalars: usize) -> usize { - if num_scalars < 32 { - 3 - } else { - super::ln_without_floats(num_scalars) - } - } - - pub fn get_window_table( - scalar_size: usize, - window: usize, - g: T, - ) -> Vec> { - let in_window = 1 << window; - let outerc = (scalar_size + window - 1) / window; - let last_in_window = 1 << (scalar_size - (outerc - 1) * window); - - let mut multiples_of_g = vec![vec![T::zero(); in_window]; outerc]; - - let mut g_outer = g; - let mut g_outers = Vec::with_capacity(outerc); - for _ in 0..outerc { - g_outers.push(g_outer); - for _ in 0..window { - g_outer.double_in_place(); - } - } - cfg_iter_mut!(multiples_of_g) - .enumerate() - .take(outerc) - .zip(g_outers) - .for_each(|((outer, multiples_of_g), g_outer)| { - let cur_in_window = if outer == outerc - 1 { - last_in_window - } else { - in_window - }; - - let mut g_inner = T::zero(); - for inner in multiples_of_g.iter_mut().take(cur_in_window) { - *inner = g_inner; - g_inner += &g_outer; - } - }); - cfg_iter!(multiples_of_g) - .map(|s| T::batch_convert_to_mul_base(s)) - .collect() - } - - pub fn windowed_mul( - outerc: usize, - window: usize, - multiples_of_g: &[Vec<::MulBase>], - scalar: &T::ScalarField, - ) -> T { - let modulus_size = T::ScalarField::MODULUS_BIT_SIZE as usize; - let scalar_val = scalar.into_bigint().to_bits_le(); - - let mut res = T::from(multiples_of_g[0][0]); - for outer in 0..outerc { - let mut inner = 0usize; - for i in 0..window { - if outer * window + i < modulus_size && scalar_val[outer * window + i] { - inner |= 1 << i; - } - } - res += &multiples_of_g[outer][inner]; - } - res - } - - // TODO use const-generics for the scalar size and window - // TODO use iterators of iterators of T::Affine instead of taking owned Vec - pub fn msm( - scalar_size: usize, - window: usize, - table: &[Vec<::MulBase>], - v: &[T::ScalarField], - ) -> Vec { - let outerc = (scalar_size + window - 1) / window; - assert!(outerc <= table.len()); - - cfg_iter!(v) - .map(|e| Self::windowed_mul::(outerc, window, table, e)) - .collect::>() - } -} diff --git a/ec/src/scalar_mul/mod.rs b/ec/src/scalar_mul/mod.rs index 10f0f7109..bcf17843b 100644 --- a/ec/src/scalar_mul/mod.rs +++ b/ec/src/scalar_mul/mod.rs @@ -1,19 +1,22 @@ pub mod glv; pub mod wnaf; -pub mod fixed_base; pub mod variable_base; use crate::{ short_weierstrass::{Affine, Projective, SWCurveConfig}, PrimeGroup, }; -use ark_ff::{AdditiveGroup, Zero}; +use ark_ff::{AdditiveGroup, BigInteger, PrimeField, Zero}; use ark_std::{ + cfg_iter, cfg_iter_mut, ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign}, vec::Vec, }; +#[cfg(feature = "parallel")] +use rayon::prelude::*; + /// The result of this function is only approximately `ln(a)` /// [`Explanation of usage`] /// @@ -76,9 +79,139 @@ pub trait ScalarMul: + core::hash::Hash + Mul + for<'a> Mul<&'a Self::ScalarField, Output = Self> - + Neg; + + Neg + + From; const NEGATION_IS_CHEAP: bool; fn batch_convert_to_mul_base(bases: &[Self]) -> Vec; + + /// Compute the vector v[0].G, v[1].G, ..., v[n-1].G, given: + /// - an element `g` + /// - a list `v` of n scalars + /// + /// # Example + /// ``` + /// use ark_std::{One, UniformRand}; + /// use ark_ec::pairing::Pairing; + /// use ark_test_curves::bls12_381::G1Projective as G; + /// use ark_test_curves::bls12_381::Fr; + /// use ark_ec::scalar_mul::ScalarMul; + /// + /// // Compute G, s.G, s^2.G, ..., s^9.G + /// let mut rng = ark_std::test_rng(); + /// let max_degree = 10; + /// let s = Fr::rand(&mut rng); + /// let g = G::rand(&mut rng); + /// let mut powers_of_s = vec![Fr::one()]; + /// let mut cur = s; + /// for _ in 0..max_degree { + /// powers_of_s.push(cur); + /// cur *= &s; + /// } + /// let powers_of_g = g.batch_mul(&powers_of_s); + /// let naive_powers_of_g: Vec = powers_of_s.iter().map(|e| g * e).collect(); + /// assert_eq!(powers_of_g, naive_powers_of_g); + /// ``` + fn batch_mul(self, v: &[Self::ScalarField]) -> Vec { + let table = BatchMulPreprocessing::new(self, v.len()); + self.batch_mul_with_preprocessing(v, &table) + } + + fn batch_mul_with_preprocessing( + self, + v: &[Self::ScalarField], + preprocessing: &BatchMulPreprocessing, + ) -> Vec { + let result = cfg_iter!(v) + .map(|e| preprocessing.windowed_mul(e)) + .collect::>(); + Self::batch_convert_to_mul_base(&result) + } +} + +/// Preprocessing used internally for batch scalar multiplication via [`ScalarMul::batch_mul`]. +/// - `window` is the window size used for the precomputation +/// - `max_scalar_size` is the maximum size of the scalars that will be multiplied +/// - `table` is the precomputed table of multiples of `base` +pub struct BatchMulPreprocessing { + pub window: usize, + pub max_scalar_size: usize, + pub table: Vec>, +} + +impl BatchMulPreprocessing { + pub fn new(base: T, num_scalars: usize) -> Self { + let window = Self::compute_window_size(num_scalars); + let scalar_size = T::ScalarField::MODULUS_BIT_SIZE as usize; + Self::with_window_and_scalar_size(base, window, scalar_size) + } + + fn compute_window_size(num_scalars: usize) -> usize { + if num_scalars < 32 { + 3 + } else { + ln_without_floats(num_scalars) + } + } + + pub fn with_window_and_scalar_size(base: T, window: usize, max_scalar_size: usize) -> Self { + let in_window = 1 << window; + let outerc = (max_scalar_size + window - 1) / window; + let last_in_window = 1 << (max_scalar_size - (outerc - 1) * window); + + let mut multiples_of_g = vec![vec![T::zero(); in_window]; outerc]; + + let mut g_outer = base; + let mut g_outers = Vec::with_capacity(outerc); + for _ in 0..outerc { + g_outers.push(g_outer); + for _ in 0..window { + g_outer.double_in_place(); + } + } + cfg_iter_mut!(multiples_of_g) + .enumerate() + .take(outerc) + .zip(g_outers) + .for_each(|((outer, multiples_of_g), g_outer)| { + let cur_in_window = if outer == outerc - 1 { + last_in_window + } else { + in_window + }; + + let mut g_inner = T::zero(); + for inner in multiples_of_g.iter_mut().take(cur_in_window) { + *inner = g_inner; + g_inner += &g_outer; + } + }); + let table = cfg_iter!(multiples_of_g) + .map(|s| T::batch_convert_to_mul_base(s)) + .collect(); + Self { + window, + max_scalar_size, + table, + } + } + + fn windowed_mul(&self, scalar: &T::ScalarField) -> T { + let outerc = (self.max_scalar_size + self.window - 1) / self.window; + let modulus_size = T::ScalarField::MODULUS_BIT_SIZE as usize; + let scalar_val = scalar.into_bigint().to_bits_le(); + + let mut res = T::from(self.table[0][0]); + for outer in 0..outerc { + let mut inner = 0usize; + for i in 0..self.window { + if outer * self.window + i < modulus_size && scalar_val[outer * self.window + i] { + inner |= 1 << i; + } + } + res += &self.table[outer][inner]; + } + res + } }