From 6b8889662665e26c554e2db3f324b83a7a480716 Mon Sep 17 00:00:00 2001 From: Jay White Date: Tue, 18 Apr 2023 19:09:28 -0500 Subject: [PATCH] Performance: use fewer operations in generate sum loop (#71) --- CHANGELOG.md | 2 ++ src/ml_sumcheck/protocol/prover.rs | 24 ++++++++++++++---------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ceb772..0f7725b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ ### Improvements +- [\#71](https://github.com/arkworks-rs/sumcheck/pull/71) Improve prover performance by using an arithmetic sequence rather than interpolation inside of the `prove_round` loop. + - [\#55](https://github.com/arkworks-rs/sumcheck/pull/55) Improve the interpolation performance and avoid unnecessary state clones. ### Bug fixes diff --git a/src/ml_sumcheck/protocol/prover.rs b/src/ml_sumcheck/protocol/prover.rs index 506cba6..721cc77 100644 --- a/src/ml_sumcheck/protocol/prover.rs +++ b/src/ml_sumcheck/protocol/prover.rs @@ -98,21 +98,25 @@ impl IPForMLSumcheck { let mut products_sum = Vec::with_capacity(degree + 1); products_sum.resize(degree + 1, F::zero()); + let mut product = Vec::with_capacity(degree + 1); + product.resize(degree + 1, F::zero()); // generate sum for b in 0..1 << (nv - i) { - let mut t_as_field = F::zero(); - for old_product in products_sum.iter_mut().take(degree + 1) { - for (coefficient, products) in &prover_state.list_of_products { - let mut product = *coefficient; - for &jth_product in products { - let table = &prover_state.flattened_ml_extensions[jth_product]; - product *= table[b << 1] * (F::one() - t_as_field) - + table[(b << 1) + 1] * t_as_field; + for (coefficient, products) in &prover_state.list_of_products { + product.fill(*coefficient); + for &jth_product in products { + let table = &prover_state.flattened_ml_extensions[jth_product]; + let mut start = table[b << 1]; + let step = table[(b << 1) + 1] - start; + for p in product.iter_mut() { + *p *= start; + start += step; } - *old_product += product; } - t_as_field += F::one(); + for t in 0..degree + 1 { + products_sum[t] += product[t]; + } } }