diff --git a/jolt-core/src/jolt/vm/instruction_lookups.rs b/jolt-core/src/jolt/vm/instruction_lookups.rs index af5f2ea2d..ab5a3026f 100644 --- a/jolt-core/src/jolt/vm/instruction_lookups.rs +++ b/jolt-core/src/jolt/vm/instruction_lookups.rs @@ -122,7 +122,7 @@ where fn batch(&self) -> Self::BatchedPolynomials { use rayon::prelude::*; let (batched_dim_read, (batched_final, batched_E, batched_flag)) = rayon::join( - || DensePolynomial::merge_dual(self.dim.as_ref(), self.read_cts.as_ref()), + || DensePolynomial::merge(self.dim.iter().chain(&self.read_cts)), || { let batched_final = DensePolynomial::merge(&self.final_cts); let (batched_E, batched_flag) = rayon::join( diff --git a/jolt-core/src/lasso/surge.rs b/jolt-core/src/lasso/surge.rs index 941e9dd23..1769e833f 100644 --- a/jolt-core/src/lasso/surge.rs +++ b/jolt-core/src/lasso/surge.rs @@ -65,7 +65,7 @@ where #[tracing::instrument(skip_all, name = "SurgePolys::batch")] fn batch(&self) -> Self::BatchedPolynomials { let (batched_dim_read, (batched_final, batched_E)) = rayon::join( - || DensePolynomial::merge_dual(self.dim.as_ref(), self.read_cts.as_ref()), + || DensePolynomial::merge(self.dim.iter().chain(&self.read_cts)), || { rayon::join( || DensePolynomial::merge(&self.final_cts), diff --git a/jolt-core/src/msm/mod.rs b/jolt-core/src/msm/mod.rs index 6a76ec523..50c40c1c9 100644 --- a/jolt-core/src/msm/mod.rs +++ b/jolt-core/src/msm/mod.rs @@ -328,3 +328,86 @@ fn ln_without_floats(a: usize) -> usize { // log2(a) * ln(2) (ark_std::log2(a) * 69 / 100) as usize } + +/// Special MSM where all scalar values are 0 / 1 – does not verify. +pub(crate) fn flags_msm(scalars: &[G::ScalarField], bases: &[G::Affine]) -> G { + assert_eq!(scalars.len(), bases.len()); + let result = scalars + .into_iter() + .enumerate() + .filter(|(_index, scalar)| !scalar.is_zero()) + .map(|(index, scalar)| bases[index]) + .sum(); + + result +} + +pub(crate) fn sm_msm( + scalars: &[::BigInt], + bases: &[V::MulBase], +) -> V { + assert_eq!(scalars.len(), bases.len()); + let num_buckets: usize = 1 << 16; // TODO(sragss): This should be passed in / dependent on M = N^{1/C} + + // #[cfg(test)] + // scalars.for_each(|scalar| { + // assert!(scalar < V::ScalarField::from(num_buckets as u64).into_bigint()) + // }); + + // Assign things to buckets based on the scalar + let mut buckets: Vec = vec![V::zero(); num_buckets]; + scalars.into_iter().enumerate().for_each(|(index, scalar)| { + let bucket_index: u64 = scalar.as_ref()[0]; + buckets[bucket_index as usize] += bases[index]; + }); + + let mut result = V::zero(); + let mut running_sum = V::zero(); + buckets + .into_iter() + .skip(1) + .enumerate() + .rev() + .for_each(|(index, bucket)| { + running_sum += bucket; + result += running_sum; + }); + result +} + +#[cfg(test)] +mod tests { + + use ark_std::test_rng; + + use crate::poly::dense_mlpoly::DensePolynomial; + + use super::*; + + #[test] + fn sm_msm_parity() { + use ark_curve25519::{EdwardsAffine as G1Affine, EdwardsProjective as G1Projective, Fr}; + let mut rng = test_rng(); + let bases = vec![ + G1Affine::rand(&mut rng), + G1Affine::rand(&mut rng), + G1Affine::rand(&mut rng), + ]; + let scalars = vec![Fr::from(3), Fr::from(2), Fr::from(1)]; + let expected_result = bases[0] + bases[0] + bases[0] + bases[1] + bases[1] + bases[2]; + assert_eq!(bases[0] + bases[0] + bases[0], bases[0] * scalars[0]); + let expected_result_b = + bases[0] * scalars[0] + bases[1] * scalars[1] + bases[2] * scalars[2]; + assert_eq!(expected_result, expected_result_b); + + let calc_result_a: G1Projective = VariableBaseMSM::msm(&bases, &scalars).unwrap(); + assert_eq!(calc_result_a, expected_result); + + let scalars_bigint: Vec<_> = scalars + .into_iter() + .map(|scalar| scalar.into_bigint()) + .collect(); + let calc_result_b: G1Projective = sm_msm(&scalars_bigint, &bases); + assert_eq!(calc_result_b, expected_result); + } +} diff --git a/jolt-core/src/poly/dense_mlpoly.rs b/jolt-core/src/poly/dense_mlpoly.rs index e4da0ac33..884c34e15 100644 --- a/jolt-core/src/poly/dense_mlpoly.rs +++ b/jolt-core/src/poly/dense_mlpoly.rs @@ -1,4 +1,5 @@ #![allow(clippy::too_many_arguments)] +use crate::msm::{flags_msm, sm_msm}; use crate::poly::eq_poly::EqPolynomial; use crate::utils::{self, compute_dotproduct, compute_dotproduct_low_optimized, mul_0_1_optimized}; @@ -219,10 +220,10 @@ impl DensePolynomial { let scalars = self.Z[R_size * i..R_size * (i + 1)].as_ref(); match hint { CommitHint::Normal => Commitments::batch_commit(scalars, &gens), - CommitHint::Flags => Self::flags_msm(scalars, &gens), + CommitHint::Flags => flags_msm(scalars, &gens), CommitHint::Small => { let bigints: Vec<_> = scalars.iter().map(|s| s.into_bigint()).collect(); - Self::sm_msm(&bigints, &gens) + sm_msm(&bigints, &gens) } } }) @@ -230,52 +231,6 @@ impl DensePolynomial { PolyCommitment { C } } - /// Special MSM where all scalar values are 0 / 1 – does not verify. - fn flags_msm(scalars: &[G::ScalarField], bases: &[G::Affine]) -> G { - assert_eq!(scalars.len(), bases.len()); - let result = scalars - .into_iter() - .enumerate() - .filter(|(_index, scalar)| !scalar.is_zero()) - .map(|(index, scalar)| bases[index]) - .sum(); - - result - } - - pub fn sm_msm( - scalars: &[::BigInt], - bases: &[V::MulBase], - ) -> V { - assert_eq!(scalars.len(), bases.len()); - let num_buckets: usize = 1 << 16; // TODO(sragss): This should be passed in / dependent on M = N^{1/C} - - // #[cfg(test)] - // scalars.for_each(|scalar| { - // assert!(scalar < V::ScalarField::from(num_buckets as u64).into_bigint()) - // }); - - // Assign things to buckets based on the scalar - let mut buckets: Vec = vec![V::zero(); num_buckets]; - scalars.into_iter().enumerate().for_each(|(index, scalar)| { - let bucket_index: u64 = scalar.as_ref()[0]; - buckets[bucket_index as usize] += bases[index]; - }); - - let mut result = V::zero(); - let mut running_sum = V::zero(); - buckets - .into_iter() - .skip(1) - .enumerate() - .rev() - .for_each(|(index, bucket)| { - running_sum += bucket; - result += running_sum; - }); - result - } - #[tracing::instrument(skip_all, name = "DensePolynomial.bound")] pub fn bound(&self, L: &[F]) -> Vec { let (left_num_vars, right_num_vars) = @@ -442,24 +397,15 @@ impl DensePolynomial { self.Z.as_ref() } - pub fn extend(&mut self, other: &DensePolynomial) { - assert_eq!(self.Z.len(), self.len); - let other_vec = other.vec(); - assert_eq!(other_vec.len(), self.len); - self.Z.extend(other_vec); - self.num_vars += 1; - self.len *= 2; - assert_eq!(self.Z.len(), self.len); - } - #[tracing::instrument(skip_all, name = "DensePoly.merge")] - pub fn merge(polys: &[T]) -> DensePolynomial - where - T: AsRef>, - { - let total_len: usize = polys.iter().map(|poly| poly.as_ref().vec().len()).sum(); + pub fn merge(polys: impl IntoIterator> + Clone) -> DensePolynomial { + let polys_iter_cloned = polys.clone().into_iter(); + let total_len: usize = polys + .into_iter() + .map(|poly| poly.as_ref().vec().len()) + .sum(); let mut Z: Vec = Vec::with_capacity(total_len.next_power_of_two()); - for poly in polys { + for poly in polys_iter_cloned { Z.extend_from_slice(poly.as_ref().vec()); } @@ -469,25 +415,6 @@ impl DensePolynomial { DensePolynomial::new(Z) } - #[tracing::instrument(skip_all, name = "DensePoly.merge_dual")] - pub fn merge_dual(polys_a: &[T], polys_b: &[T]) -> DensePolynomial - where - T: AsRef>, - { - let total_len_a: usize = polys_a.iter().map(|poly| poly.as_ref().len()).sum(); - let total_len_b: usize = polys_b.iter().map(|poly| poly.as_ref().len()).sum(); - let total_len = total_len_a + total_len_b; - - let mut Z: Vec = Vec::with_capacity(total_len.next_power_of_two()); - polys_a.iter().for_each(|poly| Z.extend_from_slice(poly.as_ref().vec())); - polys_b.iter().for_each(|poly| Z.extend_from_slice(poly.as_ref().vec())); - - // pad the polynomial with zero polynomial at the end - Z.resize(Z.capacity(), F::zero()); - - DensePolynomial::new(Z) - } - pub fn combined_commit( &self, label: &'static [u8], @@ -1024,32 +951,6 @@ mod tests { ); } - #[test] - fn sm_msm_parity() { - use ark_curve25519::{EdwardsAffine as G1Affine, EdwardsProjective as G1Projective, Fr}; - let mut rng = test_rng(); - let bases = vec![ - G1Affine::rand(&mut rng), - G1Affine::rand(&mut rng), - G1Affine::rand(&mut rng), - ]; - let scalars = vec![Fr::from(3), Fr::from(2), Fr::from(1)]; - let expected_result = bases[0] + bases[0] + bases[0] + bases[1] + bases[1] + bases[2]; - assert_eq!(bases[0] + bases[0] + bases[0], bases[0] * scalars[0]); - let expected_result_b = - bases[0] * scalars[0] + bases[1] * scalars[1] + bases[2] * scalars[2]; - assert_eq!(expected_result, expected_result_b); - - let calc_result_a: G1Projective = VariableBaseMSM::msm(&bases, &scalars).unwrap(); - assert_eq!(calc_result_a, expected_result); - - let scalars_bigint: Vec<_> = scalars - .into_iter() - .map(|scalar| scalar.into_bigint()) - .collect(); - let calc_result_b: G1Projective = DensePolynomial::::sm_msm(&scalars_bigint, &bases); - assert_eq!(calc_result_b, expected_result); - } #[test] fn commit_with_hint_parity() {