Skip to content

Commit

Permalink
feat!: return Err(_), don't panic!, on unsupported root of unity
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-ferdinand committed Feb 10, 2024
1 parent 9653716 commit 60289eb
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 55 deletions.
1 change: 1 addition & 0 deletions triton-vm/src/aet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ impl AlgebraicExecutionTrace {
aet
}

/// Guaranteed to be a power of two.
pub fn padded_height(&self) -> usize {
let relevant_table_heights = [
self.program_table_length(),
Expand Down
76 changes: 43 additions & 33 deletions triton-vm/src/arithmetic_domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ use twenty_first::prelude::*;
use twenty_first::shared_math::traits::FiniteField;
use twenty_first::shared_math::traits::PrimitiveRootOfUnity;

use crate::error::ArithmeticDomainError;

type Result<T> = std::result::Result<T, ArithmeticDomainError>;

#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct ArithmeticDomain {
pub offset: BFieldElement,
Expand All @@ -14,29 +18,35 @@ pub struct ArithmeticDomain {

impl ArithmeticDomain {
/// Create a new domain with the given length.
/// No offset is applied, but can added through [`with_offset()`](Self::with_offset).
pub fn of_length(length: usize) -> Self {
Self {
/// No offset is applied, but can be added through [`with_offset()`](Self::with_offset).
///
/// # Errors
///
/// Errors if the domain length is not a power of 2.
pub fn of_length(length: usize) -> Result<Self> {
let domain = Self {
offset: BFieldElement::one(),
generator: Self::generator_for_length(length as u64),
generator: Self::generator_for_length(length as u64)?,
length,
}
};
Ok(domain)
}

/// Set the offset of the domain.
#[must_use]
pub fn with_offset(mut self, offset: BFieldElement) -> Self {
self.offset = offset;
self
}

/// Derive a generator for a domain of the given length.
/// The domain length must be a power of 2.
pub fn generator_for_length(domain_length: u64) -> BFieldElement {
assert!(
0 == domain_length || domain_length.is_power_of_two(),
"The domain length must be a power of 2 but was {domain_length}.",
);
BFieldElement::primitive_root_of_unity(domain_length).unwrap()
///
/// # Errors
///
/// Errors if the domain length is not a power of 2.
pub fn generator_for_length(domain_length: u64) -> Result<BFieldElement> {
let error = ArithmeticDomainError::PrimitiveRootNotSupported(domain_length);
BFieldElement::primitive_root_of_unity(domain_length).ok_or(error)
}

pub fn evaluate<FF>(&self, polynomial: &Polynomial<FF>) -> Vec<FF>
Expand Down Expand Up @@ -94,19 +104,22 @@ impl ArithmeticDomain {
domain_values
}

#[must_use]
pub(crate) fn halve(&self) -> Self {
assert!(self.length >= 2);
Self {
pub(crate) fn halve(&self) -> Result<Self> {
if self.length < 2 {
return Err(ArithmeticDomainError::TooSmallForHalving(self.length));
}
let domain = Self {
offset: self.offset.square(),
generator: self.generator.square(),
length: self.length / 2,
}
};
Ok(domain)
}
}

#[cfg(test)]
mod tests {
use assert2::let_assert;
use itertools::Itertools;
use proptest::prelude::*;
use proptest_arbitrary_interop::arb;
Expand Down Expand Up @@ -140,7 +153,7 @@ mod tests {
fn arbitrary_domain_of_length(length: usize)(
offset in arb(),
) -> ArithmeticDomain {
ArithmeticDomain::of_length(length).with_offset(offset)
ArithmeticDomain::of_length(length).unwrap().with_offset(offset)
}
}

Expand Down Expand Up @@ -184,7 +197,9 @@ mod tests {
for order in [4, 8, 32] {
let generator = BFieldElement::primitive_root_of_unity(order).unwrap();
let offset = BFieldElement::generator();
let b_domain = ArithmeticDomain::of_length(order as usize).with_offset(offset);
let b_domain = ArithmeticDomain::of_length(order as usize)
.unwrap()
.with_offset(offset);

let expected_b_values = (0..order)
.map(|i| offset * generator.mod_pow(i))
Expand Down Expand Up @@ -224,8 +239,8 @@ mod tests {
let long_domain_len = 128;
let unit_distance = long_domain_len / short_domain_len;

let short_domain = ArithmeticDomain::of_length(short_domain_len);
let long_domain = ArithmeticDomain::of_length(long_domain_len);
let short_domain = ArithmeticDomain::of_length(short_domain_len).unwrap();
let long_domain = ArithmeticDomain::of_length(long_domain_len).unwrap();

let polynomial = Polynomial::new([1, 2, 3, 4].map(BFieldElement::new).to_vec());
let short_codeword = short_domain.evaluate(&polynomial);
Expand All @@ -244,7 +259,7 @@ mod tests {
fn halving_domain_squares_all_points(
#[strategy(arbitrary_halveable_domain())] domain: ArithmeticDomain,
) {
let half_domain = domain.halve();
let half_domain = domain.halve()?;
prop_assert_eq!(domain.length / 2, half_domain.length);

let domain_points = domain.domain_values();
Expand All @@ -259,16 +274,11 @@ mod tests {
}

#[test]
#[should_panic]
fn domain_of_length_one_cannot_be_halved() {
let domain = ArithmeticDomain::of_length(1);
let _ = domain.halve();
}

#[test]
#[should_panic]
fn domain_of_length_zero_cannot_be_halved() {
let domain = ArithmeticDomain::of_length(0);
let _ = domain.halve();
fn too_small_domains_cannot_be_halved() {
for i in [0, 1] {
let domain = ArithmeticDomain::of_length(i).unwrap();
let_assert!(Err(err) = domain.halve());
assert!(ArithmeticDomainError::TooSmallForHalving(i) == err);
}
}
}
25 changes: 25 additions & 0 deletions triton-vm/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@ pub enum InstructionError {
MachineHalted,
}

#[non_exhaustive]
#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)]
pub enum ArithmeticDomainError {
#[error("the domain's length must be a power of 2 but was {0}")]
PrimitiveRootNotSupported(u64),

#[error("the domain's length must be at least 2 to be halved, but it was {0}")]
TooSmallForHalving(usize),
}

#[non_exhaustive]
#[derive(Debug, Error)]
pub enum ProofStreamError {
Expand Down Expand Up @@ -131,13 +141,19 @@ pub enum FriSetupError {

#[error("the expansion factor must be smaller than the domain length")]
ExpansionFactorMismatch,

#[error(transparent)]
ArithmeticDomainError(#[from] ArithmeticDomainError),
}

#[non_exhaustive]
#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)]
pub enum FriProvingError {
#[error(transparent)]
MerkleTreeError(#[from] MerkleTreeError),

#[error(transparent)]
ArithmeticDomainError(#[from] ArithmeticDomainError),
}

#[non_exhaustive]
Expand All @@ -163,6 +179,9 @@ pub enum FriValidationError {

#[error(transparent)]
MerkleTreeError(#[from] MerkleTreeError),

#[error(transparent)]
ArithmeticDomainError(#[from] ArithmeticDomainError),
}

#[non_exhaustive]
Expand Down Expand Up @@ -227,6 +246,9 @@ pub enum ProvingError {
#[error(transparent)]
MerkleTreeError(#[from] MerkleTreeError),

#[error(transparent)]
ArithmeticDomainError(#[from] ArithmeticDomainError),

#[error(transparent)]
FriSetupError(#[from] FriSetupError),

Expand Down Expand Up @@ -273,6 +295,9 @@ pub enum VerificationError {
#[error(transparent)]
ProofStreamError(#[from] ProofStreamError),

#[error(transparent)]
ArithmeticDomainError(#[from] ArithmeticDomainError),

#[error(transparent)]
FriSetupError(#[from] FriSetupError),

Expand Down
16 changes: 9 additions & 7 deletions triton-vm/src/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ impl<'stream, H: AlgebraicHasher> FriProver<'stream, H> {
let previous_round = self.rounds.last().unwrap();
let folding_challenge = self.proof_stream.sample_scalars(1)[0];
let codeword = previous_round.split_and_fold(folding_challenge);
let domain = previous_round.domain.halve();
let domain = previous_round.domain.halve()?;
ProverRound::new(domain, &codeword)
}

Expand Down Expand Up @@ -243,7 +243,7 @@ impl<'stream, H: AlgebraicHasher> FriVerifier<'stream, H> {

fn construct_next_round(&mut self) -> VerifierResult<VerifierRound> {
let previous_round = self.rounds.last().unwrap();
let domain = previous_round.domain.halve();
let domain = previous_round.domain.halve()?;
self.construct_round_with_domain(domain)
}

Expand Down Expand Up @@ -701,7 +701,9 @@ mod tests {
let min_expanded_domain_length = min_domain_length * expansion_factor;
let domain_length = max(sampled_domain_length, min_expanded_domain_length);

let fri_domain = ArithmeticDomain::of_length(domain_length).with_offset(offset);
let maybe_domain = ArithmeticDomain::of_length(domain_length);
let fri_domain = maybe_domain.unwrap().with_offset(offset);

Fri::new(fri_domain, expansion_factor, num_collinearity_checks).unwrap()
}
}
Expand Down Expand Up @@ -798,15 +800,15 @@ mod tests {
}

fn smallest_fri() -> Fri<Tip5> {
let domain = ArithmeticDomain::of_length(2);
let domain = ArithmeticDomain::of_length(2).unwrap();
let expansion_factor = 2;
let num_collinearity_checks = 1;
Fri::new(domain, expansion_factor, num_collinearity_checks).unwrap()
}

#[test]
fn too_small_expansion_factor_is_rejected() {
let domain = ArithmeticDomain::of_length(2);
let domain = ArithmeticDomain::of_length(2).unwrap();
let expansion_factor = 1;
let num_collinearity_checks = 1;
let err = Fri::<Tip5>::new(domain, expansion_factor, num_collinearity_checks).unwrap_err();
Expand All @@ -820,7 +822,7 @@ mod tests {
expansion_factor: usize,
) {
let largest_supported_domain_size = 1 << 32;
let domain = ArithmeticDomain::of_length(largest_supported_domain_size);
let domain = ArithmeticDomain::of_length(largest_supported_domain_size).unwrap();
let num_collinearity_checks = 1;
let err = Fri::<Tip5>::new(domain, expansion_factor, num_collinearity_checks).unwrap_err();
prop_assert_eq!(FriSetupError::ExpansionFactorUnsupported, err);
Expand All @@ -833,7 +835,7 @@ mod tests {
) {
let expansion_factor = 1 << log_2_expansion_factor;
let domain_length = 1 << log_2_domain_length;
let domain = ArithmeticDomain::of_length(domain_length);
let domain = ArithmeticDomain::of_length(domain_length).unwrap();
let num_collinearity_checks = 1;
let err = Fri::<Tip5>::new(domain, expansion_factor, num_collinearity_checks).unwrap_err();
prop_assert_eq!(FriSetupError::ExpansionFactorMismatch, err);
Expand Down
2 changes: 1 addition & 1 deletion triton-vm/src/shared_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ pub(crate) fn construct_master_base_table(
let padded_height = aet.padded_height();
let fri = stark.derive_fri(padded_height).unwrap();
let max_degree = stark.derive_max_degree(padded_height);
let quotient_domain = Stark::quotient_domain(fri.domain, max_degree);
let quotient_domain = Stark::quotient_domain(fri.domain, max_degree).unwrap();
MasterBaseTable::new(
aet,
stark.num_trace_randomizers,
Expand Down
12 changes: 6 additions & 6 deletions triton-vm/src/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl Stark {
let padded_height = aet.padded_height();
let max_degree = self.derive_max_degree(padded_height);
let fri = self.derive_fri(padded_height)?;
let quotient_domain = Self::quotient_domain(fri.domain, max_degree);
let quotient_domain = Self::quotient_domain(fri.domain, max_degree)?;
proof_stream.enqueue(ProofItem::Log2PaddedHeight(padded_height.ilog2()));
prof_stop!(maybe_profiler, "derive additional parameters");

Expand Down Expand Up @@ -557,14 +557,14 @@ impl Stark {
pub(crate) fn quotient_domain(
fri_domain: ArithmeticDomain,
max_degree: Degree,
) -> ArithmeticDomain {
) -> Result<ArithmeticDomain, ProvingError> {
let maybe_blowup_factor = match cfg!(debug_assertions) {
true => 2,
false => 1,
};
let domain_length = (max_degree as u64).next_power_of_two() as usize;
let domain_length = maybe_blowup_factor * domain_length;
ArithmeticDomain::of_length(domain_length).with_offset(fri_domain.offset)
Ok(ArithmeticDomain::of_length(domain_length)?.with_offset(fri_domain.offset))
}

/// Compute the upper bound to use for the maximum degree the quotients given the length of the
Expand Down Expand Up @@ -595,7 +595,7 @@ impl Stark {
let interpolant_codeword_length = interpolant_degree as usize + 1;
let fri_domain_length = self.fri_expansion_factor * interpolant_codeword_length;
let coset_offset = BFieldElement::generator();
let domain = ArithmeticDomain::of_length(fri_domain_length).with_offset(coset_offset);
let domain = ArithmeticDomain::of_length(fri_domain_length)?.with_offset(coset_offset);

Fri::new(
domain,
Expand Down Expand Up @@ -721,7 +721,7 @@ impl Stark {
prof_stop!(maybe_profiler, "Fiat-Shamir 1");

prof_start!(maybe_profiler, "dequeue ood point and rows", "hash");
let trace_domain_generator = ArithmeticDomain::generator_for_length(padded_height as u64);
let trace_domain_generator = ArithmeticDomain::generator_for_length(padded_height as u64)?;
let out_of_domain_point_curr_row = proof_stream.sample_scalars(1)[0];
let out_of_domain_point_next_row = trace_domain_generator * out_of_domain_point_curr_row;
let out_of_domain_point_curr_row_pow_num_segments =
Expand Down Expand Up @@ -2310,7 +2310,7 @@ pub(crate) mod tests {
#[test]
fn deep_update() {
let domain_length = 1 << 10;
let domain = ArithmeticDomain::of_length(domain_length);
let domain = ArithmeticDomain::of_length(domain_length).unwrap();

let poly_degree = thread_rng().gen_range(2..20);
let low_deg_poly_coeffs: Vec<XFieldElement> = random_elements(poly_degree);
Expand Down
Loading

0 comments on commit 60289eb

Please sign in to comment.