Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use quotient domain instead of fri domain wherever applicable #124

Merged
merged 11 commits into from
Nov 18, 2022
Merged
76 changes: 41 additions & 35 deletions triton-vm/src/fri_domain.rs → triton-vm/src/arithmetic_domain.rs
Original file line number Diff line number Diff line change
@@ -1,96 +1,102 @@
use std::marker::PhantomData;
use std::ops::MulAssign;
use std::ops::{Mul, MulAssign};

use twenty_first::shared_math::b_field_element::BFieldElement;
use twenty_first::shared_math::polynomial::Polynomial;
use twenty_first::shared_math::traits::{FiniteField, ModPowU32};

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FriDomain<FF> {
pub struct ArithmeticDomain<FF> {
pub offset: BFieldElement,
pub omega: BFieldElement,
pub generator: BFieldElement,
pub length: usize,
_field: PhantomData<FF>,
_finite_field: PhantomData<FF>,
}

impl<FF> FriDomain<FF>
impl<FF> ArithmeticDomain<FF>
where
FF: FiniteField + From<BFieldElement> + MulAssign<BFieldElement>,
FF: FiniteField
+ From<BFieldElement>
+ Mul<BFieldElement, Output = FF>
+ MulAssign<BFieldElement>,
{
pub fn new(offset: BFieldElement, omega: BFieldElement, length: usize) -> Self {
let _field = PhantomData;
pub fn new(offset: BFieldElement, generator: BFieldElement, length: usize) -> Self {
Self {
offset,
omega,
generator,
length,
_field,
_finite_field: PhantomData,
}
}

pub fn evaluate(&self, polynomial: &Polynomial<FF>) -> Vec<FF> {
polynomial.fast_coset_evaluate(&self.offset, self.omega, self.length)
pub fn evaluate<GF>(&self, polynomial: &Polynomial<GF>) -> Vec<GF>
where
GF: FiniteField + From<BFieldElement> + MulAssign<BFieldElement>,
{
polynomial.fast_coset_evaluate(&self.offset, self.generator, self.length)
}

pub fn interpolate(&self, values: &[FF]) -> Polynomial<FF> {
Polynomial::<FF>::fast_coset_interpolate(&self.offset, self.omega, values)
pub fn interpolate<GF>(&self, values: &[GF]) -> Polynomial<GF>
where
GF: FiniteField + From<BFieldElement> + MulAssign<BFieldElement>,
{
Polynomial::<GF>::fast_coset_interpolate(&self.offset, self.generator, values)
}

pub fn domain_value(&self, index: u32) -> FF {
let domain_value: BFieldElement = self.omega.mod_pow_u32(index) * self.offset;
let domain_value = self.generator.mod_pow_u32(index) * self.offset;
domain_value.into()
}

pub fn domain_values(&self) -> Vec<FF> {
let mut res = Vec::with_capacity(self.length);
let mut acc = FF::one();
let mut accumulator = FF::one();
let mut domain_values = Vec::with_capacity(self.length);

for _ in 0..self.length {
let domain_value = {
let mut tmp = acc;
tmp *= self.offset;
tmp
};
res.push(domain_value);
acc *= self.omega;
domain_values.push(accumulator * self.offset);
accumulator *= self.generator;
}

res
domain_values
}
}

#[cfg(test)]
mod fri_domain_tests {
use super::*;
mod domain_tests {
use itertools::Itertools;
use twenty_first::shared_math::b_field_element::BFieldElement;
use twenty_first::shared_math::traits::PrimitiveRootOfUnity;
use twenty_first::shared_math::x_field_element::XFieldElement;

use super::*;

#[test]
fn domain_values_test() {
// f(x) = x^3
let x_squared_coefficients = vec![0u64.into(), 0u64.into(), 0u64.into(), 1u64.into()];
let poly = Polynomial::<BFieldElement>::new(x_squared_coefficients.clone());

for order in [4, 8, 32] {
let omega = BFieldElement::primitive_root_of_unity(order).unwrap();
let generator = BFieldElement::primitive_root_of_unity(order).unwrap();
let offset = BFieldElement::generator();
let b_domain = FriDomain::<BFieldElement>::new(offset, omega, order as usize);
let x_domain = FriDomain::<XFieldElement>::new(offset, omega, order as usize);
let b_domain =
ArithmeticDomain::<BFieldElement>::new(offset, generator, order as usize);
let x_domain =
ArithmeticDomain::<XFieldElement>::new(offset, generator, order as usize);

let expected_b_values: Vec<BFieldElement> =
(0..order).map(|i| offset * omega.mod_pow(i)).collect();
(0..order).map(|i| offset * generator.mod_pow(i)).collect();
let actual_b_values_1 = b_domain.domain_values();
let actual_b_values_2 = (0..order as u32)
.map(|i| b_domain.domain_value(i))
.collect_vec();
assert_eq!(
expected_b_values, actual_b_values_1,
"domain_values() generates the FRI domain BFieldElement values"
"domain_values() generates the arithmetic domain's BFieldElement values"
);
assert_eq!(
expected_b_values, actual_b_values_2,
"domain_value() generates the given FRI domain BFieldElement value"
"domain_value() generates the given domain BFieldElement value"
);

let expected_x_values: Vec<XFieldElement> =
Expand All @@ -101,11 +107,11 @@ mod fri_domain_tests {
.collect_vec();
assert_eq!(
expected_x_values, actual_x_values_1,
"domain_values() generates the FRI domain XFieldElement values"
"domain_values() generates the arithmetic domain's XFieldElement values"
);
assert_eq!(
expected_x_values, actual_x_values_2,
"domain_value() generates the given FRI domain XFieldElement values"
"domain_value() generates the given domain XFieldElement values"
);

let values = b_domain.evaluate(&poly);
Expand Down
30 changes: 16 additions & 14 deletions triton-vm/src/cross_table_arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use twenty_first::shared_math::mpolynomial::Degree;
use twenty_first::shared_math::traits::{FiniteField, Inverse};
use twenty_first::shared_math::x_field_element::XFieldElement;

use crate::fri_domain::FriDomain;
use crate::arithmetic_domain::ArithmeticDomain;
use crate::table::processor_table::PROCESSOR_TABLE_NUM_PERMUTATION_ARGUMENTS;
use crate::table::table_collection::TableId::{
HashTable, InstructionTable, JumpStackTable, OpStackTable, ProcessorTable, ProgramTable,
Expand Down Expand Up @@ -44,18 +44,19 @@ pub trait CrossTableArg {
fn terminal_quotient(
&self,
ext_codeword_tables: &ExtTableCollection,
fri_domain: &FriDomain<XFieldElement>,
omicron: XFieldElement,
quotient_domain: &ArithmeticDomain<BFieldElement>,
trace_domain_generator: BFieldElement,
) -> Vec<XFieldElement> {
let from_codeword = self.combined_from_codeword(ext_codeword_tables);
let to_codeword = self.combined_to_codeword(ext_codeword_tables);

let zerofier = fri_domain
let trace_domain_generator_inverse = trace_domain_generator.inverse();
let zerofier = quotient_domain
.domain_values()
.into_iter()
.map(|x| x - omicron.inverse())
.map(|x| x - trace_domain_generator_inverse)
.collect();
let zerofier_inverse = XFieldElement::batch_inversion(zerofier);
let zerofier_inverse = BFieldElement::batch_inversion(zerofier);

zerofier_inverse
.into_iter()
Expand Down Expand Up @@ -453,10 +454,10 @@ impl GrandCrossTableArg {
pub fn terminal_quotient_codeword(
&self,
ext_codeword_tables: &ExtTableCollection,
fri_domain: &FriDomain<XFieldElement>,
omicron: XFieldElement,
quotient_domain: &ArithmeticDomain<BFieldElement>,
trace_domain_generator: BFieldElement,
) -> Vec<XFieldElement> {
let mut non_linear_sum_codeword = vec![XFieldElement::zero(); fri_domain.length];
let mut non_linear_sum_codeword = vec![XFieldElement::zero(); quotient_domain.length];

// cross-table arguments
for (arg, weight) in self.into_iter() {
Expand All @@ -472,7 +473,7 @@ impl GrandCrossTableArg {
}

// standard input
let input_terminal_codeword = vec![self.input_terminal; fri_domain.length];
let input_terminal_codeword = vec![self.input_terminal; quotient_domain.length];
let (to_table, to_column) = self.input_to_processor;
let to_codeword = &ext_codeword_tables.data(to_table)[to_column];
let weight = self.input_to_processor_weight;
Expand All @@ -487,7 +488,7 @@ impl GrandCrossTableArg {
// standard output
let (from_table, from_column) = self.processor_to_output;
let from_codeword = &ext_codeword_tables.data(from_table)[from_column];
let output_terminal_codeword = vec![self.output_terminal; fri_domain.length];
let output_terminal_codeword = vec![self.output_terminal; quotient_domain.length];
let weight = self.processor_to_output_weight;
let non_linear_summand =
weighted_difference_codeword(from_codeword, &output_terminal_codeword, weight);
Expand All @@ -497,12 +498,13 @@ impl GrandCrossTableArg {
XFieldElement::add,
);

let zerofier = fri_domain
let trace_domain_generator_inverse = trace_domain_generator.inverse();
let zerofier = quotient_domain
.domain_values()
.into_iter()
.map(|x| x - omicron.inverse())
.map(|x| x - trace_domain_generator_inverse)
.collect();
let zerofier_inverse = XFieldElement::batch_inversion(zerofier);
let zerofier_inverse = BFieldElement::batch_inversion(zerofier);

zerofier_inverse
.into_iter()
Expand Down
28 changes: 15 additions & 13 deletions triton-vm/src/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use twenty_first::shared_math::x_field_element::XFieldElement;
use twenty_first::util_types::algebraic_hasher::{AlgebraicHasher, Hashable};
use twenty_first::util_types::merkle_tree::{MerkleTree, PartialAuthenticationPath};

use crate::fri_domain::FriDomain;
use crate::arithmetic_domain::ArithmeticDomain;
use crate::proof_item::{FriResponse, ProofItem};
use crate::proof_stream::ProofStream;

Expand Down Expand Up @@ -50,19 +50,19 @@ pub struct Fri<H> {
// nearest power of 2.
pub expansion_factor: usize,
pub colinearity_checks_count: usize,
pub domain: FriDomain<XFieldElement>,
pub domain: ArithmeticDomain<BFieldElement>,
_hasher: PhantomData<H>,
}

impl<H: AlgebraicHasher> Fri<H> {
pub fn new(
offset: BFieldElement,
omega: BFieldElement,
fri_domain_generator: BFieldElement,
domain_length: usize,
expansion_factor: usize,
colinearity_checks_count: usize,
) -> Self {
let domain = FriDomain::new(offset, omega, domain_length);
let domain = ArithmeticDomain::new(offset, fri_domain_generator, domain_length);
let _hasher = PhantomData;
Self {
domain,
Expand Down Expand Up @@ -166,7 +166,7 @@ impl<H: AlgebraicHasher> Fri<H> {
codeword: &[XFieldElement],
proof_stream: &mut ProofStream<ProofItem, H>,
) -> Result<Vec<(Vec<XFieldElement>, MerkleTree<H>)>, Box<dyn Error>> {
let mut subgroup_generator = self.domain.omega;
let mut subgroup_generator = self.domain.generator;
let mut offset = self.domain.offset;
let mut codeword_local = codeword.to_vec();

Expand Down Expand Up @@ -346,9 +346,11 @@ impl<H: AlgebraicHasher> Fri<H> {
let log_2_of_n = log_2_floor(last_codeword.len() as u128) as u32;
let mut last_polynomial = last_codeword.clone();

// XXX
let last_omega = self.domain.omega.mod_pow_u32(2u32.pow(num_rounds as u32));
intt::<XFieldElement>(&mut last_polynomial, last_omega, log_2_of_n);
let last_fri_domain_generator = self
.domain
.generator
.mod_pow_u32(2u32.pow(num_rounds as u32));
intt::<XFieldElement>(&mut last_polynomial, last_fri_domain_generator, log_2_of_n);

let last_poly_degree: isize = (Polynomial::<XFieldElement> {
coefficients: last_polynomial,
Expand Down Expand Up @@ -445,7 +447,7 @@ impl<H: AlgebraicHasher> Fri<H> {
/// FRI (co-)domain. This corresponds to `ω^i` in `f(ω^i)` from
/// [STARK-Anatomy](https://neptune.cash/learn/stark-anatomy/fri/#split-and-fold).
fn get_evaluation_argument(&self, idx: usize, round: usize) -> XFieldElement {
let domain_value = self.domain.offset * self.domain.omega.mod_pow_u32(idx as u32);
let domain_value = self.domain.offset * self.domain.generator.mod_pow_u32(idx as u32);
let round_exponent = 2u32.pow(round as u32);
let evaluation_argument = domain_value.mod_pow_u32(round_exponent);

Expand Down Expand Up @@ -574,7 +576,7 @@ mod triton_xfri_tests {
let fri: Fri<Hasher> =
get_x_field_fri_test_object(subgroup_order, expansion_factor, colinearity_check_count);
let mut proof_stream: ProofStream<ProofItem, Hasher> = ProofStream::new();
let subgroup = fri.domain.omega.lift().get_cyclic_group_elements(None);
let subgroup = fri.domain.generator.lift().get_cyclic_group_elements(None);

let (_, merkle_root_of_round_0) = fri.prove(&subgroup, &mut proof_stream).unwrap();
let verdict = fri.verify(&mut proof_stream, &merkle_root_of_round_0, &mut None);
Expand Down Expand Up @@ -616,7 +618,7 @@ mod triton_xfri_tests {
let colinearity_check_count = 6;
let fri: Fri<Hasher> =
get_x_field_fri_test_object(subgroup_order, expansion_factor, colinearity_check_count);
let subgroup = fri.domain.omega.lift().get_cyclic_group_elements(None);
let subgroup = fri.domain.generator.lift().get_cyclic_group_elements(None);

let mut points: Vec<XFieldElement>;
for n in [1, 5, 20, 30, 31] {
Expand Down Expand Up @@ -664,7 +666,7 @@ mod triton_xfri_tests {
expansion_factor: usize,
colinearity_checks: usize,
) -> Fri<H> {
let omega = BFieldElement::primitive_root_of_unity(subgroup_order).unwrap();
let fri_domain_generator = BFieldElement::primitive_root_of_unity(subgroup_order).unwrap();

// The following offset was picked arbitrarily by copying the one found in
// `get_b_field_fri_test_object`. It does not generate the full Z_p\{0}, but
Expand All @@ -673,7 +675,7 @@ mod triton_xfri_tests {

let fri: Fri<H> = Fri::new(
offset,
omega,
fri_domain_generator,
subgroup_order as usize,
expansion_factor,
colinearity_checks,
Expand Down
2 changes: 1 addition & 1 deletion triton-vm/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
pub mod arithmetic_domain;
pub mod bfield_codec;
pub mod cross_table_arguments;
pub mod error;
pub mod fri;
pub mod fri_domain;
pub mod instruction;
pub mod op_stack;
pub mod ord_n;
Expand Down
Loading