Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FixedBase refactor: rename msm -> mul, declutter the interface and hide aux methods #746

Merged
merged 13 commits into from
Jan 14, 2024
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
- [\#593](https://github.com/arkworks-rs/algebra/pull/593) (`ark-ec`) Change `AffineRepr::xy()` to return owned values.
- [\#633](https://github.com/arkworks-rs/algebra/pull/633) (`ark-ec`) Generic pairing implementation for the curves from the BW6 family.
- [\#659](https://github.com/arkworks-rs/algebra/pull/659) (`ark-ec`) Move auxiliary `parity` function from `ark_ec::hashing::curve_maps::swu` to `ark_ec::hashing::curve_maps`.
- [\#746](https://github.com/arkworks-rs/algebra/pull/746) (`ark-ec`) Refactor fixed-based batch multiplication:
- Move functionality to `ScalarMul::batch_mul` and `ScalarMul::batch_mul_with_preprocessing`.
- Create new struct `BatchMulPreprocessing` for to hold preprocessed powers of `base`.
- Provide high-level constructor `new` that calculates window size and scalar size.
- Provide low-level constructor `with_window_and_scalar_size` that allows setting these parameters.
- Make `windowed_mul` a private method of `BatchMulPreprocessing`.
- Rename `get_mul_window_size` to `compute_window_size` and make it private.
- [\#748](https://github.com/arkworks-rs/algebra/pull/748) (`ark-ff`) Add `FromStr` for `BigInteger`.

### Features
Expand Down
98 changes: 0 additions & 98 deletions ec/src/scalar_mul/fixed_base.rs

This file was deleted.

139 changes: 136 additions & 3 deletions ec/src/scalar_mul/mod.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
pub mod glv;
pub mod wnaf;

pub mod fixed_base;
pub mod variable_base;

use crate::{
short_weierstrass::{Affine, Projective, SWCurveConfig},
PrimeGroup,
};
use ark_ff::{AdditiveGroup, Zero};
use ark_ff::{AdditiveGroup, BigInteger, PrimeField, Zero};
use ark_std::{
cfg_iter, cfg_iter_mut,
ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign},
vec::Vec,
};

#[cfg(feature = "parallel")]
use rayon::prelude::*;

/// The result of this function is only approximately `ln(a)`
/// [`Explanation of usage`]
///
Expand Down Expand Up @@ -76,9 +79,139 @@ pub trait ScalarMul:
+ core::hash::Hash
+ Mul<Self::ScalarField, Output = Self>
+ for<'a> Mul<&'a Self::ScalarField, Output = Self>
+ Neg<Output = Self::MulBase>;
+ Neg<Output = Self::MulBase>
+ From<Self>;

const NEGATION_IS_CHEAP: bool;

fn batch_convert_to_mul_base(bases: &[Self]) -> Vec<Self::MulBase>;

/// Compute the vector v[0].G, v[1].G, ..., v[n-1].G, given:
/// - an element `g`
/// - a list `v` of n scalars
///
/// # Example
/// ```
/// use ark_std::{One, UniformRand};
/// use ark_ec::pairing::Pairing;
/// use ark_test_curves::bls12_381::G1Projective as G;
/// use ark_test_curves::bls12_381::Fr;
/// use ark_ec::scalar_mul::ScalarMul;
///
/// // Compute G, s.G, s^2.G, ..., s^9.G
/// let mut rng = ark_std::test_rng();
/// let max_degree = 10;
/// let s = Fr::rand(&mut rng);
/// let g = G::rand(&mut rng);
/// let mut powers_of_s = vec![Fr::one()];
/// let mut cur = s;
/// for _ in 0..max_degree {
/// powers_of_s.push(cur);
/// cur *= &s;
/// }
/// let powers_of_g = g.batch_mul(&powers_of_s);
/// let naive_powers_of_g: Vec<G> = powers_of_s.iter().map(|e| g * e).collect();
/// assert_eq!(powers_of_g, naive_powers_of_g);
/// ```
fn batch_mul(self, v: &[Self::ScalarField]) -> Vec<Self::MulBase> {
let table = BatchMulPreprocessing::new(self, v.len());
self.batch_mul_with_preprocessing(v, &table)
}

fn batch_mul_with_preprocessing(
self,
v: &[Self::ScalarField],
preprocessing: &BatchMulPreprocessing<Self>,
) -> Vec<Self::MulBase> {
let result = cfg_iter!(v)
.map(|e| preprocessing.windowed_mul(e))
.collect::<Vec<_>>();
Self::batch_convert_to_mul_base(&result)
}
}

/// Preprocessing used internally for batch scalar multiplication via [`ScalarMul::batch_mul`].
/// - `window` is the window size used for the precomputation
/// - `max_scalar_size` is the maximum size of the scalars that will be multiplied
/// - `table` is the precomputed table of multiples of `base`
pub struct BatchMulPreprocessing<T: ScalarMul> {
pub window: usize,
pub max_scalar_size: usize,
pub table: Vec<Vec<T::MulBase>>,
}

impl<T: ScalarMul> BatchMulPreprocessing<T> {
pub fn new(base: T, num_scalars: usize) -> Self {
let window = Self::compute_window_size(num_scalars);
let scalar_size = T::ScalarField::MODULUS_BIT_SIZE as usize;
Self::with_window_and_scalar_size(base, window, scalar_size)
}

fn compute_window_size(num_scalars: usize) -> usize {
if num_scalars < 32 {
3
} else {
ln_without_floats(num_scalars)
}
}

pub fn with_window_and_scalar_size(base: T, window: usize, max_scalar_size: usize) -> Self {
let in_window = 1 << window;
let outerc = (max_scalar_size + window - 1) / window;
let last_in_window = 1 << (max_scalar_size - (outerc - 1) * window);

let mut multiples_of_g = vec![vec![T::zero(); in_window]; outerc];

let mut g_outer = base;
let mut g_outers = Vec::with_capacity(outerc);
for _ in 0..outerc {
g_outers.push(g_outer);
for _ in 0..window {
g_outer.double_in_place();
}
}
cfg_iter_mut!(multiples_of_g)
.enumerate()
.take(outerc)
.zip(g_outers)
.for_each(|((outer, multiples_of_g), g_outer)| {
let cur_in_window = if outer == outerc - 1 {
last_in_window
} else {
in_window
};

let mut g_inner = T::zero();
for inner in multiples_of_g.iter_mut().take(cur_in_window) {
*inner = g_inner;
g_inner += &g_outer;
}
});
let table = cfg_iter!(multiples_of_g)
.map(|s| T::batch_convert_to_mul_base(s))
.collect();
Self {
window,
max_scalar_size,
table,
}
}

fn windowed_mul(&self, scalar: &T::ScalarField) -> T {
let outerc = (self.max_scalar_size + self.window - 1) / self.window;
let modulus_size = T::ScalarField::MODULUS_BIT_SIZE as usize;
let scalar_val = scalar.into_bigint().to_bits_le();

let mut res = T::from(self.table[0][0]);
for outer in 0..outerc {
let mut inner = 0usize;
for i in 0..self.window {
if outer * self.window + i < modulus_size && scalar_val[outer * self.window + i] {
inner |= 1 << i;
}
}
res += &self.table[outer][inner];
}
res
}
}
Loading