Skip to content

Commit

Permalink
use quotient domain instead of FRI domain wherever applicable
Browse files Browse the repository at this point in the history
- in low-degree extension (LDE):
    - use optimally sized subgroup for fast interpolation of columns
    - drop unnecessary parameter
    - perform LDE over two domains: quotient and FRI
- use codewords over quotient domain to perform most computation
- only at the very end, LDE combination codeword to FRI domain
- rename `FriDomain` to `ArithmeticDomain`
- only ever interpolate and evaluate over `BFieldElements`
- remove duplicate FRI domain in struct `Stark`.
- change variable names and wording:
    - “omicron” is now “trace domain generator”
    - “omega” is now “FRI domain generator”
  • Loading branch information
jan-ferdinand authored Nov 18, 2022
2 parents e7fd6cc + 465d42d commit 776fa19
Show file tree
Hide file tree
Showing 15 changed files with 624 additions and 576 deletions.
76 changes: 41 additions & 35 deletions triton-vm/src/fri_domain.rs → triton-vm/src/arithmetic_domain.rs
Original file line number Diff line number Diff line change
@@ -1,96 +1,102 @@
use std::marker::PhantomData;
use std::ops::MulAssign;
use std::ops::{Mul, MulAssign};

use twenty_first::shared_math::b_field_element::BFieldElement;
use twenty_first::shared_math::polynomial::Polynomial;
use twenty_first::shared_math::traits::{FiniteField, ModPowU32};

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FriDomain<FF> {
pub struct ArithmeticDomain<FF> {
pub offset: BFieldElement,
pub omega: BFieldElement,
pub generator: BFieldElement,
pub length: usize,
_field: PhantomData<FF>,
_finite_field: PhantomData<FF>,
}

impl<FF> FriDomain<FF>
impl<FF> ArithmeticDomain<FF>
where
FF: FiniteField + From<BFieldElement> + MulAssign<BFieldElement>,
FF: FiniteField
+ From<BFieldElement>
+ Mul<BFieldElement, Output = FF>
+ MulAssign<BFieldElement>,
{
pub fn new(offset: BFieldElement, omega: BFieldElement, length: usize) -> Self {
let _field = PhantomData;
pub fn new(offset: BFieldElement, generator: BFieldElement, length: usize) -> Self {
Self {
offset,
omega,
generator,
length,
_field,
_finite_field: PhantomData,
}
}

pub fn evaluate(&self, polynomial: &Polynomial<FF>) -> Vec<FF> {
polynomial.fast_coset_evaluate(&self.offset, self.omega, self.length)
pub fn evaluate<GF>(&self, polynomial: &Polynomial<GF>) -> Vec<GF>
where
GF: FiniteField + From<BFieldElement> + MulAssign<BFieldElement>,
{
polynomial.fast_coset_evaluate(&self.offset, self.generator, self.length)
}

pub fn interpolate(&self, values: &[FF]) -> Polynomial<FF> {
Polynomial::<FF>::fast_coset_interpolate(&self.offset, self.omega, values)
pub fn interpolate<GF>(&self, values: &[GF]) -> Polynomial<GF>
where
GF: FiniteField + From<BFieldElement> + MulAssign<BFieldElement>,
{
Polynomial::<GF>::fast_coset_interpolate(&self.offset, self.generator, values)
}

pub fn domain_value(&self, index: u32) -> FF {
let domain_value: BFieldElement = self.omega.mod_pow_u32(index) * self.offset;
let domain_value = self.generator.mod_pow_u32(index) * self.offset;
domain_value.into()
}

pub fn domain_values(&self) -> Vec<FF> {
let mut res = Vec::with_capacity(self.length);
let mut acc = FF::one();
let mut accumulator = FF::one();
let mut domain_values = Vec::with_capacity(self.length);

for _ in 0..self.length {
let domain_value = {
let mut tmp = acc;
tmp *= self.offset;
tmp
};
res.push(domain_value);
acc *= self.omega;
domain_values.push(accumulator * self.offset);
accumulator *= self.generator;
}

res
domain_values
}
}

#[cfg(test)]
mod fri_domain_tests {
use super::*;
mod domain_tests {
use itertools::Itertools;
use twenty_first::shared_math::b_field_element::BFieldElement;
use twenty_first::shared_math::traits::PrimitiveRootOfUnity;
use twenty_first::shared_math::x_field_element::XFieldElement;

use super::*;

#[test]
fn domain_values_test() {
// f(x) = x^3
let x_squared_coefficients = vec![0u64.into(), 0u64.into(), 0u64.into(), 1u64.into()];
let poly = Polynomial::<BFieldElement>::new(x_squared_coefficients.clone());

for order in [4, 8, 32] {
let omega = BFieldElement::primitive_root_of_unity(order).unwrap();
let generator = BFieldElement::primitive_root_of_unity(order).unwrap();
let offset = BFieldElement::generator();
let b_domain = FriDomain::<BFieldElement>::new(offset, omega, order as usize);
let x_domain = FriDomain::<XFieldElement>::new(offset, omega, order as usize);
let b_domain =
ArithmeticDomain::<BFieldElement>::new(offset, generator, order as usize);
let x_domain =
ArithmeticDomain::<XFieldElement>::new(offset, generator, order as usize);

let expected_b_values: Vec<BFieldElement> =
(0..order).map(|i| offset * omega.mod_pow(i)).collect();
(0..order).map(|i| offset * generator.mod_pow(i)).collect();
let actual_b_values_1 = b_domain.domain_values();
let actual_b_values_2 = (0..order as u32)
.map(|i| b_domain.domain_value(i))
.collect_vec();
assert_eq!(
expected_b_values, actual_b_values_1,
"domain_values() generates the FRI domain BFieldElement values"
"domain_values() generates the arithmetic domain's BFieldElement values"
);
assert_eq!(
expected_b_values, actual_b_values_2,
"domain_value() generates the given FRI domain BFieldElement value"
"domain_value() generates the given domain BFieldElement value"
);

let expected_x_values: Vec<XFieldElement> =
Expand All @@ -101,11 +107,11 @@ mod fri_domain_tests {
.collect_vec();
assert_eq!(
expected_x_values, actual_x_values_1,
"domain_values() generates the FRI domain XFieldElement values"
"domain_values() generates the arithmetic domain's XFieldElement values"
);
assert_eq!(
expected_x_values, actual_x_values_2,
"domain_value() generates the given FRI domain XFieldElement values"
"domain_value() generates the given domain XFieldElement values"
);

let values = b_domain.evaluate(&poly);
Expand Down
30 changes: 16 additions & 14 deletions triton-vm/src/cross_table_arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use twenty_first::shared_math::mpolynomial::Degree;
use twenty_first::shared_math::traits::{FiniteField, Inverse};
use twenty_first::shared_math::x_field_element::XFieldElement;

use crate::fri_domain::FriDomain;
use crate::arithmetic_domain::ArithmeticDomain;
use crate::table::processor_table::PROCESSOR_TABLE_NUM_PERMUTATION_ARGUMENTS;
use crate::table::table_collection::TableId::{
HashTable, InstructionTable, JumpStackTable, OpStackTable, ProcessorTable, ProgramTable,
Expand Down Expand Up @@ -44,18 +44,19 @@ pub trait CrossTableArg {
fn terminal_quotient(
&self,
ext_codeword_tables: &ExtTableCollection,
fri_domain: &FriDomain<XFieldElement>,
omicron: XFieldElement,
quotient_domain: &ArithmeticDomain<BFieldElement>,
trace_domain_generator: BFieldElement,
) -> Vec<XFieldElement> {
let from_codeword = self.combined_from_codeword(ext_codeword_tables);
let to_codeword = self.combined_to_codeword(ext_codeword_tables);

let zerofier = fri_domain
let trace_domain_generator_inverse = trace_domain_generator.inverse();
let zerofier = quotient_domain
.domain_values()
.into_iter()
.map(|x| x - omicron.inverse())
.map(|x| x - trace_domain_generator_inverse)
.collect();
let zerofier_inverse = XFieldElement::batch_inversion(zerofier);
let zerofier_inverse = BFieldElement::batch_inversion(zerofier);

zerofier_inverse
.into_iter()
Expand Down Expand Up @@ -453,10 +454,10 @@ impl GrandCrossTableArg {
pub fn terminal_quotient_codeword(
&self,
ext_codeword_tables: &ExtTableCollection,
fri_domain: &FriDomain<XFieldElement>,
omicron: XFieldElement,
quotient_domain: &ArithmeticDomain<BFieldElement>,
trace_domain_generator: BFieldElement,
) -> Vec<XFieldElement> {
let mut non_linear_sum_codeword = vec![XFieldElement::zero(); fri_domain.length];
let mut non_linear_sum_codeword = vec![XFieldElement::zero(); quotient_domain.length];

// cross-table arguments
for (arg, weight) in self.into_iter() {
Expand All @@ -472,7 +473,7 @@ impl GrandCrossTableArg {
}

// standard input
let input_terminal_codeword = vec![self.input_terminal; fri_domain.length];
let input_terminal_codeword = vec![self.input_terminal; quotient_domain.length];
let (to_table, to_column) = self.input_to_processor;
let to_codeword = &ext_codeword_tables.data(to_table)[to_column];
let weight = self.input_to_processor_weight;
Expand All @@ -487,7 +488,7 @@ impl GrandCrossTableArg {
// standard output
let (from_table, from_column) = self.processor_to_output;
let from_codeword = &ext_codeword_tables.data(from_table)[from_column];
let output_terminal_codeword = vec![self.output_terminal; fri_domain.length];
let output_terminal_codeword = vec![self.output_terminal; quotient_domain.length];
let weight = self.processor_to_output_weight;
let non_linear_summand =
weighted_difference_codeword(from_codeword, &output_terminal_codeword, weight);
Expand All @@ -497,12 +498,13 @@ impl GrandCrossTableArg {
XFieldElement::add,
);

let zerofier = fri_domain
let trace_domain_generator_inverse = trace_domain_generator.inverse();
let zerofier = quotient_domain
.domain_values()
.into_iter()
.map(|x| x - omicron.inverse())
.map(|x| x - trace_domain_generator_inverse)
.collect();
let zerofier_inverse = XFieldElement::batch_inversion(zerofier);
let zerofier_inverse = BFieldElement::batch_inversion(zerofier);

zerofier_inverse
.into_iter()
Expand Down
28 changes: 15 additions & 13 deletions triton-vm/src/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use twenty_first::shared_math::x_field_element::XFieldElement;
use twenty_first::util_types::algebraic_hasher::{AlgebraicHasher, Hashable};
use twenty_first::util_types::merkle_tree::{MerkleTree, PartialAuthenticationPath};

use crate::fri_domain::FriDomain;
use crate::arithmetic_domain::ArithmeticDomain;
use crate::proof_item::{FriResponse, ProofItem};
use crate::proof_stream::ProofStream;

Expand Down Expand Up @@ -50,19 +50,19 @@ pub struct Fri<H> {
// nearest power of 2.
pub expansion_factor: usize,
pub colinearity_checks_count: usize,
pub domain: FriDomain<XFieldElement>,
pub domain: ArithmeticDomain<BFieldElement>,
_hasher: PhantomData<H>,
}

impl<H: AlgebraicHasher> Fri<H> {
pub fn new(
offset: BFieldElement,
omega: BFieldElement,
fri_domain_generator: BFieldElement,
domain_length: usize,
expansion_factor: usize,
colinearity_checks_count: usize,
) -> Self {
let domain = FriDomain::new(offset, omega, domain_length);
let domain = ArithmeticDomain::new(offset, fri_domain_generator, domain_length);
let _hasher = PhantomData;
Self {
domain,
Expand Down Expand Up @@ -166,7 +166,7 @@ impl<H: AlgebraicHasher> Fri<H> {
codeword: &[XFieldElement],
proof_stream: &mut ProofStream<ProofItem, H>,
) -> Result<Vec<(Vec<XFieldElement>, MerkleTree<H>)>, Box<dyn Error>> {
let mut subgroup_generator = self.domain.omega;
let mut subgroup_generator = self.domain.generator;
let mut offset = self.domain.offset;
let mut codeword_local = codeword.to_vec();

Expand Down Expand Up @@ -346,9 +346,11 @@ impl<H: AlgebraicHasher> Fri<H> {
let log_2_of_n = log_2_floor(last_codeword.len() as u128) as u32;
let mut last_polynomial = last_codeword.clone();

// XXX
let last_omega = self.domain.omega.mod_pow_u32(2u32.pow(num_rounds as u32));
intt::<XFieldElement>(&mut last_polynomial, last_omega, log_2_of_n);
let last_fri_domain_generator = self
.domain
.generator
.mod_pow_u32(2u32.pow(num_rounds as u32));
intt::<XFieldElement>(&mut last_polynomial, last_fri_domain_generator, log_2_of_n);

let last_poly_degree: isize = (Polynomial::<XFieldElement> {
coefficients: last_polynomial,
Expand Down Expand Up @@ -445,7 +447,7 @@ impl<H: AlgebraicHasher> Fri<H> {
/// FRI (co-)domain. This corresponds to `ω^i` in `f(ω^i)` from
/// [STARK-Anatomy](https://neptune.cash/learn/stark-anatomy/fri/#split-and-fold).
fn get_evaluation_argument(&self, idx: usize, round: usize) -> XFieldElement {
let domain_value = self.domain.offset * self.domain.omega.mod_pow_u32(idx as u32);
let domain_value = self.domain.offset * self.domain.generator.mod_pow_u32(idx as u32);
let round_exponent = 2u32.pow(round as u32);
let evaluation_argument = domain_value.mod_pow_u32(round_exponent);

Expand Down Expand Up @@ -574,7 +576,7 @@ mod triton_xfri_tests {
let fri: Fri<Hasher> =
get_x_field_fri_test_object(subgroup_order, expansion_factor, colinearity_check_count);
let mut proof_stream: ProofStream<ProofItem, Hasher> = ProofStream::new();
let subgroup = fri.domain.omega.lift().get_cyclic_group_elements(None);
let subgroup = fri.domain.generator.lift().get_cyclic_group_elements(None);

let (_, merkle_root_of_round_0) = fri.prove(&subgroup, &mut proof_stream).unwrap();
let verdict = fri.verify(&mut proof_stream, &merkle_root_of_round_0, &mut None);
Expand Down Expand Up @@ -616,7 +618,7 @@ mod triton_xfri_tests {
let colinearity_check_count = 6;
let fri: Fri<Hasher> =
get_x_field_fri_test_object(subgroup_order, expansion_factor, colinearity_check_count);
let subgroup = fri.domain.omega.lift().get_cyclic_group_elements(None);
let subgroup = fri.domain.generator.lift().get_cyclic_group_elements(None);

let mut points: Vec<XFieldElement>;
for n in [1, 5, 20, 30, 31] {
Expand Down Expand Up @@ -664,7 +666,7 @@ mod triton_xfri_tests {
expansion_factor: usize,
colinearity_checks: usize,
) -> Fri<H> {
let omega = BFieldElement::primitive_root_of_unity(subgroup_order).unwrap();
let fri_domain_generator = BFieldElement::primitive_root_of_unity(subgroup_order).unwrap();

// The following offset was picked arbitrarily by copying the one found in
// `get_b_field_fri_test_object`. It does not generate the full Z_p\{0}, but
Expand All @@ -673,7 +675,7 @@ mod triton_xfri_tests {

let fri: Fri<H> = Fri::new(
offset,
omega,
fri_domain_generator,
subgroup_order as usize,
expansion_factor,
colinearity_checks,
Expand Down
2 changes: 1 addition & 1 deletion triton-vm/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
pub mod arithmetic_domain;
pub mod bfield_codec;
pub mod cross_table_arguments;
pub mod error;
pub mod fri;
pub mod fri_domain;
pub mod instruction;
pub mod op_stack;
pub mod ord_n;
Expand Down
Loading

0 comments on commit 776fa19

Please sign in to comment.