Skip to content

Commit

Permalink
perf: optimize sumcheck terms
Browse files Browse the repository at this point in the history
  • Loading branch information
JayWhite2357 committed Dec 12, 2024
1 parent be88f56 commit 52c3f90
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 22 deletions.
56 changes: 41 additions & 15 deletions crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
use core::ffi::c_void;

use super::{SumcheckRandomScalars, SumcheckSubpolynomial, SumcheckSubpolynomialType};
use super::{
sumcheck_term_optimizer::SumcheckTermOptimizer, SumcheckRandomScalars, SumcheckSubpolynomial,
SumcheckSubpolynomialType,
};
use crate::{
base::{map::IndexMap, polynomial::MultilinearExtension, scalar::Scalar},
proof_primitive::sumcheck::ProverState,
};
use alloc::vec::Vec;
use itertools::Itertools;
use tracing::Level;

#[tracing::instrument(
name = "query_proof::make_sumcheck_prover_state",
Expand All @@ -26,11 +30,19 @@ pub fn make_sumcheck_prover_state<S: Scalar>(
.iter()
.zip(subpolynomials)
.flat_map(|(multiplier, terms)| terms.iter_mul_by(*multiplier));

// Optimization should be very fast. We put this span to double check this. There is almost no copying being done.
let span = tracing::span!(Level::DEBUG, "optimize sumcheck terms").entered();
let optimizer = SumcheckTermOptimizer::new(all_terms, 1 << num_vars);
let optimized_terms = optimizer.terms();
let optimized_term_iter = optimized_terms.into_iter();
span.exit();

let mut builder = FlattenedMLEBuilder::new(
needs_entrywise_multipliers.then(|| scalars.compute_entrywise_multipliers()),
num_vars,
);
let list_of_products = all_terms
let list_of_products = optimized_term_iter
.map(|(ty, coeff, term)| {
(
coeff,
Expand Down Expand Up @@ -136,10 +148,10 @@ mod tests {
assert_eq!(
prover_state.list_of_products,
vec![
(TestScalar::from(101 * 202), vec![1, 0]),
(TestScalar::from(102 * 202), vec![2, 1, 0]),
(TestScalar::from(103 * 203), vec![1]),
(TestScalar::from(104 * 203), vec![2, 1])
(TestScalar::from(104 * 203), vec![2, 1]),
(TestScalar::from(101 * 202), vec![1, 0]),
(TestScalar::from(102 * 202), vec![2, 1, 0])
]
);
assert_eq!(
Expand Down Expand Up @@ -219,18 +231,12 @@ mod tests {
assert_eq!(
prover_state.list_of_products,
vec![
(TestScalar::from(101 * 204), vec![0]),
(TestScalar::from(102 * 204), vec![0]),
(TestScalar::from(103 * 204), vec![1, 0]),
(TestScalar::from(104 * 204), vec![2, 0]),
(TestScalar::from(111 * 207), vec![1, 2]),
(TestScalar::from(112 * 207), vec![3, 2, 4]),
(TestScalar::from(105 * 205), vec![2, 3, 0]),
(TestScalar::from(106 * 205), vec![1, 2, 4, 0]),
(TestScalar::from(107 * 206), vec![]),
(TestScalar::from(108 * 206), vec![]),
(TestScalar::from(109 * 206), vec![3]),
(TestScalar::from(110 * 206), vec![4]),
(TestScalar::from(111 * 207), vec![1, 2]),
(TestScalar::from(112 * 207), vec![3, 2, 4])
(TestScalar::from(1), vec![5]),
(TestScalar::from(1), vec![6, 0]),
]
);
assert_eq!(
Expand All @@ -250,6 +256,26 @@ mod tests {
vec![1, 0, 0, 0, 0, 0, 0, 0],
vec![2, 3, 0, 0, 0, 0, 0, 0],
vec![4, 5, 6, 7, 8, 0, 0, 0],
vec![
107 * 206 + 108 * 206 + 109 * 206 * 2 + 110 * 206 * 4,
107 * 206 + 108 * 206 + 109 * 206 * 3 + 110 * 206 * 5,
107 * 206 + 108 * 206 + 110 * 206 * 6,
107 * 206 + 108 * 206 + 110 * 206 * 7,
107 * 206 + 108 * 206 + 110 * 206 * 8,
107 * 206 + 108 * 206,
107 * 206 + 108 * 206,
107 * 206 + 108 * 206
],
vec![
101 * 204 + 102 * 204 + 104 * 204,
101 * 204 + 102 * 204,
101 * 204 + 102 * 204,
101 * 204 + 102 * 204,
101 * 204 + 102 * 204,
101 * 204 + 102 * 204,
101 * 204 + 102 * 204,
101 * 204 + 102 * 204
],
]
.into_iter()
.map(|v| v.into_iter().map(TestScalar::from).collect_vec())
Expand Down
2 changes: 2 additions & 0 deletions crates/proof-of-sql/src/sql/proof/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,5 @@ pub(crate) use first_round_builder::FirstRoundBuilder;
mod provable_query_result_test;

mod make_sumcheck_state;

mod sumcheck_term_optimizer;
10 changes: 3 additions & 7 deletions crates/proof-of-sql/src/sql/proof/sumcheck_subpolynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::base::{polynomial::MultilinearExtension, scalar::Scalar};
use alloc::{boxed::Box, vec::Vec};

/// The type of a sumcheck subpolynomial
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
pub enum SumcheckSubpolynomialType {
/// The subpolynomial should be zero at every entry/row
Identity,
Expand Down Expand Up @@ -82,15 +82,11 @@ impl<'a, S: Scalar> SumcheckSubpolynomial<'a, S> {
Item = (
SumcheckSubpolynomialType,
S,
&[Box<dyn MultilinearExtension<S> + 'a>],
&Vec<Box<dyn MultilinearExtension<S> + 'a>>,
),
> {
self.terms.iter().map(move |(coeff, multiplicands)| {
(
self.subpolynomial_type,
multiplier * *coeff,
multiplicands.as_slice(),
)
(self.subpolynomial_type, multiplier * *coeff, multiplicands)
})
}
}
Expand Down
143 changes: 143 additions & 0 deletions crates/proof-of-sql/src/sql/proof/sumcheck_term_optimizer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
use super::SumcheckSubpolynomialType;
use crate::base::{map::IndexMap, polynomial::MultilinearExtension, scalar::Scalar};
use alloc::{boxed::Box, vec, vec::Vec};
use core::{
iter::{Chain, Copied, Flatten, Map},
slice,
};

type SumcheckTerm<'a, S> = Vec<Box<dyn MultilinearExtension<S> + 'a>>;

pub struct SumcheckTermOptimizer<'a, S: Scalar> {
merged_terms: Vec<(SumcheckSubpolynomialType, S, Vec<Vec<S>>)>,
old_grouped_terms: Vec<Vec<(SumcheckSubpolynomialType, S, &'a SumcheckTerm<'a, S>)>>,
}
pub struct OptimizedSumcheckTerms<'a, S: Scalar> {
old_grouped_terms: &'a Vec<Vec<(SumcheckSubpolynomialType, S, &'a SumcheckTerm<'a, S>)>>,
new_mle_terms: Vec<(SumcheckSubpolynomialType, S, SumcheckTerm<'a, S>)>,
}

fn merge_subquadratic_terms<'a, S: Scalar + 'a>(
maybe_constant_terms: Option<Vec<(SumcheckSubpolynomialType, S, &'a SumcheckTerm<'a, S>)>>,
maybe_linear_terms: Option<Vec<(SumcheckSubpolynomialType, S, &'a SumcheckTerm<'a, S>)>>,
merged_terms: &mut Vec<(SumcheckSubpolynomialType, S, Vec<Vec<S>>)>,
term_length: usize,
ty: SumcheckSubpolynomialType,
) -> Option<Vec<(SumcheckSubpolynomialType, S, &'a SumcheckTerm<'a, S>)>> {
let maybe_constant_sum =
maybe_constant_terms.map(|terms| terms.into_iter().map(|(_, coeff, _)| coeff).sum());

match (maybe_constant_sum, maybe_linear_terms) {
(Some(constant_sum), None) => {
merged_terms.push((ty, constant_sum, vec![]));
None
}
(maybe_constant_sum, Some(linear_terms))
if maybe_constant_sum.is_some() || linear_terms.len() >= 2 =>
{
let mut combined_term = vec![maybe_constant_sum.unwrap_or(S::ZERO); term_length];
for (_, coeff, linear_term) in linear_terms {
linear_term[0].mul_add(&mut combined_term, &coeff);
}
merged_terms.push((ty, S::ONE, vec![combined_term]));
None
}
(_, maybe_linear_terms) => maybe_linear_terms,
}
}

impl<'a, S: Scalar + 'a> SumcheckTermOptimizer<'a, S> {
pub fn new(
all_terms: impl Iterator<Item = (SumcheckSubpolynomialType, S, &'a SumcheckTerm<'a, S>)>,
term_length: usize,
) -> Self {
let mut groups = all_terms.fold(
IndexMap::<_, Vec<_>>::default(),
|mut lookup, (ty, coeff, multiplicands)| {
lookup
.entry((ty, multiplicands.len().min(2)))
.or_default()
.push((ty, coeff, multiplicands));
lookup
},
);
let mut merged_terms = Vec::with_capacity(2);
let old_grouped_terms = [
SumcheckSubpolynomialType::ZeroSum,
SumcheckSubpolynomialType::Identity,
]
.into_iter()
.flat_map(|ty| {
let maybe_constant_terms = groups.swap_remove(&(ty, 0));
let maybe_linear_terms = groups.swap_remove(&(ty, 1));
let maybe_superlinear_terms = groups.swap_remove(&(ty, 2));

let maybe_combined_terms = merge_subquadratic_terms(
maybe_constant_terms,
maybe_linear_terms,
&mut merged_terms,
term_length,
ty,
);

[maybe_combined_terms, maybe_superlinear_terms]
.into_iter()
.flatten()
})
.collect();

Self {
merged_terms,
old_grouped_terms,
}
}
}

impl<'a, S: Scalar + 'a> SumcheckTermOptimizer<'a, S> {
pub fn terms(&'a self) -> OptimizedSumcheckTerms<'a, S> {
OptimizedSumcheckTerms {
old_grouped_terms: &self.old_grouped_terms,
new_mle_terms: self
.merged_terms
.iter()
.map(|(ty, coeff, terms)| {
(
*ty,
*coeff,
terms
.iter()
.map(|mle| -> Box<dyn MultilinearExtension<S>> { Box::new(mle) })
.collect::<Vec<_>>(),
)
})
.collect(),
}
}
}

impl<'a, S: Scalar + 'a> IntoIterator for &'a OptimizedSumcheckTerms<'a, S> {
type Item = (SumcheckSubpolynomialType, S, &'a SumcheckTerm<'a, S>);

// Currently, `impl Trait` in associated types is unstable. We can change this to the following when it stabilizes:
// type IntoIter = impl Iterator<Item = (SumcheckSubpolynomialType, S, &'a SumcheckTerm<'a, S>)>;
type IntoIter = Chain<
Copied<
Flatten<slice::Iter<'a, Vec<(SumcheckSubpolynomialType, S, &'a SumcheckTerm<'a, S>)>>>,
>,
Map<
slice::Iter<'a, (SumcheckSubpolynomialType, S, SumcheckTerm<'a, S>)>,
fn(
&'a (SumcheckSubpolynomialType, S, SumcheckTerm<'a, S>),
) -> (SumcheckSubpolynomialType, S, &'a SumcheckTerm<'a, S>),
>,
>;

fn into_iter(self) -> Self::IntoIter {
let result = self.old_grouped_terms.iter().flatten().copied().chain(
self.new_mle_terms
.iter()
.map((|(ty, coeff, terms)| (*ty, *coeff, terms)) as fn(&'a _) -> _),
);
result
}
}

0 comments on commit 52c3f90

Please sign in to comment.