Skip to content

Commit

Permalink
use arithmetic domain instead of fri domain wherever applicable
Browse files Browse the repository at this point in the history
Fix #114
  • Loading branch information
jan-ferdinand committed Nov 15, 2022
1 parent f17d80a commit 095f645
Show file tree
Hide file tree
Showing 4 changed files with 260 additions and 220 deletions.
14 changes: 7 additions & 7 deletions triton-vm/src/cross_table_arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ pub trait CrossTableArg {
fn terminal_quotient(
&self,
ext_codeword_tables: &ExtTableCollection,
fri_domain: &Domain<BFieldElement>,
domain: &Domain<BFieldElement>,
omicron: BFieldElement,
) -> Vec<XFieldElement> {
let from_codeword = self.combined_from_codeword(ext_codeword_tables);
let to_codeword = self.combined_to_codeword(ext_codeword_tables);

let zerofier = fri_domain
let zerofier = domain
.domain_values()
.into_iter()
.map(|x| x - omicron.inverse())
Expand Down Expand Up @@ -453,10 +453,10 @@ impl GrandCrossTableArg {
pub fn terminal_quotient_codeword(
&self,
ext_codeword_tables: &ExtTableCollection,
fri_domain: &Domain<BFieldElement>,
domain: &Domain<BFieldElement>,
omicron: BFieldElement,
) -> Vec<XFieldElement> {
let mut non_linear_sum_codeword = vec![XFieldElement::zero(); fri_domain.length];
let mut non_linear_sum_codeword = vec![XFieldElement::zero(); domain.length];

// cross-table arguments
for (arg, weight) in self.into_iter() {
Expand All @@ -472,7 +472,7 @@ impl GrandCrossTableArg {
}

// standard input
let input_terminal_codeword = vec![self.input_terminal; fri_domain.length];
let input_terminal_codeword = vec![self.input_terminal; domain.length];
let (to_table, to_column) = self.input_to_processor;
let to_codeword = &ext_codeword_tables.data(to_table)[to_column];
let weight = self.input_to_processor_weight;
Expand All @@ -487,7 +487,7 @@ impl GrandCrossTableArg {
// standard output
let (from_table, from_column) = self.processor_to_output;
let from_codeword = &ext_codeword_tables.data(from_table)[from_column];
let output_terminal_codeword = vec![self.output_terminal; fri_domain.length];
let output_terminal_codeword = vec![self.output_terminal; domain.length];
let weight = self.processor_to_output_weight;
let non_linear_summand =
weighted_difference_codeword(from_codeword, &output_terminal_codeword, weight);
Expand All @@ -497,7 +497,7 @@ impl GrandCrossTableArg {
XFieldElement::add,
);

let zerofier = fri_domain
let zerofier = domain
.domain_values()
.into_iter()
.map(|x| x - omicron.inverse())
Expand Down
119 changes: 63 additions & 56 deletions triton-vm/src/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,16 @@ impl Stark {

prof_start!(maybe_profiler, "dual LDE 1");
let arithmetic_domain = self.arithmetic_domain();
let base_fri_domain_tables = base_trace_tables.to_arithmetic_and_fri_domain_tables(
&arithmetic_domain,
&self.fri.domain,
self.parameters.num_trace_randomizers,
);
let (base_arithmetic_domain_tables, base_fri_domain_tables) = base_trace_tables
.to_arithmetic_and_fri_domain_tables(
&arithmetic_domain,
&self.fri.domain,
self.parameters.num_trace_randomizers,
);
let base_arithmetic_domain_codewords = base_arithmetic_domain_tables.get_all_base_columns();
let base_fri_domain_codewords = base_fri_domain_tables.get_all_base_columns();
let randomizer_and_base_fri_domain_codewords =
vec![b_rand_codewords, base_fri_domain_codewords.clone()].concat();
vec![b_rand_codewords, base_fri_domain_codewords].concat();
prof_stop!(maybe_profiler, "dual LDE 1");

prof_start!(maybe_profiler, "Merkle tree 1");
Expand Down Expand Up @@ -188,11 +190,14 @@ impl Stark {
prof_stop!(maybe_profiler, "extend");

prof_start!(maybe_profiler, "dual LDE 2");
let ext_fri_domain_tables = ext_trace_tables.to_arithmetic_and_fri_domain_tables(
&arithmetic_domain,
&self.fri.domain,
self.parameters.num_trace_randomizers,
);
let (ext_arithmetic_domain_tables, ext_fri_domain_tables) = ext_trace_tables
.to_arithmetic_and_fri_domain_tables(
&arithmetic_domain,
&self.fri.domain,
self.parameters.num_trace_randomizers,
);
let extension_arithmetic_domain_codewords =
ext_arithmetic_domain_tables.collect_all_columns();
let extension_fri_domain_codewords = ext_fri_domain_tables.collect_all_columns();
prof_stop!(maybe_profiler, "dual LDE 2");

Expand All @@ -206,26 +211,26 @@ impl Stark {

prof_start!(maybe_profiler, "degree bounds");
prof_start!(maybe_profiler, "base");
let base_degree_bounds =
base_fri_domain_tables.get_base_degree_bounds(self.parameters.num_trace_randomizers);
let base_degree_bounds = base_arithmetic_domain_tables
.get_base_degree_bounds(self.parameters.num_trace_randomizers);
prof_stop!(maybe_profiler, "base");

prof_start!(maybe_profiler, "extension");
let extension_degree_bounds = ext_fri_domain_tables
let extension_degree_bounds = ext_arithmetic_domain_tables
.get_extension_degree_bounds(self.parameters.num_trace_randomizers);
prof_stop!(maybe_profiler, "extension");

prof_start!(maybe_profiler, "quotient");
let full_fri_domain_tables =
ExtTableCollection::join(base_fri_domain_tables, ext_fri_domain_tables);
let mut quotient_degree_bounds = full_fri_domain_tables
let full_arithmetic_domain_tables =
ExtTableCollection::join(base_arithmetic_domain_tables, ext_arithmetic_domain_tables);
let mut quotient_degree_bounds = full_arithmetic_domain_tables
.get_all_quotient_degree_bounds(self.parameters.num_trace_randomizers);
prof_stop!(maybe_profiler, "quotient");
prof_stop!(maybe_profiler, "degree bounds");

prof_start!(maybe_profiler, "quotient codewords");
let mut quotient_codewords = full_fri_domain_tables.get_all_quotients(
&self.fri.domain,
let mut quotient_codewords = full_arithmetic_domain_tables.get_all_quotients(
&arithmetic_domain,
&extension_challenges,
maybe_profiler,
);
Expand All @@ -234,8 +239,8 @@ impl Stark {
prof_start!(maybe_profiler, "grand cross table");
let num_grand_cross_table_args = 1;
let num_non_lin_combi_weights = self.parameters.num_randomizer_polynomials
+ 2 * base_fri_domain_codewords.len()
+ 2 * extension_fri_domain_codewords.len()
+ 2 * base_arithmetic_domain_codewords.len()
+ 2 * extension_arithmetic_domain_codewords.len()
+ 2 * quotient_degree_bounds.len()
+ 2 * num_grand_cross_table_args;
let num_grand_cross_table_arg_weights = NUM_CROSS_TABLE_ARGS + NUM_PUBLIC_EVAL_ARGS;
Expand Down Expand Up @@ -272,38 +277,52 @@ impl Stark {
);
let grand_cross_table_arg_quotient_codeword = grand_cross_table_arg
.terminal_quotient_codeword(
&full_fri_domain_tables,
&self.fri.domain,
derive_omicron(full_fri_domain_tables.padded_height as u64),
&full_arithmetic_domain_tables,
&arithmetic_domain,
derive_omicron(full_arithmetic_domain_tables.padded_height as u64),
);
quotient_codewords.push(grand_cross_table_arg_quotient_codeword);

let grand_cross_table_arg_quotient_degree_bound = grand_cross_table_arg
.quotient_degree_bound(
&full_fri_domain_tables,
&full_arithmetic_domain_tables,
self.parameters.num_trace_randomizers,
);
quotient_degree_bounds.push(grand_cross_table_arg_quotient_degree_bound);
prof_stop!(maybe_profiler, "grand cross table");

prof_start!(maybe_profiler, "nonlinear combination");
// magic number `1` corresponds to `num_randomizer_polynomials`, which is currently ignored
let (randomizer_weight, base_ext_quot_weights) = non_lin_combi_weights.split_at(1);
let combination_codeword = self.create_combination_codeword(
vec![x_rand_codeword],
base_fri_domain_codewords,
extension_fri_domain_codewords,
&arithmetic_domain,
base_arithmetic_domain_codewords,
extension_arithmetic_domain_codewords,
quotient_codewords,
non_lin_combi_weights.to_vec(),
base_ext_quot_weights.to_vec(),
base_degree_bounds,
extension_degree_bounds,
quotient_degree_bounds,
maybe_profiler,
);

prof_start!(maybe_profiler, "LDE 3");
let combination_polynomial = arithmetic_domain.interpolate(&combination_codeword);
let fri_combination_codeword_without_randomizer =
self.fri.domain.evaluate(&combination_polynomial);
prof_stop!(maybe_profiler, "LDE 3");

let fri_combination_codeword: Vec<_> = fri_combination_codeword_without_randomizer
.into_par_iter()
.zip_eq(x_rand_codeword.into_par_iter())
.map(|(cc_elem, rand_elem)| cc_elem + randomizer_weight[0] * rand_elem)
.collect();
prof_stop!(maybe_profiler, "nonlinear combination");

prof_start!(maybe_profiler, "Merkle tree 3");
let mut combination_codeword_digests: Vec<Digest> =
Vec::with_capacity(combination_codeword.len());
combination_codeword
Vec::with_capacity(fri_combination_codeword.len());
fri_combination_codeword
.clone()
.into_par_iter()
.map(|elem| StarkHasher::hash(&elem))
Expand Down Expand Up @@ -331,7 +350,7 @@ impl Stark {
}

prof_start!(maybe_profiler, "FRI");
match self.fri.prove(&combination_codeword, &mut proof_stream) {
match self.fri.prove(&fri_combination_codeword, &mut proof_stream) {
Ok((_, fri_first_round_merkle_root)) => assert_eq!(
combination_root, fri_first_round_merkle_root,
"Combination root from STARK and from FRI must agree."
Expand Down Expand Up @@ -368,7 +387,7 @@ impl Stark {
// as the latter includes adjacent table rows relative to the values in `indices`
let revealed_combination_elements: Vec<XFieldElement> = cross_codeword_slice_indices
.iter()
.map(|i| combination_codeword[*i])
.map(|i| fri_combination_codeword[*i])
.collect();
let revealed_combination_auth_paths =
combination_tree.get_authentication_structure(&cross_codeword_slice_indices);
Expand Down Expand Up @@ -429,7 +448,7 @@ impl Stark {
#[allow(clippy::too_many_arguments)]
fn create_combination_codeword(
&self,
randomizer_codewords: Vec<Vec<XFieldElement>>,
arithmetic_domain: &Domain<BFieldElement>,
base_codewords: Vec<Vec<BFieldElement>>,
extension_codewords: Vec<Vec<XFieldElement>>,
quotient_codewords: Vec<Vec<XFieldElement>>,
Expand All @@ -439,13 +458,7 @@ impl Stark {
quotient_degree_bounds: Vec<i64>,
maybe_profiler: &mut Option<TritonProfiler>,
) -> Vec<XFieldElement> {
assert_eq!(
self.parameters.num_randomizer_polynomials,
randomizer_codewords.len()
);

prof_start!(maybe_profiler, "create combination codeword");

let base_codewords_lifted = base_codewords
.into_iter()
.map(|base_codeword| {
Expand All @@ -456,20 +469,10 @@ impl Stark {
})
.collect_vec();
let mut weights_iterator = weights.into_iter();
let mut combination_codeword: Vec<XFieldElement> = vec![0.into(); self.fri.domain.length];
let mut combination_codeword: Vec<XFieldElement> = vec![0.into(); arithmetic_domain.length];

// TODO don't keep the entire domain's values in memory, create them lazily when needed
let fri_x_values = self.fri.domain.domain_values();

for randomizer_codeword in randomizer_codewords {
combination_codeword = Self::non_linearly_add_to_codeword(
&combination_codeword,
&randomizer_codeword,
&weights_iterator.next().unwrap(),
&randomizer_codeword,
&0.into(),
);
}
let domain_values = arithmetic_domain.domain_values();

for (codewords, bounds, identifier) in [
(base_codewords_lifted, base_degree_bounds, "base"),
Expand All @@ -484,7 +487,7 @@ impl Stark {
codewords.into_iter().zip_eq(bounds.iter()).enumerate()
{
let shift = (self.max_degree as Degree - degree_bound) as u32;
let codeword_shifted = Self::shift_codeword(&fri_x_values, &codeword, shift);
let codeword_shifted = Self::shift_codeword(&domain_values, &codeword, shift);

combination_codeword = Self::non_linearly_add_to_codeword(
&combination_codeword,
Expand All @@ -494,6 +497,7 @@ impl Stark {
&weights_iterator.next().unwrap(),
);
self.debug_check_degrees(
arithmetic_domain,
&idx,
degree_bound,
&shift,
Expand All @@ -507,7 +511,9 @@ impl Stark {
if std::env::var("DEBUG").is_ok() {
println!(
"The combination codeword corresponds to a polynomial of degree {}",
self.fri.domain.interpolate(&combination_codeword).degree()
arithmetic_domain
.interpolate(&combination_codeword)
.degree()
);
}

Expand All @@ -519,6 +525,7 @@ impl Stark {
#[allow(clippy::too_many_arguments)]
fn debug_check_degrees(
&self,
domain: &Domain<BFieldElement>,
idx: &usize,
degree_bound: &Degree,
shift: &u32,
Expand All @@ -529,8 +536,8 @@ impl Stark {
if std::env::var("DEBUG").is_err() {
return;
}
let interpolated = self.fri.domain.interpolate(extension_codeword);
let interpolated_shifted = self.fri.domain.interpolate(extension_codeword_shifted);
let interpolated = domain.interpolate(extension_codeword);
let interpolated_shifted = domain.interpolate(extension_codeword_shifted);
let int_shift_deg = interpolated_shifted.degree();
let maybe_excl_mark = if int_shift_deg > self.max_degree as isize {
"!!!"
Expand Down
Loading

0 comments on commit 095f645

Please sign in to comment.