From 3f3a9fd1f90d5e5ead611589de17f2523e4982db Mon Sep 17 00:00:00 2001 From: Jan Ferdinand Sauer Date: Tue, 26 Mar 2024 18:37:59 +0100 Subject: [PATCH] feat!: streamline accessing AET's heights MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In particular: - the height of any table can be queried using `aet.height_of_table(…)` - the height-dominating table can be identified using `aet.height()` BREAKING_CHANGE: Previous methods for accessing various table's lengths are replaced and thus removed. --- triton-vm/benches/mem_io.rs | 3 +- triton-vm/src/aet.rs | 163 +++++++++++++++----------- triton-vm/src/table/hash_table.rs | 8 +- triton-vm/src/table/master_table.rs | 16 +-- triton-vm/src/table/op_stack_table.rs | 4 +- triton-vm/src/table/program_table.rs | 3 +- triton-vm/src/table/ram_table.rs | 3 +- triton-vm/src/table/u32_table.rs | 35 ++---- triton-vm/src/vm.rs | 5 +- 9 files changed, 133 insertions(+), 107 deletions(-) diff --git a/triton-vm/benches/mem_io.rs b/triton-vm/benches/mem_io.rs index 202c3038b..de1bee314 100644 --- a/triton-vm/benches/mem_io.rs +++ b/triton-vm/benches/mem_io.rs @@ -7,6 +7,7 @@ use criterion::Criterion; use triton_vm::prelude::*; use triton_vm::profiler::Report; use triton_vm::profiler::TritonProfiler; +use triton_vm::table::master_table::TableId; criterion_main!(benches); criterion_group!( @@ -82,7 +83,7 @@ impl MemIOBench { let proof = stark.prove(&claim, &aet, &mut profiler).unwrap(); let mut profiler = profiler.unwrap(); - let trace_len = aet.op_stack_table_length(); + let trace_len = aet.height_of_table(TableId::OpStack); let padded_height = proof.padded_height().unwrap(); let fri = stark.derive_fri(padded_height).unwrap(); diff --git a/triton-vm/src/aet.rs b/triton-vm/src/aet.rs index a40e98f3e..4d686fb70 100644 --- a/triton-vm/src/aet.rs +++ b/triton-vm/src/aet.rs @@ -3,10 +3,12 @@ use std::collections::hash_map::Entry::Vacant; use std::collections::HashMap; use std::ops::AddAssign; +use arbitrary::Arbitrary; use itertools::Itertools; use ndarray::s; use ndarray::Array2; use ndarray::Axis; +use strum::IntoEnumIterator; use twenty_first::prelude::*; use crate::error::InstructionError; @@ -15,6 +17,7 @@ use crate::instruction::Instruction; use crate::program::Program; use crate::table::hash_table::HashTable; use crate::table::hash_table::PermutationTrace; +use crate::table::master_table::TableId; use crate::table::op_stack_table::OpStackTableEntry; use crate::table::ram_table::RamTableCall; use crate::table::table_column::HashBaseTableColumn::CI; @@ -74,7 +77,15 @@ pub struct AlgebraicExecutionTrace { pub lookup_table_lookup_multiplicities: [u64; 1 << 8], } +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)] +pub struct TableHeight { + pub table: TableId, + pub height: usize, +} + impl AlgebraicExecutionTrace { + const LOOKUP_TABLE_HEIGHT: usize = 1 << 8; + pub fn new(program: Program) -> Self { let program_len = program.len_bwords(); @@ -89,26 +100,66 @@ impl AlgebraicExecutionTrace { sponge_trace: Array2::default([0, hash_table::BASE_WIDTH]), u32_entries: HashMap::new(), cascade_table_lookup_multiplicities: HashMap::new(), - lookup_table_lookup_multiplicities: [0; 1 << 8], + lookup_table_lookup_multiplicities: [0; Self::LOOKUP_TABLE_HEIGHT], }; aet.fill_program_hash_trace(); aet } + /// The height of the [AET](AlgebraicExecutionTrace) after [padding][pad]. + /// /// Guaranteed to be a power of two. + /// + /// [pad]: master_table::MasterBaseTable::pad pub fn padded_height(&self) -> usize { - let relevant_table_heights = [ - self.program_table_length(), - self.processor_table_length(), - self.op_stack_table_length(), - self.ram_table_length(), - self.hash_table_length(), - self.cascade_table_length(), - self.lookup_table_length(), - self.u32_table_length(), - ]; - let max_height = relevant_table_heights.into_iter().max().unwrap_or(0); - max_height.next_power_of_two() + self.height().height.next_power_of_two() + } + + /// The height of the [AET](AlgebraicExecutionTrace) before [padding][pad]. + /// Corresponds to the height of the longest table. + /// + /// [pad]: master_table::MasterBaseTable::pad + pub fn height(&self) -> TableHeight { + let relevant_tables = TableId::iter().filter(|&t| t != TableId::DegreeLowering); + let heights = relevant_tables.map(|t| TableHeight::new(t, self.height_of_table(t))); + heights.max().unwrap() + } + + pub fn height_of_table(&self, table: TableId) -> usize { + let hash_table_height = || { + self.sponge_trace.nrows() + self.hash_trace.nrows() + self.program_hash_trace.nrows() + }; + + match table { + TableId::Program => Self::padded_program_length(&self.program), + TableId::Processor => self.processor_trace.nrows(), + TableId::OpStack => self.op_stack_underflow_trace.nrows(), + TableId::Ram => self.ram_trace.nrows(), + TableId::JumpStack => self.processor_trace.nrows(), + TableId::Hash => hash_table_height(), + TableId::Cascade => self.cascade_table_lookup_multiplicities.len(), + TableId::Lookup => Self::LOOKUP_TABLE_HEIGHT, + TableId::U32 => self.u32_table_height(), + TableId::DegreeLowering => self.height().height, + } + } + + /// # Panics + /// + /// - if the table height exceeds [`u32::MAX`] + /// - if the table height exceeds [`usize::MAX`] + fn u32_table_height(&self) -> usize { + let entry_len = U32TableEntry::table_height_contribution; + let height = self.u32_entries.keys().map(entry_len).sum::(); + height.try_into().unwrap() + } + + fn padded_program_length(program: &Program) -> usize { + // Padding is at least one 1. + // Also note that the Program Table's side of the instruction lookup argument requires at + // least one padding row to account for the processor's “next instruction or argument.” + // Both of these are captured by the “+ 1” in the following line. + (program.len_bwords() + 1).next_multiple_of(Tip5::RATE) } /// Hash the program and record the entire Sponge's trace for program attestation. @@ -156,58 +207,6 @@ impl AlgebraicExecutionTrace { .collect() } - pub fn program_table_length(&self) -> usize { - Self::padded_program_length(&self.program) - } - - fn padded_program_length(program: &Program) -> usize { - // After adding one 1, the program table is padded to the next smallest multiple of the - // sponge's rate with 0s. - // Also note that the Program Table's side of the instruction lookup argument requires at - // least one padding row to account for the processor's “next instruction or argument.” - // Both of these are captured by the “+ 1” in the following line. - let min_padded_len = program.len_bwords() + 1; - let remainder_len = min_padded_len % Tip5::RATE; - let num_zeros_to_add = match remainder_len { - 0 => 0, - _ => Tip5::RATE - remainder_len, - }; - min_padded_len + num_zeros_to_add - } - - pub fn processor_table_length(&self) -> usize { - self.processor_trace.nrows() - } - - pub fn op_stack_table_length(&self) -> usize { - self.op_stack_underflow_trace.nrows() - } - - pub fn ram_table_length(&self) -> usize { - self.ram_trace.nrows() - } - - pub fn hash_table_length(&self) -> usize { - self.sponge_trace.nrows() + self.hash_trace.nrows() + self.program_hash_trace.nrows() - } - - pub fn cascade_table_length(&self) -> usize { - self.cascade_table_lookup_multiplicities.len() - } - - pub fn lookup_table_length(&self) -> usize { - 1 << 8 - } - - /// # Panics - /// - /// Panics if the table length exceeds [`u32::MAX`]. - pub fn u32_table_length(&self) -> usize { - let entry_len = U32TableEntry::table_length_contribution; - let len = self.u32_entries.keys().map(entry_len).sum::(); - len.try_into().unwrap() - } - pub(crate) fn record_state(&mut self, state: &VMState) -> Result<(), InstructionError> { self.record_instruction_lookup(state.instruction_pointer)?; self.append_state_to_processor_trace(state); @@ -337,12 +336,29 @@ impl AlgebraicExecutionTrace { } } +impl TableHeight { + fn new(table: TableId, height: usize) -> Self { + Self { table, height } + } +} + +impl PartialOrd for TableHeight { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for TableHeight { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.height.cmp(&other.height) + } +} + #[cfg(test)] mod tests { use assert2::assert; - use crate::triton_asm; - use crate::triton_program; + use crate::prelude::*; use super::*; @@ -355,4 +371,17 @@ mod tests { let expected = [program.to_bwords(), vec![bfe!(1)]].concat(); assert!(expected == padded_program); } + + #[test] + fn height_of_any_table_can_be_computed() { + let program = triton_program!(halt); + let (aet, _) = program + .trace_execution(PublicInput::default(), NonDeterminism::default()) + .unwrap(); + + let _ = aet.height(); + for table in TableId::iter() { + let _ = aet.height_of_table(table); + } + } } diff --git a/triton-vm/src/table/hash_table.rs b/triton-vm/src/table/hash_table.rs index 51f416780..2fd5aa1a9 100644 --- a/triton-vm/src/table/hash_table.rs +++ b/triton-vm/src/table/hash_table.rs @@ -1863,6 +1863,7 @@ pub(crate) mod tests { use crate::shared_tests::ProgramAndInput; use crate::stark::tests::master_tables_for_low_security_level; use crate::table::master_table::MasterTable; + use crate::table::master_table::TableId; use crate::triton_asm; use crate::triton_program; @@ -1894,10 +1895,11 @@ pub(crate) mod tests { }; let (aet, _) = program.trace_execution([].into(), [].into()).unwrap(); + dbg!(aet.height()); dbg!(aet.padded_height()); - dbg!(aet.hash_table_length()); - dbg!(aet.op_stack_table_length()); - dbg!(aet.cascade_table_length()); + dbg!(aet.height_of_table(TableId::Hash)); + dbg!(aet.height_of_table(TableId::OpStack)); + dbg!(aet.height_of_table(TableId::Cascade)); let (_, _, master_base_table, master_ext_table, challenges) = master_tables_for_low_security_level(ProgramAndInput::new(program)); diff --git a/triton-vm/src/table/master_table.rs b/triton-vm/src/table/master_table.rs index 06b176bad..2e5567cf3 100644 --- a/triton-vm/src/table/master_table.rs +++ b/triton-vm/src/table/master_table.rs @@ -145,7 +145,7 @@ pub const EXT_DEGREE_LOWERING_TABLE_END: usize = const NUM_TABLES_WITHOUT_DEGREE_LOWERING: usize = TableId::COUNT - 1; /// A `TableId` uniquely determines one of Triton VM's tables. -#[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] +#[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter, Arbitrary)] pub enum TableId { Program, Processor, @@ -590,13 +590,13 @@ impl MasterBaseTable { let mut master_base_table = Self { num_trace_randomizers, - program_table_len: aet.program_table_length(), - main_execution_len: aet.processor_table_length(), - op_stack_table_len: aet.op_stack_table_length(), - ram_table_len: aet.ram_table_length(), - hash_coprocessor_execution_len: aet.hash_table_length(), - cascade_table_len: aet.cascade_table_length(), - u32_coprocesor_execution_len: aet.u32_table_length(), + program_table_len: aet.height_of_table(TableId::Program), + main_execution_len: aet.height_of_table(TableId::Processor), + op_stack_table_len: aet.height_of_table(TableId::OpStack), + ram_table_len: aet.height_of_table(TableId::Ram), + hash_coprocessor_execution_len: aet.height_of_table(TableId::Hash), + cascade_table_len: aet.height_of_table(TableId::Cascade), + u32_coprocesor_execution_len: aet.height_of_table(TableId::U32), trace_domain, randomized_trace_domain, quotient_domain, diff --git a/triton-vm/src/table/op_stack_table.rs b/triton-vm/src/table/op_stack_table.rs index 052f92582..210f7ada0 100644 --- a/triton-vm/src/table/op_stack_table.rs +++ b/triton-vm/src/table/op_stack_table.rs @@ -21,6 +21,7 @@ use crate::table::constraint_circuit::DualRowIndicator::*; use crate::table::constraint_circuit::SingleRowIndicator::*; use crate::table::constraint_circuit::*; use crate::table::cross_table_argument::*; +use crate::table::master_table::TableId; use crate::table::table_column::OpStackBaseTableColumn::*; use crate::table::table_column::OpStackExtTableColumn::*; use crate::table::table_column::*; @@ -274,7 +275,8 @@ impl OpStackTable { op_stack_table: &mut ArrayViewMut2, aet: &AlgebraicExecutionTrace, ) -> Vec { - let mut op_stack_table = op_stack_table.slice_mut(s![0..aet.op_stack_table_length(), ..]); + let mut op_stack_table = + op_stack_table.slice_mut(s![0..aet.height_of_table(TableId::OpStack), ..]); let trace_iter = aet.op_stack_underflow_trace.rows().into_iter(); let sorted_rows = diff --git a/triton-vm/src/table/program_table.rs b/triton-vm/src/table/program_table.rs index 7694e61a1..e1141d826 100644 --- a/triton-vm/src/table/program_table.rs +++ b/triton-vm/src/table/program_table.rs @@ -19,6 +19,7 @@ use crate::table::constraint_circuit::*; use crate::table::cross_table_argument::CrossTableArg; use crate::table::cross_table_argument::EvalArg; use crate::table::cross_table_argument::LookupArg; +use crate::table::master_table::TableId; use crate::table::table_column::ProgramBaseTableColumn::*; use crate::table::table_column::ProgramExtTableColumn::*; use crate::table::table_column::*; @@ -290,7 +291,7 @@ impl ProgramTable { let instructions = aet.program.to_bwords(); let program_len = instructions.len(); - let padded_program_len = aet.program_table_length(); + let padded_program_len = aet.height_of_table(TableId::Program); let one_iter = [bfe!(1)].into_iter(); let zero_iter = [bfe!(0)].into_iter(); diff --git a/triton-vm/src/table/ram_table.rs b/triton-vm/src/table/ram_table.rs index ea4c9b31e..c590ec509 100644 --- a/triton-vm/src/table/ram_table.rs +++ b/triton-vm/src/table/ram_table.rs @@ -22,6 +22,7 @@ use crate::table::constraint_circuit::DualRowIndicator::*; use crate::table::constraint_circuit::SingleRowIndicator::*; use crate::table::constraint_circuit::*; use crate::table::cross_table_argument::*; +use crate::table::master_table::TableId; use crate::table::table_column::RamBaseTableColumn::*; use crate::table::table_column::RamExtTableColumn::*; use crate::table::table_column::*; @@ -70,7 +71,7 @@ impl RamTable { ram_table: &mut ArrayViewMut2, aet: &AlgebraicExecutionTrace, ) -> Vec { - let mut ram_table = ram_table.slice_mut(s![0..aet.ram_table_length(), ..]); + let mut ram_table = ram_table.slice_mut(s![0..aet.height_of_table(TableId::Ram), ..]); let trace_iter = aet.ram_trace.rows().into_iter(); let sorted_rows = diff --git a/triton-vm/src/table/u32_table.rs b/triton-vm/src/table/u32_table.rs index 8e3f65d94..b58dc96bf 100644 --- a/triton-vm/src/table/u32_table.rs +++ b/triton-vm/src/table/u32_table.rs @@ -46,9 +46,11 @@ pub struct U32TableEntry { } impl U32TableEntry { - pub fn new(instruction: Instruction, left_operand: u32, right_operand: u32) -> Self { - let left_operand: u64 = left_operand.into(); - let right_operand: u64 = right_operand.into(); + pub fn new(instruction: Instruction, left_operand: L, right_operand: R) -> Self + where + L: Into, + R: Into, + { Self { instruction, left_operand: left_operand.into(), @@ -56,28 +58,17 @@ impl U32TableEntry { } } - pub fn new_from_base_field_element( - instruction: Instruction, - left_operand: BFieldElement, - right_operand: BFieldElement, - ) -> Self { - Self { - instruction, - left_operand, - right_operand, - } - } - /// The number of rows this entry contributes to the U32 Table. - pub fn table_length_contribution(&self) -> u32 { + pub(crate) fn table_height_contribution(&self) -> u32 { + let lhs = self.left_operand.value(); + let rhs = self.right_operand.value(); let dominant_operand = match self.instruction { - // for instruction `pow`, the left-hand side doesn't change between rows - Instruction::Pow => self.right_operand.value(), - _ => max(self.left_operand.value(), self.right_operand.value()), + Instruction::Pow => rhs, // left-hand side doesn't change between rows + _ => max(lhs, rhs), }; - match dominant_operand == 0 { - true => 2 - 1, - false => 2 + dominant_operand.ilog2(), + match dominant_operand { + 0 => 2 - 1, + _ => 2 + dominant_operand.ilog2(), } } } diff --git a/triton-vm/src/vm.rs b/triton-vm/src/vm.rs index 078b797f0..02f4d0a8f 100644 --- a/triton-vm/src/vm.rs +++ b/triton-vm/src/vm.rs @@ -571,7 +571,7 @@ impl VMState { self.op_stack.push(hi); self.op_stack.push(lo); - let u32_table_entry = U32TableEntry::new_from_base_field_element(Split, lo, hi); + let u32_table_entry = U32TableEntry::new(Split, lo, hi); let co_processor_calls = vec![U32Call(u32_table_entry)]; self.instruction_pointer += 1; @@ -650,8 +650,7 @@ impl VMState { let base_pow_exponent = base.mod_pow(exponent.into()); self.op_stack.push(base_pow_exponent); - let u32_table_entry = - U32TableEntry::new_from_base_field_element(Pow, base, exponent.into()); + let u32_table_entry = U32TableEntry::new(Pow, base, exponent); let co_processor_calls = vec![U32Call(u32_table_entry)]; self.instruction_pointer += 1;