Skip to content

Commit

Permalink
feat!: streamline accessing AET's heights
Browse files Browse the repository at this point in the history
In particular:
- the height of any table can be queried using `aet.height_of_table(…)`
- the height-dominating table can be identified using `aet.height()`

BREAKING_CHANGE: Previous methods for accessing various table's lengths
are replaced and thus removed.
  • Loading branch information
jan-ferdinand committed Mar 26, 2024
1 parent c265cf4 commit 3f3a9fd
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 107 deletions.
3 changes: 2 additions & 1 deletion triton-vm/benches/mem_io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use criterion::Criterion;
use triton_vm::prelude::*;
use triton_vm::profiler::Report;
use triton_vm::profiler::TritonProfiler;
use triton_vm::table::master_table::TableId;

criterion_main!(benches);
criterion_group!(
Expand Down Expand Up @@ -82,7 +83,7 @@ impl MemIOBench {
let proof = stark.prove(&claim, &aet, &mut profiler).unwrap();
let mut profiler = profiler.unwrap();

let trace_len = aet.op_stack_table_length();
let trace_len = aet.height_of_table(TableId::OpStack);
let padded_height = proof.padded_height().unwrap();
let fri = stark.derive_fri(padded_height).unwrap();

Expand Down
163 changes: 96 additions & 67 deletions triton-vm/src/aet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ use std::collections::hash_map::Entry::Vacant;
use std::collections::HashMap;
use std::ops::AddAssign;

use arbitrary::Arbitrary;
use itertools::Itertools;
use ndarray::s;
use ndarray::Array2;
use ndarray::Axis;
use strum::IntoEnumIterator;
use twenty_first::prelude::*;

use crate::error::InstructionError;
Expand All @@ -15,6 +17,7 @@ use crate::instruction::Instruction;
use crate::program::Program;
use crate::table::hash_table::HashTable;
use crate::table::hash_table::PermutationTrace;
use crate::table::master_table::TableId;
use crate::table::op_stack_table::OpStackTableEntry;
use crate::table::ram_table::RamTableCall;
use crate::table::table_column::HashBaseTableColumn::CI;
Expand Down Expand Up @@ -74,7 +77,15 @@ pub struct AlgebraicExecutionTrace {
pub lookup_table_lookup_multiplicities: [u64; 1 << 8],
}

#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)]
pub struct TableHeight {
pub table: TableId,
pub height: usize,
}

impl AlgebraicExecutionTrace {
const LOOKUP_TABLE_HEIGHT: usize = 1 << 8;

pub fn new(program: Program) -> Self {
let program_len = program.len_bwords();

Expand All @@ -89,26 +100,66 @@ impl AlgebraicExecutionTrace {
sponge_trace: Array2::default([0, hash_table::BASE_WIDTH]),
u32_entries: HashMap::new(),
cascade_table_lookup_multiplicities: HashMap::new(),
lookup_table_lookup_multiplicities: [0; 1 << 8],
lookup_table_lookup_multiplicities: [0; Self::LOOKUP_TABLE_HEIGHT],
};
aet.fill_program_hash_trace();
aet
}

/// The height of the [AET](AlgebraicExecutionTrace) after [padding][pad].
///
/// Guaranteed to be a power of two.
///
/// [pad]: master_table::MasterBaseTable::pad
pub fn padded_height(&self) -> usize {
let relevant_table_heights = [
self.program_table_length(),
self.processor_table_length(),
self.op_stack_table_length(),
self.ram_table_length(),
self.hash_table_length(),
self.cascade_table_length(),
self.lookup_table_length(),
self.u32_table_length(),
];
let max_height = relevant_table_heights.into_iter().max().unwrap_or(0);
max_height.next_power_of_two()
self.height().height.next_power_of_two()
}

/// The height of the [AET](AlgebraicExecutionTrace) before [padding][pad].
/// Corresponds to the height of the longest table.
///
/// [pad]: master_table::MasterBaseTable::pad
pub fn height(&self) -> TableHeight {
let relevant_tables = TableId::iter().filter(|&t| t != TableId::DegreeLowering);
let heights = relevant_tables.map(|t| TableHeight::new(t, self.height_of_table(t)));
heights.max().unwrap()
}

pub fn height_of_table(&self, table: TableId) -> usize {
let hash_table_height = || {
self.sponge_trace.nrows() + self.hash_trace.nrows() + self.program_hash_trace.nrows()
};

match table {
TableId::Program => Self::padded_program_length(&self.program),
TableId::Processor => self.processor_trace.nrows(),
TableId::OpStack => self.op_stack_underflow_trace.nrows(),
TableId::Ram => self.ram_trace.nrows(),
TableId::JumpStack => self.processor_trace.nrows(),
TableId::Hash => hash_table_height(),
TableId::Cascade => self.cascade_table_lookup_multiplicities.len(),
TableId::Lookup => Self::LOOKUP_TABLE_HEIGHT,
TableId::U32 => self.u32_table_height(),
TableId::DegreeLowering => self.height().height,
}
}

/// # Panics
///
/// - if the table height exceeds [`u32::MAX`]
/// - if the table height exceeds [`usize::MAX`]
fn u32_table_height(&self) -> usize {
let entry_len = U32TableEntry::table_height_contribution;
let height = self.u32_entries.keys().map(entry_len).sum::<u32>();
height.try_into().unwrap()
}

fn padded_program_length(program: &Program) -> usize {
// Padding is at least one 1.
// Also note that the Program Table's side of the instruction lookup argument requires at
// least one padding row to account for the processor's “next instruction or argument.”
// Both of these are captured by the “+ 1” in the following line.
(program.len_bwords() + 1).next_multiple_of(Tip5::RATE)
}

/// Hash the program and record the entire Sponge's trace for program attestation.
Expand Down Expand Up @@ -156,58 +207,6 @@ impl AlgebraicExecutionTrace {
.collect()
}

pub fn program_table_length(&self) -> usize {
Self::padded_program_length(&self.program)
}

fn padded_program_length(program: &Program) -> usize {
// After adding one 1, the program table is padded to the next smallest multiple of the
// sponge's rate with 0s.
// Also note that the Program Table's side of the instruction lookup argument requires at
// least one padding row to account for the processor's “next instruction or argument.”
// Both of these are captured by the “+ 1” in the following line.
let min_padded_len = program.len_bwords() + 1;
let remainder_len = min_padded_len % Tip5::RATE;
let num_zeros_to_add = match remainder_len {
0 => 0,
_ => Tip5::RATE - remainder_len,
};
min_padded_len + num_zeros_to_add
}

pub fn processor_table_length(&self) -> usize {
self.processor_trace.nrows()
}

pub fn op_stack_table_length(&self) -> usize {
self.op_stack_underflow_trace.nrows()
}

pub fn ram_table_length(&self) -> usize {
self.ram_trace.nrows()
}

pub fn hash_table_length(&self) -> usize {
self.sponge_trace.nrows() + self.hash_trace.nrows() + self.program_hash_trace.nrows()
}

pub fn cascade_table_length(&self) -> usize {
self.cascade_table_lookup_multiplicities.len()
}

pub fn lookup_table_length(&self) -> usize {
1 << 8
}

/// # Panics
///
/// Panics if the table length exceeds [`u32::MAX`].
pub fn u32_table_length(&self) -> usize {
let entry_len = U32TableEntry::table_length_contribution;
let len = self.u32_entries.keys().map(entry_len).sum::<u32>();
len.try_into().unwrap()
}

pub(crate) fn record_state(&mut self, state: &VMState) -> Result<(), InstructionError> {
self.record_instruction_lookup(state.instruction_pointer)?;
self.append_state_to_processor_trace(state);
Expand Down Expand Up @@ -337,12 +336,29 @@ impl AlgebraicExecutionTrace {
}
}

impl TableHeight {
fn new(table: TableId, height: usize) -> Self {
Self { table, height }
}
}

impl PartialOrd for TableHeight {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}

impl Ord for TableHeight {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.height.cmp(&other.height)
}
}

#[cfg(test)]
mod tests {
use assert2::assert;

use crate::triton_asm;
use crate::triton_program;
use crate::prelude::*;

use super::*;

Expand All @@ -355,4 +371,17 @@ mod tests {
let expected = [program.to_bwords(), vec![bfe!(1)]].concat();
assert!(expected == padded_program);
}

#[test]
fn height_of_any_table_can_be_computed() {
let program = triton_program!(halt);
let (aet, _) = program
.trace_execution(PublicInput::default(), NonDeterminism::default())
.unwrap();

let _ = aet.height();
for table in TableId::iter() {
let _ = aet.height_of_table(table);
}
}
}
8 changes: 5 additions & 3 deletions triton-vm/src/table/hash_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1863,6 +1863,7 @@ pub(crate) mod tests {
use crate::shared_tests::ProgramAndInput;
use crate::stark::tests::master_tables_for_low_security_level;
use crate::table::master_table::MasterTable;
use crate::table::master_table::TableId;
use crate::triton_asm;
use crate::triton_program;

Expand Down Expand Up @@ -1894,10 +1895,11 @@ pub(crate) mod tests {
};

let (aet, _) = program.trace_execution([].into(), [].into()).unwrap();
dbg!(aet.height());
dbg!(aet.padded_height());
dbg!(aet.hash_table_length());
dbg!(aet.op_stack_table_length());
dbg!(aet.cascade_table_length());
dbg!(aet.height_of_table(TableId::Hash));
dbg!(aet.height_of_table(TableId::OpStack));
dbg!(aet.height_of_table(TableId::Cascade));

let (_, _, master_base_table, master_ext_table, challenges) =
master_tables_for_low_security_level(ProgramAndInput::new(program));
Expand Down
16 changes: 8 additions & 8 deletions triton-vm/src/table/master_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ pub const EXT_DEGREE_LOWERING_TABLE_END: usize =
const NUM_TABLES_WITHOUT_DEGREE_LOWERING: usize = TableId::COUNT - 1;

/// A `TableId` uniquely determines one of Triton VM's tables.
#[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)]
#[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter, Arbitrary)]
pub enum TableId {
Program,
Processor,
Expand Down Expand Up @@ -590,13 +590,13 @@ impl MasterBaseTable {

let mut master_base_table = Self {
num_trace_randomizers,
program_table_len: aet.program_table_length(),
main_execution_len: aet.processor_table_length(),
op_stack_table_len: aet.op_stack_table_length(),
ram_table_len: aet.ram_table_length(),
hash_coprocessor_execution_len: aet.hash_table_length(),
cascade_table_len: aet.cascade_table_length(),
u32_coprocesor_execution_len: aet.u32_table_length(),
program_table_len: aet.height_of_table(TableId::Program),
main_execution_len: aet.height_of_table(TableId::Processor),
op_stack_table_len: aet.height_of_table(TableId::OpStack),
ram_table_len: aet.height_of_table(TableId::Ram),
hash_coprocessor_execution_len: aet.height_of_table(TableId::Hash),
cascade_table_len: aet.height_of_table(TableId::Cascade),
u32_coprocesor_execution_len: aet.height_of_table(TableId::U32),
trace_domain,
randomized_trace_domain,
quotient_domain,
Expand Down
4 changes: 3 additions & 1 deletion triton-vm/src/table/op_stack_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use crate::table::constraint_circuit::DualRowIndicator::*;
use crate::table::constraint_circuit::SingleRowIndicator::*;
use crate::table::constraint_circuit::*;
use crate::table::cross_table_argument::*;
use crate::table::master_table::TableId;
use crate::table::table_column::OpStackBaseTableColumn::*;
use crate::table::table_column::OpStackExtTableColumn::*;
use crate::table::table_column::*;
Expand Down Expand Up @@ -274,7 +275,8 @@ impl OpStackTable {
op_stack_table: &mut ArrayViewMut2<BFieldElement>,
aet: &AlgebraicExecutionTrace,
) -> Vec<BFieldElement> {
let mut op_stack_table = op_stack_table.slice_mut(s![0..aet.op_stack_table_length(), ..]);
let mut op_stack_table =
op_stack_table.slice_mut(s![0..aet.height_of_table(TableId::OpStack), ..]);
let trace_iter = aet.op_stack_underflow_trace.rows().into_iter();

let sorted_rows =
Expand Down
3 changes: 2 additions & 1 deletion triton-vm/src/table/program_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use crate::table::constraint_circuit::*;
use crate::table::cross_table_argument::CrossTableArg;
use crate::table::cross_table_argument::EvalArg;
use crate::table::cross_table_argument::LookupArg;
use crate::table::master_table::TableId;
use crate::table::table_column::ProgramBaseTableColumn::*;
use crate::table::table_column::ProgramExtTableColumn::*;
use crate::table::table_column::*;
Expand Down Expand Up @@ -290,7 +291,7 @@ impl ProgramTable {

let instructions = aet.program.to_bwords();
let program_len = instructions.len();
let padded_program_len = aet.program_table_length();
let padded_program_len = aet.height_of_table(TableId::Program);

let one_iter = [bfe!(1)].into_iter();
let zero_iter = [bfe!(0)].into_iter();
Expand Down
3 changes: 2 additions & 1 deletion triton-vm/src/table/ram_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::table::constraint_circuit::DualRowIndicator::*;
use crate::table::constraint_circuit::SingleRowIndicator::*;
use crate::table::constraint_circuit::*;
use crate::table::cross_table_argument::*;
use crate::table::master_table::TableId;
use crate::table::table_column::RamBaseTableColumn::*;
use crate::table::table_column::RamExtTableColumn::*;
use crate::table::table_column::*;
Expand Down Expand Up @@ -70,7 +71,7 @@ impl RamTable {
ram_table: &mut ArrayViewMut2<BFieldElement>,
aet: &AlgebraicExecutionTrace,
) -> Vec<BFieldElement> {
let mut ram_table = ram_table.slice_mut(s![0..aet.ram_table_length(), ..]);
let mut ram_table = ram_table.slice_mut(s![0..aet.height_of_table(TableId::Ram), ..]);
let trace_iter = aet.ram_trace.rows().into_iter();

let sorted_rows =
Expand Down
35 changes: 13 additions & 22 deletions triton-vm/src/table/u32_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,38 +46,29 @@ pub struct U32TableEntry {
}

impl U32TableEntry {
pub fn new(instruction: Instruction, left_operand: u32, right_operand: u32) -> Self {
let left_operand: u64 = left_operand.into();
let right_operand: u64 = right_operand.into();
pub fn new<L, R>(instruction: Instruction, left_operand: L, right_operand: R) -> Self
where
L: Into<BFieldElement>,
R: Into<BFieldElement>,
{
Self {
instruction,
left_operand: left_operand.into(),
right_operand: right_operand.into(),
}
}

pub fn new_from_base_field_element(
instruction: Instruction,
left_operand: BFieldElement,
right_operand: BFieldElement,
) -> Self {
Self {
instruction,
left_operand,
right_operand,
}
}

/// The number of rows this entry contributes to the U32 Table.
pub fn table_length_contribution(&self) -> u32 {
pub(crate) fn table_height_contribution(&self) -> u32 {
let lhs = self.left_operand.value();
let rhs = self.right_operand.value();
let dominant_operand = match self.instruction {
// for instruction `pow`, the left-hand side doesn't change between rows
Instruction::Pow => self.right_operand.value(),
_ => max(self.left_operand.value(), self.right_operand.value()),
Instruction::Pow => rhs, // left-hand side doesn't change between rows
_ => max(lhs, rhs),
};
match dominant_operand == 0 {
true => 2 - 1,
false => 2 + dominant_operand.ilog2(),
match dominant_operand {
0 => 2 - 1,
_ => 2 + dominant_operand.ilog2(),
}
}
}
Expand Down
Loading

0 comments on commit 3f3a9fd

Please sign in to comment.