Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup dense poly #151

Merged
merged 5 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jolt-core/src/jolt/vm/instruction_lookups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ where
fn batch(&self) -> Self::BatchedPolynomials {
use rayon::prelude::*;
let (batched_dim_read, (batched_final, batched_E, batched_flag)) = rayon::join(
|| DensePolynomial::merge_dual(self.dim.as_ref(), self.read_cts.as_ref()),
|| DensePolynomial::merge(self.dim.iter().chain(&self.read_cts)),
|| {
let batched_final = DensePolynomial::merge(&self.final_cts);
let (batched_E, batched_flag) = rayon::join(
Expand Down
2 changes: 1 addition & 1 deletion jolt-core/src/lasso/surge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ where
#[tracing::instrument(skip_all, name = "SurgePolys::batch")]
fn batch(&self) -> Self::BatchedPolynomials {
let (batched_dim_read, (batched_final, batched_E)) = rayon::join(
|| DensePolynomial::merge_dual(self.dim.as_ref(), self.read_cts.as_ref()),
|| DensePolynomial::merge(self.dim.iter().chain(&self.read_cts)),
|| {
rayon::join(
|| DensePolynomial::merge(&self.final_cts),
Expand Down
83 changes: 83 additions & 0 deletions jolt-core/src/msm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,3 +328,86 @@ fn ln_without_floats(a: usize) -> usize {
// log2(a) * ln(2)
(ark_std::log2(a) * 69 / 100) as usize
}

/// Special MSM where all scalar values are 0 / 1 – does not verify.
pub(crate) fn flags_msm<G: CurveGroup>(scalars: &[G::ScalarField], bases: &[G::Affine]) -> G {
assert_eq!(scalars.len(), bases.len());
let result = scalars
.into_iter()
.enumerate()
.filter(|(_index, scalar)| !scalar.is_zero())
.map(|(index, scalar)| bases[index])
.sum();

result
}

pub(crate) fn sm_msm<V: VariableBaseMSM>(
scalars: &[<V::ScalarField as PrimeField>::BigInt],
bases: &[V::MulBase],
) -> V {
assert_eq!(scalars.len(), bases.len());
let num_buckets: usize = 1 << 16; // TODO(sragss): This should be passed in / dependent on M = N^{1/C}

// #[cfg(test)]
// scalars.for_each(|scalar| {
// assert!(scalar < V::ScalarField::from(num_buckets as u64).into_bigint())
// });

// Assign things to buckets based on the scalar
let mut buckets: Vec<V> = vec![V::zero(); num_buckets];
scalars.into_iter().enumerate().for_each(|(index, scalar)| {
let bucket_index: u64 = scalar.as_ref()[0];
buckets[bucket_index as usize] += bases[index];
});

let mut result = V::zero();
let mut running_sum = V::zero();
buckets
.into_iter()
.skip(1)
.enumerate()
.rev()
.for_each(|(index, bucket)| {
running_sum += bucket;
result += running_sum;
});
result
}

#[cfg(test)]
mod tests {

use ark_std::test_rng;

use crate::poly::dense_mlpoly::DensePolynomial;

use super::*;

#[test]
fn sm_msm_parity() {
use ark_curve25519::{EdwardsAffine as G1Affine, EdwardsProjective as G1Projective, Fr};
let mut rng = test_rng();
let bases = vec![
G1Affine::rand(&mut rng),
G1Affine::rand(&mut rng),
G1Affine::rand(&mut rng),
];
let scalars = vec![Fr::from(3), Fr::from(2), Fr::from(1)];
let expected_result = bases[0] + bases[0] + bases[0] + bases[1] + bases[1] + bases[2];
assert_eq!(bases[0] + bases[0] + bases[0], bases[0] * scalars[0]);
let expected_result_b =
bases[0] * scalars[0] + bases[1] * scalars[1] + bases[2] * scalars[2];
assert_eq!(expected_result, expected_result_b);

let calc_result_a: G1Projective = VariableBaseMSM::msm(&bases, &scalars).unwrap();
assert_eq!(calc_result_a, expected_result);

let scalars_bigint: Vec<_> = scalars
.into_iter()
.map(|scalar| scalar.into_bigint())
.collect();
let calc_result_b: G1Projective = sm_msm(&scalars_bigint, &bases);
assert_eq!(calc_result_b, expected_result);
}
}
119 changes: 10 additions & 109 deletions jolt-core/src/poly/dense_mlpoly.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#![allow(clippy::too_many_arguments)]
use crate::msm::{flags_msm, sm_msm};
use crate::poly::eq_poly::EqPolynomial;
use crate::utils::{self, compute_dotproduct, compute_dotproduct_low_optimized, mul_0_1_optimized};

Expand Down Expand Up @@ -219,63 +220,17 @@ impl<F: PrimeField> DensePolynomial<F> {
let scalars = self.Z[R_size * i..R_size * (i + 1)].as_ref();
match hint {
CommitHint::Normal => Commitments::batch_commit(scalars, &gens),
CommitHint::Flags => Self::flags_msm(scalars, &gens),
CommitHint::Flags => flags_msm(scalars, &gens),
CommitHint::Small => {
let bigints: Vec<_> = scalars.iter().map(|s| s.into_bigint()).collect();
Self::sm_msm(&bigints, &gens)
sm_msm(&bigints, &gens)
}
}
})
.collect();
PolyCommitment { C }
}

/// Special MSM where all scalar values are 0 / 1 – does not verify.
fn flags_msm<G: CurveGroup>(scalars: &[G::ScalarField], bases: &[G::Affine]) -> G {
assert_eq!(scalars.len(), bases.len());
let result = scalars
.into_iter()
.enumerate()
.filter(|(_index, scalar)| !scalar.is_zero())
.map(|(index, scalar)| bases[index])
.sum();

result
}

pub fn sm_msm<V: VariableBaseMSM>(
scalars: &[<V::ScalarField as PrimeField>::BigInt],
bases: &[V::MulBase],
) -> V {
assert_eq!(scalars.len(), bases.len());
let num_buckets: usize = 1 << 16; // TODO(sragss): This should be passed in / dependent on M = N^{1/C}

// #[cfg(test)]
// scalars.for_each(|scalar| {
// assert!(scalar < V::ScalarField::from(num_buckets as u64).into_bigint())
// });

// Assign things to buckets based on the scalar
let mut buckets: Vec<V> = vec![V::zero(); num_buckets];
scalars.into_iter().enumerate().for_each(|(index, scalar)| {
let bucket_index: u64 = scalar.as_ref()[0];
buckets[bucket_index as usize] += bases[index];
});

let mut result = V::zero();
let mut running_sum = V::zero();
buckets
.into_iter()
.skip(1)
.enumerate()
.rev()
.for_each(|(index, bucket)| {
running_sum += bucket;
result += running_sum;
});
result
}

#[tracing::instrument(skip_all, name = "DensePolynomial.bound")]
pub fn bound(&self, L: &[F]) -> Vec<F> {
let (left_num_vars, right_num_vars) =
Expand Down Expand Up @@ -442,24 +397,15 @@ impl<F: PrimeField> DensePolynomial<F> {
self.Z.as_ref()
}

pub fn extend(&mut self, other: &DensePolynomial<F>) {
assert_eq!(self.Z.len(), self.len);
let other_vec = other.vec();
assert_eq!(other_vec.len(), self.len);
self.Z.extend(other_vec);
self.num_vars += 1;
self.len *= 2;
assert_eq!(self.Z.len(), self.len);
}

#[tracing::instrument(skip_all, name = "DensePoly.merge")]
pub fn merge<T>(polys: &[T]) -> DensePolynomial<F>
where
T: AsRef<DensePolynomial<F>>,
{
let total_len: usize = polys.iter().map(|poly| poly.as_ref().vec().len()).sum();
pub fn merge(polys: impl IntoIterator<Item = impl AsRef<Self>> + Clone) -> DensePolynomial<F> {
let polys_iter_cloned = polys.clone().into_iter();
let total_len: usize = polys
.into_iter()
.map(|poly| poly.as_ref().vec().len())
.sum();
let mut Z: Vec<F> = Vec::with_capacity(total_len.next_power_of_two());
for poly in polys {
for poly in polys_iter_cloned {
Z.extend_from_slice(poly.as_ref().vec());
}

Expand All @@ -469,25 +415,6 @@ impl<F: PrimeField> DensePolynomial<F> {
DensePolynomial::new(Z)
}

#[tracing::instrument(skip_all, name = "DensePoly.merge_dual")]
pub fn merge_dual<T>(polys_a: &[T], polys_b: &[T]) -> DensePolynomial<F>
where
T: AsRef<DensePolynomial<F>>,
{
let total_len_a: usize = polys_a.iter().map(|poly| poly.as_ref().len()).sum();
let total_len_b: usize = polys_b.iter().map(|poly| poly.as_ref().len()).sum();
let total_len = total_len_a + total_len_b;

let mut Z: Vec<F> = Vec::with_capacity(total_len.next_power_of_two());
polys_a.iter().for_each(|poly| Z.extend_from_slice(poly.as_ref().vec()));
polys_b.iter().for_each(|poly| Z.extend_from_slice(poly.as_ref().vec()));

// pad the polynomial with zero polynomial at the end
Z.resize(Z.capacity(), F::zero());

DensePolynomial::new(Z)
}

pub fn combined_commit<G>(
&self,
label: &'static [u8],
Expand Down Expand Up @@ -1024,32 +951,6 @@ mod tests {
);
}

#[test]
fn sm_msm_parity() {
use ark_curve25519::{EdwardsAffine as G1Affine, EdwardsProjective as G1Projective, Fr};
let mut rng = test_rng();
let bases = vec![
G1Affine::rand(&mut rng),
G1Affine::rand(&mut rng),
G1Affine::rand(&mut rng),
];
let scalars = vec![Fr::from(3), Fr::from(2), Fr::from(1)];
let expected_result = bases[0] + bases[0] + bases[0] + bases[1] + bases[1] + bases[2];
assert_eq!(bases[0] + bases[0] + bases[0], bases[0] * scalars[0]);
let expected_result_b =
bases[0] * scalars[0] + bases[1] * scalars[1] + bases[2] * scalars[2];
assert_eq!(expected_result, expected_result_b);

let calc_result_a: G1Projective = VariableBaseMSM::msm(&bases, &scalars).unwrap();
assert_eq!(calc_result_a, expected_result);

let scalars_bigint: Vec<_> = scalars
.into_iter()
.map(|scalar| scalar.into_bigint())
.collect();
let calc_result_b: G1Projective = DensePolynomial::<Fr>::sm_msm(&scalars_bigint, &bases);
assert_eq!(calc_result_b, expected_result);
}

#[test]
fn commit_with_hint_parity() {
Expand Down