diff --git a/triton-vm/src/aet.rs b/triton-vm/src/aet.rs index 25a9b5085..919794659 100644 --- a/triton-vm/src/aet.rs +++ b/triton-vm/src/aet.rs @@ -98,6 +98,7 @@ impl AlgebraicExecutionTrace { aet } + /// Guaranteed to be a power of two. pub fn padded_height(&self) -> usize { let relevant_table_heights = [ self.program_table_length(), diff --git a/triton-vm/src/arithmetic_domain.rs b/triton-vm/src/arithmetic_domain.rs index 70b31a490..44e81e493 100644 --- a/triton-vm/src/arithmetic_domain.rs +++ b/triton-vm/src/arithmetic_domain.rs @@ -5,6 +5,10 @@ use twenty_first::prelude::*; use twenty_first::shared_math::traits::FiniteField; use twenty_first::shared_math::traits::PrimitiveRootOfUnity; +use crate::error::ArithmeticDomainError; + +type Result = std::result::Result; + #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct ArithmeticDomain { pub offset: BFieldElement, @@ -14,29 +18,35 @@ pub struct ArithmeticDomain { impl ArithmeticDomain { /// Create a new domain with the given length. - /// No offset is applied, but can added through [`with_offset()`](Self::with_offset). - pub fn of_length(length: usize) -> Self { - Self { + /// No offset is applied, but can be added through [`with_offset()`](Self::with_offset). + /// + /// # Errors + /// + /// Errors if the domain length is not a power of 2. + pub fn of_length(length: usize) -> Result { + let domain = Self { offset: BFieldElement::one(), - generator: Self::generator_for_length(length as u64), + generator: Self::generator_for_length(length as u64)?, length, - } + }; + Ok(domain) } /// Set the offset of the domain. + #[must_use] pub fn with_offset(mut self, offset: BFieldElement) -> Self { self.offset = offset; self } /// Derive a generator for a domain of the given length. - /// The domain length must be a power of 2. - pub fn generator_for_length(domain_length: u64) -> BFieldElement { - assert!( - 0 == domain_length || domain_length.is_power_of_two(), - "The domain length must be a power of 2 but was {domain_length}.", - ); - BFieldElement::primitive_root_of_unity(domain_length).unwrap() + /// + /// # Errors + /// + /// Errors if the domain length is not a power of 2. + pub fn generator_for_length(domain_length: u64) -> Result { + let error = ArithmeticDomainError::PrimitiveRootNotSupported(domain_length); + BFieldElement::primitive_root_of_unity(domain_length).ok_or(error) } pub fn evaluate(&self, polynomial: &Polynomial) -> Vec @@ -94,19 +104,22 @@ impl ArithmeticDomain { domain_values } - #[must_use] - pub(crate) fn halve(&self) -> Self { - assert!(self.length >= 2); - Self { + pub(crate) fn halve(&self) -> Result { + if self.length < 2 { + return Err(ArithmeticDomainError::TooSmallForHalving(self.length)); + } + let domain = Self { offset: self.offset.square(), generator: self.generator.square(), length: self.length / 2, - } + }; + Ok(domain) } } #[cfg(test)] mod tests { + use assert2::let_assert; use itertools::Itertools; use proptest::prelude::*; use proptest_arbitrary_interop::arb; @@ -140,7 +153,7 @@ mod tests { fn arbitrary_domain_of_length(length: usize)( offset in arb(), ) -> ArithmeticDomain { - ArithmeticDomain::of_length(length).with_offset(offset) + ArithmeticDomain::of_length(length).unwrap().with_offset(offset) } } @@ -184,7 +197,9 @@ mod tests { for order in [4, 8, 32] { let generator = BFieldElement::primitive_root_of_unity(order).unwrap(); let offset = BFieldElement::generator(); - let b_domain = ArithmeticDomain::of_length(order as usize).with_offset(offset); + let b_domain = ArithmeticDomain::of_length(order as usize) + .unwrap() + .with_offset(offset); let expected_b_values = (0..order) .map(|i| offset * generator.mod_pow(i)) @@ -224,8 +239,8 @@ mod tests { let long_domain_len = 128; let unit_distance = long_domain_len / short_domain_len; - let short_domain = ArithmeticDomain::of_length(short_domain_len); - let long_domain = ArithmeticDomain::of_length(long_domain_len); + let short_domain = ArithmeticDomain::of_length(short_domain_len).unwrap(); + let long_domain = ArithmeticDomain::of_length(long_domain_len).unwrap(); let polynomial = Polynomial::new([1, 2, 3, 4].map(BFieldElement::new).to_vec()); let short_codeword = short_domain.evaluate(&polynomial); @@ -244,7 +259,7 @@ mod tests { fn halving_domain_squares_all_points( #[strategy(arbitrary_halveable_domain())] domain: ArithmeticDomain, ) { - let half_domain = domain.halve(); + let half_domain = domain.halve()?; prop_assert_eq!(domain.length / 2, half_domain.length); let domain_points = domain.domain_values(); @@ -259,16 +274,11 @@ mod tests { } #[test] - #[should_panic] - fn domain_of_length_one_cannot_be_halved() { - let domain = ArithmeticDomain::of_length(1); - let _ = domain.halve(); - } - - #[test] - #[should_panic] - fn domain_of_length_zero_cannot_be_halved() { - let domain = ArithmeticDomain::of_length(0); - let _ = domain.halve(); + fn too_small_domains_cannot_be_halved() { + for i in [0, 1] { + let domain = ArithmeticDomain::of_length(i).unwrap(); + let_assert!(Err(err) = domain.halve()); + assert!(ArithmeticDomainError::TooSmallForHalving(i) == err); + } } } diff --git a/triton-vm/src/error.rs b/triton-vm/src/error.rs index 8a8eab501..185f1bf81 100644 --- a/triton-vm/src/error.rs +++ b/triton-vm/src/error.rs @@ -98,6 +98,16 @@ pub enum InstructionError { MachineHalted, } +#[non_exhaustive] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)] +pub enum ArithmeticDomainError { + #[error("the domain's length must be a power of 2 but was {0}")] + PrimitiveRootNotSupported(u64), + + #[error("the domain's length must be at least 2 to be halved, but it was {0}")] + TooSmallForHalving(usize), +} + #[non_exhaustive] #[derive(Debug, Error)] pub enum ProofStreamError { @@ -131,6 +141,9 @@ pub enum FriSetupError { #[error("the expansion factor must be smaller than the domain length")] ExpansionFactorMismatch, + + #[error(transparent)] + ArithmeticDomainError(#[from] ArithmeticDomainError), } #[non_exhaustive] @@ -138,6 +151,9 @@ pub enum FriSetupError { pub enum FriProvingError { #[error(transparent)] MerkleTreeError(#[from] MerkleTreeError), + + #[error(transparent)] + ArithmeticDomainError(#[from] ArithmeticDomainError), } #[non_exhaustive] @@ -163,6 +179,9 @@ pub enum FriValidationError { #[error(transparent)] MerkleTreeError(#[from] MerkleTreeError), + + #[error(transparent)] + ArithmeticDomainError(#[from] ArithmeticDomainError), } #[non_exhaustive] @@ -227,6 +246,9 @@ pub enum ProvingError { #[error(transparent)] MerkleTreeError(#[from] MerkleTreeError), + #[error(transparent)] + ArithmeticDomainError(#[from] ArithmeticDomainError), + #[error(transparent)] FriSetupError(#[from] FriSetupError), @@ -273,6 +295,9 @@ pub enum VerificationError { #[error(transparent)] ProofStreamError(#[from] ProofStreamError), + #[error(transparent)] + ArithmeticDomainError(#[from] ArithmeticDomainError), + #[error(transparent)] FriSetupError(#[from] FriSetupError), diff --git a/triton-vm/src/fri.rs b/triton-vm/src/fri.rs index e2b55dac6..cf55a421a 100644 --- a/triton-vm/src/fri.rs +++ b/triton-vm/src/fri.rs @@ -87,7 +87,7 @@ impl<'stream, H: AlgebraicHasher> FriProver<'stream, H> { let previous_round = self.rounds.last().unwrap(); let folding_challenge = self.proof_stream.sample_scalars(1)[0]; let codeword = previous_round.split_and_fold(folding_challenge); - let domain = previous_round.domain.halve(); + let domain = previous_round.domain.halve()?; ProverRound::new(domain, &codeword) } @@ -243,7 +243,7 @@ impl<'stream, H: AlgebraicHasher> FriVerifier<'stream, H> { fn construct_next_round(&mut self) -> VerifierResult { let previous_round = self.rounds.last().unwrap(); - let domain = previous_round.domain.halve(); + let domain = previous_round.domain.halve()?; self.construct_round_with_domain(domain) } @@ -701,7 +701,9 @@ mod tests { let min_expanded_domain_length = min_domain_length * expansion_factor; let domain_length = max(sampled_domain_length, min_expanded_domain_length); - let fri_domain = ArithmeticDomain::of_length(domain_length).with_offset(offset); + let maybe_domain = ArithmeticDomain::of_length(domain_length); + let fri_domain = maybe_domain.unwrap().with_offset(offset); + Fri::new(fri_domain, expansion_factor, num_collinearity_checks).unwrap() } } @@ -798,7 +800,7 @@ mod tests { } fn smallest_fri() -> Fri { - let domain = ArithmeticDomain::of_length(2); + let domain = ArithmeticDomain::of_length(2).unwrap(); let expansion_factor = 2; let num_collinearity_checks = 1; Fri::new(domain, expansion_factor, num_collinearity_checks).unwrap() @@ -806,7 +808,7 @@ mod tests { #[test] fn too_small_expansion_factor_is_rejected() { - let domain = ArithmeticDomain::of_length(2); + let domain = ArithmeticDomain::of_length(2).unwrap(); let expansion_factor = 1; let num_collinearity_checks = 1; let err = Fri::::new(domain, expansion_factor, num_collinearity_checks).unwrap_err(); @@ -820,7 +822,7 @@ mod tests { expansion_factor: usize, ) { let largest_supported_domain_size = 1 << 32; - let domain = ArithmeticDomain::of_length(largest_supported_domain_size); + let domain = ArithmeticDomain::of_length(largest_supported_domain_size).unwrap(); let num_collinearity_checks = 1; let err = Fri::::new(domain, expansion_factor, num_collinearity_checks).unwrap_err(); prop_assert_eq!(FriSetupError::ExpansionFactorUnsupported, err); @@ -833,7 +835,7 @@ mod tests { ) { let expansion_factor = 1 << log_2_expansion_factor; let domain_length = 1 << log_2_domain_length; - let domain = ArithmeticDomain::of_length(domain_length); + let domain = ArithmeticDomain::of_length(domain_length).unwrap(); let num_collinearity_checks = 1; let err = Fri::::new(domain, expansion_factor, num_collinearity_checks).unwrap_err(); prop_assert_eq!(FriSetupError::ExpansionFactorMismatch, err); diff --git a/triton-vm/src/shared_tests.rs b/triton-vm/src/shared_tests.rs index 6af8048e3..bda0cac8a 100644 --- a/triton-vm/src/shared_tests.rs +++ b/triton-vm/src/shared_tests.rs @@ -151,7 +151,7 @@ pub(crate) fn construct_master_base_table( let padded_height = aet.padded_height(); let fri = stark.derive_fri(padded_height).unwrap(); let max_degree = stark.derive_max_degree(padded_height); - let quotient_domain = Stark::quotient_domain(fri.domain, max_degree); + let quotient_domain = Stark::quotient_domain(fri.domain, max_degree).unwrap(); MasterBaseTable::new( aet, stark.num_trace_randomizers, diff --git a/triton-vm/src/stark.rs b/triton-vm/src/stark.rs index 5e7775475..816c14a5a 100644 --- a/triton-vm/src/stark.rs +++ b/triton-vm/src/stark.rs @@ -122,7 +122,7 @@ impl Stark { let padded_height = aet.padded_height(); let max_degree = self.derive_max_degree(padded_height); let fri = self.derive_fri(padded_height)?; - let quotient_domain = Self::quotient_domain(fri.domain, max_degree); + let quotient_domain = Self::quotient_domain(fri.domain, max_degree)?; proof_stream.enqueue(ProofItem::Log2PaddedHeight(padded_height.ilog2())); prof_stop!(maybe_profiler, "derive additional parameters"); @@ -557,14 +557,14 @@ impl Stark { pub(crate) fn quotient_domain( fri_domain: ArithmeticDomain, max_degree: Degree, - ) -> ArithmeticDomain { + ) -> Result { let maybe_blowup_factor = match cfg!(debug_assertions) { true => 2, false => 1, }; let domain_length = (max_degree as u64).next_power_of_two() as usize; let domain_length = maybe_blowup_factor * domain_length; - ArithmeticDomain::of_length(domain_length).with_offset(fri_domain.offset) + Ok(ArithmeticDomain::of_length(domain_length)?.with_offset(fri_domain.offset)) } /// Compute the upper bound to use for the maximum degree the quotients given the length of the @@ -595,7 +595,7 @@ impl Stark { let interpolant_codeword_length = interpolant_degree as usize + 1; let fri_domain_length = self.fri_expansion_factor * interpolant_codeword_length; let coset_offset = BFieldElement::generator(); - let domain = ArithmeticDomain::of_length(fri_domain_length).with_offset(coset_offset); + let domain = ArithmeticDomain::of_length(fri_domain_length)?.with_offset(coset_offset); Fri::new( domain, @@ -721,7 +721,7 @@ impl Stark { prof_stop!(maybe_profiler, "Fiat-Shamir 1"); prof_start!(maybe_profiler, "dequeue ood point and rows", "hash"); - let trace_domain_generator = ArithmeticDomain::generator_for_length(padded_height as u64); + let trace_domain_generator = ArithmeticDomain::generator_for_length(padded_height as u64)?; let out_of_domain_point_curr_row = proof_stream.sample_scalars(1)[0]; let out_of_domain_point_next_row = trace_domain_generator * out_of_domain_point_curr_row; let out_of_domain_point_curr_row_pow_num_segments = @@ -2310,7 +2310,7 @@ pub(crate) mod tests { #[test] fn deep_update() { let domain_length = 1 << 10; - let domain = ArithmeticDomain::of_length(domain_length); + let domain = ArithmeticDomain::of_length(domain_length).unwrap(); let poly_degree = thread_rng().gen_range(2..20); let low_deg_poly_coeffs: Vec = random_elements(poly_degree); diff --git a/triton-vm/src/table/master_table.rs b/triton-vm/src/table/master_table.rs index 83a784e77..549bec657 100644 --- a/triton-vm/src/table/master_table.rs +++ b/triton-vm/src/table/master_table.rs @@ -579,11 +579,12 @@ impl MasterBaseTable { fri_domain: ArithmeticDomain, ) -> Self { let padded_height = aet.padded_height(); - let trace_domain = ArithmeticDomain::of_length(padded_height); + let trace_domain = ArithmeticDomain::of_length(padded_height).unwrap(); let randomized_padded_trace_len = randomized_padded_trace_len(padded_height, num_trace_randomizers); - let randomized_trace_domain = ArithmeticDomain::of_length(randomized_padded_trace_len); + let randomized_trace_domain = + ArithmeticDomain::of_length(randomized_padded_trace_len).unwrap(); let num_rows = randomized_padded_trace_len; let num_columns = NUM_BASE_COLUMNS; @@ -1127,6 +1128,7 @@ pub fn num_quotients() -> usize { + MasterExtTable::num_terminal_quotients() } +/// Guaranteed to be a power of two. pub fn randomized_padded_trace_len(padded_height: usize, num_trace_randomizers: usize) -> usize { let total_table_length = padded_height + num_trace_randomizers; total_table_length.next_power_of_two() @@ -1265,10 +1267,12 @@ mod tests { fn zerofiers_are_correct() { let big_order = 16; let big_offset = BFieldElement::generator(); - let big_domain = ArithmeticDomain::of_length(big_order as usize).with_offset(big_offset); + let big_domain = ArithmeticDomain::of_length(big_order as usize) + .unwrap() + .with_offset(big_offset); let small_order = 8; - let small_domain = ArithmeticDomain::of_length(small_order as usize); + let small_domain = ArithmeticDomain::of_length(small_order as usize).unwrap(); let initial_zerofier_inv = initial_quotient_zerofier_inverse(big_domain); let initial_zerofier = BFieldElement::batch_inversion(initial_zerofier_inv.to_vec()); @@ -1530,10 +1534,10 @@ mod tests { #[test] fn master_ext_table_mut() { - let trace_domain = ArithmeticDomain::of_length(1 << 8); - let randomized_trace_domain = ArithmeticDomain::of_length(1 << 9); - let quotient_domain = ArithmeticDomain::of_length(1 << 10); - let fri_domain = ArithmeticDomain::of_length(1 << 11); + let trace_domain = ArithmeticDomain::of_length(1 << 8).unwrap(); + let randomized_trace_domain = ArithmeticDomain::of_length(1 << 9).unwrap(); + let quotient_domain = ArithmeticDomain::of_length(1 << 10).unwrap(); + let fri_domain = ArithmeticDomain::of_length(1 << 11).unwrap(); let randomized_trace_table = Array2::zeros((randomized_trace_domain.length, NUM_EXT_COLUMNS));