Skip to content

Commit

Permalink
perf!: use barycentric evaluation in FRI
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-ferdinand committed Apr 23, 2024
2 parents 6b8ffd9 + 0fc7b7f commit 991688a
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 9 deletions.
3 changes: 3 additions & 0 deletions triton-vm/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ pub enum FriValidationError {
#[error("computed and received codeword of last round do not match")]
LastCodewordMismatch,

#[error("evaluations of last round's polynomial and last round codeword do not match")]
LastRoundPolynomialEvaluationMismatch,

#[error("last round's polynomial has too high degree")]
LastRoundPolynomialHasTooHighDegree,

Expand Down
127 changes: 118 additions & 9 deletions triton-vm/src/fri.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use std::marker::PhantomData;

use itertools::Itertools;
use num_traits::Zero;
use rayon::iter::*;
use twenty_first::math::traits::FiniteField;
use twenty_first::math::traits::PrimitiveRootOfUnity;
use twenty_first::prelude::*;

use crate::arithmetic_domain::ArithmeticDomain;
Expand Down Expand Up @@ -53,6 +55,7 @@ impl<'stream, H: AlgebraicHasher> FriProver<'stream, H> {
self.commit_to_next_round()?;
}
self.send_last_codeword();
self.send_last_polynomial();
Ok(())
}

Expand Down Expand Up @@ -94,6 +97,15 @@ impl<'stream, H: AlgebraicHasher> FriProver<'stream, H> {
self.proof_stream.enqueue(proof_item);
}

fn send_last_polynomial(&mut self) {
let last_codeword = &self.rounds.last().unwrap().codeword;
let last_polynomial = ArithmeticDomain::of_length(last_codeword.len())
.unwrap()
.interpolate(last_codeword);
let proof_item = ProofItem::FriPolynomial(last_polynomial.coefficients);
self.proof_stream.enqueue(proof_item);
}

fn query(&mut self) -> ProverResult<()> {
self.sample_first_round_collinearity_check_indices();

Expand Down Expand Up @@ -193,6 +205,7 @@ struct FriVerifier<'stream, H: AlgebraicHasher> {
rounds: Vec<VerifierRound>,
first_round_domain: ArithmeticDomain,
last_round_codeword: Vec<XFieldElement>,
last_round_polynomial: Polynomial<XFieldElement>,
last_round_max_degree: usize,
num_rounds: usize,
num_collinearity_checks: usize,
Expand All @@ -211,7 +224,8 @@ struct VerifierRound {
impl<'stream, H: AlgebraicHasher> FriVerifier<'stream, H> {
fn initialize(&mut self) -> VerifierResult<()> {
self.initialize_verification_rounds()?;
self.receive_last_round_codeword()
self.receive_last_round_codeword()?;
self.receive_last_round_polynomial()
}

fn initialize_verification_rounds(&mut self) -> VerifierResult<()> {
Expand Down Expand Up @@ -288,6 +302,12 @@ impl<'stream, H: AlgebraicHasher> FriVerifier<'stream, H> {
Ok(())
}

fn receive_last_round_polynomial(&mut self) -> VerifierResult<()> {
let coefficients = self.proof_stream.dequeue()?.try_into_fri_polynomial()?;
self.last_round_polynomial = Polynomial::new(coefficients);
Ok(())
}

fn compute_last_round_folded_partial_codeword(&mut self) -> VerifierResult<()> {
self.sample_first_round_collinearity_check_indices();
self.receive_authentic_partially_revealed_codewords()?;
Expand Down Expand Up @@ -450,7 +470,7 @@ impl<'stream, H: AlgebraicHasher> FriVerifier<'stream, H> {
.collect()
}

fn authenticate_last_round_codeword(&self) -> VerifierResult<()> {
fn authenticate_last_round_codeword(&mut self) -> VerifierResult<()> {
self.assert_last_round_codeword_matches_last_round_commitment()?;
self.assert_last_round_codeword_agrees_with_last_round_folded_codeword()?;
self.assert_last_round_codeword_corresponds_to_low_degree_polynomial()
Expand Down Expand Up @@ -500,19 +520,20 @@ impl<'stream, H: AlgebraicHasher> FriVerifier<'stream, H> {
}

fn assert_last_round_codeword_corresponds_to_low_degree_polynomial(
&self,
&mut self,
) -> VerifierResult<()> {
if self.last_round_polynomial().degree() > self.last_round_max_degree as isize {
let indeterminate = self.proof_stream.sample_scalars(1)[0];
let horner_evaluation = self.last_round_polynomial.evaluate(indeterminate);
let barycentric_evaluation = barycentric_evaluate(&self.last_round_codeword, indeterminate);
if horner_evaluation != barycentric_evaluation {
return Err(LastRoundPolynomialEvaluationMismatch);
}
if self.last_round_polynomial.degree() > self.last_round_max_degree.try_into().unwrap() {
return Err(LastRoundPolynomialHasTooHighDegree);
}
Ok(())
}

fn last_round_polynomial(&self) -> Polynomial<XFieldElement> {
let domain = self.rounds.last().unwrap().domain;
domain.interpolate(&self.last_round_codeword)
}

fn first_round_partially_revealed_codeword(&self) -> Vec<(usize, XFieldElement)> {
let partial_codeword_a = self.rounds[0].partial_codeword_a.clone();
let partial_codeword_b = self.rounds[0].partial_codeword_b.clone();
Expand Down Expand Up @@ -571,6 +592,12 @@ impl<H: AlgebraicHasher> Fri<H> {
prover.commit(codeword)?;
prover.query()?;

// Sample one XFieldElement from Fiat-Shamir and then throw it away. This
// scalar is the indeterminate for the low degree test using the barycentric
// evaluation formula. This indeterminate is used only by the verifier, but
// it is important to modify the sponge state the same way.
prover.proof_stream.sample_scalars(1);

let indices = prover.all_top_level_collinearity_check_indices();
Ok(indices)
}
Expand Down Expand Up @@ -615,6 +642,7 @@ impl<H: AlgebraicHasher> Fri<H> {
rounds: vec![],
first_round_domain: self.domain,
last_round_codeword: vec![],
last_round_polynomial: Polynomial::zero(),
last_round_max_degree: self.last_round_max_degree(),
num_rounds: self.num_rounds(),
num_collinearity_checks: self.num_collinearity_checks,
Expand Down Expand Up @@ -650,6 +678,46 @@ fn codeword_as_digests(codeword: &[XFieldElement]) -> Vec<Digest> {
codeword.par_iter().map(|&xfe| xfe.into()).collect()
}

/// Use the barycentric Lagrange evaluation formula to extrapolate the codeword
/// to an out-of-domain location.
///
/// [Credit] for (re)discovering this formula and especially its application to
/// FRI goes to Al-Kindi.
///
/// # Panics
///
/// Panics if the codeword is some length that is not a power of 2 or greater than (1 << 32).
///
/// [Credit]: https://github.com/0xPolygonMiden/miden-vm/issues/568
pub fn barycentric_evaluate(
codeword: &[XFieldElement],
indeterminate: XFieldElement,
) -> XFieldElement {
let root_order = codeword.len().try_into().unwrap();
let generator = BFieldElement::primitive_root_of_unity(root_order).unwrap();
let domain_iter = (0..root_order)
.scan(bfe!(1), |acc, _| {
let to_yield = Some(*acc);
*acc *= generator;
to_yield
})
.collect_vec();

let domain_shift = domain_iter.iter().map(|&d| indeterminate - d).collect();
let domain_shift_inverses = XFieldElement::batch_inversion(domain_shift);
let domain_over_domain_shift = domain_iter
.into_iter()
.zip(domain_shift_inverses)
.map(|(d, inv)| d * inv);
let numerator = domain_over_domain_shift
.clone()
.zip(codeword)
.map(|(dsi, &abscis)| dsi * abscis)
.sum::<XFieldElement>();
let denominator = domain_over_domain_shift.sum::<XFieldElement>();
numerator / denominator
}

#[cfg(test)]
mod tests {
use std::cmp::max;
Expand All @@ -658,6 +726,7 @@ mod tests {
use assert2::assert;
use assert2::let_assert;
use itertools::Itertools;
use proptest::collection::vec;
use proptest::prelude::*;
use proptest_arbitrary_interop::arb;
use rand::prelude::*;
Expand Down Expand Up @@ -856,6 +925,7 @@ mod tests {
(MerkleRoot(p), MerkleRoot(v)) => prop_assert_eq!(p, v),
(FriResponse(p), FriResponse(v)) => prop_assert_eq!(p, v),
(FriCodeword(p), FriCodeword(v)) => prop_assert_eq!(p, v),
(FriPolynomial(p), FriPolynomial(v)) => prop_assert_eq!(p, v),
_ => panic!("Unknown items.\nProver: {prover_item:?}\nVerifier: {verifier_item:?}"),
}
}
Expand Down Expand Up @@ -1015,11 +1085,50 @@ mod tests {
}
}

#[proptest]
fn incorrect_last_round_polynomial_results_in_verification_failure(
#[strategy(arbitrary_fri())] fri: Fri<Tip5>,
#[strategy(arbitrary_polynomial())] polynomial: Polynomial<XFieldElement>,
#[strategy(arb())] incorrect_coefficients: Vec<XFieldElement>,
) {
let codeword = fri.domain.evaluate(&polynomial);
let mut proof_stream = ProofStream::new();
fri.prove(&codeword, &mut proof_stream).unwrap();

let mut proof_stream = prepare_proof_stream_for_verification(proof_stream);
proof_stream.items.iter_mut().for_each(|item| {
if let ProofItem::FriPolynomial(coefficients) = item {
*coefficients = incorrect_coefficients.clone();
}
});

let verdict = fri.verify(&mut proof_stream, &mut None);
let_assert!(Err(err) = verdict);
assert!(let LastRoundPolynomialEvaluationMismatch = err);
}

#[proptest]
fn verifying_arbitrary_proof_does_not_panic(
#[strategy(arbitrary_fri())] fri: Fri<Tip5>,
#[strategy(arb())] mut proof_stream: ProofStream,
) {
let _ = fri.verify(&mut proof_stream, &mut None);
}

#[proptest]
fn polynomial_evaluation_and_barycentric_evaluation_are_equivalent(
#[strategy(1usize..13)] _log_num_coefficients: usize,
#[strategy(1usize..6)] log_expansion_factor: usize,
#[strategy(vec(arb(), 1 << #_log_num_coefficients))] coefficients: Vec<XFieldElement>,
#[strategy(arb())] indeterminate: XFieldElement,
) {
let domain_len = coefficients.len() * (1 << log_expansion_factor);
let domain = ArithmeticDomain::of_length(domain_len).unwrap();
let polynomial = Polynomial::from(&coefficients);
let codeword = domain.evaluate(&polynomial);
prop_assert_eq!(
polynomial.evaluate(indeterminate),
barycentric_evaluate(&codeword, indeterminate)
);
}
}
2 changes: 2 additions & 0 deletions triton-vm/src/proof_item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ proof_items!(
Log2PaddedHeight(u32) => false, try_into_log2_padded_height,
QuotientSegmentsElements(Vec<QuotientSegments>) => false, try_into_quot_segments_elements,
FriCodeword(Vec<XFieldElement>) => false, try_into_fri_codeword,
FriPolynomial(Vec<XFieldElement>) => false, try_into_fri_polynomial,
FriResponse(FriResponse) => false, try_into_fri_response,
);

Expand Down Expand Up @@ -189,6 +190,7 @@ pub(crate) mod tests {
assert!(let Err(UnexpectedItem{..}) = item.clone().try_into_log2_padded_height());
assert!(let Err(UnexpectedItem{..}) = item.clone().try_into_quot_segments_elements());
assert!(let Err(UnexpectedItem{..}) = item.clone().try_into_fri_codeword());
assert!(let Err(UnexpectedItem{..}) = item.clone().try_into_fri_polynomial());
assert!(let Err(UnexpectedItem{..}) = item.try_into_fri_response());
}

Expand Down

0 comments on commit 991688a

Please sign in to comment.