From 4ec6b1d1697fea512d3512fd65c7f2bf023736dd Mon Sep 17 00:00:00 2001 From: Jan Ferdinand Sauer Date: Tue, 27 Aug 2024 12:12:15 +0200 Subject: [PATCH 01/15] refactor: Don't wildcard-use `CircuitExpression` changelog: ignore --- triton-vm/src/table/constraint_circuit.rs | 140 ++++++++++++---------- 1 file changed, 76 insertions(+), 64 deletions(-) diff --git a/triton-vm/src/table/constraint_circuit.rs b/triton-vm/src/table/constraint_circuit.rs index 44ba52bc2..4c335df8b 100644 --- a/triton-vm/src/table/constraint_circuit.rs +++ b/triton-vm/src/table/constraint_circuit.rs @@ -33,8 +33,6 @@ use quote::quote; use quote::ToTokens; use twenty_first::prelude::*; -use CircuitExpression::*; - #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] pub enum BinOp { Add, @@ -274,23 +272,23 @@ pub enum CircuitExpression { impl Hash for CircuitExpression { fn hash(&self, state: &mut H) { match self { - BConstant(bfe) => { + Self::BConstant(bfe) => { "bfe".hash(state); bfe.hash(state); } - XConstant(xfe) => { + Self::XConstant(xfe) => { "xfe".hash(state); xfe.hash(state); } - Input(index) => { + Self::Input(index) => { "input".hash(state); index.hash(state); } - Challenge(table_challenge_id) => { + Self::Challenge(table_challenge_id) => { "challenge".hash(state); table_challenge_id.hash(state); } - BinaryOperation(binop, lhs, rhs) => { + Self::BinaryOperation(binop, lhs, rhs) => { "binop".hash(state); binop.hash(state); lhs.borrow().hash(state); @@ -303,12 +301,12 @@ impl Hash for CircuitExpression { impl PartialEq for CircuitExpression { fn eq(&self, other: &Self) -> bool { match (self, other) { - (BConstant(bfe_self), BConstant(bfe_other)) => bfe_self == bfe_other, - (XConstant(xfe_self), XConstant(xfe_other)) => xfe_self == xfe_other, - (Input(input_self), Input(input_other)) => input_self == input_other, - (Challenge(id_self), Challenge(id_other)) => id_self == id_other, - (BinaryOperation(op_s, lhs_s, rhs_s), BinaryOperation(op_o, lhs_o, rhs_o)) => { - op_s == op_o && lhs_s == lhs_o && rhs_s == rhs_o + (Self::BConstant(bfe_self), Self::BConstant(bfe_other)) => bfe_self == bfe_other, + (Self::XConstant(xfe_self), Self::XConstant(xfe_other)) => xfe_self == xfe_other, + (Self::Input(input_self), Self::Input(input_other)) => input_self == input_other, + (Self::Challenge(id_self), Self::Challenge(id_other)) => id_self == id_other, + (Self::BinaryOperation(op, l, r), Self::BinaryOperation(op_o, l_o, r_o)) => { + op == op_o && l == l_o && r == r_o } _ => false, } @@ -349,11 +347,11 @@ impl PartialEq for ConstraintCircuit { impl Display for ConstraintCircuit { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { match &self.expression { - XConstant(xfe) => write!(f, "{xfe}"), - BConstant(bfe) => write!(f, "{bfe}"), - Input(input) => write!(f, "{input} "), - Challenge(self_challenge_idx) => write!(f, "{self_challenge_idx}"), - BinaryOperation(operation, lhs, rhs) => { + CircuitExpression::XConstant(xfe) => write!(f, "{xfe}"), + CircuitExpression::BConstant(bfe) => write!(f, "{bfe}"), + CircuitExpression::Input(input) => write!(f, "{input} "), + CircuitExpression::Challenge(self_challenge_idx) => write!(f, "{self_challenge_idx}"), + CircuitExpression::BinaryOperation(operation, lhs, rhs) => { write!(f, "({}) {operation} ({})", lhs.borrow(), rhs.borrow()) } } @@ -373,7 +371,7 @@ impl ConstraintCircuit { fn reset_ref_count_for_tree(&mut self) { self.ref_count = 0; - if let BinaryOperation(_, lhs, rhs) = &self.expression { + if let CircuitExpression::BinaryOperation(_, lhs, rhs) = &self.expression { lhs.borrow_mut().reset_ref_count_for_tree(); rhs.borrow_mut().reset_ref_count_for_tree(); } @@ -397,7 +395,7 @@ impl ConstraintCircuit { panic!("Repeated ID: {self_id}\nSelf:\n{self}\n{self:?}\nOther:\n{other}\n{other:?}"); } - if let BinaryOperation(_, lhs, rhs) = &self.expression { + if let CircuitExpression::BinaryOperation(_, lhs, rhs) = &self.expression { lhs.borrow_mut().assert_unique_ids_inner(ids); rhs.borrow_mut().assert_unique_ids_inner(ids); } @@ -427,7 +425,7 @@ impl ConstraintCircuit { } match &self.expression { - BinaryOperation(binop, lhs, rhs) => { + CircuitExpression::BinaryOperation(binop, lhs, rhs) => { let degree_lhs = lhs.borrow().degree(); let degree_rhs = rhs.borrow().degree(); let degree_additive = cmp::max(degree_lhs, degree_rhs); @@ -440,15 +438,17 @@ impl ConstraintCircuit { BinOp::Mul => degree_multiplicative, } } - Input(_) => 1, - BConstant(_) | XConstant(_) | Challenge(_) => 0, + CircuitExpression::Input(_) => 1, + CircuitExpression::BConstant(_) + | CircuitExpression::XConstant(_) + | CircuitExpression::Challenge(_) => 0, } } /// All unique reference counters in the subtree, sorted. pub fn all_ref_counters(&self) -> Vec { let mut ref_counters = vec![self.ref_count]; - if let BinaryOperation(_, lhs, rhs) = &self.expression { + if let CircuitExpression::BinaryOperation(_, lhs, rhs) = &self.expression { ref_counters.extend(lhs.borrow().all_ref_counters()); ref_counters.extend(rhs.borrow().all_ref_counters()); }; @@ -461,8 +461,8 @@ impl ConstraintCircuit { /// Does not catch composite expressions that will always evaluate to zero, like `0·a`. pub fn is_zero(&self) -> bool { match self.expression { - BConstant(bfe) => bfe.is_zero(), - XConstant(xfe) => xfe.is_zero(), + CircuitExpression::BConstant(bfe) => bfe.is_zero(), + CircuitExpression::XConstant(xfe) => xfe.is_zero(), _ => false, } } @@ -471,16 +471,16 @@ impl ConstraintCircuit { /// Does not catch composite expressions that will always evaluate to one, like `1·1`. pub fn is_one(&self) -> bool { match self.expression { - BConstant(bfe) => bfe.is_one(), - XConstant(xfe) => xfe.is_one(), + CircuitExpression::BConstant(bfe) => bfe.is_one(), + CircuitExpression::XConstant(xfe) => xfe.is_one(), _ => false, } } pub fn is_neg_one(&self) -> bool { match self.expression { - BConstant(bfe) => (-bfe).is_one(), - XConstant(xfe) => (-xfe).is_one(), + CircuitExpression::BConstant(bfe) => (-bfe).is_one(), + CircuitExpression::XConstant(xfe) => (-xfe).is_one(), _ => false, } } @@ -491,11 +491,11 @@ impl ConstraintCircuit { /// 3. binary operations on BFieldElements. pub fn evaluates_to_base_element(&self) -> bool { match &self.expression { - BConstant(_) => true, - XConstant(_) => false, - Input(indicator) => indicator.is_base_table_column(), - Challenge(_) => false, - BinaryOperation(_, lhs, rhs) => { + CircuitExpression::BConstant(_) => true, + CircuitExpression::XConstant(_) => false, + CircuitExpression::Input(indicator) => indicator.is_base_table_column(), + CircuitExpression::Challenge(_) => false, + CircuitExpression::BinaryOperation(_, lhs, rhs) => { lhs.borrow().evaluates_to_base_element() && rhs.borrow().evaluates_to_base_element() } } @@ -508,11 +508,11 @@ impl ConstraintCircuit { challenges: &[XFieldElement], ) -> XFieldElement { match &self.expression { - BConstant(bfe) => bfe.lift(), - XConstant(xfe) => *xfe, - Input(input) => input.evaluate(base_table, ext_table), - Challenge(challenge_id) => challenges[*challenge_id], - BinaryOperation(binop, lhs, rhs) => { + CircuitExpression::BConstant(bfe) => bfe.lift(), + CircuitExpression::XConstant(xfe) => *xfe, + CircuitExpression::Input(input) => input.evaluate(base_table, ext_table), + CircuitExpression::Challenge(challenge_id) => challenges[*challenge_id], + CircuitExpression::BinaryOperation(binop, lhs, rhs) => { let lhs_value = lhs.borrow().evaluate(base_table, ext_table, challenges); let rhs_value = rhs.borrow().evaluate(base_table, ext_table, challenges); binop.operation(lhs_value, rhs_value) @@ -583,10 +583,18 @@ fn binop( &lhs.circuit.borrow().expression, &rhs.circuit.borrow().expression, ) { - (&BConstant(l), &BConstant(r)) => return lhs.builder.b_constant(binop.operation(l, r)), - (&BConstant(l), &XConstant(r)) => return lhs.builder.x_constant(binop.operation(l, r)), - (&XConstant(l), &BConstant(r)) => return lhs.builder.x_constant(binop.operation(l, r)), - (&XConstant(l), &XConstant(r)) => return lhs.builder.x_constant(binop.operation(l, r)), + (&CircuitExpression::BConstant(l), &CircuitExpression::BConstant(r)) => { + return lhs.builder.b_constant(binop.operation(l, r)) + } + (&CircuitExpression::BConstant(l), &CircuitExpression::XConstant(r)) => { + return lhs.builder.x_constant(binop.operation(l, r)) + } + (&CircuitExpression::XConstant(l), &CircuitExpression::BConstant(r)) => { + return lhs.builder.x_constant(binop.operation(l, r)) + } + (&CircuitExpression::XConstant(l), &CircuitExpression::XConstant(r)) => { + return lhs.builder.x_constant(binop.operation(l, r)) + } _ => (), }; @@ -616,7 +624,8 @@ fn binop_new_node( rhs: &ConstraintCircuitMonad, ) -> ConstraintCircuitMonad { let id = lhs.builder.id_counter.borrow().to_owned(); - let expression = BinaryOperation(binop, lhs.circuit.clone(), rhs.circuit.clone()); + let expression = + CircuitExpression::BinaryOperation(binop, lhs.circuit.clone(), rhs.circuit.clone()); let circuit = ConstraintCircuit::new(id, expression); lhs.builder.new_monad(circuit) } @@ -723,9 +732,9 @@ impl ConstraintCircuitMonad { } let evaluation = match &circuit.borrow().expression { - BConstant(bfe) => bfe.lift(), - XConstant(xfe) => *xfe, - Input(input) => { + CircuitExpression::BConstant(bfe) => bfe.lift(), + CircuitExpression::XConstant(xfe) => *xfe, + CircuitExpression::Input(input) => { let [s0, s1, s2] = master_seed.coefficients; let dom_sep = if input.is_current_row() { DOMAIN_SEPARATOR_CURR_ROW @@ -736,14 +745,14 @@ impl ConstraintCircuitMonad { let Digest([d0, d1, d2, _, _]) = Tip5::hash_varlen(&[s0, s1, s2, dom_sep, i]); xfe!([d0, d1, d2]) } - Challenge(challenge) => { + CircuitExpression::Challenge(challenge) => { let [s0, s1, s2] = master_seed.coefficients; let dom_sep = DOMAIN_SEPARATOR_CHALLENGE; let ch = bfe!(u64::try_from(*challenge).unwrap()); let Digest([d0, d1, d2, _, _]) = Tip5::hash_varlen(&[s0, s1, s2, dom_sep, ch]); xfe!([d0, d1, d2]) } - BinaryOperation(bin_op, lhs, rhs) => { + CircuitExpression::BinaryOperation(bin_op, lhs, rhs) => { let l = Self::probe_random(lhs, id_to_eval, eval_to_ids, id_to_node, master_seed); let r = Self::probe_random(rhs, id_to_eval, eval_to_ids, id_to_node, master_seed); bin_op.operation(l, r) @@ -910,7 +919,7 @@ impl ConstraintCircuitMonad { /// Internal helper function to recursively find all nodes in a circuit. fn all_nodes_in_circuit(circuit: &ConstraintCircuit) -> Vec> { let mut all_nodes = vec![]; - if let BinaryOperation(_, lhs, rhs) = circuit.expression.clone() { + if let CircuitExpression::BinaryOperation(_, lhs, rhs) = circuit.expression.clone() { let lhs_nodes = Self::all_nodes_in_circuit(&lhs.borrow()); let rhs_nodes = Self::all_nodes_in_circuit(&rhs.borrow()); all_nodes.extend(lhs_nodes); @@ -999,7 +1008,7 @@ impl ConstraintCircuitBuilder { where B: Into, { - self.make_leaf(BConstant(bfe.into())) + self.make_leaf(CircuitExpression::BConstant(bfe.into())) } /// Leaf node with constant over the [extension field][XFieldElement]. @@ -1007,12 +1016,12 @@ impl ConstraintCircuitBuilder { where X: Into, { - self.make_leaf(XConstant(xfe.into())) + self.make_leaf(CircuitExpression::XConstant(xfe.into())) } /// Create deterministic input leaf node. pub fn input(&self, input: II) -> ConstraintCircuitMonad { - self.make_leaf(Input(input)) + self.make_leaf(CircuitExpression::Input(input)) } /// Create challenge leaf node. @@ -1020,14 +1029,14 @@ impl ConstraintCircuitBuilder { where C: Into, { - self.make_leaf(Challenge(challenge.into())) + self.make_leaf(CircuitExpression::Challenge(challenge.into())) } fn make_leaf(&self, mut expression: CircuitExpression) -> ConstraintCircuitMonad { // Don't generate an X field leaf if it can be expressed as a B field leaf - if let XConstant(xfe) = expression { + if let CircuitExpression::XConstant(xfe) = expression { if let Some(bfe) = xfe.unlift() { - expression = BConstant(bfe); + expression = CircuitExpression::BConstant(bfe); } } @@ -1055,7 +1064,9 @@ impl ConstraintCircuitBuilder { self.all_nodes.borrow_mut().remove(&old_id); for node in self.all_nodes.borrow_mut().values_mut() { let node_expression = &mut node.circuit.borrow_mut().expression; - let BinaryOperation(_, ref mut node_lhs, ref mut node_rhs) = node_expression else { + let CircuitExpression::BinaryOperation(_, ref mut node_lhs, ref mut node_rhs) = + node_expression + else { continue; }; @@ -1180,7 +1191,7 @@ mod tests { if self == other { return true; } - let BinaryOperation(_, lhs, rhs) = self else { + let CircuitExpression::BinaryOperation(_, lhs, rhs) = self else { return false; }; @@ -1343,7 +1354,7 @@ mod tests { values: &mut HashMap)>, ) -> XFieldElement { let value = match &constraint.expression { - BinaryOperation(binop, lhs, rhs) => { + CircuitExpression::BinaryOperation(binop, lhs, rhs) => { let lhs = lhs.borrow(); let rhs = rhs.borrow(); let lhs = evaluate_assert_unique(&lhs, challenges, base_rows, ext_rows, values); @@ -2012,10 +2023,11 @@ mod tests { ] { for (i, constraint) in constraints.iter().enumerate() { let expression = constraint.circuit.borrow().expression.clone(); - let BinaryOperation(BinOp::Add, lhs, rhs) = expression else { + let CircuitExpression::BinaryOperation(BinOp::Add, lhs, rhs) = expression else { panic!("New {constraint_type} constraint {i} must be a subtraction."); }; - let Input(input_indicator) = lhs.borrow().expression.clone() else { + let CircuitExpression::Input(input_indicator) = lhs.borrow().expression.clone() + else { panic!("New {constraint_type} constraint {i} must be a simple substitution."); }; let substitution_rule = rhs.borrow().clone(); @@ -2087,13 +2099,13 @@ mod tests { substitution_rule: &ConstraintCircuit, ) { match substitution_rule.expression.clone() { - BinaryOperation(_, lhs, rhs) => { + CircuitExpression::BinaryOperation(_, lhs, rhs) => { let lhs = lhs.borrow(); let rhs = rhs.borrow(); assert_substitution_rule_uses_legal_variables(new_var, &lhs); assert_substitution_rule_uses_legal_variables(new_var, &rhs); } - Input(old_var) => { + CircuitExpression::Input(old_var) => { let new_var_is_base = new_var.is_base_table_column(); let old_var_is_base = old_var.is_base_table_column(); let legal_substitute = match (new_var_is_base, old_var_is_base) { From 65daf4190904490d9da69080f93277731f96fccd Mon Sep 17 00:00:00 2001 From: Jan Ferdinand Sauer Date: Wed, 28 Aug 2024 13:22:18 +0200 Subject: [PATCH 02/15] refactor!: Combine existing crates in Triton VM changelog: ignore --- Cargo.toml | 3 +- constraint-evaluation-generator/Cargo.toml | 29 -- constraint-evaluation-generator/README.md | 9 - .../src/codegen.rs | 60 --- .../src/codegen/rust.rs | 366 --------------- .../src/constraints.rs | 280 ----------- constraint-evaluation-generator/src/main.rs | 89 ---- triton-vm/Cargo.toml | 4 +- .../circuit.rs} | 104 ++--- .../src/codegen/constraints.rs | 435 +++++++++++++++++- triton-vm/src/codegen/mod.rs | 270 +++++++++++ .../src/codegen/substitutions.rs | 28 +- triton-vm/src/instruction.rs | 4 +- triton-vm/src/lib.rs | 4 +- triton-vm/src/op_stack.rs | 3 +- triton-vm/src/program.rs | 4 +- triton-vm/src/stark.rs | 2 +- triton-vm/src/table.rs | 118 ++++- triton-vm/src/table/cascade_table.rs | 12 +- triton-vm/src/table/constraints.rs | 2 +- triton-vm/src/table/cross_table_argument.rs | 12 +- triton-vm/src/table/hash_table.rs | 14 +- triton-vm/src/table/jump_stack_table.rs | 6 +- triton-vm/src/table/lookup_table.rs | 12 +- triton-vm/src/table/master_table.rs | 8 +- triton-vm/src/table/op_stack_table.rs | 6 +- triton-vm/src/table/processor_table.rs | 6 +- triton-vm/src/table/program_table.rs | 6 +- triton-vm/src/table/ram_table.rs | 9 +- triton-vm/src/table/table_column.rs | 28 +- triton-vm/src/table/u32_table.rs | 14 +- triton-vm/src/vm.rs | 3 +- 32 files changed, 920 insertions(+), 1030 deletions(-) delete mode 100644 constraint-evaluation-generator/Cargo.toml delete mode 100644 constraint-evaluation-generator/README.md delete mode 100644 constraint-evaluation-generator/src/codegen.rs delete mode 100644 constraint-evaluation-generator/src/codegen/rust.rs delete mode 100644 constraint-evaluation-generator/src/constraints.rs delete mode 100644 constraint-evaluation-generator/src/main.rs rename triton-vm/src/{table/constraint_circuit.rs => codegen/circuit.rs} (96%) rename constraint-evaluation-generator/src/codegen/tasm.rs => triton-vm/src/codegen/constraints.rs (56%) create mode 100644 triton-vm/src/codegen/mod.rs rename constraint-evaluation-generator/src/substitution.rs => triton-vm/src/codegen/substitutions.rs (94%) diff --git a/Cargo.toml b/Cargo.toml index 6e812bdab..7659013d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["triton-vm", "constraint-evaluation-generator"] +members = ["triton-vm"] resolver = "2" [profile.test] @@ -47,7 +47,6 @@ rand = "0.8.5" rand_core = "0.6.4" rayon = "1.10" serde = { version = "1", features = ["derive"] } -serde_derive = "1" serde_json = "1.0" strum = { version = "0.26", features = ["derive"] } syn = "2.0" diff --git a/constraint-evaluation-generator/Cargo.toml b/constraint-evaluation-generator/Cargo.toml deleted file mode 100644 index de5f09597..000000000 --- a/constraint-evaluation-generator/Cargo.toml +++ /dev/null @@ -1,29 +0,0 @@ -[package] -name = "constraint-evaluation-generator" -description = "Generate constraint evaluation functions for Triton VM." - -version.workspace = true -edition.workspace = true -authors.workspace = true -license.workspace = true -homepage.workspace = true -documentation.workspace = true -repository.workspace = true -readme.workspace = true - -[dependencies] -itertools.workspace = true -prettyplease.workspace = true -proc-macro2.workspace = true -quote.workspace = true -syn.workspace = true -triton-vm = { path = "../triton-vm" } -twenty-first.workspace = true - -[dev-dependencies] -proptest.workspace = true -criterion.workspace = true -cargo-husky.workspace = true - -[lints] -workspace = true diff --git a/constraint-evaluation-generator/README.md b/constraint-evaluation-generator/README.md deleted file mode 100644 index 3caf14753..000000000 --- a/constraint-evaluation-generator/README.md +++ /dev/null @@ -1,9 +0,0 @@ -# constraint-evaluation-generator - -Generate constraint evaluation functions for Triton VM. - -## How to run - -```sh -cargo run --bin constraint-evaluation-generator -``` diff --git a/constraint-evaluation-generator/src/codegen.rs b/constraint-evaluation-generator/src/codegen.rs deleted file mode 100644 index cea9343b7..000000000 --- a/constraint-evaluation-generator/src/codegen.rs +++ /dev/null @@ -1,60 +0,0 @@ -use std::collections::HashSet; - -use proc_macro2::TokenStream; -use quote::quote; -use twenty_first::prelude::BFieldElement; -use twenty_first::prelude::XFieldElement; - -use crate::constraints::Constraints; - -mod rust; -mod tasm; - -pub(crate) trait Codegen { - fn constraint_evaluation_code(constraints: &Constraints) -> TokenStream; - - fn tokenize_bfe(bfe: BFieldElement) -> TokenStream { - let raw_u64 = bfe.raw_u64(); - quote!(BFieldElement::from_raw_u64(#raw_u64)) - } - - fn tokenize_xfe(xfe: XFieldElement) -> TokenStream { - let [c_0, c_1, c_2] = xfe.coefficients.map(Self::tokenize_bfe); - quote!(XFieldElement::new([#c_0, #c_1, #c_2])) - } -} - -#[derive(Debug, Default, Clone, Eq, PartialEq)] -pub(crate) struct RustBackend { - /// All [circuit] IDs known to be in scope. - /// - /// [circuit]: triton_vm::table::constraint_circuit::ConstraintCircuit - scope: HashSet, -} - -#[derive(Debug, Default, Clone, Eq, PartialEq)] -pub(crate) struct TasmBackend { - /// All [circuit] IDs known to be processed and stored to memory. - /// - /// [circuit]: triton_vm::table::constraint_circuit::ConstraintCircuit - scope: HashSet, - - /// The number of elements written to the output list. - elements_written: usize, - - /// Whether the code that is to be generated can assume statically provided - /// addresses for the various input arrays. - input_location_is_static: bool, -} - -#[cfg(test)] -pub mod tests { - use super::*; - - pub fn print_constraints(constraints: &Constraints) { - let code = B::constraint_evaluation_code(constraints); - let syntax_tree = syn::parse2(code).unwrap(); - let code = prettyplease::unparse(&syntax_tree); - println!("{code}"); - } -} diff --git a/constraint-evaluation-generator/src/codegen/rust.rs b/constraint-evaluation-generator/src/codegen/rust.rs deleted file mode 100644 index d2480ebe0..000000000 --- a/constraint-evaluation-generator/src/codegen/rust.rs +++ /dev/null @@ -1,366 +0,0 @@ -use itertools::Itertools; -use proc_macro2::TokenStream; -use quote::format_ident; -use quote::quote; - -use triton_vm::table::constraint_circuit::CircuitExpression; -use triton_vm::table::constraint_circuit::ConstraintCircuit; -use triton_vm::table::constraint_circuit::InputIndicator; - -use crate::codegen::Codegen; -use crate::codegen::RustBackend; -use crate::Constraints; - -impl Codegen for RustBackend { - fn constraint_evaluation_code(constraints: &Constraints) -> TokenStream { - let num_init_constraints = constraints.init.len(); - let num_cons_constraints = constraints.cons.len(); - let num_tran_constraints = constraints.tran.len(); - let num_term_constraints = constraints.term.len(); - - let (init_constraint_degrees, init_constraints_bfe, init_constraints_xfe) = - Self::tokenize_circuits(&constraints.init()); - let (cons_constraint_degrees, cons_constraints_bfe, cons_constraints_xfe) = - Self::tokenize_circuits(&constraints.cons()); - let (tran_constraint_degrees, tran_constraints_bfe, tran_constraints_xfe) = - Self::tokenize_circuits(&constraints.tran()); - let (term_constraint_degrees, term_constraints_bfe, term_constraints_xfe) = - Self::tokenize_circuits(&constraints.term()); - - let uses = Self::uses(); - let evaluable_over_base_field = Self::generate_evaluable_implementation_over_field( - &init_constraints_bfe, - &cons_constraints_bfe, - &tran_constraints_bfe, - &term_constraints_bfe, - quote!(BFieldElement), - ); - let evaluable_over_ext_field = Self::generate_evaluable_implementation_over_field( - &init_constraints_xfe, - &cons_constraints_xfe, - &tran_constraints_xfe, - &term_constraints_xfe, - quote!(XFieldElement), - ); - - let quotient_trait_impl = quote!( - impl Quotientable for MasterExtTable { - const NUM_INITIAL_CONSTRAINTS: usize = #num_init_constraints; - const NUM_CONSISTENCY_CONSTRAINTS: usize = #num_cons_constraints; - const NUM_TRANSITION_CONSTRAINTS: usize = #num_tran_constraints; - const NUM_TERMINAL_CONSTRAINTS: usize = #num_term_constraints; - - #[allow(unused_variables)] - fn initial_quotient_degree_bounds(interpolant_degree: isize) -> Vec { - let zerofier_degree = 1; - [#init_constraint_degrees].to_vec() - } - - #[allow(unused_variables)] - fn consistency_quotient_degree_bounds( - interpolant_degree: isize, - padded_height: usize, - ) -> Vec { - let zerofier_degree = padded_height as isize; - [#cons_constraint_degrees].to_vec() - } - - #[allow(unused_variables)] - fn transition_quotient_degree_bounds( - interpolant_degree: isize, - padded_height: usize, - ) -> Vec { - let zerofier_degree = padded_height as isize - 1; - [#tran_constraint_degrees].to_vec() - } - - #[allow(unused_variables)] - fn terminal_quotient_degree_bounds(interpolant_degree: isize) -> Vec { - let zerofier_degree = 1; - [#term_constraint_degrees].to_vec() - } - } - ); - - quote!( - #uses - #evaluable_over_base_field - #evaluable_over_ext_field - #quotient_trait_impl - ) - } -} - -impl RustBackend { - fn uses() -> TokenStream { - quote!( - use ndarray::ArrayView1; - use twenty_first::prelude::BFieldElement; - use twenty_first::prelude::XFieldElement; - - use crate::table::challenges::Challenges; - use crate::table::extension_table::Evaluable; - use crate::table::extension_table::Quotientable; - use crate::table::master_table::MasterExtTable; - ) - } - - fn generate_evaluable_implementation_over_field( - init_constraints: &TokenStream, - cons_constraints: &TokenStream, - tran_constraints: &TokenStream, - term_constraints: &TokenStream, - field: TokenStream, - ) -> TokenStream { - quote!( - impl Evaluable<#field> for MasterExtTable { - #[allow(unused_variables)] - fn evaluate_initial_constraints( - base_row: ArrayView1<#field>, - ext_row: ArrayView1, - challenges: &Challenges, - ) -> Vec { - #init_constraints - } - - #[allow(unused_variables)] - fn evaluate_consistency_constraints( - base_row: ArrayView1<#field>, - ext_row: ArrayView1, - challenges: &Challenges, - ) -> Vec { - #cons_constraints - } - - #[allow(unused_variables)] - fn evaluate_transition_constraints( - current_base_row: ArrayView1<#field>, - current_ext_row: ArrayView1, - next_base_row: ArrayView1<#field>, - next_ext_row: ArrayView1, - challenges: &Challenges, - ) -> Vec { - #tran_constraints - } - - #[allow(unused_variables)] - fn evaluate_terminal_constraints( - base_row: ArrayView1<#field>, - ext_row: ArrayView1, - challenges: &Challenges, - ) -> Vec { - #term_constraints - } - } - ) - } - - /// Return a tuple of [`TokenStream`]s corresponding to code evaluating these constraints as - /// well as their degrees. In particular: - /// 1. The first stream contains code that, when evaluated, produces the constraints' degrees, - /// 1. the second stream contains code that, when evaluated, produces the constraints' values, - /// with the input type for the base row being `BFieldElement`, and - /// 1. the third stream is like the second, except that the input type for the base row is - /// `XFieldElement`. - fn tokenize_circuits( - constraints: &[ConstraintCircuit], - ) -> (TokenStream, TokenStream, TokenStream) { - if constraints.is_empty() { - return (quote!(), quote!(vec![]), quote!(vec![])); - } - - let mut backend = Self::default(); - let shared_declarations = backend.declare_shared_nodes(constraints); - let (base_constraints, ext_constraints): (Vec<_>, Vec<_>) = constraints - .iter() - .partition(|constraint| constraint.evaluates_to_base_element()); - - // The order of the constraints' degrees must match the order of the constraints. - // Hence, listing the degrees is only possible after the partition into base and extension - // constraints is known. - let tokenized_degree_bounds = base_constraints - .iter() - .chain(&ext_constraints) - .map(|circuit| match circuit.degree() { - d if d > 1 => quote!(interpolant_degree * #d - zerofier_degree), - 1 => quote!(interpolant_degree - zerofier_degree), - _ => panic!("Constraint degree must be positive"), - }) - .collect_vec(); - let tokenized_degree_bounds = quote!(#(#tokenized_degree_bounds),*); - - let tokenize_constraint_evaluation = |constraints: &[&ConstraintCircuit]| { - constraints - .iter() - .map(|constraint| backend.evaluate_single_node(constraint)) - .collect_vec() - }; - let tokenized_base_constraints = tokenize_constraint_evaluation(&base_constraints); - let tokenized_ext_constraints = tokenize_constraint_evaluation(&ext_constraints); - - // If there are no base constraints, the type needs to be explicitly declared. - let tokenized_bfe_base_constraints = match base_constraints.is_empty() { - true => quote!(let base_constraints: [BFieldElement; 0] = []), - false => quote!(let base_constraints = [#(#tokenized_base_constraints),*]), - }; - let tokenized_bfe_constraints = quote!( - #(#shared_declarations)* - #tokenized_bfe_base_constraints; - let ext_constraints = [#(#tokenized_ext_constraints),*]; - base_constraints - .into_iter() - .map(|bfe| bfe.lift()) - .chain(ext_constraints) - .collect() - ); - - let tokenized_xfe_constraints = quote!( - #(#shared_declarations)* - let base_constraints = [#(#tokenized_base_constraints),*]; - let ext_constraints = [#(#tokenized_ext_constraints),*]; - base_constraints - .into_iter() - .chain(ext_constraints) - .collect() - ); - - ( - tokenized_degree_bounds, - tokenized_bfe_constraints, - tokenized_xfe_constraints, - ) - } - - /// Declare all shared variables, i.e., those with a ref count greater than 1. - /// These declarations must be made starting from the highest ref count. - /// Otherwise, the resulting code will refer to bindings that have not yet been made. - fn declare_shared_nodes( - &mut self, - constraints: &[ConstraintCircuit], - ) -> Vec { - let constraints_iter = constraints.iter(); - let all_ref_counts = constraints_iter.flat_map(ConstraintCircuit::all_ref_counters); - let relevant_ref_counts = all_ref_counts.unique().filter(|&x| x > 1); - let ordered_ref_counts = relevant_ref_counts.sorted().rev(); - - ordered_ref_counts - .map(|count| self.declare_nodes_with_ref_count(constraints, count)) - .collect() - } - - /// Produce the code to evaluate code for all nodes that share a ref count. - fn declare_nodes_with_ref_count( - &mut self, - circuits: &[ConstraintCircuit], - ref_count: usize, - ) -> TokenStream { - let all_nodes_in_circuit = - |circuit| self.declare_single_node_with_ref_count(circuit, ref_count); - let tokenized_circuits = circuits.iter().filter_map(all_nodes_in_circuit); - quote!(#(#tokenized_circuits)*) - } - - fn declare_single_node_with_ref_count( - &mut self, - circuit: &ConstraintCircuit, - ref_count: usize, - ) -> Option { - if self.scope.contains(&circuit.id) { - return None; - } - - // constants can be declared trivially - let CircuitExpression::BinaryOperation(_, lhs, rhs) = &circuit.expression else { - return None; - }; - - if circuit.ref_count < ref_count { - let out_left = self.declare_single_node_with_ref_count(&lhs.borrow(), ref_count); - let out_right = self.declare_single_node_with_ref_count(&rhs.borrow(), ref_count); - return match (out_left, out_right) { - (None, None) => None, - (Some(l), None) => Some(l), - (None, Some(r)) => Some(r), - (Some(l), Some(r)) => Some(quote!(#l #r)), - }; - } - - assert_eq!(circuit.ref_count, ref_count); - let binding_name = Self::binding_name(circuit); - let evaluation = self.evaluate_single_node(circuit); - let new_binding = quote!(let #binding_name = #evaluation;); - - let is_new_insertion = self.scope.insert(circuit.id); - assert!(is_new_insertion); - - Some(new_binding) - } - - /// Recursively construct the code for evaluating a single node. - pub fn evaluate_single_node( - &self, - circuit: &ConstraintCircuit, - ) -> TokenStream { - if self.scope.contains(&circuit.id) { - return Self::binding_name(circuit); - } - - let CircuitExpression::BinaryOperation(binop, lhs, rhs) = &circuit.expression else { - return Self::binding_name(circuit); - }; - - let lhs = self.evaluate_single_node(&lhs.borrow()); - let rhs = self.evaluate_single_node(&rhs.borrow()); - quote!((#lhs) #binop (#rhs)) - } - - fn binding_name(circuit: &ConstraintCircuit) -> TokenStream { - match &circuit.expression { - CircuitExpression::BConstant(bfe) => Self::tokenize_bfe(*bfe), - CircuitExpression::XConstant(xfe) => Self::tokenize_xfe(*xfe), - CircuitExpression::Input(idx) => quote!(#idx), - CircuitExpression::Challenge(challenge) => quote!(challenges[#challenge]), - CircuitExpression::BinaryOperation(_, _, _) => { - let node_ident = format_ident!("node_{}", circuit.id); - quote!(#node_ident) - } - } - } -} - -#[cfg(test)] -mod tests { - use twenty_first::prelude::*; - - use crate::codegen::tests::print_constraints; - - use super::*; - - #[test] - fn tokenizing_base_field_elements_produces_expected_result() { - let bfe = bfe!(42); - let expected = "BFieldElement :: from_raw_u64 (180388626390u64)"; - assert_eq!(expected, RustBackend::tokenize_bfe(bfe).to_string()); - } - - #[test] - fn tokenizing_extension_field_elements_produces_expected_result() { - let xfe = xfe!([42, 43, 44]); - let expected = "XFieldElement :: new ([\ - BFieldElement :: from_raw_u64 (180388626390u64) , \ - BFieldElement :: from_raw_u64 (184683593685u64) , \ - BFieldElement :: from_raw_u64 (188978560980u64)\ - ])"; - assert_eq!(expected, RustBackend::tokenize_xfe(xfe).to_string()); - } - - #[test] - fn print_mini_constraints() { - print_constraints::(&Constraints::mini_constraints()); - } - - #[test] - fn print_test_constraints() { - print_constraints::(&Constraints::test_constraints()); - } -} diff --git a/constraint-evaluation-generator/src/constraints.rs b/constraint-evaluation-generator/src/constraints.rs deleted file mode 100644 index 82506566e..000000000 --- a/constraint-evaluation-generator/src/constraints.rs +++ /dev/null @@ -1,280 +0,0 @@ -use itertools::Itertools; - -use triton_vm::table; -use triton_vm::table::cascade_table::ExtCascadeTable; -use triton_vm::table::constraint_circuit::ConstraintCircuit; -use triton_vm::table::constraint_circuit::ConstraintCircuitBuilder; -use triton_vm::table::constraint_circuit::ConstraintCircuitMonad; -use triton_vm::table::constraint_circuit::DualRowIndicator; -use triton_vm::table::constraint_circuit::InputIndicator; -use triton_vm::table::constraint_circuit::SingleRowIndicator; -use triton_vm::table::cross_table_argument::GrandCrossTableArg; -use triton_vm::table::degree_lowering_table; -use triton_vm::table::hash_table::ExtHashTable; -use triton_vm::table::jump_stack_table::ExtJumpStackTable; -use triton_vm::table::lookup_table::ExtLookupTable; -use triton_vm::table::master_table; -use triton_vm::table::op_stack_table::ExtOpStackTable; -use triton_vm::table::processor_table::ExtProcessorTable; -use triton_vm::table::program_table::ExtProgramTable; -use triton_vm::table::ram_table::ExtRamTable; -use triton_vm::table::u32_table::ExtU32Table; - -use crate::substitution::AllSubstitutions; -use crate::substitution::Substitutions; - -pub(crate) struct Constraints { - pub init: Vec>, - pub cons: Vec>, - pub tran: Vec>, - pub term: Vec>, -} - -impl Constraints { - pub fn all() -> Self { - Self { - init: Self::initial_constraints(), - cons: Self::consistency_constraints(), - tran: Self::transition_constraints(), - term: Self::terminal_constraints(), - } - } - - fn initial_constraints() -> Vec> { - let circuit_builder = ConstraintCircuitBuilder::new(); - vec![ - ExtProgramTable::initial_constraints(&circuit_builder), - ExtProcessorTable::initial_constraints(&circuit_builder), - ExtOpStackTable::initial_constraints(&circuit_builder), - ExtRamTable::initial_constraints(&circuit_builder), - ExtJumpStackTable::initial_constraints(&circuit_builder), - ExtHashTable::initial_constraints(&circuit_builder), - ExtCascadeTable::initial_constraints(&circuit_builder), - ExtLookupTable::initial_constraints(&circuit_builder), - ExtU32Table::initial_constraints(&circuit_builder), - GrandCrossTableArg::initial_constraints(&circuit_builder), - ] - .concat() - } - - fn consistency_constraints() -> Vec> { - let circuit_builder = ConstraintCircuitBuilder::new(); - vec![ - ExtProgramTable::consistency_constraints(&circuit_builder), - ExtProcessorTable::consistency_constraints(&circuit_builder), - ExtOpStackTable::consistency_constraints(&circuit_builder), - ExtRamTable::consistency_constraints(&circuit_builder), - ExtJumpStackTable::consistency_constraints(&circuit_builder), - ExtHashTable::consistency_constraints(&circuit_builder), - ExtCascadeTable::consistency_constraints(&circuit_builder), - ExtLookupTable::consistency_constraints(&circuit_builder), - ExtU32Table::consistency_constraints(&circuit_builder), - GrandCrossTableArg::consistency_constraints(&circuit_builder), - ] - .concat() - } - - fn transition_constraints() -> Vec> { - let circuit_builder = ConstraintCircuitBuilder::new(); - vec![ - ExtProgramTable::transition_constraints(&circuit_builder), - ExtProcessorTable::transition_constraints(&circuit_builder), - ExtOpStackTable::transition_constraints(&circuit_builder), - ExtRamTable::transition_constraints(&circuit_builder), - ExtJumpStackTable::transition_constraints(&circuit_builder), - ExtHashTable::transition_constraints(&circuit_builder), - ExtCascadeTable::transition_constraints(&circuit_builder), - ExtLookupTable::transition_constraints(&circuit_builder), - ExtU32Table::transition_constraints(&circuit_builder), - GrandCrossTableArg::transition_constraints(&circuit_builder), - ] - .concat() - } - - fn terminal_constraints() -> Vec> { - let circuit_builder = ConstraintCircuitBuilder::new(); - vec![ - ExtProgramTable::terminal_constraints(&circuit_builder), - ExtProcessorTable::terminal_constraints(&circuit_builder), - ExtOpStackTable::terminal_constraints(&circuit_builder), - ExtRamTable::terminal_constraints(&circuit_builder), - ExtJumpStackTable::terminal_constraints(&circuit_builder), - ExtHashTable::terminal_constraints(&circuit_builder), - ExtCascadeTable::terminal_constraints(&circuit_builder), - ExtLookupTable::terminal_constraints(&circuit_builder), - ExtU32Table::terminal_constraints(&circuit_builder), - GrandCrossTableArg::terminal_constraints(&circuit_builder), - ] - .concat() - } - - pub fn lower_to_target_degree_through_substitutions(&mut self) -> AllSubstitutions { - // Subtract the degree lowering table's width from the total number of columns to guarantee - // the same number of columns even for repeated runs of the constraint evaluation generator. - let mut num_base_cols = table::NUM_BASE_COLUMNS - degree_lowering_table::BASE_WIDTH; - let mut num_ext_cols = table::NUM_EXT_COLUMNS - degree_lowering_table::EXT_WIDTH; - let (init_base_substitutions, init_ext_substitutions) = - ConstraintCircuitMonad::lower_to_degree( - &mut self.init, - master_table::AIR_TARGET_DEGREE, - num_base_cols, - num_ext_cols, - ); - num_base_cols += init_base_substitutions.len(); - num_ext_cols += init_ext_substitutions.len(); - - let (cons_base_substitutions, cons_ext_substitutions) = - ConstraintCircuitMonad::lower_to_degree( - &mut self.cons, - master_table::AIR_TARGET_DEGREE, - num_base_cols, - num_ext_cols, - ); - num_base_cols += cons_base_substitutions.len(); - num_ext_cols += cons_ext_substitutions.len(); - - let (tran_base_substitutions, tran_ext_substitutions) = - ConstraintCircuitMonad::lower_to_degree( - &mut self.tran, - master_table::AIR_TARGET_DEGREE, - num_base_cols, - num_ext_cols, - ); - num_base_cols += tran_base_substitutions.len(); - num_ext_cols += tran_ext_substitutions.len(); - - let (term_base_substitutions, term_ext_substitutions) = - ConstraintCircuitMonad::lower_to_degree( - &mut self.term, - master_table::AIR_TARGET_DEGREE, - num_base_cols, - num_ext_cols, - ); - - AllSubstitutions { - base: Substitutions { - init: init_base_substitutions, - cons: cons_base_substitutions, - tran: tran_base_substitutions, - term: term_base_substitutions, - }, - ext: Substitutions { - init: init_ext_substitutions, - cons: cons_ext_substitutions, - tran: tran_ext_substitutions, - term: term_ext_substitutions, - }, - } - } - - #[must_use] - pub fn combine_with_substitution_induced_constraints( - self, - AllSubstitutions { base, ext }: AllSubstitutions, - ) -> Self { - Self { - init: [self.init, base.init, ext.init].concat(), - cons: [self.cons, base.cons, ext.cons].concat(), - tran: [self.tran, base.tran, ext.tran].concat(), - term: [self.term, base.term, ext.term].concat(), - } - } - - pub fn init(&self) -> Vec> { - Self::consume(&self.init) - } - - pub fn cons(&self) -> Vec> { - Self::consume(&self.cons) - } - - pub fn tran(&self) -> Vec> { - Self::consume(&self.tran) - } - - pub fn term(&self) -> Vec> { - Self::consume(&self.term) - } - - fn consume( - constraints: &[ConstraintCircuitMonad], - ) -> Vec> { - let mut constraints = constraints.iter().map(|c| c.consume()).collect_vec(); - ConstraintCircuit::assert_unique_ids(&mut constraints); - constraints - } -} - -#[cfg(test)] -pub(crate) mod tests { - use twenty_first::bfe; - - use triton_vm::prelude::BFieldElement; - use triton_vm::table::challenges::ChallengeId; - use triton_vm::table::constraint_circuit::DualRowIndicator; - - use super::*; - - impl Constraints { - pub(crate) fn mini_constraints() -> Self { - let circuit_builder = ConstraintCircuitBuilder::new(); - let challenge = |c| circuit_builder.challenge(c); - let constant = |c: u32| circuit_builder.x_constant(c); - let base_row = |i| circuit_builder.input(SingleRowIndicator::BaseRow(i)); - let ext_row = |i| circuit_builder.input(SingleRowIndicator::ExtRow(i)); - - let constraint = - base_row(0) * challenge(ChallengeId::StackWeight5) - ext_row(1) * constant(42); - - Self { - init: vec![constraint], - cons: vec![], - tran: vec![], - term: vec![], - } - } - - /// For testing purposes only. There is no meaning behind any of the constraints. - pub(crate) fn test_constraints() -> Self { - Self { - init: Self::small_init_constraints(), - cons: vec![], - tran: Self::small_transition_constraints(), - term: vec![], - } - } - - fn small_init_constraints() -> Vec> { - let circuit_builder = ConstraintCircuitBuilder::new(); - let challenge = |c| circuit_builder.challenge(c); - let constant = |c: u32| circuit_builder.b_constant(bfe!(c)); - let input = |i| circuit_builder.input(SingleRowIndicator::BaseRow(i)); - let input_to_the_4th = |i| input(i) * input(i) * input(i) * input(i); - - vec![ - input(0) * input(1) - input(2), - input_to_the_4th(0) - challenge(ChallengeId::StackWeight3) - constant(16), - input(2) * input_to_the_4th(0) - input_to_the_4th(1), - ] - } - - fn small_transition_constraints() -> Vec> { - let circuit_builder = ConstraintCircuitBuilder::new(); - let challenge = |c| circuit_builder.challenge(c); - let constant = |c: u32| circuit_builder.x_constant(c); - - let curr_b_row = |col| circuit_builder.input(DualRowIndicator::CurrentBaseRow(col)); - let next_b_row = |col| circuit_builder.input(DualRowIndicator::NextBaseRow(col)); - let curr_x_row = |col| circuit_builder.input(DualRowIndicator::CurrentExtRow(col)); - let next_x_row = |col| circuit_builder.input(DualRowIndicator::NextExtRow(col)); - - vec![ - curr_b_row(0) * next_x_row(1) - next_b_row(1) * curr_x_row(0), - curr_b_row(1) * next_x_row(2) - next_b_row(2) * curr_x_row(1), - curr_b_row(2) * next_x_row(0) * next_x_row(1) * next_x_row(3) + constant(42), - curr_b_row(0) * challenge(ChallengeId::StackWeight12) - - challenge(ChallengeId::StackWeight5), - ] - } - } -} diff --git a/constraint-evaluation-generator/src/main.rs b/constraint-evaluation-generator/src/main.rs deleted file mode 100644 index 44b74a2f0..000000000 --- a/constraint-evaluation-generator/src/main.rs +++ /dev/null @@ -1,89 +0,0 @@ -//! The constraint generator is a tool that generates efficient-to-evaluate code -//! for the constraints of Triton Virtual Machine, in particular, for the -//! Arithmetic Intermediate Representation (AIR) constraints of the -//! Zero-Knowledge Proof System underpinning the STARK proof system. -//! -//! The constraints are defined in the Triton VM crate. In order to leverage -//! compiler optimizations, rust code is generated using those constraints. -//! -//! Additionally, the constraints are also translated to Triton Assembly (TASM). -//! This allows Triton VM to evaluate its own constraints, which is essential -//! for recursive proof verification, or Incrementally Verifiable Computation. -//! -//! The constraint generator can be run by executing -//! `cargo run --bin constraint-evaluation-generator` -//! in the root of the repository. - -#![warn(missing_debug_implementations)] -#![warn(missing_docs)] - -use proc_macro2::TokenStream; -use std::fs::write; - -use crate::codegen::Codegen; -use crate::codegen::RustBackend; -use crate::codegen::TasmBackend; -use crate::constraints::Constraints; - -mod codegen; -mod constraints; -mod substitution; - -fn main() { - let mut constraints = Constraints::all(); - let substitutions = constraints.lower_to_target_degree_through_substitutions(); - let degree_lowering_table_code = substitutions.generate_degree_lowering_table_code(); - - let constraints = constraints.combine_with_substitution_induced_constraints(substitutions); - let rust = RustBackend::constraint_evaluation_code(&constraints); - let tasm = TasmBackend::constraint_evaluation_code(&constraints); - - write_code_to_file( - degree_lowering_table_code, - "triton-vm/src/table/degree_lowering_table.rs", - ); - write_code_to_file(rust, "triton-vm/src/table/constraints.rs"); - write_code_to_file(tasm, "triton-vm/src/air/tasm_air_constraints.rs"); -} - -fn write_code_to_file(code: TokenStream, file_name: &str) { - let syntax_tree = syn::parse2(code).unwrap(); - let code = prettyplease::unparse(&syntax_tree); - write(file_name, code).unwrap(); -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_constraints_can_be_fetched() { - let _ = Constraints::test_constraints(); - } - - #[test] - fn degree_lowering_tables_code_can_be_generated_for_test_constraints() { - let mut constraints = Constraints::test_constraints(); - let substitutions = constraints.lower_to_target_degree_through_substitutions(); - let _ = substitutions.generate_degree_lowering_table_code(); - } - - #[test] - fn all_constraints_can_be_fetched() { - let _ = Constraints::all(); - } - - #[test] - fn degree_lowering_tables_code_can_be_generated_from_all_constraints() { - let mut constraints = Constraints::all(); - let substitutions = constraints.lower_to_target_degree_through_substitutions(); - let _ = substitutions.generate_degree_lowering_table_code(); - } - - #[test] - fn constraints_and_substitutions_can_be_combined() { - let mut constraints = Constraints::test_constraints(); - let substitutions = constraints.lower_to_target_degree_through_substitutions(); - let _ = constraints.combine_with_substitution_induced_constraints(substitutions); - } -} diff --git a/triton-vm/Cargo.toml b/triton-vm/Cargo.toml index 3654d65c4..c3e97e71d 100644 --- a/triton-vm/Cargo.toml +++ b/triton-vm/Cargo.toml @@ -28,14 +28,15 @@ lazy_static.workspace = true ndarray.workspace = true nom.workspace = true num-traits.workspace = true +prettyplease.workspace = true proc-macro2.workspace = true quote.workspace = true rand.workspace = true rand_core.workspace = true rayon.workspace = true serde.workspace = true -serde_derive.workspace = true strum.workspace = true +syn.workspace = true thiserror.workspace = true twenty-first.workspace = true unicode-width.workspace = true @@ -45,6 +46,7 @@ assert2.workspace = true cargo-husky.workspace = true fs-err.workspace = true pretty_assertions.workspace = true +prettyplease.workspace = true proptest.workspace = true proptest-arbitrary-interop.workspace = true serde_json.workspace = true diff --git a/triton-vm/src/table/constraint_circuit.rs b/triton-vm/src/codegen/circuit.rs similarity index 96% rename from triton-vm/src/table/constraint_circuit.rs rename to triton-vm/src/codegen/circuit.rs index 4c335df8b..338b84d41 100644 --- a/triton-vm/src/table/constraint_circuit.rs +++ b/triton-vm/src/codegen/circuit.rs @@ -1,10 +1,12 @@ -//! Constraint circuits are a way to represent constraint polynomials in a way that is amenable -//! to optimizations. The constraint circuit is a directed acyclic graph (DAG) of -//! [`CircuitExpression`]s, where each `CircuitExpression` is a node in the graph. The edges of the -//! graph are labeled with [`BinOp`]s. The leafs of the graph are the inputs to the constraint -//! polynomial, and the (multiple) roots of the graph are the outputs of all the -//! constraint polynomials, with each root corresponding to a different constraint polynomial. -//! Because the graph has multiple roots, it is called a “multitree.” +//! Constraint circuits are a way to represent constraint polynomials in a way +//! that is amenable to optimizations. The constraint circuit is a directed +//! acyclic graph (DAG) of [`CircuitExpression`]s, where each +//! `CircuitExpression` is a node in the graph. The edges of the graph are +//! labeled with [`BinOp`]s. The leafs of the graph are the inputs to the +//! constraint polynomial, and the (multiple) roots of the graph are the outputs +//! of all the constraint polynomials, with each root corresponding to a +//! different constraint polynomial. Because the graph has multiple roots, it is +//! called a “multitree.” use std::cell::RefCell; use std::cmp; @@ -102,7 +104,7 @@ pub trait InputIndicator: Debug + Display + Copy + Hash + Eq + ToTokens { /// The position of a variable in a constraint polynomial that operates on a single row of the /// execution trace. -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)] pub enum SingleRowIndicator { BaseRow(usize), ExtRow(usize), @@ -165,7 +167,7 @@ impl InputIndicator for SingleRowIndicator { /// The position of a variable in a constraint polynomial that operates on two rows (current and /// next) of the execution trace. -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)] pub enum DualRowIndicator { CurrentBaseRow(usize), CurrentExtRow(usize), @@ -301,10 +303,10 @@ impl Hash for CircuitExpression { impl PartialEq for CircuitExpression { fn eq(&self, other: &Self) -> bool { match (self, other) { - (Self::BConstant(bfe_self), Self::BConstant(bfe_other)) => bfe_self == bfe_other, - (Self::XConstant(xfe_self), Self::XConstant(xfe_other)) => xfe_self == xfe_other, - (Self::Input(input_self), Self::Input(input_other)) => input_self == input_other, - (Self::Challenge(id_self), Self::Challenge(id_other)) => id_self == id_other, + (Self::BConstant(b), Self::BConstant(b_o)) => b == b_o, + (Self::XConstant(x), Self::XConstant(x_o)) => x == x_o, + (Self::Input(i), Self::Input(i_o)) => i == i_o, + (Self::Challenge(c), Self::Challenge(c_o)) => c == c_o, (Self::BinaryOperation(op, l, r), Self::BinaryOperation(op_o, l_o, r_o)) => { op == op_o && l == l_o && r == r_o } @@ -429,9 +431,10 @@ impl ConstraintCircuit { let degree_lhs = lhs.borrow().degree(); let degree_rhs = rhs.borrow().degree(); let degree_additive = cmp::max(degree_lhs, degree_rhs); - let degree_multiplicative = match degree_lhs == -1 || degree_rhs == -1 { - true => -1, - false => degree_lhs + degree_rhs, + let degree_multiplicative = if cmp::min(degree_lhs, degree_rhs) <= -1 { + -1 + } else { + degree_lhs + degree_rhs }; match binop { BinOp::Add => degree_additive, @@ -1119,31 +1122,6 @@ fn random_circuit_leaf<'a, II: InputIndicator + Arbitrary<'a>>( Ok(leaf) } -impl<'a> Arbitrary<'a> for SingleRowIndicator { - fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result { - let col_idx = u.arbitrary()?; - let indicator = match u.arbitrary()? { - true => Self::BaseRow(col_idx), - false => Self::ExtRow(col_idx), - }; - Ok(indicator) - } -} - -impl<'a> Arbitrary<'a> for DualRowIndicator { - fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result { - let col_idx = u.arbitrary()?; - let indicator = match u.int_in_range(0..=3)? { - 0 => Self::CurrentBaseRow(col_idx), - 1 => Self::CurrentExtRow(col_idx), - 2 => Self::NextBaseRow(col_idx), - 3 => Self::NextExtRow(col_idx), - _ => unreachable!(), - }; - Ok(indicator) - } -} - #[cfg(test)] mod tests { use std::collections::hash_map::DefaultHasher; @@ -1159,20 +1137,39 @@ mod tests { use rand::SeedableRng; use test_strategy::proptest; + use crate::prelude::Claim; use crate::table::cascade_table::ExtCascadeTable; use crate::table::challenges::Challenges; - use crate::table::constraint_circuit::SingleRowIndicator::*; use crate::table::degree_lowering_table::DegreeLoweringTable; use crate::table::hash_table::ExtHashTable; use crate::table::jump_stack_table::ExtJumpStackTable; use crate::table::lookup_table::ExtLookupTable; - use crate::table::master_table::*; + use crate::table::master_table::AIR_TARGET_DEGREE; + use crate::table::master_table::CASCADE_TABLE_END; + use crate::table::master_table::EXT_CASCADE_TABLE_END; + use crate::table::master_table::EXT_HASH_TABLE_END; + use crate::table::master_table::EXT_JUMP_STACK_TABLE_END; + use crate::table::master_table::EXT_LOOKUP_TABLE_END; + use crate::table::master_table::EXT_OP_STACK_TABLE_END; + use crate::table::master_table::EXT_PROCESSOR_TABLE_END; + use crate::table::master_table::EXT_PROGRAM_TABLE_END; + use crate::table::master_table::EXT_RAM_TABLE_END; + use crate::table::master_table::EXT_U32_TABLE_END; + use crate::table::master_table::HASH_TABLE_END; + use crate::table::master_table::JUMP_STACK_TABLE_END; + use crate::table::master_table::LOOKUP_TABLE_END; + use crate::table::master_table::OP_STACK_TABLE_END; + use crate::table::master_table::PROCESSOR_TABLE_END; + use crate::table::master_table::PROGRAM_TABLE_END; + use crate::table::master_table::RAM_TABLE_END; + use crate::table::master_table::U32_TABLE_END; use crate::table::op_stack_table::ExtOpStackTable; use crate::table::processor_table::ExtProcessorTable; use crate::table::program_table::ExtProgramTable; use crate::table::ram_table::ExtRamTable; use crate::table::u32_table::ExtU32Table; - use crate::Claim; + use crate::table::NUM_BASE_COLUMNS; + use crate::table::NUM_EXT_COLUMNS; use super::*; @@ -1242,7 +1239,10 @@ mod tests { fn printing_constraint_circuit_gives_expected_strings() { let builder = ConstraintCircuitBuilder::new(); assert_eq!("1", builder.b_constant(1).to_string()); - assert_eq!("base_row[5] ", builder.input(BaseRow(5)).to_string()); + assert_eq!( + "base_row[5] ", + builder.input(SingleRowIndicator::BaseRow(5)).to_string() + ); assert_eq!("6", builder.challenge(6_usize).to_string()); let xfe_str = builder.x_constant([2, 3, 4]).to_string(); @@ -1312,7 +1312,7 @@ mod tests { #[test] fn substitution_replaces_a_node_in_a_circuit() { let builder = ConstraintCircuitBuilder::new(); - let x = |i| builder.input(BaseRow(i)); + let x = |i| builder.input(SingleRowIndicator::BaseRow(i)); let constant = |c: u32| builder.b_constant(c); let challenge = |i: usize| builder.challenge(i); @@ -1544,7 +1544,7 @@ mod tests { #[test] fn simple_degree_lowering() { let builder = ConstraintCircuitBuilder::new(); - let x = || builder.input(BaseRow(0)); + let x = || builder.input(SingleRowIndicator::BaseRow(0)); let x_pow_3 = x() * x() * x(); let x_pow_5 = x() * x() * x() * x() * x(); let mut multicircuit = [x_pow_5, x_pow_3]; @@ -1566,8 +1566,8 @@ mod tests { #[test] fn somewhat_simple_degree_lowering() { let builder = ConstraintCircuitBuilder::new(); - let x = |i| builder.input(BaseRow(i)); - let y = |i| builder.input(ExtRow(i)); + let x = |i| builder.input(SingleRowIndicator::BaseRow(i)); + let y = |i| builder.input(SingleRowIndicator::ExtRow(i)); let b_con = |i: u64| builder.b_constant(i); let constraint_0 = x(0) * x(0) * (x(1) - x(2)) - x(0) * x(2) - b_con(42); @@ -1596,7 +1596,7 @@ mod tests { #[test] fn less_simple_degree_lowering() { let builder = ConstraintCircuitBuilder::new(); - let x = |i| builder.input(BaseRow(i)); + let x = |i| builder.input(SingleRowIndicator::BaseRow(i)); let constraint_0 = (x(0) * x(1) * x(2)) * (x(3) * x(4)) * x(5); let constraint_1 = (x(6) * x(7)) * (x(3) * x(4)) * x(8); @@ -2123,7 +2123,7 @@ mod tests { fn all_nodes_in_multicircuit_are_identified_correctly() { let builder = ConstraintCircuitBuilder::new(); - let x = |i| builder.input(BaseRow(i)); + let x = |i| builder.input(SingleRowIndicator::BaseRow(i)); let b_con = |i: u64| builder.b_constant(i); let sub_tree_0 = x(0) * x(1) * (x(2) - b_con(1)) * x(3) * x(4); @@ -2219,7 +2219,7 @@ mod tests { fn equivalent_nodes_are_detected_when_present() { let builder = ConstraintCircuitBuilder::new(); - let x = |i| builder.input(BaseRow(i)); + let x = |i| builder.input(SingleRowIndicator::BaseRow(i)); let ch = |i: usize| builder.challenge(i); let u0 = x(0) + x(1); diff --git a/constraint-evaluation-generator/src/codegen/tasm.rs b/triton-vm/src/codegen/constraints.rs similarity index 56% rename from constraint-evaluation-generator/src/codegen/tasm.rs rename to triton-vm/src/codegen/constraints.rs index e6c5e51d2..2a7402e3c 100644 --- a/constraint-evaluation-generator/src/codegen/tasm.rs +++ b/triton-vm/src/codegen/constraints.rs @@ -1,32 +1,385 @@ +//! The various tables' constraints are very inefficient to evaluate if they live in RAM. +//! Instead, the build script turns them into rust code, which is then optimized by rustc. + use std::collections::HashSet; use itertools::Itertools; use proc_macro2::TokenStream; +use quote::format_ident; use quote::quote; use quote::ToTokens; -use twenty_first::prelude::bfe; use twenty_first::prelude::x_field_element::EXTENSION_DEGREE; -use twenty_first::prelude::BFieldElement; -use twenty_first::prelude::XFieldElement; +use twenty_first::prelude::*; + +use crate::instruction::Instruction; +use crate::op_stack::NumberOfWords; + +use crate::codegen::circuit::BinOp; +use crate::codegen::circuit::CircuitExpression; +use crate::codegen::circuit::ConstraintCircuit; +use crate::codegen::circuit::InputIndicator; +use crate::codegen::Constraints; -use triton_vm::air::memory_layout; -use triton_vm::instruction::Instruction; -use triton_vm::op_stack::NumberOfWords; -use triton_vm::table::constraint_circuit::BinOp; -use triton_vm::table::constraint_circuit::CircuitExpression; -use triton_vm::table::constraint_circuit::ConstraintCircuit; -use triton_vm::table::constraint_circuit::InputIndicator; +pub(crate) trait Codegen { + fn constraint_evaluation_code(constraints: &Constraints) -> TokenStream; + + fn tokenize_bfe(bfe: BFieldElement) -> TokenStream { + let raw_u64 = bfe.raw_u64(); + quote!(BFieldElement::from_raw_u64(#raw_u64)) + } -use crate::codegen::Codegen; -use crate::codegen::TasmBackend; -use crate::constraints::Constraints; + fn tokenize_xfe(xfe: XFieldElement) -> TokenStream { + let [c_0, c_1, c_2] = xfe.coefficients.map(Self::tokenize_bfe); + quote!(XFieldElement::new([#c_0, #c_1, #c_2])) + } +} + +#[derive(Debug, Default, Clone, Eq, PartialEq)] +pub(crate) struct RustBackend { + /// All [circuit] IDs known to be in scope. + /// + /// [circuit]: triton_vm::table::circuit::ConstraintCircuit + scope: HashSet, +} + +#[derive(Debug, Default, Clone, Eq, PartialEq)] +pub(crate) struct TasmBackend { + /// All [circuit] IDs known to be processed and stored to memory. + /// + /// [circuit]: triton_vm::table::circuit::ConstraintCircuit + scope: HashSet, + + /// The number of elements written to the output list. + elements_written: usize, + + /// Whether the code that is to be generated can assume statically provided + /// addresses for the various input arrays. + input_location_is_static: bool, +} + +impl Codegen for RustBackend { + fn constraint_evaluation_code(constraints: &Constraints) -> TokenStream { + let num_init_constraints = constraints.init.len(); + let num_cons_constraints = constraints.cons.len(); + let num_tran_constraints = constraints.tran.len(); + let num_term_constraints = constraints.term.len(); + + let (init_constraint_degrees, init_constraints_bfe, init_constraints_xfe) = + Self::tokenize_circuits(&constraints.init()); + let (cons_constraint_degrees, cons_constraints_bfe, cons_constraints_xfe) = + Self::tokenize_circuits(&constraints.cons()); + let (tran_constraint_degrees, tran_constraints_bfe, tran_constraints_xfe) = + Self::tokenize_circuits(&constraints.tran()); + let (term_constraint_degrees, term_constraints_bfe, term_constraints_xfe) = + Self::tokenize_circuits(&constraints.term()); + + let uses = Self::uses(); + let evaluable_over_base_field = Self::generate_evaluable_implementation_over_field( + &init_constraints_bfe, + &cons_constraints_bfe, + &tran_constraints_bfe, + &term_constraints_bfe, + quote!(BFieldElement), + ); + let evaluable_over_ext_field = Self::generate_evaluable_implementation_over_field( + &init_constraints_xfe, + &cons_constraints_xfe, + &tran_constraints_xfe, + &term_constraints_xfe, + quote!(XFieldElement), + ); + + let quotient_trait_impl = quote!( + impl Quotientable for MasterExtTable { + const NUM_INITIAL_CONSTRAINTS: usize = #num_init_constraints; + const NUM_CONSISTENCY_CONSTRAINTS: usize = #num_cons_constraints; + const NUM_TRANSITION_CONSTRAINTS: usize = #num_tran_constraints; + const NUM_TERMINAL_CONSTRAINTS: usize = #num_term_constraints; + + #[allow(unused_variables)] + fn initial_quotient_degree_bounds(interpolant_degree: isize) -> Vec { + let zerofier_degree = 1; + [#init_constraint_degrees].to_vec() + } + + #[allow(unused_variables)] + fn consistency_quotient_degree_bounds( + interpolant_degree: isize, + padded_height: usize, + ) -> Vec { + let zerofier_degree = padded_height as isize; + [#cons_constraint_degrees].to_vec() + } + + #[allow(unused_variables)] + fn transition_quotient_degree_bounds( + interpolant_degree: isize, + padded_height: usize, + ) -> Vec { + let zerofier_degree = padded_height as isize - 1; + [#tran_constraint_degrees].to_vec() + } + + #[allow(unused_variables)] + fn terminal_quotient_degree_bounds(interpolant_degree: isize) -> Vec { + let zerofier_degree = 1; + [#term_constraint_degrees].to_vec() + } + } + ); + + quote!( + #uses + #evaluable_over_base_field + #evaluable_over_ext_field + #quotient_trait_impl + ) + } +} + +impl RustBackend { + fn uses() -> TokenStream { + quote!( + use ndarray::ArrayView1; + use twenty_first::prelude::BFieldElement; + use twenty_first::prelude::XFieldElement; + + use crate::table::challenges::Challenges; + use crate::table::extension_table::Evaluable; + use crate::table::extension_table::Quotientable; + use crate::table::master_table::MasterExtTable; + ) + } + + fn generate_evaluable_implementation_over_field( + init_constraints: &TokenStream, + cons_constraints: &TokenStream, + tran_constraints: &TokenStream, + term_constraints: &TokenStream, + field: TokenStream, + ) -> TokenStream { + quote!( + impl Evaluable<#field> for MasterExtTable { + #[allow(unused_variables)] + fn evaluate_initial_constraints( + base_row: ArrayView1<#field>, + ext_row: ArrayView1, + challenges: &Challenges, + ) -> Vec { + #init_constraints + } + + #[allow(unused_variables)] + fn evaluate_consistency_constraints( + base_row: ArrayView1<#field>, + ext_row: ArrayView1, + challenges: &Challenges, + ) -> Vec { + #cons_constraints + } + + #[allow(unused_variables)] + fn evaluate_transition_constraints( + current_base_row: ArrayView1<#field>, + current_ext_row: ArrayView1, + next_base_row: ArrayView1<#field>, + next_ext_row: ArrayView1, + challenges: &Challenges, + ) -> Vec { + #tran_constraints + } + + #[allow(unused_variables)] + fn evaluate_terminal_constraints( + base_row: ArrayView1<#field>, + ext_row: ArrayView1, + challenges: &Challenges, + ) -> Vec { + #term_constraints + } + } + ) + } + + /// Return a tuple of [`TokenStream`]s corresponding to code evaluating these constraints as + /// well as their degrees. In particular: + /// 1. The first stream contains code that, when evaluated, produces the constraints' degrees, + /// 1. the second stream contains code that, when evaluated, produces the constraints' values, + /// with the input type for the base row being `BFieldElement`, and + /// 1. the third stream is like the second, except that the input type for the base row is + /// `XFieldElement`. + fn tokenize_circuits( + constraints: &[ConstraintCircuit], + ) -> (TokenStream, TokenStream, TokenStream) { + if constraints.is_empty() { + return (quote!(), quote!(vec![]), quote!(vec![])); + } + + let mut backend = Self::default(); + let shared_declarations = backend.declare_shared_nodes(constraints); + let (base_constraints, ext_constraints): (Vec<_>, Vec<_>) = constraints + .iter() + .partition(|constraint| constraint.evaluates_to_base_element()); + + // The order of the constraints' degrees must match the order of the constraints. + // Hence, listing the degrees is only possible after the partition into base and extension + // constraints is known. + let tokenized_degree_bounds = base_constraints + .iter() + .chain(&ext_constraints) + .map(|circuit| match circuit.degree() { + d if d > 1 => quote!(interpolant_degree * #d - zerofier_degree), + 1 => quote!(interpolant_degree - zerofier_degree), + _ => panic!("Constraint degree must be positive"), + }) + .collect_vec(); + let tokenized_degree_bounds = quote!(#(#tokenized_degree_bounds),*); + + let tokenize_constraint_evaluation = |constraints: &[&ConstraintCircuit]| { + constraints + .iter() + .map(|constraint| backend.evaluate_single_node(constraint)) + .collect_vec() + }; + let tokenized_base_constraints = tokenize_constraint_evaluation(&base_constraints); + let tokenized_ext_constraints = tokenize_constraint_evaluation(&ext_constraints); + + // If there are no base constraints, the type needs to be explicitly declared. + let tokenized_bfe_base_constraints = match base_constraints.is_empty() { + true => quote!(let base_constraints: [BFieldElement; 0] = []), + false => quote!(let base_constraints = [#(#tokenized_base_constraints),*]), + }; + let tokenized_bfe_constraints = quote!( + #(#shared_declarations)* + #tokenized_bfe_base_constraints; + let ext_constraints = [#(#tokenized_ext_constraints),*]; + base_constraints + .into_iter() + .map(|bfe| bfe.lift()) + .chain(ext_constraints) + .collect() + ); + + let tokenized_xfe_constraints = quote!( + #(#shared_declarations)* + let base_constraints = [#(#tokenized_base_constraints),*]; + let ext_constraints = [#(#tokenized_ext_constraints),*]; + base_constraints + .into_iter() + .chain(ext_constraints) + .collect() + ); + + ( + tokenized_degree_bounds, + tokenized_bfe_constraints, + tokenized_xfe_constraints, + ) + } + + /// Declare all shared variables, i.e., those with a ref count greater than 1. + /// These declarations must be made starting from the highest ref count. + /// Otherwise, the resulting code will refer to bindings that have not yet been made. + fn declare_shared_nodes( + &mut self, + constraints: &[ConstraintCircuit], + ) -> Vec { + let constraints_iter = constraints.iter(); + let all_ref_counts = constraints_iter.flat_map(ConstraintCircuit::all_ref_counters); + let relevant_ref_counts = all_ref_counts.unique().filter(|&x| x > 1); + let ordered_ref_counts = relevant_ref_counts.sorted().rev(); + + ordered_ref_counts + .map(|count| self.declare_nodes_with_ref_count(constraints, count)) + .collect() + } + + /// Produce the code to evaluate code for all nodes that share a ref count. + fn declare_nodes_with_ref_count( + &mut self, + circuits: &[ConstraintCircuit], + ref_count: usize, + ) -> TokenStream { + let all_nodes_in_circuit = + |circuit| self.declare_single_node_with_ref_count(circuit, ref_count); + let tokenized_circuits = circuits.iter().filter_map(all_nodes_in_circuit); + quote!(#(#tokenized_circuits)*) + } + + fn declare_single_node_with_ref_count( + &mut self, + circuit: &ConstraintCircuit, + ref_count: usize, + ) -> Option { + if self.scope.contains(&circuit.id) { + return None; + } + + // constants can be declared trivially + let CircuitExpression::BinaryOperation(_, lhs, rhs) = &circuit.expression else { + return None; + }; + + if circuit.ref_count < ref_count { + let out_left = self.declare_single_node_with_ref_count(&lhs.borrow(), ref_count); + let out_right = self.declare_single_node_with_ref_count(&rhs.borrow(), ref_count); + return match (out_left, out_right) { + (None, None) => None, + (Some(l), None) => Some(l), + (None, Some(r)) => Some(r), + (Some(l), Some(r)) => Some(quote!(#l #r)), + }; + } + + assert_eq!(circuit.ref_count, ref_count); + let binding_name = Self::binding_name(circuit); + let evaluation = self.evaluate_single_node(circuit); + let new_binding = quote!(let #binding_name = #evaluation;); + + let is_new_insertion = self.scope.insert(circuit.id); + assert!(is_new_insertion); + + Some(new_binding) + } + + /// Recursively construct the code for evaluating a single node. + pub fn evaluate_single_node( + &self, + circuit: &ConstraintCircuit, + ) -> TokenStream { + if self.scope.contains(&circuit.id) { + return Self::binding_name(circuit); + } + + let CircuitExpression::BinaryOperation(binop, lhs, rhs) = &circuit.expression else { + return Self::binding_name(circuit); + }; + + let lhs = self.evaluate_single_node(&lhs.borrow()); + let rhs = self.evaluate_single_node(&rhs.borrow()); + quote!((#lhs) #binop (#rhs)) + } + + fn binding_name(circuit: &ConstraintCircuit) -> TokenStream { + match &circuit.expression { + CircuitExpression::BConstant(bfe) => Self::tokenize_bfe(*bfe), + CircuitExpression::XConstant(xfe) => Self::tokenize_xfe(*xfe), + CircuitExpression::Input(idx) => quote!(#idx), + CircuitExpression::Challenge(challenge) => quote!(challenges[#challenge]), + CircuitExpression::BinaryOperation(_, _, _) => { + let node_ident = format_ident!("node_{}", circuit.id); + quote!(#node_ident) + } + } + } +} /// An offset from the [memory layout][layout]'s `free_mem_page_ptr`, in number of /// [`XFieldElement`]s. Indicates the start of the to-be-returned array. /// /// [layout]: memory_layout::IntegralMemoryLayout const OUT_ARRAY_OFFSET: usize = { - let mem_page_size = memory_layout::MEM_PAGE_SIZE; + let mem_page_size = crate::air::memory_layout::MEM_PAGE_SIZE; let max_num_words_for_evaluated_constraints = 1 << 16; // magic! let out_array_offset_in_words = mem_page_size - max_num_words_for_evaluated_constraints; assert!(out_array_offset_in_words % EXTENSION_DEGREE == 0); @@ -541,17 +894,61 @@ impl ToTokens for IOList { #[cfg(test)] mod tests { - use crate::codegen::tests::print_constraints; + use crate::codegen::circuit::ConstraintCircuitBuilder; + use crate::codegen::circuit::SingleRowIndicator; + use twenty_first::prelude::*; use super::*; + pub(crate) fn mini_constraints() -> Constraints { + let circuit_builder = ConstraintCircuitBuilder::new(); + let challenge = |c: usize| circuit_builder.challenge(c); + let constant = |c: u32| circuit_builder.x_constant(c); + let base_row = |i| circuit_builder.input(SingleRowIndicator::BaseRow(i)); + let ext_row = |i| circuit_builder.input(SingleRowIndicator::ExtRow(i)); + + let constraint = base_row(0) * challenge(3) - ext_row(1) * constant(42); + + Constraints { + init: vec![constraint], + cons: vec![], + tran: vec![], + term: vec![], + } + } + + pub fn print_constraints(constraints: &Constraints) { + let code = B::constraint_evaluation_code(constraints); + let syntax_tree = syn::parse2(code).unwrap(); + let code = prettyplease::unparse(&syntax_tree); + println!("{code}"); + } + + #[test] + fn tokenizing_base_field_elements_produces_expected_result() { + let bfe = bfe!(42); + let expected = "BFieldElement :: from_raw_u64 (180388626390u64)"; + assert_eq!(expected, RustBackend::tokenize_bfe(bfe).to_string()); + } + + #[test] + fn tokenizing_extension_field_elements_produces_expected_result() { + let xfe = xfe!([42, 43, 44]); + let expected = "XFieldElement :: new ([\ + BFieldElement :: from_raw_u64 (180388626390u64) , \ + BFieldElement :: from_raw_u64 (184683593685u64) , \ + BFieldElement :: from_raw_u64 (188978560980u64)\ + ])"; + assert_eq!(expected, RustBackend::tokenize_xfe(xfe).to_string()); + } + #[test] - fn print_mini_constraints() { - print_constraints::(&Constraints::mini_constraints()); + fn print_mini_constraints_rust() { + print_constraints::(&mini_constraints()); } #[test] - fn print_test_constraints() { - print_constraints::(&Constraints::test_constraints()); + fn print_mini_constraints_tasm() { + print_constraints::(&mini_constraints()); } } diff --git a/triton-vm/src/codegen/mod.rs b/triton-vm/src/codegen/mod.rs new file mode 100644 index 000000000..526c29273 --- /dev/null +++ b/triton-vm/src/codegen/mod.rs @@ -0,0 +1,270 @@ +use arbitrary::Arbitrary; +use itertools::Itertools; +use proc_macro2::TokenStream; +use std::fs::write; + +use crate::codegen::circuit::ConstraintCircuit; +use crate::codegen::circuit::ConstraintCircuitMonad; +use crate::codegen::circuit::DualRowIndicator; +use crate::codegen::circuit::InputIndicator; +use crate::codegen::circuit::SingleRowIndicator; +use crate::codegen::constraints::Codegen; +use crate::codegen::constraints::RustBackend; +use crate::codegen::constraints::TasmBackend; +use crate::codegen::substitutions::AllSubstitutions; +use crate::codegen::substitutions::Substitutions; + +pub(crate) mod circuit; +mod constraints; +mod substitutions; + +pub fn gen(mut constraints: Constraints, info: DegreeLoweringInfo) { + let substitutions = constraints.lower_to_target_degree_through_substitutions(info); + let degree_lowering_table_code = substitutions.generate_degree_lowering_table_code(); + + let constraints = constraints.combine_with_substitution_induced_constraints(substitutions); + let rust = RustBackend::constraint_evaluation_code(&constraints); + let tasm = TasmBackend::constraint_evaluation_code(&constraints); + + write_code_to_file( + degree_lowering_table_code, + "triton-vm/src/table/degree_lowering_table.rs", + ); + write_code_to_file(rust, "triton-vm/src/table/constraints.rs"); + write_code_to_file(tasm, "triton-vm/src/air/tasm_air_constraints.rs"); +} + +fn write_code_to_file(code: TokenStream, file_name: &str) { + let syntax_tree = syn::parse2(code).unwrap(); + let code = prettyplease::unparse(&syntax_tree); + write(file_name, code).unwrap(); +} + +#[derive(Debug, Clone)] +pub(crate) struct Constraints { + pub init: Vec>, + pub cons: Vec>, + pub tran: Vec>, + pub term: Vec>, +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +pub(crate) struct DegreeLoweringInfo { + pub target_degree: isize, + + /// The total number of base columns _before_ degree lowering has happened. + pub num_base_cols: usize, + + /// The total number of extension columns _before_ degree lowering has happened. + pub num_ext_cols: usize, +} + +impl Constraints { + pub fn lower_to_target_degree_through_substitutions( + &mut self, + info: DegreeLoweringInfo, + ) -> AllSubstitutions { + // Subtract the degree lowering table's width from the total number of columns to guarantee + // the same number of columns even for repeated runs of the constraint evaluation generator. + let mut num_base_cols = info.num_base_cols; + let mut num_ext_cols = info.num_ext_cols; + let (init_base_substitutions, init_ext_substitutions) = + ConstraintCircuitMonad::lower_to_degree( + &mut self.init, + info.target_degree, + num_base_cols, + num_ext_cols, + ); + num_base_cols += init_base_substitutions.len(); + num_ext_cols += init_ext_substitutions.len(); + + let (cons_base_substitutions, cons_ext_substitutions) = + ConstraintCircuitMonad::lower_to_degree( + &mut self.cons, + info.target_degree, + num_base_cols, + num_ext_cols, + ); + num_base_cols += cons_base_substitutions.len(); + num_ext_cols += cons_ext_substitutions.len(); + + let (tran_base_substitutions, tran_ext_substitutions) = + ConstraintCircuitMonad::lower_to_degree( + &mut self.tran, + info.target_degree, + num_base_cols, + num_ext_cols, + ); + num_base_cols += tran_base_substitutions.len(); + num_ext_cols += tran_ext_substitutions.len(); + + let (term_base_substitutions, term_ext_substitutions) = + ConstraintCircuitMonad::lower_to_degree( + &mut self.term, + info.target_degree, + num_base_cols, + num_ext_cols, + ); + + AllSubstitutions { + base: Substitutions { + lowering_info: info, + init: init_base_substitutions, + cons: cons_base_substitutions, + tran: tran_base_substitutions, + term: term_base_substitutions, + }, + ext: Substitutions { + lowering_info: info, + init: init_ext_substitutions, + cons: cons_ext_substitutions, + tran: tran_ext_substitutions, + term: term_ext_substitutions, + }, + } + } + + #[must_use] + pub fn combine_with_substitution_induced_constraints( + self, + AllSubstitutions { base, ext }: AllSubstitutions, + ) -> Self { + Self { + init: [self.init, base.init, ext.init].concat(), + cons: [self.cons, base.cons, ext.cons].concat(), + tran: [self.tran, base.tran, ext.tran].concat(), + term: [self.term, base.term, ext.term].concat(), + } + } + + pub fn init(&self) -> Vec> { + Self::consume(&self.init) + } + + pub fn cons(&self) -> Vec> { + Self::consume(&self.cons) + } + + pub fn tran(&self) -> Vec> { + Self::consume(&self.tran) + } + + pub fn term(&self) -> Vec> { + Self::consume(&self.term) + } + + fn consume( + constraints: &[ConstraintCircuitMonad], + ) -> Vec> { + let mut constraints = constraints.iter().map(|c| c.consume()).collect_vec(); + ConstraintCircuit::assert_unique_ids(&mut constraints); + constraints + } +} + +#[cfg(test)] +mod tests { + use twenty_first::prelude::*; + + use crate::codegen::circuit::ConstraintCircuitBuilder; + use crate::table; + + use super::*; + + impl Default for DegreeLoweringInfo { + /// For testing purposes only. + fn default() -> Self { + Self { + target_degree: 4, + num_base_cols: 42, + num_ext_cols: 13, + } + } + } + + #[repr(usize)] + enum TestChallenges { + Ch0, + Ch1, + Ch2, + } + + impl From for usize { + fn from(challenge: TestChallenges) -> Self { + challenge as usize + } + } + + #[test] + fn test_constraints_can_be_fetched() { + let _ = Constraints::test_constraints(); + } + + #[test] + fn degree_lowering_tables_code_can_be_generated_for_test_constraints() { + let lowering_info = DegreeLoweringInfo::default(); + let mut constraints = Constraints::test_constraints(); + let substitutions = constraints.lower_to_target_degree_through_substitutions(lowering_info); + let _ = substitutions.generate_degree_lowering_table_code(); + } + + #[test] + fn degree_lowering_tables_code_can_be_generated_from_all_constraints() { + let lowering_info = DegreeLoweringInfo::default(); + let mut constraints = table::constraints(); + let substitutions = constraints.lower_to_target_degree_through_substitutions(lowering_info); + let _ = substitutions.generate_degree_lowering_table_code(); + } + + #[test] + fn constraints_and_substitutions_can_be_combined() { + let mut constraints = Constraints::test_constraints(); + let substitutions = + constraints.lower_to_target_degree_through_substitutions(DegreeLoweringInfo::default()); + let _ = constraints.combine_with_substitution_induced_constraints(substitutions); + } + + impl Constraints { + /// For testing purposes only. There is no meaning behind any of the constraints. + pub(crate) fn test_constraints() -> Self { + Self { + init: Self::small_init_constraints(), + cons: vec![], + tran: Self::small_transition_constraints(), + term: vec![], + } + } + + fn small_init_constraints() -> Vec> { + let circuit_builder = ConstraintCircuitBuilder::new(); + let challenge = |c| circuit_builder.challenge(c); + let constant = |c: u32| circuit_builder.b_constant(bfe!(c)); + let input = |i| circuit_builder.input(SingleRowIndicator::BaseRow(i)); + let input_to_the_4th = |i| input(i) * input(i) * input(i) * input(i); + + vec![ + input(0) * input(1) - input(2), + input_to_the_4th(0) - challenge(TestChallenges::Ch1) - constant(16), + input(2) * input_to_the_4th(0) - input_to_the_4th(1), + ] + } + + fn small_transition_constraints() -> Vec> { + let circuit_builder = ConstraintCircuitBuilder::new(); + let challenge = |c| circuit_builder.challenge(c); + let constant = |c: u32| circuit_builder.x_constant(c); + + let curr_b_row = |col| circuit_builder.input(DualRowIndicator::CurrentBaseRow(col)); + let next_b_row = |col| circuit_builder.input(DualRowIndicator::NextBaseRow(col)); + let curr_x_row = |col| circuit_builder.input(DualRowIndicator::CurrentExtRow(col)); + let next_x_row = |col| circuit_builder.input(DualRowIndicator::NextExtRow(col)); + + vec![ + curr_b_row(0) * next_x_row(1) - next_b_row(1) * curr_x_row(0), + curr_b_row(1) * next_x_row(2) - next_b_row(2) * curr_x_row(1), + curr_b_row(2) * next_x_row(0) * next_x_row(1) * next_x_row(3) + constant(42), + curr_b_row(0) * challenge(TestChallenges::Ch0) - challenge(TestChallenges::Ch1), + ] + } + } +} diff --git a/constraint-evaluation-generator/src/substitution.rs b/triton-vm/src/codegen/substitutions.rs similarity index 94% rename from constraint-evaluation-generator/src/substitution.rs rename to triton-vm/src/codegen/substitutions.rs index b7cce408b..769e73d27 100644 --- a/constraint-evaluation-generator/src/substitution.rs +++ b/triton-vm/src/codegen/substitutions.rs @@ -3,17 +3,15 @@ use proc_macro2::TokenStream; use quote::format_ident; use quote::quote; -use triton_vm::table; -use triton_vm::table::constraint_circuit::BinOp; -use triton_vm::table::constraint_circuit::CircuitExpression; -use triton_vm::table::constraint_circuit::ConstraintCircuit; -use triton_vm::table::constraint_circuit::ConstraintCircuitMonad; -use triton_vm::table::constraint_circuit::DualRowIndicator; -use triton_vm::table::constraint_circuit::InputIndicator; -use triton_vm::table::constraint_circuit::SingleRowIndicator; -use triton_vm::table::degree_lowering_table; - -use crate::codegen::RustBackend; +use crate::codegen::circuit::BinOp; +use crate::codegen::circuit::CircuitExpression; +use crate::codegen::circuit::ConstraintCircuit; +use crate::codegen::circuit::ConstraintCircuitMonad; +use crate::codegen::circuit::DualRowIndicator; +use crate::codegen::circuit::InputIndicator; +use crate::codegen::circuit::SingleRowIndicator; +use crate::codegen::constraints::RustBackend; +use crate::codegen::DegreeLoweringInfo; pub(crate) struct AllSubstitutions { pub base: Substitutions, @@ -21,6 +19,7 @@ pub(crate) struct AllSubstitutions { } pub(crate) struct Substitutions { + pub lowering_info: DegreeLoweringInfo, pub init: Vec>, pub cons: Vec>, pub tran: Vec>, @@ -59,7 +58,7 @@ impl AllSubstitutions { quote!( //! The degree lowering table contains the introduced variables that allow //! lowering the degree of the AIR. See - //! [`crate::table::master_table::AIR_TARGET_DEGREE`] + //! [`table::master_table::AIR_TARGET_DEGREE`] //! for additional information. //! //! This file has been auto-generated. Any modifications _will_ be lost. @@ -115,8 +114,7 @@ impl Substitutions { } fn generate_fill_base_columns_code(&self) -> TokenStream { - let derived_section_init_start = - table::NUM_BASE_COLUMNS - degree_lowering_table::BASE_WIDTH; + let derived_section_init_start = self.lowering_info.num_base_cols; let derived_section_cons_start = derived_section_init_start + self.init.len(); let derived_section_tran_start = derived_section_cons_start + self.cons.len(); let derived_section_term_start = derived_section_tran_start + self.tran.len(); @@ -148,7 +146,7 @@ impl Substitutions { } fn generate_fill_ext_columns_code(&self) -> TokenStream { - let derived_section_init_start = table::NUM_EXT_COLUMNS - degree_lowering_table::EXT_WIDTH; + let derived_section_init_start = self.lowering_info.num_ext_cols; let derived_section_cons_start = derived_section_init_start + self.init.len(); let derived_section_tran_start = derived_section_cons_start + self.cons.len(); let derived_section_term_start = derived_section_tran_start + self.tran.len(); diff --git a/triton-vm/src/instruction.rs b/triton-vm/src/instruction.rs index 777070b96..7560e666c 100644 --- a/triton-vm/src/instruction.rs +++ b/triton-vm/src/instruction.rs @@ -10,8 +10,8 @@ use get_size::GetSize; use itertools::Itertools; use lazy_static::lazy_static; use num_traits::ConstZero; -use serde_derive::Deserialize; -use serde_derive::Serialize; +use serde::Deserialize; +use serde::Serialize; use strum::EnumCount; use strum::EnumIter; use strum::IntoEnumIterator; diff --git a/triton-vm/src/lib.rs b/triton-vm/src/lib.rs index 141768594..e72d8c51b 100644 --- a/triton-vm/src/lib.rs +++ b/triton-vm/src/lib.rs @@ -169,6 +169,7 @@ use crate::prelude::*; pub mod aet; pub mod air; pub mod arithmetic_domain; +mod codegen; pub mod config; pub mod error; pub mod example_programs; @@ -645,9 +646,6 @@ mod tests { implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); diff --git a/triton-vm/src/op_stack.rs b/triton-vm/src/op_stack.rs index ae45c40a7..4a4e033e5 100644 --- a/triton-vm/src/op_stack.rs +++ b/triton-vm/src/op_stack.rs @@ -7,7 +7,8 @@ use std::ops::IndexMut; use arbitrary::Arbitrary; use get_size::GetSize; use itertools::Itertools; -use serde_derive::*; +use serde::Deserialize; +use serde::Serialize; use strum::EnumCount; use strum::EnumIter; use strum::IntoEnumIterator; diff --git a/triton-vm/src/program.rs b/triton-vm/src/program.rs index 57d96d796..a50734bac 100644 --- a/triton-vm/src/program.rs +++ b/triton-vm/src/program.rs @@ -13,8 +13,8 @@ use std::ops::Sub; use arbitrary::Arbitrary; use get_size::GetSize; use itertools::Itertools; -use serde_derive::Deserialize; -use serde_derive::Serialize; +use serde::Deserialize; +use serde::Serialize; use twenty_first::prelude::*; use crate::aet::AlgebraicExecutionTrace; diff --git a/triton-vm/src/stark.rs b/triton-vm/src/stark.rs index 653bc5144..60110b90d 100644 --- a/triton-vm/src/stark.rs +++ b/triton-vm/src/stark.rs @@ -1325,6 +1325,7 @@ pub(crate) mod tests { use test_strategy::proptest; use twenty_first::math::other::random_elements; + use crate::codegen::circuit::ConstraintCircuitBuilder; use crate::error::InstructionError; use crate::example_programs::*; use crate::instruction::Instruction; @@ -1334,7 +1335,6 @@ pub(crate) mod tests { use crate::table::cascade_table::ExtCascadeTable; use crate::table::challenges::ChallengeId::StandardInputIndeterminate; use crate::table::challenges::ChallengeId::StandardOutputIndeterminate; - use crate::table::constraint_circuit::ConstraintCircuitBuilder; use crate::table::cross_table_argument::CrossTableArg; use crate::table::cross_table_argument::EvalArg; use crate::table::cross_table_argument::GrandCrossTableArg; diff --git a/triton-vm/src/table.rs b/triton-vm/src/table.rs index 8c84eea01..f3254a602 100644 --- a/triton-vm/src/table.rs +++ b/triton-vm/src/table.rs @@ -1,19 +1,31 @@ -use std::fmt::Display; -use std::fmt::Formatter; -use std::fmt::Result as FmtResult; +pub use crate::stark::NUM_QUOTIENT_SEGMENTS; +pub use crate::table::master_table::NUM_BASE_COLUMNS; +pub use crate::table::master_table::NUM_EXT_COLUMNS; use arbitrary::Arbitrary; +use strum::Display; use strum::EnumCount; use strum::EnumIter; use twenty_first::prelude::XFieldElement; -pub use crate::stark::NUM_QUOTIENT_SEGMENTS; -pub use crate::table::master_table::NUM_BASE_COLUMNS; -pub use crate::table::master_table::NUM_EXT_COLUMNS; +use crate::codegen::circuit::ConstraintCircuitBuilder; +use crate::codegen::circuit::ConstraintCircuitMonad; +use crate::codegen::circuit::DualRowIndicator; +use crate::codegen::circuit::SingleRowIndicator; +use crate::codegen::Constraints; +use crate::table::cascade_table::ExtCascadeTable; +use crate::table::cross_table_argument::GrandCrossTableArg; +use crate::table::hash_table::ExtHashTable; +use crate::table::jump_stack_table::ExtJumpStackTable; +use crate::table::lookup_table::ExtLookupTable; +use crate::table::op_stack_table::ExtOpStackTable; +use crate::table::processor_table::ExtProcessorTable; +use crate::table::program_table::ExtProgramTable; +use crate::table::ram_table::ExtRamTable; +use crate::table::u32_table::ExtU32Table; pub mod cascade_table; pub mod challenges; -pub mod constraint_circuit; #[rustfmt::skip] pub mod constraints; pub mod cross_table_argument; @@ -31,7 +43,9 @@ pub mod ram_table; pub mod table_column; pub mod u32_table; -#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, EnumCount, EnumIter)] +#[derive( + Debug, Display, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, EnumCount, EnumIter, +)] pub enum ConstraintType { /// Pertains only to the first row of the execution trace. Initial, @@ -46,17 +60,6 @@ pub enum ConstraintType { Terminal, } -impl Display for ConstraintType { - fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - match self { - ConstraintType::Initial => write!(f, "initial"), - ConstraintType::Consistency => write!(f, "consistency"), - ConstraintType::Transition => write!(f, "transition"), - ConstraintType::Terminal => write!(f, "terminal"), - } - } -} - /// A single row of a [`MasterBaseTable`][table]. /// /// Usually, the elements in the table are [`BFieldElement`][bfe]s. For out-of-domain rows, which is @@ -76,3 +79,80 @@ pub type ExtensionRow = [XFieldElement; NUM_EXT_COLUMNS]; /// /// See also [`NUM_QUOTIENT_SEGMENTS`]. pub type QuotientSegments = [XFieldElement; NUM_QUOTIENT_SEGMENTS]; + +pub fn constraints() -> Constraints { + Constraints { + init: initial_constraints(), + cons: consistency_constraints(), + tran: transition_constraints(), + term: terminal_constraints(), + } +} + +fn initial_constraints() -> Vec> { + let circuit_builder = ConstraintCircuitBuilder::new(); + vec![ + ExtProgramTable::initial_constraints(&circuit_builder), + ExtProcessorTable::initial_constraints(&circuit_builder), + ExtOpStackTable::initial_constraints(&circuit_builder), + ExtRamTable::initial_constraints(&circuit_builder), + ExtJumpStackTable::initial_constraints(&circuit_builder), + ExtHashTable::initial_constraints(&circuit_builder), + ExtCascadeTable::initial_constraints(&circuit_builder), + ExtLookupTable::initial_constraints(&circuit_builder), + ExtU32Table::initial_constraints(&circuit_builder), + GrandCrossTableArg::initial_constraints(&circuit_builder), + ] + .concat() +} + +fn consistency_constraints() -> Vec> { + let circuit_builder = ConstraintCircuitBuilder::new(); + vec![ + ExtProgramTable::consistency_constraints(&circuit_builder), + ExtProcessorTable::consistency_constraints(&circuit_builder), + ExtOpStackTable::consistency_constraints(&circuit_builder), + ExtRamTable::consistency_constraints(&circuit_builder), + ExtJumpStackTable::consistency_constraints(&circuit_builder), + ExtHashTable::consistency_constraints(&circuit_builder), + ExtCascadeTable::consistency_constraints(&circuit_builder), + ExtLookupTable::consistency_constraints(&circuit_builder), + ExtU32Table::consistency_constraints(&circuit_builder), + GrandCrossTableArg::consistency_constraints(&circuit_builder), + ] + .concat() +} + +fn transition_constraints() -> Vec> { + let circuit_builder = ConstraintCircuitBuilder::new(); + vec![ + ExtProgramTable::transition_constraints(&circuit_builder), + ExtProcessorTable::transition_constraints(&circuit_builder), + ExtOpStackTable::transition_constraints(&circuit_builder), + ExtRamTable::transition_constraints(&circuit_builder), + ExtJumpStackTable::transition_constraints(&circuit_builder), + ExtHashTable::transition_constraints(&circuit_builder), + ExtCascadeTable::transition_constraints(&circuit_builder), + ExtLookupTable::transition_constraints(&circuit_builder), + ExtU32Table::transition_constraints(&circuit_builder), + GrandCrossTableArg::transition_constraints(&circuit_builder), + ] + .concat() +} + +fn terminal_constraints() -> Vec> { + let circuit_builder = ConstraintCircuitBuilder::new(); + vec![ + ExtProgramTable::terminal_constraints(&circuit_builder), + ExtProcessorTable::terminal_constraints(&circuit_builder), + ExtOpStackTable::terminal_constraints(&circuit_builder), + ExtRamTable::terminal_constraints(&circuit_builder), + ExtJumpStackTable::terminal_constraints(&circuit_builder), + ExtHashTable::terminal_constraints(&circuit_builder), + ExtCascadeTable::terminal_constraints(&circuit_builder), + ExtLookupTable::terminal_constraints(&circuit_builder), + ExtU32Table::terminal_constraints(&circuit_builder), + GrandCrossTableArg::terminal_constraints(&circuit_builder), + ] + .concat() +} diff --git a/triton-vm/src/table/cascade_table.rs b/triton-vm/src/table/cascade_table.rs index 3ef1f9910..c4a8dfe9f 100644 --- a/triton-vm/src/table/cascade_table.rs +++ b/triton-vm/src/table/cascade_table.rs @@ -1,3 +1,9 @@ +use crate::codegen::circuit::ConstraintCircuitBuilder; +use crate::codegen::circuit::ConstraintCircuitMonad; +use crate::codegen::circuit::DualRowIndicator; +use crate::codegen::circuit::DualRowIndicator::*; +use crate::codegen::circuit::SingleRowIndicator; +use crate::codegen::circuit::SingleRowIndicator::*; use ndarray::s; use ndarray::ArrayView2; use ndarray::ArrayViewMut2; @@ -11,12 +17,6 @@ use crate::profiler::profiler; use crate::table::challenges::ChallengeId; use crate::table::challenges::ChallengeId::*; use crate::table::challenges::Challenges; -use crate::table::constraint_circuit::ConstraintCircuitBuilder; -use crate::table::constraint_circuit::ConstraintCircuitMonad; -use crate::table::constraint_circuit::DualRowIndicator; -use crate::table::constraint_circuit::DualRowIndicator::*; -use crate::table::constraint_circuit::SingleRowIndicator; -use crate::table::constraint_circuit::SingleRowIndicator::*; use crate::table::cross_table_argument::CrossTableArg; use crate::table::cross_table_argument::LookupArg; use crate::table::table_column::CascadeBaseTableColumn; diff --git a/triton-vm/src/table/constraints.rs b/triton-vm/src/table/constraints.rs index ae9b737fb..4933815c2 100644 --- a/triton-vm/src/table/constraints.rs +++ b/triton-vm/src/table/constraints.rs @@ -2,11 +2,11 @@ //! Run `cargo run --bin constraint-evaluation-generator` //! to fill in this file with optimized constraints. -use crate::table::challenges::Challenges; use ndarray::ArrayView1; use twenty_first::prelude::BFieldElement; use twenty_first::prelude::XFieldElement; +use crate::table::challenges::Challenges; use crate::table::extension_table::Evaluable; use crate::table::extension_table::Quotientable; use crate::table::master_table::MasterExtTable; diff --git a/triton-vm/src/table/cross_table_argument.rs b/triton-vm/src/table/cross_table_argument.rs index efa4d1aa5..8eb1b41b5 100644 --- a/triton-vm/src/table/cross_table_argument.rs +++ b/triton-vm/src/table/cross_table_argument.rs @@ -1,14 +1,14 @@ use std::ops::Add; use std::ops::Mul; +use crate::codegen::circuit::ConstraintCircuitBuilder; +use crate::codegen::circuit::ConstraintCircuitMonad; +use crate::codegen::circuit::DualRowIndicator; +use crate::codegen::circuit::SingleRowIndicator; +use crate::codegen::circuit::SingleRowIndicator::ExtRow; use twenty_first::prelude::*; use crate::table::challenges::ChallengeId::*; -use crate::table::constraint_circuit::ConstraintCircuitBuilder; -use crate::table::constraint_circuit::ConstraintCircuitMonad; -use crate::table::constraint_circuit::DualRowIndicator; -use crate::table::constraint_circuit::SingleRowIndicator; -use crate::table::constraint_circuit::SingleRowIndicator::ExtRow; use crate::table::table_column::CascadeExtTableColumn; use crate::table::table_column::HashExtTableColumn; use crate::table::table_column::HashExtTableColumn::*; @@ -265,7 +265,7 @@ mod tests { #[proptest] fn lookup_argument_is_identical_to_inverse_of_evaluation_of_zerofier_polynomial( #[strategy(arb())] - #[filter(#roots.iter().all(|&r| r != #challenge) )] + #[filter(#roots.iter().all(|&r| r != #challenge))] roots: Vec, #[strategy(arb())] initial: XFieldElement, #[strategy(arb())] challenge: BFieldElement, diff --git a/triton-vm/src/table/hash_table.rs b/triton-vm/src/table/hash_table.rs index 31e7aba84..323750013 100644 --- a/triton-vm/src/table/hash_table.rs +++ b/triton-vm/src/table/hash_table.rs @@ -1,3 +1,10 @@ +use crate::codegen::circuit::ConstraintCircuitBuilder; +use crate::codegen::circuit::ConstraintCircuitMonad; +use crate::codegen::circuit::DualRowIndicator; +use crate::codegen::circuit::DualRowIndicator::*; +use crate::codegen::circuit::InputIndicator; +use crate::codegen::circuit::SingleRowIndicator; +use crate::codegen::circuit::SingleRowIndicator::*; use itertools::Itertools; use ndarray::*; use num_traits::Zero; @@ -23,13 +30,6 @@ use crate::profiler::profiler; use crate::table::cascade_table::CascadeTable; use crate::table::challenges::ChallengeId::*; use crate::table::challenges::Challenges; -use crate::table::constraint_circuit::ConstraintCircuitBuilder; -use crate::table::constraint_circuit::ConstraintCircuitMonad; -use crate::table::constraint_circuit::DualRowIndicator; -use crate::table::constraint_circuit::DualRowIndicator::*; -use crate::table::constraint_circuit::InputIndicator; -use crate::table::constraint_circuit::SingleRowIndicator; -use crate::table::constraint_circuit::SingleRowIndicator::*; use crate::table::cross_table_argument::CrossTableArg; use crate::table::cross_table_argument::EvalArg; use crate::table::cross_table_argument::LookupArg; diff --git a/triton-vm/src/table/jump_stack_table.rs b/triton-vm/src/table/jump_stack_table.rs index 1d3eed41d..ec055bc97 100644 --- a/triton-vm/src/table/jump_stack_table.rs +++ b/triton-vm/src/table/jump_stack_table.rs @@ -2,6 +2,9 @@ use std::cmp::Ordering; use std::collections::HashMap; use std::ops::Range; +use crate::codegen::circuit::DualRowIndicator::*; +use crate::codegen::circuit::SingleRowIndicator::*; +use crate::codegen::circuit::*; use itertools::Itertools; use ndarray::parallel::prelude::*; use ndarray::prelude::*; @@ -17,9 +20,6 @@ use crate::ndarray_helper::horizontal_multi_slice_mut; use crate::profiler::profiler; use crate::table::challenges::ChallengeId::*; use crate::table::challenges::Challenges; -use crate::table::constraint_circuit::DualRowIndicator::*; -use crate::table::constraint_circuit::SingleRowIndicator::*; -use crate::table::constraint_circuit::*; use crate::table::cross_table_argument::*; use crate::table::table_column::JumpStackBaseTableColumn::*; use crate::table::table_column::JumpStackExtTableColumn::*; diff --git a/triton-vm/src/table/lookup_table.rs b/triton-vm/src/table/lookup_table.rs index 21ed18ebe..f1350a355 100644 --- a/triton-vm/src/table/lookup_table.rs +++ b/triton-vm/src/table/lookup_table.rs @@ -1,3 +1,9 @@ +use crate::codegen::circuit::ConstraintCircuitBuilder; +use crate::codegen::circuit::ConstraintCircuitMonad; +use crate::codegen::circuit::DualRowIndicator; +use crate::codegen::circuit::DualRowIndicator::*; +use crate::codegen::circuit::SingleRowIndicator; +use crate::codegen::circuit::SingleRowIndicator::*; use itertools::Itertools; use ndarray::prelude::*; use num_traits::ConstOne; @@ -14,12 +20,6 @@ use crate::profiler::profiler; use crate::table::challenges::ChallengeId; use crate::table::challenges::ChallengeId::*; use crate::table::challenges::Challenges; -use crate::table::constraint_circuit::ConstraintCircuitBuilder; -use crate::table::constraint_circuit::ConstraintCircuitMonad; -use crate::table::constraint_circuit::DualRowIndicator; -use crate::table::constraint_circuit::DualRowIndicator::*; -use crate::table::constraint_circuit::SingleRowIndicator; -use crate::table::constraint_circuit::SingleRowIndicator::*; use crate::table::cross_table_argument::CrossTableArg; use crate::table::cross_table_argument::EvalArg; use crate::table::cross_table_argument::LookupArg; diff --git a/triton-vm/src/table/master_table.rs b/triton-vm/src/table/master_table.rs index 0315692e1..c9c8244f5 100644 --- a/triton-vm/src/table/master_table.rs +++ b/triton-vm/src/table/master_table.rs @@ -1292,6 +1292,10 @@ mod tests { use crate::air::tasm_air_constraints::dynamic_air_constraint_evaluation_tasm; use crate::air::tasm_air_constraints::static_air_constraint_evaluation_tasm; use crate::arithmetic_domain::ArithmeticDomain; + use crate::codegen::circuit::ConstraintCircuitBuilder; + use crate::codegen::circuit::ConstraintCircuitMonad; + use crate::codegen::circuit::DualRowIndicator; + use crate::codegen::circuit::SingleRowIndicator; use crate::instruction::tests::InstructionBucket; use crate::instruction::Instruction; use crate::instruction::InstructionBit; @@ -1304,10 +1308,6 @@ mod tests { use crate::triton_program; use self::cascade_table::ExtCascadeTable; - use self::constraint_circuit::ConstraintCircuitBuilder; - use self::constraint_circuit::ConstraintCircuitMonad; - use self::constraint_circuit::DualRowIndicator; - use self::constraint_circuit::SingleRowIndicator; use self::hash_table::ExtHashTable; use self::jump_stack_table::ExtJumpStackTable; use self::lookup_table::ExtLookupTable; diff --git a/triton-vm/src/table/op_stack_table.rs b/triton-vm/src/table/op_stack_table.rs index b300260a3..3e17c60a1 100644 --- a/triton-vm/src/table/op_stack_table.rs +++ b/triton-vm/src/table/op_stack_table.rs @@ -2,6 +2,9 @@ use std::cmp::Ordering; use std::collections::HashMap; use std::ops::Range; +use crate::codegen::circuit::DualRowIndicator::*; +use crate::codegen::circuit::SingleRowIndicator::*; +use crate::codegen::circuit::*; use arbitrary::Arbitrary; use itertools::Itertools; use ndarray::parallel::prelude::*; @@ -19,9 +22,6 @@ use crate::op_stack::UnderflowIO; use crate::profiler::profiler; use crate::table::challenges::ChallengeId::*; use crate::table::challenges::Challenges; -use crate::table::constraint_circuit::DualRowIndicator::*; -use crate::table::constraint_circuit::SingleRowIndicator::*; -use crate::table::constraint_circuit::*; use crate::table::cross_table_argument::*; use crate::table::master_table::TableId; use crate::table::table_column::OpStackBaseTableColumn::*; diff --git a/triton-vm/src/table/processor_table.rs b/triton-vm/src/table/processor_table.rs index bbbc49510..394851ef9 100644 --- a/triton-vm/src/table/processor_table.rs +++ b/triton-vm/src/table/processor_table.rs @@ -1,6 +1,9 @@ use std::cmp::max; use std::ops::Mul; +use crate::codegen::circuit::DualRowIndicator::*; +use crate::codegen::circuit::SingleRowIndicator::*; +use crate::codegen::circuit::*; use itertools::izip; use itertools::Itertools; use ndarray::parallel::prelude::*; @@ -28,9 +31,6 @@ use crate::profiler::profiler; use crate::table::challenges::ChallengeId; use crate::table::challenges::ChallengeId::*; use crate::table::challenges::Challenges; -use crate::table::constraint_circuit::DualRowIndicator::*; -use crate::table::constraint_circuit::SingleRowIndicator::*; -use crate::table::constraint_circuit::*; use crate::table::cross_table_argument::*; use crate::table::ram_table; use crate::table::table_column::ProcessorBaseTableColumn::*; diff --git a/triton-vm/src/table/program_table.rs b/triton-vm/src/table/program_table.rs index 814b100e8..409b52fcc 100644 --- a/triton-vm/src/table/program_table.rs +++ b/triton-vm/src/table/program_table.rs @@ -1,5 +1,8 @@ use std::cmp::Ordering; +use crate::codegen::circuit::DualRowIndicator::*; +use crate::codegen::circuit::SingleRowIndicator::*; +use crate::codegen::circuit::*; use ndarray::s; use ndarray::Array1; use ndarray::ArrayView1; @@ -14,9 +17,6 @@ use crate::aet::AlgebraicExecutionTrace; use crate::profiler::profiler; use crate::table::challenges::ChallengeId::*; use crate::table::challenges::Challenges; -use crate::table::constraint_circuit::DualRowIndicator::*; -use crate::table::constraint_circuit::SingleRowIndicator::*; -use crate::table::constraint_circuit::*; use crate::table::cross_table_argument::CrossTableArg; use crate::table::cross_table_argument::EvalArg; use crate::table::cross_table_argument::LookupArg; diff --git a/triton-vm/src/table/ram_table.rs b/triton-vm/src/table/ram_table.rs index b07e7899f..46d112e41 100644 --- a/triton-vm/src/table/ram_table.rs +++ b/triton-vm/src/table/ram_table.rs @@ -1,5 +1,8 @@ use std::cmp::Ordering; +use crate::codegen::circuit::DualRowIndicator::*; +use crate::codegen::circuit::SingleRowIndicator::*; +use crate::codegen::circuit::*; use arbitrary::Arbitrary; use itertools::Itertools; use ndarray::parallel::prelude::*; @@ -7,7 +10,8 @@ use ndarray::prelude::*; use num_traits::ConstOne; use num_traits::One; use num_traits::Zero; -use serde_derive::*; +use serde::Deserialize; +use serde::Serialize; use strum::EnumCount; use strum::IntoEnumIterator; use twenty_first::math::traits::FiniteField; @@ -19,9 +23,6 @@ use crate::ndarray_helper::horizontal_multi_slice_mut; use crate::profiler::profiler; use crate::table::challenges::ChallengeId::*; use crate::table::challenges::Challenges; -use crate::table::constraint_circuit::DualRowIndicator::*; -use crate::table::constraint_circuit::SingleRowIndicator::*; -use crate::table::constraint_circuit::*; use crate::table::cross_table_argument::*; use crate::table::master_table::TableId; use crate::table::table_column::RamBaseTableColumn::*; diff --git a/triton-vm/src/table/table_column.rs b/triton-vm/src/table/table_column.rs index f16215ec8..e1792a42a 100644 --- a/triton-vm/src/table/table_column.rs +++ b/triton-vm/src/table/table_column.rs @@ -30,8 +30,6 @@ use crate::table::master_table::PROGRAM_TABLE_START; use crate::table::master_table::RAM_TABLE_START; use crate::table::master_table::U32_TABLE_START; -// -------- Program Table -------- - #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] pub enum ProgramBaseTableColumn { @@ -97,8 +95,6 @@ pub enum ProgramExtTableColumn { SendChunkRunningEvaluation, } -// -------- Processor Table -------- - #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] pub enum ProcessorBaseTableColumn { @@ -170,8 +166,6 @@ pub enum ProcessorExtTableColumn { ClockJumpDifferenceLookupServerLogDerivative, } -// -------- OpStack Table -------- - #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] pub enum OpStackBaseTableColumn { @@ -190,8 +184,6 @@ pub enum OpStackExtTableColumn { ClockJumpDifferenceLookupClientLogDerivative, } -// -------- RAM Table -------- - #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] pub enum RamBaseTableColumn { @@ -224,8 +216,6 @@ pub enum RamExtTableColumn { ClockJumpDifferenceLookupClientLogDerivative, } -// -------- JumpStack Table -------- - #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] pub enum JumpStackBaseTableColumn { @@ -245,8 +235,6 @@ pub enum JumpStackExtTableColumn { ClockJumpDifferenceLookupClientLogDerivative, } -// -------- Hash Table -------- - #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] pub enum HashBaseTableColumn { @@ -383,8 +371,6 @@ pub enum HashExtTableColumn { CascadeState3LowestClientLogDerivative, } -// -------- Cascade Table -------- - #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] pub enum CascadeBaseTableColumn { @@ -427,8 +413,6 @@ pub enum CascadeExtTableColumn { LookupTableClientLogDerivative, } -// -------- Lookup Table -------- - #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] pub enum LookupBaseTableColumn { @@ -458,8 +442,6 @@ pub enum LookupExtTableColumn { PublicEvaluationArgument, } -// -------- U32 Table -------- - #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] pub enum U32BaseTableColumn { @@ -505,15 +487,13 @@ pub enum U32ExtTableColumn { LookupServerLogDerivative, } -// -------------------------------------------------------------------- - /// A trait for the columns of the master base table. This trait is implemented for all enums /// relating to the base tables. This trait provides two methods: -/// - one to get the index of the column in the ”local“ base table, _i.e., not the master base +/// - one to get the index of the column in the “local” base table, _i.e., not the master base /// table, and /// - one to get the index of the column in the master base table. pub trait MasterBaseTableColumn { - /// The index of the column in the ”local“ base table, _i.e., not the master base table. + /// The index of the column in the “local” base table, _i.e., not the master base table. fn base_table_index(&self) -> usize; /// The index of the column in the master base table. @@ -640,8 +620,6 @@ impl MasterBaseTableColumn for DegreeLoweringBaseTableColumn { } } -// -------------------------------------------------------------------- - /// A trait for the columns in the master extension table. This trait is implemented for all enums /// relating to the extension tables. The trait provides two methods: /// - one to get the index of the column in the “local” extension table, _i.e._, not the master @@ -776,8 +754,6 @@ impl MasterExtTableColumn for DegreeLoweringExtTableColumn { } } -// -------------------------------------------------------------------- - #[cfg(test)] mod tests { use strum::IntoEnumIterator; diff --git a/triton-vm/src/table/u32_table.rs b/triton-vm/src/table/u32_table.rs index 63bbef54d..58c7c7604 100644 --- a/triton-vm/src/table/u32_table.rs +++ b/triton-vm/src/table/u32_table.rs @@ -1,6 +1,13 @@ use std::cmp::max; use std::ops::Mul; +use crate::codegen::circuit::ConstraintCircuitBuilder; +use crate::codegen::circuit::ConstraintCircuitMonad; +use crate::codegen::circuit::DualRowIndicator; +use crate::codegen::circuit::DualRowIndicator::*; +use crate::codegen::circuit::InputIndicator; +use crate::codegen::circuit::SingleRowIndicator; +use crate::codegen::circuit::SingleRowIndicator::*; use arbitrary::Arbitrary; use ndarray::parallel::prelude::*; use ndarray::s; @@ -19,13 +26,6 @@ use crate::instruction::Instruction; use crate::profiler::profiler; use crate::table::challenges::ChallengeId::*; use crate::table::challenges::Challenges; -use crate::table::constraint_circuit::ConstraintCircuitBuilder; -use crate::table::constraint_circuit::ConstraintCircuitMonad; -use crate::table::constraint_circuit::DualRowIndicator; -use crate::table::constraint_circuit::DualRowIndicator::*; -use crate::table::constraint_circuit::InputIndicator; -use crate::table::constraint_circuit::SingleRowIndicator; -use crate::table::constraint_circuit::SingleRowIndicator::*; use crate::table::cross_table_argument::CrossTableArg; use crate::table::cross_table_argument::LookupArg; use crate::table::table_column::MasterBaseTableColumn; diff --git a/triton-vm/src/vm.rs b/triton-vm/src/vm.rs index b7c43bc16..c9407a676 100644 --- a/triton-vm/src/vm.rs +++ b/triton-vm/src/vm.rs @@ -11,7 +11,8 @@ use ndarray::Array1; use num_traits::ConstZero; use num_traits::One; use num_traits::Zero; -use serde_derive::*; +use serde::Deserialize; +use serde::Serialize; use twenty_first::math::x_field_element::EXTENSION_DEGREE; use twenty_first::prelude::*; use twenty_first::util_types::algebraic_hasher::Domain; From b3f298fd65a81b1f3a9a477fd72a9e90b7616e98 Mon Sep 17 00:00:00 2001 From: Jan Ferdinand Sauer Date: Wed, 28 Aug 2024 16:48:32 +0200 Subject: [PATCH 03/15] refactor!: Move constraint circuits to own crate changelog: ignore --- Cargo.toml | 2 +- constraint-builder/Cargo.toml | 29 + .../src/lib.rs | 1032 +++-------------- triton-vm/Cargo.toml | 1 + triton-vm/src/codegen/constraints.rs | 12 +- triton-vm/src/codegen/mod.rs | 116 +- triton-vm/src/codegen/substitutions.rs | 16 +- triton-vm/src/stark.rs | 2 +- triton-vm/src/table.rs | 387 ++++++- triton-vm/src/table/cascade_table.rs | 12 +- triton-vm/src/table/cross_table_argument.rs | 10 +- triton-vm/src/table/hash_table.rs | 14 +- triton-vm/src/table/jump_stack_table.rs | 6 +- triton-vm/src/table/lookup_table.rs | 12 +- triton-vm/src/table/master_table.rs | 90 +- triton-vm/src/table/op_stack_table.rs | 6 +- triton-vm/src/table/processor_table.rs | 6 +- triton-vm/src/table/program_table.rs | 6 +- triton-vm/src/table/ram_table.rs | 6 +- triton-vm/src/table/u32_table.rs | 14 +- 20 files changed, 688 insertions(+), 1091 deletions(-) create mode 100644 constraint-builder/Cargo.toml rename triton-vm/src/codegen/circuit.rs => constraint-builder/src/lib.rs (57%) diff --git a/Cargo.toml b/Cargo.toml index 7659013d4..9ef2f97cd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["triton-vm"] +members = ["constraint-builder", "triton-vm"] resolver = "2" [profile.test] diff --git a/constraint-builder/Cargo.toml b/constraint-builder/Cargo.toml new file mode 100644 index 000000000..51bdd441c --- /dev/null +++ b/constraint-builder/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "constraint-builder" +description = """ +AIR constraints build helper for Triton VM. +""" + +version.workspace = true +edition.workspace = true +authors.workspace = true +license.workspace = true +homepage.workspace = true +documentation.workspace = true +repository.workspace = true +readme.workspace = true + +[dependencies] +arbitrary.workspace = true +itertools.workspace = true +num-traits.workspace = true +ndarray.workspace = true +proc-macro2.workspace = true +quote.workspace = true +twenty-first.workspace = true + +[dev-dependencies] +proptest.workspace = true +proptest-arbitrary-interop.workspace = true +rand.workspace = true +test-strategy.workspace = true diff --git a/triton-vm/src/codegen/circuit.rs b/constraint-builder/src/lib.rs similarity index 57% rename from triton-vm/src/codegen/circuit.rs rename to constraint-builder/src/lib.rs index 338b84d41..30292f9b2 100644 --- a/triton-vm/src/codegen/circuit.rs +++ b/constraint-builder/src/lib.rs @@ -35,6 +35,18 @@ use quote::quote; use quote::ToTokens; use twenty_first::prelude::*; +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +pub struct DegreeLoweringInfo { + /// The degree after degree lowering. Must be greater than 1. + pub target_degree: isize, + + /// The total number of base columns _before_ degree lowering has happened. + pub num_base_cols: usize, + + /// The total number of extension columns _before_ degree lowering has happened. + pub num_ext_cols: usize, +} + #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] pub enum BinOp { Add, @@ -680,95 +692,6 @@ impl ConstraintCircuitMonad { self.circuit.borrow().to_owned() } - /// Traverse the circuit and find all nodes that are equivalent. Note that - /// two nodes are equivalent if they compute the same value on all identical - /// inputs. Equivalence is different from identity, which is when two nodes - /// connect the same set of neighbors in the same way. (There may be two - /// different ways to compute the same result; they are equivalent but - /// unequal.) - /// - /// This function returns a list of lists of equivalent nodes such that - /// every inner list can be reduced to a single node without changing the - /// circuit's function. - /// - /// Equivalent nodes are detected probabilistically using the multivariate - /// Schwartz-Zippel lemma. The false positive probability is zero (we can be - /// certain that equivalent nodes will be found). The false negative - /// probability is bounded by max_degree / (2^64 - 2^32 + 1)^3. - pub fn find_equivalent_nodes(&self) -> Vec>>>> { - let mut id_to_eval = HashMap::new(); - let mut eval_to_ids = HashMap::new(); - let mut id_to_node = HashMap::new(); - Self::probe_random( - &self.circuit, - &mut id_to_eval, - &mut eval_to_ids, - &mut id_to_node, - rand::random(), - ); - - eval_to_ids - .values() - .filter(|ids| ids.len() >= 2) - .map(|ids| ids.iter().map(|i| id_to_node[i].clone()).collect_vec()) - .collect_vec() - } - - /// Populate the dictionaries such that they associate with every node in - /// the circuit its evaluation in a random point. The inputs are assigned - /// random values. Equivalent nodes are detected based on evaluating to the - /// same value using the Schwartz-Zippel lemma. - fn probe_random( - circuit: &Rc>>, - id_to_eval: &mut HashMap, - eval_to_ids: &mut HashMap>, - id_to_node: &mut HashMap>>>, - master_seed: XFieldElement, - ) -> XFieldElement { - const DOMAIN_SEPARATOR_CURR_ROW: BFieldElement = BFieldElement::new(0); - const DOMAIN_SEPARATOR_NEXT_ROW: BFieldElement = BFieldElement::new(1); - const DOMAIN_SEPARATOR_CHALLENGE: BFieldElement = BFieldElement::new(2); - - let circuit_id = circuit.borrow().id; - if let Some(&xfe) = id_to_eval.get(&circuit_id) { - return xfe; - } - - let evaluation = match &circuit.borrow().expression { - CircuitExpression::BConstant(bfe) => bfe.lift(), - CircuitExpression::XConstant(xfe) => *xfe, - CircuitExpression::Input(input) => { - let [s0, s1, s2] = master_seed.coefficients; - let dom_sep = if input.is_current_row() { - DOMAIN_SEPARATOR_CURR_ROW - } else { - DOMAIN_SEPARATOR_NEXT_ROW - }; - let i = bfe!(u64::try_from(input.column()).unwrap()); - let Digest([d0, d1, d2, _, _]) = Tip5::hash_varlen(&[s0, s1, s2, dom_sep, i]); - xfe!([d0, d1, d2]) - } - CircuitExpression::Challenge(challenge) => { - let [s0, s1, s2] = master_seed.coefficients; - let dom_sep = DOMAIN_SEPARATOR_CHALLENGE; - let ch = bfe!(u64::try_from(*challenge).unwrap()); - let Digest([d0, d1, d2, _, _]) = Tip5::hash_varlen(&[s0, s1, s2, dom_sep, ch]); - xfe!([d0, d1, d2]) - } - CircuitExpression::BinaryOperation(bin_op, lhs, rhs) => { - let l = Self::probe_random(lhs, id_to_eval, eval_to_ids, id_to_node, master_seed); - let r = Self::probe_random(rhs, id_to_eval, eval_to_ids, id_to_node, master_seed); - bin_op.operation(l, r) - } - }; - - id_to_eval.insert(circuit_id, evaluation); - eval_to_ids.entry(evaluation).or_default().push(circuit_id); - id_to_node.insert(circuit_id, circuit.clone()); - - evaluation - } - /// Lowers the degree of a given multicircuit to the target degree. /// This is achieved by introducing additional variables and constraints. /// The appropriate substitutions are applied to the given multicircuit. @@ -791,10 +714,9 @@ impl ConstraintCircuitMonad { /// when a tables' constraints are built using the master table's column indices. pub fn lower_to_degree( multicircuit: &mut [Self], - target_degree: isize, - num_base_cols: usize, - num_ext_cols: usize, + info: DegreeLoweringInfo, ) -> (Vec, Vec) { + let target_degree = info.target_degree; assert!( target_degree > 1, "Target degree must be greater than 1. Got {target_degree}." @@ -816,10 +738,10 @@ impl ConstraintCircuitMonad { let chosen_node = builder.all_nodes.borrow()[&chosen_node_id].clone(); let chosen_node_is_base_col = chosen_node.circuit.borrow().evaluates_to_base_element(); let new_input_indicator = if chosen_node_is_base_col { - let new_base_col_idx = num_base_cols + base_constraints.len(); + let new_base_col_idx = info.num_base_cols + base_constraints.len(); II::base_table_input(new_base_col_idx) } else { - let new_ext_col_idx = num_ext_cols + ext_constraints.len(); + let new_ext_col_idx = info.num_ext_cols + ext_constraints.len(); II::ext_table_input(new_ext_col_idx) }; let new_variable = builder.input(new_input_indicator); @@ -943,7 +865,7 @@ impl ConstraintCircuitMonad { } /// Returns the maximum degree of all circuits in the multicircuit. - pub(crate) fn multicircuit_degree(multicircuit: &[ConstraintCircuitMonad]) -> isize { + pub fn multicircuit_degree(multicircuit: &[ConstraintCircuitMonad]) -> isize { multicircuit .iter() .map(|circuit| circuit.circuit.borrow().degree()) @@ -1128,52 +1050,105 @@ mod tests { use std::hash::Hasher; use itertools::Itertools; - use ndarray::Array2; use proptest::prelude::*; use proptest_arbitrary_interop::arb; use rand::random; - use rand::rngs::StdRng; - use rand::Rng; - use rand::SeedableRng; use test_strategy::proptest; - use crate::prelude::Claim; - use crate::table::cascade_table::ExtCascadeTable; - use crate::table::challenges::Challenges; - use crate::table::degree_lowering_table::DegreeLoweringTable; - use crate::table::hash_table::ExtHashTable; - use crate::table::jump_stack_table::ExtJumpStackTable; - use crate::table::lookup_table::ExtLookupTable; - use crate::table::master_table::AIR_TARGET_DEGREE; - use crate::table::master_table::CASCADE_TABLE_END; - use crate::table::master_table::EXT_CASCADE_TABLE_END; - use crate::table::master_table::EXT_HASH_TABLE_END; - use crate::table::master_table::EXT_JUMP_STACK_TABLE_END; - use crate::table::master_table::EXT_LOOKUP_TABLE_END; - use crate::table::master_table::EXT_OP_STACK_TABLE_END; - use crate::table::master_table::EXT_PROCESSOR_TABLE_END; - use crate::table::master_table::EXT_PROGRAM_TABLE_END; - use crate::table::master_table::EXT_RAM_TABLE_END; - use crate::table::master_table::EXT_U32_TABLE_END; - use crate::table::master_table::HASH_TABLE_END; - use crate::table::master_table::JUMP_STACK_TABLE_END; - use crate::table::master_table::LOOKUP_TABLE_END; - use crate::table::master_table::OP_STACK_TABLE_END; - use crate::table::master_table::PROCESSOR_TABLE_END; - use crate::table::master_table::PROGRAM_TABLE_END; - use crate::table::master_table::RAM_TABLE_END; - use crate::table::master_table::U32_TABLE_END; - use crate::table::op_stack_table::ExtOpStackTable; - use crate::table::processor_table::ExtProcessorTable; - use crate::table::program_table::ExtProgramTable; - use crate::table::ram_table::ExtRamTable; - use crate::table::u32_table::ExtU32Table; - use crate::table::NUM_BASE_COLUMNS; - use crate::table::NUM_EXT_COLUMNS; - use super::*; impl ConstraintCircuitMonad { + /// Traverse the circuit and find all nodes that are equivalent. Note that + /// two nodes are equivalent if they compute the same value on all identical + /// inputs. Equivalence is different from identity, which is when two nodes + /// connect the same set of neighbors in the same way. (There may be two + /// different ways to compute the same result; they are equivalent but + /// unequal.) + /// + /// This function returns a list of lists of equivalent nodes such that + /// every inner list can be reduced to a single node without changing the + /// circuit's function. + /// + /// Equivalent nodes are detected probabilistically using the multivariate + /// Schwartz-Zippel lemma. The false positive probability is zero (we can be + /// certain that equivalent nodes will be found). The false negative + /// probability is bounded by max_degree / (2^64 - 2^32 + 1)^3. + pub fn find_equivalent_nodes(&self) -> Vec>>>> { + let mut id_to_eval = HashMap::new(); + let mut eval_to_ids = HashMap::new(); + let mut id_to_node = HashMap::new(); + Self::probe_random( + &self.circuit, + &mut id_to_eval, + &mut eval_to_ids, + &mut id_to_node, + random(), + ); + + eval_to_ids + .values() + .filter(|ids| ids.len() >= 2) + .map(|ids| ids.iter().map(|i| id_to_node[i].clone()).collect_vec()) + .collect_vec() + } + + /// Populate the dictionaries such that they associate with every node in + /// the circuit its evaluation in a random point. The inputs are assigned + /// random values. Equivalent nodes are detected based on evaluating to the + /// same value using the Schwartz-Zippel lemma. + fn probe_random( + circuit: &Rc>>, + id_to_eval: &mut HashMap, + eval_to_ids: &mut HashMap>, + id_to_node: &mut HashMap>>>, + master_seed: XFieldElement, + ) -> XFieldElement { + const DOMAIN_SEPARATOR_CURR_ROW: BFieldElement = BFieldElement::new(0); + const DOMAIN_SEPARATOR_NEXT_ROW: BFieldElement = BFieldElement::new(1); + const DOMAIN_SEPARATOR_CHALLENGE: BFieldElement = BFieldElement::new(2); + + let circuit_id = circuit.borrow().id; + if let Some(&xfe) = id_to_eval.get(&circuit_id) { + return xfe; + } + + let evaluation = match &circuit.borrow().expression { + CircuitExpression::BConstant(bfe) => bfe.lift(), + CircuitExpression::XConstant(xfe) => *xfe, + CircuitExpression::Input(input) => { + let [s0, s1, s2] = master_seed.coefficients; + let dom_sep = if input.is_current_row() { + DOMAIN_SEPARATOR_CURR_ROW + } else { + DOMAIN_SEPARATOR_NEXT_ROW + }; + let i = bfe!(u64::try_from(input.column()).unwrap()); + let Digest([d0, d1, d2, _, _]) = Tip5::hash_varlen(&[s0, s1, s2, dom_sep, i]); + xfe!([d0, d1, d2]) + } + CircuitExpression::Challenge(challenge) => { + let [s0, s1, s2] = master_seed.coefficients; + let dom_sep = DOMAIN_SEPARATOR_CHALLENGE; + let ch = bfe!(u64::try_from(*challenge).unwrap()); + let Digest([d0, d1, d2, _, _]) = Tip5::hash_varlen(&[s0, s1, s2, dom_sep, ch]); + xfe!([d0, d1, d2]) + } + CircuitExpression::BinaryOperation(bin_op, lhs, rhs) => { + let l = + Self::probe_random(lhs, id_to_eval, eval_to_ids, id_to_node, master_seed); + let r = + Self::probe_random(rhs, id_to_eval, eval_to_ids, id_to_node, master_seed); + bin_op.operation(l, r) + } + }; + + id_to_eval.insert(circuit_id, evaluation); + eval_to_ids.entry(evaluation).or_default().push(circuit_id); + id_to_node.insert(circuit_id, circuit.clone()); + + evaluation + } + /// Check whether the given node is contained in this circuit. fn contains(&self, other: &Self) -> bool { let self_expression = &self.circuit.borrow().expression; @@ -1338,209 +1313,6 @@ mod tests { assert!(root_2.contains(&new_variable)); } - /// Recursively evaluates the given constraint circuit and its sub-circuits on the given - /// base and extension table, and returns the result of the evaluation. - /// At each recursive step, updates the given HashMap with the result of the evaluation. - /// If the HashMap already contains the result of the evaluation, panics. - /// This function is used to assert that the evaluation of a constraint circuit - /// and its sub-circuits is unique. - /// It is used to identify redundant constraints or sub-circuits. - /// The employed method is the Schwartz-Zippel lemma. - fn evaluate_assert_unique( - constraint: &ConstraintCircuit, - challenges: &[XFieldElement], - base_rows: ArrayView2, - ext_rows: ArrayView2, - values: &mut HashMap)>, - ) -> XFieldElement { - let value = match &constraint.expression { - CircuitExpression::BinaryOperation(binop, lhs, rhs) => { - let lhs = lhs.borrow(); - let rhs = rhs.borrow(); - let lhs = evaluate_assert_unique(&lhs, challenges, base_rows, ext_rows, values); - let rhs = evaluate_assert_unique(&rhs, challenges, base_rows, ext_rows, values); - binop.operation(lhs, rhs) - } - _ => constraint.evaluate(base_rows, ext_rows, challenges), - }; - - let own_id = constraint.id.to_owned(); - let maybe_entry = values.insert(value, (own_id, constraint.clone())); - if let Some((other_id, other_circuit)) = maybe_entry { - assert_eq!( - own_id, other_id, - "Circuit ID {other_id} and circuit ID {own_id} are not unique. \ - Collision on:\n\ - ID {other_id} – {other_circuit}\n\ - ID {own_id} – {constraint}\n\ - Both evaluate to {value}.", - ); - } - - value - } - - /// Verify that all nodes evaluate to a unique value when given a randomized input. - /// If this is not the case two nodes that are not equal evaluate to the same value. - fn table_constraints_prop( - constraints: &[ConstraintCircuit], - table_name: &str, - ) { - let seed = random(); - let mut rng = StdRng::seed_from_u64(seed); - println!("seed: {seed}"); - - let dummy_claim = Claim::default(); - let challenges: [XFieldElement; Challenges::SAMPLE_COUNT] = rng.gen(); - let challenges = challenges.to_vec(); - let challenges = Challenges::new(challenges, &dummy_claim); - let challenges = &challenges.challenges; - - let num_rows = 2; - let base_shape = [num_rows, NUM_BASE_COLUMNS]; - let ext_shape = [num_rows, NUM_EXT_COLUMNS]; - let base_rows = Array2::from_shape_simple_fn(base_shape, || rng.gen::()); - let ext_rows = Array2::from_shape_simple_fn(ext_shape, || rng.gen::()); - let base_rows = base_rows.view(); - let ext_rows = ext_rows.view(); - - let mut values = HashMap::new(); - for c in constraints { - evaluate_assert_unique(c, challenges, base_rows, ext_rows, &mut values); - } - - let circuit_degree = constraints.iter().map(|c| c.degree()).max().unwrap_or(-1); - println!("Max degree constraint for {table_name} table: {circuit_degree}"); - } - - fn build_constraints( - multicircuit_builder: &dyn Fn( - &ConstraintCircuitBuilder, - ) -> Vec>, - ) -> Vec> { - let multicircuit = build_multicircuit(multicircuit_builder); - let mut constraints = multicircuit.into_iter().map(|c| c.consume()).collect_vec(); - ConstraintCircuit::assert_unique_ids(&mut constraints); - constraints - } - - fn build_multicircuit( - multicircuit_builder: &dyn Fn( - &ConstraintCircuitBuilder, - ) -> Vec>, - ) -> Vec> { - let circuit_builder = ConstraintCircuitBuilder::new(); - multicircuit_builder(&circuit_builder) - } - - #[test] - fn constant_folding_processor_table() { - let init = build_constraints(&ExtProcessorTable::initial_constraints); - let cons = build_constraints(&ExtProcessorTable::consistency_constraints); - let tran = build_constraints(&ExtProcessorTable::transition_constraints); - let term = build_constraints(&ExtProcessorTable::terminal_constraints); - table_constraints_prop(&init, "processor initial"); - table_constraints_prop(&cons, "processor consistency"); - table_constraints_prop(&tran, "processor transition"); - table_constraints_prop(&term, "processor terminal"); - } - - #[test] - fn constant_folding_program_table() { - let init = build_constraints(&ExtProgramTable::initial_constraints); - let cons = build_constraints(&ExtProgramTable::consistency_constraints); - let tran = build_constraints(&ExtProgramTable::transition_constraints); - let term = build_constraints(&ExtProgramTable::terminal_constraints); - table_constraints_prop(&init, "program initial"); - table_constraints_prop(&cons, "program consistency"); - table_constraints_prop(&tran, "program transition"); - table_constraints_prop(&term, "program terminal"); - } - - #[test] - fn constant_folding_jump_stack_table() { - let init = build_constraints(&ExtJumpStackTable::initial_constraints); - let cons = build_constraints(&ExtJumpStackTable::consistency_constraints); - let tran = build_constraints(&ExtJumpStackTable::transition_constraints); - let term = build_constraints(&ExtJumpStackTable::terminal_constraints); - table_constraints_prop(&init, "jump stack initial"); - table_constraints_prop(&cons, "jump stack consistency"); - table_constraints_prop(&tran, "jump stack transition"); - table_constraints_prop(&term, "jump stack terminal"); - } - - #[test] - fn constant_folding_op_stack_table() { - let init = build_constraints(&ExtOpStackTable::initial_constraints); - let cons = build_constraints(&ExtOpStackTable::consistency_constraints); - let tran = build_constraints(&ExtOpStackTable::transition_constraints); - let term = build_constraints(&ExtOpStackTable::terminal_constraints); - table_constraints_prop(&init, "op stack initial"); - table_constraints_prop(&cons, "op stack consistency"); - table_constraints_prop(&tran, "op stack transition"); - table_constraints_prop(&term, "op stack terminal"); - } - - #[test] - fn constant_folding_ram_table() { - let init = build_constraints(&ExtRamTable::initial_constraints); - let cons = build_constraints(&ExtRamTable::consistency_constraints); - let tran = build_constraints(&ExtRamTable::transition_constraints); - let term = build_constraints(&ExtRamTable::terminal_constraints); - table_constraints_prop(&init, "ram initial"); - table_constraints_prop(&cons, "ram consistency"); - table_constraints_prop(&tran, "ram transition"); - table_constraints_prop(&term, "ram terminal"); - } - - #[test] - fn constant_folding_hash_table() { - let init = build_constraints(&ExtHashTable::initial_constraints); - let cons = build_constraints(&ExtHashTable::consistency_constraints); - let tran = build_constraints(&ExtHashTable::transition_constraints); - let term = build_constraints(&ExtHashTable::terminal_constraints); - table_constraints_prop(&init, "hash initial"); - table_constraints_prop(&cons, "hash consistency"); - table_constraints_prop(&tran, "hash transition"); - table_constraints_prop(&term, "hash terminal"); - } - - #[test] - fn constant_folding_u32_table() { - let init = build_constraints(&ExtU32Table::initial_constraints); - let cons = build_constraints(&ExtU32Table::consistency_constraints); - let tran = build_constraints(&ExtU32Table::transition_constraints); - let term = build_constraints(&ExtU32Table::terminal_constraints); - table_constraints_prop(&init, "u32 initial"); - table_constraints_prop(&cons, "u32 consistency"); - table_constraints_prop(&tran, "u32 transition"); - table_constraints_prop(&term, "u32 terminal"); - } - - #[test] - fn constant_folding_cascade_table() { - let init = build_constraints(&ExtCascadeTable::initial_constraints); - let cons = build_constraints(&ExtCascadeTable::consistency_constraints); - let tran = build_constraints(&ExtCascadeTable::transition_constraints); - let term = build_constraints(&ExtCascadeTable::terminal_constraints); - table_constraints_prop(&init, "cascade initial"); - table_constraints_prop(&cons, "cascade consistency"); - table_constraints_prop(&tran, "cascade transition"); - table_constraints_prop(&term, "cascade terminal"); - } - - #[test] - fn constant_folding_lookup_table() { - let init = build_constraints(&ExtLookupTable::initial_constraints); - let cons = build_constraints(&ExtLookupTable::consistency_constraints); - let tran = build_constraints(&ExtLookupTable::transition_constraints); - let term = build_constraints(&ExtLookupTable::terminal_constraints); - table_constraints_prop(&init, "lookup initial"); - table_constraints_prop(&cons, "lookup consistency"); - table_constraints_prop(&tran, "lookup transition"); - table_constraints_prop(&term, "lookup terminal"); - } - #[test] fn simple_degree_lowering() { let builder = ConstraintCircuitBuilder::new(); @@ -1549,18 +1321,16 @@ mod tests { let x_pow_5 = x() * x() * x() * x() * x(); let mut multicircuit = [x_pow_5, x_pow_3]; - let target_degree = 3; - let num_base_cols = 1; - let num_ext_cols = 0; - let (new_base_constraints, new_ext_constraints) = lower_degree_and_assert_properties( - &mut multicircuit, - target_degree, - num_base_cols, - num_ext_cols, - ); + let degree_lowering_info = DegreeLoweringInfo { + target_degree: 3, + num_base_cols: 1, + num_ext_cols: 0, + }; + let (new_base_constraints, new_ext_constraints) = + ConstraintCircuitMonad::lower_to_degree(&mut multicircuit, degree_lowering_info); - assert!(new_ext_constraints.is_empty()); assert_eq!(1, new_base_constraints.len()); + assert!(new_ext_constraints.is_empty()); } #[test] @@ -1579,15 +1349,13 @@ mod tests { let mut multicircuit = [constraint_0, constraint_1, constraint_2]; - let target_degree = 2; - let num_base_cols = 3; - let num_ext_cols = 2; - let (new_base_constraints, new_ext_constraints) = lower_degree_and_assert_properties( - &mut multicircuit, - target_degree, - num_base_cols, - num_ext_cols, - ); + let degree_lowering_info = DegreeLoweringInfo { + target_degree: 2, + num_base_cols: 3, + num_ext_cols: 2, + }; + let (new_base_constraints, new_ext_constraints) = + ConstraintCircuitMonad::lower_to_degree(&mut multicircuit, degree_lowering_info); assert!(new_base_constraints.len() <= 3); assert!(new_ext_constraints.len() <= 1); @@ -1603,522 +1371,18 @@ mod tests { let mut multicircuit = [constraint_0, constraint_1]; - let target_degree = 3; - let num_base_cols = 9; - let num_ext_cols = 0; - let (new_base_constraints, new_ext_constraints) = lower_degree_and_assert_properties( - &mut multicircuit, - target_degree, - num_base_cols, - num_ext_cols, - ); + let degree_lowering_info = DegreeLoweringInfo { + target_degree: 2, + num_base_cols: 9, + num_ext_cols: 0, + }; + let (new_base_constraints, new_ext_constraints) = + ConstraintCircuitMonad::lower_to_degree(&mut multicircuit, degree_lowering_info); assert!(new_base_constraints.len() <= 3); assert!(new_ext_constraints.is_empty()); } - #[test] - fn program_table_initial_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtProgramTable::initial_constraints), - AIR_TARGET_DEGREE, - PROGRAM_TABLE_END, - EXT_PROGRAM_TABLE_END, - ); - } - - #[test] - fn program_table_consistency_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtProgramTable::consistency_constraints), - AIR_TARGET_DEGREE, - PROGRAM_TABLE_END, - EXT_PROGRAM_TABLE_END, - ); - } - - #[test] - fn program_table_transition_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtProgramTable::transition_constraints), - AIR_TARGET_DEGREE, - PROGRAM_TABLE_END, - EXT_PROGRAM_TABLE_END, - ); - } - - #[test] - fn program_table_terminal_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtProgramTable::terminal_constraints), - AIR_TARGET_DEGREE, - PROGRAM_TABLE_END, - EXT_PROGRAM_TABLE_END, - ); - } - - #[test] - fn processor_table_initial_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtProcessorTable::initial_constraints), - AIR_TARGET_DEGREE, - PROCESSOR_TABLE_END, - EXT_PROCESSOR_TABLE_END, - ); - } - - #[test] - fn processor_table_consistency_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtProcessorTable::consistency_constraints), - AIR_TARGET_DEGREE, - PROCESSOR_TABLE_END, - EXT_PROCESSOR_TABLE_END, - ); - } - - #[test] - fn processor_table_transition_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtProcessorTable::transition_constraints), - AIR_TARGET_DEGREE, - PROCESSOR_TABLE_END, - EXT_PROCESSOR_TABLE_END, - ); - } - - #[test] - fn processor_table_terminal_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtProcessorTable::terminal_constraints), - AIR_TARGET_DEGREE, - PROCESSOR_TABLE_END, - EXT_PROCESSOR_TABLE_END, - ); - } - - #[test] - fn op_stack_table_initial_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtOpStackTable::initial_constraints), - AIR_TARGET_DEGREE, - OP_STACK_TABLE_END, - EXT_OP_STACK_TABLE_END, - ); - } - - #[test] - fn op_stack_table_consistency_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtOpStackTable::consistency_constraints), - AIR_TARGET_DEGREE, - OP_STACK_TABLE_END, - EXT_OP_STACK_TABLE_END, - ); - } - - #[test] - fn op_stack_table_transition_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtOpStackTable::transition_constraints), - AIR_TARGET_DEGREE, - OP_STACK_TABLE_END, - EXT_OP_STACK_TABLE_END, - ); - } - - #[test] - fn op_stack_table_terminal_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtOpStackTable::terminal_constraints), - AIR_TARGET_DEGREE, - OP_STACK_TABLE_END, - EXT_OP_STACK_TABLE_END, - ); - } - - #[test] - fn ram_table_initial_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtRamTable::initial_constraints), - AIR_TARGET_DEGREE, - RAM_TABLE_END, - EXT_RAM_TABLE_END, - ); - } - - #[test] - fn ram_table_consistency_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtRamTable::consistency_constraints), - AIR_TARGET_DEGREE, - RAM_TABLE_END, - EXT_RAM_TABLE_END, - ); - } - - #[test] - fn ram_table_transition_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtRamTable::transition_constraints), - AIR_TARGET_DEGREE, - RAM_TABLE_END, - EXT_RAM_TABLE_END, - ); - } - - #[test] - fn ram_table_terminal_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtRamTable::terminal_constraints), - AIR_TARGET_DEGREE, - RAM_TABLE_END, - EXT_RAM_TABLE_END, - ); - } - - #[test] - fn jump_stack_table_initial_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtJumpStackTable::initial_constraints), - AIR_TARGET_DEGREE, - JUMP_STACK_TABLE_END, - EXT_JUMP_STACK_TABLE_END, - ); - } - - #[test] - fn jump_stack_table_consistency_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtJumpStackTable::consistency_constraints), - AIR_TARGET_DEGREE, - JUMP_STACK_TABLE_END, - EXT_JUMP_STACK_TABLE_END, - ); - } - - #[test] - fn jump_stack_table_transition_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtJumpStackTable::transition_constraints), - AIR_TARGET_DEGREE, - JUMP_STACK_TABLE_END, - EXT_JUMP_STACK_TABLE_END, - ); - } - - #[test] - fn jump_stack_table_terminal_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtJumpStackTable::terminal_constraints), - AIR_TARGET_DEGREE, - JUMP_STACK_TABLE_END, - EXT_JUMP_STACK_TABLE_END, - ); - } - - #[test] - fn hash_table_initial_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtHashTable::initial_constraints), - AIR_TARGET_DEGREE, - HASH_TABLE_END, - EXT_HASH_TABLE_END, - ); - } - - #[test] - fn hash_table_consistency_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtHashTable::consistency_constraints), - AIR_TARGET_DEGREE, - HASH_TABLE_END, - EXT_HASH_TABLE_END, - ); - } - - #[test] - fn hash_table_transition_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtHashTable::transition_constraints), - AIR_TARGET_DEGREE, - HASH_TABLE_END, - EXT_HASH_TABLE_END, - ); - } - - #[test] - fn hash_table_terminal_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtHashTable::terminal_constraints), - AIR_TARGET_DEGREE, - HASH_TABLE_END, - EXT_HASH_TABLE_END, - ); - } - - #[test] - fn cascade_table_initial_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtCascadeTable::initial_constraints), - AIR_TARGET_DEGREE, - CASCADE_TABLE_END, - EXT_CASCADE_TABLE_END, - ); - } - - #[test] - fn cascade_table_consistency_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtCascadeTable::consistency_constraints), - AIR_TARGET_DEGREE, - CASCADE_TABLE_END, - EXT_CASCADE_TABLE_END, - ); - } - - #[test] - fn cascade_table_transition_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtCascadeTable::transition_constraints), - AIR_TARGET_DEGREE, - CASCADE_TABLE_END, - EXT_CASCADE_TABLE_END, - ); - } - - #[test] - fn cascade_table_terminal_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtCascadeTable::terminal_constraints), - AIR_TARGET_DEGREE, - CASCADE_TABLE_END, - EXT_CASCADE_TABLE_END, - ); - } - - #[test] - fn lookup_table_initial_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtLookupTable::initial_constraints), - AIR_TARGET_DEGREE, - LOOKUP_TABLE_END, - EXT_LOOKUP_TABLE_END, - ); - } - - #[test] - fn lookup_table_consistency_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtLookupTable::consistency_constraints), - AIR_TARGET_DEGREE, - LOOKUP_TABLE_END, - EXT_LOOKUP_TABLE_END, - ); - } - - #[test] - fn lookup_table_transition_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtLookupTable::transition_constraints), - AIR_TARGET_DEGREE, - LOOKUP_TABLE_END, - EXT_LOOKUP_TABLE_END, - ); - } - - #[test] - fn lookup_table_terminal_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtLookupTable::terminal_constraints), - AIR_TARGET_DEGREE, - LOOKUP_TABLE_END, - EXT_LOOKUP_TABLE_END, - ); - } - - #[test] - fn u32_table_initial_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtU32Table::initial_constraints), - AIR_TARGET_DEGREE, - U32_TABLE_END, - EXT_U32_TABLE_END, - ); - } - - #[test] - fn u32_table_consistency_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtU32Table::consistency_constraints), - AIR_TARGET_DEGREE, - U32_TABLE_END, - EXT_U32_TABLE_END, - ); - } - - #[test] - fn u32_table_transition_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtU32Table::transition_constraints), - AIR_TARGET_DEGREE, - U32_TABLE_END, - EXT_U32_TABLE_END, - ); - } - - #[test] - fn u32_table_terminal_constraints_degree_lowering() { - lower_degree_and_assert_properties( - &mut build_multicircuit(&ExtU32Table::terminal_constraints), - AIR_TARGET_DEGREE, - U32_TABLE_END, - EXT_U32_TABLE_END, - ); - } - - /// Like [`ConstraintCircuitMonad::lower_to_degree`] with additional assertion of expected - /// properties. Also prints: - /// - the given multicircuit prior to degree lowering - /// - the multicircuit after degree lowering - /// - the new base constraints - /// - the new extension constraints - /// - the numbers of original and new constraints - fn lower_degree_and_assert_properties( - multicircuit: &mut [ConstraintCircuitMonad], - target_deg: isize, - num_base_cols: usize, - num_ext_cols: usize, - ) -> ( - Vec>, - Vec>, - ) { - let seed = random(); - let mut rng = StdRng::seed_from_u64(seed); - println!("seed: {seed}"); - - let num_constraints = multicircuit.len(); - println!("original multicircuit:"); - for circuit in multicircuit.iter() { - println!(" {circuit}"); - } - - let (new_base_constraints, new_ext_constraints) = ConstraintCircuitMonad::lower_to_degree( - multicircuit, - target_deg, - num_base_cols, - num_ext_cols, - ); - - assert_eq!(num_constraints, multicircuit.len()); - assert!(ConstraintCircuitMonad::multicircuit_degree(multicircuit) <= target_deg); - assert!(ConstraintCircuitMonad::multicircuit_degree(&new_base_constraints) <= target_deg); - assert!(ConstraintCircuitMonad::multicircuit_degree(&new_ext_constraints) <= target_deg); - - // Check that the new constraints are simple substitutions. - let mut substitution_rules = vec![]; - for (constraint_type, constraints) in [ - ("base", &new_base_constraints), - ("ext", &new_ext_constraints), - ] { - for (i, constraint) in constraints.iter().enumerate() { - let expression = constraint.circuit.borrow().expression.clone(); - let CircuitExpression::BinaryOperation(BinOp::Add, lhs, rhs) = expression else { - panic!("New {constraint_type} constraint {i} must be a subtraction."); - }; - let CircuitExpression::Input(input_indicator) = lhs.borrow().expression.clone() - else { - panic!("New {constraint_type} constraint {i} must be a simple substitution."); - }; - let substitution_rule = rhs.borrow().clone(); - assert_substitution_rule_uses_legal_variables(input_indicator, &substitution_rule); - substitution_rules.push(substitution_rule); - } - } - - // Use the Schwartz-Zippel lemma to check no two substitution rules are equal. - let dummy_claim = Claim::default(); - let challenges: [XFieldElement; Challenges::SAMPLE_COUNT] = rng.gen(); - let challenges = challenges.to_vec(); - let challenges = Challenges::new(challenges, &dummy_claim); - let challenges = &challenges.challenges; - - let num_rows = 2; - let num_new_base_constraints = new_base_constraints.len(); - let num_new_ext_constraints = new_ext_constraints.len(); - let num_base_cols = NUM_BASE_COLUMNS + num_new_base_constraints; - let num_ext_cols = NUM_EXT_COLUMNS + num_new_ext_constraints; - let base_shape = [num_rows, num_base_cols]; - let ext_shape = [num_rows, num_ext_cols]; - let base_rows = Array2::from_shape_simple_fn(base_shape, || rng.gen::()); - let ext_rows = Array2::from_shape_simple_fn(ext_shape, || rng.gen::()); - let base_rows = base_rows.view(); - let ext_rows = ext_rows.view(); - - let evaluated_substitution_rules = substitution_rules - .iter() - .map(|c| c.evaluate(base_rows, ext_rows, challenges)); - - let mut values_to_index = HashMap::new(); - for (idx, value) in evaluated_substitution_rules.enumerate() { - if let Some(index) = values_to_index.get(&value) { - panic!("Substitution {idx} must be distinct from substitution {index}."); - } else { - values_to_index.insert(value, idx); - } - } - - // Print the multicircuit and new constraints after degree lowering. - println!("new multicircuit:"); - for circuit in multicircuit.iter() { - println!(" {circuit}"); - } - println!("new base constraints:"); - for constraint in &new_base_constraints { - println!(" {constraint}"); - } - println!("new ext constraints:"); - for constraint in &new_ext_constraints { - println!(" {constraint}"); - } - - println!( - "Started with {num_constraints} constraints. \ - Derived {num_new_base_constraints} new base, \ - {num_new_ext_constraints} new extension constraints." - ); - - (new_base_constraints, new_ext_constraints) - } - - /// Panics if the given substitution rule uses variables with an index greater than (or equal) - /// to the given index. In practice, this given index corresponds to a newly introduced - /// variable. - fn assert_substitution_rule_uses_legal_variables( - new_var: II, - substitution_rule: &ConstraintCircuit, - ) { - match substitution_rule.expression.clone() { - CircuitExpression::BinaryOperation(_, lhs, rhs) => { - let lhs = lhs.borrow(); - let rhs = rhs.borrow(); - assert_substitution_rule_uses_legal_variables(new_var, &lhs); - assert_substitution_rule_uses_legal_variables(new_var, &rhs); - } - CircuitExpression::Input(old_var) => { - let new_var_is_base = new_var.is_base_table_column(); - let old_var_is_base = old_var.is_base_table_column(); - let legal_substitute = match (new_var_is_base, old_var_is_base) { - (true, false) => false, - (false, true) => true, - _ => old_var.column() < new_var.column(), - }; - assert!(legal_substitute, "Cannot replace {old_var} with {new_var}."); - } - _ => (), - }; - } - #[test] fn all_nodes_in_multicircuit_are_identified_correctly() { let builder = ConstraintCircuitBuilder::new(); @@ -2179,42 +1443,6 @@ mod tests { assert!(most_frequent_nodes.contains(&&x(10).consume())); } - /// Fills the derived columns of the degree-lowering table using randomly generated rows and - /// checks the resulting values for uniqueness. The described method corresponds to an - /// application of the Schwartz-Zippel lemma to check uniqueness of the substitution rules - /// generated during degree lowering. - #[test] - #[ignore = "(probably) requires normalization of circuit expressions"] - fn substitution_rules_are_unique() { - let challenges = Challenges::default(); - let mut base_table_rows = Array2::from_shape_fn((2, NUM_BASE_COLUMNS), |_| random()); - let mut ext_table_rows = Array2::from_shape_fn((2, NUM_EXT_COLUMNS), |_| random()); - - DegreeLoweringTable::fill_derived_base_columns(base_table_rows.view_mut()); - DegreeLoweringTable::fill_derived_ext_columns( - base_table_rows.view(), - ext_table_rows.view_mut(), - &challenges, - ); - - let mut encountered_values = HashMap::new(); - for col_idx in 0..NUM_BASE_COLUMNS { - let val = base_table_rows[(0, col_idx)].lift(); - let other_entry = encountered_values.insert(val, col_idx); - if let Some(other_idx) = other_entry { - panic!("Duplicate value {val} in derived base column {other_idx} and {col_idx}."); - } - } - println!("Now comparing extension columns…"); - for col_idx in 0..NUM_EXT_COLUMNS { - let val = ext_table_rows[(0, col_idx)]; - let other_entry = encountered_values.insert(val, col_idx); - if let Some(other_idx) = other_entry { - panic!("Duplicate value {val} in derived ext column {other_idx} and {col_idx}."); - } - } - } - #[test] fn equivalent_nodes_are_detected_when_present() { let builder = ConstraintCircuitBuilder::new(); diff --git a/triton-vm/Cargo.toml b/triton-vm/Cargo.toml index c3e97e71d..881880b9e 100644 --- a/triton-vm/Cargo.toml +++ b/triton-vm/Cargo.toml @@ -40,6 +40,7 @@ syn.workspace = true thiserror.workspace = true twenty-first.workspace = true unicode-width.workspace = true +constraint-builder = { path = "../constraint-builder" } [dev-dependencies] assert2.workspace = true diff --git a/triton-vm/src/codegen/constraints.rs b/triton-vm/src/codegen/constraints.rs index 2a7402e3c..c5eb1acc0 100644 --- a/triton-vm/src/codegen/constraints.rs +++ b/triton-vm/src/codegen/constraints.rs @@ -3,6 +3,10 @@ use std::collections::HashSet; +use constraint_builder::BinOp; +use constraint_builder::CircuitExpression; +use constraint_builder::ConstraintCircuit; +use constraint_builder::InputIndicator; use itertools::Itertools; use proc_macro2::TokenStream; use quote::format_ident; @@ -14,10 +18,6 @@ use twenty_first::prelude::*; use crate::instruction::Instruction; use crate::op_stack::NumberOfWords; -use crate::codegen::circuit::BinOp; -use crate::codegen::circuit::CircuitExpression; -use crate::codegen::circuit::ConstraintCircuit; -use crate::codegen::circuit::InputIndicator; use crate::codegen::Constraints; pub(crate) trait Codegen { @@ -894,8 +894,8 @@ impl ToTokens for IOList { #[cfg(test)] mod tests { - use crate::codegen::circuit::ConstraintCircuitBuilder; - use crate::codegen::circuit::SingleRowIndicator; + use constraint_builder::ConstraintCircuitBuilder; + use constraint_builder::SingleRowIndicator; use twenty_first::prelude::*; use super::*; diff --git a/triton-vm/src/codegen/mod.rs b/triton-vm/src/codegen/mod.rs index 526c29273..8531de53c 100644 --- a/triton-vm/src/codegen/mod.rs +++ b/triton-vm/src/codegen/mod.rs @@ -1,20 +1,19 @@ -use arbitrary::Arbitrary; +use constraint_builder::ConstraintCircuit; +use constraint_builder::ConstraintCircuitMonad; +use constraint_builder::DegreeLoweringInfo; +use constraint_builder::DualRowIndicator; +use constraint_builder::InputIndicator; +use constraint_builder::SingleRowIndicator; use itertools::Itertools; use proc_macro2::TokenStream; use std::fs::write; -use crate::codegen::circuit::ConstraintCircuit; -use crate::codegen::circuit::ConstraintCircuitMonad; -use crate::codegen::circuit::DualRowIndicator; -use crate::codegen::circuit::InputIndicator; -use crate::codegen::circuit::SingleRowIndicator; use crate::codegen::constraints::Codegen; use crate::codegen::constraints::RustBackend; use crate::codegen::constraints::TasmBackend; use crate::codegen::substitutions::AllSubstitutions; use crate::codegen::substitutions::Substitutions; -pub(crate) mod circuit; mod constraints; mod substitutions; @@ -48,74 +47,41 @@ pub(crate) struct Constraints { pub term: Vec>, } -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -pub(crate) struct DegreeLoweringInfo { - pub target_degree: isize, - - /// The total number of base columns _before_ degree lowering has happened. - pub num_base_cols: usize, - - /// The total number of extension columns _before_ degree lowering has happened. - pub num_ext_cols: usize, -} - impl Constraints { pub fn lower_to_target_degree_through_substitutions( &mut self, - info: DegreeLoweringInfo, + mut info: DegreeLoweringInfo, ) -> AllSubstitutions { - // Subtract the degree lowering table's width from the total number of columns to guarantee - // the same number of columns even for repeated runs of the constraint evaluation generator. - let mut num_base_cols = info.num_base_cols; - let mut num_ext_cols = info.num_ext_cols; + let lowering_info = info; + let (init_base_substitutions, init_ext_substitutions) = - ConstraintCircuitMonad::lower_to_degree( - &mut self.init, - info.target_degree, - num_base_cols, - num_ext_cols, - ); - num_base_cols += init_base_substitutions.len(); - num_ext_cols += init_ext_substitutions.len(); + ConstraintCircuitMonad::lower_to_degree(&mut self.init, info); + info.num_base_cols += init_base_substitutions.len(); + info.num_ext_cols += init_ext_substitutions.len(); let (cons_base_substitutions, cons_ext_substitutions) = - ConstraintCircuitMonad::lower_to_degree( - &mut self.cons, - info.target_degree, - num_base_cols, - num_ext_cols, - ); - num_base_cols += cons_base_substitutions.len(); - num_ext_cols += cons_ext_substitutions.len(); + ConstraintCircuitMonad::lower_to_degree(&mut self.cons, info); + info.num_base_cols += cons_base_substitutions.len(); + info.num_ext_cols += cons_ext_substitutions.len(); let (tran_base_substitutions, tran_ext_substitutions) = - ConstraintCircuitMonad::lower_to_degree( - &mut self.tran, - info.target_degree, - num_base_cols, - num_ext_cols, - ); - num_base_cols += tran_base_substitutions.len(); - num_ext_cols += tran_ext_substitutions.len(); + ConstraintCircuitMonad::lower_to_degree(&mut self.tran, info); + info.num_base_cols += tran_base_substitutions.len(); + info.num_ext_cols += tran_ext_substitutions.len(); let (term_base_substitutions, term_ext_substitutions) = - ConstraintCircuitMonad::lower_to_degree( - &mut self.term, - info.target_degree, - num_base_cols, - num_ext_cols, - ); + ConstraintCircuitMonad::lower_to_degree(&mut self.term, info); AllSubstitutions { base: Substitutions { - lowering_info: info, + lowering_info, init: init_base_substitutions, cons: cons_base_substitutions, tran: tran_base_substitutions, term: term_base_substitutions, }, ext: Substitutions { - lowering_info: info, + lowering_info, init: init_ext_substitutions, cons: cons_ext_substitutions, tran: tran_ext_substitutions, @@ -164,29 +130,17 @@ impl Constraints { #[cfg(test)] mod tests { + use constraint_builder::ConstraintCircuitBuilder; use twenty_first::prelude::*; - use crate::codegen::circuit::ConstraintCircuitBuilder; use crate::table; use super::*; - impl Default for DegreeLoweringInfo { - /// For testing purposes only. - fn default() -> Self { - Self { - target_degree: 4, - num_base_cols: 42, - num_ext_cols: 13, - } - } - } - #[repr(usize)] enum TestChallenges { Ch0, Ch1, - Ch2, } impl From for usize { @@ -195,33 +149,41 @@ mod tests { } } + fn degree_lowering_info() -> DegreeLoweringInfo { + DegreeLoweringInfo { + target_degree: 4, + num_base_cols: 42, + num_ext_cols: 13, + } + } + #[test] fn test_constraints_can_be_fetched() { - let _ = Constraints::test_constraints(); + Constraints::test_constraints(); } #[test] fn degree_lowering_tables_code_can_be_generated_for_test_constraints() { - let lowering_info = DegreeLoweringInfo::default(); let mut constraints = Constraints::test_constraints(); - let substitutions = constraints.lower_to_target_degree_through_substitutions(lowering_info); - let _ = substitutions.generate_degree_lowering_table_code(); + let substitutions = + constraints.lower_to_target_degree_through_substitutions(degree_lowering_info()); + let _unused = substitutions.generate_degree_lowering_table_code(); } #[test] fn degree_lowering_tables_code_can_be_generated_from_all_constraints() { - let lowering_info = DegreeLoweringInfo::default(); let mut constraints = table::constraints(); - let substitutions = constraints.lower_to_target_degree_through_substitutions(lowering_info); - let _ = substitutions.generate_degree_lowering_table_code(); + let substitutions = + constraints.lower_to_target_degree_through_substitutions(degree_lowering_info()); + let _unused = substitutions.generate_degree_lowering_table_code(); } #[test] fn constraints_and_substitutions_can_be_combined() { let mut constraints = Constraints::test_constraints(); let substitutions = - constraints.lower_to_target_degree_through_substitutions(DegreeLoweringInfo::default()); - let _ = constraints.combine_with_substitution_induced_constraints(substitutions); + constraints.lower_to_target_degree_through_substitutions(degree_lowering_info()); + let _combined = constraints.combine_with_substitution_induced_constraints(substitutions); } impl Constraints { diff --git a/triton-vm/src/codegen/substitutions.rs b/triton-vm/src/codegen/substitutions.rs index 769e73d27..dc76fface 100644 --- a/triton-vm/src/codegen/substitutions.rs +++ b/triton-vm/src/codegen/substitutions.rs @@ -1,17 +1,17 @@ +use constraint_builder::BinOp; +use constraint_builder::CircuitExpression; +use constraint_builder::ConstraintCircuit; +use constraint_builder::ConstraintCircuitMonad; +use constraint_builder::DegreeLoweringInfo; +use constraint_builder::DualRowIndicator; +use constraint_builder::InputIndicator; +use constraint_builder::SingleRowIndicator; use itertools::Itertools; use proc_macro2::TokenStream; use quote::format_ident; use quote::quote; -use crate::codegen::circuit::BinOp; -use crate::codegen::circuit::CircuitExpression; -use crate::codegen::circuit::ConstraintCircuit; -use crate::codegen::circuit::ConstraintCircuitMonad; -use crate::codegen::circuit::DualRowIndicator; -use crate::codegen::circuit::InputIndicator; -use crate::codegen::circuit::SingleRowIndicator; use crate::codegen::constraints::RustBackend; -use crate::codegen::DegreeLoweringInfo; pub(crate) struct AllSubstitutions { pub base: Substitutions, diff --git a/triton-vm/src/stark.rs b/triton-vm/src/stark.rs index 60110b90d..030415d93 100644 --- a/triton-vm/src/stark.rs +++ b/triton-vm/src/stark.rs @@ -1325,7 +1325,6 @@ pub(crate) mod tests { use test_strategy::proptest; use twenty_first::math::other::random_elements; - use crate::codegen::circuit::ConstraintCircuitBuilder; use crate::error::InstructionError; use crate::example_programs::*; use crate::instruction::Instruction; @@ -1362,6 +1361,7 @@ pub(crate) mod tests { use crate::triton_program; use crate::vm::tests::*; use crate::PublicInput; + use constraint_builder::ConstraintCircuitBuilder; use super::*; diff --git a/triton-vm/src/table.rs b/triton-vm/src/table.rs index f3254a602..113f7c1d2 100644 --- a/triton-vm/src/table.rs +++ b/triton-vm/src/table.rs @@ -3,15 +3,15 @@ pub use crate::table::master_table::NUM_BASE_COLUMNS; pub use crate::table::master_table::NUM_EXT_COLUMNS; use arbitrary::Arbitrary; +use constraint_builder::ConstraintCircuitBuilder; +use constraint_builder::ConstraintCircuitMonad; +use constraint_builder::DualRowIndicator; +use constraint_builder::SingleRowIndicator; use strum::Display; use strum::EnumCount; use strum::EnumIter; use twenty_first::prelude::XFieldElement; -use crate::codegen::circuit::ConstraintCircuitBuilder; -use crate::codegen::circuit::ConstraintCircuitMonad; -use crate::codegen::circuit::DualRowIndicator; -use crate::codegen::circuit::SingleRowIndicator; use crate::codegen::Constraints; use crate::table::cascade_table::ExtCascadeTable; use crate::table::cross_table_argument::GrandCrossTableArg; @@ -80,7 +80,7 @@ pub type ExtensionRow = [XFieldElement; NUM_EXT_COLUMNS]; /// See also [`NUM_QUOTIENT_SEGMENTS`]. pub type QuotientSegments = [XFieldElement; NUM_QUOTIENT_SEGMENTS]; -pub fn constraints() -> Constraints { +pub(crate) fn constraints() -> Constraints { Constraints { init: initial_constraints(), cons: consistency_constraints(), @@ -156,3 +156,380 @@ fn terminal_constraints() -> Vec> { ] .concat() } + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use constraint_builder::BinOp; + use constraint_builder::CircuitExpression; + use constraint_builder::ConstraintCircuit; + use constraint_builder::ConstraintCircuitBuilder; + use constraint_builder::ConstraintCircuitMonad; + use constraint_builder::DegreeLoweringInfo; + use constraint_builder::InputIndicator; + use itertools::Itertools; + use ndarray::Array2; + use ndarray::ArrayView2; + use rand::prelude::StdRng; + use rand::random; + use rand::Rng; + use rand_core::SeedableRng; + use twenty_first::prelude::BFieldElement; + + use crate::prelude::Claim; + use crate::table::challenges::Challenges; + use crate::table::degree_lowering_table::DegreeLoweringTable; + use crate::table::master_table::AIR_TARGET_DEGREE; + use crate::table::master_table::CASCADE_TABLE_END; + use crate::table::master_table::EXT_CASCADE_TABLE_END; + use crate::table::master_table::EXT_HASH_TABLE_END; + use crate::table::master_table::EXT_JUMP_STACK_TABLE_END; + use crate::table::master_table::EXT_LOOKUP_TABLE_END; + use crate::table::master_table::EXT_OP_STACK_TABLE_END; + use crate::table::master_table::EXT_PROCESSOR_TABLE_END; + use crate::table::master_table::EXT_PROGRAM_TABLE_END; + use crate::table::master_table::EXT_RAM_TABLE_END; + use crate::table::master_table::EXT_U32_TABLE_END; + use crate::table::master_table::HASH_TABLE_END; + use crate::table::master_table::JUMP_STACK_TABLE_END; + use crate::table::master_table::LOOKUP_TABLE_END; + use crate::table::master_table::OP_STACK_TABLE_END; + use crate::table::master_table::PROCESSOR_TABLE_END; + use crate::table::master_table::PROGRAM_TABLE_END; + use crate::table::master_table::RAM_TABLE_END; + use crate::table::master_table::U32_TABLE_END; + + use super::*; + + /// Verify that all nodes evaluate to a unique value when given a randomized input. + /// If this is not the case two nodes that are not equal evaluate to the same value. + fn table_constraints_prop( + constraints: &[ConstraintCircuit], + table_name: &str, + ) { + let seed = random(); + let mut rng = StdRng::seed_from_u64(seed); + println!("seed: {seed}"); + + let dummy_claim = Claim::default(); + let challenges: [XFieldElement; Challenges::SAMPLE_COUNT] = rng.gen(); + let challenges = challenges.to_vec(); + let challenges = Challenges::new(challenges, &dummy_claim); + let challenges = &challenges.challenges; + + let num_rows = 2; + let base_shape = [num_rows, NUM_BASE_COLUMNS]; + let ext_shape = [num_rows, NUM_EXT_COLUMNS]; + let base_rows = Array2::from_shape_simple_fn(base_shape, || rng.gen::()); + let ext_rows = Array2::from_shape_simple_fn(ext_shape, || rng.gen::()); + let base_rows = base_rows.view(); + let ext_rows = ext_rows.view(); + + let mut values = HashMap::new(); + for c in constraints { + evaluate_assert_unique(c, challenges, base_rows, ext_rows, &mut values); + } + + let circuit_degree = constraints.iter().map(|c| c.degree()).max().unwrap_or(-1); + println!("Max degree constraint for {table_name} table: {circuit_degree}"); + } + + /// Recursively evaluates the given constraint circuit and its sub-circuits on the given + /// base and extension table, and returns the result of the evaluation. + /// At each recursive step, updates the given HashMap with the result of the evaluation. + /// If the HashMap already contains the result of the evaluation, panics. + /// This function is used to assert that the evaluation of a constraint circuit + /// and its sub-circuits is unique. + /// It is used to identify redundant constraints or sub-circuits. + /// The employed method is the Schwartz-Zippel lemma. + fn evaluate_assert_unique( + constraint: &ConstraintCircuit, + challenges: &[XFieldElement], + base_rows: ArrayView2, + ext_rows: ArrayView2, + values: &mut HashMap)>, + ) -> XFieldElement { + let value = match &constraint.expression { + CircuitExpression::BinaryOperation(binop, lhs, rhs) => { + let lhs = lhs.borrow(); + let rhs = rhs.borrow(); + let lhs = evaluate_assert_unique(&lhs, challenges, base_rows, ext_rows, values); + let rhs = evaluate_assert_unique(&rhs, challenges, base_rows, ext_rows, values); + binop.operation(lhs, rhs) + } + _ => constraint.evaluate(base_rows, ext_rows, challenges), + }; + + let own_id = constraint.id.to_owned(); + let maybe_entry = values.insert(value, (own_id, constraint.clone())); + if let Some((other_id, other_circuit)) = maybe_entry { + assert_eq!( + own_id, other_id, + "Circuit ID {other_id} and circuit ID {own_id} are not unique. \ + Collision on:\n\ + ID {other_id} – {other_circuit}\n\ + ID {own_id} – {constraint}\n\ + Both evaluate to {value}.", + ); + } + + value + } + + #[test] + fn nodes_are_unique_for_all_constraints() { + fn build_constraints( + multicircuit_builder: &dyn Fn( + &ConstraintCircuitBuilder, + ) -> Vec>, + ) -> Vec> { + let circuit_builder = ConstraintCircuitBuilder::new(); + let multicircuit = multicircuit_builder(&circuit_builder); + let mut constraints = multicircuit.into_iter().map(|c| c.consume()).collect_vec(); + ConstraintCircuit::assert_unique_ids(&mut constraints); + constraints + } + + macro_rules! assert_constraint_properties { + ($table:ident) => {{ + let init = build_constraints(&$table::initial_constraints); + let cons = build_constraints(&$table::consistency_constraints); + let tran = build_constraints(&$table::transition_constraints); + let term = build_constraints(&$table::terminal_constraints); + table_constraints_prop(&init, concat!(stringify!($table), " init")); + table_constraints_prop(&cons, concat!(stringify!($table), " cons")); + table_constraints_prop(&tran, concat!(stringify!($table), " tran")); + table_constraints_prop(&term, concat!(stringify!($table), " term")); + }}; + } + + assert_constraint_properties!(ExtProcessorTable); + assert_constraint_properties!(ExtProgramTable); + assert_constraint_properties!(ExtJumpStackTable); + assert_constraint_properties!(ExtOpStackTable); + assert_constraint_properties!(ExtRamTable); + assert_constraint_properties!(ExtHashTable); + assert_constraint_properties!(ExtU32Table); + assert_constraint_properties!(ExtCascadeTable); + assert_constraint_properties!(ExtLookupTable); + } + + /// Like [`ConstraintCircuitMonad::lower_to_degree`] with additional assertion of expected + /// properties. Also prints: + /// - the given multicircuit prior to degree lowering + /// - the multicircuit after degree lowering + /// - the new base constraints + /// - the new extension constraints + /// - the numbers of original and new constraints + fn lower_degree_and_assert_properties( + multicircuit: &mut [ConstraintCircuitMonad], + info: DegreeLoweringInfo, + ) -> ( + Vec>, + Vec>, + ) { + let seed = random(); + let mut rng = StdRng::seed_from_u64(seed); + println!("seed: {seed}"); + + let num_constraints = multicircuit.len(); + println!("original multicircuit:"); + for circuit in multicircuit.iter() { + println!(" {circuit}"); + } + + let (new_base_constraints, new_ext_constraints) = + ConstraintCircuitMonad::lower_to_degree(multicircuit, info); + + assert_eq!(num_constraints, multicircuit.len()); + + let target_deg = info.target_degree; + assert!(ConstraintCircuitMonad::multicircuit_degree(multicircuit) <= target_deg); + assert!(ConstraintCircuitMonad::multicircuit_degree(&new_base_constraints) <= target_deg); + assert!(ConstraintCircuitMonad::multicircuit_degree(&new_ext_constraints) <= target_deg); + + // Check that the new constraints are simple substitutions. + let mut substitution_rules = vec![]; + for (constraint_type, constraints) in [ + ("base", &new_base_constraints), + ("ext", &new_ext_constraints), + ] { + for (i, constraint) in constraints.iter().enumerate() { + let expression = constraint.circuit.borrow().expression.clone(); + let CircuitExpression::BinaryOperation(BinOp::Add, lhs, rhs) = expression else { + panic!("New {constraint_type} constraint {i} must be a subtraction."); + }; + let CircuitExpression::Input(input_indicator) = lhs.borrow().expression.clone() + else { + panic!("New {constraint_type} constraint {i} must be a simple substitution."); + }; + let substitution_rule = rhs.borrow().clone(); + assert_substitution_rule_uses_legal_variables(input_indicator, &substitution_rule); + substitution_rules.push(substitution_rule); + } + } + + // Use the Schwartz-Zippel lemma to check no two substitution rules are equal. + let dummy_claim = Claim::default(); + let challenges: [XFieldElement; Challenges::SAMPLE_COUNT] = rng.gen(); + let challenges = challenges.to_vec(); + let challenges = Challenges::new(challenges, &dummy_claim); + let challenges = &challenges.challenges; + + let num_rows = 2; + let num_new_base_constraints = new_base_constraints.len(); + let num_new_ext_constraints = new_ext_constraints.len(); + let num_base_cols = NUM_BASE_COLUMNS + num_new_base_constraints; + let num_ext_cols = NUM_EXT_COLUMNS + num_new_ext_constraints; + let base_shape = [num_rows, num_base_cols]; + let ext_shape = [num_rows, num_ext_cols]; + let base_rows = Array2::from_shape_simple_fn(base_shape, || rng.gen::()); + let ext_rows = Array2::from_shape_simple_fn(ext_shape, || rng.gen::()); + let base_rows = base_rows.view(); + let ext_rows = ext_rows.view(); + + let evaluated_substitution_rules = substitution_rules + .iter() + .map(|c| c.evaluate(base_rows, ext_rows, challenges)); + + let mut values_to_index = HashMap::new(); + for (idx, value) in evaluated_substitution_rules.enumerate() { + if let Some(index) = values_to_index.get(&value) { + panic!("Substitution {idx} must be distinct from substitution {index}."); + } else { + values_to_index.insert(value, idx); + } + } + + // Print the multicircuit and new constraints after degree lowering. + println!("new multicircuit:"); + for circuit in multicircuit.iter() { + println!(" {circuit}"); + } + println!("new base constraints:"); + for constraint in &new_base_constraints { + println!(" {constraint}"); + } + println!("new ext constraints:"); + for constraint in &new_ext_constraints { + println!(" {constraint}"); + } + + println!( + "Started with {num_constraints} constraints. \ + Derived {num_new_base_constraints} new base, \ + {num_new_ext_constraints} new extension constraints." + ); + + (new_base_constraints, new_ext_constraints) + } + + /// Panics if the given substitution rule uses variables with an index greater than (or equal) + /// to the given index. In practice, this given index corresponds to a newly introduced + /// variable. + fn assert_substitution_rule_uses_legal_variables( + new_var: II, + substitution_rule: &ConstraintCircuit, + ) { + match substitution_rule.expression.clone() { + CircuitExpression::BinaryOperation(_, lhs, rhs) => { + let lhs = lhs.borrow(); + let rhs = rhs.borrow(); + assert_substitution_rule_uses_legal_variables(new_var, &lhs); + assert_substitution_rule_uses_legal_variables(new_var, &rhs); + } + CircuitExpression::Input(old_var) => { + let new_var_is_base = new_var.is_base_table_column(); + let old_var_is_base = old_var.is_base_table_column(); + let legal_substitute = match (new_var_is_base, old_var_is_base) { + (true, false) => false, + (false, true) => true, + _ => old_var.column() < new_var.column(), + }; + assert!(legal_substitute, "Cannot replace {old_var} with {new_var}."); + } + _ => (), + }; + } + + #[test] + fn degree_lowering_works_correctly_for_all_tables() { + macro_rules! assert_degree_lowering { + ($table:ident ($base_end:ident, $ext_end:ident)) => {{ + let degree_lowering_info = DegreeLoweringInfo { + target_degree: AIR_TARGET_DEGREE, + num_base_cols: $base_end, + num_ext_cols: $ext_end, + }; + let circuit_builder = ConstraintCircuitBuilder::new(); + let mut init = $table::initial_constraints(&circuit_builder); + lower_degree_and_assert_properties(&mut init, degree_lowering_info); + + let circuit_builder = ConstraintCircuitBuilder::new(); + let mut cons = $table::consistency_constraints(&circuit_builder); + lower_degree_and_assert_properties(&mut cons, degree_lowering_info); + + let circuit_builder = ConstraintCircuitBuilder::new(); + let mut tran = $table::transition_constraints(&circuit_builder); + lower_degree_and_assert_properties(&mut tran, degree_lowering_info); + + let circuit_builder = ConstraintCircuitBuilder::new(); + let mut term = $table::terminal_constraints(&circuit_builder); + lower_degree_and_assert_properties(&mut term, degree_lowering_info); + }}; + } + + assert_degree_lowering!(ExtProgramTable(PROGRAM_TABLE_END, EXT_PROGRAM_TABLE_END)); + assert_degree_lowering!(ExtProcessorTable( + PROCESSOR_TABLE_END, + EXT_PROCESSOR_TABLE_END + )); + assert_degree_lowering!(ExtOpStackTable(OP_STACK_TABLE_END, EXT_OP_STACK_TABLE_END)); + assert_degree_lowering!(ExtRamTable(RAM_TABLE_END, EXT_RAM_TABLE_END)); + assert_degree_lowering!(ExtJumpStackTable( + JUMP_STACK_TABLE_END, + EXT_JUMP_STACK_TABLE_END + )); + assert_degree_lowering!(ExtHashTable(HASH_TABLE_END, EXT_HASH_TABLE_END)); + assert_degree_lowering!(ExtCascadeTable(CASCADE_TABLE_END, EXT_CASCADE_TABLE_END)); + assert_degree_lowering!(ExtLookupTable(LOOKUP_TABLE_END, EXT_LOOKUP_TABLE_END)); + assert_degree_lowering!(ExtU32Table(U32_TABLE_END, EXT_U32_TABLE_END)); + } + + /// Fills the derived columns of the degree-lowering table using randomly generated rows and + /// checks the resulting values for uniqueness. The described method corresponds to an + /// application of the Schwartz-Zippel lemma to check uniqueness of the substitution rules + /// generated during degree lowering. + #[test] + #[ignore = "(probably) requires normalization of circuit expressions"] + fn substitution_rules_are_unique() { + let challenges = Challenges::default(); + let mut base_table_rows = Array2::from_shape_fn((2, NUM_BASE_COLUMNS), |_| random()); + let mut ext_table_rows = Array2::from_shape_fn((2, NUM_EXT_COLUMNS), |_| random()); + + DegreeLoweringTable::fill_derived_base_columns(base_table_rows.view_mut()); + DegreeLoweringTable::fill_derived_ext_columns( + base_table_rows.view(), + ext_table_rows.view_mut(), + &challenges, + ); + + let mut encountered_values = HashMap::new(); + for col_idx in 0..NUM_BASE_COLUMNS { + let val = base_table_rows[(0, col_idx)].lift(); + let other_entry = encountered_values.insert(val, col_idx); + if let Some(other_idx) = other_entry { + panic!("Duplicate value {val} in derived base column {other_idx} and {col_idx}."); + } + } + println!("Now comparing extension columns…"); + for col_idx in 0..NUM_EXT_COLUMNS { + let val = ext_table_rows[(0, col_idx)]; + let other_entry = encountered_values.insert(val, col_idx); + if let Some(other_idx) = other_entry { + panic!("Duplicate value {val} in derived ext column {other_idx} and {col_idx}."); + } + } + } +} diff --git a/triton-vm/src/table/cascade_table.rs b/triton-vm/src/table/cascade_table.rs index c4a8dfe9f..463919564 100644 --- a/triton-vm/src/table/cascade_table.rs +++ b/triton-vm/src/table/cascade_table.rs @@ -1,9 +1,9 @@ -use crate::codegen::circuit::ConstraintCircuitBuilder; -use crate::codegen::circuit::ConstraintCircuitMonad; -use crate::codegen::circuit::DualRowIndicator; -use crate::codegen::circuit::DualRowIndicator::*; -use crate::codegen::circuit::SingleRowIndicator; -use crate::codegen::circuit::SingleRowIndicator::*; +use constraint_builder::ConstraintCircuitBuilder; +use constraint_builder::ConstraintCircuitMonad; +use constraint_builder::DualRowIndicator; +use constraint_builder::DualRowIndicator::*; +use constraint_builder::SingleRowIndicator; +use constraint_builder::SingleRowIndicator::*; use ndarray::s; use ndarray::ArrayView2; use ndarray::ArrayViewMut2; diff --git a/triton-vm/src/table/cross_table_argument.rs b/triton-vm/src/table/cross_table_argument.rs index 8eb1b41b5..43e5ab83d 100644 --- a/triton-vm/src/table/cross_table_argument.rs +++ b/triton-vm/src/table/cross_table_argument.rs @@ -1,11 +1,11 @@ use std::ops::Add; use std::ops::Mul; -use crate::codegen::circuit::ConstraintCircuitBuilder; -use crate::codegen::circuit::ConstraintCircuitMonad; -use crate::codegen::circuit::DualRowIndicator; -use crate::codegen::circuit::SingleRowIndicator; -use crate::codegen::circuit::SingleRowIndicator::ExtRow; +use constraint_builder::ConstraintCircuitBuilder; +use constraint_builder::ConstraintCircuitMonad; +use constraint_builder::DualRowIndicator; +use constraint_builder::SingleRowIndicator; +use constraint_builder::SingleRowIndicator::ExtRow; use twenty_first::prelude::*; use crate::table::challenges::ChallengeId::*; diff --git a/triton-vm/src/table/hash_table.rs b/triton-vm/src/table/hash_table.rs index 323750013..7e8411809 100644 --- a/triton-vm/src/table/hash_table.rs +++ b/triton-vm/src/table/hash_table.rs @@ -1,10 +1,10 @@ -use crate::codegen::circuit::ConstraintCircuitBuilder; -use crate::codegen::circuit::ConstraintCircuitMonad; -use crate::codegen::circuit::DualRowIndicator; -use crate::codegen::circuit::DualRowIndicator::*; -use crate::codegen::circuit::InputIndicator; -use crate::codegen::circuit::SingleRowIndicator; -use crate::codegen::circuit::SingleRowIndicator::*; +use constraint_builder::ConstraintCircuitBuilder; +use constraint_builder::ConstraintCircuitMonad; +use constraint_builder::DualRowIndicator; +use constraint_builder::DualRowIndicator::*; +use constraint_builder::InputIndicator; +use constraint_builder::SingleRowIndicator; +use constraint_builder::SingleRowIndicator::*; use itertools::Itertools; use ndarray::*; use num_traits::Zero; diff --git a/triton-vm/src/table/jump_stack_table.rs b/triton-vm/src/table/jump_stack_table.rs index ec055bc97..0547ab894 100644 --- a/triton-vm/src/table/jump_stack_table.rs +++ b/triton-vm/src/table/jump_stack_table.rs @@ -2,9 +2,9 @@ use std::cmp::Ordering; use std::collections::HashMap; use std::ops::Range; -use crate::codegen::circuit::DualRowIndicator::*; -use crate::codegen::circuit::SingleRowIndicator::*; -use crate::codegen::circuit::*; +use constraint_builder::DualRowIndicator::*; +use constraint_builder::SingleRowIndicator::*; +use constraint_builder::*; use itertools::Itertools; use ndarray::parallel::prelude::*; use ndarray::prelude::*; diff --git a/triton-vm/src/table/lookup_table.rs b/triton-vm/src/table/lookup_table.rs index f1350a355..7ecb95491 100644 --- a/triton-vm/src/table/lookup_table.rs +++ b/triton-vm/src/table/lookup_table.rs @@ -1,9 +1,9 @@ -use crate::codegen::circuit::ConstraintCircuitBuilder; -use crate::codegen::circuit::ConstraintCircuitMonad; -use crate::codegen::circuit::DualRowIndicator; -use crate::codegen::circuit::DualRowIndicator::*; -use crate::codegen::circuit::SingleRowIndicator; -use crate::codegen::circuit::SingleRowIndicator::*; +use constraint_builder::ConstraintCircuitBuilder; +use constraint_builder::ConstraintCircuitMonad; +use constraint_builder::DualRowIndicator; +use constraint_builder::DualRowIndicator::*; +use constraint_builder::SingleRowIndicator; +use constraint_builder::SingleRowIndicator::*; use itertools::Itertools; use ndarray::prelude::*; use num_traits::ConstOne; diff --git a/triton-vm/src/table/master_table.rs b/triton-vm/src/table/master_table.rs index c9c8244f5..8b10d4ca9 100644 --- a/triton-vm/src/table/master_table.rs +++ b/triton-vm/src/table/master_table.rs @@ -1273,6 +1273,11 @@ mod tests { use fs_err as fs; use std::path::Path; + use constraint_builder::ConstraintCircuitBuilder; + use constraint_builder::ConstraintCircuitMonad; + use constraint_builder::DegreeLoweringInfo; + use constraint_builder::DualRowIndicator; + use constraint_builder::SingleRowIndicator; use master_table::cross_table_argument::GrandCrossTableArg; use ndarray::s; use ndarray::Array2; @@ -1292,10 +1297,6 @@ mod tests { use crate::air::tasm_air_constraints::dynamic_air_constraint_evaluation_tasm; use crate::air::tasm_air_constraints::static_air_constraint_evaluation_tasm; use crate::arithmetic_domain::ArithmeticDomain; - use crate::codegen::circuit::ConstraintCircuitBuilder; - use crate::codegen::circuit::ConstraintCircuitMonad; - use crate::codegen::circuit::DualRowIndicator; - use crate::codegen::circuit::SingleRowIndicator; use crate::instruction::tests::InstructionBucket; use crate::instruction::Instruction; use crate::instruction::InstructionBit; @@ -1577,6 +1578,12 @@ mod tests { continue; }; + let degree_lowering_info = DegreeLoweringInfo { + target_degree, + num_base_cols: 0, + num_ext_cols: 0, + }; + let initial_constraints = constraints_without_degree_lowering!(initial_constraints); let consistency_constraints = constraints_without_degree_lowering!(consistency_constraints); @@ -1586,10 +1593,10 @@ mod tests { // generic closures are not possible; define two variants :( let lower_to_target_degree_single_row = |mut constraints: Vec<_>| { - ConstraintCircuitMonad::lower_to_degree(&mut constraints, target_degree, 0, 0) + ConstraintCircuitMonad::lower_to_degree(&mut constraints, degree_lowering_info) }; let lower_to_target_degree_double_row = |mut constraints: Vec<_>| { - ConstraintCircuitMonad::lower_to_degree(&mut constraints, target_degree, 0, 0) + ConstraintCircuitMonad::lower_to_degree(&mut constraints, degree_lowering_info) }; let (init_main, init_aux) = lower_to_target_degree_single_row(initial_constraints); @@ -1784,45 +1791,38 @@ mod tests { let mut transition_constraints = table.transition_constraints.clone(); let mut terminal_constraints = table.terminal_constraints.clone(); - if let Some(target) = target_degree { - let (new_base_initial, new_ext_initial) = - ConstraintCircuitMonad::lower_to_degree( - &mut table.initial_constraints, - target, - table.last_base_column_index, - table.last_ext_column_index, - ); - let (new_base_consistency, new_ext_consistency) = - ConstraintCircuitMonad::lower_to_degree( - &mut table.consistency_constraints, - target, - table.last_base_column_index, - table.last_ext_column_index, - ); - let (new_base_transition, new_ext_transition) = - ConstraintCircuitMonad::lower_to_degree( - &mut table.transition_constraints, - target, - table.last_base_column_index, - table.last_ext_column_index, - ); - let (new_base_terminal, new_ext_terminal) = - ConstraintCircuitMonad::lower_to_degree( - &mut table.terminal_constraints, - target, - table.last_base_column_index, - table.last_ext_column_index, - ); - - initial_constraints.extend(new_base_initial); - consistency_constraints.extend(new_base_consistency); - transition_constraints.extend(new_base_transition); - terminal_constraints.extend(new_base_terminal); - - initial_constraints.extend(new_ext_initial); - consistency_constraints.extend(new_ext_consistency); - transition_constraints.extend(new_ext_transition); - terminal_constraints.extend(new_ext_terminal); + if let Some(target_degree) = target_degree { + let info = DegreeLoweringInfo { + target_degree, + num_base_cols: table.last_base_column_index, + num_ext_cols: table.last_ext_column_index, + }; + let (new_base_init, new_ext_init) = ConstraintCircuitMonad::lower_to_degree( + &mut table.initial_constraints, + info, + ); + let (new_base_cons, new_ext_cons) = ConstraintCircuitMonad::lower_to_degree( + &mut table.consistency_constraints, + info, + ); + let (new_base_tran, new_ext_tran) = ConstraintCircuitMonad::lower_to_degree( + &mut table.transition_constraints, + info, + ); + let (new_base_term, new_ext_term) = ConstraintCircuitMonad::lower_to_degree( + &mut table.terminal_constraints, + info, + ); + + initial_constraints.extend(new_base_init); + consistency_constraints.extend(new_base_cons); + transition_constraints.extend(new_base_tran); + terminal_constraints.extend(new_base_term); + + initial_constraints.extend(new_ext_init); + consistency_constraints.extend(new_ext_cons); + transition_constraints.extend(new_ext_tran); + terminal_constraints.extend(new_ext_term); } let table_max_degree = [ diff --git a/triton-vm/src/table/op_stack_table.rs b/triton-vm/src/table/op_stack_table.rs index 3e17c60a1..ac63e2259 100644 --- a/triton-vm/src/table/op_stack_table.rs +++ b/triton-vm/src/table/op_stack_table.rs @@ -2,10 +2,10 @@ use std::cmp::Ordering; use std::collections::HashMap; use std::ops::Range; -use crate::codegen::circuit::DualRowIndicator::*; -use crate::codegen::circuit::SingleRowIndicator::*; -use crate::codegen::circuit::*; use arbitrary::Arbitrary; +use constraint_builder::DualRowIndicator::*; +use constraint_builder::SingleRowIndicator::*; +use constraint_builder::*; use itertools::Itertools; use ndarray::parallel::prelude::*; use ndarray::prelude::*; diff --git a/triton-vm/src/table/processor_table.rs b/triton-vm/src/table/processor_table.rs index 394851ef9..d7780a738 100644 --- a/triton-vm/src/table/processor_table.rs +++ b/triton-vm/src/table/processor_table.rs @@ -1,9 +1,9 @@ use std::cmp::max; use std::ops::Mul; -use crate::codegen::circuit::DualRowIndicator::*; -use crate::codegen::circuit::SingleRowIndicator::*; -use crate::codegen::circuit::*; +use constraint_builder::DualRowIndicator::*; +use constraint_builder::SingleRowIndicator::*; +use constraint_builder::*; use itertools::izip; use itertools::Itertools; use ndarray::parallel::prelude::*; diff --git a/triton-vm/src/table/program_table.rs b/triton-vm/src/table/program_table.rs index 409b52fcc..cd2f2736d 100644 --- a/triton-vm/src/table/program_table.rs +++ b/triton-vm/src/table/program_table.rs @@ -1,8 +1,8 @@ use std::cmp::Ordering; -use crate::codegen::circuit::DualRowIndicator::*; -use crate::codegen::circuit::SingleRowIndicator::*; -use crate::codegen::circuit::*; +use constraint_builder::DualRowIndicator::*; +use constraint_builder::SingleRowIndicator::*; +use constraint_builder::*; use ndarray::s; use ndarray::Array1; use ndarray::ArrayView1; diff --git a/triton-vm/src/table/ram_table.rs b/triton-vm/src/table/ram_table.rs index 46d112e41..ae285647e 100644 --- a/triton-vm/src/table/ram_table.rs +++ b/triton-vm/src/table/ram_table.rs @@ -1,9 +1,9 @@ use std::cmp::Ordering; -use crate::codegen::circuit::DualRowIndicator::*; -use crate::codegen::circuit::SingleRowIndicator::*; -use crate::codegen::circuit::*; use arbitrary::Arbitrary; +use constraint_builder::DualRowIndicator::*; +use constraint_builder::SingleRowIndicator::*; +use constraint_builder::*; use itertools::Itertools; use ndarray::parallel::prelude::*; use ndarray::prelude::*; diff --git a/triton-vm/src/table/u32_table.rs b/triton-vm/src/table/u32_table.rs index 58c7c7604..748850dc4 100644 --- a/triton-vm/src/table/u32_table.rs +++ b/triton-vm/src/table/u32_table.rs @@ -1,14 +1,14 @@ use std::cmp::max; use std::ops::Mul; -use crate::codegen::circuit::ConstraintCircuitBuilder; -use crate::codegen::circuit::ConstraintCircuitMonad; -use crate::codegen::circuit::DualRowIndicator; -use crate::codegen::circuit::DualRowIndicator::*; -use crate::codegen::circuit::InputIndicator; -use crate::codegen::circuit::SingleRowIndicator; -use crate::codegen::circuit::SingleRowIndicator::*; use arbitrary::Arbitrary; +use constraint_builder::ConstraintCircuitBuilder; +use constraint_builder::ConstraintCircuitMonad; +use constraint_builder::DualRowIndicator; +use constraint_builder::DualRowIndicator::*; +use constraint_builder::InputIndicator; +use constraint_builder::SingleRowIndicator; +use constraint_builder::SingleRowIndicator::*; use ndarray::parallel::prelude::*; use ndarray::s; use ndarray::Array1; From 5e88a44be1d5253d83794d64ae347ce675b947f1 Mon Sep 17 00:00:00 2001 From: Jan Ferdinand Sauer Date: Thu, 29 Aug 2024 09:28:56 +0200 Subject: [PATCH 04/15] refactor!: Move ISA to own crate changelog: ignore --- Cargo.toml | 7 +- .../Cargo.toml | 5 +- .../src/lib.rs | 0 triton-isa/Cargo.toml | 36 + triton-isa/src/error.rs | 6 + {triton-vm => triton-isa}/src/instruction.rs | 767 +++++------ triton-isa/src/lib.rs | 346 +++++ {triton-vm => triton-isa}/src/op_stack.rs | 232 ++-- {triton-vm => triton-isa}/src/parser.rs | 744 ++++++----- triton-isa/src/program.rs | 689 ++++++++++ triton-vm/Cargo.toml | 3 +- triton-vm/benches/cached_vs_jit_trace.rs | 4 +- triton-vm/benches/mem_io.rs | 10 +- triton-vm/benches/prove_fib.rs | 5 +- triton-vm/benches/prove_halt.rs | 11 +- .../benches/trace_mmr_new_peak_calculation.rs | 5 +- triton-vm/benches/verify_halt.rs | 6 +- triton-vm/src/aet.rs | 19 +- triton-vm/src/air.rs | 2 +- triton-vm/src/air/tasm_air_constraints.rs | 3 +- triton-vm/src/codegen/constraints.rs | 5 +- triton-vm/src/error.rs | 136 +- triton-vm/src/example_programs.rs | 5 +- triton-vm/src/execution_trace_profiler.rs | 274 ++++ triton-vm/src/lib.rs | 356 +----- triton-vm/src/prelude.rs | 19 +- triton-vm/src/program.rs | 1119 ----------------- triton-vm/src/proof.rs | 2 +- triton-vm/src/shared_tests.rs | 17 +- triton-vm/src/stark.rs | 28 +- triton-vm/src/table/hash_table.rs | 33 +- triton-vm/src/table/jump_stack_table.rs | 2 +- triton-vm/src/table/master_table.rs | 40 +- triton-vm/src/table/op_stack_table.rs | 7 +- triton-vm/src/table/processor_table.rs | 29 +- triton-vm/src/table/u32_table.rs | 2 +- triton-vm/src/vm.rs | 530 ++++++-- 37 files changed, 2861 insertions(+), 2643 deletions(-) rename {constraint-builder => triton-constraint-builder}/Cargo.toml (91%) rename {constraint-builder => triton-constraint-builder}/src/lib.rs (100%) create mode 100644 triton-isa/Cargo.toml create mode 100644 triton-isa/src/error.rs rename {triton-vm => triton-isa}/src/instruction.rs (55%) create mode 100644 triton-isa/src/lib.rs rename {triton-vm => triton-isa}/src/op_stack.rs (81%) rename {triton-vm => triton-isa}/src/parser.rs (70%) create mode 100644 triton-isa/src/program.rs create mode 100644 triton-vm/src/execution_trace_profiler.rs delete mode 100644 triton-vm/src/program.rs diff --git a/Cargo.toml b/Cargo.toml index 9ef2f97cd..11cc3d4e2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["constraint-builder", "triton-vm"] +members = ["triton-vm", "triton-constraint-builder", "triton-isa"] resolver = "2" [profile.test] @@ -56,6 +56,11 @@ trybuild = "1.0" twenty-first = "0.42.0-alpha.9" unicode-width = "0.1" +[workspace.lints.rust] +let_underscore_drop = "warn" +missing_copy_implementations = "warn" +missing_debug_implementations = "warn" + [workspace.lints.clippy] cast_lossless = "warn" cloned_instead_of_copied = "warn" diff --git a/constraint-builder/Cargo.toml b/triton-constraint-builder/Cargo.toml similarity index 91% rename from constraint-builder/Cargo.toml rename to triton-constraint-builder/Cargo.toml index 51bdd441c..8dbd03f54 100644 --- a/constraint-builder/Cargo.toml +++ b/triton-constraint-builder/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "constraint-builder" +name = "triton-constraint-builder" description = """ AIR constraints build helper for Triton VM. """ @@ -27,3 +27,6 @@ proptest.workspace = true proptest-arbitrary-interop.workspace = true rand.workspace = true test-strategy.workspace = true + +[lints] +workspace = true diff --git a/constraint-builder/src/lib.rs b/triton-constraint-builder/src/lib.rs similarity index 100% rename from constraint-builder/src/lib.rs rename to triton-constraint-builder/src/lib.rs diff --git a/triton-isa/Cargo.toml b/triton-isa/Cargo.toml new file mode 100644 index 000000000..d697d3175 --- /dev/null +++ b/triton-isa/Cargo.toml @@ -0,0 +1,36 @@ +[package] +name = "triton-isa" +description = """ +The instruction set architecture for Triton VM. +""" + +version.workspace = true +edition.workspace = true +authors.workspace = true +license.workspace = true +homepage.workspace = true +documentation.workspace = true +repository.workspace = true +readme.workspace = true + +[dependencies] +arbitrary.workspace = true +get-size.workspace = true +itertools.workspace = true +lazy_static.workspace = true +nom.workspace = true +num-traits.workspace = true +serde.workspace = true +strum.workspace = true +thiserror.workspace = true +twenty-first.workspace = true + +[dev-dependencies] +assert2.workspace = true +proptest.workspace = true +proptest-arbitrary-interop.workspace = true +rand.workspace = true +test-strategy.workspace = true + +[lints] +workspace = true diff --git a/triton-isa/src/error.rs b/triton-isa/src/error.rs new file mode 100644 index 000000000..6cbbdc0d1 --- /dev/null +++ b/triton-isa/src/error.rs @@ -0,0 +1,6 @@ +pub use crate::instruction::InstructionError; +pub use crate::op_stack::NumberOfWordsError; +pub use crate::op_stack::OpStackElementError; +pub use crate::op_stack::OpStackError; +pub use crate::parser::ParseError; +pub use crate::program::ProgramDecodingError; diff --git a/triton-vm/src/instruction.rs b/triton-isa/src/instruction.rs similarity index 55% rename from triton-vm/src/instruction.rs rename to triton-isa/src/instruction.rs index 7560e666c..f6c2df29c 100644 --- a/triton-vm/src/instruction.rs +++ b/triton-isa/src/instruction.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use std::fmt::Display; use std::fmt::Formatter; use std::fmt::Result as FmtResult; +use std::num::TryFromIntError; use std::result; use arbitrary::Arbitrary; @@ -15,24 +16,74 @@ use serde::Serialize; use strum::EnumCount; use strum::EnumIter; use strum::IntoEnumIterator; +use thiserror::Error; use twenty_first::prelude::*; -use AnInstruction::*; - -use crate::error::InstructionError; -use crate::instruction::InstructionBit::*; -use crate::op_stack::NumberOfWords::*; -use crate::op_stack::OpStackElement::*; -use crate::op_stack::*; +use crate::op_stack::NumberOfWords; +use crate::op_stack::OpStackElement; +use crate::op_stack::OpStackError; type Result = result::Result; /// An `Instruction` has `call` addresses encoded as absolute integers. pub type Instruction = AnInstruction; -pub const ALL_INSTRUCTIONS: [Instruction; Instruction::COUNT] = - all_instructions_with_default_args(); -pub const ALL_INSTRUCTION_NAMES: [&str; Instruction::COUNT] = all_instruction_names(); +pub const ALL_INSTRUCTIONS: [Instruction; Instruction::COUNT] = [ + Instruction::Pop(NumberOfWords::N1), + Instruction::Push(BFieldElement::ZERO), + Instruction::Divine(NumberOfWords::N1), + Instruction::Dup(OpStackElement::ST0), + Instruction::Swap(OpStackElement::ST0), + Instruction::Halt, + Instruction::Nop, + Instruction::Skiz, + Instruction::Call(BFieldElement::ZERO), + Instruction::Return, + Instruction::Recurse, + Instruction::RecurseOrReturn, + Instruction::Assert, + Instruction::ReadMem(NumberOfWords::N1), + Instruction::WriteMem(NumberOfWords::N1), + Instruction::Hash, + Instruction::AssertVector, + Instruction::SpongeInit, + Instruction::SpongeAbsorb, + Instruction::SpongeAbsorbMem, + Instruction::SpongeSqueeze, + Instruction::Add, + Instruction::AddI(BFieldElement::ZERO), + Instruction::Mul, + Instruction::Invert, + Instruction::Eq, + Instruction::Split, + Instruction::Lt, + Instruction::And, + Instruction::Xor, + Instruction::Log2Floor, + Instruction::Pow, + Instruction::DivMod, + Instruction::PopCount, + Instruction::XxAdd, + Instruction::XxMul, + Instruction::XInvert, + Instruction::XbMul, + Instruction::ReadIo(NumberOfWords::N1), + Instruction::WriteIo(NumberOfWords::N1), + Instruction::MerkleStep, + Instruction::MerkleStepMem, + Instruction::XxDotStep, + Instruction::XbDotStep, +]; + +pub const ALL_INSTRUCTION_NAMES: [&str; Instruction::COUNT] = { + let mut names = [""; Instruction::COUNT]; + let mut i = 0; + while i < Instruction::COUNT { + names[i] = ALL_INSTRUCTIONS[i].name(); + i += 1; + } + names +}; lazy_static! { pub static ref OPCODE_TO_INSTRUCTION_MAP: HashMap = { @@ -225,99 +276,99 @@ impl AnInstruction { /// Assign a unique positive integer to each `Instruction`. pub const fn opcode(&self) -> u32 { match self { - Pop(_) => 3, - Push(_) => 1, - Divine(_) => 9, - Dup(_) => 17, - Swap(_) => 25, - Halt => 0, - Nop => 8, - Skiz => 2, - Call(_) => 33, - Return => 16, - Recurse => 24, - RecurseOrReturn => 32, - Assert => 10, - ReadMem(_) => 41, - WriteMem(_) => 11, - Hash => 18, - AssertVector => 26, - SpongeInit => 40, - SpongeAbsorb => 34, - SpongeAbsorbMem => 48, - SpongeSqueeze => 56, - Add => 42, - AddI(_) => 49, - Mul => 50, - Invert => 64, - Eq => 58, - Split => 4, - Lt => 6, - And => 14, - Xor => 22, - Log2Floor => 12, - Pow => 30, - DivMod => 20, - PopCount => 28, - XxAdd => 66, - XxMul => 74, - XInvert => 72, - XbMul => 82, - ReadIo(_) => 57, - WriteIo(_) => 19, - MerkleStep => 36, - MerkleStepMem => 44, - XxDotStep => 80, - XbDotStep => 88, + AnInstruction::Pop(_) => 3, + AnInstruction::Push(_) => 1, + AnInstruction::Divine(_) => 9, + AnInstruction::Dup(_) => 17, + AnInstruction::Swap(_) => 25, + AnInstruction::Halt => 0, + AnInstruction::Nop => 8, + AnInstruction::Skiz => 2, + AnInstruction::Call(_) => 33, + AnInstruction::Return => 16, + AnInstruction::Recurse => 24, + AnInstruction::RecurseOrReturn => 32, + AnInstruction::Assert => 10, + AnInstruction::ReadMem(_) => 41, + AnInstruction::WriteMem(_) => 11, + AnInstruction::Hash => 18, + AnInstruction::AssertVector => 26, + AnInstruction::SpongeInit => 40, + AnInstruction::SpongeAbsorb => 34, + AnInstruction::SpongeAbsorbMem => 48, + AnInstruction::SpongeSqueeze => 56, + AnInstruction::Add => 42, + AnInstruction::AddI(_) => 49, + AnInstruction::Mul => 50, + AnInstruction::Invert => 64, + AnInstruction::Eq => 58, + AnInstruction::Split => 4, + AnInstruction::Lt => 6, + AnInstruction::And => 14, + AnInstruction::Xor => 22, + AnInstruction::Log2Floor => 12, + AnInstruction::Pow => 30, + AnInstruction::DivMod => 20, + AnInstruction::PopCount => 28, + AnInstruction::XxAdd => 66, + AnInstruction::XxMul => 74, + AnInstruction::XInvert => 72, + AnInstruction::XbMul => 82, + AnInstruction::ReadIo(_) => 57, + AnInstruction::WriteIo(_) => 19, + AnInstruction::MerkleStep => 36, + AnInstruction::MerkleStepMem => 44, + AnInstruction::XxDotStep => 80, + AnInstruction::XbDotStep => 88, } } - pub(crate) const fn name(&self) -> &'static str { + pub const fn name(&self) -> &'static str { match self { - Pop(_) => "pop", - Push(_) => "push", - Divine(_) => "divine", - Dup(_) => "dup", - Swap(_) => "swap", - Halt => "halt", - Nop => "nop", - Skiz => "skiz", - Call(_) => "call", - Return => "return", - Recurse => "recurse", - RecurseOrReturn => "recurse_or_return", - Assert => "assert", - ReadMem(_) => "read_mem", - WriteMem(_) => "write_mem", - Hash => "hash", - AssertVector => "assert_vector", - SpongeInit => "sponge_init", - SpongeAbsorb => "sponge_absorb", - SpongeAbsorbMem => "sponge_absorb_mem", - SpongeSqueeze => "sponge_squeeze", - Add => "add", - AddI(_) => "addi", - Mul => "mul", - Invert => "invert", - Eq => "eq", - Split => "split", - Lt => "lt", - And => "and", - Xor => "xor", - Log2Floor => "log_2_floor", - Pow => "pow", - DivMod => "div_mod", - PopCount => "pop_count", - XxAdd => "xx_add", - XxMul => "xx_mul", - XInvert => "x_invert", - XbMul => "xb_mul", - ReadIo(_) => "read_io", - WriteIo(_) => "write_io", - MerkleStep => "merkle_step", - MerkleStepMem => "merkle_step_mem", - XxDotStep => "xx_dot_step", - XbDotStep => "xb_dot_step", + AnInstruction::Pop(_) => "pop", + AnInstruction::Push(_) => "push", + AnInstruction::Divine(_) => "divine", + AnInstruction::Dup(_) => "dup", + AnInstruction::Swap(_) => "swap", + AnInstruction::Halt => "halt", + AnInstruction::Nop => "nop", + AnInstruction::Skiz => "skiz", + AnInstruction::Call(_) => "call", + AnInstruction::Return => "return", + AnInstruction::Recurse => "recurse", + AnInstruction::RecurseOrReturn => "recurse_or_return", + AnInstruction::Assert => "assert", + AnInstruction::ReadMem(_) => "read_mem", + AnInstruction::WriteMem(_) => "write_mem", + AnInstruction::Hash => "hash", + AnInstruction::AssertVector => "assert_vector", + AnInstruction::SpongeInit => "sponge_init", + AnInstruction::SpongeAbsorb => "sponge_absorb", + AnInstruction::SpongeAbsorbMem => "sponge_absorb_mem", + AnInstruction::SpongeSqueeze => "sponge_squeeze", + AnInstruction::Add => "add", + AnInstruction::AddI(_) => "addi", + AnInstruction::Mul => "mul", + AnInstruction::Invert => "invert", + AnInstruction::Eq => "eq", + AnInstruction::Split => "split", + AnInstruction::Lt => "lt", + AnInstruction::And => "and", + AnInstruction::Xor => "xor", + AnInstruction::Log2Floor => "log_2_floor", + AnInstruction::Pow => "pow", + AnInstruction::DivMod => "div_mod", + AnInstruction::PopCount => "pop_count", + AnInstruction::XxAdd => "xx_add", + AnInstruction::XxMul => "xx_mul", + AnInstruction::XInvert => "x_invert", + AnInstruction::XbMul => "xb_mul", + AnInstruction::ReadIo(_) => "read_io", + AnInstruction::WriteIo(_) => "write_io", + AnInstruction::MerkleStep => "merkle_step", + AnInstruction::MerkleStepMem => "merkle_step_mem", + AnInstruction::XxDotStep => "xx_dot_step", + AnInstruction::XbDotStep => "xb_dot_step", } } @@ -328,13 +379,13 @@ impl AnInstruction { /// Number of words required to represent the instruction. pub fn size(&self) -> usize { match self { - Pop(_) | Push(_) => 2, - Divine(_) => 2, - Dup(_) | Swap(_) => 2, - Call(_) => 2, - ReadMem(_) | WriteMem(_) => 2, - AddI(_) => 2, - ReadIo(_) | WriteIo(_) => 2, + AnInstruction::Pop(_) | AnInstruction::Push(_) => 2, + AnInstruction::Divine(_) => 2, + AnInstruction::Dup(_) | AnInstruction::Swap(_) => 2, + AnInstruction::Call(_) => 2, + AnInstruction::ReadMem(_) | AnInstruction::WriteMem(_) => 2, + AnInstruction::AddI(_) => 2, + AnInstruction::ReadIo(_) | AnInstruction::WriteIo(_) => 2, _ => 1, } } @@ -353,99 +404,99 @@ impl AnInstruction { NewDest: PartialEq + Default, { match self { - Pop(x) => Pop(*x), - Push(x) => Push(*x), - Divine(x) => Divine(*x), - Dup(x) => Dup(*x), - Swap(x) => Swap(*x), - Halt => Halt, - Nop => Nop, - Skiz => Skiz, - Call(label) => Call(f(label)), - Return => Return, - Recurse => Recurse, - RecurseOrReturn => RecurseOrReturn, - Assert => Assert, - ReadMem(x) => ReadMem(*x), - WriteMem(x) => WriteMem(*x), - Hash => Hash, - AssertVector => AssertVector, - SpongeInit => SpongeInit, - SpongeAbsorb => SpongeAbsorb, - SpongeAbsorbMem => SpongeAbsorbMem, - SpongeSqueeze => SpongeSqueeze, - Add => Add, - AddI(x) => AddI(*x), - Mul => Mul, - Invert => Invert, - Eq => Eq, - Split => Split, - Lt => Lt, - And => And, - Xor => Xor, - Log2Floor => Log2Floor, - Pow => Pow, - DivMod => DivMod, - PopCount => PopCount, - XxAdd => XxAdd, - XxMul => XxMul, - XInvert => XInvert, - XbMul => XbMul, - ReadIo(x) => ReadIo(*x), - WriteIo(x) => WriteIo(*x), - MerkleStep => MerkleStep, - MerkleStepMem => MerkleStepMem, - XxDotStep => XxDotStep, - XbDotStep => XbDotStep, + AnInstruction::Pop(x) => AnInstruction::Pop(*x), + AnInstruction::Push(x) => AnInstruction::Push(*x), + AnInstruction::Divine(x) => AnInstruction::Divine(*x), + AnInstruction::Dup(x) => AnInstruction::Dup(*x), + AnInstruction::Swap(x) => AnInstruction::Swap(*x), + AnInstruction::Halt => AnInstruction::Halt, + AnInstruction::Nop => AnInstruction::Nop, + AnInstruction::Skiz => AnInstruction::Skiz, + AnInstruction::Call(label) => AnInstruction::Call(f(label)), + AnInstruction::Return => AnInstruction::Return, + AnInstruction::Recurse => AnInstruction::Recurse, + AnInstruction::RecurseOrReturn => AnInstruction::RecurseOrReturn, + AnInstruction::Assert => AnInstruction::Assert, + AnInstruction::ReadMem(x) => AnInstruction::ReadMem(*x), + AnInstruction::WriteMem(x) => AnInstruction::WriteMem(*x), + AnInstruction::Hash => AnInstruction::Hash, + AnInstruction::AssertVector => AnInstruction::AssertVector, + AnInstruction::SpongeInit => AnInstruction::SpongeInit, + AnInstruction::SpongeAbsorb => AnInstruction::SpongeAbsorb, + AnInstruction::SpongeAbsorbMem => AnInstruction::SpongeAbsorbMem, + AnInstruction::SpongeSqueeze => AnInstruction::SpongeSqueeze, + AnInstruction::Add => AnInstruction::Add, + AnInstruction::AddI(x) => AnInstruction::AddI(*x), + AnInstruction::Mul => AnInstruction::Mul, + AnInstruction::Invert => AnInstruction::Invert, + AnInstruction::Eq => AnInstruction::Eq, + AnInstruction::Split => AnInstruction::Split, + AnInstruction::Lt => AnInstruction::Lt, + AnInstruction::And => AnInstruction::And, + AnInstruction::Xor => AnInstruction::Xor, + AnInstruction::Log2Floor => AnInstruction::Log2Floor, + AnInstruction::Pow => AnInstruction::Pow, + AnInstruction::DivMod => AnInstruction::DivMod, + AnInstruction::PopCount => AnInstruction::PopCount, + AnInstruction::XxAdd => AnInstruction::XxAdd, + AnInstruction::XxMul => AnInstruction::XxMul, + AnInstruction::XInvert => AnInstruction::XInvert, + AnInstruction::XbMul => AnInstruction::XbMul, + AnInstruction::ReadIo(x) => AnInstruction::ReadIo(*x), + AnInstruction::WriteIo(x) => AnInstruction::WriteIo(*x), + AnInstruction::MerkleStep => AnInstruction::MerkleStep, + AnInstruction::MerkleStepMem => AnInstruction::MerkleStepMem, + AnInstruction::XxDotStep => AnInstruction::XxDotStep, + AnInstruction::XbDotStep => AnInstruction::XbDotStep, } } pub const fn op_stack_size_influence(&self) -> i32 { match self { - Pop(n) => -(n.num_words() as i32), - Push(_) => 1, - Divine(n) => n.num_words() as i32, - Dup(_) => 1, - Swap(_) => 0, - Halt => 0, - Nop => 0, - Skiz => -1, - Call(_) => 0, - Return => 0, - Recurse => 0, - RecurseOrReturn => 0, - Assert => -1, - ReadMem(n) => n.num_words() as i32, - WriteMem(n) => -(n.num_words() as i32), - Hash => -5, - AssertVector => -5, - SpongeInit => 0, - SpongeAbsorb => -10, - SpongeAbsorbMem => 0, - SpongeSqueeze => 10, - Add => -1, - AddI(_) => 0, - Mul => -1, - Invert => 0, - Eq => -1, - Split => 1, - Lt => -1, - And => -1, - Xor => -1, - Log2Floor => 0, - Pow => -1, - DivMod => 0, - PopCount => 0, - XxAdd => -3, - XxMul => -3, - XInvert => 0, - XbMul => -1, - ReadIo(n) => n.num_words() as i32, - WriteIo(n) => -(n.num_words() as i32), - MerkleStep => 0, - MerkleStepMem => 0, - XxDotStep => 0, - XbDotStep => 0, + AnInstruction::Pop(n) => -(n.num_words() as i32), + AnInstruction::Push(_) => 1, + AnInstruction::Divine(n) => n.num_words() as i32, + AnInstruction::Dup(_) => 1, + AnInstruction::Swap(_) => 0, + AnInstruction::Halt => 0, + AnInstruction::Nop => 0, + AnInstruction::Skiz => -1, + AnInstruction::Call(_) => 0, + AnInstruction::Return => 0, + AnInstruction::Recurse => 0, + AnInstruction::RecurseOrReturn => 0, + AnInstruction::Assert => -1, + AnInstruction::ReadMem(n) => n.num_words() as i32, + AnInstruction::WriteMem(n) => -(n.num_words() as i32), + AnInstruction::Hash => -5, + AnInstruction::AssertVector => -5, + AnInstruction::SpongeInit => 0, + AnInstruction::SpongeAbsorb => -10, + AnInstruction::SpongeAbsorbMem => 0, + AnInstruction::SpongeSqueeze => 10, + AnInstruction::Add => -1, + AnInstruction::AddI(_) => 0, + AnInstruction::Mul => -1, + AnInstruction::Invert => 0, + AnInstruction::Eq => -1, + AnInstruction::Split => 1, + AnInstruction::Lt => -1, + AnInstruction::And => -1, + AnInstruction::Xor => -1, + AnInstruction::Log2Floor => 0, + AnInstruction::Pow => -1, + AnInstruction::DivMod => 0, + AnInstruction::PopCount => 0, + AnInstruction::XxAdd => -3, + AnInstruction::XxMul => -3, + AnInstruction::XInvert => 0, + AnInstruction::XbMul => -1, + AnInstruction::ReadIo(n) => n.num_words() as i32, + AnInstruction::WriteIo(n) => -(n.num_words() as i32), + AnInstruction::MerkleStep => 0, + AnInstruction::MerkleStepMem => 0, + AnInstruction::XxDotStep => 0, + AnInstruction::XbDotStep => 0, } } @@ -453,16 +504,16 @@ impl AnInstruction { pub fn is_u32_instruction(&self) -> bool { matches!( self, - Split - | Lt - | And - | Xor - | Log2Floor - | Pow - | DivMod - | PopCount - | MerkleStep - | MerkleStepMem + AnInstruction::Split + | AnInstruction::Lt + | AnInstruction::And + | AnInstruction::Xor + | AnInstruction::Log2Floor + | AnInstruction::Pow + | AnInstruction::DivMod + | AnInstruction::PopCount + | AnInstruction::MerkleStep + | AnInstruction::MerkleStepMem ) } } @@ -471,13 +522,17 @@ impl Display for AnInstruction { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { write!(f, "{}", self.name())?; match self { - Push(arg) => write!(f, " {arg}"), - Pop(arg) | Divine(arg) => write!(f, " {arg}"), - Dup(arg) | Swap(arg) => write!(f, " {arg}"), - Call(arg) => write!(f, " {arg}"), - ReadMem(arg) | WriteMem(arg) => write!(f, " {arg}"), - AddI(arg) => write!(f, " {arg}"), - ReadIo(arg) | WriteIo(arg) => write!(f, " {arg}"), + AnInstruction::Push(arg) => write!(f, " {arg}"), + AnInstruction::Pop(arg) => write!(f, " {arg}"), + AnInstruction::Divine(arg) => write!(f, " {arg}"), + AnInstruction::Dup(arg) => write!(f, " {arg}"), + AnInstruction::Swap(arg) => write!(f, " {arg}"), + AnInstruction::Call(arg) => write!(f, " {arg}"), + AnInstruction::ReadMem(arg) => write!(f, " {arg}"), + AnInstruction::WriteMem(arg) => write!(f, " {arg}"), + AnInstruction::AddI(arg) => write!(f, " {arg}"), + AnInstruction::ReadIo(arg) => write!(f, " {arg}"), + AnInstruction::WriteIo(arg) => write!(f, " {arg}"), _ => Ok(()), } } @@ -487,12 +542,17 @@ impl Instruction { /// Get the argument of the instruction, if it has one. pub fn arg(&self) -> Option { match self { - Push(arg) | Call(arg) => Some(*arg), - Pop(arg) | Divine(arg) => Some(arg.into()), - Dup(arg) | Swap(arg) => Some(arg.into()), - ReadMem(arg) | WriteMem(arg) => Some(arg.into()), - AddI(arg) => Some(*arg), - ReadIo(arg) | WriteIo(arg) => Some(arg.into()), + AnInstruction::Push(arg) => Some(*arg), + AnInstruction::Call(arg) => Some(*arg), + AnInstruction::Pop(arg) => Some(arg.into()), + AnInstruction::Divine(arg) => Some(arg.into()), + AnInstruction::Dup(arg) => Some(arg.into()), + AnInstruction::Swap(arg) => Some(arg.into()), + AnInstruction::ReadMem(arg) => Some(arg.into()), + AnInstruction::WriteMem(arg) => Some(arg.into()), + AnInstruction::AddI(arg) => Some(*arg), + AnInstruction::ReadIo(arg) => Some(arg.into()), + AnInstruction::WriteIo(arg) => Some(arg.into()), _ => None, } } @@ -505,17 +565,17 @@ impl Instruction { let op_stack_element = new_arg.try_into().map_err(|_| illegal_argument_error); let new_instruction = match self { - Pop(_) => Pop(num_words?), - Push(_) => Push(new_arg), - Divine(_) => Divine(num_words?), - Dup(_) => Dup(op_stack_element?), - Swap(_) => Swap(op_stack_element?), - Call(_) => Call(new_arg), - ReadMem(_) => ReadMem(num_words?), - WriteMem(_) => WriteMem(num_words?), - AddI(_) => AddI(new_arg), - ReadIo(_) => ReadIo(num_words?), - WriteIo(_) => WriteIo(num_words?), + AnInstruction::Pop(_) => AnInstruction::Pop(num_words?), + AnInstruction::Push(_) => AnInstruction::Push(new_arg), + AnInstruction::Divine(_) => AnInstruction::Divine(num_words?), + AnInstruction::Dup(_) => AnInstruction::Dup(op_stack_element?), + AnInstruction::Swap(_) => AnInstruction::Swap(op_stack_element?), + AnInstruction::Call(_) => AnInstruction::Call(new_arg), + AnInstruction::ReadMem(_) => AnInstruction::ReadMem(num_words?), + AnInstruction::WriteMem(_) => AnInstruction::WriteMem(num_words?), + AnInstruction::AddI(_) => AnInstruction::AddI(new_arg), + AnInstruction::ReadIo(_) => AnInstruction::ReadIo(num_words?), + AnInstruction::WriteIo(_) => AnInstruction::WriteIo(num_words?), _ => return Err(illegal_argument_error), }; @@ -561,67 +621,6 @@ impl TryFrom for Instruction { } } -/// A list of all instructions with default arguments, if any. -const fn all_instructions_with_default_args() -> [AnInstruction; Instruction::COUNT] -{ - [ - Pop(N1), - Push(BFieldElement::ZERO), - Divine(N1), - Dup(ST0), - Swap(ST0), - Halt, - Nop, - Skiz, - Call(BFieldElement::ZERO), - Return, - Recurse, - RecurseOrReturn, - Assert, - ReadMem(N1), - WriteMem(N1), - Hash, - AssertVector, - SpongeInit, - SpongeAbsorb, - SpongeAbsorbMem, - SpongeSqueeze, - Add, - AddI(BFieldElement::ZERO), - Mul, - Invert, - Eq, - Split, - Lt, - And, - Xor, - Log2Floor, - Pow, - DivMod, - PopCount, - XxAdd, - XxMul, - XInvert, - XbMul, - ReadIo(N1), - WriteIo(N1), - MerkleStep, - MerkleStepMem, - XxDotStep, - XbDotStep, - ] -} - -const fn all_instruction_names() -> [&'static str; Instruction::COUNT] { - let mut names = [""; Instruction::COUNT]; - let mut i = 0; - while i < Instruction::COUNT { - names[i] = ALL_INSTRUCTIONS[i].name(); - i += 1; - } - names -} - /// Indicators for all the possible bits in an [`Instruction`]. #[derive(Debug, Default, Copy, Clone, Eq, PartialEq, EnumCount, EnumIter)] pub enum InstructionBit { @@ -645,13 +644,13 @@ impl Display for InstructionBit { impl From for usize { fn from(instruction_bit: InstructionBit) -> Self { match instruction_bit { - IB0 => 0, - IB1 => 1, - IB2 => 2, - IB3 => 3, - IB4 => 4, - IB5 => 5, - IB6 => 6, + InstructionBit::IB0 => 0, + InstructionBit::IB1 => 1, + InstructionBit::IB2 => 2, + InstructionBit::IB3 => 3, + InstructionBit::IB4 => 4, + InstructionBit::IB5 => 5, + InstructionBit::IB6 => 6, } } } @@ -661,13 +660,13 @@ impl TryFrom for InstructionBit { fn try_from(bit_index: usize) -> result::Result { match bit_index { - 0 => Ok(IB0), - 1 => Ok(IB1), - 2 => Ok(IB2), - 3 => Ok(IB3), - 4 => Ok(IB4), - 5 => Ok(IB5), - 6 => Ok(IB6), + 0 => Ok(InstructionBit::IB0), + 1 => Ok(InstructionBit::IB1), + 2 => Ok(InstructionBit::IB2), + 3 => Ok(InstructionBit::IB3), + 4 => Ok(InstructionBit::IB4), + 5 => Ok(InstructionBit::IB5), + 6 => Ok(InstructionBit::IB6), _ => Err(format!( "Index {bit_index} is out of range for `InstructionBit`." )), @@ -771,12 +770,63 @@ impl<'a> Arbitrary<'a> for TypeHintTypeName { } } +#[non_exhaustive] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)] +pub enum InstructionError { + #[error("opcode {0} is invalid")] + InvalidOpcode(u32), + + #[error("opcode is out of range: {0}")] + OutOfRangeOpcode(#[from] TryFromIntError), + + #[error("invalid argument {1} for instruction `{0}`")] + IllegalArgument(Instruction, BFieldElement), + + #[error("instruction pointer points outside of program")] + InstructionPointerOverflow, + + #[error("jump stack is empty")] + JumpStackIsEmpty, + + #[error("assertion failed: st0 must be 1")] + AssertionFailed, + + #[error("vector assertion failed: stack[{0}] != stack[{}]", .0 + Digest::LEN)] + VectorAssertionFailed(usize), + + #[error("0 does not have a multiplicative inverse")] + InverseOfZero, + + #[error("division by 0 is impossible")] + DivisionByZero, + + #[error("the Sponge state must be initialized before it can be used")] + SpongeNotInitialized, + + #[error("the logarithm of 0 does not exist")] + LogarithmOfZero, + + #[error("public input buffer is empty after {0} reads")] + EmptyPublicInput(usize), + + #[error("secret input buffer is empty after {0} reads")] + EmptySecretInput(usize), + + #[error("no more secret digests available")] + EmptySecretDigestInput, + + #[error("Triton VM has halted and cannot execute any further instructions")] + MachineHalted, + + #[error(transparent)] + OpStackError(#[from] OpStackError), +} + #[cfg(test)] pub mod tests { use std::collections::HashMap; use assert2::assert; - use assert2::let_assert; use itertools::Itertools; use num_traits::One; use num_traits::Zero; @@ -787,15 +837,10 @@ pub mod tests { use strum::VariantNames; use twenty_first::prelude::*; - use crate::instruction::*; - use crate::op_stack::NUM_OP_STACK_REGISTERS; - use crate::program::PublicInput; use crate::triton_asm; use crate::triton_program; - use crate::vm::tests::test_program_for_call_recurse_return; - use crate::vm::VMState; - use crate::NonDeterminism; - use crate::Program; + + use super::*; #[derive(Debug, Copy, Clone, EnumCount, EnumIter, VariantNames)] pub enum InstructionBucket { @@ -883,7 +928,12 @@ pub mod tests { fn parse_push_pop() { let program = triton_program!(push 1 push 1 add pop 2); let instructions = program.into_iter().collect_vec(); - let expected = vec![Push(bfe!(1)), Push(bfe!(1)), Add, Pop(N2)]; + let expected = vec![ + Instruction::Push(bfe!(1)), + Instruction::Push(bfe!(1)), + Instruction::Add, + Instruction::Pop(NumberOfWords::N2), + ]; assert!(expected == instructions); } @@ -922,10 +972,8 @@ pub mod tests { /// is guaranteed at compile time, this test ensures the absence of repetitions. #[test] fn list_of_all_instructions_contains_unique_instructions() { - assert!(all_instructions_with_default_args() - .into_iter() - .all_unique()); - assert!(all_instruction_names().into_iter().all_unique()); + assert!(ALL_INSTRUCTIONS.into_iter().all_unique()); + assert!(ALL_INSTRUCTION_NAMES.into_iter().all_unique()); } #[test] @@ -938,14 +986,26 @@ pub mod tests { #[test] fn change_arguments_of_various_instructions() { - assert!(Push(bfe!(0)).change_arg(bfe!(7)).is_ok()); - assert!(Dup(ST0).change_arg(bfe!(1024)).is_err()); - assert!(Swap(ST0).change_arg(bfe!(1337)).is_err()); - assert!(Swap(ST0).change_arg(bfe!(0)).is_ok()); - assert!(Swap(ST0).change_arg(bfe!(1)).is_ok()); - assert!(Pop(N4).change_arg(bfe!(0)).is_err()); - assert!(Pop(N1).change_arg(bfe!(2)).is_ok()); - assert!(Nop.change_arg(bfe!(7)).is_err()); + assert!(Instruction::Push(bfe!(0)).change_arg(bfe!(7)).is_ok()); + assert!(Instruction::Dup(OpStackElement::ST0) + .change_arg(bfe!(1024)) + .is_err()); + assert!(Instruction::Swap(OpStackElement::ST0) + .change_arg(bfe!(1337)) + .is_err()); + assert!(Instruction::Swap(OpStackElement::ST0) + .change_arg(bfe!(0)) + .is_ok()); + assert!(Instruction::Swap(OpStackElement::ST0) + .change_arg(bfe!(1)) + .is_ok()); + assert!(Instruction::Pop(NumberOfWords::N4) + .change_arg(bfe!(0)) + .is_err()); + assert!(Instruction::Pop(NumberOfWords::N1) + .change_arg(bfe!(2)) + .is_ok()); + assert!(Instruction::Nop.change_arg(bfe!(7)).is_err()); } #[test] @@ -953,7 +1013,10 @@ pub mod tests { println!("instruction_push: {:?}", Instruction::Push(bfe!(7))); println!("instruction_assert: {}", Instruction::Assert); println!("instruction_invert: {:?}", Instruction::Invert); - println!("instruction_dup: {}", Instruction::Dup(ST14)); + println!( + "instruction_dup: {}", + Instruction::Dup(OpStackElement::ST14) + ); } #[test] @@ -1022,50 +1085,9 @@ pub mod tests { println!("{code}"); } - #[test] - fn instructions_act_on_op_stack_as_indicated() { - for test_instruction in all_instructions_with_default_args() { - let (program, stack_size_before_test_instruction) = - construct_test_program_for_instruction(test_instruction); - let public_input = PublicInput::from(bfe_array![0]); - let mock_digests = [Digest::default()]; - let non_determinism = NonDeterminism::from(bfe_array![0]).with_digests(mock_digests); - - let mut vm_state = VMState::new(&program, public_input, non_determinism); - let_assert!(Ok(()) = vm_state.run()); - let stack_size_after_test_instruction = vm_state.op_stack.len(); - - let stack_size_difference = (stack_size_after_test_instruction as i32) - - (stack_size_before_test_instruction as i32); - assert!( - test_instruction.op_stack_size_influence() == stack_size_difference, - "{test_instruction}" - ); - } - } - - fn construct_test_program_for_instruction( - instruction: AnInstruction, - ) -> (Program, usize) { - if matches!(instruction, Call(_) | Return | Recurse | RecurseOrReturn) { - // need jump stack setup - let program = test_program_for_call_recurse_return().program; - let stack_size = NUM_OP_STACK_REGISTERS; - (program, stack_size) - } else { - let num_push_instructions = 10; - let push_instructions = triton_asm![push 1; num_push_instructions]; - let program = triton_program!(sponge_init {&push_instructions} {instruction} nop halt); - - let stack_size_when_reaching_test_instruction = - NUM_OP_STACK_REGISTERS + num_push_instructions; - (program, stack_size_when_reaching_test_instruction) - } - } - #[test] fn labelled_instructions_act_on_op_stack_as_indicated() { - for instruction in all_instructions_with_default_args() { + for instruction in ALL_INSTRUCTIONS { let labelled_instruction = instruction.map_call_address(|_| "dummy_label".to_string()); let labelled_instruction = LabelledInstruction::Instruction(labelled_instruction); @@ -1084,14 +1106,11 @@ pub mod tests { #[test] fn can_change_arg() { - for instruction in all_instructions_with_default_args() { - if let Some(arg) = instruction.arg() { - assert_ne!( - instruction, - (instruction.change_arg(arg + bfe!(1))).unwrap() - ); + for intsr in ALL_INSTRUCTIONS { + if let Some(arg) = intsr.arg() { + assert_ne!(intsr, intsr.change_arg(arg + bfe!(1)).unwrap()); } else { - assert!(instruction.change_arg(bfe!(0)).is_err()) + assert!(intsr.change_arg(bfe!(0)).is_err()) } } } diff --git a/triton-isa/src/lib.rs b/triton-isa/src/lib.rs new file mode 100644 index 000000000..dcc84569d --- /dev/null +++ b/triton-isa/src/lib.rs @@ -0,0 +1,346 @@ +pub use twenty_first; + +pub mod error; +pub mod instruction; +pub mod op_stack; +pub mod parser; +pub mod program; + +/// Compile an entire program written in [Triton assembly][tasm]. +/// Triton VM can run the resulting [`Program`](program::Program); see there for +/// details. +/// +/// It is possible to use string-like interpolation to insert instructions, arguments, labels, +/// or other substrings into the program. +/// +/// # Examples +/// +/// ``` +/// # use triton_isa::triton_program; +/// # use twenty_first::prelude::*; +/// let program = triton_program!( +/// read_io 1 push 5 mul +/// call check_eq_15 +/// push 17 write_io 1 +/// halt +/// // assert that the top of the stack is 15 +/// check_eq_15: +/// push 15 eq assert +/// return +/// ); +/// ``` +/// +/// Any type with an appropriate [`Display`](std::fmt::Display) implementation can be +/// interpolated. This includes, for example, primitive types like `u64` and `&str`, but also +/// [`Instruction`](instruction::Instruction)s, +/// [`BFieldElement`](twenty_first::prelude::BFieldElement)s, and +/// [`Label`](instruction::LabelledInstruction)s, among others. +/// +/// ``` +/// # use twenty_first::prelude::*; +/// # use triton_isa::triton_program; +/// # use triton_isa::instruction::Instruction; +/// let element_0 = BFieldElement::new(0); +/// let label = "my_label"; +/// let instruction_push = Instruction::Push(bfe!(42)); +/// let dup_arg = 1; +/// let program = triton_program!( +/// push {element_0} +/// call {label} halt +/// {label}: +/// {instruction_push} +/// dup {dup_arg} +/// skiz recurse return +/// ); +/// ``` +/// +/// # Panics +/// +/// **Panics** if the program cannot be parsed. +/// Examples for parsing errors are: +/// - unknown (_e.g._ misspelled) instructions +/// - invalid instruction arguments, _e.g._, `push 1.5` or `swap 42` +/// - missing or duplicate labels +/// - invalid labels, _e.g._, using a reserved keyword or starting a label with a digit +/// +/// For a version that returns a `Result`, see [`Program::from_code()`][from_code]. +/// +/// [tasm]: https://triton-vm.org/spec/instructions.html +/// [from_code]: program::Program::from_code +#[macro_export] +macro_rules! triton_program { + {$($source_code:tt)*} => {{ + let labelled_instructions = $crate::triton_asm!($($source_code)*); + $crate::program::Program::new(&labelled_instructions) + }}; +} + +/// Compile [Triton assembly][tasm] into a list of labelled +/// [`Instruction`](instruction::LabelledInstruction)s. +/// Similar to [`triton_program!`](triton_program), it is possible to use string-like +/// interpolation to insert instructions, arguments, labels, or other expressions. +/// +/// Similar to [`vec!`], a single instruction can be repeated a specified number of times. +/// +/// Furthermore, a list of [`LabelledInstruction`](instruction::LabelledInstruction)s +/// can be inserted like so: `{&list}`. +/// +/// The labels for instruction `call`, if any, are also parsed. Instruction `call` can refer to +/// a label defined later in the program, _i.e.,_ labels are not checked for existence or +/// uniqueness by this parser. +/// +/// # Examples +/// +/// ``` +/// # use triton_isa::triton_asm; +/// let push_argument = 42; +/// let instructions = triton_asm!( +/// push 1 call some_label +/// push {push_argument} +/// some_other_label: skiz halt return +/// ); +/// assert_eq!(7, instructions.len()); +/// ``` +/// +/// One instruction repeated several times: +/// +/// ``` +/// # use triton_isa::triton_asm; +/// # use triton_isa::instruction::LabelledInstruction; +/// # use triton_isa::instruction::AnInstruction::SpongeAbsorb; +/// let instructions = triton_asm![sponge_absorb; 3]; +/// assert_eq!(3, instructions.len()); +/// assert_eq!(LabelledInstruction::Instruction(SpongeAbsorb), instructions[0]); +/// assert_eq!(LabelledInstruction::Instruction(SpongeAbsorb), instructions[1]); +/// assert_eq!(LabelledInstruction::Instruction(SpongeAbsorb), instructions[2]); +/// ``` +/// +/// Inserting substring of labelled instructions: +/// +/// ``` +/// # use triton_isa::instruction::AnInstruction::Push; +/// # use triton_isa::instruction::AnInstruction::Pop; +/// # use triton_isa::op_stack::NumberOfWords::N1; +/// # use triton_isa::instruction::LabelledInstruction; +/// # use triton_isa::triton_asm; +/// # use twenty_first::prelude::*; +/// let insert_me = triton_asm!( +/// pop 1 +/// nop +/// pop 1 +/// ); +/// let surrounding_code = triton_asm!( +/// push 0 +/// {&insert_me} +/// push 1 +/// ); +/// # let zero = bfe!(0); +/// # assert_eq!(LabelledInstruction::Instruction(Push(zero)), surrounding_code[0]); +/// assert_eq!(LabelledInstruction::Instruction(Pop(N1)), surrounding_code[1]); +/// assert_eq!(LabelledInstruction::Instruction(Pop(N1)), surrounding_code[3]); +/// # let one = bfe!(1); +/// # assert_eq!(LabelledInstruction::Instruction(Push(one)), surrounding_code[4]); +///``` +/// +/// # Panics +/// +/// **Panics** if the instructions cannot be parsed. +/// For examples, see [`triton_program!`](triton_program), with the exception that +/// labels are not checked for existence or uniqueness. +/// +/// [tasm]: https://triton-vm.org/spec/instructions.html +#[macro_export] +macro_rules! triton_asm { + (@fmt $fmt:expr, $($args:expr,)*; ) => { + format_args!($fmt $(,$args)*).to_string() + }; + (@fmt $fmt:expr, $($args:expr,)*; + hint $var:ident: $ty:ident = stack[$start:literal..$end:literal] $($tail:tt)*) => { + $crate::triton_asm!(@fmt + concat!($fmt, " hint {}: {} = stack[{}..{}] "), + $($args,)* stringify!($var), stringify!($ty), $start, $end,; + $($tail)* + ) + }; + (@fmt $fmt:expr, $($args:expr,)*; + hint $var:ident = stack[$start:literal..$end:literal] $($tail:tt)*) => { + $crate::triton_asm!(@fmt + concat!($fmt, " hint {} = stack[{}..{}] "), + $($args,)* stringify!($var), $start, $end,; + $($tail)* + ) + }; + (@fmt $fmt:expr, $($args:expr,)*; + hint $var:ident: $ty:ident = stack[$index:literal] $($tail:tt)*) => { + $crate::triton_asm!(@fmt + concat!($fmt, " hint {}: {} = stack[{}] "), + $($args,)* stringify!($var), stringify!($ty), $index,; + $($tail)* + ) + }; + (@fmt $fmt:expr, $($args:expr,)*; + hint $var:ident = stack[$index:literal] $($tail:tt)*) => { + $crate::triton_asm!(@fmt + concat!($fmt, " hint {} = stack[{}] "), + $($args,)* stringify!($var), $index,; + $($tail)* + ) + }; + (@fmt $fmt:expr, $($args:expr,)*; $label_declaration:ident: $($tail:tt)*) => { + $crate::triton_asm!(@fmt + concat!($fmt, " ", stringify!($label_declaration), ": "), $($args,)*; $($tail)* + ) + }; + (@fmt $fmt:expr, $($args:expr,)*; $instruction:ident $($tail:tt)*) => { + $crate::triton_asm!(@fmt + concat!($fmt, " ", stringify!($instruction), " "), $($args,)*; $($tail)* + ) + }; + (@fmt $fmt:expr, $($args:expr,)*; $instruction_argument:literal $($tail:tt)*) => { + $crate::triton_asm!(@fmt + concat!($fmt, " ", stringify!($instruction_argument), " "), $($args,)*; $($tail)* + ) + }; + (@fmt $fmt:expr, $($args:expr,)*; {$label_declaration:expr}: $($tail:tt)*) => { + $crate::triton_asm!(@fmt concat!($fmt, "{}: "), $($args,)* $label_declaration,; $($tail)*) + }; + (@fmt $fmt:expr, $($args:expr,)*; {&$instruction_list:expr} $($tail:tt)*) => { + $crate::triton_asm!(@fmt + concat!($fmt, "{} "), $($args,)* + $instruction_list.iter().map(|instr| instr.to_string()).collect::>().join(" "),; + $($tail)* + ) + }; + (@fmt $fmt:expr, $($args:expr,)*; {$expression:expr} $($tail:tt)*) => { + $crate::triton_asm!(@fmt concat!($fmt, "{} "), $($args,)* $expression,; $($tail)*) + }; + + // repeated instructions + [pop $arg:literal; $num:expr] => { vec![ $crate::triton_instr!(pop $arg); $num ] }; + [push $arg:literal; $num:expr] => { vec![ $crate::triton_instr!(push $arg); $num ] }; + [divine $arg:literal; $num:expr] => { vec![ $crate::triton_instr!(divine $arg); $num ] }; + [dup $arg:literal; $num:expr] => { vec![ $crate::triton_instr!(dup $arg); $num ] }; + [swap $arg:literal; $num:expr] => { vec![ $crate::triton_instr!(swap $arg); $num ] }; + [call $arg:ident; $num:expr] => { vec![ $crate::triton_instr!(call $arg); $num ] }; + [read_mem $arg:literal; $num:expr] => { vec![ $crate::triton_instr!(read_mem $arg); $num ] }; + [write_mem $arg:literal; $num:expr] => { vec![ $crate::triton_instr!(write_mem $arg); $num ] }; + [read_io $arg:literal; $num:expr] => { vec![ $crate::triton_instr!(read_io $arg); $num ] }; + [write_io $arg:literal; $num:expr] => { vec![ $crate::triton_instr!(write_io $arg); $num ] }; + [$instr:ident; $num:expr] => { vec![ $crate::triton_instr!($instr); $num ] }; + + // entry point + {$($source_code:tt)*} => {{ + let source_code = $crate::triton_asm!(@fmt "",; $($source_code)*); + let (_, instructions) = $crate::parser::tokenize(&source_code).unwrap(); + $crate::parser::to_labelled_instructions(&instructions) + }}; +} + +/// Compile a single [Triton assembly][tasm] instruction into a +/// [`LabelledInstruction`](instruction::LabelledInstruction). +/// +/// # Examples +/// +/// ``` +/// # use triton_isa::triton_instr; +/// # use triton_isa::instruction::LabelledInstruction; +/// # use triton_isa::instruction::AnInstruction::Call; +/// let instruction = triton_instr!(call my_label); +/// assert_eq!(LabelledInstruction::Instruction(Call("my_label".to_string())), instruction); +/// ``` +/// +/// [tasm]: https://triton-vm.org/spec/instructions.html +#[macro_export] +macro_rules! triton_instr { + (pop $arg:literal) => {{ + let argument = $crate::op_stack::NumberOfWords::try_from($arg).unwrap(); + let instruction = $crate::instruction::AnInstruction::::Pop(argument); + $crate::instruction::LabelledInstruction::Instruction(instruction) + }}; + (push $arg:expr) => {{ + let argument = $crate::twenty_first::prelude::BFieldElement::from($arg); + let instruction = $crate::instruction::AnInstruction::::Push(argument); + $crate::instruction::LabelledInstruction::Instruction(instruction) + }}; + (divine $arg:literal) => {{ + let argument = $crate::op_stack::NumberOfWords::try_from($arg).unwrap(); + let instruction = $crate::instruction::AnInstruction::::Divine(argument); + $crate::instruction::LabelledInstruction::Instruction(instruction) + }}; + (dup $arg:literal) => {{ + let argument = $crate::op_stack::OpStackElement::try_from($arg).unwrap(); + let instruction = $crate::instruction::AnInstruction::::Dup(argument); + $crate::instruction::LabelledInstruction::Instruction(instruction) + }}; + (swap $arg:literal) => {{ + let argument = $crate::op_stack::OpStackElement::try_from($arg).unwrap(); + let instruction = $crate::instruction::AnInstruction::::Swap(argument); + $crate::instruction::LabelledInstruction::Instruction(instruction) + }}; + (call $arg:ident) => {{ + let argument = stringify!($arg).to_string(); + let instruction = $crate::instruction::AnInstruction::::Call(argument); + $crate::instruction::LabelledInstruction::Instruction(instruction) + }}; + (read_mem $arg:literal) => {{ + let argument = $crate::op_stack::NumberOfWords::try_from($arg).unwrap(); + let instruction = $crate::instruction::AnInstruction::::ReadMem(argument); + $crate::instruction::LabelledInstruction::Instruction(instruction) + }}; + (write_mem $arg:literal) => {{ + let argument = $crate::op_stack::NumberOfWords::try_from($arg).unwrap(); + let instruction = $crate::instruction::AnInstruction::::WriteMem(argument); + $crate::instruction::LabelledInstruction::Instruction(instruction) + }}; + (addi $arg:expr) => {{ + let argument = $crate::twenty_first::prelude::BFieldElement::from($arg); + let instruction = $crate::instruction::AnInstruction::::AddI(argument); + $crate::instruction::LabelledInstruction::Instruction(instruction) + }}; + (read_io $arg:literal) => {{ + let argument = $crate::op_stack::NumberOfWords::try_from($arg).unwrap(); + let instruction = $crate::instruction::AnInstruction::::ReadIo(argument); + $crate::instruction::LabelledInstruction::Instruction(instruction) + }}; + (write_io $arg:literal) => {{ + let argument = $crate::op_stack::NumberOfWords::try_from($arg).unwrap(); + let instruction = $crate::instruction::AnInstruction::::WriteIo(argument); + $crate::instruction::LabelledInstruction::Instruction(instruction) + }}; + ($instr:ident) => {{ + let (_, instructions) = $crate::parser::tokenize(stringify!($instr)).unwrap(); + instructions[0].to_labelled_instruction() + }}; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn public_types_implement_usual_auto_traits() { + fn implements_auto_traits() {} + + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + + implements_auto_traits::(); + implements_auto_traits::>(); + implements_auto_traits::(); + implements_auto_traits::(); + + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + + implements_auto_traits::(); + + implements_auto_traits::(); + implements_auto_traits::(); + } +} diff --git a/triton-vm/src/op_stack.rs b/triton-isa/src/op_stack.rs similarity index 81% rename from triton-vm/src/op_stack.rs rename to triton-isa/src/op_stack.rs index 4a4e033e5..6c349e441 100644 --- a/triton-vm/src/op_stack.rs +++ b/triton-isa/src/op_stack.rs @@ -1,6 +1,7 @@ use std::fmt::Display; use std::fmt::Formatter; use std::fmt::Result as FmtResult; +use std::num::TryFromIntError; use std::ops::Index; use std::ops::IndexMut; @@ -12,13 +13,10 @@ use serde::Serialize; use strum::EnumCount; use strum::EnumIter; use strum::IntoEnumIterator; +use thiserror::Error; use twenty_first::prelude::*; -use crate::error::InstructionError::*; -use crate::error::*; -use crate::op_stack::OpStackElement::*; - -type Result = std::result::Result; +type Result = std::result::Result; type OpStackElementResult = std::result::Result; type NumWordsResult = std::result::Result; @@ -48,6 +46,16 @@ pub struct OpStack { underflow_io_sequence: Vec, } +#[non_exhaustive] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)] +pub enum OpStackError { + #[error("operational stack is too shallow")] + TooShallow, + + #[error("failed to convert BFieldElement {0} into u32")] + FailedU32Conversion(BFieldElement), +} + impl OpStack { pub fn new(program_digest: Digest) -> Self { let mut stack = bfe_vec![0; OpStackElement::COUNT]; @@ -65,14 +73,14 @@ impl OpStack { self.stack.len() } - pub(crate) fn push(&mut self, element: BFieldElement) { + pub fn push(&mut self, element: BFieldElement) { self.stack.push(element); self.record_underflow_io(UnderflowIO::Write); } - pub(crate) fn pop(&mut self) -> Result { + pub fn pop(&mut self) -> Result { self.record_underflow_io(UnderflowIO::Read); - self.stack.pop().ok_or(OpStackTooShallow) + self.stack.pop().ok_or(OpStackError::TooShallow) } fn record_underflow_io(&mut self, io_type: fn(BFieldElement) -> UnderflowIO) { @@ -80,40 +88,44 @@ impl OpStack { self.underflow_io_sequence.push(underflow_io); } - pub(crate) fn start_recording_underflow_io_sequence(&mut self) { + pub fn start_recording_underflow_io_sequence(&mut self) { self.underflow_io_sequence.clear(); } - pub(crate) fn stop_recording_underflow_io_sequence(&mut self) -> Vec { + pub fn stop_recording_underflow_io_sequence(&mut self) -> Vec { self.underflow_io_sequence.drain(..).collect() } - pub(crate) fn push_extension_field_element(&mut self, element: XFieldElement) { + pub fn push_extension_field_element(&mut self, element: XFieldElement) { for coefficient in element.coefficients.into_iter().rev() { self.push(coefficient); } } - pub(crate) fn pop_extension_field_element(&mut self) -> Result { + pub fn pop_extension_field_element(&mut self) -> Result { let coefficients = self.pop_multiple()?; Ok(xfe!(coefficients)) } - pub(crate) fn is_u32(&self, stack_element: OpStackElement) -> Result<()> { + pub fn is_u32(&self, stack_element: OpStackElement) -> Result<()> { self.get_u32(stack_element).map(|_| ()) } - pub(crate) fn get_u32(&self, stack_element: OpStackElement) -> Result { + pub fn get_u32(&self, stack_element: OpStackElement) -> Result { let element = self[stack_element]; - element.try_into().map_err(|_| FailedU32Conversion(element)) + element + .try_into() + .map_err(|_| OpStackError::FailedU32Conversion(element)) } - pub(crate) fn pop_u32(&mut self) -> Result { + pub fn pop_u32(&mut self) -> Result { let element = self.pop()?; - element.try_into().map_err(|_| FailedU32Conversion(element)) + element + .try_into() + .map_err(|_| OpStackError::FailedU32Conversion(element)) } - pub(crate) fn pop_multiple(&mut self) -> Result<[BFieldElement; N]> { + pub fn pop_multiple(&mut self) -> Result<[BFieldElement; N]> { let mut elements = bfe_array![0; N]; for element in &mut elements { *element = self.pop()?; @@ -121,18 +133,18 @@ impl OpStack { Ok(elements) } - pub(crate) fn peek_at_top_extension_field_element(&self) -> XFieldElement { + pub fn peek_at_top_extension_field_element(&self) -> XFieldElement { xfe!([self[0], self[1], self[2]]) } - pub(crate) fn would_be_too_shallow(&self, stack_delta: i32) -> bool { + pub fn would_be_too_shallow(&self, stack_delta: i32) -> bool { self.len() as i32 + stack_delta < OpStackElement::COUNT as i32 } /// The address of the next free address of the op-stack. Equivalent to the current length of /// the op-stack. - pub(crate) fn pointer(&self) -> BFieldElement { - (self.len() as u64).into() + pub fn pointer(&self) -> BFieldElement { + u64::try_from(self.len()).unwrap().into() } /// The first element of the op-stack underflow memory, or 0 if the op-stack underflow memory @@ -309,25 +321,35 @@ pub enum OpStackElement { ST15, } +#[non_exhaustive] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)] +pub enum OpStackElementError { + #[error("index {0} is out of range for `OpStackElement`")] + IndexOutOfBounds(u32), + + #[error(transparent)] + FailedIntegerConversion(#[from] TryFromIntError), +} + impl OpStackElement { pub const fn index(self) -> u32 { match self { - ST0 => 0, - ST1 => 1, - ST2 => 2, - ST3 => 3, - ST4 => 4, - ST5 => 5, - ST6 => 6, - ST7 => 7, - ST8 => 8, - ST9 => 9, - ST10 => 10, - ST11 => 11, - ST12 => 12, - ST13 => 13, - ST14 => 14, - ST15 => 15, + OpStackElement::ST0 => 0, + OpStackElement::ST1 => 1, + OpStackElement::ST2 => 2, + OpStackElement::ST3 => 3, + OpStackElement::ST4 => 4, + OpStackElement::ST5 => 5, + OpStackElement::ST6 => 6, + OpStackElement::ST7 => 7, + OpStackElement::ST8 => 8, + OpStackElement::ST9 => 9, + OpStackElement::ST10 => 10, + OpStackElement::ST11 => 11, + OpStackElement::ST12 => 12, + OpStackElement::ST13 => 13, + OpStackElement::ST14 => 14, + OpStackElement::ST15 => 15, } } } @@ -356,22 +378,22 @@ impl TryFrom for OpStackElement { fn try_from(stack_index: u32) -> OpStackElementResult { match stack_index { - 0 => Ok(ST0), - 1 => Ok(ST1), - 2 => Ok(ST2), - 3 => Ok(ST3), - 4 => Ok(ST4), - 5 => Ok(ST5), - 6 => Ok(ST6), - 7 => Ok(ST7), - 8 => Ok(ST8), - 9 => Ok(ST9), - 10 => Ok(ST10), - 11 => Ok(ST11), - 12 => Ok(ST12), - 13 => Ok(ST13), - 14 => Ok(ST14), - 15 => Ok(ST15), + 0 => Ok(OpStackElement::ST0), + 1 => Ok(OpStackElement::ST1), + 2 => Ok(OpStackElement::ST2), + 3 => Ok(OpStackElement::ST3), + 4 => Ok(OpStackElement::ST4), + 5 => Ok(OpStackElement::ST5), + 6 => Ok(OpStackElement::ST6), + 7 => Ok(OpStackElement::ST7), + 8 => Ok(OpStackElement::ST8), + 9 => Ok(OpStackElement::ST9), + 10 => Ok(OpStackElement::ST10), + 11 => Ok(OpStackElement::ST11), + 12 => Ok(OpStackElement::ST12), + 13 => Ok(OpStackElement::ST13), + 14 => Ok(OpStackElement::ST14), + 15 => Ok(OpStackElement::ST15), _ => Err(Self::Error::IndexOutOfBounds(stack_index)), } } @@ -478,6 +500,16 @@ pub enum NumberOfWords { N5, } +#[non_exhaustive] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)] +pub enum NumberOfWordsError { + #[error("index {0} is out of range for `NumberOfWords`")] + IndexOutOfBounds(usize), + + #[error(transparent)] + FailedIntegerConversion(#[from] TryFromIntError), +} + impl NumberOfWords { pub const fn num_words(self) -> usize { match self { @@ -489,12 +521,12 @@ impl NumberOfWords { } } - pub(crate) fn legal_values() -> [usize; Self::COUNT] { + pub fn legal_values() -> [usize; Self::COUNT] { let legal_indices = Self::iter().map(|n| n.num_words()).collect_vec(); legal_indices.try_into().unwrap() } - pub(crate) fn illegal_values() -> [usize; OpStackElement::COUNT - Self::COUNT] { + pub fn illegal_values() -> [usize; OpStackElement::COUNT - Self::COUNT] { let all_values = OpStackElement::iter().map(|st| st.index() as usize); let illegal_values = all_values .filter(|i| !Self::legal_values().contains(i)) @@ -642,8 +674,6 @@ mod tests { use strum::IntoEnumIterator; use test_strategy::proptest; - use crate::op_stack::NumberOfWords::N1; - use super::*; impl Default for OpStack { @@ -671,22 +701,22 @@ mod tests { assert!(op_stack.pointer().value() as usize == op_stack.len()); assert!([ - op_stack[ST0], - op_stack[ST1], - op_stack[ST2], - op_stack[ST3], - op_stack[ST4], - op_stack[ST5], - op_stack[ST6], - op_stack[ST7], - op_stack[ST8], - op_stack[ST9], - op_stack[ST10], - op_stack[ST11], - op_stack[ST12], - op_stack[ST13], - op_stack[ST14], - op_stack[ST15], + op_stack[OpStackElement::ST0], + op_stack[OpStackElement::ST1], + op_stack[OpStackElement::ST2], + op_stack[OpStackElement::ST3], + op_stack[OpStackElement::ST4], + op_stack[OpStackElement::ST5], + op_stack[OpStackElement::ST6], + op_stack[OpStackElement::ST7], + op_stack[OpStackElement::ST8], + op_stack[OpStackElement::ST9], + op_stack[OpStackElement::ST10], + op_stack[OpStackElement::ST11], + op_stack[OpStackElement::ST12], + op_stack[OpStackElement::ST13], + op_stack[OpStackElement::ST14], + op_stack[OpStackElement::ST15], op_stack.first_underflow_element(), ] .into_iter() @@ -719,7 +749,7 @@ mod tests { #[filter(#op_stack.len() > 0)] op_stack: OpStack, ) { - let top_element = op_stack[ST0]; + let top_element = op_stack[OpStackElement::ST0]; let mut iterator = op_stack.into_iter(); assert!(top_element == iterator.next().unwrap()); } @@ -854,40 +884,42 @@ mod tests { assert!(let Ok(_) = NumberOfWords::try_from(1_u64)); assert!(let Ok(_) = NumberOfWords::try_from(1_usize)); assert!(let Ok(_) = NumberOfWords::try_from(bfe!(1))); - assert!(let Ok(_) = NumberOfWords::try_from(ST1)); + assert!(let Ok(_) = NumberOfWords::try_from(OpStackElement::ST1)); } #[test] fn convert_from_op_stack_element_to_various_primitive_types() { - let _ = u32::from(ST0); - let _ = u64::from(ST0); - let _ = usize::from(ST0); - let _ = i32::from(ST0); - let _ = BFieldElement::from(ST0); - let _ = bfe!(ST0); + let _ = u32::from(OpStackElement::ST0); + let _ = u64::from(OpStackElement::ST0); + let _ = usize::from(OpStackElement::ST0); + let _ = i32::from(OpStackElement::ST0); + let _ = BFieldElement::from(OpStackElement::ST0); + let _ = bfe!(OpStackElement::ST0); - let _ = u32::from(&ST0); - let _ = usize::from(&ST0); - let _ = i32::from(&ST0); - let _ = BFieldElement::from(&ST0); - let _ = bfe!(&ST0); + let _ = u32::from(&OpStackElement::ST0); + let _ = usize::from(&OpStackElement::ST0); + let _ = i32::from(&OpStackElement::ST0); + let _ = BFieldElement::from(&OpStackElement::ST0); + let _ = bfe!(&OpStackElement::ST0); } #[test] fn convert_from_number_of_words_to_various_primitive_types() { - let _ = u32::from(N1); - let _ = u64::from(N1); - let _ = usize::from(N1); - let _ = BFieldElement::from(N1); - let _ = OpStackElement::from(N1); - let _ = bfe!(N1); - - let _ = u32::from(&N1); - let _ = u64::from(&N1); - let _ = usize::from(&N1); - let _ = BFieldElement::from(&N1); - let _ = OpStackElement::from(&N1); - let _ = bfe!(&N1); + let n1 = NumberOfWords::N1; + + let _ = u32::from(n1); + let _ = u64::from(n1); + let _ = usize::from(n1); + let _ = BFieldElement::from(n1); + let _ = OpStackElement::from(n1); + let _ = bfe!(n1); + + let _ = u32::from(&n1); + let _ = u64::from(&n1); + let _ = usize::from(&n1); + let _ = BFieldElement::from(&n1); + let _ = OpStackElement::from(&n1); + let _ = bfe!(&n1); } #[proptest] diff --git a/triton-vm/src/parser.rs b/triton-isa/src/parser.rs similarity index 70% rename from triton-vm/src/parser.rs rename to triton-isa/src/parser.rs index 5975c438e..f9748fdd8 100644 --- a/triton-vm/src/parser.rs +++ b/triton-isa/src/parser.rs @@ -1,3 +1,4 @@ +use std::collections::hash_map::Entry; use std::collections::HashMap; use std::collections::HashSet; use std::error::Error; @@ -13,16 +14,16 @@ use nom::error::*; use nom::multi::*; use nom::Finish; use nom::IResult; +use twenty_first::bfe; use twenty_first::prelude::BFieldElement; -use crate::instruction::AnInstruction::*; +use crate::instruction::AnInstruction; +use crate::instruction::Instruction; use crate::instruction::LabelledInstruction; +use crate::instruction::TypeHint; use crate::instruction::ALL_INSTRUCTION_NAMES; -use crate::instruction::*; use crate::op_stack::NumberOfWords; -use crate::op_stack::NumberOfWords::*; use crate::op_stack::OpStackElement; -use crate::op_stack::OpStackElement::*; #[derive(Debug, PartialEq)] pub struct ParseError<'a> { @@ -85,7 +86,7 @@ pub fn to_labelled_instructions(instructions: &[InstructionToken]) -> Vec) -> String { +fn pretty_print_error(s: &str, mut e: VerboseError<&str>) -> String { let (_root_s, root_error) = e.errors[0].clone(); if matches!( root_error, @@ -97,7 +98,7 @@ pub fn pretty_print_error(s: &str, mut e: VerboseError<&str>) -> String { } /// Parse a program -pub fn parse(input: &str) -> Result, ParseError> { +pub(crate) fn parse(input: &str) -> Result, ParseError> { let instructions = match tokenize(input).finish() { Ok((_, instructions)) => Ok(instructions), Err(errors) => Err(ParseError { input, errors }), @@ -142,7 +143,7 @@ fn identify_missing_labels<'a>( ) -> HashSet> { let mut missing_labels = HashSet::default(); for instruction in instructions { - if let InstructionToken::Instruction(Call(label), _) = instruction { + if let InstructionToken::Instruction(AnInstruction::Call(label), _) = instruction { if !seen_labels.contains_key(label.as_str()) { missing_labels.insert(instruction.to_owned()); } @@ -225,14 +226,14 @@ fn an_instruction(s: &str) -> ParseResult> { let opstack_manipulation = alt((pop, push, divine, dup, swap)); // Control flow - let halt = instruction("halt", Halt); - let nop = instruction("nop", Nop); - let skiz = instruction("skiz", Skiz); + let halt = instruction("halt", AnInstruction::Halt); + let nop = instruction("nop", AnInstruction::Nop); + let skiz = instruction("skiz", AnInstruction::Skiz); let call = call_instruction(); - let return_ = instruction("return", Return); - let recurse = instruction("recurse", Recurse); - let recurse_or_return = instruction("recurse_or_return", RecurseOrReturn); - let assert = instruction("assert", Assert); + let return_ = instruction("return", AnInstruction::Return); + let recurse = instruction("recurse", AnInstruction::Recurse); + let recurse_or_return = instruction("recurse_or_return", AnInstruction::RecurseOrReturn); + let assert = instruction("assert", AnInstruction::Assert); let control_flow = alt((nop, skiz, call, return_, halt)); @@ -243,33 +244,33 @@ fn an_instruction(s: &str) -> ParseResult> { let memory_access = alt((read_mem, write_mem)); // Hashing-related instructions - let hash = instruction("hash", Hash); - let assert_vector = instruction("assert_vector", AssertVector); - let sponge_init = instruction("sponge_init", SpongeInit); - let sponge_absorb = instruction("sponge_absorb", SpongeAbsorb); - let sponge_absorb_mem = instruction("sponge_absorb_mem", SpongeAbsorbMem); - let sponge_squeeze = instruction("sponge_squeeze", SpongeSqueeze); + let hash = instruction("hash", AnInstruction::Hash); + let assert_vector = instruction("assert_vector", AnInstruction::AssertVector); + let sponge_init = instruction("sponge_init", AnInstruction::SpongeInit); + let sponge_absorb = instruction("sponge_absorb", AnInstruction::SpongeAbsorb); + let sponge_absorb_mem = instruction("sponge_absorb_mem", AnInstruction::SpongeAbsorbMem); + let sponge_squeeze = instruction("sponge_squeeze", AnInstruction::SpongeSqueeze); let hashing_related = alt((hash, sponge_init, sponge_squeeze)); // Arithmetic on stack instructions - let add = instruction("add", Add); + let add = instruction("add", AnInstruction::Add); let addi = addi_instruction(); - let mul = instruction("mul", Mul); - let invert = instruction("invert", Invert); - let eq = instruction("eq", Eq); - let split = instruction("split", Split); - let lt = instruction("lt", Lt); - let and = instruction("and", And); - let xor = instruction("xor", Xor); - let log_2_floor = instruction("log_2_floor", Log2Floor); - let pow = instruction("pow", Pow); - let div_mod = instruction("div_mod", DivMod); - let pop_count = instruction("pop_count", PopCount); - let xx_add = instruction("xx_add", XxAdd); - let xx_mul = instruction("xx_mul", XxMul); - let x_invert = instruction("x_invert", XInvert); - let xb_mul = instruction("xb_mul", XbMul); + let mul = instruction("mul", AnInstruction::Mul); + let invert = instruction("invert", AnInstruction::Invert); + let eq = instruction("eq", AnInstruction::Eq); + let split = instruction("split", AnInstruction::Split); + let lt = instruction("lt", AnInstruction::Lt); + let and = instruction("and", AnInstruction::And); + let xor = instruction("xor", AnInstruction::Xor); + let log_2_floor = instruction("log_2_floor", AnInstruction::Log2Floor); + let pow = instruction("pow", AnInstruction::Pow); + let div_mod = instruction("div_mod", AnInstruction::DivMod); + let pop_count = instruction("pop_count", AnInstruction::PopCount); + let xx_add = instruction("xx_add", AnInstruction::XxAdd); + let xx_mul = instruction("xx_mul", AnInstruction::XxMul); + let x_invert = instruction("x_invert", AnInstruction::XInvert); + let xb_mul = instruction("xb_mul", AnInstruction::XbMul); let base_field_arithmetic_on_stack = alt((mul, invert, eq)); let bitwise_arithmetic_on_stack = @@ -288,10 +289,10 @@ fn an_instruction(s: &str) -> ParseResult> { let read_write = alt((read_io, write_io)); // Many-in-One - let merkle_step = instruction("merkle_step", MerkleStep); - let merkle_step_mem = instruction("merkle_step_mem", MerkleStepMem); - let xx_dot_step = instruction("xx_dot_step", XxDotStep); - let xb_dot_step = instruction("xb_dot_step", XbDotStep); + let merkle_step = instruction("merkle_step", AnInstruction::MerkleStep); + let merkle_step_mem = instruction("merkle_step_mem", AnInstruction::MerkleStepMem); + let xx_dot_step = instruction("xx_dot_step", AnInstruction::XxDotStep); + let xb_dot_step = instruction("xb_dot_step", AnInstruction::XbDotStep); let many_to_one = alt((xx_dot_step, xb_dot_step)); @@ -342,7 +343,7 @@ fn pop_instruction() -> impl Fn(&str) -> ParseResult> { move |s: &str| { let (s, _) = token1("pop")(s)?; let (s, arg) = number_of_words(s)?; - Ok((s, Pop(arg))) + Ok((s, AnInstruction::Pop(arg))) } } @@ -350,7 +351,7 @@ fn push_instruction() -> impl Fn(&str) -> ParseResult> { move |s: &str| { let (s, _) = token1("push")(s)?; let (s, elem) = field_element(s)?; - Ok((s, Push(elem))) + Ok((s, AnInstruction::Push(elem))) } } @@ -358,7 +359,7 @@ fn addi_instruction() -> impl Fn(&str) -> ParseResult> { move |s: &str| { let (s, _) = token1("addi")(s)?; let (s, elem) = field_element(s)?; - Ok((s, AddI(elem))) + Ok((s, AnInstruction::AddI(elem))) } } @@ -366,7 +367,7 @@ fn divine_instruction() -> impl Fn(&str) -> ParseResult> { move |s: &str| { let (s, _) = token1("divine")(s)?; let (s, arg) = number_of_words(s)?; - Ok((s, Divine(arg))) + Ok((s, AnInstruction::Divine(arg))) } } @@ -374,7 +375,7 @@ fn dup_instruction() -> impl Fn(&str) -> ParseResult> { move |s: &str| { let (s, _) = token1("dup")(s)?; // require space before argument let (s, stack_register) = stack_register(s)?; - Ok((s, Dup(stack_register))) + Ok((s, AnInstruction::Dup(stack_register))) } } @@ -382,7 +383,7 @@ fn swap_instruction() -> impl Fn(&str) -> ParseResult> { move |s: &str| { let (s, _) = token1("swap")(s)?; let (s, stack_register) = stack_register(s)?; - Ok((s, Swap(stack_register))) + Ok((s, AnInstruction::Swap(stack_register))) } } @@ -413,7 +414,7 @@ fn call_instruction<'a>() -> impl Fn(&'a str) -> ParseResult impl Fn(&str) -> ParseResult> move |s: &str| { let (s, _) = token1("read_mem")(s)?; let (s, arg) = number_of_words(s)?; - Ok((s, ReadMem(arg))) + Ok((s, AnInstruction::ReadMem(arg))) } } @@ -429,7 +430,7 @@ fn write_mem_instruction() -> impl Fn(&str) -> ParseResult move |s: &str| { let (s, _) = token1("write_mem")(s)?; let (s, arg) = number_of_words(s)?; - Ok((s, WriteMem(arg))) + Ok((s, AnInstruction::WriteMem(arg))) } } @@ -437,7 +438,7 @@ fn read_io_instruction() -> impl Fn(&str) -> ParseResult> move |s: &str| { let (s, _) = token1("read_io")(s)?; let (s, arg) = number_of_words(s)?; - Ok((s, ReadIo(arg))) + Ok((s, AnInstruction::ReadIo(arg))) } } @@ -445,7 +446,7 @@ fn write_io_instruction() -> impl Fn(&str) -> ParseResult> move |s: &str| { let (s, _) = token1("write_io")(s)?; let (s, arg) = number_of_words(s)?; - Ok((s, WriteIo(arg))) + Ok((s, AnInstruction::WriteIo(arg))) } } @@ -474,22 +475,22 @@ fn field_element(s_orig: &str) -> ParseResult { fn stack_register(s: &str) -> ParseResult { let (s, n) = digit1(s)?; let stack_register = match n { - "0" => ST0, - "1" => ST1, - "2" => ST2, - "3" => ST3, - "4" => ST4, - "5" => ST5, - "6" => ST6, - "7" => ST7, - "8" => ST8, - "9" => ST9, - "10" => ST10, - "11" => ST11, - "12" => ST12, - "13" => ST13, - "14" => ST14, - "15" => ST15, + "0" => OpStackElement::ST0, + "1" => OpStackElement::ST1, + "2" => OpStackElement::ST2, + "3" => OpStackElement::ST3, + "4" => OpStackElement::ST4, + "5" => OpStackElement::ST5, + "6" => OpStackElement::ST6, + "7" => OpStackElement::ST7, + "8" => OpStackElement::ST8, + "9" => OpStackElement::ST9, + "10" => OpStackElement::ST10, + "11" => OpStackElement::ST11, + "12" => OpStackElement::ST12, + "13" => OpStackElement::ST13, + "14" => OpStackElement::ST14, + "15" => OpStackElement::ST15, _ => return context("using an out-of-bounds stack register (0-15 exist)", fail)(s), }; let (s, _) = comment_or_whitespace1(s)?; @@ -500,11 +501,11 @@ fn stack_register(s: &str) -> ParseResult { fn number_of_words(s: &str) -> ParseResult { let (s, n) = digit1(s)?; let arg = match n { - "1" => N1, - "2" => N2, - "3" => N3, - "4" => N4, - "5" => N5, + "1" => NumberOfWords::N1, + "2" => NumberOfWords::N2, + "3" => NumberOfWords::N3, + "4" => NumberOfWords::N4, + "5" => NumberOfWords::N5, _ => return context("using an out-of-bounds argument (1-5 allowed)", fail)(s), }; let (s, _) = comment_or_whitespace1(s)?; // require space after element @@ -694,6 +695,57 @@ fn parse_str_to_usize(s: &str) -> ParseResult { } } +pub(crate) fn build_label_to_address_map(program: &[LabelledInstruction]) -> HashMap { + let mut label_map = HashMap::new(); + let mut instruction_pointer = 0; + + for labelled_instruction in program { + if let LabelledInstruction::Instruction(instruction) = labelled_instruction { + instruction_pointer += instruction.size() as u64; + continue; + } + + let LabelledInstruction::Label(label) = labelled_instruction else { + continue; + }; + let Entry::Vacant(new_label_map_entry) = label_map.entry(label.clone()) else { + panic!("Duplicate label: {label}"); + }; + new_label_map_entry.insert(instruction_pointer); + } + + label_map +} + +pub(crate) fn turn_labels_into_addresses( + labelled_instructions: &[LabelledInstruction], + label_to_address: &HashMap, +) -> Vec { + fn turn_label_to_address_for_instruction( + labelled_instruction: &LabelledInstruction, + label_map: &HashMap, + ) -> Option { + let LabelledInstruction::Instruction(instruction) = labelled_instruction else { + return None; + }; + + let instruction_with_absolute_address = + instruction.map_call_address(|label| address_for_label(label, label_map)); + Some(instruction_with_absolute_address) + } + + fn address_for_label(label: &str, label_map: &HashMap) -> BFieldElement { + let maybe_address = label_map.get(label).map(|&a| bfe!(a)); + maybe_address.unwrap_or_else(|| panic!("Label not found: {label}")) + } + + labelled_instructions + .iter() + .filter_map(|inst| turn_label_to_address_for_instruction(inst, label_to_address)) + .flat_map(|inst| vec![inst; inst.size()]) + .collect() +} + #[cfg(test)] pub(crate) mod tests { use assert2::assert; @@ -709,11 +761,6 @@ pub(crate) mod tests { use twenty_first::bfe; use twenty_first::prelude::Digest; - use LabelledInstruction::Breakpoint; - use LabelledInstruction::Instruction; - use LabelledInstruction::Label; - - use crate::program::Program; use crate::triton_asm; use crate::triton_instr; use crate::triton_program; @@ -722,7 +769,7 @@ pub(crate) mod tests { struct TestCase<'a> { input: &'a str, - expected: Program, + expected: Vec, message: &'static str, } @@ -733,104 +780,121 @@ pub(crate) mod tests { message: &'static str, } - fn parse_program_prop(test_case: TestCase) { - let message = test_case.message; - let parse_result = parse(test_case.input).map_err(|err| format!("{message}:\n{err}")); - let_assert!(Ok(actual) = parse_result); + impl<'a> TestCase<'a> { + fn run(&self) { + let message = self.message; + let parse_result = parse(self.input).map_err(|err| format!("{message}:\n{err}")); + let_assert!(Ok(actual) = parse_result); - let actual_program = Program::new(&to_labelled_instructions(&actual)); - assert!(test_case.expected == actual_program, "{message}"); + let labelled_instructions = to_labelled_instructions(&actual); + let label_to_address = build_label_to_address_map(&labelled_instructions); + let instructions = + turn_labels_into_addresses(&labelled_instructions, &label_to_address); + assert!(self.expected == instructions, "{message}"); + } } - fn parse_program_neg_prop(test_case: NegativeTestCase) { - let result = parse(test_case.input); - if result.is_ok() { - eprintln!("parser input: {}", test_case.input); - eprintln!("parser output: {:?}", result.unwrap()); - panic!("parser should fail, but didn't: {}", test_case.message); - } + impl<'a> NegativeTestCase<'a> { + fn run(self) { + let result = parse(self.input); + if result.is_ok() { + eprintln!("parser input: {}", self.input); + eprintln!("parser output: {:?}", result.unwrap()); + panic!("parser should fail, but didn't: {}", self.message); + } - let error = result.unwrap_err(); - let actual_error_message = format!("{error}"); - let actual_error_count = actual_error_message - .match_indices(test_case.expected_error) - .count(); - if test_case.expected_error_count != actual_error_count { - eprintln!("Actual error message:"); - eprintln!("{actual_error_message}"); - assert_eq!( - test_case.expected_error_count, actual_error_count, - "parser should report '{}' {} times: {}", - test_case.expected_error, test_case.expected_error_count, test_case.message - ) + let error = result.unwrap_err(); + let actual_error_message = format!("{error}"); + let actual_error_count = actual_error_message + .match_indices(self.expected_error) + .count(); + if self.expected_error_count != actual_error_count { + eprintln!("Actual error message:"); + eprintln!("{actual_error_message}"); + assert_eq!( + self.expected_error_count, actual_error_count, + "parser should report '{}' {} times: {}", + self.expected_error, self.expected_error_count, self.message + ) + } } } #[test] fn parse_program_empty() { - parse_program_prop(TestCase { + TestCase { input: "", - expected: Program::new(&[]), + expected: vec![], message: "empty string should parse as empty program", - }); + } + .run(); - parse_program_prop(TestCase { + TestCase { input: " ", - expected: Program::new(&[]), + expected: vec![], message: "spaces should parse as empty program", - }); + } + .run(); - parse_program_prop(TestCase { + TestCase { input: "\n", - expected: Program::new(&[]), + expected: vec![], message: "linebreaks should parse as empty program (1)", - }); + } + .run(); - parse_program_prop(TestCase { + TestCase { input: " \n ", - expected: Program::new(&[]), + expected: vec![], message: "linebreaks should parse as empty program (2)", - }); + } + .run(); - parse_program_prop(TestCase { + TestCase { input: " \n \n", - expected: Program::new(&[]), + expected: vec![], message: "linebreaks should parse as empty program (3)", - }); + } + .run(); - parse_program_prop(TestCase { + TestCase { input: "// empty program", - expected: Program::new(&[]), + expected: vec![], message: "single comment should parse as empty program", - }); + } + .run(); - parse_program_prop(TestCase { + TestCase { input: "// empty program\n", - expected: Program::new(&[]), + expected: vec![], message: "single comment with linebreak should parse as empty program", - }); + } + .run(); - parse_program_prop(TestCase { + TestCase { input: "// multi-line\n// comment", - expected: Program::new(&[]), + expected: vec![], message: "multiple comments should parse as empty program", - }); + } + .run(); - parse_program_prop(TestCase { + TestCase { input: "// multi-line\n// comment\n ", - expected: Program::new(&[]), + expected: vec![], message: "multiple comments with trailing whitespace should parse as empty program", - }); + } + .run(); } #[proptest] fn arbitrary_whitespace_and_comment_sequence_is_empty_program(whitespace: Vec) { let whitespace = whitespace.into_iter().join(""); - parse_program_prop(TestCase { + TestCase { input: &whitespace, - expected: Program::new(&[]), + expected: vec![], message: "arbitrary whitespace should parse as empty program", - }); + } + .run(); } #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, EnumCount, Arbitrary)] @@ -858,367 +922,278 @@ pub(crate) mod tests { #[test] fn parse_program_whitespace() { - parse_program_neg_prop(NegativeTestCase { + NegativeTestCase { input: "poppop", expected_error: "n/a", expected_error_count: 0, message: "whitespace required between instructions (pop)", - }); + } + .run(); - parse_program_neg_prop(NegativeTestCase { + NegativeTestCase { input: "dup 0dup 0", expected_error: "n/a", expected_error_count: 0, message: "whitespace required between instructions (dup 0)", - }); + } + .run(); - parse_program_neg_prop(NegativeTestCase { + NegativeTestCase { input: "swap 10swap 10", expected_error: "n/a", expected_error_count: 0, message: "whitespace required between instructions (swap 10)", - }); + } + .run(); - parse_program_neg_prop(NegativeTestCase { + NegativeTestCase { input: "push10", expected_error: "n/a", expected_error_count: 0, message: "push requires whitespace before its constant", - }); + } + .run(); - parse_program_neg_prop(NegativeTestCase { + NegativeTestCase { input: "push 10pop", expected_error: "n/a", expected_error_count: 0, message: "push requires whitespace after its constant", - }); + } + .run(); - parse_program_neg_prop(NegativeTestCase { + NegativeTestCase { input: "hello: callhello", expected_error: "n/a", expected_error_count: 0, message: "call requires whitespace before its label", - }); + } + .run(); - parse_program_neg_prop(NegativeTestCase { + NegativeTestCase { input: "hello: popcall hello", expected_error: "n/a", expected_error_count: 0, message: "required space between pop and call", - }); + } + .run(); } #[test] fn parse_program_label() { - parse_program_prop(TestCase { + TestCase { input: "foo: call foo", - expected: Program::new(&[ - Label("foo".to_string()), - Instruction(Call("foo".to_string())), - ]), + expected: vec![Instruction::Call(bfe!(0))], message: "parse labels and calls to labels", - }); + } + .run(); - parse_program_prop(TestCase { + TestCase { input: "foo:call foo", - expected: Program::new(&[ - Label("foo".to_string()), - Instruction(Call("foo".to_string())), - ]), + expected: vec![Instruction::Call(bfe!(0))], message: "whitespace is not required after 'label:'", - }); + } + .run(); // FIXME: Increase coverage of negative tests for duplicate labels. - parse_program_neg_prop(NegativeTestCase { + NegativeTestCase { input: "foo: pop 1 foo: pop 1 call foo", expected_error: "duplicate label", expected_error_count: 2, message: "labels cannot occur twice", - }); + } + .run(); // FIXME: Increase coverage of negative tests for missing labels. - parse_program_neg_prop(NegativeTestCase { + NegativeTestCase { input: "foo: pop 1 call herp call derp", expected_error: "missing label", expected_error_count: 2, message: "non-existent labels cannot be called", - }); + } + .run(); // FIXME: Increase coverage of negative tests for label/keyword overlap. - parse_program_neg_prop(NegativeTestCase { + NegativeTestCase { input: "pop: call pop", expected_error: "label cannot be named after instruction", expected_error_count: 1, message: "label names may not overlap with instruction names", - }); + } + .run(); - parse_program_prop(TestCase { + TestCase { input: "pops: call pops", - expected: Program::new(&[ - Label("pops".to_string()), - Instruction(Call("pops".to_string())), - ]), + expected: vec![Instruction::Call(bfe!(0))], message: "labels that share a common prefix with instruction are labels", - }); + } + .run(); - parse_program_prop(TestCase { + TestCase { input: "_call: call _call", - expected: Program::new(&[ - Label("_call".to_string()), - Instruction(Call("_call".to_string())), - ]), + expected: vec![Instruction::Call(bfe!(0))], message: "labels that share a common suffix with instruction are labels", - }); + } + .run(); } #[test] fn parse_program_nonexistent_instructions() { - parse_program_neg_prop(NegativeTestCase { + NegativeTestCase { input: "pop 0", expected_error: "expecting label, instruction or eof", expected_error_count: 1, message: "instruction `pop` cannot take argument `0`", - }); + } + .run(); - parse_program_neg_prop(NegativeTestCase { + NegativeTestCase { input: "swap 16", expected_error: "expecting label, instruction or eof", expected_error_count: 1, message: "there is no swap 16 instruction", - }); + } + .run(); - parse_program_neg_prop(NegativeTestCase { + NegativeTestCase { input: "dup 16", expected_error: "expecting label, instruction or eof", expected_error_count: 1, message: "there is no dup 16 instruction", - }); + } + .run(); } #[test] fn parse_program_bracket_syntax() { - parse_program_prop(TestCase { + TestCase { input: "foo: [foo]", - expected: Program::new(&[ - Label("foo".to_string()), - Instruction(Call("foo".to_string())), - ]), + expected: vec![Instruction::Call(bfe!(0))], message: "Handle brackets as call syntax sugar", - }); + } + .run(); - parse_program_neg_prop(NegativeTestCase { + NegativeTestCase { input: "foo: [bar]", expected_error: "missing label", expected_error_count: 1, message: "Handle missing labels with bracket syntax", - }) - } - - #[proptest] - fn parse_program(#[strategy(arb())] program: Program) { - parse(&program.to_string()).unwrap(); + } + .run(); } #[test] fn parse_program_label_must_start_with_alphabetic_character_or_underscore() { - parse_program_neg_prop(NegativeTestCase { + NegativeTestCase { input: "1foo: call 1foo", expected_error: "expecting label, instruction or eof", expected_error_count: 1, message: "labels cannot start with a digit", - }); + } + .run(); - parse_program_neg_prop(NegativeTestCase { + NegativeTestCase { input: "-foo: call -foo", expected_error: "expecting label, instruction or eof", expected_error_count: 1, message: "labels cannot start with a dash", - }); + } + .run(); - parse_program_prop(TestCase { + TestCase { input: "_foo: call _foo", - expected: Program::new(&[ - Label("_foo".to_string()), - Instruction(Call("_foo".to_string())), - ]), + expected: vec![Instruction::Call(bfe!(0))], message: "labels can start with an underscore", - }); - } - - #[test] - fn parse_simple_type_hint() { - let expected_type_hint = TypeHint { - starting_index: 0, - length: 1, - type_name: Some("Type".to_string()), - variable_name: "foo".to_string(), - }; - - parse_program_prop(TestCase { - input: "hint foo: Type = stack[0]", - expected: Program::new(&[LabelledInstruction::TypeHint(expected_type_hint)]), - message: "parse type hint", - }); - } - - #[test] - fn parse_type_hint_with_range() { - let expected_type_hint = TypeHint { - starting_index: 0, - length: 5, - type_name: Some("Digest".to_string()), - variable_name: "foo".to_string(), - }; - - parse_program_prop(TestCase { - input: "hint foo: Digest = stack[0..5]", - expected: Program::new(&[LabelledInstruction::TypeHint(expected_type_hint)]), - message: "parse type hint with range", - }); - } - - #[test] - fn parse_type_hint_with_range_and_offset() { - let expected_type_hint = TypeHint { - starting_index: 7, - length: 3, - type_name: Some("XFieldElement".to_string()), - variable_name: "bar".to_string(), - }; - - parse_program_prop(TestCase { - input: "hint bar: XFieldElement = stack[7..10]", - expected: Program::new(&[LabelledInstruction::TypeHint(expected_type_hint)]), - message: "parse type hint with range and offset", - }); - } - - #[test] - fn parse_type_hint_with_range_and_offset_and_weird_whitespace() { - let expected_type_hint = TypeHint { - starting_index: 2, - length: 12, - type_name: Some("BigType".to_string()), - variable_name: "bar".to_string(), - }; - - parse_program_prop(TestCase { - input: " hint \t \t foo :BigType=stack[ 2\t.. 14 ]\t \n", - expected: Program::new(&[LabelledInstruction::TypeHint(expected_type_hint)]), - message: "parse type hint with range and offset and weird whitespace", - }); - } - - #[test] - fn parse_type_hint_with_no_type_only_variable_name() { - let expected_type_hint = TypeHint { - starting_index: 0, - length: 1, - type_name: None, - variable_name: "foo".to_string(), - }; - - parse_program_prop(TestCase { - input: "hint foo = stack[0]", - expected: Program::new(&[LabelledInstruction::TypeHint(expected_type_hint)]), - message: "parse type hint with no type, only variable name", - }); - } - - #[test] - fn parse_type_hint_with_no_type_only_variable_name_with_range() { - let expected_type_hint = TypeHint { - starting_index: 2, - length: 5, - type_name: None, - variable_name: "foo".to_string(), - }; - - parse_program_prop(TestCase { - input: "hint foo = stack[2..7]", - expected: Program::new(&[LabelledInstruction::TypeHint(expected_type_hint)]), - message: "parse type hint with no type, only variable name, with range", - }); + } + .run(); } #[test] fn parse_type_hint_with_zero_length_range() { - parse_program_neg_prop(NegativeTestCase { + NegativeTestCase { input: "hint foo: Type = stack[0..0]", expected_error: "range end must be greater than range start", expected_error_count: 1, message: "parse type hint with zero-length range", - }); + } + .run(); } #[test] fn parse_type_hint_with_range_and_offset_and_missing_closing_bracket() { - parse_program_neg_prop(NegativeTestCase { + NegativeTestCase { input: "hint foo: Type = stack[2..5", expected_error: "expecting label, instruction or eof", expected_error_count: 1, message: "parse type hint with range and offset and missing closing bracket", - }); + } + .run(); } #[test] fn parse_type_hint_with_range_and_offset_and_missing_opening_bracket() { - parse_program_neg_prop(NegativeTestCase { + NegativeTestCase { input: "hint foo: Type = stack2..5]", expected_error: "expecting label, instruction or eof", expected_error_count: 1, message: "parse type hint with range and offset and missing opening bracket", - }); + } + .run(); } #[test] fn parse_type_hint_with_range_and_offset_and_missing_equals_sign() { - parse_program_neg_prop(NegativeTestCase { + NegativeTestCase { input: "hint foo: Type stack[2..5];", expected_error: "expecting label, instruction or eof", expected_error_count: 1, message: "parse type hint with range and offset and missing equals sign", - }); + } + .run(); } #[test] fn parse_type_hint_with_range_and_offset_and_missing_type_name() { - parse_program_neg_prop(NegativeTestCase { + NegativeTestCase { input: "hint foo: = stack[2..5]", expected_error: "expecting label, instruction or eof", expected_error_count: 1, message: "parse type hint with range and offset and missing type name", - }); + } + .run(); } #[test] fn parse_type_hint_with_range_and_offset_and_missing_variable_name() { - parse_program_neg_prop(NegativeTestCase { + NegativeTestCase { input: "hint : Type = stack[2..5]", expected_error: "expecting label, instruction or eof", expected_error_count: 1, message: "parse type hint with range and offset and missing variable name", - }); + } + .run(); } #[test] fn parse_type_hint_with_range_and_offset_and_missing_colon() { - parse_program_neg_prop(NegativeTestCase { + NegativeTestCase { input: "hint foo Type = stack[2..5]", expected_error: "expecting label, instruction or eof", expected_error_count: 1, message: "parse type hint with range and offset and missing colon", - }); + } + .run(); } #[test] fn parse_type_hint_with_range_and_offset_and_missing_hint() { - parse_program_neg_prop(NegativeTestCase { + NegativeTestCase { input: "foo: Type = stack[2..5];", expected_error: "expecting label, instruction or eof", expected_error_count: 1, message: "parse type hint with range and offset and missing hint", - }); + } + .run(); } #[proptest] @@ -1236,30 +1211,57 @@ pub(crate) mod tests { #[test] fn triton_asm_macro() { let instructions = triton_asm!(write_io 3 push 17 call huh lt swap 3); - assert_eq!(Instruction(WriteIo(N3)), instructions[0]); - assert_eq!(Instruction(Push(bfe!(17))), instructions[1]); - assert_eq!(Instruction(Call("huh".to_string())), instructions[2]); - assert_eq!(Instruction(Lt), instructions[3]); - assert_eq!(Instruction(Swap(ST3)), instructions[4]); + assert_eq!( + LabelledInstruction::Instruction(AnInstruction::WriteIo(NumberOfWords::N3)), + instructions[0] + ); + assert_eq!( + LabelledInstruction::Instruction(AnInstruction::Push(bfe!(17))), + instructions[1] + ); + assert_eq!( + LabelledInstruction::Instruction(AnInstruction::Call("huh".to_string())), + instructions[2] + ); + assert_eq!( + LabelledInstruction::Instruction(AnInstruction::Lt), + instructions[3] + ); + assert_eq!( + LabelledInstruction::Instruction(AnInstruction::Swap(OpStackElement::ST3)), + instructions[4] + ); } #[test] fn triton_asm_macro_with_a_single_return() { let instructions = triton_asm!(return); - assert_eq!(Instruction(Return), instructions[0]); + assert_eq!( + LabelledInstruction::Instruction(AnInstruction::Return), + instructions[0] + ); } #[test] fn triton_asm_macro_with_a_single_assert() { let instructions = triton_asm!(assert); - assert_eq!(Instruction(Assert), instructions[0]); + assert_eq!( + LabelledInstruction::Instruction(AnInstruction::Assert), + instructions[0] + ); } #[test] fn triton_asm_macro_with_only_assert_and_return() { let instructions = triton_asm!(assert return); - assert_eq!(Instruction(Assert), instructions[0]); - assert_eq!(Instruction(Return), instructions[1]); + assert_eq!( + LabelledInstruction::Instruction(AnInstruction::Assert), + instructions[0] + ); + assert_eq!( + LabelledInstruction::Instruction(AnInstruction::Return), + instructions[1] + ); } #[test] @@ -1295,18 +1297,26 @@ pub(crate) mod tests { #[test] fn triton_program_macro_interpolates_various_types() { let push_arg = thread_rng().gen_range(0_u64..BFieldElement::P); - let instruction_push = Instruction(Push(push_arg.into())); + let instruction_push = + LabelledInstruction::Instruction(AnInstruction::Push(push_arg.into())); let swap_argument = "1"; - triton_program!({instruction_push} push {push_arg} swap {swap_argument} eq assert halt) - .run([].into(), [].into()) - .unwrap(); + triton_program!({instruction_push} push {push_arg} swap {swap_argument} eq assert halt); } #[test] fn triton_instruction_macro_parses_simple_instructions() { - assert_eq!(Instruction(Halt), triton_instr!(halt)); - assert_eq!(Instruction(Add), triton_instr!(add)); - assert_eq!(Instruction(Pop(N3)), triton_instr!(pop 3)); + assert_eq!( + LabelledInstruction::Instruction(AnInstruction::Halt), + triton_instr!(halt) + ); + assert_eq!( + LabelledInstruction::Instruction(AnInstruction::Add), + triton_instr!(add) + ); + assert_eq!( + LabelledInstruction::Instruction(AnInstruction::Pop(NumberOfWords::N3)), + triton_instr!(pop 3) + ); } #[test] @@ -1317,11 +1327,20 @@ pub(crate) mod tests { #[test] fn triton_instruction_macro_parses_instructions_with_argument() { - assert_eq!(Instruction(Push(bfe!(7))), triton_instr!(push 7)); - assert_eq!(Instruction(Dup(ST3)), triton_instr!(dup 3)); - assert_eq!(Instruction(Swap(ST5)), triton_instr!(swap 5)); assert_eq!( - Instruction(Call("my_label".to_string())), + LabelledInstruction::Instruction(AnInstruction::Push(bfe!(7))), + triton_instr!(push 7) + ); + assert_eq!( + LabelledInstruction::Instruction(AnInstruction::Dup(OpStackElement::ST3)), + triton_instr!(dup 3) + ); + assert_eq!( + LabelledInstruction::Instruction(AnInstruction::Swap(OpStackElement::ST5)), + triton_instr!(swap 5) + ); + assert_eq!( + LabelledInstruction::Instruction(AnInstruction::Call("my_label".to_string())), triton_instr!(call my_label) ); } @@ -1329,22 +1348,28 @@ pub(crate) mod tests { #[test] fn triton_asm_macro_can_repeat_instructions() { let instructions = triton_asm![push 42; 3]; - let expected_instructions = vec![Instruction(Push(bfe!(42))); 3]; + let expected_instructions = + vec![LabelledInstruction::Instruction(AnInstruction::Push(bfe!(42))); 3]; assert_eq!(expected_instructions, instructions); let instructions = triton_asm![read_io 2; 15]; - let expected_instructions = vec![Instruction(ReadIo(N2)); 15]; + let expected_instructions = + vec![LabelledInstruction::Instruction(AnInstruction::ReadIo(NumberOfWords::N2)); 15]; assert_eq!(expected_instructions, instructions); let instructions = triton_asm![divine 3; Digest::LEN]; - let expected_instructions = vec![Instruction(Divine(N3)); Digest::LEN]; + let expected_instructions = + vec![ + LabelledInstruction::Instruction(AnInstruction::Divine(NumberOfWords::N3)); + Digest::LEN + ]; assert_eq!(expected_instructions, instructions); } #[test] fn break_gets_turned_into_labelled_instruction() { let instructions = triton_asm![break]; - let expected_instructions = vec![Breakpoint]; + let expected_instructions = vec![LabelledInstruction::Breakpoint]; assert_eq!(expected_instructions, instructions); } @@ -1353,37 +1378,4 @@ pub(crate) mod tests { let program = triton_program! { break halt break }; assert_eq!(1, program.len_bwords()); } - - #[test] - fn printing_program_includes_debug_information() { - let source_code = "\ - call foo\n\ - break\n\ - call bar\n\ - halt\n\ - foo:\n\ - break\n\ - call baz\n\ - push 1\n\ - nop\n\ - return\n\ - baz:\n\ - hash\n\ - hint my_digest: Digest = stack[0..5]\n\ - hint random_stuff = stack[17]\n\ - return\n\ - nop\n\ - pop 1\n\ - bar:\n\ - divine 1\n\ - hint got_insight: Magic = stack[0]\n\ - skiz\n\ - split\n\ - break\n\ - return\n\ - "; - let program = Program::from_code(source_code).unwrap(); - let printed_program = format!("{program}"); - assert_eq!(source_code, &printed_program); - } } diff --git a/triton-isa/src/program.rs b/triton-isa/src/program.rs new file mode 100644 index 000000000..a0112dc8c --- /dev/null +++ b/triton-isa/src/program.rs @@ -0,0 +1,689 @@ +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::collections::HashSet; +use std::fmt::Display; +use std::fmt::Formatter; +use std::fmt::Result as FmtResult; +use std::hash::Hash; +use std::io::Cursor; + +use arbitrary::Arbitrary; +use get_size::GetSize; +use itertools::Itertools; +use serde::Deserialize; +use serde::Serialize; +use thiserror::Error; +use twenty_first::prelude::*; + +use crate::instruction::AnInstruction; +use crate::instruction::Instruction; +use crate::instruction::InstructionError; +use crate::instruction::LabelledInstruction; +use crate::instruction::TypeHint; +use crate::parser; +use crate::parser::ParseError; + +/// A program for Triton VM. Triton VM can run and profile such programs, +/// and trace its execution in order to generate a proof of correct execution. +/// See there for details. +/// +/// A program may contain debug information, such as label names and breakpoints. +/// Access this information through methods [`label_for_address()`][label_for_address] and +/// [`is_breakpoint()`][is_breakpoint]. Some operations, most notably +/// [BField-encoding](BFieldCodec::encode), discard this debug information. +/// +/// [program attestation]: https://triton-vm.org/spec/program-attestation.html +/// [label_for_address]: Program::label_for_address +/// [is_breakpoint]: Program::is_breakpoint +#[derive(Debug, Clone, Eq, Serialize, Deserialize, GetSize)] +pub struct Program { + pub instructions: Vec, + address_to_label: HashMap, + breakpoints: Vec, + type_hints: HashMap>, +} + +impl Display for Program { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + for instruction in self.labelled_instructions() { + writeln!(f, "{instruction}")?; + } + Ok(()) + } +} + +impl PartialEq for Program { + fn eq(&self, other: &Program) -> bool { + self.instructions.eq(&other.instructions) + } +} + +impl BFieldCodec for Program { + type Error = ProgramDecodingError; + + fn decode(sequence: &[BFieldElement]) -> Result, Self::Error> { + if sequence.is_empty() { + return Err(Self::Error::EmptySequence); + } + let program_length = sequence[0].value() as usize; + let sequence = &sequence[1..]; + if sequence.len() < program_length { + return Err(Self::Error::SequenceTooShort); + } + if sequence.len() > program_length { + return Err(Self::Error::SequenceTooLong); + } + + // instantiating with claimed capacity is a potential DOS vector + let mut instructions = vec![]; + let mut read_idx = 0; + while read_idx < program_length { + let opcode = sequence[read_idx]; + let mut instruction = Instruction::try_from(opcode) + .map_err(|err| Self::Error::InvalidInstruction(read_idx, err))?; + let instruction_has_arg = instruction.arg().is_some(); + if instruction_has_arg && instructions.len() + instruction.size() > program_length { + return Err(Self::Error::MissingArgument(read_idx, instruction)); + } + if instruction_has_arg { + let arg = sequence[read_idx + 1]; + instruction = instruction + .change_arg(arg) + .map_err(|err| Self::Error::InvalidInstruction(read_idx, err))?; + } + + instructions.extend(vec![instruction; instruction.size()]); + read_idx += instruction.size(); + } + + if read_idx != program_length { + return Err(Self::Error::LengthMismatch); + } + if instructions.len() != program_length { + return Err(Self::Error::LengthMismatch); + } + + Ok(Box::new(Program { + instructions, + address_to_label: HashMap::default(), + breakpoints: vec![], + type_hints: HashMap::default(), + })) + } + + fn encode(&self) -> Vec { + let mut sequence = Vec::with_capacity(self.len_bwords() + 1); + sequence.push(bfe!(self.len_bwords() as u64)); + sequence.extend(self.to_bwords()); + sequence + } + + fn static_length() -> Option { + None + } +} + +impl<'a> Arbitrary<'a> for Program { + fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { + let contains_label = |labelled_instructions: &[_], maybe_label: &_| { + let LabelledInstruction::Label(label) = maybe_label else { + return false; + }; + labelled_instructions + .iter() + .any(|labelled_instruction| match labelled_instruction { + LabelledInstruction::Label(l) => l == label, + _ => false, + }) + }; + + let mut labelled_instructions = vec![]; + for _ in 0..u.arbitrary_len::()? { + let labelled_instruction = u.arbitrary()?; + if contains_label(&labelled_instructions, &labelled_instruction) { + continue; + } + labelled_instructions.push(labelled_instruction); + } + + let call_targets = labelled_instructions + .iter() + .filter_map(|instruction| match instruction { + LabelledInstruction::Instruction(AnInstruction::Call(target)) => Some(target), + _ => None, + }) + .unique(); + let additional_labels = call_targets + .map(|target| LabelledInstruction::Label(target.clone())) + .collect_vec(); + + for additional_label in additional_labels { + if contains_label(&labelled_instructions, &additional_label) { + continue; + } + let insertion_index = u.choose_index(labelled_instructions.len() + 1)?; + labelled_instructions.insert(insertion_index, additional_label); + } + + Ok(Program::new(&labelled_instructions)) + } +} + +/// An `InstructionIter` loops the instructions of a `Program` by skipping duplicate placeholders. +#[derive(Debug, Default, Clone, Eq, PartialEq)] +pub struct InstructionIter { + cursor: Cursor>, +} + +impl Iterator for InstructionIter { + type Item = Instruction; + + fn next(&mut self) -> Option { + let pos = self.cursor.position() as usize; + let instructions = self.cursor.get_ref(); + let instruction = *instructions.get(pos)?; + self.cursor.set_position((pos + instruction.size()) as u64); + + Some(instruction) + } +} + +impl IntoIterator for Program { + type Item = Instruction; + + type IntoIter = InstructionIter; + + fn into_iter(self) -> Self::IntoIter { + let cursor = Cursor::new(self.instructions); + InstructionIter { cursor } + } +} + +impl Program { + pub fn new(labelled_instructions: &[LabelledInstruction]) -> Self { + let label_to_address = parser::build_label_to_address_map(labelled_instructions); + let instructions = + parser::turn_labels_into_addresses(labelled_instructions, &label_to_address); + let address_to_label = Self::flip_map(label_to_address); + let (breakpoints, type_hints) = Self::extract_debug_information(labelled_instructions); + + assert_eq!(instructions.len(), breakpoints.len()); + Program { + instructions, + address_to_label, + breakpoints, + type_hints, + } + } + + fn flip_map(map: HashMap) -> HashMap { + map.into_iter().map(|(key, value)| (value, key)).collect() + } + + fn extract_debug_information( + labelled_instructions: &[LabelledInstruction], + ) -> (Vec, HashMap>) { + let mut breakpoints = vec![]; + let mut type_hints = HashMap::<_, Vec<_>>::new(); + let mut break_before_next_instruction = false; + + let mut address = 0; + for instruction in labelled_instructions { + match instruction { + LabelledInstruction::Instruction(instruction) => { + breakpoints.extend(vec![break_before_next_instruction; instruction.size()]); + break_before_next_instruction = false; + address += instruction.size() as u64; + } + LabelledInstruction::Label(_) => (), + LabelledInstruction::Breakpoint => break_before_next_instruction = true, + LabelledInstruction::TypeHint(type_hint) => match type_hints.entry(address) { + Entry::Occupied(mut entry) => entry.get_mut().push(type_hint.clone()), + Entry::Vacant(entry) => _ = entry.insert(vec![type_hint.clone()]), + }, + } + } + + (breakpoints, type_hints) + } + + /// Create a `Program` by parsing source code. + pub fn from_code(code: &str) -> Result { + parser::parse(code) + .map(|tokens| parser::to_labelled_instructions(&tokens)) + .map(|instructions| Program::new(&instructions)) + } + + pub fn labelled_instructions(&self) -> Vec { + let call_targets = self.call_targets(); + let instructions_with_labels = self.instructions.iter().map(|instruction| { + instruction.map_call_address(|&address| self.label_for_address(address.value())) + }); + + let mut labelled_instructions = vec![]; + let mut address = 0; + let mut instruction_stream = instructions_with_labels.into_iter(); + while let Some(instruction) = instruction_stream.next() { + let instruction_size = instruction.size() as u64; + if call_targets.contains(&address) { + let label = self.label_for_address(address); + let label = LabelledInstruction::Label(label); + labelled_instructions.push(label); + } + for type_hint in self.type_hints_at(address) { + labelled_instructions.push(LabelledInstruction::TypeHint(type_hint)); + } + if self.is_breakpoint(address) { + labelled_instructions.push(LabelledInstruction::Breakpoint); + } + labelled_instructions.push(LabelledInstruction::Instruction(instruction)); + + for _ in 1..instruction_size { + instruction_stream.next(); + } + address += instruction_size; + } + + let leftover_labels = self + .address_to_label + .iter() + .filter(|(&labels_address, _)| labels_address >= address) + .sorted(); + for (_, label) in leftover_labels { + labelled_instructions.push(LabelledInstruction::Label(label.clone())); + } + + labelled_instructions + } + + fn call_targets(&self) -> HashSet { + self.instructions + .iter() + .filter_map(|instruction| match instruction { + Instruction::Call(address) => Some(address.value()), + _ => None, + }) + .collect() + } + + pub fn is_breakpoint(&self, address: u64) -> bool { + let address: usize = address.try_into().unwrap(); + self.breakpoints.get(address).unwrap_or(&false).to_owned() + } + + pub fn type_hints_at(&self, address: u64) -> Vec { + self.type_hints.get(&address).cloned().unwrap_or_default() + } + + /// Turn the program into a sequence of `BFieldElement`s. Each instruction is encoded as its + /// opcode, followed by its argument (if any). + /// + /// **Note**: This is _almost_ (but not quite!) equivalent to [encoding](BFieldCodec::encode) + /// the program. For that, use [`encode()`](Self::encode()) instead. + pub fn to_bwords(&self) -> Vec { + self.clone() + .into_iter() + .flat_map(|instruction| { + let opcode = instruction.opcode_b(); + if let Some(arg) = instruction.arg() { + vec![opcode, arg] + } else { + vec![opcode] + } + }) + .collect() + } + + /// The total length of the program as `BFieldElement`s. Double-word instructions contribute + /// two `BFieldElement`s. + pub fn len_bwords(&self) -> usize { + self.instructions.len() + } + + pub fn is_empty(&self) -> bool { + self.instructions.is_empty() + } + + /// Produces the program's canonical hash digest. Uses [`Tip5`], the + /// canonical hash function for Triton VM. + pub fn hash(&self) -> Digest { + // not encoded using `BFieldCodec` because that would prepend the length + Tip5::hash_varlen(&self.to_bwords()) + } + + /// The label for the given address, or a deterministic, unique substitute if no label is found. + pub fn label_for_address(&self, address: u64) -> String { + // Uniqueness of the label is relevant for printing and subsequent parsing: + // Parsing fails on duplicate labels. + self.address_to_label + .get(&address) + .cloned() + .unwrap_or_else(|| format!("address_{address}")) + } +} + +#[non_exhaustive] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)] +pub enum ProgramDecodingError { + #[error("sequence to decode is empty")] + EmptySequence, + + #[error("sequence to decode is too short")] + SequenceTooShort, + + #[error("sequence to decode is too long")] + SequenceTooLong, + + #[error("length of decoded program is unexpected")] + LengthMismatch, + + #[error("sequence to decode contains invalid instruction at index {0}: {1}")] + InvalidInstruction(usize, InstructionError), + + #[error("missing argument for instruction {1} at index {0}")] + MissingArgument(usize, Instruction), +} + +#[cfg(test)] +mod tests { + use assert2::assert; + use assert2::let_assert; + use proptest::prelude::*; + use proptest_arbitrary_interop::arb; + use rand::thread_rng; + use rand::Rng; + use test_strategy::proptest; + + use crate::triton_program; + + use super::*; + + #[proptest] + fn random_program_encode_decode_equivalence(#[strategy(arb())] program: Program) { + let encoding = program.encode(); + let decoding = *Program::decode(&encoding).unwrap(); + prop_assert_eq!(program, decoding); + } + + #[test] + fn decode_program_with_missing_argument_as_last_instruction() { + let program = triton_program!(push 3 push 3 eq assert push 3); + let program_length = program.len_bwords() as u64; + let encoded = program.encode(); + + let mut encoded = encoded[0..encoded.len() - 1].to_vec(); + encoded[0] = bfe!(program_length - 1); + + let_assert!(Err(err) = Program::decode(&encoded)); + let_assert!(ProgramDecodingError::MissingArgument(6, _) = err); + } + + #[test] + fn decode_program_with_shorter_than_indicated_sequence() { + let program = triton_program!(nop nop hash push 0 skiz end: halt call end); + let mut encoded = program.encode(); + encoded[0] += bfe!(1); + let_assert!(Err(err) = Program::decode(&encoded)); + let_assert!(ProgramDecodingError::SequenceTooShort = err); + } + + #[test] + fn decode_program_with_longer_than_indicated_sequence() { + let program = triton_program!(nop nop hash push 0 skiz end: halt call end); + let mut encoded = program.encode(); + encoded[0] -= bfe!(1); + let_assert!(Err(err) = Program::decode(&encoded)); + let_assert!(ProgramDecodingError::SequenceTooLong = err); + } + + #[test] + fn decode_program_from_empty_sequence() { + let encoded = vec![]; + let_assert!(Err(err) = Program::decode(&encoded)); + let_assert!(ProgramDecodingError::EmptySequence = err); + } + + #[test] + fn hash_simple_program() { + let program = triton_program!(halt); + let digest = program.hash(); + + let expected_digest = bfe_array![ + 0x4338_de79_520b_3949_u64, + 0xe6a2_129b_2885_0dc9_u64, + 0xfd3c_d098_6a86_0450_u64, + 0x69fd_ba91_0ceb_a7bc_u64, + 0x7e5b_118c_9594_c062_u64, + ]; + let expected_digest = Digest::new(expected_digest); + + assert!(expected_digest == digest); + } + + #[test] + fn empty_program_is_empty() { + let program = triton_program!(); + assert!(program.is_empty()); + } + + #[test] + fn create_program_from_code() { + let element_3 = thread_rng().gen_range(0_u64..BFieldElement::P); + let element_2 = 1337_usize; + let element_1 = "17"; + let element_0 = bfe!(0); + let instruction_push = Instruction::Push(bfe!(42)); + let dup_arg = 1; + let label = "my_label".to_string(); + + let source_code = format!( + "push {element_3} push {element_2} push {element_1} push {element_0} + call {label} halt + {label}: + {instruction_push} + dup {dup_arg} + skiz + recurse + return" + ); + let program_from_code = Program::from_code(&source_code).unwrap(); + let program_from_macro = triton_program!({ source_code }); + assert!(program_from_code == program_from_macro); + } + + #[test] + fn parser_macro_with_interpolated_label_as_first_argument() { + let label = "my_label"; + let _program = triton_program!( + {label}: push 1 assert halt + ); + } + + #[test] + fn breakpoints_propagate_to_debug_information_as_expected() { + let program = triton_program! { + break push 1 push 2 + break break break break + pop 2 hash halt + break // no effect + }; + + assert!(program.is_breakpoint(0)); + assert!(program.is_breakpoint(1)); + assert!(!program.is_breakpoint(2)); + assert!(!program.is_breakpoint(3)); + assert!(program.is_breakpoint(4)); + assert!(program.is_breakpoint(5)); + assert!(!program.is_breakpoint(6)); + assert!(!program.is_breakpoint(7)); + + // going beyond the length of the program must not break things + assert!(!program.is_breakpoint(8)); + assert!(!program.is_breakpoint(9)); + } + + #[test] + fn print_program_without_any_debug_information() { + let program = triton_program! { + call foo + call bar + call baz + halt + foo: nop nop return + bar: call baz return + baz: push 1 return + }; + let encoding = program.encode(); + let program = Program::decode(&encoding).unwrap(); + println!("{program}"); + } + + #[proptest] + fn printed_program_can_be_parsed_again(#[strategy(arb())] program: Program) { + parser::parse(&program.to_string()).unwrap(); + } + + struct TypeHintTestCase { + expected: TypeHint, + input: &'static str, + } + + impl TypeHintTestCase { + fn run(&self) { + let program = Program::from_code(self.input).unwrap(); + let [ref type_hint] = program.type_hints_at(0)[..] else { + panic!("Expected a single type hint at address 0"); + }; + assert!(&self.expected == type_hint); + } + } + + #[test] + fn parse_simple_type_hint() { + let expected = TypeHint { + starting_index: 0, + length: 1, + type_name: Some("Type".to_string()), + variable_name: "foo".to_string(), + }; + + TypeHintTestCase { + expected, + input: "hint foo: Type = stack[0]", + } + .run(); + } + + #[test] + fn parse_type_hint_with_range() { + let expected = TypeHint { + starting_index: 0, + length: 5, + type_name: Some("Digest".to_string()), + variable_name: "foo".to_string(), + }; + + TypeHintTestCase { + expected, + input: "hint foo: Digest = stack[0..5]", + } + .run(); + } + + #[test] + fn parse_type_hint_with_range_and_offset() { + let expected = TypeHint { + starting_index: 7, + length: 3, + type_name: Some("XFieldElement".to_string()), + variable_name: "bar".to_string(), + }; + + TypeHintTestCase { + expected, + input: "hint bar: XFieldElement = stack[7..10]", + } + .run(); + } + + #[test] + fn parse_type_hint_with_range_and_offset_and_weird_whitespace() { + let expected = TypeHint { + starting_index: 2, + length: 12, + type_name: Some("BigType".to_string()), + variable_name: "bar".to_string(), + }; + + TypeHintTestCase { + expected, + input: " hint \t \t bar :BigType=stack[ 2\t.. 14 ]\t \n", + } + .run(); + } + + #[test] + fn parse_type_hint_with_no_type_only_variable_name() { + let expected = TypeHint { + starting_index: 0, + length: 1, + type_name: None, + variable_name: "foo".to_string(), + }; + + TypeHintTestCase { + expected, + input: "hint foo = stack[0]", + } + .run(); + } + + #[test] + fn parse_type_hint_with_no_type_only_variable_name_with_range() { + let expected = TypeHint { + starting_index: 2, + length: 5, + type_name: None, + variable_name: "foo".to_string(), + }; + + TypeHintTestCase { + expected, + input: "hint foo = stack[2..7]", + } + .run(); + } + + #[test] + fn printing_program_includes_debug_information() { + let source_code = "\ + call foo\n\ + break\n\ + call bar\n\ + halt\n\ + foo:\n\ + break\n\ + call baz\n\ + push 1\n\ + nop\n\ + return\n\ + baz:\n\ + hash\n\ + hint my_digest: Digest = stack[0..5]\n\ + hint random_stuff = stack[17]\n\ + return\n\ + nop\n\ + pop 1\n\ + bar:\n\ + divine 1\n\ + hint got_insight: Magic = stack[0]\n\ + skiz\n\ + split\n\ + break\n\ + return\n\ + "; + let program = Program::from_code(source_code).unwrap(); + let printed_program = format!("{program}"); + assert_eq!(source_code, &printed_program); + } +} diff --git a/triton-vm/Cargo.toml b/triton-vm/Cargo.toml index 881880b9e..a3e89f28e 100644 --- a/triton-vm/Cargo.toml +++ b/triton-vm/Cargo.toml @@ -20,9 +20,11 @@ readme.workspace = true [dependencies] arbitrary.workspace = true colored.workspace = true +constraint-builder = { path = "../triton-constraint-builder", package = "triton-constraint-builder" } criterion.workspace = true get-size.workspace = true indexmap.workspace = true +isa = { path = "../triton-isa", package = "triton-isa" } itertools.workspace = true lazy_static.workspace = true ndarray.workspace = true @@ -40,7 +42,6 @@ syn.workspace = true thiserror.workspace = true twenty-first.workspace = true unicode-width.workspace = true -constraint-builder = { path = "../constraint-builder" } [dev-dependencies] assert2.workspace = true diff --git a/triton-vm/benches/cached_vs_jit_trace.rs b/triton-vm/benches/cached_vs_jit_trace.rs index b44a697e2..552c48ed5 100644 --- a/triton-vm/benches/cached_vs_jit_trace.rs +++ b/triton-vm/benches/cached_vs_jit_trace.rs @@ -18,9 +18,7 @@ fn prove_fib(c: &mut Criterion) { let program = triton_vm::example_programs::FIBONACCI_SEQUENCE.clone(); let public_input = PublicInput::from(bfe_array![N]); let non_determinism = NonDeterminism::default(); - let (aet, output) = program - .trace_execution(public_input, non_determinism) - .unwrap(); + let (aet, output) = VM::trace_execution(&program, public_input, non_determinism).unwrap(); let claim = Claim::about_program(&program) .with_input(bfe_vec![N]) .with_output(output); diff --git a/triton-vm/benches/mem_io.rs b/triton-vm/benches/mem_io.rs index 1fd5ffe10..2964387d5 100644 --- a/triton-vm/benches/mem_io.rs +++ b/triton-vm/benches/mem_io.rs @@ -66,10 +66,12 @@ impl MemIOBench { } fn performance_profile(&self) -> VMPerformanceProfile { - let (aet, output) = self - .program - .trace_execution(self.public_input.clone(), self.secret_input.clone()) - .unwrap(); + let (aet, output) = VM::trace_execution( + &self.program, + self.public_input.clone(), + self.secret_input.clone(), + ) + .unwrap(); let claim = Claim::about_program(&self.program).with_output(output); let stark = Stark::default(); diff --git a/triton-vm/benches/prove_fib.rs b/triton-vm/benches/prove_fib.rs index 5dbd584f3..1bb1ef1b0 100644 --- a/triton-vm/benches/prove_fib.rs +++ b/triton-vm/benches/prove_fib.rs @@ -18,9 +18,8 @@ criterion_group! { fn prove_fib(c: &mut Criterion) { let program = FIBONACCI_SEQUENCE.clone(); let public_input = PublicInput::new(bfe_vec![FIBONACCI_INDEX]); - let (aet, output) = program - .trace_execution(public_input.clone(), NonDeterminism::default()) - .unwrap(); + let (aet, output) = + VM::trace_execution(&program, public_input.clone(), NonDeterminism::default()).unwrap(); let claim = Claim::about_program(&program) .with_input(public_input.individual_tokens) diff --git a/triton-vm/benches/prove_halt.rs b/triton-vm/benches/prove_halt.rs index 352ed692e..598304677 100644 --- a/triton-vm/benches/prove_halt.rs +++ b/triton-vm/benches/prove_halt.rs @@ -2,12 +2,8 @@ use criterion::criterion_group; use criterion::criterion_main; use criterion::Criterion; -use triton_vm::prelude::NonDeterminism; -use triton_vm::prelude::PublicInput; -use triton_vm::proof::Claim; -use triton_vm::stark::Stark; +use triton_vm::prelude::*; use triton_vm::table::master_table::TableId; -use triton_vm::triton_program; criterion_main!(benches); @@ -20,9 +16,8 @@ criterion_group! { /// cargo criterion --bench prove_halt fn prove_halt(c: &mut Criterion) { let program = triton_program!(halt); - let (aet, output) = program - .trace_execution(PublicInput::default(), NonDeterminism::default()) - .unwrap(); + let (aet, output) = + VM::trace_execution(&program, PublicInput::default(), NonDeterminism::default()).unwrap(); let stark = Stark::default(); let claim = Claim::about_program(&program).with_output(output); diff --git a/triton-vm/benches/trace_mmr_new_peak_calculation.rs b/triton-vm/benches/trace_mmr_new_peak_calculation.rs index 20e2b1047..917a16459 100644 --- a/triton-vm/benches/trace_mmr_new_peak_calculation.rs +++ b/triton-vm/benches/trace_mmr_new_peak_calculation.rs @@ -1,5 +1,6 @@ use criterion::*; use triton_vm::example_programs; +use triton_vm::prelude::VM; criterion_main!(benches); @@ -14,7 +15,7 @@ fn run_mmr_new_peak_calculation(criterion: &mut Criterion) { criterion.bench_function("Run finding new peaks for MMR", |bencher| { bencher.iter(|| { - program.run([].into(), [].into()).unwrap(); + VM::run(&program, [].into(), [].into()).unwrap(); }); }); } @@ -24,7 +25,7 @@ fn trace_mmr_new_peak_calculation(criterion: &mut Criterion) { criterion.bench_function("Trace execution of finding new peaks for MMR", |bencher| { bencher.iter(|| { - program.trace_execution([].into(), [].into()).unwrap(); + VM::trace_execution(&program, [].into(), [].into()).unwrap(); }); }); } diff --git a/triton-vm/benches/verify_halt.rs b/triton-vm/benches/verify_halt.rs index a23384567..0f83fe2a1 100644 --- a/triton-vm/benches/verify_halt.rs +++ b/triton-vm/benches/verify_halt.rs @@ -3,9 +3,7 @@ use criterion::criterion_main; use criterion::BenchmarkId; use criterion::Criterion; -use triton_vm::proof::Claim; -use triton_vm::stark::Stark; -use triton_vm::triton_program; +use triton_vm::prelude::*; /// cargo criterion --bench verify_halt fn verify_halt(criterion: &mut Criterion) { @@ -14,7 +12,7 @@ fn verify_halt(criterion: &mut Criterion) { let stark = Stark::default(); let claim = Claim::about_program(&program); - let (aet, _) = program.trace_execution([].into(), [].into()).unwrap(); + let (aet, _) = VM::trace_execution(&program, [].into(), [].into()).unwrap(); let proof = stark.prove(&claim, &aet).unwrap(); triton_vm::profiler::start("Verify Halt"); diff --git a/triton-vm/src/aet.rs b/triton-vm/src/aet.rs index 4f3b5833a..e0abb5c89 100644 --- a/triton-vm/src/aet.rs +++ b/triton-vm/src/aet.rs @@ -4,6 +4,10 @@ use std::collections::HashMap; use std::ops::AddAssign; use arbitrary::Arbitrary; +use isa::error::InstructionError; +use isa::error::InstructionError::InstructionPointerOverflow; +use isa::instruction::Instruction; +use isa::program::Program; use itertools::Itertools; use ndarray::s; use ndarray::Array2; @@ -11,10 +15,6 @@ use ndarray::Axis; use strum::IntoEnumIterator; use twenty_first::prelude::*; -use crate::error::InstructionError; -use crate::error::InstructionError::InstructionPointerOverflow; -use crate::instruction::Instruction; -use crate::program::Program; use crate::table::hash_table::HashTable; use crate::table::hash_table::PermutationTrace; use crate::table::master_table::TableId; @@ -357,10 +357,11 @@ impl Ord for TableHeight { #[cfg(test)] mod tests { use assert2::assert; - - use crate::prelude::*; + use isa::triton_asm; + use isa::triton_program; use super::*; + use crate::prelude::*; #[test] fn pad_program_requiring_no_padding_zeros() { @@ -375,9 +376,9 @@ mod tests { #[test] fn height_of_any_table_can_be_computed() { let program = triton_program!(halt); - let (aet, _) = program - .trace_execution(PublicInput::default(), NonDeterminism::default()) - .unwrap(); + let (aet, _) = + VM::trace_execution(&program, PublicInput::default(), NonDeterminism::default()) + .unwrap(); let _ = aet.height(); for table in TableId::iter() { diff --git a/triton-vm/src/air.rs b/triton-vm/src/air.rs index 2ea2033ad..00c403a15 100644 --- a/triton-vm/src/air.rs +++ b/triton-vm/src/air.rs @@ -4,6 +4,7 @@ pub mod tasm_air_constraints; #[cfg(test)] mod test { + use isa::instruction::AnInstruction; use itertools::Itertools; use ndarray::Array1; use proptest::collection::vec; @@ -16,7 +17,6 @@ mod test { use crate::air::tasm_air_constraints::dynamic_air_constraint_evaluation_tasm; use crate::air::tasm_air_constraints::static_air_constraint_evaluation_tasm; - use crate::instruction::AnInstruction; use crate::prelude::*; use crate::table::challenges::Challenges; use crate::table::extension_table::Evaluable; diff --git a/triton-vm/src/air/tasm_air_constraints.rs b/triton-vm/src/air/tasm_air_constraints.rs index 4c9c0b1a5..95edab7db 100644 --- a/triton-vm/src/air/tasm_air_constraints.rs +++ b/triton-vm/src/air/tasm_air_constraints.rs @@ -2,9 +2,10 @@ //! Run `cargo run --bin constraint-evaluation-generator` //! to fill in this file with optimized constraints. +use isa::instruction::LabelledInstruction; + use crate::air::memory_layout::DynamicTasmConstraintEvaluationMemoryLayout; use crate::air::memory_layout::StaticTasmConstraintEvaluationMemoryLayout; -use crate::instruction::LabelledInstruction; use crate::table::constraints::ERROR_MESSAGE_GENERATE_CONSTRAINTS; pub fn static_air_constraint_evaluation_tasm( diff --git a/triton-vm/src/codegen/constraints.rs b/triton-vm/src/codegen/constraints.rs index c5eb1acc0..38719c27a 100644 --- a/triton-vm/src/codegen/constraints.rs +++ b/triton-vm/src/codegen/constraints.rs @@ -7,6 +7,8 @@ use constraint_builder::BinOp; use constraint_builder::CircuitExpression; use constraint_builder::ConstraintCircuit; use constraint_builder::InputIndicator; +use isa::instruction::Instruction; +use isa::op_stack::NumberOfWords; use itertools::Itertools; use proc_macro2::TokenStream; use quote::format_ident; @@ -15,9 +17,6 @@ use quote::ToTokens; use twenty_first::prelude::x_field_element::EXTENSION_DEGREE; use twenty_first::prelude::*; -use crate::instruction::Instruction; -use crate::op_stack::NumberOfWords; - use crate::codegen::Constraints; pub(crate) trait Codegen { diff --git a/triton-vm/src/error.rs b/triton-vm/src/error.rs index 5dab9d872..50ce9306c 100644 --- a/triton-vm/src/error.rs +++ b/triton-vm/src/error.rs @@ -1,18 +1,17 @@ +pub use isa::error::InstructionError; + use std::fmt; use std::fmt::Display; use std::fmt::Formatter; -use std::num::TryFromIntError; use thiserror::Error; use twenty_first::error::MerkleTreeError; use twenty_first::prelude::*; -use crate::instruction::Instruction; use crate::proof_item::ProofItem; use crate::proof_item::ProofItemVariant; use crate::proof_stream::ProofStream; use crate::vm::VMState; -use crate::BFieldElement; /// Indicates a runtime error that resulted in a crash of Triton VM. #[derive(Debug, Clone, Eq, PartialEq, Error)] @@ -39,61 +38,6 @@ impl Display for VMError { } } -#[non_exhaustive] -#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)] -pub enum InstructionError { - #[error("opcode {0} is invalid")] - InvalidOpcode(u32), - - #[error("opcode is out of range: {0}")] - OutOfRangeOpcode(#[from] TryFromIntError), - - #[error("invalid argument {1} for instruction `{0}`")] - IllegalArgument(Instruction, BFieldElement), - - #[error("instruction pointer points outside of program")] - InstructionPointerOverflow, - - #[error("operational stack is too shallow")] - OpStackTooShallow, - - #[error("jump stack is empty")] - JumpStackIsEmpty, - - #[error("assertion failed: st0 must be 1")] - AssertionFailed, - - #[error("vector assertion failed: stack[{0}] != stack[{}]", .0 + Digest::LEN)] - VectorAssertionFailed(usize), - - #[error("0 does not have a multiplicative inverse")] - InverseOfZero, - - #[error("division by 0 is impossible")] - DivisionByZero, - - #[error("the Sponge state must be initialized before it can be used")] - SpongeNotInitialized, - - #[error("the logarithm of 0 does not exist")] - LogarithmOfZero, - - #[error("failed to convert BFieldElement {0} into u32")] - FailedU32Conversion(BFieldElement), - - #[error("public input buffer is empty after {0} reads")] - EmptyPublicInput(usize), - - #[error("secret input buffer is empty after {0} reads")] - EmptySecretInput(usize), - - #[error("no more secret digests available")] - EmptySecretDigestInput, - - #[error("Triton VM has halted and cannot execute any further instructions")] - MachineHalted, -} - #[non_exhaustive] #[derive(Debug, Copy, Clone, Eq, PartialEq, Error)] pub enum ArithmeticDomainError { @@ -183,28 +127,6 @@ pub enum FriValidationError { ArithmeticDomainError(#[from] ArithmeticDomainError), } -#[non_exhaustive] -#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)] -pub enum ProgramDecodingError { - #[error("sequence to decode is empty")] - EmptySequence, - - #[error("sequence to decode is too short")] - SequenceTooShort, - - #[error("sequence to decode is too long")] - SequenceTooLong, - - #[error("length of decoded program is unexpected")] - LengthMismatch, - - #[error("sequence to decode contains invalid instruction at index {0}: {1}")] - InvalidInstruction(usize, InstructionError), - - #[error("missing argument for instruction {1} at index {0}")] - MissingArgument(usize, Instruction), -} - #[non_exhaustive] #[derive(Debug, Clone, Eq, PartialEq, Error)] pub enum ProvingError { @@ -282,70 +204,51 @@ pub enum VerificationError { FriValidationError(#[from] FriValidationError), } -#[non_exhaustive] -#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)] -pub enum OpStackElementError { - #[error("index {0} is out of range for `OpStackElement`")] - IndexOutOfBounds(u32), - - #[error(transparent)] - FailedIntegerConversion(#[from] TryFromIntError), -} - -#[non_exhaustive] -#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)] -pub enum NumberOfWordsError { - #[error("index {0} is out of range for `NumberOfWords`")] - IndexOutOfBounds(usize), - - #[error(transparent)] - FailedIntegerConversion(#[from] TryFromIntError), -} - #[cfg(test)] mod tests { use assert2::assert; use assert2::let_assert; + use isa::op_stack::OpStackError; + use isa::triton_program; use proptest::prelude::*; use proptest_arbitrary_interop::arb; use test_strategy::proptest; - use crate::triton_program; - use super::*; + use crate::prelude::VM; #[test] fn instruction_pointer_overflow() { let program = triton_program!(nop); - let_assert!(Err(err) = program.run([].into(), [].into())); + let_assert!(Err(err) = VM::run(&program, [].into(), [].into())); let_assert!(InstructionError::InstructionPointerOverflow = err.source); } #[test] fn shrink_op_stack_too_much() { let program = triton_program!(pop 3 halt); - let_assert!(Err(err) = program.run([].into(), [].into())); - let_assert!(InstructionError::OpStackTooShallow = err.source); + let_assert!(Err(err) = VM::run(&program, [].into(), [].into())); + let_assert!(InstructionError::OpStackError(OpStackError::TooShallow) = err.source); } #[test] fn return_without_call() { let program = triton_program!(return halt); - let_assert!(Err(err) = program.run([].into(), [].into())); + let_assert!(Err(err) = VM::run(&program, [].into(), [].into())); let_assert!(InstructionError::JumpStackIsEmpty = err.source); } #[test] fn recurse_without_call() { let program = triton_program!(recurse halt); - let_assert!(Err(err) = program.run([].into(), [].into())); + let_assert!(Err(err) = VM::run(&program, [].into(), [].into())); let_assert!(InstructionError::JumpStackIsEmpty = err.source); } #[test] fn assert_false() { let program = triton_program!(push 0 assert halt); - let_assert!(Err(err) = program.run([].into(), [].into())); + let_assert!(Err(err) = VM::run(&program, [].into(), [].into())); let_assert!(InstructionError::AssertionFailed = err.source); } @@ -356,7 +259,7 @@ mod tests { push 4 push 3 push 2 push 10 push 0 assert_vector halt }; - let_assert!(Err(err) = program.run([].into(), [].into())); + let_assert!(Err(err) = VM::run(&program, [].into(), [].into())); let_assert!(InstructionError::VectorAssertionFailed(index) = err.source); assert!(1 == index); } @@ -389,7 +292,7 @@ mod tests { halt }; - let_assert!(Err(err) = program.run([].into(), [].into())); + let_assert!(Err(err) = VM::run(&program, [].into(), [].into())); let_assert!(InstructionError::VectorAssertionFailed(index) = err.source); prop_assert_eq!(disturbance_index, index); } @@ -397,36 +300,37 @@ mod tests { #[test] fn inverse_of_zero() { let program = triton_program!(push 0 invert halt); - let_assert!(Err(err) = program.run([].into(), [].into())); + let_assert!(Err(err) = VM::run(&program, [].into(), [].into())); let_assert!(InstructionError::InverseOfZero = err.source); } #[test] fn xfe_inverse_of_zero() { let program = triton_program!(push 0 push 0 push 0 x_invert halt); - let_assert!(Err(err) = program.run([].into(), [].into())); + let_assert!(Err(err) = VM::run(&program, [].into(), [].into())); let_assert!(InstructionError::InverseOfZero = err.source); } #[test] fn division_by_zero() { let program = triton_program!(push 0 push 5 div_mod halt); - let_assert!(Err(err) = program.run([].into(), [].into())); + let_assert!(Err(err) = VM::run(&program, [].into(), [].into())); let_assert!(InstructionError::DivisionByZero = err.source); } #[test] fn log_of_zero() { let program = triton_program!(push 0 log_2_floor halt); - let_assert!(Err(err) = program.run([].into(), [].into())); + let_assert!(Err(err) = VM::run(&program, [].into(), [].into())); let_assert!(InstructionError::LogarithmOfZero = err.source); } #[test] fn failed_u32_conversion() { let program = triton_program!(push 4294967297 push 1 and halt); - let_assert!(Err(err) = program.run([].into(), [].into())); - let_assert!(InstructionError::FailedU32Conversion(element) = err.source); + let_assert!(Err(err) = VM::run(&program, [].into(), [].into())); + let_assert!(InstructionError::OpStackError(err) = err.source); + let_assert!(OpStackError::FailedU32Conversion(element) = err); assert!(4_294_967_297 == element.value()); } } diff --git a/triton-vm/src/example_programs.rs b/triton-vm/src/example_programs.rs index 92c29992b..d2f25f42c 100644 --- a/triton-vm/src/example_programs.rs +++ b/triton-vm/src/example_programs.rs @@ -1,8 +1,7 @@ +use isa::program::Program; +use isa::triton_program; use lazy_static::lazy_static; -use crate::program::Program; -use crate::triton_program; - lazy_static! { pub static ref FIBONACCI_SEQUENCE: Program = fibonacci_sequence(); pub static ref GREATEST_COMMON_DIVISOR: Program = greatest_common_divisor(); diff --git a/triton-vm/src/execution_trace_profiler.rs b/triton-vm/src/execution_trace_profiler.rs new file mode 100644 index 000000000..d05eaf3e7 --- /dev/null +++ b/triton-vm/src/execution_trace_profiler.rs @@ -0,0 +1,274 @@ +use std::collections::HashSet; +use std::fmt::Display; +use std::fmt::Formatter; +use std::fmt::Result as FmtResult; +use std::ops::Add; +use std::ops::AddAssign; +use std::ops::Sub; + +use arbitrary::Arbitrary; +use twenty_first::prelude::*; + +use crate::table::hash_table::PERMUTATION_TRACE_LENGTH; +use crate::table::u32_table::U32TableEntry; +use crate::vm::CoProcessorCall; + +#[derive(Debug, Default, Clone, Eq, PartialEq, Arbitrary)] +pub(crate) struct ExecutionTraceProfiler { + call_stack: Vec, + profile: Vec, + table_heights: VMTableHeights, + u32_table_entries: HashSet, +} + +/// A single line in a [profile report](ExecutionTraceProfile) for profiling +/// [Triton](crate) programs. +#[derive(Debug, Default, Clone, Eq, PartialEq, Hash, Arbitrary)] +pub struct ProfileLine { + pub label: String, + pub call_depth: usize, + + /// Table heights at the start of this span, _i.e._, right before the corresponding + /// [`call`](isa::instruction::Instruction::Call) instruction was executed. + pub table_heights_start: VMTableHeights, + + table_heights_stop: VMTableHeights, +} + +/// A report for the completed execution of a [Triton](crate) program. +/// +/// Offers a human-readable [`Display`] implementation and can be processed +/// programmatically. +#[derive(Debug, Clone, Eq, PartialEq, Hash, Arbitrary)] +pub struct ExecutionTraceProfile { + pub total: VMTableHeights, + pub profile: Vec, +} + +/// The heights of various [tables](crate::aet::AlgebraicExecutionTrace) relevant for +/// proving the correct execution in [Triton VM](crate). +#[non_exhaustive] +#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)] +pub struct VMTableHeights { + pub processor: u32, + pub op_stack: u32, + pub ram: u32, + pub hash: u32, + pub u32: u32, +} + +impl ExecutionTraceProfiler { + pub fn new(num_instructions: usize) -> Self { + Self { + call_stack: vec![], + profile: vec![], + table_heights: VMTableHeights::new(num_instructions), + u32_table_entries: HashSet::default(), + } + } + + pub fn enter_span(&mut self, label: impl Into) { + let call_stack_len = self.call_stack.len(); + let line_number = self.profile.len(); + + let profile_line = ProfileLine { + label: label.into(), + call_depth: call_stack_len, + table_heights_start: self.table_heights, + table_heights_stop: VMTableHeights::default(), + }; + + self.profile.push(profile_line); + self.call_stack.push(line_number); + } + + pub fn exit_span(&mut self) { + if let Some(line_number) = self.call_stack.pop() { + self.profile[line_number].table_heights_stop = self.table_heights; + }; + } + + pub fn handle_co_processor_calls(&mut self, calls: Vec) { + self.table_heights.processor += 1; + for call in calls { + match call { + CoProcessorCall::SpongeStateReset => self.table_heights.hash += 1, + CoProcessorCall::Tip5Trace(_, trace) => { + self.table_heights.hash += u32::try_from(trace.len()).unwrap(); + } + CoProcessorCall::U32Call(c) => { + self.u32_table_entries.insert(c); + let contribution = U32TableEntry::table_height_contribution; + self.table_heights.u32 = self.u32_table_entries.iter().map(contribution).sum(); + } + CoProcessorCall::OpStackCall(_) => self.table_heights.op_stack += 1, + CoProcessorCall::RamCall(_) => self.table_heights.ram += 1, + } + } + } + + pub fn finish(mut self) -> ExecutionTraceProfile { + for &line_number in &self.call_stack { + self.profile[line_number].table_heights_stop = self.table_heights; + } + + ExecutionTraceProfile { + total: self.table_heights, + profile: self.profile, + } + } +} + +impl VMTableHeights { + fn new(num_instructions: usize) -> Self { + let padded_program_len = (num_instructions + 1).next_multiple_of(Tip5::RATE); + let num_absorbs = padded_program_len / Tip5::RATE; + let initial_hash_table_len = num_absorbs * PERMUTATION_TRACE_LENGTH; + + Self { + hash: initial_hash_table_len.try_into().unwrap(), + ..Default::default() + } + } +} + +impl Sub for VMTableHeights { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Self { + processor: self.processor.saturating_sub(rhs.processor), + op_stack: self.op_stack.saturating_sub(rhs.op_stack), + ram: self.ram.saturating_sub(rhs.ram), + hash: self.hash.saturating_sub(rhs.hash), + u32: self.u32.saturating_sub(rhs.u32), + } + } +} + +impl Add for VMTableHeights { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Self { + processor: self.processor + rhs.processor, + op_stack: self.op_stack + rhs.op_stack, + ram: self.ram + rhs.ram, + hash: self.hash + rhs.hash, + u32: self.u32 + rhs.u32, + } + } +} + +impl AddAssign for VMTableHeights { + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl ProfileLine { + fn table_height_contributions(&self) -> VMTableHeights { + self.table_heights_stop - self.table_heights_start + } +} + +impl Display for ProfileLine { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + let indentation = " ".repeat(self.call_depth); + let label = &self.label; + let cycle_count = self.table_height_contributions().processor; + write!(f, "{indentation}{label}: {cycle_count}") + } +} + +impl Display for ExecutionTraceProfile { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + struct AggregateLine { + label: String, + call_depth: usize, + table_heights: VMTableHeights, + } + + const COL_WIDTH: usize = 20; + + let mut aggregated: Vec = vec![]; + for line in &self.profile { + if let Some(agg) = aggregated + .iter_mut() + .find(|agg| agg.label == line.label && agg.call_depth == line.call_depth) + { + agg.table_heights += line.table_height_contributions(); + } else { + aggregated.push(AggregateLine { + label: line.label.clone(), + call_depth: line.call_depth, + table_heights: line.table_height_contributions(), + }); + } + } + aggregated.push(AggregateLine { + label: "Total".to_string(), + call_depth: 0, + table_heights: self.total, + }); + + let label = |line: &AggregateLine| "··".repeat(line.call_depth) + &line.label; + let label_len = |line| label(line).len(); + + let max_label_len = aggregated.iter().map(label_len).max(); + let max_label_len = max_label_len.unwrap_or_default().max(COL_WIDTH); + + let [subroutine, processor, op_stack, ram, hash, u32_title] = + ["Subroutine", "Processor", "Op Stack", "RAM", "Hash", "U32"]; + + write!(f, "| {subroutine:COL_WIDTH$} ")?; + write!(f, "| {op_stack:>COL_WIDTH$} ")?; + write!(f, "| {ram:>COL_WIDTH$} ")?; + write!(f, "| {hash:>COL_WIDTH$} ")?; + write!(f, "| {u32_title:>COL_WIDTH$} ")?; + writeln!(f, "|")?; + + let dash = "-"; + write!(f, "|:{dash:-COL_WIDTH$}:")?; + write!(f, "|-{dash:->COL_WIDTH$}:")?; + write!(f, "|-{dash:->COL_WIDTH$}:")?; + write!(f, "|-{dash:->COL_WIDTH$}:")?; + write!(f, "|-{dash:->COL_WIDTH$}:")?; + writeln!(f, "|")?; + + for line in &aggregated { + let rel_precision = 1; + let rel_width = 3 + 1 + rel_precision; // eg '100.0' + let abs_width = COL_WIDTH - rel_width - 4; // ' (' and '%)' + + let label = label(line); + let proc_abs = line.table_heights.processor; + let proc_rel = 100.0 * f64::from(proc_abs) / f64::from(self.total.processor); + let proc_rel = format!("{proc_rel:.rel_precision$}"); + let stack_abs = line.table_heights.op_stack; + let stack_rel = 100.0 * f64::from(stack_abs) / f64::from(self.total.op_stack); + let stack_rel = format!("{stack_rel:.rel_precision$}"); + let ram_abs = line.table_heights.ram; + let ram_rel = 100.0 * f64::from(ram_abs) / f64::from(self.total.ram); + let ram_rel = format!("{ram_rel:.rel_precision$}"); + let hash_abs = line.table_heights.hash; + let hash_rel = 100.0 * f64::from(hash_abs) / f64::from(self.total.hash); + let hash_rel = format!("{hash_rel:.rel_precision$}"); + let u32_abs = line.table_heights.u32; + let u32_rel = 100.0 * f64::from(u32_abs) / f64::from(self.total.u32); + let u32_rel = format!("{u32_rel:.rel_precision$}"); + + write!(f, "| {label:abs_width$} ({proc_rel:>rel_width$}%) ")?; + write!(f, "| {stack_abs:>abs_width$} ({stack_rel:>rel_width$}%) ")?; + write!(f, "| {ram_abs:>abs_width$} ({ram_rel:>rel_width$}%) ")?; + write!(f, "| {hash_abs:>abs_width$} ({hash_rel:>rel_width$}%) ")?; + write!(f, "| {u32_abs:>abs_width$} ({u32_rel:>rel_width$}%) ")?; + writeln!(f, "|")?; + } + + Ok(()) + } +} diff --git a/triton-vm/src/lib.rs b/triton-vm/src/lib.rs index e72d8c51b..6bdac5ea3 100644 --- a/triton-vm/src/lib.rs +++ b/triton-vm/src/lib.rs @@ -140,29 +140,27 @@ //! Successful termination of a program is not guaranteed. For example, a program must execute //! `halt` as its last instruction. Certain instructions, such as `assert`, `invert`, or the u32 //! instructions, can also cause the VM to crash. Upon crashing Triton VM, methods like -//! [`run`](Program::run) and [`trace_execution`][trace_execution] will return a +//! [`run`](VM::run) and [`trace_execution`](VM::trace_execution) will return a //! [`VMError`][vm_error]. This can be helpful for debugging. //! //! ``` //! # use triton_vm::*; //! # use triton_vm::prelude::*; //! let crashing_program = triton_program!(push 2 assert halt); -//! let vm_error = crashing_program.run([].into(), [].into()).unwrap_err(); +//! let vm_error = VM::run(&crashing_program, [].into(), [].into()).unwrap_err(); //! assert!(matches!(vm_error.source, InstructionError::AssertionFailed)); //! // inspect the VM state //! eprintln!("{vm_error}"); //! ``` //! //! [vm_error]: error::VMError -//! [trace_execution]: Program::trace_execution #![recursion_limit = "4096"] -#![warn(let_underscore_drop)] -#![warn(missing_copy_implementations)] -#![warn(missing_debug_implementations)] pub use twenty_first; +use isa::program::Program; + use crate::error::ProvingError; use crate::prelude::*; @@ -173,14 +171,11 @@ mod codegen; pub mod config; pub mod error; pub mod example_programs; +pub mod execution_trace_profiler; pub mod fri; -pub mod instruction; mod ndarray_helper; -pub mod op_stack; -pub mod parser; pub mod prelude; pub mod profiler; -pub mod program; pub mod proof; pub mod proof_item; pub mod proof_stream; @@ -191,312 +186,6 @@ pub mod vm; #[cfg(test)] mod shared_tests; -/// Compile an entire program written in [Triton assembly][tasm]. -/// The resulting [`Program`](Program) can be [run](Program::run). -/// -/// It is possible to use string-like interpolation to insert instructions, arguments, labels, -/// or other substrings into the program. -/// -/// # Examples -/// -/// ``` -/// # use triton_vm::prelude::*; -/// let program = triton_program!( -/// read_io 1 push 5 mul -/// call check_eq_15 -/// push 17 write_io 1 -/// halt -/// // assert that the top of the stack is 15 -/// check_eq_15: -/// push 15 eq assert -/// return -/// ); -/// let public_input = PublicInput::from([bfe!(3)]); -/// let secret_input = NonDeterminism::default(); -/// let output = program.run(public_input, secret_input).unwrap(); -/// assert_eq!(17, output[0].value()); -/// ``` -/// -/// Any type with an appropriate [`Display`](std::fmt::Display) implementation can be -/// interpolated. This includes, for example, primitive types like `u64` and `&str`, but also -/// [`Instruction`](instruction::Instruction)s, -/// [`BFieldElement`](BFieldElement)s, and -/// [`Label`](instruction::LabelledInstruction)s, among others. -/// -/// ``` -/// # use triton_vm::prelude::*; -/// # use triton_vm::instruction::Instruction; -/// let element_0 = BFieldElement::new(0); -/// let label = "my_label"; -/// let instruction_push = Instruction::Push(bfe!(42)); -/// let dup_arg = 1; -/// let program = triton_program!( -/// push {element_0} -/// call {label} halt -/// {label}: -/// {instruction_push} -/// dup {dup_arg} -/// skiz recurse return -/// ); -/// ``` -/// -/// # Panics -/// -/// **Panics** if the program cannot be parsed. -/// Examples for parsing errors are: -/// - unknown (_e.g._ misspelled) instructions -/// - invalid instruction arguments, _e.g._, `push 1.5` or `swap 42` -/// - missing or duplicate labels -/// - invalid labels, _e.g._, using a reserved keyword or starting a label with a digit -/// -/// For a version that returns a `Result`, see [`Program::from_code()`][from_code]. -/// -/// [tasm]: https://triton-vm.org/spec/instructions.html -/// [from_code]: Program::from_code -#[macro_export] -macro_rules! triton_program { - {$($source_code:tt)*} => {{ - let labelled_instructions = $crate::triton_asm!($($source_code)*); - $crate::program::Program::new(&labelled_instructions) - }}; -} - -/// Compile [Triton assembly][tasm] into a list of labelled -/// [`Instruction`](instruction::LabelledInstruction)s. -/// Similar to [`triton_program!`](triton_program), it is possible to use string-like -/// interpolation to insert instructions, arguments, labels, or other expressions. -/// -/// Similar to [`vec!`], a single instruction can be repeated a specified number of times. -/// -/// Furthermore, a list of [`LabelledInstruction`](instruction::LabelledInstruction)s -/// can be inserted like so: `{&list}`. -/// -/// The labels for instruction `call`, if any, are also parsed. Instruction `call` can refer to -/// a label defined later in the program, _i.e.,_ labels are not checked for existence or -/// uniqueness by this parser. -/// -/// # Examples -/// -/// ``` -/// # use triton_vm::triton_asm; -/// let push_argument = 42; -/// let instructions = triton_asm!( -/// push 1 call some_label -/// push {push_argument} -/// some_other_label: skiz halt return -/// ); -/// assert_eq!(7, instructions.len()); -/// ``` -/// -/// One instruction repeated several times: -/// -/// ``` -/// # use triton_vm::triton_asm; -/// # use triton_vm::instruction::LabelledInstruction; -/// # use triton_vm::instruction::AnInstruction::SpongeAbsorb; -/// let instructions = triton_asm![sponge_absorb; 3]; -/// assert_eq!(3, instructions.len()); -/// assert_eq!(LabelledInstruction::Instruction(SpongeAbsorb), instructions[0]); -/// assert_eq!(LabelledInstruction::Instruction(SpongeAbsorb), instructions[1]); -/// assert_eq!(LabelledInstruction::Instruction(SpongeAbsorb), instructions[2]); -/// ``` -/// -/// Inserting substring of labelled instructions: -/// -/// ``` -/// # use triton_vm::prelude::*; -/// # use triton_vm::instruction::AnInstruction::Push; -/// # use triton_vm::instruction::AnInstruction::Pop; -/// # use triton_vm::op_stack::NumberOfWords::N1; -/// let insert_me = triton_asm!( -/// pop 1 -/// nop -/// pop 1 -/// ); -/// let surrounding_code = triton_asm!( -/// push 0 -/// {&insert_me} -/// push 1 -/// ); -/// # let zero = bfe!(0); -/// # assert_eq!(LabelledInstruction::Instruction(Push(zero)), surrounding_code[0]); -/// assert_eq!(LabelledInstruction::Instruction(Pop(N1)), surrounding_code[1]); -/// assert_eq!(LabelledInstruction::Instruction(Pop(N1)), surrounding_code[3]); -/// # let one = bfe!(1); -/// # assert_eq!(LabelledInstruction::Instruction(Push(one)), surrounding_code[4]); -///``` -/// -/// # Panics -/// -/// **Panics** if the instructions cannot be parsed. -/// For examples, see [`triton_program!`](triton_program), with the exception that -/// labels are not checked for existence or uniqueness. -/// -/// [tasm]: https://triton-vm.org/spec/instructions.html -#[macro_export] -macro_rules! triton_asm { - (@fmt $fmt:expr, $($args:expr,)*; ) => { - format_args!($fmt $(,$args)*).to_string() - }; - (@fmt $fmt:expr, $($args:expr,)*; - hint $var:ident: $ty:ident = stack[$start:literal..$end:literal] $($tail:tt)*) => { - $crate::triton_asm!(@fmt - concat!($fmt, " hint {}: {} = stack[{}..{}] "), - $($args,)* stringify!($var), stringify!($ty), $start, $end,; - $($tail)* - ) - }; - (@fmt $fmt:expr, $($args:expr,)*; - hint $var:ident = stack[$start:literal..$end:literal] $($tail:tt)*) => { - $crate::triton_asm!(@fmt - concat!($fmt, " hint {} = stack[{}..{}] "), - $($args,)* stringify!($var), $start, $end,; - $($tail)* - ) - }; - (@fmt $fmt:expr, $($args:expr,)*; - hint $var:ident: $ty:ident = stack[$index:literal] $($tail:tt)*) => { - $crate::triton_asm!(@fmt - concat!($fmt, " hint {}: {} = stack[{}] "), - $($args,)* stringify!($var), stringify!($ty), $index,; - $($tail)* - ) - }; - (@fmt $fmt:expr, $($args:expr,)*; - hint $var:ident = stack[$index:literal] $($tail:tt)*) => { - $crate::triton_asm!(@fmt - concat!($fmt, " hint {} = stack[{}] "), - $($args,)* stringify!($var), $index,; - $($tail)* - ) - }; - (@fmt $fmt:expr, $($args:expr,)*; $label_declaration:ident: $($tail:tt)*) => { - $crate::triton_asm!(@fmt - concat!($fmt, " ", stringify!($label_declaration), ": "), $($args,)*; $($tail)* - ) - }; - (@fmt $fmt:expr, $($args:expr,)*; $instruction:ident $($tail:tt)*) => { - $crate::triton_asm!(@fmt - concat!($fmt, " ", stringify!($instruction), " "), $($args,)*; $($tail)* - ) - }; - (@fmt $fmt:expr, $($args:expr,)*; $instruction_argument:literal $($tail:tt)*) => { - $crate::triton_asm!(@fmt - concat!($fmt, " ", stringify!($instruction_argument), " "), $($args,)*; $($tail)* - ) - }; - (@fmt $fmt:expr, $($args:expr,)*; {$label_declaration:expr}: $($tail:tt)*) => { - $crate::triton_asm!(@fmt concat!($fmt, "{}: "), $($args,)* $label_declaration,; $($tail)*) - }; - (@fmt $fmt:expr, $($args:expr,)*; {&$instruction_list:expr} $($tail:tt)*) => { - $crate::triton_asm!(@fmt - concat!($fmt, "{} "), $($args,)* - $instruction_list.iter().map(|instr| instr.to_string()).collect::>().join(" "),; - $($tail)* - ) - }; - (@fmt $fmt:expr, $($args:expr,)*; {$expression:expr} $($tail:tt)*) => { - $crate::triton_asm!(@fmt concat!($fmt, "{} "), $($args,)* $expression,; $($tail)*) - }; - - // repeated instructions - [pop $arg:literal; $num:expr] => { vec![ $crate::triton_instr!(pop $arg); $num ] }; - [push $arg:literal; $num:expr] => { vec![ $crate::triton_instr!(push $arg); $num ] }; - [divine $arg:literal; $num:expr] => { vec![ $crate::triton_instr!(divine $arg); $num ] }; - [dup $arg:literal; $num:expr] => { vec![ $crate::triton_instr!(dup $arg); $num ] }; - [swap $arg:literal; $num:expr] => { vec![ $crate::triton_instr!(swap $arg); $num ] }; - [call $arg:ident; $num:expr] => { vec![ $crate::triton_instr!(call $arg); $num ] }; - [read_mem $arg:literal; $num:expr] => { vec![ $crate::triton_instr!(read_mem $arg); $num ] }; - [write_mem $arg:literal; $num:expr] => { vec![ $crate::triton_instr!(write_mem $arg); $num ] }; - [read_io $arg:literal; $num:expr] => { vec![ $crate::triton_instr!(read_io $arg); $num ] }; - [write_io $arg:literal; $num:expr] => { vec![ $crate::triton_instr!(write_io $arg); $num ] }; - [$instr:ident; $num:expr] => { vec![ $crate::triton_instr!($instr); $num ] }; - - // entry point - {$($source_code:tt)*} => {{ - let source_code = $crate::triton_asm!(@fmt "",; $($source_code)*); - let (_, instructions) = $crate::parser::tokenize(&source_code).unwrap(); - $crate::parser::to_labelled_instructions(&instructions) - }}; -} - -/// Compile a single [Triton assembly][tasm] instruction into a -/// [`LabelledInstruction`](instruction::LabelledInstruction). -/// -/// # Examples -/// -/// ``` -/// # use triton_vm::triton_instr; -/// # use triton_vm::instruction::LabelledInstruction; -/// # use triton_vm::instruction::AnInstruction::Call; -/// let instruction = triton_instr!(call my_label); -/// assert_eq!(LabelledInstruction::Instruction(Call("my_label".to_string())), instruction); -/// ``` -/// -/// [tasm]: https://triton-vm.org/spec/instructions.html -#[macro_export] -macro_rules! triton_instr { - (pop $arg:literal) => {{ - let argument = $crate::op_stack::NumberOfWords::try_from($arg).unwrap(); - let instruction = $crate::instruction::AnInstruction::::Pop(argument); - $crate::instruction::LabelledInstruction::Instruction(instruction) - }}; - (push $arg:expr) => {{ - let argument = $crate::prelude::BFieldElement::from($arg); - let instruction = $crate::instruction::AnInstruction::::Push(argument); - $crate::instruction::LabelledInstruction::Instruction(instruction) - }}; - (divine $arg:literal) => {{ - let argument = $crate::op_stack::NumberOfWords::try_from($arg).unwrap(); - let instruction = $crate::instruction::AnInstruction::::Divine(argument); - $crate::instruction::LabelledInstruction::Instruction(instruction) - }}; - (dup $arg:literal) => {{ - let argument = $crate::op_stack::OpStackElement::try_from($arg).unwrap(); - let instruction = $crate::instruction::AnInstruction::::Dup(argument); - $crate::instruction::LabelledInstruction::Instruction(instruction) - }}; - (swap $arg:literal) => {{ - let argument = $crate::op_stack::OpStackElement::try_from($arg).unwrap(); - let instruction = $crate::instruction::AnInstruction::::Swap(argument); - $crate::instruction::LabelledInstruction::Instruction(instruction) - }}; - (call $arg:ident) => {{ - let argument = stringify!($arg).to_string(); - let instruction = $crate::instruction::AnInstruction::::Call(argument); - $crate::instruction::LabelledInstruction::Instruction(instruction) - }}; - (read_mem $arg:literal) => {{ - let argument = $crate::op_stack::NumberOfWords::try_from($arg).unwrap(); - let instruction = $crate::instruction::AnInstruction::::ReadMem(argument); - $crate::instruction::LabelledInstruction::Instruction(instruction) - }}; - (write_mem $arg:literal) => {{ - let argument = $crate::op_stack::NumberOfWords::try_from($arg).unwrap(); - let instruction = $crate::instruction::AnInstruction::::WriteMem(argument); - $crate::instruction::LabelledInstruction::Instruction(instruction) - }}; - (addi $arg:expr) => {{ - let argument = $crate::prelude::BFieldElement::from($arg); - let instruction = $crate::instruction::AnInstruction::::AddI(argument); - $crate::instruction::LabelledInstruction::Instruction(instruction) - }}; - (read_io $arg:literal) => {{ - let argument = $crate::op_stack::NumberOfWords::try_from($arg).unwrap(); - let instruction = $crate::instruction::AnInstruction::::ReadIo(argument); - $crate::instruction::LabelledInstruction::Instruction(instruction) - }}; - (write_io $arg:literal) => {{ - let argument = $crate::op_stack::NumberOfWords::try_from($arg).unwrap(); - let instruction = $crate::instruction::AnInstruction::::WriteIo(argument); - $crate::instruction::LabelledInstruction::Instruction(instruction) - }}; - ($instr:ident) => {{ - let (_, instructions) = $crate::parser::tokenize(stringify!($instr)).unwrap(); - instructions[0].to_labelled_instruction() - }}; -} - /// Prove correct execution of a program written in Triton assembly. /// This is a convenience function, abstracting away the details of the STARK construction. /// If you want to have more control over the STARK construction, this method can serve as a @@ -527,7 +216,7 @@ pub fn prove_program( // - if any of the two inputs does not conform to the program, // - because of a bug in the program, among other things. // If the VM crashes, proof generation will fail. - let (aet, public_output) = program.trace_execution(public_input.clone(), non_determinism)?; + let (aet, public_output) = VM::trace_execution(program, public_input.clone(), non_determinism)?; // Set up the claim that is to be proven. The claim contains all public information. The // proof is zero-knowledge with respect to everything else. @@ -562,7 +251,8 @@ pub fn prove( if program_digest != claim.program_digest { return Err(ProvingError::ProgramDigestMismatch); } - let (aet, public_output) = program.trace_execution((&claim.input).into(), non_determinism)?; + let (aet, public_output) = + VM::trace_execution(program, (&claim.input).into(), non_determinism)?; if public_output != claim.output { return Err(ProvingError::PublicOutputMismatch); } @@ -582,15 +272,15 @@ pub fn verify(stark: Stark, claim: &Claim, proof: &Proof) -> bool { mod tests { use assert2::assert; use assert2::let_assert; + use isa::instruction::LabelledInstruction; + use isa::instruction::TypeHint; use proptest::prelude::*; use proptest_arbitrary_interop::arb; use test_strategy::proptest; - use twenty_first::prelude::tip5::Tip5; use twenty_first::prelude::*; use twenty_first::util_types::algebraic_hasher::AlgebraicHasher; - use crate::instruction::LabelledInstruction; - use crate::instruction::TypeHint; + use crate::prelude::*; use super::*; @@ -625,21 +315,18 @@ mod tests { implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); + implements_auto_traits::(); implements_auto_traits::(); // errors implements_auto_traits::(); - implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); - implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); // table things implements_auto_traits::(); @@ -703,19 +390,10 @@ mod tests { implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); - implements_auto_traits::>(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); @@ -820,7 +498,7 @@ mod tests { let source_code = triton_asm!(push 6 {&snippet_0} {&snippet_1} halt); let program = triton_program!({ &source_code }); - let public_output = program.run([].into(), [].into()).unwrap(); + let public_output = VM::run(&program, [].into(), [].into()).unwrap(); let expected_output = bfe_vec![9, 8, 7, 6]; assert_eq!(expected_output, public_output); @@ -831,7 +509,7 @@ mod tests { let push_25 = triton_asm![push 0; 25]; let pop_25 = triton_asm![pop 5; 5]; let program = triton_program! { push 1 { &push_25 } { &pop_25 } assert halt }; - program.run([].into(), [].into()).unwrap(); + VM::run(&program, [].into(), [].into()).unwrap(); } #[test] diff --git a/triton-vm/src/prelude.rs b/triton-vm/src/prelude.rs index de5164e35..7f4a583cf 100644 --- a/triton-vm/src/prelude.rs +++ b/triton-vm/src/prelude.rs @@ -19,15 +19,18 @@ pub use twenty_first::prelude::Digest; pub use twenty_first::prelude::Tip5; pub use twenty_first::prelude::XFieldElement; -pub use crate::error::InstructionError; -pub use crate::instruction::LabelledInstruction; -pub use crate::program::NonDeterminism; -pub use crate::program::Program; -pub use crate::program::PublicInput; +pub use isa as triton_isa; +pub use isa::error::InstructionError; +pub use isa::instruction::LabelledInstruction; +pub use isa::program::Program; +pub use isa::triton_asm; +pub use isa::triton_instr; +pub use isa::triton_program; + pub use crate::proof::Claim; pub use crate::proof::Proof; pub use crate::stark::Stark; -pub use crate::triton_asm; -pub use crate::triton_instr; -pub use crate::triton_program; +pub use crate::vm::NonDeterminism; +pub use crate::vm::PublicInput; pub use crate::vm::VMState; +pub use crate::vm::VM; diff --git a/triton-vm/src/program.rs b/triton-vm/src/program.rs deleted file mode 100644 index a50734bac..000000000 --- a/triton-vm/src/program.rs +++ /dev/null @@ -1,1119 +0,0 @@ -use std::collections::hash_map::Entry; -use std::collections::HashMap; -use std::collections::HashSet; -use std::fmt::Display; -use std::fmt::Formatter; -use std::fmt::Result as FmtResult; -use std::hash::Hash; -use std::io::Cursor; -use std::ops::Add; -use std::ops::AddAssign; -use std::ops::Sub; - -use arbitrary::Arbitrary; -use get_size::GetSize; -use itertools::Itertools; -use serde::Deserialize; -use serde::Serialize; -use twenty_first::prelude::*; - -use crate::aet::AlgebraicExecutionTrace; -use crate::error::ProgramDecodingError; -use crate::error::VMError; -use crate::instruction::AnInstruction; -use crate::instruction::Instruction; -use crate::instruction::LabelledInstruction; -use crate::instruction::TypeHint; -use crate::parser::parse; -use crate::parser::to_labelled_instructions; -use crate::parser::ParseError; -use crate::profiler::profiler; -use crate::table::hash_table::PERMUTATION_TRACE_LENGTH; -use crate::table::u32_table::U32TableEntry; -use crate::vm::CoProcessorCall; -use crate::vm::VMState; - -type Result = std::result::Result; - -/// A program for Triton VM. -/// It can be -/// [`run`](Program::run), -/// [`profiled`](Program::profile), -/// and its execution can be [`traced`](Program::trace_execution). -/// -/// [`Hashing`](Program::hash) a program yields a [`Digest`] that can be used -/// in a [`Claim`](crate::Claim), _i.e._, is consistent with Triton VM's -/// [program attestation]. -/// -/// A program may contain debug information, such as label names and breakpoints. -/// Access this information through methods [`label_for_address()`][label_for_address] and -/// [`is_breakpoint()`][is_breakpoint]. Some operations, most notably -/// [BField-encoding](BFieldCodec::encode), discard this debug information. -/// -/// [program attestation]: https://triton-vm.org/spec/program-attestation.html -/// [label_for_address]: Program::label_for_address -/// [is_breakpoint]: Program::is_breakpoint -#[derive(Debug, Clone, Eq, Serialize, Deserialize, GetSize)] -pub struct Program { - pub instructions: Vec, - address_to_label: HashMap, - breakpoints: Vec, - type_hints: HashMap>, -} - -impl Display for Program { - fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - for instruction in self.labelled_instructions() { - writeln!(f, "{instruction}")?; - } - Ok(()) - } -} - -impl PartialEq for Program { - fn eq(&self, other: &Program) -> bool { - self.instructions.eq(&other.instructions) - } -} - -impl BFieldCodec for Program { - type Error = ProgramDecodingError; - - fn decode(sequence: &[BFieldElement]) -> std::result::Result, Self::Error> { - if sequence.is_empty() { - return Err(Self::Error::EmptySequence); - } - let program_length = sequence[0].value() as usize; - let sequence = &sequence[1..]; - if sequence.len() < program_length { - return Err(Self::Error::SequenceTooShort); - } - if sequence.len() > program_length { - return Err(Self::Error::SequenceTooLong); - } - - // instantiating with claimed capacity is a potential DOS vector - let mut instructions = vec![]; - let mut read_idx = 0; - while read_idx < program_length { - let opcode = sequence[read_idx]; - let mut instruction = Instruction::try_from(opcode) - .map_err(|err| Self::Error::InvalidInstruction(read_idx, err))?; - let instruction_has_arg = instruction.arg().is_some(); - if instruction_has_arg && instructions.len() + instruction.size() > program_length { - return Err(Self::Error::MissingArgument(read_idx, instruction)); - } - if instruction_has_arg { - let arg = sequence[read_idx + 1]; - instruction = instruction - .change_arg(arg) - .map_err(|err| Self::Error::InvalidInstruction(read_idx, err))?; - } - - instructions.extend(vec![instruction; instruction.size()]); - read_idx += instruction.size(); - } - - if read_idx != program_length { - return Err(Self::Error::LengthMismatch); - } - if instructions.len() != program_length { - return Err(Self::Error::LengthMismatch); - } - - Ok(Box::new(Program { - instructions, - address_to_label: HashMap::default(), - breakpoints: vec![], - type_hints: HashMap::default(), - })) - } - - fn encode(&self) -> Vec { - let mut sequence = Vec::with_capacity(self.len_bwords() + 1); - sequence.push(bfe!(self.len_bwords() as u64)); - sequence.extend(self.to_bwords()); - sequence - } - - fn static_length() -> Option { - None - } -} - -impl<'a> Arbitrary<'a> for Program { - fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { - let contains_label = |labelled_instructions: &[_], maybe_label: &_| { - let LabelledInstruction::Label(label) = maybe_label else { - return false; - }; - labelled_instructions - .iter() - .any(|labelled_instruction| match labelled_instruction { - LabelledInstruction::Label(l) => l == label, - _ => false, - }) - }; - - let mut labelled_instructions = vec![]; - for _ in 0..u.arbitrary_len::()? { - let labelled_instruction = u.arbitrary()?; - if contains_label(&labelled_instructions, &labelled_instruction) { - continue; - } - labelled_instructions.push(labelled_instruction); - } - - let call_targets = labelled_instructions - .iter() - .filter_map(|instruction| match instruction { - LabelledInstruction::Instruction(AnInstruction::Call(target)) => Some(target), - _ => None, - }) - .unique(); - let additional_labels = call_targets - .map(|target| LabelledInstruction::Label(target.clone())) - .collect_vec(); - - for additional_label in additional_labels { - if contains_label(&labelled_instructions, &additional_label) { - continue; - } - let insertion_index = u.choose_index(labelled_instructions.len() + 1)?; - labelled_instructions.insert(insertion_index, additional_label); - } - - Ok(Program::new(&labelled_instructions)) - } -} - -/// An `InstructionIter` loops the instructions of a `Program` by skipping duplicate placeholders. -#[derive(Debug, Default, Clone, Eq, PartialEq)] -pub struct InstructionIter { - cursor: Cursor>, -} - -impl Iterator for InstructionIter { - type Item = Instruction; - - fn next(&mut self) -> Option { - let pos = self.cursor.position() as usize; - let instructions = self.cursor.get_ref(); - let instruction = *instructions.get(pos)?; - self.cursor.set_position((pos + instruction.size()) as u64); - - Some(instruction) - } -} - -impl IntoIterator for Program { - type Item = Instruction; - - type IntoIter = InstructionIter; - - fn into_iter(self) -> Self::IntoIter { - let cursor = Cursor::new(self.instructions); - InstructionIter { cursor } - } -} - -impl Program { - pub fn new(labelled_instructions: &[LabelledInstruction]) -> Self { - let label_to_address = Self::build_label_to_address_map(labelled_instructions); - let instructions = - Self::turn_labels_into_addresses(labelled_instructions, &label_to_address); - let address_to_label = Self::flip_map(label_to_address); - let (breakpoints, type_hints) = Self::extract_debug_information(labelled_instructions); - - assert_eq!(instructions.len(), breakpoints.len()); - Program { - instructions, - address_to_label, - breakpoints, - type_hints, - } - } - - fn build_label_to_address_map(program: &[LabelledInstruction]) -> HashMap { - let mut label_map = HashMap::new(); - let mut instruction_pointer = 0; - - for labelled_instruction in program { - if let LabelledInstruction::Instruction(instruction) = labelled_instruction { - instruction_pointer += instruction.size() as u64; - continue; - } - - let LabelledInstruction::Label(label) = labelled_instruction else { - continue; - }; - let Entry::Vacant(new_label_map_entry) = label_map.entry(label.clone()) else { - panic!("Duplicate label: {label}"); - }; - new_label_map_entry.insert(instruction_pointer); - } - - label_map - } - - fn turn_labels_into_addresses( - labelled_instructions: &[LabelledInstruction], - label_to_address: &HashMap, - ) -> Vec { - labelled_instructions - .iter() - .filter_map(|inst| Self::turn_label_to_address_for_instruction(inst, label_to_address)) - .flat_map(|inst| vec![inst; inst.size()]) - .collect() - } - - fn turn_label_to_address_for_instruction( - labelled_instruction: &LabelledInstruction, - label_map: &HashMap, - ) -> Option { - let LabelledInstruction::Instruction(instruction) = labelled_instruction else { - return None; - }; - - let instruction_with_absolute_address = - instruction.map_call_address(|label| Self::address_for_label(label, label_map)); - Some(instruction_with_absolute_address) - } - - fn address_for_label(label: &str, label_map: &HashMap) -> BFieldElement { - let maybe_address = label_map.get(label).map(|&a| bfe!(a)); - maybe_address.unwrap_or_else(|| panic!("Label not found: {label}")) - } - - fn flip_map(map: HashMap) -> HashMap { - map.into_iter().map(|(key, value)| (value, key)).collect() - } - - fn extract_debug_information( - labelled_instructions: &[LabelledInstruction], - ) -> (Vec, HashMap>) { - let mut breakpoints = vec![]; - let mut type_hints = HashMap::<_, Vec<_>>::new(); - let mut break_before_next_instruction = false; - - let mut address = 0; - for instruction in labelled_instructions { - match instruction { - LabelledInstruction::Instruction(instruction) => { - breakpoints.extend(vec![break_before_next_instruction; instruction.size()]); - break_before_next_instruction = false; - address += instruction.size() as u64; - } - LabelledInstruction::Label(_) => (), - LabelledInstruction::Breakpoint => break_before_next_instruction = true, - LabelledInstruction::TypeHint(type_hint) => match type_hints.entry(address) { - Entry::Occupied(mut entry) => entry.get_mut().push(type_hint.clone()), - Entry::Vacant(entry) => _ = entry.insert(vec![type_hint.clone()]), - }, - } - } - - (breakpoints, type_hints) - } - - /// Create a `Program` by parsing source code. - pub fn from_code(code: &str) -> std::result::Result { - parse(code) - .map(|tokens| to_labelled_instructions(&tokens)) - .map(|instructions| Program::new(&instructions)) - } - - pub fn labelled_instructions(&self) -> Vec { - let call_targets = self.call_targets(); - let instructions_with_labels = self.instructions.iter().map(|instruction| { - instruction.map_call_address(|&address| self.label_for_address(address.value())) - }); - - let mut labelled_instructions = vec![]; - let mut address = 0; - let mut instruction_stream = instructions_with_labels.into_iter(); - while let Some(instruction) = instruction_stream.next() { - let instruction_size = instruction.size() as u64; - if call_targets.contains(&address) { - let label = self.label_for_address(address); - let label = LabelledInstruction::Label(label); - labelled_instructions.push(label); - } - for type_hint in self.type_hints_at(address) { - labelled_instructions.push(LabelledInstruction::TypeHint(type_hint)); - } - if self.is_breakpoint(address) { - labelled_instructions.push(LabelledInstruction::Breakpoint); - } - labelled_instructions.push(LabelledInstruction::Instruction(instruction)); - - for _ in 1..instruction_size { - instruction_stream.next(); - } - address += instruction_size; - } - - let leftover_labels = self - .address_to_label - .iter() - .filter(|(&labels_address, _)| labels_address >= address) - .sorted(); - for (_, label) in leftover_labels { - labelled_instructions.push(LabelledInstruction::Label(label.clone())); - } - - labelled_instructions - } - - fn call_targets(&self) -> HashSet { - self.instructions - .iter() - .filter_map(|instruction| match instruction { - Instruction::Call(address) => Some(address.value()), - _ => None, - }) - .collect() - } - - pub fn is_breakpoint(&self, address: u64) -> bool { - let address: usize = address.try_into().unwrap(); - self.breakpoints.get(address).unwrap_or(&false).to_owned() - } - - pub fn type_hints_at(&self, address: u64) -> Vec { - self.type_hints.get(&address).cloned().unwrap_or_default() - } - - /// Turn the program into a sequence of `BFieldElement`s. Each instruction is encoded as its - /// opcode, followed by its argument (if any). - /// - /// **Note**: This is _almost_ (but not quite!) equivalent to [encoding](BFieldCodec::encode) - /// the program. For that, use [`encode()`](Self::encode()) instead. - pub fn to_bwords(&self) -> Vec { - self.clone() - .into_iter() - .flat_map(|instruction| { - let opcode = instruction.opcode_b(); - if let Some(arg) = instruction.arg() { - vec![opcode, arg] - } else { - vec![opcode] - } - }) - .collect() - } - - /// The total length of the program as `BFieldElement`s. Double-word instructions contribute - /// two `BFieldElement`s. - pub fn len_bwords(&self) -> usize { - self.instructions.len() - } - - pub fn is_empty(&self) -> bool { - self.instructions.is_empty() - } - - /// Produces the program's canonical hash digest. Uses [`Tip5`], the - /// canonical hash function for Triton VM. - pub fn hash(&self) -> Digest { - // not encoded using `BFieldCodec` because that would prepend the length - Tip5::hash_varlen(&self.to_bwords()) - } - - /// Run Triton VM on the [`Program`] with the given public input and non-determinism. - /// If an error is encountered, the returned [`VMError`] contains the [`VMState`] at the point - /// of execution failure. - /// - /// See also [`trace_execution`][trace_execution] and [`profile`][profile]. - /// - /// [trace_execution]: Self::trace_execution - /// [profile]: Self::profile - pub fn run( - &self, - public_input: PublicInput, - non_determinism: NonDeterminism, - ) -> Result> { - let mut state = VMState::new(self, public_input, non_determinism); - if let Err(err) = state.run() { - return Err(VMError::new(err, state)); - } - Ok(state.public_output) - } - - /// Trace the execution of a [`Program`]. That is, [`run`][run] the [`Program`] and additionally - /// record that part of every encountered state that is necessary for proving correct execution. - /// If execution succeeds, returns - /// 1. an [`AlgebraicExecutionTrace`], and - /// 1. the output of the program. - /// - /// See also [`run`][run] and [`profile`][profile]. - /// - /// [run]: Self::run - /// [profile]: Self::profile - pub fn trace_execution( - &self, - public_input: PublicInput, - non_determinism: NonDeterminism, - ) -> Result<(AlgebraicExecutionTrace, Vec)> { - profiler!(start "trace execution" ("gen")); - let state = VMState::new(self, public_input, non_determinism); - let (aet, terminal_state) = self.trace_execution_of_state(state)?; - profiler!(stop "trace execution"); - Ok((aet, terminal_state.public_output)) - } - - /// Trace the execution of a [`Program`] from a given [`VMState`]. Consider - /// using [`trace_execution`][Self::trace_execution], unless you know this is - /// what you want. - /// - /// Returns the [`AlgebraicExecutionTrace`] and the terminal [`VMState`] if - /// execution succeeds. - /// - /// # Panics - /// - /// - if the given [`VMState`] is not about to `self` - /// - if the given [`VMState`] is incorrectly initialized - pub fn trace_execution_of_state( - &self, - mut state: VMState, - ) -> Result<(AlgebraicExecutionTrace, VMState)> { - let mut aet = AlgebraicExecutionTrace::new(self.clone()); - assert_eq!(self.instructions, state.program); - assert_eq!(self.len_bwords(), aet.instruction_multiplicities.len()); - - while !state.halting { - if let Err(err) = aet.record_state(&state) { - return Err(VMError::new(err, state)); - }; - let co_processor_calls = match state.step() { - Ok(calls) => calls, - Err(err) => return Err(VMError::new(err, state)), - }; - for call in co_processor_calls { - aet.record_co_processor_call(call); - } - } - - Ok((aet, state)) - } - - /// Run Triton VM with the given public and secret input, recording the - /// influence of a callable block of instructions on the - /// [`AlgebraicExecutionTrace`]. For example, this can be used to identify the - /// number of clock cycles spent in some block of instructions, or how many rows - /// it contributes to the U32 Table. - /// - /// See also [`run`][run] and [`trace_execution`][trace_execution]. - /// - /// [run]: Self::run - /// [trace_execution]: Self::trace_execution - pub fn profile( - &self, - public_input: PublicInput, - non_determinism: NonDeterminism, - ) -> Result<(Vec, ExecutionTraceProfile)> { - let mut profiler = ExecutionTraceProfiler::new(self.instructions.len()); - let mut state = VMState::new(self, public_input, non_determinism); - let mut previous_jump_stack_len = state.jump_stack.len(); - while !state.halting { - if let Ok(Instruction::Call(address)) = state.current_instruction() { - let label = self.label_for_address(address.value()); - profiler.enter_span(label); - } - - match state.step() { - Ok(calls) => profiler.handle_co_processor_calls(calls), - Err(err) => return Err(VMError::new(err, state)), - }; - - if state.jump_stack.len() < previous_jump_stack_len { - profiler.exit_span(); - } - previous_jump_stack_len = state.jump_stack.len(); - } - - Ok((state.public_output, profiler.finish())) - } - - /// The label for the given address, or a deterministic, unique substitute if no label is found. - pub fn label_for_address(&self, address: u64) -> String { - // Uniqueness of the label is relevant for printing and subsequent parsing: - // Parsing fails on duplicate labels. - self.address_to_label - .get(&address) - .cloned() - .unwrap_or_else(|| format!("address_{address}")) - } -} - -#[derive(Debug, Default, Clone, Eq, PartialEq, Arbitrary)] -struct ExecutionTraceProfiler { - call_stack: Vec, - profile: Vec, - table_heights: VMTableHeights, - u32_table_entries: HashSet, -} - -/// A single line in a [profile report](ExecutionTraceProfile) for profiling -/// [Triton](crate) programs. -#[derive(Debug, Default, Clone, Eq, PartialEq, Hash, Arbitrary)] -pub struct ProfileLine { - pub label: String, - pub call_depth: usize, - - /// Table heights at the start of this span, _i.e._, right before the corresponding - /// [`call`](Instruction::Call) instruction was executed. - pub table_heights_start: VMTableHeights, - - table_heights_stop: VMTableHeights, -} - -/// A report for the completed execution of a [Triton](crate) program. -/// -/// Offers a human-readable [`Display`] implementation and can be processed -/// programmatically. -#[derive(Debug, Clone, Eq, PartialEq, Hash, Arbitrary)] -pub struct ExecutionTraceProfile { - pub total: VMTableHeights, - pub profile: Vec, -} - -/// The heights of various [tables](AlgebraicExecutionTrace) relevant for -/// proving the correct execution in [Triton VM](crate). -#[non_exhaustive] -#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)] -pub struct VMTableHeights { - pub processor: u32, - pub op_stack: u32, - pub ram: u32, - pub hash: u32, - pub u32: u32, -} - -impl ExecutionTraceProfiler { - fn new(num_instructions: usize) -> Self { - Self { - call_stack: vec![], - profile: vec![], - table_heights: VMTableHeights::new(num_instructions), - u32_table_entries: HashSet::default(), - } - } - - fn enter_span(&mut self, label: impl Into) { - let call_stack_len = self.call_stack.len(); - let line_number = self.profile.len(); - - let profile_line = ProfileLine { - label: label.into(), - call_depth: call_stack_len, - table_heights_start: self.table_heights, - table_heights_stop: VMTableHeights::default(), - }; - - self.profile.push(profile_line); - self.call_stack.push(line_number); - } - - fn exit_span(&mut self) { - if let Some(line_number) = self.call_stack.pop() { - self.profile[line_number].table_heights_stop = self.table_heights; - }; - } - - fn handle_co_processor_calls(&mut self, calls: Vec) { - self.table_heights.processor += 1; - for call in calls { - match call { - CoProcessorCall::SpongeStateReset => self.table_heights.hash += 1, - CoProcessorCall::Tip5Trace(_, trace) => { - self.table_heights.hash += u32::try_from(trace.len()).unwrap(); - } - CoProcessorCall::U32Call(c) => { - self.u32_table_entries.insert(c); - let contribution = U32TableEntry::table_height_contribution; - self.table_heights.u32 = self.u32_table_entries.iter().map(contribution).sum(); - } - CoProcessorCall::OpStackCall(_) => self.table_heights.op_stack += 1, - CoProcessorCall::RamCall(_) => self.table_heights.ram += 1, - } - } - } - - fn finish(mut self) -> ExecutionTraceProfile { - for &line_number in &self.call_stack { - self.profile[line_number].table_heights_stop = self.table_heights; - } - - ExecutionTraceProfile { - total: self.table_heights, - profile: self.profile, - } - } -} - -impl VMTableHeights { - fn new(num_instructions: usize) -> Self { - let padded_program_len = (num_instructions + 1).next_multiple_of(Tip5::RATE); - let num_absorbs = padded_program_len / Tip5::RATE; - let initial_hash_table_len = num_absorbs * PERMUTATION_TRACE_LENGTH; - - Self { - hash: initial_hash_table_len.try_into().unwrap(), - ..Default::default() - } - } -} - -impl Sub for VMTableHeights { - type Output = Self; - - fn sub(self, rhs: Self) -> Self::Output { - Self { - processor: self.processor.saturating_sub(rhs.processor), - op_stack: self.op_stack.saturating_sub(rhs.op_stack), - ram: self.ram.saturating_sub(rhs.ram), - hash: self.hash.saturating_sub(rhs.hash), - u32: self.u32.saturating_sub(rhs.u32), - } - } -} - -impl Add for VMTableHeights { - type Output = Self; - - fn add(self, rhs: Self) -> Self::Output { - Self { - processor: self.processor + rhs.processor, - op_stack: self.op_stack + rhs.op_stack, - ram: self.ram + rhs.ram, - hash: self.hash + rhs.hash, - u32: self.u32 + rhs.u32, - } - } -} - -impl AddAssign for VMTableHeights { - fn add_assign(&mut self, rhs: Self) { - *self = *self + rhs; - } -} - -impl ProfileLine { - fn table_height_contributions(&self) -> VMTableHeights { - self.table_heights_stop - self.table_heights_start - } -} - -impl Display for ProfileLine { - fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - let indentation = " ".repeat(self.call_depth); - let label = &self.label; - let cycle_count = self.table_height_contributions().processor; - write!(f, "{indentation}{label}: {cycle_count}") - } -} - -impl Display for ExecutionTraceProfile { - fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - struct AggregateLine { - label: String, - call_depth: usize, - table_heights: VMTableHeights, - } - - const COL_WIDTH: usize = 20; - - let mut aggregated: Vec = vec![]; - for line in &self.profile { - if let Some(agg) = aggregated - .iter_mut() - .find(|agg| agg.label == line.label && agg.call_depth == line.call_depth) - { - agg.table_heights += line.table_height_contributions(); - } else { - aggregated.push(AggregateLine { - label: line.label.clone(), - call_depth: line.call_depth, - table_heights: line.table_height_contributions(), - }); - } - } - aggregated.push(AggregateLine { - label: "Total".to_string(), - call_depth: 0, - table_heights: self.total, - }); - - let label = |line: &AggregateLine| "··".repeat(line.call_depth) + &line.label; - let label_len = |line| label(line).len(); - - let max_label_len = aggregated.iter().map(label_len).max(); - let max_label_len = max_label_len.unwrap_or_default().max(COL_WIDTH); - - let [soubroutine, processor, op_stack, ram, hash, u32_title] = - ["Subroutine", "Processor", "Op Stack", "RAM", "Hash", "U32"]; - - write!(f, "| {soubroutine:COL_WIDTH$} ")?; - write!(f, "| {op_stack:>COL_WIDTH$} ")?; - write!(f, "| {ram:>COL_WIDTH$} ")?; - write!(f, "| {hash:>COL_WIDTH$} ")?; - write!(f, "| {u32_title:>COL_WIDTH$} ")?; - writeln!(f, "|")?; - - let dash = "-"; - write!(f, "|:{dash:-COL_WIDTH$}:")?; - write!(f, "|-{dash:->COL_WIDTH$}:")?; - write!(f, "|-{dash:->COL_WIDTH$}:")?; - write!(f, "|-{dash:->COL_WIDTH$}:")?; - write!(f, "|-{dash:->COL_WIDTH$}:")?; - writeln!(f, "|")?; - - for line in &aggregated { - let rel_precision = 1; - let rel_width = 3 + 1 + rel_precision; // eg '100.0' - let abs_width = COL_WIDTH - rel_width - 4; // ' (' and '%)' - - let label = label(line); - let proc_abs = line.table_heights.processor; - let proc_rel = 100.0 * f64::from(proc_abs) / f64::from(self.total.processor); - let proc_rel = format!("{proc_rel:.rel_precision$}"); - let stack_abs = line.table_heights.op_stack; - let stack_rel = 100.0 * f64::from(stack_abs) / f64::from(self.total.op_stack); - let stack_rel = format!("{stack_rel:.rel_precision$}"); - let ram_abs = line.table_heights.ram; - let ram_rel = 100.0 * f64::from(ram_abs) / f64::from(self.total.ram); - let ram_rel = format!("{ram_rel:.rel_precision$}"); - let hash_abs = line.table_heights.hash; - let hash_rel = 100.0 * f64::from(hash_abs) / f64::from(self.total.hash); - let hash_rel = format!("{hash_rel:.rel_precision$}"); - let u32_abs = line.table_heights.u32; - let u32_rel = 100.0 * f64::from(u32_abs) / f64::from(self.total.u32); - let u32_rel = format!("{u32_rel:.rel_precision$}"); - - write!(f, "| {label:abs_width$} ({proc_rel:>rel_width$}%) ")?; - write!(f, "| {stack_abs:>abs_width$} ({stack_rel:>rel_width$}%) ")?; - write!(f, "| {ram_abs:>abs_width$} ({ram_rel:>rel_width$}%) ")?; - write!(f, "| {hash_abs:>abs_width$} ({hash_rel:>rel_width$}%) ")?; - write!(f, "| {u32_abs:>abs_width$} ({u32_rel:>rel_width$}%) ")?; - writeln!(f, "|")?; - } - - Ok(()) - } -} - -#[derive(Debug, Default, Clone, Eq, PartialEq, BFieldCodec, Arbitrary)] -pub struct PublicInput { - pub individual_tokens: Vec, -} - -impl From> for PublicInput { - fn from(individual_tokens: Vec) -> Self { - Self::new(individual_tokens) - } -} - -impl From<&Vec> for PublicInput { - fn from(tokens: &Vec) -> Self { - Self::new(tokens.to_owned()) - } -} - -impl From<[BFieldElement; N]> for PublicInput { - fn from(tokens: [BFieldElement; N]) -> Self { - Self::new(tokens.to_vec()) - } -} - -impl From<&[BFieldElement]> for PublicInput { - fn from(tokens: &[BFieldElement]) -> Self { - Self::new(tokens.to_vec()) - } -} - -impl PublicInput { - pub fn new(individual_tokens: Vec) -> Self { - Self { individual_tokens } - } -} - -/// All sources of non-determinism for a program. This includes elements that -/// can be read using instruction `divine`, digests that can be read using -/// instruction `merkle_step`, and an initial state of random-access memory. -#[derive(Debug, Default, Clone, Eq, PartialEq, Serialize, Deserialize, Arbitrary)] -pub struct NonDeterminism { - pub individual_tokens: Vec, - pub digests: Vec, - pub ram: HashMap, -} - -impl From> for NonDeterminism { - fn from(tokens: Vec) -> Self { - Self::new(tokens) - } -} - -impl From<&Vec> for NonDeterminism { - fn from(tokens: &Vec) -> Self { - Self::new(tokens.to_owned()) - } -} - -impl From<[BFieldElement; N]> for NonDeterminism { - fn from(tokens: [BFieldElement; N]) -> Self { - Self::new(tokens.to_vec()) - } -} - -impl From<&[BFieldElement]> for NonDeterminism { - fn from(tokens: &[BFieldElement]) -> Self { - Self::new(tokens.to_vec()) - } -} - -impl NonDeterminism { - pub fn new>>(individual_tokens: V) -> Self { - Self { - individual_tokens: individual_tokens.into(), - digests: vec![], - ram: HashMap::new(), - } - } - - #[must_use] - pub fn with_digests>>(mut self, digests: V) -> Self { - self.digests = digests.into(); - self - } - - #[must_use] - pub fn with_ram>>(mut self, ram: H) -> Self { - self.ram = ram.into(); - self - } -} - -#[cfg(test)] -mod tests { - use assert2::assert; - use assert2::let_assert; - use proptest::prelude::*; - use proptest_arbitrary_interop::arb; - use rand::thread_rng; - use rand::Rng; - use test_strategy::proptest; - - use crate::error::InstructionError; - use crate::example_programs::CALCULATE_NEW_MMR_PEAKS_FROM_APPEND_WITH_SAFE_LISTS; - use crate::table::master_table::TableId; - use crate::triton_program; - - use super::*; - - #[proptest] - fn random_program_encode_decode_equivalence(#[strategy(arb())] program: Program) { - let encoding = program.encode(); - let decoding = *Program::decode(&encoding).unwrap(); - prop_assert_eq!(program, decoding); - } - - #[test] - fn decode_program_with_missing_argument_as_last_instruction() { - let program = triton_program!(push 3 push 3 eq assert push 3); - let program_length = program.len_bwords() as u64; - let encoded = program.encode(); - - let mut encoded = encoded[0..encoded.len() - 1].to_vec(); - encoded[0] = bfe!(program_length - 1); - - let_assert!(Err(err) = Program::decode(&encoded)); - let_assert!(ProgramDecodingError::MissingArgument(6, _) = err); - } - - #[test] - fn decode_program_with_shorter_than_indicated_sequence() { - let program = triton_program!(nop nop hash push 0 skiz end: halt call end); - let mut encoded = program.encode(); - encoded[0] += bfe!(1); - let_assert!(Err(err) = Program::decode(&encoded)); - let_assert!(ProgramDecodingError::SequenceTooShort = err); - } - - #[test] - fn decode_program_with_longer_than_indicated_sequence() { - let program = triton_program!(nop nop hash push 0 skiz end: halt call end); - let mut encoded = program.encode(); - encoded[0] -= bfe!(1); - let_assert!(Err(err) = Program::decode(&encoded)); - let_assert!(ProgramDecodingError::SequenceTooLong = err); - } - - #[test] - fn decode_program_from_empty_sequence() { - let encoded = vec![]; - let_assert!(Err(err) = Program::decode(&encoded)); - let_assert!(ProgramDecodingError::EmptySequence = err); - } - - #[test] - fn hash_simple_program() { - let program = triton_program!(halt); - let digest = program.hash(); - - let expected_digest = bfe_array![ - 0x4338_de79_520b_3949_u64, - 0xe6a2_129b_2885_0dc9_u64, - 0xfd3c_d098_6a86_0450_u64, - 0x69fd_ba91_0ceb_a7bc_u64, - 0x7e5b_118c_9594_c062_u64, - ]; - let expected_digest = Digest::new(expected_digest); - - assert!(expected_digest == digest); - } - - #[test] - fn empty_program_is_empty() { - let program = triton_program!(); - assert!(program.is_empty()); - } - - #[proptest] - fn from_various_types_to_public_input(#[strategy(arb())] tokens: Vec) { - let public_input = PublicInput::new(tokens.clone()); - - assert!(public_input == tokens.clone().into()); - assert!(public_input == (&tokens).into()); - assert!(public_input == tokens[..].into()); - assert!(public_input == (&tokens[..]).into()); - - assert!(PublicInput::new(vec![]) == [].into()); - } - - #[proptest] - fn from_various_types_to_non_determinism(#[strategy(arb())] tokens: Vec) { - let non_determinism = NonDeterminism::new(tokens.clone()); - - assert!(non_determinism == tokens.clone().into()); - assert!(non_determinism == tokens[..].into()); - assert!(non_determinism == (&tokens[..]).into()); - - assert!(NonDeterminism::new(vec![]) == [].into()); - } - - #[test] - fn create_program_from_code() { - let element_3 = thread_rng().gen_range(0_u64..BFieldElement::P); - let element_2 = 1337_usize; - let element_1 = "17"; - let element_0 = bfe!(0); - let instruction_push = Instruction::Push(bfe!(42)); - let dup_arg = 1; - let label = "my_label".to_string(); - - let source_code = format!( - "push {element_3} push {element_2} push {element_1} push {element_0} - call {label} halt - {label}: - {instruction_push} - dup {dup_arg} - skiz - recurse - return" - ); - let program_from_code = Program::from_code(&source_code).unwrap(); - let program_from_macro = triton_program!({ source_code }); - assert!(program_from_code == program_from_macro); - } - - #[test] - fn parser_macro_with_interpolated_label_as_first_argument() { - let label = "my_label"; - let program = triton_program!( - {label}: push 1 assert halt - ); - program.run([].into(), [].into()).unwrap(); - } - - #[test] - fn profile_can_be_created_and_agrees_with_regular_vm_run() { - let program = CALCULATE_NEW_MMR_PEAKS_FROM_APPEND_WITH_SAFE_LISTS.clone(); - let (profile_output, profile) = program.profile([].into(), [].into()).unwrap(); - let mut vm_state = VMState::new(&program, [].into(), [].into()); - let_assert!(Ok(()) = vm_state.run()); - assert!(profile_output == vm_state.public_output); - assert!(profile.total.processor == vm_state.cycle_count); - - let_assert!(Ok((aet, trace_output)) = program.trace_execution([].into(), [].into())); - assert!(profile_output == trace_output); - let proc_height = u32::try_from(aet.height_of_table(TableId::Processor)).unwrap(); - assert!(proc_height == profile.total.processor); - - let op_stack_height = u32::try_from(aet.height_of_table(TableId::OpStack)).unwrap(); - assert!(op_stack_height == profile.total.op_stack); - - let ram_height = u32::try_from(aet.height_of_table(TableId::Ram)).unwrap(); - assert!(ram_height == profile.total.ram); - - let hash_height = u32::try_from(aet.height_of_table(TableId::Hash)).unwrap(); - assert!(hash_height == profile.total.hash); - - let u32_height = u32::try_from(aet.height_of_table(TableId::U32)).unwrap(); - assert!(u32_height == profile.total.u32); - - println!("{profile}"); - } - - #[test] - fn program_with_too_many_returns_crashes_vm_but_not_profiler() { - let program = triton_program! { - call foo return halt - foo: return - }; - let_assert!(Err(err) = program.profile([].into(), [].into())); - let_assert!(InstructionError::JumpStackIsEmpty = err.source); - } - - #[test] - fn breakpoints_propagate_to_debug_information_as_expected() { - let program = triton_program! { - break push 1 push 2 - break break break break - pop 2 hash halt - break // no effect - }; - - assert!(program.is_breakpoint(0)); - assert!(program.is_breakpoint(1)); - assert!(!program.is_breakpoint(2)); - assert!(!program.is_breakpoint(3)); - assert!(program.is_breakpoint(4)); - assert!(program.is_breakpoint(5)); - assert!(!program.is_breakpoint(6)); - assert!(!program.is_breakpoint(7)); - - // going beyond the length of the program must not break things - assert!(!program.is_breakpoint(8)); - assert!(!program.is_breakpoint(9)); - } - - #[test] - fn print_program_without_any_debug_information() { - let program = triton_program! { - call foo - call bar - call baz - halt - foo: nop nop return - bar: call baz return - baz: push 1 return - }; - let encoding = program.encode(); - let program = Program::decode(&encoding).unwrap(); - println!("{program}"); - } -} diff --git a/triton-vm/src/proof.rs b/triton-vm/src/proof.rs index 61195e131..a2c455bf3 100644 --- a/triton-vm/src/proof.rs +++ b/triton-vm/src/proof.rs @@ -1,12 +1,12 @@ use arbitrary::Arbitrary; use get_size::GetSize; +use isa::program::Program; use itertools::Itertools; use serde::Deserialize; use serde::Serialize; use twenty_first::prelude::*; use crate::error::ProofStreamError; -use crate::program::Program; use crate::proof_stream::ProofStream; /// Contains the necessary cryptographic information to verify a computation. diff --git a/triton-vm/src/shared_tests.rs b/triton-vm/src/shared_tests.rs index c1fe25686..11715d2ac 100644 --- a/triton-vm/src/shared_tests.rs +++ b/triton-vm/src/shared_tests.rs @@ -1,5 +1,6 @@ use assert2::assert; use assert2::let_assert; +use isa::program::Program; use num_traits::Zero; use proptest::collection::vec; use proptest::prelude::*; @@ -10,14 +11,10 @@ use twenty_first::prelude::*; use crate::aet::AlgebraicExecutionTrace; use crate::error::VMError; use crate::fri::AuthenticationStructure; +use crate::prelude::*; use crate::profiler::profiler; -use crate::program::Program; -use crate::proof::Claim; use crate::proof_item::FriResponse; -use crate::stark::Stark; use crate::table::master_table::MasterBaseTable; -use crate::NonDeterminism; -use crate::PublicInput; pub(crate) const DEFAULT_LOG2_FRI_EXPANSION_FACTOR_FOR_TESTS: usize = 2; @@ -112,9 +109,8 @@ pub(crate) fn prove_and_verify( } = program_and_input; profiler!(start "Pre-flight"); - let (aet, public_output) = program - .trace_execution(public_input.clone(), non_determinism.clone()) - .unwrap(); + let (aet, public_output) = + VM::trace_execution(&program, public_input.clone(), non_determinism.clone()).unwrap(); let claim = Claim::about_program(&program) .with_input(public_input.individual_tokens.clone()) @@ -197,9 +193,8 @@ impl ProgramAndInput { self.non_determinism.clone() } - /// A thin wrapper around [`Program::run`]. + /// A thin wrapper around [`VM::run`]. pub fn run(&self) -> Result, VMError> { - self.program - .run(self.public_input(), self.non_determinism()) + VM::run(&self.program, self.public_input(), self.non_determinism()) } } diff --git a/triton-vm/src/stark.rs b/triton-vm/src/stark.rs index 030415d93..52fc1de36 100644 --- a/triton-vm/src/stark.rs +++ b/triton-vm/src/stark.rs @@ -1313,6 +1313,10 @@ pub(crate) mod tests { use assert2::assert; use assert2::check; use assert2::let_assert; + use constraint_builder::ConstraintCircuitBuilder; + use isa::error::OpStackError; + use isa::instruction::Instruction; + use isa::op_stack::OpStackElement; use itertools::izip; use num_traits::Zero; use proptest::collection::vec; @@ -1327,9 +1331,6 @@ pub(crate) mod tests { use crate::error::InstructionError; use crate::example_programs::*; - use crate::instruction::Instruction; - use crate::op_stack::OpStackElement; - use crate::program::NonDeterminism; use crate::shared_tests::*; use crate::table::cascade_table::ExtCascadeTable; use crate::table::challenges::ChallengeId::StandardInputIndeterminate; @@ -1360,8 +1361,9 @@ pub(crate) mod tests { use crate::table::u32_table::ExtU32Table; use crate::triton_program; use crate::vm::tests::*; + use crate::vm::NonDeterminism; + use crate::vm::VM; use crate::PublicInput; - use constraint_builder::ConstraintCircuitBuilder; use super::*; @@ -1374,9 +1376,8 @@ pub(crate) mod tests { non_determinism, } = program_and_input; - let (aet, stdout) = program - .trace_execution(public_input.clone(), non_determinism) - .unwrap(); + let (aet, stdout) = + VM::trace_execution(&program, public_input.clone(), non_determinism).unwrap(); let stark = low_security_stark(DEFAULT_LOG2_FRI_EXPANSION_FACTOR_FOR_TESTS); let claim = Claim::about_program(&aet.program) .with_input(public_input.individual_tokens) @@ -2074,7 +2075,7 @@ pub(crate) mod tests { xx_dot_step halt }; - let result = program.run(PublicInput::default(), NonDeterminism::default()); + let result = VM::run(&program, PublicInput::default(), NonDeterminism::default()); assert!(result.is_ok()); let program_and_input = ProgramAndInput::new(program); triton_constraints_evaluate_to_zero(program_and_input); @@ -2346,15 +2347,16 @@ pub(crate) mod tests { st0: BFieldElement, ) { let program = triton_program!(push {st0} log_2_floor halt); - let_assert!(Err(err) = program.run([].into(), [].into())); - let_assert!(InstructionError::FailedU32Conversion(element) = err.source); + let_assert!(Err(err) = VM::run(&program, [].into(), [].into())); + let_assert!(InstructionError::OpStackError(err) = err.source); + let_assert!(OpStackError::FailedU32Conversion(element) = err); assert!(st0 == element); } #[test] fn negative_log_2_floor_of_0() { let program = triton_program!(push 0 log_2_floor halt); - let_assert!(Err(err) = program.run([].into(), [].into())); + let_assert!(Err(err) = VM::run(&program, [].into(), [].into())); let_assert!(InstructionError::LogarithmOfZero = err.source); } @@ -2776,9 +2778,7 @@ pub(crate) mod tests { public_input, non_determinism, } = program_executing_every_instruction(); - let (aet, _) = program - .trace_execution(public_input, non_determinism) - .unwrap(); + let (aet, _) = VM::trace_execution(&program, public_input, non_determinism).unwrap(); let opcodes_of_all_executed_instructions = aet .processor_trace .column(ProcessorBaseTableColumn::CI.base_table_index()) diff --git a/triton-vm/src/table/hash_table.rs b/triton-vm/src/table/hash_table.rs index 7e8411809..8ae834836 100644 --- a/triton-vm/src/table/hash_table.rs +++ b/triton-vm/src/table/hash_table.rs @@ -5,6 +5,11 @@ use constraint_builder::DualRowIndicator::*; use constraint_builder::InputIndicator; use constraint_builder::SingleRowIndicator; use constraint_builder::SingleRowIndicator::*; +use isa::instruction::AnInstruction::Hash; +use isa::instruction::AnInstruction::SpongeAbsorb; +use isa::instruction::AnInstruction::SpongeInit; +use isa::instruction::AnInstruction::SpongeSqueeze; +use isa::instruction::Instruction; use itertools::Itertools; use ndarray::*; use num_traits::Zero; @@ -21,11 +26,6 @@ use twenty_first::prelude::tip5::STATE_SIZE; use twenty_first::prelude::*; use crate::aet::AlgebraicExecutionTrace; -use crate::instruction::AnInstruction::Hash; -use crate::instruction::AnInstruction::SpongeAbsorb; -use crate::instruction::AnInstruction::SpongeInit; -use crate::instruction::AnInstruction::SpongeSqueeze; -use crate::instruction::Instruction; use crate::profiler::profiler; use crate::table::cascade_table::CascadeTable; use crate::table::challenges::ChallengeId::*; @@ -84,7 +84,7 @@ pub struct ExtHashTable; /// The empty program is not valid since any valid [`Program`][program] must execute /// instruction `halt`. /// -/// [program]: crate::program::Program +/// [program]: isa::program::Program /// [prog_hash]: HashTableMode::ProgramHashing /// [sponge]: HashTableMode::Sponge /// [hash]: type@HashTableMode::Hash @@ -93,7 +93,7 @@ pub struct ExtHashTable; pub enum HashTableMode { /// The mode in which the [`Program`][program] is hashed. This is part of program attestation. /// - /// [program]: crate::program::Program + /// [program]: isa::program::Program ProgramHashing, /// The mode in which Sponge instructions, _i.e._, `sponge_init`, @@ -554,7 +554,7 @@ impl ExtHashTable { Self::round_number_deselector(circuit_builder, &round_number, round_idx); round_constant_constraint_circuit = round_constant_constraint_circuit + round_deselector_circuit - * (round_constant_column_circuit.clone() - round_constant); + * (round_constant_column_circuit.clone() - round_constant); } constraints.push(round_constant_constraint_circuit); } @@ -709,7 +709,7 @@ impl ExtHashTable { StackWeight14, StackWeight15, ] - .map(challenge); + .map(challenge); let round_number_is_not_num_rounds = Self::round_number_deselector(circuit_builder, &round_number, NUM_ROUNDS); @@ -820,7 +820,7 @@ impl ExtHashTable { * running_evaluation_hash_input_updates + round_number_next.clone() * running_evaluation_hash_input_remains.clone() + Self::select_mode(circuit_builder, &mode_next, HashTableMode::Hash) - * running_evaluation_hash_input_remains; + * running_evaluation_hash_input_remains; // If (and only if) the row number in the next row is NUM_ROUNDS and the current instruction // in the next row corresponds to `hash`, update running evaluation “hash digest.” @@ -843,7 +843,7 @@ impl ExtHashTable { * running_evaluation_hash_digest_updates + round_number_next_is_num_rounds * running_evaluation_hash_digest_remains.clone() + Self::select_mode(circuit_builder, &mode_next, HashTableMode::Hash) - * running_evaluation_hash_digest_remains; + * running_evaluation_hash_digest_remains; // The running evaluation for “Sponge” updates correctly. let compressed_row_next = state_weights[..RATE] @@ -893,7 +893,7 @@ impl ExtHashTable { * receive_chunk_running_evaluation_absorbs_chunk_of_instructions + round_number_next * receive_chunk_running_evaluation_remains.clone() + Self::select_mode(circuit_builder, &mode_next, HashTableMode::ProgramHashing) - * receive_chunk_running_evaluation_remains; + * receive_chunk_running_evaluation_remains; let constraints = vec![ round_number_is_0_through_4_or_round_number_next_is_0, @@ -1013,7 +1013,7 @@ impl ExtHashTable { constraints, hash_function_round_correctly_performs_update.to_vec(), ] - .concat() + .concat() } fn indicate_column_index_in_base_row(column: HashBaseTableColumn) -> SingleRowIndicator { @@ -1112,7 +1112,7 @@ impl ExtHashTable { State4, State5, State6, State7, State8, State9, State10, State11, State12, State13, State14, State15, ] - .map(current_base_row); + .map(current_base_row); let state_part_after_power_map = { let mut exponentiation_accumulator = state_part_before_power_map.clone(); @@ -1156,7 +1156,7 @@ impl ExtHashTable { Constant8, Constant9, Constant10, Constant11, Constant12, Constant13, Constant14, Constant15, ] - .map(current_base_row); + .map(current_base_row); let state_after_round_constant_addition = state_after_matrix_multiplication .into_iter() @@ -1857,6 +1857,7 @@ pub(crate) mod tests { use crate::table::master_table::TableId; use crate::triton_asm; use crate::triton_program; + use crate::vm::VM; use super::*; @@ -1885,7 +1886,7 @@ pub(crate) mod tests { halt }; - let (aet, _) = program.trace_execution([].into(), [].into()).unwrap(); + let (aet, _) = VM::trace_execution(&program, [].into(), [].into()).unwrap(); dbg!(aet.height()); dbg!(aet.padded_height()); dbg!(aet.height_of_table(TableId::Hash)); diff --git a/triton-vm/src/table/jump_stack_table.rs b/triton-vm/src/table/jump_stack_table.rs index 0547ab894..334a79ed9 100644 --- a/triton-vm/src/table/jump_stack_table.rs +++ b/triton-vm/src/table/jump_stack_table.rs @@ -5,6 +5,7 @@ use std::ops::Range; use constraint_builder::DualRowIndicator::*; use constraint_builder::SingleRowIndicator::*; use constraint_builder::*; +use isa::instruction::Instruction; use itertools::Itertools; use ndarray::parallel::prelude::*; use ndarray::prelude::*; @@ -14,7 +15,6 @@ use twenty_first::math::traits::FiniteField; use twenty_first::prelude::*; use crate::aet::AlgebraicExecutionTrace; -use crate::instruction::Instruction; use crate::ndarray_helper::contiguous_column_slices; use crate::ndarray_helper::horizontal_multi_slice_mut; use crate::profiler::profiler; diff --git a/triton-vm/src/table/master_table.rs b/triton-vm/src/table/master_table.rs index 8b10d4ca9..74e00accc 100644 --- a/triton-vm/src/table/master_table.rs +++ b/triton-vm/src/table/master_table.rs @@ -1278,6 +1278,8 @@ mod tests { use constraint_builder::DegreeLoweringInfo; use constraint_builder::DualRowIndicator; use constraint_builder::SingleRowIndicator; + use isa::instruction::Instruction; + use isa::instruction::InstructionBit; use master_table::cross_table_argument::GrandCrossTableArg; use ndarray::s; use ndarray::Array2; @@ -1297,9 +1299,6 @@ mod tests { use crate::air::tasm_air_constraints::dynamic_air_constraint_evaluation_tasm; use crate::air::tasm_air_constraints::static_air_constraint_evaluation_tasm; use crate::arithmetic_domain::ArithmeticDomain; - use crate::instruction::tests::InstructionBucket; - use crate::instruction::Instruction; - use crate::instruction::InstructionBit; use crate::shared_tests::ProgramAndInput; use crate::stark::tests::*; use crate::table::degree_lowering_table::DegreeLoweringBaseTableColumn; @@ -1951,10 +1950,43 @@ mod tests { } fn generate_opcode_pressure_overview() -> SpecSnippet { + // todo: de-duplicate this from `triton_isa::instruction::tests` + #[derive(Debug, Copy, Clone, EnumCount, EnumIter, VariantNames)] + enum InstructionBucket { + HasArg, + ShrinksStack, + IsU32, + } + + impl InstructionBucket { + pub fn contains(self, instruction: Instruction) -> bool { + match self { + InstructionBucket::HasArg => instruction.arg().is_some(), + InstructionBucket::ShrinksStack => instruction.op_stack_size_influence() < 0, + InstructionBucket::IsU32 => instruction.is_u32_instruction(), + } + } + + pub fn flag(self) -> usize { + match self { + InstructionBucket::HasArg => 1, + InstructionBucket::ShrinksStack => 1 << 1, + InstructionBucket::IsU32 => 1 << 2, + } + } + } + + fn flag_set(instruction: Instruction) -> usize { + InstructionBucket::iter() + .map(|bucket| usize::from(bucket.contains(instruction)) * bucket.flag()) + .fold(0, |acc, bit_flag| acc | bit_flag) + } + // todo: end of duplication + const NUM_FLAG_SETS: usize = 1 << InstructionBucket::COUNT; let mut num_opcodes_per_flag_set = [0; NUM_FLAG_SETS]; for instruction in Instruction::iter() { - num_opcodes_per_flag_set[instruction.flag_set() as usize] += 1; + num_opcodes_per_flag_set[flag_set(instruction)] += 1; } let cell_width = InstructionBucket::VARIANTS diff --git a/triton-vm/src/table/op_stack_table.rs b/triton-vm/src/table/op_stack_table.rs index ac63e2259..7b40bd3b5 100644 --- a/triton-vm/src/table/op_stack_table.rs +++ b/triton-vm/src/table/op_stack_table.rs @@ -6,6 +6,8 @@ use arbitrary::Arbitrary; use constraint_builder::DualRowIndicator::*; use constraint_builder::SingleRowIndicator::*; use constraint_builder::*; +use isa::op_stack::OpStackElement; +use isa::op_stack::UnderflowIO; use itertools::Itertools; use ndarray::parallel::prelude::*; use ndarray::prelude::*; @@ -17,8 +19,6 @@ use twenty_first::prelude::*; use crate::aet::AlgebraicExecutionTrace; use crate::ndarray_helper::contiguous_column_slices; use crate::ndarray_helper::horizontal_multi_slice_mut; -use crate::op_stack::OpStackElement; -use crate::op_stack::UnderflowIO; use crate::profiler::profiler; use crate::table::challenges::ChallengeId::*; use crate::table::challenges::Challenges; @@ -443,14 +443,13 @@ impl OpStackTable { #[cfg(test)] pub(crate) mod tests { use assert2::assert; + use isa::op_stack::OpStackElement; use itertools::Itertools; use proptest::collection::vec; use proptest::prelude::*; use proptest_arbitrary_interop::arb; use test_strategy::proptest; - use crate::op_stack::OpStackElement; - use super::*; #[proptest] diff --git a/triton-vm/src/table/processor_table.rs b/triton-vm/src/table/processor_table.rs index d7780a738..f58083fc6 100644 --- a/triton-vm/src/table/processor_table.rs +++ b/triton-vm/src/table/processor_table.rs @@ -4,6 +4,13 @@ use std::ops::Mul; use constraint_builder::DualRowIndicator::*; use constraint_builder::SingleRowIndicator::*; use constraint_builder::*; +use isa::instruction::AnInstruction::*; +use isa::instruction::Instruction; +use isa::instruction::InstructionBit; +use isa::instruction::ALL_INSTRUCTIONS; +use isa::op_stack::NumberOfWords; +use isa::op_stack::OpStackElement; +use isa::op_stack::NUM_OP_STACK_REGISTERS; use itertools::izip; use itertools::Itertools; use ndarray::parallel::prelude::*; @@ -18,15 +25,8 @@ use twenty_first::prelude::x_field_element::EXTENSION_DEGREE; use twenty_first::prelude::*; use crate::aet::AlgebraicExecutionTrace; -use crate::instruction::AnInstruction::*; -use crate::instruction::Instruction; -use crate::instruction::InstructionBit; -use crate::instruction::ALL_INSTRUCTIONS; use crate::ndarray_helper::contiguous_column_slices; use crate::ndarray_helper::horizontal_multi_slice_mut; -use crate::op_stack::NumberOfWords; -use crate::op_stack::OpStackElement; -use crate::op_stack::NUM_OP_STACK_REGISTERS; use crate::profiler::profiler; use crate::table::challenges::ChallengeId; use crate::table::challenges::ChallengeId::*; @@ -3916,6 +3916,12 @@ pub(crate) mod tests { use std::collections::HashMap; use assert2::assert; + use isa::instruction::Instruction; + use isa::op_stack::NumberOfWords::*; + use isa::op_stack::OpStackElement; + use isa::program::Program; + use isa::triton_asm; + use isa::triton_program; use ndarray::Array2; use proptest::collection::vec; use proptest::prop_assert_eq; @@ -3926,18 +3932,13 @@ pub(crate) mod tests { use test_strategy::proptest; use crate::error::InstructionError::DivisionByZero; - use crate::instruction::Instruction; - use crate::op_stack::NumberOfWords::*; - use crate::op_stack::OpStackElement; use crate::prelude::PublicInput; - use crate::program::Program; use crate::shared_tests::ProgramAndInput; use crate::stark::tests::master_tables_for_low_security_level; use crate::table::master_table::*; - use crate::triton_asm; - use crate::triton_program; use crate::vm::VMState; use crate::vm::NUM_HELPER_VARIABLE_REGISTERS; + use crate::vm::VM; use crate::NonDeterminism; use super::*; @@ -3946,7 +3947,7 @@ pub(crate) mod tests { #[test] fn print_simple_processor_table_row() { let program = triton_program!(push 2 sponge_init assert halt); - let err = program.run([].into(), [].into()).unwrap_err(); + let err = VM::run(&program, [].into(), [].into()).unwrap_err(); println!("\n{}", err.vm_state); } diff --git a/triton-vm/src/table/u32_table.rs b/triton-vm/src/table/u32_table.rs index 748850dc4..d5475061a 100644 --- a/triton-vm/src/table/u32_table.rs +++ b/triton-vm/src/table/u32_table.rs @@ -9,6 +9,7 @@ use constraint_builder::DualRowIndicator::*; use constraint_builder::InputIndicator; use constraint_builder::SingleRowIndicator; use constraint_builder::SingleRowIndicator::*; +use isa::instruction::Instruction; use ndarray::parallel::prelude::*; use ndarray::s; use ndarray::Array1; @@ -22,7 +23,6 @@ use strum::EnumCount; use twenty_first::prelude::*; use crate::aet::AlgebraicExecutionTrace; -use crate::instruction::Instruction; use crate::profiler::profiler; use crate::table::challenges::ChallengeId::*; use crate::table::challenges::Challenges; diff --git a/triton-vm/src/vm.rs b/triton-vm/src/vm.rs index c9407a676..e2d349548 100644 --- a/triton-vm/src/vm.rs +++ b/triton-vm/src/vm.rs @@ -6,6 +6,12 @@ use std::fmt::Result as FmtResult; use std::ops::Range; use arbitrary::Arbitrary; +use isa::error::InstructionError; +use isa::instruction::AnInstruction::*; +use isa::instruction::Instruction; +use isa::op_stack::OpStackElement::*; +use isa::op_stack::*; +use isa::program::Program; use itertools::Itertools; use ndarray::Array1; use num_traits::ConstZero; @@ -17,13 +23,11 @@ use twenty_first::math::x_field_element::EXTENSION_DEGREE; use twenty_first::prelude::*; use twenty_first::util_types::algebraic_hasher::Domain; -use crate::error::InstructionError; -use crate::error::InstructionError::*; -use crate::instruction::AnInstruction::*; -use crate::instruction::Instruction; -use crate::op_stack::OpStackElement::*; -use crate::op_stack::*; -use crate::program::*; +use crate::aet::AlgebraicExecutionTrace; +use crate::error::VMError; +use crate::execution_trace_profiler::ExecutionTraceProfile; +use crate::execution_trace_profiler::ExecutionTraceProfiler; +use crate::profiler::profiler; use crate::table::hash_table::PermutationTrace; use crate::table::op_stack_table::OpStackTableEntry; use crate::table::processor_table; @@ -32,11 +36,15 @@ use crate::table::table_column::*; use crate::table::u32_table::U32TableEntry; use crate::vm::CoProcessorCall::*; -type Result = std::result::Result; +type VMResult = Result; +type InstructionResult = Result; /// The number of helper variable registers pub const NUM_HELPER_VARIABLE_REGISTERS: usize = 6; +#[derive(Debug, Copy, Clone, Eq, PartialEq, Serialize, Deserialize, Arbitrary)] +pub struct VM; + #[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize, Arbitrary)] pub struct VMState { /// The **program memory** stores the instructions (and their arguments) of the program @@ -104,6 +112,123 @@ pub enum CoProcessorCall { RamCall(RamTableCall), } +impl VM { + /// Run Triton VM on the [`Program`] with the given public input and non-determinism. + /// If an error is encountered, the returned [`VMError`] contains the [`VMState`] at the point + /// of execution failure. + /// + /// See also [`trace_execution`][trace_execution] and [`profile`][profile]. + /// + /// [trace_execution]: Self::trace_execution + /// [profile]: Self::profile + pub fn run( + program: &Program, + public_input: PublicInput, + non_determinism: NonDeterminism, + ) -> VMResult> { + let mut state = VMState::new(program, public_input, non_determinism); + if let Err(err) = state.run() { + return Err(VMError::new(err, state)); + } + Ok(state.public_output) + } + + /// Trace the execution of a [`Program`]. That is, [`run`][run] the [`Program`] and additionally + /// record that part of every encountered state that is necessary for proving correct execution. + /// If execution succeeds, returns + /// 1. an [`AlgebraicExecutionTrace`], and + /// 1. the output of the program. + /// + /// See also [`run`][run] and [`profile`][profile]. + /// + /// [run]: Self::run + /// [profile]: Self::profile + pub fn trace_execution( + program: &Program, + public_input: PublicInput, + non_determinism: NonDeterminism, + ) -> VMResult<(AlgebraicExecutionTrace, Vec)> { + profiler!(start "trace execution" ("gen")); + let state = VMState::new(program, public_input, non_determinism); + let (aet, terminal_state) = Self::trace_execution_of_state(program, state)?; + profiler!(stop "trace execution"); + Ok((aet, terminal_state.public_output)) + } + + /// Trace the execution of a [`Program`] from a given [`VMState`]. Consider + /// using [`trace_execution`][Self::trace_execution], unless you know this is + /// what you want. + /// + /// Returns the [`AlgebraicExecutionTrace`] and the terminal [`VMState`] if + /// execution succeeds. + /// + /// # Panics + /// + /// - if the given [`VMState`] is not about to `self` + /// - if the given [`VMState`] is incorrectly initialized + pub fn trace_execution_of_state( + program: &Program, + mut state: VMState, + ) -> VMResult<(AlgebraicExecutionTrace, VMState)> { + let mut aet = AlgebraicExecutionTrace::new(program.clone()); + assert_eq!(program.instructions, state.program); + assert_eq!(program.len_bwords(), aet.instruction_multiplicities.len()); + + while !state.halting { + if let Err(err) = aet.record_state(&state) { + return Err(VMError::new(err, state)); + }; + let co_processor_calls = match state.step() { + Ok(calls) => calls, + Err(err) => return Err(VMError::new(err, state)), + }; + for call in co_processor_calls { + aet.record_co_processor_call(call); + } + } + + Ok((aet, state)) + } + + /// Run Triton VM with the given public and secret input, recording the + /// influence of a callable block of instructions on the + /// [`AlgebraicExecutionTrace`]. For example, this can be used to identify the + /// number of clock cycles spent in some block of instructions, or how many rows + /// it contributes to the U32 Table. + /// + /// See also [`run`][run] and [`trace_execution`][trace_execution]. + /// + /// [run]: Self::run + /// [trace_execution]: Self::trace_execution + pub fn profile( + program: &Program, + public_input: PublicInput, + non_determinism: NonDeterminism, + ) -> VMResult<(Vec, ExecutionTraceProfile)> { + let mut profiler = ExecutionTraceProfiler::new(program.instructions.len()); + let mut state = VMState::new(program, public_input, non_determinism); + let mut previous_jump_stack_len = state.jump_stack.len(); + while !state.halting { + if let Ok(Instruction::Call(address)) = state.current_instruction() { + let label = program.label_for_address(address.value()); + profiler.enter_span(label); + } + + match state.step() { + Ok(calls) => profiler.handle_co_processor_calls(calls), + Err(err) => return Err(VMError::new(err, state)), + }; + + if state.jump_stack.len() < previous_jump_stack_len { + profiler.exit_span(); + } + previous_jump_stack_len = state.jump_stack.len(); + } + + Ok((state.public_output, profiler.finish())) + } +} + impl VMState { /// Create initial `VMState` for a given `program` /// @@ -222,15 +347,15 @@ impl VMState { } /// Perform the state transition as a mutable operation on `self`. - pub fn step(&mut self) -> Result> { + pub fn step(&mut self) -> InstructionResult> { if self.halting { - return Err(MachineHalted); + return Err(InstructionError::MachineHalted); } let current_instruction = self.current_instruction()?; let op_stack_delta = current_instruction.op_stack_size_influence(); if self.op_stack.would_be_too_shallow(op_stack_delta) { - return Err(OpStackTooShallow); + return Err(InstructionError::OpStackError(OpStackError::TooShallow)); } self.start_recording_op_stack_calls(); @@ -320,7 +445,7 @@ impl VMState { self.ram_calls.drain(..).map(RamCall).collect() } - fn pop(&mut self, n: NumberOfWords) -> Result> { + fn pop(&mut self, n: NumberOfWords) -> InstructionResult> { for _ in 0..n.num_words() { self.op_stack.pop()?; } @@ -336,10 +461,10 @@ impl VMState { vec![] } - fn divine(&mut self, n: NumberOfWords) -> Result> { + fn divine(&mut self, n: NumberOfWords) -> InstructionResult> { let input_len = self.secret_individual_tokens.len(); if input_len < n.num_words() { - return Err(EmptySecretInput(input_len)); + return Err(InstructionError::EmptySecretInput(input_len)); } for _ in 0..n.num_words() { let element = self.secret_individual_tokens.pop_front().unwrap(); @@ -369,7 +494,7 @@ impl VMState { vec![] } - fn skiz(&mut self) -> Result> { + fn skiz(&mut self) -> InstructionResult> { let top_of_stack = self.op_stack.pop()?; self.instruction_pointer += match top_of_stack.is_zero() { true => 1 + self.next_instruction()?.size(), @@ -388,21 +513,21 @@ impl VMState { vec![] } - fn return_from_call(&mut self) -> Result> { + fn return_from_call(&mut self) -> InstructionResult> { let (call_origin, _) = self.jump_stack_pop()?; self.instruction_pointer = call_origin.value().try_into().unwrap(); Ok(vec![]) } - fn recurse(&mut self) -> Result> { + fn recurse(&mut self) -> InstructionResult> { let (_, call_destination) = self.jump_stack_peek()?; self.instruction_pointer = call_destination.value().try_into().unwrap(); Ok(vec![]) } - fn recurse_or_return(&mut self) -> Result> { + fn recurse_or_return(&mut self) -> InstructionResult> { if self.jump_stack.is_empty() { - return Err(JumpStackIsEmpty); + return Err(InstructionError::JumpStackIsEmpty); } let new_ip = if self.op_stack[ST5] == self.op_stack[ST6] { @@ -418,9 +543,9 @@ impl VMState { Ok(vec![]) } - fn assert(&mut self) -> Result> { + fn assert(&mut self) -> InstructionResult> { if !self.op_stack[ST0].is_one() { - return Err(AssertionFailed); + return Err(InstructionError::AssertionFailed); } let _ = self.op_stack.pop()?; @@ -434,7 +559,7 @@ impl VMState { vec![] } - fn read_mem(&mut self, n: NumberOfWords) -> Result> { + fn read_mem(&mut self, n: NumberOfWords) -> InstructionResult> { self.start_recording_ram_calls(); let mut ram_pointer = self.op_stack.pop()?; for _ in 0..n.num_words() { @@ -449,7 +574,7 @@ impl VMState { Ok(ram_calls) } - fn write_mem(&mut self, n: NumberOfWords) -> Result> { + fn write_mem(&mut self, n: NumberOfWords) -> InstructionResult> { self.start_recording_ram_calls(); let mut ram_pointer = self.op_stack.pop()?; for _ in 0..n.num_words() { @@ -494,7 +619,7 @@ impl VMState { self.ram.insert(ram_pointer, ram_value); } - fn hash(&mut self) -> Result> { + fn hash(&mut self) -> InstructionResult> { let to_hash = self.op_stack.pop_multiple::<{ tip5::RATE }>()?; let mut hash_input = Tip5::new(Domain::FixedLength); @@ -518,9 +643,9 @@ impl VMState { vec![SpongeStateReset] } - fn sponge_absorb(&mut self) -> Result> { + fn sponge_absorb(&mut self) -> InstructionResult> { let Some(ref mut sponge) = self.sponge else { - return Err(SpongeNotInitialized); + return Err(InstructionError::SpongeNotInitialized); }; let to_absorb = self.op_stack.pop_multiple::<{ tip5::RATE }>()?; sponge.state[..tip5::RATE].copy_from_slice(&to_absorb); @@ -532,9 +657,9 @@ impl VMState { Ok(co_processor_calls) } - fn sponge_absorb_mem(&mut self) -> Result> { + fn sponge_absorb_mem(&mut self) -> InstructionResult> { let Some(mut sponge) = self.sponge.take() else { - return Err(SpongeNotInitialized); + return Err(InstructionError::SpongeNotInitialized); }; self.start_recording_ram_calls(); @@ -561,9 +686,9 @@ impl VMState { Ok(co_processor_calls) } - fn sponge_squeeze(&mut self) -> Result> { + fn sponge_squeeze(&mut self) -> InstructionResult> { let Some(ref mut sponge) = self.sponge else { - return Err(SpongeNotInitialized); + return Err(InstructionError::SpongeNotInitialized); }; for i in (0..tip5::RATE).rev() { self.op_stack.push(sponge.state[i]); @@ -576,10 +701,10 @@ impl VMState { Ok(co_processor_calls) } - fn assert_vector(&mut self) -> Result> { + fn assert_vector(&mut self) -> InstructionResult> { for i in 0..Digest::LEN { if self.op_stack[i] != self.op_stack[i + Digest::LEN] { - return Err(VectorAssertionFailed(i)); + return Err(InstructionError::VectorAssertionFailed(i)); } } self.op_stack.pop_multiple::<{ Digest::LEN }>()?; @@ -587,7 +712,7 @@ impl VMState { Ok(vec![]) } - fn add(&mut self) -> Result> { + fn add(&mut self) -> InstructionResult> { let lhs = self.op_stack.pop()?; let rhs = self.op_stack.pop()?; self.op_stack.push(lhs + rhs); @@ -602,7 +727,7 @@ impl VMState { vec![] } - fn mul(&mut self) -> Result> { + fn mul(&mut self) -> InstructionResult> { let lhs = self.op_stack.pop()?; let rhs = self.op_stack.pop()?; self.op_stack.push(lhs * rhs); @@ -611,10 +736,10 @@ impl VMState { Ok(vec![]) } - fn invert(&mut self) -> Result> { + fn invert(&mut self) -> InstructionResult> { let top_of_stack = self.op_stack[ST0]; if top_of_stack.is_zero() { - return Err(InverseOfZero); + return Err(InstructionError::InverseOfZero); } let _ = self.op_stack.pop()?; self.op_stack.push(top_of_stack.inverse()); @@ -622,7 +747,7 @@ impl VMState { Ok(vec![]) } - fn eq(&mut self) -> Result> { + fn eq(&mut self) -> InstructionResult> { let lhs = self.op_stack.pop()?; let rhs = self.op_stack.pop()?; let eq: u32 = (lhs == rhs).into(); @@ -632,7 +757,7 @@ impl VMState { Ok(vec![]) } - fn split(&mut self) -> Result> { + fn split(&mut self) -> InstructionResult> { let top_of_stack = self.op_stack.pop()?; let lo = bfe!(top_of_stack.value() & 0xffff_ffff); let hi = bfe!(top_of_stack.value() >> 32); @@ -646,7 +771,7 @@ impl VMState { Ok(co_processor_calls) } - fn lt(&mut self) -> Result> { + fn lt(&mut self) -> InstructionResult> { self.op_stack.is_u32(ST0)?; self.op_stack.is_u32(ST1)?; let lhs = self.op_stack.pop_u32()?; @@ -661,7 +786,7 @@ impl VMState { Ok(co_processor_calls) } - fn and(&mut self) -> Result> { + fn and(&mut self) -> InstructionResult> { self.op_stack.is_u32(ST0)?; self.op_stack.is_u32(ST1)?; let lhs = self.op_stack.pop_u32()?; @@ -676,7 +801,7 @@ impl VMState { Ok(co_processor_calls) } - fn xor(&mut self) -> Result> { + fn xor(&mut self) -> InstructionResult> { self.op_stack.is_u32(ST0)?; self.op_stack.is_u32(ST1)?; let lhs = self.op_stack.pop_u32()?; @@ -694,11 +819,11 @@ impl VMState { Ok(co_processor_calls) } - fn log_2_floor(&mut self) -> Result> { + fn log_2_floor(&mut self) -> InstructionResult> { self.op_stack.is_u32(ST0)?; let top_of_stack = self.op_stack[ST0]; if top_of_stack.is_zero() { - return Err(LogarithmOfZero); + return Err(InstructionError::LogarithmOfZero); } let top_of_stack = self.op_stack.pop_u32()?; let log_2_floor = top_of_stack.ilog2(); @@ -711,7 +836,7 @@ impl VMState { Ok(co_processor_calls) } - fn pow(&mut self) -> Result> { + fn pow(&mut self) -> InstructionResult> { self.op_stack.is_u32(ST1)?; let base = self.op_stack.pop()?; let exponent = self.op_stack.pop_u32()?; @@ -725,12 +850,12 @@ impl VMState { Ok(co_processor_calls) } - fn div_mod(&mut self) -> Result> { + fn div_mod(&mut self) -> InstructionResult> { self.op_stack.is_u32(ST0)?; self.op_stack.is_u32(ST1)?; let denominator = self.op_stack[ST1]; if denominator.is_zero() { - return Err(DivisionByZero); + return Err(InstructionError::DivisionByZero); } let numerator = self.op_stack.pop_u32()?; @@ -752,7 +877,7 @@ impl VMState { Ok(co_processor_calls) } - fn pop_count(&mut self) -> Result> { + fn pop_count(&mut self) -> InstructionResult> { self.op_stack.is_u32(ST0)?; let top_of_stack = self.op_stack.pop_u32()?; let pop_count = top_of_stack.count_ones(); @@ -765,7 +890,7 @@ impl VMState { Ok(co_processor_calls) } - fn xx_add(&mut self) -> Result> { + fn xx_add(&mut self) -> InstructionResult> { let lhs = self.op_stack.pop_extension_field_element()?; let rhs = self.op_stack.pop_extension_field_element()?; self.op_stack.push_extension_field_element(lhs + rhs); @@ -773,7 +898,7 @@ impl VMState { Ok(vec![]) } - fn xx_mul(&mut self) -> Result> { + fn xx_mul(&mut self) -> InstructionResult> { let lhs = self.op_stack.pop_extension_field_element()?; let rhs = self.op_stack.pop_extension_field_element()?; self.op_stack.push_extension_field_element(lhs * rhs); @@ -781,10 +906,10 @@ impl VMState { Ok(vec![]) } - fn x_invert(&mut self) -> Result> { + fn x_invert(&mut self) -> InstructionResult> { let top_of_stack = self.op_stack.peek_at_top_extension_field_element(); if top_of_stack.is_zero() { - return Err(InverseOfZero); + return Err(InstructionError::InverseOfZero); } let inverse = top_of_stack.inverse(); let _ = self.op_stack.pop_extension_field_element()?; @@ -793,7 +918,7 @@ impl VMState { Ok(vec![]) } - fn xb_mul(&mut self) -> Result> { + fn xb_mul(&mut self) -> InstructionResult> { let lhs = self.op_stack.pop()?; let rhs = self.op_stack.pop_extension_field_element()?; self.op_stack.push_extension_field_element(lhs.lift() * rhs); @@ -802,7 +927,7 @@ impl VMState { Ok(vec![]) } - fn write_io(&mut self, n: NumberOfWords) -> Result> { + fn write_io(&mut self, n: NumberOfWords) -> InstructionResult> { for _ in 0..n.num_words() { let top_of_stack = self.op_stack.pop()?; self.public_output.push(top_of_stack); @@ -812,10 +937,10 @@ impl VMState { Ok(vec![]) } - fn read_io(&mut self, n: NumberOfWords) -> Result> { + fn read_io(&mut self, n: NumberOfWords) -> InstructionResult> { let input_len = self.public_input.len(); if input_len < n.num_words() { - return Err(EmptyPublicInput(input_len)); + return Err(InstructionError::EmptyPublicInput(input_len)); } for _ in 0..n.num_words() { let read_element = self.public_input.pop_front().unwrap(); @@ -826,13 +951,13 @@ impl VMState { Ok(vec![]) } - fn merkle_step_non_determinism(&mut self) -> Result> { + fn merkle_step_non_determinism(&mut self) -> InstructionResult> { self.op_stack.is_u32(ST5)?; let sibling_digest = self.pop_secret_digest()?; self.merkle_step(sibling_digest) } - fn merkle_step_mem(&mut self) -> Result> { + fn merkle_step_mem(&mut self) -> InstructionResult> { self.op_stack.is_u32(ST5)?; self.start_recording_ram_calls(); let mut ram_pointer = self.op_stack[ST7]; @@ -851,7 +976,7 @@ impl VMState { fn merkle_step( &mut self, sibling_digest: [BFieldElement; Digest::LEN], - ) -> Result> { + ) -> InstructionResult> { let node_index = self.op_stack.get_u32(ST5)?; let parent_node_index = node_index / 2; @@ -881,7 +1006,7 @@ impl VMState { Ok(co_processor_calls) } - fn xx_dot_step(&mut self) -> Result> { + fn xx_dot_step(&mut self) -> InstructionResult> { self.start_recording_ram_calls(); let mut rhs_address = self.op_stack.pop()?; let mut lhs_address = self.op_stack.pop()?; @@ -902,7 +1027,7 @@ impl VMState { Ok(ram_calls) } - fn xb_dot_step(&mut self) -> Result> { + fn xb_dot_step(&mut self) -> InstructionResult> { self.start_recording_ram_calls(); let mut rhs_address = self.op_stack.pop()?; let mut lhs_address = self.op_stack.pop()?; @@ -923,7 +1048,7 @@ impl VMState { } pub fn to_processor_row(&self) -> Array1 { - use crate::instruction::InstructionBit; + use isa::instruction::InstructionBit; use ProcessorBaseTableColumn::*; let mut processor_row = Array1::zeros(processor_table::BASE_WIDTH); @@ -1006,9 +1131,9 @@ impl VMState { maybe_destination.unwrap_or_else(BFieldElement::zero) } - pub fn current_instruction(&self) -> Result { + pub fn current_instruction(&self) -> InstructionResult { let maybe_current_instruction = self.program.get(self.instruction_pointer).copied(); - maybe_current_instruction.ok_or(InstructionPointerOverflow) + maybe_current_instruction.ok_or(InstructionError::InstructionPointerOverflow) } /// Return the next instruction on the tape, skipping arguments. @@ -1016,31 +1141,36 @@ impl VMState { /// Note that this is not necessarily the next instruction to execute, since the current /// instruction could be a jump, but it is either program[ip + 1] or program[ip + 2], /// depending on whether the current instruction takes an argument. - pub fn next_instruction(&self) -> Result { + pub fn next_instruction(&self) -> InstructionResult { let current_instruction = self.current_instruction()?; let next_instruction_pointer = self.instruction_pointer + current_instruction.size(); let maybe_next_instruction = self.program.get(next_instruction_pointer).copied(); - maybe_next_instruction.ok_or(InstructionPointerOverflow) + maybe_next_instruction.ok_or(InstructionError::InstructionPointerOverflow) } - fn jump_stack_pop(&mut self) -> Result<(BFieldElement, BFieldElement)> { - self.jump_stack.pop().ok_or(JumpStackIsEmpty) + fn jump_stack_pop(&mut self) -> InstructionResult<(BFieldElement, BFieldElement)> { + self.jump_stack + .pop() + .ok_or(InstructionError::JumpStackIsEmpty) } - fn jump_stack_peek(&mut self) -> Result<(BFieldElement, BFieldElement)> { - self.jump_stack.last().copied().ok_or(JumpStackIsEmpty) + fn jump_stack_peek(&mut self) -> InstructionResult<(BFieldElement, BFieldElement)> { + self.jump_stack + .last() + .copied() + .ok_or(InstructionError::JumpStackIsEmpty) } - fn pop_secret_digest(&mut self) -> Result<[BFieldElement; Digest::LEN]> { + fn pop_secret_digest(&mut self) -> InstructionResult<[BFieldElement; Digest::LEN]> { let digest = self .secret_digests .pop_front() - .ok_or(EmptySecretDigestInput)?; + .ok_or(InstructionError::EmptySecretDigestInput)?; Ok(digest.values()) } /// Run Triton VM on this state to completion, or until an error occurs. - pub fn run(&mut self) -> Result<()> { + pub fn run(&mut self) -> InstructionResult<()> { while !self.halting { self.step()?; } @@ -1164,6 +1294,97 @@ impl Display for VMState { } } +#[derive(Debug, Default, Clone, Eq, PartialEq, BFieldCodec, Arbitrary)] +pub struct PublicInput { + pub individual_tokens: Vec, +} + +impl From> for PublicInput { + fn from(individual_tokens: Vec) -> Self { + Self::new(individual_tokens) + } +} + +impl From<&Vec> for PublicInput { + fn from(tokens: &Vec) -> Self { + Self::new(tokens.to_owned()) + } +} + +impl From<[BFieldElement; N]> for PublicInput { + fn from(tokens: [BFieldElement; N]) -> Self { + Self::new(tokens.to_vec()) + } +} + +impl From<&[BFieldElement]> for PublicInput { + fn from(tokens: &[BFieldElement]) -> Self { + Self::new(tokens.to_vec()) + } +} + +impl PublicInput { + pub fn new(individual_tokens: Vec) -> Self { + Self { individual_tokens } + } +} + +/// All sources of non-determinism for a program. This includes elements that +/// can be read using instruction `divine`, digests that can be read using +/// instruction `merkle_step`, and an initial state of random-access memory. +#[derive(Debug, Default, Clone, Eq, PartialEq, Serialize, Deserialize, Arbitrary)] +pub struct NonDeterminism { + pub individual_tokens: Vec, + pub digests: Vec, + pub ram: HashMap, +} + +impl From> for NonDeterminism { + fn from(tokens: Vec) -> Self { + Self::new(tokens) + } +} + +impl From<&Vec> for NonDeterminism { + fn from(tokens: &Vec) -> Self { + Self::new(tokens.to_owned()) + } +} + +impl From<[BFieldElement; N]> for NonDeterminism { + fn from(tokens: [BFieldElement; N]) -> Self { + Self::new(tokens.to_vec()) + } +} + +impl From<&[BFieldElement]> for NonDeterminism { + fn from(tokens: &[BFieldElement]) -> Self { + Self::new(tokens.to_vec()) + } +} + +impl NonDeterminism { + pub fn new>>(individual_tokens: V) -> Self { + Self { + individual_tokens: individual_tokens.into(), + digests: vec![], + ram: HashMap::new(), + } + } + + #[must_use] + pub fn with_digests>>(mut self, digests: V) -> Self { + self.digests = digests.into(); + self + } + + #[must_use] + pub fn with_ram>>(mut self, ram: H) -> Self { + self.ram = ram.into(); + self + } +} + #[cfg(test)] pub(crate) mod tests { use std::ops::BitAnd; @@ -1171,6 +1392,13 @@ pub(crate) mod tests { use assert2::assert; use assert2::let_assert; + use isa::instruction::AnInstruction; + use isa::instruction::LabelledInstruction; + use isa::instruction::ALL_INSTRUCTIONS; + use isa::program::Program; + use isa::triton_asm; + use isa::triton_instr; + use isa::triton_program; use itertools::izip; use proptest::collection::vec; use proptest::prelude::*; @@ -1188,19 +1416,119 @@ pub(crate) mod tests { use crate::shared_tests::LeavedMerkleTreeTestData; use crate::shared_tests::ProgramAndInput; use crate::shared_tests::DEFAULT_LOG2_FRI_EXPANSION_FACTOR_FOR_TESTS; - use crate::triton_asm; - use crate::triton_instr; - use crate::triton_program; - use crate::LabelledInstruction; + use crate::table::master_table::TableId; use super::*; + #[test] + fn instructions_act_on_op_stack_as_indicated() { + for test_instruction in ALL_INSTRUCTIONS { + let (program, stack_size_before_test_instruction) = + construct_test_program_for_instruction(test_instruction); + let public_input = PublicInput::from(bfe_array![0]); + let mock_digests = [Digest::default()]; + let non_determinism = NonDeterminism::from(bfe_array![0]).with_digests(mock_digests); + + let mut vm_state = VMState::new(&program, public_input, non_determinism); + let_assert!(Ok(()) = vm_state.run()); + let stack_size_after_test_instruction = vm_state.op_stack.len(); + + let stack_size_difference = (stack_size_after_test_instruction as i32) + - (stack_size_before_test_instruction as i32); + assert!( + test_instruction.op_stack_size_influence() == stack_size_difference, + "{test_instruction}" + ); + } + } + + fn construct_test_program_for_instruction( + instruction: AnInstruction, + ) -> (Program, usize) { + if matches!(instruction, Call(_) | Return | Recurse | RecurseOrReturn) { + // need jump stack setup + let program = test_program_for_call_recurse_return().program; + let stack_size = NUM_OP_STACK_REGISTERS; + (program, stack_size) + } else { + let num_push_instructions = 10; + let push_instructions = triton_asm![push 1; num_push_instructions]; + let program = triton_program!(sponge_init {&push_instructions} {instruction} nop halt); + + let stack_size_when_reaching_test_instruction = + NUM_OP_STACK_REGISTERS + num_push_instructions; + (program, stack_size_when_reaching_test_instruction) + } + } + + #[test] + fn profile_can_be_created_and_agrees_with_regular_vm_run() { + let program = CALCULATE_NEW_MMR_PEAKS_FROM_APPEND_WITH_SAFE_LISTS.clone(); + let (profile_output, profile) = VM::profile(&program, [].into(), [].into()).unwrap(); + let mut vm_state = VMState::new(&program, [].into(), [].into()); + let_assert!(Ok(()) = vm_state.run()); + assert!(profile_output == vm_state.public_output); + assert!(profile.total.processor == vm_state.cycle_count); + + let_assert!(Ok((aet, trace_output)) = VM::trace_execution(&program, [].into(), [].into())); + assert!(profile_output == trace_output); + let proc_height = u32::try_from(aet.height_of_table(TableId::Processor)).unwrap(); + assert!(proc_height == profile.total.processor); + + let op_stack_height = u32::try_from(aet.height_of_table(TableId::OpStack)).unwrap(); + assert!(op_stack_height == profile.total.op_stack); + + let ram_height = u32::try_from(aet.height_of_table(TableId::Ram)).unwrap(); + assert!(ram_height == profile.total.ram); + + let hash_height = u32::try_from(aet.height_of_table(TableId::Hash)).unwrap(); + assert!(hash_height == profile.total.hash); + + let u32_height = u32::try_from(aet.height_of_table(TableId::U32)).unwrap(); + assert!(u32_height == profile.total.u32); + + println!("{profile}"); + } + + #[test] + fn program_with_too_many_returns_crashes_vm_but_not_profiler() { + let program = triton_program! { + call foo return halt + foo: return + }; + let_assert!(Err(err) = VM::profile(&program, [].into(), [].into())); + let_assert!(InstructionError::JumpStackIsEmpty = err.source); + } + + #[proptest] + fn from_various_types_to_public_input(#[strategy(arb())] tokens: Vec) { + let public_input = PublicInput::new(tokens.clone()); + + assert!(public_input == tokens.clone().into()); + assert!(public_input == (&tokens).into()); + assert!(public_input == tokens[..].into()); + assert!(public_input == (&tokens[..]).into()); + + assert!(PublicInput::new(vec![]) == [].into()); + } + + #[proptest] + fn from_various_types_to_non_determinism(#[strategy(arb())] tokens: Vec) { + let non_determinism = NonDeterminism::new(tokens.clone()); + + assert!(non_determinism == tokens.clone().into()); + assert!(non_determinism == tokens[..].into()); + assert!(non_determinism == (&tokens[..]).into()); + + assert!(NonDeterminism::new(vec![]) == [].into()); + } + #[test] fn initialise_table() { let program = GREATEST_COMMON_DIVISOR.clone(); let stdin = PublicInput::from([42, 56].map(|b| bfe!(b))); let secret_in = NonDeterminism::default(); - program.trace_execution(stdin, secret_in).unwrap(); + VM::trace_execution(&program, stdin, secret_in).unwrap(); } #[test] @@ -1208,7 +1536,7 @@ pub(crate) mod tests { let program = GREATEST_COMMON_DIVISOR.clone(); let stdin = PublicInput::from([42, 56].map(|b| bfe!(b))); let secret_in = NonDeterminism::default(); - let_assert!(Ok(stdout) = program.run(stdin, secret_in)); + let_assert!(Ok(stdout) = VM::run(&program, stdin, secret_in)); let output = stdout.iter().map(|o| format!("{o}")).join(", "); println!("VM output: [{output}]"); @@ -1219,7 +1547,7 @@ pub(crate) mod tests { #[test] fn crash_triton_vm_and_print_vm_error() { let crashing_program = triton_program!(push 2 assert halt); - let_assert!(Err(err) = crashing_program.run([].into(), [].into())); + let_assert!(Err(err) = VM::run(&crashing_program, [].into(), [].into())); println!("{err}"); } @@ -1351,8 +1679,10 @@ pub(crate) mod tests { #[test] fn vm_crashes_when_executing_recurse_or_return_with_empty_jump_stack() { let program = triton_program!(recurse_or_return halt); - let_assert!(Err(err) = program.run(PublicInput::default(), NonDeterminism::default())); - assert!(JumpStackIsEmpty == err.source); + let_assert!( + Err(err) = VM::run(&program, PublicInput::default(), NonDeterminism::default()) + ); + assert!(InstructionError::JumpStackIsEmpty == err.source); } pub(crate) fn test_program_for_write_mem_read_mem() -> ProgramAndInput { @@ -2051,9 +2381,7 @@ pub(crate) mod tests { #[test] fn can_compute_dot_product_from_uninitialized_ram() { let program = triton_program!(xx_dot_step xb_dot_step halt); - program - .run(PublicInput::default(), NonDeterminism::default()) - .unwrap(); + VM::run(&program, PublicInput::default(), NonDeterminism::default()).unwrap(); } pub(crate) fn property_based_test_program_for_xx_dot_step() -> ProgramAndInput { @@ -2208,7 +2536,7 @@ pub(crate) mod tests { ); let program_and_input = ProgramAndInput::new(program); let_assert!(Err(err) = program_and_input.run()); - let_assert!(AssertionFailed = err.source); + let_assert!(InstructionError::AssertionFailed = err.source); } pub(crate) fn test_program_for_split() -> ProgramAndInput { @@ -2289,7 +2617,7 @@ pub(crate) mod tests { write_io 3 halt ); - let actual_stdout = program.run([].into(), [].into())?; + let actual_stdout = VM::run(&program, [].into(), [].into())?; let expected_stdout = (left_operand + right_operand).coefficients.to_vec(); prop_assert_eq!(expected_stdout, actual_stdout); } @@ -2310,7 +2638,7 @@ pub(crate) mod tests { write_io 3 halt ); - let actual_stdout = program.run([].into(), [].into())?; + let actual_stdout = VM::run(&program, [].into(), [].into())?; let expected_stdout = (left_operand * right_operand).coefficients.to_vec(); prop_assert_eq!(expected_stdout, actual_stdout); } @@ -2329,7 +2657,7 @@ pub(crate) mod tests { write_io 3 halt ); - let actual_stdout = program.run([].into(), [].into())?; + let actual_stdout = VM::run(&program, [].into(), [].into())?; let expected_stdout = operand.inverse().coefficients.to_vec(); prop_assert_eq!(expected_stdout, actual_stdout); } @@ -2345,7 +2673,7 @@ pub(crate) mod tests { write_io 3 halt ); - let actual_stdout = program.run([].into(), [].into())?; + let actual_stdout = VM::run(&program, [].into(), [].into())?; let expected_stdout = (scalar * operand).coefficients.to_vec(); prop_assert_eq!(expected_stdout, actual_stdout); } @@ -2421,7 +2749,7 @@ pub(crate) mod tests { #[test] fn run_tvm_halt_then_do_stuff() { let program = triton_program!(halt push 1 push 2 add invert write_io 5); - let_assert!(Ok((aet, _)) = program.trace_execution([].into(), [].into())); + let_assert!(Ok((aet, _)) = VM::trace_execution(&program, [].into(), [].into())); let_assert!(Some(last_processor_row) = aet.processor_trace.rows().into_iter().last()); let clk_count = last_processor_row[ProcessorBaseTableColumn::CLK.base_table_index()]; @@ -2509,7 +2837,7 @@ pub(crate) mod tests { } let program = MERKLE_TREE_AUTHENTICATION_PATH_VERIFY.clone(); - assert!(let Ok(_) = program.run(public_input.into(), non_determinism)); + assert!(let Ok(_) = VM::run(&program, public_input.into(), non_determinism)); } #[proptest] @@ -2551,7 +2879,7 @@ pub(crate) mod tests { ); let public_input = vec![p2_x, p1.1, p1.0, p0.1, p0.0]; - let_assert!(Ok(output) = get_collinear_y_program.run(public_input.into(), [].into())); + let_assert!(Ok(output) = VM::run(&get_collinear_y_program, public_input.into(), [].into())); prop_assert_eq!(p2_y, output[0]); } @@ -2573,7 +2901,7 @@ pub(crate) mod tests { halt ); - let_assert!(Ok(standard_out) = countdown_program.run([].into(), [].into())); + let_assert!(Ok(standard_out) = VM::run(&countdown_program, [].into(), [].into())); let expected = (0..=10).map(BFieldElement::new).rev().collect_vec(); assert!(expected == standard_out); } @@ -2592,21 +2920,21 @@ pub(crate) mod tests { #[test] fn run_tvm_swap() { let program = triton_program!(push 1 push 2 swap 1 assert write_io 1 halt); - let_assert!(Ok(standard_out) = program.run([].into(), [].into())); + let_assert!(Ok(standard_out) = VM::run(&program, [].into(), [].into())); assert!(bfe!(2) == standard_out[0]); } #[test] fn swap_st0_is_like_no_op() { let program = triton_program!(push 42 swap 0 write_io 1 halt); - let_assert!(Ok(standard_out) = program.run([].into(), [].into())); + let_assert!(Ok(standard_out) = VM::run(&program, [].into(), [].into())); assert!(bfe!(42) == standard_out[0]); } #[test] fn read_mem_uninitialized() { let program = triton_program!(read_mem 3 halt); - let_assert!(Ok((aet, _)) = program.trace_execution([].into(), [].into())); + let_assert!(Ok((aet, _)) = VM::trace_execution(&program, [].into(), [].into())); assert!(2 == aet.processor_trace.nrows()); } @@ -2662,8 +2990,8 @@ pub(crate) mod tests { #[test] fn program_without_halt() { let program = triton_program!(nop); - let_assert!(Err(err) = program.trace_execution([].into(), [].into())); - let_assert!(InstructionPointerOverflow = err.source); + let_assert!(Err(err) = VM::trace_execution(&program, [].into(), [].into())); + let_assert!(InstructionError::InstructionPointerOverflow = err.source); } #[test] @@ -2685,7 +3013,7 @@ pub(crate) mod tests { let std_in = PublicInput::from(sudoku.map(|b| bfe!(b))); let secret_in = NonDeterminism::default(); - assert!(let Ok(_) = program.trace_execution(std_in, secret_in)); + assert!(let Ok(_) = VM::trace_execution(&program, std_in, secret_in)); // rows and columns adhere to Sudoku rules, boxes do not let bad_sudoku = [ @@ -2703,8 +3031,8 @@ pub(crate) mod tests { ]; let bad_std_in = PublicInput::from(bad_sudoku.map(|b| bfe!(b))); let secret_in = NonDeterminism::default(); - let_assert!(Err(err) = program.trace_execution(bad_std_in, secret_in)); - let_assert!(AssertionFailed = err.source); + let_assert!(Err(err) = VM::trace_execution(&program, bad_std_in, secret_in)); + let_assert!(InstructionError::AssertionFailed = err.source); } fn instruction_does_not_change_vm_state_when_crashing_vm( From 59e2c143e24ed4c895e937e70f9b4d882ddf6a57 Mon Sep 17 00:00:00 2001 From: Jan Ferdinand Sauer Date: Fri, 30 Aug 2024 10:56:27 +0200 Subject: [PATCH 05/15] =?UTF-8?q?refactor:=20Rename=20constraint=20?= =?UTF-8?q?=E2=80=9Cbuilder=E2=80=9D=20to=20=E2=80=9Ccircuit=E2=80=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.toml | 4 ++- .../Cargo.toml | 2 +- .../src/lib.rs | 0 triton-vm/Cargo.toml | 4 +-- triton-vm/src/codegen/constraints.rs | 12 ++++---- triton-vm/src/codegen/mod.rs | 14 ++++----- triton-vm/src/codegen/substitutions.rs | 16 +++++----- triton-vm/src/stark.rs | 2 +- triton-vm/src/table.rs | 22 +++++++------- triton-vm/src/table/cascade_table.rs | 12 ++++---- triton-vm/src/table/cross_table_argument.rs | 10 +++---- triton-vm/src/table/hash_table.rs | 30 +++++++++---------- triton-vm/src/table/jump_stack_table.rs | 6 ++-- triton-vm/src/table/lookup_table.rs | 12 ++++---- triton-vm/src/table/master_table.rs | 10 +++---- triton-vm/src/table/op_stack_table.rs | 6 ++-- triton-vm/src/table/processor_table.rs | 6 ++-- triton-vm/src/table/program_table.rs | 6 ++-- triton-vm/src/table/ram_table.rs | 6 ++-- triton-vm/src/table/u32_table.rs | 14 ++++----- 20 files changed, 98 insertions(+), 96 deletions(-) rename {triton-constraint-builder => triton-constraint-circuit}/Cargo.toml (94%) rename {triton-constraint-builder => triton-constraint-circuit}/src/lib.rs (100%) diff --git a/Cargo.toml b/Cargo.toml index 11cc3d4e2..462f3789e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["triton-vm", "triton-constraint-builder", "triton-isa"] +members = ["triton-vm", "triton-constraint-circuit", "triton-isa"] resolver = "2" [profile.test] @@ -26,12 +26,14 @@ anyhow = "1.0" arbitrary = { version = "1", features = ["derive"] } assert2 = "0.3" colored = "2.1" +constraint-circuit = { path = "triton-constraint-circuit", package = "triton-constraint-circuit" } clap = { version = "4", features = ["derive", "cargo", "wrap_help", "unicode", "string"] } criterion = { version = "0.5", features = ["html_reports"] } directories = "5" fs-err = "2.11.0" get-size = "0.1.4" indexmap = "2.2.6" +isa = { path = "triton-isa", package = "triton-isa" } itertools = "0.13" lazy_static = "1.5" ndarray = { version = "0.16", features = ["rayon"] } diff --git a/triton-constraint-builder/Cargo.toml b/triton-constraint-circuit/Cargo.toml similarity index 94% rename from triton-constraint-builder/Cargo.toml rename to triton-constraint-circuit/Cargo.toml index 8dbd03f54..1ef2d0506 100644 --- a/triton-constraint-builder/Cargo.toml +++ b/triton-constraint-circuit/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "triton-constraint-builder" +name = "triton-constraint-circuit" description = """ AIR constraints build helper for Triton VM. """ diff --git a/triton-constraint-builder/src/lib.rs b/triton-constraint-circuit/src/lib.rs similarity index 100% rename from triton-constraint-builder/src/lib.rs rename to triton-constraint-circuit/src/lib.rs diff --git a/triton-vm/Cargo.toml b/triton-vm/Cargo.toml index a3e89f28e..94b94fc22 100644 --- a/triton-vm/Cargo.toml +++ b/triton-vm/Cargo.toml @@ -20,11 +20,11 @@ readme.workspace = true [dependencies] arbitrary.workspace = true colored.workspace = true -constraint-builder = { path = "../triton-constraint-builder", package = "triton-constraint-builder" } +constraint-circuit.workspace = true criterion.workspace = true get-size.workspace = true indexmap.workspace = true -isa = { path = "../triton-isa", package = "triton-isa" } +isa.workspace = true itertools.workspace = true lazy_static.workspace = true ndarray.workspace = true diff --git a/triton-vm/src/codegen/constraints.rs b/triton-vm/src/codegen/constraints.rs index 38719c27a..2bdbf33d7 100644 --- a/triton-vm/src/codegen/constraints.rs +++ b/triton-vm/src/codegen/constraints.rs @@ -3,10 +3,10 @@ use std::collections::HashSet; -use constraint_builder::BinOp; -use constraint_builder::CircuitExpression; -use constraint_builder::ConstraintCircuit; -use constraint_builder::InputIndicator; +use constraint_circuit::BinOp; +use constraint_circuit::CircuitExpression; +use constraint_circuit::ConstraintCircuit; +use constraint_circuit::InputIndicator; use isa::instruction::Instruction; use isa::op_stack::NumberOfWords; use itertools::Itertools; @@ -893,8 +893,8 @@ impl ToTokens for IOList { #[cfg(test)] mod tests { - use constraint_builder::ConstraintCircuitBuilder; - use constraint_builder::SingleRowIndicator; + use constraint_circuit::ConstraintCircuitBuilder; + use constraint_circuit::SingleRowIndicator; use twenty_first::prelude::*; use super::*; diff --git a/triton-vm/src/codegen/mod.rs b/triton-vm/src/codegen/mod.rs index 8531de53c..507caa23a 100644 --- a/triton-vm/src/codegen/mod.rs +++ b/triton-vm/src/codegen/mod.rs @@ -1,9 +1,9 @@ -use constraint_builder::ConstraintCircuit; -use constraint_builder::ConstraintCircuitMonad; -use constraint_builder::DegreeLoweringInfo; -use constraint_builder::DualRowIndicator; -use constraint_builder::InputIndicator; -use constraint_builder::SingleRowIndicator; +use constraint_circuit::ConstraintCircuit; +use constraint_circuit::ConstraintCircuitMonad; +use constraint_circuit::DegreeLoweringInfo; +use constraint_circuit::DualRowIndicator; +use constraint_circuit::InputIndicator; +use constraint_circuit::SingleRowIndicator; use itertools::Itertools; use proc_macro2::TokenStream; use std::fs::write; @@ -130,7 +130,7 @@ impl Constraints { #[cfg(test)] mod tests { - use constraint_builder::ConstraintCircuitBuilder; + use constraint_circuit::ConstraintCircuitBuilder; use twenty_first::prelude::*; use crate::table; diff --git a/triton-vm/src/codegen/substitutions.rs b/triton-vm/src/codegen/substitutions.rs index dc76fface..ea6da7ef4 100644 --- a/triton-vm/src/codegen/substitutions.rs +++ b/triton-vm/src/codegen/substitutions.rs @@ -1,11 +1,11 @@ -use constraint_builder::BinOp; -use constraint_builder::CircuitExpression; -use constraint_builder::ConstraintCircuit; -use constraint_builder::ConstraintCircuitMonad; -use constraint_builder::DegreeLoweringInfo; -use constraint_builder::DualRowIndicator; -use constraint_builder::InputIndicator; -use constraint_builder::SingleRowIndicator; +use constraint_circuit::BinOp; +use constraint_circuit::CircuitExpression; +use constraint_circuit::ConstraintCircuit; +use constraint_circuit::ConstraintCircuitMonad; +use constraint_circuit::DegreeLoweringInfo; +use constraint_circuit::DualRowIndicator; +use constraint_circuit::InputIndicator; +use constraint_circuit::SingleRowIndicator; use itertools::Itertools; use proc_macro2::TokenStream; use quote::format_ident; diff --git a/triton-vm/src/stark.rs b/triton-vm/src/stark.rs index 52fc1de36..bd7da17ec 100644 --- a/triton-vm/src/stark.rs +++ b/triton-vm/src/stark.rs @@ -1313,7 +1313,7 @@ pub(crate) mod tests { use assert2::assert; use assert2::check; use assert2::let_assert; - use constraint_builder::ConstraintCircuitBuilder; + use constraint_circuit::ConstraintCircuitBuilder; use isa::error::OpStackError; use isa::instruction::Instruction; use isa::op_stack::OpStackElement; diff --git a/triton-vm/src/table.rs b/triton-vm/src/table.rs index 113f7c1d2..c1c15dd4a 100644 --- a/triton-vm/src/table.rs +++ b/triton-vm/src/table.rs @@ -3,10 +3,10 @@ pub use crate::table::master_table::NUM_BASE_COLUMNS; pub use crate::table::master_table::NUM_EXT_COLUMNS; use arbitrary::Arbitrary; -use constraint_builder::ConstraintCircuitBuilder; -use constraint_builder::ConstraintCircuitMonad; -use constraint_builder::DualRowIndicator; -use constraint_builder::SingleRowIndicator; +use constraint_circuit::ConstraintCircuitBuilder; +use constraint_circuit::ConstraintCircuitMonad; +use constraint_circuit::DualRowIndicator; +use constraint_circuit::SingleRowIndicator; use strum::Display; use strum::EnumCount; use strum::EnumIter; @@ -161,13 +161,13 @@ fn terminal_constraints() -> Vec> { mod tests { use std::collections::HashMap; - use constraint_builder::BinOp; - use constraint_builder::CircuitExpression; - use constraint_builder::ConstraintCircuit; - use constraint_builder::ConstraintCircuitBuilder; - use constraint_builder::ConstraintCircuitMonad; - use constraint_builder::DegreeLoweringInfo; - use constraint_builder::InputIndicator; + use constraint_circuit::BinOp; + use constraint_circuit::CircuitExpression; + use constraint_circuit::ConstraintCircuit; + use constraint_circuit::ConstraintCircuitBuilder; + use constraint_circuit::ConstraintCircuitMonad; + use constraint_circuit::DegreeLoweringInfo; + use constraint_circuit::InputIndicator; use itertools::Itertools; use ndarray::Array2; use ndarray::ArrayView2; diff --git a/triton-vm/src/table/cascade_table.rs b/triton-vm/src/table/cascade_table.rs index 463919564..770f272e7 100644 --- a/triton-vm/src/table/cascade_table.rs +++ b/triton-vm/src/table/cascade_table.rs @@ -1,9 +1,9 @@ -use constraint_builder::ConstraintCircuitBuilder; -use constraint_builder::ConstraintCircuitMonad; -use constraint_builder::DualRowIndicator; -use constraint_builder::DualRowIndicator::*; -use constraint_builder::SingleRowIndicator; -use constraint_builder::SingleRowIndicator::*; +use constraint_circuit::ConstraintCircuitBuilder; +use constraint_circuit::ConstraintCircuitMonad; +use constraint_circuit::DualRowIndicator; +use constraint_circuit::DualRowIndicator::*; +use constraint_circuit::SingleRowIndicator; +use constraint_circuit::SingleRowIndicator::*; use ndarray::s; use ndarray::ArrayView2; use ndarray::ArrayViewMut2; diff --git a/triton-vm/src/table/cross_table_argument.rs b/triton-vm/src/table/cross_table_argument.rs index 43e5ab83d..dab2a3ee9 100644 --- a/triton-vm/src/table/cross_table_argument.rs +++ b/triton-vm/src/table/cross_table_argument.rs @@ -1,11 +1,11 @@ use std::ops::Add; use std::ops::Mul; -use constraint_builder::ConstraintCircuitBuilder; -use constraint_builder::ConstraintCircuitMonad; -use constraint_builder::DualRowIndicator; -use constraint_builder::SingleRowIndicator; -use constraint_builder::SingleRowIndicator::ExtRow; +use constraint_circuit::ConstraintCircuitBuilder; +use constraint_circuit::ConstraintCircuitMonad; +use constraint_circuit::DualRowIndicator; +use constraint_circuit::SingleRowIndicator; +use constraint_circuit::SingleRowIndicator::ExtRow; use twenty_first::prelude::*; use crate::table::challenges::ChallengeId::*; diff --git a/triton-vm/src/table/hash_table.rs b/triton-vm/src/table/hash_table.rs index 8ae834836..3f2c10792 100644 --- a/triton-vm/src/table/hash_table.rs +++ b/triton-vm/src/table/hash_table.rs @@ -1,10 +1,10 @@ -use constraint_builder::ConstraintCircuitBuilder; -use constraint_builder::ConstraintCircuitMonad; -use constraint_builder::DualRowIndicator; -use constraint_builder::DualRowIndicator::*; -use constraint_builder::InputIndicator; -use constraint_builder::SingleRowIndicator; -use constraint_builder::SingleRowIndicator::*; +use constraint_circuit::ConstraintCircuitBuilder; +use constraint_circuit::ConstraintCircuitMonad; +use constraint_circuit::DualRowIndicator; +use constraint_circuit::DualRowIndicator::*; +use constraint_circuit::InputIndicator; +use constraint_circuit::SingleRowIndicator; +use constraint_circuit::SingleRowIndicator::*; use isa::instruction::AnInstruction::Hash; use isa::instruction::AnInstruction::SpongeAbsorb; use isa::instruction::AnInstruction::SpongeInit; @@ -554,7 +554,7 @@ impl ExtHashTable { Self::round_number_deselector(circuit_builder, &round_number, round_idx); round_constant_constraint_circuit = round_constant_constraint_circuit + round_deselector_circuit - * (round_constant_column_circuit.clone() - round_constant); + * (round_constant_column_circuit.clone() - round_constant); } constraints.push(round_constant_constraint_circuit); } @@ -709,7 +709,7 @@ impl ExtHashTable { StackWeight14, StackWeight15, ] - .map(challenge); + .map(challenge); let round_number_is_not_num_rounds = Self::round_number_deselector(circuit_builder, &round_number, NUM_ROUNDS); @@ -820,7 +820,7 @@ impl ExtHashTable { * running_evaluation_hash_input_updates + round_number_next.clone() * running_evaluation_hash_input_remains.clone() + Self::select_mode(circuit_builder, &mode_next, HashTableMode::Hash) - * running_evaluation_hash_input_remains; + * running_evaluation_hash_input_remains; // If (and only if) the row number in the next row is NUM_ROUNDS and the current instruction // in the next row corresponds to `hash`, update running evaluation “hash digest.” @@ -843,7 +843,7 @@ impl ExtHashTable { * running_evaluation_hash_digest_updates + round_number_next_is_num_rounds * running_evaluation_hash_digest_remains.clone() + Self::select_mode(circuit_builder, &mode_next, HashTableMode::Hash) - * running_evaluation_hash_digest_remains; + * running_evaluation_hash_digest_remains; // The running evaluation for “Sponge” updates correctly. let compressed_row_next = state_weights[..RATE] @@ -893,7 +893,7 @@ impl ExtHashTable { * receive_chunk_running_evaluation_absorbs_chunk_of_instructions + round_number_next * receive_chunk_running_evaluation_remains.clone() + Self::select_mode(circuit_builder, &mode_next, HashTableMode::ProgramHashing) - * receive_chunk_running_evaluation_remains; + * receive_chunk_running_evaluation_remains; let constraints = vec![ round_number_is_0_through_4_or_round_number_next_is_0, @@ -1013,7 +1013,7 @@ impl ExtHashTable { constraints, hash_function_round_correctly_performs_update.to_vec(), ] - .concat() + .concat() } fn indicate_column_index_in_base_row(column: HashBaseTableColumn) -> SingleRowIndicator { @@ -1112,7 +1112,7 @@ impl ExtHashTable { State4, State5, State6, State7, State8, State9, State10, State11, State12, State13, State14, State15, ] - .map(current_base_row); + .map(current_base_row); let state_part_after_power_map = { let mut exponentiation_accumulator = state_part_before_power_map.clone(); @@ -1156,7 +1156,7 @@ impl ExtHashTable { Constant8, Constant9, Constant10, Constant11, Constant12, Constant13, Constant14, Constant15, ] - .map(current_base_row); + .map(current_base_row); let state_after_round_constant_addition = state_after_matrix_multiplication .into_iter() diff --git a/triton-vm/src/table/jump_stack_table.rs b/triton-vm/src/table/jump_stack_table.rs index 334a79ed9..ab3ee0914 100644 --- a/triton-vm/src/table/jump_stack_table.rs +++ b/triton-vm/src/table/jump_stack_table.rs @@ -2,9 +2,9 @@ use std::cmp::Ordering; use std::collections::HashMap; use std::ops::Range; -use constraint_builder::DualRowIndicator::*; -use constraint_builder::SingleRowIndicator::*; -use constraint_builder::*; +use constraint_circuit::DualRowIndicator::*; +use constraint_circuit::SingleRowIndicator::*; +use constraint_circuit::*; use isa::instruction::Instruction; use itertools::Itertools; use ndarray::parallel::prelude::*; diff --git a/triton-vm/src/table/lookup_table.rs b/triton-vm/src/table/lookup_table.rs index 7ecb95491..37881e5fd 100644 --- a/triton-vm/src/table/lookup_table.rs +++ b/triton-vm/src/table/lookup_table.rs @@ -1,9 +1,9 @@ -use constraint_builder::ConstraintCircuitBuilder; -use constraint_builder::ConstraintCircuitMonad; -use constraint_builder::DualRowIndicator; -use constraint_builder::DualRowIndicator::*; -use constraint_builder::SingleRowIndicator; -use constraint_builder::SingleRowIndicator::*; +use constraint_circuit::ConstraintCircuitBuilder; +use constraint_circuit::ConstraintCircuitMonad; +use constraint_circuit::DualRowIndicator; +use constraint_circuit::DualRowIndicator::*; +use constraint_circuit::SingleRowIndicator; +use constraint_circuit::SingleRowIndicator::*; use itertools::Itertools; use ndarray::prelude::*; use num_traits::ConstOne; diff --git a/triton-vm/src/table/master_table.rs b/triton-vm/src/table/master_table.rs index 74e00accc..42fdf6d6a 100644 --- a/triton-vm/src/table/master_table.rs +++ b/triton-vm/src/table/master_table.rs @@ -1273,11 +1273,11 @@ mod tests { use fs_err as fs; use std::path::Path; - use constraint_builder::ConstraintCircuitBuilder; - use constraint_builder::ConstraintCircuitMonad; - use constraint_builder::DegreeLoweringInfo; - use constraint_builder::DualRowIndicator; - use constraint_builder::SingleRowIndicator; + use constraint_circuit::ConstraintCircuitBuilder; + use constraint_circuit::ConstraintCircuitMonad; + use constraint_circuit::DegreeLoweringInfo; + use constraint_circuit::DualRowIndicator; + use constraint_circuit::SingleRowIndicator; use isa::instruction::Instruction; use isa::instruction::InstructionBit; use master_table::cross_table_argument::GrandCrossTableArg; diff --git a/triton-vm/src/table/op_stack_table.rs b/triton-vm/src/table/op_stack_table.rs index 7b40bd3b5..65f6f5847 100644 --- a/triton-vm/src/table/op_stack_table.rs +++ b/triton-vm/src/table/op_stack_table.rs @@ -3,9 +3,9 @@ use std::collections::HashMap; use std::ops::Range; use arbitrary::Arbitrary; -use constraint_builder::DualRowIndicator::*; -use constraint_builder::SingleRowIndicator::*; -use constraint_builder::*; +use constraint_circuit::DualRowIndicator::*; +use constraint_circuit::SingleRowIndicator::*; +use constraint_circuit::*; use isa::op_stack::OpStackElement; use isa::op_stack::UnderflowIO; use itertools::Itertools; diff --git a/triton-vm/src/table/processor_table.rs b/triton-vm/src/table/processor_table.rs index f58083fc6..ce95fece0 100644 --- a/triton-vm/src/table/processor_table.rs +++ b/triton-vm/src/table/processor_table.rs @@ -1,9 +1,9 @@ use std::cmp::max; use std::ops::Mul; -use constraint_builder::DualRowIndicator::*; -use constraint_builder::SingleRowIndicator::*; -use constraint_builder::*; +use constraint_circuit::DualRowIndicator::*; +use constraint_circuit::SingleRowIndicator::*; +use constraint_circuit::*; use isa::instruction::AnInstruction::*; use isa::instruction::Instruction; use isa::instruction::InstructionBit; diff --git a/triton-vm/src/table/program_table.rs b/triton-vm/src/table/program_table.rs index cd2f2736d..568fd8884 100644 --- a/triton-vm/src/table/program_table.rs +++ b/triton-vm/src/table/program_table.rs @@ -1,8 +1,8 @@ use std::cmp::Ordering; -use constraint_builder::DualRowIndicator::*; -use constraint_builder::SingleRowIndicator::*; -use constraint_builder::*; +use constraint_circuit::DualRowIndicator::*; +use constraint_circuit::SingleRowIndicator::*; +use constraint_circuit::*; use ndarray::s; use ndarray::Array1; use ndarray::ArrayView1; diff --git a/triton-vm/src/table/ram_table.rs b/triton-vm/src/table/ram_table.rs index ae285647e..e6065f969 100644 --- a/triton-vm/src/table/ram_table.rs +++ b/triton-vm/src/table/ram_table.rs @@ -1,9 +1,9 @@ use std::cmp::Ordering; use arbitrary::Arbitrary; -use constraint_builder::DualRowIndicator::*; -use constraint_builder::SingleRowIndicator::*; -use constraint_builder::*; +use constraint_circuit::DualRowIndicator::*; +use constraint_circuit::SingleRowIndicator::*; +use constraint_circuit::*; use itertools::Itertools; use ndarray::parallel::prelude::*; use ndarray::prelude::*; diff --git a/triton-vm/src/table/u32_table.rs b/triton-vm/src/table/u32_table.rs index d5475061a..da015dff7 100644 --- a/triton-vm/src/table/u32_table.rs +++ b/triton-vm/src/table/u32_table.rs @@ -2,13 +2,13 @@ use std::cmp::max; use std::ops::Mul; use arbitrary::Arbitrary; -use constraint_builder::ConstraintCircuitBuilder; -use constraint_builder::ConstraintCircuitMonad; -use constraint_builder::DualRowIndicator; -use constraint_builder::DualRowIndicator::*; -use constraint_builder::InputIndicator; -use constraint_builder::SingleRowIndicator; -use constraint_builder::SingleRowIndicator::*; +use constraint_circuit::ConstraintCircuitBuilder; +use constraint_circuit::ConstraintCircuitMonad; +use constraint_circuit::DualRowIndicator; +use constraint_circuit::DualRowIndicator::*; +use constraint_circuit::InputIndicator; +use constraint_circuit::SingleRowIndicator; +use constraint_circuit::SingleRowIndicator::*; use isa::instruction::Instruction; use ndarray::parallel::prelude::*; use ndarray::s; From e23f78d8c2aa95e558c352051983b374bfcca4b4 Mon Sep 17 00:00:00 2001 From: Jan Ferdinand Sauer Date: Fri, 30 Aug 2024 16:24:42 +0200 Subject: [PATCH 06/15] refactor!: Move AIR to own crate --- Cargo.toml | 3 +- triton-air/Cargo.toml | 32 + triton-air/src/challenge_id.rs | 241 + .../src}/cross_table_argument.rs | 107 +- triton-air/src/lib.rs | 104 + triton-air/src/table.rs | 111 + .../src/table/cascade.rs | 216 +- triton-air/src/table/hash.rs | 1393 +++++ triton-air/src/table/jump_stack.rs | 164 + triton-air/src/table/lookup.rs | 189 + triton-air/src/table/op_stack.rs | 202 + triton-air/src/table/processor.rs | 3594 ++++++++++++ triton-air/src/table/program.rs | 279 + triton-air/src/table/ram.rs | 275 + triton-air/src/table/u32.rs | 395 ++ .../table => triton-air/src}/table_column.rs | 241 +- triton-vm/Cargo.toml | 1 + triton-vm/benches/bezout_coeffs.rs | 4 +- triton-vm/benches/prove_halt.rs | 2 +- triton-vm/benches/verify_halt.rs | 19 +- triton-vm/src/aet.rs | 58 +- triton-vm/src/air.rs | 6 +- triton-vm/src/air/memory_layout.rs | 6 +- triton-vm/src/challenges.rs | 186 + triton-vm/src/execution_trace_profiler.rs | 4 +- triton-vm/src/lib.rs | 52 +- triton-vm/src/stark.rs | 156 +- triton-vm/src/table.rs | 235 +- triton-vm/src/table/cascade.rs | 124 + triton-vm/src/table/challenges.rs | 396 -- triton-vm/src/table/constraints.rs | 2 +- triton-vm/src/table/degree_lowering_table.rs | 6 +- triton-vm/src/table/extension_table.rs | 2 +- triton-vm/src/table/hash.rs | 649 +++ triton-vm/src/table/hash_table.rs | 1926 ------- triton-vm/src/table/jump_stack.rs | 229 + triton-vm/src/table/jump_stack_table.rs | 366 -- triton-vm/src/table/lookup.rs | 156 + triton-vm/src/table/lookup_table.rs | 316 -- triton-vm/src/table/master_table.rs | 426 +- triton-vm/src/table/op_stack.rs | 394 ++ triton-vm/src/table/op_stack_table.rs | 563 -- triton-vm/src/table/processor.rs | 1501 +++++ triton-vm/src/table/processor_table.rs | 4987 ----------------- triton-vm/src/table/program.rs | 259 + triton-vm/src/table/program_table.rs | 510 -- triton-vm/src/table/ram.rs | 468 ++ triton-vm/src/table/ram_table.rs | 711 --- triton-vm/src/table/u32.rs | 261 + triton-vm/src/table/u32_table.rs | 617 -- triton-vm/src/vm.rs | 23 +- 51 files changed, 11763 insertions(+), 11404 deletions(-) create mode 100644 triton-air/Cargo.toml create mode 100644 triton-air/src/challenge_id.rs rename {triton-vm/src/table => triton-air/src}/cross_table_argument.rs (67%) create mode 100644 triton-air/src/lib.rs create mode 100644 triton-air/src/table.rs rename triton-vm/src/table/cascade_table.rs => triton-air/src/table/cascade.rs (51%) create mode 100644 triton-air/src/table/hash.rs create mode 100644 triton-air/src/table/jump_stack.rs create mode 100644 triton-air/src/table/lookup.rs create mode 100644 triton-air/src/table/op_stack.rs create mode 100644 triton-air/src/table/processor.rs create mode 100644 triton-air/src/table/program.rs create mode 100644 triton-air/src/table/ram.rs create mode 100644 triton-air/src/table/u32.rs rename {triton-vm/src/table => triton-air/src}/table_column.rs (76%) create mode 100644 triton-vm/src/challenges.rs create mode 100644 triton-vm/src/table/cascade.rs delete mode 100644 triton-vm/src/table/challenges.rs create mode 100644 triton-vm/src/table/hash.rs delete mode 100644 triton-vm/src/table/hash_table.rs create mode 100644 triton-vm/src/table/jump_stack.rs delete mode 100644 triton-vm/src/table/jump_stack_table.rs create mode 100644 triton-vm/src/table/lookup.rs delete mode 100644 triton-vm/src/table/lookup_table.rs create mode 100644 triton-vm/src/table/op_stack.rs delete mode 100644 triton-vm/src/table/op_stack_table.rs create mode 100644 triton-vm/src/table/processor.rs delete mode 100644 triton-vm/src/table/processor_table.rs create mode 100644 triton-vm/src/table/program.rs delete mode 100644 triton-vm/src/table/program_table.rs create mode 100644 triton-vm/src/table/ram.rs delete mode 100644 triton-vm/src/table/ram_table.rs create mode 100644 triton-vm/src/table/u32.rs delete mode 100644 triton-vm/src/table/u32_table.rs diff --git a/Cargo.toml b/Cargo.toml index 462f3789e..83c85a7b7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["triton-vm", "triton-constraint-circuit", "triton-isa"] +members = ["triton-air", "triton-constraint-circuit", "triton-isa", "triton-vm"] resolver = "2" [profile.test] @@ -22,6 +22,7 @@ readme = "README.md" documentation = "https://triton-vm.org/spec/" [workspace.dependencies] +air = { path = "triton-air", package = "triton-air" } anyhow = "1.0" arbitrary = { version = "1", features = ["derive"] } assert2 = "0.3" diff --git a/triton-air/Cargo.toml b/triton-air/Cargo.toml new file mode 100644 index 000000000..1ba8fb87d --- /dev/null +++ b/triton-air/Cargo.toml @@ -0,0 +1,32 @@ +[package] +name = "triton-air" +description = """ +The Arithmetic Intermediate Representation (AIR) for Triton VM. +""" + +version.workspace = true +edition.workspace = true +authors.workspace = true +license.workspace = true +homepage.workspace = true +documentation.workspace = true +repository.workspace = true +readme.workspace = true + +[dependencies] +arbitrary.workspace = true +constraint-circuit.workspace = true +isa.workspace = true +itertools.workspace = true +strum.workspace = true +twenty-first.workspace = true + +[dev-dependencies] +ndarray.workspace = true +num-traits.workspace = true +proptest.workspace = true +proptest-arbitrary-interop.workspace = true +test-strategy.workspace = true + +[lints] +workspace = true diff --git a/triton-air/src/challenge_id.rs b/triton-air/src/challenge_id.rs new file mode 100644 index 000000000..044c6414e --- /dev/null +++ b/triton-air/src/challenge_id.rs @@ -0,0 +1,241 @@ +use std::fmt::Debug; +use std::hash::Hash; + +use arbitrary::Arbitrary; +use strum::Display; +use strum::EnumCount; +use strum::EnumIter; + +/// A `ChallengeId` is a unique, symbolic identifier for a challenge used in +/// Triton VM. +/// +/// Since almost all challenges relate to the Processor Table in some form, the +/// words “Processor Table” are usually omitted from the `ChallengeId`'s name. +#[repr(usize)] +#[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter, Arbitrary)] +pub enum ChallengeId { + /// The indeterminate for the [Evaluation Argument](EvalArg) compressing the program digest + /// into a single extension field element, _i.e._, [`CompressedProgramDigest`]. + /// Relates to program attestation. + CompressProgramDigestIndeterminate, + + /// The indeterminate for the [Evaluation Argument](EvalArg) with standard input. + StandardInputIndeterminate, + + /// The indeterminate for the [Evaluation Argument](EvalArg) with standard output. + StandardOutputIndeterminate, + + /// The indeterminate for the instruction + /// [Lookup Argument](crate::table::cross_table_argument::LookupArg) + /// between the [Processor Table](crate::table::processor_table) and the + /// [Program Table](crate::table::program_table) guaranteeing that the instructions and their + /// arguments are copied correctly. + InstructionLookupIndeterminate, + + HashInputIndeterminate, + HashDigestIndeterminate, + SpongeIndeterminate, + + OpStackIndeterminate, + RamIndeterminate, + JumpStackIndeterminate, + + U32Indeterminate, + + /// The indeterminate for the Lookup Argument between the Processor Table and all memory-like + /// tables, _i.e._, the OpStack Table, the Ram Table, and the JumpStack Table, guaranteeing + /// that all clock jump differences are directed forward. + ClockJumpDifferenceLookupIndeterminate, + + /// The indeterminate for the Contiguity Argument within the Ram Table. + RamTableBezoutRelationIndeterminate, + + /// A weight for linearly combining multiple elements. Applies to + /// - `Address` in the Program Table + /// - `IP` in the Processor Table + ProgramAddressWeight, + + /// A weight for linearly combining multiple elements. Applies to + /// - `Instruction` in the Program Table + /// - `CI` in the Processor Table + ProgramInstructionWeight, + + /// A weight for linearly combining multiple elements. Applies to + /// - `Instruction` in the next row in the Program Table + /// - `NIA` in the Processor Table + ProgramNextInstructionWeight, + + OpStackClkWeight, + OpStackIb1Weight, + OpStackPointerWeight, + OpStackFirstUnderflowElementWeight, + + RamClkWeight, + RamPointerWeight, + RamValueWeight, + RamInstructionTypeWeight, + + JumpStackClkWeight, + JumpStackCiWeight, + JumpStackJspWeight, + JumpStackJsoWeight, + JumpStackJsdWeight, + + /// The indeterminate for compressing a [`RATE`][rate]-sized chunk of instructions into a + /// single extension field element. + /// Relates to program attestation. + /// + /// Used by the evaluation argument [`PrepareChunkEvalArg`][prep] and in the Hash Table. + /// + /// [rate]: tip5::RATE + /// [prep]: crate::table::table_column::ProgramExtTableColumn::PrepareChunkRunningEvaluation + ProgramAttestationPrepareChunkIndeterminate, + + /// The indeterminate for the bus over which the [`RATE`][rate]-sized chunks of instructions + /// are sent. Relates to program attestation. + /// Used by the evaluation arguments [`SendChunkEvalArg`][send] and + /// [`ReceiveChunkEvalArg`][recv]. See also: [`ProgramAttestationPrepareChunkIndeterminate`]. + /// + /// [rate]: tip5::RATE + /// [send]: crate::table::table_column::ProgramExtTableColumn::SendChunkRunningEvaluation + /// [recv]: crate::table::table_column::HashExtTableColumn::ReceiveChunkRunningEvaluation + ProgramAttestationSendChunkIndeterminate, + + HashCIWeight, + + StackWeight0, + StackWeight1, + StackWeight2, + StackWeight3, + StackWeight4, + StackWeight5, + StackWeight6, + StackWeight7, + StackWeight8, + StackWeight9, + StackWeight10, + StackWeight11, + StackWeight12, + StackWeight13, + StackWeight14, + StackWeight15, + + /// The indeterminate for the Lookup Argument between the Hash Table and the Cascade Table. + HashCascadeLookupIndeterminate, + + /// A weight for linearly combining multiple elements. Applies to + /// - `*LkIn` in the Hash Table, and + /// - `2^16·LookInHi + LookInLo` in the Cascade Table. + HashCascadeLookInWeight, + + /// A weight for linearly combining multiple elements. Applies to + /// - `*LkOut` in the Hash Table, and + /// - `2^16·LookOutHi + LookOutLo` in the Cascade Table. + HashCascadeLookOutWeight, + + /// The indeterminate for the Lookup Argument between the Cascade Table and the Lookup Table. + CascadeLookupIndeterminate, + + /// A weight for linearly combining multiple elements. Applies to + /// - `LkIn*` in the Cascade Table, and + /// - `LookIn` in the Lookup Table. + LookupTableInputWeight, + + /// A weight for linearly combining multiple elements. Applies to + /// - `LkOut*` in the Cascade Table, and + /// - `LookOut` in the Lookup Table. + LookupTableOutputWeight, + + /// The indeterminate for the public Evaluation Argument establishing correctness of the + /// Lookup Table. + LookupTablePublicIndeterminate, + + U32LhsWeight, + U32RhsWeight, + U32CiWeight, + U32ResultWeight, + + // Derived challenges. + // + // When modifying this, be sure to add to the compile-time assertions in the + // `#[test] const fn compile_time_index_assertions() { … }` + // at the end of this file. + /// The terminal for the [`EvaluationArgument`](EvalArg) with standard input. + /// Makes use of challenge [`StandardInputIndeterminate`]. + StandardInputTerminal, + + /// The terminal for the [`EvaluationArgument`](EvalArg) with standard output. + /// Makes use of challenge [`StandardOutputIndeterminate`]. + StandardOutputTerminal, + + /// The terminal for the [`EvaluationArgument`](EvalArg) establishing correctness of the + /// [Lookup Table](crate::table::lookup_table::LookupTable). + /// Makes use of challenge [`LookupTablePublicIndeterminate`]. + LookupTablePublicTerminal, + + /// The digest of the program to be executed, compressed into a single extension field element. + /// The compression happens using an [`EvaluationArgument`](EvalArg) under challenge + /// [`CompressProgramDigestIndeterminate`]. + /// Relates to program attestation. + CompressedProgramDigest, +} + +impl ChallengeId { + /// The number of challenges derived from other challenges. + /// + /// The IDs of the derived challenges are guaranteed to be larger than the + /// challenges they are derived from. + pub const NUM_DERIVED_CHALLENGES: usize = 4; + + pub const fn index(&self) -> usize { + *self as usize + } +} + +impl From for usize { + fn from(id: ChallengeId) -> Self { + id.index() + } +} + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + + /// Terminal challenges are computed from public information, such as public + /// input or public output, and other challenges. Because these other challenges + /// are used to compute the terminal challenges, the terminal challenges must be + /// inserted into the challenges vector after the used challenges. + #[test] + const fn compile_time_index_assertions() { + const DERIVED: [ChallengeId; ChallengeId::NUM_DERIVED_CHALLENGES] = [ + ChallengeId::StandardInputTerminal, + ChallengeId::StandardOutputTerminal, + ChallengeId::LookupTablePublicTerminal, + ChallengeId::CompressedProgramDigest, + ]; + + assert!(ChallengeId::StandardInputIndeterminate.index() < DERIVED[0].index()); + assert!(ChallengeId::StandardInputIndeterminate.index() < DERIVED[1].index()); + assert!(ChallengeId::StandardInputIndeterminate.index() < DERIVED[2].index()); + assert!(ChallengeId::StandardInputIndeterminate.index() < DERIVED[3].index()); + + assert!(ChallengeId::StandardOutputIndeterminate.index() < DERIVED[0].index()); + assert!(ChallengeId::StandardOutputIndeterminate.index() < DERIVED[1].index()); + assert!(ChallengeId::StandardOutputIndeterminate.index() < DERIVED[2].index()); + assert!(ChallengeId::StandardOutputIndeterminate.index() < DERIVED[3].index()); + + assert!(ChallengeId::CompressProgramDigestIndeterminate.index() < DERIVED[0].index()); + assert!(ChallengeId::CompressProgramDigestIndeterminate.index() < DERIVED[1].index()); + assert!(ChallengeId::CompressProgramDigestIndeterminate.index() < DERIVED[2].index()); + assert!(ChallengeId::CompressProgramDigestIndeterminate.index() < DERIVED[3].index()); + + assert!(ChallengeId::LookupTablePublicIndeterminate.index() < DERIVED[0].index()); + assert!(ChallengeId::LookupTablePublicIndeterminate.index() < DERIVED[1].index()); + assert!(ChallengeId::LookupTablePublicIndeterminate.index() < DERIVED[2].index()); + assert!(ChallengeId::LookupTablePublicIndeterminate.index() < DERIVED[3].index()); + } + + // Ensure the compile-time assertions are actually executed by the compiler. + const _: () = compile_time_index_assertions(); +} diff --git a/triton-vm/src/table/cross_table_argument.rs b/triton-air/src/cross_table_argument.rs similarity index 67% rename from triton-vm/src/table/cross_table_argument.rs rename to triton-air/src/cross_table_argument.rs index dab2a3ee9..575087b98 100644 --- a/triton-vm/src/table/cross_table_argument.rs +++ b/triton-air/src/cross_table_argument.rs @@ -8,22 +8,18 @@ use constraint_circuit::SingleRowIndicator; use constraint_circuit::SingleRowIndicator::ExtRow; use twenty_first::prelude::*; -use crate::table::challenges::ChallengeId::*; -use crate::table::table_column::CascadeExtTableColumn; -use crate::table::table_column::HashExtTableColumn; -use crate::table::table_column::HashExtTableColumn::*; -use crate::table::table_column::JumpStackExtTableColumn; -use crate::table::table_column::LookupExtTableColumn; -use crate::table::table_column::LookupExtTableColumn::*; -use crate::table::table_column::MasterExtTableColumn; -use crate::table::table_column::OpStackExtTableColumn; -use crate::table::table_column::ProcessorExtTableColumn; -use crate::table::table_column::ProcessorExtTableColumn::*; -use crate::table::table_column::ProgramExtTableColumn; -use crate::table::table_column::ProgramExtTableColumn::*; -use crate::table::table_column::RamExtTableColumn; -use crate::table::table_column::U32ExtTableColumn; -use crate::table::table_column::U32ExtTableColumn::*; +use crate::challenge_id::ChallengeId; + +use crate::table_column::CascadeExtTableColumn; +use crate::table_column::HashExtTableColumn; +use crate::table_column::JumpStackExtTableColumn; +use crate::table_column::LookupExtTableColumn; +use crate::table_column::MasterExtTableColumn; +use crate::table_column::OpStackExtTableColumn; +use crate::table_column::ProcessorExtTableColumn; +use crate::table_column::ProgramExtTableColumn; +use crate::table_column::RamExtTableColumn; +use crate::table_column::U32ExtTableColumn; pub trait CrossTableArg { fn default_initial() -> XFieldElement @@ -154,50 +150,55 @@ impl GrandCrossTableArg { |column: LookupExtTableColumn| ext_row(column.master_ext_table_index()); let u32_ext_row = |column: U32ExtTableColumn| ext_row(column.master_ext_table_index()); - let program_attestation = program_ext_row(SendChunkRunningEvaluation) - - hash_ext_row(ReceiveChunkRunningEvaluation); - let input_to_processor = - challenge(StandardInputTerminal) - processor_ext_row(InputTableEvalArg); - let processor_to_output = - processor_ext_row(OutputTableEvalArg) - challenge(StandardOutputTerminal); - let instruction_lookup = processor_ext_row(InstructionLookupClientLogDerivative) - - program_ext_row(InstructionLookupServerLogDerivative); - let processor_to_op_stack = processor_ext_row(OpStackTablePermArg) + let program_attestation = + program_ext_row(ProgramExtTableColumn::SendChunkRunningEvaluation) + - hash_ext_row(HashExtTableColumn::ReceiveChunkRunningEvaluation); + let input_to_processor = challenge(ChallengeId::StandardInputTerminal) + - processor_ext_row(ProcessorExtTableColumn::InputTableEvalArg); + let processor_to_output = processor_ext_row(ProcessorExtTableColumn::OutputTableEvalArg) + - challenge(ChallengeId::StandardOutputTerminal); + let instruction_lookup = + processor_ext_row(ProcessorExtTableColumn::InstructionLookupClientLogDerivative) + - program_ext_row(ProgramExtTableColumn::InstructionLookupServerLogDerivative); + let processor_to_op_stack = processor_ext_row(ProcessorExtTableColumn::OpStackTablePermArg) - op_stack_ext_row(OpStackExtTableColumn::RunningProductPermArg); - let processor_to_ram = processor_ext_row(RamTablePermArg) + let processor_to_ram = processor_ext_row(ProcessorExtTableColumn::RamTablePermArg) - ram_ext_row(RamExtTableColumn::RunningProductPermArg); - let processor_to_jump_stack = processor_ext_row(JumpStackTablePermArg) - - jump_stack_ext_row(JumpStackExtTableColumn::RunningProductPermArg); - let hash_input = - processor_ext_row(HashInputEvalArg) - hash_ext_row(HashInputRunningEvaluation); - let hash_digest = - hash_ext_row(HashDigestRunningEvaluation) - processor_ext_row(HashDigestEvalArg); - let sponge = processor_ext_row(SpongeEvalArg) - hash_ext_row(SpongeRunningEvaluation); + let processor_to_jump_stack = + processor_ext_row(ProcessorExtTableColumn::JumpStackTablePermArg) + - jump_stack_ext_row(JumpStackExtTableColumn::RunningProductPermArg); + let hash_input = processor_ext_row(ProcessorExtTableColumn::HashInputEvalArg) + - hash_ext_row(HashExtTableColumn::HashInputRunningEvaluation); + let hash_digest = hash_ext_row(HashExtTableColumn::HashDigestRunningEvaluation) + - processor_ext_row(ProcessorExtTableColumn::HashDigestEvalArg); + let sponge = processor_ext_row(ProcessorExtTableColumn::SpongeEvalArg) + - hash_ext_row(HashExtTableColumn::SpongeRunningEvaluation); let hash_to_cascade = cascade_ext_row(CascadeExtTableColumn::HashTableServerLogDerivative) - - hash_ext_row(CascadeState0HighestClientLogDerivative) - - hash_ext_row(CascadeState0MidHighClientLogDerivative) - - hash_ext_row(CascadeState0MidLowClientLogDerivative) - - hash_ext_row(CascadeState0LowestClientLogDerivative) - - hash_ext_row(CascadeState1HighestClientLogDerivative) - - hash_ext_row(CascadeState1MidHighClientLogDerivative) - - hash_ext_row(CascadeState1MidLowClientLogDerivative) - - hash_ext_row(CascadeState1LowestClientLogDerivative) - - hash_ext_row(CascadeState2HighestClientLogDerivative) - - hash_ext_row(CascadeState2MidHighClientLogDerivative) - - hash_ext_row(CascadeState2MidLowClientLogDerivative) - - hash_ext_row(CascadeState2LowestClientLogDerivative) - - hash_ext_row(CascadeState3HighestClientLogDerivative) - - hash_ext_row(CascadeState3MidHighClientLogDerivative) - - hash_ext_row(CascadeState3MidLowClientLogDerivative) - - hash_ext_row(CascadeState3LowestClientLogDerivative); + - hash_ext_row(HashExtTableColumn::CascadeState0HighestClientLogDerivative) + - hash_ext_row(HashExtTableColumn::CascadeState0MidHighClientLogDerivative) + - hash_ext_row(HashExtTableColumn::CascadeState0MidLowClientLogDerivative) + - hash_ext_row(HashExtTableColumn::CascadeState0LowestClientLogDerivative) + - hash_ext_row(HashExtTableColumn::CascadeState1HighestClientLogDerivative) + - hash_ext_row(HashExtTableColumn::CascadeState1MidHighClientLogDerivative) + - hash_ext_row(HashExtTableColumn::CascadeState1MidLowClientLogDerivative) + - hash_ext_row(HashExtTableColumn::CascadeState1LowestClientLogDerivative) + - hash_ext_row(HashExtTableColumn::CascadeState2HighestClientLogDerivative) + - hash_ext_row(HashExtTableColumn::CascadeState2MidHighClientLogDerivative) + - hash_ext_row(HashExtTableColumn::CascadeState2MidLowClientLogDerivative) + - hash_ext_row(HashExtTableColumn::CascadeState2LowestClientLogDerivative) + - hash_ext_row(HashExtTableColumn::CascadeState3HighestClientLogDerivative) + - hash_ext_row(HashExtTableColumn::CascadeState3MidHighClientLogDerivative) + - hash_ext_row(HashExtTableColumn::CascadeState3MidLowClientLogDerivative) + - hash_ext_row(HashExtTableColumn::CascadeState3LowestClientLogDerivative); let cascade_to_lookup = cascade_ext_row(CascadeExtTableColumn::LookupTableClientLogDerivative) - - lookup_ext_row(CascadeTableServerLogDerivative); - let processor_to_u32 = processor_ext_row(U32LookupClientLogDerivative) - - u32_ext_row(LookupServerLogDerivative); + - lookup_ext_row(LookupExtTableColumn::CascadeTableServerLogDerivative); + let processor_to_u32 = + processor_ext_row(ProcessorExtTableColumn::U32LookupClientLogDerivative) + - u32_ext_row(U32ExtTableColumn::LookupServerLogDerivative); // Introduce new variable names to increase readability. Potentially opinionated. - let processor_cjdld = ClockJumpDifferenceLookupServerLogDerivative; + let processor_cjdld = ProcessorExtTableColumn::ClockJumpDifferenceLookupServerLogDerivative; let op_stack_cjdld = OpStackExtTableColumn::ClockJumpDifferenceLookupClientLogDerivative; let ram_cjdld = RamExtTableColumn::ClockJumpDifferenceLookupClientLogDerivative; let j_stack_cjdld = JumpStackExtTableColumn::ClockJumpDifferenceLookupClientLogDerivative; diff --git a/triton-air/src/lib.rs b/triton-air/src/lib.rs new file mode 100644 index 000000000..1f11a29fe --- /dev/null +++ b/triton-air/src/lib.rs @@ -0,0 +1,104 @@ +use constraint_circuit::ConstraintCircuitBuilder; +use constraint_circuit::ConstraintCircuitMonad; +use constraint_circuit::DualRowIndicator; +use constraint_circuit::SingleRowIndicator; +use strum::EnumCount; + +use crate::table_column::MasterBaseTableColumn; +use crate::table_column::MasterExtTableColumn; + +pub mod challenge_id; +pub mod cross_table_argument; +pub mod table; +pub mod table_column; + +/// The degree of the AIR after the degree lowering step. +/// +/// Using substitution and the introduction of new variables, the degree of the AIR as specified +/// in the respective tables +/// (e.g., in [`processor_table::ExtProcessorTable::transition_constraints`]) +/// is lowered to this value. +/// For example, with a target degree of 2 and a (fictional) constraint of the form +/// `a = b²·c²·d`, +/// the degree lowering step could (as one among multiple possibilities) +/// - introduce new variables `e`, `f`, and `g`, +/// - introduce new constraints `e = b²`, `f = c²`, and `g = e·f`, +/// - replace the original constraint with `a = g·d`. +/// +/// The degree lowering happens in the constraint evaluation generator. +/// It can be executed by running `cargo run --bin constraint-evaluation-generator`. +/// Executing the constraint evaluator is a prerequisite for running both the Stark prover +/// and the Stark verifier. +/// +/// The new variables introduced by the degree lowering step are called “derived columns.” +/// They are added to the [`DegreeLoweringTable`], whose sole purpose is to store the values +/// of these derived columns. +pub const TARGET_DEGREE: isize = 4; + +pub trait AIR { + type MainColumn: MasterBaseTableColumn + EnumCount; + type AuxColumn: MasterExtTableColumn + EnumCount; + + fn initial_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec>; + + fn consistency_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec>; + + fn transition_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec>; + + fn terminal_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec>; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn public_types_implement_usual_auto_traits() { + fn implements_auto_traits() {} + + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + } +} diff --git a/triton-air/src/table.rs b/triton-air/src/table.rs new file mode 100644 index 000000000..7e17ad4c8 --- /dev/null +++ b/triton-air/src/table.rs @@ -0,0 +1,111 @@ +use arbitrary::Arbitrary; +use strum::Display; +use strum::EnumCount; +use strum::EnumIter; + +use crate::table::cascade::CascadeTable; +use crate::table::hash::HashTable; +use crate::table::jump_stack::JumpStackTable; +use crate::table::lookup::LookupTable; +use crate::table::op_stack::OpStackTable; +use crate::table::processor::ProcessorTable; +use crate::table::program::ProgramTable; +use crate::table::ram::RamTable; +use crate::table::u32::U32Table; +use crate::AIR; + +pub mod cascade; +pub mod hash; +pub mod jump_stack; +pub mod lookup; +pub mod op_stack; +pub mod processor; +pub mod program; +pub mod ram; +pub mod u32; + +/// The total number of main columns across all tables. +/// The degree lowering columns are _not_ included. +pub const NUM_BASE_COLUMNS: usize = ::MainColumn::COUNT + + ::MainColumn::COUNT + + ::MainColumn::COUNT + + ::MainColumn::COUNT + + ::MainColumn::COUNT + + ::MainColumn::COUNT + + ::MainColumn::COUNT + + ::MainColumn::COUNT + + ::MainColumn::COUNT; + +/// The total number of auxiliary columns across all tables. +/// The degree lowering columns as well as any randomizer polynomials are _not_ +/// included. +pub const NUM_EXT_COLUMNS: usize = ::AuxColumn::COUNT + + ::AuxColumn::COUNT + + ::AuxColumn::COUNT + + ::AuxColumn::COUNT + + ::AuxColumn::COUNT + + ::AuxColumn::COUNT + + ::AuxColumn::COUNT + + ::AuxColumn::COUNT + + ::AuxColumn::COUNT; + +pub const PROGRAM_TABLE_START: usize = 0; +pub const PROGRAM_TABLE_END: usize = PROGRAM_TABLE_START + ::MainColumn::COUNT; +pub const PROCESSOR_TABLE_START: usize = PROGRAM_TABLE_END; +pub const PROCESSOR_TABLE_END: usize = + PROCESSOR_TABLE_START + ::MainColumn::COUNT; +pub const OP_STACK_TABLE_START: usize = PROCESSOR_TABLE_END; +pub const OP_STACK_TABLE_END: usize = + OP_STACK_TABLE_START + ::MainColumn::COUNT; +pub const RAM_TABLE_START: usize = OP_STACK_TABLE_END; +pub const RAM_TABLE_END: usize = RAM_TABLE_START + ::MainColumn::COUNT; +pub const JUMP_STACK_TABLE_START: usize = RAM_TABLE_END; +pub const JUMP_STACK_TABLE_END: usize = + JUMP_STACK_TABLE_START + ::MainColumn::COUNT; +pub const HASH_TABLE_START: usize = JUMP_STACK_TABLE_END; +pub const HASH_TABLE_END: usize = HASH_TABLE_START + ::MainColumn::COUNT; +pub const CASCADE_TABLE_START: usize = HASH_TABLE_END; +pub const CASCADE_TABLE_END: usize = CASCADE_TABLE_START + ::MainColumn::COUNT; +pub const LOOKUP_TABLE_START: usize = CASCADE_TABLE_END; +pub const LOOKUP_TABLE_END: usize = LOOKUP_TABLE_START + ::MainColumn::COUNT; +pub const U32_TABLE_START: usize = LOOKUP_TABLE_END; +pub const U32_TABLE_END: usize = U32_TABLE_START + ::MainColumn::COUNT; + +pub const EXT_PROGRAM_TABLE_START: usize = 0; +pub const EXT_PROGRAM_TABLE_END: usize = + EXT_PROGRAM_TABLE_START + ::AuxColumn::COUNT; +pub const EXT_PROCESSOR_TABLE_START: usize = EXT_PROGRAM_TABLE_END; +pub const EXT_PROCESSOR_TABLE_END: usize = + EXT_PROCESSOR_TABLE_START + ::AuxColumn::COUNT; +pub const EXT_OP_STACK_TABLE_START: usize = EXT_PROCESSOR_TABLE_END; +pub const EXT_OP_STACK_TABLE_END: usize = + EXT_OP_STACK_TABLE_START + ::AuxColumn::COUNT; +pub const EXT_RAM_TABLE_START: usize = EXT_OP_STACK_TABLE_END; +pub const EXT_RAM_TABLE_END: usize = EXT_RAM_TABLE_START + ::AuxColumn::COUNT; +pub const EXT_JUMP_STACK_TABLE_START: usize = EXT_RAM_TABLE_END; +pub const EXT_JUMP_STACK_TABLE_END: usize = + EXT_JUMP_STACK_TABLE_START + ::AuxColumn::COUNT; +pub const EXT_HASH_TABLE_START: usize = EXT_JUMP_STACK_TABLE_END; +pub const EXT_HASH_TABLE_END: usize = EXT_HASH_TABLE_START + ::AuxColumn::COUNT; +pub const EXT_CASCADE_TABLE_START: usize = EXT_HASH_TABLE_END; +pub const EXT_CASCADE_TABLE_END: usize = + EXT_CASCADE_TABLE_START + ::AuxColumn::COUNT; +pub const EXT_LOOKUP_TABLE_START: usize = EXT_CASCADE_TABLE_END; +pub const EXT_LOOKUP_TABLE_END: usize = + EXT_LOOKUP_TABLE_START + ::AuxColumn::COUNT; +pub const EXT_U32_TABLE_START: usize = EXT_LOOKUP_TABLE_END; +pub const EXT_U32_TABLE_END: usize = EXT_U32_TABLE_START + ::AuxColumn::COUNT; + +/// Uniquely determines one of Triton VM's tables. +#[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter, Arbitrary)] +pub enum TableId { + Program, + Processor, + OpStack, + Ram, + JumpStack, + Hash, + Cascade, + Lookup, + U32, +} diff --git a/triton-vm/src/table/cascade_table.rs b/triton-air/src/table/cascade.rs similarity index 51% rename from triton-vm/src/table/cascade_table.rs rename to triton-air/src/table/cascade.rs index 770f272e7..e9beeb93b 100644 --- a/triton-vm/src/table/cascade_table.rs +++ b/triton-air/src/table/cascade.rs @@ -1,143 +1,41 @@ use constraint_circuit::ConstraintCircuitBuilder; use constraint_circuit::ConstraintCircuitMonad; use constraint_circuit::DualRowIndicator; -use constraint_circuit::DualRowIndicator::*; +use constraint_circuit::DualRowIndicator::CurrentBaseRow; +use constraint_circuit::DualRowIndicator::CurrentExtRow; +use constraint_circuit::DualRowIndicator::NextBaseRow; +use constraint_circuit::DualRowIndicator::NextExtRow; use constraint_circuit::SingleRowIndicator; -use constraint_circuit::SingleRowIndicator::*; -use ndarray::s; -use ndarray::ArrayView2; -use ndarray::ArrayViewMut2; -use num_traits::ConstOne; -use num_traits::One; -use strum::EnumCount; -use twenty_first::prelude::*; - -use crate::aet::AlgebraicExecutionTrace; -use crate::profiler::profiler; -use crate::table::challenges::ChallengeId; -use crate::table::challenges::ChallengeId::*; -use crate::table::challenges::Challenges; -use crate::table::cross_table_argument::CrossTableArg; -use crate::table::cross_table_argument::LookupArg; -use crate::table::table_column::CascadeBaseTableColumn; -use crate::table::table_column::CascadeBaseTableColumn::*; -use crate::table::table_column::CascadeExtTableColumn; -use crate::table::table_column::CascadeExtTableColumn::*; -use crate::table::table_column::MasterBaseTableColumn; -use crate::table::table_column::MasterExtTableColumn; - -pub const BASE_WIDTH: usize = CascadeBaseTableColumn::COUNT; -pub const EXT_WIDTH: usize = CascadeExtTableColumn::COUNT; -pub const FULL_WIDTH: usize = BASE_WIDTH + EXT_WIDTH; +use constraint_circuit::SingleRowIndicator::BaseRow; +use constraint_circuit::SingleRowIndicator::ExtRow; + +use crate::challenge_id::ChallengeId; +use crate::challenge_id::ChallengeId::CascadeLookupIndeterminate; +use crate::challenge_id::ChallengeId::HashCascadeLookInWeight; +use crate::challenge_id::ChallengeId::HashCascadeLookOutWeight; +use crate::challenge_id::ChallengeId::HashCascadeLookupIndeterminate; +use crate::challenge_id::ChallengeId::LookupTableInputWeight; +use crate::challenge_id::ChallengeId::LookupTableOutputWeight; +use crate::cross_table_argument::CrossTableArg; +use crate::cross_table_argument::LookupArg; +use crate::table_column::MasterBaseTableColumn; +use crate::table_column::MasterExtTableColumn; +use crate::AIR; #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct CascadeTable; -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct ExtCascadeTable; - -impl CascadeTable { - pub fn fill_trace( - cascade_table: &mut ArrayViewMut2, - aet: &AlgebraicExecutionTrace, - ) { - for (row_idx, (&to_look_up, &multiplicity)) in - aet.cascade_table_lookup_multiplicities.iter().enumerate() - { - let to_look_up_lo = (to_look_up & 0xff) as u8; - let to_look_up_hi = ((to_look_up >> 8) & 0xff) as u8; - - let mut row = cascade_table.row_mut(row_idx); - row[LookInLo.base_table_index()] = bfe!(to_look_up_lo); - row[LookInHi.base_table_index()] = bfe!(to_look_up_hi); - row[LookOutLo.base_table_index()] = Self::lookup_8_bit_limb(to_look_up_lo); - row[LookOutHi.base_table_index()] = Self::lookup_8_bit_limb(to_look_up_hi); - row[LookupMultiplicity.base_table_index()] = bfe!(multiplicity); - } - } - - pub fn pad_trace(mut cascade_table: ArrayViewMut2, cascade_table_length: usize) { - cascade_table - .slice_mut(s![cascade_table_length.., IsPadding.base_table_index()]) - .fill(BFieldElement::ONE); - } - - pub fn extend( - base_table: ArrayView2, - mut ext_table: ArrayViewMut2, - challenges: &Challenges, - ) { - profiler!(start "cascade table"); - assert_eq!(BASE_WIDTH, base_table.ncols()); - assert_eq!(EXT_WIDTH, ext_table.ncols()); - assert_eq!(base_table.nrows(), ext_table.nrows()); - - let mut hash_table_log_derivative = LookupArg::default_initial(); - let mut lookup_table_log_derivative = LookupArg::default_initial(); - - let two_pow_8 = bfe!(1 << 8); - - let hash_indeterminate = challenges[HashCascadeLookupIndeterminate]; - let hash_input_weight = challenges[HashCascadeLookInWeight]; - let hash_output_weight = challenges[HashCascadeLookOutWeight]; +impl AIR for CascadeTable { + type MainColumn = crate::table_column::CascadeBaseTableColumn; + type AuxColumn = crate::table_column::CascadeExtTableColumn; - let lookup_indeterminate = challenges[CascadeLookupIndeterminate]; - let lookup_input_weight = challenges[LookupTableInputWeight]; - let lookup_output_weight = challenges[LookupTableOutputWeight]; - - for row_idx in 0..base_table.nrows() { - let base_row = base_table.row(row_idx); - let is_padding = base_row[IsPadding.base_table_index()].is_one(); - - if !is_padding { - let look_in = two_pow_8 * base_row[LookInHi.base_table_index()] - + base_row[LookInLo.base_table_index()]; - let look_out = two_pow_8 * base_row[LookOutHi.base_table_index()] - + base_row[LookOutLo.base_table_index()]; - let compressed_row_hash = - hash_input_weight * look_in + hash_output_weight * look_out; - let lookup_multiplicity = base_row[LookupMultiplicity.base_table_index()]; - hash_table_log_derivative += - (hash_indeterminate - compressed_row_hash).inverse() * lookup_multiplicity; - - let compressed_row_lo = lookup_input_weight * base_row[LookInLo.base_table_index()] - + lookup_output_weight * base_row[LookOutLo.base_table_index()]; - let compressed_row_hi = lookup_input_weight * base_row[LookInHi.base_table_index()] - + lookup_output_weight * base_row[LookOutHi.base_table_index()]; - lookup_table_log_derivative += (lookup_indeterminate - compressed_row_lo).inverse(); - lookup_table_log_derivative += (lookup_indeterminate - compressed_row_hi).inverse(); - } - - let mut extension_row = ext_table.row_mut(row_idx); - extension_row[HashTableServerLogDerivative.ext_table_index()] = - hash_table_log_derivative; - extension_row[LookupTableClientLogDerivative.ext_table_index()] = - lookup_table_log_derivative; - } - profiler!(stop "cascade table"); - } - - fn lookup_8_bit_limb(to_look_up: u8) -> BFieldElement { - tip5::LOOKUP_TABLE[to_look_up as usize].into() - } - - pub fn lookup_16_bit_limb(to_look_up: u16) -> BFieldElement { - let to_look_up_lo = (to_look_up & 0xff) as u8; - let to_look_up_hi = ((to_look_up >> 8) & 0xff) as u8; - let looked_up_lo = Self::lookup_8_bit_limb(to_look_up_lo); - let looked_up_hi = Self::lookup_8_bit_limb(to_look_up_hi); - bfe!(1 << 8) * looked_up_hi + looked_up_lo - } -} - -impl ExtCascadeTable { - pub fn initial_constraints( + fn initial_constraints( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let base_row = |col_id: CascadeBaseTableColumn| { + let main_row = |col_id: Self::MainColumn| { circuit_builder.input(BaseRow(col_id.master_base_table_index())) }; - let ext_row = |col_id: CascadeExtTableColumn| { + let aux_row = |col_id: Self::AuxColumn| { circuit_builder.input(ExtRow(col_id.master_ext_table_index())) }; let challenge = |challenge_id: ChallengeId| circuit_builder.challenge(challenge_id); @@ -147,14 +45,16 @@ impl ExtCascadeTable { let two_pow_8 = circuit_builder.b_constant(1 << 8); let lookup_arg_default_initial = circuit_builder.x_constant(LookupArg::default_initial()); - let is_padding = base_row(IsPadding); - let look_in_hi = base_row(LookInHi); - let look_in_lo = base_row(LookInLo); - let look_out_hi = base_row(LookOutHi); - let look_out_lo = base_row(LookOutLo); - let lookup_multiplicity = base_row(LookupMultiplicity); - let hash_table_server_log_derivative = ext_row(HashTableServerLogDerivative); - let lookup_table_client_log_derivative = ext_row(LookupTableClientLogDerivative); + let is_padding = main_row(Self::MainColumn::IsPadding); + let look_in_hi = main_row(Self::MainColumn::LookInHi); + let look_in_lo = main_row(Self::MainColumn::LookInLo); + let look_out_hi = main_row(Self::MainColumn::LookOutHi); + let look_out_lo = main_row(Self::MainColumn::LookOutLo); + let lookup_multiplicity = main_row(Self::MainColumn::LookupMultiplicity); + let hash_table_server_log_derivative = + aux_row(Self::AuxColumn::HashTableServerLogDerivative); + let lookup_table_client_log_derivative = + aux_row(Self::AuxColumn::LookupTableClientLogDerivative); let hash_indeterminate = challenge(HashCascadeLookupIndeterminate); let hash_input_weight = challenge(HashCascadeLookInWeight); @@ -202,36 +102,36 @@ impl ExtCascadeTable { ] } - pub fn consistency_constraints( + fn consistency_constraints( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let base_row = |col_id: CascadeBaseTableColumn| { + let row = |col_id: Self::MainColumn| { circuit_builder.input(BaseRow(col_id.master_base_table_index())) }; let one = circuit_builder.b_constant(1); - let is_padding = base_row(IsPadding); + let is_padding = row(Self::MainColumn::IsPadding); let is_padding_is_0_or_1 = is_padding.clone() * (one - is_padding); vec![is_padding_is_0_or_1] } - pub fn transition_constraints( + fn transition_constraints( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { let challenge = |c| circuit_builder.challenge(c); let constant = |c: u64| circuit_builder.b_constant(c); - let current_base_row = |column_idx: CascadeBaseTableColumn| { + let curr_main_row = |column_idx: Self::MainColumn| { circuit_builder.input(CurrentBaseRow(column_idx.master_base_table_index())) }; - let next_base_row = |column_idx: CascadeBaseTableColumn| { + let next_main_row = |column_idx: Self::MainColumn| { circuit_builder.input(NextBaseRow(column_idx.master_base_table_index())) }; - let current_ext_row = |column_idx: CascadeExtTableColumn| { + let curr_aux_row = |column_idx: Self::AuxColumn| { circuit_builder.input(CurrentExtRow(column_idx.master_ext_table_index())) }; - let next_ext_row = |column_idx: CascadeExtTableColumn| { + let next_aux_row = |column_idx: Self::AuxColumn| { circuit_builder.input(NextExtRow(column_idx.master_ext_table_index())) }; @@ -239,18 +139,22 @@ impl ExtCascadeTable { let two = constant(2); let two_pow_8 = constant(1 << 8); - let is_padding = current_base_row(IsPadding); - let hash_table_server_log_derivative = current_ext_row(HashTableServerLogDerivative); - let lookup_table_client_log_derivative = current_ext_row(LookupTableClientLogDerivative); - - let is_padding_next = next_base_row(IsPadding); - let look_in_hi_next = next_base_row(LookInHi); - let look_in_lo_next = next_base_row(LookInLo); - let look_out_hi_next = next_base_row(LookOutHi); - let look_out_lo_next = next_base_row(LookOutLo); - let lookup_multiplicity_next = next_base_row(LookupMultiplicity); - let hash_table_server_log_derivative_next = next_ext_row(HashTableServerLogDerivative); - let lookup_table_client_log_derivative_next = next_ext_row(LookupTableClientLogDerivative); + let is_padding = curr_main_row(Self::MainColumn::IsPadding); + let hash_table_server_log_derivative = + curr_aux_row(Self::AuxColumn::HashTableServerLogDerivative); + let lookup_table_client_log_derivative = + curr_aux_row(Self::AuxColumn::LookupTableClientLogDerivative); + + let is_padding_next = next_main_row(Self::MainColumn::IsPadding); + let look_in_hi_next = next_main_row(Self::MainColumn::LookInHi); + let look_in_lo_next = next_main_row(Self::MainColumn::LookInLo); + let look_out_hi_next = next_main_row(Self::MainColumn::LookOutHi); + let look_out_lo_next = next_main_row(Self::MainColumn::LookOutLo); + let lookup_multiplicity_next = next_main_row(Self::MainColumn::LookupMultiplicity); + let hash_table_server_log_derivative_next = + next_aux_row(Self::AuxColumn::HashTableServerLogDerivative); + let lookup_table_client_log_derivative_next = + next_aux_row(Self::AuxColumn::LookupTableClientLogDerivative); let hash_indeterminate = challenge(HashCascadeLookupIndeterminate); let hash_input_weight = challenge(HashCascadeLookInWeight); @@ -304,7 +208,7 @@ impl ExtCascadeTable { ] } - pub fn terminal_constraints( + fn terminal_constraints( _circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { // no further constraints diff --git a/triton-air/src/table/hash.rs b/triton-air/src/table/hash.rs new file mode 100644 index 000000000..b7ce592a1 --- /dev/null +++ b/triton-air/src/table/hash.rs @@ -0,0 +1,1393 @@ +use constraint_circuit::ConstraintCircuitBuilder; +use constraint_circuit::ConstraintCircuitMonad; +use constraint_circuit::DualRowIndicator; +use constraint_circuit::DualRowIndicator::CurrentBaseRow; +use constraint_circuit::DualRowIndicator::CurrentExtRow; +use constraint_circuit::DualRowIndicator::NextBaseRow; +use constraint_circuit::DualRowIndicator::NextExtRow; +use constraint_circuit::InputIndicator; +use constraint_circuit::SingleRowIndicator; +use constraint_circuit::SingleRowIndicator::BaseRow; +use constraint_circuit::SingleRowIndicator::ExtRow; +use isa::instruction::Instruction; +use itertools::Itertools; +use strum::Display; +use strum::EnumCount; +use strum::EnumIter; +use strum::IntoEnumIterator; +use twenty_first::prelude::tip5::NUM_ROUNDS; +use twenty_first::prelude::*; + +use crate::challenge_id::ChallengeId; +use crate::cross_table_argument::CrossTableArg; +use crate::cross_table_argument::EvalArg; +use crate::cross_table_argument::LookupArg; +use crate::table_column::MasterBaseTableColumn; +use crate::table_column::MasterExtTableColumn; +use crate::AIR; + +/// See [`HashTable::base_field_element_into_16_bit_limbs`] for more details. +pub const MONTGOMERY_MODULUS: BFieldElement = + BFieldElement::new(((1_u128 << 64) % BFieldElement::P as u128) as u64); + +pub const POWER_MAP_EXPONENT: u64 = 7; +pub const NUM_ROUND_CONSTANTS: usize = tip5::STATE_SIZE; + +pub const PERMUTATION_TRACE_LENGTH: usize = NUM_ROUNDS + 1; + +pub type PermutationTrace = [[BFieldElement; tip5::STATE_SIZE]; PERMUTATION_TRACE_LENGTH]; + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct HashTable; + +impl HashTable { + /// Get the MDS matrix's entry in row `row_idx` and column `col_idx`. + const fn mds_matrix_entry(row_idx: usize, col_idx: usize) -> BFieldElement { + assert!(row_idx < tip5::STATE_SIZE); + assert!(col_idx < tip5::STATE_SIZE); + let index_in_matrix_defining_column = + (tip5::STATE_SIZE + row_idx - col_idx) % tip5::STATE_SIZE; + let mds_matrix_entry = tip5::MDS_MATRIX_FIRST_COLUMN[index_in_matrix_defining_column]; + BFieldElement::new(mds_matrix_entry as u64) + } + + /// The round constants for round `r` if it is a valid round number in the Tip5 permutation, + /// and the zero vector otherwise. + pub fn tip5_round_constants_by_round_number(r: usize) -> [BFieldElement; NUM_ROUND_CONSTANTS] { + if r >= NUM_ROUNDS { + return bfe_array![0; NUM_ROUND_CONSTANTS]; + } + + let range_start = NUM_ROUND_CONSTANTS * r; + let range_end = NUM_ROUND_CONSTANTS * (r + 1); + tip5::ROUND_CONSTANTS[range_start..range_end] + .try_into() + .unwrap() + } + + /// Construct one of the states 0 through 3 from its constituent limbs. + /// For example, state 0 (prior to it being looked up in the split-and-lookup S-Box, which is + /// usually the desired version of the state) is constructed from limbs + /// [`State0HighestLkIn`] through [`State0LowestLkIn`]. + /// + /// States 4 through 15 are directly accessible. See also the slightly related + /// [`Self::state_column_by_index`]. + fn re_compose_16_bit_limbs( + circuit_builder: &ConstraintCircuitBuilder, + highest: ConstraintCircuitMonad, + mid_high: ConstraintCircuitMonad, + mid_low: ConstraintCircuitMonad, + lowest: ConstraintCircuitMonad, + ) -> ConstraintCircuitMonad { + let constant = |c: u64| circuit_builder.b_constant(c); + let montgomery_modulus_inv = circuit_builder.b_constant(MONTGOMERY_MODULUS.inverse()); + + let sum_of_shifted_limbs = highest * constant(1 << 48) + + mid_high * constant(1 << 32) + + mid_low * constant(1 << 16) + + lowest; + sum_of_shifted_limbs * montgomery_modulus_inv + } + + /// A constraint circuit evaluating to zero if and only if the given + /// `round_number_circuit_node` is not equal to the given `round_number_to_deselect`. + fn round_number_deselector( + circuit_builder: &ConstraintCircuitBuilder, + round_number_circuit_node: &ConstraintCircuitMonad, + round_number_to_deselect: usize, + ) -> ConstraintCircuitMonad { + assert!( + round_number_to_deselect <= NUM_ROUNDS, + "Round number must be in [0, {NUM_ROUNDS}] but got {round_number_to_deselect}." + ); + let constant = |c: u64| circuit_builder.b_constant(c); + + // To not subtract zero from the first factor: some special casing. + let first_factor = match round_number_to_deselect { + 0 => constant(1), + _ => round_number_circuit_node.clone(), + }; + (1..=NUM_ROUNDS) + .filter(|&r| r != round_number_to_deselect) + .map(|r| round_number_circuit_node.clone() - constant(r as u64)) + .fold(first_factor, |a, b| a * b) + } + + /// A constraint circuit evaluating to zero if and only if the given `mode_circuit_node` is + /// equal to the given `mode_to_select`. + fn select_mode( + circuit_builder: &ConstraintCircuitBuilder, + mode_circuit_node: &ConstraintCircuitMonad, + mode_to_select: HashTableMode, + ) -> ConstraintCircuitMonad { + mode_circuit_node.clone() - circuit_builder.b_constant(mode_to_select) + } + + /// A constraint circuit evaluating to zero if and only if the given `mode_circuit_node` is + /// not equal to the given `mode_to_deselect`. + fn mode_deselector( + circuit_builder: &ConstraintCircuitBuilder, + mode_circuit_node: &ConstraintCircuitMonad, + mode_to_deselect: HashTableMode, + ) -> ConstraintCircuitMonad { + let constant = |c: u64| circuit_builder.b_constant(c); + HashTableMode::iter() + .filter(|&mode| mode != mode_to_deselect) + .map(|mode| mode_circuit_node.clone() - constant(mode.into())) + .fold(constant(1), |accumulator, factor| accumulator * factor) + } + + fn instruction_deselector( + circuit_builder: &ConstraintCircuitBuilder, + current_instruction_node: &ConstraintCircuitMonad, + instruction_to_deselect: Instruction, + ) -> ConstraintCircuitMonad { + let constant = |c: u64| circuit_builder.b_constant(c); + let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b()); + let relevant_instructions = [ + Instruction::Hash, + Instruction::SpongeInit, + Instruction::SpongeAbsorb, + Instruction::SpongeSqueeze, + ]; + assert!(relevant_instructions.contains(&instruction_to_deselect)); + + relevant_instructions + .iter() + .filter(|&instruction| instruction != &instruction_to_deselect) + .map(|&instruction| current_instruction_node.clone() - opcode(instruction)) + .fold(constant(1), |accumulator, factor| accumulator * factor) + } + + /// The [`HashBaseTableColumn`] for the round constant corresponding to the given index. + /// Valid indices are 0 through 15, corresponding to the 16 round constants + /// [`Constant0`] through [`Constant15`]. + pub fn round_constant_column_by_index(index: usize) -> ::MainColumn { + match index { + 0 => ::MainColumn::Constant0, + 1 => ::MainColumn::Constant1, + 2 => ::MainColumn::Constant2, + 3 => ::MainColumn::Constant3, + 4 => ::MainColumn::Constant4, + 5 => ::MainColumn::Constant5, + 6 => ::MainColumn::Constant6, + 7 => ::MainColumn::Constant7, + 8 => ::MainColumn::Constant8, + 9 => ::MainColumn::Constant9, + 10 => ::MainColumn::Constant10, + 11 => ::MainColumn::Constant11, + 12 => ::MainColumn::Constant12, + 13 => ::MainColumn::Constant13, + 14 => ::MainColumn::Constant14, + 15 => ::MainColumn::Constant15, + _ => panic!("invalid constant column index"), + } + } + + /// The [`HashBaseTableColumn`] for the state corresponding to the given index. + /// Valid indices are 4 through 15, corresponding to the 12 state columns + /// [`State4`] through [`State15`]. + /// + /// States with indices 0 through 3 have to be assembled from the respective limbs; + /// see [`Self::re_compose_states_0_through_3_before_lookup`] + /// or [`Self::re_compose_16_bit_limbs`]. + fn state_column_by_index(index: usize) -> ::MainColumn { + match index { + 4 => ::MainColumn::State4, + 5 => ::MainColumn::State5, + 6 => ::MainColumn::State6, + 7 => ::MainColumn::State7, + 8 => ::MainColumn::State8, + 9 => ::MainColumn::State9, + 10 => ::MainColumn::State10, + 11 => ::MainColumn::State11, + 12 => ::MainColumn::State12, + 13 => ::MainColumn::State13, + 14 => ::MainColumn::State14, + 15 => ::MainColumn::State15, + _ => panic!("invalid state column index"), + } + } + + fn indicate_column_index_in_base_row(column: ::MainColumn) -> SingleRowIndicator { + BaseRow(column.master_base_table_index()) + } + + fn indicate_column_index_in_current_base_row( + column: ::MainColumn, + ) -> DualRowIndicator { + CurrentBaseRow(column.master_base_table_index()) + } + + fn indicate_column_index_in_next_base_row( + column: ::MainColumn, + ) -> DualRowIndicator { + NextBaseRow(column.master_base_table_index()) + } + + fn re_compose_states_0_through_3_before_lookup( + circuit_builder: &ConstraintCircuitBuilder, + main_row_to_input_indicator: fn(::MainColumn) -> II, + ) -> [ConstraintCircuitMonad; 4] { + let input = |input_indicator: II| circuit_builder.input(input_indicator); + let state_0 = Self::re_compose_16_bit_limbs( + circuit_builder, + input(main_row_to_input_indicator( + ::MainColumn::State0HighestLkIn, + )), + input(main_row_to_input_indicator( + ::MainColumn::State0MidHighLkIn, + )), + input(main_row_to_input_indicator( + ::MainColumn::State0MidLowLkIn, + )), + input(main_row_to_input_indicator( + ::MainColumn::State0LowestLkIn, + )), + ); + let state_1 = Self::re_compose_16_bit_limbs( + circuit_builder, + input(main_row_to_input_indicator( + ::MainColumn::State1HighestLkIn, + )), + input(main_row_to_input_indicator( + ::MainColumn::State1MidHighLkIn, + )), + input(main_row_to_input_indicator( + ::MainColumn::State1MidLowLkIn, + )), + input(main_row_to_input_indicator( + ::MainColumn::State1LowestLkIn, + )), + ); + let state_2 = Self::re_compose_16_bit_limbs( + circuit_builder, + input(main_row_to_input_indicator( + ::MainColumn::State2HighestLkIn, + )), + input(main_row_to_input_indicator( + ::MainColumn::State2MidHighLkIn, + )), + input(main_row_to_input_indicator( + ::MainColumn::State2MidLowLkIn, + )), + input(main_row_to_input_indicator( + ::MainColumn::State2LowestLkIn, + )), + ); + let state_3 = Self::re_compose_16_bit_limbs( + circuit_builder, + input(main_row_to_input_indicator( + ::MainColumn::State3HighestLkIn, + )), + input(main_row_to_input_indicator( + ::MainColumn::State3MidHighLkIn, + )), + input(main_row_to_input_indicator( + ::MainColumn::State3MidLowLkIn, + )), + input(main_row_to_input_indicator( + ::MainColumn::State3LowestLkIn, + )), + ); + [state_0, state_1, state_2, state_3] + } + + fn tip5_constraints_as_circuits( + circuit_builder: &ConstraintCircuitBuilder, + ) -> ( + [ConstraintCircuitMonad; tip5::STATE_SIZE], + [ConstraintCircuitMonad; tip5::STATE_SIZE], + ) { + let constant = |c: u64| circuit_builder.b_constant(c); + let b_constant = |c| circuit_builder.b_constant(c); + let current_main_row = |column_idx: ::MainColumn| { + circuit_builder.input(CurrentBaseRow(column_idx.master_base_table_index())) + }; + let next_main_row = |column_idx: ::MainColumn| { + circuit_builder.input(NextBaseRow(column_idx.master_base_table_index())) + }; + + let state_0_after_lookup = Self::re_compose_16_bit_limbs( + circuit_builder, + current_main_row(::MainColumn::State0HighestLkOut), + current_main_row(::MainColumn::State0MidHighLkOut), + current_main_row(::MainColumn::State0MidLowLkOut), + current_main_row(::MainColumn::State0LowestLkOut), + ); + let state_1_after_lookup = Self::re_compose_16_bit_limbs( + circuit_builder, + current_main_row(::MainColumn::State1HighestLkOut), + current_main_row(::MainColumn::State1MidHighLkOut), + current_main_row(::MainColumn::State1MidLowLkOut), + current_main_row(::MainColumn::State1LowestLkOut), + ); + let state_2_after_lookup = Self::re_compose_16_bit_limbs( + circuit_builder, + current_main_row(::MainColumn::State2HighestLkOut), + current_main_row(::MainColumn::State2MidHighLkOut), + current_main_row(::MainColumn::State2MidLowLkOut), + current_main_row(::MainColumn::State2LowestLkOut), + ); + let state_3_after_lookup = Self::re_compose_16_bit_limbs( + circuit_builder, + current_main_row(::MainColumn::State3HighestLkOut), + current_main_row(::MainColumn::State3MidHighLkOut), + current_main_row(::MainColumn::State3MidLowLkOut), + current_main_row(::MainColumn::State3LowestLkOut), + ); + + let state_part_before_power_map: [_; tip5::STATE_SIZE - tip5::NUM_SPLIT_AND_LOOKUP] = [ + ::MainColumn::State4, + ::MainColumn::State5, + ::MainColumn::State6, + ::MainColumn::State7, + ::MainColumn::State8, + ::MainColumn::State9, + ::MainColumn::State10, + ::MainColumn::State11, + ::MainColumn::State12, + ::MainColumn::State13, + ::MainColumn::State14, + ::MainColumn::State15, + ] + .map(current_main_row); + + let state_part_after_power_map = { + let mut exponentiation_accumulator = state_part_before_power_map.clone(); + for _ in 1..POWER_MAP_EXPONENT { + for (i, state) in exponentiation_accumulator.iter_mut().enumerate() { + *state = state.clone() * state_part_before_power_map[i].clone(); + } + } + exponentiation_accumulator + }; + + let state_after_s_box_application = [ + state_0_after_lookup, + state_1_after_lookup, + state_2_after_lookup, + state_3_after_lookup, + state_part_after_power_map[0].clone(), + state_part_after_power_map[1].clone(), + state_part_after_power_map[2].clone(), + state_part_after_power_map[3].clone(), + state_part_after_power_map[4].clone(), + state_part_after_power_map[5].clone(), + state_part_after_power_map[6].clone(), + state_part_after_power_map[7].clone(), + state_part_after_power_map[8].clone(), + state_part_after_power_map[9].clone(), + state_part_after_power_map[10].clone(), + state_part_after_power_map[11].clone(), + ]; + + let mut state_after_matrix_multiplication = vec![constant(0); tip5::STATE_SIZE]; + for (row_idx, acc) in state_after_matrix_multiplication.iter_mut().enumerate() { + for (col_idx, state) in state_after_s_box_application.iter().enumerate() { + let matrix_entry = b_constant(Self::mds_matrix_entry(row_idx, col_idx)); + *acc = acc.clone() + matrix_entry * state.clone(); + } + } + + let round_constants: [_; tip5::STATE_SIZE] = [ + ::MainColumn::Constant0, + ::MainColumn::Constant1, + ::MainColumn::Constant2, + ::MainColumn::Constant3, + ::MainColumn::Constant4, + ::MainColumn::Constant5, + ::MainColumn::Constant6, + ::MainColumn::Constant7, + ::MainColumn::Constant8, + ::MainColumn::Constant9, + ::MainColumn::Constant10, + ::MainColumn::Constant11, + ::MainColumn::Constant12, + ::MainColumn::Constant13, + ::MainColumn::Constant14, + ::MainColumn::Constant15, + ] + .map(current_main_row); + + let state_after_round_constant_addition = state_after_matrix_multiplication + .into_iter() + .zip_eq(round_constants) + .map(|(st, rndc)| st + rndc) + .collect_vec(); + + let [state_0_next, state_1_next, state_2_next, state_3_next] = + Self::re_compose_states_0_through_3_before_lookup( + circuit_builder, + Self::indicate_column_index_in_next_base_row, + ); + let state_next = [ + state_0_next, + state_1_next, + state_2_next, + state_3_next, + next_main_row(::MainColumn::State4), + next_main_row(::MainColumn::State5), + next_main_row(::MainColumn::State6), + next_main_row(::MainColumn::State7), + next_main_row(::MainColumn::State8), + next_main_row(::MainColumn::State9), + next_main_row(::MainColumn::State10), + next_main_row(::MainColumn::State11), + next_main_row(::MainColumn::State12), + next_main_row(::MainColumn::State13), + next_main_row(::MainColumn::State14), + next_main_row(::MainColumn::State15), + ]; + + let round_number_next = next_main_row(::MainColumn::RoundNumber); + let hash_function_round_correctly_performs_update = state_after_round_constant_addition + .into_iter() + .zip_eq(state_next.clone()) + .map(|(state_element, state_element_next)| { + round_number_next.clone() * (state_element - state_element_next) + }) + .collect_vec() + .try_into() + .unwrap(); + + (state_next, hash_function_round_correctly_performs_update) + } + + fn cascade_log_derivative_update_circuit( + circuit_builder: &ConstraintCircuitBuilder, + look_in_column: ::MainColumn, + look_out_column: ::MainColumn, + cascade_log_derivative_column: ::AuxColumn, + ) -> ConstraintCircuitMonad { + let challenge = |c| circuit_builder.challenge(c); + let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b()); + let constant = |c: u32| circuit_builder.b_constant(c); + let next_main_row = |column_idx: ::MainColumn| { + circuit_builder.input(NextBaseRow(column_idx.master_base_table_index())) + }; + let current_aux_row = |column_idx: ::AuxColumn| { + circuit_builder.input(CurrentExtRow(column_idx.master_ext_table_index())) + }; + let next_aux_row = |column_idx: ::AuxColumn| { + circuit_builder.input(NextExtRow(column_idx.master_ext_table_index())) + }; + + let cascade_indeterminate = challenge(ChallengeId::HashCascadeLookupIndeterminate); + let look_in_weight = challenge(ChallengeId::HashCascadeLookInWeight); + let look_out_weight = challenge(ChallengeId::HashCascadeLookOutWeight); + + let ci_next = next_main_row(::MainColumn::CI); + let mode_next = next_main_row(::MainColumn::Mode); + let round_number_next = next_main_row(::MainColumn::RoundNumber); + let cascade_log_derivative = current_aux_row(cascade_log_derivative_column); + let cascade_log_derivative_next = next_aux_row(cascade_log_derivative_column); + + let compressed_row = look_in_weight * next_main_row(look_in_column) + + look_out_weight * next_main_row(look_out_column); + + let cascade_log_derivative_remains = + cascade_log_derivative_next.clone() - cascade_log_derivative.clone(); + let cascade_log_derivative_updates = (cascade_log_derivative_next - cascade_log_derivative) + * (cascade_indeterminate - compressed_row) + - constant(1); + + let next_row_is_padding_row_or_round_number_next_is_max_or_ci_next_is_sponge_init = + Self::select_mode(circuit_builder, &mode_next, HashTableMode::Pad) + * (round_number_next.clone() - constant(NUM_ROUNDS as u32)) + * (ci_next.clone() - opcode(Instruction::SpongeInit)); + let round_number_next_is_not_num_rounds = + Self::round_number_deselector(circuit_builder, &round_number_next, NUM_ROUNDS); + let current_instruction_next_is_not_sponge_init = + Self::instruction_deselector(circuit_builder, &ci_next, Instruction::SpongeInit); + + next_row_is_padding_row_or_round_number_next_is_max_or_ci_next_is_sponge_init + * cascade_log_derivative_updates + + round_number_next_is_not_num_rounds * cascade_log_derivative_remains.clone() + + current_instruction_next_is_not_sponge_init * cascade_log_derivative_remains + } +} + +impl AIR for HashTable { + type MainColumn = crate::table_column::HashBaseTableColumn; + type AuxColumn = crate::table_column::HashExtTableColumn; + + fn initial_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + let challenge = |c| circuit_builder.challenge(c); + let constant = |c: u64| circuit_builder.b_constant(c); + + let main_row = |column: Self::MainColumn| { + circuit_builder.input(BaseRow(column.master_base_table_index())) + }; + let aux_row = |column: Self::AuxColumn| { + circuit_builder.input(ExtRow(column.master_ext_table_index())) + }; + + let running_evaluation_initial = circuit_builder.x_constant(EvalArg::default_initial()); + let lookup_arg_default_initial = circuit_builder.x_constant(LookupArg::default_initial()); + + let mode = main_row(Self::MainColumn::Mode); + let running_evaluation_hash_input = aux_row(Self::AuxColumn::HashInputRunningEvaluation); + let running_evaluation_hash_digest = aux_row(Self::AuxColumn::HashDigestRunningEvaluation); + let running_evaluation_sponge = aux_row(Self::AuxColumn::SpongeRunningEvaluation); + let running_evaluation_receive_chunk = + aux_row(Self::AuxColumn::ReceiveChunkRunningEvaluation); + + let cascade_indeterminate = challenge(ChallengeId::HashCascadeLookupIndeterminate); + let look_in_weight = challenge(ChallengeId::HashCascadeLookInWeight); + let look_out_weight = challenge(ChallengeId::HashCascadeLookOutWeight); + let prepare_chunk_indeterminate = + challenge(ChallengeId::ProgramAttestationPrepareChunkIndeterminate); + let receive_chunk_indeterminate = + challenge(ChallengeId::ProgramAttestationSendChunkIndeterminate); + + // First chunk of the program is received correctly. Relates to program attestation. + let [state_0, state_1, state_2, state_3] = + Self::re_compose_states_0_through_3_before_lookup( + circuit_builder, + Self::indicate_column_index_in_base_row, + ); + let state_rate_part: [_; tip5::RATE] = [ + state_0, + state_1, + state_2, + state_3, + main_row(Self::MainColumn::State4), + main_row(Self::MainColumn::State5), + main_row(Self::MainColumn::State6), + main_row(Self::MainColumn::State7), + main_row(Self::MainColumn::State8), + main_row(Self::MainColumn::State9), + ]; + let compressed_chunk = state_rate_part + .into_iter() + .fold(running_evaluation_initial.clone(), |acc, state_element| { + acc * prepare_chunk_indeterminate.clone() + state_element + }); + let running_evaluation_receive_chunk_is_initialized_correctly = + running_evaluation_receive_chunk + - receive_chunk_indeterminate * running_evaluation_initial.clone() + - compressed_chunk; + + // The lookup arguments with the Cascade Table for the S-Boxes are initialized correctly. + let cascade_log_derivative_init_circuit = + |look_in_column, look_out_column, cascade_log_derivative_column| { + let look_in = main_row(look_in_column); + let look_out = main_row(look_out_column); + let compressed_row = + look_in_weight.clone() * look_in + look_out_weight.clone() * look_out; + let cascade_log_derivative = aux_row(cascade_log_derivative_column); + (cascade_log_derivative - lookup_arg_default_initial.clone()) + * (cascade_indeterminate.clone() - compressed_row) + - constant(1) + }; + + // miscellaneous initial constraints + let mode_is_program_hashing = + Self::select_mode(circuit_builder, &mode, HashTableMode::ProgramHashing); + let round_number_is_0 = main_row(Self::MainColumn::RoundNumber); + let running_evaluation_hash_input_is_default_initial = + running_evaluation_hash_input - running_evaluation_initial.clone(); + let running_evaluation_hash_digest_is_default_initial = + running_evaluation_hash_digest - running_evaluation_initial.clone(); + let running_evaluation_sponge_is_default_initial = + running_evaluation_sponge - running_evaluation_initial; + + vec![ + mode_is_program_hashing, + round_number_is_0, + running_evaluation_hash_input_is_default_initial, + running_evaluation_hash_digest_is_default_initial, + running_evaluation_sponge_is_default_initial, + running_evaluation_receive_chunk_is_initialized_correctly, + cascade_log_derivative_init_circuit( + Self::MainColumn::State0HighestLkIn, + Self::MainColumn::State0HighestLkOut, + Self::AuxColumn::CascadeState0HighestClientLogDerivative, + ), + cascade_log_derivative_init_circuit( + Self::MainColumn::State0MidHighLkIn, + Self::MainColumn::State0MidHighLkOut, + Self::AuxColumn::CascadeState0MidHighClientLogDerivative, + ), + cascade_log_derivative_init_circuit( + Self::MainColumn::State0MidLowLkIn, + Self::MainColumn::State0MidLowLkOut, + Self::AuxColumn::CascadeState0MidLowClientLogDerivative, + ), + cascade_log_derivative_init_circuit( + Self::MainColumn::State0LowestLkIn, + Self::MainColumn::State0LowestLkOut, + Self::AuxColumn::CascadeState0LowestClientLogDerivative, + ), + cascade_log_derivative_init_circuit( + Self::MainColumn::State1HighestLkIn, + Self::MainColumn::State1HighestLkOut, + Self::AuxColumn::CascadeState1HighestClientLogDerivative, + ), + cascade_log_derivative_init_circuit( + Self::MainColumn::State1MidHighLkIn, + Self::MainColumn::State1MidHighLkOut, + Self::AuxColumn::CascadeState1MidHighClientLogDerivative, + ), + cascade_log_derivative_init_circuit( + Self::MainColumn::State1MidLowLkIn, + Self::MainColumn::State2MidHighLkIn, + Self::AuxColumn::CascadeState1MidLowClientLogDerivative, + ), + cascade_log_derivative_init_circuit( + Self::MainColumn::State1LowestLkIn, + Self::MainColumn::State1LowestLkOut, + Self::AuxColumn::CascadeState1LowestClientLogDerivative, + ), + cascade_log_derivative_init_circuit( + Self::MainColumn::State2LowestLkIn, + Self::MainColumn::State2HighestLkOut, + Self::AuxColumn::CascadeState2HighestClientLogDerivative, + ), + cascade_log_derivative_init_circuit( + Self::MainColumn::State2MidHighLkIn, + Self::MainColumn::State2MidHighLkOut, + Self::AuxColumn::CascadeState2MidHighClientLogDerivative, + ), + cascade_log_derivative_init_circuit( + Self::MainColumn::State2MidLowLkIn, + Self::MainColumn::State2MidLowLkOut, + Self::AuxColumn::CascadeState2MidLowClientLogDerivative, + ), + cascade_log_derivative_init_circuit( + Self::MainColumn::State2LowestLkIn, + Self::MainColumn::State2LowestLkOut, + Self::AuxColumn::CascadeState2LowestClientLogDerivative, + ), + cascade_log_derivative_init_circuit( + Self::MainColumn::State3HighestLkIn, + Self::MainColumn::State3HighestLkOut, + Self::AuxColumn::CascadeState3HighestClientLogDerivative, + ), + cascade_log_derivative_init_circuit( + Self::MainColumn::State3MidHighLkIn, + Self::MainColumn::State3MidHighLkOut, + Self::AuxColumn::CascadeState3MidHighClientLogDerivative, + ), + cascade_log_derivative_init_circuit( + Self::MainColumn::State3MidLowLkIn, + Self::MainColumn::State3MidLowLkOut, + Self::AuxColumn::CascadeState3MidLowClientLogDerivative, + ), + cascade_log_derivative_init_circuit( + Self::MainColumn::State3LowestLkIn, + Self::MainColumn::State3LowestLkOut, + Self::AuxColumn::CascadeState3LowestClientLogDerivative, + ), + ] + } + + fn consistency_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b()); + let constant = |c: u64| circuit_builder.b_constant(c); + let main_row = |column_id: Self::MainColumn| { + circuit_builder.input(BaseRow(column_id.master_base_table_index())) + }; + + let mode = main_row(Self::MainColumn::Mode); + let ci = main_row(Self::MainColumn::CI); + let round_number = main_row(Self::MainColumn::RoundNumber); + + let ci_is_hash = ci.clone() - opcode(Instruction::Hash); + let ci_is_sponge_init = ci.clone() - opcode(Instruction::SpongeInit); + let ci_is_sponge_absorb = ci.clone() - opcode(Instruction::SpongeAbsorb); + let ci_is_sponge_squeeze = ci - opcode(Instruction::SpongeSqueeze); + + let mode_is_not_hash = Self::mode_deselector(circuit_builder, &mode, HashTableMode::Hash); + let round_number_is_not_0 = + Self::round_number_deselector(circuit_builder, &round_number, 0); + + let mode_is_a_valid_mode = + Self::mode_deselector(circuit_builder, &mode, HashTableMode::Pad) + * Self::select_mode(circuit_builder, &mode, HashTableMode::Pad); + + let if_mode_is_not_sponge_then_ci_is_hash = + Self::select_mode(circuit_builder, &mode, HashTableMode::Sponge) * ci_is_hash.clone(); + + let if_mode_is_sponge_then_ci_is_a_sponge_instruction = + Self::mode_deselector(circuit_builder, &mode, HashTableMode::Sponge) + * ci_is_sponge_init + * ci_is_sponge_absorb.clone() + * ci_is_sponge_squeeze.clone(); + + let if_padding_mode_then_round_number_is_0 = + Self::mode_deselector(circuit_builder, &mode, HashTableMode::Pad) + * round_number.clone(); + + let if_ci_is_sponge_init_then_ = ci_is_hash * ci_is_sponge_absorb * ci_is_sponge_squeeze; + let if_ci_is_sponge_init_then_round_number_is_0 = + if_ci_is_sponge_init_then_.clone() * round_number.clone(); + + let if_ci_is_sponge_init_then_rate_is_0 = (10..=15).map(|state_index| { + let state_element = main_row(Self::state_column_by_index(state_index)); + if_ci_is_sponge_init_then_.clone() * state_element + }); + + let if_mode_is_hash_and_round_no_is_0_then_ = round_number_is_not_0 * mode_is_not_hash; + let if_mode_is_hash_and_round_no_is_0_then_states_10_through_15_are_1 = + (10..=15).map(|state_index| { + let state_element = main_row(Self::state_column_by_index(state_index)); + if_mode_is_hash_and_round_no_is_0_then_.clone() * (state_element - constant(1)) + }); + + // consistency of the inverse of the highest 2 limbs minus 2^32 - 1 + let one = constant(1); + let two_pow_16 = constant(1 << 16); + let two_pow_32 = constant(1 << 32); + let state_0_hi_limbs_minus_2_pow_32 = two_pow_32.clone() + - one.clone() + - main_row(Self::MainColumn::State0HighestLkIn) * two_pow_16.clone() + - main_row(Self::MainColumn::State0MidHighLkIn); + let state_1_hi_limbs_minus_2_pow_32 = two_pow_32.clone() + - one.clone() + - main_row(Self::MainColumn::State1HighestLkIn) * two_pow_16.clone() + - main_row(Self::MainColumn::State1MidHighLkIn); + let state_2_hi_limbs_minus_2_pow_32 = two_pow_32.clone() + - one.clone() + - main_row(Self::MainColumn::State2HighestLkIn) * two_pow_16.clone() + - main_row(Self::MainColumn::State2MidHighLkIn); + let state_3_hi_limbs_minus_2_pow_32 = two_pow_32 + - one.clone() + - main_row(Self::MainColumn::State3HighestLkIn) * two_pow_16.clone() + - main_row(Self::MainColumn::State3MidHighLkIn); + + let state_0_hi_limbs_inv = main_row(Self::MainColumn::State0Inv); + let state_1_hi_limbs_inv = main_row(Self::MainColumn::State1Inv); + let state_2_hi_limbs_inv = main_row(Self::MainColumn::State2Inv); + let state_3_hi_limbs_inv = main_row(Self::MainColumn::State3Inv); + + let state_0_hi_limbs_are_not_all_1s = + state_0_hi_limbs_minus_2_pow_32.clone() * state_0_hi_limbs_inv.clone() - one.clone(); + let state_1_hi_limbs_are_not_all_1s = + state_1_hi_limbs_minus_2_pow_32.clone() * state_1_hi_limbs_inv.clone() - one.clone(); + let state_2_hi_limbs_are_not_all_1s = + state_2_hi_limbs_minus_2_pow_32.clone() * state_2_hi_limbs_inv.clone() - one.clone(); + let state_3_hi_limbs_are_not_all_1s = + state_3_hi_limbs_minus_2_pow_32.clone() * state_3_hi_limbs_inv.clone() - one; + + let state_0_hi_limbs_inv_is_inv_or_is_zero = + state_0_hi_limbs_are_not_all_1s.clone() * state_0_hi_limbs_inv; + let state_1_hi_limbs_inv_is_inv_or_is_zero = + state_1_hi_limbs_are_not_all_1s.clone() * state_1_hi_limbs_inv; + let state_2_hi_limbs_inv_is_inv_or_is_zero = + state_2_hi_limbs_are_not_all_1s.clone() * state_2_hi_limbs_inv; + let state_3_hi_limbs_inv_is_inv_or_is_zero = + state_3_hi_limbs_are_not_all_1s.clone() * state_3_hi_limbs_inv; + + let state_0_hi_limbs_inv_is_inv_or_state_0_hi_limbs_is_zero = + state_0_hi_limbs_are_not_all_1s.clone() * state_0_hi_limbs_minus_2_pow_32; + let state_1_hi_limbs_inv_is_inv_or_state_1_hi_limbs_is_zero = + state_1_hi_limbs_are_not_all_1s.clone() * state_1_hi_limbs_minus_2_pow_32; + let state_2_hi_limbs_inv_is_inv_or_state_2_hi_limbs_is_zero = + state_2_hi_limbs_are_not_all_1s.clone() * state_2_hi_limbs_minus_2_pow_32; + let state_3_hi_limbs_inv_is_inv_or_state_3_hi_limbs_is_zero = + state_3_hi_limbs_are_not_all_1s.clone() * state_3_hi_limbs_minus_2_pow_32; + + // consistent decomposition into limbs + let state_0_lo_limbs = main_row(Self::MainColumn::State0MidLowLkIn) * two_pow_16.clone() + + main_row(Self::MainColumn::State0LowestLkIn); + let state_1_lo_limbs = main_row(Self::MainColumn::State1MidLowLkIn) * two_pow_16.clone() + + main_row(Self::MainColumn::State1LowestLkIn); + let state_2_lo_limbs = main_row(Self::MainColumn::State2MidLowLkIn) * two_pow_16.clone() + + main_row(Self::MainColumn::State2LowestLkIn); + let state_3_lo_limbs = main_row(Self::MainColumn::State3MidLowLkIn) * two_pow_16 + + main_row(Self::MainColumn::State3LowestLkIn); + + let if_state_0_hi_limbs_are_all_1_then_state_0_lo_limbs_are_all_0 = + state_0_hi_limbs_are_not_all_1s * state_0_lo_limbs; + let if_state_1_hi_limbs_are_all_1_then_state_1_lo_limbs_are_all_0 = + state_1_hi_limbs_are_not_all_1s * state_1_lo_limbs; + let if_state_2_hi_limbs_are_all_1_then_state_2_lo_limbs_are_all_0 = + state_2_hi_limbs_are_not_all_1s * state_2_lo_limbs; + let if_state_3_hi_limbs_are_all_1_then_state_3_lo_limbs_are_all_0 = + state_3_hi_limbs_are_not_all_1s * state_3_lo_limbs; + + let mut constraints = vec![ + mode_is_a_valid_mode, + if_mode_is_not_sponge_then_ci_is_hash, + if_mode_is_sponge_then_ci_is_a_sponge_instruction, + if_padding_mode_then_round_number_is_0, + if_ci_is_sponge_init_then_round_number_is_0, + state_0_hi_limbs_inv_is_inv_or_is_zero, + state_1_hi_limbs_inv_is_inv_or_is_zero, + state_2_hi_limbs_inv_is_inv_or_is_zero, + state_3_hi_limbs_inv_is_inv_or_is_zero, + state_0_hi_limbs_inv_is_inv_or_state_0_hi_limbs_is_zero, + state_1_hi_limbs_inv_is_inv_or_state_1_hi_limbs_is_zero, + state_2_hi_limbs_inv_is_inv_or_state_2_hi_limbs_is_zero, + state_3_hi_limbs_inv_is_inv_or_state_3_hi_limbs_is_zero, + if_state_0_hi_limbs_are_all_1_then_state_0_lo_limbs_are_all_0, + if_state_1_hi_limbs_are_all_1_then_state_1_lo_limbs_are_all_0, + if_state_2_hi_limbs_are_all_1_then_state_2_lo_limbs_are_all_0, + if_state_3_hi_limbs_are_all_1_then_state_3_lo_limbs_are_all_0, + ]; + + constraints.extend(if_ci_is_sponge_init_then_rate_is_0); + constraints.extend(if_mode_is_hash_and_round_no_is_0_then_states_10_through_15_are_1); + + for round_constant_column_idx in 0..NUM_ROUND_CONSTANTS { + let round_constant_column = + Self::round_constant_column_by_index(round_constant_column_idx); + let round_constant_column_circuit = main_row(round_constant_column); + let mut round_constant_constraint_circuit = constant(0); + for round_idx in 0..NUM_ROUNDS { + let round_constants = Self::tip5_round_constants_by_round_number(round_idx); + let round_constant = round_constants[round_constant_column_idx]; + let round_constant = circuit_builder.b_constant(round_constant); + let round_deselector_circuit = + Self::round_number_deselector(circuit_builder, &round_number, round_idx); + round_constant_constraint_circuit = round_constant_constraint_circuit + + round_deselector_circuit + * (round_constant_column_circuit.clone() - round_constant); + } + constraints.push(round_constant_constraint_circuit); + } + + constraints + } + + fn transition_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + let challenge = |c| circuit_builder.challenge(c); + let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b()); + let constant = |c: u64| circuit_builder.b_constant(c); + + let opcode_hash = opcode(Instruction::Hash); + let opcode_sponge_init = opcode(Instruction::SpongeInit); + let opcode_sponge_absorb = opcode(Instruction::SpongeAbsorb); + let opcode_sponge_squeeze = opcode(Instruction::SpongeSqueeze); + + let current_main_row = |column_idx: Self::MainColumn| { + circuit_builder.input(CurrentBaseRow(column_idx.master_base_table_index())) + }; + let next_base_row = |column_idx: Self::MainColumn| { + circuit_builder.input(NextBaseRow(column_idx.master_base_table_index())) + }; + let current_ext_row = |column_idx: Self::AuxColumn| { + circuit_builder.input(CurrentExtRow(column_idx.master_ext_table_index())) + }; + let next_ext_row = |column_idx: Self::AuxColumn| { + circuit_builder.input(NextExtRow(column_idx.master_ext_table_index())) + }; + + let running_evaluation_initial = circuit_builder.x_constant(EvalArg::default_initial()); + + let prepare_chunk_indeterminate = + challenge(ChallengeId::ProgramAttestationPrepareChunkIndeterminate); + let receive_chunk_indeterminate = + challenge(ChallengeId::ProgramAttestationSendChunkIndeterminate); + let compress_program_digest_indeterminate = + challenge(ChallengeId::CompressProgramDigestIndeterminate); + let expected_program_digest = challenge(ChallengeId::CompressedProgramDigest); + let hash_input_eval_indeterminate = challenge(ChallengeId::HashInputIndeterminate); + let hash_digest_eval_indeterminate = challenge(ChallengeId::HashDigestIndeterminate); + let sponge_indeterminate = challenge(ChallengeId::SpongeIndeterminate); + + let mode = current_main_row(Self::MainColumn::Mode); + let ci = current_main_row(Self::MainColumn::CI); + let round_number = current_main_row(Self::MainColumn::RoundNumber); + let running_evaluation_receive_chunk = + current_ext_row(Self::AuxColumn::ReceiveChunkRunningEvaluation); + let running_evaluation_hash_input = + current_ext_row(Self::AuxColumn::HashInputRunningEvaluation); + let running_evaluation_hash_digest = + current_ext_row(Self::AuxColumn::HashDigestRunningEvaluation); + let running_evaluation_sponge = current_ext_row(Self::AuxColumn::SpongeRunningEvaluation); + + let mode_next = next_base_row(Self::MainColumn::Mode); + let ci_next = next_base_row(Self::MainColumn::CI); + let round_number_next = next_base_row(Self::MainColumn::RoundNumber); + let running_evaluation_receive_chunk_next = + next_ext_row(Self::AuxColumn::ReceiveChunkRunningEvaluation); + let running_evaluation_hash_input_next = + next_ext_row(Self::AuxColumn::HashInputRunningEvaluation); + let running_evaluation_hash_digest_next = + next_ext_row(Self::AuxColumn::HashDigestRunningEvaluation); + let running_evaluation_sponge_next = next_ext_row(Self::AuxColumn::SpongeRunningEvaluation); + + let [state_0, state_1, state_2, state_3] = + Self::re_compose_states_0_through_3_before_lookup( + circuit_builder, + Self::indicate_column_index_in_current_base_row, + ); + + let state_current = [ + state_0, + state_1, + state_2, + state_3, + current_main_row(Self::MainColumn::State4), + current_main_row(Self::MainColumn::State5), + current_main_row(Self::MainColumn::State6), + current_main_row(Self::MainColumn::State7), + current_main_row(Self::MainColumn::State8), + current_main_row(Self::MainColumn::State9), + current_main_row(Self::MainColumn::State10), + current_main_row(Self::MainColumn::State11), + current_main_row(Self::MainColumn::State12), + current_main_row(Self::MainColumn::State13), + current_main_row(Self::MainColumn::State14), + current_main_row(Self::MainColumn::State15), + ]; + + let (state_next, hash_function_round_correctly_performs_update) = + Self::tip5_constraints_as_circuits(circuit_builder); + + let state_weights = [ + ChallengeId::StackWeight0, + ChallengeId::StackWeight1, + ChallengeId::StackWeight2, + ChallengeId::StackWeight3, + ChallengeId::StackWeight4, + ChallengeId::StackWeight5, + ChallengeId::StackWeight6, + ChallengeId::StackWeight7, + ChallengeId::StackWeight8, + ChallengeId::StackWeight9, + ChallengeId::StackWeight10, + ChallengeId::StackWeight11, + ChallengeId::StackWeight12, + ChallengeId::StackWeight13, + ChallengeId::StackWeight14, + ChallengeId::StackWeight15, + ] + .map(challenge); + + let round_number_is_not_num_rounds = + Self::round_number_deselector(circuit_builder, &round_number, NUM_ROUNDS); + + let round_number_is_0_through_4_or_round_number_next_is_0 = + round_number_is_not_num_rounds * round_number_next.clone(); + + let next_mode_is_padding_mode_or_round_number_is_num_rounds_or_increments_by_one = + Self::select_mode(circuit_builder, &mode_next, HashTableMode::Pad) + * (ci.clone() - opcode_sponge_init.clone()) + * (round_number.clone() - constant(NUM_ROUNDS as u64)) + * (round_number_next.clone() - round_number.clone() - constant(1)); + + // compress the digest by computing the terminal of an evaluation argument + let compressed_digest = state_current[..Digest::LEN].iter().fold( + running_evaluation_initial.clone(), + |acc, digest_element| { + acc * compress_program_digest_indeterminate.clone() + digest_element.clone() + }, + ); + let if_mode_changes_from_program_hashing_then_current_digest_is_expected_program_digest = + Self::mode_deselector(circuit_builder, &mode, HashTableMode::ProgramHashing) + * Self::select_mode(circuit_builder, &mode_next, HashTableMode::ProgramHashing) + * (compressed_digest - expected_program_digest); + + let if_mode_is_program_hashing_and_next_mode_is_sponge_then_ci_next_is_sponge_init = + Self::mode_deselector(circuit_builder, &mode, HashTableMode::ProgramHashing) + * Self::mode_deselector(circuit_builder, &mode_next, HashTableMode::Sponge) + * (ci_next.clone() - opcode_sponge_init.clone()); + + let if_round_number_is_not_max_and_ci_is_not_sponge_init_then_ci_doesnt_change = + (round_number.clone() - constant(NUM_ROUNDS as u64)) + * (ci.clone() - opcode_sponge_init.clone()) + * (ci_next.clone() - ci.clone()); + + let if_round_number_is_not_max_and_ci_is_not_sponge_init_then_mode_doesnt_change = + (round_number - constant(NUM_ROUNDS as u64)) + * (ci.clone() - opcode_sponge_init.clone()) + * (mode_next.clone() - mode.clone()); + + let if_mode_is_sponge_then_mode_next_is_sponge_or_hash_or_pad = + Self::mode_deselector(circuit_builder, &mode, HashTableMode::Sponge) + * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Sponge) + * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Hash) + * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Pad); + + let if_mode_is_hash_then_mode_next_is_hash_or_pad = + Self::mode_deselector(circuit_builder, &mode, HashTableMode::Hash) + * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Hash) + * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Pad); + + let if_mode_is_pad_then_mode_next_is_pad = + Self::mode_deselector(circuit_builder, &mode, HashTableMode::Pad) + * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Pad); + + let difference_of_capacity_registers = state_current[tip5::RATE..] + .iter() + .zip_eq(state_next[tip5::RATE..].iter()) + .map(|(current, next)| next.clone() - current.clone()) + .collect_vec(); + let randomized_sum_of_capacity_differences = state_weights[tip5::RATE..] + .iter() + .zip_eq(difference_of_capacity_registers) + .map(|(weight, state_difference)| weight.clone() * state_difference) + .sum::>(); + + let capacity_doesnt_change_at_section_start_when_program_hashing_or_absorbing = + Self::round_number_deselector(circuit_builder, &round_number_next, 0) + * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Hash) + * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Pad) + * (ci_next.clone() - opcode_sponge_init.clone()) + * randomized_sum_of_capacity_differences.clone(); + + let difference_of_state_registers = state_current + .iter() + .zip_eq(state_next.iter()) + .map(|(current, next)| next.clone() - current.clone()) + .collect_vec(); + let randomized_sum_of_state_differences = state_weights + .iter() + .zip_eq(difference_of_state_registers.iter()) + .map(|(weight, state_difference)| weight.clone() * state_difference.clone()) + .sum(); + let if_round_number_next_is_0_and_ci_next_is_squeeze_then_state_doesnt_change = + Self::round_number_deselector(circuit_builder, &round_number_next, 0) + * Self::instruction_deselector( + circuit_builder, + &ci_next, + Instruction::SpongeSqueeze, + ) + * randomized_sum_of_state_differences; + + // Evaluation Arguments + + // If (and only if) the row number in the next row is 0 and the mode in the next row is + // `hash`, update running evaluation “hash input.” + let running_evaluation_hash_input_remains = + running_evaluation_hash_input_next.clone() - running_evaluation_hash_input.clone(); + let tip5_input = state_next[..tip5::RATE].to_owned(); + let compressed_row_from_processor = tip5_input + .into_iter() + .zip_eq(state_weights[..tip5::RATE].iter()) + .map(|(state, weight)| weight.clone() * state) + .sum(); + + let running_evaluation_hash_input_updates = running_evaluation_hash_input_next + - hash_input_eval_indeterminate * running_evaluation_hash_input + - compressed_row_from_processor; + let running_evaluation_hash_input_is_updated_correctly = + Self::round_number_deselector(circuit_builder, &round_number_next, 0) + * Self::mode_deselector(circuit_builder, &mode_next, HashTableMode::Hash) + * running_evaluation_hash_input_updates + + round_number_next.clone() * running_evaluation_hash_input_remains.clone() + + Self::select_mode(circuit_builder, &mode_next, HashTableMode::Hash) + * running_evaluation_hash_input_remains; + + // If (and only if) the row number in the next row is NUM_ROUNDS and the current instruction + // in the next row corresponds to `hash`, update running evaluation “hash digest.” + let round_number_next_is_num_rounds = + round_number_next.clone() - constant(NUM_ROUNDS as u64); + let running_evaluation_hash_digest_remains = + running_evaluation_hash_digest_next.clone() - running_evaluation_hash_digest.clone(); + let hash_digest = state_next[..Digest::LEN].to_owned(); + let compressed_row_hash_digest = hash_digest + .into_iter() + .zip_eq(state_weights[..Digest::LEN].iter()) + .map(|(state, weight)| weight.clone() * state) + .sum(); + let running_evaluation_hash_digest_updates = running_evaluation_hash_digest_next + - hash_digest_eval_indeterminate * running_evaluation_hash_digest + - compressed_row_hash_digest; + let running_evaluation_hash_digest_is_updated_correctly = + Self::round_number_deselector(circuit_builder, &round_number_next, NUM_ROUNDS) + * Self::mode_deselector(circuit_builder, &mode_next, HashTableMode::Hash) + * running_evaluation_hash_digest_updates + + round_number_next_is_num_rounds * running_evaluation_hash_digest_remains.clone() + + Self::select_mode(circuit_builder, &mode_next, HashTableMode::Hash) + * running_evaluation_hash_digest_remains; + + // The running evaluation for “Sponge” updates correctly. + let compressed_row_next = state_weights[..tip5::RATE] + .iter() + .zip_eq(state_next[..tip5::RATE].iter()) + .map(|(weight, st_next)| weight.clone() * st_next.clone()) + .sum(); + let running_evaluation_sponge_has_accumulated_ci = running_evaluation_sponge_next.clone() + - sponge_indeterminate * running_evaluation_sponge.clone() + - challenge(ChallengeId::HashCIWeight) * ci_next.clone(); + let running_evaluation_sponge_has_accumulated_next_row = + running_evaluation_sponge_has_accumulated_ci.clone() - compressed_row_next; + let if_round_no_next_0_and_ci_next_is_spongy_then_running_evaluation_sponge_updates = + Self::round_number_deselector(circuit_builder, &round_number_next, 0) + * (ci_next.clone() - opcode_hash) + * running_evaluation_sponge_has_accumulated_next_row; + + let running_evaluation_sponge_remains = + running_evaluation_sponge_next - running_evaluation_sponge; + let if_round_no_next_is_not_0_then_running_evaluation_sponge_remains = + round_number_next.clone() * running_evaluation_sponge_remains.clone(); + let if_ci_next_is_not_spongy_then_running_evaluation_sponge_remains = (ci_next.clone() + - opcode_sponge_init) + * (ci_next.clone() - opcode_sponge_absorb) + * (ci_next - opcode_sponge_squeeze) + * running_evaluation_sponge_remains; + let running_evaluation_sponge_is_updated_correctly = + if_round_no_next_0_and_ci_next_is_spongy_then_running_evaluation_sponge_updates + + if_round_no_next_is_not_0_then_running_evaluation_sponge_remains + + if_ci_next_is_not_spongy_then_running_evaluation_sponge_remains; + + // program attestation: absorb RATE instructions if in the right mode on the right row + let compressed_chunk = state_next[..tip5::RATE] + .iter() + .fold(running_evaluation_initial, |acc, rate_element| { + acc * prepare_chunk_indeterminate.clone() + rate_element.clone() + }); + let receive_chunk_running_evaluation_absorbs_chunk_of_instructions = + running_evaluation_receive_chunk_next.clone() + - receive_chunk_indeterminate * running_evaluation_receive_chunk.clone() + - compressed_chunk; + let receive_chunk_running_evaluation_remains = + running_evaluation_receive_chunk_next - running_evaluation_receive_chunk; + let receive_chunk_of_instructions_iff_next_mode_is_prog_hashing_and_next_round_number_is_0 = + Self::round_number_deselector(circuit_builder, &round_number_next, 0) + * Self::mode_deselector(circuit_builder, &mode_next, HashTableMode::ProgramHashing) + * receive_chunk_running_evaluation_absorbs_chunk_of_instructions + + round_number_next * receive_chunk_running_evaluation_remains.clone() + + Self::select_mode(circuit_builder, &mode_next, HashTableMode::ProgramHashing) + * receive_chunk_running_evaluation_remains; + + let constraints = vec![ + round_number_is_0_through_4_or_round_number_next_is_0, + next_mode_is_padding_mode_or_round_number_is_num_rounds_or_increments_by_one, + receive_chunk_of_instructions_iff_next_mode_is_prog_hashing_and_next_round_number_is_0, + if_mode_changes_from_program_hashing_then_current_digest_is_expected_program_digest, + if_mode_is_program_hashing_and_next_mode_is_sponge_then_ci_next_is_sponge_init, + if_round_number_is_not_max_and_ci_is_not_sponge_init_then_ci_doesnt_change, + if_round_number_is_not_max_and_ci_is_not_sponge_init_then_mode_doesnt_change, + if_mode_is_sponge_then_mode_next_is_sponge_or_hash_or_pad, + if_mode_is_hash_then_mode_next_is_hash_or_pad, + if_mode_is_pad_then_mode_next_is_pad, + capacity_doesnt_change_at_section_start_when_program_hashing_or_absorbing, + if_round_number_next_is_0_and_ci_next_is_squeeze_then_state_doesnt_change, + running_evaluation_hash_input_is_updated_correctly, + running_evaluation_hash_digest_is_updated_correctly, + running_evaluation_sponge_is_updated_correctly, + Self::cascade_log_derivative_update_circuit( + circuit_builder, + Self::MainColumn::State0HighestLkIn, + Self::MainColumn::State0HighestLkOut, + Self::AuxColumn::CascadeState0HighestClientLogDerivative, + ), + Self::cascade_log_derivative_update_circuit( + circuit_builder, + Self::MainColumn::State0MidHighLkIn, + Self::MainColumn::State0MidHighLkOut, + Self::AuxColumn::CascadeState0MidHighClientLogDerivative, + ), + Self::cascade_log_derivative_update_circuit( + circuit_builder, + Self::MainColumn::State0MidLowLkIn, + Self::MainColumn::State0MidLowLkOut, + Self::AuxColumn::CascadeState0MidLowClientLogDerivative, + ), + Self::cascade_log_derivative_update_circuit( + circuit_builder, + Self::MainColumn::State0LowestLkIn, + Self::MainColumn::State0LowestLkOut, + Self::AuxColumn::CascadeState0LowestClientLogDerivative, + ), + Self::cascade_log_derivative_update_circuit( + circuit_builder, + Self::MainColumn::State1HighestLkIn, + Self::MainColumn::State1HighestLkOut, + Self::AuxColumn::CascadeState1HighestClientLogDerivative, + ), + Self::cascade_log_derivative_update_circuit( + circuit_builder, + Self::MainColumn::State1MidHighLkIn, + Self::MainColumn::State1MidHighLkOut, + Self::AuxColumn::CascadeState1MidHighClientLogDerivative, + ), + Self::cascade_log_derivative_update_circuit( + circuit_builder, + Self::MainColumn::State1MidLowLkIn, + Self::MainColumn::State1MidLowLkOut, + Self::AuxColumn::CascadeState1MidLowClientLogDerivative, + ), + Self::cascade_log_derivative_update_circuit( + circuit_builder, + Self::MainColumn::State1LowestLkIn, + Self::MainColumn::State1LowestLkOut, + Self::AuxColumn::CascadeState1LowestClientLogDerivative, + ), + Self::cascade_log_derivative_update_circuit( + circuit_builder, + Self::MainColumn::State2HighestLkIn, + Self::MainColumn::State2HighestLkOut, + Self::AuxColumn::CascadeState2HighestClientLogDerivative, + ), + Self::cascade_log_derivative_update_circuit( + circuit_builder, + Self::MainColumn::State2MidHighLkIn, + Self::MainColumn::State2MidHighLkOut, + Self::AuxColumn::CascadeState2MidHighClientLogDerivative, + ), + Self::cascade_log_derivative_update_circuit( + circuit_builder, + Self::MainColumn::State2MidLowLkIn, + Self::MainColumn::State2MidLowLkOut, + Self::AuxColumn::CascadeState2MidLowClientLogDerivative, + ), + Self::cascade_log_derivative_update_circuit( + circuit_builder, + Self::MainColumn::State2LowestLkIn, + Self::MainColumn::State2LowestLkOut, + Self::AuxColumn::CascadeState2LowestClientLogDerivative, + ), + Self::cascade_log_derivative_update_circuit( + circuit_builder, + Self::MainColumn::State3HighestLkIn, + Self::MainColumn::State3HighestLkOut, + Self::AuxColumn::CascadeState3HighestClientLogDerivative, + ), + Self::cascade_log_derivative_update_circuit( + circuit_builder, + Self::MainColumn::State3MidHighLkIn, + Self::MainColumn::State3MidHighLkOut, + Self::AuxColumn::CascadeState3MidHighClientLogDerivative, + ), + Self::cascade_log_derivative_update_circuit( + circuit_builder, + Self::MainColumn::State3MidLowLkIn, + Self::MainColumn::State3MidLowLkOut, + Self::AuxColumn::CascadeState3MidLowClientLogDerivative, + ), + Self::cascade_log_derivative_update_circuit( + circuit_builder, + Self::MainColumn::State3LowestLkIn, + Self::MainColumn::State3LowestLkOut, + Self::AuxColumn::CascadeState3LowestClientLogDerivative, + ), + ]; + + [ + constraints, + hash_function_round_correctly_performs_update.to_vec(), + ] + .concat() + } + + fn terminal_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + let challenge = |c| circuit_builder.challenge(c); + let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b()); + let constant = |c: u64| circuit_builder.b_constant(c); + let main_row = |column_idx: Self::MainColumn| { + circuit_builder.input(BaseRow(column_idx.master_base_table_index())) + }; + + let mode = main_row(Self::MainColumn::Mode); + let round_number = main_row(Self::MainColumn::RoundNumber); + + let compress_program_digest_indeterminate = + challenge(ChallengeId::CompressProgramDigestIndeterminate); + let expected_program_digest = challenge(ChallengeId::CompressedProgramDigest); + + let max_round_number = constant(NUM_ROUNDS as u64); + + let [state_0, state_1, state_2, state_3] = + Self::re_compose_states_0_through_3_before_lookup( + circuit_builder, + Self::indicate_column_index_in_base_row, + ); + let state_4 = main_row(Self::MainColumn::State4); + let program_digest = [state_0, state_1, state_2, state_3, state_4]; + let compressed_digest = program_digest.into_iter().fold( + circuit_builder.x_constant(EvalArg::default_initial()), + |acc, digest_element| { + acc * compress_program_digest_indeterminate.clone() + digest_element + }, + ); + let if_mode_is_program_hashing_then_current_digest_is_expected_program_digest = + Self::mode_deselector(circuit_builder, &mode, HashTableMode::ProgramHashing) + * (compressed_digest - expected_program_digest); + + let if_mode_is_not_pad_and_ci_is_not_sponge_init_then_round_number_is_max_round_number = + Self::select_mode(circuit_builder, &mode, HashTableMode::Pad) + * (main_row(Self::MainColumn::CI) - opcode(Instruction::SpongeInit)) + * (round_number - max_round_number); + + vec![ + if_mode_is_program_hashing_then_current_digest_is_expected_program_digest, + if_mode_is_not_pad_and_ci_is_not_sponge_init_then_round_number_is_max_round_number, + ] + } +} + +/// The current “mode” of the Hash Table. The Hash Table can be in one of four distinct modes: +/// +/// 1. Hashing the [`Program`][program]. This is part of program attestation. +/// 1. Processing all Sponge instructions, _i.e._, `sponge_init`, +/// `sponge_absorb`, `sponge_absorb_mem`, and `sponge_squeeze`. +/// 1. Processing the `hash` instruction. +/// 1. Padding mode. +/// +/// Changing the mode is only possible when the current [`RoundNumber`] is [`NUM_ROUNDS`]. +/// The mode evolves as +/// [`ProgramHashing`][prog_hash] → [`Sponge`][sponge] → [`Hash`][hash] → [`Pad`][pad]. +/// Once mode [`Pad`][pad] is reached, it is not possible to change the mode anymore. +/// Skipping any or all of the modes [`Sponge`][sponge], [`Hash`][hash], or [`Pad`][pad] +/// is possible in principle: +/// - if no Sponge instructions are executed, mode [`Sponge`][sponge] will be skipped, +/// - if no `hash` instruction is executed, mode [`Hash`][hash] will be skipped, and +/// - if the Hash Table does not require any padding, mode [`Pad`][pad] will be skipped. +/// +/// It is not possible to skip mode [`ProgramHashing`][prog_hash]: +/// the [`Program`][program] is always hashed. +/// The empty program is not valid since any valid [`Program`][program] must execute +/// instruction `halt`. +/// +/// [program]: isa::program::Program +/// [prog_hash]: HashTableMode::ProgramHashing +/// [sponge]: HashTableMode::Sponge +/// [hash]: type@HashTableMode::Hash +/// [pad]: HashTableMode::Pad +#[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] +pub enum HashTableMode { + /// The mode in which the [`Program`][program] is hashed. This is part of program attestation. + /// + /// [program]: isa::program::Program + ProgramHashing, + + /// The mode in which Sponge instructions, _i.e._, `sponge_init`, + /// `sponge_absorb`, `sponge_absorb_mem`, and `sponge_squeeze`, are processed. + Sponge, + + /// The mode in which the `hash` instruction is processed. + Hash, + + /// Indicator for padding rows. + Pad, +} + +impl From for u32 { + fn from(mode: HashTableMode) -> Self { + match mode { + HashTableMode::ProgramHashing => 1, + HashTableMode::Sponge => 2, + HashTableMode::Hash => 3, + HashTableMode::Pad => 0, + } + } +} + +impl From for u64 { + fn from(mode: HashTableMode) -> Self { + let discriminant: u32 = mode.into(); + discriminant.into() + } +} + +impl From for BFieldElement { + fn from(mode: HashTableMode) -> Self { + let discriminant: u32 = mode.into(); + discriminant.into() + } +} diff --git a/triton-air/src/table/jump_stack.rs b/triton-air/src/table/jump_stack.rs new file mode 100644 index 000000000..cc2b3e23e --- /dev/null +++ b/triton-air/src/table/jump_stack.rs @@ -0,0 +1,164 @@ +use constraint_circuit::ConstraintCircuitBuilder; +use constraint_circuit::ConstraintCircuitMonad; +use constraint_circuit::DualRowIndicator; +use constraint_circuit::DualRowIndicator::CurrentBaseRow; +use constraint_circuit::DualRowIndicator::CurrentExtRow; +use constraint_circuit::DualRowIndicator::NextBaseRow; +use constraint_circuit::DualRowIndicator::NextExtRow; +use constraint_circuit::SingleRowIndicator; +use constraint_circuit::SingleRowIndicator::BaseRow; +use constraint_circuit::SingleRowIndicator::ExtRow; +use isa::instruction::Instruction; +use twenty_first::prelude::BFieldElement; + +use crate::challenge_id::ChallengeId::ClockJumpDifferenceLookupIndeterminate; +use crate::challenge_id::ChallengeId::JumpStackCiWeight; +use crate::challenge_id::ChallengeId::JumpStackClkWeight; +use crate::challenge_id::ChallengeId::JumpStackIndeterminate; +use crate::challenge_id::ChallengeId::JumpStackJsdWeight; +use crate::challenge_id::ChallengeId::JumpStackJsoWeight; +use crate::challenge_id::ChallengeId::JumpStackJspWeight; +use crate::cross_table_argument::CrossTableArg; +use crate::cross_table_argument::LookupArg; +use crate::table_column::JumpStackBaseTableColumn::CI; +use crate::table_column::JumpStackBaseTableColumn::CLK; +use crate::table_column::JumpStackBaseTableColumn::JSD; +use crate::table_column::JumpStackBaseTableColumn::JSO; +use crate::table_column::JumpStackBaseTableColumn::JSP; +use crate::table_column::JumpStackExtTableColumn::ClockJumpDifferenceLookupClientLogDerivative; +use crate::table_column::JumpStackExtTableColumn::RunningProductPermArg; +use crate::table_column::MasterBaseTableColumn; +use crate::table_column::MasterExtTableColumn; +use crate::AIR; + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct JumpStackTable; + +impl AIR for JumpStackTable { + type MainColumn = crate::table_column::JumpStackBaseTableColumn; + type AuxColumn = crate::table_column::JumpStackExtTableColumn; + + fn initial_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + let clk = circuit_builder.input(BaseRow(CLK.master_base_table_index())); + let jsp = circuit_builder.input(BaseRow(JSP.master_base_table_index())); + let jso = circuit_builder.input(BaseRow(JSO.master_base_table_index())); + let jsd = circuit_builder.input(BaseRow(JSD.master_base_table_index())); + let ci = circuit_builder.input(BaseRow(CI.master_base_table_index())); + let rppa = circuit_builder.input(ExtRow(RunningProductPermArg.master_ext_table_index())); + let clock_jump_diff_log_derivative = circuit_builder.input(ExtRow( + ClockJumpDifferenceLookupClientLogDerivative.master_ext_table_index(), + )); + + let processor_perm_indeterminate = circuit_builder.challenge(JumpStackIndeterminate); + // note: `clk`, `jsp`, `jso`, and `jsd` are all constrained to be 0 and can thus be omitted. + let compressed_row = circuit_builder.challenge(JumpStackCiWeight) * ci; + let rppa_starts_correctly = rppa - (processor_perm_indeterminate - compressed_row); + + // A clock jump difference of 0 is not allowed. Hence, the initial is recorded. + let clock_jump_diff_log_derivative_starts_correctly = clock_jump_diff_log_derivative + - circuit_builder.x_constant(LookupArg::default_initial()); + + vec![ + clk, + jsp, + jso, + jsd, + rppa_starts_correctly, + clock_jump_diff_log_derivative_starts_correctly, + ] + } + + fn consistency_constraints( + _circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + // no further constraints + vec![] + } + + fn transition_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + let one = || circuit_builder.b_constant(1); + let call_opcode = + circuit_builder.b_constant(Instruction::Call(BFieldElement::default()).opcode_b()); + let return_opcode = circuit_builder.b_constant(Instruction::Return.opcode_b()); + let recurse_or_return_opcode = + circuit_builder.b_constant(Instruction::RecurseOrReturn.opcode_b()); + + let clk = circuit_builder.input(CurrentBaseRow(CLK.master_base_table_index())); + let ci = circuit_builder.input(CurrentBaseRow(CI.master_base_table_index())); + let jsp = circuit_builder.input(CurrentBaseRow(JSP.master_base_table_index())); + let jso = circuit_builder.input(CurrentBaseRow(JSO.master_base_table_index())); + let jsd = circuit_builder.input(CurrentBaseRow(JSD.master_base_table_index())); + let rppa = circuit_builder.input(CurrentExtRow( + RunningProductPermArg.master_ext_table_index(), + )); + let clock_jump_diff_log_derivative = circuit_builder.input(CurrentExtRow( + ClockJumpDifferenceLookupClientLogDerivative.master_ext_table_index(), + )); + + let clk_next = circuit_builder.input(NextBaseRow(CLK.master_base_table_index())); + let ci_next = circuit_builder.input(NextBaseRow(CI.master_base_table_index())); + let jsp_next = circuit_builder.input(NextBaseRow(JSP.master_base_table_index())); + let jso_next = circuit_builder.input(NextBaseRow(JSO.master_base_table_index())); + let jsd_next = circuit_builder.input(NextBaseRow(JSD.master_base_table_index())); + let rppa_next = + circuit_builder.input(NextExtRow(RunningProductPermArg.master_ext_table_index())); + let clock_jump_diff_log_derivative_next = circuit_builder.input(NextExtRow( + ClockJumpDifferenceLookupClientLogDerivative.master_ext_table_index(), + )); + + let jsp_inc_or_stays = + (jsp_next.clone() - jsp.clone() - one()) * (jsp_next.clone() - jsp.clone()); + + let jsp_inc_by_one_or_ci_can_return = (jsp_next.clone() - jsp.clone() - one()) + * (ci.clone() - return_opcode) + * (ci.clone() - recurse_or_return_opcode); + let jsp_inc_or_jso_stays_or_ci_can_ret = + jsp_inc_by_one_or_ci_can_return.clone() * (jso_next.clone() - jso); + + let jsp_inc_or_jsd_stays_or_ci_can_ret = + jsp_inc_by_one_or_ci_can_return.clone() * (jsd_next.clone() - jsd); + + let jsp_inc_or_clk_inc_or_ci_call_or_ci_can_ret = jsp_inc_by_one_or_ci_can_return + * (clk_next.clone() - clk.clone() - one()) + * (ci.clone() - call_opcode); + + let compressed_row = circuit_builder.challenge(JumpStackClkWeight) * clk_next.clone() + + circuit_builder.challenge(JumpStackCiWeight) * ci_next + + circuit_builder.challenge(JumpStackJspWeight) * jsp_next.clone() + + circuit_builder.challenge(JumpStackJsoWeight) * jso_next + + circuit_builder.challenge(JumpStackJsdWeight) * jsd_next; + let rppa_updates_correctly = + rppa_next - rppa * (circuit_builder.challenge(JumpStackIndeterminate) - compressed_row); + + let log_derivative_remains = + clock_jump_diff_log_derivative_next.clone() - clock_jump_diff_log_derivative.clone(); + let clk_diff = clk_next - clk; + let log_derivative_accumulates = (clock_jump_diff_log_derivative_next + - clock_jump_diff_log_derivative) + * (circuit_builder.challenge(ClockJumpDifferenceLookupIndeterminate) - clk_diff) + - one(); + let log_derivative_updates_correctly = (jsp_next.clone() - jsp.clone() - one()) + * log_derivative_accumulates + + (jsp_next - jsp) * log_derivative_remains; + + vec![ + jsp_inc_or_stays, + jsp_inc_or_jso_stays_or_ci_can_ret, + jsp_inc_or_jsd_stays_or_ci_can_ret, + jsp_inc_or_clk_inc_or_ci_call_or_ci_can_ret, + rppa_updates_correctly, + log_derivative_updates_correctly, + ] + } + + fn terminal_constraints( + _circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + // no further constraints + vec![] + } +} diff --git a/triton-air/src/table/lookup.rs b/triton-air/src/table/lookup.rs new file mode 100644 index 000000000..a296a4ede --- /dev/null +++ b/triton-air/src/table/lookup.rs @@ -0,0 +1,189 @@ +use constraint_circuit::ConstraintCircuitBuilder; +use constraint_circuit::ConstraintCircuitMonad; +use constraint_circuit::DualRowIndicator; +use constraint_circuit::DualRowIndicator::CurrentBaseRow; +use constraint_circuit::DualRowIndicator::CurrentExtRow; +use constraint_circuit::DualRowIndicator::NextBaseRow; +use constraint_circuit::DualRowIndicator::NextExtRow; +use constraint_circuit::SingleRowIndicator; +use constraint_circuit::SingleRowIndicator::BaseRow; +use constraint_circuit::SingleRowIndicator::ExtRow; + +use crate::challenge_id::ChallengeId; +use crate::challenge_id::ChallengeId::CascadeLookupIndeterminate; +use crate::challenge_id::ChallengeId::LookupTableInputWeight; +use crate::challenge_id::ChallengeId::LookupTableOutputWeight; +use crate::challenge_id::ChallengeId::LookupTablePublicIndeterminate; +use crate::challenge_id::ChallengeId::LookupTablePublicTerminal; +use crate::cross_table_argument::CrossTableArg; +use crate::cross_table_argument::EvalArg; +use crate::cross_table_argument::LookupArg; +use crate::table_column::LookupBaseTableColumn; +use crate::table_column::LookupBaseTableColumn::IsPadding; +use crate::table_column::LookupBaseTableColumn::LookIn; +use crate::table_column::LookupBaseTableColumn::LookOut; +use crate::table_column::LookupBaseTableColumn::LookupMultiplicity; +use crate::table_column::LookupExtTableColumn; +use crate::table_column::LookupExtTableColumn::CascadeTableServerLogDerivative; +use crate::table_column::LookupExtTableColumn::PublicEvaluationArgument; +use crate::table_column::MasterBaseTableColumn; +use crate::table_column::MasterExtTableColumn; +use crate::AIR; + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct LookupTable; + +impl AIR for LookupTable { + type MainColumn = LookupBaseTableColumn; + type AuxColumn = LookupExtTableColumn; + + fn initial_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + let main_row = |col_id: Self::MainColumn| { + circuit_builder.input(BaseRow(col_id.master_base_table_index())) + }; + let aux_row = |col_id: Self::AuxColumn| { + circuit_builder.input(ExtRow(col_id.master_ext_table_index())) + }; + let challenge = |challenge_id: ChallengeId| circuit_builder.challenge(challenge_id); + + let lookup_input = main_row(LookIn); + let lookup_output = main_row(LookOut); + let lookup_multiplicity = main_row(LookupMultiplicity); + let cascade_table_server_log_derivative = aux_row(CascadeTableServerLogDerivative); + let public_evaluation_argument = aux_row(PublicEvaluationArgument); + + let lookup_input_is_0 = lookup_input; + + // Lookup Argument with Cascade Table + // note: `lookup_input` is known to be 0 and thus doesn't appear in the compressed row + let lookup_argument_default_initial = + circuit_builder.x_constant(LookupArg::default_initial()); + let cascade_table_indeterminate = challenge(CascadeLookupIndeterminate); + let compressed_row = lookup_output.clone() * challenge(LookupTableOutputWeight); + let cascade_table_log_derivative_is_initialized_correctly = + (cascade_table_server_log_derivative - lookup_argument_default_initial) + * (cascade_table_indeterminate - compressed_row) + - lookup_multiplicity; + + // public Evaluation Argument + let eval_argument_default_initial = circuit_builder.x_constant(EvalArg::default_initial()); + let public_indeterminate = challenge(LookupTablePublicIndeterminate); + let public_evaluation_argument_is_initialized_correctly = public_evaluation_argument + - eval_argument_default_initial * public_indeterminate + - lookup_output; + + vec![ + lookup_input_is_0, + cascade_table_log_derivative_is_initialized_correctly, + public_evaluation_argument_is_initialized_correctly, + ] + } + + fn consistency_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + let constant = |c: u64| circuit_builder.b_constant(c); + let main_row = |col_id: Self::MainColumn| { + circuit_builder.input(BaseRow(col_id.master_base_table_index())) + }; + + let padding_is_0_or_1 = main_row(IsPadding) * (constant(1) - main_row(IsPadding)); + + vec![padding_is_0_or_1] + } + + fn transition_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + let one = || circuit_builder.b_constant(1); + + let current_main_row = |col_id: Self::MainColumn| { + circuit_builder.input(CurrentBaseRow(col_id.master_base_table_index())) + }; + let next_main_row = |col_id: Self::MainColumn| { + circuit_builder.input(NextBaseRow(col_id.master_base_table_index())) + }; + let current_aux_row = |col_id: Self::AuxColumn| { + circuit_builder.input(CurrentExtRow(col_id.master_ext_table_index())) + }; + let next_aux_row = |col_id: Self::AuxColumn| { + circuit_builder.input(NextExtRow(col_id.master_ext_table_index())) + }; + let challenge = |challenge_id: ChallengeId| circuit_builder.challenge(challenge_id); + + let lookup_input = current_main_row(LookIn); + let is_padding = current_main_row(IsPadding); + let cascade_table_server_log_derivative = current_aux_row(CascadeTableServerLogDerivative); + let public_evaluation_argument = current_aux_row(PublicEvaluationArgument); + + let lookup_input_next = next_main_row(LookIn); + let lookup_output_next = next_main_row(LookOut); + let lookup_multiplicity_next = next_main_row(LookupMultiplicity); + let is_padding_next = next_main_row(IsPadding); + let cascade_table_server_log_derivative_next = + next_aux_row(CascadeTableServerLogDerivative); + let public_evaluation_argument_next = next_aux_row(PublicEvaluationArgument); + + // Padding section is contiguous: if the current row is a padding row, then the next row + // is also a padding row. + let if_current_row_is_padding_row_then_next_row_is_padding_row = + is_padding * (one() - is_padding_next.clone()); + + // Lookup Table's input increments by 1 if and only if the next row is not a padding row + let if_next_row_is_padding_row_then_lookup_input_next_is_0 = + is_padding_next.clone() * lookup_input_next.clone(); + let if_next_row_is_not_padding_row_then_lookup_input_next_increments_by_1 = + (one() - is_padding_next.clone()) * (lookup_input_next.clone() - lookup_input - one()); + let lookup_input_increments_if_and_only_if_next_row_is_not_padding_row = + if_next_row_is_padding_row_then_lookup_input_next_is_0 + + if_next_row_is_not_padding_row_then_lookup_input_next_increments_by_1; + + // Lookup Argument with Cascade Table + let cascade_table_indeterminate = challenge(CascadeLookupIndeterminate); + let compressed_row = lookup_input_next * challenge(LookupTableInputWeight) + + lookup_output_next.clone() * challenge(LookupTableOutputWeight); + let cascade_table_log_derivative_remains = cascade_table_server_log_derivative_next.clone() + - cascade_table_server_log_derivative.clone(); + let cascade_table_log_derivative_updates = (cascade_table_server_log_derivative_next + - cascade_table_server_log_derivative) + * (cascade_table_indeterminate - compressed_row) + - lookup_multiplicity_next; + let cascade_table_log_derivative_updates_if_and_only_if_next_row_is_not_padding_row = + (one() - is_padding_next.clone()) * cascade_table_log_derivative_updates + + is_padding_next.clone() * cascade_table_log_derivative_remains; + + // public Evaluation Argument + let public_indeterminate = challenge(LookupTablePublicIndeterminate); + let public_evaluation_argument_remains = + public_evaluation_argument_next.clone() - public_evaluation_argument.clone(); + let public_evaluation_argument_updates = public_evaluation_argument_next + - public_evaluation_argument * public_indeterminate + - lookup_output_next; + let public_evaluation_argument_updates_if_and_only_if_next_row_is_not_padding_row = + (one() - is_padding_next.clone()) * public_evaluation_argument_updates + + is_padding_next * public_evaluation_argument_remains; + + vec![ + if_current_row_is_padding_row_then_next_row_is_padding_row, + lookup_input_increments_if_and_only_if_next_row_is_not_padding_row, + cascade_table_log_derivative_updates_if_and_only_if_next_row_is_not_padding_row, + public_evaluation_argument_updates_if_and_only_if_next_row_is_not_padding_row, + ] + } + + fn terminal_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + let challenge = |challenge_id: ChallengeId| circuit_builder.challenge(challenge_id); + let aux_row = |col_id: Self::AuxColumn| { + circuit_builder.input(ExtRow(col_id.master_ext_table_index())) + }; + + let narrow_table_terminal_matches_user_supplied_terminal = + aux_row(PublicEvaluationArgument) - challenge(LookupTablePublicTerminal); + + vec![narrow_table_terminal_matches_user_supplied_terminal] + } +} diff --git a/triton-air/src/table/op_stack.rs b/triton-air/src/table/op_stack.rs new file mode 100644 index 000000000..6a822cc08 --- /dev/null +++ b/triton-air/src/table/op_stack.rs @@ -0,0 +1,202 @@ +use constraint_circuit::ConstraintCircuitBuilder; +use constraint_circuit::ConstraintCircuitMonad; +use constraint_circuit::DualRowIndicator; +use constraint_circuit::DualRowIndicator::CurrentBaseRow; +use constraint_circuit::DualRowIndicator::CurrentExtRow; +use constraint_circuit::DualRowIndicator::NextBaseRow; +use constraint_circuit::DualRowIndicator::NextExtRow; +use constraint_circuit::SingleRowIndicator; +use constraint_circuit::SingleRowIndicator::BaseRow; +use constraint_circuit::SingleRowIndicator::ExtRow; +use isa::op_stack::OpStackElement; +use strum::EnumCount; +use twenty_first::prelude::*; + +use crate::challenge_id::ChallengeId; +use crate::cross_table_argument::CrossTableArg; +use crate::cross_table_argument::LookupArg; +use crate::cross_table_argument::PermArg; +use crate::table_column::MasterBaseTableColumn; +use crate::table_column::MasterExtTableColumn; +use crate::AIR; + +/// The value indicating a padding row in the op stack table. Stored in the +/// `ib1_shrink_stack` column. +pub const PADDING_VALUE: BFieldElement = BFieldElement::new(2); + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct OpStackTable; + +impl AIR for OpStackTable { + type MainColumn = crate::table_column::OpStackBaseTableColumn; + type AuxColumn = crate::table_column::OpStackExtTableColumn; + + fn initial_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + let challenge = |c| circuit_builder.challenge(c); + let constant = |c| circuit_builder.b_constant(c); + let x_constant = |c| circuit_builder.x_constant(c); + let main_row = |column: Self::MainColumn| { + circuit_builder.input(BaseRow(column.master_base_table_index())) + }; + let aux_row = |column: Self::AuxColumn| { + circuit_builder.input(ExtRow(column.master_ext_table_index())) + }; + + let initial_stack_length = u32::try_from(OpStackElement::COUNT).unwrap(); + let initial_stack_length = constant(initial_stack_length.into()); + let padding_indicator = constant(PADDING_VALUE); + + let stack_pointer_is_16 = + main_row(Self::MainColumn::StackPointer) - initial_stack_length.clone(); + + let compressed_row = challenge(ChallengeId::OpStackClkWeight) + * main_row(Self::MainColumn::CLK) + + challenge(ChallengeId::OpStackIb1Weight) * main_row(Self::MainColumn::IB1ShrinkStack) + + challenge(ChallengeId::OpStackPointerWeight) * initial_stack_length + + challenge(ChallengeId::OpStackFirstUnderflowElementWeight) + * main_row(Self::MainColumn::FirstUnderflowElement); + let rppa_initial = challenge(ChallengeId::OpStackIndeterminate) - compressed_row; + let rppa_has_accumulated_first_row = + aux_row(Self::AuxColumn::RunningProductPermArg) - rppa_initial; + + let rppa_is_default_initial = aux_row(Self::AuxColumn::RunningProductPermArg) + - x_constant(PermArg::default_initial()); + + let first_row_is_padding_row = + main_row(Self::MainColumn::IB1ShrinkStack) - padding_indicator; + let first_row_is_not_padding_row = main_row(Self::MainColumn::IB1ShrinkStack) + * (main_row(Self::MainColumn::IB1ShrinkStack) - constant(bfe!(1))); + + let rppa_starts_correctly = rppa_has_accumulated_first_row * first_row_is_padding_row + + rppa_is_default_initial * first_row_is_not_padding_row; + + let lookup_argument_initial = x_constant(LookupArg::default_initial()); + let clock_jump_diff_log_derivative_is_initialized_correctly = + aux_row(Self::AuxColumn::ClockJumpDifferenceLookupClientLogDerivative) + - lookup_argument_initial; + + vec![ + stack_pointer_is_16, + rppa_starts_correctly, + clock_jump_diff_log_derivative_is_initialized_correctly, + ] + } + + fn consistency_constraints( + _circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + // no further constraints + vec![] + } + + fn transition_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + let constant = |c| circuit_builder.b_constant(c); + let challenge = |c| circuit_builder.challenge(c); + let current_main_row = |column: Self::MainColumn| { + circuit_builder.input(CurrentBaseRow(column.master_base_table_index())) + }; + let current_aux_row = |column: Self::AuxColumn| { + circuit_builder.input(CurrentExtRow(column.master_ext_table_index())) + }; + let next_main_row = |column: Self::MainColumn| { + circuit_builder.input(NextBaseRow(column.master_base_table_index())) + }; + let next_aux_row = |column: Self::AuxColumn| { + circuit_builder.input(NextExtRow(column.master_ext_table_index())) + }; + + let one = constant(1_u32.into()); + let padding_indicator = constant(PADDING_VALUE); + + let clk = current_main_row(Self::MainColumn::CLK); + let ib1_shrink_stack = current_main_row(Self::MainColumn::IB1ShrinkStack); + let stack_pointer = current_main_row(Self::MainColumn::StackPointer); + let first_underflow_element = current_main_row(Self::MainColumn::FirstUnderflowElement); + let rppa = current_aux_row(Self::AuxColumn::RunningProductPermArg); + let clock_jump_diff_log_derivative = + current_aux_row(Self::AuxColumn::ClockJumpDifferenceLookupClientLogDerivative); + + let clk_next = next_main_row(Self::MainColumn::CLK); + let ib1_shrink_stack_next = next_main_row(Self::MainColumn::IB1ShrinkStack); + let stack_pointer_next = next_main_row(Self::MainColumn::StackPointer); + let first_underflow_element_next = next_main_row(Self::MainColumn::FirstUnderflowElement); + let rppa_next = next_aux_row(Self::AuxColumn::RunningProductPermArg); + let clock_jump_diff_log_derivative_next = + next_aux_row(Self::AuxColumn::ClockJumpDifferenceLookupClientLogDerivative); + + let stack_pointer_increases_by_1_or_does_not_change = + (stack_pointer_next.clone() - stack_pointer.clone() - one.clone()) + * (stack_pointer_next.clone() - stack_pointer.clone()); + + let stack_pointer_inc_by_1_or_underflow_element_doesnt_change_or_next_ci_grows_stack = + (stack_pointer_next.clone() - stack_pointer.clone() - one.clone()) + * (first_underflow_element_next.clone() - first_underflow_element.clone()) + * ib1_shrink_stack_next.clone(); + + let next_row_is_padding_row = ib1_shrink_stack_next.clone() - padding_indicator.clone(); + let if_current_row_is_padding_row_then_next_row_is_padding_row = ib1_shrink_stack.clone() + * (ib1_shrink_stack - one.clone()) + * next_row_is_padding_row.clone(); + + // The running product for the permutation argument `rppa` is updated correctly. + let compressed_row = circuit_builder.challenge(ChallengeId::OpStackClkWeight) + * clk_next.clone() + + circuit_builder.challenge(ChallengeId::OpStackIb1Weight) + * ib1_shrink_stack_next.clone() + + circuit_builder.challenge(ChallengeId::OpStackPointerWeight) + * stack_pointer_next.clone() + + circuit_builder.challenge(ChallengeId::OpStackFirstUnderflowElementWeight) + * first_underflow_element_next; + + let rppa_updates = rppa_next.clone() + - rppa.clone() * (challenge(ChallengeId::OpStackIndeterminate) - compressed_row); + + let next_row_is_not_padding_row = + ib1_shrink_stack_next.clone() * (ib1_shrink_stack_next.clone() - one.clone()); + let rppa_remains = rppa_next - rppa; + + let rppa_updates_correctly = rppa_updates * next_row_is_padding_row.clone() + + rppa_remains * next_row_is_not_padding_row.clone(); + + let clk_diff = clk_next - clk; + let log_derivative_accumulates = (clock_jump_diff_log_derivative_next.clone() + - clock_jump_diff_log_derivative.clone()) + * (challenge(ChallengeId::ClockJumpDifferenceLookupIndeterminate) - clk_diff) + - one.clone(); + let log_derivative_remains = + clock_jump_diff_log_derivative_next.clone() - clock_jump_diff_log_derivative.clone(); + + let log_derivative_accumulates_or_stack_pointer_changes_or_next_row_is_padding_row = + log_derivative_accumulates + * (stack_pointer_next.clone() - stack_pointer.clone() - one.clone()) + * next_row_is_padding_row; + let log_derivative_remains_or_stack_pointer_doesnt_change = + log_derivative_remains.clone() * (stack_pointer_next.clone() - stack_pointer.clone()); + let log_derivatve_remains_or_next_row_is_not_padding_row = + log_derivative_remains * next_row_is_not_padding_row; + + let log_derivative_updates_correctly = + log_derivative_accumulates_or_stack_pointer_changes_or_next_row_is_padding_row + + log_derivative_remains_or_stack_pointer_doesnt_change + + log_derivatve_remains_or_next_row_is_not_padding_row; + + vec![ + stack_pointer_increases_by_1_or_does_not_change, + stack_pointer_inc_by_1_or_underflow_element_doesnt_change_or_next_ci_grows_stack, + if_current_row_is_padding_row_then_next_row_is_padding_row, + rppa_updates_correctly, + log_derivative_updates_correctly, + ] + } + + fn terminal_constraints( + _circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + // no further constraints + vec![] + } +} diff --git a/triton-air/src/table/processor.rs b/triton-air/src/table/processor.rs new file mode 100644 index 000000000..628ec3537 --- /dev/null +++ b/triton-air/src/table/processor.rs @@ -0,0 +1,3594 @@ +use std::cmp::max; +use std::ops::Mul; + +use constraint_circuit::ConstraintCircuitBuilder; +use constraint_circuit::ConstraintCircuitMonad; +use constraint_circuit::DualRowIndicator; +use constraint_circuit::DualRowIndicator::CurrentBaseRow; +use constraint_circuit::DualRowIndicator::CurrentExtRow; +use constraint_circuit::DualRowIndicator::NextBaseRow; +use constraint_circuit::DualRowIndicator::NextExtRow; +use constraint_circuit::InputIndicator; +use constraint_circuit::SingleRowIndicator; +use constraint_circuit::SingleRowIndicator::BaseRow; +use constraint_circuit::SingleRowIndicator::ExtRow; +use isa::instruction::Instruction; +use isa::instruction::InstructionBit; +use isa::instruction::ALL_INSTRUCTIONS; +use isa::op_stack::NumberOfWords; +use isa::op_stack::OpStackElement; +use isa::op_stack::NUM_OP_STACK_REGISTERS; +use itertools::izip; +use itertools::Itertools; +use strum::EnumCount; +use twenty_first::math::x_field_element::EXTENSION_DEGREE; +use twenty_first::prelude::*; + +use crate::challenge_id::ChallengeId; +use crate::challenge_id::ChallengeId::ClockJumpDifferenceLookupIndeterminate; +use crate::challenge_id::ChallengeId::CompressProgramDigestIndeterminate; +use crate::challenge_id::ChallengeId::CompressedProgramDigest; +use crate::challenge_id::ChallengeId::HashCIWeight; +use crate::challenge_id::ChallengeId::HashDigestIndeterminate; +use crate::challenge_id::ChallengeId::HashInputIndeterminate; +use crate::challenge_id::ChallengeId::InstructionLookupIndeterminate; +use crate::challenge_id::ChallengeId::JumpStackCiWeight; +use crate::challenge_id::ChallengeId::JumpStackClkWeight; +use crate::challenge_id::ChallengeId::JumpStackIndeterminate; +use crate::challenge_id::ChallengeId::JumpStackJsdWeight; +use crate::challenge_id::ChallengeId::JumpStackJsoWeight; +use crate::challenge_id::ChallengeId::JumpStackJspWeight; +use crate::challenge_id::ChallengeId::OpStackClkWeight; +use crate::challenge_id::ChallengeId::OpStackFirstUnderflowElementWeight; +use crate::challenge_id::ChallengeId::OpStackIb1Weight; +use crate::challenge_id::ChallengeId::OpStackIndeterminate; +use crate::challenge_id::ChallengeId::OpStackPointerWeight; +use crate::challenge_id::ChallengeId::ProgramAddressWeight; +use crate::challenge_id::ChallengeId::ProgramInstructionWeight; +use crate::challenge_id::ChallengeId::ProgramNextInstructionWeight; +use crate::challenge_id::ChallengeId::RamClkWeight; +use crate::challenge_id::ChallengeId::RamIndeterminate; +use crate::challenge_id::ChallengeId::RamInstructionTypeWeight; +use crate::challenge_id::ChallengeId::RamPointerWeight; +use crate::challenge_id::ChallengeId::RamValueWeight; +use crate::challenge_id::ChallengeId::SpongeIndeterminate; +use crate::challenge_id::ChallengeId::StackWeight0; +use crate::challenge_id::ChallengeId::StackWeight1; +use crate::challenge_id::ChallengeId::StackWeight10; +use crate::challenge_id::ChallengeId::StackWeight11; +use crate::challenge_id::ChallengeId::StackWeight12; +use crate::challenge_id::ChallengeId::StackWeight13; +use crate::challenge_id::ChallengeId::StackWeight14; +use crate::challenge_id::ChallengeId::StackWeight15; +use crate::challenge_id::ChallengeId::StackWeight2; +use crate::challenge_id::ChallengeId::StackWeight3; +use crate::challenge_id::ChallengeId::StackWeight4; +use crate::challenge_id::ChallengeId::StackWeight5; +use crate::challenge_id::ChallengeId::StackWeight6; +use crate::challenge_id::ChallengeId::StackWeight7; +use crate::challenge_id::ChallengeId::StackWeight8; +use crate::challenge_id::ChallengeId::StackWeight9; +use crate::challenge_id::ChallengeId::StandardInputIndeterminate; +use crate::challenge_id::ChallengeId::StandardOutputIndeterminate; +use crate::challenge_id::ChallengeId::U32CiWeight; +use crate::challenge_id::ChallengeId::U32Indeterminate; +use crate::challenge_id::ChallengeId::U32LhsWeight; +use crate::challenge_id::ChallengeId::U32ResultWeight; +use crate::challenge_id::ChallengeId::U32RhsWeight; +use crate::cross_table_argument::CrossTableArg; +use crate::cross_table_argument::EvalArg; +use crate::cross_table_argument::LookupArg; +use crate::cross_table_argument::PermArg; +use crate::table; +use crate::table_column::MasterBaseTableColumn; +use crate::table_column::MasterExtTableColumn; +use crate::table_column::ProcessorBaseTableColumn; +use crate::table_column::ProcessorBaseTableColumn::ClockJumpDifferenceLookupMultiplicity; +use crate::table_column::ProcessorBaseTableColumn::IsPadding; +use crate::table_column::ProcessorBaseTableColumn::OpStackPointer; +use crate::table_column::ProcessorBaseTableColumn::CI; +use crate::table_column::ProcessorBaseTableColumn::CLK; +use crate::table_column::ProcessorBaseTableColumn::HV0; +use crate::table_column::ProcessorBaseTableColumn::HV1; +use crate::table_column::ProcessorBaseTableColumn::HV2; +use crate::table_column::ProcessorBaseTableColumn::HV3; +use crate::table_column::ProcessorBaseTableColumn::HV4; +use crate::table_column::ProcessorBaseTableColumn::HV5; +use crate::table_column::ProcessorBaseTableColumn::IB0; +use crate::table_column::ProcessorBaseTableColumn::IB1; +use crate::table_column::ProcessorBaseTableColumn::IB2; +use crate::table_column::ProcessorBaseTableColumn::IB3; +use crate::table_column::ProcessorBaseTableColumn::IB4; +use crate::table_column::ProcessorBaseTableColumn::IB5; +use crate::table_column::ProcessorBaseTableColumn::IB6; +use crate::table_column::ProcessorBaseTableColumn::IP; +use crate::table_column::ProcessorBaseTableColumn::JSD; +use crate::table_column::ProcessorBaseTableColumn::JSO; +use crate::table_column::ProcessorBaseTableColumn::JSP; +use crate::table_column::ProcessorBaseTableColumn::NIA; +use crate::table_column::ProcessorBaseTableColumn::ST0; +use crate::table_column::ProcessorBaseTableColumn::ST1; +use crate::table_column::ProcessorBaseTableColumn::ST10; +use crate::table_column::ProcessorBaseTableColumn::ST11; +use crate::table_column::ProcessorBaseTableColumn::ST12; +use crate::table_column::ProcessorBaseTableColumn::ST13; +use crate::table_column::ProcessorBaseTableColumn::ST14; +use crate::table_column::ProcessorBaseTableColumn::ST15; +use crate::table_column::ProcessorBaseTableColumn::ST2; +use crate::table_column::ProcessorBaseTableColumn::ST3; +use crate::table_column::ProcessorBaseTableColumn::ST4; +use crate::table_column::ProcessorBaseTableColumn::ST5; +use crate::table_column::ProcessorBaseTableColumn::ST6; +use crate::table_column::ProcessorBaseTableColumn::ST7; +use crate::table_column::ProcessorBaseTableColumn::ST8; +use crate::table_column::ProcessorBaseTableColumn::ST9; +use crate::table_column::ProcessorExtTableColumn; +use crate::table_column::ProcessorExtTableColumn::ClockJumpDifferenceLookupServerLogDerivative; +use crate::table_column::ProcessorExtTableColumn::HashDigestEvalArg; +use crate::table_column::ProcessorExtTableColumn::HashInputEvalArg; +use crate::table_column::ProcessorExtTableColumn::InputTableEvalArg; +use crate::table_column::ProcessorExtTableColumn::InstructionLookupClientLogDerivative; +use crate::table_column::ProcessorExtTableColumn::JumpStackTablePermArg; +use crate::table_column::ProcessorExtTableColumn::OpStackTablePermArg; +use crate::table_column::ProcessorExtTableColumn::OutputTableEvalArg; +use crate::table_column::ProcessorExtTableColumn::RamTablePermArg; +use crate::table_column::ProcessorExtTableColumn::SpongeEvalArg; +use crate::table_column::ProcessorExtTableColumn::U32LookupClientLogDerivative; +use crate::AIR; + +/// The number of helper variable registers +pub const NUM_HELPER_VARIABLE_REGISTERS: usize = 6; + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct ProcessorTable; + +impl ProcessorTable { + /// # Panics + /// + /// Panics if the index is out of bounds. + pub fn op_stack_column_by_index(index: usize) -> ProcessorBaseTableColumn { + assert!( + OpStackElement::COUNT < index, + "Op Stack column index must be in [0, 15], not {index}" + ); + + match index { + 0 => ST0, + 1 => ST1, + 2 => ST2, + 3 => ST3, + 4 => ST4, + 5 => ST5, + 6 => ST6, + 7 => ST7, + 8 => ST8, + 9 => ST9, + 10 => ST10, + 11 => ST11, + 12 => ST12, + 13 => ST13, + 14 => ST14, + 15 => ST15, + _ => unreachable!(), + } + } +} + +impl AIR for ProcessorTable { + type MainColumn = crate::table_column::ProcessorBaseTableColumn; + type AuxColumn = crate::table_column::ProcessorExtTableColumn; + + fn initial_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + let constant = |c: u32| circuit_builder.b_constant(c); + let x_constant = |x| circuit_builder.x_constant(x); + let challenge = |c| circuit_builder.challenge(c); + let base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(BaseRow(col.master_base_table_index())) + }; + let ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(ExtRow(col.master_ext_table_index())) + }; + + let clk_is_0 = base_row(CLK); + let ip_is_0 = base_row(IP); + let jsp_is_0 = base_row(JSP); + let jso_is_0 = base_row(JSO); + let jsd_is_0 = base_row(JSD); + let st0_is_0 = base_row(ST0); + let st1_is_0 = base_row(ST1); + let st2_is_0 = base_row(ST2); + let st3_is_0 = base_row(ST3); + let st4_is_0 = base_row(ST4); + let st5_is_0 = base_row(ST5); + let st6_is_0 = base_row(ST6); + let st7_is_0 = base_row(ST7); + let st8_is_0 = base_row(ST8); + let st9_is_0 = base_row(ST9); + let st10_is_0 = base_row(ST10); + let op_stack_pointer_is_16 = base_row(OpStackPointer) - constant(16); + + // Compress the program digest using an Evaluation Argument. + // Lowest index in the digest corresponds to lowest index on the stack. + let program_digest: [_; Digest::LEN] = [ + base_row(ST11), + base_row(ST12), + base_row(ST13), + base_row(ST14), + base_row(ST15), + ]; + let compressed_program_digest = program_digest.into_iter().fold( + circuit_builder.x_constant(EvalArg::default_initial()), + |acc, digest_element| { + acc * challenge(CompressProgramDigestIndeterminate) + digest_element + }, + ); + let compressed_program_digest_is_expected_program_digest = + compressed_program_digest - challenge(CompressedProgramDigest); + + // Permutation and Evaluation Arguments with all tables the Processor Table relates to + + // standard input + let running_evaluation_for_standard_input_is_initialized_correctly = + ext_row(InputTableEvalArg) - x_constant(EvalArg::default_initial()); + + // program table + let instruction_lookup_indeterminate = challenge(InstructionLookupIndeterminate); + let instruction_ci_weight = challenge(ProgramInstructionWeight); + let instruction_nia_weight = challenge(ProgramNextInstructionWeight); + let compressed_row_for_instruction_lookup = + instruction_ci_weight * base_row(CI) + instruction_nia_weight * base_row(NIA); + let instruction_lookup_log_derivative_is_initialized_correctly = + (ext_row(InstructionLookupClientLogDerivative) + - x_constant(LookupArg::default_initial())) + * (instruction_lookup_indeterminate - compressed_row_for_instruction_lookup) + - constant(1); + + // standard output + let running_evaluation_for_standard_output_is_initialized_correctly = + ext_row(OutputTableEvalArg) - x_constant(EvalArg::default_initial()); + + let running_product_for_op_stack_table_is_initialized_correctly = + ext_row(OpStackTablePermArg) - x_constant(PermArg::default_initial()); + + // ram table + let running_product_for_ram_table_is_initialized_correctly = + ext_row(RamTablePermArg) - x_constant(PermArg::default_initial()); + + // jump-stack table + let jump_stack_indeterminate = challenge(JumpStackIndeterminate); + let jump_stack_ci_weight = challenge(JumpStackCiWeight); + // note: `clk`, `jsp`, `jso`, and `jsd` are already constrained to be 0. + let compressed_row_for_jump_stack_table = jump_stack_ci_weight * base_row(CI); + let running_product_for_jump_stack_table_is_initialized_correctly = + ext_row(JumpStackTablePermArg) + - x_constant(PermArg::default_initial()) + * (jump_stack_indeterminate - compressed_row_for_jump_stack_table); + + // clock jump difference lookup argument + // The clock jump difference logarithmic derivative accumulator starts + // off having accumulated the contribution from the first row. + // Note that (challenge(ClockJumpDifferenceLookupIndeterminate) - base_row(CLK)) + // collapses to challenge(ClockJumpDifferenceLookupIndeterminate) + // because base_row(CLK) = 0 is already a constraint. + let clock_jump_diff_lookup_log_derivative_is_initialized_correctly = + ext_row(ClockJumpDifferenceLookupServerLogDerivative) + * challenge(ClockJumpDifferenceLookupIndeterminate) + - base_row(ClockJumpDifferenceLookupMultiplicity); + + // from processor to hash table + let hash_selector = base_row(CI) - constant(Instruction::Hash.opcode()); + let hash_deselector = instruction_deselector_single_row(circuit_builder, Instruction::Hash); + let hash_input_indeterminate = challenge(HashInputIndeterminate); + // the opStack is guaranteed to be initialized to 0 by virtue of other initial constraints + let compressed_row = constant(0); + let running_evaluation_hash_input_has_absorbed_first_row = ext_row(HashInputEvalArg) + - hash_input_indeterminate * x_constant(EvalArg::default_initial()) + - compressed_row; + let running_evaluation_hash_input_is_default_initial = + ext_row(HashInputEvalArg) - x_constant(EvalArg::default_initial()); + let running_evaluation_hash_input_is_initialized_correctly = hash_selector + * running_evaluation_hash_input_is_default_initial + + hash_deselector * running_evaluation_hash_input_has_absorbed_first_row; + + // from hash table to processor + let running_evaluation_hash_digest_is_initialized_correctly = + ext_row(HashDigestEvalArg) - x_constant(EvalArg::default_initial()); + + // Hash Table – Sponge + let running_evaluation_sponge_absorb_is_initialized_correctly = + ext_row(SpongeEvalArg) - x_constant(EvalArg::default_initial()); + + // u32 table + let running_sum_log_derivative_for_u32_table_is_initialized_correctly = + ext_row(U32LookupClientLogDerivative) - x_constant(LookupArg::default_initial()); + + vec![ + clk_is_0, + ip_is_0, + jsp_is_0, + jso_is_0, + jsd_is_0, + st0_is_0, + st1_is_0, + st2_is_0, + st3_is_0, + st4_is_0, + st5_is_0, + st6_is_0, + st7_is_0, + st8_is_0, + st9_is_0, + st10_is_0, + compressed_program_digest_is_expected_program_digest, + op_stack_pointer_is_16, + running_evaluation_for_standard_input_is_initialized_correctly, + instruction_lookup_log_derivative_is_initialized_correctly, + running_evaluation_for_standard_output_is_initialized_correctly, + running_product_for_op_stack_table_is_initialized_correctly, + running_product_for_ram_table_is_initialized_correctly, + running_product_for_jump_stack_table_is_initialized_correctly, + clock_jump_diff_lookup_log_derivative_is_initialized_correctly, + running_evaluation_hash_input_is_initialized_correctly, + running_evaluation_hash_digest_is_initialized_correctly, + running_evaluation_sponge_absorb_is_initialized_correctly, + running_sum_log_derivative_for_u32_table_is_initialized_correctly, + ] + } + + fn consistency_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + let constant = |c: u32| circuit_builder.b_constant(c); + let base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(BaseRow(col.master_base_table_index())) + }; + + // The composition of instruction bits ib0-ib7 corresponds the current instruction ci. + let ib_composition = base_row(IB0) + + constant(1 << 1) * base_row(IB1) + + constant(1 << 2) * base_row(IB2) + + constant(1 << 3) * base_row(IB3) + + constant(1 << 4) * base_row(IB4) + + constant(1 << 5) * base_row(IB5) + + constant(1 << 6) * base_row(IB6); + let ci_corresponds_to_ib0_thru_ib7 = base_row(CI) - ib_composition; + + let ib0_is_bit = base_row(IB0) * (base_row(IB0) - constant(1)); + let ib1_is_bit = base_row(IB1) * (base_row(IB1) - constant(1)); + let ib2_is_bit = base_row(IB2) * (base_row(IB2) - constant(1)); + let ib3_is_bit = base_row(IB3) * (base_row(IB3) - constant(1)); + let ib4_is_bit = base_row(IB4) * (base_row(IB4) - constant(1)); + let ib5_is_bit = base_row(IB5) * (base_row(IB5) - constant(1)); + let ib6_is_bit = base_row(IB6) * (base_row(IB6) - constant(1)); + let is_padding_is_bit = base_row(IsPadding) * (base_row(IsPadding) - constant(1)); + + // In padding rows, the clock jump difference lookup multiplicity is 0. The one row + // exempt from this rule is the row wth CLK == 1: since the memory-like tables don't have + // an “awareness” of padding rows, they keep looking up clock jump differences of + // magnitude 1. + let clock_jump_diff_lookup_multiplicity_is_0_in_padding_rows = base_row(IsPadding) + * (base_row(CLK) - constant(1)) + * base_row(ClockJumpDifferenceLookupMultiplicity); + + vec![ + ib0_is_bit, + ib1_is_bit, + ib2_is_bit, + ib3_is_bit, + ib4_is_bit, + ib5_is_bit, + ib6_is_bit, + is_padding_is_bit, + ci_corresponds_to_ib0_thru_ib7, + clock_jump_diff_lookup_multiplicity_is_0_in_padding_rows, + ] + } + + fn transition_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + let constant = |c: u64| circuit_builder.b_constant(c); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + // constraints common to all instructions + let clk_increases_by_1 = next_base_row(CLK) - curr_base_row(CLK) - constant(1); + let is_padding_is_0_or_does_not_change = + curr_base_row(IsPadding) * (next_base_row(IsPadding) - curr_base_row(IsPadding)); + + let instruction_independent_constraints = + vec![clk_increases_by_1, is_padding_is_0_or_does_not_change]; + + // instruction-specific constraints + let transition_constraints_for_instruction = + |instr| transition_constraints_for_instruction(circuit_builder, instr); + let all_instructions_and_their_transition_constraints = + ALL_INSTRUCTIONS.map(|instr| (instr, transition_constraints_for_instruction(instr))); + let deselected_transition_constraints = combine_instruction_constraints_with_deselectors( + circuit_builder, + all_instructions_and_their_transition_constraints, + ); + + // if next row is padding row: disable transition constraints, enable padding constraints + let doubly_deselected_transition_constraints = + combine_transition_constraints_with_padding_constraints( + circuit_builder, + deselected_transition_constraints, + ); + + let table_linking_constraints = vec![ + log_derivative_accumulates_clk_next(circuit_builder), + log_derivative_for_instruction_lookup_updates_correctly(circuit_builder), + running_product_for_jump_stack_table_updates_correctly(circuit_builder), + running_evaluation_hash_input_updates_correctly(circuit_builder), + running_evaluation_hash_digest_updates_correctly(circuit_builder), + running_evaluation_sponge_updates_correctly(circuit_builder), + log_derivative_with_u32_table_updates_correctly(circuit_builder), + ]; + + [ + instruction_independent_constraints, + doubly_deselected_transition_constraints, + table_linking_constraints, + ] + .concat() + } + + fn terminal_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + let main_row = + |col: Self::MainColumn| circuit_builder.input(BaseRow(col.master_base_table_index())); + let constant = |c| circuit_builder.b_constant(c); + + let last_ci_is_halt = + main_row(Self::MainColumn::CI) - constant(Instruction::Halt.opcode_b()); + + vec![last_ci_is_halt] + } +} + +/// Instruction-specific transition constraints are combined with deselectors in such a way +/// that arbitrary sets of mutually exclusive combinations are summed, i.e., +/// +/// ```py +/// [ deselector_pop * tc_pop_0 + deselector_push * tc_push_0 + ..., +/// deselector_pop * tc_pop_1 + deselector_push * tc_push_1 + ..., +/// ..., +/// deselector_pop * tc_pop_i + deselector_push * tc_push_i + ..., +/// deselector_pop * 0 + deselector_push * tc_push_{i+1} + ..., +/// ..., +/// ] +/// ``` +/// For instructions that have fewer transition constraints than the maximal number of +/// transition constraints among all instructions, the deselector is multiplied with a zero, +/// causing no additional terms in the final sets of combined transition constraint polynomials. +fn combine_instruction_constraints_with_deselectors( + circuit_builder: &ConstraintCircuitBuilder, + instr_tc_polys_tuples: [(Instruction, Vec>); + Instruction::COUNT], +) -> Vec> { + let (all_instructions, all_tc_polys_for_all_instructions): (Vec<_>, Vec<_>) = + instr_tc_polys_tuples.into_iter().unzip(); + + let all_instruction_deselectors = all_instructions + .into_iter() + .map(|instr| instruction_deselector_current_row(circuit_builder, instr)) + .collect_vec(); + + let max_number_of_constraints = all_tc_polys_for_all_instructions + .iter() + .map(|tc_polys_for_instr| tc_polys_for_instr.len()) + .max() + .unwrap(); + + let zero_poly = circuit_builder.b_constant(0); + let all_tc_polys_for_all_instructions_transposed = (0..max_number_of_constraints) + .map(|idx| { + all_tc_polys_for_all_instructions + .iter() + .map(|tc_polys_for_instr| tc_polys_for_instr.get(idx).unwrap_or(&zero_poly)) + .collect_vec() + }) + .collect_vec(); + + all_tc_polys_for_all_instructions_transposed + .into_iter() + .map(|row| { + all_instruction_deselectors + .clone() + .into_iter() + .zip(row) + .map(|(deselector, instruction_tc)| deselector * instruction_tc.to_owned()) + .sum() + }) + .collect_vec() +} + +fn combine_transition_constraints_with_padding_constraints( + circuit_builder: &ConstraintCircuitBuilder, + instruction_transition_constraints: Vec>, +) -> Vec> { + let constant = |c: u64| circuit_builder.b_constant(c); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let padding_row_transition_constraints = [ + vec![ + next_base_row(IP) - curr_base_row(IP), + next_base_row(CI) - curr_base_row(CI), + next_base_row(NIA) - curr_base_row(NIA), + ], + instruction_group_keep_jump_stack(circuit_builder), + instruction_group_keep_op_stack(circuit_builder), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat(); + + let padding_row_deselector = constant(1) - next_base_row(IsPadding); + let padding_row_selector = next_base_row(IsPadding); + + let max_number_of_constraints = max( + instruction_transition_constraints.len(), + padding_row_transition_constraints.len(), + ); + + (0..max_number_of_constraints) + .map(|idx| { + let instruction_constraint = instruction_transition_constraints + .get(idx) + .unwrap_or(&constant(0)) + .to_owned(); + let padding_constraint = padding_row_transition_constraints + .get(idx) + .unwrap_or(&constant(0)) + .to_owned(); + + instruction_constraint * padding_row_deselector.clone() + + padding_constraint * padding_row_selector.clone() + }) + .collect_vec() +} + +fn instruction_group_decompose_arg( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let constant = |c: u32| circuit_builder.b_constant(c); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + + let hv0_is_a_bit = curr_base_row(HV0) * (curr_base_row(HV0) - constant(1)); + let hv1_is_a_bit = curr_base_row(HV1) * (curr_base_row(HV1) - constant(1)); + let hv2_is_a_bit = curr_base_row(HV2) * (curr_base_row(HV2) - constant(1)); + let hv3_is_a_bit = curr_base_row(HV3) * (curr_base_row(HV3) - constant(1)); + + let helper_variables_are_binary_decomposition_of_nia = curr_base_row(NIA) + - constant(8) * curr_base_row(HV3) + - constant(4) * curr_base_row(HV2) + - constant(2) * curr_base_row(HV1) + - curr_base_row(HV0); + + vec![ + hv0_is_a_bit, + hv1_is_a_bit, + hv2_is_a_bit, + hv3_is_a_bit, + helper_variables_are_binary_decomposition_of_nia, + ] +} + +/// The permutation argument accumulator with the RAM table does +/// not change, because there is no RAM access. +fn instruction_group_no_ram( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let curr_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) + }; + let next_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(NextExtRow(col.master_ext_table_index())) + }; + + vec![next_ext_row(RamTablePermArg) - curr_ext_row(RamTablePermArg)] +} + +fn instruction_group_no_io( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + vec![ + running_evaluation_for_standard_input_remains_unchanged(circuit_builder), + running_evaluation_for_standard_output_remains_unchanged(circuit_builder), + ] +} + +/// Op Stack height does not change and except for the top n elements, +/// the values remain also. +fn instruction_group_op_stack_remains_except_top_n( + circuit_builder: &ConstraintCircuitBuilder, + n: usize, +) -> Vec> { + assert!(n <= NUM_OP_STACK_REGISTERS); + + let curr_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let stack = (0..OpStackElement::COUNT) + .map(ProcessorTable::op_stack_column_by_index) + .collect_vec(); + let next_stack = stack.iter().map(|&st| next_row(st)).collect_vec(); + let curr_stack = stack.iter().map(|&st| curr_row(st)).collect_vec(); + + let compress_stack_except_top_n = |stack: Vec<_>| -> ConstraintCircuitMonad<_> { + assert_eq!(NUM_OP_STACK_REGISTERS, stack.len()); + let weight = |i| circuit_builder.challenge(stack_weight_by_index(i)); + stack + .into_iter() + .enumerate() + .skip(n) + .map(|(i, st)| weight(i) * st) + .sum() + }; + + let all_but_n_top_elements_remain = + compress_stack_except_top_n(next_stack) - compress_stack_except_top_n(curr_stack); + + let mut constraints = instruction_group_keep_op_stack_height(circuit_builder); + constraints.push(all_but_n_top_elements_remain); + constraints +} + +/// Op stack does not change, _i.e._, all stack elements persist +fn instruction_group_keep_op_stack( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + instruction_group_op_stack_remains_except_top_n(circuit_builder, 0) +} + +/// Op stack *height* does not change, _i.e._, the accumulator for the +/// permutation argument with the op stack table remains the same as does +/// the op stack pointer. +fn instruction_group_keep_op_stack_height( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let op_stack_pointer_curr = + circuit_builder.input(CurrentBaseRow(OpStackPointer.master_base_table_index())); + let op_stack_pointer_next = + circuit_builder.input(NextBaseRow(OpStackPointer.master_base_table_index())); + let osp_remains_unchanged = op_stack_pointer_next - op_stack_pointer_curr; + + let op_stack_table_perm_arg_curr = + circuit_builder.input(CurrentExtRow(OpStackTablePermArg.master_ext_table_index())); + let op_stack_table_perm_arg_next = + circuit_builder.input(NextExtRow(OpStackTablePermArg.master_ext_table_index())); + let perm_arg_remains_unchanged = op_stack_table_perm_arg_next - op_stack_table_perm_arg_curr; + + vec![osp_remains_unchanged, perm_arg_remains_unchanged] +} + +fn instruction_group_grow_op_stack_and_top_two_elements_unconstrained( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let constant = |c: u32| circuit_builder.b_constant(c); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + vec![ + next_base_row(ST2) - curr_base_row(ST1), + next_base_row(ST3) - curr_base_row(ST2), + next_base_row(ST4) - curr_base_row(ST3), + next_base_row(ST5) - curr_base_row(ST4), + next_base_row(ST6) - curr_base_row(ST5), + next_base_row(ST7) - curr_base_row(ST6), + next_base_row(ST8) - curr_base_row(ST7), + next_base_row(ST9) - curr_base_row(ST8), + next_base_row(ST10) - curr_base_row(ST9), + next_base_row(ST11) - curr_base_row(ST10), + next_base_row(ST12) - curr_base_row(ST11), + next_base_row(ST13) - curr_base_row(ST12), + next_base_row(ST14) - curr_base_row(ST13), + next_base_row(ST15) - curr_base_row(ST14), + next_base_row(OpStackPointer) - curr_base_row(OpStackPointer) - constant(1), + running_product_op_stack_accounts_for_growing_stack_by(circuit_builder, 1), + ] +} + +fn instruction_group_grow_op_stack( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let specific_constraints = vec![next_base_row(ST1) - curr_base_row(ST0)]; + let inherited_constraints = + instruction_group_grow_op_stack_and_top_two_elements_unconstrained(circuit_builder); + + [specific_constraints, inherited_constraints].concat() +} + +fn instruction_group_op_stack_shrinks_and_top_three_elements_unconstrained( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let constant = |c: u32| circuit_builder.b_constant(c); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + vec![ + next_base_row(ST3) - curr_base_row(ST4), + next_base_row(ST4) - curr_base_row(ST5), + next_base_row(ST5) - curr_base_row(ST6), + next_base_row(ST6) - curr_base_row(ST7), + next_base_row(ST7) - curr_base_row(ST8), + next_base_row(ST8) - curr_base_row(ST9), + next_base_row(ST9) - curr_base_row(ST10), + next_base_row(ST10) - curr_base_row(ST11), + next_base_row(ST11) - curr_base_row(ST12), + next_base_row(ST12) - curr_base_row(ST13), + next_base_row(ST13) - curr_base_row(ST14), + next_base_row(ST14) - curr_base_row(ST15), + next_base_row(OpStackPointer) - curr_base_row(OpStackPointer) + constant(1), + running_product_op_stack_accounts_for_shrinking_stack_by(circuit_builder, 1), + ] +} + +fn instruction_group_binop( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let specific_constraints = vec![ + next_base_row(ST1) - curr_base_row(ST2), + next_base_row(ST2) - curr_base_row(ST3), + ]; + let inherited_constraints = + instruction_group_op_stack_shrinks_and_top_three_elements_unconstrained(circuit_builder); + + [specific_constraints, inherited_constraints].concat() +} + +fn instruction_group_shrink_op_stack( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let specific_constraints = vec![next_base_row(ST0) - curr_base_row(ST1)]; + let inherited_constraints = instruction_group_binop(circuit_builder); + + [specific_constraints, inherited_constraints].concat() +} + +fn instruction_group_keep_jump_stack( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let jsp_does_not_change = next_base_row(JSP) - curr_base_row(JSP); + let jso_does_not_change = next_base_row(JSO) - curr_base_row(JSO); + let jsd_does_not_change = next_base_row(JSD) - curr_base_row(JSD); + + vec![ + jsp_does_not_change, + jso_does_not_change, + jsd_does_not_change, + ] +} + +/// Increase the instruction pointer by 1. +fn instruction_group_step_1( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let constant = |c: u32| circuit_builder.b_constant(c); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let instruction_pointer_increases_by_one = next_base_row(IP) - curr_base_row(IP) - constant(1); + [ + instruction_group_keep_jump_stack(circuit_builder), + vec![instruction_pointer_increases_by_one], + ] + .concat() +} + +/// Increase the instruction pointer by 2. +fn instruction_group_step_2( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let constant = |c: u32| circuit_builder.b_constant(c); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let instruction_pointer_increases_by_two = next_base_row(IP) - curr_base_row(IP) - constant(2); + [ + instruction_group_keep_jump_stack(circuit_builder), + vec![instruction_pointer_increases_by_two], + ] + .concat() +} + +/// Internal helper function to de-duplicate functionality common between the similar (but +/// different on a type level) functions for construction deselectors. +fn instruction_deselector_common_functionality( + circuit_builder: &ConstraintCircuitBuilder, + instruction: Instruction, + instruction_bit_polynomials: [ConstraintCircuitMonad; InstructionBit::COUNT], +) -> ConstraintCircuitMonad { + let one = || circuit_builder.b_constant(1); + + let selector_bits: [_; InstructionBit::COUNT] = [ + instruction.ib(InstructionBit::IB0), + instruction.ib(InstructionBit::IB1), + instruction.ib(InstructionBit::IB2), + instruction.ib(InstructionBit::IB3), + instruction.ib(InstructionBit::IB4), + instruction.ib(InstructionBit::IB5), + instruction.ib(InstructionBit::IB6), + ]; + let deselector_polynomials = selector_bits.map(|b| one() - circuit_builder.b_constant(b)); + + instruction_bit_polynomials + .into_iter() + .zip_eq(deselector_polynomials) + .map(|(instruction_bit_poly, deselector_poly)| instruction_bit_poly - deselector_poly) + .fold(one(), ConstraintCircuitMonad::mul) +} + +/// A polynomial that has no solutions when `ci` is `instruction`. +/// The number of variables in the polynomial corresponds to two rows. +fn instruction_deselector_current_row( + circuit_builder: &ConstraintCircuitBuilder, + instruction: Instruction, +) -> ConstraintCircuitMonad { + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + + let instruction_bit_polynomials = [ + curr_base_row(IB0), + curr_base_row(IB1), + curr_base_row(IB2), + curr_base_row(IB3), + curr_base_row(IB4), + curr_base_row(IB5), + curr_base_row(IB6), + ]; + + instruction_deselector_common_functionality( + circuit_builder, + instruction, + instruction_bit_polynomials, + ) +} + +/// A polynomial that has no solutions when `ci_next` is `instruction`. +/// The number of variables in the polynomial corresponds to two rows. +fn instruction_deselector_next_row( + circuit_builder: &ConstraintCircuitBuilder, + instruction: Instruction, +) -> ConstraintCircuitMonad { + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let instruction_bit_polynomials = [ + next_base_row(IB0), + next_base_row(IB1), + next_base_row(IB2), + next_base_row(IB3), + next_base_row(IB4), + next_base_row(IB5), + next_base_row(IB6), + ]; + + instruction_deselector_common_functionality( + circuit_builder, + instruction, + instruction_bit_polynomials, + ) +} + +/// A polynomial that has no solutions when `ci` is `instruction`. +/// The number of variables in the polynomial corresponds to a single row. +fn instruction_deselector_single_row( + circuit_builder: &ConstraintCircuitBuilder, + instruction: Instruction, +) -> ConstraintCircuitMonad { + let base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(BaseRow(col.master_base_table_index())) + }; + + let instruction_bit_polynomials = [ + base_row(IB0), + base_row(IB1), + base_row(IB2), + base_row(IB3), + base_row(IB4), + base_row(IB5), + base_row(IB6), + ]; + + instruction_deselector_common_functionality( + circuit_builder, + instruction, + instruction_bit_polynomials, + ) +} + +fn instruction_pop( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + [ + instruction_group_step_2(circuit_builder), + instruction_group_decompose_arg(circuit_builder), + stack_shrinks_by_any_of(circuit_builder, &NumberOfWords::legal_values()), + prohibit_any_illegal_number_of_words(circuit_builder), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_push( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let specific_constraints = vec![next_base_row(ST0) - curr_base_row(NIA)]; + [ + specific_constraints, + instruction_group_grow_op_stack(circuit_builder), + instruction_group_step_2(circuit_builder), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_divine( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + [ + instruction_group_step_2(circuit_builder), + instruction_group_decompose_arg(circuit_builder), + stack_grows_by_any_of(circuit_builder, &NumberOfWords::legal_values()), + prohibit_any_illegal_number_of_words(circuit_builder), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_dup( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let indicator_poly = |idx| indicator_polynomial(circuit_builder, idx); + let curr_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let st_column = ProcessorTable::op_stack_column_by_index; + let duplicate_element = |i| indicator_poly(i) * (next_row(ST0) - curr_row(st_column(i))); + let duplicate_indicated_element = (0..OpStackElement::COUNT).map(duplicate_element).sum(); + + [ + vec![duplicate_indicated_element], + instruction_group_decompose_arg(circuit_builder), + instruction_group_step_2(circuit_builder), + instruction_group_grow_op_stack(circuit_builder), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_swap( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let curr_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let stack = (0..OpStackElement::COUNT) + .map(ProcessorTable::op_stack_column_by_index) + .collect_vec(); + let stack_with_swapped_i = |i| { + let mut stack = stack.clone(); + stack.swap(0, i); + stack.into_iter() + }; + + let next_stack = stack.iter().map(|&st| next_row(st)).collect_vec(); + let curr_stack_with_swapped_i = |i| stack_with_swapped_i(i).map(curr_row).collect_vec(); + let compress = |stack: Vec<_>| -> ConstraintCircuitMonad<_> { + assert_eq!(OpStackElement::COUNT, stack.len()); + let weight = |i| circuit_builder.challenge(stack_weight_by_index(i)); + let enumerated_stack = stack.into_iter().enumerate(); + enumerated_stack.map(|(i, st)| weight(i) * st).sum() + }; + + let next_stack_is_current_stack_with_swapped_i = |i| { + indicator_polynomial(circuit_builder, i) + * (compress(next_stack.clone()) - compress(curr_stack_with_swapped_i(i))) + }; + let next_stack_is_current_stack_with_correct_element_swapped = (0..OpStackElement::COUNT) + .map(next_stack_is_current_stack_with_swapped_i) + .sum(); + + [ + vec![next_stack_is_current_stack_with_correct_element_swapped], + instruction_group_decompose_arg(circuit_builder), + instruction_group_step_2(circuit_builder), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + instruction_group_keep_op_stack_height(circuit_builder), + ] + .concat() +} + +fn instruction_nop( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + [ + instruction_group_step_1(circuit_builder), + instruction_group_keep_op_stack(circuit_builder), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_skiz( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let constant = |c: u32| circuit_builder.b_constant(c); + let one = || constant(1); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let hv0_is_inverse_of_st0 = curr_base_row(HV0) * curr_base_row(ST0) - one(); + let hv0_is_inverse_of_st0_or_hv0_is_0 = hv0_is_inverse_of_st0.clone() * curr_base_row(HV0); + let hv0_is_inverse_of_st0_or_st0_is_0 = hv0_is_inverse_of_st0 * curr_base_row(ST0); + + // The next instruction nia is decomposed into helper variables hv. + let nia_decomposes_to_hvs = curr_base_row(NIA) + - curr_base_row(HV1) + - constant(1 << 1) * curr_base_row(HV2) + - constant(1 << 3) * curr_base_row(HV3) + - constant(1 << 5) * curr_base_row(HV4) + - constant(1 << 7) * curr_base_row(HV5); + + // If `st0` is non-zero, register `ip` is incremented by 1. + // If `st0` is 0 and `nia` takes no argument, register `ip` is incremented by 2. + // If `st0` is 0 and `nia` takes an argument, register `ip` is incremented by 3. + // + // The opcodes are constructed such that hv1 == 1 means that nia takes an argument. + // + // Written as Disjunctive Normal Form, the constraint can be expressed as: + // (Register `st0` is 0 or `ip` is incremented by 1), and + // (`st0` has a multiplicative inverse or `hv1` is 1 or `ip` is incremented by 2), and + // (`st0` has a multiplicative inverse or `hv1` is 0 or `ip` is incremented by 3). + let ip_case_1 = (next_base_row(IP) - curr_base_row(IP) - constant(1)) * curr_base_row(ST0); + let ip_case_2 = (next_base_row(IP) - curr_base_row(IP) - constant(2)) + * (curr_base_row(ST0) * curr_base_row(HV0) - one()) + * (curr_base_row(HV1) - one()); + let ip_case_3 = (next_base_row(IP) - curr_base_row(IP) - constant(3)) + * (curr_base_row(ST0) * curr_base_row(HV0) - one()) + * curr_base_row(HV1); + let ip_incr_by_1_or_2_or_3 = ip_case_1 + ip_case_2 + ip_case_3; + + let specific_constraints = vec![ + hv0_is_inverse_of_st0_or_hv0_is_0, + hv0_is_inverse_of_st0_or_st0_is_0, + nia_decomposes_to_hvs, + ip_incr_by_1_or_2_or_3, + ]; + [ + specific_constraints, + next_instruction_range_check_constraints_for_instruction_skiz(circuit_builder), + instruction_group_keep_jump_stack(circuit_builder), + instruction_group_shrink_op_stack(circuit_builder), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn next_instruction_range_check_constraints_for_instruction_skiz( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let constant = |c: u32| circuit_builder.b_constant(c); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + + let is_0_or_1 = + |var: ProcessorBaseTableColumn| curr_base_row(var) * (curr_base_row(var) - constant(1)); + let is_0_or_1_or_2_or_3 = |var: ProcessorBaseTableColumn| { + curr_base_row(var) + * (curr_base_row(var) - constant(1)) + * (curr_base_row(var) - constant(2)) + * (curr_base_row(var) - constant(3)) + }; + + vec![ + is_0_or_1(HV1), + is_0_or_1_or_2_or_3(HV2), + is_0_or_1_or_2_or_3(HV3), + is_0_or_1_or_2_or_3(HV4), + is_0_or_1_or_2_or_3(HV5), + ] +} + +fn instruction_call( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let constant = |c: u32| circuit_builder.b_constant(c); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + // The jump stack pointer jsp is incremented by 1. + let jsp_incr_1 = next_base_row(JSP) - curr_base_row(JSP) - constant(1); + + // The jump's origin jso is set to the current instruction pointer ip plus 2. + let jso_becomes_ip_plus_2 = next_base_row(JSO) - curr_base_row(IP) - constant(2); + + // The jump's destination jsd is set to the instruction's argument. + let jsd_becomes_nia = next_base_row(JSD) - curr_base_row(NIA); + + // The instruction pointer ip is set to the instruction's argument. + let ip_becomes_nia = next_base_row(IP) - curr_base_row(NIA); + + let specific_constraints = vec![ + jsp_incr_1, + jso_becomes_ip_plus_2, + jsd_becomes_nia, + ip_becomes_nia, + ]; + [ + specific_constraints, + instruction_group_keep_op_stack(circuit_builder), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_return( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let constant = |c: u32| circuit_builder.b_constant(c); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let jsp_decrements_by_1 = next_base_row(JSP) - curr_base_row(JSP) + constant(1); + let ip_is_set_to_jso = next_base_row(IP) - curr_base_row(JSO); + let specific_constraints = vec![jsp_decrements_by_1, ip_is_set_to_jso]; + + [ + specific_constraints, + instruction_group_keep_op_stack(circuit_builder), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_recurse( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + // The instruction pointer ip is set to the last jump's destination jsd. + let ip_becomes_jsd = next_base_row(IP) - curr_base_row(JSD); + let specific_constraints = vec![ip_becomes_jsd]; + [ + specific_constraints, + instruction_group_keep_jump_stack(circuit_builder), + instruction_group_keep_op_stack(circuit_builder), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_recurse_or_return( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let one = || circuit_builder.b_constant(1); + let curr_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + // Zero if the ST5 equals ST6. One if they are not equal. + let st5_eq_st6 = || curr_row(HV0) * (curr_row(ST6) - curr_row(ST5)); + let st5_neq_st6 = || one() - st5_eq_st6(); + + let maybe_return = vec![ + // hv0 is inverse-or-zero of the difference of ST6 and ST5. + st5_neq_st6() * curr_row(HV0), + st5_neq_st6() * (curr_row(ST6) - curr_row(ST5)), + st5_neq_st6() * (next_row(IP) - curr_row(JSO)), + st5_neq_st6() * (next_row(JSP) - curr_row(JSP) + one()), + ]; + let maybe_recurse = vec![ + // constraints are ordered to line up nicely with group “maybe_return” + st5_eq_st6() * (next_row(JSO) - curr_row(JSO)), + st5_eq_st6() * (next_row(JSD) - curr_row(JSD)), + st5_eq_st6() * (next_row(IP) - curr_row(JSD)), + st5_eq_st6() * (next_row(JSP) - curr_row(JSP)), + ]; + + // The two constraint groups are mutually exclusive: the stack element is either + // equal to its successor or not, indicated by `st5_eq_st6` and `st5_neq_st6`. + // Therefore, it is safe (and sound) to combine the groups into a single set of + // constraints. + let constraint_groups = vec![maybe_return, maybe_recurse]; + let specific_constraints = + combine_mutually_exclusive_constraint_groups(circuit_builder, constraint_groups); + + [ + specific_constraints, + instruction_group_keep_op_stack(circuit_builder), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_assert( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let constant = |c: u32| circuit_builder.b_constant(c); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + + // The current top of the stack st0 is 1. + let st_0_is_1 = curr_base_row(ST0) - constant(1); + + let specific_constraints = vec![st_0_is_1]; + [ + specific_constraints, + instruction_group_step_1(circuit_builder), + instruction_group_shrink_op_stack(circuit_builder), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_halt( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + // The instruction executed in the following step is instruction halt. + let halt_is_followed_by_halt = next_base_row(CI) - curr_base_row(CI); + + let specific_constraints = vec![halt_is_followed_by_halt]; + [ + specific_constraints, + instruction_group_step_1(circuit_builder), + instruction_group_keep_op_stack(circuit_builder), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_read_mem( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + [ + instruction_group_step_2(circuit_builder), + instruction_group_decompose_arg(circuit_builder), + read_from_ram_any_of(circuit_builder, &NumberOfWords::legal_values()), + prohibit_any_illegal_number_of_words(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_write_mem( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + [ + instruction_group_step_2(circuit_builder), + instruction_group_decompose_arg(circuit_builder), + write_to_ram_any_of(circuit_builder, &NumberOfWords::legal_values()), + prohibit_any_illegal_number_of_words(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +/// Two Evaluation Arguments with the Hash Table guarantee correct transition. +fn instruction_hash( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let constant = |c: u32| circuit_builder.b_constant(c); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let op_stack_shrinks_by_5_and_top_5_unconstrained = vec![ + next_base_row(ST5) - curr_base_row(ST10), + next_base_row(ST6) - curr_base_row(ST11), + next_base_row(ST7) - curr_base_row(ST12), + next_base_row(ST8) - curr_base_row(ST13), + next_base_row(ST9) - curr_base_row(ST14), + next_base_row(ST10) - curr_base_row(ST15), + next_base_row(OpStackPointer) - curr_base_row(OpStackPointer) + constant(5), + running_product_op_stack_accounts_for_shrinking_stack_by(circuit_builder, 5), + ]; + + [ + instruction_group_step_1(circuit_builder), + op_stack_shrinks_by_5_and_top_5_unconstrained, + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_merkle_step( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + [ + instruction_merkle_step_shared_constraints(circuit_builder), + instruction_group_op_stack_remains_except_top_n(circuit_builder, 6), + instruction_group_no_ram(circuit_builder), + ] + .concat() +} + +fn instruction_merkle_step_mem( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let constant = |c: u32| circuit_builder.b_constant(c); + let stack_weight = |i| circuit_builder.challenge(stack_weight_by_index(i)); + let curr = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let ram_pointers = [0, 1, 2, 3, 4].map(|i| curr(ST7) + constant(i)); + let ram_read_destinations = [HV0, HV1, HV2, HV3, HV4].map(curr); + let read_from_ram_to_hvs = + read_from_ram_to(circuit_builder, ram_pointers, ram_read_destinations); + + let st6_does_not_change = next(ST6) - curr(ST6); + let st7_increments_by_5 = next(ST7) - curr(ST7) - constant(5); + let st6_and_st7_update_correctly = + stack_weight(6) * st6_does_not_change + stack_weight(7) * st7_increments_by_5; + + [ + vec![st6_and_st7_update_correctly, read_from_ram_to_hvs], + instruction_merkle_step_shared_constraints(circuit_builder), + instruction_group_op_stack_remains_except_top_n(circuit_builder, 8), + ] + .concat() +} + +/// Recall that in a Merkle tree, the indices of left (respectively right) +/// leaves have least-significant bit 0 (respectively 1). +/// +/// Two Evaluation Arguments with the Hash Table guarantee correct transition of +/// stack elements ST0 through ST4. +fn instruction_merkle_step_shared_constraints( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let constant = |c: u32| circuit_builder.b_constant(c); + let one = || constant(1); + let curr = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let hv5_is_0_or_1 = curr(HV5) * (curr(HV5) - one()); + let new_st5_is_previous_st5_div_2 = constant(2) * next(ST5) + curr(HV5) - curr(ST5); + let update_merkle_tree_node_index = vec![hv5_is_0_or_1, new_st5_is_previous_st5_div_2]; + + [ + update_merkle_tree_node_index, + instruction_group_step_1(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_assert_vector( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + + let specific_constraints = vec![ + curr_base_row(ST5) - curr_base_row(ST0), + curr_base_row(ST6) - curr_base_row(ST1), + curr_base_row(ST7) - curr_base_row(ST2), + curr_base_row(ST8) - curr_base_row(ST3), + curr_base_row(ST9) - curr_base_row(ST4), + ]; + [ + specific_constraints, + instruction_group_step_1(circuit_builder), + constraints_for_shrinking_stack_by(circuit_builder, 5), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_sponge_init( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + [ + instruction_group_step_1(circuit_builder), + instruction_group_keep_op_stack(circuit_builder), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_sponge_absorb( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + [ + instruction_group_step_1(circuit_builder), + constraints_for_shrinking_stack_by(circuit_builder, 10), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_sponge_absorb_mem( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + let constant = |c| circuit_builder.b_constant(c); + + let increment_ram_pointer = + next_base_row(ST0) - curr_base_row(ST0) - constant(tip5::RATE as u32); + + [ + vec![increment_ram_pointer], + instruction_group_step_1(circuit_builder), + instruction_group_op_stack_remains_except_top_n(circuit_builder, 5), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_sponge_squeeze( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + [ + instruction_group_step_1(circuit_builder), + constraints_for_growing_stack_by(circuit_builder, 10), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_add( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let specific_constraints = vec![next_base_row(ST0) - curr_base_row(ST0) - curr_base_row(ST1)]; + [ + specific_constraints, + instruction_group_step_1(circuit_builder), + instruction_group_binop(circuit_builder), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_addi( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let specific_constraints = vec![next_base_row(ST0) - curr_base_row(ST0) - curr_base_row(NIA)]; + [ + specific_constraints, + instruction_group_step_2(circuit_builder), + instruction_group_op_stack_remains_except_top_n(circuit_builder, 1), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_mul( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let specific_constraints = vec![next_base_row(ST0) - curr_base_row(ST0) * curr_base_row(ST1)]; + [ + specific_constraints, + instruction_group_step_1(circuit_builder), + instruction_group_binop(circuit_builder), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_invert( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let constant = |c: u32| circuit_builder.b_constant(c); + let one = || constant(1); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let specific_constraints = vec![next_base_row(ST0) * curr_base_row(ST0) - one()]; + [ + specific_constraints, + instruction_group_step_1(circuit_builder), + instruction_group_op_stack_remains_except_top_n(circuit_builder, 1), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_eq( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let constant = |c: u32| circuit_builder.b_constant(c); + let one = || constant(1); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let st0_eq_st1 = || one() - curr_base_row(HV0) * (curr_base_row(ST1) - curr_base_row(ST0)); + + // Helper variable hv0 is the inverse-or-zero of the difference of the stack's two top-most + // elements: `hv0·(1 - hv0·(st1 - st0))` + let hv0_is_inverse_of_diff_or_hv0_is_0 = curr_base_row(HV0) * st0_eq_st1(); + + // Helper variable hv0 is the inverse-or-zero of the difference of the stack's two + // top-most elements: `(st1 - st0)·(1 - hv0·(st1 - st0))` + let hv0_is_inverse_of_diff_or_diff_is_0 = + (curr_base_row(ST1) - curr_base_row(ST0)) * st0_eq_st1(); + + // The new top of the stack is 1 if the difference between the stack's two top-most + // elements is not invertible, 0 otherwise: `st0' - (1 - hv0·(st1 - st0))` + let st0_becomes_1_if_diff_is_not_invertible = next_base_row(ST0) - st0_eq_st1(); + + let specific_constraints = vec![ + hv0_is_inverse_of_diff_or_hv0_is_0, + hv0_is_inverse_of_diff_or_diff_is_0, + st0_becomes_1_if_diff_is_not_invertible, + ]; + [ + specific_constraints, + instruction_group_step_1(circuit_builder), + instruction_group_binop(circuit_builder), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_split( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let constant = |c: u64| circuit_builder.b_constant(c); + let one = || constant(1); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + // The top of the stack is decomposed as 32-bit chunks into the stack's top-most elements: + // st0 - (2^32·st0' + st1') = 0$ + let st0_decomposes_to_two_32_bit_chunks = + curr_base_row(ST0) - (constant(1 << 32) * next_base_row(ST1) + next_base_row(ST0)); + + // Helper variable `hv0` = 0 if either + // 1. `hv0` is the difference between (2^32 - 1) and the high 32 bits (`st0'`), or + // 1. the low 32 bits (`st1'`) are 0. + // + // st1'·(hv0·(st0' - (2^32 - 1)) - 1) + // lo·(hv0·(hi - 0xffff_ffff)) - 1) + let hv0_holds_inverse_of_chunk_difference_or_low_bits_are_0 = { + let hv0 = curr_base_row(HV0); + let hi = next_base_row(ST1); + let lo = next_base_row(ST0); + let ffff_ffff = constant(0xffff_ffff); + + lo * (hv0 * (hi - ffff_ffff) - one()) + }; + + let specific_constraints = vec![ + st0_decomposes_to_two_32_bit_chunks, + hv0_holds_inverse_of_chunk_difference_or_low_bits_are_0, + ]; + [ + specific_constraints, + instruction_group_grow_op_stack_and_top_two_elements_unconstrained(circuit_builder), + instruction_group_step_1(circuit_builder), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_lt( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + [ + instruction_group_step_1(circuit_builder), + instruction_group_binop(circuit_builder), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_and( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + [ + instruction_group_step_1(circuit_builder), + instruction_group_binop(circuit_builder), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_xor( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + [ + instruction_group_step_1(circuit_builder), + instruction_group_binop(circuit_builder), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_log_2_floor( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + [ + instruction_group_step_1(circuit_builder), + instruction_group_op_stack_remains_except_top_n(circuit_builder, 1), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_pow( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + [ + instruction_group_step_1(circuit_builder), + instruction_group_binop(circuit_builder), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_div_mod( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + // `n == d·q + r` means `st0 - st1·st1' - st0'` + let numerator_is_quotient_times_denominator_plus_remainder = + curr_base_row(ST0) - curr_base_row(ST1) * next_base_row(ST1) - next_base_row(ST0); + + let specific_constraints = vec![numerator_is_quotient_times_denominator_plus_remainder]; + [ + specific_constraints, + instruction_group_step_1(circuit_builder), + instruction_group_op_stack_remains_except_top_n(circuit_builder, 2), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_pop_count( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + [ + instruction_group_step_1(circuit_builder), + instruction_group_op_stack_remains_except_top_n(circuit_builder, 1), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_xx_add( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let st0_becomes_st0_plus_st3 = next_base_row(ST0) - curr_base_row(ST0) - curr_base_row(ST3); + let st1_becomes_st1_plus_st4 = next_base_row(ST1) - curr_base_row(ST1) - curr_base_row(ST4); + let st2_becomes_st2_plus_st5 = next_base_row(ST2) - curr_base_row(ST2) - curr_base_row(ST5); + let specific_constraints = vec![ + st0_becomes_st0_plus_st3, + st1_becomes_st1_plus_st4, + st2_becomes_st2_plus_st5, + ]; + + [ + specific_constraints, + constraints_for_shrinking_stack_by_3_and_top_3_unconstrained(circuit_builder), + instruction_group_step_1(circuit_builder), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_xx_mul( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let [x0, x1, x2, y0, y1, y2] = [ST0, ST1, ST2, ST3, ST4, ST5].map(curr_base_row); + let [c0, c1, c2] = xx_product([x0, x1, x2], [y0, y1, y2]); + + let specific_constraints = vec![ + next_base_row(ST0) - c0, + next_base_row(ST1) - c1, + next_base_row(ST2) - c2, + ]; + [ + specific_constraints, + constraints_for_shrinking_stack_by_3_and_top_3_unconstrained(circuit_builder), + instruction_group_step_1(circuit_builder), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_xinv( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let constant = |c: u64| circuit_builder.b_constant(c); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let first_coefficient_of_product_of_element_and_inverse_is_1 = curr_base_row(ST0) + * next_base_row(ST0) + - curr_base_row(ST2) * next_base_row(ST1) + - curr_base_row(ST1) * next_base_row(ST2) + - constant(1); + + let second_coefficient_of_product_of_element_and_inverse_is_0 = + curr_base_row(ST1) * next_base_row(ST0) + curr_base_row(ST0) * next_base_row(ST1) + - curr_base_row(ST2) * next_base_row(ST2) + + curr_base_row(ST2) * next_base_row(ST1) + + curr_base_row(ST1) * next_base_row(ST2); + + let third_coefficient_of_product_of_element_and_inverse_is_0 = curr_base_row(ST2) + * next_base_row(ST0) + + curr_base_row(ST1) * next_base_row(ST1) + + curr_base_row(ST0) * next_base_row(ST2) + + curr_base_row(ST2) * next_base_row(ST2); + + let specific_constraints = vec![ + first_coefficient_of_product_of_element_and_inverse_is_1, + second_coefficient_of_product_of_element_and_inverse_is_0, + third_coefficient_of_product_of_element_and_inverse_is_0, + ]; + [ + specific_constraints, + instruction_group_op_stack_remains_except_top_n(circuit_builder, 3), + instruction_group_step_1(circuit_builder), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_xb_mul( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let [x, y0, y1, y2] = [ST0, ST1, ST2, ST3].map(curr_base_row); + let [c0, c1, c2] = xb_product([y0, y1, y2], x); + + let specific_constraints = vec![ + next_base_row(ST0) - c0, + next_base_row(ST1) - c1, + next_base_row(ST2) - c2, + ]; + [ + specific_constraints, + instruction_group_op_stack_shrinks_and_top_three_elements_unconstrained(circuit_builder), + instruction_group_step_1(circuit_builder), + instruction_group_no_ram(circuit_builder), + instruction_group_no_io(circuit_builder), + ] + .concat() +} + +fn instruction_read_io( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let constraint_groups_for_legal_arguments = NumberOfWords::legal_values() + .map(|n| grow_stack_by_n_and_read_n_symbols_from_input(circuit_builder, n)) + .to_vec(); + let read_any_legal_number_of_words = combine_mutually_exclusive_constraint_groups( + circuit_builder, + constraint_groups_for_legal_arguments, + ); + + [ + instruction_group_step_2(circuit_builder), + instruction_group_decompose_arg(circuit_builder), + read_any_legal_number_of_words, + prohibit_any_illegal_number_of_words(circuit_builder), + instruction_group_no_ram(circuit_builder), + vec![running_evaluation_for_standard_output_remains_unchanged( + circuit_builder, + )], + ] + .concat() +} + +fn instruction_write_io( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let constraint_groups_for_legal_arguments = NumberOfWords::legal_values() + .map(|n| shrink_stack_by_n_and_write_n_symbols_to_output(circuit_builder, n)) + .to_vec(); + let write_any_of_1_through_5_elements = combine_mutually_exclusive_constraint_groups( + circuit_builder, + constraint_groups_for_legal_arguments, + ); + + [ + instruction_group_step_2(circuit_builder), + instruction_group_decompose_arg(circuit_builder), + write_any_of_1_through_5_elements, + prohibit_any_illegal_number_of_words(circuit_builder), + instruction_group_no_ram(circuit_builder), + vec![running_evaluation_for_standard_input_remains_unchanged( + circuit_builder, + )], + ] + .concat() +} + +/// Update the accumulator for the Permutation Argument with the RAM table in +/// accordance with reading a bunch of words from the indicated ram pointers to +/// the indicated destination registers. +/// +/// Does not constrain the op stack by default.[^stack] For that, see: +/// [`read_from_ram_any_of`]. +/// +/// [^stack]: Op stack registers used in arguments will be constrained. +fn read_from_ram_to( + circuit_builder: &ConstraintCircuitBuilder, + ram_pointers: [ConstraintCircuitMonad; N], + destinations: [ConstraintCircuitMonad; N], +) -> ConstraintCircuitMonad { + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let curr_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) + }; + let next_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(NextExtRow(col.master_ext_table_index())) + }; + let challenge = |c: ChallengeId| circuit_builder.challenge(c); + let constant = |bfe| circuit_builder.b_constant(bfe); + + let compress_row = |(ram_pointer, destination)| { + curr_base_row(CLK) * challenge(RamClkWeight) + + constant(table::ram::INSTRUCTION_TYPE_READ) * challenge(RamInstructionTypeWeight) + + ram_pointer * challenge(RamPointerWeight) + + destination * challenge(RamValueWeight) + }; + + let factor = ram_pointers + .into_iter() + .zip(destinations) + .map(compress_row) + .map(|compressed_row| challenge(RamIndeterminate) - compressed_row) + .reduce(|l, r| l * r) + .unwrap_or_else(|| constant(bfe!(1))); + curr_ext_row(RamTablePermArg) * factor - next_ext_row(RamTablePermArg) +} + +fn xx_product( + [x_0, x_1, x_2]: [ConstraintCircuitMonad; EXTENSION_DEGREE], + [y_0, y_1, y_2]: [ConstraintCircuitMonad; EXTENSION_DEGREE], +) -> [ConstraintCircuitMonad; EXTENSION_DEGREE] { + let z0 = x_0.clone() * y_0.clone(); + let z1 = x_1.clone() * y_0.clone() + x_0.clone() * y_1.clone(); + let z2 = x_2.clone() * y_0 + x_1.clone() * y_1.clone() + x_0 * y_2.clone(); + let z3 = x_2.clone() * y_1 + x_1 * y_2.clone(); + let z4 = x_2 * y_2; + + // reduce modulo x³ - x + 1 + [z0 - z3.clone(), z1 - z4.clone() + z3, z2 + z4] +} + +fn xb_product( + [x_0, x_1, x_2]: [ConstraintCircuitMonad; EXTENSION_DEGREE], + y: ConstraintCircuitMonad, +) -> [ConstraintCircuitMonad; EXTENSION_DEGREE] { + let z0 = x_0 * y.clone(); + let z1 = x_1 * y.clone(); + let z2 = x_2 * y; + [z0, z1, z2] +} + +fn update_dotstep_accumulator( + circuit_builder: &ConstraintCircuitBuilder, + accumulator_indices: [ProcessorBaseTableColumn; EXTENSION_DEGREE], + difference: [ConstraintCircuitMonad; EXTENSION_DEGREE], +) -> Vec> { + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + let curr = accumulator_indices.map(curr_base_row); + let next = accumulator_indices.map(next_base_row); + izip!(curr, next, difference) + .map(|(c, n, d)| n - c - d) + .collect() +} + +fn instruction_xx_dot_step( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + let constant = |c| circuit_builder.b_constant(c); + + let increment_ram_pointer_st0 = next_base_row(ST0) - curr_base_row(ST0) - constant(3); + let increment_ram_pointer_st1 = next_base_row(ST1) - curr_base_row(ST1) - constant(3); + + let rhs_ptr0 = curr_base_row(ST0); + let rhs_ptr1 = rhs_ptr0.clone() + constant(1); + let rhs_ptr2 = rhs_ptr0.clone() + constant(2); + let lhs_ptr0 = curr_base_row(ST1); + let lhs_ptr1 = lhs_ptr0.clone() + constant(1); + let lhs_ptr2 = lhs_ptr0.clone() + constant(2); + let ram_read_sources = [rhs_ptr0, rhs_ptr1, rhs_ptr2, lhs_ptr0, lhs_ptr1, lhs_ptr2]; + let ram_read_destinations = [HV0, HV1, HV2, HV3, HV4, HV5].map(curr_base_row); + let read_two_xfes_from_ram = + read_from_ram_to(circuit_builder, ram_read_sources, ram_read_destinations); + + let ram_pointer_constraints = vec![ + increment_ram_pointer_st0, + increment_ram_pointer_st1, + read_two_xfes_from_ram, + ]; + + let [hv0, hv1, hv2, hv3, hv4, hv5] = [HV0, HV1, HV2, HV3, HV4, HV5].map(curr_base_row); + let hv_product = xx_product([hv0, hv1, hv2], [hv3, hv4, hv5]); + + [ + ram_pointer_constraints, + update_dotstep_accumulator(circuit_builder, [ST2, ST3, ST4], hv_product), + instruction_group_step_1(circuit_builder), + instruction_group_no_io(circuit_builder), + instruction_group_op_stack_remains_except_top_n(circuit_builder, 5), + ] + .concat() +} + +fn instruction_xb_dot_step( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + let constant = |c| circuit_builder.b_constant(c); + + let increment_ram_pointer_st0 = next_base_row(ST0) - curr_base_row(ST0) - constant(1); + let increment_ram_pointer_st1 = next_base_row(ST1) - curr_base_row(ST1) - constant(3); + + let rhs_ptr0 = curr_base_row(ST0); + let lhs_ptr0 = curr_base_row(ST1); + let lhs_ptr1 = lhs_ptr0.clone() + constant(1); + let lhs_ptr2 = lhs_ptr0.clone() + constant(2); + let ram_read_sources = [rhs_ptr0, lhs_ptr0, lhs_ptr1, lhs_ptr2]; + let ram_read_destinations = [HV0, HV1, HV2, HV3].map(curr_base_row); + let read_bfe_and_xfe_from_ram = + read_from_ram_to(circuit_builder, ram_read_sources, ram_read_destinations); + + let ram_pointer_constraints = vec![ + increment_ram_pointer_st0, + increment_ram_pointer_st1, + read_bfe_and_xfe_from_ram, + ]; + + let [hv0, hv1, hv2, hv3] = [HV0, HV1, HV2, HV3].map(curr_base_row); + let hv_product = xb_product([hv1, hv2, hv3], hv0); + + [ + ram_pointer_constraints, + update_dotstep_accumulator(circuit_builder, [ST2, ST3, ST4], hv_product), + instruction_group_step_1(circuit_builder), + instruction_group_no_io(circuit_builder), + instruction_group_op_stack_remains_except_top_n(circuit_builder, 5), + ] + .concat() +} + +#[doc(hidden)] // allows testing in different crate +pub fn transition_constraints_for_instruction( + circuit_builder: &ConstraintCircuitBuilder, + instruction: Instruction, +) -> Vec> { + match instruction { + Instruction::Pop(_) => instruction_pop(circuit_builder), + Instruction::Push(_) => instruction_push(circuit_builder), + Instruction::Divine(_) => instruction_divine(circuit_builder), + Instruction::Dup(_) => instruction_dup(circuit_builder), + Instruction::Swap(_) => instruction_swap(circuit_builder), + Instruction::Halt => instruction_halt(circuit_builder), + Instruction::Nop => instruction_nop(circuit_builder), + Instruction::Skiz => instruction_skiz(circuit_builder), + Instruction::Call(_) => instruction_call(circuit_builder), + Instruction::Return => instruction_return(circuit_builder), + Instruction::Recurse => instruction_recurse(circuit_builder), + Instruction::RecurseOrReturn => instruction_recurse_or_return(circuit_builder), + Instruction::Assert => instruction_assert(circuit_builder), + Instruction::ReadMem(_) => instruction_read_mem(circuit_builder), + Instruction::WriteMem(_) => instruction_write_mem(circuit_builder), + Instruction::Hash => instruction_hash(circuit_builder), + Instruction::AssertVector => instruction_assert_vector(circuit_builder), + Instruction::SpongeInit => instruction_sponge_init(circuit_builder), + Instruction::SpongeAbsorb => instruction_sponge_absorb(circuit_builder), + Instruction::SpongeAbsorbMem => instruction_sponge_absorb_mem(circuit_builder), + Instruction::SpongeSqueeze => instruction_sponge_squeeze(circuit_builder), + Instruction::Add => instruction_add(circuit_builder), + Instruction::AddI(_) => instruction_addi(circuit_builder), + Instruction::Mul => instruction_mul(circuit_builder), + Instruction::Invert => instruction_invert(circuit_builder), + Instruction::Eq => instruction_eq(circuit_builder), + Instruction::Split => instruction_split(circuit_builder), + Instruction::Lt => instruction_lt(circuit_builder), + Instruction::And => instruction_and(circuit_builder), + Instruction::Xor => instruction_xor(circuit_builder), + Instruction::Log2Floor => instruction_log_2_floor(circuit_builder), + Instruction::Pow => instruction_pow(circuit_builder), + Instruction::DivMod => instruction_div_mod(circuit_builder), + Instruction::PopCount => instruction_pop_count(circuit_builder), + Instruction::XxAdd => instruction_xx_add(circuit_builder), + Instruction::XxMul => instruction_xx_mul(circuit_builder), + Instruction::XInvert => instruction_xinv(circuit_builder), + Instruction::XbMul => instruction_xb_mul(circuit_builder), + Instruction::ReadIo(_) => instruction_read_io(circuit_builder), + Instruction::WriteIo(_) => instruction_write_io(circuit_builder), + Instruction::MerkleStep => instruction_merkle_step(circuit_builder), + Instruction::MerkleStepMem => instruction_merkle_step_mem(circuit_builder), + Instruction::XxDotStep => instruction_xx_dot_step(circuit_builder), + Instruction::XbDotStep => instruction_xb_dot_step(circuit_builder), + } +} + +/// Constrains instruction argument `nia` such that 0 < nia <= 5. +fn prohibit_any_illegal_number_of_words( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + vec![NumberOfWords::illegal_values() + .map(|n| indicator_polynomial(circuit_builder, n)) + .into_iter() + .sum()] +} + +fn log_derivative_accumulates_clk_next( + circuit_builder: &ConstraintCircuitBuilder, +) -> ConstraintCircuitMonad { + let challenge = |c: ChallengeId| circuit_builder.challenge(c); + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + let curr_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) + }; + let next_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(NextExtRow(col.master_ext_table_index())) + }; + + (next_ext_row(ClockJumpDifferenceLookupServerLogDerivative) + - curr_ext_row(ClockJumpDifferenceLookupServerLogDerivative)) + * (challenge(ClockJumpDifferenceLookupIndeterminate) - next_base_row(CLK)) + - next_base_row(ClockJumpDifferenceLookupMultiplicity) +} + +fn running_evaluation_for_standard_input_remains_unchanged( + circuit_builder: &ConstraintCircuitBuilder, +) -> ConstraintCircuitMonad { + let curr_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) + }; + let next_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(NextExtRow(col.master_ext_table_index())) + }; + + next_ext_row(InputTableEvalArg) - curr_ext_row(InputTableEvalArg) +} + +fn running_evaluation_for_standard_output_remains_unchanged( + circuit_builder: &ConstraintCircuitBuilder, +) -> ConstraintCircuitMonad { + let curr_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) + }; + let next_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(NextExtRow(col.master_ext_table_index())) + }; + + next_ext_row(OutputTableEvalArg) - curr_ext_row(OutputTableEvalArg) +} + +fn grow_stack_by_n_and_read_n_symbols_from_input( + circuit_builder: &ConstraintCircuitBuilder, + n: usize, +) -> Vec> { + let indeterminate = || circuit_builder.challenge(StandardInputIndeterminate); + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + let curr_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) + }; + let next_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(NextExtRow(col.master_ext_table_index())) + }; + + let mut running_evaluation = curr_ext_row(InputTableEvalArg); + for i in (0..n).rev() { + let stack_element = ProcessorTable::op_stack_column_by_index(i); + running_evaluation = indeterminate() * running_evaluation + next_base_row(stack_element); + } + let running_evaluation_update = next_ext_row(InputTableEvalArg) - running_evaluation; + let conditional_running_evaluation_update = + indicator_polynomial(circuit_builder, n) * running_evaluation_update; + + let mut constraints = conditional_constraints_for_growing_stack_by(circuit_builder, n); + constraints.push(conditional_running_evaluation_update); + constraints +} + +fn shrink_stack_by_n_and_write_n_symbols_to_output( + circuit_builder: &ConstraintCircuitBuilder, + n: usize, +) -> Vec> { + let indeterminate = || circuit_builder.challenge(StandardOutputIndeterminate); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let curr_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) + }; + let next_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(NextExtRow(col.master_ext_table_index())) + }; + + let mut running_evaluation = curr_ext_row(OutputTableEvalArg); + for i in 0..n { + let stack_element = ProcessorTable::op_stack_column_by_index(i); + running_evaluation = indeterminate() * running_evaluation + curr_base_row(stack_element); + } + let running_evaluation_update = next_ext_row(OutputTableEvalArg) - running_evaluation; + let conditional_running_evaluation_update = + indicator_polynomial(circuit_builder, n) * running_evaluation_update; + + let mut constraints = conditional_constraints_for_shrinking_stack_by(circuit_builder, n); + constraints.push(conditional_running_evaluation_update); + constraints +} + +fn log_derivative_for_instruction_lookup_updates_correctly( + circuit_builder: &ConstraintCircuitBuilder, +) -> ConstraintCircuitMonad { + let one = || circuit_builder.b_constant(1); + let challenge = |c: ChallengeId| circuit_builder.challenge(c); + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + let curr_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) + }; + let next_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(NextExtRow(col.master_ext_table_index())) + }; + + let compressed_row = challenge(ProgramAddressWeight) * next_base_row(IP) + + challenge(ProgramInstructionWeight) * next_base_row(CI) + + challenge(ProgramNextInstructionWeight) * next_base_row(NIA); + let log_derivative_updates = (next_ext_row(InstructionLookupClientLogDerivative) + - curr_ext_row(InstructionLookupClientLogDerivative)) + * (challenge(InstructionLookupIndeterminate) - compressed_row) + - one(); + let log_derivative_remains = next_ext_row(InstructionLookupClientLogDerivative) + - curr_ext_row(InstructionLookupClientLogDerivative); + + (one() - next_base_row(IsPadding)) * log_derivative_updates + + next_base_row(IsPadding) * log_derivative_remains +} + +fn constraints_for_shrinking_stack_by_3_and_top_3_unconstrained( + circuit_builder: &ConstraintCircuitBuilder, +) -> Vec> { + let constant = |c: u64| circuit_builder.b_constant(c); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + vec![ + next_base_row(ST3) - curr_base_row(ST6), + next_base_row(ST4) - curr_base_row(ST7), + next_base_row(ST5) - curr_base_row(ST8), + next_base_row(ST6) - curr_base_row(ST9), + next_base_row(ST7) - curr_base_row(ST10), + next_base_row(ST8) - curr_base_row(ST11), + next_base_row(ST9) - curr_base_row(ST12), + next_base_row(ST10) - curr_base_row(ST13), + next_base_row(ST11) - curr_base_row(ST14), + next_base_row(ST12) - curr_base_row(ST15), + next_base_row(OpStackPointer) - curr_base_row(OpStackPointer) + constant(3), + running_product_op_stack_accounts_for_shrinking_stack_by(circuit_builder, 3), + ] +} + +fn stack_shrinks_by_any_of( + circuit_builder: &ConstraintCircuitBuilder, + shrinkages: &[usize], +) -> Vec> { + let all_constraints_for_all_shrinkages = shrinkages + .iter() + .map(|&n| conditional_constraints_for_shrinking_stack_by(circuit_builder, n)) + .collect_vec(); + + combine_mutually_exclusive_constraint_groups( + circuit_builder, + all_constraints_for_all_shrinkages, + ) +} + +fn stack_grows_by_any_of( + circuit_builder: &ConstraintCircuitBuilder, + growths: &[usize], +) -> Vec> { + let all_constraints_for_all_growths = growths + .iter() + .map(|&n| conditional_constraints_for_growing_stack_by(circuit_builder, n)) + .collect_vec(); + + combine_mutually_exclusive_constraint_groups(circuit_builder, all_constraints_for_all_growths) +} + +/// Reduces the number of constraints by summing mutually exclusive constraints. The mutual +/// exclusion is due to the conditional nature of the constraints, which has to be guaranteed by +/// the caller. +/// +/// For example, the constraints for shrinking the stack by 2, 3, and 4 elements are: +/// +/// ```markdown +/// | shrink by 2 | shrink by 3 | shrink by 4 | +/// |-----------------------:|-----------------------:|-----------------------:| +/// | ind_2·(st0' - st2) | ind_3·(st0' - st3) | ind_4·(st0' - st4) | +/// | ind_2·(st1' - st3) | ind_3·(st1' - st4) | ind_4·(st1' - st5) | +/// | … | … | … | +/// | ind_2·(st11' - st13) | ind_3·(st11' - st14) | ind_4·(st11' - st15) | +/// | ind_2·(st12' - st14) | ind_3·(st12' - st15) | ind_4·(osp' - osp + 4) | +/// | ind_2·(st13' - st15) | ind_3·(osp' - osp + 3) | ind_4·(rp' - rp·fac_4) | +/// | ind_2·(osp' - osp + 2) | ind_3·(rp' - rp·fac_3) | | +/// | ind_2·(rp' - rp·fac_2) | | | +/// ``` +/// +/// This method sums these constraints “per row”. That is, the resulting constraints are: +/// +/// ```markdown +/// | shrink by 2 or 3 or 4 | +/// |-----------------------------------------------------------------------:| +/// | ind_2·(st0' - st2) + ind_3·(st0' - st3) + ind_4·(st0' - st4) | +/// | ind_2·(st1' - st3) + ind_3·(st1' - st4) + ind_4·(st1' - st5) | +/// | … | +/// | ind_2·(st11' - st13) + ind_3·(st11' - st14) + ind_4·(st11' - st15) | +/// | ind_2·(st12' - st14) + ind_3·(st12' - st15) + ind_4·(osp' - osp + 4) | +/// | ind_2·(st13' - st15) + ind_3·(osp' - osp + 3) + ind_4·(rp' - rp·fac_4) | +/// | ind_2·(osp' - osp + 2) + ind_3·(rp' - rp·fac_3) | +/// | ind_2·(rp' - rp·fac_2) | +/// ``` +/// +/// Syntax in above example: +/// - `ind_n` is the [indicator polynomial](indicator_polynomial) for `n` +/// - `osp` is the [op stack pointer](OpStackPointer) +/// - `rp` is the running product for the permutation argument +/// - `fac_n` is the factor for the running product +fn combine_mutually_exclusive_constraint_groups( + circuit_builder: &ConstraintCircuitBuilder, + all_constraint_groups: Vec>>, +) -> Vec> { + let constraint_group_lengths = all_constraint_groups.iter().map(|x| x.len()); + let num_constraints = constraint_group_lengths.max().unwrap_or(0); + + let zero_constraint = || circuit_builder.b_constant(0); + let mut combined_constraints = vec![]; + for i in 0..num_constraints { + let combined_constraint = all_constraint_groups + .iter() + .filter_map(|constraint_group| constraint_group.get(i)) + .fold(zero_constraint(), |acc, summand| acc + summand.clone()); + combined_constraints.push(combined_constraint); + } + combined_constraints +} + +fn constraints_for_shrinking_stack_by( + circuit_builder: &ConstraintCircuitBuilder, + n: usize, +) -> Vec> { + let constant = |c: usize| circuit_builder.b_constant(u64::try_from(c).unwrap()); + let curr_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let stack = || (0..OpStackElement::COUNT).map(ProcessorTable::op_stack_column_by_index); + let new_stack = stack().dropping_back(n).map(next_row).collect_vec(); + let old_stack_with_top_n_removed = stack().skip(n).map(curr_row).collect_vec(); + + let compress = |stack: Vec<_>| -> ConstraintCircuitMonad<_> { + assert_eq!(OpStackElement::COUNT - n, stack.len()); + let weight = |i| circuit_builder.challenge(stack_weight_by_index(i)); + let enumerated_stack = stack.into_iter().enumerate(); + enumerated_stack.map(|(i, st)| weight(i) * st).sum() + }; + let compressed_new_stack = compress(new_stack); + let compressed_old_stack = compress(old_stack_with_top_n_removed); + + let op_stack_pointer_shrinks_by_n = + next_row(OpStackPointer) - curr_row(OpStackPointer) + constant(n); + let new_stack_is_old_stack_with_top_n_removed = compressed_new_stack - compressed_old_stack; + + vec![ + op_stack_pointer_shrinks_by_n, + new_stack_is_old_stack_with_top_n_removed, + running_product_op_stack_accounts_for_shrinking_stack_by(circuit_builder, n), + ] +} + +fn constraints_for_growing_stack_by( + circuit_builder: &ConstraintCircuitBuilder, + n: usize, +) -> Vec> { + let constant = |c: usize| circuit_builder.b_constant(u32::try_from(c).unwrap()); + let curr_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let stack = || (0..OpStackElement::COUNT).map(ProcessorTable::op_stack_column_by_index); + let new_stack = stack().skip(n).map(next_row).collect_vec(); + let old_stack_with_top_n_added = stack().map(curr_row).dropping_back(n).collect_vec(); + + let compress = |stack: Vec<_>| -> ConstraintCircuitMonad<_> { + assert_eq!(OpStackElement::COUNT - n, stack.len()); + let weight = |i| circuit_builder.challenge(stack_weight_by_index(i)); + let enumerated_stack = stack.into_iter().enumerate(); + enumerated_stack.map(|(i, st)| weight(i) * st).sum() + }; + let compressed_new_stack = compress(new_stack); + let compressed_old_stack = compress(old_stack_with_top_n_added); + + let op_stack_pointer_grows_by_n = + next_row(OpStackPointer) - curr_row(OpStackPointer) - constant(n); + let new_stack_is_old_stack_with_top_n_added = compressed_new_stack - compressed_old_stack; + + vec![ + op_stack_pointer_grows_by_n, + new_stack_is_old_stack_with_top_n_added, + running_product_op_stack_accounts_for_growing_stack_by(circuit_builder, n), + ] +} + +fn conditional_constraints_for_shrinking_stack_by( + circuit_builder: &ConstraintCircuitBuilder, + n: usize, +) -> Vec> { + constraints_for_shrinking_stack_by(circuit_builder, n) + .into_iter() + .map(|constraint| indicator_polynomial(circuit_builder, n) * constraint) + .collect() +} + +fn conditional_constraints_for_growing_stack_by( + circuit_builder: &ConstraintCircuitBuilder, + n: usize, +) -> Vec> { + constraints_for_growing_stack_by(circuit_builder, n) + .into_iter() + .map(|constraint| indicator_polynomial(circuit_builder, n) * constraint) + .collect() +} + +fn running_product_op_stack_accounts_for_growing_stack_by( + circuit_builder: &ConstraintCircuitBuilder, + n: usize, +) -> ConstraintCircuitMonad { + let constant = |c: u32| circuit_builder.b_constant(c); + let curr_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) + }; + let next_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(NextExtRow(col.master_ext_table_index())) + }; + let single_grow_factor = |op_stack_pointer_offset| { + single_factor_for_permutation_argument_with_op_stack_table( + circuit_builder, + CurrentBaseRow, + op_stack_pointer_offset, + ) + }; + + let mut factor = constant(1); + for op_stack_pointer_offset in 0..n { + factor = factor * single_grow_factor(op_stack_pointer_offset); + } + + next_ext_row(OpStackTablePermArg) - curr_ext_row(OpStackTablePermArg) * factor +} + +fn running_product_op_stack_accounts_for_shrinking_stack_by( + circuit_builder: &ConstraintCircuitBuilder, + n: usize, +) -> ConstraintCircuitMonad { + let constant = |c: u32| circuit_builder.b_constant(c); + let curr_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) + }; + let next_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(NextExtRow(col.master_ext_table_index())) + }; + let single_shrink_factor = |op_stack_pointer_offset| { + single_factor_for_permutation_argument_with_op_stack_table( + circuit_builder, + NextBaseRow, + op_stack_pointer_offset, + ) + }; + + let mut factor = constant(1); + for op_stack_pointer_offset in 0..n { + factor = factor * single_shrink_factor(op_stack_pointer_offset); + } + + next_ext_row(OpStackTablePermArg) - curr_ext_row(OpStackTablePermArg) * factor +} + +fn single_factor_for_permutation_argument_with_op_stack_table( + circuit_builder: &ConstraintCircuitBuilder, + row_with_shorter_stack_indicator: fn(usize) -> DualRowIndicator, + op_stack_pointer_offset: usize, +) -> ConstraintCircuitMonad { + let constant = |c: u32| circuit_builder.b_constant(c); + let challenge = |c: ChallengeId| circuit_builder.challenge(c); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let row_with_shorter_stack = |col: ProcessorBaseTableColumn| { + circuit_builder.input(row_with_shorter_stack_indicator( + col.master_base_table_index(), + )) + }; + + let max_stack_element_index = OpStackElement::COUNT - 1; + let stack_element_index = max_stack_element_index - op_stack_pointer_offset; + let stack_element = ProcessorTable::op_stack_column_by_index(stack_element_index); + let underflow_element = row_with_shorter_stack(stack_element); + + let op_stack_pointer = row_with_shorter_stack(OpStackPointer); + let offset = constant(op_stack_pointer_offset as u32); + let offset_op_stack_pointer = op_stack_pointer + offset; + + let compressed_row = challenge(OpStackClkWeight) * curr_base_row(CLK) + + challenge(OpStackIb1Weight) * curr_base_row(IB1) + + challenge(OpStackPointerWeight) * offset_op_stack_pointer + + challenge(OpStackFirstUnderflowElementWeight) * underflow_element; + challenge(OpStackIndeterminate) - compressed_row +} + +/// Build constraints for popping `n` elements from the top of the stack and +/// writing them to RAM. The reciprocal of [`read_from_ram_any_of`]. +fn write_to_ram_any_of( + circuit_builder: &ConstraintCircuitBuilder, + number_of_words: &[usize], +) -> Vec> { + let all_constraint_groups = number_of_words + .iter() + .map(|&n| conditional_constraints_for_writing_n_elements_to_ram(circuit_builder, n)) + .collect_vec(); + combine_mutually_exclusive_constraint_groups(circuit_builder, all_constraint_groups) +} + +/// Build constraints for reading `n` elements from RAM and putting them on top +/// of the stack. The reciprocal of [`write_to_ram_any_of`]. +/// +/// To constrain RAM reads with more flexible target locations, see +/// [`read_from_ram_to`]. +fn read_from_ram_any_of( + circuit_builder: &ConstraintCircuitBuilder, + number_of_words: &[usize], +) -> Vec> { + let all_constraint_groups = number_of_words + .iter() + .map(|&n| conditional_constraints_for_reading_n_elements_from_ram(circuit_builder, n)) + .collect_vec(); + combine_mutually_exclusive_constraint_groups(circuit_builder, all_constraint_groups) +} + +fn conditional_constraints_for_writing_n_elements_to_ram( + circuit_builder: &ConstraintCircuitBuilder, + n: usize, +) -> Vec> { + shrink_stack_by_n_and_write_n_elements_to_ram(circuit_builder, n) + .into_iter() + .map(|constraint| indicator_polynomial(circuit_builder, n) * constraint) + .collect() +} + +fn conditional_constraints_for_reading_n_elements_from_ram( + circuit_builder: &ConstraintCircuitBuilder, + n: usize, +) -> Vec> { + grow_stack_by_n_and_read_n_elements_from_ram(circuit_builder, n) + .into_iter() + .map(|constraint| indicator_polynomial(circuit_builder, n) * constraint) + .collect() +} + +fn shrink_stack_by_n_and_write_n_elements_to_ram( + circuit_builder: &ConstraintCircuitBuilder, + n: usize, +) -> Vec> { + let constant = |c: usize| circuit_builder.b_constant(u32::try_from(c).unwrap()); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let op_stack_pointer_shrinks_by_n = + next_base_row(OpStackPointer) - curr_base_row(OpStackPointer) + constant(n); + let ram_pointer_grows_by_n = next_base_row(ST0) - curr_base_row(ST0) - constant(n); + + let mut constraints = vec![ + op_stack_pointer_shrinks_by_n, + ram_pointer_grows_by_n, + running_product_op_stack_accounts_for_shrinking_stack_by(circuit_builder, n), + running_product_ram_accounts_for_writing_n_elements(circuit_builder, n), + ]; + + let num_ram_pointers = 1; + for i in n + num_ram_pointers..OpStackElement::COUNT { + let curr_stack_element = ProcessorTable::op_stack_column_by_index(i); + let next_stack_element = ProcessorTable::op_stack_column_by_index(i - n); + let element_i_is_shifted_by_n = + next_base_row(next_stack_element) - curr_base_row(curr_stack_element); + constraints.push(element_i_is_shifted_by_n); + } + constraints +} + +fn grow_stack_by_n_and_read_n_elements_from_ram( + circuit_builder: &ConstraintCircuitBuilder, + n: usize, +) -> Vec> { + let constant = |c: usize| circuit_builder.b_constant(u64::try_from(c).unwrap()); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let op_stack_pointer_grows_by_n = + next_base_row(OpStackPointer) - curr_base_row(OpStackPointer) - constant(n); + let ram_pointer_shrinks_by_n = next_base_row(ST0) - curr_base_row(ST0) + constant(n); + + let mut constraints = vec![ + op_stack_pointer_grows_by_n, + ram_pointer_shrinks_by_n, + running_product_op_stack_accounts_for_growing_stack_by(circuit_builder, n), + running_product_ram_accounts_for_reading_n_elements(circuit_builder, n), + ]; + + let num_ram_pointers = 1; + for i in num_ram_pointers..OpStackElement::COUNT - n { + let curr_stack_element = ProcessorTable::op_stack_column_by_index(i); + let next_stack_element = ProcessorTable::op_stack_column_by_index(i + n); + let element_i_is_shifted_by_n = + next_base_row(next_stack_element) - curr_base_row(curr_stack_element); + constraints.push(element_i_is_shifted_by_n); + } + constraints +} + +fn running_product_ram_accounts_for_writing_n_elements( + circuit_builder: &ConstraintCircuitBuilder, + n: usize, +) -> ConstraintCircuitMonad { + let constant = |c: u32| circuit_builder.b_constant(c); + let curr_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) + }; + let next_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(NextExtRow(col.master_ext_table_index())) + }; + let single_write_factor = |ram_pointer_offset| { + single_factor_for_permutation_argument_with_ram_table( + circuit_builder, + CurrentBaseRow, + table::ram::INSTRUCTION_TYPE_WRITE, + ram_pointer_offset, + ) + }; + + let mut factor = constant(1); + for ram_pointer_offset in 0..n { + factor = factor * single_write_factor(ram_pointer_offset); + } + + next_ext_row(RamTablePermArg) - curr_ext_row(RamTablePermArg) * factor +} + +fn running_product_ram_accounts_for_reading_n_elements( + circuit_builder: &ConstraintCircuitBuilder, + n: usize, +) -> ConstraintCircuitMonad { + let constant = |c: u32| circuit_builder.b_constant(c); + let curr_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) + }; + let next_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(NextExtRow(col.master_ext_table_index())) + }; + let single_read_factor = |ram_pointer_offset| { + single_factor_for_permutation_argument_with_ram_table( + circuit_builder, + NextBaseRow, + table::ram::INSTRUCTION_TYPE_READ, + ram_pointer_offset, + ) + }; + + let mut factor = constant(1); + for ram_pointer_offset in 0..n { + factor = factor * single_read_factor(ram_pointer_offset); + } + + next_ext_row(RamTablePermArg) - curr_ext_row(RamTablePermArg) * factor +} + +fn single_factor_for_permutation_argument_with_ram_table( + circuit_builder: &ConstraintCircuitBuilder, + row_with_longer_stack_indicator: fn(usize) -> DualRowIndicator, + instruction_type: BFieldElement, + ram_pointer_offset: usize, +) -> ConstraintCircuitMonad { + let constant = |c: u32| circuit_builder.b_constant(c); + let b_constant = |c| circuit_builder.b_constant(c); + let challenge = |c: ChallengeId| circuit_builder.challenge(c); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let row_with_longer_stack = |col: ProcessorBaseTableColumn| { + circuit_builder.input(row_with_longer_stack_indicator( + col.master_base_table_index(), + )) + }; + + let num_ram_pointers = 1; + let ram_value_index = ram_pointer_offset + num_ram_pointers; + let ram_value_column = ProcessorTable::op_stack_column_by_index(ram_value_index); + let ram_value = row_with_longer_stack(ram_value_column); + + let additional_offset = match instruction_type { + table::ram::INSTRUCTION_TYPE_READ => 1, + table::ram::INSTRUCTION_TYPE_WRITE => 0, + _ => panic!("Invalid instruction type"), + }; + + let ram_pointer = row_with_longer_stack(ST0); + let offset = constant(additional_offset + ram_pointer_offset as u32); + let offset_ram_pointer = ram_pointer + offset; + + let compressed_row = curr_base_row(CLK) * challenge(RamClkWeight) + + b_constant(instruction_type) * challenge(RamInstructionTypeWeight) + + offset_ram_pointer * challenge(RamPointerWeight) + + ram_value * challenge(RamValueWeight); + challenge(RamIndeterminate) - compressed_row +} + +fn running_product_for_jump_stack_table_updates_correctly( + circuit_builder: &ConstraintCircuitBuilder, +) -> ConstraintCircuitMonad { + let challenge = |c: ChallengeId| circuit_builder.challenge(c); + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + let curr_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) + }; + let next_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(NextExtRow(col.master_ext_table_index())) + }; + + let compressed_row = challenge(JumpStackClkWeight) * next_base_row(CLK) + + challenge(JumpStackCiWeight) * next_base_row(CI) + + challenge(JumpStackJspWeight) * next_base_row(JSP) + + challenge(JumpStackJsoWeight) * next_base_row(JSO) + + challenge(JumpStackJsdWeight) * next_base_row(JSD); + + next_ext_row(JumpStackTablePermArg) + - curr_ext_row(JumpStackTablePermArg) * (challenge(JumpStackIndeterminate) - compressed_row) +} + +/// Deal with instructions `hash` and `merkle_step`. The registers from which +/// the preimage is loaded differs between the two instructions: +/// 1. `hash` always loads the stack's 10 top elements, +/// 1. `merkle_step` loads the stack's 5 top elements and helper variables 0 +/// through 4. The order of those two quintuplets depends on helper variable +/// hv5. +/// +/// The Hash Table does not “know” about instruction `merkle_step`. +/// +/// Note that using `next_row` might be confusing at first glance; See the +/// [specification](https://triton-vm.org/spec/processor-table.html). +fn running_evaluation_hash_input_updates_correctly( + circuit_builder: &ConstraintCircuitBuilder, +) -> ConstraintCircuitMonad { + let constant = |c: u32| circuit_builder.b_constant(c); + let one = || constant(1); + let challenge = |c: ChallengeId| circuit_builder.challenge(c); + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + let curr_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) + }; + let next_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(NextExtRow(col.master_ext_table_index())) + }; + + let hash_deselector = instruction_deselector_next_row(circuit_builder, Instruction::Hash); + let merkle_step_deselector = + instruction_deselector_next_row(circuit_builder, Instruction::MerkleStep); + let merkle_step_mem_deselector = + instruction_deselector_next_row(circuit_builder, Instruction::MerkleStepMem); + let hash_and_merkle_step_selector = (next_base_row(CI) - constant(Instruction::Hash.opcode())) + * (next_base_row(CI) - constant(Instruction::MerkleStep.opcode())) + * (next_base_row(CI) - constant(Instruction::MerkleStepMem.opcode())); + + let weights = [ + StackWeight0, + StackWeight1, + StackWeight2, + StackWeight3, + StackWeight4, + StackWeight5, + StackWeight6, + StackWeight7, + StackWeight8, + StackWeight9, + ] + .map(challenge); + + // hash + let state_for_hash = [ST0, ST1, ST2, ST3, ST4, ST5, ST6, ST7, ST8, ST9].map(next_base_row); + let compressed_hash_row = weights + .iter() + .zip_eq(state_for_hash) + .map(|(weight, state)| weight.clone() * state) + .sum(); + + // merkle step + let is_left_sibling = || next_base_row(HV5); + let is_right_sibling = || one() - next_base_row(HV5); + let merkle_step_state_element = + |l, r| is_right_sibling() * next_base_row(l) + is_left_sibling() * next_base_row(r); + let state_for_merkle_step = [ + merkle_step_state_element(ST0, HV0), + merkle_step_state_element(ST1, HV1), + merkle_step_state_element(ST2, HV2), + merkle_step_state_element(ST3, HV3), + merkle_step_state_element(ST4, HV4), + merkle_step_state_element(HV0, ST0), + merkle_step_state_element(HV1, ST1), + merkle_step_state_element(HV2, ST2), + merkle_step_state_element(HV3, ST3), + merkle_step_state_element(HV4, ST4), + ]; + let compressed_merkle_step_row = weights + .into_iter() + .zip_eq(state_for_merkle_step) + .map(|(weight, state)| weight * state) + .sum::>(); + + let running_evaluation_updates_with = |compressed_row| { + next_ext_row(HashInputEvalArg) + - challenge(HashInputIndeterminate) * curr_ext_row(HashInputEvalArg) + - compressed_row + }; + let running_evaluation_remains = + next_ext_row(HashInputEvalArg) - curr_ext_row(HashInputEvalArg); + + hash_and_merkle_step_selector * running_evaluation_remains + + hash_deselector * running_evaluation_updates_with(compressed_hash_row) + + merkle_step_deselector + * running_evaluation_updates_with(compressed_merkle_step_row.clone()) + + merkle_step_mem_deselector * running_evaluation_updates_with(compressed_merkle_step_row) +} + +fn running_evaluation_hash_digest_updates_correctly( + circuit_builder: &ConstraintCircuitBuilder, +) -> ConstraintCircuitMonad { + let constant = |c: u32| circuit_builder.b_constant(c); + let challenge = |c: ChallengeId| circuit_builder.challenge(c); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + let curr_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) + }; + let next_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(NextExtRow(col.master_ext_table_index())) + }; + + let hash_deselector = instruction_deselector_current_row(circuit_builder, Instruction::Hash); + let merkle_step_deselector = + instruction_deselector_current_row(circuit_builder, Instruction::MerkleStep); + let merkle_step_mem_deselector = + instruction_deselector_current_row(circuit_builder, Instruction::MerkleStepMem); + let hash_and_merkle_step_selector = (curr_base_row(CI) - constant(Instruction::Hash.opcode())) + * (curr_base_row(CI) - constant(Instruction::MerkleStep.opcode())) + * (curr_base_row(CI) - constant(Instruction::MerkleStepMem.opcode())); + + let weights = [ + StackWeight0, + StackWeight1, + StackWeight2, + StackWeight3, + StackWeight4, + ] + .map(challenge); + let state = [ST0, ST1, ST2, ST3, ST4].map(next_base_row); + let compressed_row = weights + .into_iter() + .zip_eq(state) + .map(|(weight, state)| weight * state) + .sum(); + + let running_evaluation_updates = next_ext_row(HashDigestEvalArg) + - challenge(HashDigestIndeterminate) * curr_ext_row(HashDigestEvalArg) + - compressed_row; + let running_evaluation_remains = + next_ext_row(HashDigestEvalArg) - curr_ext_row(HashDigestEvalArg); + + hash_and_merkle_step_selector * running_evaluation_remains + + (hash_deselector + merkle_step_deselector + merkle_step_mem_deselector) + * running_evaluation_updates +} + +fn running_evaluation_sponge_updates_correctly( + circuit_builder: &ConstraintCircuitBuilder, +) -> ConstraintCircuitMonad { + let constant = |c: u32| circuit_builder.b_constant(c); + let challenge = |c: ChallengeId| circuit_builder.challenge(c); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + let curr_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) + }; + let next_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(NextExtRow(col.master_ext_table_index())) + }; + + let sponge_init_deselector = + instruction_deselector_current_row(circuit_builder, Instruction::SpongeInit); + let sponge_absorb_deselector = + instruction_deselector_current_row(circuit_builder, Instruction::SpongeAbsorb); + let sponge_absorb_mem_deselector = + instruction_deselector_current_row(circuit_builder, Instruction::SpongeAbsorbMem); + let sponge_squeeze_deselector = + instruction_deselector_current_row(circuit_builder, Instruction::SpongeSqueeze); + + let sponge_instruction_selector = (curr_base_row(CI) + - constant(Instruction::SpongeInit.opcode())) + * (curr_base_row(CI) - constant(Instruction::SpongeAbsorb.opcode())) + * (curr_base_row(CI) - constant(Instruction::SpongeAbsorbMem.opcode())) + * (curr_base_row(CI) - constant(Instruction::SpongeSqueeze.opcode())); + + let weighted_sum = |state| { + let weights = [ + StackWeight0, + StackWeight1, + StackWeight2, + StackWeight3, + StackWeight4, + StackWeight5, + StackWeight6, + StackWeight7, + StackWeight8, + StackWeight9, + ]; + let weights = weights.map(challenge).into_iter(); + weights.zip_eq(state).map(|(w, st)| w * st).sum() + }; + + let state = [ST0, ST1, ST2, ST3, ST4, ST5, ST6, ST7, ST8, ST9]; + let compressed_row_current = weighted_sum(state.map(curr_base_row)); + let compressed_row_next = weighted_sum(state.map(next_base_row)); + + // Use domain-specific knowledge: the compressed row (i.e., random linear sum) + // of the initial Sponge state is 0. + let running_evaluation_updates_for_sponge_init = next_ext_row(SpongeEvalArg) + - challenge(SpongeIndeterminate) * curr_ext_row(SpongeEvalArg) + - challenge(HashCIWeight) * curr_base_row(CI); + let running_evaluation_updates_for_absorb = + running_evaluation_updates_for_sponge_init.clone() - compressed_row_current; + let running_evaluation_updates_for_squeeze = + running_evaluation_updates_for_sponge_init.clone() - compressed_row_next; + let running_evaluation_remains = next_ext_row(SpongeEvalArg) - curr_ext_row(SpongeEvalArg); + + // `sponge_absorb_mem` + let stack_elements = [ST1, ST2, ST3, ST4].map(next_base_row); + let hv_elements = [HV0, HV1, HV2, HV3, HV4, HV5].map(curr_base_row); + let absorb_mem_elements = stack_elements.into_iter().chain(hv_elements); + let absorb_mem_elements = absorb_mem_elements.collect_vec().try_into().unwrap(); + let compressed_row_absorb_mem = weighted_sum(absorb_mem_elements); + let running_evaluation_updates_for_absorb_mem = next_ext_row(SpongeEvalArg) + - challenge(SpongeIndeterminate) * curr_ext_row(SpongeEvalArg) + - challenge(HashCIWeight) * constant(Instruction::SpongeAbsorb.opcode()) + - compressed_row_absorb_mem; + + sponge_instruction_selector * running_evaluation_remains + + sponge_init_deselector * running_evaluation_updates_for_sponge_init + + sponge_absorb_deselector * running_evaluation_updates_for_absorb + + sponge_absorb_mem_deselector * running_evaluation_updates_for_absorb_mem + + sponge_squeeze_deselector * running_evaluation_updates_for_squeeze +} + +fn log_derivative_with_u32_table_updates_correctly( + circuit_builder: &ConstraintCircuitBuilder, +) -> ConstraintCircuitMonad { + let constant = |c: u32| circuit_builder.b_constant(c); + let one = || constant(1); + let two_inverse = circuit_builder.b_constant(bfe!(2).inverse()); + let challenge = |c: ChallengeId| circuit_builder.challenge(c); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + let curr_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) + }; + let next_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(NextExtRow(col.master_ext_table_index())) + }; + + let split_deselector = instruction_deselector_current_row(circuit_builder, Instruction::Split); + let lt_deselector = instruction_deselector_current_row(circuit_builder, Instruction::Lt); + let and_deselector = instruction_deselector_current_row(circuit_builder, Instruction::And); + let xor_deselector = instruction_deselector_current_row(circuit_builder, Instruction::Xor); + let pow_deselector = instruction_deselector_current_row(circuit_builder, Instruction::Pow); + let log_2_floor_deselector = + instruction_deselector_current_row(circuit_builder, Instruction::Log2Floor); + let div_mod_deselector = + instruction_deselector_current_row(circuit_builder, Instruction::DivMod); + let pop_count_deselector = + instruction_deselector_current_row(circuit_builder, Instruction::PopCount); + let merkle_step_deselector = + instruction_deselector_current_row(circuit_builder, Instruction::MerkleStep); + let merkle_step_mem_deselector = + instruction_deselector_current_row(circuit_builder, Instruction::MerkleStepMem); + + let running_sum = curr_ext_row(U32LookupClientLogDerivative); + let running_sum_next = next_ext_row(U32LookupClientLogDerivative); + + let split_factor = challenge(U32Indeterminate) + - challenge(U32LhsWeight) * next_base_row(ST0) + - challenge(U32RhsWeight) * next_base_row(ST1) + - challenge(U32CiWeight) * curr_base_row(CI); + let binop_factor = challenge(U32Indeterminate) + - challenge(U32LhsWeight) * curr_base_row(ST0) + - challenge(U32RhsWeight) * curr_base_row(ST1) + - challenge(U32CiWeight) * curr_base_row(CI) + - challenge(U32ResultWeight) * next_base_row(ST0); + let xor_factor = challenge(U32Indeterminate) + - challenge(U32LhsWeight) * curr_base_row(ST0) + - challenge(U32RhsWeight) * curr_base_row(ST1) + - challenge(U32CiWeight) * constant(Instruction::And.opcode()) + - challenge(U32ResultWeight) + * (curr_base_row(ST0) + curr_base_row(ST1) - next_base_row(ST0)) + * two_inverse; + let unop_factor = challenge(U32Indeterminate) + - challenge(U32LhsWeight) * curr_base_row(ST0) + - challenge(U32CiWeight) * curr_base_row(CI) + - challenge(U32ResultWeight) * next_base_row(ST0); + let div_mod_factor_for_lt = challenge(U32Indeterminate) + - challenge(U32LhsWeight) * next_base_row(ST0) + - challenge(U32RhsWeight) * curr_base_row(ST1) + - challenge(U32CiWeight) * constant(Instruction::Lt.opcode()) + - challenge(U32ResultWeight); + let div_mod_factor_for_range_check = challenge(U32Indeterminate) + - challenge(U32LhsWeight) * curr_base_row(ST0) + - challenge(U32RhsWeight) * next_base_row(ST1) + - challenge(U32CiWeight) * constant(Instruction::Split.opcode()); + let merkle_step_range_check_factor = challenge(U32Indeterminate) + - challenge(U32LhsWeight) * curr_base_row(ST5) + - challenge(U32RhsWeight) * next_base_row(ST5) + - challenge(U32CiWeight) * constant(Instruction::Split.opcode()); + + let running_sum_absorbs_split_factor = + (running_sum_next.clone() - running_sum.clone()) * split_factor - one(); + let running_sum_absorbs_binop_factor = + (running_sum_next.clone() - running_sum.clone()) * binop_factor - one(); + let running_sum_absorb_xor_factor = + (running_sum_next.clone() - running_sum.clone()) * xor_factor - one(); + let running_sum_absorbs_unop_factor = + (running_sum_next.clone() - running_sum.clone()) * unop_factor - one(); + let running_sum_absorbs_merkle_step_factor = + (running_sum_next.clone() - running_sum.clone()) * merkle_step_range_check_factor - one(); + + let split_summand = split_deselector * running_sum_absorbs_split_factor; + let lt_summand = lt_deselector * running_sum_absorbs_binop_factor.clone(); + let and_summand = and_deselector * running_sum_absorbs_binop_factor.clone(); + let xor_summand = xor_deselector * running_sum_absorb_xor_factor; + let pow_summand = pow_deselector * running_sum_absorbs_binop_factor; + let log_2_floor_summand = log_2_floor_deselector * running_sum_absorbs_unop_factor.clone(); + let div_mod_summand = div_mod_deselector + * ((running_sum_next.clone() - running_sum.clone()) + * div_mod_factor_for_lt.clone() + * div_mod_factor_for_range_check.clone() + - div_mod_factor_for_lt + - div_mod_factor_for_range_check); + let pop_count_summand = pop_count_deselector * running_sum_absorbs_unop_factor; + let merkle_step_summand = + merkle_step_deselector * running_sum_absorbs_merkle_step_factor.clone(); + let merkle_step_mem_summand = + merkle_step_mem_deselector * running_sum_absorbs_merkle_step_factor; + let no_update_summand = (one() - curr_base_row(IB2)) * (running_sum_next - running_sum); + + split_summand + + lt_summand + + and_summand + + xor_summand + + pow_summand + + log_2_floor_summand + + div_mod_summand + + pop_count_summand + + merkle_step_summand + + merkle_step_mem_summand + + no_update_summand +} + +fn stack_weight_by_index(i: usize) -> ChallengeId { + match i { + 0 => StackWeight0, + 1 => StackWeight1, + 2 => StackWeight2, + 3 => StackWeight3, + 4 => StackWeight4, + 5 => StackWeight5, + 6 => StackWeight6, + 7 => StackWeight7, + 8 => StackWeight8, + 9 => StackWeight9, + 10 => StackWeight10, + 11 => StackWeight11, + 12 => StackWeight12, + 13 => StackWeight13, + 14 => StackWeight14, + 15 => StackWeight15, + i => panic!("Op Stack weight index must be in [0, 15], not {i}."), + } +} + +/// A polynomial that is 1 when evaluated on the given index, and 0 otherwise. +fn indicator_polynomial( + circuit_builder: &ConstraintCircuitBuilder, + index: usize, +) -> ConstraintCircuitMonad { + let one = || circuit_builder.b_constant(1); + let hv = |idx| helper_variable(circuit_builder, idx); + + match index { + 0 => (one() - hv(3)) * (one() - hv(2)) * (one() - hv(1)) * (one() - hv(0)), + 1 => (one() - hv(3)) * (one() - hv(2)) * (one() - hv(1)) * hv(0), + 2 => (one() - hv(3)) * (one() - hv(2)) * hv(1) * (one() - hv(0)), + 3 => (one() - hv(3)) * (one() - hv(2)) * hv(1) * hv(0), + 4 => (one() - hv(3)) * hv(2) * (one() - hv(1)) * (one() - hv(0)), + 5 => (one() - hv(3)) * hv(2) * (one() - hv(1)) * hv(0), + 6 => (one() - hv(3)) * hv(2) * hv(1) * (one() - hv(0)), + 7 => (one() - hv(3)) * hv(2) * hv(1) * hv(0), + 8 => hv(3) * (one() - hv(2)) * (one() - hv(1)) * (one() - hv(0)), + 9 => hv(3) * (one() - hv(2)) * (one() - hv(1)) * hv(0), + 10 => hv(3) * (one() - hv(2)) * hv(1) * (one() - hv(0)), + 11 => hv(3) * (one() - hv(2)) * hv(1) * hv(0), + 12 => hv(3) * hv(2) * (one() - hv(1)) * (one() - hv(0)), + 13 => hv(3) * hv(2) * (one() - hv(1)) * hv(0), + 14 => hv(3) * hv(2) * hv(1) * (one() - hv(0)), + 15 => hv(3) * hv(2) * hv(1) * hv(0), + i => panic!("indicator polynomial index {i} out of bounds"), + } +} + +fn helper_variable( + circuit_builder: &ConstraintCircuitBuilder, + index: usize, +) -> ConstraintCircuitMonad { + match index { + 0 => circuit_builder.input(CurrentBaseRow(HV0.master_base_table_index())), + 1 => circuit_builder.input(CurrentBaseRow(HV1.master_base_table_index())), + 2 => circuit_builder.input(CurrentBaseRow(HV2.master_base_table_index())), + 3 => circuit_builder.input(CurrentBaseRow(HV3.master_base_table_index())), + 4 => circuit_builder.input(CurrentBaseRow(HV4.master_base_table_index())), + 5 => circuit_builder.input(CurrentBaseRow(HV5.master_base_table_index())), + i => unimplemented!("Helper variable index {i} out of bounds."), + } +} + +#[cfg(test)] +mod tests { + use ndarray::s; + use ndarray::Array2; + use num_traits::identities::Zero; + use proptest::prop_assert_eq; + use proptest_arbitrary_interop::arb; + use test_strategy::proptest; + + use crate::table::NUM_BASE_COLUMNS; + use crate::table::NUM_EXT_COLUMNS; + + use super::*; + + #[test] + fn instruction_deselector_gives_0_for_all_other_instructions() { + let circuit_builder = ConstraintCircuitBuilder::new(); + + let mut master_base_table = Array2::zeros([2, NUM_BASE_COLUMNS]); + let master_ext_table = Array2::zeros([2, NUM_EXT_COLUMNS]); + + // For this test, dummy challenges suffice to evaluate the constraints. + let dummy_challenges = (0..ChallengeId::COUNT) + .map(|i| XFieldElement::from(i as u64)) + .collect_vec(); + for instruction in ALL_INSTRUCTIONS { + use ProcessorBaseTableColumn::*; + let deselector = instruction_deselector_current_row(&circuit_builder, instruction); + + println!("\n\nThe Deselector for instruction {instruction} is:\n{deselector}",); + + // Negative tests + for other_instruction in ALL_INSTRUCTIONS + .into_iter() + .filter(|other_instruction| *other_instruction != instruction) + { + let mut curr_row = master_base_table.slice_mut(s![0, ..]); + curr_row[IB0.master_base_table_index()] = other_instruction.ib(InstructionBit::IB0); + curr_row[IB1.master_base_table_index()] = other_instruction.ib(InstructionBit::IB1); + curr_row[IB2.master_base_table_index()] = other_instruction.ib(InstructionBit::IB2); + curr_row[IB3.master_base_table_index()] = other_instruction.ib(InstructionBit::IB3); + curr_row[IB4.master_base_table_index()] = other_instruction.ib(InstructionBit::IB4); + curr_row[IB5.master_base_table_index()] = other_instruction.ib(InstructionBit::IB5); + curr_row[IB6.master_base_table_index()] = other_instruction.ib(InstructionBit::IB6); + let result = deselector.clone().consume().evaluate( + master_base_table.view(), + master_ext_table.view(), + &dummy_challenges, + ); + + assert!( + result.is_zero(), + "Deselector for {instruction} should return 0 for all other instructions, \ + including {other_instruction} whose opcode is {}", + other_instruction.opcode() + ) + } + + // Positive tests + let mut curr_row = master_base_table.slice_mut(s![0, ..]); + curr_row[IB0.master_base_table_index()] = instruction.ib(InstructionBit::IB0); + curr_row[IB1.master_base_table_index()] = instruction.ib(InstructionBit::IB1); + curr_row[IB2.master_base_table_index()] = instruction.ib(InstructionBit::IB2); + curr_row[IB3.master_base_table_index()] = instruction.ib(InstructionBit::IB3); + curr_row[IB4.master_base_table_index()] = instruction.ib(InstructionBit::IB4); + curr_row[IB5.master_base_table_index()] = instruction.ib(InstructionBit::IB5); + curr_row[IB6.master_base_table_index()] = instruction.ib(InstructionBit::IB6); + let result = deselector.consume().evaluate( + master_base_table.view(), + master_ext_table.view(), + &dummy_challenges, + ); + assert!( + !result.is_zero(), + "Deselector for {instruction} should be non-zero when CI is {}", + instruction.opcode() + ) + } + } + + #[test] + fn print_number_and_degrees_of_transition_constraints_for_all_instructions() { + println!(); + println!("| Instruction | #polys | max deg | Degrees"); + println!("|:--------------------|-------:|--------:|:------------"); + let circuit_builder = ConstraintCircuitBuilder::new(); + for instruction in ALL_INSTRUCTIONS { + let constraints = transition_constraints_for_instruction(&circuit_builder, instruction); + let degrees = constraints + .iter() + .map(|circuit| circuit.clone().consume().degree()) + .collect_vec(); + let max_degree = degrees.iter().max().unwrap_or(&0); + let degrees_str = degrees.iter().join(", "); + println!( + "| {:<19} | {:>6} | {max_degree:>7} | [{degrees_str}]", + format!("{instruction}"), + constraints.len(), + ); + } + } + #[test] + fn range_check_for_skiz_is_as_efficient_as_possible() { + let range_check_constraints = next_instruction_range_check_constraints_for_instruction_skiz( + &ConstraintCircuitBuilder::new(), + ); + let range_check_constraints = range_check_constraints.iter(); + let all_degrees = range_check_constraints.map(|c| c.clone().consume().degree()); + let max_constraint_degree = all_degrees.max().unwrap_or(0); + assert!( + crate::TARGET_DEGREE <= max_constraint_degree, + "Can the range check constraints be of a higher degree, saving columns?" + ); + } + + #[test] + fn helper_variables_in_bounds() { + let circuit_builder = ConstraintCircuitBuilder::new(); + for index in 0..NUM_HELPER_VARIABLE_REGISTERS { + helper_variable(&circuit_builder, index); + } + } + + #[proptest] + #[should_panic(expected = "out of bounds")] + fn cannot_get_helper_variable_for_out_of_range_index( + #[strategy(NUM_HELPER_VARIABLE_REGISTERS..)] index: usize, + ) { + let circuit_builder = ConstraintCircuitBuilder::new(); + helper_variable(&circuit_builder, index); + } + + #[test] + fn indicator_polynomial_in_bounds() { + let circuit_builder = ConstraintCircuitBuilder::new(); + for index in 0..16 { + indicator_polynomial(&circuit_builder, index); + } + } + + #[proptest] + #[should_panic(expected = "out of bounds")] + fn cannot_get_indicator_polynomial_for_out_of_range_index( + #[strategy(16_usize.. + )] + index: usize, + ) { + let circuit_builder = ConstraintCircuitBuilder::new(); + indicator_polynomial(&circuit_builder, index); + } + + #[proptest] + fn indicator_polynomial_is_one_on_indicated_index_and_zero_on_all_other_indices( + #[strategy(0_usize..16)] indicator_poly_index: usize, + #[strategy(0_u64..16)] query_index: u64, + ) { + let mut base_table = Array2::ones([2, NUM_BASE_COLUMNS]); + let aux_table = Array2::ones([2, NUM_EXT_COLUMNS]); + + base_table[[0, HV0.master_base_table_index()]] = bfe!(query_index % 2); + base_table[[0, HV1.master_base_table_index()]] = bfe!((query_index >> 1) % 2); + base_table[[0, HV2.master_base_table_index()]] = bfe!((query_index >> 2) % 2); + base_table[[0, HV3.master_base_table_index()]] = bfe!((query_index >> 3) % 2); + + let builder = ConstraintCircuitBuilder::new(); + let indicator_poly = indicator_polynomial(&builder, indicator_poly_index).consume(); + let evaluation = indicator_poly.evaluate(base_table.view(), aux_table.view(), &[]); + + if indicator_poly_index as u64 == query_index { + prop_assert_eq!(xfe!(1), evaluation); + } else { + prop_assert_eq!(xfe!(0), evaluation); + } + } + + #[test] + fn can_get_op_stack_column_for_in_range_index() { + for index in 0..OpStackElement::COUNT { + let _ = ProcessorTable::op_stack_column_by_index(index); + } + } + + #[proptest] + #[should_panic(expected = "[0, 15]")] + fn cannot_get_op_stack_column_for_out_of_range_index( + #[strategy(OpStackElement::COUNT..)] index: usize, + ) { + let _ = ProcessorTable::op_stack_column_by_index(index); + } + + #[test] + fn can_get_stack_weight_for_in_range_index() { + for index in 0..OpStackElement::COUNT { + let _ = stack_weight_by_index(index); + } + } + + #[proptest] + #[should_panic(expected = "[0, 15]")] + fn cannot_get_stack_weight_for_out_of_range_index( + #[strategy(OpStackElement::COUNT..)] index: usize, + ) { + let _ = stack_weight_by_index(index); + } + + #[proptest] + fn xx_product_is_accurate( + #[strategy(arb())] a: XFieldElement, + #[strategy(arb())] b: XFieldElement, + ) { + let circuit_builder = ConstraintCircuitBuilder::new(); + let main_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(BaseRow(col.master_base_table_index())) + }; + let [x0, x1, x2, y0, y1, y2] = [ST0, ST1, ST2, ST3, ST4, ST5].map(main_row); + + let mut base_table = Array2::zeros([1, NUM_BASE_COLUMNS]); + let ext_table = Array2::zeros([1, NUM_EXT_COLUMNS]); + base_table[[0, ST0.master_base_table_index()]] = a.coefficients[0]; + base_table[[0, ST1.master_base_table_index()]] = a.coefficients[1]; + base_table[[0, ST2.master_base_table_index()]] = a.coefficients[2]; + base_table[[0, ST3.master_base_table_index()]] = b.coefficients[0]; + base_table[[0, ST4.master_base_table_index()]] = b.coefficients[1]; + base_table[[0, ST5.master_base_table_index()]] = b.coefficients[2]; + + let [c0, c1, c2] = xx_product([x0, x1, x2], [y0, y1, y2]) + .map(|c| c.consume()) + .map(|c| c.evaluate(base_table.view(), ext_table.view(), &[])) + .map(|xfe| xfe.unlift().unwrap()); + + let c = a * b; + prop_assert_eq!(c.coefficients[0], c0); + prop_assert_eq!(c.coefficients[1], c1); + prop_assert_eq!(c.coefficients[2], c2); + } + + #[proptest] + fn xb_product_is_accurate( + #[strategy(arb())] a: XFieldElement, + #[strategy(arb())] b: BFieldElement, + ) { + let circuit_builder = ConstraintCircuitBuilder::new(); + let base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(BaseRow(col.master_base_table_index())) + }; + let [x0, x1, x2, y] = [ST0, ST1, ST2, ST3].map(base_row); + + let mut base_table = Array2::zeros([1, NUM_BASE_COLUMNS]); + let ext_table = Array2::zeros([1, NUM_EXT_COLUMNS]); + base_table[[0, ST0.master_base_table_index()]] = a.coefficients[0]; + base_table[[0, ST1.master_base_table_index()]] = a.coefficients[1]; + base_table[[0, ST2.master_base_table_index()]] = a.coefficients[2]; + base_table[[0, ST3.master_base_table_index()]] = b; + + let [c0, c1, c2] = xb_product([x0, x1, x2], y) + .map(|c| c.consume()) + .map(|c| c.evaluate(base_table.view(), ext_table.view(), &[])) + .map(|xfe| xfe.unlift().unwrap()); + + let c = a * b; + prop_assert_eq!(c.coefficients[0], c0); + prop_assert_eq!(c.coefficients[1], c1); + prop_assert_eq!(c.coefficients[2], c2); + } +} diff --git a/triton-air/src/table/program.rs b/triton-air/src/table/program.rs new file mode 100644 index 000000000..f4e67490b --- /dev/null +++ b/triton-air/src/table/program.rs @@ -0,0 +1,279 @@ +use constraint_circuit::ConstraintCircuitBuilder; +use constraint_circuit::ConstraintCircuitMonad; +use constraint_circuit::DualRowIndicator; +use constraint_circuit::DualRowIndicator::CurrentBaseRow; +use constraint_circuit::DualRowIndicator::CurrentExtRow; +use constraint_circuit::DualRowIndicator::NextBaseRow; +use constraint_circuit::DualRowIndicator::NextExtRow; +use constraint_circuit::SingleRowIndicator; +use constraint_circuit::SingleRowIndicator::BaseRow; +use constraint_circuit::SingleRowIndicator::ExtRow; +use twenty_first::prelude::*; + +use crate::challenge_id::ChallengeId; +use crate::cross_table_argument::CrossTableArg; +use crate::cross_table_argument::EvalArg; +use crate::cross_table_argument::LookupArg; +use crate::table_column::MasterBaseTableColumn; +use crate::table_column::MasterExtTableColumn; +use crate::AIR; + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct ProgramTable; + +impl AIR for ProgramTable { + type MainColumn = crate::table_column::ProgramBaseTableColumn; + type AuxColumn = crate::table_column::ProgramExtTableColumn; + + fn initial_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + let challenge = |c| circuit_builder.challenge(c); + let x_constant = |xfe| circuit_builder.x_constant(xfe); + let base_row = + |col: Self::MainColumn| circuit_builder.input(BaseRow(col.master_base_table_index())); + let ext_row = + |col: Self::AuxColumn| circuit_builder.input(ExtRow(col.master_ext_table_index())); + + let address = base_row(Self::MainColumn::Address); + let instruction = base_row(Self::MainColumn::Instruction); + let index_in_chunk = base_row(Self::MainColumn::IndexInChunk); + let is_hash_input_padding = base_row(Self::MainColumn::IsHashInputPadding); + let instruction_lookup_log_derivative = + ext_row(Self::AuxColumn::InstructionLookupServerLogDerivative); + let prepare_chunk_running_evaluation = + ext_row(Self::AuxColumn::PrepareChunkRunningEvaluation); + let send_chunk_running_evaluation = ext_row(Self::AuxColumn::SendChunkRunningEvaluation); + + let lookup_arg_initial = x_constant(LookupArg::default_initial()); + let eval_arg_initial = x_constant(EvalArg::default_initial()); + + let program_attestation_prepare_chunk_indeterminate = + challenge(ChallengeId::ProgramAttestationPrepareChunkIndeterminate); + + let first_address_is_zero = address; + let index_in_chunk_is_zero = index_in_chunk; + let hash_input_padding_indicator_is_zero = is_hash_input_padding; + + let instruction_lookup_log_derivative_is_initialized_correctly = + instruction_lookup_log_derivative - lookup_arg_initial; + + let prepare_chunk_running_evaluation_has_absorbed_first_instruction = + prepare_chunk_running_evaluation + - eval_arg_initial.clone() * program_attestation_prepare_chunk_indeterminate + - instruction; + + let send_chunk_running_evaluation_is_default_initial = + send_chunk_running_evaluation - eval_arg_initial; + + vec![ + first_address_is_zero, + index_in_chunk_is_zero, + hash_input_padding_indicator_is_zero, + instruction_lookup_log_derivative_is_initialized_correctly, + prepare_chunk_running_evaluation_has_absorbed_first_instruction, + send_chunk_running_evaluation_is_default_initial, + ] + } + + fn consistency_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + let constant = |c: u32| circuit_builder.b_constant(c); + let main_row = + |col: Self::MainColumn| circuit_builder.input(BaseRow(col.master_base_table_index())); + + let one = constant(1); + let max_index_in_chunk = constant((Tip5::RATE - 1).try_into().unwrap()); + + let index_in_chunk = main_row(Self::MainColumn::IndexInChunk); + let max_minus_index_in_chunk_inv = main_row(Self::MainColumn::MaxMinusIndexInChunkInv); + let is_hash_input_padding = main_row(Self::MainColumn::IsHashInputPadding); + let is_table_padding = main_row(Self::MainColumn::IsTablePadding); + + let max_minus_index_in_chunk = max_index_in_chunk - index_in_chunk; + let max_minus_index_in_chunk_inv_is_zero_or_the_inverse_of_max_minus_index_in_chunk = + (one.clone() - max_minus_index_in_chunk.clone() * max_minus_index_in_chunk_inv.clone()) + * max_minus_index_in_chunk_inv.clone(); + let max_minus_index_in_chunk_is_zero_or_the_inverse_of_max_minus_index_in_chunk_inv = + (one.clone() - max_minus_index_in_chunk.clone() * max_minus_index_in_chunk_inv) + * max_minus_index_in_chunk; + + let is_hash_input_padding_is_bit = + is_hash_input_padding.clone() * (is_hash_input_padding - one.clone()); + let is_table_padding_is_bit = is_table_padding.clone() * (is_table_padding - one); + + vec![ + max_minus_index_in_chunk_inv_is_zero_or_the_inverse_of_max_minus_index_in_chunk, + max_minus_index_in_chunk_is_zero_or_the_inverse_of_max_minus_index_in_chunk_inv, + is_hash_input_padding_is_bit, + is_table_padding_is_bit, + ] + } + + fn transition_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + let challenge = |c| circuit_builder.challenge(c); + let constant = |c: u64| circuit_builder.b_constant(c); + + let current_base_row = |col: Self::MainColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: Self::MainColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + let current_ext_row = |col: Self::AuxColumn| { + circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) + }; + let next_ext_row = + |col: Self::AuxColumn| circuit_builder.input(NextExtRow(col.master_ext_table_index())); + + let one = constant(1); + let rate_minus_one = constant(u64::try_from(Tip5::RATE).unwrap() - 1); + + let prepare_chunk_indeterminate = + challenge(ChallengeId::ProgramAttestationPrepareChunkIndeterminate); + let send_chunk_indeterminate = + challenge(ChallengeId::ProgramAttestationSendChunkIndeterminate); + + let address = current_base_row(Self::MainColumn::Address); + let instruction = current_base_row(Self::MainColumn::Instruction); + let lookup_multiplicity = current_base_row(Self::MainColumn::LookupMultiplicity); + let index_in_chunk = current_base_row(Self::MainColumn::IndexInChunk); + let max_minus_index_in_chunk_inv = + current_base_row(Self::MainColumn::MaxMinusIndexInChunkInv); + let is_hash_input_padding = current_base_row(Self::MainColumn::IsHashInputPadding); + let is_table_padding = current_base_row(Self::MainColumn::IsTablePadding); + let log_derivative = current_ext_row(Self::AuxColumn::InstructionLookupServerLogDerivative); + let prepare_chunk_running_evaluation = + current_ext_row(Self::AuxColumn::PrepareChunkRunningEvaluation); + let send_chunk_running_evaluation = + current_ext_row(Self::AuxColumn::SendChunkRunningEvaluation); + + let address_next = next_base_row(Self::MainColumn::Address); + let instruction_next = next_base_row(Self::MainColumn::Instruction); + let index_in_chunk_next = next_base_row(Self::MainColumn::IndexInChunk); + let max_minus_index_in_chunk_inv_next = + next_base_row(Self::MainColumn::MaxMinusIndexInChunkInv); + let is_hash_input_padding_next = next_base_row(Self::MainColumn::IsHashInputPadding); + let is_table_padding_next = next_base_row(Self::MainColumn::IsTablePadding); + let log_derivative_next = + next_ext_row(Self::AuxColumn::InstructionLookupServerLogDerivative); + let prepare_chunk_running_evaluation_next = + next_ext_row(Self::AuxColumn::PrepareChunkRunningEvaluation); + let send_chunk_running_evaluation_next = + next_ext_row(Self::AuxColumn::SendChunkRunningEvaluation); + + let address_increases_by_one = address_next - (address.clone() + one.clone()); + let is_table_padding_is_0_or_remains_unchanged = + is_table_padding.clone() * (is_table_padding_next.clone() - is_table_padding); + + let index_in_chunk_cycles_correctly = (one.clone() + - max_minus_index_in_chunk_inv.clone() + * (rate_minus_one.clone() - index_in_chunk.clone())) + * index_in_chunk_next.clone() + + max_minus_index_in_chunk_inv.clone() + * (index_in_chunk_next.clone() - index_in_chunk.clone() - one.clone()); + + let hash_input_indicator_is_0_or_remains_unchanged = + is_hash_input_padding.clone() * (is_hash_input_padding_next.clone() - one.clone()); + + let first_hash_input_padding_is_1 = (is_hash_input_padding.clone() - one.clone()) + * is_hash_input_padding_next + * (instruction_next.clone() - one.clone()); + + let hash_input_padding_is_0_after_the_first_1 = + is_hash_input_padding.clone() * instruction_next.clone(); + + let next_row_is_table_padding_row = is_table_padding_next.clone() - one.clone(); + let table_padding_starts_when_hash_input_padding_is_active_and_index_in_chunk_is_zero = + is_hash_input_padding.clone() + * (one.clone() + - max_minus_index_in_chunk_inv.clone() + * (rate_minus_one.clone() - index_in_chunk.clone())) + * next_row_is_table_padding_row.clone(); + + let log_derivative_remains = log_derivative_next.clone() - log_derivative.clone(); + let compressed_row = challenge(ChallengeId::ProgramAddressWeight) * address + + challenge(ChallengeId::ProgramInstructionWeight) * instruction + + challenge(ChallengeId::ProgramNextInstructionWeight) * instruction_next.clone(); + + let indeterminate = challenge(ChallengeId::InstructionLookupIndeterminate); + let log_derivative_updates = (log_derivative_next - log_derivative) + * (indeterminate - compressed_row) + - lookup_multiplicity; + let log_derivative_updates_if_and_only_if_not_a_padding_row = + (one.clone() - is_hash_input_padding.clone()) * log_derivative_updates + + is_hash_input_padding * log_derivative_remains; + + let prepare_chunk_running_evaluation_absorbs_next_instruction = + prepare_chunk_running_evaluation_next.clone() + - prepare_chunk_indeterminate.clone() * prepare_chunk_running_evaluation + - instruction_next.clone(); + let prepare_chunk_running_evaluation_resets_and_absorbs_next_instruction = + prepare_chunk_running_evaluation_next.clone() + - prepare_chunk_indeterminate + - instruction_next; + let index_in_chunk_is_max = rate_minus_one.clone() - index_in_chunk.clone(); + let index_in_chunk_is_not_max = + one.clone() - max_minus_index_in_chunk_inv * (rate_minus_one.clone() - index_in_chunk); + let prepare_chunk_running_evaluation_resets_every_rate_rows_and_absorbs_next_instruction = + index_in_chunk_is_max * prepare_chunk_running_evaluation_absorbs_next_instruction + + index_in_chunk_is_not_max + * prepare_chunk_running_evaluation_resets_and_absorbs_next_instruction; + + let send_chunk_running_evaluation_absorbs_next_chunk = send_chunk_running_evaluation_next + .clone() + - send_chunk_indeterminate * send_chunk_running_evaluation.clone() + - prepare_chunk_running_evaluation_next; + let send_chunk_running_evaluation_does_not_change = + send_chunk_running_evaluation_next - send_chunk_running_evaluation; + let index_in_chunk_next_is_max = rate_minus_one - index_in_chunk_next; + let index_in_chunk_next_is_not_max = + one - max_minus_index_in_chunk_inv_next * index_in_chunk_next_is_max.clone(); + + let send_chunk_running_eval_absorbs_chunk_iff_index_in_chunk_next_is_max_and_not_padding_row = + send_chunk_running_evaluation_absorbs_next_chunk + * next_row_is_table_padding_row + * index_in_chunk_next_is_not_max + + send_chunk_running_evaluation_does_not_change.clone() * is_table_padding_next + + send_chunk_running_evaluation_does_not_change * index_in_chunk_next_is_max; + + vec![ + address_increases_by_one, + is_table_padding_is_0_or_remains_unchanged, + index_in_chunk_cycles_correctly, + hash_input_indicator_is_0_or_remains_unchanged, + first_hash_input_padding_is_1, + hash_input_padding_is_0_after_the_first_1, + table_padding_starts_when_hash_input_padding_is_active_and_index_in_chunk_is_zero, + log_derivative_updates_if_and_only_if_not_a_padding_row, + prepare_chunk_running_evaluation_resets_every_rate_rows_and_absorbs_next_instruction, + send_chunk_running_eval_absorbs_chunk_iff_index_in_chunk_next_is_max_and_not_padding_row, + ] + } + + fn terminal_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + let constant = |c: u64| circuit_builder.b_constant(c); + let main_row = + |col: Self::MainColumn| circuit_builder.input(BaseRow(col.master_base_table_index())); + + let index_in_chunk = main_row(Self::MainColumn::IndexInChunk); + let is_hash_input_padding = main_row(Self::MainColumn::IsHashInputPadding); + let is_table_padding = main_row(Self::MainColumn::IsTablePadding); + + let hash_input_padding_is_one = is_hash_input_padding - constant(1); + + let max_index_in_chunk = Tip5::RATE as u64 - 1; + let index_in_chunk_is_max_or_row_is_padding_row = + (index_in_chunk - constant(max_index_in_chunk)) * (is_table_padding - constant(1)); + + vec![ + hash_input_padding_is_one, + index_in_chunk_is_max_or_row_is_padding_row, + ] + } +} diff --git a/triton-air/src/table/ram.rs b/triton-air/src/table/ram.rs new file mode 100644 index 000000000..ba29141a3 --- /dev/null +++ b/triton-air/src/table/ram.rs @@ -0,0 +1,275 @@ +use constraint_circuit::ConstraintCircuitBuilder; +use constraint_circuit::ConstraintCircuitMonad; +use constraint_circuit::DualRowIndicator; +use constraint_circuit::DualRowIndicator::CurrentBaseRow; +use constraint_circuit::DualRowIndicator::CurrentExtRow; +use constraint_circuit::DualRowIndicator::NextBaseRow; +use constraint_circuit::DualRowIndicator::NextExtRow; +use constraint_circuit::SingleRowIndicator; +use constraint_circuit::SingleRowIndicator::BaseRow; +use constraint_circuit::SingleRowIndicator::ExtRow; +use twenty_first::prelude::*; + +use crate::challenge_id::ChallengeId; +use crate::cross_table_argument::CrossTableArg; +use crate::cross_table_argument::LookupArg; +use crate::cross_table_argument::PermArg; +use crate::table_column::MasterBaseTableColumn; +use crate::table_column::MasterExtTableColumn; +use crate::AIR; + +pub const INSTRUCTION_TYPE_WRITE: BFieldElement = BFieldElement::new(0); +pub const INSTRUCTION_TYPE_READ: BFieldElement = BFieldElement::new(1); +pub const PADDING_INDICATOR: BFieldElement = BFieldElement::new(2); + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct RamTable; + +impl AIR for RamTable { + type MainColumn = crate::table_column::RamBaseTableColumn; + type AuxColumn = crate::table_column::RamExtTableColumn; + + fn initial_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + let challenge = |c| circuit_builder.challenge(c); + let constant = |c| circuit_builder.b_constant(c); + let x_constant = |c| circuit_builder.x_constant(c); + let main_row = |column: Self::MainColumn| { + circuit_builder.input(BaseRow(column.master_base_table_index())) + }; + let aux_row = |column: Self::AuxColumn| { + circuit_builder.input(ExtRow(column.master_ext_table_index())) + }; + + let first_row_is_padding_row = + main_row(Self::MainColumn::InstructionType) - constant(PADDING_INDICATOR); + let first_row_is_not_padding_row = (main_row(Self::MainColumn::InstructionType) + - constant(INSTRUCTION_TYPE_READ)) + * (main_row(Self::MainColumn::InstructionType) - constant(INSTRUCTION_TYPE_WRITE)); + + let bezout_coefficient_polynomial_coefficient_0_is_0 = + main_row(Self::MainColumn::BezoutCoefficientPolynomialCoefficient0); + let bezout_coefficient_0_is_0 = aux_row(Self::AuxColumn::BezoutCoefficient0); + let bezout_coefficient_1_is_bezout_coefficient_polynomial_coefficient_1 = + aux_row(Self::AuxColumn::BezoutCoefficient1) + - main_row(Self::MainColumn::BezoutCoefficientPolynomialCoefficient1); + let formal_derivative_is_1 = + aux_row(Self::AuxColumn::FormalDerivative) - constant(1_u32.into()); + let running_product_polynomial_is_initialized_correctly = + aux_row(Self::AuxColumn::RunningProductOfRAMP) + - challenge(ChallengeId::RamTableBezoutRelationIndeterminate) + + main_row(Self::MainColumn::RamPointer); + + let clock_jump_diff_log_derivative_is_default_initial = + aux_row(Self::AuxColumn::ClockJumpDifferenceLookupClientLogDerivative) + - x_constant(LookupArg::default_initial()); + + let compressed_row_for_permutation_argument = main_row(Self::MainColumn::CLK) + * challenge(ChallengeId::RamClkWeight) + + main_row(Self::MainColumn::InstructionType) + * challenge(ChallengeId::RamInstructionTypeWeight) + + main_row(Self::MainColumn::RamPointer) * challenge(ChallengeId::RamPointerWeight) + + main_row(Self::MainColumn::RamValue) * challenge(ChallengeId::RamValueWeight); + let running_product_permutation_argument_has_accumulated_first_row = + aux_row(Self::AuxColumn::RunningProductPermArg) + - challenge(ChallengeId::RamIndeterminate) + + compressed_row_for_permutation_argument; + let running_product_permutation_argument_is_default_initial = + aux_row(Self::AuxColumn::RunningProductPermArg) + - x_constant(PermArg::default_initial()); + + let running_product_permutation_argument_starts_correctly = + running_product_permutation_argument_has_accumulated_first_row + * first_row_is_padding_row + + running_product_permutation_argument_is_default_initial + * first_row_is_not_padding_row; + + vec![ + bezout_coefficient_polynomial_coefficient_0_is_0, + bezout_coefficient_0_is_0, + bezout_coefficient_1_is_bezout_coefficient_polynomial_coefficient_1, + running_product_polynomial_is_initialized_correctly, + formal_derivative_is_1, + running_product_permutation_argument_starts_correctly, + clock_jump_diff_log_derivative_is_default_initial, + ] + } + + fn consistency_constraints( + _circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + // no further constraints + vec![] + } + + fn transition_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + let constant = |c| circuit_builder.b_constant(c); + let challenge = |c| circuit_builder.challenge(c); + let curr_base_row = |column: Self::MainColumn| { + circuit_builder.input(CurrentBaseRow(column.master_base_table_index())) + }; + let curr_ext_row = |column: Self::AuxColumn| { + circuit_builder.input(CurrentExtRow(column.master_ext_table_index())) + }; + let next_base_row = |column: Self::MainColumn| { + circuit_builder.input(NextBaseRow(column.master_base_table_index())) + }; + let next_ext_row = |column: Self::AuxColumn| { + circuit_builder.input(NextExtRow(column.master_ext_table_index())) + }; + + let one = constant(1_u32.into()); + + let bezout_challenge = challenge(ChallengeId::RamTableBezoutRelationIndeterminate); + + let clock = curr_base_row(Self::MainColumn::CLK); + let ram_pointer = curr_base_row(Self::MainColumn::RamPointer); + let ram_value = curr_base_row(Self::MainColumn::RamValue); + let instruction_type = curr_base_row(Self::MainColumn::InstructionType); + let inverse_of_ram_pointer_difference = + curr_base_row(Self::MainColumn::InverseOfRampDifference); + let bcpc0 = curr_base_row(Self::MainColumn::BezoutCoefficientPolynomialCoefficient0); + let bcpc1 = curr_base_row(Self::MainColumn::BezoutCoefficientPolynomialCoefficient1); + + let running_product_ram_pointer = curr_ext_row(Self::AuxColumn::RunningProductOfRAMP); + let fd = curr_ext_row(Self::AuxColumn::FormalDerivative); + let bc0 = curr_ext_row(Self::AuxColumn::BezoutCoefficient0); + let bc1 = curr_ext_row(Self::AuxColumn::BezoutCoefficient1); + let rppa = curr_ext_row(Self::AuxColumn::RunningProductPermArg); + let clock_jump_diff_log_derivative = + curr_ext_row(Self::AuxColumn::ClockJumpDifferenceLookupClientLogDerivative); + + let clock_next = next_base_row(Self::MainColumn::CLK); + let ram_pointer_next = next_base_row(Self::MainColumn::RamPointer); + let ram_value_next = next_base_row(Self::MainColumn::RamValue); + let instruction_type_next = next_base_row(Self::MainColumn::InstructionType); + let bcpc0_next = next_base_row(Self::MainColumn::BezoutCoefficientPolynomialCoefficient0); + let bcpc1_next = next_base_row(Self::MainColumn::BezoutCoefficientPolynomialCoefficient1); + + let running_product_ram_pointer_next = next_ext_row(Self::AuxColumn::RunningProductOfRAMP); + let fd_next = next_ext_row(Self::AuxColumn::FormalDerivative); + let bc0_next = next_ext_row(Self::AuxColumn::BezoutCoefficient0); + let bc1_next = next_ext_row(Self::AuxColumn::BezoutCoefficient1); + let rppa_next = next_ext_row(Self::AuxColumn::RunningProductPermArg); + let clock_jump_diff_log_derivative_next = + next_ext_row(Self::AuxColumn::ClockJumpDifferenceLookupClientLogDerivative); + + let next_row_is_padding_row = + instruction_type_next.clone() - constant(PADDING_INDICATOR).clone(); + let if_current_row_is_padding_row_then_next_row_is_padding_row = (instruction_type.clone() + - constant(INSTRUCTION_TYPE_READ)) + * (instruction_type - constant(INSTRUCTION_TYPE_WRITE)) + * next_row_is_padding_row.clone(); + + let ram_pointer_difference = ram_pointer_next.clone() - ram_pointer; + let ram_pointer_changes = one.clone() + - ram_pointer_difference.clone() * inverse_of_ram_pointer_difference.clone(); + + let iord_is_0_or_iord_is_inverse_of_ram_pointer_difference = + inverse_of_ram_pointer_difference * ram_pointer_changes.clone(); + + let ram_pointer_difference_is_0_or_iord_is_inverse_of_ram_pointer_difference = + ram_pointer_difference.clone() * ram_pointer_changes.clone(); + + let ram_pointer_changes_or_write_mem_or_ram_value_stays = ram_pointer_changes.clone() + * (constant(INSTRUCTION_TYPE_WRITE) - instruction_type_next.clone()) + * (ram_value_next.clone() - ram_value); + + let bcbp0_only_changes_if_ram_pointer_changes = + ram_pointer_changes.clone() * (bcpc0_next.clone() - bcpc0); + + let bcbp1_only_changes_if_ram_pointer_changes = + ram_pointer_changes.clone() * (bcpc1_next.clone() - bcpc1); + + let running_product_ram_pointer_updates_correctly = ram_pointer_difference.clone() + * (running_product_ram_pointer_next.clone() + - running_product_ram_pointer.clone() + * (bezout_challenge.clone() - ram_pointer_next.clone())) + + ram_pointer_changes.clone() + * (running_product_ram_pointer_next - running_product_ram_pointer.clone()); + + let formal_derivative_updates_correctly = ram_pointer_difference.clone() + * (fd_next.clone() + - running_product_ram_pointer + - (bezout_challenge.clone() - ram_pointer_next.clone()) * fd.clone()) + + ram_pointer_changes.clone() * (fd_next - fd); + + let bezout_coefficient_0_is_constructed_correctly = ram_pointer_difference.clone() + * (bc0_next.clone() - bezout_challenge.clone() * bc0.clone() - bcpc0_next) + + ram_pointer_changes.clone() * (bc0_next - bc0); + + let bezout_coefficient_1_is_constructed_correctly = ram_pointer_difference.clone() + * (bc1_next.clone() - bezout_challenge * bc1.clone() - bcpc1_next) + + ram_pointer_changes.clone() * (bc1_next - bc1); + + let compressed_row = clock_next.clone() * challenge(ChallengeId::RamClkWeight) + + ram_pointer_next * challenge(ChallengeId::RamPointerWeight) + + ram_value_next * challenge(ChallengeId::RamValueWeight) + + instruction_type_next.clone() * challenge(ChallengeId::RamInstructionTypeWeight); + let rppa_accumulates_next_row = rppa_next.clone() + - rppa.clone() * (challenge(ChallengeId::RamIndeterminate) - compressed_row); + + let next_row_is_not_padding_row = (instruction_type_next.clone() + - constant(INSTRUCTION_TYPE_READ)) + * (instruction_type_next - constant(INSTRUCTION_TYPE_WRITE)); + let rppa_remains_unchanged = rppa_next - rppa; + + let rppa_updates_correctly = rppa_accumulates_next_row * next_row_is_padding_row.clone() + + rppa_remains_unchanged * next_row_is_not_padding_row.clone(); + + let clock_difference = clock_next - clock; + let log_derivative_accumulates = (clock_jump_diff_log_derivative_next.clone() + - clock_jump_diff_log_derivative.clone()) + * (challenge(ChallengeId::ClockJumpDifferenceLookupIndeterminate) - clock_difference) + - one.clone(); + let log_derivative_remains = + clock_jump_diff_log_derivative_next - clock_jump_diff_log_derivative.clone(); + + let log_derivative_accumulates_or_ram_pointer_changes_or_next_row_is_padding_row = + log_derivative_accumulates * ram_pointer_changes.clone() * next_row_is_padding_row; + let log_derivative_remains_or_ram_pointer_doesnt_change = + log_derivative_remains.clone() * ram_pointer_difference.clone(); + let log_derivative_remains_or_next_row_is_not_padding_row = + log_derivative_remains * next_row_is_not_padding_row; + + let log_derivative_updates_correctly = + log_derivative_accumulates_or_ram_pointer_changes_or_next_row_is_padding_row + + log_derivative_remains_or_ram_pointer_doesnt_change + + log_derivative_remains_or_next_row_is_not_padding_row; + + vec![ + if_current_row_is_padding_row_then_next_row_is_padding_row, + iord_is_0_or_iord_is_inverse_of_ram_pointer_difference, + ram_pointer_difference_is_0_or_iord_is_inverse_of_ram_pointer_difference, + ram_pointer_changes_or_write_mem_or_ram_value_stays, + bcbp0_only_changes_if_ram_pointer_changes, + bcbp1_only_changes_if_ram_pointer_changes, + running_product_ram_pointer_updates_correctly, + formal_derivative_updates_correctly, + bezout_coefficient_0_is_constructed_correctly, + bezout_coefficient_1_is_constructed_correctly, + rppa_updates_correctly, + log_derivative_updates_correctly, + ] + } + + fn terminal_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + let constant = |c: u32| circuit_builder.b_constant(c); + let ext_row = |column: Self::AuxColumn| { + circuit_builder.input(ExtRow(column.master_ext_table_index())) + }; + + let bezout_relation_holds = ext_row(Self::AuxColumn::BezoutCoefficient0) + * ext_row(Self::AuxColumn::RunningProductOfRAMP) + + ext_row(Self::AuxColumn::BezoutCoefficient1) + * ext_row(Self::AuxColumn::FormalDerivative) + - constant(1); + + vec![bezout_relation_holds] + } +} diff --git a/triton-air/src/table/u32.rs b/triton-air/src/table/u32.rs new file mode 100644 index 000000000..ea3c37464 --- /dev/null +++ b/triton-air/src/table/u32.rs @@ -0,0 +1,395 @@ +use constraint_circuit::ConstraintCircuitBuilder; +use constraint_circuit::ConstraintCircuitMonad; +use constraint_circuit::DualRowIndicator; +use constraint_circuit::DualRowIndicator::CurrentBaseRow; +use constraint_circuit::DualRowIndicator::CurrentExtRow; +use constraint_circuit::DualRowIndicator::NextBaseRow; +use constraint_circuit::DualRowIndicator::NextExtRow; +use constraint_circuit::InputIndicator; +use constraint_circuit::SingleRowIndicator; +use constraint_circuit::SingleRowIndicator::BaseRow; +use constraint_circuit::SingleRowIndicator::ExtRow; +use isa::instruction::Instruction; +use std::ops::Mul; + +use crate::challenge_id::ChallengeId; +use crate::cross_table_argument::CrossTableArg; +use crate::cross_table_argument::LookupArg; +use crate::table_column::MasterBaseTableColumn; +use crate::table_column::MasterExtTableColumn; +use crate::AIR; + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct U32Table; + +impl AIR for U32Table { + type MainColumn = crate::table_column::U32BaseTableColumn; + type AuxColumn = crate::table_column::U32ExtTableColumn; + + fn initial_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + let main_row = |column: Self::MainColumn| { + circuit_builder.input(BaseRow(column.master_base_table_index())) + }; + let aux_row = |column: Self::AuxColumn| { + circuit_builder.input(ExtRow(column.master_ext_table_index())) + }; + let challenge = |c| circuit_builder.challenge(c); + let one = circuit_builder.b_constant(1); + + let copy_flag = main_row(Self::MainColumn::CopyFlag); + let lhs = main_row(Self::MainColumn::LHS); + let rhs = main_row(Self::MainColumn::RHS); + let ci = main_row(Self::MainColumn::CI); + let result = main_row(Self::MainColumn::Result); + let lookup_multiplicity = main_row(Self::MainColumn::LookupMultiplicity); + + let running_sum_log_derivative = aux_row(Self::AuxColumn::LookupServerLogDerivative); + + let compressed_row = challenge(ChallengeId::U32LhsWeight) * lhs + + challenge(ChallengeId::U32RhsWeight) * rhs + + challenge(ChallengeId::U32CiWeight) * ci + + challenge(ChallengeId::U32ResultWeight) * result; + let if_copy_flag_is_1_then_log_derivative_has_accumulated_first_row = copy_flag.clone() + * (running_sum_log_derivative.clone() + * (challenge(ChallengeId::U32Indeterminate) - compressed_row) + - lookup_multiplicity); + + let default_initial = circuit_builder.x_constant(LookupArg::default_initial()); + let if_copy_flag_is_0_then_log_derivative_is_default_initial = + (copy_flag - one) * (running_sum_log_derivative - default_initial); + + let running_sum_log_derivative_starts_correctly = + if_copy_flag_is_0_then_log_derivative_is_default_initial + + if_copy_flag_is_1_then_log_derivative_has_accumulated_first_row; + + vec![running_sum_log_derivative_starts_correctly] + } + + fn consistency_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + let main_row = |column: Self::MainColumn| { + circuit_builder.input(BaseRow(column.master_base_table_index())) + }; + let one = || circuit_builder.b_constant(1); + let two = || circuit_builder.b_constant(2); + + let copy_flag = main_row(Self::MainColumn::CopyFlag); + let bits = main_row(Self::MainColumn::Bits); + let bits_minus_33_inv = main_row(Self::MainColumn::BitsMinus33Inv); + let ci = main_row(Self::MainColumn::CI); + let lhs = main_row(Self::MainColumn::LHS); + let lhs_inv = main_row(Self::MainColumn::LhsInv); + let rhs = main_row(Self::MainColumn::RHS); + let rhs_inv = main_row(Self::MainColumn::RhsInv); + let result = main_row(Self::MainColumn::Result); + let lookup_multiplicity = main_row(Self::MainColumn::LookupMultiplicity); + + let instruction_deselector = |instruction_to_select| { + instruction_deselector(instruction_to_select, circuit_builder, &ci) + }; + + let copy_flag_is_bit = copy_flag.clone() * (one() - copy_flag.clone()); + let copy_flag_is_0_or_bits_is_0 = copy_flag.clone() * bits.clone(); + let bits_minus_33_inv_is_inverse_of_bits_minus_33 = + one() - bits_minus_33_inv * (bits - circuit_builder.b_constant(33)); + let lhs_inv_is_0_or_the_inverse_of_lhs = + lhs_inv.clone() * (one() - lhs.clone() * lhs_inv.clone()); + let lhs_is_0_or_lhs_inverse_is_the_inverse_of_lhs = + lhs.clone() * (one() - lhs.clone() * lhs_inv.clone()); + let rhs_inv_is_0_or_the_inverse_of_rhs = + rhs_inv.clone() * (one() - rhs.clone() * rhs_inv.clone()); + let rhs_is_0_or_rhs_inverse_is_the_inverse_of_rhs = + rhs.clone() * (one() - rhs.clone() * rhs_inv.clone()); + let result_is_initialized_correctly_for_lt_with_copy_flag_0 = + instruction_deselector(Instruction::Lt) + * (copy_flag.clone() - one()) + * (one() - lhs.clone() * lhs_inv.clone()) + * (one() - rhs.clone() * rhs_inv.clone()) + * (result.clone() - two()); + let result_is_initialized_correctly_for_lt_with_copy_flag_1 = + instruction_deselector(Instruction::Lt) + * copy_flag.clone() + * (one() - lhs.clone() * lhs_inv.clone()) + * (one() - rhs.clone() * rhs_inv.clone()) + * result.clone(); + let result_is_initialized_correctly_for_and = instruction_deselector(Instruction::And) + * (one() - lhs.clone() * lhs_inv.clone()) + * (one() - rhs.clone() * rhs_inv.clone()) + * result.clone(); + let result_is_initialized_correctly_for_pow = instruction_deselector(Instruction::Pow) + * (one() - rhs * rhs_inv) + * (result.clone() - one()); + let result_is_initialized_correctly_for_log_2_floor = + instruction_deselector(Instruction::Log2Floor) + * (copy_flag.clone() - one()) + * (one() - lhs.clone() * lhs_inv.clone()) + * (result.clone() + one()); + let result_is_initialized_correctly_for_pop_count = + instruction_deselector(Instruction::PopCount) + * (one() - lhs.clone() * lhs_inv.clone()) + * result; + let if_log_2_floor_on_0_then_vm_crashes = instruction_deselector(Instruction::Log2Floor) + * copy_flag.clone() + * (one() - lhs * lhs_inv); + let if_copy_flag_is_0_then_lookup_multiplicity_is_0 = + (copy_flag - one()) * lookup_multiplicity; + + vec![ + copy_flag_is_bit, + copy_flag_is_0_or_bits_is_0, + bits_minus_33_inv_is_inverse_of_bits_minus_33, + lhs_inv_is_0_or_the_inverse_of_lhs, + lhs_is_0_or_lhs_inverse_is_the_inverse_of_lhs, + rhs_inv_is_0_or_the_inverse_of_rhs, + rhs_is_0_or_rhs_inverse_is_the_inverse_of_rhs, + result_is_initialized_correctly_for_lt_with_copy_flag_0, + result_is_initialized_correctly_for_lt_with_copy_flag_1, + result_is_initialized_correctly_for_and, + result_is_initialized_correctly_for_pow, + result_is_initialized_correctly_for_log_2_floor, + result_is_initialized_correctly_for_pop_count, + if_log_2_floor_on_0_then_vm_crashes, + if_copy_flag_is_0_then_lookup_multiplicity_is_0, + ] + } + + fn transition_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + let curr_main_row = |column: Self::MainColumn| { + circuit_builder.input(CurrentBaseRow(column.master_base_table_index())) + }; + let next_main_row = |column: Self::MainColumn| { + circuit_builder.input(NextBaseRow(column.master_base_table_index())) + }; + let curr_aux_row = |column: Self::AuxColumn| { + circuit_builder.input(CurrentExtRow(column.master_ext_table_index())) + }; + let next_aux_row = |column: Self::AuxColumn| { + circuit_builder.input(NextExtRow(column.master_ext_table_index())) + }; + let challenge = |c| circuit_builder.challenge(c); + let one = || circuit_builder.b_constant(1); + let two = || circuit_builder.b_constant(2); + + let copy_flag = curr_main_row(Self::MainColumn::CopyFlag); + let bits = curr_main_row(Self::MainColumn::Bits); + let ci = curr_main_row(Self::MainColumn::CI); + let lhs = curr_main_row(Self::MainColumn::LHS); + let rhs = curr_main_row(Self::MainColumn::RHS); + let result = curr_main_row(Self::MainColumn::Result); + let running_sum_log_derivative = curr_aux_row(Self::AuxColumn::LookupServerLogDerivative); + + let copy_flag_next = next_main_row(Self::MainColumn::CopyFlag); + let bits_next = next_main_row(Self::MainColumn::Bits); + let ci_next = next_main_row(Self::MainColumn::CI); + let lhs_next = next_main_row(Self::MainColumn::LHS); + let rhs_next = next_main_row(Self::MainColumn::RHS); + let result_next = next_main_row(Self::MainColumn::Result); + let lhs_inv_next = next_main_row(Self::MainColumn::LhsInv); + let lookup_multiplicity_next = next_main_row(Self::MainColumn::LookupMultiplicity); + let running_sum_log_derivative_next = + next_aux_row(Self::AuxColumn::LookupServerLogDerivative); + + let instruction_deselector = |instruction_to_select: Instruction| { + instruction_deselector(instruction_to_select, circuit_builder, &ci_next) + }; + + // helpful aliases + let ci_is_pow = ci.clone() - circuit_builder.b_constant(Instruction::Pow.opcode_b()); + let lhs_lsb = lhs.clone() - two() * lhs_next.clone(); + let rhs_lsb = rhs.clone() - two() * rhs_next.clone(); + + // general constraints + let if_copy_flag_next_is_1_then_lhs_is_0_or_ci_is_pow = + copy_flag_next.clone() * lhs.clone() * ci_is_pow.clone(); + let if_copy_flag_next_is_1_then_rhs_is_0 = copy_flag_next.clone() * rhs.clone(); + let if_copy_flag_next_is_0_then_ci_stays = + (copy_flag_next.clone() - one()) * (ci_next.clone() - ci); + let if_copy_flag_next_is_0_and_lhs_next_is_nonzero_and_ci_not_pow_then_bits_increases_by_1 = + (copy_flag_next.clone() - one()) + * lhs.clone() + * ci_is_pow.clone() + * (bits_next.clone() - bits.clone() - one()); + let if_copy_flag_next_is_0_and_rhs_next_is_nonzero_then_bits_increases_by_1 = + (copy_flag_next.clone() - one()) * rhs * (bits_next - bits.clone() - one()); + let if_copy_flag_next_is_0_and_ci_not_pow_then_lhs_lsb_is_a_bit = (copy_flag_next.clone() + - one()) + * ci_is_pow + * lhs_lsb.clone() + * (lhs_lsb.clone() - one()); + let if_copy_flag_next_is_0_then_rhs_lsb_is_a_bit = + (copy_flag_next.clone() - one()) * rhs_lsb.clone() * (rhs_lsb.clone() - one()); + + // instruction lt + let if_copy_flag_next_is_0_and_ci_is_lt_and_result_next_is_0_then_result_is_0 = + (copy_flag_next.clone() - one()) + * instruction_deselector(Instruction::Lt) + * (result_next.clone() - one()) + * (result_next.clone() - two()) + * result.clone(); + let if_copy_flag_next_is_0_and_ci_is_lt_and_result_next_is_1_then_result_is_1 = + (copy_flag_next.clone() - one()) + * instruction_deselector(Instruction::Lt) + * result_next.clone() + * (result_next.clone() - two()) + * (result.clone() - one()); + let if_copy_flag_next_is_0_and_ci_is_lt_and_result_next_is_2_and_lt_is_0_then_result_is_0 = + (copy_flag_next.clone() - one()) + * instruction_deselector(Instruction::Lt) + * result_next.clone() + * (result_next.clone() - one()) + * (lhs_lsb.clone() - one()) + * rhs_lsb.clone() + * (result.clone() - one()); + let if_copy_flag_next_is_0_and_ci_is_lt_and_result_next_is_2_and_lt_is_1_then_result_is_1 = + (copy_flag_next.clone() - one()) + * instruction_deselector(Instruction::Lt) + * result_next.clone() + * (result_next.clone() - one()) + * lhs_lsb.clone() + * (rhs_lsb.clone() - one()) + * result.clone(); + let if_copy_flag_next_is_0_and_ci_is_lt_and_result_still_not_known_then_result_is_2 = + (copy_flag_next.clone() - one()) + * instruction_deselector(Instruction::Lt) + * result_next.clone() + * (result_next.clone() - one()) + * (one() - lhs_lsb.clone() - rhs_lsb.clone() + + two() * lhs_lsb.clone() * rhs_lsb.clone()) + * (copy_flag.clone() - one()) + * (result.clone() - two()); + let if_copy_flag_next_is_0_and_ci_is_lt_and_copy_flag_dictates_result_then_result_is_0 = + (copy_flag_next.clone() - one()) + * instruction_deselector(Instruction::Lt) + * result_next.clone() + * (result_next.clone() - one()) + * (one() - lhs_lsb.clone() - rhs_lsb.clone() + + two() * lhs_lsb.clone() * rhs_lsb.clone()) + * copy_flag + * result.clone(); + + // instruction and + let if_copy_flag_next_is_0_and_ci_is_and_then_results_updates_correctly = (copy_flag_next + .clone() + - one()) + * instruction_deselector(Instruction::And) + * (result.clone() - two() * result_next.clone() - lhs_lsb.clone() * rhs_lsb.clone()); + + // instruction log_2_floor + let if_copy_flag_next_is_0_and_ci_is_log_2_floor_lhs_next_0_for_first_time_then_set_result = + (copy_flag_next.clone() - one()) + * instruction_deselector(Instruction::Log2Floor) + * (one() - lhs_next.clone() * lhs_inv_next) + * lhs.clone() + * (result.clone() - bits); + let if_copy_flag_next_is_0_and_ci_is_log_2_floor_and_lhs_next_not_0_then_copy_result = + (copy_flag_next.clone() - one()) + * instruction_deselector(Instruction::Log2Floor) + * lhs_next.clone() + * (result_next.clone() - result.clone()); + + // instruction pow + let if_copy_flag_next_is_0_and_ci_is_pow_then_lhs_remains_unchanged = + (copy_flag_next.clone() - one()) + * instruction_deselector(Instruction::Pow) + * (lhs_next.clone() - lhs.clone()); + + let if_copy_flag_next_is_0_and_ci_is_pow_and_rhs_lsb_is_0_then_result_squares = + (copy_flag_next.clone() - one()) + * instruction_deselector(Instruction::Pow) + * (rhs_lsb.clone() - one()) + * (result.clone() - result_next.clone() * result_next.clone()); + + let if_copy_flag_next_is_0_and_ci_is_pow_and_rhs_lsb_is_1_then_result_squares_and_mults = + (copy_flag_next.clone() - one()) + * instruction_deselector(Instruction::Pow) + * rhs_lsb + * (result.clone() - result_next.clone() * result_next.clone() * lhs); + + let if_copy_flag_next_is_0_and_ci_is_pop_count_then_result_increases_by_lhs_lsb = + (copy_flag_next.clone() - one()) + * instruction_deselector(Instruction::PopCount) + * (result - result_next.clone() - lhs_lsb); + + // running sum for Lookup Argument with Processor Table + let if_copy_flag_next_is_0_then_running_sum_log_derivative_stays = (copy_flag_next.clone() + - one()) + * (running_sum_log_derivative_next.clone() - running_sum_log_derivative.clone()); + + let compressed_row_next = challenge(ChallengeId::U32CiWeight) * ci_next + + challenge(ChallengeId::U32LhsWeight) * lhs_next + + challenge(ChallengeId::U32RhsWeight) * rhs_next + + challenge(ChallengeId::U32ResultWeight) * result_next; + let if_copy_flag_next_is_1_then_running_sum_log_derivative_accumulates_next_row = + copy_flag_next + * ((running_sum_log_derivative_next - running_sum_log_derivative) + * (challenge(ChallengeId::U32Indeterminate) - compressed_row_next) + - lookup_multiplicity_next); + + vec![ + if_copy_flag_next_is_1_then_lhs_is_0_or_ci_is_pow, + if_copy_flag_next_is_1_then_rhs_is_0, + if_copy_flag_next_is_0_then_ci_stays, + if_copy_flag_next_is_0_and_lhs_next_is_nonzero_and_ci_not_pow_then_bits_increases_by_1, + if_copy_flag_next_is_0_and_rhs_next_is_nonzero_then_bits_increases_by_1, + if_copy_flag_next_is_0_and_ci_not_pow_then_lhs_lsb_is_a_bit, + if_copy_flag_next_is_0_then_rhs_lsb_is_a_bit, + if_copy_flag_next_is_0_and_ci_is_lt_and_result_next_is_0_then_result_is_0, + if_copy_flag_next_is_0_and_ci_is_lt_and_result_next_is_1_then_result_is_1, + if_copy_flag_next_is_0_and_ci_is_lt_and_result_next_is_2_and_lt_is_0_then_result_is_0, + if_copy_flag_next_is_0_and_ci_is_lt_and_result_next_is_2_and_lt_is_1_then_result_is_1, + if_copy_flag_next_is_0_and_ci_is_lt_and_result_still_not_known_then_result_is_2, + if_copy_flag_next_is_0_and_ci_is_lt_and_copy_flag_dictates_result_then_result_is_0, + if_copy_flag_next_is_0_and_ci_is_and_then_results_updates_correctly, + if_copy_flag_next_is_0_and_ci_is_log_2_floor_lhs_next_0_for_first_time_then_set_result, + if_copy_flag_next_is_0_and_ci_is_log_2_floor_and_lhs_next_not_0_then_copy_result, + if_copy_flag_next_is_0_and_ci_is_pow_then_lhs_remains_unchanged, + if_copy_flag_next_is_0_and_ci_is_pow_and_rhs_lsb_is_0_then_result_squares, + if_copy_flag_next_is_0_and_ci_is_pow_and_rhs_lsb_is_1_then_result_squares_and_mults, + if_copy_flag_next_is_0_and_ci_is_pop_count_then_result_increases_by_lhs_lsb, + if_copy_flag_next_is_0_then_running_sum_log_derivative_stays, + if_copy_flag_next_is_1_then_running_sum_log_derivative_accumulates_next_row, + ] + } + + fn terminal_constraints( + circuit_builder: &ConstraintCircuitBuilder, + ) -> Vec> { + let main_row = |column: Self::MainColumn| { + circuit_builder.input(BaseRow(column.master_base_table_index())) + }; + let constant = |c| circuit_builder.b_constant(c); + + let ci = main_row(Self::MainColumn::CI); + let lhs = main_row(Self::MainColumn::LHS); + let rhs = main_row(Self::MainColumn::RHS); + + let lhs_is_0_or_ci_is_pow = lhs * (ci - constant(Instruction::Pow.opcode_b())); + let rhs_is_0 = rhs; + + vec![lhs_is_0_or_ci_is_pow, rhs_is_0] + } +} + +fn instruction_deselector( + instruction_to_select: Instruction, + circuit_builder: &ConstraintCircuitBuilder, + current_instruction: &ConstraintCircuitMonad, +) -> ConstraintCircuitMonad { + [ + Instruction::Split, + Instruction::Lt, + Instruction::And, + Instruction::Log2Floor, + Instruction::Pow, + Instruction::PopCount, + ] + .into_iter() + .filter(|&instruction| instruction != instruction_to_select) + .map(|instr| current_instruction.clone() - circuit_builder.b_constant(instr.opcode_b())) + .fold(circuit_builder.b_constant(1), ConstraintCircuitMonad::mul) +} diff --git a/triton-vm/src/table/table_column.rs b/triton-air/src/table_column.rs similarity index 76% rename from triton-vm/src/table/table_column.rs rename to triton-air/src/table_column.rs index e1792a42a..7031da9c8 100644 --- a/triton-vm/src/table/table_column.rs +++ b/triton-air/src/table_column.rs @@ -1,5 +1,5 @@ -//! Enums that convert table column names into `usize` indices. Allows addressing columns by name -//! rather than their hard-to-remember index. +//! Enums that convert table column names into `usize` indices. Allows +//! addressing columns by name rather than their hard-to-remember index. use std::hash::Hash; @@ -7,28 +7,24 @@ use strum::Display; use strum::EnumCount; use strum::EnumIter; -use crate::table::degree_lowering_table::DegreeLoweringBaseTableColumn; -use crate::table::degree_lowering_table::DegreeLoweringExtTableColumn; -use crate::table::master_table::CASCADE_TABLE_START; -use crate::table::master_table::DEGREE_LOWERING_TABLE_START; -use crate::table::master_table::EXT_CASCADE_TABLE_START; -use crate::table::master_table::EXT_DEGREE_LOWERING_TABLE_START; -use crate::table::master_table::EXT_HASH_TABLE_START; -use crate::table::master_table::EXT_JUMP_STACK_TABLE_START; -use crate::table::master_table::EXT_LOOKUP_TABLE_START; -use crate::table::master_table::EXT_OP_STACK_TABLE_START; -use crate::table::master_table::EXT_PROCESSOR_TABLE_START; -use crate::table::master_table::EXT_PROGRAM_TABLE_START; -use crate::table::master_table::EXT_RAM_TABLE_START; -use crate::table::master_table::EXT_U32_TABLE_START; -use crate::table::master_table::HASH_TABLE_START; -use crate::table::master_table::JUMP_STACK_TABLE_START; -use crate::table::master_table::LOOKUP_TABLE_START; -use crate::table::master_table::OP_STACK_TABLE_START; -use crate::table::master_table::PROCESSOR_TABLE_START; -use crate::table::master_table::PROGRAM_TABLE_START; -use crate::table::master_table::RAM_TABLE_START; -use crate::table::master_table::U32_TABLE_START; +use crate::table::CASCADE_TABLE_START; +use crate::table::EXT_CASCADE_TABLE_START; +use crate::table::EXT_HASH_TABLE_START; +use crate::table::EXT_JUMP_STACK_TABLE_START; +use crate::table::EXT_LOOKUP_TABLE_START; +use crate::table::EXT_OP_STACK_TABLE_START; +use crate::table::EXT_PROCESSOR_TABLE_START; +use crate::table::EXT_PROGRAM_TABLE_START; +use crate::table::EXT_RAM_TABLE_START; +use crate::table::EXT_U32_TABLE_START; +use crate::table::HASH_TABLE_START; +use crate::table::JUMP_STACK_TABLE_START; +use crate::table::LOOKUP_TABLE_START; +use crate::table::OP_STACK_TABLE_START; +use crate::table::PROCESSOR_TABLE_START; +use crate::table::PROGRAM_TABLE_START; +use crate::table::RAM_TABLE_START; +use crate::table::U32_TABLE_START; #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] @@ -608,18 +604,6 @@ impl MasterBaseTableColumn for U32BaseTableColumn { } } -impl MasterBaseTableColumn for DegreeLoweringBaseTableColumn { - #[inline] - fn base_table_index(&self) -> usize { - (*self) as usize - } - - #[inline] - fn master_base_table_index(&self) -> usize { - DEGREE_LOWERING_TABLE_START + self.base_table_index() - } -} - /// A trait for the columns in the master extension table. This trait is implemented for all enums /// relating to the extension tables. The trait provides two methods: /// - one to get the index of the column in the “local” extension table, _i.e._, not the master @@ -742,189 +726,12 @@ impl MasterExtTableColumn for U32ExtTableColumn { } } -impl MasterExtTableColumn for DegreeLoweringExtTableColumn { - #[inline] - fn ext_table_index(&self) -> usize { - (*self) as usize - } - - #[inline] - fn master_ext_table_index(&self) -> usize { - EXT_DEGREE_LOWERING_TABLE_START + self.ext_table_index() - } -} - #[cfg(test)] mod tests { use strum::IntoEnumIterator; - use crate::table::cascade_table; - use crate::table::hash_table; - use crate::table::jump_stack_table; - use crate::table::lookup_table; - use crate::table::op_stack_table; - use crate::table::processor_table; - use crate::table::program_table; - use crate::table::ram_table; - use crate::table::u32_table; - use super::*; - #[test] - fn column_max_bound_matches_table_width() { - assert_eq!( - program_table::BASE_WIDTH, - ProgramBaseTableColumn::iter() - .last() - .unwrap() - .base_table_index() - + 1, - "ProgramTable's BASE_WIDTH is 1 + its max column index", - ); - assert_eq!( - processor_table::BASE_WIDTH, - ProcessorBaseTableColumn::iter() - .last() - .unwrap() - .base_table_index() - + 1, - "ProcessorTable's BASE_WIDTH is 1 + its max column index", - ); - assert_eq!( - op_stack_table::BASE_WIDTH, - OpStackBaseTableColumn::iter() - .last() - .unwrap() - .base_table_index() - + 1, - "OpStackTable's BASE_WIDTH is 1 + its max column index", - ); - assert_eq!( - ram_table::BASE_WIDTH, - RamBaseTableColumn::iter() - .last() - .unwrap() - .base_table_index() - + 1, - "RamTable's BASE_WIDTH is 1 + its max column index", - ); - assert_eq!( - jump_stack_table::BASE_WIDTH, - JumpStackBaseTableColumn::iter() - .last() - .unwrap() - .base_table_index() - + 1, - "JumpStackTable's BASE_WIDTH is 1 + its max column index", - ); - assert_eq!( - hash_table::BASE_WIDTH, - HashBaseTableColumn::iter() - .last() - .unwrap() - .base_table_index() - + 1, - "HashTable's BASE_WIDTH is 1 + its max column index", - ); - assert_eq!( - cascade_table::BASE_WIDTH, - CascadeBaseTableColumn::iter() - .last() - .unwrap() - .base_table_index() - + 1, - "CascadeTable's BASE_WIDTH is 1 + its max column index", - ); - assert_eq!( - lookup_table::BASE_WIDTH, - LookupBaseTableColumn::iter() - .last() - .unwrap() - .base_table_index() - + 1, - "LookupTable's BASE_WIDTH is 1 + its max column index", - ); - assert_eq!( - u32_table::BASE_WIDTH, - U32BaseTableColumn::iter() - .last() - .unwrap() - .base_table_index() - + 1, - "U32Table's BASE_WIDTH is 1 + its max column index", - ); - - assert_eq!( - program_table::EXT_WIDTH, - ProgramExtTableColumn::iter() - .last() - .unwrap() - .ext_table_index() - + 1, - "ProgramTable's EXT_WIDTH is 1 + its max column index", - ); - assert_eq!( - processor_table::EXT_WIDTH, - ProcessorExtTableColumn::iter() - .last() - .unwrap() - .ext_table_index() - + 1, - "ProcessorTable's EXT_WIDTH is 1 + its max column index", - ); - assert_eq!( - op_stack_table::EXT_WIDTH, - OpStackExtTableColumn::iter() - .last() - .unwrap() - .ext_table_index() - + 1, - "OpStack:Table's EXT_WIDTH is 1 + its max column index", - ); - assert_eq!( - ram_table::EXT_WIDTH, - RamExtTableColumn::iter().last().unwrap().ext_table_index() + 1, - "RamTable's EXT_WIDTH is 1 + its max column index", - ); - assert_eq!( - jump_stack_table::EXT_WIDTH, - JumpStackExtTableColumn::iter() - .last() - .unwrap() - .ext_table_index() - + 1, - "JumpStack:Table's EXT_WIDTH is 1 + its max column index", - ); - assert_eq!( - hash_table::EXT_WIDTH, - HashExtTableColumn::iter().last().unwrap().ext_table_index() + 1, - "HashTable's EXT_WIDTH is 1 + its max column index", - ); - assert_eq!( - cascade_table::EXT_WIDTH, - CascadeExtTableColumn::iter() - .last() - .unwrap() - .ext_table_index() - + 1, - "CascadeTable's EXT_WIDTH is 1 + its max column index", - ); - assert_eq!( - lookup_table::EXT_WIDTH, - LookupExtTableColumn::iter() - .last() - .unwrap() - .ext_table_index() - + 1, - "LookupTable's EXT_WIDTH is 1 + its max column index", - ); - assert_eq!( - u32_table::EXT_WIDTH, - U32ExtTableColumn::iter().last().unwrap().ext_table_index() + 1, - "U32Table's EXT_WIDTH is 1 + its max column index", - ); - } - #[test] fn master_base_table_is_contiguous() { let mut expected_column_index = 0; @@ -964,10 +771,6 @@ mod tests { assert_eq!(expected_column_index, column.master_base_table_index()); expected_column_index += 1; } - for column in DegreeLoweringBaseTableColumn::iter() { - assert_eq!(expected_column_index, column.master_base_table_index()); - expected_column_index += 1; - } } #[test] @@ -1009,9 +812,5 @@ mod tests { assert_eq!(expected_column_index, column.master_ext_table_index()); expected_column_index += 1; } - for column in DegreeLoweringExtTableColumn::iter() { - assert_eq!(expected_column_index, column.master_ext_table_index()); - expected_column_index += 1; - } } } diff --git a/triton-vm/Cargo.toml b/triton-vm/Cargo.toml index 94b94fc22..31622881e 100644 --- a/triton-vm/Cargo.toml +++ b/triton-vm/Cargo.toml @@ -18,6 +18,7 @@ repository.workspace = true readme.workspace = true [dependencies] +air.workspace = true arbitrary.workspace = true colored.workspace = true constraint-circuit.workspace = true diff --git a/triton-vm/benches/bezout_coeffs.rs b/triton-vm/benches/bezout_coeffs.rs index 4427ef33d..8c62ebb59 100644 --- a/triton-vm/benches/bezout_coeffs.rs +++ b/triton-vm/benches/bezout_coeffs.rs @@ -6,7 +6,7 @@ use num_traits::Zero; use twenty_first::prelude::*; use triton_vm::prelude::*; -use triton_vm::table::ram_table::RamTable; +use triton_vm::table::ram::bezout_coefficient_polynomials_coefficients; criterion_main!(benches); criterion_group!( @@ -32,7 +32,7 @@ fn current_design(c: &mut Criterion) { let roots = unique_roots::(); let bench_id = format!("Bézout coefficients (current design) (degree {N})"); c.bench_function(&bench_id, |b| { - b.iter(|| RamTable::bezout_coefficient_polynomials_coefficients(&roots)) + b.iter(|| bezout_coefficient_polynomials_coefficients(&roots)) }); } diff --git a/triton-vm/benches/prove_halt.rs b/triton-vm/benches/prove_halt.rs index 598304677..9f14e4a54 100644 --- a/triton-vm/benches/prove_halt.rs +++ b/triton-vm/benches/prove_halt.rs @@ -1,9 +1,9 @@ +use air::table::TableId; use criterion::criterion_group; use criterion::criterion_main; use criterion::Criterion; use triton_vm::prelude::*; -use triton_vm::table::master_table::TableId; criterion_main!(benches); diff --git a/triton-vm/benches/verify_halt.rs b/triton-vm/benches/verify_halt.rs index 0f83fe2a1..445ed0ded 100644 --- a/triton-vm/benches/verify_halt.rs +++ b/triton-vm/benches/verify_halt.rs @@ -1,6 +1,5 @@ use criterion::criterion_group; use criterion::criterion_main; -use criterion::BenchmarkId; use criterion::Criterion; use triton_vm::prelude::*; @@ -11,12 +10,13 @@ fn verify_halt(criterion: &mut Criterion) { let stark = Stark::default(); let claim = Claim::about_program(&program); - let (aet, _) = VM::trace_execution(&program, [].into(), [].into()).unwrap(); let proof = stark.prove(&claim, &aet).unwrap(); triton_vm::profiler::start("Verify Halt"); - stark.verify(&claim, &proof).unwrap(); + criterion.bench_function("Verify Halt", |b| { + b.iter(|| stark.verify(&claim, &proof).unwrap()) + }); let profile = triton_vm::profiler::finish(); let padded_height = proof.padded_height().unwrap(); @@ -25,23 +25,12 @@ fn verify_halt(criterion: &mut Criterion) { .with_cycle_count(aet.processor_trace.nrows()) .with_padded_height(padded_height) .with_fri_domain_len(fri.domain.length); - - let bench_id = BenchmarkId::new("VerifyHalt", 0); - let mut group = criterion.benchmark_group("verify_halt"); - group.sample_size(10); - group.bench_function(bench_id, |bencher| { - bencher.iter(|| { - let _ = stark.verify(&claim, &proof); - }); - }); - group.finish(); - eprintln!("{profile}"); } criterion_group! { name = benches; - config = Criterion::default(); + config = Criterion::default().sample_size(10); targets = verify_halt } diff --git a/triton-vm/src/aet.rs b/triton-vm/src/aet.rs index e0abb5c89..38a2997cf 100644 --- a/triton-vm/src/aet.rs +++ b/triton-vm/src/aet.rs @@ -3,6 +3,15 @@ use std::collections::hash_map::Entry::Vacant; use std::collections::HashMap; use std::ops::AddAssign; +use air::table::hash::HashTable; +use air::table::hash::PermutationTrace; +use air::table::op_stack; +use air::table::processor; +use air::table::ram; +use air::table::TableId; +use air::table_column::HashBaseTableColumn::CI; +use air::table_column::MasterBaseTableColumn; +use air::AIR; use arbitrary::Arbitrary; use isa::error::InstructionError; use isa::error::InstructionError::InstructionPointerOverflow; @@ -12,18 +21,14 @@ use itertools::Itertools; use ndarray::s; use ndarray::Array2; use ndarray::Axis; +use strum::EnumCount; use strum::IntoEnumIterator; use twenty_first::prelude::*; -use crate::table::hash_table::HashTable; -use crate::table::hash_table::PermutationTrace; -use crate::table::master_table::TableId; -use crate::table::op_stack_table::OpStackTableEntry; -use crate::table::ram_table::RamTableCall; -use crate::table::table_column::HashBaseTableColumn::CI; -use crate::table::table_column::MasterBaseTableColumn; -use crate::table::u32_table::U32TableEntry; -use crate::table::*; +use crate::table; +use crate::table::op_stack::OpStackTableEntry; +use crate::table::ram::RamTableCall; +use crate::table::u32::U32TableEntry; use crate::vm::CoProcessorCall; use crate::vm::CoProcessorCall::*; use crate::vm::VMState; @@ -88,17 +93,22 @@ impl AlgebraicExecutionTrace { pub(crate) const LOOKUP_TABLE_HEIGHT: usize = 1 << 8; pub fn new(program: Program) -> Self { + const PROCESSOR_WIDTH: usize = ::MainColumn::COUNT; + const OP_STACK_WIDTH: usize = ::MainColumn::COUNT; + const RAM_WIDTH: usize = ::MainColumn::COUNT; + const HASH_WIDTH: usize = ::MainColumn::COUNT; + let program_len = program.len_bwords(); let mut aet = Self { program, instruction_multiplicities: vec![0_u32; program_len], - processor_trace: Array2::default([0, processor_table::BASE_WIDTH]), - op_stack_underflow_trace: Array2::default([0, op_stack_table::BASE_WIDTH]), - ram_trace: Array2::default([0, ram_table::BASE_WIDTH]), - program_hash_trace: Array2::default([0, hash_table::BASE_WIDTH]), - hash_trace: Array2::default([0, hash_table::BASE_WIDTH]), - sponge_trace: Array2::default([0, hash_table::BASE_WIDTH]), + processor_trace: Array2::default([0, PROCESSOR_WIDTH]), + op_stack_underflow_trace: Array2::default([0, OP_STACK_WIDTH]), + ram_trace: Array2::default([0, RAM_WIDTH]), + program_hash_trace: Array2::default([0, HASH_WIDTH]), + hash_trace: Array2::default([0, HASH_WIDTH]), + sponge_trace: Array2::default([0, HASH_WIDTH]), u32_entries: HashMap::new(), cascade_table_lookup_multiplicities: HashMap::new(), lookup_table_lookup_multiplicities: [0; Self::LOOKUP_TABLE_HEIGHT], @@ -121,9 +131,10 @@ impl AlgebraicExecutionTrace { /// /// [pad]: master_table::MasterBaseTable::pad pub fn height(&self) -> TableHeight { - let relevant_tables = TableId::iter().filter(|&t| t != TableId::DegreeLowering); - let heights = relevant_tables.map(|t| TableHeight::new(t, self.height_of_table(t))); - heights.max().unwrap() + TableId::iter() + .map(|t| TableHeight::new(t, self.height_of_table(t))) + .max() + .unwrap() } pub fn height_of_table(&self, table: TableId) -> usize { @@ -141,7 +152,6 @@ impl AlgebraicExecutionTrace { TableId::Cascade => self.cascade_table_lookup_multiplicities.len(), TableId::Lookup => Self::LOOKUP_TABLE_HEIGHT, TableId::U32 => self.u32_table_height(), - TableId::DegreeLowering => self.height().height, } } @@ -173,7 +183,7 @@ impl AlgebraicExecutionTrace { .zip_eq(chunk_to_absorb) .for_each(|(sponge_state_elem, &absorb_elem)| *sponge_state_elem = absorb_elem); let hash_trace = program_sponge.trace(); - let trace_addendum = HashTable::trace_to_table_rows(hash_trace); + let trace_addendum = table::hash::trace_to_table_rows(hash_trace); self.increase_lookup_multiplicities(hash_trace); self.program_hash_trace @@ -242,7 +252,7 @@ impl AlgebraicExecutionTrace { fn append_hash_trace(&mut self, trace: PermutationTrace) { self.increase_lookup_multiplicities(trace); - let mut hash_trace_addendum = HashTable::trace_to_table_rows(trace); + let mut hash_trace_addendum = table::hash::trace_to_table_rows(trace); hash_trace_addendum .slice_mut(s![.., CI.base_table_index()]) .fill(Instruction::Hash.opcode_b()); @@ -254,7 +264,7 @@ impl AlgebraicExecutionTrace { fn append_initial_sponge_state(&mut self) { let round_number = 0; let initial_state = Tip5::init().state; - let mut hash_table_row = HashTable::trace_row_to_table_row(initial_state, round_number); + let mut hash_table_row = table::hash::trace_row_to_table_row(initial_state, round_number); hash_table_row[CI.base_table_index()] = Instruction::SpongeInit.opcode_b(); self.sponge_trace.push_row(hash_table_row.view()).unwrap(); } @@ -265,7 +275,7 @@ impl AlgebraicExecutionTrace { Instruction::SpongeAbsorb | Instruction::SpongeSqueeze )); self.increase_lookup_multiplicities(trace); - let mut sponge_trace_addendum = HashTable::trace_to_table_rows(trace); + let mut sponge_trace_addendum = table::hash::trace_to_table_rows(trace); sponge_trace_addendum .slice_mut(s![.., CI.base_table_index()]) .fill(instruction.opcode_b()); @@ -298,7 +308,7 @@ impl AlgebraicExecutionTrace { /// Given one state element, increase the multiplicities of the corresponding entries in the /// cascade table and/or the lookup table. fn increase_lookup_multiplicities_for_state_element(&mut self, state_element: BFieldElement) { - for limb in HashTable::base_field_element_into_16_bit_limbs(state_element) { + for limb in table::hash::base_field_element_into_16_bit_limbs(state_element) { match self.cascade_table_lookup_multiplicities.entry(limb) { Occupied(mut cascade_table_entry) => *cascade_table_entry.get_mut() += 1, Vacant(cascade_table_entry) => { diff --git a/triton-vm/src/air.rs b/triton-vm/src/air.rs index 00c403a15..4d1fcf1ae 100644 --- a/triton-vm/src/air.rs +++ b/triton-vm/src/air.rs @@ -4,6 +4,8 @@ pub mod tasm_air_constraints; #[cfg(test)] mod test { + use air::table::NUM_BASE_COLUMNS; + use air::table::NUM_EXT_COLUMNS; use isa::instruction::AnInstruction; use itertools::Itertools; use ndarray::Array1; @@ -17,13 +19,11 @@ mod test { use crate::air::tasm_air_constraints::dynamic_air_constraint_evaluation_tasm; use crate::air::tasm_air_constraints::static_air_constraint_evaluation_tasm; + use crate::challenges::Challenges; use crate::prelude::*; - use crate::table::challenges::Challenges; use crate::table::extension_table::Evaluable; use crate::table::extension_table::Quotientable; use crate::table::master_table::MasterExtTable; - use crate::table::NUM_BASE_COLUMNS; - use crate::table::NUM_EXT_COLUMNS; use super::memory_layout::DynamicTasmConstraintEvaluationMemoryLayout; use super::memory_layout::IntegralMemoryLayout; diff --git a/triton-vm/src/air/memory_layout.rs b/triton-vm/src/air/memory_layout.rs index 7a8e4ff58..c19879b0f 100644 --- a/triton-vm/src/air/memory_layout.rs +++ b/triton-vm/src/air/memory_layout.rs @@ -1,10 +1,10 @@ +use air::table::NUM_BASE_COLUMNS; +use air::table::NUM_EXT_COLUMNS; use arbitrary::Arbitrary; use itertools::Itertools; use twenty_first::prelude::*; -use crate::table::challenges::Challenges; -use crate::table::NUM_BASE_COLUMNS; -use crate::table::NUM_EXT_COLUMNS; +use crate::challenges::Challenges; /// The minimal required size of a memory page in [`BFieldElement`]s. pub const MEM_PAGE_SIZE: usize = 1 << 32; diff --git a/triton-vm/src/challenges.rs b/triton-vm/src/challenges.rs new file mode 100644 index 000000000..ca2e99c5f --- /dev/null +++ b/triton-vm/src/challenges.rs @@ -0,0 +1,186 @@ +//! Challenges are needed for the [cross-table arguments](CrossTableArg), _i.e._, +//! [Permutation Arguments](crate::cross_table_argument::PermArg), +//! [Evaluation Arguments](crate::cross_table_argument::EvalArg), and +//! [Lookup Arguments](crate::cross_table_argument::LookupArg), +//! as well as for the RAM Table's Contiguity Argument. +//! +//! There are three types of challenges: +//! - **Weights**. Weights are used to linearly combine multiple elements into one element. The +//! resulting single element can then be used in a cross-table argument. +//! - **Indeterminates**. All cross-table arguments work by checking the equality of polynomials (or +//! rational functions). Through the Schwartz-Zippel lemma, this equality check can be performed +//! by evaluating the polynomials (or rational functions) in a single point. The challenges that +//! are indeterminates are exactly this evaluation point. The polynomials (or rational functions) +//! are never stored explicitly. Instead, they are directly evaluated at the point indicated by a +//! challenge of “type” `Indeterminate`, giving rise to “running products”, “running +//! evaluations”, _et cetera_. +//! - **Terminals**. The public input (respectively output) of the program is not stored in any +//! table. Instead, the terminal of the Evaluation Argument is computed directly from the +//! public input (respectively output) and the indeterminate. + +use arbitrary::Arbitrary; +use std::ops::Index; +use std::ops::Range; +use std::ops::RangeInclusive; +use strum::EnumCount; +use twenty_first::math::tip5; +use twenty_first::prelude::*; + +use air::challenge_id::ChallengeId; +use air::cross_table_argument::CrossTableArg; +use air::cross_table_argument::EvalArg; + +use crate::prelude::Claim; + +/// The `Challenges` struct holds the challenges used in Triton VM. The concrete +/// challenges are known only at runtime. The challenges are indexed using enum +/// [`ChallengeId`]. The `Challenges` struct is essentially a thin wrapper +/// around an array of [`XFieldElement`]s, providing convenience methods. +#[derive(Debug, Clone, Arbitrary)] +pub struct Challenges { + pub challenges: [XFieldElement; Self::COUNT], +} + +impl Challenges { + /// The total number of challenges used in Triton VM. + pub const COUNT: usize = ChallengeId::COUNT; + + /// The number of weights to sample using the Fiat-Shamir heuristic. This number is lower + /// than the number of challenges because several challenges are not sampled, but computed + /// from publicly known values and other, sampled challenges. + /// + /// Concretely: + /// - The [`StandardInputTerminal`] is computed from Triton VM's public input and the sampled + /// indeterminate [`StandardInputIndeterminate`]. + /// - The [`StandardOutputTerminal`] is computed from Triton VM's public output and the sampled + /// indeterminate [`StandardOutputIndeterminate`]. + /// - The [`LookupTablePublicTerminal`] is computed from the publicly known and constant + /// lookup table and the sampled indeterminate [`LookupTablePublicIndeterminate`]. + /// - The [`CompressedProgramDigest`] is computed from the program to be executed and the + /// sampled indeterminate [`CompressProgramDigestIndeterminate`]. + pub const SAMPLE_COUNT: usize = Self::COUNT - ChallengeId::NUM_DERIVED_CHALLENGES; + + pub fn new(mut challenges: Vec, claim: &Claim) -> Self { + assert_eq!(Self::SAMPLE_COUNT, challenges.len()); + + let compressed_digest = EvalArg::compute_terminal( + &claim.program_digest.values(), + EvalArg::default_initial(), + challenges[ChallengeId::CompressProgramDigestIndeterminate.index()], + ); + let input_terminal = EvalArg::compute_terminal( + &claim.input, + EvalArg::default_initial(), + challenges[ChallengeId::StandardInputIndeterminate.index()], + ); + let output_terminal = EvalArg::compute_terminal( + &claim.output, + EvalArg::default_initial(), + challenges[ChallengeId::StandardOutputIndeterminate.index()], + ); + let lookup_terminal = EvalArg::compute_terminal( + &tip5::LOOKUP_TABLE.map(BFieldElement::from), + EvalArg::default_initial(), + challenges[ChallengeId::LookupTablePublicIndeterminate.index()], + ); + + challenges.insert(ChallengeId::StandardInputTerminal.index(), input_terminal); + challenges.insert(ChallengeId::StandardOutputTerminal.index(), output_terminal); + challenges.insert( + ChallengeId::LookupTablePublicTerminal.index(), + lookup_terminal, + ); + challenges.insert( + ChallengeId::CompressedProgramDigest.index(), + compressed_digest, + ); + assert_eq!(Self::COUNT, challenges.len()); + let challenges = challenges.try_into().unwrap(); + + Self { challenges } + } +} + +impl Index for Challenges { + type Output = XFieldElement; + + fn index(&self, id: usize) -> &Self::Output { + &self.challenges[id] + } +} + +impl Index> for Challenges { + type Output = [XFieldElement]; + + fn index(&self, indices: Range) -> &Self::Output { + &self.challenges[indices.start..indices.end] + } +} + +impl Index> for Challenges { + type Output = [XFieldElement]; + + fn index(&self, indices: RangeInclusive) -> &Self::Output { + &self.challenges[*indices.start()..=*indices.end()] + } +} + +impl Index for Challenges { + type Output = XFieldElement; + + fn index(&self, id: ChallengeId) -> &Self::Output { + &self[id.index()] + } +} + +impl Index> for Challenges { + type Output = [XFieldElement]; + + fn index(&self, indices: Range) -> &Self::Output { + &self[indices.start.index()..indices.end.index()] + } +} + +impl Index> for Challenges { + type Output = [XFieldElement]; + + fn index(&self, indices: RangeInclusive) -> &Self::Output { + &self[indices.start().index()..=indices.end().index()] + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::prelude::Claim; + use twenty_first::xfe; + + // For testing purposes only. + impl Default for Challenges { + fn default() -> Self { + Self::placeholder(&Claim::default()) + } + } + + impl Challenges { + /// Stand-in challenges for use in tests. For non-interactive STARKs, use the + /// Fiat-Shamir heuristic to derive the actual challenges. + pub fn placeholder(claim: &Claim) -> Self { + let stand_in_challenges = (1..=Self::SAMPLE_COUNT) + .map(|i| xfe!([42, i as u64, 24])) + .collect(); + Self::new(stand_in_challenges, claim) + } + } + + #[test] + fn various_challenge_indexing_operations_are_possible() { + let challenges = Challenges::default(); + let _ = challenges[ChallengeId::StackWeight0]; + let _ = challenges[ChallengeId::StackWeight0..ChallengeId::StackWeight8]; + let _ = challenges[ChallengeId::StackWeight0..=ChallengeId::StackWeight8]; + let _ = challenges[0]; + let _ = challenges[0..8]; + let _ = challenges[0..=8]; + } +} diff --git a/triton-vm/src/execution_trace_profiler.rs b/triton-vm/src/execution_trace_profiler.rs index d05eaf3e7..5c01df9bd 100644 --- a/triton-vm/src/execution_trace_profiler.rs +++ b/triton-vm/src/execution_trace_profiler.rs @@ -6,11 +6,11 @@ use std::ops::Add; use std::ops::AddAssign; use std::ops::Sub; +use air::table::hash::PERMUTATION_TRACE_LENGTH; use arbitrary::Arbitrary; use twenty_first::prelude::*; -use crate::table::hash_table::PERMUTATION_TRACE_LENGTH; -use crate::table::u32_table::U32TableEntry; +use crate::table::u32::U32TableEntry; use crate::vm::CoProcessorCall; #[derive(Debug, Default, Clone, Eq, PartialEq, Arbitrary)] diff --git a/triton-vm/src/lib.rs b/triton-vm/src/lib.rs index 6bdac5ea3..32652675f 100644 --- a/triton-vm/src/lib.rs +++ b/triton-vm/src/lib.rs @@ -167,6 +167,7 @@ use crate::prelude::*; pub mod aet; pub mod air; pub mod arithmetic_domain; +pub mod challenges; mod codegen; pub mod config; pub mod error; @@ -329,58 +330,15 @@ mod tests { implements_auto_traits::(); // table things - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); + implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); diff --git a/triton-vm/src/stark.rs b/triton-vm/src/stark.rs index bd7da17ec..ddeb32862 100644 --- a/triton-vm/src/stark.rs +++ b/triton-vm/src/stark.rs @@ -1,6 +1,8 @@ use std::ops::Mul; use std::ops::MulAssign; +use air::table::NUM_BASE_COLUMNS; +use air::table::NUM_EXT_COLUMNS; use arbitrary::Arbitrary; use arbitrary::Unstructured; use itertools::izip; @@ -18,6 +20,7 @@ use twenty_first::prelude::*; use crate::aet::AlgebraicExecutionTrace; use crate::arithmetic_domain::ArithmeticDomain; +use crate::challenges::Challenges; use crate::error::ProvingError; use crate::error::VerificationError; use crate::fri; @@ -27,7 +30,6 @@ use crate::proof::Claim; use crate::proof::Proof; use crate::proof_item::ProofItem; use crate::proof_stream::ProofStream; -use crate::table::challenges::Challenges; use crate::table::extension_table::Evaluable; use crate::table::extension_table::Quotientable; use crate::table::master_table::all_quotients_combined; @@ -36,14 +38,11 @@ use crate::table::master_table::max_degree_with_origin; use crate::table::master_table::MasterBaseTable; use crate::table::master_table::MasterExtTable; use crate::table::master_table::MasterTable; -use crate::table::master_table::AIR_TARGET_DEGREE; use crate::table::QuotientSegments; -use crate::table::NUM_BASE_COLUMNS; -use crate::table::NUM_EXT_COLUMNS; /// The number of segments the quotient polynomial is split into. /// Helps keeping the FRI domain small. -pub const NUM_QUOTIENT_SEGMENTS: usize = AIR_TARGET_DEGREE as usize; +pub const NUM_QUOTIENT_SEGMENTS: usize = air::TARGET_DEGREE as usize; /// The number of randomizer polynomials over the [extension field](XFieldElement) used in the /// [`STARK`](Stark). Integral for achieving zero-knowledge in [FRI](Fri). @@ -567,8 +566,8 @@ impl Stark { /// length of the execution trace and the FRI expansion factor, a security parameter. /// /// In principle, the FRI domain is also influenced by the AIR's degree - /// (see [`AIR_TARGET_DEGREE`]). However, by segmenting the quotient polynomial into - /// [`AIR_TARGET_DEGREE`]-many parts, that influence is mitigated. + /// (see [`TARGET_DEGREE`]). However, by segmenting the quotient polynomial into + /// [`TARGET_DEGREE`]-many parts, that influence is mitigated. pub fn derive_fri(&self, padded_height: usize) -> fri::SetupResult { let interpolant_degree = interpolant_degree(padded_height, self.num_trace_randomizers); let interpolant_codeword_length = interpolant_degree as usize + 1; @@ -1310,6 +1309,30 @@ pub(crate) mod tests { use std::collections::HashMap; use std::collections::HashSet; + use air::challenge_id::ChallengeId::StandardInputIndeterminate; + use air::challenge_id::ChallengeId::StandardOutputIndeterminate; + use air::cross_table_argument::CrossTableArg; + use air::cross_table_argument::EvalArg; + use air::cross_table_argument::GrandCrossTableArg; + use air::table::cascade::CascadeTable; + use air::table::hash::HashTable; + use air::table::jump_stack::JumpStackTable; + use air::table::lookup::LookupTable; + use air::table::op_stack::OpStackTable; + use air::table::processor::ProcessorTable; + use air::table::program::ProgramTable; + use air::table::ram; + use air::table::ram::RamTable; + use air::table::u32::U32Table; + use air::table::TableId; + use air::table_column::MasterBaseTableColumn; + use air::table_column::MasterExtTableColumn; + use air::table_column::OpStackBaseTableColumn; + use air::table_column::ProcessorBaseTableColumn; + use air::table_column::ProcessorExtTableColumn::InputTableEvalArg; + use air::table_column::ProcessorExtTableColumn::OutputTableEvalArg; + use air::table_column::RamBaseTableColumn; + use air::AIR; use assert2::assert; use assert2::check; use assert2::let_assert; @@ -1332,33 +1355,10 @@ pub(crate) mod tests { use crate::error::InstructionError; use crate::example_programs::*; use crate::shared_tests::*; - use crate::table::cascade_table::ExtCascadeTable; - use crate::table::challenges::ChallengeId::StandardInputIndeterminate; - use crate::table::challenges::ChallengeId::StandardOutputIndeterminate; - use crate::table::cross_table_argument::CrossTableArg; - use crate::table::cross_table_argument::EvalArg; - use crate::table::cross_table_argument::GrandCrossTableArg; use crate::table::extension_table; use crate::table::extension_table::Evaluable; use crate::table::extension_table::Quotientable; - use crate::table::hash_table::ExtHashTable; - use crate::table::jump_stack_table::ExtJumpStackTable; - use crate::table::lookup_table::ExtLookupTable; use crate::table::master_table::MasterExtTable; - use crate::table::master_table::TableId; - use crate::table::op_stack_table::ExtOpStackTable; - use crate::table::processor_table::ExtProcessorTable; - use crate::table::program_table::ExtProgramTable; - use crate::table::ram_table; - use crate::table::ram_table::ExtRamTable; - use crate::table::table_column::MasterBaseTableColumn; - use crate::table::table_column::MasterExtTableColumn; - use crate::table::table_column::OpStackBaseTableColumn; - use crate::table::table_column::ProcessorBaseTableColumn; - use crate::table::table_column::ProcessorExtTableColumn::InputTableEvalArg; - use crate::table::table_column::ProcessorExtTableColumn::OutputTableEvalArg; - use crate::table::table_column::RamBaseTableColumn; - use crate::table::u32_table::ExtU32Table; use crate::triton_program; use crate::vm::tests::*; use crate::vm::NonDeterminism; @@ -1469,9 +1469,9 @@ pub(crate) mod tests { let instruction_type = match row[RamBaseTableColumn::InstructionType.base_table_index()] { - ram_table::INSTRUCTION_TYPE_READ => "read", - ram_table::INSTRUCTION_TYPE_WRITE => "write", - ram_table::PADDING_INDICATOR => "pad", + ram::INSTRUCTION_TYPE_READ => "read", + ram::INSTRUCTION_TYPE_WRITE => "write", + ram::PADDING_INDICATOR => "pad", _ => "-", } .to_string(); @@ -1656,57 +1656,57 @@ pub(crate) mod tests { ]; let circuit_builder = ConstraintCircuitBuilder::new(); let all_init = [ - ExtProgramTable::initial_constraints(&circuit_builder), - ExtProcessorTable::initial_constraints(&circuit_builder), - ExtOpStackTable::initial_constraints(&circuit_builder), - ExtRamTable::initial_constraints(&circuit_builder), - ExtJumpStackTable::initial_constraints(&circuit_builder), - ExtHashTable::initial_constraints(&circuit_builder), - ExtCascadeTable::initial_constraints(&circuit_builder), - ExtLookupTable::initial_constraints(&circuit_builder), - ExtU32Table::initial_constraints(&circuit_builder), + ProgramTable::initial_constraints(&circuit_builder), + ProcessorTable::initial_constraints(&circuit_builder), + OpStackTable::initial_constraints(&circuit_builder), + RamTable::initial_constraints(&circuit_builder), + JumpStackTable::initial_constraints(&circuit_builder), + HashTable::initial_constraints(&circuit_builder), + CascadeTable::initial_constraints(&circuit_builder), + LookupTable::initial_constraints(&circuit_builder), + U32Table::initial_constraints(&circuit_builder), GrandCrossTableArg::initial_constraints(&circuit_builder), ] .map(|vec| vec.len()); let circuit_builder = ConstraintCircuitBuilder::new(); let all_cons = [ - ExtProgramTable::consistency_constraints(&circuit_builder), - ExtProcessorTable::consistency_constraints(&circuit_builder), - ExtOpStackTable::consistency_constraints(&circuit_builder), - ExtRamTable::consistency_constraints(&circuit_builder), - ExtJumpStackTable::consistency_constraints(&circuit_builder), - ExtHashTable::consistency_constraints(&circuit_builder), - ExtCascadeTable::consistency_constraints(&circuit_builder), - ExtLookupTable::consistency_constraints(&circuit_builder), - ExtU32Table::consistency_constraints(&circuit_builder), + ProgramTable::consistency_constraints(&circuit_builder), + ProcessorTable::consistency_constraints(&circuit_builder), + OpStackTable::consistency_constraints(&circuit_builder), + RamTable::consistency_constraints(&circuit_builder), + JumpStackTable::consistency_constraints(&circuit_builder), + HashTable::consistency_constraints(&circuit_builder), + CascadeTable::consistency_constraints(&circuit_builder), + LookupTable::consistency_constraints(&circuit_builder), + U32Table::consistency_constraints(&circuit_builder), GrandCrossTableArg::consistency_constraints(&circuit_builder), ] .map(|vec| vec.len()); let circuit_builder = ConstraintCircuitBuilder::new(); let all_trans = [ - ExtProgramTable::transition_constraints(&circuit_builder), - ExtProcessorTable::transition_constraints(&circuit_builder), - ExtOpStackTable::transition_constraints(&circuit_builder), - ExtRamTable::transition_constraints(&circuit_builder), - ExtJumpStackTable::transition_constraints(&circuit_builder), - ExtHashTable::transition_constraints(&circuit_builder), - ExtCascadeTable::transition_constraints(&circuit_builder), - ExtLookupTable::transition_constraints(&circuit_builder), - ExtU32Table::transition_constraints(&circuit_builder), + ProgramTable::transition_constraints(&circuit_builder), + ProcessorTable::transition_constraints(&circuit_builder), + OpStackTable::transition_constraints(&circuit_builder), + RamTable::transition_constraints(&circuit_builder), + JumpStackTable::transition_constraints(&circuit_builder), + HashTable::transition_constraints(&circuit_builder), + CascadeTable::transition_constraints(&circuit_builder), + LookupTable::transition_constraints(&circuit_builder), + U32Table::transition_constraints(&circuit_builder), GrandCrossTableArg::transition_constraints(&circuit_builder), ] .map(|vec| vec.len()); let circuit_builder = ConstraintCircuitBuilder::new(); let all_term = [ - ExtProgramTable::terminal_constraints(&circuit_builder), - ExtProcessorTable::terminal_constraints(&circuit_builder), - ExtOpStackTable::terminal_constraints(&circuit_builder), - ExtRamTable::terminal_constraints(&circuit_builder), - ExtJumpStackTable::terminal_constraints(&circuit_builder), - ExtHashTable::terminal_constraints(&circuit_builder), - ExtCascadeTable::terminal_constraints(&circuit_builder), - ExtLookupTable::terminal_constraints(&circuit_builder), - ExtU32Table::terminal_constraints(&circuit_builder), + ProgramTable::terminal_constraints(&circuit_builder), + ProcessorTable::terminal_constraints(&circuit_builder), + OpStackTable::terminal_constraints(&circuit_builder), + RamTable::terminal_constraints(&circuit_builder), + JumpStackTable::terminal_constraints(&circuit_builder), + HashTable::terminal_constraints(&circuit_builder), + CascadeTable::terminal_constraints(&circuit_builder), + LookupTable::terminal_constraints(&circuit_builder), + U32Table::terminal_constraints(&circuit_builder), GrandCrossTableArg::terminal_constraints(&circuit_builder), ] .map(|vec| vec.len()); @@ -2178,15 +2178,15 @@ pub(crate) mod tests { }; } - check_constraints_fn!(fn check_program_table_constraints for ExtProgramTable); - check_constraints_fn!(fn check_processor_table_constraints for ExtProcessorTable); - check_constraints_fn!(fn check_op_stack_table_constraints for ExtOpStackTable); - check_constraints_fn!(fn check_ram_table_constraints for ExtRamTable); - check_constraints_fn!(fn check_jump_stack_table_constraints for ExtJumpStackTable); - check_constraints_fn!(fn check_hash_table_constraints for ExtHashTable); - check_constraints_fn!(fn check_cascade_table_constraints for ExtCascadeTable); - check_constraints_fn!(fn check_lookup_table_constraints for ExtLookupTable); - check_constraints_fn!(fn check_u32_table_constraints for ExtU32Table); + check_constraints_fn!(fn check_program_table_constraints for ProgramTable); + check_constraints_fn!(fn check_processor_table_constraints for ProcessorTable); + check_constraints_fn!(fn check_op_stack_table_constraints for OpStackTable); + check_constraints_fn!(fn check_ram_table_constraints for RamTable); + check_constraints_fn!(fn check_jump_stack_table_constraints for JumpStackTable); + check_constraints_fn!(fn check_hash_table_constraints for HashTable); + check_constraints_fn!(fn check_cascade_table_constraints for CascadeTable); + check_constraints_fn!(fn check_lookup_table_constraints for LookupTable); + check_constraints_fn!(fn check_u32_table_constraints for U32Table); check_constraints_fn!(fn check_cross_table_constraints for GrandCrossTableArg); fn triton_constraints_evaluate_to_zero(program_and_input: ProgramAndInput) { diff --git a/triton-vm/src/table.rs b/triton-vm/src/table.rs index c1c15dd4a..8bb7e8321 100644 --- a/triton-vm/src/table.rs +++ b/triton-vm/src/table.rs @@ -1,47 +1,69 @@ -pub use crate::stark::NUM_QUOTIENT_SEGMENTS; -pub use crate::table::master_table::NUM_BASE_COLUMNS; -pub use crate::table::master_table::NUM_EXT_COLUMNS; - +use air::cross_table_argument::GrandCrossTableArg; +use air::table::cascade::CascadeTable; +use air::table::hash::HashTable; +use air::table::jump_stack::JumpStackTable; +use air::table::lookup::LookupTable; +use air::table::op_stack::OpStackTable; +use air::table::processor::ProcessorTable; +use air::table::program::ProgramTable; +use air::table::ram::RamTable; +use air::table::u32::U32Table; +use air::table::NUM_BASE_COLUMNS; +use air::table::NUM_EXT_COLUMNS; +use air::AIR; use arbitrary::Arbitrary; use constraint_circuit::ConstraintCircuitBuilder; use constraint_circuit::ConstraintCircuitMonad; use constraint_circuit::DualRowIndicator; use constraint_circuit::SingleRowIndicator; +use ndarray::ArrayView2; +use ndarray::ArrayViewMut2; use strum::Display; use strum::EnumCount; use strum::EnumIter; -use twenty_first::prelude::XFieldElement; +use twenty_first::prelude::*; +use crate::aet::AlgebraicExecutionTrace; +use crate::challenges::Challenges; use crate::codegen::Constraints; -use crate::table::cascade_table::ExtCascadeTable; -use crate::table::cross_table_argument::GrandCrossTableArg; -use crate::table::hash_table::ExtHashTable; -use crate::table::jump_stack_table::ExtJumpStackTable; -use crate::table::lookup_table::ExtLookupTable; -use crate::table::op_stack_table::ExtOpStackTable; -use crate::table::processor_table::ExtProcessorTable; -use crate::table::program_table::ExtProgramTable; -use crate::table::ram_table::ExtRamTable; -use crate::table::u32_table::ExtU32Table; - -pub mod cascade_table; -pub mod challenges; +pub use crate::stark::NUM_QUOTIENT_SEGMENTS; + #[rustfmt::skip] pub mod constraints; -pub mod cross_table_argument; #[rustfmt::skip] pub mod degree_lowering_table; + +pub mod cascade; pub mod extension_table; -pub mod hash_table; -pub mod jump_stack_table; -pub mod lookup_table; +pub mod hash; +pub mod jump_stack; +pub mod lookup; pub mod master_table; -pub mod op_stack_table; -pub mod processor_table; -pub mod program_table; -pub mod ram_table; -pub mod table_column; -pub mod u32_table; +pub mod op_stack; +pub mod processor; +pub mod program; +pub mod ram; +pub mod u32; + +trait TraceTable: AIR { + // a nicer design is in order + type FillParam; + type FillReturnInfo; + + fn fill( + main_table: ArrayViewMut2, + aet: &AlgebraicExecutionTrace, + _: Self::FillParam, + ) -> Self::FillReturnInfo; + + fn pad(main_table: ArrayViewMut2, table_length: usize); + + fn extend( + base_table: ArrayView2, + ext_table: ArrayViewMut2, + challenges: &Challenges, + ); +} #[derive( Debug, Display, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, EnumCount, EnumIter, @@ -62,12 +84,11 @@ pub enum ConstraintType { /// A single row of a [`MasterBaseTable`][table]. /// -/// Usually, the elements in the table are [`BFieldElement`][bfe]s. For out-of-domain rows, which is +/// Usually, the elements in the table are [`BFieldElement`]s. For out-of-domain rows, which is /// relevant for “Domain Extension to Eliminate Pretenders” (DEEP), the elements are /// [`XFieldElement`]s. /// /// [table]: master_table::MasterBaseTable -/// [bfe]: crate::prelude::BFieldElement pub type BaseRow = [T; NUM_BASE_COLUMNS]; /// A single row of a [`MasterExtensionTable`][table]. @@ -92,15 +113,15 @@ pub(crate) fn constraints() -> Constraints { fn initial_constraints() -> Vec> { let circuit_builder = ConstraintCircuitBuilder::new(); vec![ - ExtProgramTable::initial_constraints(&circuit_builder), - ExtProcessorTable::initial_constraints(&circuit_builder), - ExtOpStackTable::initial_constraints(&circuit_builder), - ExtRamTable::initial_constraints(&circuit_builder), - ExtJumpStackTable::initial_constraints(&circuit_builder), - ExtHashTable::initial_constraints(&circuit_builder), - ExtCascadeTable::initial_constraints(&circuit_builder), - ExtLookupTable::initial_constraints(&circuit_builder), - ExtU32Table::initial_constraints(&circuit_builder), + ProgramTable::initial_constraints(&circuit_builder), + ProcessorTable::initial_constraints(&circuit_builder), + OpStackTable::initial_constraints(&circuit_builder), + RamTable::initial_constraints(&circuit_builder), + JumpStackTable::initial_constraints(&circuit_builder), + HashTable::initial_constraints(&circuit_builder), + CascadeTable::initial_constraints(&circuit_builder), + LookupTable::initial_constraints(&circuit_builder), + U32Table::initial_constraints(&circuit_builder), GrandCrossTableArg::initial_constraints(&circuit_builder), ] .concat() @@ -109,15 +130,15 @@ fn initial_constraints() -> Vec> { fn consistency_constraints() -> Vec> { let circuit_builder = ConstraintCircuitBuilder::new(); vec![ - ExtProgramTable::consistency_constraints(&circuit_builder), - ExtProcessorTable::consistency_constraints(&circuit_builder), - ExtOpStackTable::consistency_constraints(&circuit_builder), - ExtRamTable::consistency_constraints(&circuit_builder), - ExtJumpStackTable::consistency_constraints(&circuit_builder), - ExtHashTable::consistency_constraints(&circuit_builder), - ExtCascadeTable::consistency_constraints(&circuit_builder), - ExtLookupTable::consistency_constraints(&circuit_builder), - ExtU32Table::consistency_constraints(&circuit_builder), + ProgramTable::consistency_constraints(&circuit_builder), + ProcessorTable::consistency_constraints(&circuit_builder), + OpStackTable::consistency_constraints(&circuit_builder), + RamTable::consistency_constraints(&circuit_builder), + JumpStackTable::consistency_constraints(&circuit_builder), + HashTable::consistency_constraints(&circuit_builder), + CascadeTable::consistency_constraints(&circuit_builder), + LookupTable::consistency_constraints(&circuit_builder), + U32Table::consistency_constraints(&circuit_builder), GrandCrossTableArg::consistency_constraints(&circuit_builder), ] .concat() @@ -126,15 +147,15 @@ fn consistency_constraints() -> Vec> fn transition_constraints() -> Vec> { let circuit_builder = ConstraintCircuitBuilder::new(); vec![ - ExtProgramTable::transition_constraints(&circuit_builder), - ExtProcessorTable::transition_constraints(&circuit_builder), - ExtOpStackTable::transition_constraints(&circuit_builder), - ExtRamTable::transition_constraints(&circuit_builder), - ExtJumpStackTable::transition_constraints(&circuit_builder), - ExtHashTable::transition_constraints(&circuit_builder), - ExtCascadeTable::transition_constraints(&circuit_builder), - ExtLookupTable::transition_constraints(&circuit_builder), - ExtU32Table::transition_constraints(&circuit_builder), + ProgramTable::transition_constraints(&circuit_builder), + ProcessorTable::transition_constraints(&circuit_builder), + OpStackTable::transition_constraints(&circuit_builder), + RamTable::transition_constraints(&circuit_builder), + JumpStackTable::transition_constraints(&circuit_builder), + HashTable::transition_constraints(&circuit_builder), + CascadeTable::transition_constraints(&circuit_builder), + LookupTable::transition_constraints(&circuit_builder), + U32Table::transition_constraints(&circuit_builder), GrandCrossTableArg::transition_constraints(&circuit_builder), ] .concat() @@ -143,15 +164,15 @@ fn transition_constraints() -> Vec> { fn terminal_constraints() -> Vec> { let circuit_builder = ConstraintCircuitBuilder::new(); vec![ - ExtProgramTable::terminal_constraints(&circuit_builder), - ExtProcessorTable::terminal_constraints(&circuit_builder), - ExtOpStackTable::terminal_constraints(&circuit_builder), - ExtRamTable::terminal_constraints(&circuit_builder), - ExtJumpStackTable::terminal_constraints(&circuit_builder), - ExtHashTable::terminal_constraints(&circuit_builder), - ExtCascadeTable::terminal_constraints(&circuit_builder), - ExtLookupTable::terminal_constraints(&circuit_builder), - ExtU32Table::terminal_constraints(&circuit_builder), + ProgramTable::terminal_constraints(&circuit_builder), + ProcessorTable::terminal_constraints(&circuit_builder), + OpStackTable::terminal_constraints(&circuit_builder), + RamTable::terminal_constraints(&circuit_builder), + JumpStackTable::terminal_constraints(&circuit_builder), + HashTable::terminal_constraints(&circuit_builder), + CascadeTable::terminal_constraints(&circuit_builder), + LookupTable::terminal_constraints(&circuit_builder), + U32Table::terminal_constraints(&circuit_builder), GrandCrossTableArg::terminal_constraints(&circuit_builder), ] .concat() @@ -161,6 +182,26 @@ fn terminal_constraints() -> Vec> { mod tests { use std::collections::HashMap; + use air::table::hash::HashTable; + use air::table::op_stack::OpStackTable; + use air::table::CASCADE_TABLE_END; + use air::table::EXT_CASCADE_TABLE_END; + use air::table::EXT_HASH_TABLE_END; + use air::table::EXT_JUMP_STACK_TABLE_END; + use air::table::EXT_LOOKUP_TABLE_END; + use air::table::EXT_OP_STACK_TABLE_END; + use air::table::EXT_PROCESSOR_TABLE_END; + use air::table::EXT_PROGRAM_TABLE_END; + use air::table::EXT_RAM_TABLE_END; + use air::table::EXT_U32_TABLE_END; + use air::table::HASH_TABLE_END; + use air::table::JUMP_STACK_TABLE_END; + use air::table::LOOKUP_TABLE_END; + use air::table::OP_STACK_TABLE_END; + use air::table::PROCESSOR_TABLE_END; + use air::table::PROGRAM_TABLE_END; + use air::table::RAM_TABLE_END; + use air::table::U32_TABLE_END; use constraint_circuit::BinOp; use constraint_circuit::CircuitExpression; use constraint_circuit::ConstraintCircuit; @@ -177,28 +218,9 @@ mod tests { use rand_core::SeedableRng; use twenty_first::prelude::BFieldElement; + use crate::challenges::Challenges; use crate::prelude::Claim; - use crate::table::challenges::Challenges; use crate::table::degree_lowering_table::DegreeLoweringTable; - use crate::table::master_table::AIR_TARGET_DEGREE; - use crate::table::master_table::CASCADE_TABLE_END; - use crate::table::master_table::EXT_CASCADE_TABLE_END; - use crate::table::master_table::EXT_HASH_TABLE_END; - use crate::table::master_table::EXT_JUMP_STACK_TABLE_END; - use crate::table::master_table::EXT_LOOKUP_TABLE_END; - use crate::table::master_table::EXT_OP_STACK_TABLE_END; - use crate::table::master_table::EXT_PROCESSOR_TABLE_END; - use crate::table::master_table::EXT_PROGRAM_TABLE_END; - use crate::table::master_table::EXT_RAM_TABLE_END; - use crate::table::master_table::EXT_U32_TABLE_END; - use crate::table::master_table::HASH_TABLE_END; - use crate::table::master_table::JUMP_STACK_TABLE_END; - use crate::table::master_table::LOOKUP_TABLE_END; - use crate::table::master_table::OP_STACK_TABLE_END; - use crate::table::master_table::PROCESSOR_TABLE_END; - use crate::table::master_table::PROGRAM_TABLE_END; - use crate::table::master_table::RAM_TABLE_END; - use crate::table::master_table::U32_TABLE_END; use super::*; @@ -304,15 +326,15 @@ mod tests { }}; } - assert_constraint_properties!(ExtProcessorTable); - assert_constraint_properties!(ExtProgramTable); - assert_constraint_properties!(ExtJumpStackTable); - assert_constraint_properties!(ExtOpStackTable); - assert_constraint_properties!(ExtRamTable); - assert_constraint_properties!(ExtHashTable); - assert_constraint_properties!(ExtU32Table); - assert_constraint_properties!(ExtCascadeTable); - assert_constraint_properties!(ExtLookupTable); + assert_constraint_properties!(ProcessorTable); + assert_constraint_properties!(ProgramTable); + assert_constraint_properties!(JumpStackTable); + assert_constraint_properties!(OpStackTable); + assert_constraint_properties!(RamTable); + assert_constraint_properties!(HashTable); + assert_constraint_properties!(U32Table); + assert_constraint_properties!(CascadeTable); + assert_constraint_properties!(LookupTable); } /// Like [`ConstraintCircuitMonad::lower_to_degree`] with additional assertion of expected @@ -458,7 +480,7 @@ mod tests { macro_rules! assert_degree_lowering { ($table:ident ($base_end:ident, $ext_end:ident)) => {{ let degree_lowering_info = DegreeLoweringInfo { - target_degree: AIR_TARGET_DEGREE, + target_degree: air::TARGET_DEGREE, num_base_cols: $base_end, num_ext_cols: $ext_end, }; @@ -480,21 +502,18 @@ mod tests { }}; } - assert_degree_lowering!(ExtProgramTable(PROGRAM_TABLE_END, EXT_PROGRAM_TABLE_END)); - assert_degree_lowering!(ExtProcessorTable( - PROCESSOR_TABLE_END, - EXT_PROCESSOR_TABLE_END - )); - assert_degree_lowering!(ExtOpStackTable(OP_STACK_TABLE_END, EXT_OP_STACK_TABLE_END)); - assert_degree_lowering!(ExtRamTable(RAM_TABLE_END, EXT_RAM_TABLE_END)); - assert_degree_lowering!(ExtJumpStackTable( + assert_degree_lowering!(ProgramTable(PROGRAM_TABLE_END, EXT_PROGRAM_TABLE_END)); + assert_degree_lowering!(ProcessorTable(PROCESSOR_TABLE_END, EXT_PROCESSOR_TABLE_END)); + assert_degree_lowering!(OpStackTable(OP_STACK_TABLE_END, EXT_OP_STACK_TABLE_END)); + assert_degree_lowering!(RamTable(RAM_TABLE_END, EXT_RAM_TABLE_END)); + assert_degree_lowering!(JumpStackTable( JUMP_STACK_TABLE_END, EXT_JUMP_STACK_TABLE_END )); - assert_degree_lowering!(ExtHashTable(HASH_TABLE_END, EXT_HASH_TABLE_END)); - assert_degree_lowering!(ExtCascadeTable(CASCADE_TABLE_END, EXT_CASCADE_TABLE_END)); - assert_degree_lowering!(ExtLookupTable(LOOKUP_TABLE_END, EXT_LOOKUP_TABLE_END)); - assert_degree_lowering!(ExtU32Table(U32_TABLE_END, EXT_U32_TABLE_END)); + assert_degree_lowering!(HashTable(HASH_TABLE_END, EXT_HASH_TABLE_END)); + assert_degree_lowering!(CascadeTable(CASCADE_TABLE_END, EXT_CASCADE_TABLE_END)); + assert_degree_lowering!(LookupTable(LOOKUP_TABLE_END, EXT_LOOKUP_TABLE_END)); + assert_degree_lowering!(U32Table(U32_TABLE_END, EXT_U32_TABLE_END)); } /// Fills the derived columns of the degree-lowering table using randomly generated rows and diff --git a/triton-vm/src/table/cascade.rs b/triton-vm/src/table/cascade.rs new file mode 100644 index 000000000..a12eb5d18 --- /dev/null +++ b/triton-vm/src/table/cascade.rs @@ -0,0 +1,124 @@ +use air::challenge_id::ChallengeId; +use air::challenge_id::ChallengeId::*; +use air::cross_table_argument::CrossTableArg; +use air::cross_table_argument::LookupArg; +use air::table::cascade::CascadeTable; +use air::table_column::CascadeBaseTableColumn; +use air::table_column::CascadeBaseTableColumn::*; +use air::table_column::CascadeExtTableColumn; +use air::table_column::CascadeExtTableColumn::*; +use air::table_column::MasterBaseTableColumn; +use air::table_column::MasterExtTableColumn; +use air::AIR; +use constraint_circuit::ConstraintCircuitBuilder; +use constraint_circuit::ConstraintCircuitMonad; +use constraint_circuit::DualRowIndicator; +use constraint_circuit::DualRowIndicator::*; +use constraint_circuit::SingleRowIndicator; +use constraint_circuit::SingleRowIndicator::*; +use ndarray::s; +use ndarray::ArrayView2; +use ndarray::ArrayViewMut2; +use num_traits::ConstOne; +use num_traits::One; +use strum::EnumCount; +use twenty_first::prelude::*; + +use crate::aet::AlgebraicExecutionTrace; +use crate::challenges::Challenges; +use crate::profiler::profiler; +use crate::table::TraceTable; + +fn lookup_8_bit_limb(to_look_up: u8) -> BFieldElement { + tip5::LOOKUP_TABLE[usize::from(to_look_up)].into() +} + +pub(crate) fn lookup_16_bit_limb(to_look_up: u16) -> BFieldElement { + let to_look_up_lo = (to_look_up & 0xff) as u8; + let to_look_up_hi = ((to_look_up >> 8) & 0xff) as u8; + let looked_up_lo = lookup_8_bit_limb(to_look_up_lo); + let looked_up_hi = lookup_8_bit_limb(to_look_up_hi); + bfe!(1 << 8) * looked_up_hi + looked_up_lo +} + +impl TraceTable for CascadeTable { + type FillParam = (); + type FillReturnInfo = (); + + fn fill(mut main_table: ArrayViewMut2, aet: &AlgebraicExecutionTrace, _: ()) { + for (row_idx, (&to_look_up, &multiplicity)) in + aet.cascade_table_lookup_multiplicities.iter().enumerate() + { + let to_look_up_lo = (to_look_up & 0xff) as u8; + let to_look_up_hi = ((to_look_up >> 8) & 0xff) as u8; + + let mut row = main_table.row_mut(row_idx); + row[LookInLo.base_table_index()] = bfe!(to_look_up_lo); + row[LookInHi.base_table_index()] = bfe!(to_look_up_hi); + row[LookOutLo.base_table_index()] = lookup_8_bit_limb(to_look_up_lo); + row[LookOutHi.base_table_index()] = lookup_8_bit_limb(to_look_up_hi); + row[LookupMultiplicity.base_table_index()] = bfe!(multiplicity); + } + } + + fn pad(mut main_table: ArrayViewMut2, cascade_table_length: usize) { + main_table + .slice_mut(s![cascade_table_length.., IsPadding.base_table_index()]) + .fill(BFieldElement::ONE); + } + + fn extend( + main_table: ArrayView2, + mut aux_table: ArrayViewMut2, + challenges: &Challenges, + ) { + profiler!(start "cascade table"); + assert_eq!(Self::MainColumn::COUNT, main_table.ncols()); + assert_eq!(Self::AuxColumn::COUNT, aux_table.ncols()); + assert_eq!(main_table.nrows(), aux_table.nrows()); + + let mut hash_table_log_derivative = LookupArg::default_initial(); + let mut lookup_table_log_derivative = LookupArg::default_initial(); + + let two_pow_8 = bfe!(1 << 8); + + let hash_indeterminate = challenges[HashCascadeLookupIndeterminate]; + let hash_input_weight = challenges[HashCascadeLookInWeight]; + let hash_output_weight = challenges[HashCascadeLookOutWeight]; + + let lookup_indeterminate = challenges[CascadeLookupIndeterminate]; + let lookup_input_weight = challenges[LookupTableInputWeight]; + let lookup_output_weight = challenges[LookupTableOutputWeight]; + + for row_idx in 0..main_table.nrows() { + let base_row = main_table.row(row_idx); + let is_padding = base_row[IsPadding.base_table_index()].is_one(); + + if !is_padding { + let look_in = two_pow_8 * base_row[LookInHi.base_table_index()] + + base_row[LookInLo.base_table_index()]; + let look_out = two_pow_8 * base_row[LookOutHi.base_table_index()] + + base_row[LookOutLo.base_table_index()]; + let compressed_row_hash = + hash_input_weight * look_in + hash_output_weight * look_out; + let lookup_multiplicity = base_row[LookupMultiplicity.base_table_index()]; + hash_table_log_derivative += + (hash_indeterminate - compressed_row_hash).inverse() * lookup_multiplicity; + + let compressed_row_lo = lookup_input_weight * base_row[LookInLo.base_table_index()] + + lookup_output_weight * base_row[LookOutLo.base_table_index()]; + let compressed_row_hi = lookup_input_weight * base_row[LookInHi.base_table_index()] + + lookup_output_weight * base_row[LookOutHi.base_table_index()]; + lookup_table_log_derivative += (lookup_indeterminate - compressed_row_lo).inverse(); + lookup_table_log_derivative += (lookup_indeterminate - compressed_row_hi).inverse(); + } + + let mut extension_row = aux_table.row_mut(row_idx); + extension_row[HashTableServerLogDerivative.ext_table_index()] = + hash_table_log_derivative; + extension_row[LookupTableClientLogDerivative.ext_table_index()] = + lookup_table_log_derivative; + } + profiler!(stop "cascade table"); + } +} diff --git a/triton-vm/src/table/challenges.rs b/triton-vm/src/table/challenges.rs deleted file mode 100644 index e5b59b855..000000000 --- a/triton-vm/src/table/challenges.rs +++ /dev/null @@ -1,396 +0,0 @@ -//! Challenges are needed for the [cross-table arguments](CrossTableArg), _i.e._, -//! [Permutation Arguments](crate::table::cross_table_argument::PermArg), -//! [Evaluation Arguments](crate::table::cross_table_argument::EvalArg), and -//! [Lookup Arguments](crate::table::cross_table_argument::LookupArg), -//! as well as for the RAM Table's Contiguity Argument. -//! -//! There are three types of challenges: -//! - **Weights**. Weights are used to linearly combine multiple elements into one element. The -//! resulting single element can then be used in a cross-table argument. -//! - **Indeterminates**. All cross-table arguments work by checking the equality of polynomials (or -//! rational functions). Through the Schwartz-Zippel lemma, this equality check can be performed -//! by evaluating the polynomials (or rational functions) in a single point. The challenges that -//! are indeterminates are exactly this evaluation point. The polynomials (or rational functions) -//! are never stored explicitly. Instead, they are directly evaluated at the point indicated by a -//! challenge of “type” `Indeterminate`, giving rise to “running products”, “running -//! evaluations”, _et cetera_. -//! - **Terminals**. The public input (respectively output) of the program is not stored in any -//! table. Instead, the terminal of the Evaluation Argument is computed directly from the -//! public input (respectively output) and the indeterminate. - -use std::fmt::Debug; -use std::hash::Hash; -use std::ops::Index; -use std::ops::Range; -use std::ops::RangeInclusive; - -use arbitrary::Arbitrary; -use strum::Display; -use strum::EnumCount; -use strum::EnumIter; -use twenty_first::prelude::*; - -use crate::table::challenges::ChallengeId::*; -use crate::table::cross_table_argument::CrossTableArg; -use crate::table::cross_table_argument::EvalArg; -use crate::Claim; - -/// A `ChallengeId` is a unique, symbolic identifier for a challenge used in Triton VM. The -/// `ChallengeId` enum works in tandem with the struct [`Challenges`], which can be -/// instantiated to hold actual challenges that can be indexed by some `ChallengeId`. -/// -/// Since almost all challenges relate to the Processor Table in some form, the words “Processor -/// Table” are usually omitted from the `ChallengeId`'s name. -#[repr(usize)] -#[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter, Arbitrary)] -pub enum ChallengeId { - /// The indeterminate for the [Evaluation Argument](EvalArg) compressing the program digest - /// into a single extension field element, _i.e._, [`CompressedProgramDigest`]. - /// Relates to program attestation. - CompressProgramDigestIndeterminate, - - /// The indeterminate for the [Evaluation Argument](EvalArg) with standard input. - StandardInputIndeterminate, - - /// The indeterminate for the [Evaluation Argument](EvalArg) with standard output. - StandardOutputIndeterminate, - - /// The indeterminate for the instruction - /// [Lookup Argument](crate::table::cross_table_argument::LookupArg) - /// between the [Processor Table](crate::table::processor_table) and the - /// [Program Table](crate::table::program_table) guaranteeing that the instructions and their - /// arguments are copied correctly. - InstructionLookupIndeterminate, - - HashInputIndeterminate, - HashDigestIndeterminate, - SpongeIndeterminate, - - OpStackIndeterminate, - RamIndeterminate, - JumpStackIndeterminate, - - U32Indeterminate, - - /// The indeterminate for the Lookup Argument between the Processor Table and all memory-like - /// tables, _i.e._, the OpStack Table, the Ram Table, and the JumpStack Table, guaranteeing - /// that all clock jump differences are directed forward. - ClockJumpDifferenceLookupIndeterminate, - - /// The indeterminate for the Contiguity Argument within the Ram Table. - RamTableBezoutRelationIndeterminate, - - /// A weight for linearly combining multiple elements. Applies to - /// - `Address` in the Program Table - /// - `IP` in the Processor Table - ProgramAddressWeight, - - /// A weight for linearly combining multiple elements. Applies to - /// - `Instruction` in the Program Table - /// - `CI` in the Processor Table - ProgramInstructionWeight, - - /// A weight for linearly combining multiple elements. Applies to - /// - `Instruction'` (_i.e._, in the next row) in the Program Table - /// - `NIA` in the Processor Table - ProgramNextInstructionWeight, - - OpStackClkWeight, - OpStackIb1Weight, - OpStackPointerWeight, - OpStackFirstUnderflowElementWeight, - - RamClkWeight, - RamPointerWeight, - RamValueWeight, - RamInstructionTypeWeight, - - JumpStackClkWeight, - JumpStackCiWeight, - JumpStackJspWeight, - JumpStackJsoWeight, - JumpStackJsdWeight, - - /// The indeterminate for compressing a [`RATE`][rate]-sized chunk of instructions into a - /// single extension field element. - /// Relates to program attestation. - /// - /// Used by the evaluation argument [`PrepareChunkEvalArg`][prep] and in the Hash Table. - /// - /// [rate]: twenty_first::math::tip5::RATE - /// [prep]: crate::table::table_column::ProgramExtTableColumn::PrepareChunkRunningEvaluation - ProgramAttestationPrepareChunkIndeterminate, - - /// The indeterminate for the bus over which the [`RATE`][rate]-sized chunks of instructions - /// are sent. Relates to program attestation. - /// Used by the evaluation arguments [`SendChunkEvalArg`][send] and - /// [`ReceiveChunkEvalArg`][recv]. See also: [`ProgramAttestationPrepareChunkIndeterminate`]. - /// - /// [rate]: twenty_first::math::tip5::RATE - /// [send]: crate::table::table_column::ProgramExtTableColumn::SendChunkRunningEvaluation - /// [recv]: crate::table::table_column::HashExtTableColumn::ReceiveChunkRunningEvaluation - ProgramAttestationSendChunkIndeterminate, - - HashCIWeight, - - StackWeight0, - StackWeight1, - StackWeight2, - StackWeight3, - StackWeight4, - StackWeight5, - StackWeight6, - StackWeight7, - StackWeight8, - StackWeight9, - StackWeight10, - StackWeight11, - StackWeight12, - StackWeight13, - StackWeight14, - StackWeight15, - - /// The indeterminate for the Lookup Argument between the Hash Table and the Cascade Table. - HashCascadeLookupIndeterminate, - - /// A weight for linearly combining multiple elements. Applies to - /// - `*LkIn` in the Hash Table, and - /// - `2^16·LookInHi + LookInLo` in the Cascade Table. - HashCascadeLookInWeight, - - /// A weight for linearly combining multiple elements. Applies to - /// - `*LkOut` in the Hash Table, and - /// - `2^16·LookOutHi + LookOutLo` in the Cascade Table. - HashCascadeLookOutWeight, - - /// The indeterminate for the Lookup Argument between the Cascade Table and the Lookup Table. - CascadeLookupIndeterminate, - - /// A weight for linearly combining multiple elements. Applies to - /// - `LkIn*` in the Cascade Table, and - /// - `LookIn` in the Lookup Table. - LookupTableInputWeight, - - /// A weight for linearly combining multiple elements. Applies to - /// - `LkOut*` in the Cascade Table, and - /// - `LookOut` in the Lookup Table. - LookupTableOutputWeight, - - /// The indeterminate for the public Evaluation Argument establishing correctness of the - /// Lookup Table. - LookupTablePublicIndeterminate, - - U32LhsWeight, - U32RhsWeight, - U32CiWeight, - U32ResultWeight, - - /// The terminal for the [`EvaluationArgument`](EvalArg) with standard input. - /// Makes use of challenge [`StandardInputIndeterminate`]. - StandardInputTerminal, - - /// The terminal for the [`EvaluationArgument`](EvalArg) with standard output. - /// Makes use of challenge [`StandardOutputIndeterminate`]. - StandardOutputTerminal, - - /// The terminal for the [`EvaluationArgument`](EvalArg) establishing correctness of the - /// [Lookup Table](crate::table::lookup_table::LookupTable). - /// Makes use of challenge [`LookupTablePublicIndeterminate`]. - LookupTablePublicTerminal, - - /// The digest of the program to be executed, compressed into a single extension field element. - /// The compression happens using an [`EvaluationArgument`](EvalArg) under challenge - /// [`CompressProgramDigestIndeterminate`]. - /// Relates to program attestation. - CompressedProgramDigest, -} - -impl ChallengeId { - pub const fn index(&self) -> usize { - *self as usize - } -} - -impl From for usize { - fn from(id: ChallengeId) -> Self { - id.index() - } -} - -/// The `Challenges` struct holds the challenges used in Triton VM. The concrete challenges are -/// known only at runtime. The challenges are indexed using enum [`ChallengeId`]. The `Challenges` -/// struct is essentially a thin wrapper around an array of [`XFieldElement`]s, providing -/// convenience methods. -#[derive(Debug, Clone, Arbitrary)] -pub struct Challenges { - pub challenges: [XFieldElement; Self::COUNT], -} - -impl Challenges { - /// The total number of challenges used in Triton VM. - pub const COUNT: usize = ChallengeId::COUNT; - - /// The number of weights to sample using the Fiat-Shamir heuristic. This number is lower - /// than the number of challenges because several challenges are not sampled, but computed - /// from publicly known values and other, sampled challenges. - /// - /// Concretely: - /// - The [`StandardInputTerminal`] is computed from Triton VM's public input and the sampled - /// indeterminate [`StandardInputIndeterminate`]. - /// - The [`StandardOutputTerminal`] is computed from Triton VM's public output and the sampled - /// indeterminate [`StandardOutputIndeterminate`]. - /// - The [`LookupTablePublicTerminal`] is computed from the publicly known and constant - /// lookup table and the sampled indeterminate [`LookupTablePublicIndeterminate`]. - /// - The [`CompressedProgramDigest`] is computed from the program to be executed and the - /// sampled indeterminate [`CompressProgramDigestIndeterminate`]. - // When modifying this, be sure to add to the compile-time assertions in the - // `#[test] const fn compile_time_index_assertions() { … }` - // at the end of this file. - pub const SAMPLE_COUNT: usize = Self::COUNT - 4; - - pub fn new(mut challenges: Vec, claim: &Claim) -> Self { - assert_eq!(Self::SAMPLE_COUNT, challenges.len()); - - let compressed_digest = EvalArg::compute_terminal( - &claim.program_digest.values(), - EvalArg::default_initial(), - challenges[CompressProgramDigestIndeterminate.index()], - ); - let input_terminal = EvalArg::compute_terminal( - &claim.input, - EvalArg::default_initial(), - challenges[StandardInputIndeterminate.index()], - ); - let output_terminal = EvalArg::compute_terminal( - &claim.output, - EvalArg::default_initial(), - challenges[StandardOutputIndeterminate.index()], - ); - let lookup_terminal = EvalArg::compute_terminal( - &tip5::LOOKUP_TABLE.map(BFieldElement::from), - EvalArg::default_initial(), - challenges[LookupTablePublicIndeterminate.index()], - ); - - challenges.insert(StandardInputTerminal.index(), input_terminal); - challenges.insert(StandardOutputTerminal.index(), output_terminal); - challenges.insert(LookupTablePublicTerminal.index(), lookup_terminal); - challenges.insert(CompressedProgramDigest.index(), compressed_digest); - assert_eq!(Self::COUNT, challenges.len()); - let challenges = challenges.try_into().unwrap(); - - Self { challenges } - } -} - -impl Index for Challenges { - type Output = XFieldElement; - - fn index(&self, id: usize) -> &Self::Output { - &self.challenges[id] - } -} - -impl Index> for Challenges { - type Output = [XFieldElement]; - - fn index(&self, indices: Range) -> &Self::Output { - &self.challenges[indices.start..indices.end] - } -} - -impl Index> for Challenges { - type Output = [XFieldElement]; - - fn index(&self, indices: RangeInclusive) -> &Self::Output { - &self.challenges[*indices.start()..=*indices.end()] - } -} - -impl Index for Challenges { - type Output = XFieldElement; - - fn index(&self, id: ChallengeId) -> &Self::Output { - &self[id.index()] - } -} - -impl Index> for Challenges { - type Output = [XFieldElement]; - - fn index(&self, indices: Range) -> &Self::Output { - &self[indices.start.index()..indices.end.index()] - } -} - -impl Index> for Challenges { - type Output = [XFieldElement]; - - fn index(&self, indices: RangeInclusive) -> &Self::Output { - &self[indices.start().index()..=indices.end().index()] - } -} - -#[cfg(test)] -pub(crate) mod tests { - use super::*; - - // For testing purposes only. - impl Default for Challenges { - fn default() -> Self { - Self::placeholder(&Claim::default()) - } - } - - impl Challenges { - /// Stand-in challenges for use in tests. For non-interactive STARKs, use the - /// Fiat-Shamir heuristic to derive the actual challenges. - pub fn placeholder(claim: &Claim) -> Self { - let stand_in_challenges = (1..=Self::SAMPLE_COUNT) - .map(|i| xfe!([42, i as u64, 24])) - .collect(); - Self::new(stand_in_challenges, claim) - } - } - - #[test] - const fn compile_time_index_assertions() { - // Terminal challenges are computed from public information, such as public input or - // public output, and other challenges. Because these other challenges are used to compute - // the terminal challenges, the terminal challenges must be inserted into the challenges - // vector after the used challenges. - assert!(StandardInputIndeterminate.index() < StandardInputTerminal.index()); - assert!(StandardInputIndeterminate.index() < StandardOutputTerminal.index()); - assert!(StandardInputIndeterminate.index() < LookupTablePublicTerminal.index()); - assert!(StandardInputIndeterminate.index() < CompressedProgramDigest.index()); - - assert!(StandardOutputIndeterminate.index() < StandardInputTerminal.index()); - assert!(StandardOutputIndeterminate.index() < StandardOutputTerminal.index()); - assert!(StandardOutputIndeterminate.index() < LookupTablePublicTerminal.index()); - assert!(StandardOutputIndeterminate.index() < CompressedProgramDigest.index()); - - assert!(CompressProgramDigestIndeterminate.index() < StandardInputTerminal.index()); - assert!(CompressProgramDigestIndeterminate.index() < StandardOutputTerminal.index()); - assert!(CompressProgramDigestIndeterminate.index() < LookupTablePublicTerminal.index()); - assert!(CompressProgramDigestIndeterminate.index() < CompressedProgramDigest.index()); - - assert!(LookupTablePublicIndeterminate.index() < StandardInputTerminal.index()); - assert!(LookupTablePublicIndeterminate.index() < StandardOutputTerminal.index()); - assert!(LookupTablePublicIndeterminate.index() < LookupTablePublicTerminal.index()); - assert!(LookupTablePublicIndeterminate.index() < CompressedProgramDigest.index()); - } - - // Ensure the compile-time assertions are actually executed by the compiler. - const _: () = compile_time_index_assertions(); - - #[test] - fn various_challenge_indexing_operations_are_possible() { - let challenges = Challenges::placeholder(&Claim::default()); - let _ = challenges[StackWeight0]; - let _ = challenges[StackWeight0..StackWeight8]; - let _ = challenges[StackWeight0..=StackWeight8]; - let _ = challenges[0]; - let _ = challenges[0..8]; - let _ = challenges[0..=8]; - } -} diff --git a/triton-vm/src/table/constraints.rs b/triton-vm/src/table/constraints.rs index 4933815c2..abd5383de 100644 --- a/triton-vm/src/table/constraints.rs +++ b/triton-vm/src/table/constraints.rs @@ -6,7 +6,7 @@ use ndarray::ArrayView1; use twenty_first::prelude::BFieldElement; use twenty_first::prelude::XFieldElement; -use crate::table::challenges::Challenges; +use crate::challenges::Challenges; use crate::table::extension_table::Evaluable; use crate::table::extension_table::Quotientable; use crate::table::master_table::MasterExtTable; diff --git a/triton-vm/src/table/degree_lowering_table.rs b/triton-vm/src/table/degree_lowering_table.rs index b097844b9..697567e20 100644 --- a/triton-vm/src/table/degree_lowering_table.rs +++ b/triton-vm/src/table/degree_lowering_table.rs @@ -9,11 +9,7 @@ use strum::EnumIter; use twenty_first::prelude::BFieldElement; use twenty_first::prelude::XFieldElement; -use crate::table::challenges::Challenges; - -pub const BASE_WIDTH: usize = DegreeLoweringBaseTableColumn::COUNT; -pub const EXT_WIDTH: usize = DegreeLoweringExtTableColumn::COUNT; -pub const FULL_WIDTH: usize = BASE_WIDTH + EXT_WIDTH; +use crate::challenges::Challenges; #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] pub enum DegreeLoweringBaseTableColumn {} diff --git a/triton-vm/src/table/extension_table.rs b/triton-vm/src/table/extension_table.rs index 4cfbbc49c..c43485f44 100644 --- a/triton-vm/src/table/extension_table.rs +++ b/triton-vm/src/table/extension_table.rs @@ -7,7 +7,7 @@ use ndarray::ArrayView1; use twenty_first::math::traits::FiniteField; use twenty_first::prelude::*; -use crate::table::challenges::Challenges; +use crate::challenges::Challenges; use crate::table::master_table::MasterExtTable; use crate::table::ConstraintType; diff --git a/triton-vm/src/table/hash.rs b/triton-vm/src/table/hash.rs new file mode 100644 index 000000000..d0a3df342 --- /dev/null +++ b/triton-vm/src/table/hash.rs @@ -0,0 +1,649 @@ +use air::challenge_id::ChallengeId::*; +use air::cross_table_argument::CrossTableArg; +use air::cross_table_argument::EvalArg; +use air::cross_table_argument::LookupArg; +use air::table::hash::HashTable; +use air::table::hash::HashTableMode; +use air::table::hash::PermutationTrace; +use air::table::hash::MONTGOMERY_MODULUS; +use air::table::hash::NUM_ROUND_CONSTANTS; +use air::table_column::HashBaseTableColumn::*; +use air::table_column::HashExtTableColumn::*; +use air::table_column::MasterBaseTableColumn; +use air::table_column::MasterExtTableColumn; +use air::AIR; +use constraint_circuit::ConstraintCircuitBuilder; +use constraint_circuit::ConstraintCircuitMonad; +use constraint_circuit::DualRowIndicator; +use constraint_circuit::DualRowIndicator::*; +use constraint_circuit::InputIndicator; +use constraint_circuit::SingleRowIndicator; +use constraint_circuit::SingleRowIndicator::*; +use isa::instruction::AnInstruction::Hash; +use isa::instruction::AnInstruction::SpongeAbsorb; +use isa::instruction::AnInstruction::SpongeInit; +use isa::instruction::AnInstruction::SpongeSqueeze; +use isa::instruction::Instruction; +use itertools::Itertools; +use ndarray::*; +use num_traits::Zero; +use strum::Display; +use strum::EnumCount; +use strum::EnumIter; +use strum::IntoEnumIterator; +use twenty_first::prelude::tip5::NUM_ROUNDS; +use twenty_first::prelude::tip5::NUM_SPLIT_AND_LOOKUP; +use twenty_first::prelude::tip5::RATE; +use twenty_first::prelude::tip5::ROUND_CONSTANTS; +use twenty_first::prelude::tip5::STATE_SIZE; +use twenty_first::prelude::*; + +use crate::aet::AlgebraicExecutionTrace; +use crate::challenges::Challenges; +use crate::profiler::profiler; +use crate::table::TraceTable; + +/// Return the 16-bit chunks of the “un-Montgomery'd” representation, in little-endian chunk +/// order. This (basically) translates to the application of `σ(R·x)` for input `x`, which +/// are the first two steps in Tip5's split-and-lookup S-Box. +/// `R` is the Montgomery modulus, _i.e._, `R = 2^64 mod p`. +/// `σ` as described in the paper decomposes the 64-bit input into 8-bit limbs, whereas +/// this method decomposes into 16-bit limbs for arithmetization reasons; the 16-bit limbs +/// are split into 8-bit limbs in the Cascade Table. +/// For a more in-depth explanation of all the necessary steps in the split-and-lookup S-Box, +/// see the [Tip5 paper](https://eprint.iacr.org/2023/107.pdf). +/// +/// Note: this is distinct from the seemingly similar [`raw_u16s`](BFieldElement::raw_u16s). +pub(crate) fn base_field_element_into_16_bit_limbs(x: BFieldElement) -> [u16; 4] { + let r_times_x = (MONTGOMERY_MODULUS * x).value(); + [0, 16, 32, 48].map(|shift| ((r_times_x >> shift) & 0xffff) as u16) +} + +/// Convert a permutation trace to a segment in the Hash Table. +/// +/// **Note**: The current instruction [`CI`] is _not_ set. +pub(crate) fn trace_to_table_rows(trace: PermutationTrace) -> Array2 { + let mut table_rows = Array2::default([0, ::MainColumn::COUNT]); + for (round_number, &trace_row) in trace.iter().enumerate() { + let table_row = trace_row_to_table_row(trace_row, round_number); + table_rows.push_row(table_row.view()).unwrap(); + } + table_rows +} + +pub(crate) fn trace_row_to_table_row( + trace_row: [BFieldElement; STATE_SIZE], + round_number: usize, +) -> Array1 { + let row = Array1::zeros([::MainColumn::COUNT]); + let row = fill_row_with_round_number(row, round_number); + let row = fill_row_with_split_state_elements_using_trace_row(row, trace_row); + let row = fill_row_with_unsplit_state_elements_using_trace_row(row, trace_row); + let row = fill_row_with_state_inverses_using_trace_row(row, trace_row); + fill_row_with_round_constants_for_round(row, round_number) +} + +fn fill_row_with_round_number( + mut row: Array1, + round_number: usize, +) -> Array1 { + row[RoundNumber.base_table_index()] = bfe!(round_number as u64); + row +} + +fn fill_row_with_split_state_elements_using_trace_row( + row: Array1, + trace_row: [BFieldElement; STATE_SIZE], +) -> Array1 { + let row = fill_split_state_element_0_of_row_using_trace_row(row, trace_row); + let row = fill_split_state_element_1_of_row_using_trace_row(row, trace_row); + let row = fill_split_state_element_2_of_row_using_trace_row(row, trace_row); + fill_split_state_element_3_of_row_using_trace_row(row, trace_row) +} + +fn fill_split_state_element_0_of_row_using_trace_row( + mut row: Array1, + trace_row: [BFieldElement; STATE_SIZE], +) -> Array1 { + let limbs = base_field_element_into_16_bit_limbs(trace_row[0]); + let look_in_split = limbs.map(|limb| bfe!(limb)); + row[State0LowestLkIn.base_table_index()] = look_in_split[0]; + row[State0MidLowLkIn.base_table_index()] = look_in_split[1]; + row[State0MidHighLkIn.base_table_index()] = look_in_split[2]; + row[State0HighestLkIn.base_table_index()] = look_in_split[3]; + + let look_out_split = limbs.map(crate::table::cascade::lookup_16_bit_limb); + row[State0LowestLkOut.base_table_index()] = look_out_split[0]; + row[State0MidLowLkOut.base_table_index()] = look_out_split[1]; + row[State0MidHighLkOut.base_table_index()] = look_out_split[2]; + row[State0HighestLkOut.base_table_index()] = look_out_split[3]; + + row +} + +fn fill_split_state_element_1_of_row_using_trace_row( + mut row: Array1, + trace_row: [BFieldElement; STATE_SIZE], +) -> Array1 { + let limbs = base_field_element_into_16_bit_limbs(trace_row[1]); + let look_in_split = limbs.map(|limb| bfe!(limb)); + row[State1LowestLkIn.base_table_index()] = look_in_split[0]; + row[State1MidLowLkIn.base_table_index()] = look_in_split[1]; + row[State1MidHighLkIn.base_table_index()] = look_in_split[2]; + row[State1HighestLkIn.base_table_index()] = look_in_split[3]; + + let look_out_split = limbs.map(crate::table::cascade::lookup_16_bit_limb); + row[State1LowestLkOut.base_table_index()] = look_out_split[0]; + row[State1MidLowLkOut.base_table_index()] = look_out_split[1]; + row[State1MidHighLkOut.base_table_index()] = look_out_split[2]; + row[State1HighestLkOut.base_table_index()] = look_out_split[3]; + + row +} + +fn fill_split_state_element_2_of_row_using_trace_row( + mut row: Array1, + trace_row: [BFieldElement; STATE_SIZE], +) -> Array1 { + let limbs = base_field_element_into_16_bit_limbs(trace_row[2]); + let look_in_split = limbs.map(|limb| bfe!(limb)); + row[State2LowestLkIn.base_table_index()] = look_in_split[0]; + row[State2MidLowLkIn.base_table_index()] = look_in_split[1]; + row[State2MidHighLkIn.base_table_index()] = look_in_split[2]; + row[State2HighestLkIn.base_table_index()] = look_in_split[3]; + + let look_out_split = limbs.map(crate::table::cascade::lookup_16_bit_limb); + row[State2LowestLkOut.base_table_index()] = look_out_split[0]; + row[State2MidLowLkOut.base_table_index()] = look_out_split[1]; + row[State2MidHighLkOut.base_table_index()] = look_out_split[2]; + row[State2HighestLkOut.base_table_index()] = look_out_split[3]; + + row +} + +fn fill_split_state_element_3_of_row_using_trace_row( + mut row: Array1, + trace_row: [BFieldElement; STATE_SIZE], +) -> Array1 { + let limbs = base_field_element_into_16_bit_limbs(trace_row[3]); + let look_in_split = limbs.map(|limb| bfe!(limb)); + row[State3LowestLkIn.base_table_index()] = look_in_split[0]; + row[State3MidLowLkIn.base_table_index()] = look_in_split[1]; + row[State3MidHighLkIn.base_table_index()] = look_in_split[2]; + row[State3HighestLkIn.base_table_index()] = look_in_split[3]; + + let look_out_split = limbs.map(crate::table::cascade::lookup_16_bit_limb); + row[State3LowestLkOut.base_table_index()] = look_out_split[0]; + row[State3MidLowLkOut.base_table_index()] = look_out_split[1]; + row[State3MidHighLkOut.base_table_index()] = look_out_split[2]; + row[State3HighestLkOut.base_table_index()] = look_out_split[3]; + + row +} + +fn fill_row_with_unsplit_state_elements_using_trace_row( + mut row: Array1, + trace_row: [BFieldElement; STATE_SIZE], +) -> Array1 { + row[State4.base_table_index()] = trace_row[4]; + row[State5.base_table_index()] = trace_row[5]; + row[State6.base_table_index()] = trace_row[6]; + row[State7.base_table_index()] = trace_row[7]; + row[State8.base_table_index()] = trace_row[8]; + row[State9.base_table_index()] = trace_row[9]; + row[State10.base_table_index()] = trace_row[10]; + row[State11.base_table_index()] = trace_row[11]; + row[State12.base_table_index()] = trace_row[12]; + row[State13.base_table_index()] = trace_row[13]; + row[State14.base_table_index()] = trace_row[14]; + row[State15.base_table_index()] = trace_row[15]; + row +} + +fn fill_row_with_state_inverses_using_trace_row( + mut row: Array1, + trace_row: [BFieldElement; STATE_SIZE], +) -> Array1 { + row[State0Inv.base_table_index()] = inverse_or_zero_of_highest_2_limbs(trace_row[0]); + row[State1Inv.base_table_index()] = inverse_or_zero_of_highest_2_limbs(trace_row[1]); + row[State2Inv.base_table_index()] = inverse_or_zero_of_highest_2_limbs(trace_row[2]); + row[State3Inv.base_table_index()] = inverse_or_zero_of_highest_2_limbs(trace_row[3]); + row +} + +/// The inverse-or-zero of (2^32 - 1 - 2^16·`highest` - `mid_high`) where `highest` +/// is the most significant limb of the given `state_element`, and `mid_high` the second-most +/// significant limb. +fn inverse_or_zero_of_highest_2_limbs(state_element: BFieldElement) -> BFieldElement { + let limbs = base_field_element_into_16_bit_limbs(state_element); + let highest: u64 = limbs[3].into(); + let mid_high: u64 = limbs[2].into(); + let high_limbs = bfe!((highest << 16) + mid_high); + let two_pow_32_minus_1 = bfe!((1_u64 << 32) - 1); + let to_invert = two_pow_32_minus_1 - high_limbs; + to_invert.inverse_or_zero() +} + +fn fill_row_with_round_constants_for_round( + mut row: Array1, + round_number: usize, +) -> Array1 { + let round_constants = HashTable::tip5_round_constants_by_round_number(round_number); + let [r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13, r14, r15] = round_constants; + row[Constant0.base_table_index()] = r0; + row[Constant1.base_table_index()] = r1; + row[Constant2.base_table_index()] = r2; + row[Constant3.base_table_index()] = r3; + row[Constant4.base_table_index()] = r4; + row[Constant5.base_table_index()] = r5; + row[Constant6.base_table_index()] = r6; + row[Constant7.base_table_index()] = r7; + row[Constant8.base_table_index()] = r8; + row[Constant9.base_table_index()] = r9; + row[Constant10.base_table_index()] = r10; + row[Constant11.base_table_index()] = r11; + row[Constant12.base_table_index()] = r12; + row[Constant13.base_table_index()] = r13; + row[Constant14.base_table_index()] = r14; + row[Constant15.base_table_index()] = r15; + row +} + +impl TraceTable for HashTable { + type FillParam = (); + type FillReturnInfo = (); + + fn fill(mut main_table: ArrayViewMut2, aet: &AlgebraicExecutionTrace, _: ()) { + let program_hash_part_start = 0; + let program_hash_part_end = program_hash_part_start + aet.program_hash_trace.nrows(); + let sponge_part_start = program_hash_part_end; + let sponge_part_end = sponge_part_start + aet.sponge_trace.nrows(); + let hash_part_start = sponge_part_end; + let hash_part_end = hash_part_start + aet.hash_trace.nrows(); + + let (mut program_hash_part, mut sponge_part, mut hash_part) = main_table.multi_slice_mut(( + s![program_hash_part_start..program_hash_part_end, ..], + s![sponge_part_start..sponge_part_end, ..], + s![hash_part_start..hash_part_end, ..], + )); + + program_hash_part.assign(&aet.program_hash_trace); + sponge_part.assign(&aet.sponge_trace); + hash_part.assign(&aet.hash_trace); + + let mode_column_idx = Mode.base_table_index(); + let mut program_hash_mode_column = program_hash_part.column_mut(mode_column_idx); + let mut sponge_mode_column = sponge_part.column_mut(mode_column_idx); + let mut hash_mode_column = hash_part.column_mut(mode_column_idx); + + program_hash_mode_column.fill(HashTableMode::ProgramHashing.into()); + sponge_mode_column.fill(HashTableMode::Sponge.into()); + hash_mode_column.fill(HashTableMode::Hash.into()); + } + + fn pad(mut main_table: ArrayViewMut2, table_length: usize) { + let inverse_of_high_limbs = inverse_or_zero_of_highest_2_limbs(bfe!(0)); + for column_id in [State0Inv, State1Inv, State2Inv, State3Inv] { + let column_index = column_id.base_table_index(); + let slice_info = s![table_length.., column_index]; + let mut column = main_table.slice_mut(slice_info); + column.fill(inverse_of_high_limbs); + } + + let round_constants = Self::tip5_round_constants_by_round_number(0); + for (round_constant_idx, &round_constant) in round_constants.iter().enumerate() { + let round_constant_column = + HashTable::round_constant_column_by_index(round_constant_idx); + let round_constant_column_idx = round_constant_column.base_table_index(); + let slice_info = s![table_length.., round_constant_column_idx]; + let mut column = main_table.slice_mut(slice_info); + column.fill(round_constant); + } + + let mode_column_index = Mode.base_table_index(); + let mode_column_slice_info = s![table_length.., mode_column_index]; + let mut mode_column = main_table.slice_mut(mode_column_slice_info); + mode_column.fill(HashTableMode::Pad.into()); + + let instruction_column_index = CI.base_table_index(); + let instruction_column_slice_info = s![table_length.., instruction_column_index]; + let mut instruction_column = main_table.slice_mut(instruction_column_slice_info); + instruction_column.fill(Instruction::Hash.opcode_b()); + } + + fn extend( + main_table: ArrayView2, + mut aux_table: ArrayViewMut2, + challenges: &Challenges, + ) { + profiler!(start "hash table"); + assert_eq!(Self::MainColumn::COUNT, main_table.ncols()); + assert_eq!(Self::AuxColumn::COUNT, aux_table.ncols()); + assert_eq!(main_table.nrows(), aux_table.nrows()); + + let ci_weight = challenges[HashCIWeight]; + let hash_digest_eval_indeterminate = challenges[HashDigestIndeterminate]; + let hash_input_eval_indeterminate = challenges[HashInputIndeterminate]; + let sponge_eval_indeterminate = challenges[SpongeIndeterminate]; + let cascade_indeterminate = challenges[HashCascadeLookupIndeterminate]; + let send_chunk_indeterminate = challenges[ProgramAttestationSendChunkIndeterminate]; + + let mut hash_input_running_evaluation = EvalArg::default_initial(); + let mut hash_digest_running_evaluation = EvalArg::default_initial(); + let mut sponge_running_evaluation = EvalArg::default_initial(); + let mut cascade_state_0_highest_log_derivative = LookupArg::default_initial(); + let mut cascade_state_0_mid_high_log_derivative = LookupArg::default_initial(); + let mut cascade_state_0_mid_low_log_derivative = LookupArg::default_initial(); + let mut cascade_state_0_lowest_log_derivative = LookupArg::default_initial(); + let mut cascade_state_1_highest_log_derivative = LookupArg::default_initial(); + let mut cascade_state_1_mid_high_log_derivative = LookupArg::default_initial(); + let mut cascade_state_1_mid_low_log_derivative = LookupArg::default_initial(); + let mut cascade_state_1_lowest_log_derivative = LookupArg::default_initial(); + let mut cascade_state_2_highest_log_derivative = LookupArg::default_initial(); + let mut cascade_state_2_mid_high_log_derivative = LookupArg::default_initial(); + let mut cascade_state_2_mid_low_log_derivative = LookupArg::default_initial(); + let mut cascade_state_2_lowest_log_derivative = LookupArg::default_initial(); + let mut cascade_state_3_highest_log_derivative = LookupArg::default_initial(); + let mut cascade_state_3_mid_high_log_derivative = LookupArg::default_initial(); + let mut cascade_state_3_mid_low_log_derivative = LookupArg::default_initial(); + let mut cascade_state_3_lowest_log_derivative = LookupArg::default_initial(); + let mut receive_chunk_running_evaluation = EvalArg::default_initial(); + + let two_pow_16 = bfe!(1_u64 << 16); + let two_pow_32 = bfe!(1_u64 << 32); + let two_pow_48 = bfe!(1_u64 << 48); + + let montgomery_modulus_inverse = MONTGOMERY_MODULUS.inverse(); + let re_compose_state_element = + |row: ArrayView1, + highest: Self::MainColumn, + mid_high: Self::MainColumn, + mid_low: Self::MainColumn, + lowest: Self::MainColumn| { + (row[highest.base_table_index()] * two_pow_48 + + row[mid_high.base_table_index()] * two_pow_32 + + row[mid_low.base_table_index()] * two_pow_16 + + row[lowest.base_table_index()]) + * montgomery_modulus_inverse + }; + + let rate_registers = |row: ArrayView1| { + let state_0 = re_compose_state_element( + row, + State0HighestLkIn, + State0MidHighLkIn, + State0MidLowLkIn, + State0LowestLkIn, + ); + let state_1 = re_compose_state_element( + row, + State1HighestLkIn, + State1MidHighLkIn, + State1MidLowLkIn, + State1LowestLkIn, + ); + let state_2 = re_compose_state_element( + row, + State2HighestLkIn, + State2MidHighLkIn, + State2MidLowLkIn, + State2LowestLkIn, + ); + let state_3 = re_compose_state_element( + row, + State3HighestLkIn, + State3MidHighLkIn, + State3MidLowLkIn, + State3LowestLkIn, + ); + [ + state_0, + state_1, + state_2, + state_3, + row[State4.base_table_index()], + row[State5.base_table_index()], + row[State6.base_table_index()], + row[State7.base_table_index()], + row[State8.base_table_index()], + row[State9.base_table_index()], + ] + }; + + let state_weights = &challenges[StackWeight0..StackWeight10]; + let compressed_row = |row: ArrayView1| -> XFieldElement { + rate_registers(row) + .iter() + .zip_eq(state_weights.iter()) + .map(|(&state, &weight)| weight * state) + .sum() + }; + + let cascade_look_in_weight = challenges[HashCascadeLookInWeight]; + let cascade_look_out_weight = challenges[HashCascadeLookOutWeight]; + + let log_derivative_summand = + |row: ArrayView1, + lk_in_col: Self::MainColumn, + lk_out_col: Self::MainColumn| { + let compressed_elements = cascade_indeterminate + - cascade_look_in_weight * row[lk_in_col.base_table_index()] + - cascade_look_out_weight * row[lk_out_col.base_table_index()]; + compressed_elements.inverse() + }; + + for row_idx in 0..main_table.nrows() { + let row = main_table.row(row_idx); + + let mode = row[Mode.base_table_index()]; + let in_program_hashing_mode = mode == HashTableMode::ProgramHashing.into(); + let in_sponge_mode = mode == HashTableMode::Sponge.into(); + let in_hash_mode = mode == HashTableMode::Hash.into(); + let in_pad_mode = mode == HashTableMode::Pad.into(); + + let round_number = row[RoundNumber.base_table_index()]; + let in_round_0 = round_number.is_zero(); + let in_last_round = round_number == (NUM_ROUNDS as u64).into(); + + let current_instruction = row[CI.base_table_index()]; + let current_instruction_is_sponge_init = + current_instruction == Instruction::SpongeInit.opcode_b(); + + if in_program_hashing_mode && in_round_0 { + let compressed_chunk_of_instructions = EvalArg::compute_terminal( + &rate_registers(row), + EvalArg::default_initial(), + challenges[ProgramAttestationPrepareChunkIndeterminate], + ); + receive_chunk_running_evaluation = receive_chunk_running_evaluation + * send_chunk_indeterminate + + compressed_chunk_of_instructions + } + + if in_sponge_mode && in_round_0 && current_instruction_is_sponge_init { + sponge_running_evaluation = sponge_running_evaluation * sponge_eval_indeterminate + + ci_weight * current_instruction + } + + if in_sponge_mode && in_round_0 && !current_instruction_is_sponge_init { + sponge_running_evaluation = sponge_running_evaluation * sponge_eval_indeterminate + + ci_weight * current_instruction + + compressed_row(row) + } + + if in_hash_mode && in_round_0 { + hash_input_running_evaluation = hash_input_running_evaluation + * hash_input_eval_indeterminate + + compressed_row(row) + } + + if in_hash_mode && in_last_round { + let compressed_digest: XFieldElement = rate_registers(row)[..Digest::LEN] + .iter() + .zip_eq(state_weights[..Digest::LEN].iter()) + .map(|(&state, &weight)| weight * state) + .sum(); + hash_digest_running_evaluation = hash_digest_running_evaluation + * hash_digest_eval_indeterminate + + compressed_digest + } + + if !in_pad_mode && !in_last_round && !current_instruction_is_sponge_init { + cascade_state_0_highest_log_derivative += + log_derivative_summand(row, State0HighestLkIn, State0HighestLkOut); + cascade_state_0_mid_high_log_derivative += + log_derivative_summand(row, State0MidHighLkIn, State0MidHighLkOut); + cascade_state_0_mid_low_log_derivative += + log_derivative_summand(row, State0MidLowLkIn, State0MidLowLkOut); + cascade_state_0_lowest_log_derivative += + log_derivative_summand(row, State0LowestLkIn, State0LowestLkOut); + cascade_state_1_highest_log_derivative += + log_derivative_summand(row, State1HighestLkIn, State1HighestLkOut); + cascade_state_1_mid_high_log_derivative += + log_derivative_summand(row, State1MidHighLkIn, State1MidHighLkOut); + cascade_state_1_mid_low_log_derivative += + log_derivative_summand(row, State1MidLowLkIn, State1MidLowLkOut); + cascade_state_1_lowest_log_derivative += + log_derivative_summand(row, State1LowestLkIn, State1LowestLkOut); + cascade_state_2_highest_log_derivative += + log_derivative_summand(row, State2HighestLkIn, State2HighestLkOut); + cascade_state_2_mid_high_log_derivative += + log_derivative_summand(row, State2MidHighLkIn, State2MidHighLkOut); + cascade_state_2_mid_low_log_derivative += + log_derivative_summand(row, State2MidLowLkIn, State2MidLowLkOut); + cascade_state_2_lowest_log_derivative += + log_derivative_summand(row, State2LowestLkIn, State2LowestLkOut); + cascade_state_3_highest_log_derivative += + log_derivative_summand(row, State3HighestLkIn, State3HighestLkOut); + cascade_state_3_mid_high_log_derivative += + log_derivative_summand(row, State3MidHighLkIn, State3MidHighLkOut); + cascade_state_3_mid_low_log_derivative += + log_derivative_summand(row, State3MidLowLkIn, State3MidLowLkOut); + cascade_state_3_lowest_log_derivative += + log_derivative_summand(row, State3LowestLkIn, State3LowestLkOut); + } + + let mut extension_row = aux_table.row_mut(row_idx); + extension_row[ReceiveChunkRunningEvaluation.ext_table_index()] = + receive_chunk_running_evaluation; + extension_row[HashInputRunningEvaluation.ext_table_index()] = + hash_input_running_evaluation; + extension_row[HashDigestRunningEvaluation.ext_table_index()] = + hash_digest_running_evaluation; + extension_row[SpongeRunningEvaluation.ext_table_index()] = sponge_running_evaluation; + extension_row[CascadeState0HighestClientLogDerivative.ext_table_index()] = + cascade_state_0_highest_log_derivative; + extension_row[CascadeState0MidHighClientLogDerivative.ext_table_index()] = + cascade_state_0_mid_high_log_derivative; + extension_row[CascadeState0MidLowClientLogDerivative.ext_table_index()] = + cascade_state_0_mid_low_log_derivative; + extension_row[CascadeState0LowestClientLogDerivative.ext_table_index()] = + cascade_state_0_lowest_log_derivative; + extension_row[CascadeState1HighestClientLogDerivative.ext_table_index()] = + cascade_state_1_highest_log_derivative; + extension_row[CascadeState1MidHighClientLogDerivative.ext_table_index()] = + cascade_state_1_mid_high_log_derivative; + extension_row[CascadeState1MidLowClientLogDerivative.ext_table_index()] = + cascade_state_1_mid_low_log_derivative; + extension_row[CascadeState1LowestClientLogDerivative.ext_table_index()] = + cascade_state_1_lowest_log_derivative; + extension_row[CascadeState2HighestClientLogDerivative.ext_table_index()] = + cascade_state_2_highest_log_derivative; + extension_row[CascadeState2MidHighClientLogDerivative.ext_table_index()] = + cascade_state_2_mid_high_log_derivative; + extension_row[CascadeState2MidLowClientLogDerivative.ext_table_index()] = + cascade_state_2_mid_low_log_derivative; + extension_row[CascadeState2LowestClientLogDerivative.ext_table_index()] = + cascade_state_2_lowest_log_derivative; + extension_row[CascadeState3HighestClientLogDerivative.ext_table_index()] = + cascade_state_3_highest_log_derivative; + extension_row[CascadeState3MidHighClientLogDerivative.ext_table_index()] = + cascade_state_3_mid_high_log_derivative; + extension_row[CascadeState3MidLowClientLogDerivative.ext_table_index()] = + cascade_state_3_mid_low_log_derivative; + extension_row[CascadeState3LowestClientLogDerivative.ext_table_index()] = + cascade_state_3_lowest_log_derivative; + } + profiler!(stop "hash table"); + } +} + +#[cfg(test)] +pub(crate) mod tests { + use air::table::TableId; + use air::table_column::HashBaseTableColumn; + use air::AIR; + use std::collections::HashMap; + + use crate::shared_tests::ProgramAndInput; + use crate::stark::tests::master_tables_for_low_security_level; + use crate::table::master_table::MasterTable; + use crate::triton_asm; + use crate::triton_program; + use crate::vm::VM; + + use super::*; + + #[test] + fn hash_table_mode_discriminant_is_unique() { + let mut discriminants_and_modes = HashMap::new(); + for mode in HashTableMode::iter() { + let discriminant = u32::from(mode); + let maybe_entry = discriminants_and_modes.insert(discriminant, mode); + if let Some(entry) = maybe_entry { + panic!("Discriminant collision for {discriminant} between {entry} and {mode}."); + } + } + } + + #[test] + fn terminal_constraints_hold_for_sponge_init_edge_case() { + let many_sponge_inits = triton_asm![sponge_init; 23_631]; + let many_squeeze_absorbs = (0..2_100) + .flat_map(|_| triton_asm!(sponge_squeeze sponge_absorb)) + .collect_vec(); + let program = triton_program! { + {&many_sponge_inits} + {&many_squeeze_absorbs} + sponge_init + halt + }; + + let (aet, _) = VM::trace_execution(&program, [].into(), [].into()).unwrap(); + dbg!(aet.height()); + dbg!(aet.padded_height()); + dbg!(aet.height_of_table(TableId::Hash)); + dbg!(aet.height_of_table(TableId::OpStack)); + dbg!(aet.height_of_table(TableId::Cascade)); + + let (_, _, master_base_table, master_ext_table, challenges) = + master_tables_for_low_security_level(ProgramAndInput::new(program)); + let challenges = &challenges.challenges; + + let master_base_trace_table = master_base_table.trace_table(); + let master_ext_trace_table = master_ext_table.trace_table(); + + let last_row = master_base_trace_table.slice(s![-1.., ..]); + let last_opcode = last_row[[0, HashBaseTableColumn::CI.master_base_table_index()]]; + let last_instruction: Instruction = last_opcode.value().try_into().unwrap(); + assert_eq!(Instruction::SpongeInit, last_instruction); + + let circuit_builder = ConstraintCircuitBuilder::new(); + for (constraint_idx, constraint) in HashTable::terminal_constraints(&circuit_builder) + .into_iter() + .map(|constraint_monad| constraint_monad.consume()) + .enumerate() + { + let evaluated_constraint = constraint.evaluate( + master_base_trace_table.slice(s![-1.., ..]), + master_ext_trace_table.slice(s![-1.., ..]), + challenges, + ); + assert_eq!( + xfe!(0), + evaluated_constraint, + "Terminal constraint {constraint_idx} failed." + ); + } + } +} diff --git a/triton-vm/src/table/hash_table.rs b/triton-vm/src/table/hash_table.rs deleted file mode 100644 index 3f2c10792..000000000 --- a/triton-vm/src/table/hash_table.rs +++ /dev/null @@ -1,1926 +0,0 @@ -use constraint_circuit::ConstraintCircuitBuilder; -use constraint_circuit::ConstraintCircuitMonad; -use constraint_circuit::DualRowIndicator; -use constraint_circuit::DualRowIndicator::*; -use constraint_circuit::InputIndicator; -use constraint_circuit::SingleRowIndicator; -use constraint_circuit::SingleRowIndicator::*; -use isa::instruction::AnInstruction::Hash; -use isa::instruction::AnInstruction::SpongeAbsorb; -use isa::instruction::AnInstruction::SpongeInit; -use isa::instruction::AnInstruction::SpongeSqueeze; -use isa::instruction::Instruction; -use itertools::Itertools; -use ndarray::*; -use num_traits::Zero; -use strum::Display; -use strum::EnumCount; -use strum::EnumIter; -use strum::IntoEnumIterator; -use twenty_first::prelude::tip5::MDS_MATRIX_FIRST_COLUMN; -use twenty_first::prelude::tip5::NUM_ROUNDS; -use twenty_first::prelude::tip5::NUM_SPLIT_AND_LOOKUP; -use twenty_first::prelude::tip5::RATE; -use twenty_first::prelude::tip5::ROUND_CONSTANTS; -use twenty_first::prelude::tip5::STATE_SIZE; -use twenty_first::prelude::*; - -use crate::aet::AlgebraicExecutionTrace; -use crate::profiler::profiler; -use crate::table::cascade_table::CascadeTable; -use crate::table::challenges::ChallengeId::*; -use crate::table::challenges::Challenges; -use crate::table::cross_table_argument::CrossTableArg; -use crate::table::cross_table_argument::EvalArg; -use crate::table::cross_table_argument::LookupArg; -use crate::table::table_column::HashBaseTableColumn; -use crate::table::table_column::HashBaseTableColumn::*; -use crate::table::table_column::HashExtTableColumn; -use crate::table::table_column::HashExtTableColumn::*; -use crate::table::table_column::MasterBaseTableColumn; -use crate::table::table_column::MasterExtTableColumn; - -pub const BASE_WIDTH: usize = HashBaseTableColumn::COUNT; -pub const EXT_WIDTH: usize = HashExtTableColumn::COUNT; -pub const FULL_WIDTH: usize = BASE_WIDTH + EXT_WIDTH; - -/// See [`HashTable::base_field_element_into_16_bit_limbs`] for more details. -const MONTGOMERY_MODULUS: BFieldElement = - BFieldElement::new(((1_u128 << 64) % BFieldElement::P as u128) as u64); - -pub const POWER_MAP_EXPONENT: u64 = 7; -pub const NUM_ROUND_CONSTANTS: usize = STATE_SIZE; - -pub(crate) const PERMUTATION_TRACE_LENGTH: usize = NUM_ROUNDS + 1; - -pub type PermutationTrace = [[BFieldElement; STATE_SIZE]; PERMUTATION_TRACE_LENGTH]; - -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct HashTable; - -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct ExtHashTable; - -/// The current “mode” of the Hash Table. The Hash Table can be in one of four distinct modes: -/// -/// 1. Hashing the [`Program`][program]. This is part of program attestation. -/// 1. Processing all Sponge instructions, _i.e._, `sponge_init`, -/// `sponge_absorb`, `sponge_absorb_mem`, and `sponge_squeeze`. -/// 1. Processing the `hash` instruction. -/// 1. Padding mode. -/// -/// Changing the mode is only possible when the current [`RoundNumber`] is [`NUM_ROUNDS`]. -/// The mode evolves as -/// [`ProgramHashing`][prog_hash] → [`Sponge`][sponge] → [`Hash`][hash] → [`Pad`][pad]. -/// Once mode [`Pad`][pad] is reached, it is not possible to change the mode anymore. -/// Skipping any or all of the modes [`Sponge`][sponge], [`Hash`][hash], or [`Pad`][pad] -/// is possible in principle: -/// - if no Sponge instructions are executed, mode [`Sponge`][sponge] will be skipped, -/// - if no `hash` instruction is executed, mode [`Hash`][hash] will be skipped, and -/// - if the Hash Table does not require any padding, mode [`Pad`][pad] will be skipped. -/// -/// It is not possible to skip mode [`ProgramHashing`][prog_hash]: -/// the [`Program`][program] is always hashed. -/// The empty program is not valid since any valid [`Program`][program] must execute -/// instruction `halt`. -/// -/// [program]: isa::program::Program -/// [prog_hash]: HashTableMode::ProgramHashing -/// [sponge]: HashTableMode::Sponge -/// [hash]: type@HashTableMode::Hash -/// [pad]: HashTableMode::Pad -#[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] -pub enum HashTableMode { - /// The mode in which the [`Program`][program] is hashed. This is part of program attestation. - /// - /// [program]: isa::program::Program - ProgramHashing, - - /// The mode in which Sponge instructions, _i.e._, `sponge_init`, - /// `sponge_absorb`, `sponge_absorb_mem`, and `sponge_squeeze`, are processed. - Sponge, - - /// The mode in which the `hash` instruction is processed. - Hash, - - /// Indicator for padding rows. - Pad, -} - -impl From for u32 { - fn from(mode: HashTableMode) -> Self { - match mode { - HashTableMode::ProgramHashing => 1, - HashTableMode::Sponge => 2, - HashTableMode::Hash => 3, - HashTableMode::Pad => 0, - } - } -} - -impl From for u64 { - fn from(mode: HashTableMode) -> Self { - let discriminant: u32 = mode.into(); - discriminant.into() - } -} - -impl From for BFieldElement { - fn from(mode: HashTableMode) -> Self { - let discriminant: u32 = mode.into(); - discriminant.into() - } -} - -impl ExtHashTable { - /// Construct one of the states 0 through 3 from its constituent limbs. - /// For example, state 0 (prior to it being looked up in the split-and-lookup S-Box, which is - /// usually the desired version of the state) is constructed from limbs - /// [`State0HighestLkIn`] through [`State0LowestLkIn`]. - /// - /// States 4 through 15 are directly accessible. See also the slightly related - /// [`Self::state_column_by_index`]. - fn re_compose_16_bit_limbs( - circuit_builder: &ConstraintCircuitBuilder, - highest: ConstraintCircuitMonad, - mid_high: ConstraintCircuitMonad, - mid_low: ConstraintCircuitMonad, - lowest: ConstraintCircuitMonad, - ) -> ConstraintCircuitMonad { - let constant = |c: u64| circuit_builder.b_constant(c); - let montgomery_modulus_inv = circuit_builder.b_constant(MONTGOMERY_MODULUS.inverse()); - - let sum_of_shifted_limbs = highest * constant(1 << 48) - + mid_high * constant(1 << 32) - + mid_low * constant(1 << 16) - + lowest; - sum_of_shifted_limbs * montgomery_modulus_inv - } - - pub fn initial_constraints( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let challenge = |c| circuit_builder.challenge(c); - let constant = |c: u64| circuit_builder.b_constant(c); - - let base_row = |column: HashBaseTableColumn| { - circuit_builder.input(BaseRow(column.master_base_table_index())) - }; - let ext_row = |column: HashExtTableColumn| { - circuit_builder.input(ExtRow(column.master_ext_table_index())) - }; - - let running_evaluation_initial = circuit_builder.x_constant(EvalArg::default_initial()); - let lookup_arg_default_initial = circuit_builder.x_constant(LookupArg::default_initial()); - - let mode = base_row(Mode); - let running_evaluation_hash_input = ext_row(HashInputRunningEvaluation); - let running_evaluation_hash_digest = ext_row(HashDigestRunningEvaluation); - let running_evaluation_sponge = ext_row(SpongeRunningEvaluation); - let running_evaluation_receive_chunk = ext_row(ReceiveChunkRunningEvaluation); - - let cascade_indeterminate = challenge(HashCascadeLookupIndeterminate); - let look_in_weight = challenge(HashCascadeLookInWeight); - let look_out_weight = challenge(HashCascadeLookOutWeight); - let prepare_chunk_indeterminate = challenge(ProgramAttestationPrepareChunkIndeterminate); - let receive_chunk_indeterminate = challenge(ProgramAttestationSendChunkIndeterminate); - - // First chunk of the program is received correctly. Relates to program attestation. - let [state_0, state_1, state_2, state_3] = - Self::re_compose_states_0_through_3_before_lookup( - circuit_builder, - Self::indicate_column_index_in_base_row, - ); - let state_rate_part: [_; RATE] = [ - state_0, - state_1, - state_2, - state_3, - base_row(State4), - base_row(State5), - base_row(State6), - base_row(State7), - base_row(State8), - base_row(State9), - ]; - let compressed_chunk = state_rate_part - .into_iter() - .fold(running_evaluation_initial.clone(), |acc, state_element| { - acc * prepare_chunk_indeterminate.clone() + state_element - }); - let running_evaluation_receive_chunk_is_initialized_correctly = - running_evaluation_receive_chunk - - receive_chunk_indeterminate * running_evaluation_initial.clone() - - compressed_chunk; - - // The lookup arguments with the Cascade Table for the S-Boxes are initialized correctly. - let cascade_log_derivative_init_circuit = - |look_in_column, look_out_column, cascade_log_derivative_column| { - let look_in = base_row(look_in_column); - let look_out = base_row(look_out_column); - let compressed_row = - look_in_weight.clone() * look_in + look_out_weight.clone() * look_out; - let cascade_log_derivative = ext_row(cascade_log_derivative_column); - (cascade_log_derivative - lookup_arg_default_initial.clone()) - * (cascade_indeterminate.clone() - compressed_row) - - constant(1) - }; - - // miscellaneous initial constraints - let mode_is_program_hashing = - Self::select_mode(circuit_builder, &mode, HashTableMode::ProgramHashing); - let round_number_is_0 = base_row(RoundNumber); - let running_evaluation_hash_input_is_default_initial = - running_evaluation_hash_input - running_evaluation_initial.clone(); - let running_evaluation_hash_digest_is_default_initial = - running_evaluation_hash_digest - running_evaluation_initial.clone(); - let running_evaluation_sponge_is_default_initial = - running_evaluation_sponge - running_evaluation_initial; - - vec![ - mode_is_program_hashing, - round_number_is_0, - running_evaluation_hash_input_is_default_initial, - running_evaluation_hash_digest_is_default_initial, - running_evaluation_sponge_is_default_initial, - running_evaluation_receive_chunk_is_initialized_correctly, - cascade_log_derivative_init_circuit( - State0HighestLkIn, - State0HighestLkOut, - CascadeState0HighestClientLogDerivative, - ), - cascade_log_derivative_init_circuit( - State0MidHighLkIn, - State0MidHighLkOut, - CascadeState0MidHighClientLogDerivative, - ), - cascade_log_derivative_init_circuit( - State0MidLowLkIn, - State0MidLowLkOut, - CascadeState0MidLowClientLogDerivative, - ), - cascade_log_derivative_init_circuit( - State0LowestLkIn, - State0LowestLkOut, - CascadeState0LowestClientLogDerivative, - ), - cascade_log_derivative_init_circuit( - State1HighestLkIn, - State1HighestLkOut, - CascadeState1HighestClientLogDerivative, - ), - cascade_log_derivative_init_circuit( - State1MidHighLkIn, - State1MidHighLkOut, - CascadeState1MidHighClientLogDerivative, - ), - cascade_log_derivative_init_circuit( - State1MidLowLkIn, - State1MidLowLkOut, - CascadeState1MidLowClientLogDerivative, - ), - cascade_log_derivative_init_circuit( - State1LowestLkIn, - State1LowestLkOut, - CascadeState1LowestClientLogDerivative, - ), - cascade_log_derivative_init_circuit( - State2HighestLkIn, - State2HighestLkOut, - CascadeState2HighestClientLogDerivative, - ), - cascade_log_derivative_init_circuit( - State2MidHighLkIn, - State2MidHighLkOut, - CascadeState2MidHighClientLogDerivative, - ), - cascade_log_derivative_init_circuit( - State2MidLowLkIn, - State2MidLowLkOut, - CascadeState2MidLowClientLogDerivative, - ), - cascade_log_derivative_init_circuit( - State2LowestLkIn, - State2LowestLkOut, - CascadeState2LowestClientLogDerivative, - ), - cascade_log_derivative_init_circuit( - State3HighestLkIn, - State3HighestLkOut, - CascadeState3HighestClientLogDerivative, - ), - cascade_log_derivative_init_circuit( - State3MidHighLkIn, - State3MidHighLkOut, - CascadeState3MidHighClientLogDerivative, - ), - cascade_log_derivative_init_circuit( - State3MidLowLkIn, - State3MidLowLkOut, - CascadeState3MidLowClientLogDerivative, - ), - cascade_log_derivative_init_circuit( - State3LowestLkIn, - State3LowestLkOut, - CascadeState3LowestClientLogDerivative, - ), - ] - } - - /// A constraint circuit evaluating to zero if and only if the given - /// `round_number_circuit_node` is not equal to the given `round_number_to_deselect`. - fn round_number_deselector( - circuit_builder: &ConstraintCircuitBuilder, - round_number_circuit_node: &ConstraintCircuitMonad, - round_number_to_deselect: usize, - ) -> ConstraintCircuitMonad { - assert!( - round_number_to_deselect <= NUM_ROUNDS, - "Round number must be in [0, {NUM_ROUNDS}] but got {round_number_to_deselect}." - ); - let constant = |c: u64| circuit_builder.b_constant(c); - - // To not subtract zero from the first factor: some special casing. - let first_factor = match round_number_to_deselect { - 0 => constant(1), - _ => round_number_circuit_node.clone(), - }; - (1..=NUM_ROUNDS) - .filter(|&r| r != round_number_to_deselect) - .map(|r| round_number_circuit_node.clone() - constant(r as u64)) - .fold(first_factor, |a, b| a * b) - } - - /// A constraint circuit evaluating to zero if and only if the given `mode_circuit_node` is - /// equal to the given `mode_to_select`. - fn select_mode( - circuit_builder: &ConstraintCircuitBuilder, - mode_circuit_node: &ConstraintCircuitMonad, - mode_to_select: HashTableMode, - ) -> ConstraintCircuitMonad { - mode_circuit_node.clone() - circuit_builder.b_constant(mode_to_select) - } - - /// A constraint circuit evaluating to zero if and only if the given `mode_circuit_node` is - /// not equal to the given `mode_to_deselect`. - fn mode_deselector( - circuit_builder: &ConstraintCircuitBuilder, - mode_circuit_node: &ConstraintCircuitMonad, - mode_to_deselect: HashTableMode, - ) -> ConstraintCircuitMonad { - let constant = |c: u64| circuit_builder.b_constant(c); - HashTableMode::iter() - .filter(|&mode| mode != mode_to_deselect) - .map(|mode| mode_circuit_node.clone() - constant(mode.into())) - .fold(constant(1), |accumulator, factor| accumulator * factor) - } - - fn instruction_deselector( - circuit_builder: &ConstraintCircuitBuilder, - current_instruction_node: &ConstraintCircuitMonad, - instruction_to_deselect: Instruction, - ) -> ConstraintCircuitMonad { - let constant = |c: u64| circuit_builder.b_constant(c); - let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b()); - let relevant_instructions = [Hash, SpongeInit, SpongeAbsorb, SpongeSqueeze]; - assert!(relevant_instructions.contains(&instruction_to_deselect)); - - relevant_instructions - .iter() - .filter(|&instruction| instruction != &instruction_to_deselect) - .map(|&instruction| current_instruction_node.clone() - opcode(instruction)) - .fold(constant(1), |accumulator, factor| accumulator * factor) - } - - pub fn consistency_constraints( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b()); - let constant = |c: u64| circuit_builder.b_constant(c); - let base_row = |column_id: HashBaseTableColumn| { - circuit_builder.input(BaseRow(column_id.master_base_table_index())) - }; - - let mode = base_row(Mode); - let ci = base_row(CI); - let round_number = base_row(RoundNumber); - - let ci_is_hash = ci.clone() - opcode(Hash); - let ci_is_sponge_init = ci.clone() - opcode(SpongeInit); - let ci_is_sponge_absorb = ci.clone() - opcode(SpongeAbsorb); - let ci_is_sponge_squeeze = ci - opcode(SpongeSqueeze); - - let mode_is_not_hash = Self::mode_deselector(circuit_builder, &mode, HashTableMode::Hash); - let round_number_is_not_0 = - Self::round_number_deselector(circuit_builder, &round_number, 0); - - let mode_is_a_valid_mode = - Self::mode_deselector(circuit_builder, &mode, HashTableMode::Pad) - * Self::select_mode(circuit_builder, &mode, HashTableMode::Pad); - - let if_mode_is_not_sponge_then_ci_is_hash = - Self::select_mode(circuit_builder, &mode, HashTableMode::Sponge) * ci_is_hash.clone(); - - let if_mode_is_sponge_then_ci_is_a_sponge_instruction = - Self::mode_deselector(circuit_builder, &mode, HashTableMode::Sponge) - * ci_is_sponge_init - * ci_is_sponge_absorb.clone() - * ci_is_sponge_squeeze.clone(); - - let if_padding_mode_then_round_number_is_0 = - Self::mode_deselector(circuit_builder, &mode, HashTableMode::Pad) - * round_number.clone(); - - let if_ci_is_sponge_init_then_ = ci_is_hash * ci_is_sponge_absorb * ci_is_sponge_squeeze; - let if_ci_is_sponge_init_then_round_number_is_0 = - if_ci_is_sponge_init_then_.clone() * round_number.clone(); - - let if_ci_is_sponge_init_then_rate_is_0 = (10..=15).map(|state_index| { - let state_element = base_row(Self::state_column_by_index(state_index)); - if_ci_is_sponge_init_then_.clone() * state_element - }); - - let if_mode_is_hash_and_round_no_is_0_then_ = round_number_is_not_0 * mode_is_not_hash; - let if_mode_is_hash_and_round_no_is_0_then_states_10_through_15_are_1 = - (10..=15).map(|state_index| { - let state_element = base_row(Self::state_column_by_index(state_index)); - if_mode_is_hash_and_round_no_is_0_then_.clone() * (state_element - constant(1)) - }); - - // consistency of the inverse of the highest 2 limbs minus 2^32 - 1 - let one = constant(1); - let two_pow_16 = constant(1 << 16); - let two_pow_32 = constant(1 << 32); - let state_0_hi_limbs_minus_2_pow_32 = two_pow_32.clone() - - one.clone() - - base_row(State0HighestLkIn) * two_pow_16.clone() - - base_row(State0MidHighLkIn); - let state_1_hi_limbs_minus_2_pow_32 = two_pow_32.clone() - - one.clone() - - base_row(State1HighestLkIn) * two_pow_16.clone() - - base_row(State1MidHighLkIn); - let state_2_hi_limbs_minus_2_pow_32 = two_pow_32.clone() - - one.clone() - - base_row(State2HighestLkIn) * two_pow_16.clone() - - base_row(State2MidHighLkIn); - let state_3_hi_limbs_minus_2_pow_32 = two_pow_32 - - one.clone() - - base_row(State3HighestLkIn) * two_pow_16.clone() - - base_row(State3MidHighLkIn); - - let state_0_hi_limbs_inv = base_row(State0Inv); - let state_1_hi_limbs_inv = base_row(State1Inv); - let state_2_hi_limbs_inv = base_row(State2Inv); - let state_3_hi_limbs_inv = base_row(State3Inv); - - let state_0_hi_limbs_are_not_all_1s = - state_0_hi_limbs_minus_2_pow_32.clone() * state_0_hi_limbs_inv.clone() - one.clone(); - let state_1_hi_limbs_are_not_all_1s = - state_1_hi_limbs_minus_2_pow_32.clone() * state_1_hi_limbs_inv.clone() - one.clone(); - let state_2_hi_limbs_are_not_all_1s = - state_2_hi_limbs_minus_2_pow_32.clone() * state_2_hi_limbs_inv.clone() - one.clone(); - let state_3_hi_limbs_are_not_all_1s = - state_3_hi_limbs_minus_2_pow_32.clone() * state_3_hi_limbs_inv.clone() - one; - - let state_0_hi_limbs_inv_is_inv_or_is_zero = - state_0_hi_limbs_are_not_all_1s.clone() * state_0_hi_limbs_inv; - let state_1_hi_limbs_inv_is_inv_or_is_zero = - state_1_hi_limbs_are_not_all_1s.clone() * state_1_hi_limbs_inv; - let state_2_hi_limbs_inv_is_inv_or_is_zero = - state_2_hi_limbs_are_not_all_1s.clone() * state_2_hi_limbs_inv; - let state_3_hi_limbs_inv_is_inv_or_is_zero = - state_3_hi_limbs_are_not_all_1s.clone() * state_3_hi_limbs_inv; - - let state_0_hi_limbs_inv_is_inv_or_state_0_hi_limbs_is_zero = - state_0_hi_limbs_are_not_all_1s.clone() * state_0_hi_limbs_minus_2_pow_32; - let state_1_hi_limbs_inv_is_inv_or_state_1_hi_limbs_is_zero = - state_1_hi_limbs_are_not_all_1s.clone() * state_1_hi_limbs_minus_2_pow_32; - let state_2_hi_limbs_inv_is_inv_or_state_2_hi_limbs_is_zero = - state_2_hi_limbs_are_not_all_1s.clone() * state_2_hi_limbs_minus_2_pow_32; - let state_3_hi_limbs_inv_is_inv_or_state_3_hi_limbs_is_zero = - state_3_hi_limbs_are_not_all_1s.clone() * state_3_hi_limbs_minus_2_pow_32; - - // consistent decomposition into limbs - let state_0_lo_limbs = - base_row(State0MidLowLkIn) * two_pow_16.clone() + base_row(State0LowestLkIn); - let state_1_lo_limbs = - base_row(State1MidLowLkIn) * two_pow_16.clone() + base_row(State1LowestLkIn); - let state_2_lo_limbs = - base_row(State2MidLowLkIn) * two_pow_16.clone() + base_row(State2LowestLkIn); - let state_3_lo_limbs = base_row(State3MidLowLkIn) * two_pow_16 + base_row(State3LowestLkIn); - - let if_state_0_hi_limbs_are_all_1_then_state_0_lo_limbs_are_all_0 = - state_0_hi_limbs_are_not_all_1s * state_0_lo_limbs; - let if_state_1_hi_limbs_are_all_1_then_state_1_lo_limbs_are_all_0 = - state_1_hi_limbs_are_not_all_1s * state_1_lo_limbs; - let if_state_2_hi_limbs_are_all_1_then_state_2_lo_limbs_are_all_0 = - state_2_hi_limbs_are_not_all_1s * state_2_lo_limbs; - let if_state_3_hi_limbs_are_all_1_then_state_3_lo_limbs_are_all_0 = - state_3_hi_limbs_are_not_all_1s * state_3_lo_limbs; - - let mut constraints = vec![ - mode_is_a_valid_mode, - if_mode_is_not_sponge_then_ci_is_hash, - if_mode_is_sponge_then_ci_is_a_sponge_instruction, - if_padding_mode_then_round_number_is_0, - if_ci_is_sponge_init_then_round_number_is_0, - state_0_hi_limbs_inv_is_inv_or_is_zero, - state_1_hi_limbs_inv_is_inv_or_is_zero, - state_2_hi_limbs_inv_is_inv_or_is_zero, - state_3_hi_limbs_inv_is_inv_or_is_zero, - state_0_hi_limbs_inv_is_inv_or_state_0_hi_limbs_is_zero, - state_1_hi_limbs_inv_is_inv_or_state_1_hi_limbs_is_zero, - state_2_hi_limbs_inv_is_inv_or_state_2_hi_limbs_is_zero, - state_3_hi_limbs_inv_is_inv_or_state_3_hi_limbs_is_zero, - if_state_0_hi_limbs_are_all_1_then_state_0_lo_limbs_are_all_0, - if_state_1_hi_limbs_are_all_1_then_state_1_lo_limbs_are_all_0, - if_state_2_hi_limbs_are_all_1_then_state_2_lo_limbs_are_all_0, - if_state_3_hi_limbs_are_all_1_then_state_3_lo_limbs_are_all_0, - ]; - - constraints.extend(if_ci_is_sponge_init_then_rate_is_0); - constraints.extend(if_mode_is_hash_and_round_no_is_0_then_states_10_through_15_are_1); - - for round_constant_column_idx in 0..NUM_ROUND_CONSTANTS { - let round_constant_column = - Self::round_constant_column_by_index(round_constant_column_idx); - let round_constant_column_circuit = base_row(round_constant_column); - let mut round_constant_constraint_circuit = constant(0); - for round_idx in 0..NUM_ROUNDS { - let round_constants = HashTable::tip5_round_constants_by_round_number(round_idx); - let round_constant = round_constants[round_constant_column_idx]; - let round_constant = circuit_builder.b_constant(round_constant); - let round_deselector_circuit = - Self::round_number_deselector(circuit_builder, &round_number, round_idx); - round_constant_constraint_circuit = round_constant_constraint_circuit - + round_deselector_circuit - * (round_constant_column_circuit.clone() - round_constant); - } - constraints.push(round_constant_constraint_circuit); - } - - constraints - } - - /// The [`HashBaseTableColumn`] for the round constant corresponding to the given index. - /// Valid indices are 0 through 15, corresponding to the 16 round constants - /// [`Constant0`] through [`Constant15`]. - fn round_constant_column_by_index(index: usize) -> HashBaseTableColumn { - match index { - 0 => Constant0, - 1 => Constant1, - 2 => Constant2, - 3 => Constant3, - 4 => Constant4, - 5 => Constant5, - 6 => Constant6, - 7 => Constant7, - 8 => Constant8, - 9 => Constant9, - 10 => Constant10, - 11 => Constant11, - 12 => Constant12, - 13 => Constant13, - 14 => Constant14, - 15 => Constant15, - _ => panic!("invalid constant column index"), - } - } - - /// The [`HashBaseTableColumn`] for the state corresponding to the given index. - /// Valid indices are 4 through 15, corresponding to the 12 state columns - /// [`State4`] through [`State15`]. - /// - /// States with indices 0 through 3 have to be assembled from the respective limbs; - /// see [`Self::re_compose_states_0_through_3_before_lookup`] - /// or [`Self::re_compose_16_bit_limbs`]. - fn state_column_by_index(index: usize) -> HashBaseTableColumn { - match index { - 4 => State4, - 5 => State5, - 6 => State6, - 7 => State7, - 8 => State8, - 9 => State9, - 10 => State10, - 11 => State11, - 12 => State12, - 13 => State13, - 14 => State14, - 15 => State15, - _ => panic!("invalid state column index"), - } - } - - pub fn transition_constraints( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let challenge = |c| circuit_builder.challenge(c); - let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b()); - let constant = |c: u64| circuit_builder.b_constant(c); - - let opcode_hash = opcode(Hash); - let opcode_sponge_init = opcode(SpongeInit); - let opcode_sponge_absorb = opcode(SpongeAbsorb); - let opcode_sponge_squeeze = opcode(SpongeSqueeze); - - let current_base_row = |column_idx: HashBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(column_idx.master_base_table_index())) - }; - let next_base_row = |column_idx: HashBaseTableColumn| { - circuit_builder.input(NextBaseRow(column_idx.master_base_table_index())) - }; - let current_ext_row = |column_idx: HashExtTableColumn| { - circuit_builder.input(CurrentExtRow(column_idx.master_ext_table_index())) - }; - let next_ext_row = |column_idx: HashExtTableColumn| { - circuit_builder.input(NextExtRow(column_idx.master_ext_table_index())) - }; - - let running_evaluation_initial = circuit_builder.x_constant(EvalArg::default_initial()); - - let prepare_chunk_indeterminate = challenge(ProgramAttestationPrepareChunkIndeterminate); - let receive_chunk_indeterminate = challenge(ProgramAttestationSendChunkIndeterminate); - let compress_program_digest_indeterminate = challenge(CompressProgramDigestIndeterminate); - let expected_program_digest = challenge(CompressedProgramDigest); - let hash_input_eval_indeterminate = challenge(HashInputIndeterminate); - let hash_digest_eval_indeterminate = challenge(HashDigestIndeterminate); - let sponge_indeterminate = challenge(SpongeIndeterminate); - - let mode = current_base_row(Mode); - let ci = current_base_row(CI); - let round_number = current_base_row(RoundNumber); - let running_evaluation_receive_chunk = current_ext_row(ReceiveChunkRunningEvaluation); - let running_evaluation_hash_input = current_ext_row(HashInputRunningEvaluation); - let running_evaluation_hash_digest = current_ext_row(HashDigestRunningEvaluation); - let running_evaluation_sponge = current_ext_row(SpongeRunningEvaluation); - - let mode_next = next_base_row(Mode); - let ci_next = next_base_row(CI); - let round_number_next = next_base_row(RoundNumber); - let running_evaluation_receive_chunk_next = next_ext_row(ReceiveChunkRunningEvaluation); - let running_evaluation_hash_input_next = next_ext_row(HashInputRunningEvaluation); - let running_evaluation_hash_digest_next = next_ext_row(HashDigestRunningEvaluation); - let running_evaluation_sponge_next = next_ext_row(SpongeRunningEvaluation); - - let [state_0, state_1, state_2, state_3] = - Self::re_compose_states_0_through_3_before_lookup( - circuit_builder, - Self::indicate_column_index_in_current_base_row, - ); - - let state_current = [ - state_0, - state_1, - state_2, - state_3, - current_base_row(State4), - current_base_row(State5), - current_base_row(State6), - current_base_row(State7), - current_base_row(State8), - current_base_row(State9), - current_base_row(State10), - current_base_row(State11), - current_base_row(State12), - current_base_row(State13), - current_base_row(State14), - current_base_row(State15), - ]; - - let (state_next, hash_function_round_correctly_performs_update) = - Self::tip5_constraints_as_circuits(circuit_builder); - - let state_weights = [ - StackWeight0, - StackWeight1, - StackWeight2, - StackWeight3, - StackWeight4, - StackWeight5, - StackWeight6, - StackWeight7, - StackWeight8, - StackWeight9, - StackWeight10, - StackWeight11, - StackWeight12, - StackWeight13, - StackWeight14, - StackWeight15, - ] - .map(challenge); - - let round_number_is_not_num_rounds = - Self::round_number_deselector(circuit_builder, &round_number, NUM_ROUNDS); - - let round_number_is_0_through_4_or_round_number_next_is_0 = - round_number_is_not_num_rounds * round_number_next.clone(); - - let next_mode_is_padding_mode_or_round_number_is_num_rounds_or_increments_by_one = - Self::select_mode(circuit_builder, &mode_next, HashTableMode::Pad) - * (ci.clone() - opcode_sponge_init.clone()) - * (round_number.clone() - constant(NUM_ROUNDS as u64)) - * (round_number_next.clone() - round_number.clone() - constant(1)); - - // compress the digest by computing the terminal of an evaluation argument - let compressed_digest = state_current[..Digest::LEN].iter().fold( - running_evaluation_initial.clone(), - |acc, digest_element| { - acc * compress_program_digest_indeterminate.clone() + digest_element.clone() - }, - ); - let if_mode_changes_from_program_hashing_then_current_digest_is_expected_program_digest = - Self::mode_deselector(circuit_builder, &mode, HashTableMode::ProgramHashing) - * Self::select_mode(circuit_builder, &mode_next, HashTableMode::ProgramHashing) - * (compressed_digest - expected_program_digest); - - let if_mode_is_program_hashing_and_next_mode_is_sponge_then_ci_next_is_sponge_init = - Self::mode_deselector(circuit_builder, &mode, HashTableMode::ProgramHashing) - * Self::mode_deselector(circuit_builder, &mode_next, HashTableMode::Sponge) - * (ci_next.clone() - opcode_sponge_init.clone()); - - let if_round_number_is_not_max_and_ci_is_not_sponge_init_then_ci_doesnt_change = - (round_number.clone() - constant(NUM_ROUNDS as u64)) - * (ci.clone() - opcode_sponge_init.clone()) - * (ci_next.clone() - ci.clone()); - - let if_round_number_is_not_max_and_ci_is_not_sponge_init_then_mode_doesnt_change = - (round_number - constant(NUM_ROUNDS as u64)) - * (ci.clone() - opcode_sponge_init.clone()) - * (mode_next.clone() - mode.clone()); - - let if_mode_is_sponge_then_mode_next_is_sponge_or_hash_or_pad = - Self::mode_deselector(circuit_builder, &mode, HashTableMode::Sponge) - * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Sponge) - * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Hash) - * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Pad); - - let if_mode_is_hash_then_mode_next_is_hash_or_pad = - Self::mode_deselector(circuit_builder, &mode, HashTableMode::Hash) - * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Hash) - * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Pad); - - let if_mode_is_pad_then_mode_next_is_pad = - Self::mode_deselector(circuit_builder, &mode, HashTableMode::Pad) - * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Pad); - - let difference_of_capacity_registers = state_current[RATE..] - .iter() - .zip_eq(state_next[RATE..].iter()) - .map(|(current, next)| next.clone() - current.clone()) - .collect_vec(); - let randomized_sum_of_capacity_differences = state_weights[RATE..] - .iter() - .zip_eq(difference_of_capacity_registers) - .map(|(weight, state_difference)| weight.clone() * state_difference) - .sum::>(); - - let capacity_doesnt_change_at_section_start_when_program_hashing_or_absorbing = - Self::round_number_deselector(circuit_builder, &round_number_next, 0) - * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Hash) - * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Pad) - * (ci_next.clone() - opcode_sponge_init.clone()) - * randomized_sum_of_capacity_differences.clone(); - - let difference_of_state_registers = state_current - .iter() - .zip_eq(state_next.iter()) - .map(|(current, next)| next.clone() - current.clone()) - .collect_vec(); - let randomized_sum_of_state_differences = state_weights - .iter() - .zip_eq(difference_of_state_registers.iter()) - .map(|(weight, state_difference)| weight.clone() * state_difference.clone()) - .sum(); - let if_round_number_next_is_0_and_ci_next_is_squeeze_then_state_doesnt_change = - Self::round_number_deselector(circuit_builder, &round_number_next, 0) - * Self::instruction_deselector(circuit_builder, &ci_next, SpongeSqueeze) - * randomized_sum_of_state_differences; - - // Evaluation Arguments - - // If (and only if) the row number in the next row is 0 and the mode in the next row is - // `hash`, update running evaluation “hash input.” - let running_evaluation_hash_input_remains = - running_evaluation_hash_input_next.clone() - running_evaluation_hash_input.clone(); - let tip5_input = state_next[..RATE].to_owned(); - let compressed_row_from_processor = tip5_input - .into_iter() - .zip_eq(state_weights[..RATE].iter()) - .map(|(state, weight)| weight.clone() * state) - .sum(); - - let running_evaluation_hash_input_updates = running_evaluation_hash_input_next - - hash_input_eval_indeterminate * running_evaluation_hash_input - - compressed_row_from_processor; - let running_evaluation_hash_input_is_updated_correctly = - Self::round_number_deselector(circuit_builder, &round_number_next, 0) - * Self::mode_deselector(circuit_builder, &mode_next, HashTableMode::Hash) - * running_evaluation_hash_input_updates - + round_number_next.clone() * running_evaluation_hash_input_remains.clone() - + Self::select_mode(circuit_builder, &mode_next, HashTableMode::Hash) - * running_evaluation_hash_input_remains; - - // If (and only if) the row number in the next row is NUM_ROUNDS and the current instruction - // in the next row corresponds to `hash`, update running evaluation “hash digest.” - let round_number_next_is_num_rounds = - round_number_next.clone() - constant(NUM_ROUNDS as u64); - let running_evaluation_hash_digest_remains = - running_evaluation_hash_digest_next.clone() - running_evaluation_hash_digest.clone(); - let hash_digest = state_next[..Digest::LEN].to_owned(); - let compressed_row_hash_digest = hash_digest - .into_iter() - .zip_eq(state_weights[..Digest::LEN].iter()) - .map(|(state, weight)| weight.clone() * state) - .sum(); - let running_evaluation_hash_digest_updates = running_evaluation_hash_digest_next - - hash_digest_eval_indeterminate * running_evaluation_hash_digest - - compressed_row_hash_digest; - let running_evaluation_hash_digest_is_updated_correctly = - Self::round_number_deselector(circuit_builder, &round_number_next, NUM_ROUNDS) - * Self::mode_deselector(circuit_builder, &mode_next, HashTableMode::Hash) - * running_evaluation_hash_digest_updates - + round_number_next_is_num_rounds * running_evaluation_hash_digest_remains.clone() - + Self::select_mode(circuit_builder, &mode_next, HashTableMode::Hash) - * running_evaluation_hash_digest_remains; - - // The running evaluation for “Sponge” updates correctly. - let compressed_row_next = state_weights[..RATE] - .iter() - .zip_eq(state_next[..RATE].iter()) - .map(|(weight, st_next)| weight.clone() * st_next.clone()) - .sum(); - let running_evaluation_sponge_has_accumulated_ci = running_evaluation_sponge_next.clone() - - sponge_indeterminate * running_evaluation_sponge.clone() - - challenge(HashCIWeight) * ci_next.clone(); - let running_evaluation_sponge_has_accumulated_next_row = - running_evaluation_sponge_has_accumulated_ci.clone() - compressed_row_next; - let if_round_no_next_0_and_ci_next_is_spongy_then_running_evaluation_sponge_updates = - Self::round_number_deselector(circuit_builder, &round_number_next, 0) - * (ci_next.clone() - opcode_hash) - * running_evaluation_sponge_has_accumulated_next_row; - - let running_evaluation_sponge_remains = - running_evaluation_sponge_next - running_evaluation_sponge; - let if_round_no_next_is_not_0_then_running_evaluation_sponge_remains = - round_number_next.clone() * running_evaluation_sponge_remains.clone(); - let if_ci_next_is_not_spongy_then_running_evaluation_sponge_remains = (ci_next.clone() - - opcode_sponge_init) - * (ci_next.clone() - opcode_sponge_absorb) - * (ci_next - opcode_sponge_squeeze) - * running_evaluation_sponge_remains; - let running_evaluation_sponge_is_updated_correctly = - if_round_no_next_0_and_ci_next_is_spongy_then_running_evaluation_sponge_updates - + if_round_no_next_is_not_0_then_running_evaluation_sponge_remains - + if_ci_next_is_not_spongy_then_running_evaluation_sponge_remains; - - // program attestation: absorb RATE instructions if in the right mode on the right row - let compressed_chunk = state_next[..RATE] - .iter() - .fold(running_evaluation_initial, |acc, rate_element| { - acc * prepare_chunk_indeterminate.clone() + rate_element.clone() - }); - let receive_chunk_running_evaluation_absorbs_chunk_of_instructions = - running_evaluation_receive_chunk_next.clone() - - receive_chunk_indeterminate * running_evaluation_receive_chunk.clone() - - compressed_chunk; - let receive_chunk_running_evaluation_remains = - running_evaluation_receive_chunk_next - running_evaluation_receive_chunk; - let receive_chunk_of_instructions_iff_next_mode_is_prog_hashing_and_next_round_number_is_0 = - Self::round_number_deselector(circuit_builder, &round_number_next, 0) - * Self::mode_deselector(circuit_builder, &mode_next, HashTableMode::ProgramHashing) - * receive_chunk_running_evaluation_absorbs_chunk_of_instructions - + round_number_next * receive_chunk_running_evaluation_remains.clone() - + Self::select_mode(circuit_builder, &mode_next, HashTableMode::ProgramHashing) - * receive_chunk_running_evaluation_remains; - - let constraints = vec![ - round_number_is_0_through_4_or_round_number_next_is_0, - next_mode_is_padding_mode_or_round_number_is_num_rounds_or_increments_by_one, - receive_chunk_of_instructions_iff_next_mode_is_prog_hashing_and_next_round_number_is_0, - if_mode_changes_from_program_hashing_then_current_digest_is_expected_program_digest, - if_mode_is_program_hashing_and_next_mode_is_sponge_then_ci_next_is_sponge_init, - if_round_number_is_not_max_and_ci_is_not_sponge_init_then_ci_doesnt_change, - if_round_number_is_not_max_and_ci_is_not_sponge_init_then_mode_doesnt_change, - if_mode_is_sponge_then_mode_next_is_sponge_or_hash_or_pad, - if_mode_is_hash_then_mode_next_is_hash_or_pad, - if_mode_is_pad_then_mode_next_is_pad, - capacity_doesnt_change_at_section_start_when_program_hashing_or_absorbing, - if_round_number_next_is_0_and_ci_next_is_squeeze_then_state_doesnt_change, - running_evaluation_hash_input_is_updated_correctly, - running_evaluation_hash_digest_is_updated_correctly, - running_evaluation_sponge_is_updated_correctly, - Self::cascade_log_derivative_update_circuit( - circuit_builder, - State0HighestLkIn, - State0HighestLkOut, - CascadeState0HighestClientLogDerivative, - ), - Self::cascade_log_derivative_update_circuit( - circuit_builder, - State0MidHighLkIn, - State0MidHighLkOut, - CascadeState0MidHighClientLogDerivative, - ), - Self::cascade_log_derivative_update_circuit( - circuit_builder, - State0MidLowLkIn, - State0MidLowLkOut, - CascadeState0MidLowClientLogDerivative, - ), - Self::cascade_log_derivative_update_circuit( - circuit_builder, - State0LowestLkIn, - State0LowestLkOut, - CascadeState0LowestClientLogDerivative, - ), - Self::cascade_log_derivative_update_circuit( - circuit_builder, - State1HighestLkIn, - State1HighestLkOut, - CascadeState1HighestClientLogDerivative, - ), - Self::cascade_log_derivative_update_circuit( - circuit_builder, - State1MidHighLkIn, - State1MidHighLkOut, - CascadeState1MidHighClientLogDerivative, - ), - Self::cascade_log_derivative_update_circuit( - circuit_builder, - State1MidLowLkIn, - State1MidLowLkOut, - CascadeState1MidLowClientLogDerivative, - ), - Self::cascade_log_derivative_update_circuit( - circuit_builder, - State1LowestLkIn, - State1LowestLkOut, - CascadeState1LowestClientLogDerivative, - ), - Self::cascade_log_derivative_update_circuit( - circuit_builder, - State2HighestLkIn, - State2HighestLkOut, - CascadeState2HighestClientLogDerivative, - ), - Self::cascade_log_derivative_update_circuit( - circuit_builder, - State2MidHighLkIn, - State2MidHighLkOut, - CascadeState2MidHighClientLogDerivative, - ), - Self::cascade_log_derivative_update_circuit( - circuit_builder, - State2MidLowLkIn, - State2MidLowLkOut, - CascadeState2MidLowClientLogDerivative, - ), - Self::cascade_log_derivative_update_circuit( - circuit_builder, - State2LowestLkIn, - State2LowestLkOut, - CascadeState2LowestClientLogDerivative, - ), - Self::cascade_log_derivative_update_circuit( - circuit_builder, - State3HighestLkIn, - State3HighestLkOut, - CascadeState3HighestClientLogDerivative, - ), - Self::cascade_log_derivative_update_circuit( - circuit_builder, - State3MidHighLkIn, - State3MidHighLkOut, - CascadeState3MidHighClientLogDerivative, - ), - Self::cascade_log_derivative_update_circuit( - circuit_builder, - State3MidLowLkIn, - State3MidLowLkOut, - CascadeState3MidLowClientLogDerivative, - ), - Self::cascade_log_derivative_update_circuit( - circuit_builder, - State3LowestLkIn, - State3LowestLkOut, - CascadeState3LowestClientLogDerivative, - ), - ]; - - [ - constraints, - hash_function_round_correctly_performs_update.to_vec(), - ] - .concat() - } - - fn indicate_column_index_in_base_row(column: HashBaseTableColumn) -> SingleRowIndicator { - BaseRow(column.master_base_table_index()) - } - - fn indicate_column_index_in_current_base_row(column: HashBaseTableColumn) -> DualRowIndicator { - CurrentBaseRow(column.master_base_table_index()) - } - - fn indicate_column_index_in_next_base_row(column: HashBaseTableColumn) -> DualRowIndicator { - NextBaseRow(column.master_base_table_index()) - } - - fn re_compose_states_0_through_3_before_lookup( - circuit_builder: &ConstraintCircuitBuilder, - base_row_to_input_indicator: fn(HashBaseTableColumn) -> II, - ) -> [ConstraintCircuitMonad; 4] { - let input = |input_indicator: II| circuit_builder.input(input_indicator); - let state_0 = Self::re_compose_16_bit_limbs( - circuit_builder, - input(base_row_to_input_indicator(State0HighestLkIn)), - input(base_row_to_input_indicator(State0MidHighLkIn)), - input(base_row_to_input_indicator(State0MidLowLkIn)), - input(base_row_to_input_indicator(State0LowestLkIn)), - ); - let state_1 = Self::re_compose_16_bit_limbs( - circuit_builder, - input(base_row_to_input_indicator(State1HighestLkIn)), - input(base_row_to_input_indicator(State1MidHighLkIn)), - input(base_row_to_input_indicator(State1MidLowLkIn)), - input(base_row_to_input_indicator(State1LowestLkIn)), - ); - let state_2 = Self::re_compose_16_bit_limbs( - circuit_builder, - input(base_row_to_input_indicator(State2HighestLkIn)), - input(base_row_to_input_indicator(State2MidHighLkIn)), - input(base_row_to_input_indicator(State2MidLowLkIn)), - input(base_row_to_input_indicator(State2LowestLkIn)), - ); - let state_3 = Self::re_compose_16_bit_limbs( - circuit_builder, - input(base_row_to_input_indicator(State3HighestLkIn)), - input(base_row_to_input_indicator(State3MidHighLkIn)), - input(base_row_to_input_indicator(State3MidLowLkIn)), - input(base_row_to_input_indicator(State3LowestLkIn)), - ); - [state_0, state_1, state_2, state_3] - } - - fn tip5_constraints_as_circuits( - circuit_builder: &ConstraintCircuitBuilder, - ) -> ( - [ConstraintCircuitMonad; STATE_SIZE], - [ConstraintCircuitMonad; STATE_SIZE], - ) { - let constant = |c: u64| circuit_builder.b_constant(c); - let b_constant = |c| circuit_builder.b_constant(c); - let current_base_row = |column_idx: HashBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(column_idx.master_base_table_index())) - }; - let next_base_row = |column_idx: HashBaseTableColumn| { - circuit_builder.input(NextBaseRow(column_idx.master_base_table_index())) - }; - - let state_0_after_lookup = Self::re_compose_16_bit_limbs( - circuit_builder, - current_base_row(State0HighestLkOut), - current_base_row(State0MidHighLkOut), - current_base_row(State0MidLowLkOut), - current_base_row(State0LowestLkOut), - ); - let state_1_after_lookup = Self::re_compose_16_bit_limbs( - circuit_builder, - current_base_row(State1HighestLkOut), - current_base_row(State1MidHighLkOut), - current_base_row(State1MidLowLkOut), - current_base_row(State1LowestLkOut), - ); - let state_2_after_lookup = Self::re_compose_16_bit_limbs( - circuit_builder, - current_base_row(State2HighestLkOut), - current_base_row(State2MidHighLkOut), - current_base_row(State2MidLowLkOut), - current_base_row(State2LowestLkOut), - ); - let state_3_after_lookup = Self::re_compose_16_bit_limbs( - circuit_builder, - current_base_row(State3HighestLkOut), - current_base_row(State3MidHighLkOut), - current_base_row(State3MidLowLkOut), - current_base_row(State3LowestLkOut), - ); - - let state_part_before_power_map: [_; STATE_SIZE - NUM_SPLIT_AND_LOOKUP] = [ - State4, State5, State6, State7, State8, State9, State10, State11, State12, State13, - State14, State15, - ] - .map(current_base_row); - - let state_part_after_power_map = { - let mut exponentiation_accumulator = state_part_before_power_map.clone(); - for _ in 1..POWER_MAP_EXPONENT { - for (i, state) in exponentiation_accumulator.iter_mut().enumerate() { - *state = state.clone() * state_part_before_power_map[i].clone(); - } - } - exponentiation_accumulator - }; - - let state_after_s_box_application = [ - state_0_after_lookup, - state_1_after_lookup, - state_2_after_lookup, - state_3_after_lookup, - state_part_after_power_map[0].clone(), - state_part_after_power_map[1].clone(), - state_part_after_power_map[2].clone(), - state_part_after_power_map[3].clone(), - state_part_after_power_map[4].clone(), - state_part_after_power_map[5].clone(), - state_part_after_power_map[6].clone(), - state_part_after_power_map[7].clone(), - state_part_after_power_map[8].clone(), - state_part_after_power_map[9].clone(), - state_part_after_power_map[10].clone(), - state_part_after_power_map[11].clone(), - ]; - - let mut state_after_matrix_multiplication = vec![constant(0); STATE_SIZE]; - for (row_idx, acc) in state_after_matrix_multiplication.iter_mut().enumerate() { - for (col_idx, state) in state_after_s_box_application.iter().enumerate() { - let matrix_entry = b_constant(HashTable::mds_matrix_entry(row_idx, col_idx)); - *acc = acc.clone() + matrix_entry * state.clone(); - } - } - - let round_constants: [_; STATE_SIZE] = [ - Constant0, Constant1, Constant2, Constant3, Constant4, Constant5, Constant6, Constant7, - Constant8, Constant9, Constant10, Constant11, Constant12, Constant13, Constant14, - Constant15, - ] - .map(current_base_row); - - let state_after_round_constant_addition = state_after_matrix_multiplication - .into_iter() - .zip_eq(round_constants) - .map(|(st, rndc)| st + rndc) - .collect_vec(); - - let [state_0_next, state_1_next, state_2_next, state_3_next] = - Self::re_compose_states_0_through_3_before_lookup( - circuit_builder, - Self::indicate_column_index_in_next_base_row, - ); - let state_next = [ - state_0_next, - state_1_next, - state_2_next, - state_3_next, - next_base_row(State4), - next_base_row(State5), - next_base_row(State6), - next_base_row(State7), - next_base_row(State8), - next_base_row(State9), - next_base_row(State10), - next_base_row(State11), - next_base_row(State12), - next_base_row(State13), - next_base_row(State14), - next_base_row(State15), - ]; - - let round_number_next = next_base_row(RoundNumber); - let hash_function_round_correctly_performs_update = state_after_round_constant_addition - .into_iter() - .zip_eq(state_next.clone()) - .map(|(state_element, state_element_next)| { - round_number_next.clone() * (state_element - state_element_next) - }) - .collect_vec() - .try_into() - .unwrap(); - - (state_next, hash_function_round_correctly_performs_update) - } - - fn cascade_log_derivative_update_circuit( - circuit_builder: &ConstraintCircuitBuilder, - look_in_column: HashBaseTableColumn, - look_out_column: HashBaseTableColumn, - cascade_log_derivative_column: HashExtTableColumn, - ) -> ConstraintCircuitMonad { - let challenge = |c| circuit_builder.challenge(c); - let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b()); - let constant = |c: u32| circuit_builder.b_constant(c); - let next_base_row = |column_idx: HashBaseTableColumn| { - circuit_builder.input(NextBaseRow(column_idx.master_base_table_index())) - }; - let current_ext_row = |column_idx: HashExtTableColumn| { - circuit_builder.input(CurrentExtRow(column_idx.master_ext_table_index())) - }; - let next_ext_row = |column_idx: HashExtTableColumn| { - circuit_builder.input(NextExtRow(column_idx.master_ext_table_index())) - }; - - let cascade_indeterminate = challenge(HashCascadeLookupIndeterminate); - let look_in_weight = challenge(HashCascadeLookInWeight); - let look_out_weight = challenge(HashCascadeLookOutWeight); - - let ci_next = next_base_row(CI); - let mode_next = next_base_row(Mode); - let round_number_next = next_base_row(RoundNumber); - let cascade_log_derivative = current_ext_row(cascade_log_derivative_column); - let cascade_log_derivative_next = next_ext_row(cascade_log_derivative_column); - - let compressed_row = look_in_weight * next_base_row(look_in_column) - + look_out_weight * next_base_row(look_out_column); - - let cascade_log_derivative_remains = - cascade_log_derivative_next.clone() - cascade_log_derivative.clone(); - let cascade_log_derivative_updates = (cascade_log_derivative_next - cascade_log_derivative) - * (cascade_indeterminate - compressed_row) - - constant(1); - - let next_row_is_padding_row_or_round_number_next_is_max_or_ci_next_is_sponge_init = - Self::select_mode(circuit_builder, &mode_next, HashTableMode::Pad) - * (round_number_next.clone() - constant(NUM_ROUNDS as u32)) - * (ci_next.clone() - opcode(SpongeInit)); - let round_number_next_is_not_num_rounds = - Self::round_number_deselector(circuit_builder, &round_number_next, NUM_ROUNDS); - let current_instruction_next_is_not_sponge_init = - Self::instruction_deselector(circuit_builder, &ci_next, SpongeInit); - - next_row_is_padding_row_or_round_number_next_is_max_or_ci_next_is_sponge_init - * cascade_log_derivative_updates - + round_number_next_is_not_num_rounds * cascade_log_derivative_remains.clone() - + current_instruction_next_is_not_sponge_init * cascade_log_derivative_remains - } - - pub fn terminal_constraints( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let challenge = |c| circuit_builder.challenge(c); - let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b()); - let constant = |c: u64| circuit_builder.b_constant(c); - let base_row = |column_idx: HashBaseTableColumn| { - circuit_builder.input(BaseRow(column_idx.master_base_table_index())) - }; - - let mode = base_row(Mode); - let round_number = base_row(RoundNumber); - - let compress_program_digest_indeterminate = challenge(CompressProgramDigestIndeterminate); - let expected_program_digest = challenge(CompressedProgramDigest); - - let max_round_number = constant(NUM_ROUNDS as u64); - - let [state_0, state_1, state_2, state_3] = - Self::re_compose_states_0_through_3_before_lookup( - circuit_builder, - Self::indicate_column_index_in_base_row, - ); - let state_4 = base_row(State4); - let program_digest = [state_0, state_1, state_2, state_3, state_4]; - let compressed_digest = program_digest.into_iter().fold( - circuit_builder.x_constant(EvalArg::default_initial()), - |acc, digest_element| { - acc * compress_program_digest_indeterminate.clone() + digest_element - }, - ); - let if_mode_is_program_hashing_then_current_digest_is_expected_program_digest = - Self::mode_deselector(circuit_builder, &mode, HashTableMode::ProgramHashing) - * (compressed_digest - expected_program_digest); - - let if_mode_is_not_pad_and_ci_is_not_sponge_init_then_round_number_is_max_round_number = - Self::select_mode(circuit_builder, &mode, HashTableMode::Pad) - * (base_row(CI) - opcode(SpongeInit)) - * (round_number - max_round_number); - - vec![ - if_mode_is_program_hashing_then_current_digest_is_expected_program_digest, - if_mode_is_not_pad_and_ci_is_not_sponge_init_then_round_number_is_max_round_number, - ] - } -} - -impl HashTable { - /// Get the MDS matrix's entry in row `row_idx` and column `col_idx`. - pub const fn mds_matrix_entry(row_idx: usize, col_idx: usize) -> BFieldElement { - assert!(row_idx < STATE_SIZE); - assert!(col_idx < STATE_SIZE); - let index_in_matrix_defining_column = (STATE_SIZE + row_idx - col_idx) % STATE_SIZE; - let mds_matrix_entry = MDS_MATRIX_FIRST_COLUMN[index_in_matrix_defining_column]; - BFieldElement::new(mds_matrix_entry as u64) - } - - /// The round constants for round `r` if it is a valid round number in the Tip5 permutation, - /// and the zero vector otherwise. - pub fn tip5_round_constants_by_round_number(r: usize) -> [BFieldElement; NUM_ROUND_CONSTANTS] { - if r >= NUM_ROUNDS { - return bfe_array![0; NUM_ROUND_CONSTANTS]; - } - - let range_start = NUM_ROUND_CONSTANTS * r; - let range_end = NUM_ROUND_CONSTANTS * (r + 1); - ROUND_CONSTANTS[range_start..range_end].try_into().unwrap() - } - - /// Return the 16-bit chunks of the “un-Montgomery'd” representation, in little-endian chunk - /// order. This (basically) translates to the application of `σ(R·x)` for input `x`, which - /// are the first two steps in Tip5's split-and-lookup S-Box. - /// `R` is the Montgomery modulus, _i.e._, `R = 2^64 mod p`. - /// `σ` as described in the paper decomposes the 64-bit input into 8-bit limbs, whereas - /// this method decomposes into 16-bit limbs for arithmetization reasons; the 16-bit limbs - /// are split into 8-bit limbs in the Cascade Table. - /// For a more in-depth explanation of all the necessary steps in the split-and-lookup S-Box, - /// see the [Tip5 paper](https://eprint.iacr.org/2023/107.pdf). - /// - /// Note: this is distinct from the seemingly similar [`raw_u16s`](BFieldElement::raw_u16s). - pub fn base_field_element_into_16_bit_limbs(x: BFieldElement) -> [u16; 4] { - let r_times_x = (MONTGOMERY_MODULUS * x).value(); - [0, 16, 32, 48].map(|shift| ((r_times_x >> shift) & 0xffff) as u16) - } - - /// Convert a permutation trace to a segment in the Hash Table. - /// - /// **Note**: The current instruction [`CI`] is _not_ set. - pub fn trace_to_table_rows(trace: PermutationTrace) -> Array2 { - let mut table_rows = Array2::default([0, BASE_WIDTH]); - for (round_number, &trace_row) in trace.iter().enumerate() { - let table_row = Self::trace_row_to_table_row(trace_row, round_number); - table_rows.push_row(table_row.view()).unwrap(); - } - table_rows - } - - pub fn trace_row_to_table_row( - trace_row: [BFieldElement; STATE_SIZE], - round_number: usize, - ) -> Array1 { - let row = Array1::zeros([BASE_WIDTH]); - let row = Self::fill_row_with_round_number(row, round_number); - let row = Self::fill_row_with_split_state_elements_using_trace_row(row, trace_row); - let row = Self::fill_row_with_unsplit_state_elements_using_trace_row(row, trace_row); - let row = Self::fill_row_with_state_inverses_using_trace_row(row, trace_row); - Self::fill_row_with_round_constants_for_round(row, round_number) - } - - fn fill_row_with_round_number( - mut row: Array1, - round_number: usize, - ) -> Array1 { - row[RoundNumber.base_table_index()] = bfe!(round_number as u64); - row - } - - fn fill_row_with_split_state_elements_using_trace_row( - row: Array1, - trace_row: [BFieldElement; STATE_SIZE], - ) -> Array1 { - let row = Self::fill_split_state_element_0_of_row_using_trace_row(row, trace_row); - let row = Self::fill_split_state_element_1_of_row_using_trace_row(row, trace_row); - let row = Self::fill_split_state_element_2_of_row_using_trace_row(row, trace_row); - Self::fill_split_state_element_3_of_row_using_trace_row(row, trace_row) - } - - fn fill_split_state_element_0_of_row_using_trace_row( - mut row: Array1, - trace_row: [BFieldElement; STATE_SIZE], - ) -> Array1 { - let limbs = Self::base_field_element_into_16_bit_limbs(trace_row[0]); - let look_in_split = limbs.map(|limb| bfe!(limb)); - row[State0LowestLkIn.base_table_index()] = look_in_split[0]; - row[State0MidLowLkIn.base_table_index()] = look_in_split[1]; - row[State0MidHighLkIn.base_table_index()] = look_in_split[2]; - row[State0HighestLkIn.base_table_index()] = look_in_split[3]; - - let look_out_split = limbs.map(CascadeTable::lookup_16_bit_limb); - row[State0LowestLkOut.base_table_index()] = look_out_split[0]; - row[State0MidLowLkOut.base_table_index()] = look_out_split[1]; - row[State0MidHighLkOut.base_table_index()] = look_out_split[2]; - row[State0HighestLkOut.base_table_index()] = look_out_split[3]; - - row - } - - fn fill_split_state_element_1_of_row_using_trace_row( - mut row: Array1, - trace_row: [BFieldElement; STATE_SIZE], - ) -> Array1 { - let limbs = Self::base_field_element_into_16_bit_limbs(trace_row[1]); - let look_in_split = limbs.map(|limb| bfe!(limb)); - row[State1LowestLkIn.base_table_index()] = look_in_split[0]; - row[State1MidLowLkIn.base_table_index()] = look_in_split[1]; - row[State1MidHighLkIn.base_table_index()] = look_in_split[2]; - row[State1HighestLkIn.base_table_index()] = look_in_split[3]; - - let look_out_split = limbs.map(CascadeTable::lookup_16_bit_limb); - row[State1LowestLkOut.base_table_index()] = look_out_split[0]; - row[State1MidLowLkOut.base_table_index()] = look_out_split[1]; - row[State1MidHighLkOut.base_table_index()] = look_out_split[2]; - row[State1HighestLkOut.base_table_index()] = look_out_split[3]; - - row - } - - fn fill_split_state_element_2_of_row_using_trace_row( - mut row: Array1, - trace_row: [BFieldElement; STATE_SIZE], - ) -> Array1 { - let limbs = Self::base_field_element_into_16_bit_limbs(trace_row[2]); - let look_in_split = limbs.map(|limb| bfe!(limb)); - row[State2LowestLkIn.base_table_index()] = look_in_split[0]; - row[State2MidLowLkIn.base_table_index()] = look_in_split[1]; - row[State2MidHighLkIn.base_table_index()] = look_in_split[2]; - row[State2HighestLkIn.base_table_index()] = look_in_split[3]; - - let look_out_split = limbs.map(CascadeTable::lookup_16_bit_limb); - row[State2LowestLkOut.base_table_index()] = look_out_split[0]; - row[State2MidLowLkOut.base_table_index()] = look_out_split[1]; - row[State2MidHighLkOut.base_table_index()] = look_out_split[2]; - row[State2HighestLkOut.base_table_index()] = look_out_split[3]; - - row - } - - fn fill_split_state_element_3_of_row_using_trace_row( - mut row: Array1, - trace_row: [BFieldElement; STATE_SIZE], - ) -> Array1 { - let limbs = Self::base_field_element_into_16_bit_limbs(trace_row[3]); - let look_in_split = limbs.map(|limb| bfe!(limb)); - row[State3LowestLkIn.base_table_index()] = look_in_split[0]; - row[State3MidLowLkIn.base_table_index()] = look_in_split[1]; - row[State3MidHighLkIn.base_table_index()] = look_in_split[2]; - row[State3HighestLkIn.base_table_index()] = look_in_split[3]; - - let look_out_split = limbs.map(CascadeTable::lookup_16_bit_limb); - row[State3LowestLkOut.base_table_index()] = look_out_split[0]; - row[State3MidLowLkOut.base_table_index()] = look_out_split[1]; - row[State3MidHighLkOut.base_table_index()] = look_out_split[2]; - row[State3HighestLkOut.base_table_index()] = look_out_split[3]; - - row - } - - fn fill_row_with_unsplit_state_elements_using_trace_row( - mut row: Array1, - trace_row: [BFieldElement; STATE_SIZE], - ) -> Array1 { - row[State4.base_table_index()] = trace_row[4]; - row[State5.base_table_index()] = trace_row[5]; - row[State6.base_table_index()] = trace_row[6]; - row[State7.base_table_index()] = trace_row[7]; - row[State8.base_table_index()] = trace_row[8]; - row[State9.base_table_index()] = trace_row[9]; - row[State10.base_table_index()] = trace_row[10]; - row[State11.base_table_index()] = trace_row[11]; - row[State12.base_table_index()] = trace_row[12]; - row[State13.base_table_index()] = trace_row[13]; - row[State14.base_table_index()] = trace_row[14]; - row[State15.base_table_index()] = trace_row[15]; - row - } - - fn fill_row_with_state_inverses_using_trace_row( - mut row: Array1, - trace_row: [BFieldElement; STATE_SIZE], - ) -> Array1 { - row[State0Inv.base_table_index()] = Self::inverse_or_zero_of_highest_2_limbs(trace_row[0]); - row[State1Inv.base_table_index()] = Self::inverse_or_zero_of_highest_2_limbs(trace_row[1]); - row[State2Inv.base_table_index()] = Self::inverse_or_zero_of_highest_2_limbs(trace_row[2]); - row[State3Inv.base_table_index()] = Self::inverse_or_zero_of_highest_2_limbs(trace_row[3]); - row - } - - /// The inverse-or-zero of (2^32 - 1 - 2^16·`highest` - `mid_high`) where `highest` - /// is the most significant limb of the given `state_element`, and `mid_high` the second-most - /// significant limb. - fn inverse_or_zero_of_highest_2_limbs(state_element: BFieldElement) -> BFieldElement { - let limbs = Self::base_field_element_into_16_bit_limbs(state_element); - let highest: u64 = limbs[3].into(); - let mid_high: u64 = limbs[2].into(); - let high_limbs = bfe!((highest << 16) + mid_high); - let two_pow_32_minus_1 = bfe!((1_u64 << 32) - 1); - let to_invert = two_pow_32_minus_1 - high_limbs; - to_invert.inverse_or_zero() - } - - fn fill_row_with_round_constants_for_round( - mut row: Array1, - round_number: usize, - ) -> Array1 { - let round_constants = Self::tip5_round_constants_by_round_number(round_number); - row[Constant0.base_table_index()] = round_constants[0]; - row[Constant1.base_table_index()] = round_constants[1]; - row[Constant2.base_table_index()] = round_constants[2]; - row[Constant3.base_table_index()] = round_constants[3]; - row[Constant4.base_table_index()] = round_constants[4]; - row[Constant5.base_table_index()] = round_constants[5]; - row[Constant6.base_table_index()] = round_constants[6]; - row[Constant7.base_table_index()] = round_constants[7]; - row[Constant8.base_table_index()] = round_constants[8]; - row[Constant9.base_table_index()] = round_constants[9]; - row[Constant10.base_table_index()] = round_constants[10]; - row[Constant11.base_table_index()] = round_constants[11]; - row[Constant12.base_table_index()] = round_constants[12]; - row[Constant13.base_table_index()] = round_constants[13]; - row[Constant14.base_table_index()] = round_constants[14]; - row[Constant15.base_table_index()] = round_constants[15]; - row - } - - pub fn fill_trace( - hash_table: &mut ArrayViewMut2, - aet: &AlgebraicExecutionTrace, - ) { - let program_hash_part_start = 0; - let program_hash_part_end = program_hash_part_start + aet.program_hash_trace.nrows(); - let sponge_part_start = program_hash_part_end; - let sponge_part_end = sponge_part_start + aet.sponge_trace.nrows(); - let hash_part_start = sponge_part_end; - let hash_part_end = hash_part_start + aet.hash_trace.nrows(); - - let (mut program_hash_part, mut sponge_part, mut hash_part) = hash_table.multi_slice_mut(( - s![program_hash_part_start..program_hash_part_end, ..], - s![sponge_part_start..sponge_part_end, ..], - s![hash_part_start..hash_part_end, ..], - )); - - program_hash_part.assign(&aet.program_hash_trace); - sponge_part.assign(&aet.sponge_trace); - hash_part.assign(&aet.hash_trace); - - let mode_column_idx = Mode.base_table_index(); - let mut program_hash_mode_column = program_hash_part.column_mut(mode_column_idx); - let mut sponge_mode_column = sponge_part.column_mut(mode_column_idx); - let mut hash_mode_column = hash_part.column_mut(mode_column_idx); - - program_hash_mode_column.fill(HashTableMode::ProgramHashing.into()); - sponge_mode_column.fill(HashTableMode::Sponge.into()); - hash_mode_column.fill(HashTableMode::Hash.into()); - } - - pub fn pad_trace(mut hash_table: ArrayViewMut2, hash_table_length: usize) { - let inverse_of_high_limbs = Self::inverse_or_zero_of_highest_2_limbs(bfe!(0)); - for column_id in [State0Inv, State1Inv, State2Inv, State3Inv] { - let column_index = column_id.base_table_index(); - let slice_info = s![hash_table_length.., column_index]; - let mut column = hash_table.slice_mut(slice_info); - column.fill(inverse_of_high_limbs); - } - - let round_constants = Self::tip5_round_constants_by_round_number(0); - for (round_constant_idx, &round_constant) in round_constants.iter().enumerate() { - let round_constant_column = - ExtHashTable::round_constant_column_by_index(round_constant_idx); - let round_constant_column_idx = round_constant_column.base_table_index(); - let slice_info = s![hash_table_length.., round_constant_column_idx]; - let mut column = hash_table.slice_mut(slice_info); - column.fill(round_constant); - } - - let mode_column_index = Mode.base_table_index(); - let mode_column_slice_info = s![hash_table_length.., mode_column_index]; - let mut mode_column = hash_table.slice_mut(mode_column_slice_info); - mode_column.fill(HashTableMode::Pad.into()); - - let instruction_column_index = CI.base_table_index(); - let instruction_column_slice_info = s![hash_table_length.., instruction_column_index]; - let mut instruction_column = hash_table.slice_mut(instruction_column_slice_info); - instruction_column.fill(Instruction::Hash.opcode_b()); - } - - pub fn extend( - base_table: ArrayView2, - mut ext_table: ArrayViewMut2, - challenges: &Challenges, - ) { - profiler!(start "hash table"); - assert_eq!(BASE_WIDTH, base_table.ncols()); - assert_eq!(EXT_WIDTH, ext_table.ncols()); - assert_eq!(base_table.nrows(), ext_table.nrows()); - - let ci_weight = challenges[HashCIWeight]; - let hash_digest_eval_indeterminate = challenges[HashDigestIndeterminate]; - let hash_input_eval_indeterminate = challenges[HashInputIndeterminate]; - let sponge_eval_indeterminate = challenges[SpongeIndeterminate]; - let cascade_indeterminate = challenges[HashCascadeLookupIndeterminate]; - let send_chunk_indeterminate = challenges[ProgramAttestationSendChunkIndeterminate]; - - let mut hash_input_running_evaluation = EvalArg::default_initial(); - let mut hash_digest_running_evaluation = EvalArg::default_initial(); - let mut sponge_running_evaluation = EvalArg::default_initial(); - let mut cascade_state_0_highest_log_derivative = LookupArg::default_initial(); - let mut cascade_state_0_mid_high_log_derivative = LookupArg::default_initial(); - let mut cascade_state_0_mid_low_log_derivative = LookupArg::default_initial(); - let mut cascade_state_0_lowest_log_derivative = LookupArg::default_initial(); - let mut cascade_state_1_highest_log_derivative = LookupArg::default_initial(); - let mut cascade_state_1_mid_high_log_derivative = LookupArg::default_initial(); - let mut cascade_state_1_mid_low_log_derivative = LookupArg::default_initial(); - let mut cascade_state_1_lowest_log_derivative = LookupArg::default_initial(); - let mut cascade_state_2_highest_log_derivative = LookupArg::default_initial(); - let mut cascade_state_2_mid_high_log_derivative = LookupArg::default_initial(); - let mut cascade_state_2_mid_low_log_derivative = LookupArg::default_initial(); - let mut cascade_state_2_lowest_log_derivative = LookupArg::default_initial(); - let mut cascade_state_3_highest_log_derivative = LookupArg::default_initial(); - let mut cascade_state_3_mid_high_log_derivative = LookupArg::default_initial(); - let mut cascade_state_3_mid_low_log_derivative = LookupArg::default_initial(); - let mut cascade_state_3_lowest_log_derivative = LookupArg::default_initial(); - let mut receive_chunk_running_evaluation = EvalArg::default_initial(); - - let two_pow_16 = bfe!(1_u64 << 16); - let two_pow_32 = bfe!(1_u64 << 32); - let two_pow_48 = bfe!(1_u64 << 48); - - let montgomery_modulus_inverse = MONTGOMERY_MODULUS.inverse(); - let re_compose_state_element = - |row: ArrayView1, - highest: HashBaseTableColumn, - mid_high: HashBaseTableColumn, - mid_low: HashBaseTableColumn, - lowest: HashBaseTableColumn| { - (row[highest.base_table_index()] * two_pow_48 - + row[mid_high.base_table_index()] * two_pow_32 - + row[mid_low.base_table_index()] * two_pow_16 - + row[lowest.base_table_index()]) - * montgomery_modulus_inverse - }; - - let rate_registers = |row: ArrayView1| { - let state_0 = re_compose_state_element( - row, - State0HighestLkIn, - State0MidHighLkIn, - State0MidLowLkIn, - State0LowestLkIn, - ); - let state_1 = re_compose_state_element( - row, - State1HighestLkIn, - State1MidHighLkIn, - State1MidLowLkIn, - State1LowestLkIn, - ); - let state_2 = re_compose_state_element( - row, - State2HighestLkIn, - State2MidHighLkIn, - State2MidLowLkIn, - State2LowestLkIn, - ); - let state_3 = re_compose_state_element( - row, - State3HighestLkIn, - State3MidHighLkIn, - State3MidLowLkIn, - State3LowestLkIn, - ); - [ - state_0, - state_1, - state_2, - state_3, - row[State4.base_table_index()], - row[State5.base_table_index()], - row[State6.base_table_index()], - row[State7.base_table_index()], - row[State8.base_table_index()], - row[State9.base_table_index()], - ] - }; - - let state_weights = &challenges[StackWeight0..StackWeight10]; - let compressed_row = |row: ArrayView1| -> XFieldElement { - rate_registers(row) - .iter() - .zip_eq(state_weights.iter()) - .map(|(&state, &weight)| weight * state) - .sum() - }; - - let cascade_look_in_weight = challenges[HashCascadeLookInWeight]; - let cascade_look_out_weight = challenges[HashCascadeLookOutWeight]; - - let log_derivative_summand = - |row: ArrayView1, - lk_in_col: HashBaseTableColumn, - lk_out_col: HashBaseTableColumn| { - let compressed_elements = cascade_indeterminate - - cascade_look_in_weight * row[lk_in_col.base_table_index()] - - cascade_look_out_weight * row[lk_out_col.base_table_index()]; - compressed_elements.inverse() - }; - - for row_idx in 0..base_table.nrows() { - let row = base_table.row(row_idx); - - let mode = row[Mode.base_table_index()]; - let in_program_hashing_mode = mode == HashTableMode::ProgramHashing.into(); - let in_sponge_mode = mode == HashTableMode::Sponge.into(); - let in_hash_mode = mode == HashTableMode::Hash.into(); - let in_pad_mode = mode == HashTableMode::Pad.into(); - - let round_number = row[RoundNumber.base_table_index()]; - let in_round_0 = round_number.is_zero(); - let in_last_round = round_number == (NUM_ROUNDS as u64).into(); - - let current_instruction = row[CI.base_table_index()]; - let current_instruction_is_sponge_init = - current_instruction == Instruction::SpongeInit.opcode_b(); - - if in_program_hashing_mode && in_round_0 { - let compressed_chunk_of_instructions = EvalArg::compute_terminal( - &rate_registers(row), - EvalArg::default_initial(), - challenges[ProgramAttestationPrepareChunkIndeterminate], - ); - receive_chunk_running_evaluation = receive_chunk_running_evaluation - * send_chunk_indeterminate - + compressed_chunk_of_instructions - } - - if in_sponge_mode && in_round_0 && current_instruction_is_sponge_init { - sponge_running_evaluation = sponge_running_evaluation * sponge_eval_indeterminate - + ci_weight * current_instruction - } - - if in_sponge_mode && in_round_0 && !current_instruction_is_sponge_init { - sponge_running_evaluation = sponge_running_evaluation * sponge_eval_indeterminate - + ci_weight * current_instruction - + compressed_row(row) - } - - if in_hash_mode && in_round_0 { - hash_input_running_evaluation = hash_input_running_evaluation - * hash_input_eval_indeterminate - + compressed_row(row) - } - - if in_hash_mode && in_last_round { - let compressed_digest: XFieldElement = rate_registers(row)[..Digest::LEN] - .iter() - .zip_eq(state_weights[..Digest::LEN].iter()) - .map(|(&state, &weight)| weight * state) - .sum(); - hash_digest_running_evaluation = hash_digest_running_evaluation - * hash_digest_eval_indeterminate - + compressed_digest - } - - if !in_pad_mode && !in_last_round && !current_instruction_is_sponge_init { - cascade_state_0_highest_log_derivative += - log_derivative_summand(row, State0HighestLkIn, State0HighestLkOut); - cascade_state_0_mid_high_log_derivative += - log_derivative_summand(row, State0MidHighLkIn, State0MidHighLkOut); - cascade_state_0_mid_low_log_derivative += - log_derivative_summand(row, State0MidLowLkIn, State0MidLowLkOut); - cascade_state_0_lowest_log_derivative += - log_derivative_summand(row, State0LowestLkIn, State0LowestLkOut); - cascade_state_1_highest_log_derivative += - log_derivative_summand(row, State1HighestLkIn, State1HighestLkOut); - cascade_state_1_mid_high_log_derivative += - log_derivative_summand(row, State1MidHighLkIn, State1MidHighLkOut); - cascade_state_1_mid_low_log_derivative += - log_derivative_summand(row, State1MidLowLkIn, State1MidLowLkOut); - cascade_state_1_lowest_log_derivative += - log_derivative_summand(row, State1LowestLkIn, State1LowestLkOut); - cascade_state_2_highest_log_derivative += - log_derivative_summand(row, State2HighestLkIn, State2HighestLkOut); - cascade_state_2_mid_high_log_derivative += - log_derivative_summand(row, State2MidHighLkIn, State2MidHighLkOut); - cascade_state_2_mid_low_log_derivative += - log_derivative_summand(row, State2MidLowLkIn, State2MidLowLkOut); - cascade_state_2_lowest_log_derivative += - log_derivative_summand(row, State2LowestLkIn, State2LowestLkOut); - cascade_state_3_highest_log_derivative += - log_derivative_summand(row, State3HighestLkIn, State3HighestLkOut); - cascade_state_3_mid_high_log_derivative += - log_derivative_summand(row, State3MidHighLkIn, State3MidHighLkOut); - cascade_state_3_mid_low_log_derivative += - log_derivative_summand(row, State3MidLowLkIn, State3MidLowLkOut); - cascade_state_3_lowest_log_derivative += - log_derivative_summand(row, State3LowestLkIn, State3LowestLkOut); - } - - let mut extension_row = ext_table.row_mut(row_idx); - extension_row[ReceiveChunkRunningEvaluation.ext_table_index()] = - receive_chunk_running_evaluation; - extension_row[HashInputRunningEvaluation.ext_table_index()] = - hash_input_running_evaluation; - extension_row[HashDigestRunningEvaluation.ext_table_index()] = - hash_digest_running_evaluation; - extension_row[SpongeRunningEvaluation.ext_table_index()] = sponge_running_evaluation; - extension_row[CascadeState0HighestClientLogDerivative.ext_table_index()] = - cascade_state_0_highest_log_derivative; - extension_row[CascadeState0MidHighClientLogDerivative.ext_table_index()] = - cascade_state_0_mid_high_log_derivative; - extension_row[CascadeState0MidLowClientLogDerivative.ext_table_index()] = - cascade_state_0_mid_low_log_derivative; - extension_row[CascadeState0LowestClientLogDerivative.ext_table_index()] = - cascade_state_0_lowest_log_derivative; - extension_row[CascadeState1HighestClientLogDerivative.ext_table_index()] = - cascade_state_1_highest_log_derivative; - extension_row[CascadeState1MidHighClientLogDerivative.ext_table_index()] = - cascade_state_1_mid_high_log_derivative; - extension_row[CascadeState1MidLowClientLogDerivative.ext_table_index()] = - cascade_state_1_mid_low_log_derivative; - extension_row[CascadeState1LowestClientLogDerivative.ext_table_index()] = - cascade_state_1_lowest_log_derivative; - extension_row[CascadeState2HighestClientLogDerivative.ext_table_index()] = - cascade_state_2_highest_log_derivative; - extension_row[CascadeState2MidHighClientLogDerivative.ext_table_index()] = - cascade_state_2_mid_high_log_derivative; - extension_row[CascadeState2MidLowClientLogDerivative.ext_table_index()] = - cascade_state_2_mid_low_log_derivative; - extension_row[CascadeState2LowestClientLogDerivative.ext_table_index()] = - cascade_state_2_lowest_log_derivative; - extension_row[CascadeState3HighestClientLogDerivative.ext_table_index()] = - cascade_state_3_highest_log_derivative; - extension_row[CascadeState3MidHighClientLogDerivative.ext_table_index()] = - cascade_state_3_mid_high_log_derivative; - extension_row[CascadeState3MidLowClientLogDerivative.ext_table_index()] = - cascade_state_3_mid_low_log_derivative; - extension_row[CascadeState3LowestClientLogDerivative.ext_table_index()] = - cascade_state_3_lowest_log_derivative; - } - profiler!(stop "hash table"); - } -} - -#[cfg(test)] -pub(crate) mod tests { - use std::collections::HashMap; - - use crate::shared_tests::ProgramAndInput; - use crate::stark::tests::master_tables_for_low_security_level; - use crate::table::master_table::MasterTable; - use crate::table::master_table::TableId; - use crate::triton_asm; - use crate::triton_program; - use crate::vm::VM; - - use super::*; - - #[test] - fn hash_table_mode_discriminant_is_unique() { - let mut discriminants_and_modes = HashMap::new(); - for mode in HashTableMode::iter() { - let discriminant = u32::from(mode); - let maybe_entry = discriminants_and_modes.insert(discriminant, mode); - if let Some(entry) = maybe_entry { - panic!("Discriminant collision for {discriminant} between {entry} and {mode}."); - } - } - } - - #[test] - fn terminal_constraints_hold_for_sponge_init_edge_case() { - let many_sponge_inits = triton_asm![sponge_init; 23_631]; - let many_squeeze_absorbs = (0..2_100) - .flat_map(|_| triton_asm!(sponge_squeeze sponge_absorb)) - .collect_vec(); - let program = triton_program! { - {&many_sponge_inits} - {&many_squeeze_absorbs} - sponge_init - halt - }; - - let (aet, _) = VM::trace_execution(&program, [].into(), [].into()).unwrap(); - dbg!(aet.height()); - dbg!(aet.padded_height()); - dbg!(aet.height_of_table(TableId::Hash)); - dbg!(aet.height_of_table(TableId::OpStack)); - dbg!(aet.height_of_table(TableId::Cascade)); - - let (_, _, master_base_table, master_ext_table, challenges) = - master_tables_for_low_security_level(ProgramAndInput::new(program)); - let challenges = &challenges.challenges; - - let master_base_trace_table = master_base_table.trace_table(); - let master_ext_trace_table = master_ext_table.trace_table(); - - let last_row = master_base_trace_table.slice(s![-1.., ..]); - let last_opcode = last_row[[0, HashBaseTableColumn::CI.master_base_table_index()]]; - let last_instruction: Instruction = last_opcode.value().try_into().unwrap(); - assert_eq!(Instruction::SpongeInit, last_instruction); - - let circuit_builder = ConstraintCircuitBuilder::new(); - for (constraint_idx, constraint) in ExtHashTable::terminal_constraints(&circuit_builder) - .into_iter() - .map(|constraint_monad| constraint_monad.consume()) - .enumerate() - { - let evaluated_constraint = constraint.evaluate( - master_base_trace_table.slice(s![-1.., ..]), - master_ext_trace_table.slice(s![-1.., ..]), - challenges, - ); - assert_eq!( - xfe!(0), - evaluated_constraint, - "Terminal constraint {constraint_idx} failed." - ); - } - } -} diff --git a/triton-vm/src/table/jump_stack.rs b/triton-vm/src/table/jump_stack.rs new file mode 100644 index 000000000..43afe3da5 --- /dev/null +++ b/triton-vm/src/table/jump_stack.rs @@ -0,0 +1,229 @@ +use std::cmp::Ordering; +use std::collections::HashMap; +use std::ops::Range; + +use air::challenge_id::ChallengeId::*; +use air::cross_table_argument::*; +use air::table::jump_stack::JumpStackTable; +use air::table_column::JumpStackBaseTableColumn::*; +use air::table_column::JumpStackExtTableColumn::*; +use air::table_column::*; +use constraint_circuit::DualRowIndicator::*; +use constraint_circuit::SingleRowIndicator::*; +use constraint_circuit::*; +use isa::instruction::Instruction; +use itertools::Itertools; +use ndarray::parallel::prelude::*; +use ndarray::prelude::*; +use strum::EnumCount; +use strum::IntoEnumIterator; +use twenty_first::math::traits::FiniteField; +use twenty_first::prelude::*; + +use crate::aet::AlgebraicExecutionTrace; +use crate::challenges::Challenges; +use crate::ndarray_helper::contiguous_column_slices; +use crate::ndarray_helper::horizontal_multi_slice_mut; +use crate::profiler::profiler; +use crate::table::TraceTable; + +fn extension_column_running_product_permutation_argument( + main_table: ArrayView2, + challenges: &Challenges, +) -> Array2 { + let mut running_product = PermArg::default_initial(); + let mut extension_column = Vec::with_capacity(main_table.nrows()); + for row in main_table.rows() { + let compressed_row = row[CLK.base_table_index()] * challenges[JumpStackClkWeight] + + row[CI.base_table_index()] * challenges[JumpStackCiWeight] + + row[JSP.base_table_index()] * challenges[JumpStackJspWeight] + + row[JSO.base_table_index()] * challenges[JumpStackJsoWeight] + + row[JSD.base_table_index()] * challenges[JumpStackJsdWeight]; + running_product *= challenges[JumpStackIndeterminate] - compressed_row; + extension_column.push(running_product); + } + Array2::from_shape_vec((main_table.nrows(), 1), extension_column).unwrap() +} + +fn extension_column_clock_jump_diff_lookup_log_derivative( + main_table: ArrayView2, + challenges: &Challenges, +) -> Array2 { + // - use memoization to avoid recomputing inverses + // - precompute common values through batch inversion + const PRECOMPUTE_INVERSES_OF: Range = 0..100; + let indeterminate = challenges[ClockJumpDifferenceLookupIndeterminate]; + let to_invert = PRECOMPUTE_INVERSES_OF + .map(|i| indeterminate - bfe!(i)) + .collect(); + let mut inverses_dictionary = PRECOMPUTE_INVERSES_OF + .zip_eq(XFieldElement::batch_inversion(to_invert)) + .map(|(i, inv)| (bfe!(i), inv)) + .collect::>(); + + // populate extension column using memoization + let mut cjd_lookup_log_derivative = LookupArg::default_initial(); + let mut extension_column = Vec::with_capacity(main_table.nrows()); + extension_column.push(cjd_lookup_log_derivative); + for (previous_row, current_row) in main_table.rows().into_iter().tuple_windows() { + if previous_row[JSP.base_table_index()] == current_row[JSP.base_table_index()] { + let previous_clock = previous_row[CLK.base_table_index()]; + let current_clock = current_row[CLK.base_table_index()]; + let clock_jump_difference = current_clock - previous_clock; + let &mut inverse = inverses_dictionary + .entry(clock_jump_difference) + .or_insert_with(|| (indeterminate - clock_jump_difference).inverse()); + cjd_lookup_log_derivative += inverse; + } + extension_column.push(cjd_lookup_log_derivative); + } + Array2::from_shape_vec((main_table.nrows(), 1), extension_column).unwrap() +} + +impl TraceTable for JumpStackTable { + type FillParam = (); + type FillReturnInfo = Vec; + + fn fill( + mut jump_stack_table: ArrayViewMut2, + aet: &AlgebraicExecutionTrace, + _: Self::FillParam, + ) -> Self::FillReturnInfo { + // Store the registers relevant for the Jump Stack Table, i.e., CLK, CI, JSP, JSO, JSD, + // with JSP as the key. Preserves, thus allows reusing, the order of the processor's + // rows, which are sorted by CLK. + let mut pre_processed_jump_stack_table: Vec> = vec![]; + for processor_row in aet.processor_trace.rows() { + let clk = processor_row[ProcessorBaseTableColumn::CLK.base_table_index()]; + let ci = processor_row[ProcessorBaseTableColumn::CI.base_table_index()]; + let jsp = processor_row[ProcessorBaseTableColumn::JSP.base_table_index()]; + let jso = processor_row[ProcessorBaseTableColumn::JSO.base_table_index()]; + let jsd = processor_row[ProcessorBaseTableColumn::JSD.base_table_index()]; + // The (honest) prover can only grow the Jump Stack's size by at most 1 per execution + // step. Hence, the following (a) works, and (b) sorts. + let jsp_val = jsp.value() as usize; + let jump_stack_row = (clk, ci, jso, jsd); + match jsp_val.cmp(&pre_processed_jump_stack_table.len()) { + Ordering::Less => pre_processed_jump_stack_table[jsp_val].push(jump_stack_row), + Ordering::Equal => pre_processed_jump_stack_table.push(vec![jump_stack_row]), + Ordering::Greater => panic!("JSP must increase by at most 1 per execution step."), + } + } + + // Move the rows into the Jump Stack Table, sorted by JSP first, CLK second. + let mut jump_stack_table_row = 0; + for (jsp_val, rows_with_this_jsp) in pre_processed_jump_stack_table.into_iter().enumerate() + { + let jsp = bfe!(jsp_val as u64); + for (clk, ci, jso, jsd) in rows_with_this_jsp { + jump_stack_table[(jump_stack_table_row, CLK.base_table_index())] = clk; + jump_stack_table[(jump_stack_table_row, CI.base_table_index())] = ci; + jump_stack_table[(jump_stack_table_row, JSP.base_table_index())] = jsp; + jump_stack_table[(jump_stack_table_row, JSO.base_table_index())] = jso; + jump_stack_table[(jump_stack_table_row, JSD.base_table_index())] = jsd; + jump_stack_table_row += 1; + } + } + assert_eq!(aet.processor_trace.nrows(), jump_stack_table_row); + + // Collect all clock jump differences. + // The Jump Stack Table and the Processor Table have the same length. + let mut clock_jump_differences = vec![]; + for row_idx in 0..aet.processor_trace.nrows() - 1 { + let curr_row = jump_stack_table.row(row_idx); + let next_row = jump_stack_table.row(row_idx + 1); + let clk_diff = next_row[CLK.base_table_index()] - curr_row[CLK.base_table_index()]; + if curr_row[JSP.base_table_index()] == next_row[JSP.base_table_index()] { + clock_jump_differences.push(clk_diff); + } + } + clock_jump_differences + } + + fn pad(mut jump_stack_table: ArrayViewMut2, table_len: usize) { + assert!(table_len > 0, "Processor Table must have at least 1 row."); + + // Set up indices for relevant sections of the table. + let padded_height = jump_stack_table.nrows(); + let num_padding_rows = padded_height - table_len; + let max_clk_before_padding = table_len - 1; + let max_clk_before_padding_row_idx = jump_stack_table + .rows() + .into_iter() + .enumerate() + .find(|(_, row)| row[CLK.base_table_index()].value() as usize == max_clk_before_padding) + .map(|(idx, _)| idx) + .expect("Jump Stack Table must contain row with clock cycle equal to max cycle."); + let rows_to_move_source_section_start = max_clk_before_padding_row_idx + 1; + let rows_to_move_source_section_end = table_len; + let num_rows_to_move = rows_to_move_source_section_end - rows_to_move_source_section_start; + let rows_to_move_dest_section_start = rows_to_move_source_section_start + num_padding_rows; + let rows_to_move_dest_section_end = rows_to_move_dest_section_start + num_rows_to_move; + let padding_section_start = rows_to_move_source_section_start; + let padding_section_end = padding_section_start + num_padding_rows; + assert_eq!(padded_height, rows_to_move_dest_section_end); + + // Move all rows below the row with highest CLK to the end of the table – if they exist. + if num_rows_to_move > 0 { + let rows_to_move_source_range = + rows_to_move_source_section_start..rows_to_move_source_section_end; + let rows_to_move_dest_range = + rows_to_move_dest_section_start..rows_to_move_dest_section_end; + let rows_to_move = jump_stack_table + .slice(s![rows_to_move_source_range, ..]) + .to_owned(); + rows_to_move + .move_into(&mut jump_stack_table.slice_mut(s![rows_to_move_dest_range, ..])); + } + + // Fill the created gap with padding rows, i.e., with copies of the last row before the + // gap. This is the padding section. + let padding_row_template = jump_stack_table + .row(max_clk_before_padding_row_idx) + .to_owned(); + let mut padding_section = + jump_stack_table.slice_mut(s![padding_section_start..padding_section_end, ..]); + padding_section + .axis_iter_mut(Axis(0)) + .into_par_iter() + .for_each(|padding_row| padding_row_template.clone().move_into(padding_row)); + + // CLK keeps increasing by 1 also in the padding section. + let new_clk_values = + Array1::from_iter((table_len..padded_height).map(|clk| bfe!(clk as u64))); + new_clk_values.move_into(padding_section.slice_mut(s![.., CLK.base_table_index()])); + } + + fn extend( + main_table: ArrayView2, + mut aux_table: ArrayViewMut2, + challenges: &Challenges, + ) { + profiler!(start "jump stack table"); + assert_eq!(Self::MainColumn::COUNT, main_table.ncols()); + assert_eq!(Self::AuxColumn::COUNT, aux_table.ncols()); + assert_eq!(main_table.nrows(), aux_table.nrows()); + + // use strum::IntoEnumIterator; + let extension_column_indices = JumpStackExtTableColumn::iter() + .map(|column| column.ext_table_index()) + .collect_vec(); + let extension_column_slices = horizontal_multi_slice_mut( + aux_table.view_mut(), + &contiguous_column_slices(&extension_column_indices), + ); + let extension_functions = [ + extension_column_running_product_permutation_argument, + extension_column_clock_jump_diff_lookup_log_derivative, + ]; + + extension_functions + .into_par_iter() + .zip_eq(extension_column_slices) + .for_each(|(generator, slice)| { + generator(main_table, challenges).move_into(slice); + }); + + profiler!(stop "jump stack table"); + } +} diff --git a/triton-vm/src/table/jump_stack_table.rs b/triton-vm/src/table/jump_stack_table.rs deleted file mode 100644 index ab3ee0914..000000000 --- a/triton-vm/src/table/jump_stack_table.rs +++ /dev/null @@ -1,366 +0,0 @@ -use std::cmp::Ordering; -use std::collections::HashMap; -use std::ops::Range; - -use constraint_circuit::DualRowIndicator::*; -use constraint_circuit::SingleRowIndicator::*; -use constraint_circuit::*; -use isa::instruction::Instruction; -use itertools::Itertools; -use ndarray::parallel::prelude::*; -use ndarray::prelude::*; -use strum::EnumCount; -use strum::IntoEnumIterator; -use twenty_first::math::traits::FiniteField; -use twenty_first::prelude::*; - -use crate::aet::AlgebraicExecutionTrace; -use crate::ndarray_helper::contiguous_column_slices; -use crate::ndarray_helper::horizontal_multi_slice_mut; -use crate::profiler::profiler; -use crate::table::challenges::ChallengeId::*; -use crate::table::challenges::Challenges; -use crate::table::cross_table_argument::*; -use crate::table::table_column::JumpStackBaseTableColumn::*; -use crate::table::table_column::JumpStackExtTableColumn::*; -use crate::table::table_column::*; - -pub const BASE_WIDTH: usize = JumpStackBaseTableColumn::COUNT; -pub const EXT_WIDTH: usize = JumpStackExtTableColumn::COUNT; -pub const FULL_WIDTH: usize = BASE_WIDTH + EXT_WIDTH; - -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct JumpStackTable; - -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct ExtJumpStackTable; - -impl ExtJumpStackTable { - pub fn initial_constraints( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let clk = circuit_builder.input(BaseRow(CLK.master_base_table_index())); - let jsp = circuit_builder.input(BaseRow(JSP.master_base_table_index())); - let jso = circuit_builder.input(BaseRow(JSO.master_base_table_index())); - let jsd = circuit_builder.input(BaseRow(JSD.master_base_table_index())); - let ci = circuit_builder.input(BaseRow(CI.master_base_table_index())); - let rppa = circuit_builder.input(ExtRow(RunningProductPermArg.master_ext_table_index())); - let clock_jump_diff_log_derivative = circuit_builder.input(ExtRow( - ClockJumpDifferenceLookupClientLogDerivative.master_ext_table_index(), - )); - - let processor_perm_indeterminate = circuit_builder.challenge(JumpStackIndeterminate); - // note: `clk`, `jsp`, `jso`, and `jsd` are all constrained to be 0 and can thus be omitted. - let compressed_row = circuit_builder.challenge(JumpStackCiWeight) * ci; - let rppa_starts_correctly = rppa - (processor_perm_indeterminate - compressed_row); - - // A clock jump difference of 0 is not allowed. Hence, the initial is recorded. - let clock_jump_diff_log_derivative_starts_correctly = clock_jump_diff_log_derivative - - circuit_builder.x_constant(LookupArg::default_initial()); - - vec![ - clk, - jsp, - jso, - jsd, - rppa_starts_correctly, - clock_jump_diff_log_derivative_starts_correctly, - ] - } - - pub fn consistency_constraints( - _circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - // no further constraints - vec![] - } - - pub fn transition_constraints( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let one = || circuit_builder.b_constant(1); - let call_opcode = - circuit_builder.b_constant(Instruction::Call(BFieldElement::default()).opcode_b()); - let return_opcode = circuit_builder.b_constant(Instruction::Return.opcode_b()); - let recurse_or_return_opcode = - circuit_builder.b_constant(Instruction::RecurseOrReturn.opcode_b()); - - let clk = circuit_builder.input(CurrentBaseRow(CLK.master_base_table_index())); - let ci = circuit_builder.input(CurrentBaseRow(CI.master_base_table_index())); - let jsp = circuit_builder.input(CurrentBaseRow(JSP.master_base_table_index())); - let jso = circuit_builder.input(CurrentBaseRow(JSO.master_base_table_index())); - let jsd = circuit_builder.input(CurrentBaseRow(JSD.master_base_table_index())); - let rppa = circuit_builder.input(CurrentExtRow( - RunningProductPermArg.master_ext_table_index(), - )); - let clock_jump_diff_log_derivative = circuit_builder.input(CurrentExtRow( - ClockJumpDifferenceLookupClientLogDerivative.master_ext_table_index(), - )); - - let clk_next = circuit_builder.input(NextBaseRow(CLK.master_base_table_index())); - let ci_next = circuit_builder.input(NextBaseRow(CI.master_base_table_index())); - let jsp_next = circuit_builder.input(NextBaseRow(JSP.master_base_table_index())); - let jso_next = circuit_builder.input(NextBaseRow(JSO.master_base_table_index())); - let jsd_next = circuit_builder.input(NextBaseRow(JSD.master_base_table_index())); - let rppa_next = - circuit_builder.input(NextExtRow(RunningProductPermArg.master_ext_table_index())); - let clock_jump_diff_log_derivative_next = circuit_builder.input(NextExtRow( - ClockJumpDifferenceLookupClientLogDerivative.master_ext_table_index(), - )); - - let jsp_inc_or_stays = - (jsp_next.clone() - jsp.clone() - one()) * (jsp_next.clone() - jsp.clone()); - - let jsp_inc_by_one_or_ci_can_return = (jsp_next.clone() - jsp.clone() - one()) - * (ci.clone() - return_opcode) - * (ci.clone() - recurse_or_return_opcode); - let jsp_inc_or_jso_stays_or_ci_can_ret = - jsp_inc_by_one_or_ci_can_return.clone() * (jso_next.clone() - jso); - - let jsp_inc_or_jsd_stays_or_ci_can_ret = - jsp_inc_by_one_or_ci_can_return.clone() * (jsd_next.clone() - jsd); - - let jsp_inc_or_clk_inc_or_ci_call_or_ci_can_ret = jsp_inc_by_one_or_ci_can_return - * (clk_next.clone() - clk.clone() - one()) - * (ci.clone() - call_opcode); - - let compressed_row = circuit_builder.challenge(JumpStackClkWeight) * clk_next.clone() - + circuit_builder.challenge(JumpStackCiWeight) * ci_next - + circuit_builder.challenge(JumpStackJspWeight) * jsp_next.clone() - + circuit_builder.challenge(JumpStackJsoWeight) * jso_next - + circuit_builder.challenge(JumpStackJsdWeight) * jsd_next; - let rppa_updates_correctly = - rppa_next - rppa * (circuit_builder.challenge(JumpStackIndeterminate) - compressed_row); - - let log_derivative_remains = - clock_jump_diff_log_derivative_next.clone() - clock_jump_diff_log_derivative.clone(); - let clk_diff = clk_next - clk; - let log_derivative_accumulates = (clock_jump_diff_log_derivative_next - - clock_jump_diff_log_derivative) - * (circuit_builder.challenge(ClockJumpDifferenceLookupIndeterminate) - clk_diff) - - one(); - let log_derivative_updates_correctly = (jsp_next.clone() - jsp.clone() - one()) - * log_derivative_accumulates - + (jsp_next - jsp) * log_derivative_remains; - - vec![ - jsp_inc_or_stays, - jsp_inc_or_jso_stays_or_ci_can_ret, - jsp_inc_or_jsd_stays_or_ci_can_ret, - jsp_inc_or_clk_inc_or_ci_call_or_ci_can_ret, - rppa_updates_correctly, - log_derivative_updates_correctly, - ] - } - - pub fn terminal_constraints( - _circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - // no further constraints - vec![] - } -} - -impl JumpStackTable { - /// Fills the trace table in-place and returns all clock jump differences. - pub fn fill_trace( - jump_stack_table: &mut ArrayViewMut2, - aet: &AlgebraicExecutionTrace, - ) -> Vec { - // Store the registers relevant for the Jump Stack Table, i.e., CLK, CI, JSP, JSO, JSD, - // with JSP as the key. Preserves, thus allows reusing, the order of the processor's - // rows, which are sorted by CLK. - let mut pre_processed_jump_stack_table: Vec> = vec![]; - for processor_row in aet.processor_trace.rows() { - let clk = processor_row[ProcessorBaseTableColumn::CLK.base_table_index()]; - let ci = processor_row[ProcessorBaseTableColumn::CI.base_table_index()]; - let jsp = processor_row[ProcessorBaseTableColumn::JSP.base_table_index()]; - let jso = processor_row[ProcessorBaseTableColumn::JSO.base_table_index()]; - let jsd = processor_row[ProcessorBaseTableColumn::JSD.base_table_index()]; - // The (honest) prover can only grow the Jump Stack's size by at most 1 per execution - // step. Hence, the following (a) works, and (b) sorts. - let jsp_val = jsp.value() as usize; - let jump_stack_row = (clk, ci, jso, jsd); - match jsp_val.cmp(&pre_processed_jump_stack_table.len()) { - Ordering::Less => pre_processed_jump_stack_table[jsp_val].push(jump_stack_row), - Ordering::Equal => pre_processed_jump_stack_table.push(vec![jump_stack_row]), - Ordering::Greater => panic!("JSP must increase by at most 1 per execution step."), - } - } - - // Move the rows into the Jump Stack Table, sorted by JSP first, CLK second. - let mut jump_stack_table_row = 0; - for (jsp_val, rows_with_this_jsp) in pre_processed_jump_stack_table.into_iter().enumerate() - { - let jsp = bfe!(jsp_val as u64); - for (clk, ci, jso, jsd) in rows_with_this_jsp { - jump_stack_table[(jump_stack_table_row, CLK.base_table_index())] = clk; - jump_stack_table[(jump_stack_table_row, CI.base_table_index())] = ci; - jump_stack_table[(jump_stack_table_row, JSP.base_table_index())] = jsp; - jump_stack_table[(jump_stack_table_row, JSO.base_table_index())] = jso; - jump_stack_table[(jump_stack_table_row, JSD.base_table_index())] = jsd; - jump_stack_table_row += 1; - } - } - assert_eq!(aet.processor_trace.nrows(), jump_stack_table_row); - - // Collect all clock jump differences. - // The Jump Stack Table and the Processor Table have the same length. - let mut clock_jump_differences = vec![]; - for row_idx in 0..aet.processor_trace.nrows() - 1 { - let curr_row = jump_stack_table.row(row_idx); - let next_row = jump_stack_table.row(row_idx + 1); - let clk_diff = next_row[CLK.base_table_index()] - curr_row[CLK.base_table_index()]; - if curr_row[JSP.base_table_index()] == next_row[JSP.base_table_index()] { - clock_jump_differences.push(clk_diff); - } - } - clock_jump_differences - } - - pub fn pad_trace( - mut jump_stack_table: ArrayViewMut2, - processor_table_len: usize, - ) { - assert!( - processor_table_len > 0, - "Processor Table must have at least 1 row." - ); - - // Set up indices for relevant sections of the table. - let padded_height = jump_stack_table.nrows(); - let num_padding_rows = padded_height - processor_table_len; - let max_clk_before_padding = processor_table_len - 1; - let max_clk_before_padding_row_idx = jump_stack_table - .rows() - .into_iter() - .enumerate() - .find(|(_, row)| row[CLK.base_table_index()].value() as usize == max_clk_before_padding) - .map(|(idx, _)| idx) - .expect("Jump Stack Table must contain row with clock cycle equal to max cycle."); - let rows_to_move_source_section_start = max_clk_before_padding_row_idx + 1; - let rows_to_move_source_section_end = processor_table_len; - let num_rows_to_move = rows_to_move_source_section_end - rows_to_move_source_section_start; - let rows_to_move_dest_section_start = rows_to_move_source_section_start + num_padding_rows; - let rows_to_move_dest_section_end = rows_to_move_dest_section_start + num_rows_to_move; - let padding_section_start = rows_to_move_source_section_start; - let padding_section_end = padding_section_start + num_padding_rows; - assert_eq!(padded_height, rows_to_move_dest_section_end); - - // Move all rows below the row with highest CLK to the end of the table – if they exist. - if num_rows_to_move > 0 { - let rows_to_move_source_range = - rows_to_move_source_section_start..rows_to_move_source_section_end; - let rows_to_move_dest_range = - rows_to_move_dest_section_start..rows_to_move_dest_section_end; - let rows_to_move = jump_stack_table - .slice(s![rows_to_move_source_range, ..]) - .to_owned(); - rows_to_move - .move_into(&mut jump_stack_table.slice_mut(s![rows_to_move_dest_range, ..])); - } - - // Fill the created gap with padding rows, i.e., with copies of the last row before the - // gap. This is the padding section. - let padding_row_template = jump_stack_table - .row(max_clk_before_padding_row_idx) - .to_owned(); - let mut padding_section = - jump_stack_table.slice_mut(s![padding_section_start..padding_section_end, ..]); - padding_section - .axis_iter_mut(Axis(0)) - .into_par_iter() - .for_each(|padding_row| padding_row_template.clone().move_into(padding_row)); - - // CLK keeps increasing by 1 also in the padding section. - let new_clk_values = - Array1::from_iter((processor_table_len..padded_height).map(|clk| bfe!(clk as u64))); - new_clk_values.move_into(padding_section.slice_mut(s![.., CLK.base_table_index()])); - } - - pub fn extend( - base_table: ArrayView2, - mut ext_table: ArrayViewMut2, - challenges: &Challenges, - ) { - profiler!(start "jump stack table"); - assert_eq!(BASE_WIDTH, base_table.ncols()); - assert_eq!(EXT_WIDTH, ext_table.ncols()); - assert_eq!(base_table.nrows(), ext_table.nrows()); - - // use strum::IntoEnumIterator; - let extension_column_indices = JumpStackExtTableColumn::iter() - .map(|column| column.ext_table_index()) - .collect_vec(); - let extension_column_slices = horizontal_multi_slice_mut( - ext_table.view_mut(), - &contiguous_column_slices(&extension_column_indices), - ); - let extension_functions = [ - Self::extension_column_running_product_permutation_argument, - Self::extension_column_clock_jump_diff_lookup_log_derivative, - ]; - - extension_functions - .into_par_iter() - .zip_eq(extension_column_slices) - .for_each(|(generator, slice)| { - generator(base_table, challenges).move_into(slice); - }); - - profiler!(stop "jump stack table"); - } - - fn extension_column_running_product_permutation_argument( - base_table: ArrayView2, - challenges: &Challenges, - ) -> Array2 { - let mut running_product = PermArg::default_initial(); - let mut extension_column = Vec::with_capacity(base_table.nrows()); - for row in base_table.rows() { - let compressed_row = row[CLK.base_table_index()] * challenges[JumpStackClkWeight] - + row[CI.base_table_index()] * challenges[JumpStackCiWeight] - + row[JSP.base_table_index()] * challenges[JumpStackJspWeight] - + row[JSO.base_table_index()] * challenges[JumpStackJsoWeight] - + row[JSD.base_table_index()] * challenges[JumpStackJsdWeight]; - running_product *= challenges[JumpStackIndeterminate] - compressed_row; - extension_column.push(running_product); - } - Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() - } - - fn extension_column_clock_jump_diff_lookup_log_derivative( - base_table: ArrayView2, - challenges: &Challenges, - ) -> Array2 { - // - use memoization to avoid recomputing inverses - // - precompute common values through batch inversion - const PRECOMPUTE_INVERSES_OF: Range = 0..100; - let indeterminate = challenges[ClockJumpDifferenceLookupIndeterminate]; - let to_invert = PRECOMPUTE_INVERSES_OF - .map(|i| indeterminate - bfe!(i)) - .collect(); - let mut inverses_dictionary = PRECOMPUTE_INVERSES_OF - .zip_eq(XFieldElement::batch_inversion(to_invert)) - .map(|(i, inv)| (bfe!(i), inv)) - .collect::>(); - - // populate extension column using memoization - let mut cjd_lookup_log_derivative = LookupArg::default_initial(); - let mut extension_column = Vec::with_capacity(base_table.nrows()); - extension_column.push(cjd_lookup_log_derivative); - for (previous_row, current_row) in base_table.rows().into_iter().tuple_windows() { - if previous_row[JSP.base_table_index()] == current_row[JSP.base_table_index()] { - let previous_clock = previous_row[CLK.base_table_index()]; - let current_clock = current_row[CLK.base_table_index()]; - let clock_jump_difference = current_clock - previous_clock; - let &mut inverse = inverses_dictionary - .entry(clock_jump_difference) - .or_insert_with(|| (indeterminate - clock_jump_difference).inverse()); - cjd_lookup_log_derivative += inverse; - } - extension_column.push(cjd_lookup_log_derivative); - } - Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() - } -} diff --git a/triton-vm/src/table/lookup.rs b/triton-vm/src/table/lookup.rs new file mode 100644 index 000000000..17d0d2853 --- /dev/null +++ b/triton-vm/src/table/lookup.rs @@ -0,0 +1,156 @@ +use air::challenge_id::ChallengeId; +use air::challenge_id::ChallengeId::*; +use air::cross_table_argument::CrossTableArg; +use air::cross_table_argument::EvalArg; +use air::cross_table_argument::LookupArg; +use air::table::lookup::LookupTable; +use air::table_column::LookupBaseTableColumn; +use air::table_column::LookupBaseTableColumn::*; +use air::table_column::LookupExtTableColumn; +use air::table_column::LookupExtTableColumn::*; +use air::table_column::MasterBaseTableColumn; +use air::table_column::MasterExtTableColumn; +use constraint_circuit::ConstraintCircuitBuilder; +use constraint_circuit::ConstraintCircuitMonad; +use constraint_circuit::DualRowIndicator; +use constraint_circuit::DualRowIndicator::*; +use constraint_circuit::SingleRowIndicator; +use constraint_circuit::SingleRowIndicator::*; +use itertools::Itertools; +use ndarray::prelude::*; +use num_traits::ConstOne; +use num_traits::One; +use rayon::iter::*; +use strum::EnumCount; +use strum::IntoEnumIterator; +use twenty_first::prelude::*; + +use crate::aet::AlgebraicExecutionTrace; +use crate::challenges::Challenges; +use crate::ndarray_helper::contiguous_column_slices; +use crate::ndarray_helper::horizontal_multi_slice_mut; +use crate::profiler::profiler; +use crate::table::TraceTable; + +fn extension_column_cascade_running_sum_log_derivative( + base_table: ArrayView2, + challenges: &Challenges, +) -> Array2 { + let look_in_weight = challenges[LookupTableInputWeight]; + let look_out_weight = challenges[LookupTableOutputWeight]; + let indeterminate = challenges[CascadeLookupIndeterminate]; + + let mut cascade_table_running_sum_log_derivative = LookupArg::default_initial(); + let mut extension_column = Vec::with_capacity(base_table.nrows()); + for row in base_table.rows() { + if row[IsPadding.base_table_index()].is_one() { + break; + } + + let lookup_input = row[LookIn.base_table_index()]; + let lookup_output = row[LookOut.base_table_index()]; + let compressed_row = lookup_input * look_in_weight + lookup_output * look_out_weight; + + let lookup_multiplicity = row[LookupMultiplicity.base_table_index()]; + cascade_table_running_sum_log_derivative += + (indeterminate - compressed_row).inverse() * lookup_multiplicity; + + extension_column.push(cascade_table_running_sum_log_derivative); + } + + // fill padding section + extension_column.resize(base_table.nrows(), cascade_table_running_sum_log_derivative); + Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() +} + +fn extension_column_public_running_evaluation( + base_table: ArrayView2, + challenges: &Challenges, +) -> Array2 { + let mut running_evaluation = EvalArg::default_initial(); + let mut extension_column = Vec::with_capacity(base_table.nrows()); + for row in base_table.rows() { + if row[IsPadding.base_table_index()].is_one() { + break; + } + + running_evaluation = running_evaluation * challenges[LookupTablePublicIndeterminate] + + row[LookOut.base_table_index()]; + extension_column.push(running_evaluation); + } + + // fill padding section + extension_column.resize(base_table.nrows(), running_evaluation); + Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() +} + +impl TraceTable for LookupTable { + type FillParam = (); + type FillReturnInfo = (); + + fn fill(mut main_table: ArrayViewMut2, aet: &AlgebraicExecutionTrace, _: ()) { + const LOOKUP_TABLE_LEN: usize = tip5::LOOKUP_TABLE.len(); + assert!(main_table.nrows() >= LOOKUP_TABLE_LEN); + + // Lookup Table input + let lookup_input = Array1::from_iter((0..LOOKUP_TABLE_LEN).map(|i| bfe!(i as u64))); + let lookup_input_column = + main_table.slice_mut(s![..LOOKUP_TABLE_LEN, LookIn.base_table_index()]); + lookup_input.move_into(lookup_input_column); + + // Lookup Table output + let lookup_output = Array1::from_iter(tip5::LOOKUP_TABLE.map(BFieldElement::from)); + let lookup_output_column = + main_table.slice_mut(s![..LOOKUP_TABLE_LEN, LookOut.base_table_index()]); + lookup_output.move_into(lookup_output_column); + + // Lookup Table multiplicities + let lookup_multiplicities = Array1::from_iter( + aet.lookup_table_lookup_multiplicities + .map(BFieldElement::new), + ); + let lookup_multiplicities_column = main_table.slice_mut(s![ + ..LOOKUP_TABLE_LEN, + LookupMultiplicity.base_table_index() + ]); + lookup_multiplicities.move_into(lookup_multiplicities_column); + } + + fn pad(mut lookup_table: ArrayViewMut2, table_length: usize) { + lookup_table + .slice_mut(s![table_length.., IsPadding.base_table_index()]) + .fill(BFieldElement::ONE); + } + + fn extend( + main_table: ArrayView2, + mut aux_table: ArrayViewMut2, + challenges: &Challenges, + ) { + profiler!(start "lookup table"); + assert_eq!(Self::MainColumn::COUNT, main_table.ncols()); + assert_eq!(Self::AuxColumn::COUNT, aux_table.ncols()); + assert_eq!(main_table.nrows(), aux_table.nrows()); + + let extension_column_indices = LookupExtTableColumn::iter() + .map(|column| column.ext_table_index()) + .collect_vec(); + let extension_column_slices = horizontal_multi_slice_mut( + aux_table.view_mut(), + &contiguous_column_slices(&extension_column_indices), + ); + let extension_functions = [ + extension_column_cascade_running_sum_log_derivative, + extension_column_public_running_evaluation, + ]; + + extension_functions + .into_par_iter() + .zip_eq(extension_column_slices) + .for_each(|(generator, slice)| { + generator(main_table, challenges).move_into(slice); + }); + + profiler!(stop "lookup table"); + } +} diff --git a/triton-vm/src/table/lookup_table.rs b/triton-vm/src/table/lookup_table.rs deleted file mode 100644 index 37881e5fd..000000000 --- a/triton-vm/src/table/lookup_table.rs +++ /dev/null @@ -1,316 +0,0 @@ -use constraint_circuit::ConstraintCircuitBuilder; -use constraint_circuit::ConstraintCircuitMonad; -use constraint_circuit::DualRowIndicator; -use constraint_circuit::DualRowIndicator::*; -use constraint_circuit::SingleRowIndicator; -use constraint_circuit::SingleRowIndicator::*; -use itertools::Itertools; -use ndarray::prelude::*; -use num_traits::ConstOne; -use num_traits::One; -use rayon::iter::*; -use strum::EnumCount; -use strum::IntoEnumIterator; -use twenty_first::prelude::*; - -use crate::aet::AlgebraicExecutionTrace; -use crate::ndarray_helper::contiguous_column_slices; -use crate::ndarray_helper::horizontal_multi_slice_mut; -use crate::profiler::profiler; -use crate::table::challenges::ChallengeId; -use crate::table::challenges::ChallengeId::*; -use crate::table::challenges::Challenges; -use crate::table::cross_table_argument::CrossTableArg; -use crate::table::cross_table_argument::EvalArg; -use crate::table::cross_table_argument::LookupArg; -use crate::table::table_column::LookupBaseTableColumn; -use crate::table::table_column::LookupBaseTableColumn::*; -use crate::table::table_column::LookupExtTableColumn; -use crate::table::table_column::LookupExtTableColumn::*; -use crate::table::table_column::MasterBaseTableColumn; -use crate::table::table_column::MasterExtTableColumn; - -pub const BASE_WIDTH: usize = LookupBaseTableColumn::COUNT; -pub const EXT_WIDTH: usize = LookupExtTableColumn::COUNT; -pub const FULL_WIDTH: usize = BASE_WIDTH + EXT_WIDTH; - -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct LookupTable; - -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct ExtLookupTable; - -impl LookupTable { - pub fn fill_trace( - lookup_table: &mut ArrayViewMut2, - aet: &AlgebraicExecutionTrace, - ) { - const LOOKUP_TABLE_LEN: usize = tip5::LOOKUP_TABLE.len(); - assert!(lookup_table.nrows() >= LOOKUP_TABLE_LEN); - - // Lookup Table input - let lookup_input = Array1::from_iter((0..LOOKUP_TABLE_LEN).map(|i| bfe!(i as u64))); - let lookup_input_column = - lookup_table.slice_mut(s![..LOOKUP_TABLE_LEN, LookIn.base_table_index()]); - lookup_input.move_into(lookup_input_column); - - // Lookup Table output - let lookup_output = Array1::from_iter(tip5::LOOKUP_TABLE.map(BFieldElement::from)); - let lookup_output_column = - lookup_table.slice_mut(s![..LOOKUP_TABLE_LEN, LookOut.base_table_index()]); - lookup_output.move_into(lookup_output_column); - - // Lookup Table multiplicities - let lookup_multiplicities = Array1::from_iter( - aet.lookup_table_lookup_multiplicities - .map(BFieldElement::new), - ); - let lookup_multiplicities_column = lookup_table.slice_mut(s![ - ..LOOKUP_TABLE_LEN, - LookupMultiplicity.base_table_index() - ]); - lookup_multiplicities.move_into(lookup_multiplicities_column); - } - - pub fn pad_trace(mut lookup_table: ArrayViewMut2, lookup_table_length: usize) { - lookup_table - .slice_mut(s![lookup_table_length.., IsPadding.base_table_index()]) - .fill(BFieldElement::ONE); - } - - pub fn extend( - base_table: ArrayView2, - mut ext_table: ArrayViewMut2, - challenges: &Challenges, - ) { - profiler!(start "lookup table"); - assert_eq!(BASE_WIDTH, base_table.ncols()); - assert_eq!(EXT_WIDTH, ext_table.ncols()); - assert_eq!(base_table.nrows(), ext_table.nrows()); - - let extension_column_indices = LookupExtTableColumn::iter() - .map(|column| column.ext_table_index()) - .collect_vec(); - let extension_column_slices = horizontal_multi_slice_mut( - ext_table.view_mut(), - &contiguous_column_slices(&extension_column_indices), - ); - let extension_functions = [ - Self::extension_column_cascade_running_sum_log_derivative, - Self::extension_column_public_running_evaluation, - ]; - - extension_functions - .into_par_iter() - .zip_eq(extension_column_slices) - .for_each(|(generator, slice)| { - generator(base_table, challenges).move_into(slice); - }); - - profiler!(stop "lookup table"); - } - - fn extension_column_cascade_running_sum_log_derivative( - base_table: ArrayView2, - challenges: &Challenges, - ) -> Array2 { - let look_in_weight = challenges[LookupTableInputWeight]; - let look_out_weight = challenges[LookupTableOutputWeight]; - let indeterminate = challenges[CascadeLookupIndeterminate]; - - let mut cascade_table_running_sum_log_derivative = LookupArg::default_initial(); - let mut extension_column = Vec::with_capacity(base_table.nrows()); - for row in base_table.rows() { - if row[IsPadding.base_table_index()].is_one() { - break; - } - - let lookup_input = row[LookIn.base_table_index()]; - let lookup_output = row[LookOut.base_table_index()]; - let compressed_row = lookup_input * look_in_weight + lookup_output * look_out_weight; - - let lookup_multiplicity = row[LookupMultiplicity.base_table_index()]; - cascade_table_running_sum_log_derivative += - (indeterminate - compressed_row).inverse() * lookup_multiplicity; - - extension_column.push(cascade_table_running_sum_log_derivative); - } - - // fill padding section - extension_column.resize(base_table.nrows(), cascade_table_running_sum_log_derivative); - Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() - } - - fn extension_column_public_running_evaluation( - base_table: ArrayView2, - challenges: &Challenges, - ) -> Array2 { - let mut running_evaluation = EvalArg::default_initial(); - let mut extension_column = Vec::with_capacity(base_table.nrows()); - for row in base_table.rows() { - if row[IsPadding.base_table_index()].is_one() { - break; - } - - running_evaluation = running_evaluation * challenges[LookupTablePublicIndeterminate] - + row[LookOut.base_table_index()]; - extension_column.push(running_evaluation); - } - - // fill padding section - extension_column.resize(base_table.nrows(), running_evaluation); - Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() - } -} - -impl ExtLookupTable { - pub fn initial_constraints( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let base_row = |col_id: LookupBaseTableColumn| { - circuit_builder.input(BaseRow(col_id.master_base_table_index())) - }; - let ext_row = |col_id: LookupExtTableColumn| { - circuit_builder.input(ExtRow(col_id.master_ext_table_index())) - }; - let challenge = |challenge_id: ChallengeId| circuit_builder.challenge(challenge_id); - - let lookup_input = base_row(LookIn); - let lookup_output = base_row(LookOut); - let lookup_multiplicity = base_row(LookupMultiplicity); - let cascade_table_server_log_derivative = ext_row(CascadeTableServerLogDerivative); - let public_evaluation_argument = ext_row(PublicEvaluationArgument); - - let lookup_input_is_0 = lookup_input; - - // Lookup Argument with Cascade Table - // note: `lookup_input` is known to be 0 and thus doesn't appear in the compressed row - let lookup_argument_default_initial = - circuit_builder.x_constant(LookupArg::default_initial()); - let cascade_table_indeterminate = challenge(CascadeLookupIndeterminate); - let compressed_row = lookup_output.clone() * challenge(LookupTableOutputWeight); - let cascade_table_log_derivative_is_initialized_correctly = - (cascade_table_server_log_derivative - lookup_argument_default_initial) - * (cascade_table_indeterminate - compressed_row) - - lookup_multiplicity; - - // public Evaluation Argument - let eval_argument_default_initial = circuit_builder.x_constant(EvalArg::default_initial()); - let public_indeterminate = challenge(LookupTablePublicIndeterminate); - let public_evaluation_argument_is_initialized_correctly = public_evaluation_argument - - eval_argument_default_initial * public_indeterminate - - lookup_output; - - vec![ - lookup_input_is_0, - cascade_table_log_derivative_is_initialized_correctly, - public_evaluation_argument_is_initialized_correctly, - ] - } - - pub fn consistency_constraints( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constant = |c: u64| circuit_builder.b_constant(c); - let base_row = |col_id: LookupBaseTableColumn| { - circuit_builder.input(BaseRow(col_id.master_base_table_index())) - }; - - let padding_is_0_or_1 = base_row(IsPadding) * (constant(1) - base_row(IsPadding)); - - vec![padding_is_0_or_1] - } - - pub fn transition_constraints( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let one = || circuit_builder.b_constant(1); - - let current_base_row = |col_id: LookupBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col_id.master_base_table_index())) - }; - let next_base_row = |col_id: LookupBaseTableColumn| { - circuit_builder.input(NextBaseRow(col_id.master_base_table_index())) - }; - let current_ext_row = |col_id: LookupExtTableColumn| { - circuit_builder.input(CurrentExtRow(col_id.master_ext_table_index())) - }; - let next_ext_row = |col_id: LookupExtTableColumn| { - circuit_builder.input(NextExtRow(col_id.master_ext_table_index())) - }; - let challenge = |challenge_id: ChallengeId| circuit_builder.challenge(challenge_id); - - let lookup_input = current_base_row(LookIn); - let is_padding = current_base_row(IsPadding); - let cascade_table_server_log_derivative = current_ext_row(CascadeTableServerLogDerivative); - let public_evaluation_argument = current_ext_row(PublicEvaluationArgument); - - let lookup_input_next = next_base_row(LookIn); - let lookup_output_next = next_base_row(LookOut); - let lookup_multiplicity_next = next_base_row(LookupMultiplicity); - let is_padding_next = next_base_row(IsPadding); - let cascade_table_server_log_derivative_next = - next_ext_row(CascadeTableServerLogDerivative); - let public_evaluation_argument_next = next_ext_row(PublicEvaluationArgument); - - // Padding section is contiguous: if the current row is a padding row, then the next row - // is also a padding row. - let if_current_row_is_padding_row_then_next_row_is_padding_row = - is_padding * (one() - is_padding_next.clone()); - - // Lookup Table's input increments by 1 if and only if the next row is not a padding row - let if_next_row_is_padding_row_then_lookup_input_next_is_0 = - is_padding_next.clone() * lookup_input_next.clone(); - let if_next_row_is_not_padding_row_then_lookup_input_next_increments_by_1 = - (one() - is_padding_next.clone()) * (lookup_input_next.clone() - lookup_input - one()); - let lookup_input_increments_if_and_only_if_next_row_is_not_padding_row = - if_next_row_is_padding_row_then_lookup_input_next_is_0 - + if_next_row_is_not_padding_row_then_lookup_input_next_increments_by_1; - - // Lookup Argument with Cascade Table - let cascade_table_indeterminate = challenge(CascadeLookupIndeterminate); - let compressed_row = lookup_input_next * challenge(LookupTableInputWeight) - + lookup_output_next.clone() * challenge(LookupTableOutputWeight); - let cascade_table_log_derivative_remains = cascade_table_server_log_derivative_next.clone() - - cascade_table_server_log_derivative.clone(); - let cascade_table_log_derivative_updates = (cascade_table_server_log_derivative_next - - cascade_table_server_log_derivative) - * (cascade_table_indeterminate - compressed_row) - - lookup_multiplicity_next; - let cascade_table_log_derivative_updates_if_and_only_if_next_row_is_not_padding_row = - (one() - is_padding_next.clone()) * cascade_table_log_derivative_updates - + is_padding_next.clone() * cascade_table_log_derivative_remains; - - // public Evaluation Argument - let public_indeterminate = challenge(LookupTablePublicIndeterminate); - let public_evaluation_argument_remains = - public_evaluation_argument_next.clone() - public_evaluation_argument.clone(); - let public_evaluation_argument_updates = public_evaluation_argument_next - - public_evaluation_argument * public_indeterminate - - lookup_output_next; - let public_evaluation_argument_updates_if_and_only_if_next_row_is_not_padding_row = - (one() - is_padding_next.clone()) * public_evaluation_argument_updates - + is_padding_next * public_evaluation_argument_remains; - - vec![ - if_current_row_is_padding_row_then_next_row_is_padding_row, - lookup_input_increments_if_and_only_if_next_row_is_not_padding_row, - cascade_table_log_derivative_updates_if_and_only_if_next_row_is_not_padding_row, - public_evaluation_argument_updates_if_and_only_if_next_row_is_not_padding_row, - ] - } - - pub fn terminal_constraints( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let challenge = |challenge_id: ChallengeId| circuit_builder.challenge(challenge_id); - let ext_row = |col_id: LookupExtTableColumn| { - circuit_builder.input(ExtRow(col_id.master_ext_table_index())) - }; - - let narrow_table_terminal_matches_user_supplied_terminal = - ext_row(PublicEvaluationArgument) - challenge(LookupTablePublicTerminal); - - vec![narrow_table_terminal_matches_user_supplied_terminal] - } -} diff --git a/triton-vm/src/table/master_table.rs b/triton-vm/src/table/master_table.rs index 42fdf6d6a..ddbf67d18 100644 --- a/triton-vm/src/table/master_table.rs +++ b/triton-vm/src/table/master_table.rs @@ -4,6 +4,50 @@ use std::ops::Mul; use std::ops::MulAssign; use std::ops::Range; +use air::table::lookup::LookupTable; +use air::table::op_stack::OpStackTable; +use air::table::processor::ProcessorTable; +use air::table::program::ProgramTable; +use air::table::ram::RamTable; +use air::table::u32::U32Table; +use air::table::TableId; +use air::table::CASCADE_TABLE_END; +use air::table::CASCADE_TABLE_START; +use air::table::EXT_CASCADE_TABLE_END; +use air::table::EXT_CASCADE_TABLE_START; +use air::table::EXT_HASH_TABLE_END; +use air::table::EXT_HASH_TABLE_START; +use air::table::EXT_JUMP_STACK_TABLE_END; +use air::table::EXT_JUMP_STACK_TABLE_START; +use air::table::EXT_LOOKUP_TABLE_END; +use air::table::EXT_LOOKUP_TABLE_START; +use air::table::EXT_OP_STACK_TABLE_END; +use air::table::EXT_OP_STACK_TABLE_START; +use air::table::EXT_PROCESSOR_TABLE_END; +use air::table::EXT_PROCESSOR_TABLE_START; +use air::table::EXT_PROGRAM_TABLE_END; +use air::table::EXT_PROGRAM_TABLE_START; +use air::table::EXT_RAM_TABLE_END; +use air::table::EXT_RAM_TABLE_START; +use air::table::EXT_U32_TABLE_END; +use air::table::EXT_U32_TABLE_START; +use air::table::HASH_TABLE_END; +use air::table::HASH_TABLE_START; +use air::table::JUMP_STACK_TABLE_END; +use air::table::JUMP_STACK_TABLE_START; +use air::table::LOOKUP_TABLE_END; +use air::table::LOOKUP_TABLE_START; +use air::table::OP_STACK_TABLE_END; +use air::table::OP_STACK_TABLE_START; +use air::table::PROCESSOR_TABLE_END; +use air::table::PROCESSOR_TABLE_START; +use air::table::PROGRAM_TABLE_END; +use air::table::PROGRAM_TABLE_START; +use air::table::RAM_TABLE_END; +use air::table::RAM_TABLE_START; +use air::table::U32_TABLE_END; +use air::table::U32_TABLE_START; +use air::table_column::*; use itertools::Itertools; use master_table::extension_table::Evaluable; use ndarray::parallel::prelude::*; @@ -28,6 +72,7 @@ use twenty_first::util_types::algebraic_hasher; use crate::aet::AlgebraicExecutionTrace; use crate::arithmetic_domain::ArithmeticDomain; +use crate::challenges::Challenges; use crate::config::CacheDecision; use crate::error::ProvingError; use crate::ndarray_helper::fast_zeros_column_major; @@ -35,139 +80,13 @@ use crate::ndarray_helper::horizontal_multi_slice_mut; use crate::ndarray_helper::partial_sums; use crate::profiler::profiler; use crate::stark::NUM_RANDOMIZER_POLYNOMIALS; -use crate::table::cascade_table::CascadeTable; -use crate::table::challenges::Challenges; use crate::table::degree_lowering_table::DegreeLoweringTable; use crate::table::extension_table::all_degrees_with_origin; use crate::table::extension_table::DegreeWithOrigin; use crate::table::extension_table::Quotientable; -use crate::table::hash_table::HashTable; -use crate::table::jump_stack_table::JumpStackTable; -use crate::table::lookup_table::LookupTable; -use crate::table::op_stack_table::OpStackTable; -use crate::table::processor_table::ProcessorTable; -use crate::table::program_table::ProgramTable; -use crate::table::ram_table::RamTable; -use crate::table::table_column::*; -use crate::table::u32_table::U32Table; +use crate::table::processor::ClkJumpDiffs; use crate::table::*; -/// The degree of the AIR after the degree lowering step. -/// -/// Using substitution and the introduction of new variables, the degree of the AIR as specified -/// in the respective tables -/// (e.g., in [`processor_table::ExtProcessorTable::transition_constraints`]) -/// is lowered to this value. -/// For example, with a target degree of 2 and a (fictional) constraint of the form -/// `a = b²·c²·d`, -/// the degree lowering step could (as one among multiple possibilities) -/// - introduce new variables `e`, `f`, and `g`, -/// - introduce new constraints `e = b²`, `f = c²`, and `g = e·f`, -/// - replace the original constraint with `a = g·d`. -/// -/// The degree lowering happens in the constraint evaluation generator. -/// It can be executed by running `cargo run --bin constraint-evaluation-generator`. -/// Executing the constraint evaluator is a prerequisite for running both the Stark prover -/// and the Stark verifier. -/// -/// The new variables introduced by the degree lowering step are called “derived columns.” -/// They are added to the [`DegreeLoweringTable`], whose sole purpose is to store the values -/// of these derived columns. -pub const AIR_TARGET_DEGREE: isize = 4; - -/// The total number of base columns across all tables. -pub const NUM_BASE_COLUMNS: usize = program_table::BASE_WIDTH - + processor_table::BASE_WIDTH - + op_stack_table::BASE_WIDTH - + ram_table::BASE_WIDTH - + jump_stack_table::BASE_WIDTH - + hash_table::BASE_WIDTH - + cascade_table::BASE_WIDTH - + lookup_table::BASE_WIDTH - + u32_table::BASE_WIDTH - + degree_lowering_table::BASE_WIDTH; - -const NUM_EXT_COLUMNS_WITHOUT_RANDOMIZER_POLYS: usize = program_table::EXT_WIDTH - + processor_table::EXT_WIDTH - + op_stack_table::EXT_WIDTH - + ram_table::EXT_WIDTH - + jump_stack_table::EXT_WIDTH - + hash_table::EXT_WIDTH - + cascade_table::EXT_WIDTH - + lookup_table::EXT_WIDTH - + u32_table::EXT_WIDTH - + degree_lowering_table::EXT_WIDTH; - -/// The total number of extension columns across all tables. -/// Includes the columns required for [randomizer polynomials](NUM_RANDOMIZER_POLYNOMIALS). -pub const NUM_EXT_COLUMNS: usize = - NUM_EXT_COLUMNS_WITHOUT_RANDOMIZER_POLYS + NUM_RANDOMIZER_POLYNOMIALS; - -/// The total number of columns across all tables. -pub const NUM_COLUMNS: usize = NUM_BASE_COLUMNS + NUM_EXT_COLUMNS; - -pub const PROGRAM_TABLE_START: usize = 0; -pub const PROGRAM_TABLE_END: usize = PROGRAM_TABLE_START + program_table::BASE_WIDTH; -pub const PROCESSOR_TABLE_START: usize = PROGRAM_TABLE_END; -pub const PROCESSOR_TABLE_END: usize = PROCESSOR_TABLE_START + processor_table::BASE_WIDTH; -pub const OP_STACK_TABLE_START: usize = PROCESSOR_TABLE_END; -pub const OP_STACK_TABLE_END: usize = OP_STACK_TABLE_START + op_stack_table::BASE_WIDTH; -pub const RAM_TABLE_START: usize = OP_STACK_TABLE_END; -pub const RAM_TABLE_END: usize = RAM_TABLE_START + ram_table::BASE_WIDTH; -pub const JUMP_STACK_TABLE_START: usize = RAM_TABLE_END; -pub const JUMP_STACK_TABLE_END: usize = JUMP_STACK_TABLE_START + jump_stack_table::BASE_WIDTH; -pub const HASH_TABLE_START: usize = JUMP_STACK_TABLE_END; -pub const HASH_TABLE_END: usize = HASH_TABLE_START + hash_table::BASE_WIDTH; -pub const CASCADE_TABLE_START: usize = HASH_TABLE_END; -pub const CASCADE_TABLE_END: usize = CASCADE_TABLE_START + cascade_table::BASE_WIDTH; -pub const LOOKUP_TABLE_START: usize = CASCADE_TABLE_END; -pub const LOOKUP_TABLE_END: usize = LOOKUP_TABLE_START + lookup_table::BASE_WIDTH; -pub const U32_TABLE_START: usize = LOOKUP_TABLE_END; -pub const U32_TABLE_END: usize = U32_TABLE_START + u32_table::BASE_WIDTH; -pub const DEGREE_LOWERING_TABLE_START: usize = U32_TABLE_END; -pub const DEGREE_LOWERING_TABLE_END: usize = - DEGREE_LOWERING_TABLE_START + degree_lowering_table::BASE_WIDTH; - -pub const EXT_PROGRAM_TABLE_START: usize = 0; -pub const EXT_PROGRAM_TABLE_END: usize = EXT_PROGRAM_TABLE_START + program_table::EXT_WIDTH; -pub const EXT_PROCESSOR_TABLE_START: usize = EXT_PROGRAM_TABLE_END; -pub const EXT_PROCESSOR_TABLE_END: usize = EXT_PROCESSOR_TABLE_START + processor_table::EXT_WIDTH; -pub const EXT_OP_STACK_TABLE_START: usize = EXT_PROCESSOR_TABLE_END; -pub const EXT_OP_STACK_TABLE_END: usize = EXT_OP_STACK_TABLE_START + op_stack_table::EXT_WIDTH; -pub const EXT_RAM_TABLE_START: usize = EXT_OP_STACK_TABLE_END; -pub const EXT_RAM_TABLE_END: usize = EXT_RAM_TABLE_START + ram_table::EXT_WIDTH; -pub const EXT_JUMP_STACK_TABLE_START: usize = EXT_RAM_TABLE_END; -pub const EXT_JUMP_STACK_TABLE_END: usize = - EXT_JUMP_STACK_TABLE_START + jump_stack_table::EXT_WIDTH; -pub const EXT_HASH_TABLE_START: usize = EXT_JUMP_STACK_TABLE_END; -pub const EXT_HASH_TABLE_END: usize = EXT_HASH_TABLE_START + hash_table::EXT_WIDTH; -pub const EXT_CASCADE_TABLE_START: usize = EXT_HASH_TABLE_END; -pub const EXT_CASCADE_TABLE_END: usize = EXT_CASCADE_TABLE_START + cascade_table::EXT_WIDTH; -pub const EXT_LOOKUP_TABLE_START: usize = EXT_CASCADE_TABLE_END; -pub const EXT_LOOKUP_TABLE_END: usize = EXT_LOOKUP_TABLE_START + lookup_table::EXT_WIDTH; -pub const EXT_U32_TABLE_START: usize = EXT_LOOKUP_TABLE_END; -pub const EXT_U32_TABLE_END: usize = EXT_U32_TABLE_START + u32_table::EXT_WIDTH; -pub const EXT_DEGREE_LOWERING_TABLE_START: usize = EXT_U32_TABLE_END; -pub const EXT_DEGREE_LOWERING_TABLE_END: usize = - EXT_DEGREE_LOWERING_TABLE_START + degree_lowering_table::EXT_WIDTH; - -const NUM_TABLES_WITHOUT_DEGREE_LOWERING: usize = TableId::COUNT - 1; - -/// A `TableId` uniquely determines one of Triton VM's tables. -#[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter, Arbitrary)] -pub enum TableId { - Program, - Processor, - OpStack, - Ram, - JumpStack, - Hash, - Cascade, - Lookup, - U32, - DegreeLowering, -} - /// A Master Table is, in some sense, a top-level table of Triton VM. It contains all the data /// but little logic beyond bookkeeping and presenting the data in useful ways. Conversely, the /// individual tables contain no data but all the respective logic. Master Tables are @@ -374,7 +293,7 @@ where ); /// Requires having called - /// [`low_degree_extend_all_columns`](Self::low_degree_extend_all_columns) first. + /// [`low_degree_extend_all_columns`](Self::low_degree_extend_all_columns) first. fn interpolation_polynomials(&self) -> ArrayView1>; /// Get one row of the table at an arbitrary index. Notably, the index does not have to be in @@ -787,26 +706,24 @@ impl MasterBaseTable { // memory-like tables must be filled in before clock jump differences are known, hence // the break from the usual order let clk_jump_diffs_op_stack = - OpStackTable::fill_trace(&mut master_base_table.table_mut(TableId::OpStack), aet); - let clk_jump_diffs_ram = - RamTable::fill_trace(&mut master_base_table.table_mut(TableId::Ram), aet); + OpStackTable::fill(master_base_table.table_mut(TableId::OpStack), aet, ()); + let clk_jump_diffs_ram = RamTable::fill(master_base_table.table_mut(TableId::Ram), aet, ()); let clk_jump_diffs_jump_stack = - JumpStackTable::fill_trace(&mut master_base_table.table_mut(TableId::JumpStack), aet); - - let processor_table = &mut master_base_table.table_mut(TableId::Processor); - ProcessorTable::fill_trace( - processor_table, - aet, - &clk_jump_diffs_op_stack, - &clk_jump_diffs_ram, - &clk_jump_diffs_jump_stack, - ); + JumpStackTable::fill(master_base_table.table_mut(TableId::JumpStack), aet, ()); + + let clk_jump_diffs = ClkJumpDiffs { + op_stack: clk_jump_diffs_op_stack, + ram: clk_jump_diffs_ram, + jump_stack: clk_jump_diffs_jump_stack, + }; + let processor_table = master_base_table.table_mut(TableId::Processor); + ProcessorTable::fill(processor_table, aet, clk_jump_diffs); - ProgramTable::fill_trace(&mut master_base_table.table_mut(TableId::Program), aet); - HashTable::fill_trace(&mut master_base_table.table_mut(TableId::Hash), aet); - CascadeTable::fill_trace(&mut master_base_table.table_mut(TableId::Cascade), aet); - LookupTable::fill_trace(&mut master_base_table.table_mut(TableId::Lookup), aet); - U32Table::fill_trace(&mut master_base_table.table_mut(TableId::U32), aet); + ProgramTable::fill(master_base_table.table_mut(TableId::Program), aet, ()); + HashTable::fill(master_base_table.table_mut(TableId::Hash), aet, ()); + CascadeTable::fill(master_base_table.table_mut(TableId::Cascade), aet, ()); + LookupTable::fill(master_base_table.table_mut(TableId::Lookup), aet, ()); + U32Table::fill(master_base_table.table_mut(TableId::U32), aet, ()); // Filling the degree-lowering table only makes sense after padding has happened. // Hence, this table is omitted here. @@ -827,7 +744,7 @@ impl MasterBaseTable { .randomized_trace_table .slice_mut(s![..; unit_distance, ..]); - let base_tables: [_; NUM_TABLES_WITHOUT_DEGREE_LOWERING] = horizontal_multi_slice_mut( + let base_tables: [_; TableId::COUNT] = horizontal_multi_slice_mut( master_table_without_randomizers, &partial_sums(&[ ProgramBaseTableColumn::COUNT, @@ -845,7 +762,18 @@ impl MasterBaseTable { .unwrap(); profiler!(start "pad original tables"); - Self::all_pad_functions() + let all_pad_functions: [PadFunction; TableId::COUNT] = [ + ProgramTable::pad, + ProcessorTable::pad, + OpStackTable::pad, + RamTable::pad, + JumpStackTable::pad, + HashTable::pad, + CascadeTable::pad, + LookupTable::pad, + U32Table::pad, + ]; + all_pad_functions .into_par_iter() .zip_eq(base_tables.into_par_iter()) .zip_eq(table_lengths.into_par_iter()) @@ -859,21 +787,7 @@ impl MasterBaseTable { profiler!(stop "fill degree-lowering table"); } - fn all_pad_functions() -> [PadFunction; NUM_TABLES_WITHOUT_DEGREE_LOWERING] { - [ - ProgramTable::pad_trace, - ProcessorTable::pad_trace, - OpStackTable::pad_trace, - RamTable::pad_trace, - JumpStackTable::pad_trace, - HashTable::pad_trace, - CascadeTable::pad_trace, - LookupTable::pad_trace, - U32Table::pad_trace, - ] - } - - fn all_table_lengths(&self) -> [usize; NUM_TABLES_WITHOUT_DEGREE_LOWERING] { + fn all_table_lengths(&self) -> [usize; TableId::COUNT] { let processor_table_len = self.main_execution_len; let jump_stack_table_len = self.main_execution_len; @@ -898,10 +812,10 @@ impl MasterBaseTable { let num_rows = self.randomized_trace_table().nrows(); profiler!(start "initialize master table"); let mut randomized_trace_extension_table = - fast_zeros_column_major::(num_rows, NUM_EXT_COLUMNS); + fast_zeros_column_major(num_rows, NUM_EXT_COLUMNS + NUM_RANDOMIZER_POLYNOMIALS); randomized_trace_extension_table - .slice_mut(s![.., NUM_EXT_COLUMNS_WITHOUT_RANDOMIZER_POLYS..]) + .slice_mut(s![.., NUM_EXT_COLUMNS..]) .par_mapv_inplace(|_| random::()); profiler!(stop "initialize master table"); @@ -921,7 +835,7 @@ impl MasterBaseTable { let master_ext_table_without_randomizers = master_ext_table .randomized_trace_table .slice_mut(s![..; unit_distance, ..NUM_EXT_COLUMNS]); - let extension_tables: [_; NUM_TABLES_WITHOUT_DEGREE_LOWERING] = horizontal_multi_slice_mut( + let extension_tables: [_; TableId::COUNT] = horizontal_multi_slice_mut( master_ext_table_without_randomizers, &partial_sums(&[ ProgramExtTableColumn::COUNT, @@ -960,7 +874,7 @@ impl MasterBaseTable { master_ext_table } - fn all_extend_functions() -> [ExtendFunction; NUM_TABLES_WITHOUT_DEGREE_LOWERING] { + fn all_extend_functions() -> [ExtendFunction; TableId::COUNT] { [ ProgramTable::extend, ProcessorTable::extend, @@ -974,9 +888,7 @@ impl MasterBaseTable { ] } - fn base_tables_for_extending( - &self, - ) -> [ArrayView2; NUM_TABLES_WITHOUT_DEGREE_LOWERING] { + fn base_tables_for_extending(&self) -> [ArrayView2; TableId::COUNT] { [ self.table(TableId::Program), self.table(TableId::Processor), @@ -1002,7 +914,6 @@ impl MasterBaseTable { Cascade => CASCADE_TABLE_START..CASCADE_TABLE_END, Lookup => LOOKUP_TABLE_START..LOOKUP_TABLE_END, U32 => U32_TABLE_START..U32_TABLE_END, - DegreeLowering => DEGREE_LOWERING_TABLE_START..DEGREE_LOWERING_TABLE_END, } } @@ -1046,7 +957,6 @@ impl MasterExtTable { Cascade => EXT_CASCADE_TABLE_START..EXT_CASCADE_TABLE_END, Lookup => EXT_LOOKUP_TABLE_START..EXT_LOOKUP_TABLE_END, U32 => EXT_U32_TABLE_START..EXT_U32_TABLE_END, - DegreeLowering => EXT_DEGREE_LOWERING_TABLE_START..EXT_DEGREE_LOWERING_TABLE_END, } } @@ -1273,6 +1183,10 @@ mod tests { use fs_err as fs; use std::path::Path; + use air::cross_table_argument::GrandCrossTableArg; + use air::table::cascade::CascadeTable; + use air::table::hash::HashTable; + use air::table::jump_stack::JumpStackTable; use constraint_circuit::ConstraintCircuitBuilder; use constraint_circuit::ConstraintCircuitMonad; use constraint_circuit::DegreeLoweringInfo; @@ -1280,7 +1194,6 @@ mod tests { use constraint_circuit::SingleRowIndicator; use isa::instruction::Instruction; use isa::instruction::InstructionBit; - use master_table::cross_table_argument::GrandCrossTableArg; use ndarray::s; use ndarray::Array2; use num_traits::Zero; @@ -1303,20 +1216,9 @@ mod tests { use crate::stark::tests::*; use crate::table::degree_lowering_table::DegreeLoweringBaseTableColumn; use crate::table::degree_lowering_table::DegreeLoweringExtTableColumn; - use crate::table::table_column::*; use crate::table::*; use crate::triton_program; - use self::cascade_table::ExtCascadeTable; - use self::hash_table::ExtHashTable; - use self::jump_stack_table::ExtJumpStackTable; - use self::lookup_table::ExtLookupTable; - use self::op_stack_table::ExtOpStackTable; - use self::processor_table::ExtProcessorTable; - use self::program_table::ExtProgramTable; - use self::ram_table::ExtRamTable; - use self::u32_table::ExtU32Table; - use super::*; #[test] @@ -1325,45 +1227,41 @@ mod tests { let (_, _, master_base_table) = master_base_table_for_low_security_level(program); assert_eq!( - program_table::BASE_WIDTH, + ::MainColumn::COUNT, master_base_table.table(TableId::Program).ncols() ); assert_eq!( - processor_table::BASE_WIDTH, + ::MainColumn::COUNT, master_base_table.table(TableId::Processor).ncols() ); assert_eq!( - op_stack_table::BASE_WIDTH, + ::MainColumn::COUNT, master_base_table.table(TableId::OpStack).ncols() ); assert_eq!( - ram_table::BASE_WIDTH, + ::MainColumn::COUNT, master_base_table.table(TableId::Ram).ncols() ); assert_eq!( - jump_stack_table::BASE_WIDTH, + ::MainColumn::COUNT, master_base_table.table(TableId::JumpStack).ncols() ); assert_eq!( - hash_table::BASE_WIDTH, + ::MainColumn::COUNT, master_base_table.table(TableId::Hash).ncols() ); assert_eq!( - cascade_table::BASE_WIDTH, + ::MainColumn::COUNT, master_base_table.table(TableId::Cascade).ncols() ); assert_eq!( - lookup_table::BASE_WIDTH, + ::MainColumn::COUNT, master_base_table.table(TableId::Lookup).ncols() ); assert_eq!( - u32_table::BASE_WIDTH, + ::MainColumn::COUNT, master_base_table.table(TableId::U32).ncols() ); - assert_eq!( - degree_lowering_table::BASE_WIDTH, - master_base_table.table(TableId::DegreeLowering).ncols() - ); } #[test] @@ -1372,51 +1270,48 @@ mod tests { let (_, _, _, master_ext_table, _) = master_tables_for_low_security_level(program); assert_eq!( - program_table::EXT_WIDTH, + ::AuxColumn::COUNT, master_ext_table.table(TableId::Program).ncols() ); assert_eq!( - processor_table::EXT_WIDTH, + ::AuxColumn::COUNT, master_ext_table.table(TableId::Processor).ncols() ); assert_eq!( - op_stack_table::EXT_WIDTH, + ::AuxColumn::COUNT, master_ext_table.table(TableId::OpStack).ncols() ); assert_eq!( - ram_table::EXT_WIDTH, + ::AuxColumn::COUNT, master_ext_table.table(TableId::Ram).ncols() ); assert_eq!( - jump_stack_table::EXT_WIDTH, + ::AuxColumn::COUNT, master_ext_table.table(TableId::JumpStack).ncols() ); assert_eq!( - hash_table::EXT_WIDTH, + ::AuxColumn::COUNT, master_ext_table.table(TableId::Hash).ncols() ); assert_eq!( - cascade_table::EXT_WIDTH, + ::AuxColumn::COUNT, master_ext_table.table(TableId::Cascade).ncols() ); assert_eq!( - lookup_table::EXT_WIDTH, + ::AuxColumn::COUNT, master_ext_table.table(TableId::Lookup).ncols() ); assert_eq!( - u32_table::EXT_WIDTH, + ::AuxColumn::COUNT, master_ext_table.table(TableId::U32).ncols() ); - assert_eq!( - degree_lowering_table::EXT_WIDTH, - master_ext_table.table(TableId::DegreeLowering).ncols() - ); + // use some domain-specific knowledge to also check for the randomizer columns assert_eq!( NUM_RANDOMIZER_POLYNOMIALS, master_ext_table .randomized_trace_table() - .slice(s![.., EXT_DEGREE_LOWERING_TABLE_END..]) + .slice(s![.., EXT_U32_TABLE_END..]) .ncols() ); } @@ -1476,25 +1371,6 @@ mod tests { .is_zero()); } - macro_rules! constraints_without_degree_lowering { - ($constraint_type: ident) => {{ - let circuit_builder = ConstraintCircuitBuilder::new(); - vec![ - ExtProgramTable::$constraint_type(&circuit_builder), - ExtProcessorTable::$constraint_type(&circuit_builder), - ExtOpStackTable::$constraint_type(&circuit_builder), - ExtRamTable::$constraint_type(&circuit_builder), - ExtJumpStackTable::$constraint_type(&circuit_builder), - ExtHashTable::$constraint_type(&circuit_builder), - ExtCascadeTable::$constraint_type(&circuit_builder), - ExtLookupTable::$constraint_type(&circuit_builder), - ExtU32Table::$constraint_type(&circuit_builder), - GrandCrossTableArg::$constraint_type(&circuit_builder), - ] - .concat() - }}; - } - struct SpecSnippet { pub start_marker: &'static str, pub stop_marker: &'static str, @@ -1537,36 +1413,32 @@ mod tests { const NUM_DEGREE_LOWERING_TARGETS: usize = 3; const DEGREE_LOWERING_TARGETS: [Option; NUM_DEGREE_LOWERING_TARGETS] = [None, Some(8), Some(4)]; - assert!(DEGREE_LOWERING_TARGETS.contains(&Some(AIR_TARGET_DEGREE))); - macro_rules! table_info { - ($($module:ident: $name:literal at $location:literal),* $(,)?) => {{ - let mut info = vec![]; - $( - let name = format!("[{}]({})", $name, $location); - info.push( - ( - name, - [$module::BASE_WIDTH; NUM_DEGREE_LOWERING_TARGETS], - [$module::EXT_WIDTH; NUM_DEGREE_LOWERING_TARGETS] - ) - ); - )* - info - }}; + fn table_widths() -> ((usize, usize)) { + (A::MainColumn::COUNT, A::AuxColumn::COUNT) } - let mut all_table_info = table_info![ - program_table: "ProgramTable" at "program-table.md", - processor_table: "ProcessorTable" at "processor-table.md", - op_stack_table: "OpStackTable" at "operational-stack-table.md", - ram_table: "RamTable" at "random-access-memory-table.md", - jump_stack_table: "JumpStackTable" at "jump-stack-table.md", - hash_table: "HashTable" at "hash-table.md", - cascade_table: "CascadeTable" at "cascade-table.md", - lookup_table: "LookupTable" at "lookup-table.md", - u32_table: "U32Table" at "u32-table.md", - ]; + assert!(DEGREE_LOWERING_TARGETS.contains(&Some(air::TARGET_DEGREE))); + + let mut all_table_info = [ + ("program-table.md", table_widths::()), + ("processor-table.md", table_widths::()), + ("operational-stack-table.md", table_widths::()), + ("random-access-memory-table.md", table_widths::()), + ("jump-stack-table.md", table_widths::()), + ("hash-table.md", table_widths::()), + ("cascade-table.md", table_widths::()), + ("lookup-table.md", table_widths::()), + ("u32-table.md", table_widths::()), + ] + .map(|(description, (main_width, aux_width))| { + ( + description.to_string(), + [main_width; NUM_DEGREE_LOWERING_TARGETS], + [aux_width; NUM_DEGREE_LOWERING_TARGETS], + ) + }) + .to_vec(); let mut deg_low_main = vec![]; let mut deg_low_aux = vec![]; @@ -1583,13 +1455,6 @@ mod tests { num_ext_cols: 0, }; - let initial_constraints = constraints_without_degree_lowering!(initial_constraints); - let consistency_constraints = - constraints_without_degree_lowering!(consistency_constraints); - let transition_constraints = - constraints_without_degree_lowering!(transition_constraints); - let terminal_constraints = constraints_without_degree_lowering!(terminal_constraints); - // generic closures are not possible; define two variants :( let lower_to_target_degree_single_row = |mut constraints: Vec<_>| { ConstraintCircuitMonad::lower_to_degree(&mut constraints, degree_lowering_info) @@ -1598,10 +1463,11 @@ mod tests { ConstraintCircuitMonad::lower_to_degree(&mut constraints, degree_lowering_info) }; - let (init_main, init_aux) = lower_to_target_degree_single_row(initial_constraints); - let (cons_main, cons_aux) = lower_to_target_degree_single_row(consistency_constraints); - let (tran_main, tran_aux) = lower_to_target_degree_double_row(transition_constraints); - let (term_main, term_aux) = lower_to_target_degree_single_row(terminal_constraints); + let constraints = crate::table::constraints(); + let (init_main, init_aux) = lower_to_target_degree_single_row(constraints.init); + let (cons_main, cons_aux) = lower_to_target_degree_single_row(constraints.cons); + let (tran_main, tran_aux) = lower_to_target_degree_double_row(constraints.tran); + let (term_main, term_aux) = lower_to_target_degree_single_row(constraints.term); deg_low_main .push(init_main.len() + cons_main.len() + tran_main.len() + term_main.len()); @@ -1725,7 +1591,7 @@ mod tests { const ZERO: usize = 0; let degree_lowering_targets = [None, Some(8), Some(4)]; - assert!(degree_lowering_targets.contains(&Some(AIR_TARGET_DEGREE))); + assert!(degree_lowering_targets.contains(&Some(air::TARGET_DEGREE))); let mut ft = String::new(); for target_degree in degree_lowering_targets { @@ -1759,23 +1625,23 @@ mod tests { let mut total_max_degree = 0; let mut tables = constraint_overview_rows!( - ExtProgramTable ends at PROGRAM_TABLE_END and EXT_PROGRAM_TABLE_END. + ProgramTable ends at PROGRAM_TABLE_END and EXT_PROGRAM_TABLE_END. Spec: ["ProgramTable"]("program-table.md"), - ExtProcessorTable ends at PROCESSOR_TABLE_END and EXT_PROCESSOR_TABLE_END. + ProcessorTable ends at PROCESSOR_TABLE_END and EXT_PROCESSOR_TABLE_END. Spec: ["ProcessorTable"]("processor-table.md"), - ExtOpStackTable ends at OP_STACK_TABLE_END and EXT_OP_STACK_TABLE_END. + OpStackTable ends at OP_STACK_TABLE_END and EXT_OP_STACK_TABLE_END. Spec: ["OpStackTable"]("operational-stack-table.md"), - ExtRamTable ends at RAM_TABLE_END and EXT_RAM_TABLE_END. + RamTable ends at RAM_TABLE_END and EXT_RAM_TABLE_END. Spec: ["RamTable"]("random-access-memory-table.md"), - ExtJumpStackTable ends at JUMP_STACK_TABLE_END and EXT_JUMP_STACK_TABLE_END. + JumpStackTable ends at JUMP_STACK_TABLE_END and EXT_JUMP_STACK_TABLE_END. Spec: ["JumpStackTable"]("jump-stack-table.md"), - ExtHashTable ends at HASH_TABLE_END and EXT_HASH_TABLE_END. + HashTable ends at HASH_TABLE_END and EXT_HASH_TABLE_END. Spec: ["HashTable"]("hash-table.md"), - ExtCascadeTable ends at CASCADE_TABLE_END and EXT_CASCADE_TABLE_END. + CascadeTable ends at CASCADE_TABLE_END and EXT_CASCADE_TABLE_END. Spec: ["CascadeTable"]("cascade-table.md"), - ExtLookupTable ends at LOOKUP_TABLE_END and EXT_LOOKUP_TABLE_END. + LookupTable ends at LOOKUP_TABLE_END and EXT_LOOKUP_TABLE_END. Spec: ["LookupTable"]("lookup-table.md"), - ExtU32Table ends at U32_TABLE_END and EXT_U32_TABLE_END. + U32Table ends at U32_TABLE_END and EXT_U32_TABLE_END. Spec: ["U32Table"]("u32-table.md"), GrandCrossTableArg ends at ZERO and ZERO. Spec: ["Grand Cross-Table Argument"]("table-linking.md"), @@ -2068,7 +1934,7 @@ mod tests { print_columns!(base CascadeBaseTableColumn for "cascade"); print_columns!(base LookupBaseTableColumn for "lookup"); print_columns!(base U32BaseTableColumn for "u32"); - print_columns!(base DegreeLoweringBaseTableColumn for "degree low."); + // print_columns!(base DegreeLoweringBaseTableColumn for "degree low."); // todo println!(); println!("| idx | table | extension column"); @@ -2082,7 +1948,7 @@ mod tests { print_columns!(ext CascadeExtTableColumn for "cascade"); print_columns!(ext LookupExtTableColumn for "lookup"); print_columns!(ext U32ExtTableColumn for "u32"); - print_columns!(ext DegreeLoweringExtTableColumn for "degree low."); + // print_columns!(ext DegreeLoweringExtTableColumn for "degree low."); // todo } #[test] diff --git a/triton-vm/src/table/op_stack.rs b/triton-vm/src/table/op_stack.rs new file mode 100644 index 000000000..6c128f6f1 --- /dev/null +++ b/triton-vm/src/table/op_stack.rs @@ -0,0 +1,394 @@ +use std::cmp::Ordering; +use std::collections::HashMap; +use std::ops::Range; + +use air::challenge_id::ChallengeId::*; +use air::cross_table_argument::*; +use air::table::op_stack::OpStackTable; +use air::table::op_stack::PADDING_VALUE; +use air::table::TableId; +use air::table_column::OpStackBaseTableColumn::*; +use air::table_column::OpStackExtTableColumn::*; +use air::table_column::*; +use air::AIR; +use arbitrary::Arbitrary; +use constraint_circuit::DualRowIndicator::*; +use constraint_circuit::SingleRowIndicator::*; +use constraint_circuit::*; +use isa::op_stack::OpStackElement; +use isa::op_stack::UnderflowIO; +use itertools::Itertools; +use ndarray::parallel::prelude::*; +use ndarray::prelude::*; +use strum::EnumCount; +use strum::IntoEnumIterator; +use twenty_first::math::traits::FiniteField; +use twenty_first::prelude::*; + +use crate::aet::AlgebraicExecutionTrace; +use crate::challenges::Challenges; +use crate::ndarray_helper::contiguous_column_slices; +use crate::ndarray_helper::horizontal_multi_slice_mut; +use crate::profiler::profiler; +use crate::table::TraceTable; + +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)] +pub struct OpStackTableEntry { + pub clk: u32, + pub op_stack_pointer: BFieldElement, + pub underflow_io: UnderflowIO, +} + +impl OpStackTableEntry { + pub fn new(clk: u32, op_stack_pointer: BFieldElement, underflow_io: UnderflowIO) -> Self { + Self { + clk, + op_stack_pointer, + underflow_io, + } + } + + pub fn shrinks_stack(&self) -> bool { + self.underflow_io.shrinks_stack() + } + + pub fn grows_stack(&self) -> bool { + self.underflow_io.grows_stack() + } + + pub fn from_underflow_io_sequence( + clk: u32, + op_stack_pointer_after_sequence_execution: BFieldElement, + mut underflow_io_sequence: Vec, + ) -> Vec { + UnderflowIO::canonicalize_sequence(&mut underflow_io_sequence); + assert!(UnderflowIO::is_uniform_sequence(&underflow_io_sequence)); + + let sequence_length: BFieldElement = + u32::try_from(underflow_io_sequence.len()).unwrap().into(); + let mut op_stack_pointer = match UnderflowIO::is_writing_sequence(&underflow_io_sequence) { + true => op_stack_pointer_after_sequence_execution - sequence_length, + false => op_stack_pointer_after_sequence_execution + sequence_length, + }; + let mut op_stack_table_entries = vec![]; + for underflow_io in underflow_io_sequence { + if underflow_io.shrinks_stack() { + op_stack_pointer.decrement(); + } + let op_stack_table_entry = Self::new(clk, op_stack_pointer, underflow_io); + op_stack_table_entries.push(op_stack_table_entry); + if underflow_io.grows_stack() { + op_stack_pointer.increment(); + } + } + op_stack_table_entries + } + + pub fn to_base_table_row(self) -> Array1 { + let shrink_stack_indicator = if self.shrinks_stack() { + bfe!(1) + } else { + bfe!(0) + }; + + let mut row = Array1::zeros(::MainColumn::COUNT); + row[CLK.base_table_index()] = self.clk.into(); + row[IB1ShrinkStack.base_table_index()] = shrink_stack_indicator; + row[StackPointer.base_table_index()] = self.op_stack_pointer; + row[FirstUnderflowElement.base_table_index()] = self.underflow_io.payload(); + row + } +} + +fn extension_column_running_product_permutation_argument( + base_table: ArrayView2, + challenges: &Challenges, +) -> Array2 { + let perm_arg_indeterminate = challenges[OpStackIndeterminate]; + + let mut running_product = PermArg::default_initial(); + let mut extension_column = Vec::with_capacity(base_table.nrows()); + for row in base_table.rows() { + if row[IB1ShrinkStack.base_table_index()] != PADDING_VALUE { + let compressed_row = row[CLK.base_table_index()] * challenges[OpStackClkWeight] + + row[IB1ShrinkStack.base_table_index()] * challenges[OpStackIb1Weight] + + row[StackPointer.base_table_index()] * challenges[OpStackPointerWeight] + + row[FirstUnderflowElement.base_table_index()] + * challenges[OpStackFirstUnderflowElementWeight]; + running_product *= perm_arg_indeterminate - compressed_row; + } + extension_column.push(running_product); + } + Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() +} + +fn extension_column_clock_jump_diff_lookup_log_derivative( + base_table: ArrayView2, + challenges: &Challenges, +) -> Array2 { + // - use memoization to avoid recomputing inverses + // - precompute common values through batch inversion + const PRECOMPUTE_INVERSES_OF: Range = 0..100; + let cjd_lookup_indeterminate = challenges[ClockJumpDifferenceLookupIndeterminate]; + let to_invert = PRECOMPUTE_INVERSES_OF + .map(|i| cjd_lookup_indeterminate - bfe!(i)) + .collect_vec(); + let inverses = XFieldElement::batch_inversion(to_invert); + let mut inverses_dictionary = PRECOMPUTE_INVERSES_OF + .zip_eq(inverses) + .map(|(i, inv)| (bfe!(i), inv)) + .collect::>(); + + // populate extension column using memoization + let mut cjd_lookup_log_derivative = LookupArg::default_initial(); + let mut extension_column = Vec::with_capacity(base_table.nrows()); + extension_column.push(cjd_lookup_log_derivative); + for (previous_row, current_row) in base_table.rows().into_iter().tuple_windows() { + if current_row[IB1ShrinkStack.base_table_index()] == PADDING_VALUE { + break; + }; + + let previous_stack_pointer = previous_row[StackPointer.base_table_index()]; + let current_stack_pointer = current_row[StackPointer.base_table_index()]; + if previous_stack_pointer == current_stack_pointer { + let previous_clock = previous_row[CLK.base_table_index()]; + let current_clock = current_row[CLK.base_table_index()]; + let clock_jump_difference = current_clock - previous_clock; + let &mut inverse = inverses_dictionary + .entry(clock_jump_difference) + .or_insert_with(|| (cjd_lookup_indeterminate - clock_jump_difference).inverse()); + cjd_lookup_log_derivative += inverse; + } + extension_column.push(cjd_lookup_log_derivative); + } + + // fill padding section + extension_column.resize(base_table.nrows(), cjd_lookup_log_derivative); + Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() +} + +impl TraceTable for OpStackTable { + type FillParam = (); + type FillReturnInfo = Vec; + + fn fill( + mut op_stack_table: ArrayViewMut2, + aet: &AlgebraicExecutionTrace, + _: Self::FillParam, + ) -> Vec { + let mut op_stack_table = + op_stack_table.slice_mut(s![0..aet.height_of_table(TableId::OpStack), ..]); + let trace_iter = aet.op_stack_underflow_trace.rows().into_iter(); + + let sorted_rows = + trace_iter.sorted_by(|row_0, row_1| compare_rows(row_0.view(), row_1.view())); + for (row_index, row) in sorted_rows.enumerate() { + op_stack_table.row_mut(row_index).assign(&row); + } + + clock_jump_differences(op_stack_table.view()) + } + + fn pad(mut op_stack_table: ArrayViewMut2, op_stack_table_len: usize) { + let last_row_index = op_stack_table_len.saturating_sub(1); + let mut padding_row = op_stack_table.row(last_row_index).to_owned(); + padding_row[IB1ShrinkStack.base_table_index()] = PADDING_VALUE; + if op_stack_table_len == 0 { + let first_stack_pointer = u32::try_from(OpStackElement::COUNT).unwrap().into(); + padding_row[StackPointer.base_table_index()] = first_stack_pointer; + } + + let mut padding_section = op_stack_table.slice_mut(s![op_stack_table_len.., ..]); + padding_section + .axis_iter_mut(Axis(0)) + .into_par_iter() + .for_each(|mut row| row.assign(&padding_row)); + } + + fn extend( + main_table: ArrayView2, + mut aux_table: ArrayViewMut2, + challenges: &Challenges, + ) { + profiler!(start "op stack table"); + assert_eq!(Self::MainColumn::COUNT, main_table.ncols()); + assert_eq!(Self::AuxColumn::COUNT, aux_table.ncols()); + assert_eq!(main_table.nrows(), aux_table.nrows()); + + let extension_column_indices = OpStackExtTableColumn::iter() + .map(|column| column.ext_table_index()) + .collect_vec(); + let extension_column_slices = horizontal_multi_slice_mut( + aux_table.view_mut(), + &contiguous_column_slices(&extension_column_indices), + ); + let extension_functions = [ + extension_column_running_product_permutation_argument, + extension_column_clock_jump_diff_lookup_log_derivative, + ]; + + extension_functions + .into_par_iter() + .zip_eq(extension_column_slices) + .for_each(|(generator, slice)| { + generator(main_table, challenges).move_into(slice); + }); + + profiler!(stop "op stack table"); + } +} + +fn compare_rows(row_0: ArrayView1, row_1: ArrayView1) -> Ordering { + let stack_pointer_0 = row_0[StackPointer.base_table_index()].value(); + let stack_pointer_1 = row_1[StackPointer.base_table_index()].value(); + let compare_stack_pointers = stack_pointer_0.cmp(&stack_pointer_1); + + let clk_0 = row_0[CLK.base_table_index()].value(); + let clk_1 = row_1[CLK.base_table_index()].value(); + let compare_clocks = clk_0.cmp(&clk_1); + + compare_stack_pointers.then(compare_clocks) +} + +fn clock_jump_differences(op_stack_table: ArrayView2) -> Vec { + let mut clock_jump_differences = vec![]; + for consecutive_rows in op_stack_table.axis_windows(Axis(0), 2) { + let current_row = consecutive_rows.row(0); + let next_row = consecutive_rows.row(1); + let current_stack_pointer = current_row[StackPointer.base_table_index()]; + let next_stack_pointer = next_row[StackPointer.base_table_index()]; + if current_stack_pointer == next_stack_pointer { + let current_clk = current_row[CLK.base_table_index()]; + let next_clk = next_row[CLK.base_table_index()]; + let clk_difference = next_clk - current_clk; + clock_jump_differences.push(clk_difference); + } + } + clock_jump_differences +} + +#[cfg(test)] +pub(crate) mod tests { + use assert2::assert; + use isa::op_stack::OpStackElement; + use itertools::Itertools; + use proptest::collection::vec; + use proptest::prelude::*; + use proptest_arbitrary_interop::arb; + use test_strategy::proptest; + + use super::*; + + #[proptest] + fn op_stack_table_entry_either_shrinks_stack_or_grows_stack( + #[strategy(arb())] entry: OpStackTableEntry, + ) { + let shrinks_stack = entry.shrinks_stack(); + let grows_stack = entry.grows_stack(); + assert!(shrinks_stack ^ grows_stack); + } + + #[proptest] + fn op_stack_pointer_in_sequence_of_op_stack_table_entries( + clk: u32, + #[strategy(OpStackElement::COUNT..1024)] stack_pointer: usize, + #[strategy(vec(arb(), ..OpStackElement::COUNT))] base_field_elements: Vec, + sequence_of_writes: bool, + ) { + let sequence_length = u64::try_from(base_field_elements.len()).unwrap(); + let stack_pointer = u64::try_from(stack_pointer).unwrap(); + + let underflow_io_operation = match sequence_of_writes { + true => UnderflowIO::Write, + false => UnderflowIO::Read, + }; + let underflow_io = base_field_elements + .into_iter() + .map(underflow_io_operation) + .collect(); + + let op_stack_pointer = stack_pointer.into(); + let entries = + OpStackTableEntry::from_underflow_io_sequence(clk, op_stack_pointer, underflow_io); + let op_stack_pointers = entries + .iter() + .map(|entry| entry.op_stack_pointer.value()) + .sorted() + .collect_vec(); + + let expected_stack_pointer_range = match sequence_of_writes { + true => stack_pointer - sequence_length..stack_pointer, + false => stack_pointer..stack_pointer + sequence_length, + }; + let expected_op_stack_pointers = expected_stack_pointer_range.collect_vec(); + prop_assert_eq!(expected_op_stack_pointers, op_stack_pointers); + } + + #[proptest] + fn clk_stays_same_in_sequence_of_op_stack_table_entries( + clk: u32, + #[strategy(OpStackElement::COUNT..1024)] stack_pointer: usize, + #[strategy(vec(arb(), ..OpStackElement::COUNT))] base_field_elements: Vec, + sequence_of_writes: bool, + ) { + let underflow_io_operation = match sequence_of_writes { + true => UnderflowIO::Write, + false => UnderflowIO::Read, + }; + let underflow_io = base_field_elements + .into_iter() + .map(underflow_io_operation) + .collect(); + + let op_stack_pointer = u64::try_from(stack_pointer).unwrap().into(); + let entries = + OpStackTableEntry::from_underflow_io_sequence(clk, op_stack_pointer, underflow_io); + let clk_values = entries.iter().map(|entry| entry.clk).collect_vec(); + let all_clk_values_are_clk = clk_values.iter().all(|&c| c == clk); + prop_assert!(all_clk_values_are_clk); + } + + #[proptest] + fn compare_rows_with_unequal_stack_pointer_and_equal_clk( + stack_pointer_0: u64, + stack_pointer_1: u64, + clk: u64, + ) { + const BASE_WIDTH: usize = ::MainColumn::COUNT; + + let mut row_0 = Array1::zeros(BASE_WIDTH); + row_0[StackPointer.base_table_index()] = stack_pointer_0.into(); + row_0[CLK.base_table_index()] = clk.into(); + + let mut row_1 = Array1::zeros(BASE_WIDTH); + row_1[StackPointer.base_table_index()] = stack_pointer_1.into(); + row_1[CLK.base_table_index()] = clk.into(); + + let stack_pointer_comparison = stack_pointer_0.cmp(&stack_pointer_1); + let row_comparison = compare_rows(row_0.view(), row_1.view()); + + prop_assert_eq!(stack_pointer_comparison, row_comparison); + } + + #[proptest] + fn compare_rows_with_equal_stack_pointer_and_unequal_clk( + stack_pointer: u64, + clk_0: u64, + clk_1: u64, + ) { + const BASE_WIDTH: usize = ::MainColumn::COUNT; + + let mut row_0 = Array1::zeros(BASE_WIDTH); + row_0[StackPointer.base_table_index()] = stack_pointer.into(); + row_0[CLK.base_table_index()] = clk_0.into(); + + let mut row_1 = Array1::zeros(BASE_WIDTH); + row_1[StackPointer.base_table_index()] = stack_pointer.into(); + row_1[CLK.base_table_index()] = clk_1.into(); + + let clk_comparison = clk_0.cmp(&clk_1); + let row_comparison = compare_rows(row_0.view(), row_1.view()); + + prop_assert_eq!(clk_comparison, row_comparison); + } +} diff --git a/triton-vm/src/table/op_stack_table.rs b/triton-vm/src/table/op_stack_table.rs deleted file mode 100644 index 65f6f5847..000000000 --- a/triton-vm/src/table/op_stack_table.rs +++ /dev/null @@ -1,563 +0,0 @@ -use std::cmp::Ordering; -use std::collections::HashMap; -use std::ops::Range; - -use arbitrary::Arbitrary; -use constraint_circuit::DualRowIndicator::*; -use constraint_circuit::SingleRowIndicator::*; -use constraint_circuit::*; -use isa::op_stack::OpStackElement; -use isa::op_stack::UnderflowIO; -use itertools::Itertools; -use ndarray::parallel::prelude::*; -use ndarray::prelude::*; -use strum::EnumCount; -use strum::IntoEnumIterator; -use twenty_first::math::traits::FiniteField; -use twenty_first::prelude::*; - -use crate::aet::AlgebraicExecutionTrace; -use crate::ndarray_helper::contiguous_column_slices; -use crate::ndarray_helper::horizontal_multi_slice_mut; -use crate::profiler::profiler; -use crate::table::challenges::ChallengeId::*; -use crate::table::challenges::Challenges; -use crate::table::cross_table_argument::*; -use crate::table::master_table::TableId; -use crate::table::table_column::OpStackBaseTableColumn::*; -use crate::table::table_column::OpStackExtTableColumn::*; -use crate::table::table_column::*; - -pub const BASE_WIDTH: usize = OpStackBaseTableColumn::COUNT; -pub const EXT_WIDTH: usize = OpStackExtTableColumn::COUNT; -pub const FULL_WIDTH: usize = BASE_WIDTH + EXT_WIDTH; - -/// The value indicating a padding row in the op stack table. Stored in the `ib1_shrink_stack` -/// column. -pub(crate) const PADDING_VALUE: BFieldElement = BFieldElement::new(2); - -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct OpStackTable; - -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct ExtOpStackTable; - -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)] -pub struct OpStackTableEntry { - pub clk: u32, - pub op_stack_pointer: BFieldElement, - pub underflow_io: UnderflowIO, -} - -impl OpStackTableEntry { - pub fn new(clk: u32, op_stack_pointer: BFieldElement, underflow_io: UnderflowIO) -> Self { - Self { - clk, - op_stack_pointer, - underflow_io, - } - } - - pub fn shrinks_stack(&self) -> bool { - self.underflow_io.shrinks_stack() - } - - pub fn grows_stack(&self) -> bool { - self.underflow_io.grows_stack() - } - - pub fn from_underflow_io_sequence( - clk: u32, - op_stack_pointer_after_sequence_execution: BFieldElement, - mut underflow_io_sequence: Vec, - ) -> Vec { - UnderflowIO::canonicalize_sequence(&mut underflow_io_sequence); - assert!(UnderflowIO::is_uniform_sequence(&underflow_io_sequence)); - - let sequence_length: BFieldElement = - u32::try_from(underflow_io_sequence.len()).unwrap().into(); - let mut op_stack_pointer = match UnderflowIO::is_writing_sequence(&underflow_io_sequence) { - true => op_stack_pointer_after_sequence_execution - sequence_length, - false => op_stack_pointer_after_sequence_execution + sequence_length, - }; - let mut op_stack_table_entries = vec![]; - for underflow_io in underflow_io_sequence { - if underflow_io.shrinks_stack() { - op_stack_pointer.decrement(); - } - let op_stack_table_entry = Self::new(clk, op_stack_pointer, underflow_io); - op_stack_table_entries.push(op_stack_table_entry); - if underflow_io.grows_stack() { - op_stack_pointer.increment(); - } - } - op_stack_table_entries - } - - pub fn to_base_table_row(self) -> Array1 { - let shrink_stack_indicator = match self.shrinks_stack() { - true => bfe!(1), - false => bfe!(0), - }; - - let mut row = Array1::zeros(BASE_WIDTH); - row[CLK.base_table_index()] = self.clk.into(); - row[IB1ShrinkStack.base_table_index()] = shrink_stack_indicator; - row[StackPointer.base_table_index()] = self.op_stack_pointer; - row[FirstUnderflowElement.base_table_index()] = self.underflow_io.payload(); - row - } -} - -impl ExtOpStackTable { - pub fn initial_constraints( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let challenge = |c| circuit_builder.challenge(c); - let constant = |c| circuit_builder.b_constant(c); - let x_constant = |c| circuit_builder.x_constant(c); - let base_row = |column: OpStackBaseTableColumn| { - circuit_builder.input(BaseRow(column.master_base_table_index())) - }; - let ext_row = |column: OpStackExtTableColumn| { - circuit_builder.input(ExtRow(column.master_ext_table_index())) - }; - - let initial_stack_length = u32::try_from(OpStackElement::COUNT).unwrap(); - let initial_stack_length = constant(initial_stack_length.into()); - let padding_indicator = constant(PADDING_VALUE); - - let stack_pointer_is_16 = base_row(StackPointer) - initial_stack_length.clone(); - - let compressed_row = challenge(OpStackClkWeight) * base_row(CLK) - + challenge(OpStackIb1Weight) * base_row(IB1ShrinkStack) - + challenge(OpStackPointerWeight) * initial_stack_length - + challenge(OpStackFirstUnderflowElementWeight) * base_row(FirstUnderflowElement); - let rppa_initial = challenge(OpStackIndeterminate) - compressed_row; - let rppa_has_accumulated_first_row = ext_row(RunningProductPermArg) - rppa_initial; - - let rppa_is_default_initial = - ext_row(RunningProductPermArg) - x_constant(PermArg::default_initial()); - - let first_row_is_padding_row = base_row(IB1ShrinkStack) - padding_indicator; - let first_row_is_not_padding_row = - base_row(IB1ShrinkStack) * (base_row(IB1ShrinkStack) - constant(bfe!(1))); - - let rppa_starts_correctly = rppa_has_accumulated_first_row * first_row_is_padding_row - + rppa_is_default_initial * first_row_is_not_padding_row; - - let lookup_argument_initial = x_constant(LookupArg::default_initial()); - let clock_jump_diff_log_derivative_is_initialized_correctly = - ext_row(ClockJumpDifferenceLookupClientLogDerivative) - lookup_argument_initial; - - vec![ - stack_pointer_is_16, - rppa_starts_correctly, - clock_jump_diff_log_derivative_is_initialized_correctly, - ] - } - - pub fn consistency_constraints( - _circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - // no further constraints - vec![] - } - - pub fn transition_constraints( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constant = |c| circuit_builder.b_constant(c); - let challenge = |c| circuit_builder.challenge(c); - let current_base_row = |column: OpStackBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(column.master_base_table_index())) - }; - let current_ext_row = |column: OpStackExtTableColumn| { - circuit_builder.input(CurrentExtRow(column.master_ext_table_index())) - }; - let next_base_row = |column: OpStackBaseTableColumn| { - circuit_builder.input(NextBaseRow(column.master_base_table_index())) - }; - let next_ext_row = |column: OpStackExtTableColumn| { - circuit_builder.input(NextExtRow(column.master_ext_table_index())) - }; - - let one = constant(1_u32.into()); - let padding_indicator = constant(PADDING_VALUE); - - let clk = current_base_row(CLK); - let ib1_shrink_stack = current_base_row(IB1ShrinkStack); - let stack_pointer = current_base_row(StackPointer); - let first_underflow_element = current_base_row(FirstUnderflowElement); - let rppa = current_ext_row(RunningProductPermArg); - let clock_jump_diff_log_derivative = - current_ext_row(ClockJumpDifferenceLookupClientLogDerivative); - - let clk_next = next_base_row(CLK); - let ib1_shrink_stack_next = next_base_row(IB1ShrinkStack); - let stack_pointer_next = next_base_row(StackPointer); - let first_underflow_element_next = next_base_row(FirstUnderflowElement); - let rppa_next = next_ext_row(RunningProductPermArg); - let clock_jump_diff_log_derivative_next = - next_ext_row(ClockJumpDifferenceLookupClientLogDerivative); - - let stack_pointer_increases_by_1_or_does_not_change = - (stack_pointer_next.clone() - stack_pointer.clone() - one.clone()) - * (stack_pointer_next.clone() - stack_pointer.clone()); - - let stack_pointer_inc_by_1_or_underflow_element_doesnt_change_or_next_ci_grows_stack = - (stack_pointer_next.clone() - stack_pointer.clone() - one.clone()) - * (first_underflow_element_next.clone() - first_underflow_element.clone()) - * ib1_shrink_stack_next.clone(); - - let next_row_is_padding_row = ib1_shrink_stack_next.clone() - padding_indicator.clone(); - let if_current_row_is_padding_row_then_next_row_is_padding_row = ib1_shrink_stack.clone() - * (ib1_shrink_stack - one.clone()) - * next_row_is_padding_row.clone(); - - // The running product for the permutation argument `rppa` is updated correctly. - let compressed_row = circuit_builder.challenge(OpStackClkWeight) * clk_next.clone() - + circuit_builder.challenge(OpStackIb1Weight) * ib1_shrink_stack_next.clone() - + circuit_builder.challenge(OpStackPointerWeight) * stack_pointer_next.clone() - + circuit_builder.challenge(OpStackFirstUnderflowElementWeight) - * first_underflow_element_next; - - let rppa_updates = - rppa_next.clone() - rppa.clone() * (challenge(OpStackIndeterminate) - compressed_row); - - let next_row_is_not_padding_row = - ib1_shrink_stack_next.clone() * (ib1_shrink_stack_next.clone() - one.clone()); - let rppa_remains = rppa_next - rppa; - - let rppa_updates_correctly = rppa_updates * next_row_is_padding_row.clone() - + rppa_remains * next_row_is_not_padding_row.clone(); - - let clk_diff = clk_next - clk; - let log_derivative_accumulates = (clock_jump_diff_log_derivative_next.clone() - - clock_jump_diff_log_derivative.clone()) - * (challenge(ClockJumpDifferenceLookupIndeterminate) - clk_diff) - - one.clone(); - let log_derivative_remains = - clock_jump_diff_log_derivative_next.clone() - clock_jump_diff_log_derivative.clone(); - - let log_derivative_accumulates_or_stack_pointer_changes_or_next_row_is_padding_row = - log_derivative_accumulates - * (stack_pointer_next.clone() - stack_pointer.clone() - one.clone()) - * next_row_is_padding_row; - let log_derivative_remains_or_stack_pointer_doesnt_change = - log_derivative_remains.clone() * (stack_pointer_next.clone() - stack_pointer.clone()); - let log_derivatve_remains_or_next_row_is_not_padding_row = - log_derivative_remains * next_row_is_not_padding_row; - - let log_derivative_updates_correctly = - log_derivative_accumulates_or_stack_pointer_changes_or_next_row_is_padding_row - + log_derivative_remains_or_stack_pointer_doesnt_change - + log_derivatve_remains_or_next_row_is_not_padding_row; - - vec![ - stack_pointer_increases_by_1_or_does_not_change, - stack_pointer_inc_by_1_or_underflow_element_doesnt_change_or_next_ci_grows_stack, - if_current_row_is_padding_row_then_next_row_is_padding_row, - rppa_updates_correctly, - log_derivative_updates_correctly, - ] - } - - pub fn terminal_constraints( - _circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - // no further constraints - vec![] - } -} - -impl OpStackTable { - /// Fills the trace table in-place and returns all clock jump differences. - pub fn fill_trace( - op_stack_table: &mut ArrayViewMut2, - aet: &AlgebraicExecutionTrace, - ) -> Vec { - let mut op_stack_table = - op_stack_table.slice_mut(s![0..aet.height_of_table(TableId::OpStack), ..]); - let trace_iter = aet.op_stack_underflow_trace.rows().into_iter(); - - let sorted_rows = - trace_iter.sorted_by(|row_0, row_1| Self::compare_rows(row_0.view(), row_1.view())); - for (row_index, row) in sorted_rows.enumerate() { - op_stack_table.row_mut(row_index).assign(&row); - } - - Self::clock_jump_differences(op_stack_table.view()) - } - - fn compare_rows( - row_0: ArrayView1, - row_1: ArrayView1, - ) -> Ordering { - let stack_pointer_0 = row_0[StackPointer.base_table_index()].value(); - let stack_pointer_1 = row_1[StackPointer.base_table_index()].value(); - let compare_stack_pointers = stack_pointer_0.cmp(&stack_pointer_1); - - let clk_0 = row_0[CLK.base_table_index()].value(); - let clk_1 = row_1[CLK.base_table_index()].value(); - let compare_clocks = clk_0.cmp(&clk_1); - - compare_stack_pointers.then(compare_clocks) - } - - fn clock_jump_differences(op_stack_table: ArrayView2) -> Vec { - let mut clock_jump_differences = vec![]; - for consecutive_rows in op_stack_table.axis_windows(Axis(0), 2) { - let current_row = consecutive_rows.row(0); - let next_row = consecutive_rows.row(1); - let current_stack_pointer = current_row[StackPointer.base_table_index()]; - let next_stack_pointer = next_row[StackPointer.base_table_index()]; - if current_stack_pointer == next_stack_pointer { - let current_clk = current_row[CLK.base_table_index()]; - let next_clk = next_row[CLK.base_table_index()]; - let clk_difference = next_clk - current_clk; - clock_jump_differences.push(clk_difference); - } - } - clock_jump_differences - } - - pub fn pad_trace(mut op_stack_table: ArrayViewMut2, op_stack_table_len: usize) { - let last_row_index = op_stack_table_len.saturating_sub(1); - let mut padding_row = op_stack_table.row(last_row_index).to_owned(); - padding_row[IB1ShrinkStack.base_table_index()] = PADDING_VALUE; - if op_stack_table_len == 0 { - let first_stack_pointer = u32::try_from(OpStackElement::COUNT).unwrap().into(); - padding_row[StackPointer.base_table_index()] = first_stack_pointer; - } - - let mut padding_section = op_stack_table.slice_mut(s![op_stack_table_len.., ..]); - padding_section - .axis_iter_mut(Axis(0)) - .into_par_iter() - .for_each(|mut row| row.assign(&padding_row)); - } - - pub fn extend( - base_table: ArrayView2, - mut ext_table: ArrayViewMut2, - challenges: &Challenges, - ) { - profiler!(start "op stack table"); - assert_eq!(BASE_WIDTH, base_table.ncols()); - assert_eq!(EXT_WIDTH, ext_table.ncols()); - assert_eq!(base_table.nrows(), ext_table.nrows()); - - let extension_column_indices = OpStackExtTableColumn::iter() - .map(|column| column.ext_table_index()) - .collect_vec(); - let extension_column_slices = horizontal_multi_slice_mut( - ext_table.view_mut(), - &contiguous_column_slices(&extension_column_indices), - ); - let extension_functions = [ - Self::extension_column_running_product_permutation_argument, - Self::extension_column_clock_jump_diff_lookup_log_derivative, - ]; - - extension_functions - .into_par_iter() - .zip_eq(extension_column_slices) - .for_each(|(generator, slice)| { - generator(base_table, challenges).move_into(slice); - }); - - profiler!(stop "op stack table"); - } - - fn extension_column_running_product_permutation_argument( - base_table: ArrayView2, - challenges: &Challenges, - ) -> Array2 { - let perm_arg_indeterminate = challenges[OpStackIndeterminate]; - - let mut running_product = PermArg::default_initial(); - let mut extension_column = Vec::with_capacity(base_table.nrows()); - for row in base_table.rows() { - if row[IB1ShrinkStack.base_table_index()] != PADDING_VALUE { - let compressed_row = row[CLK.base_table_index()] * challenges[OpStackClkWeight] - + row[IB1ShrinkStack.base_table_index()] * challenges[OpStackIb1Weight] - + row[StackPointer.base_table_index()] * challenges[OpStackPointerWeight] - + row[FirstUnderflowElement.base_table_index()] - * challenges[OpStackFirstUnderflowElementWeight]; - running_product *= perm_arg_indeterminate - compressed_row; - } - extension_column.push(running_product); - } - Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() - } - - fn extension_column_clock_jump_diff_lookup_log_derivative( - base_table: ArrayView2, - challenges: &Challenges, - ) -> Array2 { - // - use memoization to avoid recomputing inverses - // - precompute common values through batch inversion - const PRECOMPUTE_INVERSES_OF: Range = 0..100; - let cjd_lookup_indeterminate = challenges[ClockJumpDifferenceLookupIndeterminate]; - let to_invert = PRECOMPUTE_INVERSES_OF - .map(|i| cjd_lookup_indeterminate - bfe!(i)) - .collect_vec(); - let inverses = XFieldElement::batch_inversion(to_invert); - let mut inverses_dictionary = PRECOMPUTE_INVERSES_OF - .zip_eq(inverses) - .map(|(i, inv)| (bfe!(i), inv)) - .collect::>(); - - // populate extension column using memoization - let mut cjd_lookup_log_derivative = LookupArg::default_initial(); - let mut extension_column = Vec::with_capacity(base_table.nrows()); - extension_column.push(cjd_lookup_log_derivative); - for (previous_row, current_row) in base_table.rows().into_iter().tuple_windows() { - if current_row[IB1ShrinkStack.base_table_index()] == PADDING_VALUE { - break; - }; - - let previous_stack_pointer = previous_row[StackPointer.base_table_index()]; - let current_stack_pointer = current_row[StackPointer.base_table_index()]; - if previous_stack_pointer == current_stack_pointer { - let previous_clock = previous_row[CLK.base_table_index()]; - let current_clock = current_row[CLK.base_table_index()]; - let clock_jump_difference = current_clock - previous_clock; - let &mut inverse = inverses_dictionary - .entry(clock_jump_difference) - .or_insert_with(|| { - (cjd_lookup_indeterminate - clock_jump_difference).inverse() - }); - cjd_lookup_log_derivative += inverse; - } - extension_column.push(cjd_lookup_log_derivative); - } - - // fill padding section - extension_column.resize(base_table.nrows(), cjd_lookup_log_derivative); - Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() - } -} - -#[cfg(test)] -pub(crate) mod tests { - use assert2::assert; - use isa::op_stack::OpStackElement; - use itertools::Itertools; - use proptest::collection::vec; - use proptest::prelude::*; - use proptest_arbitrary_interop::arb; - use test_strategy::proptest; - - use super::*; - - #[proptest] - fn op_stack_table_entry_either_shrinks_stack_or_grows_stack( - #[strategy(arb())] entry: OpStackTableEntry, - ) { - let shrinks_stack = entry.shrinks_stack(); - let grows_stack = entry.grows_stack(); - assert!(shrinks_stack ^ grows_stack); - } - - #[proptest] - fn op_stack_pointer_in_sequence_of_op_stack_table_entries( - clk: u32, - #[strategy(OpStackElement::COUNT..1024)] stack_pointer: usize, - #[strategy(vec(arb(), ..OpStackElement::COUNT))] base_field_elements: Vec, - sequence_of_writes: bool, - ) { - let sequence_length = u64::try_from(base_field_elements.len()).unwrap(); - let stack_pointer = u64::try_from(stack_pointer).unwrap(); - - let underflow_io_operation = match sequence_of_writes { - true => UnderflowIO::Write, - false => UnderflowIO::Read, - }; - let underflow_io = base_field_elements - .into_iter() - .map(underflow_io_operation) - .collect(); - - let op_stack_pointer = stack_pointer.into(); - let entries = - OpStackTableEntry::from_underflow_io_sequence(clk, op_stack_pointer, underflow_io); - let op_stack_pointers = entries - .iter() - .map(|entry| entry.op_stack_pointer.value()) - .sorted() - .collect_vec(); - - let expected_stack_pointer_range = match sequence_of_writes { - true => stack_pointer - sequence_length..stack_pointer, - false => stack_pointer..stack_pointer + sequence_length, - }; - let expected_op_stack_pointers = expected_stack_pointer_range.collect_vec(); - prop_assert_eq!(expected_op_stack_pointers, op_stack_pointers); - } - - #[proptest] - fn clk_stays_same_in_sequence_of_op_stack_table_entries( - clk: u32, - #[strategy(OpStackElement::COUNT..1024)] stack_pointer: usize, - #[strategy(vec(arb(), ..OpStackElement::COUNT))] base_field_elements: Vec, - sequence_of_writes: bool, - ) { - let underflow_io_operation = match sequence_of_writes { - true => UnderflowIO::Write, - false => UnderflowIO::Read, - }; - let underflow_io = base_field_elements - .into_iter() - .map(underflow_io_operation) - .collect(); - - let op_stack_pointer = u64::try_from(stack_pointer).unwrap().into(); - let entries = - OpStackTableEntry::from_underflow_io_sequence(clk, op_stack_pointer, underflow_io); - let clk_values = entries.iter().map(|entry| entry.clk).collect_vec(); - let all_clk_values_are_clk = clk_values.iter().all(|&c| c == clk); - prop_assert!(all_clk_values_are_clk); - } - - #[proptest] - fn compare_rows_with_unequal_stack_pointer_and_equal_clk( - stack_pointer_0: u64, - stack_pointer_1: u64, - clk: u64, - ) { - let mut row_0 = Array1::zeros(BASE_WIDTH); - row_0[StackPointer.base_table_index()] = stack_pointer_0.into(); - row_0[CLK.base_table_index()] = clk.into(); - - let mut row_1 = Array1::zeros(BASE_WIDTH); - row_1[StackPointer.base_table_index()] = stack_pointer_1.into(); - row_1[CLK.base_table_index()] = clk.into(); - - let stack_pointer_comparison = stack_pointer_0.cmp(&stack_pointer_1); - let row_comparison = OpStackTable::compare_rows(row_0.view(), row_1.view()); - - prop_assert_eq!(stack_pointer_comparison, row_comparison); - } - - #[proptest] - fn compare_rows_with_equal_stack_pointer_and_unequal_clk( - stack_pointer: u64, - clk_0: u64, - clk_1: u64, - ) { - let mut row_0 = Array1::zeros(BASE_WIDTH); - row_0[StackPointer.base_table_index()] = stack_pointer.into(); - row_0[CLK.base_table_index()] = clk_0.into(); - - let mut row_1 = Array1::zeros(BASE_WIDTH); - row_1[StackPointer.base_table_index()] = stack_pointer.into(); - row_1[CLK.base_table_index()] = clk_1.into(); - - let clk_comparison = clk_0.cmp(&clk_1); - let row_comparison = OpStackTable::compare_rows(row_0.view(), row_1.view()); - - prop_assert_eq!(clk_comparison, row_comparison); - } -} diff --git a/triton-vm/src/table/processor.rs b/triton-vm/src/table/processor.rs new file mode 100644 index 000000000..3bd239bc1 --- /dev/null +++ b/triton-vm/src/table/processor.rs @@ -0,0 +1,1501 @@ +use std::cmp::max; +use std::ops::Mul; + +use air::challenge_id::ChallengeId; +use air::challenge_id::ChallengeId::*; +use air::cross_table_argument::*; +use air::table::processor::ProcessorTable; +use air::table::ram; +use air::table_column::ProcessorBaseTableColumn::*; +use air::table_column::ProcessorExtTableColumn::*; +use air::table_column::*; +use constraint_circuit::DualRowIndicator::*; +use constraint_circuit::SingleRowIndicator::*; +use constraint_circuit::*; +use isa::instruction::AnInstruction::*; +use isa::instruction::Instruction; +use isa::instruction::InstructionBit; +use isa::instruction::ALL_INSTRUCTIONS; +use isa::op_stack::NumberOfWords; +use isa::op_stack::OpStackElement; +use isa::op_stack::NUM_OP_STACK_REGISTERS; +use itertools::izip; +use itertools::Itertools; +use ndarray::parallel::prelude::*; +use ndarray::*; +use num_traits::ConstOne; +use num_traits::One; +use num_traits::Zero; +use strum::EnumCount; +use strum::IntoEnumIterator; +use twenty_first::math::traits::FiniteField; +use twenty_first::prelude::x_field_element::EXTENSION_DEGREE; +use twenty_first::prelude::*; + +use crate::aet::AlgebraicExecutionTrace; +use crate::challenges::Challenges; +use crate::ndarray_helper::contiguous_column_slices; +use crate::ndarray_helper::horizontal_multi_slice_mut; +use crate::profiler::profiler; +use crate::table::TraceTable; + +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub(super) struct ClkJumpDiffs { + pub op_stack: Vec, + pub ram: Vec, + pub jump_stack: Vec, +} + +impl TraceTable for ProcessorTable { + type FillParam = ClkJumpDiffs; + type FillReturnInfo = (); + + fn fill( + mut main_table: ArrayViewMut2, + aet: &AlgebraicExecutionTrace, + clk_jump_diffs: Self::FillParam, + ) { + let num_rows = aet.processor_trace.nrows(); + let mut clk_jump_diff_multiplicities = Array1::zeros([num_rows]); + + for clk_jump_diff in clk_jump_diffs + .op_stack + .into_iter() + .chain(clk_jump_diffs.ram) + .chain(clk_jump_diffs.jump_stack) + { + let clk = clk_jump_diff.value() as usize; + clk_jump_diff_multiplicities[clk] += BFieldElement::ONE; + } + + let mut processor_table = main_table.slice_mut(s![0..num_rows, ..]); + processor_table.assign(&aet.processor_trace); + processor_table + .column_mut(ClockJumpDifferenceLookupMultiplicity.base_table_index()) + .assign(&clk_jump_diff_multiplicities); + } + + fn pad(mut main_table: ArrayViewMut2, table_len: usize) { + assert!(table_len > 0, "Processor Table must have at least one row."); + let mut padding_template = main_table.row(table_len - 1).to_owned(); + padding_template[IsPadding.base_table_index()] = bfe!(1); + padding_template[ClockJumpDifferenceLookupMultiplicity.base_table_index()] = bfe!(0); + main_table + .slice_mut(s![table_len.., ..]) + .axis_iter_mut(Axis(0)) + .into_par_iter() + .for_each(|mut row| row.assign(&padding_template)); + + let clk_range = table_len..main_table.nrows(); + let clk_col = Array1::from_iter(clk_range.map(|a| bfe!(a as u64))); + clk_col.move_into(main_table.slice_mut(s![table_len.., CLK.base_table_index()])); + + // The Jump Stack Table does not have a padding indicator. Hence, clock jump differences are + // being looked up in its padding sections. The clock jump differences in that section are + // always 1. The lookup multiplicities of clock value 1 must be increased accordingly: one + // per padding row. + let num_padding_rows = main_table.nrows() - table_len; + let num_padding_rows = bfe!(num_padding_rows as u64); + let mut row_1 = main_table.row_mut(1); + + row_1[ClockJumpDifferenceLookupMultiplicity.base_table_index()] += num_padding_rows; + } + + fn extend( + main_table: ArrayView2, + mut aux_table: ArrayViewMut2, + challenges: &Challenges, + ) { + profiler!(start "processor table"); + assert_eq!(Self::MainColumn::COUNT, main_table.ncols()); + assert_eq!(Self::AuxColumn::COUNT, aux_table.ncols()); + assert_eq!(main_table.nrows(), aux_table.nrows()); + + let all_column_indices = ProcessorExtTableColumn::iter() + .map(|column| column.ext_table_index()) + .collect_vec(); + let all_column_slices = horizontal_multi_slice_mut( + aux_table.view_mut(), + &contiguous_column_slices(&all_column_indices), + ); + + let all_column_generators = [ + extension_column_input_table_eval_argument, + extension_column_output_table_eval_argument, + extension_column_instruction_lookup_argument, + extension_column_op_stack_table_perm_argument, + extension_column_ram_table_perm_argument, + extension_column_jump_stack_table_perm_argument, + extension_column_hash_input_eval_argument, + extension_column_hash_digest_eval_argument, + extension_column_sponge_eval_argument, + extension_column_for_u32_lookup_argument, + extension_column_for_clock_jump_difference_lookup_argument, + ]; + all_column_generators + .into_par_iter() + .zip_eq(all_column_slices) + .for_each(|(generator, slice)| { + generator(main_table, challenges).move_into(slice); + }); + + profiler!(stop "processor table"); + } +} + +fn extension_column_input_table_eval_argument( + base_table: ArrayView2, + challenges: &Challenges, +) -> Array2 { + let mut input_table_running_evaluation = EvalArg::default_initial(); + let mut extension_column = Vec::with_capacity(base_table.nrows()); + extension_column.push(input_table_running_evaluation); + for (previous_row, current_row) in base_table.rows().into_iter().tuple_windows() { + if let Some(Instruction::ReadIo(st)) = instruction_from_row(previous_row) { + for i in (0..st.num_words()).rev() { + let input_symbol_column = ProcessorTable::op_stack_column_by_index(i); + let input_symbol = current_row[input_symbol_column.base_table_index()]; + input_table_running_evaluation = input_table_running_evaluation + * challenges[StandardInputIndeterminate] + + input_symbol; + } + } + extension_column.push(input_table_running_evaluation); + } + Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() +} + +fn extension_column_output_table_eval_argument( + base_table: ArrayView2, + challenges: &Challenges, +) -> Array2 { + let mut output_table_running_evaluation = EvalArg::default_initial(); + let mut extension_column = Vec::with_capacity(base_table.nrows()); + extension_column.push(output_table_running_evaluation); + for (previous_row, _) in base_table.rows().into_iter().tuple_windows() { + if let Some(Instruction::WriteIo(st)) = instruction_from_row(previous_row) { + for i in 0..st.num_words() { + let output_symbol_column = ProcessorTable::op_stack_column_by_index(i); + let output_symbol = previous_row[output_symbol_column.base_table_index()]; + output_table_running_evaluation = output_table_running_evaluation + * challenges[StandardOutputIndeterminate] + + output_symbol; + } + } + extension_column.push(output_table_running_evaluation); + } + Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() +} + +fn extension_column_instruction_lookup_argument( + base_table: ArrayView2, + challenges: &Challenges, +) -> Array2 { + // collect all to-be-inverted elements for batch inversion + let mut to_invert = vec![]; + for row in base_table.rows() { + if row[IsPadding.base_table_index()].is_one() { + break; // padding marks the end of the trace + } + + let compressed_row = row[IP.base_table_index()] * challenges[ProgramAddressWeight] + + row[CI.base_table_index()] * challenges[ProgramInstructionWeight] + + row[NIA.base_table_index()] * challenges[ProgramNextInstructionWeight]; + to_invert.push(challenges[InstructionLookupIndeterminate] - compressed_row); + } + + // populate extension column with inverses + let mut instruction_lookup_log_derivative = LookupArg::default_initial(); + let mut extension_column = Vec::with_capacity(base_table.nrows()); + for inverse in XFieldElement::batch_inversion(to_invert) { + instruction_lookup_log_derivative += inverse; + extension_column.push(instruction_lookup_log_derivative); + } + + // fill padding section + extension_column.resize(base_table.nrows(), instruction_lookup_log_derivative); + Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() +} + +fn extension_column_op_stack_table_perm_argument( + base_table: ArrayView2, + challenges: &Challenges, +) -> Array2 { + let mut op_stack_table_running_product = EvalArg::default_initial(); + let mut extension_column = Vec::with_capacity(base_table.nrows()); + extension_column.push(op_stack_table_running_product); + for (prev, curr) in base_table.rows().into_iter().tuple_windows() { + op_stack_table_running_product *= + factor_for_op_stack_table_running_product(prev, curr, challenges); + extension_column.push(op_stack_table_running_product); + } + Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() +} + +fn extension_column_ram_table_perm_argument( + base_table: ArrayView2, + challenges: &Challenges, +) -> Array2 { + let mut ram_table_running_product = PermArg::default_initial(); + let mut extension_column = Vec::with_capacity(base_table.nrows()); + extension_column.push(ram_table_running_product); + for (prev, curr) in base_table.rows().into_iter().tuple_windows() { + if let Some(f) = factor_for_ram_table_running_product(prev, curr, challenges) { + ram_table_running_product *= f; + }; + extension_column.push(ram_table_running_product); + } + Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() +} + +fn extension_column_jump_stack_table_perm_argument( + base_table: ArrayView2, + challenges: &Challenges, +) -> Array2 { + let mut jump_stack_running_product = PermArg::default_initial(); + let mut extension_column = Vec::with_capacity(base_table.nrows()); + for row in base_table.rows() { + let compressed_row = row[CLK.base_table_index()] * challenges[JumpStackClkWeight] + + row[CI.base_table_index()] * challenges[JumpStackCiWeight] + + row[JSP.base_table_index()] * challenges[JumpStackJspWeight] + + row[JSO.base_table_index()] * challenges[JumpStackJsoWeight] + + row[JSD.base_table_index()] * challenges[JumpStackJsdWeight]; + jump_stack_running_product *= challenges[JumpStackIndeterminate] - compressed_row; + extension_column.push(jump_stack_running_product); + } + Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() +} + +/// Hash Table – `hash`'s or `merkle_step`'s input from Processor to Hash Coprocessor +fn extension_column_hash_input_eval_argument( + base_table: ArrayView2, + challenges: &Challenges, +) -> Array2 { + let st0_through_st9 = [ST0, ST1, ST2, ST3, ST4, ST5, ST6, ST7, ST8, ST9]; + let hash_state_weights = &challenges[StackWeight0..StackWeight10]; + + let merkle_step_left_sibling = [ST0, ST1, ST2, ST3, ST4, HV0, HV1, HV2, HV3, HV4]; + let merkle_step_right_sibling = [HV0, HV1, HV2, HV3, HV4, ST0, ST1, ST2, ST3, ST4]; + + let mut hash_input_running_evaluation = EvalArg::default_initial(); + let mut extension_column = Vec::with_capacity(base_table.nrows()); + for row in base_table.rows() { + let current_instruction = row[CI.base_table_index()]; + if current_instruction == Instruction::Hash.opcode_b() + || current_instruction == Instruction::MerkleStep.opcode_b() + || current_instruction == Instruction::MerkleStepMem.opcode_b() + { + let is_left_sibling = row[ST5.base_table_index()].value() % 2 == 0; + let hash_input = match instruction_from_row(row) { + Some(MerkleStep | MerkleStepMem) if is_left_sibling => merkle_step_left_sibling, + Some(MerkleStep | MerkleStepMem) => merkle_step_right_sibling, + Some(Hash) => st0_through_st9, + _ => unreachable!(), + }; + let compressed_row = hash_input + .map(|st| row[st.base_table_index()]) + .into_iter() + .zip_eq(hash_state_weights.iter()) + .map(|(st, &weight)| weight * st) + .sum::(); + hash_input_running_evaluation = + hash_input_running_evaluation * challenges[HashInputIndeterminate] + compressed_row; + } + extension_column.push(hash_input_running_evaluation); + } + Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() +} + +/// Hash Table – `hash`'s output from Hash Coprocessor to Processor +fn extension_column_hash_digest_eval_argument( + base_table: ArrayView2, + challenges: &Challenges, +) -> Array2 { + let mut hash_digest_running_evaluation = EvalArg::default_initial(); + let mut extension_column = Vec::with_capacity(base_table.nrows()); + extension_column.push(hash_digest_running_evaluation); + for (previous_row, current_row) in base_table.rows().into_iter().tuple_windows() { + let previous_ci = previous_row[CI.base_table_index()]; + if previous_ci == Instruction::Hash.opcode_b() + || previous_ci == Instruction::MerkleStep.opcode_b() + || previous_ci == Instruction::MerkleStepMem.opcode_b() + { + let compressed_row = [ST0, ST1, ST2, ST3, ST4] + .map(|st| current_row[st.base_table_index()]) + .into_iter() + .zip_eq(&challenges[StackWeight0..=StackWeight4]) + .map(|(st, &weight)| weight * st) + .sum::(); + hash_digest_running_evaluation = hash_digest_running_evaluation + * challenges[HashDigestIndeterminate] + + compressed_row; + } + extension_column.push(hash_digest_running_evaluation); + } + Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() +} + +/// Hash Table – `hash`'s or `merkle_step`'s input from Processor to Hash Coprocessor +fn extension_column_sponge_eval_argument( + base_table: ArrayView2, + challenges: &Challenges, +) -> Array2 { + let st0_through_st9 = [ST0, ST1, ST2, ST3, ST4, ST5, ST6, ST7, ST8, ST9]; + let hash_state_weights = &challenges[StackWeight0..StackWeight10]; + + let mut sponge_running_evaluation = EvalArg::default_initial(); + let mut extension_column = Vec::with_capacity(base_table.nrows()); + extension_column.push(sponge_running_evaluation); + for (previous_row, current_row) in base_table.rows().into_iter().tuple_windows() { + let previous_ci = previous_row[CI.base_table_index()]; + if previous_ci == Instruction::SpongeInit.opcode_b() { + sponge_running_evaluation = sponge_running_evaluation * challenges[SpongeIndeterminate] + + challenges[HashCIWeight] * Instruction::SpongeInit.opcode_b(); + } else if previous_ci == Instruction::SpongeAbsorb.opcode_b() { + let compressed_row = st0_through_st9 + .map(|st| previous_row[st.base_table_index()]) + .into_iter() + .zip_eq(hash_state_weights.iter()) + .map(|(st, &weight)| weight * st) + .sum::(); + sponge_running_evaluation = sponge_running_evaluation * challenges[SpongeIndeterminate] + + challenges[HashCIWeight] * Instruction::SpongeAbsorb.opcode_b() + + compressed_row; + } else if previous_ci == Instruction::SpongeAbsorbMem.opcode_b() { + let stack_elements = [ST1, ST2, ST3, ST4]; + let helper_variables = [HV0, HV1, HV2, HV3, HV4, HV5]; + let compressed_row = stack_elements + .map(|st| current_row[st.base_table_index()]) + .into_iter() + .chain(helper_variables.map(|hv| previous_row[hv.base_table_index()])) + .zip_eq(hash_state_weights.iter()) + .map(|(element, &weight)| weight * element) + .sum::(); + sponge_running_evaluation = sponge_running_evaluation * challenges[SpongeIndeterminate] + + challenges[HashCIWeight] * Instruction::SpongeAbsorb.opcode_b() + + compressed_row; + } else if previous_ci == Instruction::SpongeSqueeze.opcode_b() { + let compressed_row = st0_through_st9 + .map(|st| current_row[st.base_table_index()]) + .into_iter() + .zip_eq(hash_state_weights.iter()) + .map(|(st, &weight)| weight * st) + .sum::(); + sponge_running_evaluation = sponge_running_evaluation * challenges[SpongeIndeterminate] + + challenges[HashCIWeight] * Instruction::SpongeSqueeze.opcode_b() + + compressed_row; + } + extension_column.push(sponge_running_evaluation); + } + Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() +} + +fn extension_column_for_u32_lookup_argument( + base_table: ArrayView2, + challenges: &Challenges, +) -> Array2 { + // collect elements to be inverted for more performant batch inversion + let mut to_invert = vec![]; + for (previous_row, current_row) in base_table.rows().into_iter().tuple_windows() { + let previous_ci = previous_row[CI.base_table_index()]; + if previous_ci == Instruction::Split.opcode_b() { + let compressed_row = current_row[ST0.base_table_index()] * challenges[U32LhsWeight] + + current_row[ST1.base_table_index()] * challenges[U32RhsWeight] + + previous_row[CI.base_table_index()] * challenges[U32CiWeight]; + to_invert.push(challenges[U32Indeterminate] - compressed_row); + } else if previous_ci == Instruction::Lt.opcode_b() + || previous_ci == Instruction::And.opcode_b() + || previous_ci == Instruction::Pow.opcode_b() + { + let compressed_row = previous_row[ST0.base_table_index()] * challenges[U32LhsWeight] + + previous_row[ST1.base_table_index()] * challenges[U32RhsWeight] + + previous_row[CI.base_table_index()] * challenges[U32CiWeight] + + current_row[ST0.base_table_index()] * challenges[U32ResultWeight]; + to_invert.push(challenges[U32Indeterminate] - compressed_row); + } else if previous_ci == Instruction::Xor.opcode_b() { + // Triton VM uses the following equality to compute the results of both the + // `and` and `xor` instruction using the u32 coprocessor's `and` capability: + // a ^ b = a + b - 2 · (a & b) + // <=> a & b = (a + b - a ^ b) / 2 + let st0_prev = previous_row[ST0.base_table_index()]; + let st1_prev = previous_row[ST1.base_table_index()]; + let st0 = current_row[ST0.base_table_index()]; + let from_xor_in_processor_to_and_in_u32_coprocessor = + (st0_prev + st1_prev - st0) / bfe!(2); + let compressed_row = st0_prev * challenges[U32LhsWeight] + + st1_prev * challenges[U32RhsWeight] + + Instruction::And.opcode_b() * challenges[U32CiWeight] + + from_xor_in_processor_to_and_in_u32_coprocessor * challenges[U32ResultWeight]; + to_invert.push(challenges[U32Indeterminate] - compressed_row); + } else if previous_ci == Instruction::Log2Floor.opcode_b() + || previous_ci == Instruction::PopCount.opcode_b() + { + let compressed_row = previous_row[ST0.base_table_index()] * challenges[U32LhsWeight] + + previous_row[CI.base_table_index()] * challenges[U32CiWeight] + + current_row[ST0.base_table_index()] * challenges[U32ResultWeight]; + to_invert.push(challenges[U32Indeterminate] - compressed_row); + } else if previous_ci == Instruction::DivMod.opcode_b() { + let compressed_row_for_lt_check = current_row[ST0.base_table_index()] + * challenges[U32LhsWeight] + + previous_row[ST1.base_table_index()] * challenges[U32RhsWeight] + + Instruction::Lt.opcode_b() * challenges[U32CiWeight] + + bfe!(1) * challenges[U32ResultWeight]; + let compressed_row_for_range_check = previous_row[ST0.base_table_index()] + * challenges[U32LhsWeight] + + current_row[ST1.base_table_index()] * challenges[U32RhsWeight] + + Instruction::Split.opcode_b() * challenges[U32CiWeight]; + to_invert.push(challenges[U32Indeterminate] - compressed_row_for_lt_check); + to_invert.push(challenges[U32Indeterminate] - compressed_row_for_range_check); + } else if previous_ci == Instruction::MerkleStep.opcode_b() + || previous_ci == Instruction::MerkleStepMem.opcode_b() + { + let compressed_row = previous_row[ST5.base_table_index()] * challenges[U32LhsWeight] + + current_row[ST5.base_table_index()] * challenges[U32RhsWeight] + + Instruction::Split.opcode_b() * challenges[U32CiWeight]; + to_invert.push(challenges[U32Indeterminate] - compressed_row); + } + } + let mut inverses = XFieldElement::batch_inversion(to_invert).into_iter(); + + // populate column with inverses + let mut u32_table_running_sum_log_derivative = LookupArg::default_initial(); + let mut extension_column = Vec::with_capacity(base_table.nrows()); + extension_column.push(u32_table_running_sum_log_derivative); + for (previous_row, _) in base_table.rows().into_iter().tuple_windows() { + let previous_ci = previous_row[CI.base_table_index()]; + if Instruction::try_from(previous_ci) + .unwrap() + .is_u32_instruction() + { + u32_table_running_sum_log_derivative += inverses.next().unwrap(); + } + + // instruction `div_mod` requires a second inverse + if previous_ci == Instruction::DivMod.opcode_b() { + u32_table_running_sum_log_derivative += inverses.next().unwrap(); + } + + extension_column.push(u32_table_running_sum_log_derivative); + } + + Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() +} + +fn extension_column_for_clock_jump_difference_lookup_argument( + base_table: ArrayView2, + challenges: &Challenges, +) -> Array2 { + // collect inverses to batch invert + let mut to_invert = vec![]; + for row in base_table.rows() { + let lookup_multiplicity = row[ClockJumpDifferenceLookupMultiplicity.base_table_index()]; + if !lookup_multiplicity.is_zero() { + let clk = row[CLK.base_table_index()]; + to_invert.push(challenges[ClockJumpDifferenceLookupIndeterminate] - clk); + } + } + let mut inverses = XFieldElement::batch_inversion(to_invert).into_iter(); + + // populate extension column with inverses + let mut cjd_lookup_log_derivative = LookupArg::default_initial(); + let mut extension_column = Vec::with_capacity(base_table.nrows()); + for row in base_table.rows() { + let lookup_multiplicity = row[ClockJumpDifferenceLookupMultiplicity.base_table_index()]; + if !lookup_multiplicity.is_zero() { + cjd_lookup_log_derivative += inverses.next().unwrap() * lookup_multiplicity; + } + extension_column.push(cjd_lookup_log_derivative); + } + + Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() +} + +fn factor_for_op_stack_table_running_product( + previous_row: ArrayView1, + current_row: ArrayView1, + challenges: &Challenges, +) -> XFieldElement { + let default_factor = xfe!(1); + + let is_padding_row = current_row[IsPadding.base_table_index()].is_one(); + if is_padding_row { + return default_factor; + } + + let Some(previous_instruction) = instruction_from_row(previous_row) else { + return default_factor; + }; + + // shorter stack means relevant information is on top of stack, i.e., in stack registers + let row_with_shorter_stack = if previous_instruction.op_stack_size_influence() > 0 { + previous_row.view() + } else { + current_row.view() + }; + let op_stack_delta = previous_instruction + .op_stack_size_influence() + .unsigned_abs() as usize; + + let mut factor = default_factor; + for op_stack_pointer_offset in 0..op_stack_delta { + let max_stack_element_index = OpStackElement::COUNT - 1; + let stack_element_index = max_stack_element_index - op_stack_pointer_offset; + let stack_element_column = ProcessorTable::op_stack_column_by_index(stack_element_index); + let underflow_element = row_with_shorter_stack[stack_element_column.base_table_index()]; + + let op_stack_pointer = row_with_shorter_stack[OpStackPointer.base_table_index()]; + let offset = bfe!(op_stack_pointer_offset as u64); + let offset_op_stack_pointer = op_stack_pointer + offset; + + let clk = previous_row[CLK.base_table_index()]; + let ib1_shrink_stack = previous_row[IB1.base_table_index()]; + let compressed_row = clk * challenges[OpStackClkWeight] + + ib1_shrink_stack * challenges[OpStackIb1Weight] + + offset_op_stack_pointer * challenges[OpStackPointerWeight] + + underflow_element * challenges[OpStackFirstUnderflowElementWeight]; + factor *= challenges[OpStackIndeterminate] - compressed_row; + } + factor +} + +fn factor_for_ram_table_running_product( + previous_row: ArrayView1, + current_row: ArrayView1, + challenges: &Challenges, +) -> Option { + let is_padding_row = current_row[IsPadding.base_table_index()].is_one(); + if is_padding_row { + return None; + } + + let instruction = instruction_from_row(previous_row)?; + + let clk = previous_row[CLK.base_table_index()]; + let instruction_type = match instruction { + ReadMem(_) => ram::INSTRUCTION_TYPE_READ, + WriteMem(_) => ram::INSTRUCTION_TYPE_WRITE, + SpongeAbsorbMem => ram::INSTRUCTION_TYPE_READ, + MerkleStepMem => ram::INSTRUCTION_TYPE_READ, + XxDotStep => ram::INSTRUCTION_TYPE_READ, + XbDotStep => ram::INSTRUCTION_TYPE_READ, + _ => return None, + }; + let mut accesses = vec![]; + + match instruction { + ReadMem(_) | WriteMem(_) => { + // longer stack means relevant information is on top of stack, i.e., + // available in stack registers + let row_with_longer_stack = if let ReadMem(_) = instruction { + current_row.view() + } else { + previous_row.view() + }; + let op_stack_delta = instruction.op_stack_size_influence().unsigned_abs() as usize; + + let num_ram_pointers = 1; + for ram_pointer_offset in 0..op_stack_delta { + let ram_value_index = ram_pointer_offset + num_ram_pointers; + let ram_value_column = ProcessorTable::op_stack_column_by_index(ram_value_index); + let ram_value = row_with_longer_stack[ram_value_column.base_table_index()]; + let offset_ram_pointer = + offset_ram_pointer(instruction, row_with_longer_stack, ram_pointer_offset); + accesses.push((offset_ram_pointer, ram_value)); + } + } + SpongeAbsorbMem => { + let mem_pointer = previous_row[ST0.base_table_index()]; + accesses.push((mem_pointer + bfe!(0), current_row[ST1.base_table_index()])); + accesses.push((mem_pointer + bfe!(1), current_row[ST2.base_table_index()])); + accesses.push((mem_pointer + bfe!(2), current_row[ST3.base_table_index()])); + accesses.push((mem_pointer + bfe!(3), current_row[ST4.base_table_index()])); + accesses.push((mem_pointer + bfe!(4), previous_row[HV0.base_table_index()])); + accesses.push((mem_pointer + bfe!(5), previous_row[HV1.base_table_index()])); + accesses.push((mem_pointer + bfe!(6), previous_row[HV2.base_table_index()])); + accesses.push((mem_pointer + bfe!(7), previous_row[HV3.base_table_index()])); + accesses.push((mem_pointer + bfe!(8), previous_row[HV4.base_table_index()])); + accesses.push((mem_pointer + bfe!(9), previous_row[HV5.base_table_index()])); + } + MerkleStepMem => { + let mem_pointer = previous_row[ST7.base_table_index()]; + accesses.push((mem_pointer + bfe!(0), previous_row[HV0.base_table_index()])); + accesses.push((mem_pointer + bfe!(1), previous_row[HV1.base_table_index()])); + accesses.push((mem_pointer + bfe!(2), previous_row[HV2.base_table_index()])); + accesses.push((mem_pointer + bfe!(3), previous_row[HV3.base_table_index()])); + accesses.push((mem_pointer + bfe!(4), previous_row[HV4.base_table_index()])); + } + XxDotStep => { + let rhs_pointer = previous_row[ST0.base_table_index()]; + let lhs_pointer = previous_row[ST1.base_table_index()]; + accesses.push((rhs_pointer + bfe!(0), previous_row[HV0.base_table_index()])); + accesses.push((rhs_pointer + bfe!(1), previous_row[HV1.base_table_index()])); + accesses.push((rhs_pointer + bfe!(2), previous_row[HV2.base_table_index()])); + accesses.push((lhs_pointer + bfe!(0), previous_row[HV3.base_table_index()])); + accesses.push((lhs_pointer + bfe!(1), previous_row[HV4.base_table_index()])); + accesses.push((lhs_pointer + bfe!(2), previous_row[HV5.base_table_index()])); + } + XbDotStep => { + let rhs_pointer = previous_row[ST0.base_table_index()]; + let lhs_pointer = previous_row[ST1.base_table_index()]; + accesses.push((rhs_pointer + bfe!(0), previous_row[HV0.base_table_index()])); + accesses.push((lhs_pointer + bfe!(0), previous_row[HV1.base_table_index()])); + accesses.push((lhs_pointer + bfe!(1), previous_row[HV2.base_table_index()])); + accesses.push((lhs_pointer + bfe!(2), previous_row[HV3.base_table_index()])); + } + _ => unreachable!(), + }; + + accesses + .into_iter() + .map(|(ramp, ramv)| { + clk * challenges[RamClkWeight] + + instruction_type * challenges[RamInstructionTypeWeight] + + ramp * challenges[RamPointerWeight] + + ramv * challenges[RamValueWeight] + }) + .map(|compressed_row| challenges[RamIndeterminate] - compressed_row) + .reduce(|l, r| l * r) +} + +fn offset_ram_pointer( + instruction: Instruction, + row_with_longer_stack: ArrayView1, + ram_pointer_offset: usize, +) -> BFieldElement { + let ram_pointer = row_with_longer_stack[ST0.base_table_index()]; + let offset = bfe!(ram_pointer_offset as u64); + + match instruction { + // adjust for ram_pointer pointing in front of last-read address: + // `push 0 read_mem 1` leaves stack as `_ a -1` where `a` was read from address 0. + ReadMem(_) => ram_pointer + offset + bfe!(1), + WriteMem(_) => ram_pointer + offset, + _ => unreachable!(), + } +} + +fn instruction_from_row(row: ArrayView1) -> Option { + let opcode = row[CI.base_table_index()]; + let instruction = Instruction::try_from(opcode).ok()?; + + if instruction.arg().is_some() { + let arg = row[NIA.base_table_index()]; + return instruction.change_arg(arg).ok(); + } + + Some(instruction) +} + +#[cfg(test)] +pub(crate) mod tests { + use std::collections::HashMap; + + use air::table::processor::NUM_HELPER_VARIABLE_REGISTERS; + use air::table::NUM_BASE_COLUMNS; + use air::table::NUM_EXT_COLUMNS; + use assert2::assert; + use isa::instruction::Instruction; + use isa::op_stack::NumberOfWords::*; + use isa::op_stack::OpStackElement; + use isa::program::Program; + use isa::triton_asm; + use isa::triton_program; + use ndarray::Array2; + use proptest::collection::vec; + use proptest::prop_assert_eq; + use proptest_arbitrary_interop::arb; + use rand::thread_rng; + use rand::Rng; + use strum::IntoEnumIterator; + use test_strategy::proptest; + + use crate::error::InstructionError::DivisionByZero; + use crate::prelude::PublicInput; + use crate::shared_tests::ProgramAndInput; + use crate::stark::tests::master_tables_for_low_security_level; + use crate::table::master_table::*; + use crate::vm::VMState; + use crate::vm::VM; + use crate::NonDeterminism; + + use super::*; + + const MAIN_WIDTH: usize = ::MainColumn::COUNT; + + /// Does printing recurse infinitely? + #[test] + fn print_simple_processor_table_row() { + let program = triton_program!(push 2 sponge_init assert halt); + let err = VM::run(&program, [].into(), [].into()).unwrap_err(); + println!("\n{}", err.vm_state); + } + + #[derive(Debug, Clone)] + struct TestRows { + pub challenges: Challenges, + pub consecutive_master_base_table_rows: Array2, + pub consecutive_ext_base_table_rows: Array2, + } + + #[derive(Debug, Clone)] + struct TestRowsDebugInfo { + pub instruction: Instruction, + pub debug_cols_curr_row: Vec, + pub debug_cols_next_row: Vec, + } + + fn test_row_from_program(program: Program, row_num: usize) -> TestRows { + test_row_from_program_with_input(ProgramAndInput::new(program), row_num) + } + + fn test_row_from_program_with_input( + program_and_input: ProgramAndInput, + row_num: usize, + ) -> TestRows { + let (_, _, master_base_table, master_ext_table, challenges) = + master_tables_for_low_security_level(program_and_input); + TestRows { + challenges, + consecutive_master_base_table_rows: master_base_table + .trace_table() + .slice(s![row_num..=row_num + 1, ..]) + .to_owned(), + consecutive_ext_base_table_rows: master_ext_table + .trace_table() + .slice(s![row_num..=row_num + 1, ..]) + .to_owned(), + } + } + + fn assert_constraints_for_rows_with_debug_info( + test_rows: &[TestRows], + debug_info: TestRowsDebugInfo, + ) { + let instruction = debug_info.instruction; + let circuit_builder = ConstraintCircuitBuilder::new(); + let transition_constraints = air::table::processor::transition_constraints_for_instruction( + &circuit_builder, + instruction, + ); + + for (case_idx, rows) in test_rows.iter().enumerate() { + let curr_row = rows.consecutive_master_base_table_rows.slice(s![0, ..]); + let next_row = rows.consecutive_master_base_table_rows.slice(s![1, ..]); + + println!("Testing all constraints of {instruction} for test case {case_idx}…"); + for &c in &debug_info.debug_cols_curr_row { + print!("{c} = {}, ", curr_row[c.master_base_table_index()]); + } + println!(); + for &c in &debug_info.debug_cols_next_row { + print!("{c}' = {}, ", next_row[c.master_base_table_index()]); + } + println!(); + + assert!( + instruction.opcode_b() == curr_row[CI.master_base_table_index()], + "The test is trying to check the wrong transition constraint polynomials." + ); + + for (constraint_idx, constraint) in transition_constraints.iter().enumerate() { + let evaluation_result = constraint.clone().consume().evaluate( + rows.consecutive_master_base_table_rows.view(), + rows.consecutive_ext_base_table_rows.view(), + &rows.challenges.challenges, + ); + assert!( + evaluation_result.is_zero(), + "For case {case_idx}, transition constraint polynomial with \ + index {constraint_idx} must evaluate to zero. Got {evaluation_result} instead.", + ); + } + } + } + + #[proptest(cases = 20)] + fn transition_constraints_for_instruction_pop_n(#[strategy(arb())] n: NumberOfWords) { + let program = triton_program!(push 1 push 2 push 3 push 4 push 5 pop {n} halt); + + let test_rows = [test_row_from_program(program, 5)]; + let debug_info = TestRowsDebugInfo { + instruction: Pop(n), + debug_cols_curr_row: vec![ST0, ST1, ST2], + debug_cols_next_row: vec![ST0, ST1, ST2], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_push() { + let test_rows = [test_row_from_program(triton_program!(push 1 halt), 0)]; + + let debug_info = TestRowsDebugInfo { + instruction: Push(bfe!(1)), + debug_cols_curr_row: vec![ST0, ST1, ST2], + debug_cols_next_row: vec![ST0, ST1, ST2], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[proptest(cases = 20)] + fn transition_constraints_for_instruction_divine_n(#[strategy(arb())] n: NumberOfWords) { + let program = triton_program! { divine {n} halt }; + + let non_determinism = (1..=16).map(|b| bfe!(b)).collect_vec(); + let program_and_input = ProgramAndInput::new(program).with_non_determinism(non_determinism); + let test_rows = [test_row_from_program_with_input(program_and_input, 0)]; + let debug_info = TestRowsDebugInfo { + instruction: Divine(n), + debug_cols_curr_row: vec![ST0, ST1, ST2], + debug_cols_next_row: vec![ST0, ST1, ST2], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_dup() { + let programs = [ + triton_program!(dup 0 halt), + triton_program!(dup 1 halt), + triton_program!(dup 2 halt), + triton_program!(dup 3 halt), + triton_program!(dup 4 halt), + triton_program!(dup 5 halt), + triton_program!(dup 6 halt), + triton_program!(dup 7 halt), + triton_program!(dup 8 halt), + triton_program!(dup 9 halt), + triton_program!(dup 10 halt), + triton_program!(dup 11 halt), + triton_program!(dup 12 halt), + triton_program!(dup 13 halt), + triton_program!(dup 14 halt), + triton_program!(dup 15 halt), + ]; + let test_rows = programs.map(|program| test_row_from_program(program, 0)); + + let debug_info = TestRowsDebugInfo { + instruction: Dup(OpStackElement::ST0), + debug_cols_curr_row: vec![ST0, ST1, ST2], + debug_cols_next_row: vec![ST0, ST1, ST2], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_swap() { + let test_rows = (0..OpStackElement::COUNT) + .map(|i| triton_program!(swap {i} halt)) + .map(|program| test_row_from_program(program, 0)) + .collect_vec(); + let debug_info = TestRowsDebugInfo { + instruction: Swap(OpStackElement::ST0), + debug_cols_curr_row: vec![ST0, ST1, ST2], + debug_cols_next_row: vec![ST0, ST1, ST2], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_skiz() { + let programs = [ + triton_program!(push 1 skiz halt), // ST0 is non-zero + triton_program!(push 0 skiz assert halt), // ST0 is zero, next instruction of size 1 + triton_program!(push 0 skiz push 1 halt), // ST0 is zero, next instruction of size 2 + ]; + let test_rows = programs.map(|program| test_row_from_program(program, 1)); + let debug_info = TestRowsDebugInfo { + instruction: Skiz, + debug_cols_curr_row: vec![IP, NIA, ST0, HV5, HV4, HV3, HV2], + debug_cols_next_row: vec![IP], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_call() { + let programs = [triton_program!(call label label: halt)]; + let test_rows = programs.map(|program| test_row_from_program(program, 0)); + let debug_info = TestRowsDebugInfo { + instruction: Call(BFieldElement::default()), + debug_cols_curr_row: vec![IP, CI, NIA, JSP, JSO, JSD], + debug_cols_next_row: vec![IP, CI, NIA, JSP, JSO, JSD], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_return() { + let programs = [triton_program!(call label halt label: return)]; + let test_rows = programs.map(|program| test_row_from_program(program, 1)); + let debug_info = TestRowsDebugInfo { + instruction: Return, + debug_cols_curr_row: vec![IP, JSP, JSO, JSD], + debug_cols_next_row: vec![IP, JSP, JSO, JSD], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_recurse() { + let programs = + [triton_program!(push 2 call label halt label: push -1 add dup 0 skiz recurse return)]; + let test_rows = programs.map(|program| test_row_from_program(program, 6)); + let debug_info = TestRowsDebugInfo { + instruction: Recurse, + debug_cols_curr_row: vec![IP, JSP, JSO, JSD], + debug_cols_next_row: vec![IP, JSP, JSO, JSD], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_recurse_or_return() { + let program = triton_program! { + push 2 swap 6 + call loop halt + loop: + swap 5 push 1 add swap 5 + recurse_or_return + }; + let test_rows = [ + test_row_from_program(program.clone(), 7), // recurse + test_row_from_program(program, 12), // return + ]; + let debug_info = TestRowsDebugInfo { + instruction: RecurseOrReturn, + debug_cols_curr_row: vec![IP, JSP, JSO, JSD, ST5, ST6, HV4], + debug_cols_next_row: vec![IP, JSP, JSO, JSD], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_read_mem() { + let programs = [ + triton_program!(push 1 read_mem 1 push 0 eq assert assert halt), + triton_program!(push 2 read_mem 2 push 0 eq assert swap 1 push 2 eq assert halt), + triton_program!(push 3 read_mem 3 push 0 eq assert swap 2 push 3 eq assert halt), + triton_program!(push 4 read_mem 4 push 0 eq assert swap 3 push 4 eq assert halt), + triton_program!(push 5 read_mem 5 push 0 eq assert swap 4 push 5 eq assert halt), + ]; + let initial_ram = (1..=5) + .map(|i| (bfe!(i), bfe!(i))) + .collect::>(); + let non_determinism = NonDeterminism::default().with_ram(initial_ram); + let programs_with_input = programs.map(|program| { + ProgramAndInput::new(program).with_non_determinism(non_determinism.clone()) + }); + let test_rows = programs_with_input.map(|p_w_i| test_row_from_program_with_input(p_w_i, 1)); + let debug_info = TestRowsDebugInfo { + instruction: ReadMem(N1), + debug_cols_curr_row: vec![ST0, ST1], + debug_cols_next_row: vec![ST0, ST1], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_write_mem() { + let push_10_elements = triton_asm![push 2; 10]; + let programs = [ + triton_program!({&push_10_elements} write_mem 1 halt), + triton_program!({&push_10_elements} write_mem 2 halt), + triton_program!({&push_10_elements} write_mem 3 halt), + triton_program!({&push_10_elements} write_mem 4 halt), + triton_program!({&push_10_elements} write_mem 5 halt), + ]; + let test_rows = programs.map(|program| test_row_from_program(program, 10)); + let debug_info = TestRowsDebugInfo { + instruction: WriteMem(N1), + debug_cols_curr_row: vec![ST0, ST1], + debug_cols_next_row: vec![ST0, ST1], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_merkle_step() { + let programs = [ + triton_program!(push 2 swap 5 merkle_step halt), + triton_program!(push 3 swap 5 merkle_step halt), + ]; + let dummy_digest = Digest::new(bfe_array![1, 2, 3, 4, 5]); + let non_determinism = NonDeterminism::default().with_digests(vec![dummy_digest]); + let programs_with_input = programs.map(|program| { + ProgramAndInput::new(program).with_non_determinism(non_determinism.clone()) + }); + let test_rows = programs_with_input.map(|p_w_i| test_row_from_program_with_input(p_w_i, 2)); + + let debug_info = TestRowsDebugInfo { + instruction: MerkleStep, + debug_cols_curr_row: vec![ST0, ST1, ST2, ST3, ST4, ST5, HV0, HV1, HV2, HV3, HV4, HV5], + debug_cols_next_row: vec![ST0, ST1, ST2, ST3, ST4, ST5], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_merkle_step_mem() { + let sibling_digest = bfe_array![1, 2, 3, 4, 5]; + let acc_digest = bfe_array![11, 12, 13, 14, 15]; + let test_program = |node_index: u32| { + triton_program! { + push 42 // RAM pointer + push 1 // dummy + push {node_index} + push {acc_digest[0]} + push {acc_digest[1]} + push {acc_digest[2]} + push {acc_digest[3]} + push {acc_digest[4]} + merkle_step_mem + halt + } + }; + let mut ram = HashMap::new(); + ram.insert(bfe!(42), sibling_digest[0]); + ram.insert(bfe!(43), sibling_digest[1]); + ram.insert(bfe!(44), sibling_digest[2]); + ram.insert(bfe!(45), sibling_digest[3]); + ram.insert(bfe!(46), sibling_digest[4]); + let non_determinism = NonDeterminism::default().with_ram(ram); + + let node_indices = [2, 3]; + let test_rows = node_indices + .map(test_program) + .map(ProgramAndInput::new) + .map(|p| p.with_non_determinism(non_determinism.clone())) + .map(|p| test_row_from_program_with_input(p, 8)); + + let debug_info = TestRowsDebugInfo { + instruction: MerkleStepMem, + debug_cols_curr_row: vec![ST0, ST1, ST2, ST3, ST4, ST5, HV0, HV1, HV2, HV3, HV4, HV5], + debug_cols_next_row: vec![ST0, ST1, ST2, ST3, ST4, ST5], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_sponge_init() { + let programs = [triton_program!(sponge_init halt)]; + let test_rows = programs.map(|program| test_row_from_program(program, 0)); + let debug_info = TestRowsDebugInfo { + instruction: SpongeInit, + debug_cols_curr_row: vec![], + debug_cols_next_row: vec![], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_sponge_absorb() { + let push_10_zeros = triton_asm![push 0; 10]; + let push_10_ones = triton_asm![push 1; 10]; + let programs = [ + triton_program!(sponge_init {&push_10_zeros} sponge_absorb halt), + triton_program!(sponge_init {&push_10_ones} sponge_absorb halt), + ]; + let test_rows = programs.map(|program| test_row_from_program(program, 11)); + let debug_info = TestRowsDebugInfo { + instruction: SpongeAbsorb, + debug_cols_curr_row: vec![ST0, ST1, ST2, ST3, ST4, ST5, ST6, ST7, ST8, ST9, ST10], + debug_cols_next_row: vec![ST0, ST1, ST2, ST3, ST4, ST5, ST6, ST7, ST8, ST9, ST10], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_sponge_absorb_mem() { + let programs = [triton_program!(sponge_init push 0 sponge_absorb_mem halt)]; + let test_rows = programs.map(|program| test_row_from_program(program, 2)); + let debug_info = TestRowsDebugInfo { + instruction: SpongeAbsorbMem, + debug_cols_curr_row: vec![ST0, HV0, HV1, HV2, HV3, HV4, HV5], + debug_cols_next_row: vec![ST0, ST1, ST2, ST3, ST4], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_sponge_squeeze() { + let programs = [triton_program!(sponge_init sponge_squeeze halt)]; + let test_rows = programs.map(|program| test_row_from_program(program, 1)); + let debug_info = TestRowsDebugInfo { + instruction: SpongeSqueeze, + debug_cols_curr_row: vec![ST0, ST1, ST2, ST3, ST4, ST5, ST6, ST7, ST8, ST9, ST10], + debug_cols_next_row: vec![ST0, ST1, ST2, ST3, ST4, ST5, ST6, ST7, ST8, ST9, ST10], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_eq() { + let programs = [ + triton_program!(push 3 push 3 eq assert halt), + triton_program!(push 3 push 2 eq push 0 eq assert halt), + ]; + let test_rows = programs.map(|program| test_row_from_program(program, 2)); + let debug_info = TestRowsDebugInfo { + instruction: Eq, + debug_cols_curr_row: vec![ST0, ST1, HV0], + debug_cols_next_row: vec![ST0], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_split() { + let programs = [ + triton_program!(push -1 split halt), + triton_program!(push 0 split halt), + triton_program!(push 1 split halt), + triton_program!(push 2 split halt), + triton_program!(push 3 split halt), + // test pushing push 2^32 +- 1 + triton_program!(push 4294967295 split halt), + triton_program!(push 4294967296 split halt), + triton_program!(push 4294967297 split halt), + ]; + let test_rows = programs.map(|program| test_row_from_program(program, 1)); + let debug_info = TestRowsDebugInfo { + instruction: Split, + debug_cols_curr_row: vec![ST0, ST1, HV0], + debug_cols_next_row: vec![ST0, ST1, HV0], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_lt() { + let programs = [ + triton_program!(push 3 push 3 lt push 0 eq assert halt), + triton_program!(push 3 push 2 lt push 1 eq assert halt), + triton_program!(push 2 push 3 lt push 0 eq assert halt), + triton_program!(push 512 push 513 lt push 0 eq assert halt), + ]; + let test_rows = programs.map(|program| test_row_from_program(program, 2)); + let debug_info = TestRowsDebugInfo { + instruction: Lt, + debug_cols_curr_row: vec![ST0, ST1], + debug_cols_next_row: vec![ST0], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_and() { + let test_rows = [test_row_from_program( + triton_program!(push 5 push 12 and push 4 eq assert halt), + 2, + )]; + let debug_info = TestRowsDebugInfo { + instruction: And, + debug_cols_curr_row: vec![ST0, ST1], + debug_cols_next_row: vec![ST0], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_xor() { + let test_rows = [test_row_from_program( + triton_program!(push 5 push 12 xor push 9 eq assert halt), + 2, + )]; + let debug_info = TestRowsDebugInfo { + instruction: Xor, + debug_cols_curr_row: vec![ST0, ST1], + debug_cols_next_row: vec![ST0], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_log2floor() { + let programs = [ + triton_program!(push 1 log_2_floor push 0 eq assert halt), + triton_program!(push 2 log_2_floor push 1 eq assert halt), + triton_program!(push 3 log_2_floor push 1 eq assert halt), + triton_program!(push 4 log_2_floor push 2 eq assert halt), + triton_program!(push 5 log_2_floor push 2 eq assert halt), + triton_program!(push 6 log_2_floor push 2 eq assert halt), + triton_program!(push 7 log_2_floor push 2 eq assert halt), + triton_program!(push 8 log_2_floor push 3 eq assert halt), + triton_program!(push 9 log_2_floor push 3 eq assert halt), + triton_program!(push 10 log_2_floor push 3 eq assert halt), + triton_program!(push 11 log_2_floor push 3 eq assert halt), + triton_program!(push 12 log_2_floor push 3 eq assert halt), + triton_program!(push 13 log_2_floor push 3 eq assert halt), + triton_program!(push 14 log_2_floor push 3 eq assert halt), + triton_program!(push 15 log_2_floor push 3 eq assert halt), + triton_program!(push 16 log_2_floor push 4 eq assert halt), + triton_program!(push 17 log_2_floor push 4 eq assert halt), + ]; + + let test_rows = programs.map(|program| test_row_from_program(program, 1)); + let debug_info = TestRowsDebugInfo { + instruction: Log2Floor, + debug_cols_curr_row: vec![ST0, ST1], + debug_cols_next_row: vec![ST0], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_pow() { + let programs = [ + triton_program!(push 0 push 0 pow push 1 eq assert halt), + triton_program!(push 1 push 0 pow push 0 eq assert halt), + triton_program!(push 2 push 0 pow push 0 eq assert halt), + triton_program!(push 0 push 1 pow push 1 eq assert halt), + triton_program!(push 1 push 1 pow push 1 eq assert halt), + triton_program!(push 2 push 1 pow push 1 eq assert halt), + triton_program!(push 0 push 2 pow push 1 eq assert halt), + triton_program!(push 1 push 2 pow push 2 eq assert halt), + triton_program!(push 2 push 2 pow push 4 eq assert halt), + triton_program!(push 3 push 2 pow push 8 eq assert halt), + triton_program!(push 4 push 2 pow push 16 eq assert halt), + triton_program!(push 5 push 2 pow push 32 eq assert halt), + triton_program!(push 0 push 3 pow push 1 eq assert halt), + triton_program!(push 1 push 3 pow push 3 eq assert halt), + triton_program!(push 2 push 3 pow push 9 eq assert halt), + triton_program!(push 3 push 3 pow push 27 eq assert halt), + triton_program!(push 4 push 3 pow push 81 eq assert halt), + triton_program!(push 0 push 17 pow push 1 eq assert halt), + triton_program!(push 1 push 17 pow push 17 eq assert halt), + triton_program!(push 2 push 17 pow push 289 eq assert halt), + ]; + + let test_rows = programs.map(|program| test_row_from_program(program, 2)); + let debug_info = TestRowsDebugInfo { + instruction: Pow, + debug_cols_curr_row: vec![ST0, ST1], + debug_cols_next_row: vec![ST0], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_div_mod() { + let programs = [ + triton_program!(push 2 push 3 div_mod push 1 eq assert push 1 eq assert halt), + triton_program!(push 3 push 7 div_mod push 1 eq assert push 2 eq assert halt), + triton_program!(push 4 push 7 div_mod push 3 eq assert push 1 eq assert halt), + ]; + let test_rows = programs.map(|program| test_row_from_program(program, 2)); + let debug_info = TestRowsDebugInfo { + instruction: DivMod, + debug_cols_curr_row: vec![ST0, ST1], + debug_cols_next_row: vec![ST0, ST1], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn division_by_zero_is_impossible() { + let program = ProgramAndInput::new(triton_program! { div_mod }); + let err = program.run().unwrap_err(); + assert_eq!(DivisionByZero, err.source); + } + + #[test] + fn transition_constraints_for_instruction_xx_add() { + let programs = [ + triton_program!(push 5 push 6 push 7 push 8 push 9 push 10 xx_add halt), + triton_program!(push 2 push 3 push 4 push -2 push -3 push -4 xx_add halt), + ]; + let test_rows = programs.map(|program| test_row_from_program(program, 6)); + let debug_info = TestRowsDebugInfo { + instruction: XxAdd, + debug_cols_curr_row: vec![ST0, ST1, ST2, ST3, ST4, ST5], + debug_cols_next_row: vec![ST0, ST1, ST2, ST3, ST4, ST5], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_xx_mul() { + let programs = [ + triton_program!(push 5 push 6 push 7 push 8 push 9 push 10 xx_mul halt), + triton_program!(push 2 push 3 push 4 push -2 push -3 push -4 xx_mul halt), + ]; + let test_rows = programs.map(|program| test_row_from_program(program, 6)); + let debug_info = TestRowsDebugInfo { + instruction: XxMul, + debug_cols_curr_row: vec![ST0, ST1, ST2, ST3, ST4, ST5], + debug_cols_next_row: vec![ST0, ST1, ST2, ST3, ST4, ST5], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_x_invert() { + let programs = [ + triton_program!(push 5 push 6 push 7 x_invert halt), + triton_program!(push -2 push -3 push -4 x_invert halt), + ]; + let test_rows = programs.map(|program| test_row_from_program(program, 3)); + let debug_info = TestRowsDebugInfo { + instruction: XInvert, + debug_cols_curr_row: vec![ST0, ST1, ST2], + debug_cols_next_row: vec![ST0, ST1, ST2], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_xb_mul() { + let programs = [ + triton_program!(push 5 push 6 push 7 push 2 xb_mul halt), + triton_program!(push 2 push 3 push 4 push -2 xb_mul halt), + ]; + let test_rows = programs.map(|program| test_row_from_program(program, 4)); + let debug_info = TestRowsDebugInfo { + instruction: XbMul, + debug_cols_curr_row: vec![ST0, ST1, ST2, ST3, OpStackPointer, HV0], + debug_cols_next_row: vec![ST0, ST1, ST2, ST3, OpStackPointer, HV0], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[proptest(cases = 20)] + fn transition_constraints_for_instruction_read_io_n(#[strategy(arb())] n: NumberOfWords) { + let program = triton_program! {read_io {n} halt}; + + let public_input = (1..=16).map(|i| bfe!(i)).collect_vec(); + let program_and_input = ProgramAndInput::new(program).with_input(public_input); + let test_rows = [test_row_from_program_with_input(program_and_input, 0)]; + let debug_info = TestRowsDebugInfo { + instruction: ReadIo(n), + debug_cols_curr_row: vec![ST0, ST1, ST2], + debug_cols_next_row: vec![ST0, ST1, ST2], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[proptest(cases = 20)] + fn transition_constraints_for_instruction_write_io_n(#[strategy(arb())] n: NumberOfWords) { + let program = triton_program! {divine 5 write_io {n} halt}; + + let non_determinism = (1..=16).map(|b| bfe!(b)).collect_vec(); + let program_and_input = ProgramAndInput::new(program).with_non_determinism(non_determinism); + let test_rows = [test_row_from_program_with_input(program_and_input, 1)]; + let debug_info = TestRowsDebugInfo { + instruction: WriteIo(n), + debug_cols_curr_row: vec![ST0, ST1, ST2, ST3, ST4, ST5], + debug_cols_next_row: vec![ST0, ST1, ST2, ST3, ST4, ST5], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_xb_dot_step() { + let program = triton_program! { + push 10 push 20 push 30 // accumulator `[30, 20, 10]` + push 96 // pointer to extension-field element `[3, 5, 7]` + push 42 // pointer to base-field element `2` + xb_dot_step + push 43 eq assert + push 99 eq assert + push {30 + 2 * 3} eq assert + push {20 + 2 * 5} eq assert + push {10 + 2 * 7} eq assert + halt + }; + + let mut ram = HashMap::new(); + ram.insert(bfe!(42), bfe!(2)); + ram.insert(bfe!(96), bfe!(3)); + ram.insert(bfe!(97), bfe!(5)); + ram.insert(bfe!(98), bfe!(7)); + let non_determinism = NonDeterminism::default().with_ram(ram); + let program_and_input = ProgramAndInput::new(program).with_non_determinism(non_determinism); + let test_rows = [test_row_from_program_with_input(program_and_input, 5)]; + let debug_info = TestRowsDebugInfo { + instruction: XbDotStep, + debug_cols_curr_row: vec![ST0, ST1, ST2, ST3, ST4, HV0, HV1, HV2, HV3], + debug_cols_next_row: vec![ST0, ST1, ST2, ST3, ST4], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn transition_constraints_for_instruction_xx_dot_step() { + let operand_0 = xfe!([3, 5, 7]); + let operand_1 = xfe!([11, 13, 17]); + let product = operand_0 * operand_1; + + let program = triton_program! { + push 10 push 20 push 30 // accumulator `[30, 20, 10]` + push 96 // pointer to `operand_1` + push 42 // pointer to `operand_0` + xx_dot_step + push 45 eq assert + push 99 eq assert + push {bfe!(30) + product.coefficients[0]} eq assert + push {bfe!(20) + product.coefficients[1]} eq assert + push {bfe!(10) + product.coefficients[2]} eq assert + halt + }; + + let mut ram = HashMap::new(); + ram.insert(bfe!(42), operand_0.coefficients[0]); + ram.insert(bfe!(43), operand_0.coefficients[1]); + ram.insert(bfe!(44), operand_0.coefficients[2]); + ram.insert(bfe!(96), operand_1.coefficients[0]); + ram.insert(bfe!(97), operand_1.coefficients[1]); + ram.insert(bfe!(98), operand_1.coefficients[2]); + let non_determinism = NonDeterminism::default().with_ram(ram); + let program_and_input = ProgramAndInput::new(program).with_non_determinism(non_determinism); + let test_rows = [test_row_from_program_with_input(program_and_input, 5)]; + let debug_info = TestRowsDebugInfo { + instruction: XxDotStep, + debug_cols_curr_row: vec![ST0, ST1, ST2, ST3, ST4, HV0, HV1, HV2, HV3, HV4, HV5], + debug_cols_next_row: vec![ST0, ST1, ST2, ST3, ST4], + }; + assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); + } + + #[test] + fn opcode_decomposition_for_skiz_is_unique() { + let max_value_of_skiz_constraint_for_nia_decomposition = + (3 << 7) * (3 << 5) * (3 << 3) * (3 << 1) * 2; + for instruction in Instruction::iter() { + assert!( + instruction.opcode() < max_value_of_skiz_constraint_for_nia_decomposition, + "Opcode for {instruction} is too high." + ); + } + } + + #[proptest] + fn constructing_factor_for_op_stack_table_running_product_never_panics( + #[strategy(vec(arb(), MAIN_WIDTH))] previous_row: Vec, + #[strategy(vec(arb(), MAIN_WIDTH))] current_row: Vec, + #[strategy(arb())] challenges: Challenges, + ) { + let previous_row = Array1::from(previous_row); + let current_row = Array1::from(current_row); + let _ = factor_for_op_stack_table_running_product( + previous_row.view(), + current_row.view(), + &challenges, + ); + } + + #[proptest] + fn constructing_factor_for_ram_table_running_product_never_panics( + #[strategy(vec(arb(), MAIN_WIDTH))] previous_row: Vec, + #[strategy(vec(arb(), MAIN_WIDTH))] current_row: Vec, + #[strategy(arb())] challenges: Challenges, + ) { + let previous_row = Array1::from(previous_row); + let current_row = Array1::from(current_row); + let _ = factor_for_ram_table_running_product( + previous_row.view(), + current_row.view(), + &challenges, + ); + } +} diff --git a/triton-vm/src/table/processor_table.rs b/triton-vm/src/table/processor_table.rs deleted file mode 100644 index ce95fece0..000000000 --- a/triton-vm/src/table/processor_table.rs +++ /dev/null @@ -1,4987 +0,0 @@ -use std::cmp::max; -use std::ops::Mul; - -use constraint_circuit::DualRowIndicator::*; -use constraint_circuit::SingleRowIndicator::*; -use constraint_circuit::*; -use isa::instruction::AnInstruction::*; -use isa::instruction::Instruction; -use isa::instruction::InstructionBit; -use isa::instruction::ALL_INSTRUCTIONS; -use isa::op_stack::NumberOfWords; -use isa::op_stack::OpStackElement; -use isa::op_stack::NUM_OP_STACK_REGISTERS; -use itertools::izip; -use itertools::Itertools; -use ndarray::parallel::prelude::*; -use ndarray::*; -use num_traits::ConstOne; -use num_traits::One; -use num_traits::Zero; -use strum::EnumCount; -use strum::IntoEnumIterator; -use twenty_first::math::traits::FiniteField; -use twenty_first::prelude::x_field_element::EXTENSION_DEGREE; -use twenty_first::prelude::*; - -use crate::aet::AlgebraicExecutionTrace; -use crate::ndarray_helper::contiguous_column_slices; -use crate::ndarray_helper::horizontal_multi_slice_mut; -use crate::profiler::profiler; -use crate::table::challenges::ChallengeId; -use crate::table::challenges::ChallengeId::*; -use crate::table::challenges::Challenges; -use crate::table::cross_table_argument::*; -use crate::table::ram_table; -use crate::table::table_column::ProcessorBaseTableColumn::*; -use crate::table::table_column::ProcessorExtTableColumn::*; -use crate::table::table_column::*; - -pub const BASE_WIDTH: usize = ProcessorBaseTableColumn::COUNT; -pub const EXT_WIDTH: usize = ProcessorExtTableColumn::COUNT; -pub const FULL_WIDTH: usize = BASE_WIDTH + EXT_WIDTH; - -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct ProcessorTable; - -impl ProcessorTable { - pub fn fill_trace( - processor_table: &mut ArrayViewMut2, - aet: &AlgebraicExecutionTrace, - clk_jump_diffs_op_stack: &[BFieldElement], - clk_jump_diffs_ram: &[BFieldElement], - clk_jump_diffs_jump_stack: &[BFieldElement], - ) { - let num_rows = aet.processor_trace.nrows(); - let mut clk_jump_diff_multiplicities = Array1::zeros([num_rows]); - - for clk_jump_diff in clk_jump_diffs_op_stack - .iter() - .chain(clk_jump_diffs_ram) - .chain(clk_jump_diffs_jump_stack) - { - let clk = clk_jump_diff.value() as usize; - clk_jump_diff_multiplicities[clk] += BFieldElement::ONE; - } - - let mut processor_table = processor_table.slice_mut(s![0..num_rows, ..]); - processor_table.assign(&aet.processor_trace); - processor_table - .column_mut(ClockJumpDifferenceLookupMultiplicity.base_table_index()) - .assign(&clk_jump_diff_multiplicities); - } - - pub fn pad_trace( - mut processor_table: ArrayViewMut2, - processor_table_len: usize, - ) { - assert!( - processor_table_len > 0, - "Processor Table must have at least one row." - ); - let mut padding_template = processor_table.row(processor_table_len - 1).to_owned(); - padding_template[IsPadding.base_table_index()] = bfe!(1); - padding_template[ClockJumpDifferenceLookupMultiplicity.base_table_index()] = bfe!(0); - processor_table - .slice_mut(s![processor_table_len.., ..]) - .axis_iter_mut(Axis(0)) - .into_par_iter() - .for_each(|mut row| row.assign(&padding_template)); - - let clk_range = processor_table_len..processor_table.nrows(); - let clk_col = Array1::from_iter(clk_range.map(|a| bfe!(a as u64))); - clk_col.move_into( - processor_table.slice_mut(s![processor_table_len.., CLK.base_table_index()]), - ); - - // The Jump Stack Table does not have a padding indicator. Hence, clock jump differences are - // being looked up in its padding sections. The clock jump differences in that section are - // always 1. The lookup multiplicities of clock value 1 must be increased accordingly: one - // per padding row. - let num_padding_rows = processor_table.nrows() - processor_table_len; - let num_padding_rows = bfe!(num_padding_rows as u64); - let mut row_1 = processor_table.row_mut(1); - - row_1[ClockJumpDifferenceLookupMultiplicity.base_table_index()] += num_padding_rows; - } - - pub fn extend( - base_table: ArrayView2, - mut ext_table: ArrayViewMut2, - challenges: &Challenges, - ) { - profiler!(start "processor table"); - assert_eq!(BASE_WIDTH, base_table.ncols()); - assert_eq!(EXT_WIDTH, ext_table.ncols()); - assert_eq!(base_table.nrows(), ext_table.nrows()); - - let all_column_indices = ProcessorExtTableColumn::iter() - .map(|column| column.ext_table_index()) - .collect_vec(); - let all_column_slices = horizontal_multi_slice_mut( - ext_table.view_mut(), - &contiguous_column_slices(&all_column_indices), - ); - - let all_column_generators = [ - Self::extension_column_input_table_eval_argument, - Self::extension_column_output_table_eval_argument, - Self::extension_column_instruction_lookup_argument, - Self::extension_column_op_stack_table_perm_argument, - Self::extension_column_ram_table_perm_argument, - Self::extension_column_jump_stack_table_perm_argument, - Self::extension_column_hash_input_eval_argument, - Self::extension_column_hash_digest_eval_argument, - Self::extension_column_sponge_eval_argument, - Self::extension_column_for_u32_lookup_argument, - Self::extension_column_for_clock_jump_difference_lookup_argument, - ]; - all_column_generators - .into_par_iter() - .zip_eq(all_column_slices) - .for_each(|(generator, slice)| { - generator(base_table, challenges).move_into(slice); - }); - - profiler!(stop "processor table"); - } - - fn extension_column_input_table_eval_argument( - base_table: ArrayView2, - challenges: &Challenges, - ) -> Array2 { - let mut input_table_running_evaluation = EvalArg::default_initial(); - let mut extension_column = Vec::with_capacity(base_table.nrows()); - extension_column.push(input_table_running_evaluation); - for (previous_row, current_row) in base_table.rows().into_iter().tuple_windows() { - if let Some(Instruction::ReadIo(st)) = Self::instruction_from_row(previous_row) { - for i in (0..st.num_words()).rev() { - let input_symbol_column = Self::op_stack_column_by_index(i); - let input_symbol = current_row[input_symbol_column.base_table_index()]; - input_table_running_evaluation = input_table_running_evaluation - * challenges[StandardInputIndeterminate] - + input_symbol; - } - } - extension_column.push(input_table_running_evaluation); - } - Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() - } - - fn extension_column_output_table_eval_argument( - base_table: ArrayView2, - challenges: &Challenges, - ) -> Array2 { - let mut output_table_running_evaluation = EvalArg::default_initial(); - let mut extension_column = Vec::with_capacity(base_table.nrows()); - extension_column.push(output_table_running_evaluation); - for (previous_row, _) in base_table.rows().into_iter().tuple_windows() { - if let Some(Instruction::WriteIo(st)) = Self::instruction_from_row(previous_row) { - for i in 0..st.num_words() { - let output_symbol_column = Self::op_stack_column_by_index(i); - let output_symbol = previous_row[output_symbol_column.base_table_index()]; - output_table_running_evaluation = output_table_running_evaluation - * challenges[StandardOutputIndeterminate] - + output_symbol; - } - } - extension_column.push(output_table_running_evaluation); - } - Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() - } - - fn extension_column_instruction_lookup_argument( - base_table: ArrayView2, - challenges: &Challenges, - ) -> Array2 { - // collect all to-be-inverted elements for batch inversion - let mut to_invert = vec![]; - for row in base_table.rows() { - if row[IsPadding.base_table_index()].is_one() { - break; // padding marks the end of the trace - } - - let compressed_row = row[IP.base_table_index()] * challenges[ProgramAddressWeight] - + row[CI.base_table_index()] * challenges[ProgramInstructionWeight] - + row[NIA.base_table_index()] * challenges[ProgramNextInstructionWeight]; - to_invert.push(challenges[InstructionLookupIndeterminate] - compressed_row); - } - - // populate extension column with inverses - let mut instruction_lookup_log_derivative = LookupArg::default_initial(); - let mut extension_column = Vec::with_capacity(base_table.nrows()); - for inverse in XFieldElement::batch_inversion(to_invert) { - instruction_lookup_log_derivative += inverse; - extension_column.push(instruction_lookup_log_derivative); - } - - // fill padding section - extension_column.resize(base_table.nrows(), instruction_lookup_log_derivative); - Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() - } - - fn extension_column_op_stack_table_perm_argument( - base_table: ArrayView2, - challenges: &Challenges, - ) -> Array2 { - let mut op_stack_table_running_product = EvalArg::default_initial(); - let mut extension_column = Vec::with_capacity(base_table.nrows()); - extension_column.push(op_stack_table_running_product); - for (prev, curr) in base_table.rows().into_iter().tuple_windows() { - op_stack_table_running_product *= - Self::factor_for_op_stack_table_running_product(prev, curr, challenges); - extension_column.push(op_stack_table_running_product); - } - Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() - } - - fn extension_column_ram_table_perm_argument( - base_table: ArrayView2, - challenges: &Challenges, - ) -> Array2 { - let mut ram_table_running_product = PermArg::default_initial(); - let mut extension_column = Vec::with_capacity(base_table.nrows()); - extension_column.push(ram_table_running_product); - for (prev, curr) in base_table.rows().into_iter().tuple_windows() { - if let Some(f) = Self::factor_for_ram_table_running_product(prev, curr, challenges) { - ram_table_running_product *= f; - }; - extension_column.push(ram_table_running_product); - } - Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() - } - - fn extension_column_jump_stack_table_perm_argument( - base_table: ArrayView2, - challenges: &Challenges, - ) -> Array2 { - let mut jump_stack_running_product = PermArg::default_initial(); - let mut extension_column = Vec::with_capacity(base_table.nrows()); - for row in base_table.rows() { - let compressed_row = row[CLK.base_table_index()] * challenges[JumpStackClkWeight] - + row[CI.base_table_index()] * challenges[JumpStackCiWeight] - + row[JSP.base_table_index()] * challenges[JumpStackJspWeight] - + row[JSO.base_table_index()] * challenges[JumpStackJsoWeight] - + row[JSD.base_table_index()] * challenges[JumpStackJsdWeight]; - jump_stack_running_product *= challenges[JumpStackIndeterminate] - compressed_row; - extension_column.push(jump_stack_running_product); - } - Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() - } - - /// Hash Table – `hash`'s or `merkle_step`'s input from Processor to Hash Coprocessor - fn extension_column_hash_input_eval_argument( - base_table: ArrayView2, - challenges: &Challenges, - ) -> Array2 { - let st0_through_st9 = [ST0, ST1, ST2, ST3, ST4, ST5, ST6, ST7, ST8, ST9]; - let hash_state_weights = &challenges[StackWeight0..StackWeight10]; - - let merkle_step_left_sibling = [ST0, ST1, ST2, ST3, ST4, HV0, HV1, HV2, HV3, HV4]; - let merkle_step_right_sibling = [HV0, HV1, HV2, HV3, HV4, ST0, ST1, ST2, ST3, ST4]; - - let mut hash_input_running_evaluation = EvalArg::default_initial(); - let mut extension_column = Vec::with_capacity(base_table.nrows()); - for row in base_table.rows() { - let current_instruction = row[CI.base_table_index()]; - if current_instruction == Instruction::Hash.opcode_b() - || current_instruction == Instruction::MerkleStep.opcode_b() - || current_instruction == Instruction::MerkleStepMem.opcode_b() - { - let is_left_sibling = row[ST5.base_table_index()].value() % 2 == 0; - let hash_input = match Self::instruction_from_row(row) { - Some(MerkleStep | MerkleStepMem) if is_left_sibling => merkle_step_left_sibling, - Some(MerkleStep | MerkleStepMem) => merkle_step_right_sibling, - Some(Hash) => st0_through_st9, - _ => unreachable!(), - }; - let compressed_row = hash_input - .map(|st| row[st.base_table_index()]) - .into_iter() - .zip_eq(hash_state_weights.iter()) - .map(|(st, &weight)| weight * st) - .sum::(); - hash_input_running_evaluation = hash_input_running_evaluation - * challenges[HashInputIndeterminate] - + compressed_row; - } - extension_column.push(hash_input_running_evaluation); - } - Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() - } - - /// Hash Table – `hash`'s output from Hash Coprocessor to Processor - fn extension_column_hash_digest_eval_argument( - base_table: ArrayView2, - challenges: &Challenges, - ) -> Array2 { - let mut hash_digest_running_evaluation = EvalArg::default_initial(); - let mut extension_column = Vec::with_capacity(base_table.nrows()); - extension_column.push(hash_digest_running_evaluation); - for (previous_row, current_row) in base_table.rows().into_iter().tuple_windows() { - let previous_ci = previous_row[CI.base_table_index()]; - if previous_ci == Instruction::Hash.opcode_b() - || previous_ci == Instruction::MerkleStep.opcode_b() - || previous_ci == Instruction::MerkleStepMem.opcode_b() - { - let compressed_row = [ST0, ST1, ST2, ST3, ST4] - .map(|st| current_row[st.base_table_index()]) - .into_iter() - .zip_eq(&challenges[StackWeight0..=StackWeight4]) - .map(|(st, &weight)| weight * st) - .sum::(); - hash_digest_running_evaluation = hash_digest_running_evaluation - * challenges[HashDigestIndeterminate] - + compressed_row; - } - extension_column.push(hash_digest_running_evaluation); - } - Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() - } - - /// Hash Table – `hash`'s or `merkle_step`'s input from Processor to Hash Coprocessor - fn extension_column_sponge_eval_argument( - base_table: ArrayView2, - challenges: &Challenges, - ) -> Array2 { - let st0_through_st9 = [ST0, ST1, ST2, ST3, ST4, ST5, ST6, ST7, ST8, ST9]; - let hash_state_weights = &challenges[StackWeight0..StackWeight10]; - - let mut sponge_running_evaluation = EvalArg::default_initial(); - let mut extension_column = Vec::with_capacity(base_table.nrows()); - extension_column.push(sponge_running_evaluation); - for (previous_row, current_row) in base_table.rows().into_iter().tuple_windows() { - let previous_ci = previous_row[CI.base_table_index()]; - if previous_ci == Instruction::SpongeInit.opcode_b() { - sponge_running_evaluation = sponge_running_evaluation - * challenges[SpongeIndeterminate] - + challenges[HashCIWeight] * Instruction::SpongeInit.opcode_b(); - } else if previous_ci == Instruction::SpongeAbsorb.opcode_b() { - let compressed_row = st0_through_st9 - .map(|st| previous_row[st.base_table_index()]) - .into_iter() - .zip_eq(hash_state_weights.iter()) - .map(|(st, &weight)| weight * st) - .sum::(); - sponge_running_evaluation = sponge_running_evaluation - * challenges[SpongeIndeterminate] - + challenges[HashCIWeight] * Instruction::SpongeAbsorb.opcode_b() - + compressed_row; - } else if previous_ci == Instruction::SpongeAbsorbMem.opcode_b() { - let stack_elements = [ST1, ST2, ST3, ST4]; - let helper_variables = [HV0, HV1, HV2, HV3, HV4, HV5]; - let compressed_row = stack_elements - .map(|st| current_row[st.base_table_index()]) - .into_iter() - .chain(helper_variables.map(|hv| previous_row[hv.base_table_index()])) - .zip_eq(hash_state_weights.iter()) - .map(|(element, &weight)| weight * element) - .sum::(); - sponge_running_evaluation = sponge_running_evaluation - * challenges[SpongeIndeterminate] - + challenges[HashCIWeight] * Instruction::SpongeAbsorb.opcode_b() - + compressed_row; - } else if previous_ci == Instruction::SpongeSqueeze.opcode_b() { - let compressed_row = st0_through_st9 - .map(|st| current_row[st.base_table_index()]) - .into_iter() - .zip_eq(hash_state_weights.iter()) - .map(|(st, &weight)| weight * st) - .sum::(); - sponge_running_evaluation = sponge_running_evaluation - * challenges[SpongeIndeterminate] - + challenges[HashCIWeight] * Instruction::SpongeSqueeze.opcode_b() - + compressed_row; - } - extension_column.push(sponge_running_evaluation); - } - Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() - } - - fn extension_column_for_u32_lookup_argument( - base_table: ArrayView2, - challenges: &Challenges, - ) -> Array2 { - // collect elements to be inverted for more performant batch inversion - let mut to_invert = vec![]; - for (previous_row, current_row) in base_table.rows().into_iter().tuple_windows() { - let previous_ci = previous_row[CI.base_table_index()]; - if previous_ci == Instruction::Split.opcode_b() { - let compressed_row = current_row[ST0.base_table_index()] * challenges[U32LhsWeight] - + current_row[ST1.base_table_index()] * challenges[U32RhsWeight] - + previous_row[CI.base_table_index()] * challenges[U32CiWeight]; - to_invert.push(challenges[U32Indeterminate] - compressed_row); - } else if previous_ci == Instruction::Lt.opcode_b() - || previous_ci == Instruction::And.opcode_b() - || previous_ci == Instruction::Pow.opcode_b() - { - let compressed_row = previous_row[ST0.base_table_index()] - * challenges[U32LhsWeight] - + previous_row[ST1.base_table_index()] * challenges[U32RhsWeight] - + previous_row[CI.base_table_index()] * challenges[U32CiWeight] - + current_row[ST0.base_table_index()] * challenges[U32ResultWeight]; - to_invert.push(challenges[U32Indeterminate] - compressed_row); - } else if previous_ci == Instruction::Xor.opcode_b() { - // Triton VM uses the following equality to compute the results of both the - // `and` and `xor` instruction using the u32 coprocessor's `and` capability: - // a ^ b = a + b - 2 · (a & b) - // <=> a & b = (a + b - a ^ b) / 2 - let st0_prev = previous_row[ST0.base_table_index()]; - let st1_prev = previous_row[ST1.base_table_index()]; - let st0 = current_row[ST0.base_table_index()]; - let from_xor_in_processor_to_and_in_u32_coprocessor = - (st0_prev + st1_prev - st0) / bfe!(2); - let compressed_row = st0_prev * challenges[U32LhsWeight] - + st1_prev * challenges[U32RhsWeight] - + Instruction::And.opcode_b() * challenges[U32CiWeight] - + from_xor_in_processor_to_and_in_u32_coprocessor * challenges[U32ResultWeight]; - to_invert.push(challenges[U32Indeterminate] - compressed_row); - } else if previous_ci == Instruction::Log2Floor.opcode_b() - || previous_ci == Instruction::PopCount.opcode_b() - { - let compressed_row = previous_row[ST0.base_table_index()] - * challenges[U32LhsWeight] - + previous_row[CI.base_table_index()] * challenges[U32CiWeight] - + current_row[ST0.base_table_index()] * challenges[U32ResultWeight]; - to_invert.push(challenges[U32Indeterminate] - compressed_row); - } else if previous_ci == Instruction::DivMod.opcode_b() { - let compressed_row_for_lt_check = current_row[ST0.base_table_index()] - * challenges[U32LhsWeight] - + previous_row[ST1.base_table_index()] * challenges[U32RhsWeight] - + Instruction::Lt.opcode_b() * challenges[U32CiWeight] - + bfe!(1) * challenges[U32ResultWeight]; - let compressed_row_for_range_check = previous_row[ST0.base_table_index()] - * challenges[U32LhsWeight] - + current_row[ST1.base_table_index()] * challenges[U32RhsWeight] - + Instruction::Split.opcode_b() * challenges[U32CiWeight]; - to_invert.push(challenges[U32Indeterminate] - compressed_row_for_lt_check); - to_invert.push(challenges[U32Indeterminate] - compressed_row_for_range_check); - } else if previous_ci == Instruction::MerkleStep.opcode_b() - || previous_ci == Instruction::MerkleStepMem.opcode_b() - { - let compressed_row = previous_row[ST5.base_table_index()] - * challenges[U32LhsWeight] - + current_row[ST5.base_table_index()] * challenges[U32RhsWeight] - + Instruction::Split.opcode_b() * challenges[U32CiWeight]; - to_invert.push(challenges[U32Indeterminate] - compressed_row); - } - } - let mut inverses = XFieldElement::batch_inversion(to_invert).into_iter(); - - // populate column with inverses - let mut u32_table_running_sum_log_derivative = LookupArg::default_initial(); - let mut extension_column = Vec::with_capacity(base_table.nrows()); - extension_column.push(u32_table_running_sum_log_derivative); - for (previous_row, _) in base_table.rows().into_iter().tuple_windows() { - let previous_ci = previous_row[CI.base_table_index()]; - if Instruction::try_from(previous_ci) - .unwrap() - .is_u32_instruction() - { - u32_table_running_sum_log_derivative += inverses.next().unwrap(); - } - - // instruction `div_mod` requires a second inverse - if previous_ci == Instruction::DivMod.opcode_b() { - u32_table_running_sum_log_derivative += inverses.next().unwrap(); - } - - extension_column.push(u32_table_running_sum_log_derivative); - } - - Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() - } - - fn extension_column_for_clock_jump_difference_lookup_argument( - base_table: ArrayView2, - challenges: &Challenges, - ) -> Array2 { - // collect inverses to batch invert - let mut to_invert = vec![]; - for row in base_table.rows() { - let lookup_multiplicity = row[ClockJumpDifferenceLookupMultiplicity.base_table_index()]; - if !lookup_multiplicity.is_zero() { - let clk = row[CLK.base_table_index()]; - to_invert.push(challenges[ClockJumpDifferenceLookupIndeterminate] - clk); - } - } - let mut inverses = XFieldElement::batch_inversion(to_invert).into_iter(); - - // populate extension column with inverses - let mut cjd_lookup_log_derivative = LookupArg::default_initial(); - let mut extension_column = Vec::with_capacity(base_table.nrows()); - for row in base_table.rows() { - let lookup_multiplicity = row[ClockJumpDifferenceLookupMultiplicity.base_table_index()]; - if !lookup_multiplicity.is_zero() { - cjd_lookup_log_derivative += inverses.next().unwrap() * lookup_multiplicity; - } - extension_column.push(cjd_lookup_log_derivative); - } - - Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() - } - - fn factor_for_op_stack_table_running_product( - previous_row: ArrayView1, - current_row: ArrayView1, - challenges: &Challenges, - ) -> XFieldElement { - let default_factor = xfe!(1); - - let is_padding_row = current_row[IsPadding.base_table_index()].is_one(); - if is_padding_row { - return default_factor; - } - - let Some(previous_instruction) = Self::instruction_from_row(previous_row) else { - return default_factor; - }; - - // shorter stack means relevant information is on top of stack, i.e., in stack registers - let row_with_shorter_stack = if previous_instruction.op_stack_size_influence() > 0 { - previous_row.view() - } else { - current_row.view() - }; - let op_stack_delta = previous_instruction - .op_stack_size_influence() - .unsigned_abs() as usize; - - let mut factor = default_factor; - for op_stack_pointer_offset in 0..op_stack_delta { - let max_stack_element_index = OpStackElement::COUNT - 1; - let stack_element_index = max_stack_element_index - op_stack_pointer_offset; - let stack_element_column = Self::op_stack_column_by_index(stack_element_index); - let underflow_element = row_with_shorter_stack[stack_element_column.base_table_index()]; - - let op_stack_pointer = row_with_shorter_stack[OpStackPointer.base_table_index()]; - let offset = bfe!(op_stack_pointer_offset as u64); - let offset_op_stack_pointer = op_stack_pointer + offset; - - let clk = previous_row[CLK.base_table_index()]; - let ib1_shrink_stack = previous_row[IB1.base_table_index()]; - let compressed_row = clk * challenges[OpStackClkWeight] - + ib1_shrink_stack * challenges[OpStackIb1Weight] - + offset_op_stack_pointer * challenges[OpStackPointerWeight] - + underflow_element * challenges[OpStackFirstUnderflowElementWeight]; - factor *= challenges[OpStackIndeterminate] - compressed_row; - } - factor - } - - fn factor_for_ram_table_running_product( - previous_row: ArrayView1, - current_row: ArrayView1, - challenges: &Challenges, - ) -> Option { - let is_padding_row = current_row[IsPadding.base_table_index()].is_one(); - if is_padding_row { - return None; - } - - let instruction = Self::instruction_from_row(previous_row)?; - - let clk = previous_row[CLK.base_table_index()]; - let instruction_type = match instruction { - ReadMem(_) => ram_table::INSTRUCTION_TYPE_READ, - WriteMem(_) => ram_table::INSTRUCTION_TYPE_WRITE, - SpongeAbsorbMem => ram_table::INSTRUCTION_TYPE_READ, - MerkleStepMem => ram_table::INSTRUCTION_TYPE_READ, - XxDotStep => ram_table::INSTRUCTION_TYPE_READ, - XbDotStep => ram_table::INSTRUCTION_TYPE_READ, - _ => return None, - }; - let mut accesses = vec![]; - - match instruction { - ReadMem(_) | WriteMem(_) => { - // longer stack means relevant information is on top of stack, i.e., - // available in stack registers - let row_with_longer_stack = if let ReadMem(_) = instruction { - current_row.view() - } else { - previous_row.view() - }; - let op_stack_delta = instruction.op_stack_size_influence().unsigned_abs() as usize; - - let num_ram_pointers = 1; - for ram_pointer_offset in 0..op_stack_delta { - let ram_value_index = ram_pointer_offset + num_ram_pointers; - let ram_value_column = Self::op_stack_column_by_index(ram_value_index); - let ram_value = row_with_longer_stack[ram_value_column.base_table_index()]; - let offset_ram_pointer = Self::offset_ram_pointer( - instruction, - row_with_longer_stack, - ram_pointer_offset, - ); - accesses.push((offset_ram_pointer, ram_value)); - } - } - SpongeAbsorbMem => { - let mem_pointer = previous_row[ST0.base_table_index()]; - accesses.push((mem_pointer + bfe!(0), current_row[ST1.base_table_index()])); - accesses.push((mem_pointer + bfe!(1), current_row[ST2.base_table_index()])); - accesses.push((mem_pointer + bfe!(2), current_row[ST3.base_table_index()])); - accesses.push((mem_pointer + bfe!(3), current_row[ST4.base_table_index()])); - accesses.push((mem_pointer + bfe!(4), previous_row[HV0.base_table_index()])); - accesses.push((mem_pointer + bfe!(5), previous_row[HV1.base_table_index()])); - accesses.push((mem_pointer + bfe!(6), previous_row[HV2.base_table_index()])); - accesses.push((mem_pointer + bfe!(7), previous_row[HV3.base_table_index()])); - accesses.push((mem_pointer + bfe!(8), previous_row[HV4.base_table_index()])); - accesses.push((mem_pointer + bfe!(9), previous_row[HV5.base_table_index()])); - } - MerkleStepMem => { - let mem_pointer = previous_row[ST7.base_table_index()]; - accesses.push((mem_pointer + bfe!(0), previous_row[HV0.base_table_index()])); - accesses.push((mem_pointer + bfe!(1), previous_row[HV1.base_table_index()])); - accesses.push((mem_pointer + bfe!(2), previous_row[HV2.base_table_index()])); - accesses.push((mem_pointer + bfe!(3), previous_row[HV3.base_table_index()])); - accesses.push((mem_pointer + bfe!(4), previous_row[HV4.base_table_index()])); - } - XxDotStep => { - let rhs_pointer = previous_row[ST0.base_table_index()]; - let lhs_pointer = previous_row[ST1.base_table_index()]; - accesses.push((rhs_pointer + bfe!(0), previous_row[HV0.base_table_index()])); - accesses.push((rhs_pointer + bfe!(1), previous_row[HV1.base_table_index()])); - accesses.push((rhs_pointer + bfe!(2), previous_row[HV2.base_table_index()])); - accesses.push((lhs_pointer + bfe!(0), previous_row[HV3.base_table_index()])); - accesses.push((lhs_pointer + bfe!(1), previous_row[HV4.base_table_index()])); - accesses.push((lhs_pointer + bfe!(2), previous_row[HV5.base_table_index()])); - } - XbDotStep => { - let rhs_pointer = previous_row[ST0.base_table_index()]; - let lhs_pointer = previous_row[ST1.base_table_index()]; - accesses.push((rhs_pointer + bfe!(0), previous_row[HV0.base_table_index()])); - accesses.push((lhs_pointer + bfe!(0), previous_row[HV1.base_table_index()])); - accesses.push((lhs_pointer + bfe!(1), previous_row[HV2.base_table_index()])); - accesses.push((lhs_pointer + bfe!(2), previous_row[HV3.base_table_index()])); - } - _ => unreachable!(), - }; - - accesses - .into_iter() - .map(|(ramp, ramv)| { - clk * challenges[RamClkWeight] - + instruction_type * challenges[RamInstructionTypeWeight] - + ramp * challenges[RamPointerWeight] - + ramv * challenges[RamValueWeight] - }) - .map(|compressed_row| challenges[RamIndeterminate] - compressed_row) - .reduce(|l, r| l * r) - } - - fn offset_ram_pointer( - instruction: Instruction, - row_with_longer_stack: ArrayView1, - ram_pointer_offset: usize, - ) -> BFieldElement { - let ram_pointer = row_with_longer_stack[ST0.base_table_index()]; - let offset = bfe!(ram_pointer_offset as u64); - - match instruction { - // adjust for ram_pointer pointing in front of last-read address: - // `push 0 read_mem 1` leaves stack as `_ a -1` where `a` was read from address 0. - ReadMem(_) => ram_pointer + offset + bfe!(1), - WriteMem(_) => ram_pointer + offset, - _ => unreachable!(), - } - } - - fn instruction_from_row(row: ArrayView1) -> Option { - let opcode = row[CI.base_table_index()]; - let instruction = Instruction::try_from(opcode).ok()?; - - if instruction.arg().is_some() { - let arg = row[NIA.base_table_index()]; - return instruction.change_arg(arg).ok(); - } - - Some(instruction) - } - - fn op_stack_column_by_index(index: usize) -> ProcessorBaseTableColumn { - match index { - 0 => ST0, - 1 => ST1, - 2 => ST2, - 3 => ST3, - 4 => ST4, - 5 => ST5, - 6 => ST6, - 7 => ST7, - 8 => ST8, - 9 => ST9, - 10 => ST10, - 11 => ST11, - 12 => ST12, - 13 => ST13, - 14 => ST14, - 15 => ST15, - i => panic!("Op Stack column index must be in [0, 15], not {i}."), - } - } -} - -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct ExtProcessorTable; - -impl ExtProcessorTable { - pub fn initial_constraints( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constant = |c: u32| circuit_builder.b_constant(c); - let x_constant = |x| circuit_builder.x_constant(x); - let challenge = |c| circuit_builder.challenge(c); - let base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(BaseRow(col.master_base_table_index())) - }; - let ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(ExtRow(col.master_ext_table_index())) - }; - - let clk_is_0 = base_row(CLK); - let ip_is_0 = base_row(IP); - let jsp_is_0 = base_row(JSP); - let jso_is_0 = base_row(JSO); - let jsd_is_0 = base_row(JSD); - let st0_is_0 = base_row(ST0); - let st1_is_0 = base_row(ST1); - let st2_is_0 = base_row(ST2); - let st3_is_0 = base_row(ST3); - let st4_is_0 = base_row(ST4); - let st5_is_0 = base_row(ST5); - let st6_is_0 = base_row(ST6); - let st7_is_0 = base_row(ST7); - let st8_is_0 = base_row(ST8); - let st9_is_0 = base_row(ST9); - let st10_is_0 = base_row(ST10); - let op_stack_pointer_is_16 = base_row(OpStackPointer) - constant(16); - - // Compress the program digest using an Evaluation Argument. - // Lowest index in the digest corresponds to lowest index on the stack. - let program_digest: [_; Digest::LEN] = [ - base_row(ST11), - base_row(ST12), - base_row(ST13), - base_row(ST14), - base_row(ST15), - ]; - let compressed_program_digest = program_digest.into_iter().fold( - circuit_builder.x_constant(EvalArg::default_initial()), - |acc, digest_element| { - acc * challenge(CompressProgramDigestIndeterminate) + digest_element - }, - ); - let compressed_program_digest_is_expected_program_digest = - compressed_program_digest - challenge(CompressedProgramDigest); - - // Permutation and Evaluation Arguments with all tables the Processor Table relates to - - // standard input - let running_evaluation_for_standard_input_is_initialized_correctly = - ext_row(InputTableEvalArg) - x_constant(EvalArg::default_initial()); - - // program table - let instruction_lookup_indeterminate = challenge(InstructionLookupIndeterminate); - let instruction_ci_weight = challenge(ProgramInstructionWeight); - let instruction_nia_weight = challenge(ProgramNextInstructionWeight); - let compressed_row_for_instruction_lookup = - instruction_ci_weight * base_row(CI) + instruction_nia_weight * base_row(NIA); - let instruction_lookup_log_derivative_is_initialized_correctly = - (ext_row(InstructionLookupClientLogDerivative) - - x_constant(LookupArg::default_initial())) - * (instruction_lookup_indeterminate - compressed_row_for_instruction_lookup) - - constant(1); - - // standard output - let running_evaluation_for_standard_output_is_initialized_correctly = - ext_row(OutputTableEvalArg) - x_constant(EvalArg::default_initial()); - - let running_product_for_op_stack_table_is_initialized_correctly = - ext_row(OpStackTablePermArg) - x_constant(PermArg::default_initial()); - - // ram table - let running_product_for_ram_table_is_initialized_correctly = - ext_row(RamTablePermArg) - x_constant(PermArg::default_initial()); - - // jump-stack table - let jump_stack_indeterminate = challenge(JumpStackIndeterminate); - let jump_stack_ci_weight = challenge(JumpStackCiWeight); - // note: `clk`, `jsp`, `jso`, and `jsd` are already constrained to be 0. - let compressed_row_for_jump_stack_table = jump_stack_ci_weight * base_row(CI); - let running_product_for_jump_stack_table_is_initialized_correctly = - ext_row(JumpStackTablePermArg) - - x_constant(PermArg::default_initial()) - * (jump_stack_indeterminate - compressed_row_for_jump_stack_table); - - // clock jump difference lookup argument - // The clock jump difference logarithmic derivative accumulator starts - // off having accumulated the contribution from the first row. - // Note that (challenge(ClockJumpDifferenceLookupIndeterminate) - base_row(CLK)) - // collapses to challenge(ClockJumpDifferenceLookupIndeterminate) - // because base_row(CLK) = 0 is already a constraint. - let clock_jump_diff_lookup_log_derivative_is_initialized_correctly = - ext_row(ClockJumpDifferenceLookupServerLogDerivative) - * challenge(ClockJumpDifferenceLookupIndeterminate) - - base_row(ClockJumpDifferenceLookupMultiplicity); - - // from processor to hash table - let hash_selector = base_row(CI) - constant(Instruction::Hash.opcode()); - let hash_deselector = - Self::instruction_deselector_single_row(circuit_builder, Instruction::Hash); - let hash_input_indeterminate = challenge(HashInputIndeterminate); - // the opStack is guaranteed to be initialized to 0 by virtue of other initial constraints - let compressed_row = constant(0); - let running_evaluation_hash_input_has_absorbed_first_row = ext_row(HashInputEvalArg) - - hash_input_indeterminate * x_constant(EvalArg::default_initial()) - - compressed_row; - let running_evaluation_hash_input_is_default_initial = - ext_row(HashInputEvalArg) - x_constant(EvalArg::default_initial()); - let running_evaluation_hash_input_is_initialized_correctly = hash_selector - * running_evaluation_hash_input_is_default_initial - + hash_deselector * running_evaluation_hash_input_has_absorbed_first_row; - - // from hash table to processor - let running_evaluation_hash_digest_is_initialized_correctly = - ext_row(HashDigestEvalArg) - x_constant(EvalArg::default_initial()); - - // Hash Table – Sponge - let running_evaluation_sponge_absorb_is_initialized_correctly = - ext_row(SpongeEvalArg) - x_constant(EvalArg::default_initial()); - - // u32 table - let running_sum_log_derivative_for_u32_table_is_initialized_correctly = - ext_row(U32LookupClientLogDerivative) - x_constant(LookupArg::default_initial()); - - vec![ - clk_is_0, - ip_is_0, - jsp_is_0, - jso_is_0, - jsd_is_0, - st0_is_0, - st1_is_0, - st2_is_0, - st3_is_0, - st4_is_0, - st5_is_0, - st6_is_0, - st7_is_0, - st8_is_0, - st9_is_0, - st10_is_0, - compressed_program_digest_is_expected_program_digest, - op_stack_pointer_is_16, - running_evaluation_for_standard_input_is_initialized_correctly, - instruction_lookup_log_derivative_is_initialized_correctly, - running_evaluation_for_standard_output_is_initialized_correctly, - running_product_for_op_stack_table_is_initialized_correctly, - running_product_for_ram_table_is_initialized_correctly, - running_product_for_jump_stack_table_is_initialized_correctly, - clock_jump_diff_lookup_log_derivative_is_initialized_correctly, - running_evaluation_hash_input_is_initialized_correctly, - running_evaluation_hash_digest_is_initialized_correctly, - running_evaluation_sponge_absorb_is_initialized_correctly, - running_sum_log_derivative_for_u32_table_is_initialized_correctly, - ] - } - - pub fn consistency_constraints( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constant = |c: u32| circuit_builder.b_constant(c); - let base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(BaseRow(col.master_base_table_index())) - }; - - // The composition of instruction bits ib0-ib7 corresponds the current instruction ci. - let ib_composition = base_row(IB0) - + constant(1 << 1) * base_row(IB1) - + constant(1 << 2) * base_row(IB2) - + constant(1 << 3) * base_row(IB3) - + constant(1 << 4) * base_row(IB4) - + constant(1 << 5) * base_row(IB5) - + constant(1 << 6) * base_row(IB6); - let ci_corresponds_to_ib0_thru_ib7 = base_row(CI) - ib_composition; - - let ib0_is_bit = base_row(IB0) * (base_row(IB0) - constant(1)); - let ib1_is_bit = base_row(IB1) * (base_row(IB1) - constant(1)); - let ib2_is_bit = base_row(IB2) * (base_row(IB2) - constant(1)); - let ib3_is_bit = base_row(IB3) * (base_row(IB3) - constant(1)); - let ib4_is_bit = base_row(IB4) * (base_row(IB4) - constant(1)); - let ib5_is_bit = base_row(IB5) * (base_row(IB5) - constant(1)); - let ib6_is_bit = base_row(IB6) * (base_row(IB6) - constant(1)); - let is_padding_is_bit = base_row(IsPadding) * (base_row(IsPadding) - constant(1)); - - // In padding rows, the clock jump difference lookup multiplicity is 0. The one row - // exempt from this rule is the row wth CLK == 1: since the memory-like tables don't have - // an “awareness” of padding rows, they keep looking up clock jump differences of - // magnitude 1. - let clock_jump_diff_lookup_multiplicity_is_0_in_padding_rows = base_row(IsPadding) - * (base_row(CLK) - constant(1)) - * base_row(ClockJumpDifferenceLookupMultiplicity); - - vec![ - ib0_is_bit, - ib1_is_bit, - ib2_is_bit, - ib3_is_bit, - ib4_is_bit, - ib5_is_bit, - ib6_is_bit, - is_padding_is_bit, - ci_corresponds_to_ib0_thru_ib7, - clock_jump_diff_lookup_multiplicity_is_0_in_padding_rows, - ] - } - - /// A polynomial that is 1 when evaluated on the given index, and 0 otherwise. - fn indicator_polynomial( - circuit_builder: &ConstraintCircuitBuilder, - index: usize, - ) -> ConstraintCircuitMonad { - let one = || circuit_builder.b_constant(1); - let hv = |idx| Self::helper_variable(circuit_builder, idx); - - match index { - 0 => (one() - hv(3)) * (one() - hv(2)) * (one() - hv(1)) * (one() - hv(0)), - 1 => (one() - hv(3)) * (one() - hv(2)) * (one() - hv(1)) * hv(0), - 2 => (one() - hv(3)) * (one() - hv(2)) * hv(1) * (one() - hv(0)), - 3 => (one() - hv(3)) * (one() - hv(2)) * hv(1) * hv(0), - 4 => (one() - hv(3)) * hv(2) * (one() - hv(1)) * (one() - hv(0)), - 5 => (one() - hv(3)) * hv(2) * (one() - hv(1)) * hv(0), - 6 => (one() - hv(3)) * hv(2) * hv(1) * (one() - hv(0)), - 7 => (one() - hv(3)) * hv(2) * hv(1) * hv(0), - 8 => hv(3) * (one() - hv(2)) * (one() - hv(1)) * (one() - hv(0)), - 9 => hv(3) * (one() - hv(2)) * (one() - hv(1)) * hv(0), - 10 => hv(3) * (one() - hv(2)) * hv(1) * (one() - hv(0)), - 11 => hv(3) * (one() - hv(2)) * hv(1) * hv(0), - 12 => hv(3) * hv(2) * (one() - hv(1)) * (one() - hv(0)), - 13 => hv(3) * hv(2) * (one() - hv(1)) * hv(0), - 14 => hv(3) * hv(2) * hv(1) * (one() - hv(0)), - 15 => hv(3) * hv(2) * hv(1) * hv(0), - i => panic!("indicator polynomial index {i} out of bounds"), - } - } - - fn helper_variable( - circuit_builder: &ConstraintCircuitBuilder, - index: usize, - ) -> ConstraintCircuitMonad { - match index { - 0 => circuit_builder.input(CurrentBaseRow(HV0.master_base_table_index())), - 1 => circuit_builder.input(CurrentBaseRow(HV1.master_base_table_index())), - 2 => circuit_builder.input(CurrentBaseRow(HV2.master_base_table_index())), - 3 => circuit_builder.input(CurrentBaseRow(HV3.master_base_table_index())), - 4 => circuit_builder.input(CurrentBaseRow(HV4.master_base_table_index())), - 5 => circuit_builder.input(CurrentBaseRow(HV5.master_base_table_index())), - i => unimplemented!("Helper variable index {i} out of bounds."), - } - } - - /// Instruction-specific transition constraints are combined with deselectors in such a way - /// that arbitrary sets of mutually exclusive combinations are summed, i.e., - /// - /// ```py - /// [ deselector_pop * tc_pop_0 + deselector_push * tc_push_0 + ..., - /// deselector_pop * tc_pop_1 + deselector_push * tc_push_1 + ..., - /// ..., - /// deselector_pop * tc_pop_i + deselector_push * tc_push_i + ..., - /// deselector_pop * 0 + deselector_push * tc_push_{i+1} + ..., - /// ..., - /// ] - /// ``` - /// For instructions that have fewer transition constraints than the maximal number of - /// transition constraints among all instructions, the deselector is multiplied with a zero, - /// causing no additional terms in the final sets of combined transition constraint polynomials. - fn combine_instruction_constraints_with_deselectors( - circuit_builder: &ConstraintCircuitBuilder, - instr_tc_polys_tuples: [(Instruction, Vec>); - Instruction::COUNT], - ) -> Vec> { - let (all_instructions, all_tc_polys_for_all_instructions): (Vec<_>, Vec<_>) = - instr_tc_polys_tuples.into_iter().unzip(); - - let all_instruction_deselectors = all_instructions - .into_iter() - .map(|instr| Self::instruction_deselector_current_row(circuit_builder, instr)) - .collect_vec(); - - let max_number_of_constraints = all_tc_polys_for_all_instructions - .iter() - .map(|tc_polys_for_instr| tc_polys_for_instr.len()) - .max() - .unwrap(); - - let zero_poly = circuit_builder.b_constant(0); - let all_tc_polys_for_all_instructions_transposed = (0..max_number_of_constraints) - .map(|idx| { - all_tc_polys_for_all_instructions - .iter() - .map(|tc_polys_for_instr| tc_polys_for_instr.get(idx).unwrap_or(&zero_poly)) - .collect_vec() - }) - .collect_vec(); - - all_tc_polys_for_all_instructions_transposed - .into_iter() - .map(|row| { - all_instruction_deselectors - .clone() - .into_iter() - .zip(row) - .map(|(deselector, instruction_tc)| deselector * instruction_tc.to_owned()) - .sum() - }) - .collect_vec() - } - - fn combine_transition_constraints_with_padding_constraints( - circuit_builder: &ConstraintCircuitBuilder, - instruction_transition_constraints: Vec>, - ) -> Vec> { - let constant = |c: u64| circuit_builder.b_constant(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let padding_row_transition_constraints = [ - vec![ - next_base_row(IP) - curr_base_row(IP), - next_base_row(CI) - curr_base_row(CI), - next_base_row(NIA) - curr_base_row(NIA), - ], - Self::instruction_group_keep_jump_stack(circuit_builder), - Self::instruction_group_keep_op_stack(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat(); - - let padding_row_deselector = constant(1) - next_base_row(IsPadding); - let padding_row_selector = next_base_row(IsPadding); - - let max_number_of_constraints = max( - instruction_transition_constraints.len(), - padding_row_transition_constraints.len(), - ); - - (0..max_number_of_constraints) - .map(|idx| { - let instruction_constraint = instruction_transition_constraints - .get(idx) - .unwrap_or(&constant(0)) - .to_owned(); - let padding_constraint = padding_row_transition_constraints - .get(idx) - .unwrap_or(&constant(0)) - .to_owned(); - - instruction_constraint * padding_row_deselector.clone() - + padding_constraint * padding_row_selector.clone() - }) - .collect_vec() - } - - fn instruction_group_decompose_arg( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constant = |c: u32| circuit_builder.b_constant(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - - let hv0_is_a_bit = curr_base_row(HV0) * (curr_base_row(HV0) - constant(1)); - let hv1_is_a_bit = curr_base_row(HV1) * (curr_base_row(HV1) - constant(1)); - let hv2_is_a_bit = curr_base_row(HV2) * (curr_base_row(HV2) - constant(1)); - let hv3_is_a_bit = curr_base_row(HV3) * (curr_base_row(HV3) - constant(1)); - - let helper_variables_are_binary_decomposition_of_nia = curr_base_row(NIA) - - constant(8) * curr_base_row(HV3) - - constant(4) * curr_base_row(HV2) - - constant(2) * curr_base_row(HV1) - - curr_base_row(HV0); - - vec![ - hv0_is_a_bit, - hv1_is_a_bit, - hv2_is_a_bit, - hv3_is_a_bit, - helper_variables_are_binary_decomposition_of_nia, - ] - } - - /// The permutation argument accumulator with the RAM table does - /// not change, because there is no RAM access. - fn instruction_group_no_ram( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; - - vec![next_ext_row(RamTablePermArg) - curr_ext_row(RamTablePermArg)] - } - - fn instruction_group_no_io( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - vec![ - Self::running_evaluation_for_standard_input_remains_unchanged(circuit_builder), - Self::running_evaluation_for_standard_output_remains_unchanged(circuit_builder), - ] - } - - /// Op Stack height does not change and except for the top n elements, - /// the values remain also. - fn instruction_group_op_stack_remains_except_top_n( - circuit_builder: &ConstraintCircuitBuilder, - n: usize, - ) -> Vec> { - assert!(n <= NUM_OP_STACK_REGISTERS); - - let curr_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let stack = (0..OpStackElement::COUNT) - .map(ProcessorTable::op_stack_column_by_index) - .collect_vec(); - let next_stack = stack.iter().map(|&st| next_row(st)).collect_vec(); - let curr_stack = stack.iter().map(|&st| curr_row(st)).collect_vec(); - - let compress_stack_except_top_n = |stack: Vec<_>| -> ConstraintCircuitMonad<_> { - assert_eq!(NUM_OP_STACK_REGISTERS, stack.len()); - let weight = |i| circuit_builder.challenge(Self::stack_weight_by_index(i)); - stack - .into_iter() - .enumerate() - .skip(n) - .map(|(i, st)| weight(i) * st) - .sum() - }; - - let all_but_n_top_elements_remain = - compress_stack_except_top_n(next_stack) - compress_stack_except_top_n(curr_stack); - - let mut constraints = Self::instruction_group_keep_op_stack_height(circuit_builder); - constraints.push(all_but_n_top_elements_remain); - constraints - } - - /// Op stack does not change, _i.e._, all stack elements persist - fn instruction_group_keep_op_stack( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - Self::instruction_group_op_stack_remains_except_top_n(circuit_builder, 0) - } - - /// Op stack *height* does not change, _i.e._, the accumulator for the - /// permutation argument with the op stack table remains the same as does - /// the op stack pointer. - fn instruction_group_keep_op_stack_height( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let op_stack_pointer_curr = - circuit_builder.input(CurrentBaseRow(OpStackPointer.master_base_table_index())); - let op_stack_pointer_next = - circuit_builder.input(NextBaseRow(OpStackPointer.master_base_table_index())); - let osp_remains_unchanged = op_stack_pointer_next - op_stack_pointer_curr; - - let op_stack_table_perm_arg_curr = - circuit_builder.input(CurrentExtRow(OpStackTablePermArg.master_ext_table_index())); - let op_stack_table_perm_arg_next = - circuit_builder.input(NextExtRow(OpStackTablePermArg.master_ext_table_index())); - let perm_arg_remains_unchanged = - op_stack_table_perm_arg_next - op_stack_table_perm_arg_curr; - - vec![osp_remains_unchanged, perm_arg_remains_unchanged] - } - - fn instruction_group_grow_op_stack_and_top_two_elements_unconstrained( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constant = |c: u32| circuit_builder.b_constant(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - vec![ - next_base_row(ST2) - curr_base_row(ST1), - next_base_row(ST3) - curr_base_row(ST2), - next_base_row(ST4) - curr_base_row(ST3), - next_base_row(ST5) - curr_base_row(ST4), - next_base_row(ST6) - curr_base_row(ST5), - next_base_row(ST7) - curr_base_row(ST6), - next_base_row(ST8) - curr_base_row(ST7), - next_base_row(ST9) - curr_base_row(ST8), - next_base_row(ST10) - curr_base_row(ST9), - next_base_row(ST11) - curr_base_row(ST10), - next_base_row(ST12) - curr_base_row(ST11), - next_base_row(ST13) - curr_base_row(ST12), - next_base_row(ST14) - curr_base_row(ST13), - next_base_row(ST15) - curr_base_row(ST14), - next_base_row(OpStackPointer) - curr_base_row(OpStackPointer) - constant(1), - Self::running_product_op_stack_accounts_for_growing_stack_by(circuit_builder, 1), - ] - } - - fn instruction_group_grow_op_stack( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let specific_constraints = vec![next_base_row(ST1) - curr_base_row(ST0)]; - let inherited_constraints = - Self::instruction_group_grow_op_stack_and_top_two_elements_unconstrained( - circuit_builder, - ); - - [specific_constraints, inherited_constraints].concat() - } - - fn instruction_group_op_stack_shrinks_and_top_three_elements_unconstrained( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constant = |c: u32| circuit_builder.b_constant(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - vec![ - next_base_row(ST3) - curr_base_row(ST4), - next_base_row(ST4) - curr_base_row(ST5), - next_base_row(ST5) - curr_base_row(ST6), - next_base_row(ST6) - curr_base_row(ST7), - next_base_row(ST7) - curr_base_row(ST8), - next_base_row(ST8) - curr_base_row(ST9), - next_base_row(ST9) - curr_base_row(ST10), - next_base_row(ST10) - curr_base_row(ST11), - next_base_row(ST11) - curr_base_row(ST12), - next_base_row(ST12) - curr_base_row(ST13), - next_base_row(ST13) - curr_base_row(ST14), - next_base_row(ST14) - curr_base_row(ST15), - next_base_row(OpStackPointer) - curr_base_row(OpStackPointer) + constant(1), - Self::running_product_op_stack_accounts_for_shrinking_stack_by(circuit_builder, 1), - ] - } - - fn instruction_group_binop( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let specific_constraints = vec![ - next_base_row(ST1) - curr_base_row(ST2), - next_base_row(ST2) - curr_base_row(ST3), - ]; - let inherited_constraints = - Self::instruction_group_op_stack_shrinks_and_top_three_elements_unconstrained( - circuit_builder, - ); - - [specific_constraints, inherited_constraints].concat() - } - - fn instruction_group_shrink_op_stack( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let specific_constraints = vec![next_base_row(ST0) - curr_base_row(ST1)]; - let inherited_constraints = Self::instruction_group_binop(circuit_builder); - - [specific_constraints, inherited_constraints].concat() - } - - fn instruction_group_keep_jump_stack( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let jsp_does_not_change = next_base_row(JSP) - curr_base_row(JSP); - let jso_does_not_change = next_base_row(JSO) - curr_base_row(JSO); - let jsd_does_not_change = next_base_row(JSD) - curr_base_row(JSD); - - vec![ - jsp_does_not_change, - jso_does_not_change, - jsd_does_not_change, - ] - } - - /// Increase the instruction pointer by 1. - fn instruction_group_step_1( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constant = |c: u32| circuit_builder.b_constant(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let instruction_pointer_increases_by_one = - next_base_row(IP) - curr_base_row(IP) - constant(1); - [ - Self::instruction_group_keep_jump_stack(circuit_builder), - vec![instruction_pointer_increases_by_one], - ] - .concat() - } - - /// Increase the instruction pointer by 2. - fn instruction_group_step_2( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constant = |c: u32| circuit_builder.b_constant(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let instruction_pointer_increases_by_two = - next_base_row(IP) - curr_base_row(IP) - constant(2); - [ - Self::instruction_group_keep_jump_stack(circuit_builder), - vec![instruction_pointer_increases_by_two], - ] - .concat() - } - - /// Internal helper function to de-duplicate functionality common between the similar (but - /// different on a type level) functions for construction deselectors. - fn instruction_deselector_common_functionality( - circuit_builder: &ConstraintCircuitBuilder, - instruction: Instruction, - instruction_bit_polynomials: [ConstraintCircuitMonad; InstructionBit::COUNT], - ) -> ConstraintCircuitMonad { - let one = || circuit_builder.b_constant(1); - - let selector_bits: [_; InstructionBit::COUNT] = [ - instruction.ib(InstructionBit::IB0), - instruction.ib(InstructionBit::IB1), - instruction.ib(InstructionBit::IB2), - instruction.ib(InstructionBit::IB3), - instruction.ib(InstructionBit::IB4), - instruction.ib(InstructionBit::IB5), - instruction.ib(InstructionBit::IB6), - ]; - let deselector_polynomials = selector_bits.map(|b| one() - circuit_builder.b_constant(b)); - - instruction_bit_polynomials - .into_iter() - .zip_eq(deselector_polynomials) - .map(|(instruction_bit_poly, deselector_poly)| instruction_bit_poly - deselector_poly) - .fold(one(), ConstraintCircuitMonad::mul) - } - - /// A polynomial that has no solutions when `ci` is `instruction`. - /// The number of variables in the polynomial corresponds to two rows. - fn instruction_deselector_current_row( - circuit_builder: &ConstraintCircuitBuilder, - instruction: Instruction, - ) -> ConstraintCircuitMonad { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - - let instruction_bit_polynomials = [ - curr_base_row(IB0), - curr_base_row(IB1), - curr_base_row(IB2), - curr_base_row(IB3), - curr_base_row(IB4), - curr_base_row(IB5), - curr_base_row(IB6), - ]; - - Self::instruction_deselector_common_functionality( - circuit_builder, - instruction, - instruction_bit_polynomials, - ) - } - - /// A polynomial that has no solutions when `ci_next` is `instruction`. - /// The number of variables in the polynomial corresponds to two rows. - fn instruction_deselector_next_row( - circuit_builder: &ConstraintCircuitBuilder, - instruction: Instruction, - ) -> ConstraintCircuitMonad { - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let instruction_bit_polynomials = [ - next_base_row(IB0), - next_base_row(IB1), - next_base_row(IB2), - next_base_row(IB3), - next_base_row(IB4), - next_base_row(IB5), - next_base_row(IB6), - ]; - - Self::instruction_deselector_common_functionality( - circuit_builder, - instruction, - instruction_bit_polynomials, - ) - } - - /// A polynomial that has no solutions when `ci` is `instruction`. - /// The number of variables in the polynomial corresponds to a single row. - fn instruction_deselector_single_row( - circuit_builder: &ConstraintCircuitBuilder, - instruction: Instruction, - ) -> ConstraintCircuitMonad { - let base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(BaseRow(col.master_base_table_index())) - }; - - let instruction_bit_polynomials = [ - base_row(IB0), - base_row(IB1), - base_row(IB2), - base_row(IB3), - base_row(IB4), - base_row(IB5), - base_row(IB6), - ]; - - Self::instruction_deselector_common_functionality( - circuit_builder, - instruction, - instruction_bit_polynomials, - ) - } - - fn instruction_pop( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - [ - Self::instruction_group_step_2(circuit_builder), - Self::instruction_group_decompose_arg(circuit_builder), - Self::stack_shrinks_by_any_of(circuit_builder, &NumberOfWords::legal_values()), - Self::prohibit_any_illegal_number_of_words(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_push( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let specific_constraints = vec![next_base_row(ST0) - curr_base_row(NIA)]; - [ - specific_constraints, - Self::instruction_group_grow_op_stack(circuit_builder), - Self::instruction_group_step_2(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_divine( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - [ - Self::instruction_group_step_2(circuit_builder), - Self::instruction_group_decompose_arg(circuit_builder), - Self::stack_grows_by_any_of(circuit_builder, &NumberOfWords::legal_values()), - Self::prohibit_any_illegal_number_of_words(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_dup( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let indicator_poly = |idx| Self::indicator_polynomial(circuit_builder, idx); - let curr_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let st_column = ProcessorTable::op_stack_column_by_index; - let duplicate_element = |i| indicator_poly(i) * (next_row(ST0) - curr_row(st_column(i))); - let duplicate_indicated_element = (0..OpStackElement::COUNT).map(duplicate_element).sum(); - - [ - vec![duplicate_indicated_element], - Self::instruction_group_decompose_arg(circuit_builder), - Self::instruction_group_step_2(circuit_builder), - Self::instruction_group_grow_op_stack(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_swap( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let curr_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let stack = (0..OpStackElement::COUNT) - .map(ProcessorTable::op_stack_column_by_index) - .collect_vec(); - let stack_with_swapped_i = |i| { - let mut stack = stack.clone(); - stack.swap(0, i); - stack.into_iter() - }; - - let next_stack = stack.iter().map(|&st| next_row(st)).collect_vec(); - let curr_stack_with_swapped_i = |i| stack_with_swapped_i(i).map(curr_row).collect_vec(); - let compress = |stack: Vec<_>| -> ConstraintCircuitMonad<_> { - assert_eq!(OpStackElement::COUNT, stack.len()); - let weight = |i| circuit_builder.challenge(Self::stack_weight_by_index(i)); - let enumerated_stack = stack.into_iter().enumerate(); - enumerated_stack.map(|(i, st)| weight(i) * st).sum() - }; - - let next_stack_is_current_stack_with_swapped_i = |i| { - Self::indicator_polynomial(circuit_builder, i) - * (compress(next_stack.clone()) - compress(curr_stack_with_swapped_i(i))) - }; - let next_stack_is_current_stack_with_correct_element_swapped = (0..OpStackElement::COUNT) - .map(next_stack_is_current_stack_with_swapped_i) - .sum(); - - [ - vec![next_stack_is_current_stack_with_correct_element_swapped], - Self::instruction_group_decompose_arg(circuit_builder), - Self::instruction_group_step_2(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - Self::instruction_group_keep_op_stack_height(circuit_builder), - ] - .concat() - } - - fn instruction_nop( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - [ - Self::instruction_group_step_1(circuit_builder), - Self::instruction_group_keep_op_stack(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_skiz( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constant = |c: u32| circuit_builder.b_constant(c); - let one = || constant(1); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let hv0_is_inverse_of_st0 = curr_base_row(HV0) * curr_base_row(ST0) - one(); - let hv0_is_inverse_of_st0_or_hv0_is_0 = hv0_is_inverse_of_st0.clone() * curr_base_row(HV0); - let hv0_is_inverse_of_st0_or_st0_is_0 = hv0_is_inverse_of_st0 * curr_base_row(ST0); - - // The next instruction nia is decomposed into helper variables hv. - let nia_decomposes_to_hvs = curr_base_row(NIA) - - curr_base_row(HV1) - - constant(1 << 1) * curr_base_row(HV2) - - constant(1 << 3) * curr_base_row(HV3) - - constant(1 << 5) * curr_base_row(HV4) - - constant(1 << 7) * curr_base_row(HV5); - - // If `st0` is non-zero, register `ip` is incremented by 1. - // If `st0` is 0 and `nia` takes no argument, register `ip` is incremented by 2. - // If `st0` is 0 and `nia` takes an argument, register `ip` is incremented by 3. - // - // The opcodes are constructed such that hv1 == 1 means that nia takes an argument. - // - // Written as Disjunctive Normal Form, the constraint can be expressed as: - // (Register `st0` is 0 or `ip` is incremented by 1), and - // (`st0` has a multiplicative inverse or `hv1` is 1 or `ip` is incremented by 2), and - // (`st0` has a multiplicative inverse or `hv1` is 0 or `ip` is incremented by 3). - let ip_case_1 = (next_base_row(IP) - curr_base_row(IP) - constant(1)) * curr_base_row(ST0); - let ip_case_2 = (next_base_row(IP) - curr_base_row(IP) - constant(2)) - * (curr_base_row(ST0) * curr_base_row(HV0) - one()) - * (curr_base_row(HV1) - one()); - let ip_case_3 = (next_base_row(IP) - curr_base_row(IP) - constant(3)) - * (curr_base_row(ST0) * curr_base_row(HV0) - one()) - * curr_base_row(HV1); - let ip_incr_by_1_or_2_or_3 = ip_case_1 + ip_case_2 + ip_case_3; - - let specific_constraints = vec![ - hv0_is_inverse_of_st0_or_hv0_is_0, - hv0_is_inverse_of_st0_or_st0_is_0, - nia_decomposes_to_hvs, - ip_incr_by_1_or_2_or_3, - ]; - [ - specific_constraints, - Self::next_instruction_range_check_constraints_for_instruction_skiz(circuit_builder), - Self::instruction_group_keep_jump_stack(circuit_builder), - Self::instruction_group_shrink_op_stack(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn next_instruction_range_check_constraints_for_instruction_skiz( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constant = |c: u32| circuit_builder.b_constant(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - - let is_0_or_1 = - |var: ProcessorBaseTableColumn| curr_base_row(var) * (curr_base_row(var) - constant(1)); - let is_0_or_1_or_2_or_3 = |var: ProcessorBaseTableColumn| { - curr_base_row(var) - * (curr_base_row(var) - constant(1)) - * (curr_base_row(var) - constant(2)) - * (curr_base_row(var) - constant(3)) - }; - - vec![ - is_0_or_1(HV1), - is_0_or_1_or_2_or_3(HV2), - is_0_or_1_or_2_or_3(HV3), - is_0_or_1_or_2_or_3(HV4), - is_0_or_1_or_2_or_3(HV5), - ] - } - - fn instruction_call( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constant = |c: u32| circuit_builder.b_constant(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - // The jump stack pointer jsp is incremented by 1. - let jsp_incr_1 = next_base_row(JSP) - curr_base_row(JSP) - constant(1); - - // The jump's origin jso is set to the current instruction pointer ip plus 2. - let jso_becomes_ip_plus_2 = next_base_row(JSO) - curr_base_row(IP) - constant(2); - - // The jump's destination jsd is set to the instruction's argument. - let jsd_becomes_nia = next_base_row(JSD) - curr_base_row(NIA); - - // The instruction pointer ip is set to the instruction's argument. - let ip_becomes_nia = next_base_row(IP) - curr_base_row(NIA); - - let specific_constraints = vec![ - jsp_incr_1, - jso_becomes_ip_plus_2, - jsd_becomes_nia, - ip_becomes_nia, - ]; - [ - specific_constraints, - Self::instruction_group_keep_op_stack(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_return( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constant = |c: u32| circuit_builder.b_constant(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let jsp_decrements_by_1 = next_base_row(JSP) - curr_base_row(JSP) + constant(1); - let ip_is_set_to_jso = next_base_row(IP) - curr_base_row(JSO); - let specific_constraints = vec![jsp_decrements_by_1, ip_is_set_to_jso]; - - [ - specific_constraints, - Self::instruction_group_keep_op_stack(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_recurse( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - // The instruction pointer ip is set to the last jump's destination jsd. - let ip_becomes_jsd = next_base_row(IP) - curr_base_row(JSD); - let specific_constraints = vec![ip_becomes_jsd]; - [ - specific_constraints, - Self::instruction_group_keep_jump_stack(circuit_builder), - Self::instruction_group_keep_op_stack(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_recurse_or_return( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let one = || circuit_builder.b_constant(1); - let curr_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - // Zero if the ST5 equals ST6. One if they are not equal. - let st5_eq_st6 = || curr_row(HV0) * (curr_row(ST6) - curr_row(ST5)); - let st5_neq_st6 = || one() - st5_eq_st6(); - - let maybe_return = vec![ - // hv0 is inverse-or-zero of the difference of ST6 and ST5. - st5_neq_st6() * curr_row(HV0), - st5_neq_st6() * (curr_row(ST6) - curr_row(ST5)), - st5_neq_st6() * (next_row(IP) - curr_row(JSO)), - st5_neq_st6() * (next_row(JSP) - curr_row(JSP) + one()), - ]; - let maybe_recurse = vec![ - // constraints are ordered to line up nicely with group “maybe_return” - st5_eq_st6() * (next_row(JSO) - curr_row(JSO)), - st5_eq_st6() * (next_row(JSD) - curr_row(JSD)), - st5_eq_st6() * (next_row(IP) - curr_row(JSD)), - st5_eq_st6() * (next_row(JSP) - curr_row(JSP)), - ]; - - // The two constraint groups are mutually exclusive: the stack element is either - // equal to its successor or not, indicated by `st5_eq_st6` and `st5_neq_st6`. - // Therefore, it is safe (and sound) to combine the groups into a single set of - // constraints. - let constraint_groups = vec![maybe_return, maybe_recurse]; - let specific_constraints = - Self::combine_mutually_exclusive_constraint_groups(circuit_builder, constraint_groups); - - [ - specific_constraints, - Self::instruction_group_keep_op_stack(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_assert( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constant = |c: u32| circuit_builder.b_constant(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - - // The current top of the stack st0 is 1. - let st_0_is_1 = curr_base_row(ST0) - constant(1); - - let specific_constraints = vec![st_0_is_1]; - [ - specific_constraints, - Self::instruction_group_step_1(circuit_builder), - Self::instruction_group_shrink_op_stack(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_halt( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - // The instruction executed in the following step is instruction halt. - let halt_is_followed_by_halt = next_base_row(CI) - curr_base_row(CI); - - let specific_constraints = vec![halt_is_followed_by_halt]; - [ - specific_constraints, - Self::instruction_group_step_1(circuit_builder), - Self::instruction_group_keep_op_stack(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_read_mem( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - [ - Self::instruction_group_step_2(circuit_builder), - Self::instruction_group_decompose_arg(circuit_builder), - Self::read_from_ram_any_of(circuit_builder, &NumberOfWords::legal_values()), - Self::prohibit_any_illegal_number_of_words(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_write_mem( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - [ - Self::instruction_group_step_2(circuit_builder), - Self::instruction_group_decompose_arg(circuit_builder), - Self::write_to_ram_any_of(circuit_builder, &NumberOfWords::legal_values()), - Self::prohibit_any_illegal_number_of_words(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - /// Two Evaluation Arguments with the Hash Table guarantee correct transition. - fn instruction_hash( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constant = |c: u32| circuit_builder.b_constant(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let op_stack_shrinks_by_5_and_top_5_unconstrained = vec![ - next_base_row(ST5) - curr_base_row(ST10), - next_base_row(ST6) - curr_base_row(ST11), - next_base_row(ST7) - curr_base_row(ST12), - next_base_row(ST8) - curr_base_row(ST13), - next_base_row(ST9) - curr_base_row(ST14), - next_base_row(ST10) - curr_base_row(ST15), - next_base_row(OpStackPointer) - curr_base_row(OpStackPointer) + constant(5), - Self::running_product_op_stack_accounts_for_shrinking_stack_by(circuit_builder, 5), - ]; - - [ - Self::instruction_group_step_1(circuit_builder), - op_stack_shrinks_by_5_and_top_5_unconstrained, - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_merkle_step( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - [ - Self::instruction_merkle_step_shared_constraints(circuit_builder), - Self::instruction_group_op_stack_remains_except_top_n(circuit_builder, 6), - Self::instruction_group_no_ram(circuit_builder), - ] - .concat() - } - - fn instruction_merkle_step_mem( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constant = |c: u32| circuit_builder.b_constant(c); - let stack_weight = |i| circuit_builder.challenge(Self::stack_weight_by_index(i)); - let curr = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let ram_pointers = [0, 1, 2, 3, 4].map(|i| curr(ST7) + constant(i)); - let ram_read_destinations = [HV0, HV1, HV2, HV3, HV4].map(curr); - let read_from_ram_to_hvs = - Self::read_from_ram_to(circuit_builder, ram_pointers, ram_read_destinations); - - let st6_does_not_change = next(ST6) - curr(ST6); - let st7_increments_by_5 = next(ST7) - curr(ST7) - constant(5); - let st6_and_st7_update_correctly = - stack_weight(6) * st6_does_not_change + stack_weight(7) * st7_increments_by_5; - - [ - vec![st6_and_st7_update_correctly, read_from_ram_to_hvs], - Self::instruction_merkle_step_shared_constraints(circuit_builder), - Self::instruction_group_op_stack_remains_except_top_n(circuit_builder, 8), - ] - .concat() - } - - /// Recall that in a Merkle tree, the indices of left (respectively right) - /// leaves have least-significant bit 0 (respectively 1). - /// - /// Two Evaluation Arguments with the Hash Table guarantee correct transition of - /// stack elements ST0 through ST4. - fn instruction_merkle_step_shared_constraints( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constant = |c: u32| circuit_builder.b_constant(c); - let one = || constant(1); - let curr = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let hv5_is_0_or_1 = curr(HV5) * (curr(HV5) - one()); - let new_st5_is_previous_st5_div_2 = constant(2) * next(ST5) + curr(HV5) - curr(ST5); - let update_merkle_tree_node_index = vec![hv5_is_0_or_1, new_st5_is_previous_st5_div_2]; - - [ - update_merkle_tree_node_index, - Self::instruction_group_step_1(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_assert_vector( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - - let specific_constraints = vec![ - curr_base_row(ST5) - curr_base_row(ST0), - curr_base_row(ST6) - curr_base_row(ST1), - curr_base_row(ST7) - curr_base_row(ST2), - curr_base_row(ST8) - curr_base_row(ST3), - curr_base_row(ST9) - curr_base_row(ST4), - ]; - [ - specific_constraints, - Self::instruction_group_step_1(circuit_builder), - Self::constraints_for_shrinking_stack_by(circuit_builder, 5), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_sponge_init( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - [ - Self::instruction_group_step_1(circuit_builder), - Self::instruction_group_keep_op_stack(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_sponge_absorb( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - [ - Self::instruction_group_step_1(circuit_builder), - Self::constraints_for_shrinking_stack_by(circuit_builder, 10), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_sponge_absorb_mem( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - let constant = |c| circuit_builder.b_constant(c); - - let increment_ram_pointer = - next_base_row(ST0) - curr_base_row(ST0) - constant(tip5::RATE as u32); - - [ - vec![increment_ram_pointer], - Self::instruction_group_step_1(circuit_builder), - Self::instruction_group_op_stack_remains_except_top_n(circuit_builder, 5), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_sponge_squeeze( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - [ - Self::instruction_group_step_1(circuit_builder), - Self::constraints_for_growing_stack_by(circuit_builder, 10), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_add( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let specific_constraints = - vec![next_base_row(ST0) - curr_base_row(ST0) - curr_base_row(ST1)]; - [ - specific_constraints, - Self::instruction_group_step_1(circuit_builder), - Self::instruction_group_binop(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_addi( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let specific_constraints = - vec![next_base_row(ST0) - curr_base_row(ST0) - curr_base_row(NIA)]; - [ - specific_constraints, - Self::instruction_group_step_2(circuit_builder), - Self::instruction_group_op_stack_remains_except_top_n(circuit_builder, 1), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_mul( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let specific_constraints = - vec![next_base_row(ST0) - curr_base_row(ST0) * curr_base_row(ST1)]; - [ - specific_constraints, - Self::instruction_group_step_1(circuit_builder), - Self::instruction_group_binop(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_invert( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constant = |c: u32| circuit_builder.b_constant(c); - let one = || constant(1); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let specific_constraints = vec![next_base_row(ST0) * curr_base_row(ST0) - one()]; - [ - specific_constraints, - Self::instruction_group_step_1(circuit_builder), - Self::instruction_group_op_stack_remains_except_top_n(circuit_builder, 1), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_eq( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constant = |c: u32| circuit_builder.b_constant(c); - let one = || constant(1); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let st0_eq_st1 = || one() - curr_base_row(HV0) * (curr_base_row(ST1) - curr_base_row(ST0)); - - // Helper variable hv0 is the inverse-or-zero of the difference of the stack's two top-most - // elements: `hv0·(1 - hv0·(st1 - st0))` - let hv0_is_inverse_of_diff_or_hv0_is_0 = curr_base_row(HV0) * st0_eq_st1(); - - // Helper variable hv0 is the inverse-or-zero of the difference of the stack's two - // top-most elements: `(st1 - st0)·(1 - hv0·(st1 - st0))` - let hv0_is_inverse_of_diff_or_diff_is_0 = - (curr_base_row(ST1) - curr_base_row(ST0)) * st0_eq_st1(); - - // The new top of the stack is 1 if the difference between the stack's two top-most - // elements is not invertible, 0 otherwise: `st0' - (1 - hv0·(st1 - st0))` - let st0_becomes_1_if_diff_is_not_invertible = next_base_row(ST0) - st0_eq_st1(); - - let specific_constraints = vec![ - hv0_is_inverse_of_diff_or_hv0_is_0, - hv0_is_inverse_of_diff_or_diff_is_0, - st0_becomes_1_if_diff_is_not_invertible, - ]; - [ - specific_constraints, - Self::instruction_group_step_1(circuit_builder), - Self::instruction_group_binop(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_split( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constant = |c: u64| circuit_builder.b_constant(c); - let one = || constant(1); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - // The top of the stack is decomposed as 32-bit chunks into the stack's top-most elements: - // st0 - (2^32·st0' + st1') = 0$ - let st0_decomposes_to_two_32_bit_chunks = - curr_base_row(ST0) - (constant(1 << 32) * next_base_row(ST1) + next_base_row(ST0)); - - // Helper variable `hv0` = 0 if either - // 1. `hv0` is the difference between (2^32 - 1) and the high 32 bits (`st0'`), or - // 1. the low 32 bits (`st1'`) are 0. - // - // st1'·(hv0·(st0' - (2^32 - 1)) - 1) - // lo·(hv0·(hi - 0xffff_ffff)) - 1) - let hv0_holds_inverse_of_chunk_difference_or_low_bits_are_0 = { - let hv0 = curr_base_row(HV0); - let hi = next_base_row(ST1); - let lo = next_base_row(ST0); - let ffff_ffff = constant(0xffff_ffff); - - lo * (hv0 * (hi - ffff_ffff) - one()) - }; - - let specific_constraints = vec![ - st0_decomposes_to_two_32_bit_chunks, - hv0_holds_inverse_of_chunk_difference_or_low_bits_are_0, - ]; - [ - specific_constraints, - Self::instruction_group_grow_op_stack_and_top_two_elements_unconstrained( - circuit_builder, - ), - Self::instruction_group_step_1(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_lt( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - [ - Self::instruction_group_step_1(circuit_builder), - Self::instruction_group_binop(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_and( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - [ - Self::instruction_group_step_1(circuit_builder), - Self::instruction_group_binop(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_xor( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - [ - Self::instruction_group_step_1(circuit_builder), - Self::instruction_group_binop(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_log_2_floor( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - [ - Self::instruction_group_step_1(circuit_builder), - Self::instruction_group_op_stack_remains_except_top_n(circuit_builder, 1), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_pow( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - [ - Self::instruction_group_step_1(circuit_builder), - Self::instruction_group_binop(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_div_mod( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - // `n == d·q + r` means `st0 - st1·st1' - st0'` - let numerator_is_quotient_times_denominator_plus_remainder = - curr_base_row(ST0) - curr_base_row(ST1) * next_base_row(ST1) - next_base_row(ST0); - - let specific_constraints = vec![numerator_is_quotient_times_denominator_plus_remainder]; - [ - specific_constraints, - Self::instruction_group_step_1(circuit_builder), - Self::instruction_group_op_stack_remains_except_top_n(circuit_builder, 2), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_pop_count( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - [ - Self::instruction_group_step_1(circuit_builder), - Self::instruction_group_op_stack_remains_except_top_n(circuit_builder, 1), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_xx_add( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let st0_becomes_st0_plus_st3 = next_base_row(ST0) - curr_base_row(ST0) - curr_base_row(ST3); - let st1_becomes_st1_plus_st4 = next_base_row(ST1) - curr_base_row(ST1) - curr_base_row(ST4); - let st2_becomes_st2_plus_st5 = next_base_row(ST2) - curr_base_row(ST2) - curr_base_row(ST5); - let specific_constraints = vec![ - st0_becomes_st0_plus_st3, - st1_becomes_st1_plus_st4, - st2_becomes_st2_plus_st5, - ]; - - [ - specific_constraints, - Self::constraints_for_shrinking_stack_by_3_and_top_3_unconstrained(circuit_builder), - Self::instruction_group_step_1(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_xx_mul( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let [x0, x1, x2, y0, y1, y2] = [ST0, ST1, ST2, ST3, ST4, ST5].map(curr_base_row); - let [c0, c1, c2] = Self::xx_product([x0, x1, x2], [y0, y1, y2]); - - let specific_constraints = vec![ - next_base_row(ST0) - c0, - next_base_row(ST1) - c1, - next_base_row(ST2) - c2, - ]; - [ - specific_constraints, - Self::constraints_for_shrinking_stack_by_3_and_top_3_unconstrained(circuit_builder), - Self::instruction_group_step_1(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_xinv( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constant = |c: u64| circuit_builder.b_constant(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let first_coefficient_of_product_of_element_and_inverse_is_1 = curr_base_row(ST0) - * next_base_row(ST0) - - curr_base_row(ST2) * next_base_row(ST1) - - curr_base_row(ST1) * next_base_row(ST2) - - constant(1); - - let second_coefficient_of_product_of_element_and_inverse_is_0 = - curr_base_row(ST1) * next_base_row(ST0) + curr_base_row(ST0) * next_base_row(ST1) - - curr_base_row(ST2) * next_base_row(ST2) - + curr_base_row(ST2) * next_base_row(ST1) - + curr_base_row(ST1) * next_base_row(ST2); - - let third_coefficient_of_product_of_element_and_inverse_is_0 = curr_base_row(ST2) - * next_base_row(ST0) - + curr_base_row(ST1) * next_base_row(ST1) - + curr_base_row(ST0) * next_base_row(ST2) - + curr_base_row(ST2) * next_base_row(ST2); - - let specific_constraints = vec![ - first_coefficient_of_product_of_element_and_inverse_is_1, - second_coefficient_of_product_of_element_and_inverse_is_0, - third_coefficient_of_product_of_element_and_inverse_is_0, - ]; - [ - specific_constraints, - Self::instruction_group_op_stack_remains_except_top_n(circuit_builder, 3), - Self::instruction_group_step_1(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_xb_mul( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let [x, y0, y1, y2] = [ST0, ST1, ST2, ST3].map(curr_base_row); - let [c0, c1, c2] = Self::xb_product([y0, y1, y2], x); - - let specific_constraints = vec![ - next_base_row(ST0) - c0, - next_base_row(ST1) - c1, - next_base_row(ST2) - c2, - ]; - [ - specific_constraints, - Self::instruction_group_op_stack_shrinks_and_top_three_elements_unconstrained( - circuit_builder, - ), - Self::instruction_group_step_1(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - ] - .concat() - } - - fn instruction_read_io( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constraint_groups_for_legal_arguments = NumberOfWords::legal_values() - .map(|n| Self::grow_stack_by_n_and_read_n_symbols_from_input(circuit_builder, n)) - .to_vec(); - let read_any_legal_number_of_words = Self::combine_mutually_exclusive_constraint_groups( - circuit_builder, - constraint_groups_for_legal_arguments, - ); - - [ - Self::instruction_group_step_2(circuit_builder), - Self::instruction_group_decompose_arg(circuit_builder), - read_any_legal_number_of_words, - Self::prohibit_any_illegal_number_of_words(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - vec![Self::running_evaluation_for_standard_output_remains_unchanged(circuit_builder)], - ] - .concat() - } - - fn instruction_write_io( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constraint_groups_for_legal_arguments = NumberOfWords::legal_values() - .map(|n| Self::shrink_stack_by_n_and_write_n_symbols_to_output(circuit_builder, n)) - .to_vec(); - let write_any_of_1_through_5_elements = Self::combine_mutually_exclusive_constraint_groups( - circuit_builder, - constraint_groups_for_legal_arguments, - ); - - [ - Self::instruction_group_step_2(circuit_builder), - Self::instruction_group_decompose_arg(circuit_builder), - write_any_of_1_through_5_elements, - Self::prohibit_any_illegal_number_of_words(circuit_builder), - Self::instruction_group_no_ram(circuit_builder), - vec![Self::running_evaluation_for_standard_input_remains_unchanged(circuit_builder)], - ] - .concat() - } - - /// Update the accumulator for the Permutation Argument with the RAM table in - /// accordance with reading a bunch of words from the indicated ram pointers to - /// the indicated destination registers. - /// - /// Does not constrain the op stack by default.[^stack] For that, see: - /// [`Self::read_from_ram_any_of`]. - /// - /// [^stack]: Op stack registers used in arguments will be constrained. - fn read_from_ram_to( - circuit_builder: &ConstraintCircuitBuilder, - ram_pointers: [ConstraintCircuitMonad; N], - destinations: [ConstraintCircuitMonad; N], - ) -> ConstraintCircuitMonad { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; - let challenge = |c: ChallengeId| circuit_builder.challenge(c); - let constant = |bfe| circuit_builder.b_constant(bfe); - - let compress_row = |(ram_pointer, destination)| { - curr_base_row(CLK) * challenge(RamClkWeight) - + constant(ram_table::INSTRUCTION_TYPE_READ) * challenge(RamInstructionTypeWeight) - + ram_pointer * challenge(RamPointerWeight) - + destination * challenge(RamValueWeight) - }; - - let factor = ram_pointers - .into_iter() - .zip(destinations) - .map(compress_row) - .map(|compressed_row| challenge(RamIndeterminate) - compressed_row) - .reduce(|l, r| l * r) - .unwrap_or_else(|| constant(bfe!(1))); - curr_ext_row(RamTablePermArg) * factor - next_ext_row(RamTablePermArg) - } - - fn xx_product( - [x_0, x_1, x_2]: [ConstraintCircuitMonad; EXTENSION_DEGREE], - [y_0, y_1, y_2]: [ConstraintCircuitMonad; EXTENSION_DEGREE], - ) -> [ConstraintCircuitMonad; EXTENSION_DEGREE] { - let z0 = x_0.clone() * y_0.clone(); - let z1 = x_1.clone() * y_0.clone() + x_0.clone() * y_1.clone(); - let z2 = x_2.clone() * y_0 + x_1.clone() * y_1.clone() + x_0 * y_2.clone(); - let z3 = x_2.clone() * y_1 + x_1 * y_2.clone(); - let z4 = x_2 * y_2; - - // reduce modulo x³ - x + 1 - [z0 - z3.clone(), z1 - z4.clone() + z3, z2 + z4] - } - - fn xb_product( - [x_0, x_1, x_2]: [ConstraintCircuitMonad; EXTENSION_DEGREE], - y: ConstraintCircuitMonad, - ) -> [ConstraintCircuitMonad; EXTENSION_DEGREE] { - let z0 = x_0 * y.clone(); - let z1 = x_1 * y.clone(); - let z2 = x_2 * y; - [z0, z1, z2] - } - - fn update_dotstep_accumulator( - circuit_builder: &ConstraintCircuitBuilder, - accumulator_indices: [ProcessorBaseTableColumn; EXTENSION_DEGREE], - difference: [ConstraintCircuitMonad; EXTENSION_DEGREE], - ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - let curr = accumulator_indices.map(curr_base_row); - let next = accumulator_indices.map(next_base_row); - izip!(curr, next, difference) - .map(|(c, n, d)| n - c - d) - .collect() - } - - fn instruction_xx_dot_step( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - let constant = |c| circuit_builder.b_constant(c); - - let increment_ram_pointer_st0 = next_base_row(ST0) - curr_base_row(ST0) - constant(3); - let increment_ram_pointer_st1 = next_base_row(ST1) - curr_base_row(ST1) - constant(3); - - let rhs_ptr0 = curr_base_row(ST0); - let rhs_ptr1 = rhs_ptr0.clone() + constant(1); - let rhs_ptr2 = rhs_ptr0.clone() + constant(2); - let lhs_ptr0 = curr_base_row(ST1); - let lhs_ptr1 = lhs_ptr0.clone() + constant(1); - let lhs_ptr2 = lhs_ptr0.clone() + constant(2); - let ram_read_sources = [rhs_ptr0, rhs_ptr1, rhs_ptr2, lhs_ptr0, lhs_ptr1, lhs_ptr2]; - let ram_read_destinations = [HV0, HV1, HV2, HV3, HV4, HV5].map(curr_base_row); - let read_two_xfes_from_ram = - Self::read_from_ram_to(circuit_builder, ram_read_sources, ram_read_destinations); - - let ram_pointer_constraints = vec![ - increment_ram_pointer_st0, - increment_ram_pointer_st1, - read_two_xfes_from_ram, - ]; - - let [hv0, hv1, hv2, hv3, hv4, hv5] = [HV0, HV1, HV2, HV3, HV4, HV5].map(curr_base_row); - let hv_product = Self::xx_product([hv0, hv1, hv2], [hv3, hv4, hv5]); - - [ - ram_pointer_constraints, - Self::update_dotstep_accumulator(circuit_builder, [ST2, ST3, ST4], hv_product), - Self::instruction_group_step_1(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - Self::instruction_group_op_stack_remains_except_top_n(circuit_builder, 5), - ] - .concat() - } - - fn instruction_xb_dot_step( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - let constant = |c| circuit_builder.b_constant(c); - - let increment_ram_pointer_st0 = next_base_row(ST0) - curr_base_row(ST0) - constant(1); - let increment_ram_pointer_st1 = next_base_row(ST1) - curr_base_row(ST1) - constant(3); - - let rhs_ptr0 = curr_base_row(ST0); - let lhs_ptr0 = curr_base_row(ST1); - let lhs_ptr1 = lhs_ptr0.clone() + constant(1); - let lhs_ptr2 = lhs_ptr0.clone() + constant(2); - let ram_read_sources = [rhs_ptr0, lhs_ptr0, lhs_ptr1, lhs_ptr2]; - let ram_read_destinations = [HV0, HV1, HV2, HV3].map(curr_base_row); - let read_bfe_and_xfe_from_ram = - Self::read_from_ram_to(circuit_builder, ram_read_sources, ram_read_destinations); - - let ram_pointer_constraints = vec![ - increment_ram_pointer_st0, - increment_ram_pointer_st1, - read_bfe_and_xfe_from_ram, - ]; - - let [hv0, hv1, hv2, hv3] = [HV0, HV1, HV2, HV3].map(curr_base_row); - let hv_product = Self::xb_product([hv1, hv2, hv3], hv0); - - [ - ram_pointer_constraints, - Self::update_dotstep_accumulator(circuit_builder, [ST2, ST3, ST4], hv_product), - Self::instruction_group_step_1(circuit_builder), - Self::instruction_group_no_io(circuit_builder), - Self::instruction_group_op_stack_remains_except_top_n(circuit_builder, 5), - ] - .concat() - } - - fn transition_constraints_for_instruction( - circuit_builder: &ConstraintCircuitBuilder, - instruction: Instruction, - ) -> Vec> { - match instruction { - Pop(_) => ExtProcessorTable::instruction_pop(circuit_builder), - Push(_) => ExtProcessorTable::instruction_push(circuit_builder), - Divine(_) => ExtProcessorTable::instruction_divine(circuit_builder), - Dup(_) => ExtProcessorTable::instruction_dup(circuit_builder), - Swap(_) => ExtProcessorTable::instruction_swap(circuit_builder), - Halt => ExtProcessorTable::instruction_halt(circuit_builder), - Nop => ExtProcessorTable::instruction_nop(circuit_builder), - Skiz => ExtProcessorTable::instruction_skiz(circuit_builder), - Call(_) => ExtProcessorTable::instruction_call(circuit_builder), - Return => ExtProcessorTable::instruction_return(circuit_builder), - Recurse => ExtProcessorTable::instruction_recurse(circuit_builder), - RecurseOrReturn => ExtProcessorTable::instruction_recurse_or_return(circuit_builder), - Assert => ExtProcessorTable::instruction_assert(circuit_builder), - ReadMem(_) => ExtProcessorTable::instruction_read_mem(circuit_builder), - WriteMem(_) => ExtProcessorTable::instruction_write_mem(circuit_builder), - Hash => ExtProcessorTable::instruction_hash(circuit_builder), - AssertVector => ExtProcessorTable::instruction_assert_vector(circuit_builder), - SpongeInit => ExtProcessorTable::instruction_sponge_init(circuit_builder), - SpongeAbsorb => ExtProcessorTable::instruction_sponge_absorb(circuit_builder), - SpongeAbsorbMem => ExtProcessorTable::instruction_sponge_absorb_mem(circuit_builder), - SpongeSqueeze => ExtProcessorTable::instruction_sponge_squeeze(circuit_builder), - Add => ExtProcessorTable::instruction_add(circuit_builder), - AddI(_) => ExtProcessorTable::instruction_addi(circuit_builder), - Mul => ExtProcessorTable::instruction_mul(circuit_builder), - Invert => ExtProcessorTable::instruction_invert(circuit_builder), - Eq => ExtProcessorTable::instruction_eq(circuit_builder), - Split => ExtProcessorTable::instruction_split(circuit_builder), - Lt => ExtProcessorTable::instruction_lt(circuit_builder), - And => ExtProcessorTable::instruction_and(circuit_builder), - Xor => ExtProcessorTable::instruction_xor(circuit_builder), - Log2Floor => ExtProcessorTable::instruction_log_2_floor(circuit_builder), - Pow => ExtProcessorTable::instruction_pow(circuit_builder), - DivMod => ExtProcessorTable::instruction_div_mod(circuit_builder), - PopCount => ExtProcessorTable::instruction_pop_count(circuit_builder), - XxAdd => ExtProcessorTable::instruction_xx_add(circuit_builder), - XxMul => ExtProcessorTable::instruction_xx_mul(circuit_builder), - XInvert => ExtProcessorTable::instruction_xinv(circuit_builder), - XbMul => ExtProcessorTable::instruction_xb_mul(circuit_builder), - ReadIo(_) => ExtProcessorTable::instruction_read_io(circuit_builder), - WriteIo(_) => ExtProcessorTable::instruction_write_io(circuit_builder), - MerkleStep => ExtProcessorTable::instruction_merkle_step(circuit_builder), - MerkleStepMem => ExtProcessorTable::instruction_merkle_step_mem(circuit_builder), - XxDotStep => ExtProcessorTable::instruction_xx_dot_step(circuit_builder), - XbDotStep => ExtProcessorTable::instruction_xb_dot_step(circuit_builder), - } - } - - /// Constrains instruction argument `nia` such that 0 < nia <= 5. - fn prohibit_any_illegal_number_of_words( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - vec![NumberOfWords::illegal_values() - .map(|n| Self::indicator_polynomial(circuit_builder, n)) - .into_iter() - .sum()] - } - - fn log_derivative_accumulates_clk_next( - circuit_builder: &ConstraintCircuitBuilder, - ) -> ConstraintCircuitMonad { - let challenge = |c: ChallengeId| circuit_builder.challenge(c); - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; - - (next_ext_row(ClockJumpDifferenceLookupServerLogDerivative) - - curr_ext_row(ClockJumpDifferenceLookupServerLogDerivative)) - * (challenge(ClockJumpDifferenceLookupIndeterminate) - next_base_row(CLK)) - - next_base_row(ClockJumpDifferenceLookupMultiplicity) - } - - fn running_evaluation_for_standard_input_remains_unchanged( - circuit_builder: &ConstraintCircuitBuilder, - ) -> ConstraintCircuitMonad { - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; - - next_ext_row(InputTableEvalArg) - curr_ext_row(InputTableEvalArg) - } - - fn running_evaluation_for_standard_output_remains_unchanged( - circuit_builder: &ConstraintCircuitBuilder, - ) -> ConstraintCircuitMonad { - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; - - next_ext_row(OutputTableEvalArg) - curr_ext_row(OutputTableEvalArg) - } - - fn grow_stack_by_n_and_read_n_symbols_from_input( - circuit_builder: &ConstraintCircuitBuilder, - n: usize, - ) -> Vec> { - let indeterminate = || circuit_builder.challenge(StandardInputIndeterminate); - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; - - let mut running_evaluation = curr_ext_row(InputTableEvalArg); - for i in (0..n).rev() { - let stack_element = ProcessorTable::op_stack_column_by_index(i); - running_evaluation = - indeterminate() * running_evaluation + next_base_row(stack_element); - } - let running_evaluation_update = next_ext_row(InputTableEvalArg) - running_evaluation; - let conditional_running_evaluation_update = - Self::indicator_polynomial(circuit_builder, n) * running_evaluation_update; - - let mut constraints = - Self::conditional_constraints_for_growing_stack_by(circuit_builder, n); - constraints.push(conditional_running_evaluation_update); - constraints - } - - fn shrink_stack_by_n_and_write_n_symbols_to_output( - circuit_builder: &ConstraintCircuitBuilder, - n: usize, - ) -> Vec> { - let indeterminate = || circuit_builder.challenge(StandardOutputIndeterminate); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; - - let mut running_evaluation = curr_ext_row(OutputTableEvalArg); - for i in 0..n { - let stack_element = ProcessorTable::op_stack_column_by_index(i); - running_evaluation = - indeterminate() * running_evaluation + curr_base_row(stack_element); - } - let running_evaluation_update = next_ext_row(OutputTableEvalArg) - running_evaluation; - let conditional_running_evaluation_update = - Self::indicator_polynomial(circuit_builder, n) * running_evaluation_update; - - let mut constraints = - Self::conditional_constraints_for_shrinking_stack_by(circuit_builder, n); - constraints.push(conditional_running_evaluation_update); - constraints - } - - fn log_derivative_for_instruction_lookup_updates_correctly( - circuit_builder: &ConstraintCircuitBuilder, - ) -> ConstraintCircuitMonad { - let one = || circuit_builder.b_constant(1); - let challenge = |c: ChallengeId| circuit_builder.challenge(c); - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; - - let compressed_row = challenge(ProgramAddressWeight) * next_base_row(IP) - + challenge(ProgramInstructionWeight) * next_base_row(CI) - + challenge(ProgramNextInstructionWeight) * next_base_row(NIA); - let log_derivative_updates = (next_ext_row(InstructionLookupClientLogDerivative) - - curr_ext_row(InstructionLookupClientLogDerivative)) - * (challenge(InstructionLookupIndeterminate) - compressed_row) - - one(); - let log_derivative_remains = next_ext_row(InstructionLookupClientLogDerivative) - - curr_ext_row(InstructionLookupClientLogDerivative); - - (one() - next_base_row(IsPadding)) * log_derivative_updates - + next_base_row(IsPadding) * log_derivative_remains - } - - fn constraints_for_shrinking_stack_by_3_and_top_3_unconstrained( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constant = |c: u64| circuit_builder.b_constant(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - vec![ - next_base_row(ST3) - curr_base_row(ST6), - next_base_row(ST4) - curr_base_row(ST7), - next_base_row(ST5) - curr_base_row(ST8), - next_base_row(ST6) - curr_base_row(ST9), - next_base_row(ST7) - curr_base_row(ST10), - next_base_row(ST8) - curr_base_row(ST11), - next_base_row(ST9) - curr_base_row(ST12), - next_base_row(ST10) - curr_base_row(ST13), - next_base_row(ST11) - curr_base_row(ST14), - next_base_row(ST12) - curr_base_row(ST15), - next_base_row(OpStackPointer) - curr_base_row(OpStackPointer) + constant(3), - Self::running_product_op_stack_accounts_for_shrinking_stack_by(circuit_builder, 3), - ] - } - - fn stack_shrinks_by_any_of( - circuit_builder: &ConstraintCircuitBuilder, - shrinkages: &[usize], - ) -> Vec> { - let all_constraints_for_all_shrinkages = shrinkages - .iter() - .map(|&n| Self::conditional_constraints_for_shrinking_stack_by(circuit_builder, n)) - .collect_vec(); - - Self::combine_mutually_exclusive_constraint_groups( - circuit_builder, - all_constraints_for_all_shrinkages, - ) - } - - fn stack_grows_by_any_of( - circuit_builder: &ConstraintCircuitBuilder, - growths: &[usize], - ) -> Vec> { - let all_constraints_for_all_growths = growths - .iter() - .map(|&n| Self::conditional_constraints_for_growing_stack_by(circuit_builder, n)) - .collect_vec(); - - Self::combine_mutually_exclusive_constraint_groups( - circuit_builder, - all_constraints_for_all_growths, - ) - } - - /// Reduces the number of constraints by summing mutually exclusive constraints. The mutual - /// exclusion is due to the conditional nature of the constraints, which has to be guaranteed by - /// the caller. - /// - /// For example, the constraints for shrinking the stack by 2, 3, and 4 elements are: - /// - /// ```markdown - /// | shrink by 2 | shrink by 3 | shrink by 4 | - /// |-----------------------:|-----------------------:|-----------------------:| - /// | ind_2·(st0' - st2) | ind_3·(st0' - st3) | ind_4·(st0' - st4) | - /// | ind_2·(st1' - st3) | ind_3·(st1' - st4) | ind_4·(st1' - st5) | - /// | … | … | … | - /// | ind_2·(st11' - st13) | ind_3·(st11' - st14) | ind_4·(st11' - st15) | - /// | ind_2·(st12' - st14) | ind_3·(st12' - st15) | ind_4·(osp' - osp + 4) | - /// | ind_2·(st13' - st15) | ind_3·(osp' - osp + 3) | ind_4·(rp' - rp·fac_4) | - /// | ind_2·(osp' - osp + 2) | ind_3·(rp' - rp·fac_3) | | - /// | ind_2·(rp' - rp·fac_2) | | | - /// ``` - /// - /// This method sums these constraints “per row”. That is, the resulting constraints are: - /// - /// ```markdown - /// | shrink by 2 or 3 or 4 | - /// |-----------------------------------------------------------------------:| - /// | ind_2·(st0' - st2) + ind_3·(st0' - st3) + ind_4·(st0' - st4) | - /// | ind_2·(st1' - st3) + ind_3·(st1' - st4) + ind_4·(st1' - st5) | - /// | … | - /// | ind_2·(st11' - st13) + ind_3·(st11' - st14) + ind_4·(st11' - st15) | - /// | ind_2·(st12' - st14) + ind_3·(st12' - st15) + ind_4·(osp' - osp + 4) | - /// | ind_2·(st13' - st15) + ind_3·(osp' - osp + 3) + ind_4·(rp' - rp·fac_4) | - /// | ind_2·(osp' - osp + 2) + ind_3·(rp' - rp·fac_3) | - /// | ind_2·(rp' - rp·fac_2) | - /// ``` - /// - /// Syntax in above example: - /// - `ind_n` is the [indicator polynomial](Self::indicator_polynomial) for `n` - /// - `osp` is the [op stack pointer](OpStackPointer) - /// - `rp` is the running product for the permutation argument - /// - `fac_n` is the factor for the running product - fn combine_mutually_exclusive_constraint_groups( - circuit_builder: &ConstraintCircuitBuilder, - all_constraint_groups: Vec>>, - ) -> Vec> { - let constraint_group_lengths = all_constraint_groups.iter().map(|x| x.len()); - let num_constraints = constraint_group_lengths.max().unwrap_or(0); - - let zero_constraint = || circuit_builder.b_constant(0); - let mut combined_constraints = vec![]; - for i in 0..num_constraints { - let combined_constraint = all_constraint_groups - .iter() - .filter_map(|constraint_group| constraint_group.get(i)) - .fold(zero_constraint(), |acc, summand| acc + summand.clone()); - combined_constraints.push(combined_constraint); - } - combined_constraints - } - - fn constraints_for_shrinking_stack_by( - circuit_builder: &ConstraintCircuitBuilder, - n: usize, - ) -> Vec> { - let constant = |c: usize| circuit_builder.b_constant(u64::try_from(c).unwrap()); - let curr_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let stack = || (0..OpStackElement::COUNT).map(ProcessorTable::op_stack_column_by_index); - let new_stack = stack().dropping_back(n).map(next_row).collect_vec(); - let old_stack_with_top_n_removed = stack().skip(n).map(curr_row).collect_vec(); - - let compress = |stack: Vec<_>| -> ConstraintCircuitMonad<_> { - assert_eq!(OpStackElement::COUNT - n, stack.len()); - let weight = |i| circuit_builder.challenge(Self::stack_weight_by_index(i)); - let enumerated_stack = stack.into_iter().enumerate(); - enumerated_stack.map(|(i, st)| weight(i) * st).sum() - }; - let compressed_new_stack = compress(new_stack); - let compressed_old_stack = compress(old_stack_with_top_n_removed); - - let op_stack_pointer_shrinks_by_n = - next_row(OpStackPointer) - curr_row(OpStackPointer) + constant(n); - let new_stack_is_old_stack_with_top_n_removed = compressed_new_stack - compressed_old_stack; - - vec![ - op_stack_pointer_shrinks_by_n, - new_stack_is_old_stack_with_top_n_removed, - Self::running_product_op_stack_accounts_for_shrinking_stack_by(circuit_builder, n), - ] - } - - fn constraints_for_growing_stack_by( - circuit_builder: &ConstraintCircuitBuilder, - n: usize, - ) -> Vec> { - let constant = |c: usize| circuit_builder.b_constant(u32::try_from(c).unwrap()); - let curr_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let stack = || (0..OpStackElement::COUNT).map(ProcessorTable::op_stack_column_by_index); - let new_stack = stack().skip(n).map(next_row).collect_vec(); - let old_stack_with_top_n_added = stack().map(curr_row).dropping_back(n).collect_vec(); - - let compress = |stack: Vec<_>| -> ConstraintCircuitMonad<_> { - assert_eq!(OpStackElement::COUNT - n, stack.len()); - let weight = |i| circuit_builder.challenge(Self::stack_weight_by_index(i)); - let enumerated_stack = stack.into_iter().enumerate(); - enumerated_stack.map(|(i, st)| weight(i) * st).sum() - }; - let compressed_new_stack = compress(new_stack); - let compressed_old_stack = compress(old_stack_with_top_n_added); - - let op_stack_pointer_grows_by_n = - next_row(OpStackPointer) - curr_row(OpStackPointer) - constant(n); - let new_stack_is_old_stack_with_top_n_added = compressed_new_stack - compressed_old_stack; - - vec![ - op_stack_pointer_grows_by_n, - new_stack_is_old_stack_with_top_n_added, - Self::running_product_op_stack_accounts_for_growing_stack_by(circuit_builder, n), - ] - } - - fn conditional_constraints_for_shrinking_stack_by( - circuit_builder: &ConstraintCircuitBuilder, - n: usize, - ) -> Vec> { - Self::constraints_for_shrinking_stack_by(circuit_builder, n) - .into_iter() - .map(|constraint| Self::indicator_polynomial(circuit_builder, n) * constraint) - .collect() - } - - fn conditional_constraints_for_growing_stack_by( - circuit_builder: &ConstraintCircuitBuilder, - n: usize, - ) -> Vec> { - Self::constraints_for_growing_stack_by(circuit_builder, n) - .into_iter() - .map(|constraint| Self::indicator_polynomial(circuit_builder, n) * constraint) - .collect() - } - - fn running_product_op_stack_accounts_for_growing_stack_by( - circuit_builder: &ConstraintCircuitBuilder, - n: usize, - ) -> ConstraintCircuitMonad { - let constant = |c: u32| circuit_builder.b_constant(c); - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; - let single_grow_factor = |op_stack_pointer_offset| { - Self::single_factor_for_permutation_argument_with_op_stack_table( - circuit_builder, - CurrentBaseRow, - op_stack_pointer_offset, - ) - }; - - let mut factor = constant(1); - for op_stack_pointer_offset in 0..n { - factor = factor * single_grow_factor(op_stack_pointer_offset); - } - - next_ext_row(OpStackTablePermArg) - curr_ext_row(OpStackTablePermArg) * factor - } - - fn running_product_op_stack_accounts_for_shrinking_stack_by( - circuit_builder: &ConstraintCircuitBuilder, - n: usize, - ) -> ConstraintCircuitMonad { - let constant = |c: u32| circuit_builder.b_constant(c); - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; - let single_shrink_factor = |op_stack_pointer_offset| { - Self::single_factor_for_permutation_argument_with_op_stack_table( - circuit_builder, - NextBaseRow, - op_stack_pointer_offset, - ) - }; - - let mut factor = constant(1); - for op_stack_pointer_offset in 0..n { - factor = factor * single_shrink_factor(op_stack_pointer_offset); - } - - next_ext_row(OpStackTablePermArg) - curr_ext_row(OpStackTablePermArg) * factor - } - - fn single_factor_for_permutation_argument_with_op_stack_table( - circuit_builder: &ConstraintCircuitBuilder, - row_with_shorter_stack_indicator: fn(usize) -> DualRowIndicator, - op_stack_pointer_offset: usize, - ) -> ConstraintCircuitMonad { - let constant = |c: u32| circuit_builder.b_constant(c); - let challenge = |c: ChallengeId| circuit_builder.challenge(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let row_with_shorter_stack = |col: ProcessorBaseTableColumn| { - circuit_builder.input(row_with_shorter_stack_indicator( - col.master_base_table_index(), - )) - }; - - let max_stack_element_index = OpStackElement::COUNT - 1; - let stack_element_index = max_stack_element_index - op_stack_pointer_offset; - let stack_element = ProcessorTable::op_stack_column_by_index(stack_element_index); - let underflow_element = row_with_shorter_stack(stack_element); - - let op_stack_pointer = row_with_shorter_stack(OpStackPointer); - let offset = constant(op_stack_pointer_offset as u32); - let offset_op_stack_pointer = op_stack_pointer + offset; - - let compressed_row = challenge(OpStackClkWeight) * curr_base_row(CLK) - + challenge(OpStackIb1Weight) * curr_base_row(IB1) - + challenge(OpStackPointerWeight) * offset_op_stack_pointer - + challenge(OpStackFirstUnderflowElementWeight) * underflow_element; - challenge(OpStackIndeterminate) - compressed_row - } - - /// Build constraints for popping `n` elements from the top of the stack and - /// writing them to RAM. The reciprocal of [`Self::read_from_ram_any_of`]. - fn write_to_ram_any_of( - circuit_builder: &ConstraintCircuitBuilder, - number_of_words: &[usize], - ) -> Vec> { - let all_constraint_groups = number_of_words - .iter() - .map(|&n| { - Self::conditional_constraints_for_writing_n_elements_to_ram(circuit_builder, n) - }) - .collect_vec(); - Self::combine_mutually_exclusive_constraint_groups(circuit_builder, all_constraint_groups) - } - - /// Build constraints for reading `n` elements from RAM and putting them on top - /// of the stack. The reciprocal of [`Self::write_to_ram_any_of`]. - /// - /// To constrain RAM reads with more flexible target locations, see - /// [`Self::read_from_ram_to`]. - fn read_from_ram_any_of( - circuit_builder: &ConstraintCircuitBuilder, - number_of_words: &[usize], - ) -> Vec> { - let all_constraint_groups = number_of_words - .iter() - .map(|&n| { - Self::conditional_constraints_for_reading_n_elements_from_ram(circuit_builder, n) - }) - .collect_vec(); - Self::combine_mutually_exclusive_constraint_groups(circuit_builder, all_constraint_groups) - } - - fn conditional_constraints_for_writing_n_elements_to_ram( - circuit_builder: &ConstraintCircuitBuilder, - n: usize, - ) -> Vec> { - Self::shrink_stack_by_n_and_write_n_elements_to_ram(circuit_builder, n) - .into_iter() - .map(|constraint| Self::indicator_polynomial(circuit_builder, n) * constraint) - .collect() - } - - fn conditional_constraints_for_reading_n_elements_from_ram( - circuit_builder: &ConstraintCircuitBuilder, - n: usize, - ) -> Vec> { - Self::grow_stack_by_n_and_read_n_elements_from_ram(circuit_builder, n) - .into_iter() - .map(|constraint| Self::indicator_polynomial(circuit_builder, n) * constraint) - .collect() - } - - fn shrink_stack_by_n_and_write_n_elements_to_ram( - circuit_builder: &ConstraintCircuitBuilder, - n: usize, - ) -> Vec> { - let constant = |c: usize| circuit_builder.b_constant(u32::try_from(c).unwrap()); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let op_stack_pointer_shrinks_by_n = - next_base_row(OpStackPointer) - curr_base_row(OpStackPointer) + constant(n); - let ram_pointer_grows_by_n = next_base_row(ST0) - curr_base_row(ST0) - constant(n); - - let mut constraints = vec![ - op_stack_pointer_shrinks_by_n, - ram_pointer_grows_by_n, - Self::running_product_op_stack_accounts_for_shrinking_stack_by(circuit_builder, n), - Self::running_product_ram_accounts_for_writing_n_elements(circuit_builder, n), - ]; - - let num_ram_pointers = 1; - for i in n + num_ram_pointers..OpStackElement::COUNT { - let curr_stack_element = ProcessorTable::op_stack_column_by_index(i); - let next_stack_element = ProcessorTable::op_stack_column_by_index(i - n); - let element_i_is_shifted_by_n = - next_base_row(next_stack_element) - curr_base_row(curr_stack_element); - constraints.push(element_i_is_shifted_by_n); - } - constraints - } - - fn grow_stack_by_n_and_read_n_elements_from_ram( - circuit_builder: &ConstraintCircuitBuilder, - n: usize, - ) -> Vec> { - let constant = |c: usize| circuit_builder.b_constant(u64::try_from(c).unwrap()); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let op_stack_pointer_grows_by_n = - next_base_row(OpStackPointer) - curr_base_row(OpStackPointer) - constant(n); - let ram_pointer_shrinks_by_n = next_base_row(ST0) - curr_base_row(ST0) + constant(n); - - let mut constraints = vec![ - op_stack_pointer_grows_by_n, - ram_pointer_shrinks_by_n, - Self::running_product_op_stack_accounts_for_growing_stack_by(circuit_builder, n), - Self::running_product_ram_accounts_for_reading_n_elements(circuit_builder, n), - ]; - - let num_ram_pointers = 1; - for i in num_ram_pointers..OpStackElement::COUNT - n { - let curr_stack_element = ProcessorTable::op_stack_column_by_index(i); - let next_stack_element = ProcessorTable::op_stack_column_by_index(i + n); - let element_i_is_shifted_by_n = - next_base_row(next_stack_element) - curr_base_row(curr_stack_element); - constraints.push(element_i_is_shifted_by_n); - } - constraints - } - - fn running_product_ram_accounts_for_writing_n_elements( - circuit_builder: &ConstraintCircuitBuilder, - n: usize, - ) -> ConstraintCircuitMonad { - let constant = |c: u32| circuit_builder.b_constant(c); - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; - let single_write_factor = |ram_pointer_offset| { - Self::single_factor_for_permutation_argument_with_ram_table( - circuit_builder, - CurrentBaseRow, - ram_table::INSTRUCTION_TYPE_WRITE, - ram_pointer_offset, - ) - }; - - let mut factor = constant(1); - for ram_pointer_offset in 0..n { - factor = factor * single_write_factor(ram_pointer_offset); - } - - next_ext_row(RamTablePermArg) - curr_ext_row(RamTablePermArg) * factor - } - - fn running_product_ram_accounts_for_reading_n_elements( - circuit_builder: &ConstraintCircuitBuilder, - n: usize, - ) -> ConstraintCircuitMonad { - let constant = |c: u32| circuit_builder.b_constant(c); - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; - let single_read_factor = |ram_pointer_offset| { - Self::single_factor_for_permutation_argument_with_ram_table( - circuit_builder, - NextBaseRow, - ram_table::INSTRUCTION_TYPE_READ, - ram_pointer_offset, - ) - }; - - let mut factor = constant(1); - for ram_pointer_offset in 0..n { - factor = factor * single_read_factor(ram_pointer_offset); - } - - next_ext_row(RamTablePermArg) - curr_ext_row(RamTablePermArg) * factor - } - - fn single_factor_for_permutation_argument_with_ram_table( - circuit_builder: &ConstraintCircuitBuilder, - row_with_longer_stack_indicator: fn(usize) -> DualRowIndicator, - instruction_type: BFieldElement, - ram_pointer_offset: usize, - ) -> ConstraintCircuitMonad { - let constant = |c: u32| circuit_builder.b_constant(c); - let b_constant = |c| circuit_builder.b_constant(c); - let challenge = |c: ChallengeId| circuit_builder.challenge(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let row_with_longer_stack = |col: ProcessorBaseTableColumn| { - circuit_builder.input(row_with_longer_stack_indicator( - col.master_base_table_index(), - )) - }; - - let num_ram_pointers = 1; - let ram_value_index = ram_pointer_offset + num_ram_pointers; - let ram_value_column = ProcessorTable::op_stack_column_by_index(ram_value_index); - let ram_value = row_with_longer_stack(ram_value_column); - - let additional_offset = match instruction_type { - ram_table::INSTRUCTION_TYPE_READ => 1, - ram_table::INSTRUCTION_TYPE_WRITE => 0, - _ => panic!("Invalid instruction type"), - }; - - let ram_pointer = row_with_longer_stack(ST0); - let offset = constant(additional_offset + ram_pointer_offset as u32); - let offset_ram_pointer = ram_pointer + offset; - - let compressed_row = curr_base_row(CLK) * challenge(RamClkWeight) - + b_constant(instruction_type) * challenge(RamInstructionTypeWeight) - + offset_ram_pointer * challenge(RamPointerWeight) - + ram_value * challenge(RamValueWeight); - challenge(RamIndeterminate) - compressed_row - } - - fn running_product_for_jump_stack_table_updates_correctly( - circuit_builder: &ConstraintCircuitBuilder, - ) -> ConstraintCircuitMonad { - let challenge = |c: ChallengeId| circuit_builder.challenge(c); - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; - - let compressed_row = challenge(JumpStackClkWeight) * next_base_row(CLK) - + challenge(JumpStackCiWeight) * next_base_row(CI) - + challenge(JumpStackJspWeight) * next_base_row(JSP) - + challenge(JumpStackJsoWeight) * next_base_row(JSO) - + challenge(JumpStackJsdWeight) * next_base_row(JSD); - - next_ext_row(JumpStackTablePermArg) - - curr_ext_row(JumpStackTablePermArg) - * (challenge(JumpStackIndeterminate) - compressed_row) - } - - /// Deal with instructions `hash` and `merkle_step`. The registers from which - /// the preimage is loaded differs between the two instructions: - /// 1. `hash` always loads the stack's 10 top elements, - /// 1. `merkle_step` loads the stack's 5 top elements and helper variables 0 - /// through 4. The order of those two quintuplets depends on helper variable - /// hv5. - /// - /// The Hash Table does not “know” about instruction `merkle_step`. - /// - /// Note that using `next_row` might be confusing at first glance; See the - /// [specification](https://triton-vm.org/spec/processor-table.html). - fn running_evaluation_hash_input_updates_correctly( - circuit_builder: &ConstraintCircuitBuilder, - ) -> ConstraintCircuitMonad { - let constant = |c: u32| circuit_builder.b_constant(c); - let one = || constant(1); - let challenge = |c: ChallengeId| circuit_builder.challenge(c); - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; - - let hash_deselector = - Self::instruction_deselector_next_row(circuit_builder, Instruction::Hash); - let merkle_step_deselector = - Self::instruction_deselector_next_row(circuit_builder, Instruction::MerkleStep); - let merkle_step_mem_deselector = - Self::instruction_deselector_next_row(circuit_builder, Instruction::MerkleStepMem); - let hash_and_merkle_step_selector = (next_base_row(CI) - - constant(Instruction::Hash.opcode())) - * (next_base_row(CI) - constant(Instruction::MerkleStep.opcode())) - * (next_base_row(CI) - constant(Instruction::MerkleStepMem.opcode())); - - let weights = [ - StackWeight0, - StackWeight1, - StackWeight2, - StackWeight3, - StackWeight4, - StackWeight5, - StackWeight6, - StackWeight7, - StackWeight8, - StackWeight9, - ] - .map(challenge); - - // hash - let state_for_hash = [ST0, ST1, ST2, ST3, ST4, ST5, ST6, ST7, ST8, ST9].map(next_base_row); - let compressed_hash_row = weights - .iter() - .zip_eq(state_for_hash) - .map(|(weight, state)| weight.clone() * state) - .sum(); - - // merkle step - let is_left_sibling = || next_base_row(HV5); - let is_right_sibling = || one() - next_base_row(HV5); - let merkle_step_state_element = - |l, r| is_right_sibling() * next_base_row(l) + is_left_sibling() * next_base_row(r); - let state_for_merkle_step = [ - merkle_step_state_element(ST0, HV0), - merkle_step_state_element(ST1, HV1), - merkle_step_state_element(ST2, HV2), - merkle_step_state_element(ST3, HV3), - merkle_step_state_element(ST4, HV4), - merkle_step_state_element(HV0, ST0), - merkle_step_state_element(HV1, ST1), - merkle_step_state_element(HV2, ST2), - merkle_step_state_element(HV3, ST3), - merkle_step_state_element(HV4, ST4), - ]; - let compressed_merkle_step_row = weights - .into_iter() - .zip_eq(state_for_merkle_step) - .map(|(weight, state)| weight * state) - .sum::>(); - - let running_evaluation_updates_with = |compressed_row| { - next_ext_row(HashInputEvalArg) - - challenge(HashInputIndeterminate) * curr_ext_row(HashInputEvalArg) - - compressed_row - }; - let running_evaluation_remains = - next_ext_row(HashInputEvalArg) - curr_ext_row(HashInputEvalArg); - - hash_and_merkle_step_selector * running_evaluation_remains - + hash_deselector * running_evaluation_updates_with(compressed_hash_row) - + merkle_step_deselector - * running_evaluation_updates_with(compressed_merkle_step_row.clone()) - + merkle_step_mem_deselector - * running_evaluation_updates_with(compressed_merkle_step_row) - } - - fn running_evaluation_hash_digest_updates_correctly( - circuit_builder: &ConstraintCircuitBuilder, - ) -> ConstraintCircuitMonad { - let constant = |c: u32| circuit_builder.b_constant(c); - let challenge = |c: ChallengeId| circuit_builder.challenge(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; - - let hash_deselector = - Self::instruction_deselector_current_row(circuit_builder, Instruction::Hash); - let merkle_step_deselector = - Self::instruction_deselector_current_row(circuit_builder, Instruction::MerkleStep); - let merkle_step_mem_deselector = - Self::instruction_deselector_current_row(circuit_builder, Instruction::MerkleStepMem); - let hash_and_merkle_step_selector = (curr_base_row(CI) - - constant(Instruction::Hash.opcode())) - * (curr_base_row(CI) - constant(Instruction::MerkleStep.opcode())) - * (curr_base_row(CI) - constant(Instruction::MerkleStepMem.opcode())); - - let weights = [ - StackWeight0, - StackWeight1, - StackWeight2, - StackWeight3, - StackWeight4, - ] - .map(challenge); - let state = [ST0, ST1, ST2, ST3, ST4].map(next_base_row); - let compressed_row = weights - .into_iter() - .zip_eq(state) - .map(|(weight, state)| weight * state) - .sum(); - - let running_evaluation_updates = next_ext_row(HashDigestEvalArg) - - challenge(HashDigestIndeterminate) * curr_ext_row(HashDigestEvalArg) - - compressed_row; - let running_evaluation_remains = - next_ext_row(HashDigestEvalArg) - curr_ext_row(HashDigestEvalArg); - - hash_and_merkle_step_selector * running_evaluation_remains - + (hash_deselector + merkle_step_deselector + merkle_step_mem_deselector) - * running_evaluation_updates - } - - fn running_evaluation_sponge_updates_correctly( - circuit_builder: &ConstraintCircuitBuilder, - ) -> ConstraintCircuitMonad { - let constant = |c: u32| circuit_builder.b_constant(c); - let challenge = |c: ChallengeId| circuit_builder.challenge(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; - - let sponge_init_deselector = - Self::instruction_deselector_current_row(circuit_builder, Instruction::SpongeInit); - let sponge_absorb_deselector = - Self::instruction_deselector_current_row(circuit_builder, Instruction::SpongeAbsorb); - let sponge_absorb_mem_deselector = - Self::instruction_deselector_current_row(circuit_builder, Instruction::SpongeAbsorbMem); - let sponge_squeeze_deselector = - Self::instruction_deselector_current_row(circuit_builder, Instruction::SpongeSqueeze); - - let sponge_instruction_selector = (curr_base_row(CI) - - constant(Instruction::SpongeInit.opcode())) - * (curr_base_row(CI) - constant(Instruction::SpongeAbsorb.opcode())) - * (curr_base_row(CI) - constant(Instruction::SpongeAbsorbMem.opcode())) - * (curr_base_row(CI) - constant(Instruction::SpongeSqueeze.opcode())); - - let weighted_sum = |state| { - let weights = [ - StackWeight0, - StackWeight1, - StackWeight2, - StackWeight3, - StackWeight4, - StackWeight5, - StackWeight6, - StackWeight7, - StackWeight8, - StackWeight9, - ]; - let weights = weights.map(challenge).into_iter(); - weights.zip_eq(state).map(|(w, st)| w * st).sum() - }; - - let state = [ST0, ST1, ST2, ST3, ST4, ST5, ST6, ST7, ST8, ST9]; - let compressed_row_current = weighted_sum(state.map(curr_base_row)); - let compressed_row_next = weighted_sum(state.map(next_base_row)); - - // Use domain-specific knowledge: the compressed row (i.e., random linear sum) - // of the initial Sponge state is 0. - let running_evaluation_updates_for_sponge_init = next_ext_row(SpongeEvalArg) - - challenge(SpongeIndeterminate) * curr_ext_row(SpongeEvalArg) - - challenge(HashCIWeight) * curr_base_row(CI); - let running_evaluation_updates_for_absorb = - running_evaluation_updates_for_sponge_init.clone() - compressed_row_current; - let running_evaluation_updates_for_squeeze = - running_evaluation_updates_for_sponge_init.clone() - compressed_row_next; - let running_evaluation_remains = next_ext_row(SpongeEvalArg) - curr_ext_row(SpongeEvalArg); - - // `sponge_absorb_mem` - let stack_elements = [ST1, ST2, ST3, ST4].map(next_base_row); - let hv_elements = [HV0, HV1, HV2, HV3, HV4, HV5].map(curr_base_row); - let absorb_mem_elements = stack_elements.into_iter().chain(hv_elements); - let absorb_mem_elements = absorb_mem_elements.collect_vec().try_into().unwrap(); - let compressed_row_absorb_mem = weighted_sum(absorb_mem_elements); - let running_evaluation_updates_for_absorb_mem = next_ext_row(SpongeEvalArg) - - challenge(SpongeIndeterminate) * curr_ext_row(SpongeEvalArg) - - challenge(HashCIWeight) * constant(Instruction::SpongeAbsorb.opcode()) - - compressed_row_absorb_mem; - - sponge_instruction_selector * running_evaluation_remains - + sponge_init_deselector * running_evaluation_updates_for_sponge_init - + sponge_absorb_deselector * running_evaluation_updates_for_absorb - + sponge_absorb_mem_deselector * running_evaluation_updates_for_absorb_mem - + sponge_squeeze_deselector * running_evaluation_updates_for_squeeze - } - - fn log_derivative_with_u32_table_updates_correctly( - circuit_builder: &ConstraintCircuitBuilder, - ) -> ConstraintCircuitMonad { - let constant = |c: u32| circuit_builder.b_constant(c); - let one = || constant(1); - let two_inverse = circuit_builder.b_constant(bfe!(2).inverse()); - let challenge = |c: ChallengeId| circuit_builder.challenge(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; - - let split_deselector = - Self::instruction_deselector_current_row(circuit_builder, Instruction::Split); - let lt_deselector = - Self::instruction_deselector_current_row(circuit_builder, Instruction::Lt); - let and_deselector = - Self::instruction_deselector_current_row(circuit_builder, Instruction::And); - let xor_deselector = - Self::instruction_deselector_current_row(circuit_builder, Instruction::Xor); - let pow_deselector = - Self::instruction_deselector_current_row(circuit_builder, Instruction::Pow); - let log_2_floor_deselector = - Self::instruction_deselector_current_row(circuit_builder, Instruction::Log2Floor); - let div_mod_deselector = - Self::instruction_deselector_current_row(circuit_builder, Instruction::DivMod); - let pop_count_deselector = - Self::instruction_deselector_current_row(circuit_builder, Instruction::PopCount); - let merkle_step_deselector = - Self::instruction_deselector_current_row(circuit_builder, Instruction::MerkleStep); - let merkle_step_mem_deselector = - Self::instruction_deselector_current_row(circuit_builder, Instruction::MerkleStepMem); - - let running_sum = curr_ext_row(U32LookupClientLogDerivative); - let running_sum_next = next_ext_row(U32LookupClientLogDerivative); - - let split_factor = challenge(U32Indeterminate) - - challenge(U32LhsWeight) * next_base_row(ST0) - - challenge(U32RhsWeight) * next_base_row(ST1) - - challenge(U32CiWeight) * curr_base_row(CI); - let binop_factor = challenge(U32Indeterminate) - - challenge(U32LhsWeight) * curr_base_row(ST0) - - challenge(U32RhsWeight) * curr_base_row(ST1) - - challenge(U32CiWeight) * curr_base_row(CI) - - challenge(U32ResultWeight) * next_base_row(ST0); - let xor_factor = challenge(U32Indeterminate) - - challenge(U32LhsWeight) * curr_base_row(ST0) - - challenge(U32RhsWeight) * curr_base_row(ST1) - - challenge(U32CiWeight) * constant(Instruction::And.opcode()) - - challenge(U32ResultWeight) - * (curr_base_row(ST0) + curr_base_row(ST1) - next_base_row(ST0)) - * two_inverse; - let unop_factor = challenge(U32Indeterminate) - - challenge(U32LhsWeight) * curr_base_row(ST0) - - challenge(U32CiWeight) * curr_base_row(CI) - - challenge(U32ResultWeight) * next_base_row(ST0); - let div_mod_factor_for_lt = challenge(U32Indeterminate) - - challenge(U32LhsWeight) * next_base_row(ST0) - - challenge(U32RhsWeight) * curr_base_row(ST1) - - challenge(U32CiWeight) * constant(Instruction::Lt.opcode()) - - challenge(U32ResultWeight); - let div_mod_factor_for_range_check = challenge(U32Indeterminate) - - challenge(U32LhsWeight) * curr_base_row(ST0) - - challenge(U32RhsWeight) * next_base_row(ST1) - - challenge(U32CiWeight) * constant(Instruction::Split.opcode()); - let merkle_step_range_check_factor = challenge(U32Indeterminate) - - challenge(U32LhsWeight) * curr_base_row(ST5) - - challenge(U32RhsWeight) * next_base_row(ST5) - - challenge(U32CiWeight) * constant(Instruction::Split.opcode()); - - let running_sum_absorbs_split_factor = - (running_sum_next.clone() - running_sum.clone()) * split_factor - one(); - let running_sum_absorbs_binop_factor = - (running_sum_next.clone() - running_sum.clone()) * binop_factor - one(); - let running_sum_absorb_xor_factor = - (running_sum_next.clone() - running_sum.clone()) * xor_factor - one(); - let running_sum_absorbs_unop_factor = - (running_sum_next.clone() - running_sum.clone()) * unop_factor - one(); - let running_sum_absorbs_merkle_step_factor = - (running_sum_next.clone() - running_sum.clone()) * merkle_step_range_check_factor - - one(); - - let split_summand = split_deselector * running_sum_absorbs_split_factor; - let lt_summand = lt_deselector * running_sum_absorbs_binop_factor.clone(); - let and_summand = and_deselector * running_sum_absorbs_binop_factor.clone(); - let xor_summand = xor_deselector * running_sum_absorb_xor_factor; - let pow_summand = pow_deselector * running_sum_absorbs_binop_factor; - let log_2_floor_summand = log_2_floor_deselector * running_sum_absorbs_unop_factor.clone(); - let div_mod_summand = div_mod_deselector - * ((running_sum_next.clone() - running_sum.clone()) - * div_mod_factor_for_lt.clone() - * div_mod_factor_for_range_check.clone() - - div_mod_factor_for_lt - - div_mod_factor_for_range_check); - let pop_count_summand = pop_count_deselector * running_sum_absorbs_unop_factor; - let merkle_step_summand = - merkle_step_deselector * running_sum_absorbs_merkle_step_factor.clone(); - let merkle_step_mem_summand = - merkle_step_mem_deselector * running_sum_absorbs_merkle_step_factor; - let no_update_summand = (one() - curr_base_row(IB2)) * (running_sum_next - running_sum); - - split_summand - + lt_summand - + and_summand - + xor_summand - + pow_summand - + log_2_floor_summand - + div_mod_summand - + pop_count_summand - + merkle_step_summand - + merkle_step_mem_summand - + no_update_summand - } - - fn stack_weight_by_index(i: usize) -> ChallengeId { - match i { - 0 => StackWeight0, - 1 => StackWeight1, - 2 => StackWeight2, - 3 => StackWeight3, - 4 => StackWeight4, - 5 => StackWeight5, - 6 => StackWeight6, - 7 => StackWeight7, - 8 => StackWeight8, - 9 => StackWeight9, - 10 => StackWeight10, - 11 => StackWeight11, - 12 => StackWeight12, - 13 => StackWeight13, - 14 => StackWeight14, - 15 => StackWeight15, - i => panic!("Op Stack weight index must be in [0, 15], not {i}."), - } - } - - pub fn transition_constraints( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constant = |c: u64| circuit_builder.b_constant(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - // constraints common to all instructions - let clk_increases_by_1 = next_base_row(CLK) - curr_base_row(CLK) - constant(1); - let is_padding_is_0_or_does_not_change = - curr_base_row(IsPadding) * (next_base_row(IsPadding) - curr_base_row(IsPadding)); - - let instruction_independent_constraints = - vec![clk_increases_by_1, is_padding_is_0_or_does_not_change]; - - // instruction-specific constraints - let transition_constraints_for_instruction = - |instr| Self::transition_constraints_for_instruction(circuit_builder, instr); - let all_instructions_and_their_transition_constraints = - ALL_INSTRUCTIONS.map(|instr| (instr, transition_constraints_for_instruction(instr))); - let deselected_transition_constraints = - Self::combine_instruction_constraints_with_deselectors( - circuit_builder, - all_instructions_and_their_transition_constraints, - ); - - // if next row is padding row: disable transition constraints, enable padding constraints - let doubly_deselected_transition_constraints = - Self::combine_transition_constraints_with_padding_constraints( - circuit_builder, - deselected_transition_constraints, - ); - - let table_linking_constraints = vec![ - Self::log_derivative_accumulates_clk_next(circuit_builder), - Self::log_derivative_for_instruction_lookup_updates_correctly(circuit_builder), - Self::running_product_for_jump_stack_table_updates_correctly(circuit_builder), - Self::running_evaluation_hash_input_updates_correctly(circuit_builder), - Self::running_evaluation_hash_digest_updates_correctly(circuit_builder), - Self::running_evaluation_sponge_updates_correctly(circuit_builder), - Self::log_derivative_with_u32_table_updates_correctly(circuit_builder), - ]; - - [ - instruction_independent_constraints, - doubly_deselected_transition_constraints, - table_linking_constraints, - ] - .concat() - } - - pub fn terminal_constraints( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(BaseRow(col.master_base_table_index())) - }; - let constant = |c| circuit_builder.b_constant(c); - - // In the last row, register “current instruction” `ci` corresponds to instruction `halt`. - let last_ci_is_halt = base_row(CI) - constant(Instruction::Halt.opcode_b()); - - vec![last_ci_is_halt] - } -} - -#[cfg(test)] -pub(crate) mod tests { - use std::collections::HashMap; - - use assert2::assert; - use isa::instruction::Instruction; - use isa::op_stack::NumberOfWords::*; - use isa::op_stack::OpStackElement; - use isa::program::Program; - use isa::triton_asm; - use isa::triton_program; - use ndarray::Array2; - use proptest::collection::vec; - use proptest::prop_assert_eq; - use proptest_arbitrary_interop::arb; - use rand::thread_rng; - use rand::Rng; - use strum::IntoEnumIterator; - use test_strategy::proptest; - - use crate::error::InstructionError::DivisionByZero; - use crate::prelude::PublicInput; - use crate::shared_tests::ProgramAndInput; - use crate::stark::tests::master_tables_for_low_security_level; - use crate::table::master_table::*; - use crate::vm::VMState; - use crate::vm::NUM_HELPER_VARIABLE_REGISTERS; - use crate::vm::VM; - use crate::NonDeterminism; - - use super::*; - - /// Does printing recurse infinitely? - #[test] - fn print_simple_processor_table_row() { - let program = triton_program!(push 2 sponge_init assert halt); - let err = VM::run(&program, [].into(), [].into()).unwrap_err(); - println!("\n{}", err.vm_state); - } - - #[derive(Debug, Clone)] - struct TestRows { - pub challenges: Challenges, - pub consecutive_master_base_table_rows: Array2, - pub consecutive_ext_base_table_rows: Array2, - } - - #[derive(Debug, Clone)] - struct TestRowsDebugInfo { - pub instruction: Instruction, - pub debug_cols_curr_row: Vec, - pub debug_cols_next_row: Vec, - } - - fn test_row_from_program(program: Program, row_num: usize) -> TestRows { - test_row_from_program_with_input(ProgramAndInput::new(program), row_num) - } - - fn test_row_from_program_with_input( - program_and_input: ProgramAndInput, - row_num: usize, - ) -> TestRows { - let (_, _, master_base_table, master_ext_table, challenges) = - master_tables_for_low_security_level(program_and_input); - TestRows { - challenges, - consecutive_master_base_table_rows: master_base_table - .trace_table() - .slice(s![row_num..=row_num + 1, ..]) - .to_owned(), - consecutive_ext_base_table_rows: master_ext_table - .trace_table() - .slice(s![row_num..=row_num + 1, ..]) - .to_owned(), - } - } - - fn assert_constraints_for_rows_with_debug_info( - test_rows: &[TestRows], - debug_info: TestRowsDebugInfo, - ) { - let instruction = debug_info.instruction; - let circuit_builder = ConstraintCircuitBuilder::new(); - let transition_constraints = ExtProcessorTable::transition_constraints_for_instruction( - &circuit_builder, - instruction, - ); - - for (case_idx, rows) in test_rows.iter().enumerate() { - let curr_row = rows.consecutive_master_base_table_rows.slice(s![0, ..]); - let next_row = rows.consecutive_master_base_table_rows.slice(s![1, ..]); - - println!("Testing all constraints of {instruction} for test case {case_idx}…"); - for &c in &debug_info.debug_cols_curr_row { - print!("{c} = {}, ", curr_row[c.master_base_table_index()]); - } - println!(); - for &c in &debug_info.debug_cols_next_row { - print!("{c}' = {}, ", next_row[c.master_base_table_index()]); - } - println!(); - - assert!( - instruction.opcode_b() == curr_row[CI.master_base_table_index()], - "The test is trying to check the wrong transition constraint polynomials." - ); - - for (constraint_idx, constraint) in transition_constraints.iter().enumerate() { - let evaluation_result = constraint.clone().consume().evaluate( - rows.consecutive_master_base_table_rows.view(), - rows.consecutive_ext_base_table_rows.view(), - &rows.challenges.challenges, - ); - assert!( - evaluation_result.is_zero(), - "For case {case_idx}, transition constraint polynomial with \ - index {constraint_idx} must evaluate to zero. Got {evaluation_result} instead.", - ); - } - } - } - - #[proptest(cases = 20)] - fn transition_constraints_for_instruction_pop_n(#[strategy(arb())] n: NumberOfWords) { - let program = triton_program!(push 1 push 2 push 3 push 4 push 5 pop {n} halt); - - let test_rows = [test_row_from_program(program, 5)]; - let debug_info = TestRowsDebugInfo { - instruction: Pop(n), - debug_cols_curr_row: vec![ST0, ST1, ST2], - debug_cols_next_row: vec![ST0, ST1, ST2], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_push() { - let test_rows = [test_row_from_program(triton_program!(push 1 halt), 0)]; - - let debug_info = TestRowsDebugInfo { - instruction: Push(bfe!(1)), - debug_cols_curr_row: vec![ST0, ST1, ST2], - debug_cols_next_row: vec![ST0, ST1, ST2], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[proptest(cases = 20)] - fn transition_constraints_for_instruction_divine_n(#[strategy(arb())] n: NumberOfWords) { - let program = triton_program! { divine {n} halt }; - - let non_determinism = (1..=16).map(|b| bfe!(b)).collect_vec(); - let program_and_input = ProgramAndInput::new(program).with_non_determinism(non_determinism); - let test_rows = [test_row_from_program_with_input(program_and_input, 0)]; - let debug_info = TestRowsDebugInfo { - instruction: Divine(n), - debug_cols_curr_row: vec![ST0, ST1, ST2], - debug_cols_next_row: vec![ST0, ST1, ST2], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_dup() { - let programs = [ - triton_program!(dup 0 halt), - triton_program!(dup 1 halt), - triton_program!(dup 2 halt), - triton_program!(dup 3 halt), - triton_program!(dup 4 halt), - triton_program!(dup 5 halt), - triton_program!(dup 6 halt), - triton_program!(dup 7 halt), - triton_program!(dup 8 halt), - triton_program!(dup 9 halt), - triton_program!(dup 10 halt), - triton_program!(dup 11 halt), - triton_program!(dup 12 halt), - triton_program!(dup 13 halt), - triton_program!(dup 14 halt), - triton_program!(dup 15 halt), - ]; - let test_rows = programs.map(|program| test_row_from_program(program, 0)); - - let debug_info = TestRowsDebugInfo { - instruction: Dup(OpStackElement::ST0), - debug_cols_curr_row: vec![ST0, ST1, ST2], - debug_cols_next_row: vec![ST0, ST1, ST2], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_swap() { - let test_rows = (0..OpStackElement::COUNT) - .map(|i| triton_program!(swap {i} halt)) - .map(|program| test_row_from_program(program, 0)) - .collect_vec(); - let debug_info = TestRowsDebugInfo { - instruction: Swap(OpStackElement::ST0), - debug_cols_curr_row: vec![ST0, ST1, ST2], - debug_cols_next_row: vec![ST0, ST1, ST2], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_skiz() { - let programs = [ - triton_program!(push 1 skiz halt), // ST0 is non-zero - triton_program!(push 0 skiz assert halt), // ST0 is zero, next instruction of size 1 - triton_program!(push 0 skiz push 1 halt), // ST0 is zero, next instruction of size 2 - ]; - let test_rows = programs.map(|program| test_row_from_program(program, 1)); - let debug_info = TestRowsDebugInfo { - instruction: Skiz, - debug_cols_curr_row: vec![IP, NIA, ST0, HV5, HV4, HV3, HV2], - debug_cols_next_row: vec![IP], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_call() { - let programs = [triton_program!(call label label: halt)]; - let test_rows = programs.map(|program| test_row_from_program(program, 0)); - let debug_info = TestRowsDebugInfo { - instruction: Call(BFieldElement::default()), - debug_cols_curr_row: vec![IP, CI, NIA, JSP, JSO, JSD], - debug_cols_next_row: vec![IP, CI, NIA, JSP, JSO, JSD], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_return() { - let programs = [triton_program!(call label halt label: return)]; - let test_rows = programs.map(|program| test_row_from_program(program, 1)); - let debug_info = TestRowsDebugInfo { - instruction: Return, - debug_cols_curr_row: vec![IP, JSP, JSO, JSD], - debug_cols_next_row: vec![IP, JSP, JSO, JSD], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_recurse() { - let programs = - [triton_program!(push 2 call label halt label: push -1 add dup 0 skiz recurse return)]; - let test_rows = programs.map(|program| test_row_from_program(program, 6)); - let debug_info = TestRowsDebugInfo { - instruction: Recurse, - debug_cols_curr_row: vec![IP, JSP, JSO, JSD], - debug_cols_next_row: vec![IP, JSP, JSO, JSD], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_recurse_or_return() { - let program = triton_program! { - push 2 swap 6 - call loop halt - loop: - swap 5 push 1 add swap 5 - recurse_or_return - }; - let test_rows = [ - test_row_from_program(program.clone(), 7), // recurse - test_row_from_program(program, 12), // return - ]; - let debug_info = TestRowsDebugInfo { - instruction: RecurseOrReturn, - debug_cols_curr_row: vec![IP, JSP, JSO, JSD, ST5, ST6, HV4], - debug_cols_next_row: vec![IP, JSP, JSO, JSD], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_read_mem() { - let programs = [ - triton_program!(push 1 read_mem 1 push 0 eq assert assert halt), - triton_program!(push 2 read_mem 2 push 0 eq assert swap 1 push 2 eq assert halt), - triton_program!(push 3 read_mem 3 push 0 eq assert swap 2 push 3 eq assert halt), - triton_program!(push 4 read_mem 4 push 0 eq assert swap 3 push 4 eq assert halt), - triton_program!(push 5 read_mem 5 push 0 eq assert swap 4 push 5 eq assert halt), - ]; - let initial_ram = (1..=5) - .map(|i| (bfe!(i), bfe!(i))) - .collect::>(); - let non_determinism = NonDeterminism::default().with_ram(initial_ram); - let programs_with_input = programs.map(|program| { - ProgramAndInput::new(program).with_non_determinism(non_determinism.clone()) - }); - let test_rows = programs_with_input.map(|p_w_i| test_row_from_program_with_input(p_w_i, 1)); - let debug_info = TestRowsDebugInfo { - instruction: ReadMem(N1), - debug_cols_curr_row: vec![ST0, ST1], - debug_cols_next_row: vec![ST0, ST1], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_write_mem() { - let push_10_elements = triton_asm![push 2; 10]; - let programs = [ - triton_program!({&push_10_elements} write_mem 1 halt), - triton_program!({&push_10_elements} write_mem 2 halt), - triton_program!({&push_10_elements} write_mem 3 halt), - triton_program!({&push_10_elements} write_mem 4 halt), - triton_program!({&push_10_elements} write_mem 5 halt), - ]; - let test_rows = programs.map(|program| test_row_from_program(program, 10)); - let debug_info = TestRowsDebugInfo { - instruction: WriteMem(N1), - debug_cols_curr_row: vec![ST0, ST1], - debug_cols_next_row: vec![ST0, ST1], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_merkle_step() { - let programs = [ - triton_program!(push 2 swap 5 merkle_step halt), - triton_program!(push 3 swap 5 merkle_step halt), - ]; - let dummy_digest = Digest::new(bfe_array![1, 2, 3, 4, 5]); - let non_determinism = NonDeterminism::default().with_digests(vec![dummy_digest]); - let programs_with_input = programs.map(|program| { - ProgramAndInput::new(program).with_non_determinism(non_determinism.clone()) - }); - let test_rows = programs_with_input.map(|p_w_i| test_row_from_program_with_input(p_w_i, 2)); - - let debug_info = TestRowsDebugInfo { - instruction: MerkleStep, - debug_cols_curr_row: vec![ST0, ST1, ST2, ST3, ST4, ST5, HV0, HV1, HV2, HV3, HV4, HV5], - debug_cols_next_row: vec![ST0, ST1, ST2, ST3, ST4, ST5], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_merkle_step_mem() { - let sibling_digest = bfe_array![1, 2, 3, 4, 5]; - let acc_digest = bfe_array![11, 12, 13, 14, 15]; - let test_program = |node_index: u32| { - triton_program! { - push 42 // RAM pointer - push 1 // dummy - push {node_index} - push {acc_digest[0]} - push {acc_digest[1]} - push {acc_digest[2]} - push {acc_digest[3]} - push {acc_digest[4]} - merkle_step_mem - halt - } - }; - let mut ram = HashMap::new(); - ram.insert(bfe!(42), sibling_digest[0]); - ram.insert(bfe!(43), sibling_digest[1]); - ram.insert(bfe!(44), sibling_digest[2]); - ram.insert(bfe!(45), sibling_digest[3]); - ram.insert(bfe!(46), sibling_digest[4]); - let non_determinism = NonDeterminism::default().with_ram(ram); - - let node_indices = [2, 3]; - let test_rows = node_indices - .map(test_program) - .map(ProgramAndInput::new) - .map(|p| p.with_non_determinism(non_determinism.clone())) - .map(|p| test_row_from_program_with_input(p, 8)); - - let debug_info = TestRowsDebugInfo { - instruction: MerkleStepMem, - debug_cols_curr_row: vec![ST0, ST1, ST2, ST3, ST4, ST5, HV0, HV1, HV2, HV3, HV4, HV5], - debug_cols_next_row: vec![ST0, ST1, ST2, ST3, ST4, ST5], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_sponge_init() { - let programs = [triton_program!(sponge_init halt)]; - let test_rows = programs.map(|program| test_row_from_program(program, 0)); - let debug_info = TestRowsDebugInfo { - instruction: SpongeInit, - debug_cols_curr_row: vec![], - debug_cols_next_row: vec![], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_sponge_absorb() { - let push_10_zeros = triton_asm![push 0; 10]; - let push_10_ones = triton_asm![push 1; 10]; - let programs = [ - triton_program!(sponge_init {&push_10_zeros} sponge_absorb halt), - triton_program!(sponge_init {&push_10_ones} sponge_absorb halt), - ]; - let test_rows = programs.map(|program| test_row_from_program(program, 11)); - let debug_info = TestRowsDebugInfo { - instruction: SpongeAbsorb, - debug_cols_curr_row: vec![ST0, ST1, ST2, ST3, ST4, ST5, ST6, ST7, ST8, ST9, ST10], - debug_cols_next_row: vec![ST0, ST1, ST2, ST3, ST4, ST5, ST6, ST7, ST8, ST9, ST10], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_sponge_absorb_mem() { - let programs = [triton_program!(sponge_init push 0 sponge_absorb_mem halt)]; - let test_rows = programs.map(|program| test_row_from_program(program, 2)); - let debug_info = TestRowsDebugInfo { - instruction: SpongeAbsorbMem, - debug_cols_curr_row: vec![ST0, HV0, HV1, HV2, HV3, HV4, HV5], - debug_cols_next_row: vec![ST0, ST1, ST2, ST3, ST4], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_sponge_squeeze() { - let programs = [triton_program!(sponge_init sponge_squeeze halt)]; - let test_rows = programs.map(|program| test_row_from_program(program, 1)); - let debug_info = TestRowsDebugInfo { - instruction: SpongeSqueeze, - debug_cols_curr_row: vec![ST0, ST1, ST2, ST3, ST4, ST5, ST6, ST7, ST8, ST9, ST10], - debug_cols_next_row: vec![ST0, ST1, ST2, ST3, ST4, ST5, ST6, ST7, ST8, ST9, ST10], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_eq() { - let programs = [ - triton_program!(push 3 push 3 eq assert halt), - triton_program!(push 3 push 2 eq push 0 eq assert halt), - ]; - let test_rows = programs.map(|program| test_row_from_program(program, 2)); - let debug_info = TestRowsDebugInfo { - instruction: Eq, - debug_cols_curr_row: vec![ST0, ST1, HV0], - debug_cols_next_row: vec![ST0], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_split() { - let programs = [ - triton_program!(push -1 split halt), - triton_program!(push 0 split halt), - triton_program!(push 1 split halt), - triton_program!(push 2 split halt), - triton_program!(push 3 split halt), - // test pushing push 2^32 +- 1 - triton_program!(push 4294967295 split halt), - triton_program!(push 4294967296 split halt), - triton_program!(push 4294967297 split halt), - ]; - let test_rows = programs.map(|program| test_row_from_program(program, 1)); - let debug_info = TestRowsDebugInfo { - instruction: Split, - debug_cols_curr_row: vec![ST0, ST1, HV0], - debug_cols_next_row: vec![ST0, ST1, HV0], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_lt() { - let programs = [ - triton_program!(push 3 push 3 lt push 0 eq assert halt), - triton_program!(push 3 push 2 lt push 1 eq assert halt), - triton_program!(push 2 push 3 lt push 0 eq assert halt), - triton_program!(push 512 push 513 lt push 0 eq assert halt), - ]; - let test_rows = programs.map(|program| test_row_from_program(program, 2)); - let debug_info = TestRowsDebugInfo { - instruction: Lt, - debug_cols_curr_row: vec![ST0, ST1], - debug_cols_next_row: vec![ST0], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_and() { - let test_rows = [test_row_from_program( - triton_program!(push 5 push 12 and push 4 eq assert halt), - 2, - )]; - let debug_info = TestRowsDebugInfo { - instruction: And, - debug_cols_curr_row: vec![ST0, ST1], - debug_cols_next_row: vec![ST0], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_xor() { - let test_rows = [test_row_from_program( - triton_program!(push 5 push 12 xor push 9 eq assert halt), - 2, - )]; - let debug_info = TestRowsDebugInfo { - instruction: Xor, - debug_cols_curr_row: vec![ST0, ST1], - debug_cols_next_row: vec![ST0], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_log2floor() { - let programs = [ - triton_program!(push 1 log_2_floor push 0 eq assert halt), - triton_program!(push 2 log_2_floor push 1 eq assert halt), - triton_program!(push 3 log_2_floor push 1 eq assert halt), - triton_program!(push 4 log_2_floor push 2 eq assert halt), - triton_program!(push 5 log_2_floor push 2 eq assert halt), - triton_program!(push 6 log_2_floor push 2 eq assert halt), - triton_program!(push 7 log_2_floor push 2 eq assert halt), - triton_program!(push 8 log_2_floor push 3 eq assert halt), - triton_program!(push 9 log_2_floor push 3 eq assert halt), - triton_program!(push 10 log_2_floor push 3 eq assert halt), - triton_program!(push 11 log_2_floor push 3 eq assert halt), - triton_program!(push 12 log_2_floor push 3 eq assert halt), - triton_program!(push 13 log_2_floor push 3 eq assert halt), - triton_program!(push 14 log_2_floor push 3 eq assert halt), - triton_program!(push 15 log_2_floor push 3 eq assert halt), - triton_program!(push 16 log_2_floor push 4 eq assert halt), - triton_program!(push 17 log_2_floor push 4 eq assert halt), - ]; - - let test_rows = programs.map(|program| test_row_from_program(program, 1)); - let debug_info = TestRowsDebugInfo { - instruction: Log2Floor, - debug_cols_curr_row: vec![ST0, ST1], - debug_cols_next_row: vec![ST0], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_pow() { - let programs = [ - triton_program!(push 0 push 0 pow push 1 eq assert halt), - triton_program!(push 1 push 0 pow push 0 eq assert halt), - triton_program!(push 2 push 0 pow push 0 eq assert halt), - triton_program!(push 0 push 1 pow push 1 eq assert halt), - triton_program!(push 1 push 1 pow push 1 eq assert halt), - triton_program!(push 2 push 1 pow push 1 eq assert halt), - triton_program!(push 0 push 2 pow push 1 eq assert halt), - triton_program!(push 1 push 2 pow push 2 eq assert halt), - triton_program!(push 2 push 2 pow push 4 eq assert halt), - triton_program!(push 3 push 2 pow push 8 eq assert halt), - triton_program!(push 4 push 2 pow push 16 eq assert halt), - triton_program!(push 5 push 2 pow push 32 eq assert halt), - triton_program!(push 0 push 3 pow push 1 eq assert halt), - triton_program!(push 1 push 3 pow push 3 eq assert halt), - triton_program!(push 2 push 3 pow push 9 eq assert halt), - triton_program!(push 3 push 3 pow push 27 eq assert halt), - triton_program!(push 4 push 3 pow push 81 eq assert halt), - triton_program!(push 0 push 17 pow push 1 eq assert halt), - triton_program!(push 1 push 17 pow push 17 eq assert halt), - triton_program!(push 2 push 17 pow push 289 eq assert halt), - ]; - - let test_rows = programs.map(|program| test_row_from_program(program, 2)); - let debug_info = TestRowsDebugInfo { - instruction: Pow, - debug_cols_curr_row: vec![ST0, ST1], - debug_cols_next_row: vec![ST0], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_div_mod() { - let programs = [ - triton_program!(push 2 push 3 div_mod push 1 eq assert push 1 eq assert halt), - triton_program!(push 3 push 7 div_mod push 1 eq assert push 2 eq assert halt), - triton_program!(push 4 push 7 div_mod push 3 eq assert push 1 eq assert halt), - ]; - let test_rows = programs.map(|program| test_row_from_program(program, 2)); - let debug_info = TestRowsDebugInfo { - instruction: DivMod, - debug_cols_curr_row: vec![ST0, ST1], - debug_cols_next_row: vec![ST0, ST1], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn division_by_zero_is_impossible() { - let program = ProgramAndInput::new(triton_program! { div_mod }); - let err = program.run().unwrap_err(); - assert_eq!(DivisionByZero, err.source); - } - - #[test] - fn transition_constraints_for_instruction_xx_add() { - let programs = [ - triton_program!(push 5 push 6 push 7 push 8 push 9 push 10 xx_add halt), - triton_program!(push 2 push 3 push 4 push -2 push -3 push -4 xx_add halt), - ]; - let test_rows = programs.map(|program| test_row_from_program(program, 6)); - let debug_info = TestRowsDebugInfo { - instruction: XxAdd, - debug_cols_curr_row: vec![ST0, ST1, ST2, ST3, ST4, ST5], - debug_cols_next_row: vec![ST0, ST1, ST2, ST3, ST4, ST5], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_xx_mul() { - let programs = [ - triton_program!(push 5 push 6 push 7 push 8 push 9 push 10 xx_mul halt), - triton_program!(push 2 push 3 push 4 push -2 push -3 push -4 xx_mul halt), - ]; - let test_rows = programs.map(|program| test_row_from_program(program, 6)); - let debug_info = TestRowsDebugInfo { - instruction: XxMul, - debug_cols_curr_row: vec![ST0, ST1, ST2, ST3, ST4, ST5], - debug_cols_next_row: vec![ST0, ST1, ST2, ST3, ST4, ST5], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_x_invert() { - let programs = [ - triton_program!(push 5 push 6 push 7 x_invert halt), - triton_program!(push -2 push -3 push -4 x_invert halt), - ]; - let test_rows = programs.map(|program| test_row_from_program(program, 3)); - let debug_info = TestRowsDebugInfo { - instruction: XInvert, - debug_cols_curr_row: vec![ST0, ST1, ST2], - debug_cols_next_row: vec![ST0, ST1, ST2], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_xb_mul() { - let programs = [ - triton_program!(push 5 push 6 push 7 push 2 xb_mul halt), - triton_program!(push 2 push 3 push 4 push -2 xb_mul halt), - ]; - let test_rows = programs.map(|program| test_row_from_program(program, 4)); - let debug_info = TestRowsDebugInfo { - instruction: XbMul, - debug_cols_curr_row: vec![ST0, ST1, ST2, ST3, OpStackPointer, HV0], - debug_cols_next_row: vec![ST0, ST1, ST2, ST3, OpStackPointer, HV0], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[proptest(cases = 20)] - fn transition_constraints_for_instruction_read_io_n(#[strategy(arb())] n: NumberOfWords) { - let program = triton_program! {read_io {n} halt}; - - let public_input = (1..=16).map(|i| bfe!(i)).collect_vec(); - let program_and_input = ProgramAndInput::new(program).with_input(public_input); - let test_rows = [test_row_from_program_with_input(program_and_input, 0)]; - let debug_info = TestRowsDebugInfo { - instruction: ReadIo(n), - debug_cols_curr_row: vec![ST0, ST1, ST2], - debug_cols_next_row: vec![ST0, ST1, ST2], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[proptest(cases = 20)] - fn transition_constraints_for_instruction_write_io_n(#[strategy(arb())] n: NumberOfWords) { - let program = triton_program! {divine 5 write_io {n} halt}; - - let non_determinism = (1..=16).map(|b| bfe!(b)).collect_vec(); - let program_and_input = ProgramAndInput::new(program).with_non_determinism(non_determinism); - let test_rows = [test_row_from_program_with_input(program_and_input, 1)]; - let debug_info = TestRowsDebugInfo { - instruction: WriteIo(n), - debug_cols_curr_row: vec![ST0, ST1, ST2, ST3, ST4, ST5], - debug_cols_next_row: vec![ST0, ST1, ST2, ST3, ST4, ST5], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_xb_dot_step() { - let program = triton_program! { - push 10 push 20 push 30 // accumulator `[30, 20, 10]` - push 96 // pointer to extension-field element `[3, 5, 7]` - push 42 // pointer to base-field element `2` - xb_dot_step - push 43 eq assert - push 99 eq assert - push {30 + 2 * 3} eq assert - push {20 + 2 * 5} eq assert - push {10 + 2 * 7} eq assert - halt - }; - - let mut ram = HashMap::new(); - ram.insert(bfe!(42), bfe!(2)); - ram.insert(bfe!(96), bfe!(3)); - ram.insert(bfe!(97), bfe!(5)); - ram.insert(bfe!(98), bfe!(7)); - let non_determinism = NonDeterminism::default().with_ram(ram); - let program_and_input = ProgramAndInput::new(program).with_non_determinism(non_determinism); - let test_rows = [test_row_from_program_with_input(program_and_input, 5)]; - let debug_info = TestRowsDebugInfo { - instruction: XbDotStep, - debug_cols_curr_row: vec![ST0, ST1, ST2, ST3, ST4, HV0, HV1, HV2, HV3], - debug_cols_next_row: vec![ST0, ST1, ST2, ST3, ST4], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn transition_constraints_for_instruction_xx_dot_step() { - let operand_0 = xfe!([3, 5, 7]); - let operand_1 = xfe!([11, 13, 17]); - let product = operand_0 * operand_1; - - let program = triton_program! { - push 10 push 20 push 30 // accumulator `[30, 20, 10]` - push 96 // pointer to `operand_1` - push 42 // pointer to `operand_0` - xx_dot_step - push 45 eq assert - push 99 eq assert - push {bfe!(30) + product.coefficients[0]} eq assert - push {bfe!(20) + product.coefficients[1]} eq assert - push {bfe!(10) + product.coefficients[2]} eq assert - halt - }; - - let mut ram = HashMap::new(); - ram.insert(bfe!(42), operand_0.coefficients[0]); - ram.insert(bfe!(43), operand_0.coefficients[1]); - ram.insert(bfe!(44), operand_0.coefficients[2]); - ram.insert(bfe!(96), operand_1.coefficients[0]); - ram.insert(bfe!(97), operand_1.coefficients[1]); - ram.insert(bfe!(98), operand_1.coefficients[2]); - let non_determinism = NonDeterminism::default().with_ram(ram); - let program_and_input = ProgramAndInput::new(program).with_non_determinism(non_determinism); - let test_rows = [test_row_from_program_with_input(program_and_input, 5)]; - let debug_info = TestRowsDebugInfo { - instruction: XxDotStep, - debug_cols_curr_row: vec![ST0, ST1, ST2, ST3, ST4, HV0, HV1, HV2, HV3, HV4, HV5], - debug_cols_next_row: vec![ST0, ST1, ST2, ST3, ST4], - }; - assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); - } - - #[test] - fn instruction_deselector_gives_0_for_all_other_instructions() { - let circuit_builder = ConstraintCircuitBuilder::new(); - - let mut master_base_table = Array2::zeros([2, NUM_BASE_COLUMNS]); - let master_ext_table = Array2::zeros([2, NUM_EXT_COLUMNS]); - - // For this test, dummy challenges suffice to evaluate the constraints. - let dummy_challenges = Challenges::default().challenges; - for instruction in ALL_INSTRUCTIONS { - use ProcessorBaseTableColumn::*; - let deselector = ExtProcessorTable::instruction_deselector_current_row( - &circuit_builder, - instruction, - ); - - println!("\n\nThe Deselector for instruction {instruction} is:\n{deselector}",); - - // Negative tests - for other_instruction in ALL_INSTRUCTIONS - .into_iter() - .filter(|other_instruction| *other_instruction != instruction) - { - let mut curr_row = master_base_table.slice_mut(s![0, ..]); - curr_row[IB0.master_base_table_index()] = other_instruction.ib(InstructionBit::IB0); - curr_row[IB1.master_base_table_index()] = other_instruction.ib(InstructionBit::IB1); - curr_row[IB2.master_base_table_index()] = other_instruction.ib(InstructionBit::IB2); - curr_row[IB3.master_base_table_index()] = other_instruction.ib(InstructionBit::IB3); - curr_row[IB4.master_base_table_index()] = other_instruction.ib(InstructionBit::IB4); - curr_row[IB5.master_base_table_index()] = other_instruction.ib(InstructionBit::IB5); - curr_row[IB6.master_base_table_index()] = other_instruction.ib(InstructionBit::IB6); - let result = deselector.clone().consume().evaluate( - master_base_table.view(), - master_ext_table.view(), - &dummy_challenges, - ); - - assert!( - result.is_zero(), - "Deselector for {instruction} should return 0 for all other instructions, \ - including {other_instruction} whose opcode is {}", - other_instruction.opcode() - ) - } - - // Positive tests - let mut curr_row = master_base_table.slice_mut(s![0, ..]); - curr_row[IB0.master_base_table_index()] = instruction.ib(InstructionBit::IB0); - curr_row[IB1.master_base_table_index()] = instruction.ib(InstructionBit::IB1); - curr_row[IB2.master_base_table_index()] = instruction.ib(InstructionBit::IB2); - curr_row[IB3.master_base_table_index()] = instruction.ib(InstructionBit::IB3); - curr_row[IB4.master_base_table_index()] = instruction.ib(InstructionBit::IB4); - curr_row[IB5.master_base_table_index()] = instruction.ib(InstructionBit::IB5); - curr_row[IB6.master_base_table_index()] = instruction.ib(InstructionBit::IB6); - let result = deselector.consume().evaluate( - master_base_table.view(), - master_ext_table.view(), - &dummy_challenges, - ); - assert!( - !result.is_zero(), - "Deselector for {instruction} should be non-zero when CI is {}", - instruction.opcode() - ) - } - } - - #[test] - fn print_number_and_degrees_of_transition_constraints_for_all_instructions() { - println!(); - println!("| Instruction | #polys | max deg | Degrees"); - println!("|:--------------------|-------:|--------:|:------------"); - let circuit_builder = ConstraintCircuitBuilder::new(); - for instruction in ALL_INSTRUCTIONS { - let constraints = ExtProcessorTable::transition_constraints_for_instruction( - &circuit_builder, - instruction, - ); - let degrees = constraints - .iter() - .map(|circuit| circuit.clone().consume().degree()) - .collect_vec(); - let max_degree = degrees.iter().max().unwrap_or(&0); - let degrees_str = degrees.iter().join(", "); - println!( - "| {:<19} | {:>6} | {max_degree:>7} | [{degrees_str}]", - format!("{instruction}"), - constraints.len(), - ); - } - } - - #[test] - fn opcode_decomposition_for_skiz_is_unique() { - let max_value_of_skiz_constraint_for_nia_decomposition = - (3 << 7) * (3 << 5) * (3 << 3) * (3 << 1) * 2; - for instruction in Instruction::iter() { - assert!( - instruction.opcode() < max_value_of_skiz_constraint_for_nia_decomposition, - "Opcode for {instruction} is too high." - ); - } - } - - #[test] - fn range_check_for_skiz_is_as_efficient_as_possible() { - let range_check_constraints = - ExtProcessorTable::next_instruction_range_check_constraints_for_instruction_skiz( - &ConstraintCircuitBuilder::new(), - ); - let range_check_constraints = range_check_constraints.iter(); - let all_degrees = range_check_constraints.map(|c| c.clone().consume().degree()); - let max_constraint_degree = all_degrees.max().unwrap_or(0); - assert!( - AIR_TARGET_DEGREE <= max_constraint_degree, - "Can the range check constraints be of a higher degree, saving columns?" - ); - } - - #[test] - fn helper_variables_in_bounds() { - let circuit_builder = ConstraintCircuitBuilder::new(); - for index in 0..NUM_HELPER_VARIABLE_REGISTERS { - ExtProcessorTable::helper_variable(&circuit_builder, index); - } - } - - #[test] - #[should_panic(expected = "out of bounds")] - fn helper_variables_out_of_bounds() { - let index = thread_rng().gen_range(NUM_HELPER_VARIABLE_REGISTERS..usize::MAX); - let circuit_builder = ConstraintCircuitBuilder::new(); - ExtProcessorTable::helper_variable(&circuit_builder, index); - } - - #[test] - fn indicator_polynomial_in_bounds() { - let circuit_builder = ConstraintCircuitBuilder::new(); - for index in 0..16 { - ExtProcessorTable::indicator_polynomial(&circuit_builder, index); - } - } - - #[test] - #[should_panic(expected = "out of bounds")] - fn indicator_polynomial_out_of_bounds() { - let index = thread_rng().gen_range(16..usize::MAX); - let circuit_builder = ConstraintCircuitBuilder::new(); - ExtProcessorTable::indicator_polynomial(&circuit_builder, index); - } - - #[proptest] - fn indicator_polynomial_is_one_on_indicated_index_and_zero_on_all_other_indices( - #[strategy(0_usize..16)] indicator_poly_index: usize, - #[strategy(0_usize..16)] query_index: usize, - ) { - // Unfortunately, setting up the query index requires a pretty elaborate setup. - let program = triton_program!(dup {query_index} halt); - let input = PublicInput::default(); - let non_determinism = NonDeterminism::default(); - let vm_state = VMState::new(&program, input, non_determinism); - let helper_variables = vm_state.derive_helper_variables(); - - let mut base_table = Array2::ones([2, NUM_BASE_COLUMNS]); - base_table[[0, HV0.master_base_table_index()]] = helper_variables[0]; - base_table[[0, HV1.master_base_table_index()]] = helper_variables[1]; - base_table[[0, HV2.master_base_table_index()]] = helper_variables[2]; - base_table[[0, HV3.master_base_table_index()]] = helper_variables[3]; - base_table[[0, HV4.master_base_table_index()]] = helper_variables[4]; - base_table[[0, HV5.master_base_table_index()]] = helper_variables[5]; - - let builder = ConstraintCircuitBuilder::new(); - let indicator_poly = - ExtProcessorTable::indicator_polynomial(&builder, indicator_poly_index); - let indicator_poly = indicator_poly.consume(); - - let ext_table = Array2::ones([2, NUM_EXT_COLUMNS]); - let challenges = Challenges::default().challenges; - let evaluation = indicator_poly.evaluate(base_table.view(), ext_table.view(), &challenges); - - if indicator_poly_index == query_index { - prop_assert_eq!(xfe!(1), evaluation); - } else { - prop_assert_eq!(xfe!(0), evaluation); - } - } - - #[test] - fn can_get_op_stack_column_for_in_range_index() { - for index in 0..OpStackElement::COUNT { - let _ = ProcessorTable::op_stack_column_by_index(index); - } - } - - #[proptest] - #[should_panic(expected = "[0, 15]")] - fn cannot_get_op_stack_column_for_out_of_range_index( - #[strategy(OpStackElement::COUNT..)] index: usize, - ) { - let _ = ProcessorTable::op_stack_column_by_index(index); - } - - #[test] - fn can_get_stack_weight_for_in_range_index() { - for index in 0..OpStackElement::COUNT { - let _ = ExtProcessorTable::stack_weight_by_index(index); - } - } - - #[proptest] - #[should_panic(expected = "[0, 15]")] - fn cannot_get_stack_weight_for_out_of_range_index( - #[strategy(OpStackElement::COUNT..)] index: usize, - ) { - let _ = ExtProcessorTable::stack_weight_by_index(index); - } - - #[proptest] - fn constructing_factor_for_op_stack_table_running_product_never_panics( - #[strategy(vec(arb(), BASE_WIDTH))] previous_row: Vec, - #[strategy(vec(arb(), BASE_WIDTH))] current_row: Vec, - #[strategy(arb())] challenges: Challenges, - ) { - let previous_row = Array1::from(previous_row); - let current_row = Array1::from(current_row); - let _ = ProcessorTable::factor_for_op_stack_table_running_product( - previous_row.view(), - current_row.view(), - &challenges, - ); - } - - #[proptest] - fn constructing_factor_for_ram_table_running_product_never_panics( - #[strategy(vec(arb(), BASE_WIDTH))] previous_row: Vec, - #[strategy(vec(arb(), BASE_WIDTH))] current_row: Vec, - #[strategy(arb())] challenges: Challenges, - ) { - let previous_row = Array1::from(previous_row); - let current_row = Array1::from(current_row); - let _ = ProcessorTable::factor_for_ram_table_running_product( - previous_row.view(), - current_row.view(), - &challenges, - ); - } - - #[proptest] - fn xx_product_is_accurate( - #[strategy(arb())] a: XFieldElement, - #[strategy(arb())] b: XFieldElement, - ) { - let circuit_builder = ConstraintCircuitBuilder::new(); - let base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(BaseRow(col.master_base_table_index())) - }; - let [x0, x1, x2, y0, y1, y2] = [ST0, ST1, ST2, ST3, ST4, ST5].map(base_row); - - let mut base_table = Array2::zeros([1, NUM_BASE_COLUMNS]); - let ext_table = Array2::zeros([1, NUM_EXT_COLUMNS]); - let challenges = Challenges::default().challenges; - base_table[[0, ST0.master_base_table_index()]] = a.coefficients[0]; - base_table[[0, ST1.master_base_table_index()]] = a.coefficients[1]; - base_table[[0, ST2.master_base_table_index()]] = a.coefficients[2]; - base_table[[0, ST3.master_base_table_index()]] = b.coefficients[0]; - base_table[[0, ST4.master_base_table_index()]] = b.coefficients[1]; - base_table[[0, ST5.master_base_table_index()]] = b.coefficients[2]; - - let [c0, c1, c2] = ExtProcessorTable::xx_product([x0, x1, x2], [y0, y1, y2]) - .map(|c| c.consume()) - .map(|c| c.evaluate(base_table.view(), ext_table.view(), &challenges)); - - let c = a * b; - prop_assert_eq!(c.coefficients[0], c0.coefficients[0]); - prop_assert_eq!(c.coefficients[1], c1.coefficients[0]); - prop_assert_eq!(c.coefficients[2], c2.coefficients[0]); - } - - #[proptest] - fn xb_product_is_accurate( - #[strategy(arb())] a: XFieldElement, - #[strategy(arb())] b: BFieldElement, - ) { - let circuit_builder = ConstraintCircuitBuilder::new(); - let base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(BaseRow(col.master_base_table_index())) - }; - let [x0, x1, x2, y] = [ST0, ST1, ST2, ST3].map(base_row); - - let mut base_table = Array2::zeros([1, NUM_BASE_COLUMNS]); - let ext_table = Array2::zeros([1, NUM_EXT_COLUMNS]); - let challenges = Challenges::default().challenges; - base_table[[0, ST0.master_base_table_index()]] = a.coefficients[0]; - base_table[[0, ST1.master_base_table_index()]] = a.coefficients[1]; - base_table[[0, ST2.master_base_table_index()]] = a.coefficients[2]; - base_table[[0, ST3.master_base_table_index()]] = b; - - let [c0, c1, c2] = ExtProcessorTable::xb_product([x0, x1, x2], y) - .map(|c| c.consume()) - .map(|c| c.evaluate(base_table.view(), ext_table.view(), &challenges)); - - let c = a * b; - prop_assert_eq!(c.coefficients[0], c0.coefficients[0]); - prop_assert_eq!(c.coefficients[1], c1.coefficients[0]); - prop_assert_eq!(c.coefficients[2], c2.coefficients[0]); - } -} diff --git a/triton-vm/src/table/program.rs b/triton-vm/src/table/program.rs new file mode 100644 index 000000000..86e77f533 --- /dev/null +++ b/triton-vm/src/table/program.rs @@ -0,0 +1,259 @@ +use std::cmp::Ordering; + +use air::challenge_id::ChallengeId::*; +use air::cross_table_argument::CrossTableArg; +use air::cross_table_argument::EvalArg; +use air::cross_table_argument::LookupArg; +use air::table::program::ProgramTable; +use air::table::TableId; +use air::table_column::ProgramBaseTableColumn::*; +use air::table_column::ProgramExtTableColumn::*; +use air::table_column::*; +use constraint_circuit::DualRowIndicator::*; +use constraint_circuit::SingleRowIndicator::*; +use constraint_circuit::*; +use ndarray::s; +use ndarray::Array1; +use ndarray::ArrayView1; +use ndarray::ArrayView2; +use ndarray::ArrayViewMut2; +use num_traits::One; +use num_traits::Zero; +use strum::EnumCount; +use twenty_first::prelude::*; + +use crate::aet::AlgebraicExecutionTrace; +use crate::challenges::Challenges; +use crate::profiler::profiler; +use crate::table::TraceTable; + +impl TraceTable for ProgramTable { + type FillParam = (); + type FillReturnInfo = (); + + fn fill(mut program_table: ArrayViewMut2, aet: &AlgebraicExecutionTrace, _: ()) { + let max_index_in_chunk = bfe!(Tip5::RATE as u64 - 1); + + let instructions = aet.program.to_bwords(); + let program_len = instructions.len(); + let padded_program_len = aet.height_of_table(TableId::Program); + + let one_iter = bfe_array![1].into_iter(); + let zero_iter = bfe_array![0].into_iter(); + let padding_iter = one_iter.chain(zero_iter.cycle()); + let padded_instructions = instructions.into_iter().chain(padding_iter); + let padded_instructions = padded_instructions.take(padded_program_len); + + for (row_idx, instruction) in padded_instructions.enumerate() { + let address = u64::try_from(row_idx).unwrap(); + let address = bfe!(address); + + let lookup_multiplicity = match row_idx.cmp(&program_len) { + Ordering::Less => aet.instruction_multiplicities[row_idx], + _ => 0, + }; + let lookup_multiplicity = bfe!(lookup_multiplicity); + let index_in_chunk = bfe!((row_idx % Tip5::RATE) as u64); + + let max_minus_index_in_chunk_inv = + (max_index_in_chunk - index_in_chunk).inverse_or_zero(); + + let is_hash_input_padding = match row_idx.cmp(&program_len) { + Ordering::Less => bfe!(0), + _ => bfe!(1), + }; + + let mut current_row = program_table.row_mut(row_idx); + current_row[Address.base_table_index()] = address; + current_row[Instruction.base_table_index()] = instruction; + current_row[LookupMultiplicity.base_table_index()] = lookup_multiplicity; + current_row[IndexInChunk.base_table_index()] = index_in_chunk; + current_row[MaxMinusIndexInChunkInv.base_table_index()] = max_minus_index_in_chunk_inv; + current_row[IsHashInputPadding.base_table_index()] = is_hash_input_padding; + } + } + + fn pad(mut program_table: ArrayViewMut2, program_len: usize) { + let addresses = + (program_len..program_table.nrows()).map(|a| bfe!(u64::try_from(a).unwrap())); + let addresses = Array1::from_iter(addresses); + let address_column = program_table.slice_mut(s![program_len.., Address.base_table_index()]); + addresses.move_into(address_column); + + let indices_in_chunks = (program_len..program_table.nrows()) + .map(|idx| idx % Tip5::RATE) + .map(|ac| bfe!(u64::try_from(ac).unwrap())); + let indices_in_chunks = Array1::from_iter(indices_in_chunks); + let index_in_chunk_column = + program_table.slice_mut(s![program_len.., IndexInChunk.base_table_index()]); + indices_in_chunks.move_into(index_in_chunk_column); + + let max_minus_indices_in_chunks_inverses = (program_len..program_table.nrows()) + .map(|idx| Tip5::RATE - 1 - (idx % Tip5::RATE)) + .map(|ac| BFieldElement::new(ac.try_into().unwrap())) + .map(|bfe| bfe.inverse_or_zero()); + let max_minus_indices_in_chunks_inverses = + Array1::from_iter(max_minus_indices_in_chunks_inverses); + let max_minus_index_in_chunk_inv_column = program_table.slice_mut(s![ + program_len.., + MaxMinusIndexInChunkInv.base_table_index() + ]); + max_minus_indices_in_chunks_inverses.move_into(max_minus_index_in_chunk_inv_column); + + program_table + .slice_mut(s![program_len.., IsHashInputPadding.base_table_index()]) + .fill(BFieldElement::one()); + program_table + .slice_mut(s![program_len.., IsTablePadding.base_table_index()]) + .fill(BFieldElement::one()); + } + + fn extend( + main_table: ArrayView2, + mut aux_table: ArrayViewMut2, + challenges: &Challenges, + ) { + profiler!(start "program table"); + assert_eq!(Self::MainColumn::COUNT, main_table.ncols()); + assert_eq!(Self::AuxColumn::COUNT, aux_table.ncols()); + assert_eq!(main_table.nrows(), aux_table.nrows()); + + let mut instruction_lookup_log_derivative = LookupArg::default_initial(); + let mut prepare_chunk_running_evaluation = EvalArg::default_initial(); + let mut send_chunk_running_evaluation = EvalArg::default_initial(); + + for (idx, consecutive_rows) in main_table + .windows([2, Self::MainColumn::COUNT]) + .into_iter() + .enumerate() + { + let row = consecutive_rows.row(0); + let next_row = consecutive_rows.row(1); + let mut extension_row = aux_table.row_mut(idx); + + // In the Program Table, the logarithmic derivative for the instruction lookup + // argument does record the initial in the first row, as an exception to all other + // table-linking arguments. + // This is necessary because an instruction's potential argument, or else the next + // instruction, is recorded in the next row. To be able to check correct initialization + // of the logarithmic derivative, both the current and the next row must be accessible + // to the constraint. Only transition constraints can access both rows. Hence, the + // initial value of the logarithmic derivative must be independent of the second row. + // The logarithmic derivative's final value, allowing for a meaningful cross-table + // argument, is recorded in the first padding row. This row is guaranteed to exist + // due to the hash-input padding mechanics. + extension_row[InstructionLookupServerLogDerivative.ext_table_index()] = + instruction_lookup_log_derivative; + + instruction_lookup_log_derivative = update_instruction_lookup_log_derivative( + challenges, + row, + next_row, + instruction_lookup_log_derivative, + ); + prepare_chunk_running_evaluation = update_prepare_chunk_running_evaluation( + row, + challenges, + prepare_chunk_running_evaluation, + ); + send_chunk_running_evaluation = update_send_chunk_running_evaluation( + row, + challenges, + send_chunk_running_evaluation, + prepare_chunk_running_evaluation, + ); + + extension_row[PrepareChunkRunningEvaluation.ext_table_index()] = + prepare_chunk_running_evaluation; + extension_row[SendChunkRunningEvaluation.ext_table_index()] = + send_chunk_running_evaluation; + } + + // special treatment for the last row + let base_rows_iter = main_table.rows().into_iter(); + let ext_rows_iter = aux_table.rows_mut().into_iter(); + let last_base_row = base_rows_iter.last().unwrap(); + let mut last_ext_row = ext_rows_iter.last().unwrap(); + + prepare_chunk_running_evaluation = update_prepare_chunk_running_evaluation( + last_base_row, + challenges, + prepare_chunk_running_evaluation, + ); + send_chunk_running_evaluation = update_send_chunk_running_evaluation( + last_base_row, + challenges, + send_chunk_running_evaluation, + prepare_chunk_running_evaluation, + ); + + last_ext_row[InstructionLookupServerLogDerivative.ext_table_index()] = + instruction_lookup_log_derivative; + last_ext_row[PrepareChunkRunningEvaluation.ext_table_index()] = + prepare_chunk_running_evaluation; + last_ext_row[SendChunkRunningEvaluation.ext_table_index()] = send_chunk_running_evaluation; + + profiler!(stop "program table"); + } +} + +fn update_instruction_lookup_log_derivative( + challenges: &Challenges, + row: ArrayView1, + next_row: ArrayView1, + instruction_lookup_log_derivative: XFieldElement, +) -> XFieldElement { + if row[IsHashInputPadding.base_table_index()].is_one() { + return instruction_lookup_log_derivative; + } + instruction_lookup_log_derivative + + instruction_lookup_log_derivative_summand(row, next_row, challenges) +} + +fn instruction_lookup_log_derivative_summand( + row: ArrayView1, + next_row: ArrayView1, + challenges: &Challenges, +) -> XFieldElement { + let compressed_row = row[Address.base_table_index()] * challenges[ProgramAddressWeight] + + row[Instruction.base_table_index()] * challenges[ProgramInstructionWeight] + + next_row[Instruction.base_table_index()] * challenges[ProgramNextInstructionWeight]; + (challenges[InstructionLookupIndeterminate] - compressed_row).inverse() + * row[LookupMultiplicity.base_table_index()] +} + +fn update_prepare_chunk_running_evaluation( + row: ArrayView1, + challenges: &Challenges, + prepare_chunk_running_evaluation: XFieldElement, +) -> XFieldElement { + let running_evaluation_resets = row[IndexInChunk.base_table_index()].is_zero(); + let prepare_chunk_running_evaluation = if running_evaluation_resets { + EvalArg::default_initial() + } else { + prepare_chunk_running_evaluation + }; + + prepare_chunk_running_evaluation * challenges[ProgramAttestationPrepareChunkIndeterminate] + + row[Instruction.base_table_index()] +} + +fn update_send_chunk_running_evaluation( + row: ArrayView1, + challenges: &Challenges, + send_chunk_running_evaluation: XFieldElement, + prepare_chunk_running_evaluation: XFieldElement, +) -> XFieldElement { + let index_in_chunk = row[IndexInChunk.base_table_index()]; + let is_table_padding_row = row[IsTablePadding.base_table_index()].is_one(); + let max_index_in_chunk = Tip5::RATE as u64 - 1; + let running_evaluation_needs_update = + !is_table_padding_row && index_in_chunk.value() == max_index_in_chunk; + + if !running_evaluation_needs_update { + return send_chunk_running_evaluation; + } + + send_chunk_running_evaluation * challenges[ProgramAttestationSendChunkIndeterminate] + + prepare_chunk_running_evaluation +} diff --git a/triton-vm/src/table/program_table.rs b/triton-vm/src/table/program_table.rs deleted file mode 100644 index 568fd8884..000000000 --- a/triton-vm/src/table/program_table.rs +++ /dev/null @@ -1,510 +0,0 @@ -use std::cmp::Ordering; - -use constraint_circuit::DualRowIndicator::*; -use constraint_circuit::SingleRowIndicator::*; -use constraint_circuit::*; -use ndarray::s; -use ndarray::Array1; -use ndarray::ArrayView1; -use ndarray::ArrayView2; -use ndarray::ArrayViewMut2; -use num_traits::One; -use num_traits::Zero; -use strum::EnumCount; -use twenty_first::prelude::*; - -use crate::aet::AlgebraicExecutionTrace; -use crate::profiler::profiler; -use crate::table::challenges::ChallengeId::*; -use crate::table::challenges::Challenges; -use crate::table::cross_table_argument::CrossTableArg; -use crate::table::cross_table_argument::EvalArg; -use crate::table::cross_table_argument::LookupArg; -use crate::table::master_table::TableId; -use crate::table::table_column::ProgramBaseTableColumn::*; -use crate::table::table_column::ProgramExtTableColumn::*; -use crate::table::table_column::*; - -pub const BASE_WIDTH: usize = ProgramBaseTableColumn::COUNT; -pub const EXT_WIDTH: usize = ProgramExtTableColumn::COUNT; -pub const FULL_WIDTH: usize = BASE_WIDTH + EXT_WIDTH; - -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct ProgramTable; - -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct ExtProgramTable; - -impl ExtProgramTable { - pub fn initial_constraints( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let challenge = |c| circuit_builder.challenge(c); - let x_constant = |xfe| circuit_builder.x_constant(xfe); - let base_row = |col: ProgramBaseTableColumn| { - circuit_builder.input(BaseRow(col.master_base_table_index())) - }; - let ext_row = |col: ProgramExtTableColumn| { - circuit_builder.input(ExtRow(col.master_ext_table_index())) - }; - - let address = base_row(Address); - let instruction = base_row(Instruction); - let index_in_chunk = base_row(IndexInChunk); - let is_hash_input_padding = base_row(IsHashInputPadding); - let instruction_lookup_log_derivative = ext_row(InstructionLookupServerLogDerivative); - let prepare_chunk_running_evaluation = ext_row(PrepareChunkRunningEvaluation); - let send_chunk_running_evaluation = ext_row(SendChunkRunningEvaluation); - - let lookup_arg_initial = x_constant(LookupArg::default_initial()); - let eval_arg_initial = x_constant(EvalArg::default_initial()); - - let program_attestation_prepare_chunk_indeterminate = - challenge(ProgramAttestationPrepareChunkIndeterminate); - - let first_address_is_zero = address; - let index_in_chunk_is_zero = index_in_chunk; - let hash_input_padding_indicator_is_zero = is_hash_input_padding; - - let instruction_lookup_log_derivative_is_initialized_correctly = - instruction_lookup_log_derivative - lookup_arg_initial; - - let prepare_chunk_running_evaluation_has_absorbed_first_instruction = - prepare_chunk_running_evaluation - - eval_arg_initial.clone() * program_attestation_prepare_chunk_indeterminate - - instruction; - - let send_chunk_running_evaluation_is_default_initial = - send_chunk_running_evaluation - eval_arg_initial; - - vec![ - first_address_is_zero, - index_in_chunk_is_zero, - hash_input_padding_indicator_is_zero, - instruction_lookup_log_derivative_is_initialized_correctly, - prepare_chunk_running_evaluation_has_absorbed_first_instruction, - send_chunk_running_evaluation_is_default_initial, - ] - } - - pub fn consistency_constraints( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constant = |c: u32| circuit_builder.b_constant(c); - let base_row = |col: ProgramBaseTableColumn| { - circuit_builder.input(BaseRow(col.master_base_table_index())) - }; - - let one = constant(1); - let max_index_in_chunk = constant((Tip5::RATE - 1).try_into().unwrap()); - - let index_in_chunk = base_row(IndexInChunk); - let max_minus_index_in_chunk_inv = base_row(MaxMinusIndexInChunkInv); - let is_hash_input_padding = base_row(IsHashInputPadding); - let is_table_padding = base_row(IsTablePadding); - - let max_minus_index_in_chunk = max_index_in_chunk - index_in_chunk; - let max_minus_index_in_chunk_inv_is_zero_or_the_inverse_of_max_minus_index_in_chunk = - (one.clone() - max_minus_index_in_chunk.clone() * max_minus_index_in_chunk_inv.clone()) - * max_minus_index_in_chunk_inv.clone(); - let max_minus_index_in_chunk_is_zero_or_the_inverse_of_max_minus_index_in_chunk_inv = - (one.clone() - max_minus_index_in_chunk.clone() * max_minus_index_in_chunk_inv) - * max_minus_index_in_chunk; - - let is_hash_input_padding_is_bit = - is_hash_input_padding.clone() * (is_hash_input_padding - one.clone()); - let is_table_padding_is_bit = is_table_padding.clone() * (is_table_padding - one); - - vec![ - max_minus_index_in_chunk_inv_is_zero_or_the_inverse_of_max_minus_index_in_chunk, - max_minus_index_in_chunk_is_zero_or_the_inverse_of_max_minus_index_in_chunk_inv, - is_hash_input_padding_is_bit, - is_table_padding_is_bit, - ] - } - - pub fn transition_constraints( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let challenge = |c| circuit_builder.challenge(c); - let constant = |c: u64| circuit_builder.b_constant(c); - - let current_base_row = |col: ProgramBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProgramBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - let current_ext_row = |col: ProgramExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProgramExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; - - let one = constant(1); - let rate_minus_one = constant(u64::try_from(Tip5::RATE).unwrap() - 1); - - let prepare_chunk_indeterminate = challenge(ProgramAttestationPrepareChunkIndeterminate); - let send_chunk_indeterminate = challenge(ProgramAttestationSendChunkIndeterminate); - - let address = current_base_row(Address); - let instruction = current_base_row(Instruction); - let lookup_multiplicity = current_base_row(LookupMultiplicity); - let index_in_chunk = current_base_row(IndexInChunk); - let max_minus_index_in_chunk_inv = current_base_row(MaxMinusIndexInChunkInv); - let is_hash_input_padding = current_base_row(IsHashInputPadding); - let is_table_padding = current_base_row(IsTablePadding); - let log_derivative = current_ext_row(InstructionLookupServerLogDerivative); - let prepare_chunk_running_evaluation = current_ext_row(PrepareChunkRunningEvaluation); - let send_chunk_running_evaluation = current_ext_row(SendChunkRunningEvaluation); - - let address_next = next_base_row(Address); - let instruction_next = next_base_row(Instruction); - let index_in_chunk_next = next_base_row(IndexInChunk); - let max_minus_index_in_chunk_inv_next = next_base_row(MaxMinusIndexInChunkInv); - let is_hash_input_padding_next = next_base_row(IsHashInputPadding); - let is_table_padding_next = next_base_row(IsTablePadding); - let log_derivative_next = next_ext_row(InstructionLookupServerLogDerivative); - let prepare_chunk_running_evaluation_next = next_ext_row(PrepareChunkRunningEvaluation); - let send_chunk_running_evaluation_next = next_ext_row(SendChunkRunningEvaluation); - - let address_increases_by_one = address_next - (address.clone() + one.clone()); - let is_table_padding_is_0_or_remains_unchanged = - is_table_padding.clone() * (is_table_padding_next.clone() - is_table_padding); - - let index_in_chunk_cycles_correctly = (one.clone() - - max_minus_index_in_chunk_inv.clone() - * (rate_minus_one.clone() - index_in_chunk.clone())) - * index_in_chunk_next.clone() - + max_minus_index_in_chunk_inv.clone() - * (index_in_chunk_next.clone() - index_in_chunk.clone() - one.clone()); - - let hash_input_indicator_is_0_or_remains_unchanged = - is_hash_input_padding.clone() * (is_hash_input_padding_next.clone() - one.clone()); - - let first_hash_input_padding_is_1 = (is_hash_input_padding.clone() - one.clone()) - * is_hash_input_padding_next - * (instruction_next.clone() - one.clone()); - - let hash_input_padding_is_0_after_the_first_1 = - is_hash_input_padding.clone() * instruction_next.clone(); - - let next_row_is_table_padding_row = is_table_padding_next.clone() - one.clone(); - let table_padding_starts_when_hash_input_padding_is_active_and_index_in_chunk_is_zero = - is_hash_input_padding.clone() - * (one.clone() - - max_minus_index_in_chunk_inv.clone() - * (rate_minus_one.clone() - index_in_chunk.clone())) - * next_row_is_table_padding_row.clone(); - - let log_derivative_remains = log_derivative_next.clone() - log_derivative.clone(); - let compressed_row = challenge(ProgramAddressWeight) * address - + challenge(ProgramInstructionWeight) * instruction - + challenge(ProgramNextInstructionWeight) * instruction_next.clone(); - - let indeterminate = challenge(InstructionLookupIndeterminate); - let log_derivative_updates = (log_derivative_next - log_derivative) - * (indeterminate - compressed_row) - - lookup_multiplicity; - let log_derivative_updates_if_and_only_if_not_a_padding_row = - (one.clone() - is_hash_input_padding.clone()) * log_derivative_updates - + is_hash_input_padding * log_derivative_remains; - - let prepare_chunk_running_evaluation_absorbs_next_instruction = - prepare_chunk_running_evaluation_next.clone() - - prepare_chunk_indeterminate.clone() * prepare_chunk_running_evaluation - - instruction_next.clone(); - let prepare_chunk_running_evaluation_resets_and_absorbs_next_instruction = - prepare_chunk_running_evaluation_next.clone() - - prepare_chunk_indeterminate - - instruction_next; - let index_in_chunk_is_max = rate_minus_one.clone() - index_in_chunk.clone(); - let index_in_chunk_is_not_max = - one.clone() - max_minus_index_in_chunk_inv * (rate_minus_one.clone() - index_in_chunk); - let prepare_chunk_running_evaluation_resets_every_rate_rows_and_absorbs_next_instruction = - index_in_chunk_is_max * prepare_chunk_running_evaluation_absorbs_next_instruction - + index_in_chunk_is_not_max - * prepare_chunk_running_evaluation_resets_and_absorbs_next_instruction; - - let send_chunk_running_evaluation_absorbs_next_chunk = send_chunk_running_evaluation_next - .clone() - - send_chunk_indeterminate * send_chunk_running_evaluation.clone() - - prepare_chunk_running_evaluation_next; - let send_chunk_running_evaluation_does_not_change = - send_chunk_running_evaluation_next - send_chunk_running_evaluation; - let index_in_chunk_next_is_max = rate_minus_one - index_in_chunk_next; - let index_in_chunk_next_is_not_max = - one - max_minus_index_in_chunk_inv_next * index_in_chunk_next_is_max.clone(); - - let send_chunk_running_eval_absorbs_chunk_iff_index_in_chunk_next_is_max_and_not_padding_row = - send_chunk_running_evaluation_absorbs_next_chunk - * next_row_is_table_padding_row - * index_in_chunk_next_is_not_max - + send_chunk_running_evaluation_does_not_change.clone() * is_table_padding_next - + send_chunk_running_evaluation_does_not_change * index_in_chunk_next_is_max; - - vec![ - address_increases_by_one, - is_table_padding_is_0_or_remains_unchanged, - index_in_chunk_cycles_correctly, - hash_input_indicator_is_0_or_remains_unchanged, - first_hash_input_padding_is_1, - hash_input_padding_is_0_after_the_first_1, - table_padding_starts_when_hash_input_padding_is_active_and_index_in_chunk_is_zero, - log_derivative_updates_if_and_only_if_not_a_padding_row, - prepare_chunk_running_evaluation_resets_every_rate_rows_and_absorbs_next_instruction, - send_chunk_running_eval_absorbs_chunk_iff_index_in_chunk_next_is_max_and_not_padding_row, - ] - } - - pub fn terminal_constraints( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constant = |c: u64| circuit_builder.b_constant(c); - let base_row = |col: ProgramBaseTableColumn| { - circuit_builder.input(BaseRow(col.master_base_table_index())) - }; - - let index_in_chunk = base_row(IndexInChunk); - let is_hash_input_padding = base_row(IsHashInputPadding); - let is_table_padding = base_row(IsTablePadding); - - let hash_input_padding_is_one = is_hash_input_padding - constant(1); - - let max_index_in_chunk = Tip5::RATE as u64 - 1; - let index_in_chunk_is_max_or_row_is_padding_row = - (index_in_chunk - constant(max_index_in_chunk)) * (is_table_padding - constant(1)); - - vec![ - hash_input_padding_is_one, - index_in_chunk_is_max_or_row_is_padding_row, - ] - } -} - -impl ProgramTable { - pub fn fill_trace( - program_table: &mut ArrayViewMut2, - aet: &AlgebraicExecutionTrace, - ) { - let max_index_in_chunk = bfe!(Tip5::RATE as u64 - 1); - - let instructions = aet.program.to_bwords(); - let program_len = instructions.len(); - let padded_program_len = aet.height_of_table(TableId::Program); - - let one_iter = bfe_array![1].into_iter(); - let zero_iter = bfe_array![0].into_iter(); - let padding_iter = one_iter.chain(zero_iter.cycle()); - let padded_instructions = instructions.into_iter().chain(padding_iter); - let padded_instructions = padded_instructions.take(padded_program_len); - - for (row_idx, instruction) in padded_instructions.enumerate() { - let address = u64::try_from(row_idx).unwrap(); - let address = bfe!(address); - - let lookup_multiplicity = match row_idx.cmp(&program_len) { - Ordering::Less => aet.instruction_multiplicities[row_idx], - _ => 0, - }; - let lookup_multiplicity = bfe!(lookup_multiplicity); - let index_in_chunk = bfe!((row_idx % Tip5::RATE) as u64); - - let max_minus_index_in_chunk_inv = - (max_index_in_chunk - index_in_chunk).inverse_or_zero(); - - let is_hash_input_padding = match row_idx.cmp(&program_len) { - Ordering::Less => bfe!(0), - _ => bfe!(1), - }; - - let mut current_row = program_table.row_mut(row_idx); - current_row[Address.base_table_index()] = address; - current_row[Instruction.base_table_index()] = instruction; - current_row[LookupMultiplicity.base_table_index()] = lookup_multiplicity; - current_row[IndexInChunk.base_table_index()] = index_in_chunk; - current_row[MaxMinusIndexInChunkInv.base_table_index()] = max_minus_index_in_chunk_inv; - current_row[IsHashInputPadding.base_table_index()] = is_hash_input_padding; - } - } - - pub fn pad_trace(mut program_table: ArrayViewMut2, program_len: usize) { - let addresses = - (program_len..program_table.nrows()).map(|a| bfe!(u64::try_from(a).unwrap())); - let addresses = Array1::from_iter(addresses); - let address_column = program_table.slice_mut(s![program_len.., Address.base_table_index()]); - addresses.move_into(address_column); - - let indices_in_chunks = (program_len..program_table.nrows()) - .map(|idx| idx % Tip5::RATE) - .map(|ac| bfe!(u64::try_from(ac).unwrap())); - let indices_in_chunks = Array1::from_iter(indices_in_chunks); - let index_in_chunk_column = - program_table.slice_mut(s![program_len.., IndexInChunk.base_table_index()]); - indices_in_chunks.move_into(index_in_chunk_column); - - let max_minus_indices_in_chunks_inverses = (program_len..program_table.nrows()) - .map(|idx| Tip5::RATE - 1 - (idx % Tip5::RATE)) - .map(|ac| BFieldElement::new(ac.try_into().unwrap())) - .map(|bfe| bfe.inverse_or_zero()); - let max_minus_indices_in_chunks_inverses = - Array1::from_iter(max_minus_indices_in_chunks_inverses); - let max_minus_index_in_chunk_inv_column = program_table.slice_mut(s![ - program_len.., - MaxMinusIndexInChunkInv.base_table_index() - ]); - max_minus_indices_in_chunks_inverses.move_into(max_minus_index_in_chunk_inv_column); - - program_table - .slice_mut(s![program_len.., IsHashInputPadding.base_table_index()]) - .fill(BFieldElement::one()); - program_table - .slice_mut(s![program_len.., IsTablePadding.base_table_index()]) - .fill(BFieldElement::one()); - } - - pub fn extend( - base_table: ArrayView2, - mut ext_table: ArrayViewMut2, - challenges: &Challenges, - ) { - profiler!(start "program table"); - assert_eq!(BASE_WIDTH, base_table.ncols()); - assert_eq!(EXT_WIDTH, ext_table.ncols()); - assert_eq!(base_table.nrows(), ext_table.nrows()); - - let mut instruction_lookup_log_derivative = LookupArg::default_initial(); - let mut prepare_chunk_running_evaluation = EvalArg::default_initial(); - let mut send_chunk_running_evaluation = EvalArg::default_initial(); - - for (idx, consecutive_rows) in base_table.windows([2, BASE_WIDTH]).into_iter().enumerate() { - let row = consecutive_rows.row(0); - let next_row = consecutive_rows.row(1); - let mut extension_row = ext_table.row_mut(idx); - - // In the Program Table, the logarithmic derivative for the instruction lookup - // argument does record the initial in the first row, as an exception to all other - // table-linking arguments. - // This is necessary because an instruction's potential argument, or else the next - // instruction, is recorded in the next row. To be able to check correct initialization - // of the logarithmic derivative, both the current and the next row must be accessible - // to the constraint. Only transition constraints can access both rows. Hence, the - // initial value of the logarithmic derivative must be independent of the second row. - // The logarithmic derivative's final value, allowing for a meaningful cross-table - // argument, is recorded in the first padding row. This row is guaranteed to exist - // due to the hash-input padding mechanics. - extension_row[InstructionLookupServerLogDerivative.ext_table_index()] = - instruction_lookup_log_derivative; - - instruction_lookup_log_derivative = Self::update_instruction_lookup_log_derivative( - challenges, - row, - next_row, - instruction_lookup_log_derivative, - ); - prepare_chunk_running_evaluation = Self::update_prepare_chunk_running_evaluation( - row, - challenges, - prepare_chunk_running_evaluation, - ); - send_chunk_running_evaluation = Self::update_send_chunk_running_evaluation( - row, - challenges, - send_chunk_running_evaluation, - prepare_chunk_running_evaluation, - ); - - extension_row[PrepareChunkRunningEvaluation.ext_table_index()] = - prepare_chunk_running_evaluation; - extension_row[SendChunkRunningEvaluation.ext_table_index()] = - send_chunk_running_evaluation; - } - - // special treatment for the last row - let base_rows_iter = base_table.rows().into_iter(); - let ext_rows_iter = ext_table.rows_mut().into_iter(); - let last_base_row = base_rows_iter.last().unwrap(); - let mut last_ext_row = ext_rows_iter.last().unwrap(); - - prepare_chunk_running_evaluation = Self::update_prepare_chunk_running_evaluation( - last_base_row, - challenges, - prepare_chunk_running_evaluation, - ); - send_chunk_running_evaluation = Self::update_send_chunk_running_evaluation( - last_base_row, - challenges, - send_chunk_running_evaluation, - prepare_chunk_running_evaluation, - ); - - last_ext_row[InstructionLookupServerLogDerivative.ext_table_index()] = - instruction_lookup_log_derivative; - last_ext_row[PrepareChunkRunningEvaluation.ext_table_index()] = - prepare_chunk_running_evaluation; - last_ext_row[SendChunkRunningEvaluation.ext_table_index()] = send_chunk_running_evaluation; - - profiler!(stop "program table"); - } - - fn update_instruction_lookup_log_derivative( - challenges: &Challenges, - row: ArrayView1, - next_row: ArrayView1, - instruction_lookup_log_derivative: XFieldElement, - ) -> XFieldElement { - if row[IsHashInputPadding.base_table_index()].is_one() { - return instruction_lookup_log_derivative; - } - instruction_lookup_log_derivative - + Self::instruction_lookup_log_derivative_summand(row, next_row, challenges) - } - - fn instruction_lookup_log_derivative_summand( - row: ArrayView1, - next_row: ArrayView1, - challenges: &Challenges, - ) -> XFieldElement { - let compressed_row = row[Address.base_table_index()] * challenges[ProgramAddressWeight] - + row[Instruction.base_table_index()] * challenges[ProgramInstructionWeight] - + next_row[Instruction.base_table_index()] * challenges[ProgramNextInstructionWeight]; - (challenges[InstructionLookupIndeterminate] - compressed_row).inverse() - * row[LookupMultiplicity.base_table_index()] - } - - fn update_prepare_chunk_running_evaluation( - row: ArrayView1, - challenges: &Challenges, - prepare_chunk_running_evaluation: XFieldElement, - ) -> XFieldElement { - let running_evaluation_resets = row[IndexInChunk.base_table_index()].is_zero(); - let prepare_chunk_running_evaluation = match running_evaluation_resets { - true => EvalArg::default_initial(), - false => prepare_chunk_running_evaluation, - }; - - prepare_chunk_running_evaluation * challenges[ProgramAttestationPrepareChunkIndeterminate] - + row[Instruction.base_table_index()] - } - - fn update_send_chunk_running_evaluation( - row: ArrayView1, - challenges: &Challenges, - send_chunk_running_evaluation: XFieldElement, - prepare_chunk_running_evaluation: XFieldElement, - ) -> XFieldElement { - let index_in_chunk = row[IndexInChunk.base_table_index()]; - let is_table_padding_row = row[IsTablePadding.base_table_index()].is_one(); - let max_index_in_chunk = Tip5::RATE as u64 - 1; - let running_evaluation_needs_update = - !is_table_padding_row && index_in_chunk.value() == max_index_in_chunk; - - if !running_evaluation_needs_update { - return send_chunk_running_evaluation; - } - - send_chunk_running_evaluation * challenges[ProgramAttestationSendChunkIndeterminate] - + prepare_chunk_running_evaluation - } -} diff --git a/triton-vm/src/table/ram.rs b/triton-vm/src/table/ram.rs new file mode 100644 index 000000000..0a0f33b9a --- /dev/null +++ b/triton-vm/src/table/ram.rs @@ -0,0 +1,468 @@ +use std::cmp::Ordering; + +use air::challenge_id::ChallengeId::*; +use air::cross_table_argument::*; +use air::table::ram::RamTable; +use air::table::ram::PADDING_INDICATOR; +use air::table::TableId; +use air::table_column::RamBaseTableColumn::*; +use air::table_column::RamExtTableColumn::*; +use air::table_column::*; +use air::AIR; +use arbitrary::Arbitrary; +use constraint_circuit::DualRowIndicator::*; +use constraint_circuit::SingleRowIndicator::*; +use constraint_circuit::*; +use itertools::Itertools; +use ndarray::parallel::prelude::*; +use ndarray::prelude::*; +use num_traits::ConstOne; +use num_traits::One; +use num_traits::Zero; +use serde::Deserialize; +use serde::Serialize; +use strum::EnumCount; +use strum::IntoEnumIterator; +use twenty_first::math::traits::FiniteField; +use twenty_first::prelude::*; + +use crate::aet::AlgebraicExecutionTrace; +use crate::challenges::Challenges; +use crate::ndarray_helper::contiguous_column_slices; +use crate::ndarray_helper::horizontal_multi_slice_mut; +use crate::profiler::profiler; +use crate::table::TraceTable; + +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize, Arbitrary)] +pub struct RamTableCall { + pub clk: u32, + pub ram_pointer: BFieldElement, + pub ram_value: BFieldElement, + pub is_write: bool, +} + +impl RamTableCall { + pub fn to_table_row(self) -> Array1 { + let instruction_type = if self.is_write { + air::table::ram::INSTRUCTION_TYPE_WRITE + } else { + air::table::ram::INSTRUCTION_TYPE_READ + }; + + let mut row = Array1::zeros(::MainColumn::COUNT); + row[CLK.base_table_index()] = self.clk.into(); + row[InstructionType.base_table_index()] = instruction_type; + row[RamPointer.base_table_index()] = self.ram_pointer; + row[RamValue.base_table_index()] = self.ram_value; + row + } +} + +impl TraceTable for RamTable { + type FillParam = (); + type FillReturnInfo = Vec; + + fn fill( + mut ram_table: ArrayViewMut2, + aet: &AlgebraicExecutionTrace, + _: Self::FillParam, + ) -> Self::FillReturnInfo { + let mut ram_table = ram_table.slice_mut(s![0..aet.height_of_table(TableId::Ram), ..]); + let trace_iter = aet.ram_trace.rows().into_iter(); + + let sorted_rows = + trace_iter.sorted_by(|row_0, row_1| compare_rows(row_0.view(), row_1.view())); + for (row_index, row) in sorted_rows.enumerate() { + ram_table.row_mut(row_index).assign(&row); + } + + let all_ram_pointers = ram_table.column(RamPointer.base_table_index()); + let unique_ram_pointers = all_ram_pointers.iter().unique().copied().collect_vec(); + let (bezout_0, bezout_1) = + bezout_coefficient_polynomials_coefficients(&unique_ram_pointers); + + make_ram_table_consistent(&mut ram_table, bezout_0, bezout_1) + } + + fn pad(mut main_table: ArrayViewMut2, table_len: usize) { + let last_row_index = table_len.saturating_sub(1); + let mut padding_row = main_table.row(last_row_index).to_owned(); + padding_row[InstructionType.base_table_index()] = PADDING_INDICATOR; + if table_len == 0 { + padding_row[BezoutCoefficientPolynomialCoefficient1.base_table_index()] = + BFieldElement::ONE; + } + + let mut padding_section = main_table.slice_mut(s![table_len.., ..]); + padding_section + .axis_iter_mut(Axis(0)) + .into_par_iter() + .for_each(|mut row| row.assign(&padding_row)); + } + + fn extend( + base_table: ArrayView2, + mut ext_table: ArrayViewMut2, + challenges: &Challenges, + ) { + profiler!(start "ram table"); + assert_eq!(Self::MainColumn::COUNT, base_table.ncols()); + assert_eq!(Self::AuxColumn::COUNT, ext_table.ncols()); + assert_eq!(base_table.nrows(), ext_table.nrows()); + + let extension_column_indices = RamExtTableColumn::iter() + // RunningProductOfRAMP + FormalDerivative are constitute one + // slice and are populated by the same function + .filter(|column| *column != RamExtTableColumn::FormalDerivative) + .map(|column| column.ext_table_index()) + .collect_vec(); + let extension_column_slices = horizontal_multi_slice_mut( + ext_table.view_mut(), + &contiguous_column_slices(&extension_column_indices), + ); + let extension_functions = [ + extension_column_running_product_of_ramp_and_formal_derivative, + extension_column_bezout_coefficient_0, + extension_column_bezout_coefficient_1, + extension_column_running_product_perm_arg, + extension_column_clock_jump_difference_lookup_log_derivative, + ]; + extension_functions + .into_par_iter() + .zip_eq(extension_column_slices) + .for_each(|(generator, slice)| { + generator(base_table, challenges).move_into(slice); + }); + + profiler!(stop "ram table"); + } +} + +fn compare_rows(row_0: ArrayView1, row_1: ArrayView1) -> Ordering { + let ram_pointer_0 = row_0[RamPointer.base_table_index()].value(); + let ram_pointer_1 = row_1[RamPointer.base_table_index()].value(); + let compare_ram_pointers = ram_pointer_0.cmp(&ram_pointer_1); + + let clk_0 = row_0[CLK.base_table_index()].value(); + let clk_1 = row_1[CLK.base_table_index()].value(); + let compare_clocks = clk_0.cmp(&clk_1); + + compare_ram_pointers.then(compare_clocks) +} + +/// Compute the [Bézout coefficients](https://en.wikipedia.org/wiki/B%C3%A9zout%27s_identity) +/// of the polynomial with the given roots and its formal derivative. +/// +/// All roots _must_ be unique. That is, the corresponding polynomial must be square free. +#[doc(hidden)] // public for benchmarking purposes only +pub fn bezout_coefficient_polynomials_coefficients( + unique_roots: &[BFieldElement], +) -> (Vec, Vec) { + if unique_roots.is_empty() { + return (vec![], vec![]); + } + + // The structure of the problem is exploited heavily to compute the Bézout coefficients + // as fast as possible. In the following paragraphs, let `rp` denote the polynomial with the + // given `unique_roots` as its roots, and `fd` the formal derivative of `rp`. + // + // The naïve approach is to perform the extended Euclidean algorithm (xgcd) on `rp` and + // `fd`. This has a time complexity in O(n^2) where `n` is the number of roots: for the + // given problem shape, the degrees `rp` and `fd` are `n` and `n-1`, respectively. Each step + // of the (x)gcd takes O(n) time and reduces the degree of the polynomials by one. + // For programs with a large number of different RAM accesses, `n` is large. + // + // The approach taken here is to exploit the structure of the problem. Concretely, since all + // roots of `rp` are unique, _i.e._, `rp` is square free, the gcd of `rp` and `fd` is 1. + // This implies `∀ r ∈ unique_roots: fd(r)·b(r) = 1`, where `b` is one of the Bézout + // coefficients. In other words, the evaluation of `fd` in `unique_roots` is the inverse of + // the evaluation of `b` in `unique_roots`. Furthermore, `b` is a polynomial of degree `n`, + // and therefore fully determined by the evaluations in `unique_roots`. Finally, the other + // Bézout coefficient `a` is determined by `a = (1 - fd·b) / rp`. + // In total, this allows computing the Bézout coefficients in O(n·(log n)^2) time. + + debug_assert!(unique_roots.iter().all_unique()); + let rp = Polynomial::par_zerofier(unique_roots); + let fd = rp.formal_derivative(); + let fd_in_roots = fd.par_batch_evaluate(unique_roots); + let b_in_roots = BFieldElement::batch_inversion(fd_in_roots); + let b = Polynomial::par_interpolate(unique_roots, &b_in_roots); + let one_minus_fd_b = Polynomial::one() - fd.multiply(&b); + let a = one_minus_fd_b.clean_divide(rp); + + let mut coefficients_0 = a.coefficients; + let mut coefficients_1 = b.coefficients; + coefficients_0.resize(unique_roots.len(), bfe!(0)); + coefficients_1.resize(unique_roots.len(), bfe!(0)); + (coefficients_0, coefficients_1) +} + +/// - Set inverse of RAM pointer difference +/// - Fill in the Bézout coefficients if the RAM pointer changes between two consecutive rows +/// - Collect and return all clock jump differences +fn make_ram_table_consistent( + ram_table: &mut ArrayViewMut2, + mut bezout_coefficient_polynomial_coefficients_0: Vec, + mut bezout_coefficient_polynomial_coefficients_1: Vec, +) -> Vec { + if ram_table.nrows() == 0 { + assert_eq!(0, bezout_coefficient_polynomial_coefficients_0.len()); + assert_eq!(0, bezout_coefficient_polynomial_coefficients_1.len()); + return vec![]; + } + + let mut current_bcpc_0 = bezout_coefficient_polynomial_coefficients_0.pop().unwrap(); + let mut current_bcpc_1 = bezout_coefficient_polynomial_coefficients_1.pop().unwrap(); + ram_table.row_mut(0)[BezoutCoefficientPolynomialCoefficient0.base_table_index()] = + current_bcpc_0; + ram_table.row_mut(0)[BezoutCoefficientPolynomialCoefficient1.base_table_index()] = + current_bcpc_1; + + let mut clock_jump_differences = vec![]; + for row_idx in 0..ram_table.nrows() - 1 { + let (mut curr_row, mut next_row) = + ram_table.multi_slice_mut((s![row_idx, ..], s![row_idx + 1, ..])); + + let ramp_diff = + next_row[RamPointer.base_table_index()] - curr_row[RamPointer.base_table_index()]; + let clk_diff = next_row[CLK.base_table_index()] - curr_row[CLK.base_table_index()]; + + if ramp_diff.is_zero() { + clock_jump_differences.push(clk_diff); + } else { + current_bcpc_0 = bezout_coefficient_polynomial_coefficients_0.pop().unwrap(); + current_bcpc_1 = bezout_coefficient_polynomial_coefficients_1.pop().unwrap(); + } + + curr_row[InverseOfRampDifference.base_table_index()] = ramp_diff.inverse_or_zero(); + next_row[BezoutCoefficientPolynomialCoefficient0.base_table_index()] = current_bcpc_0; + next_row[BezoutCoefficientPolynomialCoefficient1.base_table_index()] = current_bcpc_1; + } + + assert_eq!(0, bezout_coefficient_polynomial_coefficients_0.len()); + assert_eq!(0, bezout_coefficient_polynomial_coefficients_1.len()); + clock_jump_differences +} + +fn extension_column_running_product_of_ramp_and_formal_derivative( + base_table: ArrayView2, + challenges: &Challenges, +) -> Array2 { + let bezout_indeterminate = challenges[RamTableBezoutRelationIndeterminate]; + + let mut extension_columns = Vec::with_capacity(2 * base_table.nrows()); + let mut running_product_ram_pointer = + bezout_indeterminate - base_table.row(0)[RamPointer.base_table_index()]; + let mut formal_derivative = xfe!(1); + + extension_columns.push(running_product_ram_pointer); + extension_columns.push(formal_derivative); + + for (previous_row, current_row) in base_table.rows().into_iter().tuple_windows() { + let instruction_type = current_row[InstructionType.base_table_index()]; + let is_no_padding_row = instruction_type != PADDING_INDICATOR; + + if is_no_padding_row { + let current_ram_pointer = current_row[RamPointer.base_table_index()]; + let previous_ram_pointer = previous_row[RamPointer.base_table_index()]; + if previous_ram_pointer != current_ram_pointer { + formal_derivative = (bezout_indeterminate - current_ram_pointer) + * formal_derivative + + running_product_ram_pointer; + running_product_ram_pointer *= bezout_indeterminate - current_ram_pointer; + } + } + + extension_columns.push(running_product_ram_pointer); + extension_columns.push(formal_derivative); + } + + Array2::from_shape_vec((base_table.nrows(), 2), extension_columns).unwrap() +} + +fn extension_column_bezout_coefficient_0( + base_table: ArrayView2, + challenges: &Challenges, +) -> Array2 { + extension_column_bezout_coefficient( + base_table, + challenges, + BezoutCoefficientPolynomialCoefficient0, + ) +} + +fn extension_column_bezout_coefficient_1( + base_table: ArrayView2, + challenges: &Challenges, +) -> Array2 { + extension_column_bezout_coefficient( + base_table, + challenges, + BezoutCoefficientPolynomialCoefficient1, + ) +} + +fn extension_column_bezout_coefficient( + base_table: ArrayView2, + challenges: &Challenges, + bezout_cefficient_column: RamBaseTableColumn, +) -> Array2 { + let bezout_indeterminate = challenges[RamTableBezoutRelationIndeterminate]; + + let mut bezout_coefficient = + base_table.row(0)[bezout_cefficient_column.base_table_index()].lift(); + let mut extension_column = Vec::with_capacity(base_table.nrows()); + extension_column.push(bezout_coefficient); + + for (previous_row, current_row) in base_table.rows().into_iter().tuple_windows() { + if current_row[InstructionType.base_table_index()] == PADDING_INDICATOR { + break; // padding marks the end of the trace + } + + let previous_ram_pointer = previous_row[RamPointer.base_table_index()]; + let current_ram_pointer = current_row[RamPointer.base_table_index()]; + if previous_ram_pointer != current_ram_pointer { + bezout_coefficient *= bezout_indeterminate; + bezout_coefficient += current_row[bezout_cefficient_column.base_table_index()]; + } + extension_column.push(bezout_coefficient); + } + + // fill padding section + extension_column.resize(base_table.nrows(), bezout_coefficient); + Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() +} + +fn extension_column_running_product_perm_arg( + base_table: ArrayView2, + challenges: &Challenges, +) -> Array2 { + let mut running_product_for_perm_arg = PermArg::default_initial(); + let mut extension_column = Vec::with_capacity(base_table.nrows()); + for row in base_table.rows() { + let instruction_type = row[InstructionType.base_table_index()]; + if instruction_type == PADDING_INDICATOR { + break; // padding marks the end of the trace + } + + let clk = row[CLK.base_table_index()]; + let current_ram_pointer = row[RamPointer.base_table_index()]; + let ram_value = row[RamValue.base_table_index()]; + let compressed_row = clk * challenges[RamClkWeight] + + instruction_type * challenges[RamInstructionTypeWeight] + + current_ram_pointer * challenges[RamPointerWeight] + + ram_value * challenges[RamValueWeight]; + running_product_for_perm_arg *= challenges[RamIndeterminate] - compressed_row; + extension_column.push(running_product_for_perm_arg); + } + + // fill padding section + extension_column.resize(base_table.nrows(), running_product_for_perm_arg); + Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() +} + +fn extension_column_clock_jump_difference_lookup_log_derivative( + base_table: ArrayView2, + challenges: &Challenges, +) -> Array2 { + let indeterminate = challenges[ClockJumpDifferenceLookupIndeterminate]; + + let mut cjd_lookup_log_derivative = LookupArg::default_initial(); + let mut extension_column = Vec::with_capacity(base_table.nrows()); + extension_column.push(cjd_lookup_log_derivative); + + for (previous_row, current_row) in base_table.rows().into_iter().tuple_windows() { + if current_row[InstructionType.base_table_index()] == PADDING_INDICATOR { + break; // padding marks the end of the trace + } + + let previous_ram_pointer = previous_row[RamPointer.base_table_index()]; + let current_ram_pointer = current_row[RamPointer.base_table_index()]; + if previous_ram_pointer == current_ram_pointer { + let previous_clock = previous_row[CLK.base_table_index()]; + let current_clock = current_row[CLK.base_table_index()]; + let clock_jump_difference = current_clock - previous_clock; + let log_derivative_summand = (indeterminate - clock_jump_difference).inverse(); + cjd_lookup_log_derivative += log_derivative_summand; + } + extension_column.push(cjd_lookup_log_derivative); + } + + // fill padding section + extension_column.resize(base_table.nrows(), cjd_lookup_log_derivative); + Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() +} + +#[cfg(test)] +pub(crate) mod tests { + use proptest::prelude::*; + use proptest_arbitrary_interop::arb; + use test_strategy::proptest; + + use super::*; + + #[proptest] + fn ram_table_call_can_be_converted_to_table_row( + #[strategy(arb())] ram_table_call: RamTableCall, + ) { + ram_table_call.to_table_row(); + } + + #[test] + fn bezout_coefficient_polynomials_of_empty_ram_table_are_default() { + let (a, b) = bezout_coefficient_polynomials_coefficients(&[]); + assert_eq!(a, vec![]); + assert_eq!(b, vec![]); + } + + #[test] + fn bezout_coefficient_polynomials_are_as_expected() { + let rp = bfe_array![1, 2, 3]; + let (a, b) = bezout_coefficient_polynomials_coefficients(&rp); + + let expected_a = bfe_array![9, 0x7fff_ffff_7fff_fffc_u64, 0]; + let expected_b = bfe_array![5, 0xffff_fffe_ffff_fffb_u64, 0x7fff_ffff_8000_0002_u64]; + + assert_eq!(expected_a, *a); + assert_eq!(expected_b, *b); + } + + #[proptest] + fn bezout_coefficient_polynomials_agree_with_xgcd( + #[strategy(arb())] + #[filter(#ram_pointers.iter().all_unique())] + ram_pointers: Vec, + ) { + let (a, b) = bezout_coefficient_polynomials_coefficients(&ram_pointers); + + let rp = Polynomial::zerofier(&ram_pointers); + let fd = rp.formal_derivative(); + let (_, a_xgcd, b_xgcd) = Polynomial::xgcd(rp, fd); + + let mut a_xgcd = a_xgcd.coefficients; + let mut b_xgcd = b_xgcd.coefficients; + + a_xgcd.resize(ram_pointers.len(), bfe!(0)); + b_xgcd.resize(ram_pointers.len(), bfe!(0)); + + prop_assert_eq!(a, a_xgcd); + prop_assert_eq!(b, b_xgcd); + } + + #[proptest] + fn bezout_coefficients_are_actually_bezout_coefficients( + #[strategy(arb())] + #[filter(!#ram_pointers.is_empty())] + #[filter(#ram_pointers.iter().all_unique())] + ram_pointers: Vec, + ) { + let (a, b) = bezout_coefficient_polynomials_coefficients(&ram_pointers); + + let rp = Polynomial::zerofier(&ram_pointers); + let fd = rp.formal_derivative(); + + let [a, b] = [a, b].map(Polynomial::new); + let gcd = rp * a + fd * b; + prop_assert_eq!(Polynomial::one(), gcd); + } +} diff --git a/triton-vm/src/table/ram_table.rs b/triton-vm/src/table/ram_table.rs deleted file mode 100644 index e6065f969..000000000 --- a/triton-vm/src/table/ram_table.rs +++ /dev/null @@ -1,711 +0,0 @@ -use std::cmp::Ordering; - -use arbitrary::Arbitrary; -use constraint_circuit::DualRowIndicator::*; -use constraint_circuit::SingleRowIndicator::*; -use constraint_circuit::*; -use itertools::Itertools; -use ndarray::parallel::prelude::*; -use ndarray::prelude::*; -use num_traits::ConstOne; -use num_traits::One; -use num_traits::Zero; -use serde::Deserialize; -use serde::Serialize; -use strum::EnumCount; -use strum::IntoEnumIterator; -use twenty_first::math::traits::FiniteField; -use twenty_first::prelude::*; - -use crate::aet::AlgebraicExecutionTrace; -use crate::ndarray_helper::contiguous_column_slices; -use crate::ndarray_helper::horizontal_multi_slice_mut; -use crate::profiler::profiler; -use crate::table::challenges::ChallengeId::*; -use crate::table::challenges::Challenges; -use crate::table::cross_table_argument::*; -use crate::table::master_table::TableId; -use crate::table::table_column::RamBaseTableColumn::*; -use crate::table::table_column::RamExtTableColumn::*; -use crate::table::table_column::*; - -pub const BASE_WIDTH: usize = RamBaseTableColumn::COUNT; -pub const EXT_WIDTH: usize = RamExtTableColumn::COUNT; -pub const FULL_WIDTH: usize = BASE_WIDTH + EXT_WIDTH; - -pub const INSTRUCTION_TYPE_WRITE: BFieldElement = BFieldElement::new(0); -pub const INSTRUCTION_TYPE_READ: BFieldElement = BFieldElement::new(1); -pub const PADDING_INDICATOR: BFieldElement = BFieldElement::new(2); - -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize, Arbitrary)] -pub struct RamTableCall { - pub clk: u32, - pub ram_pointer: BFieldElement, - pub ram_value: BFieldElement, - pub is_write: bool, -} - -impl RamTableCall { - pub fn to_table_row(self) -> Array1 { - let instruction_type = match self.is_write { - true => INSTRUCTION_TYPE_WRITE, - false => INSTRUCTION_TYPE_READ, - }; - - let mut row = Array1::zeros(BASE_WIDTH); - row[CLK.base_table_index()] = self.clk.into(); - row[InstructionType.base_table_index()] = instruction_type; - row[RamPointer.base_table_index()] = self.ram_pointer; - row[RamValue.base_table_index()] = self.ram_value; - row - } -} - -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct RamTable; - -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct ExtRamTable; - -impl RamTable { - /// Fills the trace table in-place and returns all clock jump differences. - pub fn fill_trace( - ram_table: &mut ArrayViewMut2, - aet: &AlgebraicExecutionTrace, - ) -> Vec { - let mut ram_table = ram_table.slice_mut(s![0..aet.height_of_table(TableId::Ram), ..]); - let trace_iter = aet.ram_trace.rows().into_iter(); - - let sorted_rows = - trace_iter.sorted_by(|row_0, row_1| Self::compare_rows(row_0.view(), row_1.view())); - for (row_index, row) in sorted_rows.enumerate() { - ram_table.row_mut(row_index).assign(&row); - } - - let all_ram_pointers = ram_table.column(RamPointer.base_table_index()); - let unique_ram_pointers = all_ram_pointers.iter().unique().copied().collect_vec(); - let (bezout_0, bezout_1) = - Self::bezout_coefficient_polynomials_coefficients(&unique_ram_pointers); - - Self::make_ram_table_consistent(&mut ram_table, bezout_0, bezout_1) - } - - fn compare_rows( - row_0: ArrayView1, - row_1: ArrayView1, - ) -> Ordering { - let ram_pointer_0 = row_0[RamPointer.base_table_index()].value(); - let ram_pointer_1 = row_1[RamPointer.base_table_index()].value(); - let compare_ram_pointers = ram_pointer_0.cmp(&ram_pointer_1); - - let clk_0 = row_0[CLK.base_table_index()].value(); - let clk_1 = row_1[CLK.base_table_index()].value(); - let compare_clocks = clk_0.cmp(&clk_1); - - compare_ram_pointers.then(compare_clocks) - } - - /// Compute the [Bézout coefficients](https://en.wikipedia.org/wiki/B%C3%A9zout%27s_identity) - /// of the polynomial with the given roots and its formal derivative. - /// - /// All roots _must_ be unique. That is, the corresponding polynomial must be square free. - pub fn bezout_coefficient_polynomials_coefficients( - unique_roots: &[BFieldElement], - ) -> (Vec, Vec) { - if unique_roots.is_empty() { - return (vec![], vec![]); - } - - // The structure of the problem is exploited heavily to compute the Bézout coefficients - // as fast as possible. In the following paragraphs, let `rp` denote the polynomial with the - // given `unique_roots` as its roots, and `fd` the formal derivative of `rp`. - // - // The naïve approach is to perform the extended Euclidean algorithm (xgcd) on `rp` and - // `fd`. This has a time complexity in O(n^2) where `n` is the number of roots: for the - // given problem shape, the degrees `rp` and `fd` are `n` and `n-1`, respectively. Each step - // of the (x)gcd takes O(n) time and reduces the degree of the polynomials by one. - // For programs with a large number of different RAM accesses, `n` is large. - // - // The approach taken here is to exploit the structure of the problem. Concretely, since all - // roots of `rp` are unique, _i.e._, `rp` is square free, the gcd of `rp` and `fd` is 1. - // This implies `∀ r ∈ unique_roots: fd(r)·b(r) = 1`, where `b` is one of the Bézout - // coefficients. In other words, the evaluation of `fd` in `unique_roots` is the inverse of - // the evaluation of `b` in `unique_roots`. Furthermore, `b` is a polynomial of degree `n`, - // and therefore fully determined by the evaluations in `unique_roots`. Finally, the other - // Bézout coefficient `a` is determined by `a = (1 - fd·b) / rp`. - // In total, this allows computing the Bézout coefficients in O(n·(log n)^2) time. - - debug_assert!(unique_roots.iter().all_unique()); - let rp = Polynomial::par_zerofier(unique_roots); - let fd = rp.formal_derivative(); - let fd_in_roots = fd.par_batch_evaluate(unique_roots); - let b_in_roots = BFieldElement::batch_inversion(fd_in_roots); - let b = Polynomial::par_interpolate(unique_roots, &b_in_roots); - let one_minus_fd_b = Polynomial::one() - fd.multiply(&b); - let a = one_minus_fd_b.clean_divide(rp); - - let mut coefficients_0 = a.coefficients; - let mut coefficients_1 = b.coefficients; - coefficients_0.resize(unique_roots.len(), bfe!(0)); - coefficients_1.resize(unique_roots.len(), bfe!(0)); - (coefficients_0, coefficients_1) - } - - /// - Set inverse of RAM pointer difference - /// - Fill in the Bézout coefficients if the RAM pointer changes between two consecutive rows - /// - Collect and return all clock jump differences - fn make_ram_table_consistent( - ram_table: &mut ArrayViewMut2, - mut bezout_coefficient_polynomial_coefficients_0: Vec, - mut bezout_coefficient_polynomial_coefficients_1: Vec, - ) -> Vec { - if ram_table.nrows() == 0 { - assert_eq!(0, bezout_coefficient_polynomial_coefficients_0.len()); - assert_eq!(0, bezout_coefficient_polynomial_coefficients_1.len()); - return vec![]; - } - - let mut current_bcpc_0 = bezout_coefficient_polynomial_coefficients_0.pop().unwrap(); - let mut current_bcpc_1 = bezout_coefficient_polynomial_coefficients_1.pop().unwrap(); - ram_table.row_mut(0)[BezoutCoefficientPolynomialCoefficient0.base_table_index()] = - current_bcpc_0; - ram_table.row_mut(0)[BezoutCoefficientPolynomialCoefficient1.base_table_index()] = - current_bcpc_1; - - let mut clock_jump_differences = vec![]; - for row_idx in 0..ram_table.nrows() - 1 { - let (mut curr_row, mut next_row) = - ram_table.multi_slice_mut((s![row_idx, ..], s![row_idx + 1, ..])); - - let ramp_diff = - next_row[RamPointer.base_table_index()] - curr_row[RamPointer.base_table_index()]; - let clk_diff = next_row[CLK.base_table_index()] - curr_row[CLK.base_table_index()]; - - if ramp_diff.is_zero() { - clock_jump_differences.push(clk_diff); - } else { - current_bcpc_0 = bezout_coefficient_polynomial_coefficients_0.pop().unwrap(); - current_bcpc_1 = bezout_coefficient_polynomial_coefficients_1.pop().unwrap(); - } - - curr_row[InverseOfRampDifference.base_table_index()] = ramp_diff.inverse_or_zero(); - next_row[BezoutCoefficientPolynomialCoefficient0.base_table_index()] = current_bcpc_0; - next_row[BezoutCoefficientPolynomialCoefficient1.base_table_index()] = current_bcpc_1; - } - - assert_eq!(0, bezout_coefficient_polynomial_coefficients_0.len()); - assert_eq!(0, bezout_coefficient_polynomial_coefficients_1.len()); - clock_jump_differences - } - - pub fn pad_trace(mut ram_table: ArrayViewMut2, ram_table_len: usize) { - let last_row_index = ram_table_len.saturating_sub(1); - let mut padding_row = ram_table.row(last_row_index).to_owned(); - padding_row[InstructionType.base_table_index()] = PADDING_INDICATOR; - if ram_table_len == 0 { - padding_row[BezoutCoefficientPolynomialCoefficient1.base_table_index()] = - BFieldElement::ONE; - } - - let mut padding_section = ram_table.slice_mut(s![ram_table_len.., ..]); - padding_section - .axis_iter_mut(Axis(0)) - .into_par_iter() - .for_each(|mut row| row.assign(&padding_row)); - } - - pub fn extend( - base_table: ArrayView2, - mut ext_table: ArrayViewMut2, - challenges: &Challenges, - ) { - profiler!(start "ram table"); - assert_eq!(BASE_WIDTH, base_table.ncols()); - assert_eq!(EXT_WIDTH, ext_table.ncols()); - assert_eq!(base_table.nrows(), ext_table.nrows()); - - let extension_column_indices = RamExtTableColumn::iter() - // RunningProductOfRAMP + FormalDerivative are constitute one - // slice and are populated by the same function - .filter(|column| *column != RamExtTableColumn::FormalDerivative) - .map(|column| column.ext_table_index()) - .collect_vec(); - let extension_column_slices = horizontal_multi_slice_mut( - ext_table.view_mut(), - &contiguous_column_slices(&extension_column_indices), - ); - let extension_functions = [ - Self::extension_column_running_product_of_ramp_and_formal_derivative, - Self::extension_column_bezout_coefficient_0, - Self::extension_column_bezout_coefficient_1, - Self::extension_column_running_product_perm_arg, - Self::extension_column_clock_jump_difference_lookup_log_derivative, - ]; - extension_functions - .into_par_iter() - .zip_eq(extension_column_slices) - .for_each(|(generator, slice)| { - generator(base_table, challenges).move_into(slice); - }); - - profiler!(stop "ram table"); - } - - fn extension_column_running_product_of_ramp_and_formal_derivative( - base_table: ArrayView2, - challenges: &Challenges, - ) -> Array2 { - let bezout_indeterminate = challenges[RamTableBezoutRelationIndeterminate]; - - let mut extension_columns = Vec::with_capacity(2 * base_table.nrows()); - let mut running_product_ram_pointer = - bezout_indeterminate - base_table.row(0)[RamPointer.base_table_index()]; - let mut formal_derivative = xfe!(1); - - extension_columns.push(running_product_ram_pointer); - extension_columns.push(formal_derivative); - - for (previous_row, current_row) in base_table.rows().into_iter().tuple_windows() { - let instruction_type = current_row[InstructionType.base_table_index()]; - let is_no_padding_row = instruction_type != PADDING_INDICATOR; - - if is_no_padding_row { - let current_ram_pointer = current_row[RamPointer.base_table_index()]; - let previous_ram_pointer = previous_row[RamPointer.base_table_index()]; - if previous_ram_pointer != current_ram_pointer { - formal_derivative = (bezout_indeterminate - current_ram_pointer) - * formal_derivative - + running_product_ram_pointer; - running_product_ram_pointer *= bezout_indeterminate - current_ram_pointer; - } - } - - extension_columns.push(running_product_ram_pointer); - extension_columns.push(formal_derivative); - } - - Array2::from_shape_vec((base_table.nrows(), 2), extension_columns).unwrap() - } - - fn extension_column_bezout_coefficient_0( - base_table: ArrayView2, - challenges: &Challenges, - ) -> Array2 { - Self::extension_column_bezout_coefficient( - base_table, - challenges, - BezoutCoefficientPolynomialCoefficient0, - ) - } - - fn extension_column_bezout_coefficient_1( - base_table: ArrayView2, - challenges: &Challenges, - ) -> Array2 { - Self::extension_column_bezout_coefficient( - base_table, - challenges, - BezoutCoefficientPolynomialCoefficient1, - ) - } - - fn extension_column_bezout_coefficient( - base_table: ArrayView2, - challenges: &Challenges, - bezout_cefficient_column: RamBaseTableColumn, - ) -> Array2 { - let bezout_indeterminate = challenges[RamTableBezoutRelationIndeterminate]; - - let mut bezout_coefficient = - base_table.row(0)[bezout_cefficient_column.base_table_index()].lift(); - let mut extension_column = Vec::with_capacity(base_table.nrows()); - extension_column.push(bezout_coefficient); - - for (previous_row, current_row) in base_table.rows().into_iter().tuple_windows() { - if current_row[InstructionType.base_table_index()] == PADDING_INDICATOR { - break; // padding marks the end of the trace - } - - let previous_ram_pointer = previous_row[RamPointer.base_table_index()]; - let current_ram_pointer = current_row[RamPointer.base_table_index()]; - if previous_ram_pointer != current_ram_pointer { - bezout_coefficient *= bezout_indeterminate; - bezout_coefficient += current_row[bezout_cefficient_column.base_table_index()]; - } - extension_column.push(bezout_coefficient); - } - - // fill padding section - extension_column.resize(base_table.nrows(), bezout_coefficient); - Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() - } - - fn extension_column_running_product_perm_arg( - base_table: ArrayView2, - challenges: &Challenges, - ) -> Array2 { - let mut running_product_for_perm_arg = PermArg::default_initial(); - let mut extension_column = Vec::with_capacity(base_table.nrows()); - for row in base_table.rows() { - let instruction_type = row[InstructionType.base_table_index()]; - if instruction_type == PADDING_INDICATOR { - break; // padding marks the end of the trace - } - - let clk = row[CLK.base_table_index()]; - let current_ram_pointer = row[RamPointer.base_table_index()]; - let ram_value = row[RamValue.base_table_index()]; - let compressed_row = clk * challenges[RamClkWeight] - + instruction_type * challenges[RamInstructionTypeWeight] - + current_ram_pointer * challenges[RamPointerWeight] - + ram_value * challenges[RamValueWeight]; - running_product_for_perm_arg *= challenges[RamIndeterminate] - compressed_row; - extension_column.push(running_product_for_perm_arg); - } - - // fill padding section - extension_column.resize(base_table.nrows(), running_product_for_perm_arg); - Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() - } - - fn extension_column_clock_jump_difference_lookup_log_derivative( - base_table: ArrayView2, - challenges: &Challenges, - ) -> Array2 { - let indeterminate = challenges[ClockJumpDifferenceLookupIndeterminate]; - - let mut cjd_lookup_log_derivative = LookupArg::default_initial(); - let mut extension_column = Vec::with_capacity(base_table.nrows()); - extension_column.push(cjd_lookup_log_derivative); - - for (previous_row, current_row) in base_table.rows().into_iter().tuple_windows() { - if current_row[InstructionType.base_table_index()] == PADDING_INDICATOR { - break; // padding marks the end of the trace - } - - let previous_ram_pointer = previous_row[RamPointer.base_table_index()]; - let current_ram_pointer = current_row[RamPointer.base_table_index()]; - if previous_ram_pointer == current_ram_pointer { - let previous_clock = previous_row[CLK.base_table_index()]; - let current_clock = current_row[CLK.base_table_index()]; - let clock_jump_difference = current_clock - previous_clock; - let log_derivative_summand = (indeterminate - clock_jump_difference).inverse(); - cjd_lookup_log_derivative += log_derivative_summand; - } - extension_column.push(cjd_lookup_log_derivative); - } - - // fill padding section - extension_column.resize(base_table.nrows(), cjd_lookup_log_derivative); - Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() - } -} - -impl ExtRamTable { - pub fn initial_constraints( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let challenge = |c| circuit_builder.challenge(c); - let constant = |c| circuit_builder.b_constant(c); - let x_constant = |c| circuit_builder.x_constant(c); - let base_row = |column: RamBaseTableColumn| { - circuit_builder.input(BaseRow(column.master_base_table_index())) - }; - let ext_row = |column: RamExtTableColumn| { - circuit_builder.input(ExtRow(column.master_ext_table_index())) - }; - - let first_row_is_padding_row = base_row(InstructionType) - constant(PADDING_INDICATOR); - let first_row_is_not_padding_row = (base_row(InstructionType) - - constant(INSTRUCTION_TYPE_READ)) - * (base_row(InstructionType) - constant(INSTRUCTION_TYPE_WRITE)); - - let bezout_coefficient_polynomial_coefficient_0_is_0 = - base_row(BezoutCoefficientPolynomialCoefficient0); - let bezout_coefficient_0_is_0 = ext_row(BezoutCoefficient0); - let bezout_coefficient_1_is_bezout_coefficient_polynomial_coefficient_1 = - ext_row(BezoutCoefficient1) - base_row(BezoutCoefficientPolynomialCoefficient1); - let formal_derivative_is_1 = ext_row(FormalDerivative) - constant(1_u32.into()); - let running_product_polynomial_is_initialized_correctly = ext_row(RunningProductOfRAMP) - - challenge(RamTableBezoutRelationIndeterminate) - + base_row(RamPointer); - - let clock_jump_diff_log_derivative_is_default_initial = - ext_row(ClockJumpDifferenceLookupClientLogDerivative) - - x_constant(LookupArg::default_initial()); - - let compressed_row_for_permutation_argument = base_row(CLK) * challenge(RamClkWeight) - + base_row(InstructionType) * challenge(RamInstructionTypeWeight) - + base_row(RamPointer) * challenge(RamPointerWeight) - + base_row(RamValue) * challenge(RamValueWeight); - let running_product_permutation_argument_has_accumulated_first_row = - ext_row(RunningProductPermArg) - challenge(RamIndeterminate) - + compressed_row_for_permutation_argument; - let running_product_permutation_argument_is_default_initial = - ext_row(RunningProductPermArg) - x_constant(PermArg::default_initial()); - - let running_product_permutation_argument_starts_correctly = - running_product_permutation_argument_has_accumulated_first_row - * first_row_is_padding_row - + running_product_permutation_argument_is_default_initial - * first_row_is_not_padding_row; - - vec![ - bezout_coefficient_polynomial_coefficient_0_is_0, - bezout_coefficient_0_is_0, - bezout_coefficient_1_is_bezout_coefficient_polynomial_coefficient_1, - running_product_polynomial_is_initialized_correctly, - formal_derivative_is_1, - running_product_permutation_argument_starts_correctly, - clock_jump_diff_log_derivative_is_default_initial, - ] - } - - pub fn consistency_constraints( - _circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - // no further constraints - vec![] - } - - pub fn transition_constraints( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constant = |c| circuit_builder.b_constant(c); - let challenge = |c| circuit_builder.challenge(c); - let curr_base_row = |column: RamBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(column.master_base_table_index())) - }; - let curr_ext_row = |column: RamExtTableColumn| { - circuit_builder.input(CurrentExtRow(column.master_ext_table_index())) - }; - let next_base_row = |column: RamBaseTableColumn| { - circuit_builder.input(NextBaseRow(column.master_base_table_index())) - }; - let next_ext_row = |column: RamExtTableColumn| { - circuit_builder.input(NextExtRow(column.master_ext_table_index())) - }; - - let one = constant(1_u32.into()); - - let bezout_challenge = challenge(RamTableBezoutRelationIndeterminate); - - let clock = curr_base_row(CLK); - let ram_pointer = curr_base_row(RamPointer); - let ram_value = curr_base_row(RamValue); - let instruction_type = curr_base_row(InstructionType); - let inverse_of_ram_pointer_difference = curr_base_row(InverseOfRampDifference); - let bcpc0 = curr_base_row(BezoutCoefficientPolynomialCoefficient0); - let bcpc1 = curr_base_row(BezoutCoefficientPolynomialCoefficient1); - - let running_product_ram_pointer = curr_ext_row(RunningProductOfRAMP); - let fd = curr_ext_row(FormalDerivative); - let bc0 = curr_ext_row(BezoutCoefficient0); - let bc1 = curr_ext_row(BezoutCoefficient1); - let rppa = curr_ext_row(RunningProductPermArg); - let clock_jump_diff_log_derivative = - curr_ext_row(ClockJumpDifferenceLookupClientLogDerivative); - - let clock_next = next_base_row(CLK); - let ram_pointer_next = next_base_row(RamPointer); - let ram_value_next = next_base_row(RamValue); - let instruction_type_next = next_base_row(InstructionType); - let bcpc0_next = next_base_row(BezoutCoefficientPolynomialCoefficient0); - let bcpc1_next = next_base_row(BezoutCoefficientPolynomialCoefficient1); - - let running_product_ram_pointer_next = next_ext_row(RunningProductOfRAMP); - let fd_next = next_ext_row(FormalDerivative); - let bc0_next = next_ext_row(BezoutCoefficient0); - let bc1_next = next_ext_row(BezoutCoefficient1); - let rppa_next = next_ext_row(RunningProductPermArg); - let clock_jump_diff_log_derivative_next = - next_ext_row(ClockJumpDifferenceLookupClientLogDerivative); - - let next_row_is_padding_row = - instruction_type_next.clone() - constant(PADDING_INDICATOR).clone(); - let if_current_row_is_padding_row_then_next_row_is_padding_row = (instruction_type.clone() - - constant(INSTRUCTION_TYPE_READ)) - * (instruction_type - constant(INSTRUCTION_TYPE_WRITE)) - * next_row_is_padding_row.clone(); - - let ram_pointer_difference = ram_pointer_next.clone() - ram_pointer; - let ram_pointer_changes = one.clone() - - ram_pointer_difference.clone() * inverse_of_ram_pointer_difference.clone(); - - let iord_is_0_or_iord_is_inverse_of_ram_pointer_difference = - inverse_of_ram_pointer_difference * ram_pointer_changes.clone(); - - let ram_pointer_difference_is_0_or_iord_is_inverse_of_ram_pointer_difference = - ram_pointer_difference.clone() * ram_pointer_changes.clone(); - - let ram_pointer_changes_or_write_mem_or_ram_value_stays = ram_pointer_changes.clone() - * (constant(INSTRUCTION_TYPE_WRITE) - instruction_type_next.clone()) - * (ram_value_next.clone() - ram_value); - - let bcbp0_only_changes_if_ram_pointer_changes = - ram_pointer_changes.clone() * (bcpc0_next.clone() - bcpc0); - - let bcbp1_only_changes_if_ram_pointer_changes = - ram_pointer_changes.clone() * (bcpc1_next.clone() - bcpc1); - - let running_product_ram_pointer_updates_correctly = ram_pointer_difference.clone() - * (running_product_ram_pointer_next.clone() - - running_product_ram_pointer.clone() - * (bezout_challenge.clone() - ram_pointer_next.clone())) - + ram_pointer_changes.clone() - * (running_product_ram_pointer_next - running_product_ram_pointer.clone()); - - let formal_derivative_updates_correctly = ram_pointer_difference.clone() - * (fd_next.clone() - - running_product_ram_pointer - - (bezout_challenge.clone() - ram_pointer_next.clone()) * fd.clone()) - + ram_pointer_changes.clone() * (fd_next - fd); - - let bezout_coefficient_0_is_constructed_correctly = ram_pointer_difference.clone() - * (bc0_next.clone() - bezout_challenge.clone() * bc0.clone() - bcpc0_next) - + ram_pointer_changes.clone() * (bc0_next - bc0); - - let bezout_coefficient_1_is_constructed_correctly = ram_pointer_difference.clone() - * (bc1_next.clone() - bezout_challenge * bc1.clone() - bcpc1_next) - + ram_pointer_changes.clone() * (bc1_next - bc1); - - let compressed_row = clock_next.clone() * challenge(RamClkWeight) - + ram_pointer_next * challenge(RamPointerWeight) - + ram_value_next * challenge(RamValueWeight) - + instruction_type_next.clone() * challenge(RamInstructionTypeWeight); - let rppa_accumulates_next_row = - rppa_next.clone() - rppa.clone() * (challenge(RamIndeterminate) - compressed_row); - - let next_row_is_not_padding_row = (instruction_type_next.clone() - - constant(INSTRUCTION_TYPE_READ)) - * (instruction_type_next - constant(INSTRUCTION_TYPE_WRITE)); - let rppa_remains_unchanged = rppa_next - rppa; - - let rppa_updates_correctly = rppa_accumulates_next_row * next_row_is_padding_row.clone() - + rppa_remains_unchanged * next_row_is_not_padding_row.clone(); - - let clock_difference = clock_next - clock; - let log_derivative_accumulates = (clock_jump_diff_log_derivative_next.clone() - - clock_jump_diff_log_derivative.clone()) - * (challenge(ClockJumpDifferenceLookupIndeterminate) - clock_difference) - - one.clone(); - let log_derivative_remains = - clock_jump_diff_log_derivative_next - clock_jump_diff_log_derivative.clone(); - - let log_derivative_accumulates_or_ram_pointer_changes_or_next_row_is_padding_row = - log_derivative_accumulates * ram_pointer_changes.clone() * next_row_is_padding_row; - let log_derivative_remains_or_ram_pointer_doesnt_change = - log_derivative_remains.clone() * ram_pointer_difference.clone(); - let log_derivative_remains_or_next_row_is_not_padding_row = - log_derivative_remains * next_row_is_not_padding_row; - - let log_derivative_updates_correctly = - log_derivative_accumulates_or_ram_pointer_changes_or_next_row_is_padding_row - + log_derivative_remains_or_ram_pointer_doesnt_change - + log_derivative_remains_or_next_row_is_not_padding_row; - - vec![ - if_current_row_is_padding_row_then_next_row_is_padding_row, - iord_is_0_or_iord_is_inverse_of_ram_pointer_difference, - ram_pointer_difference_is_0_or_iord_is_inverse_of_ram_pointer_difference, - ram_pointer_changes_or_write_mem_or_ram_value_stays, - bcbp0_only_changes_if_ram_pointer_changes, - bcbp1_only_changes_if_ram_pointer_changes, - running_product_ram_pointer_updates_correctly, - formal_derivative_updates_correctly, - bezout_coefficient_0_is_constructed_correctly, - bezout_coefficient_1_is_constructed_correctly, - rppa_updates_correctly, - log_derivative_updates_correctly, - ] - } - - pub fn terminal_constraints( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let constant = |c: u32| circuit_builder.b_constant(c); - let ext_row = |column: RamExtTableColumn| { - circuit_builder.input(ExtRow(column.master_ext_table_index())) - }; - - let bezout_relation_holds = ext_row(BezoutCoefficient0) * ext_row(RunningProductOfRAMP) - + ext_row(BezoutCoefficient1) * ext_row(FormalDerivative) - - constant(1); - - vec![bezout_relation_holds] - } -} - -#[cfg(test)] -pub(crate) mod tests { - use proptest::prelude::*; - use proptest_arbitrary_interop::arb; - use test_strategy::proptest; - - use super::*; - - #[proptest] - fn ram_table_call_can_be_converted_to_table_row( - #[strategy(arb())] ram_table_call: RamTableCall, - ) { - ram_table_call.to_table_row(); - } - - #[test] - fn bezout_coefficient_polynomials_of_empty_ram_table_are_default() { - let (a, b) = RamTable::bezout_coefficient_polynomials_coefficients(&[]); - assert_eq!(a, vec![]); - assert_eq!(b, vec![]); - } - - #[test] - fn bezout_coefficient_polynomials_are_as_expected() { - let rp = bfe_array![1, 2, 3]; - let (a, b) = RamTable::bezout_coefficient_polynomials_coefficients(&rp); - - let expected_a = bfe_array![9, 0x7fff_ffff_7fff_fffc_u64, 0]; - let expected_b = bfe_array![5, 0xffff_fffe_ffff_fffb_u64, 0x7fff_ffff_8000_0002_u64]; - - assert_eq!(expected_a, *a); - assert_eq!(expected_b, *b); - } - - #[proptest] - fn bezout_coefficient_polynomials_agree_with_xgcd( - #[strategy(arb())] - #[filter(#ram_pointers.iter().all_unique())] - ram_pointers: Vec, - ) { - let (a, b) = RamTable::bezout_coefficient_polynomials_coefficients(&ram_pointers); - - let rp = Polynomial::zerofier(&ram_pointers); - let fd = rp.formal_derivative(); - let (_, a_xgcd, b_xgcd) = Polynomial::xgcd(rp, fd); - - let mut a_xgcd = a_xgcd.coefficients; - let mut b_xgcd = b_xgcd.coefficients; - - a_xgcd.resize(ram_pointers.len(), bfe!(0)); - b_xgcd.resize(ram_pointers.len(), bfe!(0)); - - prop_assert_eq!(a, a_xgcd); - prop_assert_eq!(b, b_xgcd); - } - - #[proptest] - fn bezout_coefficients_are_actually_bezout_coefficients( - #[strategy(arb())] - #[filter(!#ram_pointers.is_empty())] - #[filter(#ram_pointers.iter().all_unique())] - ram_pointers: Vec, - ) { - let (a, b) = RamTable::bezout_coefficient_polynomials_coefficients(&ram_pointers); - - let rp = Polynomial::zerofier(&ram_pointers); - let fd = rp.formal_derivative(); - - let [a, b] = [a, b].map(Polynomial::new); - let gcd = rp * a + fd * b; - prop_assert_eq!(Polynomial::one(), gcd); - } -} diff --git a/triton-vm/src/table/u32.rs b/triton-vm/src/table/u32.rs new file mode 100644 index 000000000..7b9f5f7be --- /dev/null +++ b/triton-vm/src/table/u32.rs @@ -0,0 +1,261 @@ +use std::cmp::max; +use std::ops::Mul; + +use air::challenge_id::ChallengeId::*; +use air::cross_table_argument::CrossTableArg; +use air::cross_table_argument::LookupArg; +use air::table::u32::U32Table; +use air::table_column::MasterBaseTableColumn; +use air::table_column::MasterExtTableColumn; +use air::table_column::U32BaseTableColumn; +use air::table_column::U32BaseTableColumn::*; +use air::table_column::U32ExtTableColumn; +use air::table_column::U32ExtTableColumn::*; +use arbitrary::Arbitrary; +use constraint_circuit::ConstraintCircuitBuilder; +use constraint_circuit::ConstraintCircuitMonad; +use constraint_circuit::DualRowIndicator; +use constraint_circuit::DualRowIndicator::*; +use constraint_circuit::InputIndicator; +use constraint_circuit::SingleRowIndicator; +use constraint_circuit::SingleRowIndicator::*; +use isa::instruction::Instruction; +use ndarray::parallel::prelude::*; +use ndarray::s; +use ndarray::Array1; +use ndarray::Array2; +use ndarray::ArrayView2; +use ndarray::ArrayViewMut2; +use ndarray::Axis; +use num_traits::One; +use num_traits::Zero; +use strum::EnumCount; +use twenty_first::prelude::*; + +use crate::aet::AlgebraicExecutionTrace; +use crate::challenges::Challenges; +use crate::profiler::profiler; +use crate::table::TraceTable; + +/// An executed u32 instruction as well as its operands. +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)] +pub struct U32TableEntry { + pub instruction: Instruction, + pub left_operand: BFieldElement, + pub right_operand: BFieldElement, +} + +impl U32TableEntry { + pub fn new(instruction: Instruction, left_operand: L, right_operand: R) -> Self + where + L: Into, + R: Into, + { + Self { + instruction, + left_operand: left_operand.into(), + right_operand: right_operand.into(), + } + } + + /// The number of rows this entry contributes to the U32 Table. + pub(crate) fn table_height_contribution(&self) -> u32 { + let lhs = self.left_operand.value(); + let rhs = self.right_operand.value(); + let dominant_operand = match self.instruction { + Instruction::Pow => rhs, // left-hand side doesn't change between rows + _ => max(lhs, rhs), + }; + match dominant_operand { + 0 => 2 - 1, + _ => 2 + dominant_operand.ilog2(), + } + } +} + +impl TraceTable for U32Table { + type FillParam = (); + type FillReturnInfo = (); + + fn fill(mut u32_table: ArrayViewMut2, aet: &AlgebraicExecutionTrace, _: ()) { + let mut next_section_start = 0; + for (&u32_table_entry, &multiplicity) in &aet.u32_entries { + let mut first_row = Array2::zeros([1, Self::MainColumn::COUNT]); + first_row[[0, CopyFlag.base_table_index()]] = bfe!(1); + first_row[[0, Bits.base_table_index()]] = bfe!(0); + first_row[[0, BitsMinus33Inv.base_table_index()]] = bfe!(-33).inverse(); + first_row[[0, CI.base_table_index()]] = u32_table_entry.instruction.opcode_b(); + first_row[[0, LHS.base_table_index()]] = u32_table_entry.left_operand; + first_row[[0, RHS.base_table_index()]] = u32_table_entry.right_operand; + first_row[[0, LookupMultiplicity.base_table_index()]] = multiplicity.into(); + let u32_section = u32_section_next_row(first_row); + + let next_section_end = next_section_start + u32_section.nrows(); + u32_table + .slice_mut(s![next_section_start..next_section_end, ..]) + .assign(&u32_section); + next_section_start = next_section_end; + } + } + + fn pad(mut main_table: ArrayViewMut2, table_len: usize) { + let mut padding_row = Array1::zeros([Self::MainColumn::COUNT]); + padding_row[[CI.base_table_index()]] = Instruction::Split.opcode_b(); + padding_row[[BitsMinus33Inv.base_table_index()]] = bfe!(-33).inverse(); + + if table_len > 0 { + let last_row = main_table.row(table_len - 1); + padding_row[[CI.base_table_index()]] = last_row[CI.base_table_index()]; + padding_row[[LHS.base_table_index()]] = last_row[LHS.base_table_index()]; + padding_row[[LhsInv.base_table_index()]] = last_row[LhsInv.base_table_index()]; + padding_row[[Result.base_table_index()]] = last_row[Result.base_table_index()]; + + // In the edge case that the last non-padding row comes from executing instruction + // `lt` on operands 0 and 0, the `Result` column is 0. For the padding section, + // where the `CopyFlag` is always 0, the `Result` needs to be set to 2 instead. + if padding_row[[CI.base_table_index()]] == Instruction::Lt.opcode_b() { + padding_row[[Result.base_table_index()]] = bfe!(2); + } + } + + main_table + .slice_mut(s![table_len.., ..]) + .axis_iter_mut(Axis(0)) + .into_par_iter() + .for_each(|mut row| row.assign(&padding_row)); + } + + fn extend( + base_table: ArrayView2, + mut ext_table: ArrayViewMut2, + challenges: &Challenges, + ) { + profiler!(start "u32 table"); + assert_eq!(Self::MainColumn::COUNT, base_table.ncols()); + assert_eq!(Self::AuxColumn::COUNT, ext_table.ncols()); + assert_eq!(base_table.nrows(), ext_table.nrows()); + + let ci_weight = challenges[U32CiWeight]; + let lhs_weight = challenges[U32LhsWeight]; + let rhs_weight = challenges[U32RhsWeight]; + let result_weight = challenges[U32ResultWeight]; + let lookup_indeterminate = challenges[U32Indeterminate]; + + let mut running_sum_log_derivative = LookupArg::default_initial(); + for row_idx in 0..base_table.nrows() { + let current_row = base_table.row(row_idx); + if current_row[CopyFlag.base_table_index()].is_one() { + let lookup_multiplicity = current_row[LookupMultiplicity.base_table_index()]; + let compressed_row = ci_weight * current_row[CI.base_table_index()] + + lhs_weight * current_row[LHS.base_table_index()] + + rhs_weight * current_row[RHS.base_table_index()] + + result_weight * current_row[Result.base_table_index()]; + running_sum_log_derivative += + lookup_multiplicity * (lookup_indeterminate - compressed_row).inverse(); + } + + let mut extension_row = ext_table.row_mut(row_idx); + extension_row[LookupServerLogDerivative.ext_table_index()] = running_sum_log_derivative; + } + profiler!(stop "u32 table"); + } +} + +fn u32_section_next_row(mut section: Array2) -> Array2 { + let row_idx = section.nrows() - 1; + let current_instruction: Instruction = section[[row_idx, CI.base_table_index()]] + .value() + .try_into() + .expect("Unknown instruction"); + + // Is the last row in this section reached? + if (section[[row_idx, LHS.base_table_index()]].is_zero() + || current_instruction == Instruction::Pow) + && section[[row_idx, RHS.base_table_index()]].is_zero() + { + section[[row_idx, Result.base_table_index()]] = match current_instruction { + Instruction::Split => bfe!(0), + Instruction::Lt => bfe!(2), + Instruction::And => bfe!(0), + Instruction::Log2Floor => bfe!(-1), + Instruction::Pow => bfe!(1), + Instruction::PopCount => bfe!(0), + _ => panic!("Must be u32 instruction, not {current_instruction}."), + }; + + // If instruction `lt` is executed on operands 0 and 0, the result is known to be 0. + // The edge case can be reliably detected by checking whether column `Bits` is 0. + let both_operands_are_0 = section[[row_idx, Bits.base_table_index()]].is_zero(); + if current_instruction == Instruction::Lt && both_operands_are_0 { + section[[row_idx, Result.base_table_index()]] = bfe!(0); + } + + // The right hand side is guaranteed to be 0. However, if the current instruction is + // `pow`, then the left hand side might be non-zero. + let lhs_inv_or_0 = section[[row_idx, LHS.base_table_index()]].inverse_or_zero(); + section[[row_idx, LhsInv.base_table_index()]] = lhs_inv_or_0; + + return section; + } + + let lhs_lsb = bfe!(section[[row_idx, LHS.base_table_index()]].value() % 2); + let rhs_lsb = bfe!(section[[row_idx, RHS.base_table_index()]].value() % 2); + let mut next_row = section.row(row_idx).to_owned(); + next_row[CopyFlag.base_table_index()] = bfe!(0); + next_row[Bits.base_table_index()] += bfe!(1); + next_row[BitsMinus33Inv.base_table_index()] = + (next_row[Bits.base_table_index()] - bfe!(33)).inverse(); + next_row[LHS.base_table_index()] = match current_instruction == Instruction::Pow { + true => section[[row_idx, LHS.base_table_index()]], + false => (section[[row_idx, LHS.base_table_index()]] - lhs_lsb) / bfe!(2), + }; + next_row[RHS.base_table_index()] = + (section[[row_idx, RHS.base_table_index()]] - rhs_lsb) / bfe!(2); + next_row[LookupMultiplicity.base_table_index()] = bfe!(0); + + section.push_row(next_row.view()).unwrap(); + section = u32_section_next_row(section); + let (mut row, next_row) = section.multi_slice_mut((s![row_idx, ..], s![row_idx + 1, ..])); + + row[LhsInv.base_table_index()] = row[LHS.base_table_index()].inverse_or_zero(); + row[RhsInv.base_table_index()] = row[RHS.base_table_index()].inverse_or_zero(); + + let next_row_result = next_row[Result.base_table_index()]; + row[Result.base_table_index()] = match current_instruction { + Instruction::Split => next_row_result, + Instruction::Lt => { + match ( + next_row_result.value(), + lhs_lsb.value(), + rhs_lsb.value(), + row[CopyFlag.base_table_index()].value(), + ) { + (0 | 1, _, _, _) => next_row_result, // result already known + (2, 0, 1, _) => bfe!(1), // LHS < RHS + (2, 1, 0, _) => bfe!(0), // LHS > RHS + (2, _, _, 1) => bfe!(0), // LHS == RHS + (2, _, _, 0) => bfe!(2), // result still unknown + _ => panic!("Invalid state"), + } + } + Instruction::And => bfe!(2) * next_row_result + lhs_lsb * rhs_lsb, + Instruction::Log2Floor => { + if row[LHS.base_table_index()].is_zero() { + bfe!(-1) + } else if !next_row[LHS.base_table_index()].is_zero() { + next_row_result + } else { + // LHS != 0 && LHS' == 0 + row[Bits.base_table_index()] + } + } + Instruction::Pow => match rhs_lsb.is_zero() { + true => next_row_result * next_row_result, + false => next_row_result * next_row_result * row[LHS.base_table_index()], + }, + Instruction::PopCount => next_row_result + lhs_lsb, + _ => panic!("Must be u32 instruction, not {current_instruction}."), + }; + + section +} diff --git a/triton-vm/src/table/u32_table.rs b/triton-vm/src/table/u32_table.rs deleted file mode 100644 index da015dff7..000000000 --- a/triton-vm/src/table/u32_table.rs +++ /dev/null @@ -1,617 +0,0 @@ -use std::cmp::max; -use std::ops::Mul; - -use arbitrary::Arbitrary; -use constraint_circuit::ConstraintCircuitBuilder; -use constraint_circuit::ConstraintCircuitMonad; -use constraint_circuit::DualRowIndicator; -use constraint_circuit::DualRowIndicator::*; -use constraint_circuit::InputIndicator; -use constraint_circuit::SingleRowIndicator; -use constraint_circuit::SingleRowIndicator::*; -use isa::instruction::Instruction; -use ndarray::parallel::prelude::*; -use ndarray::s; -use ndarray::Array1; -use ndarray::Array2; -use ndarray::ArrayView2; -use ndarray::ArrayViewMut2; -use ndarray::Axis; -use num_traits::One; -use num_traits::Zero; -use strum::EnumCount; -use twenty_first::prelude::*; - -use crate::aet::AlgebraicExecutionTrace; -use crate::profiler::profiler; -use crate::table::challenges::ChallengeId::*; -use crate::table::challenges::Challenges; -use crate::table::cross_table_argument::CrossTableArg; -use crate::table::cross_table_argument::LookupArg; -use crate::table::table_column::MasterBaseTableColumn; -use crate::table::table_column::MasterExtTableColumn; -use crate::table::table_column::U32BaseTableColumn; -use crate::table::table_column::U32BaseTableColumn::*; -use crate::table::table_column::U32ExtTableColumn; -use crate::table::table_column::U32ExtTableColumn::*; - -pub const BASE_WIDTH: usize = U32BaseTableColumn::COUNT; -pub const EXT_WIDTH: usize = U32ExtTableColumn::COUNT; -pub const FULL_WIDTH: usize = BASE_WIDTH + EXT_WIDTH; - -/// An executed u32 instruction as well as its operands. -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)] -pub struct U32TableEntry { - pub instruction: Instruction, - pub left_operand: BFieldElement, - pub right_operand: BFieldElement, -} - -impl U32TableEntry { - pub fn new(instruction: Instruction, left_operand: L, right_operand: R) -> Self - where - L: Into, - R: Into, - { - Self { - instruction, - left_operand: left_operand.into(), - right_operand: right_operand.into(), - } - } - - /// The number of rows this entry contributes to the U32 Table. - pub(crate) fn table_height_contribution(&self) -> u32 { - let lhs = self.left_operand.value(); - let rhs = self.right_operand.value(); - let dominant_operand = match self.instruction { - Instruction::Pow => rhs, // left-hand side doesn't change between rows - _ => max(lhs, rhs), - }; - match dominant_operand { - 0 => 2 - 1, - _ => 2 + dominant_operand.ilog2(), - } - } -} - -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct U32Table; - -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct ExtU32Table; - -impl ExtU32Table { - fn instruction_deselector( - instruction_to_select: Instruction, - circuit_builder: &ConstraintCircuitBuilder, - current_instruction: &ConstraintCircuitMonad, - ) -> ConstraintCircuitMonad { - [ - Instruction::Split, - Instruction::Lt, - Instruction::And, - Instruction::Log2Floor, - Instruction::Pow, - Instruction::PopCount, - ] - .into_iter() - .filter(|&instruction| instruction != instruction_to_select) - .map(|instr| current_instruction.clone() - circuit_builder.b_constant(instr.opcode_b())) - .fold(circuit_builder.b_constant(1), ConstraintCircuitMonad::mul) - } - - pub fn initial_constraints( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let challenge = |c| circuit_builder.challenge(c); - let one = circuit_builder.b_constant(1); - - let copy_flag = circuit_builder.input(BaseRow(CopyFlag.master_base_table_index())); - let lhs = circuit_builder.input(BaseRow(LHS.master_base_table_index())); - let rhs = circuit_builder.input(BaseRow(RHS.master_base_table_index())); - let ci = circuit_builder.input(BaseRow(CI.master_base_table_index())); - let result = circuit_builder.input(BaseRow(Result.master_base_table_index())); - let lookup_multiplicity = - circuit_builder.input(BaseRow(LookupMultiplicity.master_base_table_index())); - - let running_sum_log_derivative = - circuit_builder.input(ExtRow(LookupServerLogDerivative.master_ext_table_index())); - - let compressed_row = challenge(U32LhsWeight) * lhs - + challenge(U32RhsWeight) * rhs - + challenge(U32CiWeight) * ci - + challenge(U32ResultWeight) * result; - let if_copy_flag_is_1_then_log_derivative_has_accumulated_first_row = copy_flag.clone() - * (running_sum_log_derivative.clone() * (challenge(U32Indeterminate) - compressed_row) - - lookup_multiplicity); - - let default_initial = circuit_builder.x_constant(LookupArg::default_initial()); - let if_copy_flag_is_0_then_log_derivative_is_default_initial = - (copy_flag - one) * (running_sum_log_derivative - default_initial); - - let running_sum_log_derivative_starts_correctly = - if_copy_flag_is_0_then_log_derivative_is_default_initial - + if_copy_flag_is_1_then_log_derivative_has_accumulated_first_row; - - vec![running_sum_log_derivative_starts_correctly] - } - - pub fn consistency_constraints( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let one = || circuit_builder.b_constant(1); - let two = || circuit_builder.b_constant(2); - - let copy_flag = circuit_builder.input(BaseRow(CopyFlag.master_base_table_index())); - let bits = circuit_builder.input(BaseRow(Bits.master_base_table_index())); - let bits_minus_33_inv = - circuit_builder.input(BaseRow(BitsMinus33Inv.master_base_table_index())); - let ci = circuit_builder.input(BaseRow(CI.master_base_table_index())); - let lhs = circuit_builder.input(BaseRow(LHS.master_base_table_index())); - let lhs_inv = circuit_builder.input(BaseRow(LhsInv.master_base_table_index())); - let rhs = circuit_builder.input(BaseRow(RHS.master_base_table_index())); - let rhs_inv = circuit_builder.input(BaseRow(RhsInv.master_base_table_index())); - let result = circuit_builder.input(BaseRow(Result.master_base_table_index())); - let lookup_multiplicity = - circuit_builder.input(BaseRow(LookupMultiplicity.master_base_table_index())); - - let instruction_deselector = |instruction_to_select| { - Self::instruction_deselector(instruction_to_select, circuit_builder, &ci) - }; - - let copy_flag_is_bit = copy_flag.clone() * (one() - copy_flag.clone()); - let copy_flag_is_0_or_bits_is_0 = copy_flag.clone() * bits.clone(); - let bits_minus_33_inv_is_inverse_of_bits_minus_33 = - one() - bits_minus_33_inv * (bits - circuit_builder.b_constant(33)); - let lhs_inv_is_0_or_the_inverse_of_lhs = - lhs_inv.clone() * (one() - lhs.clone() * lhs_inv.clone()); - let lhs_is_0_or_lhs_inverse_is_the_inverse_of_lhs = - lhs.clone() * (one() - lhs.clone() * lhs_inv.clone()); - let rhs_inv_is_0_or_the_inverse_of_rhs = - rhs_inv.clone() * (one() - rhs.clone() * rhs_inv.clone()); - let rhs_is_0_or_rhs_inverse_is_the_inverse_of_rhs = - rhs.clone() * (one() - rhs.clone() * rhs_inv.clone()); - let result_is_initialized_correctly_for_lt_with_copy_flag_0 = - instruction_deselector(Instruction::Lt) - * (copy_flag.clone() - one()) - * (one() - lhs.clone() * lhs_inv.clone()) - * (one() - rhs.clone() * rhs_inv.clone()) - * (result.clone() - two()); - let result_is_initialized_correctly_for_lt_with_copy_flag_1 = - instruction_deselector(Instruction::Lt) - * copy_flag.clone() - * (one() - lhs.clone() * lhs_inv.clone()) - * (one() - rhs.clone() * rhs_inv.clone()) - * result.clone(); - let result_is_initialized_correctly_for_and = instruction_deselector(Instruction::And) - * (one() - lhs.clone() * lhs_inv.clone()) - * (one() - rhs.clone() * rhs_inv.clone()) - * result.clone(); - let result_is_initialized_correctly_for_pow = instruction_deselector(Instruction::Pow) - * (one() - rhs * rhs_inv) - * (result.clone() - one()); - let result_is_initialized_correctly_for_log_2_floor = - instruction_deselector(Instruction::Log2Floor) - * (copy_flag.clone() - one()) - * (one() - lhs.clone() * lhs_inv.clone()) - * (result.clone() + one()); - let result_is_initialized_correctly_for_pop_count = - instruction_deselector(Instruction::PopCount) - * (one() - lhs.clone() * lhs_inv.clone()) - * result; - let if_log_2_floor_on_0_then_vm_crashes = instruction_deselector(Instruction::Log2Floor) - * copy_flag.clone() - * (one() - lhs * lhs_inv); - let if_copy_flag_is_0_then_lookup_multiplicity_is_0 = - (copy_flag - one()) * lookup_multiplicity; - - vec![ - copy_flag_is_bit, - copy_flag_is_0_or_bits_is_0, - bits_minus_33_inv_is_inverse_of_bits_minus_33, - lhs_inv_is_0_or_the_inverse_of_lhs, - lhs_is_0_or_lhs_inverse_is_the_inverse_of_lhs, - rhs_inv_is_0_or_the_inverse_of_rhs, - rhs_is_0_or_rhs_inverse_is_the_inverse_of_rhs, - result_is_initialized_correctly_for_lt_with_copy_flag_0, - result_is_initialized_correctly_for_lt_with_copy_flag_1, - result_is_initialized_correctly_for_and, - result_is_initialized_correctly_for_pow, - result_is_initialized_correctly_for_log_2_floor, - result_is_initialized_correctly_for_pop_count, - if_log_2_floor_on_0_then_vm_crashes, - if_copy_flag_is_0_then_lookup_multiplicity_is_0, - ] - } - - pub fn transition_constraints( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let challenge = |c| circuit_builder.challenge(c); - let one = || circuit_builder.b_constant(1); - let two = || circuit_builder.b_constant(2); - - let copy_flag = circuit_builder.input(CurrentBaseRow(CopyFlag.master_base_table_index())); - let bits = circuit_builder.input(CurrentBaseRow(Bits.master_base_table_index())); - let ci = circuit_builder.input(CurrentBaseRow(CI.master_base_table_index())); - let lhs = circuit_builder.input(CurrentBaseRow(LHS.master_base_table_index())); - let rhs = circuit_builder.input(CurrentBaseRow(RHS.master_base_table_index())); - let result = circuit_builder.input(CurrentBaseRow(Result.master_base_table_index())); - let running_sum_log_derivative = circuit_builder.input(CurrentExtRow( - LookupServerLogDerivative.master_ext_table_index(), - )); - - let copy_flag_next = circuit_builder.input(NextBaseRow(CopyFlag.master_base_table_index())); - let bits_next = circuit_builder.input(NextBaseRow(Bits.master_base_table_index())); - let ci_next = circuit_builder.input(NextBaseRow(CI.master_base_table_index())); - let lhs_next = circuit_builder.input(NextBaseRow(LHS.master_base_table_index())); - let rhs_next = circuit_builder.input(NextBaseRow(RHS.master_base_table_index())); - let result_next = circuit_builder.input(NextBaseRow(Result.master_base_table_index())); - let lhs_inv_next = circuit_builder.input(NextBaseRow(LhsInv.master_base_table_index())); - let lookup_multiplicity_next = - circuit_builder.input(NextBaseRow(LookupMultiplicity.master_base_table_index())); - let running_sum_log_derivative_next = circuit_builder.input(NextExtRow( - LookupServerLogDerivative.master_ext_table_index(), - )); - - let instruction_deselector = |instruction_to_select: Instruction| { - Self::instruction_deselector(instruction_to_select, circuit_builder, &ci_next) - }; - - // helpful aliases - let ci_is_pow = ci.clone() - circuit_builder.b_constant(Instruction::Pow.opcode_b()); - let lhs_lsb = lhs.clone() - two() * lhs_next.clone(); - let rhs_lsb = rhs.clone() - two() * rhs_next.clone(); - - // general constraints - let if_copy_flag_next_is_1_then_lhs_is_0_or_ci_is_pow = - copy_flag_next.clone() * lhs.clone() * ci_is_pow.clone(); - let if_copy_flag_next_is_1_then_rhs_is_0 = copy_flag_next.clone() * rhs.clone(); - let if_copy_flag_next_is_0_then_ci_stays = - (copy_flag_next.clone() - one()) * (ci_next.clone() - ci); - let if_copy_flag_next_is_0_and_lhs_next_is_nonzero_and_ci_not_pow_then_bits_increases_by_1 = - (copy_flag_next.clone() - one()) - * lhs.clone() - * ci_is_pow.clone() - * (bits_next.clone() - bits.clone() - one()); - let if_copy_flag_next_is_0_and_rhs_next_is_nonzero_then_bits_increases_by_1 = - (copy_flag_next.clone() - one()) * rhs * (bits_next - bits.clone() - one()); - let if_copy_flag_next_is_0_and_ci_not_pow_then_lhs_lsb_is_a_bit = (copy_flag_next.clone() - - one()) - * ci_is_pow - * lhs_lsb.clone() - * (lhs_lsb.clone() - one()); - let if_copy_flag_next_is_0_then_rhs_lsb_is_a_bit = - (copy_flag_next.clone() - one()) * rhs_lsb.clone() * (rhs_lsb.clone() - one()); - - // instruction lt - let if_copy_flag_next_is_0_and_ci_is_lt_and_result_next_is_0_then_result_is_0 = - (copy_flag_next.clone() - one()) - * instruction_deselector(Instruction::Lt) - * (result_next.clone() - one()) - * (result_next.clone() - two()) - * result.clone(); - let if_copy_flag_next_is_0_and_ci_is_lt_and_result_next_is_1_then_result_is_1 = - (copy_flag_next.clone() - one()) - * instruction_deselector(Instruction::Lt) - * result_next.clone() - * (result_next.clone() - two()) - * (result.clone() - one()); - let if_copy_flag_next_is_0_and_ci_is_lt_and_result_next_is_2_and_lt_is_0_then_result_is_0 = - (copy_flag_next.clone() - one()) - * instruction_deselector(Instruction::Lt) - * result_next.clone() - * (result_next.clone() - one()) - * (lhs_lsb.clone() - one()) - * rhs_lsb.clone() - * (result.clone() - one()); - let if_copy_flag_next_is_0_and_ci_is_lt_and_result_next_is_2_and_lt_is_1_then_result_is_1 = - (copy_flag_next.clone() - one()) - * instruction_deselector(Instruction::Lt) - * result_next.clone() - * (result_next.clone() - one()) - * lhs_lsb.clone() - * (rhs_lsb.clone() - one()) - * result.clone(); - let if_copy_flag_next_is_0_and_ci_is_lt_and_result_still_not_known_then_result_is_2 = - (copy_flag_next.clone() - one()) - * instruction_deselector(Instruction::Lt) - * result_next.clone() - * (result_next.clone() - one()) - * (one() - lhs_lsb.clone() - rhs_lsb.clone() - + two() * lhs_lsb.clone() * rhs_lsb.clone()) - * (copy_flag.clone() - one()) - * (result.clone() - two()); - let if_copy_flag_next_is_0_and_ci_is_lt_and_copy_flag_dictates_result_then_result_is_0 = - (copy_flag_next.clone() - one()) - * instruction_deselector(Instruction::Lt) - * result_next.clone() - * (result_next.clone() - one()) - * (one() - lhs_lsb.clone() - rhs_lsb.clone() - + two() * lhs_lsb.clone() * rhs_lsb.clone()) - * copy_flag - * result.clone(); - - // instruction and - let if_copy_flag_next_is_0_and_ci_is_and_then_results_updates_correctly = (copy_flag_next - .clone() - - one()) - * instruction_deselector(Instruction::And) - * (result.clone() - two() * result_next.clone() - lhs_lsb.clone() * rhs_lsb.clone()); - - // instruction log_2_floor - let if_copy_flag_next_is_0_and_ci_is_log_2_floor_lhs_next_0_for_first_time_then_set_result = - (copy_flag_next.clone() - one()) - * instruction_deselector(Instruction::Log2Floor) - * (one() - lhs_next.clone() * lhs_inv_next) - * lhs.clone() - * (result.clone() - bits); - let if_copy_flag_next_is_0_and_ci_is_log_2_floor_and_lhs_next_not_0_then_copy_result = - (copy_flag_next.clone() - one()) - * instruction_deselector(Instruction::Log2Floor) - * lhs_next.clone() - * (result_next.clone() - result.clone()); - - // instruction pow - let if_copy_flag_next_is_0_and_ci_is_pow_then_lhs_remains_unchanged = - (copy_flag_next.clone() - one()) - * instruction_deselector(Instruction::Pow) - * (lhs_next.clone() - lhs.clone()); - - let if_copy_flag_next_is_0_and_ci_is_pow_and_rhs_lsb_is_0_then_result_squares = - (copy_flag_next.clone() - one()) - * instruction_deselector(Instruction::Pow) - * (rhs_lsb.clone() - one()) - * (result.clone() - result_next.clone() * result_next.clone()); - - let if_copy_flag_next_is_0_and_ci_is_pow_and_rhs_lsb_is_1_then_result_squares_and_mults = - (copy_flag_next.clone() - one()) - * instruction_deselector(Instruction::Pow) - * rhs_lsb - * (result.clone() - result_next.clone() * result_next.clone() * lhs); - - let if_copy_flag_next_is_0_and_ci_is_pop_count_then_result_increases_by_lhs_lsb = - (copy_flag_next.clone() - one()) - * instruction_deselector(Instruction::PopCount) - * (result - result_next.clone() - lhs_lsb); - - // running sum for Lookup Argument with Processor Table - let if_copy_flag_next_is_0_then_running_sum_log_derivative_stays = (copy_flag_next.clone() - - one()) - * (running_sum_log_derivative_next.clone() - running_sum_log_derivative.clone()); - - let compressed_row_next = challenge(U32CiWeight) * ci_next - + challenge(U32LhsWeight) * lhs_next - + challenge(U32RhsWeight) * rhs_next - + challenge(U32ResultWeight) * result_next; - let if_copy_flag_next_is_1_then_running_sum_log_derivative_accumulates_next_row = - copy_flag_next - * ((running_sum_log_derivative_next - running_sum_log_derivative) - * (challenge(U32Indeterminate) - compressed_row_next) - - lookup_multiplicity_next); - - vec![ - if_copy_flag_next_is_1_then_lhs_is_0_or_ci_is_pow, - if_copy_flag_next_is_1_then_rhs_is_0, - if_copy_flag_next_is_0_then_ci_stays, - if_copy_flag_next_is_0_and_lhs_next_is_nonzero_and_ci_not_pow_then_bits_increases_by_1, - if_copy_flag_next_is_0_and_rhs_next_is_nonzero_then_bits_increases_by_1, - if_copy_flag_next_is_0_and_ci_not_pow_then_lhs_lsb_is_a_bit, - if_copy_flag_next_is_0_then_rhs_lsb_is_a_bit, - if_copy_flag_next_is_0_and_ci_is_lt_and_result_next_is_0_then_result_is_0, - if_copy_flag_next_is_0_and_ci_is_lt_and_result_next_is_1_then_result_is_1, - if_copy_flag_next_is_0_and_ci_is_lt_and_result_next_is_2_and_lt_is_0_then_result_is_0, - if_copy_flag_next_is_0_and_ci_is_lt_and_result_next_is_2_and_lt_is_1_then_result_is_1, - if_copy_flag_next_is_0_and_ci_is_lt_and_result_still_not_known_then_result_is_2, - if_copy_flag_next_is_0_and_ci_is_lt_and_copy_flag_dictates_result_then_result_is_0, - if_copy_flag_next_is_0_and_ci_is_and_then_results_updates_correctly, - if_copy_flag_next_is_0_and_ci_is_log_2_floor_lhs_next_0_for_first_time_then_set_result, - if_copy_flag_next_is_0_and_ci_is_log_2_floor_and_lhs_next_not_0_then_copy_result, - if_copy_flag_next_is_0_and_ci_is_pow_then_lhs_remains_unchanged, - if_copy_flag_next_is_0_and_ci_is_pow_and_rhs_lsb_is_0_then_result_squares, - if_copy_flag_next_is_0_and_ci_is_pow_and_rhs_lsb_is_1_then_result_squares_and_mults, - if_copy_flag_next_is_0_and_ci_is_pop_count_then_result_increases_by_lhs_lsb, - if_copy_flag_next_is_0_then_running_sum_log_derivative_stays, - if_copy_flag_next_is_1_then_running_sum_log_derivative_accumulates_next_row, - ] - } - - pub fn terminal_constraints( - circuit_builder: &ConstraintCircuitBuilder, - ) -> Vec> { - let ci = circuit_builder.input(BaseRow(CI.master_base_table_index())); - let lhs = circuit_builder.input(BaseRow(LHS.master_base_table_index())); - let rhs = circuit_builder.input(BaseRow(RHS.master_base_table_index())); - - let lhs_is_0_or_ci_is_pow = - lhs * (ci - circuit_builder.b_constant(Instruction::Pow.opcode_b())); - let rhs_is_0 = rhs; - - vec![lhs_is_0_or_ci_is_pow, rhs_is_0] - } -} - -impl U32Table { - pub fn fill_trace(u32_table: &mut ArrayViewMut2, aet: &AlgebraicExecutionTrace) { - let mut next_section_start = 0; - for (&u32_table_entry, &multiplicity) in &aet.u32_entries { - let mut first_row = Array2::zeros([1, BASE_WIDTH]); - first_row[[0, CopyFlag.base_table_index()]] = bfe!(1); - first_row[[0, Bits.base_table_index()]] = bfe!(0); - first_row[[0, BitsMinus33Inv.base_table_index()]] = bfe!(-33).inverse(); - first_row[[0, CI.base_table_index()]] = u32_table_entry.instruction.opcode_b(); - first_row[[0, LHS.base_table_index()]] = u32_table_entry.left_operand; - first_row[[0, RHS.base_table_index()]] = u32_table_entry.right_operand; - first_row[[0, LookupMultiplicity.base_table_index()]] = multiplicity.into(); - let u32_section = Self::u32_section_next_row(first_row); - - let next_section_end = next_section_start + u32_section.nrows(); - u32_table - .slice_mut(s![next_section_start..next_section_end, ..]) - .assign(&u32_section); - next_section_start = next_section_end; - } - } - - fn u32_section_next_row(mut section: Array2) -> Array2 { - let row_idx = section.nrows() - 1; - let current_instruction: Instruction = section[[row_idx, CI.base_table_index()]] - .value() - .try_into() - .expect("Unknown instruction"); - - // Is the last row in this section reached? - if (section[[row_idx, LHS.base_table_index()]].is_zero() - || current_instruction == Instruction::Pow) - && section[[row_idx, RHS.base_table_index()]].is_zero() - { - section[[row_idx, Result.base_table_index()]] = match current_instruction { - Instruction::Split => bfe!(0), - Instruction::Lt => bfe!(2), - Instruction::And => bfe!(0), - Instruction::Log2Floor => bfe!(-1), - Instruction::Pow => bfe!(1), - Instruction::PopCount => bfe!(0), - _ => panic!("Must be u32 instruction, not {current_instruction}."), - }; - - // If instruction `lt` is executed on operands 0 and 0, the result is known to be 0. - // The edge case can be reliably detected by checking whether column `Bits` is 0. - let both_operands_are_0 = section[[row_idx, Bits.base_table_index()]].is_zero(); - if current_instruction == Instruction::Lt && both_operands_are_0 { - section[[row_idx, Result.base_table_index()]] = bfe!(0); - } - - // The right hand side is guaranteed to be 0. However, if the current instruction is - // `pow`, then the left hand side might be non-zero. - let lhs_inv_or_0 = section[[row_idx, LHS.base_table_index()]].inverse_or_zero(); - section[[row_idx, LhsInv.base_table_index()]] = lhs_inv_or_0; - - return section; - } - - let lhs_lsb = bfe!(section[[row_idx, LHS.base_table_index()]].value() % 2); - let rhs_lsb = bfe!(section[[row_idx, RHS.base_table_index()]].value() % 2); - let mut next_row = section.row(row_idx).to_owned(); - next_row[CopyFlag.base_table_index()] = bfe!(0); - next_row[Bits.base_table_index()] += bfe!(1); - next_row[BitsMinus33Inv.base_table_index()] = - (next_row[Bits.base_table_index()] - bfe!(33)).inverse(); - next_row[LHS.base_table_index()] = match current_instruction == Instruction::Pow { - true => section[[row_idx, LHS.base_table_index()]], - false => (section[[row_idx, LHS.base_table_index()]] - lhs_lsb) / bfe!(2), - }; - next_row[RHS.base_table_index()] = - (section[[row_idx, RHS.base_table_index()]] - rhs_lsb) / bfe!(2); - next_row[LookupMultiplicity.base_table_index()] = bfe!(0); - - section.push_row(next_row.view()).unwrap(); - section = Self::u32_section_next_row(section); - let (mut row, next_row) = section.multi_slice_mut((s![row_idx, ..], s![row_idx + 1, ..])); - - row[LhsInv.base_table_index()] = row[LHS.base_table_index()].inverse_or_zero(); - row[RhsInv.base_table_index()] = row[RHS.base_table_index()].inverse_or_zero(); - - let next_row_result = next_row[Result.base_table_index()]; - row[Result.base_table_index()] = match current_instruction { - Instruction::Split => next_row_result, - Instruction::Lt => { - match ( - next_row_result.value(), - lhs_lsb.value(), - rhs_lsb.value(), - row[CopyFlag.base_table_index()].value(), - ) { - (0 | 1, _, _, _) => next_row_result, // result already known - (2, 0, 1, _) => bfe!(1), // LHS < RHS - (2, 1, 0, _) => bfe!(0), // LHS > RHS - (2, _, _, 1) => bfe!(0), // LHS == RHS - (2, _, _, 0) => bfe!(2), // result still unknown - _ => panic!("Invalid state"), - } - } - Instruction::And => bfe!(2) * next_row_result + lhs_lsb * rhs_lsb, - Instruction::Log2Floor => { - if row[LHS.base_table_index()].is_zero() { - bfe!(-1) - } else if !next_row[LHS.base_table_index()].is_zero() { - next_row_result - } else { - // LHS != 0 && LHS' == 0 - row[Bits.base_table_index()] - } - } - Instruction::Pow => match rhs_lsb.is_zero() { - true => next_row_result * next_row_result, - false => next_row_result * next_row_result * row[LHS.base_table_index()], - }, - Instruction::PopCount => next_row_result + lhs_lsb, - _ => panic!("Must be u32 instruction, not {current_instruction}."), - }; - - section - } - - pub fn pad_trace(mut u32_table: ArrayViewMut2, u32_table_len: usize) { - let mut padding_row = Array1::zeros([BASE_WIDTH]); - padding_row[[CI.base_table_index()]] = Instruction::Split.opcode_b(); - padding_row[[BitsMinus33Inv.base_table_index()]] = bfe!(-33).inverse(); - - if u32_table_len > 0 { - let last_row = u32_table.row(u32_table_len - 1); - padding_row[[CI.base_table_index()]] = last_row[CI.base_table_index()]; - padding_row[[LHS.base_table_index()]] = last_row[LHS.base_table_index()]; - padding_row[[LhsInv.base_table_index()]] = last_row[LhsInv.base_table_index()]; - padding_row[[Result.base_table_index()]] = last_row[Result.base_table_index()]; - - // In the edge case that the last non-padding row comes from executing instruction - // `lt` on operands 0 and 0, the `Result` column is 0. For the padding section, - // where the `CopyFlag` is always 0, the `Result` needs to be set to 2 instead. - if padding_row[[CI.base_table_index()]] == Instruction::Lt.opcode_b() { - padding_row[[Result.base_table_index()]] = bfe!(2); - } - } - - u32_table - .slice_mut(s![u32_table_len.., ..]) - .axis_iter_mut(Axis(0)) - .into_par_iter() - .for_each(|mut row| row.assign(&padding_row)); - } - - pub fn extend( - base_table: ArrayView2, - mut ext_table: ArrayViewMut2, - challenges: &Challenges, - ) { - profiler!(start "u32 table"); - assert_eq!(BASE_WIDTH, base_table.ncols()); - assert_eq!(EXT_WIDTH, ext_table.ncols()); - assert_eq!(base_table.nrows(), ext_table.nrows()); - - let ci_weight = challenges[U32CiWeight]; - let lhs_weight = challenges[U32LhsWeight]; - let rhs_weight = challenges[U32RhsWeight]; - let result_weight = challenges[U32ResultWeight]; - let lookup_indeterminate = challenges[U32Indeterminate]; - - let mut running_sum_log_derivative = LookupArg::default_initial(); - for row_idx in 0..base_table.nrows() { - let current_row = base_table.row(row_idx); - if current_row[CopyFlag.base_table_index()].is_one() { - let lookup_multiplicity = current_row[LookupMultiplicity.base_table_index()]; - let compressed_row = ci_weight * current_row[CI.base_table_index()] - + lhs_weight * current_row[LHS.base_table_index()] - + rhs_weight * current_row[RHS.base_table_index()] - + result_weight * current_row[Result.base_table_index()]; - running_sum_log_derivative += - lookup_multiplicity * (lookup_indeterminate - compressed_row).inverse(); - } - - let mut extension_row = ext_table.row_mut(row_idx); - extension_row[LookupServerLogDerivative.ext_table_index()] = running_sum_log_derivative; - } - profiler!(stop "u32 table"); - } -} diff --git a/triton-vm/src/vm.rs b/triton-vm/src/vm.rs index e2d349548..ba85bd4dd 100644 --- a/triton-vm/src/vm.rs +++ b/triton-vm/src/vm.rs @@ -5,6 +5,11 @@ use std::fmt::Formatter; use std::fmt::Result as FmtResult; use std::ops::Range; +use air::table::hash::PermutationTrace; +use air::table::processor::ProcessorTable; +use air::table::processor::NUM_HELPER_VARIABLE_REGISTERS; +use air::table_column::*; +use air::AIR; use arbitrary::Arbitrary; use isa::error::InstructionError; use isa::instruction::AnInstruction::*; @@ -19,6 +24,7 @@ use num_traits::One; use num_traits::Zero; use serde::Deserialize; use serde::Serialize; +use strum::EnumCount; use twenty_first::math::x_field_element::EXTENSION_DEGREE; use twenty_first::prelude::*; use twenty_first::util_types::algebraic_hasher::Domain; @@ -28,20 +34,15 @@ use crate::error::VMError; use crate::execution_trace_profiler::ExecutionTraceProfile; use crate::execution_trace_profiler::ExecutionTraceProfiler; use crate::profiler::profiler; -use crate::table::hash_table::PermutationTrace; -use crate::table::op_stack_table::OpStackTableEntry; -use crate::table::processor_table; -use crate::table::ram_table::RamTableCall; -use crate::table::table_column::*; -use crate::table::u32_table::U32TableEntry; +use crate::table::op_stack::OpStackTableEntry; +use crate::table::processor; +use crate::table::ram::RamTableCall; +use crate::table::u32::U32TableEntry; use crate::vm::CoProcessorCall::*; type VMResult = Result; type InstructionResult = Result; -/// The number of helper variable registers -pub const NUM_HELPER_VARIABLE_REGISTERS: usize = 6; - #[derive(Debug, Copy, Clone, Eq, PartialEq, Serialize, Deserialize, Arbitrary)] pub struct VM; @@ -1050,7 +1051,7 @@ impl VMState { pub fn to_processor_row(&self) -> Array1 { use isa::instruction::InstructionBit; use ProcessorBaseTableColumn::*; - let mut processor_row = Array1::zeros(processor_table::BASE_WIDTH); + let mut processor_row = Array1::zeros(::MainColumn::COUNT); let current_instruction = self.current_instruction().unwrap_or(Nop); let helper_variables = self.derive_helper_variables(); @@ -1390,6 +1391,7 @@ pub(crate) mod tests { use std::ops::BitAnd; use std::ops::BitXor; + use air::table::TableId; use assert2::assert; use assert2::let_assert; use isa::instruction::AnInstruction; @@ -1416,7 +1418,6 @@ pub(crate) mod tests { use crate::shared_tests::LeavedMerkleTreeTestData; use crate::shared_tests::ProgramAndInput; use crate::shared_tests::DEFAULT_LOG2_FRI_EXPANSION_FACTOR_FOR_TESTS; - use crate::table::master_table::TableId; use super::*; From aae5ccd4e5ae4ae5fff78685e2c6d7351dfe1ddc Mon Sep 17 00:00:00 2001 From: Jan Ferdinand Sauer Date: Tue, 3 Sep 2024 13:33:51 +0200 Subject: [PATCH 07/15] refactor!: Move constraint builder to own crate Use a build script to generate the required constraint code. This breaks the circular dependency of Triton VM. changelog: ignore --- Cargo.toml | 9 +- triton-air/src/table/hash.rs | 4 +- triton-air/src/table/processor.rs | 2 +- triton-constraint-builder/Cargo.toml | 35 ++++ .../src/codegen.rs | 66 +++---- .../src/lib.rs | 124 ++++++++++-- .../src}/substitutions.rs | 78 +++++--- triton-constraint-circuit/src/lib.rs | 2 +- triton-isa/src/instruction.rs | 2 +- triton-isa/src/parser.rs | 12 +- triton-vm/Cargo.toml | 12 +- triton-vm/build.rs | 38 ++++ triton-vm/src/air.rs | 68 +++---- triton-vm/src/air/tasm_air_constraints.rs | 21 -- triton-vm/src/lib.rs | 19 +- triton-vm/src/{air => }/memory_layout.rs | 21 +- triton-vm/src/stark.rs | 70 +++---- triton-vm/src/table.rs | 138 +++----------- triton-vm/src/table/cascade.rs | 72 ++++--- triton-vm/src/table/constraints.rs | 112 ----------- triton-vm/src/table/degree_lowering.rs | 5 + triton-vm/src/table/degree_lowering_table.rs | 35 ---- triton-vm/src/table/extension_table.rs | 41 +--- triton-vm/src/table/hash.rs | 20 +- triton-vm/src/table/jump_stack.rs | 61 +++--- triton-vm/src/table/lookup.rs | 58 +++--- triton-vm/src/table/master_table.rs | 92 ++++----- triton-vm/src/table/op_stack.rs | 78 ++++---- triton-vm/src/table/processor.rs | 179 +++++++++--------- triton-vm/src/table/program.rs | 3 - triton-vm/src/table/ram.rs | 4 - triton-vm/src/table/u32.rs | 134 ++++++------- triton-vm/src/vm.rs | 1 - 33 files changed, 752 insertions(+), 864 deletions(-) create mode 100644 triton-constraint-builder/Cargo.toml rename triton-vm/src/codegen/constraints.rs => triton-constraint-builder/src/codegen.rs (94%) rename triton-vm/src/codegen/mod.rs => triton-constraint-builder/src/lib.rs (62%) rename {triton-vm/src/codegen => triton-constraint-builder/src}/substitutions.rs (86%) create mode 100644 triton-vm/build.rs delete mode 100644 triton-vm/src/air/tasm_air_constraints.rs rename triton-vm/src/{air => }/memory_layout.rs (93%) delete mode 100644 triton-vm/src/table/constraints.rs create mode 100644 triton-vm/src/table/degree_lowering.rs delete mode 100644 triton-vm/src/table/degree_lowering_table.rs diff --git a/Cargo.toml b/Cargo.toml index 83c85a7b7..55cb8d41b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,11 @@ [workspace] -members = ["triton-air", "triton-constraint-circuit", "triton-isa", "triton-vm"] +members = [ + "triton-air", + "triton-constraint-builder", + "triton-constraint-circuit", + "triton-isa", + "triton-vm", +] resolver = "2" [profile.test] @@ -27,6 +33,7 @@ anyhow = "1.0" arbitrary = { version = "1", features = ["derive"] } assert2 = "0.3" colored = "2.1" +constraint-builder = { path = "triton-constraint-builder", package = "triton-constraint-builder" } constraint-circuit = { path = "triton-constraint-circuit", package = "triton-constraint-circuit" } clap = { version = "4", features = ["derive", "cargo", "wrap_help", "unicode", "string"] } criterion = { version = "0.5", features = ["html_reports"] } diff --git a/triton-air/src/table/hash.rs b/triton-air/src/table/hash.rs index b7ce592a1..89fac3fc3 100644 --- a/triton-air/src/table/hash.rs +++ b/triton-air/src/table/hash.rs @@ -634,7 +634,7 @@ impl AIR for HashTable { ), cascade_log_derivative_init_circuit( Self::MainColumn::State1MidLowLkIn, - Self::MainColumn::State2MidHighLkIn, + Self::MainColumn::State1MidLowLkOut, Self::AuxColumn::CascadeState1MidLowClientLogDerivative, ), cascade_log_derivative_init_circuit( @@ -643,7 +643,7 @@ impl AIR for HashTable { Self::AuxColumn::CascadeState1LowestClientLogDerivative, ), cascade_log_derivative_init_circuit( - Self::MainColumn::State2LowestLkIn, + Self::MainColumn::State2HighestLkIn, Self::MainColumn::State2HighestLkOut, Self::AuxColumn::CascadeState2HighestClientLogDerivative, ), diff --git a/triton-air/src/table/processor.rs b/triton-air/src/table/processor.rs index 628ec3537..d319d72d5 100644 --- a/triton-air/src/table/processor.rs +++ b/triton-air/src/table/processor.rs @@ -148,7 +148,7 @@ impl ProcessorTable { /// Panics if the index is out of bounds. pub fn op_stack_column_by_index(index: usize) -> ProcessorBaseTableColumn { assert!( - OpStackElement::COUNT < index, + index < OpStackElement::COUNT, "Op Stack column index must be in [0, 15], not {index}" ); diff --git a/triton-constraint-builder/Cargo.toml b/triton-constraint-builder/Cargo.toml new file mode 100644 index 000000000..bc07adc4d --- /dev/null +++ b/triton-constraint-builder/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "triton-constraint-builder" +description = """ +Emits efficient code from Triton VM's AIR. +""" + +version.workspace = true +edition.workspace = true +authors.workspace = true +license.workspace = true +homepage.workspace = true +documentation.workspace = true +repository.workspace = true +readme.workspace = true + +[dependencies] +air.workspace = true +arbitrary.workspace = true +constraint-circuit.workspace = true +isa.workspace = true +itertools.workspace = true +prettyplease.workspace = true +proc-macro2.workspace = true +quote.workspace = true +strum.workspace = true +syn.workspace = true +twenty-first.workspace = true + +[dev-dependencies] +proptest.workspace = true +proptest-arbitrary-interop.workspace = true +test-strategy.workspace = true + +[lints] +workspace = true diff --git a/triton-vm/src/codegen/constraints.rs b/triton-constraint-builder/src/codegen.rs similarity index 94% rename from triton-vm/src/codegen/constraints.rs rename to triton-constraint-builder/src/codegen.rs index 2bdbf33d7..ff8d9edae 100644 --- a/triton-vm/src/codegen/constraints.rs +++ b/triton-constraint-builder/src/codegen.rs @@ -17,9 +17,9 @@ use quote::ToTokens; use twenty_first::prelude::x_field_element::EXTENSION_DEGREE; use twenty_first::prelude::*; -use crate::codegen::Constraints; +use crate::Constraints; -pub(crate) trait Codegen { +pub trait Codegen { fn constraint_evaluation_code(constraints: &Constraints) -> TokenStream; fn tokenize_bfe(bfe: BFieldElement) -> TokenStream { @@ -34,7 +34,7 @@ pub(crate) trait Codegen { } #[derive(Debug, Default, Clone, Eq, PartialEq)] -pub(crate) struct RustBackend { +pub struct RustBackend { /// All [circuit] IDs known to be in scope. /// /// [circuit]: triton_vm::table::circuit::ConstraintCircuit @@ -42,7 +42,7 @@ pub(crate) struct RustBackend { } #[derive(Debug, Default, Clone, Eq, PartialEq)] -pub(crate) struct TasmBackend { +pub struct TasmBackend { /// All [circuit] IDs known to be processed and stored to memory. /// /// [circuit]: triton_vm::table::circuit::ConstraintCircuit @@ -72,7 +72,6 @@ impl Codegen for RustBackend { let (term_constraint_degrees, term_constraints_bfe, term_constraints_xfe) = Self::tokenize_circuits(&constraints.term()); - let uses = Self::uses(); let evaluable_over_base_field = Self::generate_evaluable_implementation_over_field( &init_constraints_bfe, &cons_constraints_bfe, @@ -89,20 +88,24 @@ impl Codegen for RustBackend { ); let quotient_trait_impl = quote!( - impl Quotientable for MasterExtTable { - const NUM_INITIAL_CONSTRAINTS: usize = #num_init_constraints; - const NUM_CONSISTENCY_CONSTRAINTS: usize = #num_cons_constraints; - const NUM_TRANSITION_CONSTRAINTS: usize = #num_tran_constraints; - const NUM_TERMINAL_CONSTRAINTS: usize = #num_term_constraints; + impl MasterExtTable { + pub const NUM_INITIAL_CONSTRAINTS: usize = #num_init_constraints; + pub const NUM_CONSISTENCY_CONSTRAINTS: usize = #num_cons_constraints; + pub const NUM_TRANSITION_CONSTRAINTS: usize = #num_tran_constraints; + pub const NUM_TERMINAL_CONSTRAINTS: usize = #num_term_constraints; + pub const NUM_CONSTRAINTS: usize = Self::NUM_INITIAL_CONSTRAINTS + + Self::NUM_CONSISTENCY_CONSTRAINTS + + Self::NUM_TRANSITION_CONSTRAINTS + + Self::NUM_TERMINAL_CONSTRAINTS; #[allow(unused_variables)] - fn initial_quotient_degree_bounds(interpolant_degree: isize) -> Vec { + pub fn initial_quotient_degree_bounds(interpolant_degree: isize) -> Vec { let zerofier_degree = 1; [#init_constraint_degrees].to_vec() } #[allow(unused_variables)] - fn consistency_quotient_degree_bounds( + pub fn consistency_quotient_degree_bounds( interpolant_degree: isize, padded_height: usize, ) -> Vec { @@ -111,7 +114,7 @@ impl Codegen for RustBackend { } #[allow(unused_variables)] - fn transition_quotient_degree_bounds( + pub fn transition_quotient_degree_bounds( interpolant_degree: isize, padded_height: usize, ) -> Vec { @@ -120,7 +123,7 @@ impl Codegen for RustBackend { } #[allow(unused_variables)] - fn terminal_quotient_degree_bounds(interpolant_degree: isize) -> Vec { + pub fn terminal_quotient_degree_bounds(interpolant_degree: isize) -> Vec { let zerofier_degree = 1; [#term_constraint_degrees].to_vec() } @@ -128,7 +131,6 @@ impl Codegen for RustBackend { ); quote!( - #uses #evaluable_over_base_field #evaluable_over_ext_field #quotient_trait_impl @@ -137,19 +139,6 @@ impl Codegen for RustBackend { } impl RustBackend { - fn uses() -> TokenStream { - quote!( - use ndarray::ArrayView1; - use twenty_first::prelude::BFieldElement; - use twenty_first::prelude::XFieldElement; - - use crate::table::challenges::Challenges; - use crate::table::extension_table::Evaluable; - use crate::table::extension_table::Quotientable; - use crate::table::master_table::MasterExtTable; - ) - } - fn generate_evaluable_implementation_over_field( init_constraints: &TokenStream, cons_constraints: &TokenStream, @@ -373,14 +362,16 @@ impl RustBackend { } } +/// The minimal required size of a memory page in [`BFieldElement`]s. +pub const MEM_PAGE_SIZE: usize = 1 << 32; + /// An offset from the [memory layout][layout]'s `free_mem_page_ptr`, in number of /// [`XFieldElement`]s. Indicates the start of the to-be-returned array. /// /// [layout]: memory_layout::IntegralMemoryLayout const OUT_ARRAY_OFFSET: usize = { - let mem_page_size = crate::air::memory_layout::MEM_PAGE_SIZE; let max_num_words_for_evaluated_constraints = 1 << 16; // magic! - let out_array_offset_in_words = mem_page_size - max_num_words_for_evaluated_constraints; + let out_array_offset_in_words = MEM_PAGE_SIZE - max_num_words_for_evaluated_constraints; assert!(out_array_offset_in_words % EXTENSION_DEGREE == 0); out_array_offset_in_words / EXTENSION_DEGREE }; @@ -554,13 +545,10 @@ impl TasmBackend { quote!( use twenty_first::prelude::BFieldCodec; use twenty_first::prelude::BFieldElement; - use crate::instruction::LabelledInstruction; - use crate::Program; - use crate::air::memory_layout::StaticTasmConstraintEvaluationMemoryLayout; - use crate::air::memory_layout::DynamicTasmConstraintEvaluationMemoryLayout; - // for rustdoc – https://github.com/rust-lang/rust/issues/74563 - #[allow(unused_imports)] - use crate::table::extension_table::Quotientable; + use isa::instruction::LabelledInstruction; + use isa::program::Program; + use crate::memory_layout::StaticTasmConstraintEvaluationMemoryLayout; + use crate::memory_layout::DynamicTasmConstraintEvaluationMemoryLayout; ) } @@ -596,7 +584,7 @@ impl TasmBackend { (respectively and in this order) correspond to the evaluations of the initial, consistency, transition, and terminal constraints. - [integral]: crate::air::memory_layout::IntegralMemoryLayout::is_integral + [integral]: crate::memory_layout::IntegralMemoryLayout::is_integral [xfe]: twenty_first::prelude::XFieldElement [total]: crate::table::master_table::MasterExtTable::NUM_CONSTRAINTS [init]: crate::table::master_table::MasterExtTable::NUM_INITIAL_CONSTRAINTS @@ -638,7 +626,7 @@ impl TasmBackend { (respectively and in this order) correspond to the evaluations of the initial, consistency, transition, and terminal constraints. - [integral]: crate::air::memory_layout::IntegralMemoryLayout::is_integral + [integral]: crate::memory_layout::IntegralMemoryLayout::is_integral [xfe]: twenty_first::prelude::XFieldElement [total]: crate::table::master_table::MasterExtTable::NUM_CONSTRAINTS [init]: crate::table::master_table::MasterExtTable::NUM_INITIAL_CONSTRAINTS diff --git a/triton-vm/src/codegen/mod.rs b/triton-constraint-builder/src/lib.rs similarity index 62% rename from triton-vm/src/codegen/mod.rs rename to triton-constraint-builder/src/lib.rs index 507caa23a..fefbfdd6a 100644 --- a/triton-vm/src/codegen/mod.rs +++ b/triton-constraint-builder/src/lib.rs @@ -1,4 +1,16 @@ +use air::cross_table_argument::GrandCrossTableArg; +use air::table::cascade::CascadeTable; +use air::table::hash::HashTable; +use air::table::jump_stack::JumpStackTable; +use air::table::lookup::LookupTable; +use air::table::op_stack::OpStackTable; +use air::table::processor::ProcessorTable; +use air::table::program::ProgramTable; +use air::table::ram::RamTable; +use air::table::u32::U32Table; +use air::AIR; use constraint_circuit::ConstraintCircuit; +use constraint_circuit::ConstraintCircuitBuilder; use constraint_circuit::ConstraintCircuitMonad; use constraint_circuit::DegreeLoweringInfo; use constraint_circuit::DualRowIndicator; @@ -8,13 +20,13 @@ use itertools::Itertools; use proc_macro2::TokenStream; use std::fs::write; -use crate::codegen::constraints::Codegen; -use crate::codegen::constraints::RustBackend; -use crate::codegen::constraints::TasmBackend; -use crate::codegen::substitutions::AllSubstitutions; -use crate::codegen::substitutions::Substitutions; +use crate::codegen::Codegen; +use crate::codegen::RustBackend; +use crate::codegen::TasmBackend; +use crate::substitutions::AllSubstitutions; +use crate::substitutions::Substitutions; -mod constraints; +pub mod codegen; mod substitutions; pub fn gen(mut constraints: Constraints, info: DegreeLoweringInfo) { @@ -40,7 +52,7 @@ fn write_code_to_file(code: TokenStream, file_name: &str) { } #[derive(Debug, Clone)] -pub(crate) struct Constraints { +pub struct Constraints { pub init: Vec>, pub cons: Vec>, pub tran: Vec>, @@ -48,11 +60,88 @@ pub(crate) struct Constraints { } impl Constraints { + pub fn all() -> Constraints { + Constraints { + init: Self::initial_constraints(), + cons: Self::consistency_constraints(), + tran: Self::transition_constraints(), + term: Self::terminal_constraints(), + } + } + + pub fn initial_constraints() -> Vec> { + let circuit_builder = ConstraintCircuitBuilder::new(); + vec![ + ProgramTable::initial_constraints(&circuit_builder), + ProcessorTable::initial_constraints(&circuit_builder), + OpStackTable::initial_constraints(&circuit_builder), + RamTable::initial_constraints(&circuit_builder), + JumpStackTable::initial_constraints(&circuit_builder), + HashTable::initial_constraints(&circuit_builder), + CascadeTable::initial_constraints(&circuit_builder), + LookupTable::initial_constraints(&circuit_builder), + U32Table::initial_constraints(&circuit_builder), + GrandCrossTableArg::initial_constraints(&circuit_builder), + ] + .concat() + } + + pub fn consistency_constraints() -> Vec> { + let circuit_builder = ConstraintCircuitBuilder::new(); + vec![ + ProgramTable::consistency_constraints(&circuit_builder), + ProcessorTable::consistency_constraints(&circuit_builder), + OpStackTable::consistency_constraints(&circuit_builder), + RamTable::consistency_constraints(&circuit_builder), + JumpStackTable::consistency_constraints(&circuit_builder), + HashTable::consistency_constraints(&circuit_builder), + CascadeTable::consistency_constraints(&circuit_builder), + LookupTable::consistency_constraints(&circuit_builder), + U32Table::consistency_constraints(&circuit_builder), + GrandCrossTableArg::consistency_constraints(&circuit_builder), + ] + .concat() + } + + pub fn transition_constraints() -> Vec> { + let circuit_builder = ConstraintCircuitBuilder::new(); + vec![ + ProgramTable::transition_constraints(&circuit_builder), + ProcessorTable::transition_constraints(&circuit_builder), + OpStackTable::transition_constraints(&circuit_builder), + RamTable::transition_constraints(&circuit_builder), + JumpStackTable::transition_constraints(&circuit_builder), + HashTable::transition_constraints(&circuit_builder), + CascadeTable::transition_constraints(&circuit_builder), + LookupTable::transition_constraints(&circuit_builder), + U32Table::transition_constraints(&circuit_builder), + GrandCrossTableArg::transition_constraints(&circuit_builder), + ] + .concat() + } + + pub fn terminal_constraints() -> Vec> { + let circuit_builder = ConstraintCircuitBuilder::new(); + vec![ + ProgramTable::terminal_constraints(&circuit_builder), + ProcessorTable::terminal_constraints(&circuit_builder), + OpStackTable::terminal_constraints(&circuit_builder), + RamTable::terminal_constraints(&circuit_builder), + JumpStackTable::terminal_constraints(&circuit_builder), + HashTable::terminal_constraints(&circuit_builder), + CascadeTable::terminal_constraints(&circuit_builder), + LookupTable::terminal_constraints(&circuit_builder), + U32Table::terminal_constraints(&circuit_builder), + GrandCrossTableArg::terminal_constraints(&circuit_builder), + ] + .concat() + } + pub fn lower_to_target_degree_through_substitutions( &mut self, - mut info: DegreeLoweringInfo, + lowering_info: DegreeLoweringInfo, ) -> AllSubstitutions { - let lowering_info = info; + let mut info = lowering_info; let (init_base_substitutions, init_ext_substitutions) = ConstraintCircuitMonad::lower_to_degree(&mut self.init, info); @@ -133,8 +222,6 @@ mod tests { use constraint_circuit::ConstraintCircuitBuilder; use twenty_first::prelude::*; - use crate::table; - use super::*; #[repr(usize)] @@ -157,6 +244,19 @@ mod tests { } } + #[test] + fn public_types_implement_usual_auto_traits() { + fn implements_auto_traits() {} + + implements_auto_traits::(); + implements_auto_traits::(); + + // maybe some day + // implements_auto_traits::(); + // implements_auto_traits::(); + // implements_auto_traits::(); + } + #[test] fn test_constraints_can_be_fetched() { Constraints::test_constraints(); @@ -172,7 +272,7 @@ mod tests { #[test] fn degree_lowering_tables_code_can_be_generated_from_all_constraints() { - let mut constraints = table::constraints(); + let mut constraints = Constraints::all(); let substitutions = constraints.lower_to_target_degree_through_substitutions(degree_lowering_info()); let _unused = substitutions.generate_degree_lowering_table_code(); diff --git a/triton-vm/src/codegen/substitutions.rs b/triton-constraint-builder/src/substitutions.rs similarity index 86% rename from triton-vm/src/codegen/substitutions.rs rename to triton-constraint-builder/src/substitutions.rs index ea6da7ef4..dc2df5dc4 100644 --- a/triton-vm/src/codegen/substitutions.rs +++ b/triton-constraint-builder/src/substitutions.rs @@ -11,14 +11,16 @@ use proc_macro2::TokenStream; use quote::format_ident; use quote::quote; -use crate::codegen::constraints::RustBackend; +use crate::codegen::RustBackend; -pub(crate) struct AllSubstitutions { +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub struct AllSubstitutions { pub base: Substitutions, pub ext: Substitutions, } -pub(crate) struct Substitutions { +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub struct Substitutions { pub lowering_info: DegreeLoweringInfo, pub init: Vec>, pub cons: Vec>, @@ -56,15 +58,6 @@ impl AllSubstitutions { let fill_ext_columns_code = self.ext.generate_fill_ext_columns_code(); quote!( - //! The degree lowering table contains the introduced variables that allow - //! lowering the degree of the AIR. See - //! [`table::master_table::AIR_TARGET_DEGREE`] - //! for additional information. - //! - //! This file has been auto-generated. Any modifications _will_ be lost. - //! To re-generate, execute: - //! `cargo run --bin constraint-evaluation-generator` - use ndarray::Array1; use ndarray::s; use ndarray::ArrayView2; @@ -76,14 +69,11 @@ impl AllSubstitutions { use strum::EnumIter; use twenty_first::prelude::BFieldElement; use twenty_first::prelude::XFieldElement; + use air::table_column::MasterBaseTableColumn; + use air::table_column::MasterExtTableColumn; - use crate::table::challenges::Challenges; - use crate::table::master_table::NUM_BASE_COLUMNS; - use crate::table::master_table::NUM_EXT_COLUMNS; - - pub const BASE_WIDTH: usize = DegreeLoweringBaseTableColumn::COUNT; - pub const EXT_WIDTH: usize = DegreeLoweringExtTableColumn::COUNT; - pub const FULL_WIDTH: usize = BASE_WIDTH + EXT_WIDTH; + use crate::challenges::Challenges; + use crate::table::master_table::MasterTable; #base_repr_usize #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] @@ -91,12 +81,34 @@ impl AllSubstitutions { #(#base_columns),* } + impl MasterBaseTableColumn for DegreeLoweringBaseTableColumn { + fn base_table_index(&self) -> usize { + (*self) as usize + } + + fn master_base_table_index(&self) -> usize { + // hardcore domain-specific knowledge, and bad style + air::table::U32_TABLE_END + self.base_table_index() + } + } + #ext_repr_usize #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] pub enum DegreeLoweringExtTableColumn { #(#ext_columns),* } + impl MasterExtTableColumn for DegreeLoweringExtTableColumn { + fn ext_table_index(&self) -> usize { + (*self) as usize + } + + fn master_ext_table_index(&self) -> usize { + // hardcore domain-specific knowledge, and bad style + air::table::EXT_U32_TABLE_END + self.ext_table_index() + } + } + #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct DegreeLoweringTable; @@ -134,14 +146,18 @@ impl Substitutions { Self::base_single_row_substitutions(derived_section_term_start, &term_substitutions); quote!( - #[allow(unused_variables)] - pub fn fill_derived_base_columns(mut master_base_table: ArrayViewMut2) { - assert_eq!(NUM_BASE_COLUMNS, master_base_table.ncols()); - #init_substitutions - #cons_substitutions - #tran_substitutions - #term_substitutions - } + #[allow(unused_variables)] + pub fn fill_derived_base_columns( + mut master_base_table: ArrayViewMut2 + ) { + let num_expected_columns = + crate::table::master_table::MasterBaseTable::NUM_COLUMNS; + assert_eq!(num_expected_columns, master_base_table.ncols()); + #init_substitutions + #cons_substitutions + #tran_substitutions + #term_substitutions + } ) } @@ -173,8 +189,12 @@ impl Substitutions { mut master_ext_table: ArrayViewMut2, challenges: &Challenges, ) { - assert_eq!(NUM_BASE_COLUMNS, master_base_table.ncols()); - assert_eq!(NUM_EXT_COLUMNS, master_ext_table.ncols()); + let num_expected_main_columns = + crate::table::master_table::MasterBaseTable::NUM_COLUMNS; + let num_expected_aux_columns = + crate::table::master_table::MasterExtTable::NUM_COLUMNS; + assert_eq!(num_expected_main_columns, master_base_table.ncols()); + assert_eq!(num_expected_aux_columns, master_ext_table.ncols()); assert_eq!(master_base_table.nrows(), master_ext_table.nrows()); #init_substitutions #cons_substitutions diff --git a/triton-constraint-circuit/src/lib.rs b/triton-constraint-circuit/src/lib.rs index 30292f9b2..a4e7f0a5c 100644 --- a/triton-constraint-circuit/src/lib.rs +++ b/triton-constraint-circuit/src/lib.rs @@ -1372,7 +1372,7 @@ mod tests { let mut multicircuit = [constraint_0, constraint_1]; let degree_lowering_info = DegreeLoweringInfo { - target_degree: 2, + target_degree: 3, num_base_cols: 9, num_ext_cols: 0, }; diff --git a/triton-isa/src/instruction.rs b/triton-isa/src/instruction.rs index f6c2df29c..1884457ee 100644 --- a/triton-isa/src/instruction.rs +++ b/triton-isa/src/instruction.rs @@ -398,7 +398,7 @@ impl AnInstruction { ((opcode >> bit_number) & 1).into() } - pub(crate) fn map_call_address(&self, f: F) -> AnInstruction + pub fn map_call_address(&self, f: F) -> AnInstruction where F: FnOnce(&Dest) -> NewDest, NewDest: PartialEq + Default, diff --git a/triton-isa/src/parser.rs b/triton-isa/src/parser.rs index f9748fdd8..37f43153f 100644 --- a/triton-isa/src/parser.rs +++ b/triton-isa/src/parser.rs @@ -983,14 +983,14 @@ pub(crate) mod tests { fn parse_program_label() { TestCase { input: "foo: call foo", - expected: vec![Instruction::Call(bfe!(0))], + expected: vec![Instruction::Call(bfe!(0)), Instruction::Call(bfe!(0))], message: "parse labels and calls to labels", } .run(); TestCase { input: "foo:call foo", - expected: vec![Instruction::Call(bfe!(0))], + expected: vec![Instruction::Call(bfe!(0)), Instruction::Call(bfe!(0))], message: "whitespace is not required after 'label:'", } .run(); @@ -1024,14 +1024,14 @@ pub(crate) mod tests { TestCase { input: "pops: call pops", - expected: vec![Instruction::Call(bfe!(0))], + expected: vec![Instruction::Call(bfe!(0)), Instruction::Call(bfe!(0))], message: "labels that share a common prefix with instruction are labels", } .run(); TestCase { input: "_call: call _call", - expected: vec![Instruction::Call(bfe!(0))], + expected: vec![Instruction::Call(bfe!(0)), Instruction::Call(bfe!(0))], message: "labels that share a common suffix with instruction are labels", } .run(); @@ -1068,7 +1068,7 @@ pub(crate) mod tests { fn parse_program_bracket_syntax() { TestCase { input: "foo: [foo]", - expected: vec![Instruction::Call(bfe!(0))], + expected: vec![Instruction::Call(bfe!(0)), Instruction::Call(bfe!(0))], message: "Handle brackets as call syntax sugar", } .run(); @@ -1102,7 +1102,7 @@ pub(crate) mod tests { TestCase { input: "_foo: call _foo", - expected: vec![Instruction::Call(bfe!(0))], + expected: vec![Instruction::Call(bfe!(0)), Instruction::Call(bfe!(0))], message: "labels can start with an underscore", } .run(); diff --git a/triton-vm/Cargo.toml b/triton-vm/Cargo.toml index 31622881e..b54daa884 100644 --- a/triton-vm/Cargo.toml +++ b/triton-vm/Cargo.toml @@ -21,7 +21,7 @@ readme.workspace = true air.workspace = true arbitrary.workspace = true colored.workspace = true -constraint-circuit.workspace = true +constraint-builder.workspace = true criterion.workspace = true get-size.workspace = true indexmap.workspace = true @@ -47,6 +47,7 @@ unicode-width.workspace = true [dev-dependencies] assert2.workspace = true cargo-husky.workspace = true +constraint-circuit.workspace = true fs-err.workspace = true pretty_assertions.workspace = true prettyplease.workspace = true @@ -56,6 +57,15 @@ serde_json.workspace = true test-strategy.workspace = true trybuild.workspace = true +[build-dependencies] +air.workspace = true +constraint-builder.workspace = true +constraint-circuit.workspace = true +prettyplease.workspace = true +proc-macro2.workspace = true +quote.workspace = true +syn.workspace = true + [features] default = ["no_profile"] no_profile = [] # see `profiler.rs` for an explanation of this seemingly backwards feature diff --git a/triton-vm/build.rs b/triton-vm/build.rs new file mode 100644 index 000000000..6c8fd6c74 --- /dev/null +++ b/triton-vm/build.rs @@ -0,0 +1,38 @@ +use std::path::Path; + +use constraint_builder::codegen::Codegen; +use constraint_builder::codegen::RustBackend; +use constraint_builder::codegen::TasmBackend; +use constraint_builder::Constraints; +use proc_macro2::TokenStream; + +fn main() { + println!("cargo::rerun-if-changed=build.rs"); + + let mut constraints = Constraints::all(); + let degree_lowering_info = constraint_circuit::DegreeLoweringInfo { + target_degree: air::TARGET_DEGREE, + num_base_cols: air::table::NUM_BASE_COLUMNS, + num_ext_cols: air::table::NUM_EXT_COLUMNS, + }; + let substitutions = + constraints.lower_to_target_degree_through_substitutions(degree_lowering_info); + let deg_low_table = substitutions.generate_degree_lowering_table_code(); + + let constraints = constraints.combine_with_substitution_induced_constraints(substitutions); + let rust = RustBackend::constraint_evaluation_code(&constraints); + let tasm = TasmBackend::constraint_evaluation_code(&constraints); + + write_code_to_file(deg_low_table, "degree_lowering_table.rs"); + write_code_to_file(rust, "evaluate_constraints.rs"); + write_code_to_file(tasm, "tasm_constraints.rs"); +} + +fn write_code_to_file(code: TokenStream, file_name: &str) { + let syntax_tree = syn::parse2(code).unwrap(); + let code = prettyplease::unparse(&syntax_tree); + + let out_dir = std::env::var_os("OUT_DIR").unwrap(); + let file_path = Path::new(&out_dir).join(file_name); + std::fs::write(file_path, code).unwrap(); +} diff --git a/triton-vm/src/air.rs b/triton-vm/src/air.rs index 4d1fcf1ae..91371b0d6 100644 --- a/triton-vm/src/air.rs +++ b/triton-vm/src/air.rs @@ -1,11 +1,7 @@ -pub mod memory_layout; -#[rustfmt::skip] -pub mod tasm_air_constraints; +include!(concat!(env!("OUT_DIR"), "/tasm_constraints.rs")); #[cfg(test)] mod test { - use air::table::NUM_BASE_COLUMNS; - use air::table::NUM_EXT_COLUMNS; use isa::instruction::AnInstruction; use itertools::Itertools; use ndarray::Array1; @@ -17,34 +13,34 @@ mod test { use twenty_first::prelude::x_field_element::EXTENSION_DEGREE; use twenty_first::prelude::*; - use crate::air::tasm_air_constraints::dynamic_air_constraint_evaluation_tasm; - use crate::air::tasm_air_constraints::static_air_constraint_evaluation_tasm; + use crate::air::dynamic_air_constraint_evaluation_tasm; + use crate::air::static_air_constraint_evaluation_tasm; use crate::challenges::Challenges; + use crate::memory_layout; + use crate::memory_layout::DynamicTasmConstraintEvaluationMemoryLayout; + use crate::memory_layout::IntegralMemoryLayout; + use crate::memory_layout::StaticTasmConstraintEvaluationMemoryLayout; use crate::prelude::*; use crate::table::extension_table::Evaluable; - use crate::table::extension_table::Quotientable; use crate::table::master_table::MasterExtTable; - - use super::memory_layout::DynamicTasmConstraintEvaluationMemoryLayout; - use super::memory_layout::IntegralMemoryLayout; - use super::memory_layout::StaticTasmConstraintEvaluationMemoryLayout; - use super::*; + use crate::table::NUM_AUX_COLUMNS; + use crate::table::NUM_MAIN_COLUMNS; #[derive(Debug, Clone, test_strategy::Arbitrary)] struct ConstraintEvaluationPoint { - #[strategy(vec(arb(), NUM_BASE_COLUMNS))] + #[strategy(vec(arb(), NUM_MAIN_COLUMNS))] #[map(Array1::from)] curr_base_row: Array1, - #[strategy(vec(arb(), NUM_EXT_COLUMNS))] + #[strategy(vec(arb(), NUM_AUX_COLUMNS))] #[map(Array1::from)] curr_ext_row: Array1, - #[strategy(vec(arb(), NUM_BASE_COLUMNS))] + #[strategy(vec(arb(), NUM_MAIN_COLUMNS))] #[map(Array1::from)] next_base_row: Array1, - #[strategy(vec(arb(), NUM_EXT_COLUMNS))] + #[strategy(vec(arb(), NUM_AUX_COLUMNS))] #[map(Array1::from)] next_ext_row: Array1, @@ -88,24 +84,16 @@ mod test { let program = self.tasm_static_constraint_evaluation_code(); let mut vm_state = self.set_up_triton_vm_to_evaluate_constraints_in_tasm_static(&program); - vm_state.run().unwrap(); - - let output_list_ptr = vm_state.op_stack.pop().unwrap().value(); - let num_quotients = MasterExtTable::NUM_CONSTRAINTS; - Self::read_xfe_list_at_address(vm_state.ram, output_list_ptr, num_quotients) + Self::extract_constraint_evaluations(vm_state) } fn evaluate_all_constraints_tasm_dynamic(&self) -> Vec { let program = self.tasm_dynamic_constraint_evaluation_code(); let mut vm_state = self.set_up_triton_vm_to_evaluate_constraints_in_tasm_dynamic(&program); - vm_state.run().unwrap(); - - let output_list_ptr = vm_state.op_stack.pop().unwrap().value(); - let num_quotients = MasterExtTable::NUM_CONSTRAINTS; - Self::read_xfe_list_at_address(vm_state.ram, output_list_ptr, num_quotients) + Self::extract_constraint_evaluations(vm_state) } fn tasm_static_constraint_evaluation_code(&self) -> Program { @@ -124,6 +112,14 @@ mod test { Program::new(&source_code) } + /// Requires a VM State that has executed constraint evaluation code. + fn extract_constraint_evaluations(mut vm_state: VMState) -> Vec { + assert!(vm_state.halting); + let output_list_ptr = vm_state.op_stack.pop().unwrap().value(); + let num_quotients = MasterExtTable::NUM_CONSTRAINTS; + Self::read_xfe_list_at_address(vm_state.ram, output_list_ptr, num_quotients) + } + fn set_up_triton_vm_to_evaluate_constraints_in_tasm_static( &self, program: &Program, @@ -152,18 +148,12 @@ mod test { // for convenience, reuse the (integral) static memory layout let mut vm_state = self.set_up_triton_vm_to_evaluate_constraints_in_tasm_static(program); - vm_state - .op_stack - .push(self.static_memory_layout.curr_base_row_ptr); - vm_state - .op_stack - .push(self.static_memory_layout.curr_ext_row_ptr); - vm_state - .op_stack - .push(self.static_memory_layout.next_base_row_ptr); - vm_state - .op_stack - .push(self.static_memory_layout.next_ext_row_ptr); + + let layout = self.static_memory_layout; + vm_state.op_stack.push(layout.curr_base_row_ptr); + vm_state.op_stack.push(layout.curr_ext_row_ptr); + vm_state.op_stack.push(layout.next_base_row_ptr); + vm_state.op_stack.push(layout.next_ext_row_ptr); vm_state } diff --git a/triton-vm/src/air/tasm_air_constraints.rs b/triton-vm/src/air/tasm_air_constraints.rs deleted file mode 100644 index 95edab7db..000000000 --- a/triton-vm/src/air/tasm_air_constraints.rs +++ /dev/null @@ -1,21 +0,0 @@ -//! This file is a placeholder for auto-generated code -//! Run `cargo run --bin constraint-evaluation-generator` -//! to fill in this file with optimized constraints. - -use isa::instruction::LabelledInstruction; - -use crate::air::memory_layout::DynamicTasmConstraintEvaluationMemoryLayout; -use crate::air::memory_layout::StaticTasmConstraintEvaluationMemoryLayout; -use crate::table::constraints::ERROR_MESSAGE_GENERATE_CONSTRAINTS; - -pub fn static_air_constraint_evaluation_tasm( - _: StaticTasmConstraintEvaluationMemoryLayout, -) -> Vec { - panic!("{ERROR_MESSAGE_GENERATE_CONSTRAINTS}"); -} - -pub fn dynamic_air_constraint_evaluation_tasm( - _: DynamicTasmConstraintEvaluationMemoryLayout, -) -> Vec { - panic!("{ERROR_MESSAGE_GENERATE_CONSTRAINTS}"); -} diff --git a/triton-vm/src/lib.rs b/triton-vm/src/lib.rs index 32652675f..5c4e07d1a 100644 --- a/triton-vm/src/lib.rs +++ b/triton-vm/src/lib.rs @@ -168,12 +168,12 @@ pub mod aet; pub mod air; pub mod arithmetic_domain; pub mod challenges; -mod codegen; pub mod config; pub mod error; pub mod example_programs; pub mod execution_trace_profiler; pub mod fri; +pub mod memory_layout; mod ndarray_helper; pub mod prelude; pub mod profiler; @@ -331,30 +331,31 @@ mod tests { // table things implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); // other implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); + implements_auto_traits::(); implements_auto_traits::(); } diff --git a/triton-vm/src/air/memory_layout.rs b/triton-vm/src/memory_layout.rs similarity index 93% rename from triton-vm/src/air/memory_layout.rs rename to triton-vm/src/memory_layout.rs index c19879b0f..57f9bd2dc 100644 --- a/triton-vm/src/air/memory_layout.rs +++ b/triton-vm/src/memory_layout.rs @@ -1,18 +1,17 @@ +pub use constraint_builder::codegen::MEM_PAGE_SIZE; + +use air::challenge_id::ChallengeId; use air::table::NUM_BASE_COLUMNS; use air::table::NUM_EXT_COLUMNS; use arbitrary::Arbitrary; use itertools::Itertools; +use strum::EnumCount; use twenty_first::prelude::*; -use crate::challenges::Challenges; - -/// The minimal required size of a memory page in [`BFieldElement`]s. -pub const MEM_PAGE_SIZE: usize = 1 << 32; - /// Memory layout guarantees for the [Triton assembly AIR constraint evaluator][tasm_air] /// with input lists at dynamically known memory locations. /// -/// [tasm_air]: crate::air::tasm_air_constraints::dynamic_air_constraint_evaluation_tasm +/// [tasm_air]: crate::air::dynamic_air_constraint_evaluation_tasm #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)] pub struct DynamicTasmConstraintEvaluationMemoryLayout { /// Pointer to a region of memory that is reserved for (a) pointers to {current, @@ -23,14 +22,14 @@ pub struct DynamicTasmConstraintEvaluationMemoryLayout { /// Pointer to an array of [`XFieldElement`]s of length [`NUM_CHALLENGES`][num_challenges]. /// - /// [num_challenges]: Challenges::COUNT + /// [num_challenges]: ChallengeId::COUNT pub challenges_ptr: BFieldElement, } /// Memory layout guarantees for the [Triton assembly AIR constraint evaluator][tasm_air] /// with input lists at statically known memory locations. /// -/// [tasm_air]: crate::air::tasm_air_constraints::static_air_constraint_evaluation_tasm +/// [tasm_air]: crate::air::static_air_constraint_evaluation_tasm #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)] pub struct StaticTasmConstraintEvaluationMemoryLayout { /// Pointer to a region of memory that is reserved for constraint evaluation. @@ -51,7 +50,7 @@ pub struct StaticTasmConstraintEvaluationMemoryLayout { /// Pointer to an array of [`XFieldElement`]s of length [`NUM_CHALLENGES`][num_challenges]. /// - /// [num_challenges]: Challenges::COUNT + /// [num_challenges]: ChallengeId::COUNT pub challenges_ptr: BFieldElement, } @@ -82,7 +81,7 @@ impl IntegralMemoryLayout for StaticTasmConstraintEvaluationMemoryLayout { MemoryRegion::new(self.curr_ext_row_ptr, NUM_EXT_COLUMNS), MemoryRegion::new(self.next_base_row_ptr, NUM_BASE_COLUMNS), MemoryRegion::new(self.next_ext_row_ptr, NUM_EXT_COLUMNS), - MemoryRegion::new(self.challenges_ptr, Challenges::COUNT), + MemoryRegion::new(self.challenges_ptr, ChallengeId::COUNT), ]; Box::new(all_regions) } @@ -92,7 +91,7 @@ impl IntegralMemoryLayout for DynamicTasmConstraintEvaluationMemoryLayout { fn memory_regions(&self) -> Box<[MemoryRegion]> { let all_regions = [ MemoryRegion::new(self.free_mem_page_ptr, MEM_PAGE_SIZE), - MemoryRegion::new(self.challenges_ptr, Challenges::COUNT), + MemoryRegion::new(self.challenges_ptr, ChallengeId::COUNT), ]; Box::new(all_regions) } diff --git a/triton-vm/src/stark.rs b/triton-vm/src/stark.rs index ddeb32862..496ea79af 100644 --- a/triton-vm/src/stark.rs +++ b/triton-vm/src/stark.rs @@ -1,8 +1,6 @@ use std::ops::Mul; use std::ops::MulAssign; -use air::table::NUM_BASE_COLUMNS; -use air::table::NUM_EXT_COLUMNS; use arbitrary::Arbitrary; use arbitrary::Unstructured; use itertools::izip; @@ -31,7 +29,6 @@ use crate::proof::Proof; use crate::proof_item::ProofItem; use crate::proof_stream::ProofStream; use crate::table::extension_table::Evaluable; -use crate::table::extension_table::Quotientable; use crate::table::master_table::all_quotients_combined; use crate::table::master_table::interpolant_degree; use crate::table::master_table::max_degree_with_origin; @@ -250,12 +247,12 @@ impl Stark { profiler!(start "linear combination"); profiler!(start "base" ("CC")); let base_combination_polynomial = - Self::random_linear_sum(master_base_table.interpolation_polynomials(), weights.base); + Self::random_linear_sum(master_base_table.interpolation_polynomials(), weights.main); profiler!(stop "base"); profiler!(start "ext" ("CC")); let ext_combination_polynomial = - Self::random_linear_sum(master_ext_table.interpolation_polynomials(), weights.ext); + Self::random_linear_sum(master_ext_table.interpolation_polynomials(), weights.aux); profiler!(stop "ext"); let base_and_ext_combination_polynomial = base_combination_polynomial + ext_combination_polynomial; @@ -363,7 +360,7 @@ impl Stark { if let Some(fri_domain_table) = master_base_table.fri_domain_table() { Self::read_revealed_rows(fri_domain_table, &revealed_current_row_indices)? } else { - Self::recompute_revealed_rows::( + Self::recompute_revealed_rows::<{ MasterBaseTable::NUM_COLUMNS }, BFieldElement>( &master_base_table.interpolation_polynomials(), &revealed_current_row_indices, fri.domain, @@ -1266,11 +1263,11 @@ impl<'a> Arbitrary<'a> for Stark { /// Fiat-Shamir-sampled challenges to compress a row into a single /// [extension field element][XFieldElement]. struct LinearCombinationWeights { - /// of length [`NUM_BASE_COLUMNS`] - base: Array1, + /// of length [`MasterBaseTable::NUM_COLUMNS`] + main: Array1, - /// of length [`NUM_EXT_COLUMNS`] - ext: Array1, + /// of length [`MasterExtTable::NUM_COLUMNS`] + aux: Array1, /// of length [`NUM_QUOTIENT_SEGMENTS`] quot_segments: Array1, @@ -1280,27 +1277,29 @@ struct LinearCombinationWeights { } impl LinearCombinationWeights { - const NUM: usize = - NUM_BASE_COLUMNS + NUM_EXT_COLUMNS + NUM_QUOTIENT_SEGMENTS + NUM_DEEP_CODEWORD_COMPONENTS; + const NUM: usize = MasterBaseTable::NUM_COLUMNS + + MasterExtTable::NUM_COLUMNS + + NUM_QUOTIENT_SEGMENTS + + NUM_DEEP_CODEWORD_COMPONENTS; fn sample(proof_stream: &mut ProofStream) -> Self { - const BASE_END: usize = NUM_BASE_COLUMNS; - const EXT_END: usize = BASE_END + NUM_EXT_COLUMNS; - const QUOT_END: usize = EXT_END + NUM_QUOTIENT_SEGMENTS; + const MAIN_END: usize = MasterBaseTable::NUM_COLUMNS; + const AUX_END: usize = MAIN_END + MasterExtTable::NUM_COLUMNS; + const QUOT_END: usize = AUX_END + NUM_QUOTIENT_SEGMENTS; let weights = proof_stream.sample_scalars(Self::NUM); Self { - base: weights[..BASE_END].to_vec().into(), - ext: weights[BASE_END..EXT_END].to_vec().into(), - quot_segments: weights[EXT_END..QUOT_END].to_vec().into(), + main: weights[..MAIN_END].to_vec().into(), + aux: weights[MAIN_END..AUX_END].to_vec().into(), + quot_segments: weights[AUX_END..QUOT_END].to_vec().into(), deep: weights[QUOT_END..].to_vec().into(), } } fn base_and_ext(&self) -> Array1 { - let base = self.base.clone().into_iter(); - base.chain(self.ext.clone()).collect() + let base = self.main.clone().into_iter(); + base.chain(self.aux.clone()).collect() } } @@ -1357,7 +1356,6 @@ pub(crate) mod tests { use crate::shared_tests::*; use crate::table::extension_table; use crate::table::extension_table::Evaluable; - use crate::table::extension_table::Quotientable; use crate::table::master_table::MasterExtTable; use crate::triton_program; use crate::vm::tests::*; @@ -1628,8 +1626,8 @@ pub(crate) mod tests { #[test] fn constraint_polynomials_use_right_number_of_variables() { let challenges = Challenges::default(); - let base_row = Array1::::zeros(NUM_BASE_COLUMNS); - let ext_row = Array1::zeros(NUM_EXT_COLUMNS); + let base_row = Array1::::zeros(MasterBaseTable::NUM_COLUMNS); + let ext_row = Array1::zeros(MasterExtTable::NUM_COLUMNS); let br = base_row.view(); let er = ext_row.view(); @@ -1738,8 +1736,8 @@ pub(crate) mod tests { #[test] fn number_of_quotient_degree_bounds_match_number_of_constraints() { - let base_row = Array1::::zeros(NUM_BASE_COLUMNS); - let ext_row = Array1::zeros(NUM_EXT_COLUMNS); + let base_row = Array1::::zeros(MasterBaseTable::NUM_COLUMNS); + let ext_row = Array1::zeros(MasterExtTable::NUM_COLUMNS); let ch = Challenges::default(); let padded_height = 2; let num_trace_randomizers = 2; @@ -2193,9 +2191,9 @@ pub(crate) mod tests { let (_, _, master_base_table, master_ext_table, challenges) = master_tables_for_low_security_level(program_and_input); - let num_base_rows = master_base_table.randomized_trace_table().nrows(); - let num_ext_rows = master_ext_table.randomized_trace_table().nrows(); - assert!(num_base_rows == num_ext_rows); + let num_main_rows = master_base_table.randomized_trace_table().nrows(); + let num_aux_rows = master_ext_table.randomized_trace_table().nrows(); + assert!(num_main_rows == num_aux_rows); let mbt = master_base_table.trace_table(); let met = master_ext_table.trace_table(); @@ -2471,8 +2469,10 @@ pub(crate) mod tests { #[strategy(arb())] #[filter(!#offset.is_zero())] offset: BFieldElement, - #[strategy(arb())] main_polynomials: [Polynomial; NUM_BASE_COLUMNS], - #[strategy(arb())] aux_polynomials: [Polynomial; NUM_EXT_COLUMNS], + #[strategy(arb())] main_polynomials: [Polynomial; + MasterBaseTable::NUM_COLUMNS], + #[strategy(arb())] aux_polynomials: [Polynomial; + MasterExtTable::NUM_COLUMNS], #[strategy(arb())] challenges: Challenges, #[strategy(arb())] quotient_weights: [XFieldElement; MasterExtTable::NUM_CONSTRAINTS], ) { @@ -2529,14 +2529,14 @@ pub(crate) mod tests { quotient_weights: &[XFieldElement], ) -> (Array2, Array1>) { let mut base_quotient_domain_codewords = - Array2::::zeros([quotient_domain.length, NUM_BASE_COLUMNS]); + Array2::::zeros([quotient_domain.length, MasterBaseTable::NUM_COLUMNS]); Zip::from(base_quotient_domain_codewords.axis_iter_mut(Axis(1))) .and(main_polynomials.axis_iter(Axis(0))) .for_each(|codeword, polynomial| { Array1::from_vec(quotient_domain.evaluate(&polynomial[()])).move_into(codeword); }); let mut ext_quotient_domain_codewords = - Array2::::zeros([quotient_domain.length, NUM_EXT_COLUMNS]); + Array2::::zeros([quotient_domain.length, MasterExtTable::NUM_COLUMNS]); Zip::from(ext_quotient_domain_codewords.axis_iter_mut(Axis(1))) .and(aux_polynomials.axis_iter(Axis(0))) .for_each(|codeword, polynomial| { @@ -2621,12 +2621,12 @@ pub(crate) mod tests { ) { let weights = LinearCombinationWeights::sample(&mut proof_stream); - prop_assert_eq!(NUM_BASE_COLUMNS, weights.base.len()); - prop_assert_eq!(NUM_EXT_COLUMNS, weights.ext.len()); + prop_assert_eq!(MasterBaseTable::NUM_COLUMNS, weights.main.len()); + prop_assert_eq!(MasterExtTable::NUM_COLUMNS, weights.aux.len()); prop_assert_eq!(NUM_QUOTIENT_SEGMENTS, weights.quot_segments.len()); prop_assert_eq!(NUM_DEEP_CODEWORD_COMPONENTS, weights.deep.len()); prop_assert_eq!( - NUM_BASE_COLUMNS + NUM_EXT_COLUMNS, + MasterBaseTable::NUM_COLUMNS + MasterExtTable::NUM_COLUMNS, weights.base_and_ext().len() ); } diff --git a/triton-vm/src/table.rs b/triton-vm/src/table.rs index 8bb7e8321..bc1ace5f6 100644 --- a/triton-vm/src/table.rs +++ b/triton-vm/src/table.rs @@ -1,21 +1,7 @@ -use air::cross_table_argument::GrandCrossTableArg; use air::table::cascade::CascadeTable; use air::table::hash::HashTable; use air::table::jump_stack::JumpStackTable; -use air::table::lookup::LookupTable; -use air::table::op_stack::OpStackTable; -use air::table::processor::ProcessorTable; -use air::table::program::ProgramTable; -use air::table::ram::RamTable; -use air::table::u32::U32Table; -use air::table::NUM_BASE_COLUMNS; -use air::table::NUM_EXT_COLUMNS; use air::AIR; -use arbitrary::Arbitrary; -use constraint_circuit::ConstraintCircuitBuilder; -use constraint_circuit::ConstraintCircuitMonad; -use constraint_circuit::DualRowIndicator; -use constraint_circuit::SingleRowIndicator; use ndarray::ArrayView2; use ndarray::ArrayViewMut2; use strum::Display; @@ -25,13 +11,12 @@ use twenty_first::prelude::*; use crate::aet::AlgebraicExecutionTrace; use crate::challenges::Challenges; -use crate::codegen::Constraints; pub use crate::stark::NUM_QUOTIENT_SEGMENTS; +use crate::table::master_table::MasterBaseTable; +use crate::table::master_table::MasterExtTable; +use crate::table::master_table::MasterTable; -#[rustfmt::skip] -pub mod constraints; -#[rustfmt::skip] -pub mod degree_lowering_table; +pub mod degree_lowering; pub mod cascade; pub mod extension_table; @@ -45,6 +30,17 @@ pub mod program; pub mod ram; pub mod u32; +/// The total number of main columns across all tables. +/// The degree lowering columns _are_ included. +pub const NUM_MAIN_COLUMNS: usize = + air::table::NUM_BASE_COLUMNS + degree_lowering::DegreeLoweringBaseTableColumn::COUNT; + +/// The total number of auxiliary columns across all tables. +/// The degree lowering columns _are_ included, +/// randomizer polynomials are _not_ included. +pub const NUM_AUX_COLUMNS: usize = + air::table::NUM_EXT_COLUMNS + degree_lowering::DegreeLoweringExtTableColumn::COUNT; + trait TraceTable: AIR { // a nicer design is in order type FillParam; @@ -89,101 +85,29 @@ pub enum ConstraintType { /// [`XFieldElement`]s. /// /// [table]: master_table::MasterBaseTable -pub type BaseRow = [T; NUM_BASE_COLUMNS]; +pub type BaseRow = [T; MasterBaseTable::NUM_COLUMNS]; /// A single row of a [`MasterExtensionTable`][table]. /// /// [table]: master_table::MasterExtTable -pub type ExtensionRow = [XFieldElement; NUM_EXT_COLUMNS]; +pub type ExtensionRow = [XFieldElement; MasterExtTable::NUM_COLUMNS]; /// An element of the split-up quotient polynomial. /// /// See also [`NUM_QUOTIENT_SEGMENTS`]. pub type QuotientSegments = [XFieldElement; NUM_QUOTIENT_SEGMENTS]; -pub(crate) fn constraints() -> Constraints { - Constraints { - init: initial_constraints(), - cons: consistency_constraints(), - tran: transition_constraints(), - term: terminal_constraints(), - } -} - -fn initial_constraints() -> Vec> { - let circuit_builder = ConstraintCircuitBuilder::new(); - vec![ - ProgramTable::initial_constraints(&circuit_builder), - ProcessorTable::initial_constraints(&circuit_builder), - OpStackTable::initial_constraints(&circuit_builder), - RamTable::initial_constraints(&circuit_builder), - JumpStackTable::initial_constraints(&circuit_builder), - HashTable::initial_constraints(&circuit_builder), - CascadeTable::initial_constraints(&circuit_builder), - LookupTable::initial_constraints(&circuit_builder), - U32Table::initial_constraints(&circuit_builder), - GrandCrossTableArg::initial_constraints(&circuit_builder), - ] - .concat() -} - -fn consistency_constraints() -> Vec> { - let circuit_builder = ConstraintCircuitBuilder::new(); - vec![ - ProgramTable::consistency_constraints(&circuit_builder), - ProcessorTable::consistency_constraints(&circuit_builder), - OpStackTable::consistency_constraints(&circuit_builder), - RamTable::consistency_constraints(&circuit_builder), - JumpStackTable::consistency_constraints(&circuit_builder), - HashTable::consistency_constraints(&circuit_builder), - CascadeTable::consistency_constraints(&circuit_builder), - LookupTable::consistency_constraints(&circuit_builder), - U32Table::consistency_constraints(&circuit_builder), - GrandCrossTableArg::consistency_constraints(&circuit_builder), - ] - .concat() -} - -fn transition_constraints() -> Vec> { - let circuit_builder = ConstraintCircuitBuilder::new(); - vec![ - ProgramTable::transition_constraints(&circuit_builder), - ProcessorTable::transition_constraints(&circuit_builder), - OpStackTable::transition_constraints(&circuit_builder), - RamTable::transition_constraints(&circuit_builder), - JumpStackTable::transition_constraints(&circuit_builder), - HashTable::transition_constraints(&circuit_builder), - CascadeTable::transition_constraints(&circuit_builder), - LookupTable::transition_constraints(&circuit_builder), - U32Table::transition_constraints(&circuit_builder), - GrandCrossTableArg::transition_constraints(&circuit_builder), - ] - .concat() -} - -fn terminal_constraints() -> Vec> { - let circuit_builder = ConstraintCircuitBuilder::new(); - vec![ - ProgramTable::terminal_constraints(&circuit_builder), - ProcessorTable::terminal_constraints(&circuit_builder), - OpStackTable::terminal_constraints(&circuit_builder), - RamTable::terminal_constraints(&circuit_builder), - JumpStackTable::terminal_constraints(&circuit_builder), - HashTable::terminal_constraints(&circuit_builder), - CascadeTable::terminal_constraints(&circuit_builder), - LookupTable::terminal_constraints(&circuit_builder), - U32Table::terminal_constraints(&circuit_builder), - GrandCrossTableArg::terminal_constraints(&circuit_builder), - ] - .concat() -} - #[cfg(test)] mod tests { use std::collections::HashMap; use air::table::hash::HashTable; + use air::table::lookup::LookupTable; use air::table::op_stack::OpStackTable; + use air::table::processor::ProcessorTable; + use air::table::program::ProgramTable; + use air::table::ram::RamTable; + use air::table::u32::U32Table; use air::table::CASCADE_TABLE_END; use air::table::EXT_CASCADE_TABLE_END; use air::table::EXT_HASH_TABLE_END; @@ -220,7 +144,7 @@ mod tests { use crate::challenges::Challenges; use crate::prelude::Claim; - use crate::table::degree_lowering_table::DegreeLoweringTable; + use crate::table::degree_lowering::DegreeLoweringTable; use super::*; @@ -241,8 +165,8 @@ mod tests { let challenges = &challenges.challenges; let num_rows = 2; - let base_shape = [num_rows, NUM_BASE_COLUMNS]; - let ext_shape = [num_rows, NUM_EXT_COLUMNS]; + let base_shape = [num_rows, NUM_MAIN_COLUMNS]; + let ext_shape = [num_rows, NUM_AUX_COLUMNS]; let base_rows = Array2::from_shape_simple_fn(base_shape, || rng.gen::()); let ext_rows = Array2::from_shape_simple_fn(ext_shape, || rng.gen::()); let base_rows = base_rows.view(); @@ -402,8 +326,8 @@ mod tests { let num_rows = 2; let num_new_base_constraints = new_base_constraints.len(); let num_new_ext_constraints = new_ext_constraints.len(); - let num_base_cols = NUM_BASE_COLUMNS + num_new_base_constraints; - let num_ext_cols = NUM_EXT_COLUMNS + num_new_ext_constraints; + let num_base_cols = NUM_MAIN_COLUMNS + num_new_base_constraints; + let num_ext_cols = NUM_AUX_COLUMNS + num_new_ext_constraints; let base_shape = [num_rows, num_base_cols]; let ext_shape = [num_rows, num_ext_cols]; let base_rows = Array2::from_shape_simple_fn(base_shape, || rng.gen::()); @@ -524,8 +448,8 @@ mod tests { #[ignore = "(probably) requires normalization of circuit expressions"] fn substitution_rules_are_unique() { let challenges = Challenges::default(); - let mut base_table_rows = Array2::from_shape_fn((2, NUM_BASE_COLUMNS), |_| random()); - let mut ext_table_rows = Array2::from_shape_fn((2, NUM_EXT_COLUMNS), |_| random()); + let mut base_table_rows = Array2::from_shape_fn((2, NUM_MAIN_COLUMNS), |_| random()); + let mut ext_table_rows = Array2::from_shape_fn((2, NUM_AUX_COLUMNS), |_| random()); DegreeLoweringTable::fill_derived_base_columns(base_table_rows.view_mut()); DegreeLoweringTable::fill_derived_ext_columns( @@ -535,7 +459,7 @@ mod tests { ); let mut encountered_values = HashMap::new(); - for col_idx in 0..NUM_BASE_COLUMNS { + for col_idx in 0..NUM_MAIN_COLUMNS { let val = base_table_rows[(0, col_idx)].lift(); let other_entry = encountered_values.insert(val, col_idx); if let Some(other_idx) = other_entry { @@ -543,7 +467,7 @@ mod tests { } } println!("Now comparing extension columns…"); - for col_idx in 0..NUM_EXT_COLUMNS { + for col_idx in 0..NUM_AUX_COLUMNS { let val = ext_table_rows[(0, col_idx)]; let other_entry = encountered_values.insert(val, col_idx); if let Some(other_idx) = other_entry { diff --git a/triton-vm/src/table/cascade.rs b/triton-vm/src/table/cascade.rs index a12eb5d18..434d94294 100644 --- a/triton-vm/src/table/cascade.rs +++ b/triton-vm/src/table/cascade.rs @@ -1,21 +1,10 @@ use air::challenge_id::ChallengeId; -use air::challenge_id::ChallengeId::*; use air::cross_table_argument::CrossTableArg; use air::cross_table_argument::LookupArg; use air::table::cascade::CascadeTable; -use air::table_column::CascadeBaseTableColumn; -use air::table_column::CascadeBaseTableColumn::*; -use air::table_column::CascadeExtTableColumn; -use air::table_column::CascadeExtTableColumn::*; use air::table_column::MasterBaseTableColumn; use air::table_column::MasterExtTableColumn; use air::AIR; -use constraint_circuit::ConstraintCircuitBuilder; -use constraint_circuit::ConstraintCircuitMonad; -use constraint_circuit::DualRowIndicator; -use constraint_circuit::DualRowIndicator::*; -use constraint_circuit::SingleRowIndicator; -use constraint_circuit::SingleRowIndicator::*; use ndarray::s; use ndarray::ArrayView2; use ndarray::ArrayViewMut2; @@ -29,6 +18,9 @@ use crate::challenges::Challenges; use crate::profiler::profiler; use crate::table::TraceTable; +type MainColumn = ::MainColumn; +type AuxColumn = ::AuxColumn; + fn lookup_8_bit_limb(to_look_up: u8) -> BFieldElement { tip5::LOOKUP_TABLE[usize::from(to_look_up)].into() } @@ -53,17 +45,20 @@ impl TraceTable for CascadeTable { let to_look_up_hi = ((to_look_up >> 8) & 0xff) as u8; let mut row = main_table.row_mut(row_idx); - row[LookInLo.base_table_index()] = bfe!(to_look_up_lo); - row[LookInHi.base_table_index()] = bfe!(to_look_up_hi); - row[LookOutLo.base_table_index()] = lookup_8_bit_limb(to_look_up_lo); - row[LookOutHi.base_table_index()] = lookup_8_bit_limb(to_look_up_hi); - row[LookupMultiplicity.base_table_index()] = bfe!(multiplicity); + row[MainColumn::LookInLo.base_table_index()] = bfe!(to_look_up_lo); + row[MainColumn::LookInHi.base_table_index()] = bfe!(to_look_up_hi); + row[MainColumn::LookOutLo.base_table_index()] = lookup_8_bit_limb(to_look_up_lo); + row[MainColumn::LookOutHi.base_table_index()] = lookup_8_bit_limb(to_look_up_hi); + row[MainColumn::LookupMultiplicity.base_table_index()] = bfe!(multiplicity); } } fn pad(mut main_table: ArrayViewMut2, cascade_table_length: usize) { main_table - .slice_mut(s![cascade_table_length.., IsPadding.base_table_index()]) + .slice_mut(s![ + cascade_table_length.., + MainColumn::IsPadding.base_table_index() + ]) .fill(BFieldElement::ONE); } @@ -73,8 +68,8 @@ impl TraceTable for CascadeTable { challenges: &Challenges, ) { profiler!(start "cascade table"); - assert_eq!(Self::MainColumn::COUNT, main_table.ncols()); - assert_eq!(Self::AuxColumn::COUNT, aux_table.ncols()); + assert_eq!(MainColumn::COUNT, main_table.ncols()); + assert_eq!(AuxColumn::COUNT, aux_table.ncols()); assert_eq!(main_table.nrows(), aux_table.nrows()); let mut hash_table_log_derivative = LookupArg::default_initial(); @@ -82,41 +77,44 @@ impl TraceTable for CascadeTable { let two_pow_8 = bfe!(1 << 8); - let hash_indeterminate = challenges[HashCascadeLookupIndeterminate]; - let hash_input_weight = challenges[HashCascadeLookInWeight]; - let hash_output_weight = challenges[HashCascadeLookOutWeight]; + let hash_indeterminate = challenges[ChallengeId::HashCascadeLookupIndeterminate]; + let hash_input_weight = challenges[ChallengeId::HashCascadeLookInWeight]; + let hash_output_weight = challenges[ChallengeId::HashCascadeLookOutWeight]; - let lookup_indeterminate = challenges[CascadeLookupIndeterminate]; - let lookup_input_weight = challenges[LookupTableInputWeight]; - let lookup_output_weight = challenges[LookupTableOutputWeight]; + let lookup_indeterminate = challenges[ChallengeId::CascadeLookupIndeterminate]; + let lookup_input_weight = challenges[ChallengeId::LookupTableInputWeight]; + let lookup_output_weight = challenges[ChallengeId::LookupTableOutputWeight]; for row_idx in 0..main_table.nrows() { let base_row = main_table.row(row_idx); - let is_padding = base_row[IsPadding.base_table_index()].is_one(); + let is_padding = base_row[MainColumn::IsPadding.base_table_index()].is_one(); if !is_padding { - let look_in = two_pow_8 * base_row[LookInHi.base_table_index()] - + base_row[LookInLo.base_table_index()]; - let look_out = two_pow_8 * base_row[LookOutHi.base_table_index()] - + base_row[LookOutLo.base_table_index()]; + let look_in = two_pow_8 * base_row[MainColumn::LookInHi.base_table_index()] + + base_row[MainColumn::LookInLo.base_table_index()]; + let look_out = two_pow_8 * base_row[MainColumn::LookOutHi.base_table_index()] + + base_row[MainColumn::LookOutLo.base_table_index()]; let compressed_row_hash = hash_input_weight * look_in + hash_output_weight * look_out; - let lookup_multiplicity = base_row[LookupMultiplicity.base_table_index()]; + let lookup_multiplicity = + base_row[MainColumn::LookupMultiplicity.base_table_index()]; hash_table_log_derivative += (hash_indeterminate - compressed_row_hash).inverse() * lookup_multiplicity; - let compressed_row_lo = lookup_input_weight * base_row[LookInLo.base_table_index()] - + lookup_output_weight * base_row[LookOutLo.base_table_index()]; - let compressed_row_hi = lookup_input_weight * base_row[LookInHi.base_table_index()] - + lookup_output_weight * base_row[LookOutHi.base_table_index()]; + let compressed_row_lo = lookup_input_weight + * base_row[MainColumn::LookInLo.base_table_index()] + + lookup_output_weight * base_row[MainColumn::LookOutLo.base_table_index()]; + let compressed_row_hi = lookup_input_weight + * base_row[MainColumn::LookInHi.base_table_index()] + + lookup_output_weight * base_row[MainColumn::LookOutHi.base_table_index()]; lookup_table_log_derivative += (lookup_indeterminate - compressed_row_lo).inverse(); lookup_table_log_derivative += (lookup_indeterminate - compressed_row_hi).inverse(); } let mut extension_row = aux_table.row_mut(row_idx); - extension_row[HashTableServerLogDerivative.ext_table_index()] = + extension_row[AuxColumn::HashTableServerLogDerivative.ext_table_index()] = hash_table_log_derivative; - extension_row[LookupTableClientLogDerivative.ext_table_index()] = + extension_row[AuxColumn::LookupTableClientLogDerivative.ext_table_index()] = lookup_table_log_derivative; } profiler!(stop "cascade table"); diff --git a/triton-vm/src/table/constraints.rs b/triton-vm/src/table/constraints.rs deleted file mode 100644 index abd5383de..000000000 --- a/triton-vm/src/table/constraints.rs +++ /dev/null @@ -1,112 +0,0 @@ -//! This file is a placeholder for auto-generated code -//! Run `cargo run --bin constraint-evaluation-generator` -//! to fill in this file with optimized constraints. - -use ndarray::ArrayView1; -use twenty_first::prelude::BFieldElement; -use twenty_first::prelude::XFieldElement; - -use crate::challenges::Challenges; -use crate::table::extension_table::Evaluable; -use crate::table::extension_table::Quotientable; -use crate::table::master_table::MasterExtTable; - -pub(crate) const ERROR_MESSAGE_GENERATE_CONSTRAINTS: &str = - "Constraints must be in place. Run: `cargo run --bin constraint-evaluation-generator`"; -const ERROR_MESSAGE_GENERATE_DEGREE_BOUNDS: &str = - "Degree bounds must be in place. Run: `cargo run --bin constraint-evaluation-generator`"; - -impl Evaluable for MasterExtTable { - fn evaluate_initial_constraints( - _: ArrayView1, - _: ArrayView1, - _: &Challenges, - ) -> Vec { - panic!("{ERROR_MESSAGE_GENERATE_CONSTRAINTS}") - } - - fn evaluate_consistency_constraints( - _: ArrayView1, - _: ArrayView1, - _: &Challenges, - ) -> Vec { - panic!("{ERROR_MESSAGE_GENERATE_CONSTRAINTS}") - } - - fn evaluate_transition_constraints( - _: ArrayView1, - _: ArrayView1, - _: ArrayView1, - _: ArrayView1, - _: &Challenges, - ) -> Vec { - panic!("{ERROR_MESSAGE_GENERATE_CONSTRAINTS}") - } - - fn evaluate_terminal_constraints( - _: ArrayView1, - _: ArrayView1, - _: &Challenges, - ) -> Vec { - panic!("{ERROR_MESSAGE_GENERATE_CONSTRAINTS}") - } -} - -impl Evaluable for MasterExtTable { - fn evaluate_initial_constraints( - _: ArrayView1, - _: ArrayView1, - _: &Challenges, - ) -> Vec { - panic!("{ERROR_MESSAGE_GENERATE_CONSTRAINTS}") - } - - fn evaluate_consistency_constraints( - _: ArrayView1, - _: ArrayView1, - _: &Challenges, - ) -> Vec { - panic!("{ERROR_MESSAGE_GENERATE_CONSTRAINTS}") - } - - fn evaluate_transition_constraints( - _: ArrayView1, - _: ArrayView1, - _: ArrayView1, - _: ArrayView1, - _: &Challenges, - ) -> Vec { - panic!("{ERROR_MESSAGE_GENERATE_CONSTRAINTS}") - } - - fn evaluate_terminal_constraints( - _: ArrayView1, - _: ArrayView1, - _: &Challenges, - ) -> Vec { - panic!("{ERROR_MESSAGE_GENERATE_CONSTRAINTS}") - } -} - -impl Quotientable for MasterExtTable { - const NUM_INITIAL_CONSTRAINTS: usize = 0; - const NUM_CONSISTENCY_CONSTRAINTS: usize = 0; - const NUM_TRANSITION_CONSTRAINTS: usize = 0; - const NUM_TERMINAL_CONSTRAINTS: usize = 0; - - fn initial_quotient_degree_bounds(_: isize) -> Vec { - panic!("{ERROR_MESSAGE_GENERATE_DEGREE_BOUNDS}") - } - - fn consistency_quotient_degree_bounds(_: isize, _: usize) -> Vec { - panic!("{ERROR_MESSAGE_GENERATE_DEGREE_BOUNDS}") - } - - fn transition_quotient_degree_bounds(_: isize, _: usize) -> Vec { - panic!("{ERROR_MESSAGE_GENERATE_DEGREE_BOUNDS}") - } - - fn terminal_quotient_degree_bounds(_: isize) -> Vec { - panic!("{ERROR_MESSAGE_GENERATE_DEGREE_BOUNDS}") - } -} diff --git a/triton-vm/src/table/degree_lowering.rs b/triton-vm/src/table/degree_lowering.rs new file mode 100644 index 000000000..770201232 --- /dev/null +++ b/triton-vm/src/table/degree_lowering.rs @@ -0,0 +1,5 @@ +//! The degree lowering table contains the introduced variables that allow +//! lowering the degree of the AIR. See [`air::TARGET_DEGREE`] +//! for additional information. + +include!(concat!(env!("OUT_DIR"), "/degree_lowering_table.rs")); diff --git a/triton-vm/src/table/degree_lowering_table.rs b/triton-vm/src/table/degree_lowering_table.rs deleted file mode 100644 index 697567e20..000000000 --- a/triton-vm/src/table/degree_lowering_table.rs +++ /dev/null @@ -1,35 +0,0 @@ -//! This file is a placeholder for auto-generated code. -//! Run `cargo run --bin constraint-evaluation-generator` to generate the actual code. - -use ndarray::ArrayView2; -use ndarray::ArrayViewMut2; -use strum::Display; -use strum::EnumCount; -use strum::EnumIter; -use twenty_first::prelude::BFieldElement; -use twenty_first::prelude::XFieldElement; - -use crate::challenges::Challenges; - -#[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] -pub enum DegreeLoweringBaseTableColumn {} - -#[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] -pub enum DegreeLoweringExtTableColumn {} - -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct DegreeLoweringTable; - -impl DegreeLoweringTable { - pub fn fill_derived_base_columns(mut _master_base_table: ArrayViewMut2) { - // to be filled by generated code - } - - pub fn fill_derived_ext_columns( - _master_base_table: ArrayView2, - mut _master_ext_table: ArrayViewMut2, - _challenges: &Challenges, - ) { - // to be filled by generated code - } -} diff --git a/triton-vm/src/table/extension_table.rs b/triton-vm/src/table/extension_table.rs index c43485f44..519189c7a 100644 --- a/triton-vm/src/table/extension_table.rs +++ b/triton-vm/src/table/extension_table.rs @@ -11,28 +11,22 @@ use crate::challenges::Challenges; use crate::table::master_table::MasterExtTable; use crate::table::ConstraintType; -/// The implementations of these functions are automatically generated using the -/// command `cargo run --bin constraint-evaluation-generator` and live in -/// `constraints.rs`. +include!(concat!(env!("OUT_DIR"), "/evaluate_constraints.rs")); + +// The implementations of these functions are generated in `build.rs`. pub trait Evaluable { - /// The code for this method must be generated by running - /// `cargo run --bin constraint-evaluation-generator` fn evaluate_initial_constraints( base_row: ArrayView1, ext_row: ArrayView1, challenges: &Challenges, ) -> Vec; - /// The code for this method must be generated by running - /// `cargo run --bin constraint-evaluation-generator` fn evaluate_consistency_constraints( base_row: ArrayView1, ext_row: ArrayView1, challenges: &Challenges, ) -> Vec; - /// The code for this method must be generated by running - /// `cargo run --bin constraint-evaluation-generator` fn evaluate_transition_constraints( current_base_row: ArrayView1, current_ext_row: ArrayView1, @@ -41,8 +35,6 @@ pub trait Evaluable { challenges: &Challenges, ) -> Vec; - /// The code for this method must be generated by running - /// `cargo run --bin constraint-evaluation-generator` fn evaluate_terminal_constraints( base_row: ArrayView1, ext_row: ArrayView1, @@ -50,33 +42,6 @@ pub trait Evaluable { ) -> Vec; } -pub trait Quotientable: Evaluable { - const NUM_INITIAL_CONSTRAINTS: usize; - const NUM_CONSISTENCY_CONSTRAINTS: usize; - const NUM_TRANSITION_CONSTRAINTS: usize; - const NUM_TERMINAL_CONSTRAINTS: usize; - - /// The total number of constraints. The number of quotients is identical. - const NUM_CONSTRAINTS: usize = Self::NUM_INITIAL_CONSTRAINTS - + Self::NUM_CONSISTENCY_CONSTRAINTS - + Self::NUM_TRANSITION_CONSTRAINTS - + Self::NUM_TERMINAL_CONSTRAINTS; - - fn initial_quotient_degree_bounds(interpolant_degree: isize) -> Vec; - - fn consistency_quotient_degree_bounds( - interpolant_degree: isize, - padded_height: usize, - ) -> Vec; - - fn transition_quotient_degree_bounds( - interpolant_degree: isize, - padded_height: usize, - ) -> Vec; - - fn terminal_quotient_degree_bounds(interpolant_degree: isize) -> Vec; -} - /// Helps debugging and benchmarking. The maximal degree achieved in any table dictates the length /// of the FRI domain, which in turn is responsible for the main performance bottleneck. #[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)] diff --git a/triton-vm/src/table/hash.rs b/triton-vm/src/table/hash.rs index d0a3df342..14272035c 100644 --- a/triton-vm/src/table/hash.rs +++ b/triton-vm/src/table/hash.rs @@ -6,35 +6,17 @@ use air::table::hash::HashTable; use air::table::hash::HashTableMode; use air::table::hash::PermutationTrace; use air::table::hash::MONTGOMERY_MODULUS; -use air::table::hash::NUM_ROUND_CONSTANTS; use air::table_column::HashBaseTableColumn::*; use air::table_column::HashExtTableColumn::*; use air::table_column::MasterBaseTableColumn; use air::table_column::MasterExtTableColumn; use air::AIR; -use constraint_circuit::ConstraintCircuitBuilder; -use constraint_circuit::ConstraintCircuitMonad; -use constraint_circuit::DualRowIndicator; -use constraint_circuit::DualRowIndicator::*; -use constraint_circuit::InputIndicator; -use constraint_circuit::SingleRowIndicator; -use constraint_circuit::SingleRowIndicator::*; -use isa::instruction::AnInstruction::Hash; -use isa::instruction::AnInstruction::SpongeAbsorb; -use isa::instruction::AnInstruction::SpongeInit; -use isa::instruction::AnInstruction::SpongeSqueeze; use isa::instruction::Instruction; use itertools::Itertools; use ndarray::*; use num_traits::Zero; -use strum::Display; use strum::EnumCount; -use strum::EnumIter; -use strum::IntoEnumIterator; use twenty_first::prelude::tip5::NUM_ROUNDS; -use twenty_first::prelude::tip5::NUM_SPLIT_AND_LOOKUP; -use twenty_first::prelude::tip5::RATE; -use twenty_first::prelude::tip5::ROUND_CONSTANTS; use twenty_first::prelude::tip5::STATE_SIZE; use twenty_first::prelude::*; @@ -573,7 +555,9 @@ pub(crate) mod tests { use air::table::TableId; use air::table_column::HashBaseTableColumn; use air::AIR; + use constraint_circuit::ConstraintCircuitBuilder; use std::collections::HashMap; + use strum::IntoEnumIterator; use crate::shared_tests::ProgramAndInput; use crate::stark::tests::master_tables_for_low_security_level; diff --git a/triton-vm/src/table/jump_stack.rs b/triton-vm/src/table/jump_stack.rs index 43afe3da5..1f2301d44 100644 --- a/triton-vm/src/table/jump_stack.rs +++ b/triton-vm/src/table/jump_stack.rs @@ -3,15 +3,12 @@ use std::collections::HashMap; use std::ops::Range; use air::challenge_id::ChallengeId::*; -use air::cross_table_argument::*; +use air::cross_table_argument::CrossTableArg; +use air::cross_table_argument::LookupArg; +use air::cross_table_argument::PermArg; use air::table::jump_stack::JumpStackTable; -use air::table_column::JumpStackBaseTableColumn::*; -use air::table_column::JumpStackExtTableColumn::*; use air::table_column::*; -use constraint_circuit::DualRowIndicator::*; -use constraint_circuit::SingleRowIndicator::*; -use constraint_circuit::*; -use isa::instruction::Instruction; +use air::AIR; use itertools::Itertools; use ndarray::parallel::prelude::*; use ndarray::prelude::*; @@ -27,6 +24,9 @@ use crate::ndarray_helper::horizontal_multi_slice_mut; use crate::profiler::profiler; use crate::table::TraceTable; +type MainColumn = ::MainColumn; +type AuxColumn = ::AuxColumn; + fn extension_column_running_product_permutation_argument( main_table: ArrayView2, challenges: &Challenges, @@ -34,11 +34,12 @@ fn extension_column_running_product_permutation_argument( let mut running_product = PermArg::default_initial(); let mut extension_column = Vec::with_capacity(main_table.nrows()); for row in main_table.rows() { - let compressed_row = row[CLK.base_table_index()] * challenges[JumpStackClkWeight] - + row[CI.base_table_index()] * challenges[JumpStackCiWeight] - + row[JSP.base_table_index()] * challenges[JumpStackJspWeight] - + row[JSO.base_table_index()] * challenges[JumpStackJsoWeight] - + row[JSD.base_table_index()] * challenges[JumpStackJsdWeight]; + let compressed_row = row[MainColumn::CLK.base_table_index()] + * challenges[JumpStackClkWeight] + + row[MainColumn::CI.base_table_index()] * challenges[JumpStackCiWeight] + + row[MainColumn::JSP.base_table_index()] * challenges[JumpStackJspWeight] + + row[MainColumn::JSO.base_table_index()] * challenges[JumpStackJsoWeight] + + row[MainColumn::JSD.base_table_index()] * challenges[JumpStackJsdWeight]; running_product *= challenges[JumpStackIndeterminate] - compressed_row; extension_column.push(running_product); } @@ -66,9 +67,11 @@ fn extension_column_clock_jump_diff_lookup_log_derivative( let mut extension_column = Vec::with_capacity(main_table.nrows()); extension_column.push(cjd_lookup_log_derivative); for (previous_row, current_row) in main_table.rows().into_iter().tuple_windows() { - if previous_row[JSP.base_table_index()] == current_row[JSP.base_table_index()] { - let previous_clock = previous_row[CLK.base_table_index()]; - let current_clock = current_row[CLK.base_table_index()]; + if previous_row[MainColumn::JSP.base_table_index()] + == current_row[MainColumn::JSP.base_table_index()] + { + let previous_clock = previous_row[MainColumn::CLK.base_table_index()]; + let current_clock = current_row[MainColumn::CLK.base_table_index()]; let clock_jump_difference = current_clock - previous_clock; let &mut inverse = inverses_dictionary .entry(clock_jump_difference) @@ -116,11 +119,11 @@ impl TraceTable for JumpStackTable { { let jsp = bfe!(jsp_val as u64); for (clk, ci, jso, jsd) in rows_with_this_jsp { - jump_stack_table[(jump_stack_table_row, CLK.base_table_index())] = clk; - jump_stack_table[(jump_stack_table_row, CI.base_table_index())] = ci; - jump_stack_table[(jump_stack_table_row, JSP.base_table_index())] = jsp; - jump_stack_table[(jump_stack_table_row, JSO.base_table_index())] = jso; - jump_stack_table[(jump_stack_table_row, JSD.base_table_index())] = jsd; + jump_stack_table[(jump_stack_table_row, MainColumn::CLK.base_table_index())] = clk; + jump_stack_table[(jump_stack_table_row, MainColumn::CI.base_table_index())] = ci; + jump_stack_table[(jump_stack_table_row, MainColumn::JSP.base_table_index())] = jsp; + jump_stack_table[(jump_stack_table_row, MainColumn::JSO.base_table_index())] = jso; + jump_stack_table[(jump_stack_table_row, MainColumn::JSD.base_table_index())] = jsd; jump_stack_table_row += 1; } } @@ -132,8 +135,11 @@ impl TraceTable for JumpStackTable { for row_idx in 0..aet.processor_trace.nrows() - 1 { let curr_row = jump_stack_table.row(row_idx); let next_row = jump_stack_table.row(row_idx + 1); - let clk_diff = next_row[CLK.base_table_index()] - curr_row[CLK.base_table_index()]; - if curr_row[JSP.base_table_index()] == next_row[JSP.base_table_index()] { + let clk_diff = next_row[MainColumn::CLK.base_table_index()] + - curr_row[MainColumn::CLK.base_table_index()]; + if curr_row[MainColumn::JSP.base_table_index()] + == next_row[MainColumn::JSP.base_table_index()] + { clock_jump_differences.push(clk_diff); } } @@ -151,7 +157,9 @@ impl TraceTable for JumpStackTable { .rows() .into_iter() .enumerate() - .find(|(_, row)| row[CLK.base_table_index()].value() as usize == max_clk_before_padding) + .find(|(_, row)| { + row[MainColumn::CLK.base_table_index()].value() as usize == max_clk_before_padding + }) .map(|(idx, _)| idx) .expect("Jump Stack Table must contain row with clock cycle equal to max cycle."); let rows_to_move_source_section_start = max_clk_before_padding_row_idx + 1; @@ -191,7 +199,8 @@ impl TraceTable for JumpStackTable { // CLK keeps increasing by 1 also in the padding section. let new_clk_values = Array1::from_iter((table_len..padded_height).map(|clk| bfe!(clk as u64))); - new_clk_values.move_into(padding_section.slice_mut(s![.., CLK.base_table_index()])); + new_clk_values + .move_into(padding_section.slice_mut(s![.., MainColumn::CLK.base_table_index()])); } fn extend( @@ -200,8 +209,8 @@ impl TraceTable for JumpStackTable { challenges: &Challenges, ) { profiler!(start "jump stack table"); - assert_eq!(Self::MainColumn::COUNT, main_table.ncols()); - assert_eq!(Self::AuxColumn::COUNT, aux_table.ncols()); + assert_eq!(MainColumn::COUNT, main_table.ncols()); + assert_eq!(AuxColumn::COUNT, aux_table.ncols()); assert_eq!(main_table.nrows(), aux_table.nrows()); // use strum::IntoEnumIterator; diff --git a/triton-vm/src/table/lookup.rs b/triton-vm/src/table/lookup.rs index 17d0d2853..ed5295245 100644 --- a/triton-vm/src/table/lookup.rs +++ b/triton-vm/src/table/lookup.rs @@ -1,21 +1,11 @@ use air::challenge_id::ChallengeId; -use air::challenge_id::ChallengeId::*; use air::cross_table_argument::CrossTableArg; use air::cross_table_argument::EvalArg; use air::cross_table_argument::LookupArg; use air::table::lookup::LookupTable; -use air::table_column::LookupBaseTableColumn; -use air::table_column::LookupBaseTableColumn::*; -use air::table_column::LookupExtTableColumn; -use air::table_column::LookupExtTableColumn::*; use air::table_column::MasterBaseTableColumn; use air::table_column::MasterExtTableColumn; -use constraint_circuit::ConstraintCircuitBuilder; -use constraint_circuit::ConstraintCircuitMonad; -use constraint_circuit::DualRowIndicator; -use constraint_circuit::DualRowIndicator::*; -use constraint_circuit::SingleRowIndicator; -use constraint_circuit::SingleRowIndicator::*; +use air::AIR; use itertools::Itertools; use ndarray::prelude::*; use num_traits::ConstOne; @@ -32,26 +22,29 @@ use crate::ndarray_helper::horizontal_multi_slice_mut; use crate::profiler::profiler; use crate::table::TraceTable; +type MainColumn = ::MainColumn; +type AuxColumn = ::AuxColumn; + fn extension_column_cascade_running_sum_log_derivative( base_table: ArrayView2, challenges: &Challenges, ) -> Array2 { - let look_in_weight = challenges[LookupTableInputWeight]; - let look_out_weight = challenges[LookupTableOutputWeight]; - let indeterminate = challenges[CascadeLookupIndeterminate]; + let look_in_weight = challenges[ChallengeId::LookupTableInputWeight]; + let look_out_weight = challenges[ChallengeId::LookupTableOutputWeight]; + let indeterminate = challenges[ChallengeId::CascadeLookupIndeterminate]; let mut cascade_table_running_sum_log_derivative = LookupArg::default_initial(); let mut extension_column = Vec::with_capacity(base_table.nrows()); for row in base_table.rows() { - if row[IsPadding.base_table_index()].is_one() { + if row[MainColumn::IsPadding.base_table_index()].is_one() { break; } - let lookup_input = row[LookIn.base_table_index()]; - let lookup_output = row[LookOut.base_table_index()]; + let lookup_input = row[MainColumn::LookIn.base_table_index()]; + let lookup_output = row[MainColumn::LookOut.base_table_index()]; let compressed_row = lookup_input * look_in_weight + lookup_output * look_out_weight; - let lookup_multiplicity = row[LookupMultiplicity.base_table_index()]; + let lookup_multiplicity = row[MainColumn::LookupMultiplicity.base_table_index()]; cascade_table_running_sum_log_derivative += (indeterminate - compressed_row).inverse() * lookup_multiplicity; @@ -70,12 +63,13 @@ fn extension_column_public_running_evaluation( let mut running_evaluation = EvalArg::default_initial(); let mut extension_column = Vec::with_capacity(base_table.nrows()); for row in base_table.rows() { - if row[IsPadding.base_table_index()].is_one() { + if row[MainColumn::IsPadding.base_table_index()].is_one() { break; } - running_evaluation = running_evaluation * challenges[LookupTablePublicIndeterminate] - + row[LookOut.base_table_index()]; + running_evaluation = running_evaluation + * challenges[ChallengeId::LookupTablePublicIndeterminate] + + row[MainColumn::LookOut.base_table_index()]; extension_column.push(running_evaluation); } @@ -94,14 +88,18 @@ impl TraceTable for LookupTable { // Lookup Table input let lookup_input = Array1::from_iter((0..LOOKUP_TABLE_LEN).map(|i| bfe!(i as u64))); - let lookup_input_column = - main_table.slice_mut(s![..LOOKUP_TABLE_LEN, LookIn.base_table_index()]); + let lookup_input_column = main_table.slice_mut(s![ + ..LOOKUP_TABLE_LEN, + MainColumn::LookIn.base_table_index() + ]); lookup_input.move_into(lookup_input_column); // Lookup Table output let lookup_output = Array1::from_iter(tip5::LOOKUP_TABLE.map(BFieldElement::from)); - let lookup_output_column = - main_table.slice_mut(s![..LOOKUP_TABLE_LEN, LookOut.base_table_index()]); + let lookup_output_column = main_table.slice_mut(s![ + ..LOOKUP_TABLE_LEN, + MainColumn::LookOut.base_table_index() + ]); lookup_output.move_into(lookup_output_column); // Lookup Table multiplicities @@ -111,14 +109,14 @@ impl TraceTable for LookupTable { ); let lookup_multiplicities_column = main_table.slice_mut(s![ ..LOOKUP_TABLE_LEN, - LookupMultiplicity.base_table_index() + MainColumn::LookupMultiplicity.base_table_index() ]); lookup_multiplicities.move_into(lookup_multiplicities_column); } fn pad(mut lookup_table: ArrayViewMut2, table_length: usize) { lookup_table - .slice_mut(s![table_length.., IsPadding.base_table_index()]) + .slice_mut(s![table_length.., MainColumn::IsPadding.base_table_index()]) .fill(BFieldElement::ONE); } @@ -128,11 +126,11 @@ impl TraceTable for LookupTable { challenges: &Challenges, ) { profiler!(start "lookup table"); - assert_eq!(Self::MainColumn::COUNT, main_table.ncols()); - assert_eq!(Self::AuxColumn::COUNT, aux_table.ncols()); + assert_eq!(MainColumn::COUNT, main_table.ncols()); + assert_eq!(AuxColumn::COUNT, aux_table.ncols()); assert_eq!(main_table.nrows(), aux_table.nrows()); - let extension_column_indices = LookupExtTableColumn::iter() + let extension_column_indices = AuxColumn::iter() .map(|column| column.ext_table_index()) .collect_vec(); let extension_column_slices = horizontal_multi_slice_mut( diff --git a/triton-vm/src/table/master_table.rs b/triton-vm/src/table/master_table.rs index ddbf67d18..2d3fd548b 100644 --- a/triton-vm/src/table/master_table.rs +++ b/triton-vm/src/table/master_table.rs @@ -62,9 +62,7 @@ use num_traits::Zero; use rand::distributions::Standard; use rand::prelude::Distribution; use rand::random; -use strum::Display; use strum::EnumCount; -use strum::EnumIter; use twenty_first::math::tip5::RATE; use twenty_first::math::traits::FiniteField; use twenty_first::prelude::*; @@ -80,10 +78,9 @@ use crate::ndarray_helper::horizontal_multi_slice_mut; use crate::ndarray_helper::partial_sums; use crate::profiler::profiler; use crate::stark::NUM_RANDOMIZER_POLYNOMIALS; -use crate::table::degree_lowering_table::DegreeLoweringTable; +use crate::table::degree_lowering::DegreeLoweringTable; use crate::table::extension_table::all_degrees_with_origin; use crate::table::extension_table::DegreeWithOrigin; -use crate::table::extension_table::Quotientable; use crate::table::processor::ClkJumpDiffs; use crate::table::*; @@ -461,7 +458,7 @@ pub struct MasterExtTable { } impl MasterTable for MasterBaseTable { - const NUM_COLUMNS: usize = NUM_BASE_COLUMNS; + const NUM_COLUMNS: usize = NUM_MAIN_COLUMNS; fn trace_domain(&self) -> ArithmeticDomain { self.trace_domain @@ -567,7 +564,7 @@ impl MasterTable for MasterBaseTable { } impl MasterTable for MasterExtTable { - const NUM_COLUMNS: usize = NUM_EXT_COLUMNS; + const NUM_COLUMNS: usize = NUM_AUX_COLUMNS + NUM_RANDOMIZER_POLYNOMIALS; fn trace_domain(&self) -> ArithmeticDomain { self.trace_domain @@ -588,13 +585,13 @@ impl MasterTable for MasterExtTable { fn trace_table(&self) -> ArrayView2 { let unit_distance = self.randomized_trace_domain().length / self.trace_domain().length; self.randomized_trace_table - .slice(s![..; unit_distance, ..NUM_EXT_COLUMNS]) + .slice(s![..; unit_distance, ..Self::NUM_COLUMNS]) } fn trace_table_mut(&mut self) -> ArrayViewMut2 { let unit_distance = self.randomized_trace_domain().length / self.trace_domain().length; self.randomized_trace_table - .slice_mut(s![..; unit_distance, ..NUM_EXT_COLUMNS]) + .slice_mut(s![..; unit_distance, ..Self::NUM_COLUMNS]) } fn randomized_trace_table(&self) -> ArrayView2 { @@ -655,7 +652,7 @@ impl MasterTable for MasterExtTable { fn out_of_domain_row(&self, indeterminate: XFieldElement) -> Array1 { self.interpolation_polynomials() - .slice(s![..NUM_EXT_COLUMNS]) + .slice(s![..Self::NUM_COLUMNS]) .into_par_iter() .map(|polynomial| polynomial.evaluate(indeterminate)) .collect::>() @@ -682,7 +679,7 @@ impl MasterBaseTable { ArithmeticDomain::of_length(randomized_padded_trace_len).unwrap(); let num_rows = randomized_padded_trace_len; - let num_columns = NUM_BASE_COLUMNS; + let num_columns = Self::NUM_COLUMNS; let randomized_trace_table = Array2::zeros([num_rows, num_columns].f()); let mut master_base_table = Self { @@ -811,11 +808,13 @@ impl MasterBaseTable { // randomizer polynomials let num_rows = self.randomized_trace_table().nrows(); profiler!(start "initialize master table"); + let num_aux_columns = MasterExtTable::NUM_COLUMNS; let mut randomized_trace_extension_table = - fast_zeros_column_major(num_rows, NUM_EXT_COLUMNS + NUM_RANDOMIZER_POLYNOMIALS); + fast_zeros_column_major(num_rows, num_aux_columns); + let randomizers_start = MasterExtTable::NUM_COLUMNS - NUM_RANDOMIZER_POLYNOMIALS; randomized_trace_extension_table - .slice_mut(s![.., NUM_EXT_COLUMNS..]) + .slice_mut(s![.., randomizers_start..]) .par_mapv_inplace(|_| random::()); profiler!(stop "initialize master table"); @@ -834,7 +833,7 @@ impl MasterBaseTable { let unit_distance = self.randomized_trace_domain().length / self.trace_domain().length; let master_ext_table_without_randomizers = master_ext_table .randomized_trace_table - .slice_mut(s![..; unit_distance, ..NUM_EXT_COLUMNS]); + .slice_mut(s![..; unit_distance, ..randomizers_start]); let extension_tables: [_; TableId::COUNT] = horizontal_multi_slice_mut( master_ext_table_without_randomizers, &partial_sums(&[ @@ -856,8 +855,8 @@ impl MasterBaseTable { profiler!(start "all tables"); Self::all_extend_functions() .into_par_iter() - .zip_eq(self.base_tables_for_extending().into_par_iter()) - .zip_eq(extension_tables.into_par_iter()) + .zip_eq(self.base_tables_for_extending()) + .zip_eq(extension_tables) .for_each(|((extend, base_table), ext_table)| { extend(base_table, ext_table, challenges) }); @@ -937,7 +936,7 @@ impl MasterBaseTable { row: Array1, ) -> Result, ProvingError> { let err = || ProvingError::TableRowConversionError { - expected_len: NUM_BASE_COLUMNS, + expected_len: Self::NUM_COLUMNS, actual_len: row.len(), }; row.to_vec().try_into().map_err(|_| err()) @@ -978,7 +977,7 @@ impl MasterExtTable { pub(crate) fn try_to_ext_row(row: Array1) -> Result { let err = || ProvingError::TableRowConversionError { - expected_len: NUM_EXT_COLUMNS, + expected_len: Self::NUM_COLUMNS, actual_len: row.len(), }; row.to_vec().try_into().map_err(|_| err()) @@ -1194,8 +1193,8 @@ mod tests { use constraint_circuit::SingleRowIndicator; use isa::instruction::Instruction; use isa::instruction::InstructionBit; - use ndarray::s; use ndarray::Array2; + use num_traits::ConstZero; use num_traits::Zero; use proptest::prelude::*; use proptest_arbitrary_interop::arb; @@ -1207,15 +1206,15 @@ mod tests { use twenty_first::math::traits::FiniteField; use twenty_first::prelude::x_field_element::EXTENSION_DEGREE; - use crate::air::memory_layout::DynamicTasmConstraintEvaluationMemoryLayout; - use crate::air::memory_layout::StaticTasmConstraintEvaluationMemoryLayout; - use crate::air::tasm_air_constraints::dynamic_air_constraint_evaluation_tasm; - use crate::air::tasm_air_constraints::static_air_constraint_evaluation_tasm; + use crate::air::dynamic_air_constraint_evaluation_tasm; + use crate::air::static_air_constraint_evaluation_tasm; use crate::arithmetic_domain::ArithmeticDomain; + use crate::memory_layout::DynamicTasmConstraintEvaluationMemoryLayout; + use crate::memory_layout::StaticTasmConstraintEvaluationMemoryLayout; use crate::shared_tests::ProgramAndInput; use crate::stark::tests::*; - use crate::table::degree_lowering_table::DegreeLoweringBaseTableColumn; - use crate::table::degree_lowering_table::DegreeLoweringExtTableColumn; + use crate::table::degree_lowering::DegreeLoweringBaseTableColumn; + use crate::table::degree_lowering::DegreeLoweringExtTableColumn; use crate::table::*; use crate::triton_program; @@ -1305,15 +1304,6 @@ mod tests { ::AuxColumn::COUNT, master_ext_table.table(TableId::U32).ncols() ); - - // use some domain-specific knowledge to also check for the randomizer columns - assert_eq!( - NUM_RANDOMIZER_POLYNOMIALS, - master_ext_table - .randomized_trace_table() - .slice(s![.., EXT_U32_TABLE_END..]) - .ncols() - ); } #[test] @@ -1410,14 +1400,13 @@ mod tests { } fn generate_table_overview() -> SpecSnippet { - const NUM_DEGREE_LOWERING_TARGETS: usize = 3; - const DEGREE_LOWERING_TARGETS: [Option; NUM_DEGREE_LOWERING_TARGETS] = - [None, Some(8), Some(4)]; - - fn table_widths() -> ((usize, usize)) { + fn table_widths() -> (usize, usize) { (A::MainColumn::COUNT, A::AuxColumn::COUNT) } + const NUM_DEGREE_LOWERING_TARGETS: usize = 3; + const DEGREE_LOWERING_TARGETS: [Option; NUM_DEGREE_LOWERING_TARGETS] = + [None, Some(8), Some(4)]; assert!(DEGREE_LOWERING_TARGETS.contains(&Some(air::TARGET_DEGREE))); let mut all_table_info = [ @@ -1463,7 +1452,7 @@ mod tests { ConstraintCircuitMonad::lower_to_degree(&mut constraints, degree_lowering_info) }; - let constraints = crate::table::constraints(); + let constraints = constraint_builder::Constraints::all(); let (init_main, init_aux) = lower_to_target_degree_single_row(constraints.init); let (cons_main, cons_aux) = lower_to_target_degree_single_row(constraints.cons); let (tran_main, tran_aux) = lower_to_target_degree_double_row(constraints.tran); @@ -1768,10 +1757,21 @@ mod tests { } fn generate_tasm_air_evaluation_cost_overview() -> SpecSnippet { - let static_layout = StaticTasmConstraintEvaluationMemoryLayout::default(); - let static_tasm = static_air_constraint_evaluation_tasm(static_layout); - let dynamic_layout = DynamicTasmConstraintEvaluationMemoryLayout::default(); - let dynamic_tasm = dynamic_air_constraint_evaluation_tasm(dynamic_layout); + let dummy_static_layout = StaticTasmConstraintEvaluationMemoryLayout { + free_mem_page_ptr: BFieldElement::ZERO, + curr_base_row_ptr: BFieldElement::ZERO, + curr_ext_row_ptr: BFieldElement::ZERO, + next_base_row_ptr: BFieldElement::ZERO, + next_ext_row_ptr: BFieldElement::ZERO, + challenges_ptr: BFieldElement::ZERO, + }; + let dummy_dynamic_layout = DynamicTasmConstraintEvaluationMemoryLayout { + free_mem_page_ptr: BFieldElement::ZERO, + challenges_ptr: BFieldElement::ZERO, + }; + + let static_tasm = static_air_constraint_evaluation_tasm(dummy_static_layout); + let dynamic_tasm = dynamic_air_constraint_evaluation_tasm(dummy_dynamic_layout); let mut snippet = "\ | Type | Processor | Op Stack | RAM |\n\ @@ -1934,7 +1934,7 @@ mod tests { print_columns!(base CascadeBaseTableColumn for "cascade"); print_columns!(base LookupBaseTableColumn for "lookup"); print_columns!(base U32BaseTableColumn for "u32"); - // print_columns!(base DegreeLoweringBaseTableColumn for "degree low."); // todo + print_columns!(base DegreeLoweringBaseTableColumn for "degree low."); println!(); println!("| idx | table | extension column"); @@ -1948,7 +1948,7 @@ mod tests { print_columns!(ext CascadeExtTableColumn for "cascade"); print_columns!(ext LookupExtTableColumn for "lookup"); print_columns!(ext U32ExtTableColumn for "u32"); - // print_columns!(ext DegreeLoweringExtTableColumn for "degree low."); // todo + print_columns!(ext DegreeLoweringExtTableColumn for "degree low."); } #[test] @@ -1959,7 +1959,7 @@ mod tests { let fri_domain = ArithmeticDomain::of_length(1 << 11).unwrap(); let randomized_trace_table = - Array2::zeros((randomized_trace_domain.length, NUM_EXT_COLUMNS)); + Array2::zeros((randomized_trace_domain.length, NUM_AUX_COLUMNS)); let mut master_table = MasterExtTable { num_trace_randomizers: 16, diff --git a/triton-vm/src/table/op_stack.rs b/triton-vm/src/table/op_stack.rs index 6c128f6f1..f6341002f 100644 --- a/triton-vm/src/table/op_stack.rs +++ b/triton-vm/src/table/op_stack.rs @@ -7,14 +7,9 @@ use air::cross_table_argument::*; use air::table::op_stack::OpStackTable; use air::table::op_stack::PADDING_VALUE; use air::table::TableId; -use air::table_column::OpStackBaseTableColumn::*; -use air::table_column::OpStackExtTableColumn::*; use air::table_column::*; use air::AIR; use arbitrary::Arbitrary; -use constraint_circuit::DualRowIndicator::*; -use constraint_circuit::SingleRowIndicator::*; -use constraint_circuit::*; use isa::op_stack::OpStackElement; use isa::op_stack::UnderflowIO; use itertools::Itertools; @@ -32,6 +27,9 @@ use crate::ndarray_helper::horizontal_multi_slice_mut; use crate::profiler::profiler; use crate::table::TraceTable; +type MainColumn = ::MainColumn; +type AuxColumn = ::AuxColumn; + #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)] pub struct OpStackTableEntry { pub clk: u32, @@ -92,10 +90,10 @@ impl OpStackTableEntry { }; let mut row = Array1::zeros(::MainColumn::COUNT); - row[CLK.base_table_index()] = self.clk.into(); - row[IB1ShrinkStack.base_table_index()] = shrink_stack_indicator; - row[StackPointer.base_table_index()] = self.op_stack_pointer; - row[FirstUnderflowElement.base_table_index()] = self.underflow_io.payload(); + row[MainColumn::CLK.base_table_index()] = self.clk.into(); + row[MainColumn::IB1ShrinkStack.base_table_index()] = shrink_stack_indicator; + row[MainColumn::StackPointer.base_table_index()] = self.op_stack_pointer; + row[MainColumn::FirstUnderflowElement.base_table_index()] = self.underflow_io.payload(); row } } @@ -109,11 +107,13 @@ fn extension_column_running_product_permutation_argument( let mut running_product = PermArg::default_initial(); let mut extension_column = Vec::with_capacity(base_table.nrows()); for row in base_table.rows() { - if row[IB1ShrinkStack.base_table_index()] != PADDING_VALUE { - let compressed_row = row[CLK.base_table_index()] * challenges[OpStackClkWeight] - + row[IB1ShrinkStack.base_table_index()] * challenges[OpStackIb1Weight] - + row[StackPointer.base_table_index()] * challenges[OpStackPointerWeight] - + row[FirstUnderflowElement.base_table_index()] + if row[MainColumn::IB1ShrinkStack.base_table_index()] != PADDING_VALUE { + let compressed_row = row[MainColumn::CLK.base_table_index()] + * challenges[OpStackClkWeight] + + row[MainColumn::IB1ShrinkStack.base_table_index()] * challenges[OpStackIb1Weight] + + row[MainColumn::StackPointer.base_table_index()] + * challenges[OpStackPointerWeight] + + row[MainColumn::FirstUnderflowElement.base_table_index()] * challenges[OpStackFirstUnderflowElementWeight]; running_product *= perm_arg_indeterminate - compressed_row; } @@ -144,15 +144,15 @@ fn extension_column_clock_jump_diff_lookup_log_derivative( let mut extension_column = Vec::with_capacity(base_table.nrows()); extension_column.push(cjd_lookup_log_derivative); for (previous_row, current_row) in base_table.rows().into_iter().tuple_windows() { - if current_row[IB1ShrinkStack.base_table_index()] == PADDING_VALUE { + if current_row[MainColumn::IB1ShrinkStack.base_table_index()] == PADDING_VALUE { break; }; - let previous_stack_pointer = previous_row[StackPointer.base_table_index()]; - let current_stack_pointer = current_row[StackPointer.base_table_index()]; + let previous_stack_pointer = previous_row[MainColumn::StackPointer.base_table_index()]; + let current_stack_pointer = current_row[MainColumn::StackPointer.base_table_index()]; if previous_stack_pointer == current_stack_pointer { - let previous_clock = previous_row[CLK.base_table_index()]; - let current_clock = current_row[CLK.base_table_index()]; + let previous_clock = previous_row[MainColumn::CLK.base_table_index()]; + let current_clock = current_row[MainColumn::CLK.base_table_index()]; let clock_jump_difference = current_clock - previous_clock; let &mut inverse = inverses_dictionary .entry(clock_jump_difference) @@ -192,10 +192,10 @@ impl TraceTable for OpStackTable { fn pad(mut op_stack_table: ArrayViewMut2, op_stack_table_len: usize) { let last_row_index = op_stack_table_len.saturating_sub(1); let mut padding_row = op_stack_table.row(last_row_index).to_owned(); - padding_row[IB1ShrinkStack.base_table_index()] = PADDING_VALUE; + padding_row[MainColumn::IB1ShrinkStack.base_table_index()] = PADDING_VALUE; if op_stack_table_len == 0 { let first_stack_pointer = u32::try_from(OpStackElement::COUNT).unwrap().into(); - padding_row[StackPointer.base_table_index()] = first_stack_pointer; + padding_row[MainColumn::StackPointer.base_table_index()] = first_stack_pointer; } let mut padding_section = op_stack_table.slice_mut(s![op_stack_table_len.., ..]); @@ -211,8 +211,8 @@ impl TraceTable for OpStackTable { challenges: &Challenges, ) { profiler!(start "op stack table"); - assert_eq!(Self::MainColumn::COUNT, main_table.ncols()); - assert_eq!(Self::AuxColumn::COUNT, aux_table.ncols()); + assert_eq!(MainColumn::COUNT, main_table.ncols()); + assert_eq!(AuxColumn::COUNT, aux_table.ncols()); assert_eq!(main_table.nrows(), aux_table.nrows()); let extension_column_indices = OpStackExtTableColumn::iter() @@ -239,12 +239,12 @@ impl TraceTable for OpStackTable { } fn compare_rows(row_0: ArrayView1, row_1: ArrayView1) -> Ordering { - let stack_pointer_0 = row_0[StackPointer.base_table_index()].value(); - let stack_pointer_1 = row_1[StackPointer.base_table_index()].value(); + let stack_pointer_0 = row_0[MainColumn::StackPointer.base_table_index()].value(); + let stack_pointer_1 = row_1[MainColumn::StackPointer.base_table_index()].value(); let compare_stack_pointers = stack_pointer_0.cmp(&stack_pointer_1); - let clk_0 = row_0[CLK.base_table_index()].value(); - let clk_1 = row_1[CLK.base_table_index()].value(); + let clk_0 = row_0[MainColumn::CLK.base_table_index()].value(); + let clk_1 = row_1[MainColumn::CLK.base_table_index()].value(); let compare_clocks = clk_0.cmp(&clk_1); compare_stack_pointers.then(compare_clocks) @@ -255,11 +255,11 @@ fn clock_jump_differences(op_stack_table: ArrayView2) -> Vec::MainColumn::COUNT; let mut row_0 = Array1::zeros(BASE_WIDTH); - row_0[StackPointer.base_table_index()] = stack_pointer_0.into(); - row_0[CLK.base_table_index()] = clk.into(); + row_0[MainColumn::StackPointer.base_table_index()] = stack_pointer_0.into(); + row_0[MainColumn::CLK.base_table_index()] = clk.into(); let mut row_1 = Array1::zeros(BASE_WIDTH); - row_1[StackPointer.base_table_index()] = stack_pointer_1.into(); - row_1[CLK.base_table_index()] = clk.into(); + row_1[MainColumn::StackPointer.base_table_index()] = stack_pointer_1.into(); + row_1[MainColumn::CLK.base_table_index()] = clk.into(); let stack_pointer_comparison = stack_pointer_0.cmp(&stack_pointer_1); let row_comparison = compare_rows(row_0.view(), row_1.view()); @@ -379,12 +379,12 @@ pub(crate) mod tests { const BASE_WIDTH: usize = ::MainColumn::COUNT; let mut row_0 = Array1::zeros(BASE_WIDTH); - row_0[StackPointer.base_table_index()] = stack_pointer.into(); - row_0[CLK.base_table_index()] = clk_0.into(); + row_0[MainColumn::StackPointer.base_table_index()] = stack_pointer.into(); + row_0[MainColumn::CLK.base_table_index()] = clk_0.into(); let mut row_1 = Array1::zeros(BASE_WIDTH); - row_1[StackPointer.base_table_index()] = stack_pointer.into(); - row_1[CLK.base_table_index()] = clk_1.into(); + row_1[MainColumn::StackPointer.base_table_index()] = stack_pointer.into(); + row_1[MainColumn::CLK.base_table_index()] = clk_1.into(); let clk_comparison = clk_0.cmp(&clk_1); let row_comparison = compare_rows(row_0.view(), row_1.view()); diff --git a/triton-vm/src/table/processor.rs b/triton-vm/src/table/processor.rs index 3bd239bc1..eb699b05f 100644 --- a/triton-vm/src/table/processor.rs +++ b/triton-vm/src/table/processor.rs @@ -1,25 +1,12 @@ -use std::cmp::max; -use std::ops::Mul; - use air::challenge_id::ChallengeId; -use air::challenge_id::ChallengeId::*; use air::cross_table_argument::*; use air::table::processor::ProcessorTable; use air::table::ram; use air::table_column::ProcessorBaseTableColumn::*; -use air::table_column::ProcessorExtTableColumn::*; use air::table_column::*; -use constraint_circuit::DualRowIndicator::*; -use constraint_circuit::SingleRowIndicator::*; -use constraint_circuit::*; use isa::instruction::AnInstruction::*; use isa::instruction::Instruction; -use isa::instruction::InstructionBit; -use isa::instruction::ALL_INSTRUCTIONS; -use isa::op_stack::NumberOfWords; use isa::op_stack::OpStackElement; -use isa::op_stack::NUM_OP_STACK_REGISTERS; -use itertools::izip; use itertools::Itertools; use ndarray::parallel::prelude::*; use ndarray::*; @@ -29,7 +16,6 @@ use num_traits::Zero; use strum::EnumCount; use strum::IntoEnumIterator; use twenty_first::math::traits::FiniteField; -use twenty_first::prelude::x_field_element::EXTENSION_DEGREE; use twenty_first::prelude::*; use crate::aet::AlgebraicExecutionTrace; @@ -156,7 +142,7 @@ fn extension_column_input_table_eval_argument( let input_symbol_column = ProcessorTable::op_stack_column_by_index(i); let input_symbol = current_row[input_symbol_column.base_table_index()]; input_table_running_evaluation = input_table_running_evaluation - * challenges[StandardInputIndeterminate] + * challenges[ChallengeId::StandardInputIndeterminate] + input_symbol; } } @@ -178,7 +164,7 @@ fn extension_column_output_table_eval_argument( let output_symbol_column = ProcessorTable::op_stack_column_by_index(i); let output_symbol = previous_row[output_symbol_column.base_table_index()]; output_table_running_evaluation = output_table_running_evaluation - * challenges[StandardOutputIndeterminate] + * challenges[ChallengeId::StandardOutputIndeterminate] + output_symbol; } } @@ -198,10 +184,11 @@ fn extension_column_instruction_lookup_argument( break; // padding marks the end of the trace } - let compressed_row = row[IP.base_table_index()] * challenges[ProgramAddressWeight] - + row[CI.base_table_index()] * challenges[ProgramInstructionWeight] - + row[NIA.base_table_index()] * challenges[ProgramNextInstructionWeight]; - to_invert.push(challenges[InstructionLookupIndeterminate] - compressed_row); + let compressed_row = row[IP.base_table_index()] + * challenges[ChallengeId::ProgramAddressWeight] + + row[CI.base_table_index()] * challenges[ChallengeId::ProgramInstructionWeight] + + row[NIA.base_table_index()] * challenges[ChallengeId::ProgramNextInstructionWeight]; + to_invert.push(challenges[ChallengeId::InstructionLookupIndeterminate] - compressed_row); } // populate extension column with inverses @@ -255,12 +242,14 @@ fn extension_column_jump_stack_table_perm_argument( let mut jump_stack_running_product = PermArg::default_initial(); let mut extension_column = Vec::with_capacity(base_table.nrows()); for row in base_table.rows() { - let compressed_row = row[CLK.base_table_index()] * challenges[JumpStackClkWeight] - + row[CI.base_table_index()] * challenges[JumpStackCiWeight] - + row[JSP.base_table_index()] * challenges[JumpStackJspWeight] - + row[JSO.base_table_index()] * challenges[JumpStackJsoWeight] - + row[JSD.base_table_index()] * challenges[JumpStackJsdWeight]; - jump_stack_running_product *= challenges[JumpStackIndeterminate] - compressed_row; + let compressed_row = row[CLK.base_table_index()] + * challenges[ChallengeId::JumpStackClkWeight] + + row[CI.base_table_index()] * challenges[ChallengeId::JumpStackCiWeight] + + row[JSP.base_table_index()] * challenges[ChallengeId::JumpStackJspWeight] + + row[JSO.base_table_index()] * challenges[ChallengeId::JumpStackJsoWeight] + + row[JSD.base_table_index()] * challenges[ChallengeId::JumpStackJsdWeight]; + jump_stack_running_product *= + challenges[ChallengeId::JumpStackIndeterminate] - compressed_row; extension_column.push(jump_stack_running_product); } Array2::from_shape_vec((base_table.nrows(), 1), extension_column).unwrap() @@ -272,7 +261,7 @@ fn extension_column_hash_input_eval_argument( challenges: &Challenges, ) -> Array2 { let st0_through_st9 = [ST0, ST1, ST2, ST3, ST4, ST5, ST6, ST7, ST8, ST9]; - let hash_state_weights = &challenges[StackWeight0..StackWeight10]; + let hash_state_weights = &challenges[ChallengeId::StackWeight0..ChallengeId::StackWeight10]; let merkle_step_left_sibling = [ST0, ST1, ST2, ST3, ST4, HV0, HV1, HV2, HV3, HV4]; let merkle_step_right_sibling = [HV0, HV1, HV2, HV3, HV4, ST0, ST1, ST2, ST3, ST4]; @@ -298,8 +287,9 @@ fn extension_column_hash_input_eval_argument( .zip_eq(hash_state_weights.iter()) .map(|(st, &weight)| weight * st) .sum::(); - hash_input_running_evaluation = - hash_input_running_evaluation * challenges[HashInputIndeterminate] + compressed_row; + hash_input_running_evaluation = hash_input_running_evaluation + * challenges[ChallengeId::HashInputIndeterminate] + + compressed_row; } extension_column.push(hash_input_running_evaluation); } @@ -323,11 +313,11 @@ fn extension_column_hash_digest_eval_argument( let compressed_row = [ST0, ST1, ST2, ST3, ST4] .map(|st| current_row[st.base_table_index()]) .into_iter() - .zip_eq(&challenges[StackWeight0..=StackWeight4]) + .zip_eq(&challenges[ChallengeId::StackWeight0..=ChallengeId::StackWeight4]) .map(|(st, &weight)| weight * st) .sum::(); hash_digest_running_evaluation = hash_digest_running_evaluation - * challenges[HashDigestIndeterminate] + * challenges[ChallengeId::HashDigestIndeterminate] + compressed_row; } extension_column.push(hash_digest_running_evaluation); @@ -341,7 +331,7 @@ fn extension_column_sponge_eval_argument( challenges: &Challenges, ) -> Array2 { let st0_through_st9 = [ST0, ST1, ST2, ST3, ST4, ST5, ST6, ST7, ST8, ST9]; - let hash_state_weights = &challenges[StackWeight0..StackWeight10]; + let hash_state_weights = &challenges[ChallengeId::StackWeight0..ChallengeId::StackWeight10]; let mut sponge_running_evaluation = EvalArg::default_initial(); let mut extension_column = Vec::with_capacity(base_table.nrows()); @@ -349,8 +339,9 @@ fn extension_column_sponge_eval_argument( for (previous_row, current_row) in base_table.rows().into_iter().tuple_windows() { let previous_ci = previous_row[CI.base_table_index()]; if previous_ci == Instruction::SpongeInit.opcode_b() { - sponge_running_evaluation = sponge_running_evaluation * challenges[SpongeIndeterminate] - + challenges[HashCIWeight] * Instruction::SpongeInit.opcode_b(); + sponge_running_evaluation = sponge_running_evaluation + * challenges[ChallengeId::SpongeIndeterminate] + + challenges[ChallengeId::HashCIWeight] * Instruction::SpongeInit.opcode_b(); } else if previous_ci == Instruction::SpongeAbsorb.opcode_b() { let compressed_row = st0_through_st9 .map(|st| previous_row[st.base_table_index()]) @@ -358,8 +349,9 @@ fn extension_column_sponge_eval_argument( .zip_eq(hash_state_weights.iter()) .map(|(st, &weight)| weight * st) .sum::(); - sponge_running_evaluation = sponge_running_evaluation * challenges[SpongeIndeterminate] - + challenges[HashCIWeight] * Instruction::SpongeAbsorb.opcode_b() + sponge_running_evaluation = sponge_running_evaluation + * challenges[ChallengeId::SpongeIndeterminate] + + challenges[ChallengeId::HashCIWeight] * Instruction::SpongeAbsorb.opcode_b() + compressed_row; } else if previous_ci == Instruction::SpongeAbsorbMem.opcode_b() { let stack_elements = [ST1, ST2, ST3, ST4]; @@ -371,8 +363,9 @@ fn extension_column_sponge_eval_argument( .zip_eq(hash_state_weights.iter()) .map(|(element, &weight)| weight * element) .sum::(); - sponge_running_evaluation = sponge_running_evaluation * challenges[SpongeIndeterminate] - + challenges[HashCIWeight] * Instruction::SpongeAbsorb.opcode_b() + sponge_running_evaluation = sponge_running_evaluation + * challenges[ChallengeId::SpongeIndeterminate] + + challenges[ChallengeId::HashCIWeight] * Instruction::SpongeAbsorb.opcode_b() + compressed_row; } else if previous_ci == Instruction::SpongeSqueeze.opcode_b() { let compressed_row = st0_through_st9 @@ -381,8 +374,9 @@ fn extension_column_sponge_eval_argument( .zip_eq(hash_state_weights.iter()) .map(|(st, &weight)| weight * st) .sum::(); - sponge_running_evaluation = sponge_running_evaluation * challenges[SpongeIndeterminate] - + challenges[HashCIWeight] * Instruction::SpongeSqueeze.opcode_b() + sponge_running_evaluation = sponge_running_evaluation + * challenges[ChallengeId::SpongeIndeterminate] + + challenges[ChallengeId::HashCIWeight] * Instruction::SpongeSqueeze.opcode_b() + compressed_row; } extension_column.push(sponge_running_evaluation); @@ -399,19 +393,21 @@ fn extension_column_for_u32_lookup_argument( for (previous_row, current_row) in base_table.rows().into_iter().tuple_windows() { let previous_ci = previous_row[CI.base_table_index()]; if previous_ci == Instruction::Split.opcode_b() { - let compressed_row = current_row[ST0.base_table_index()] * challenges[U32LhsWeight] - + current_row[ST1.base_table_index()] * challenges[U32RhsWeight] - + previous_row[CI.base_table_index()] * challenges[U32CiWeight]; - to_invert.push(challenges[U32Indeterminate] - compressed_row); + let compressed_row = current_row[ST0.base_table_index()] + * challenges[ChallengeId::U32LhsWeight] + + current_row[ST1.base_table_index()] * challenges[ChallengeId::U32RhsWeight] + + previous_row[CI.base_table_index()] * challenges[ChallengeId::U32CiWeight]; + to_invert.push(challenges[ChallengeId::U32Indeterminate] - compressed_row); } else if previous_ci == Instruction::Lt.opcode_b() || previous_ci == Instruction::And.opcode_b() || previous_ci == Instruction::Pow.opcode_b() { - let compressed_row = previous_row[ST0.base_table_index()] * challenges[U32LhsWeight] - + previous_row[ST1.base_table_index()] * challenges[U32RhsWeight] - + previous_row[CI.base_table_index()] * challenges[U32CiWeight] - + current_row[ST0.base_table_index()] * challenges[U32ResultWeight]; - to_invert.push(challenges[U32Indeterminate] - compressed_row); + let compressed_row = previous_row[ST0.base_table_index()] + * challenges[ChallengeId::U32LhsWeight] + + previous_row[ST1.base_table_index()] * challenges[ChallengeId::U32RhsWeight] + + previous_row[CI.base_table_index()] * challenges[ChallengeId::U32CiWeight] + + current_row[ST0.base_table_index()] * challenges[ChallengeId::U32ResultWeight]; + to_invert.push(challenges[ChallengeId::U32Indeterminate] - compressed_row); } else if previous_ci == Instruction::Xor.opcode_b() { // Triton VM uses the following equality to compute the results of both the // `and` and `xor` instruction using the u32 coprocessor's `and` capability: @@ -422,37 +418,41 @@ fn extension_column_for_u32_lookup_argument( let st0 = current_row[ST0.base_table_index()]; let from_xor_in_processor_to_and_in_u32_coprocessor = (st0_prev + st1_prev - st0) / bfe!(2); - let compressed_row = st0_prev * challenges[U32LhsWeight] - + st1_prev * challenges[U32RhsWeight] - + Instruction::And.opcode_b() * challenges[U32CiWeight] - + from_xor_in_processor_to_and_in_u32_coprocessor * challenges[U32ResultWeight]; - to_invert.push(challenges[U32Indeterminate] - compressed_row); + let compressed_row = st0_prev * challenges[ChallengeId::U32LhsWeight] + + st1_prev * challenges[ChallengeId::U32RhsWeight] + + Instruction::And.opcode_b() * challenges[ChallengeId::U32CiWeight] + + from_xor_in_processor_to_and_in_u32_coprocessor + * challenges[ChallengeId::U32ResultWeight]; + to_invert.push(challenges[ChallengeId::U32Indeterminate] - compressed_row); } else if previous_ci == Instruction::Log2Floor.opcode_b() || previous_ci == Instruction::PopCount.opcode_b() { - let compressed_row = previous_row[ST0.base_table_index()] * challenges[U32LhsWeight] - + previous_row[CI.base_table_index()] * challenges[U32CiWeight] - + current_row[ST0.base_table_index()] * challenges[U32ResultWeight]; - to_invert.push(challenges[U32Indeterminate] - compressed_row); + let compressed_row = previous_row[ST0.base_table_index()] + * challenges[ChallengeId::U32LhsWeight] + + previous_row[CI.base_table_index()] * challenges[ChallengeId::U32CiWeight] + + current_row[ST0.base_table_index()] * challenges[ChallengeId::U32ResultWeight]; + to_invert.push(challenges[ChallengeId::U32Indeterminate] - compressed_row); } else if previous_ci == Instruction::DivMod.opcode_b() { let compressed_row_for_lt_check = current_row[ST0.base_table_index()] - * challenges[U32LhsWeight] - + previous_row[ST1.base_table_index()] * challenges[U32RhsWeight] - + Instruction::Lt.opcode_b() * challenges[U32CiWeight] - + bfe!(1) * challenges[U32ResultWeight]; + * challenges[ChallengeId::U32LhsWeight] + + previous_row[ST1.base_table_index()] * challenges[ChallengeId::U32RhsWeight] + + Instruction::Lt.opcode_b() * challenges[ChallengeId::U32CiWeight] + + bfe!(1) * challenges[ChallengeId::U32ResultWeight]; let compressed_row_for_range_check = previous_row[ST0.base_table_index()] - * challenges[U32LhsWeight] - + current_row[ST1.base_table_index()] * challenges[U32RhsWeight] - + Instruction::Split.opcode_b() * challenges[U32CiWeight]; - to_invert.push(challenges[U32Indeterminate] - compressed_row_for_lt_check); - to_invert.push(challenges[U32Indeterminate] - compressed_row_for_range_check); + * challenges[ChallengeId::U32LhsWeight] + + current_row[ST1.base_table_index()] * challenges[ChallengeId::U32RhsWeight] + + Instruction::Split.opcode_b() * challenges[ChallengeId::U32CiWeight]; + to_invert.push(challenges[ChallengeId::U32Indeterminate] - compressed_row_for_lt_check); + to_invert + .push(challenges[ChallengeId::U32Indeterminate] - compressed_row_for_range_check); } else if previous_ci == Instruction::MerkleStep.opcode_b() || previous_ci == Instruction::MerkleStepMem.opcode_b() { - let compressed_row = previous_row[ST5.base_table_index()] * challenges[U32LhsWeight] - + current_row[ST5.base_table_index()] * challenges[U32RhsWeight] - + Instruction::Split.opcode_b() * challenges[U32CiWeight]; - to_invert.push(challenges[U32Indeterminate] - compressed_row); + let compressed_row = previous_row[ST5.base_table_index()] + * challenges[ChallengeId::U32LhsWeight] + + current_row[ST5.base_table_index()] * challenges[ChallengeId::U32RhsWeight] + + Instruction::Split.opcode_b() * challenges[ChallengeId::U32CiWeight]; + to_invert.push(challenges[ChallengeId::U32Indeterminate] - compressed_row); } } let mut inverses = XFieldElement::batch_inversion(to_invert).into_iter(); @@ -491,7 +491,7 @@ fn extension_column_for_clock_jump_difference_lookup_argument( let lookup_multiplicity = row[ClockJumpDifferenceLookupMultiplicity.base_table_index()]; if !lookup_multiplicity.is_zero() { let clk = row[CLK.base_table_index()]; - to_invert.push(challenges[ClockJumpDifferenceLookupIndeterminate] - clk); + to_invert.push(challenges[ChallengeId::ClockJumpDifferenceLookupIndeterminate] - clk); } } let mut inverses = XFieldElement::batch_inversion(to_invert).into_iter(); @@ -549,11 +549,11 @@ fn factor_for_op_stack_table_running_product( let clk = previous_row[CLK.base_table_index()]; let ib1_shrink_stack = previous_row[IB1.base_table_index()]; - let compressed_row = clk * challenges[OpStackClkWeight] - + ib1_shrink_stack * challenges[OpStackIb1Weight] - + offset_op_stack_pointer * challenges[OpStackPointerWeight] - + underflow_element * challenges[OpStackFirstUnderflowElementWeight]; - factor *= challenges[OpStackIndeterminate] - compressed_row; + let compressed_row = clk * challenges[ChallengeId::OpStackClkWeight] + + ib1_shrink_stack * challenges[ChallengeId::OpStackIb1Weight] + + offset_op_stack_pointer * challenges[ChallengeId::OpStackPointerWeight] + + underflow_element * challenges[ChallengeId::OpStackFirstUnderflowElementWeight]; + factor *= challenges[ChallengeId::OpStackIndeterminate] - compressed_row; } factor } @@ -648,12 +648,12 @@ fn factor_for_ram_table_running_product( accesses .into_iter() .map(|(ramp, ramv)| { - clk * challenges[RamClkWeight] - + instruction_type * challenges[RamInstructionTypeWeight] - + ramp * challenges[RamPointerWeight] - + ramv * challenges[RamValueWeight] + clk * challenges[ChallengeId::RamClkWeight] + + instruction_type * challenges[ChallengeId::RamInstructionTypeWeight] + + ramp * challenges[ChallengeId::RamPointerWeight] + + ramv * challenges[ChallengeId::RamValueWeight] }) - .map(|compressed_row| challenges[RamIndeterminate] - compressed_row) + .map(|compressed_row| challenges[ChallengeId::RamIndeterminate] - compressed_row) .reduce(|l, r| l * r) } @@ -690,31 +690,24 @@ fn instruction_from_row(row: ArrayView1) -> Option { pub(crate) mod tests { use std::collections::HashMap; - use air::table::processor::NUM_HELPER_VARIABLE_REGISTERS; - use air::table::NUM_BASE_COLUMNS; - use air::table::NUM_EXT_COLUMNS; use assert2::assert; + use constraint_circuit::*; use isa::instruction::Instruction; - use isa::op_stack::NumberOfWords::*; + use isa::op_stack::NumberOfWords; use isa::op_stack::OpStackElement; use isa::program::Program; use isa::triton_asm; use isa::triton_program; use ndarray::Array2; use proptest::collection::vec; - use proptest::prop_assert_eq; use proptest_arbitrary_interop::arb; - use rand::thread_rng; - use rand::Rng; use strum::IntoEnumIterator; use test_strategy::proptest; use crate::error::InstructionError::DivisionByZero; - use crate::prelude::PublicInput; use crate::shared_tests::ProgramAndInput; use crate::stark::tests::master_tables_for_low_security_level; use crate::table::master_table::*; - use crate::vm::VMState; use crate::vm::VM; use crate::NonDeterminism; @@ -988,7 +981,7 @@ pub(crate) mod tests { }); let test_rows = programs_with_input.map(|p_w_i| test_row_from_program_with_input(p_w_i, 1)); let debug_info = TestRowsDebugInfo { - instruction: ReadMem(N1), + instruction: ReadMem(NumberOfWords::N1), debug_cols_curr_row: vec![ST0, ST1], debug_cols_next_row: vec![ST0, ST1], }; @@ -1007,7 +1000,7 @@ pub(crate) mod tests { ]; let test_rows = programs.map(|program| test_row_from_program(program, 10)); let debug_info = TestRowsDebugInfo { - instruction: WriteMem(N1), + instruction: WriteMem(NumberOfWords::N1), debug_cols_curr_row: vec![ST0, ST1], debug_cols_next_row: vec![ST0, ST1], }; diff --git a/triton-vm/src/table/program.rs b/triton-vm/src/table/program.rs index 86e77f533..78d334bf8 100644 --- a/triton-vm/src/table/program.rs +++ b/triton-vm/src/table/program.rs @@ -9,9 +9,6 @@ use air::table::TableId; use air::table_column::ProgramBaseTableColumn::*; use air::table_column::ProgramExtTableColumn::*; use air::table_column::*; -use constraint_circuit::DualRowIndicator::*; -use constraint_circuit::SingleRowIndicator::*; -use constraint_circuit::*; use ndarray::s; use ndarray::Array1; use ndarray::ArrayView1; diff --git a/triton-vm/src/table/ram.rs b/triton-vm/src/table/ram.rs index 0a0f33b9a..567d20a51 100644 --- a/triton-vm/src/table/ram.rs +++ b/triton-vm/src/table/ram.rs @@ -6,13 +6,9 @@ use air::table::ram::RamTable; use air::table::ram::PADDING_INDICATOR; use air::table::TableId; use air::table_column::RamBaseTableColumn::*; -use air::table_column::RamExtTableColumn::*; use air::table_column::*; use air::AIR; use arbitrary::Arbitrary; -use constraint_circuit::DualRowIndicator::*; -use constraint_circuit::SingleRowIndicator::*; -use constraint_circuit::*; use itertools::Itertools; use ndarray::parallel::prelude::*; use ndarray::prelude::*; diff --git a/triton-vm/src/table/u32.rs b/triton-vm/src/table/u32.rs index 7b9f5f7be..e6a92caf5 100644 --- a/triton-vm/src/table/u32.rs +++ b/triton-vm/src/table/u32.rs @@ -1,5 +1,4 @@ use std::cmp::max; -use std::ops::Mul; use air::challenge_id::ChallengeId::*; use air::cross_table_argument::CrossTableArg; @@ -7,18 +6,7 @@ use air::cross_table_argument::LookupArg; use air::table::u32::U32Table; use air::table_column::MasterBaseTableColumn; use air::table_column::MasterExtTableColumn; -use air::table_column::U32BaseTableColumn; -use air::table_column::U32BaseTableColumn::*; -use air::table_column::U32ExtTableColumn; -use air::table_column::U32ExtTableColumn::*; use arbitrary::Arbitrary; -use constraint_circuit::ConstraintCircuitBuilder; -use constraint_circuit::ConstraintCircuitMonad; -use constraint_circuit::DualRowIndicator; -use constraint_circuit::DualRowIndicator::*; -use constraint_circuit::InputIndicator; -use constraint_circuit::SingleRowIndicator; -use constraint_circuit::SingleRowIndicator::*; use isa::instruction::Instruction; use ndarray::parallel::prelude::*; use ndarray::s; @@ -37,6 +25,9 @@ use crate::challenges::Challenges; use crate::profiler::profiler; use crate::table::TraceTable; +type MainColumn = ::MainColumn; +type AuxColumn = ::AuxColumn; + /// An executed u32 instruction as well as its operands. #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)] pub struct U32TableEntry { @@ -80,14 +71,15 @@ impl TraceTable for U32Table { fn fill(mut u32_table: ArrayViewMut2, aet: &AlgebraicExecutionTrace, _: ()) { let mut next_section_start = 0; for (&u32_table_entry, &multiplicity) in &aet.u32_entries { - let mut first_row = Array2::zeros([1, Self::MainColumn::COUNT]); - first_row[[0, CopyFlag.base_table_index()]] = bfe!(1); - first_row[[0, Bits.base_table_index()]] = bfe!(0); - first_row[[0, BitsMinus33Inv.base_table_index()]] = bfe!(-33).inverse(); - first_row[[0, CI.base_table_index()]] = u32_table_entry.instruction.opcode_b(); - first_row[[0, LHS.base_table_index()]] = u32_table_entry.left_operand; - first_row[[0, RHS.base_table_index()]] = u32_table_entry.right_operand; - first_row[[0, LookupMultiplicity.base_table_index()]] = multiplicity.into(); + let mut first_row = Array2::zeros([1, MainColumn::COUNT]); + first_row[[0, MainColumn::CopyFlag.base_table_index()]] = bfe!(1); + first_row[[0, MainColumn::Bits.base_table_index()]] = bfe!(0); + first_row[[0, MainColumn::BitsMinus33Inv.base_table_index()]] = bfe!(-33).inverse(); + first_row[[0, MainColumn::CI.base_table_index()]] = + u32_table_entry.instruction.opcode_b(); + first_row[[0, MainColumn::LHS.base_table_index()]] = u32_table_entry.left_operand; + first_row[[0, MainColumn::RHS.base_table_index()]] = u32_table_entry.right_operand; + first_row[[0, MainColumn::LookupMultiplicity.base_table_index()]] = multiplicity.into(); let u32_section = u32_section_next_row(first_row); let next_section_end = next_section_start + u32_section.nrows(); @@ -99,22 +91,26 @@ impl TraceTable for U32Table { } fn pad(mut main_table: ArrayViewMut2, table_len: usize) { - let mut padding_row = Array1::zeros([Self::MainColumn::COUNT]); - padding_row[[CI.base_table_index()]] = Instruction::Split.opcode_b(); - padding_row[[BitsMinus33Inv.base_table_index()]] = bfe!(-33).inverse(); + let mut padding_row = Array1::zeros([MainColumn::COUNT]); + padding_row[[MainColumn::CI.base_table_index()]] = Instruction::Split.opcode_b(); + padding_row[[MainColumn::BitsMinus33Inv.base_table_index()]] = bfe!(-33).inverse(); if table_len > 0 { let last_row = main_table.row(table_len - 1); - padding_row[[CI.base_table_index()]] = last_row[CI.base_table_index()]; - padding_row[[LHS.base_table_index()]] = last_row[LHS.base_table_index()]; - padding_row[[LhsInv.base_table_index()]] = last_row[LhsInv.base_table_index()]; - padding_row[[Result.base_table_index()]] = last_row[Result.base_table_index()]; + padding_row[[MainColumn::CI.base_table_index()]] = + last_row[MainColumn::CI.base_table_index()]; + padding_row[[MainColumn::LHS.base_table_index()]] = + last_row[MainColumn::LHS.base_table_index()]; + padding_row[[MainColumn::LhsInv.base_table_index()]] = + last_row[MainColumn::LhsInv.base_table_index()]; + padding_row[[MainColumn::Result.base_table_index()]] = + last_row[MainColumn::Result.base_table_index()]; // In the edge case that the last non-padding row comes from executing instruction // `lt` on operands 0 and 0, the `Result` column is 0. For the padding section, // where the `CopyFlag` is always 0, the `Result` needs to be set to 2 instead. - if padding_row[[CI.base_table_index()]] == Instruction::Lt.opcode_b() { - padding_row[[Result.base_table_index()]] = bfe!(2); + if padding_row[[MainColumn::CI.base_table_index()]] == Instruction::Lt.opcode_b() { + padding_row[[MainColumn::Result.base_table_index()]] = bfe!(2); } } @@ -131,8 +127,8 @@ impl TraceTable for U32Table { challenges: &Challenges, ) { profiler!(start "u32 table"); - assert_eq!(Self::MainColumn::COUNT, base_table.ncols()); - assert_eq!(Self::AuxColumn::COUNT, ext_table.ncols()); + assert_eq!(MainColumn::COUNT, base_table.ncols()); + assert_eq!(AuxColumn::COUNT, ext_table.ncols()); assert_eq!(base_table.nrows(), ext_table.nrows()); let ci_weight = challenges[U32CiWeight]; @@ -144,18 +140,20 @@ impl TraceTable for U32Table { let mut running_sum_log_derivative = LookupArg::default_initial(); for row_idx in 0..base_table.nrows() { let current_row = base_table.row(row_idx); - if current_row[CopyFlag.base_table_index()].is_one() { - let lookup_multiplicity = current_row[LookupMultiplicity.base_table_index()]; - let compressed_row = ci_weight * current_row[CI.base_table_index()] - + lhs_weight * current_row[LHS.base_table_index()] - + rhs_weight * current_row[RHS.base_table_index()] - + result_weight * current_row[Result.base_table_index()]; + if current_row[MainColumn::CopyFlag.base_table_index()].is_one() { + let lookup_multiplicity = + current_row[MainColumn::LookupMultiplicity.base_table_index()]; + let compressed_row = ci_weight * current_row[MainColumn::CI.base_table_index()] + + lhs_weight * current_row[MainColumn::LHS.base_table_index()] + + rhs_weight * current_row[MainColumn::RHS.base_table_index()] + + result_weight * current_row[MainColumn::Result.base_table_index()]; running_sum_log_derivative += lookup_multiplicity * (lookup_indeterminate - compressed_row).inverse(); } let mut extension_row = ext_table.row_mut(row_idx); - extension_row[LookupServerLogDerivative.ext_table_index()] = running_sum_log_derivative; + extension_row[AuxColumn::LookupServerLogDerivative.ext_table_index()] = + running_sum_log_derivative; } profiler!(stop "u32 table"); } @@ -163,17 +161,17 @@ impl TraceTable for U32Table { fn u32_section_next_row(mut section: Array2) -> Array2 { let row_idx = section.nrows() - 1; - let current_instruction: Instruction = section[[row_idx, CI.base_table_index()]] + let current_instruction: Instruction = section[[row_idx, MainColumn::CI.base_table_index()]] .value() .try_into() .expect("Unknown instruction"); // Is the last row in this section reached? - if (section[[row_idx, LHS.base_table_index()]].is_zero() + if (section[[row_idx, MainColumn::LHS.base_table_index()]].is_zero() || current_instruction == Instruction::Pow) - && section[[row_idx, RHS.base_table_index()]].is_zero() + && section[[row_idx, MainColumn::RHS.base_table_index()]].is_zero() { - section[[row_idx, Result.base_table_index()]] = match current_instruction { + section[[row_idx, MainColumn::Result.base_table_index()]] = match current_instruction { Instruction::Split => bfe!(0), Instruction::Lt => bfe!(2), Instruction::And => bfe!(0), @@ -185,50 +183,52 @@ fn u32_section_next_row(mut section: Array2) -> Array2 section[[row_idx, LHS.base_table_index()]], - false => (section[[row_idx, LHS.base_table_index()]] - lhs_lsb) / bfe!(2), + next_row[MainColumn::CopyFlag.base_table_index()] = bfe!(0); + next_row[MainColumn::Bits.base_table_index()] += bfe!(1); + next_row[MainColumn::BitsMinus33Inv.base_table_index()] = + (next_row[MainColumn::Bits.base_table_index()] - bfe!(33)).inverse(); + next_row[MainColumn::LHS.base_table_index()] = match current_instruction == Instruction::Pow { + true => section[[row_idx, MainColumn::LHS.base_table_index()]], + false => (section[[row_idx, MainColumn::LHS.base_table_index()]] - lhs_lsb) / bfe!(2), }; - next_row[RHS.base_table_index()] = - (section[[row_idx, RHS.base_table_index()]] - rhs_lsb) / bfe!(2); - next_row[LookupMultiplicity.base_table_index()] = bfe!(0); + next_row[MainColumn::RHS.base_table_index()] = + (section[[row_idx, MainColumn::RHS.base_table_index()]] - rhs_lsb) / bfe!(2); + next_row[MainColumn::LookupMultiplicity.base_table_index()] = bfe!(0); section.push_row(next_row.view()).unwrap(); section = u32_section_next_row(section); let (mut row, next_row) = section.multi_slice_mut((s![row_idx, ..], s![row_idx + 1, ..])); - row[LhsInv.base_table_index()] = row[LHS.base_table_index()].inverse_or_zero(); - row[RhsInv.base_table_index()] = row[RHS.base_table_index()].inverse_or_zero(); + row[MainColumn::LhsInv.base_table_index()] = + row[MainColumn::LHS.base_table_index()].inverse_or_zero(); + row[MainColumn::RhsInv.base_table_index()] = + row[MainColumn::RHS.base_table_index()].inverse_or_zero(); - let next_row_result = next_row[Result.base_table_index()]; - row[Result.base_table_index()] = match current_instruction { + let next_row_result = next_row[MainColumn::Result.base_table_index()]; + row[MainColumn::Result.base_table_index()] = match current_instruction { Instruction::Split => next_row_result, Instruction::Lt => { match ( next_row_result.value(), lhs_lsb.value(), rhs_lsb.value(), - row[CopyFlag.base_table_index()].value(), + row[MainColumn::CopyFlag.base_table_index()].value(), ) { (0 | 1, _, _, _) => next_row_result, // result already known (2, 0, 1, _) => bfe!(1), // LHS < RHS @@ -240,18 +240,18 @@ fn u32_section_next_row(mut section: Array2) -> Array2 bfe!(2) * next_row_result + lhs_lsb * rhs_lsb, Instruction::Log2Floor => { - if row[LHS.base_table_index()].is_zero() { + if row[MainColumn::LHS.base_table_index()].is_zero() { bfe!(-1) - } else if !next_row[LHS.base_table_index()].is_zero() { + } else if !next_row[MainColumn::LHS.base_table_index()].is_zero() { next_row_result } else { // LHS != 0 && LHS' == 0 - row[Bits.base_table_index()] + row[MainColumn::Bits.base_table_index()] } } Instruction::Pow => match rhs_lsb.is_zero() { true => next_row_result * next_row_result, - false => next_row_result * next_row_result * row[LHS.base_table_index()], + false => next_row_result * next_row_result * row[MainColumn::LHS.base_table_index()], }, Instruction::PopCount => next_row_result + lhs_lsb, _ => panic!("Must be u32 instruction, not {current_instruction}."), diff --git a/triton-vm/src/vm.rs b/triton-vm/src/vm.rs index ba85bd4dd..da5162ac0 100644 --- a/triton-vm/src/vm.rs +++ b/triton-vm/src/vm.rs @@ -35,7 +35,6 @@ use crate::execution_trace_profiler::ExecutionTraceProfile; use crate::execution_trace_profiler::ExecutionTraceProfiler; use crate::profiler::profiler; use crate::table::op_stack::OpStackTableEntry; -use crate::table::processor; use crate::table::ram::RamTableCall; use crate::table::u32::U32TableEntry; use crate::vm::CoProcessorCall::*; From 70229677bbc197363c782df986f0d455bba5df6d Mon Sep 17 00:00:00 2001 From: Jan Ferdinand Sauer Date: Tue, 3 Sep 2024 15:15:33 +0200 Subject: [PATCH 08/15] CI: Remove obsolete constraint builder directive changelog: ignore --- .github/workflows/coverage.yml | 3 --- .github/workflows/main.yml | 3 --- 2 files changed, 6 deletions(-) diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 91cff2fca..41779007a 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -27,9 +27,6 @@ jobs: - name: Install nextest uses: taiki-e/install-action@nextest - - name: Build AIR constraints - run: cargo run --bin constraint-evaluation-generator - - name: Collect coverage data run: cargo llvm-cov nextest --all-targets --lcov --output-path lcov.info diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 519f2986c..7f8b3ce08 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -28,9 +28,6 @@ jobs: - name: Install nextest uses: taiki-e/install-action@nextest - - name: Build AIR constraints - run: cargo run --bin constraint-evaluation-generator - - name: Run fmt run: cargo fmt --all -- --check From 65545fe9f6eb2f690abaad505eb79914e86544f6 Mon Sep 17 00:00:00 2001 From: Jan Ferdinand Sauer Date: Wed, 4 Sep 2024 14:24:02 +0200 Subject: [PATCH 09/15] perf(test): Remove super slow try-build test --- Cargo.toml | 1 - triton-vm/Cargo.toml | 1 - triton-vm/src/profiler.rs | 7 ------- triton-vm/trybuild/profiler_macro_is_private.rs | 3 --- triton-vm/trybuild/profiler_macro_is_private.stderr | 11 ----------- 5 files changed, 23 deletions(-) delete mode 100644 triton-vm/trybuild/profiler_macro_is_private.rs delete mode 100644 triton-vm/trybuild/profiler_macro_is_private.stderr diff --git a/Cargo.toml b/Cargo.toml index 55cb8d41b..4bbce6559 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,7 +62,6 @@ strum = { version = "0.26", features = ["derive"] } syn = "2.0" test-strategy = "0.4.0" thiserror = "1.0" -trybuild = "1.0" twenty-first = "0.42.0-alpha.9" unicode-width = "0.1" diff --git a/triton-vm/Cargo.toml b/triton-vm/Cargo.toml index b54daa884..b352ef9b6 100644 --- a/triton-vm/Cargo.toml +++ b/triton-vm/Cargo.toml @@ -55,7 +55,6 @@ proptest.workspace = true proptest-arbitrary-interop.workspace = true serde_json.workspace = true test-strategy.workspace = true -trybuild.workspace = true [build-dependencies] air.workspace = true diff --git a/triton-vm/src/profiler.rs b/triton-vm/src/profiler.rs index a4f7350ea..5043e830c 100644 --- a/triton-vm/src/profiler.rs +++ b/triton-vm/src/profiler.rs @@ -671,16 +671,9 @@ mod tests { use std::time::Duration; use test_strategy::proptest; - use trybuild; use super::*; - #[test] - fn profiler_macro_is_private() { - let trybuild = trybuild::TestCases::new(); - trybuild.compile_fail("trybuild/profiler_macro_is_private.rs"); - } - #[test] fn sanity() { let mut profiler = VMPerformanceProfiler::new("Sanity Test"); diff --git a/triton-vm/trybuild/profiler_macro_is_private.rs b/triton-vm/trybuild/profiler_macro_is_private.rs deleted file mode 100644 index 9f9a45620..000000000 --- a/triton-vm/trybuild/profiler_macro_is_private.rs +++ /dev/null @@ -1,3 +0,0 @@ -use triton_vm::profiler::profiler; - -fn main() {} diff --git a/triton-vm/trybuild/profiler_macro_is_private.stderr b/triton-vm/trybuild/profiler_macro_is_private.stderr deleted file mode 100644 index 48cd30746..000000000 --- a/triton-vm/trybuild/profiler_macro_is_private.stderr +++ /dev/null @@ -1,11 +0,0 @@ -error[E0603]: macro `profiler` is private - --> trybuild/profiler_macro_is_private.rs:1:26 - | -1 | use triton_vm::profiler::profiler; - | ^^^^^^^^ private macro - | -note: the macro `profiler` is defined here - --> src/profiler.rs - | - | pub(crate) use profiler; - | ^^^^^^^^ From 098103e3ebc59bb559779cf3fbaaa37d68d09939 Mon Sep 17 00:00:00 2001 From: Jan Ferdinand Sauer Date: Wed, 4 Sep 2024 15:28:50 +0200 Subject: [PATCH 10/15] docs: Fix intra-doc links changelog: ignore --- triton-air/src/challenge_id.rs | 64 ++++++++++++++++-------- triton-air/src/lib.rs | 11 +--- triton-air/src/table/hash.rs | 16 +++--- triton-air/src/table_column.rs | 18 +++---- triton-constraint-builder/src/codegen.rs | 2 +- triton-vm/src/aet.rs | 4 +- triton-vm/src/challenges.rs | 26 ++++++---- triton-vm/src/stark.rs | 4 +- triton-vm/src/table.rs | 8 +-- triton-vm/src/table/master_table.rs | 41 +++++++++++---- 10 files changed, 115 insertions(+), 79 deletions(-) diff --git a/triton-air/src/challenge_id.rs b/triton-air/src/challenge_id.rs index 044c6414e..fd8975290 100644 --- a/triton-air/src/challenge_id.rs +++ b/triton-air/src/challenge_id.rs @@ -14,21 +14,28 @@ use strum::EnumIter; #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter, Arbitrary)] pub enum ChallengeId { - /// The indeterminate for the [Evaluation Argument](EvalArg) compressing the program digest - /// into a single extension field element, _i.e._, [`CompressedProgramDigest`]. + /// The indeterminate for the [Evaluation Argument][eval] compressing the program digest + /// into a single extension field element, _i.e._, + /// [`CompressedProgramDigest`][Self::CompressedProgramDigest]. /// Relates to program attestation. + /// + /// [eval]: crate::cross_table_argument::EvalArg CompressProgramDigestIndeterminate, - /// The indeterminate for the [Evaluation Argument](EvalArg) with standard input. + /// The indeterminate for the [Evaluation Argument][eval] with standard input. + /// + /// [eval]: crate::cross_table_argument::EvalArg StandardInputIndeterminate, - /// The indeterminate for the [Evaluation Argument](EvalArg) with standard output. + /// The indeterminate for the [Evaluation Argument][eval] with standard output. + /// + /// [eval]: crate::cross_table_argument::EvalArg StandardOutputIndeterminate, /// The indeterminate for the instruction - /// [Lookup Argument](crate::table::cross_table_argument::LookupArg) - /// between the [Processor Table](crate::table::processor_table) and the - /// [Program Table](crate::table::program_table) guaranteeing that the instructions and their + /// [Lookup Argument](crate::cross_table_argument::LookupArg) + /// between the [Processor Table](crate::table::processor) and the + /// [Program Table](crate::table::program) guaranteeing that the instructions and their /// arguments are copied correctly. InstructionLookupIndeterminate, @@ -87,18 +94,20 @@ pub enum ChallengeId { /// /// Used by the evaluation argument [`PrepareChunkEvalArg`][prep] and in the Hash Table. /// - /// [rate]: tip5::RATE - /// [prep]: crate::table::table_column::ProgramExtTableColumn::PrepareChunkRunningEvaluation + /// [rate]: twenty_first::prelude::tip5::RATE + /// [prep]: crate::table_column::ProgramExtTableColumn::PrepareChunkRunningEvaluation ProgramAttestationPrepareChunkIndeterminate, /// The indeterminate for the bus over which the [`RATE`][rate]-sized chunks of instructions /// are sent. Relates to program attestation. /// Used by the evaluation arguments [`SendChunkEvalArg`][send] and - /// [`ReceiveChunkEvalArg`][recv]. See also: [`ProgramAttestationPrepareChunkIndeterminate`]. + /// [`ReceiveChunkEvalArg`][recv]. See also: + /// [`ProgramAttestationPrepareChunkIndeterminate`][ind]. /// - /// [rate]: tip5::RATE - /// [send]: crate::table::table_column::ProgramExtTableColumn::SendChunkRunningEvaluation - /// [recv]: crate::table::table_column::HashExtTableColumn::ReceiveChunkRunningEvaluation + /// [rate]: twenty_first::prelude::tip5::RATE + /// [send]: crate::table_column::ProgramExtTableColumn::SendChunkRunningEvaluation + /// [recv]: crate::table_column::HashExtTableColumn::ReceiveChunkRunningEvaluation + /// [ind]: ChallengeId::ProgramAttestationPrepareChunkIndeterminate ProgramAttestationSendChunkIndeterminate, HashCIWeight, @@ -160,23 +169,34 @@ pub enum ChallengeId { // When modifying this, be sure to add to the compile-time assertions in the // `#[test] const fn compile_time_index_assertions() { … }` // at the end of this file. - /// The terminal for the [`EvaluationArgument`](EvalArg) with standard input. - /// Makes use of challenge [`StandardInputIndeterminate`]. + /// The terminal for the [`EvaluationArgument`][eval] with standard input. + /// Makes use of challenge + /// [`StandardInputIndeterminate`][Self::StandardInputIndeterminate]. + /// + /// [eval]: crate::cross_table_argument::EvalArg StandardInputTerminal, - /// The terminal for the [`EvaluationArgument`](EvalArg) with standard output. - /// Makes use of challenge [`StandardOutputIndeterminate`]. + /// The terminal for the [`EvaluationArgument`][eval] with standard output. + /// Makes use of challenge + /// [`StandardOutputIndeterminate`][Self::StandardOutputIndeterminate]. + /// + /// [eval]: crate::cross_table_argument::EvalArg StandardOutputTerminal, - /// The terminal for the [`EvaluationArgument`](EvalArg) establishing correctness of the - /// [Lookup Table](crate::table::lookup_table::LookupTable). - /// Makes use of challenge [`LookupTablePublicIndeterminate`]. + /// The terminal for the [`EvaluationArgument`][eval] establishing correctness of the + /// [Lookup Table](crate::table::lookup::LookupTable). + /// Makes use of challenge + /// [`LookupTablePublicIndeterminate`][Self::LookupTablePublicIndeterminate]. + /// + /// [eval]: crate::cross_table_argument::EvalArg LookupTablePublicTerminal, /// The digest of the program to be executed, compressed into a single extension field element. - /// The compression happens using an [`EvaluationArgument`](EvalArg) under challenge - /// [`CompressProgramDigestIndeterminate`]. + /// The compression happens using an [`EvaluationArgument`][eval] under challenge + /// [`CompressProgramDigestIndeterminate`][Self::CompressProgramDigestIndeterminate]. /// Relates to program attestation. + /// + /// [eval]: crate::cross_table_argument::EvalArg CompressedProgramDigest, } diff --git a/triton-air/src/lib.rs b/triton-air/src/lib.rs index 1f11a29fe..8046f54c2 100644 --- a/triton-air/src/lib.rs +++ b/triton-air/src/lib.rs @@ -16,7 +16,7 @@ pub mod table_column; /// /// Using substitution and the introduction of new variables, the degree of the AIR as specified /// in the respective tables -/// (e.g., in [`processor_table::ExtProcessorTable::transition_constraints`]) +/// (e.g., in [`table::processor::ProcessorTable::transition_constraints`]) /// is lowered to this value. /// For example, with a target degree of 2 and a (fictional) constraint of the form /// `a = b²·c²·d`, @@ -25,14 +25,7 @@ pub mod table_column; /// - introduce new constraints `e = b²`, `f = c²`, and `g = e·f`, /// - replace the original constraint with `a = g·d`. /// -/// The degree lowering happens in the constraint evaluation generator. -/// It can be executed by running `cargo run --bin constraint-evaluation-generator`. -/// Executing the constraint evaluator is a prerequisite for running both the Stark prover -/// and the Stark verifier. -/// -/// The new variables introduced by the degree lowering step are called “derived columns.” -/// They are added to the [`DegreeLoweringTable`], whose sole purpose is to store the values -/// of these derived columns. +/// The degree lowering happens in the Triton VM's build script, `build.rs`. pub const TARGET_DEGREE: isize = 4; pub trait AIR { diff --git a/triton-air/src/table/hash.rs b/triton-air/src/table/hash.rs index 89fac3fc3..4f463a82e 100644 --- a/triton-air/src/table/hash.rs +++ b/triton-air/src/table/hash.rs @@ -26,12 +26,11 @@ use crate::table_column::MasterBaseTableColumn; use crate::table_column::MasterExtTableColumn; use crate::AIR; -/// See [`HashTable::base_field_element_into_16_bit_limbs`] for more details. pub const MONTGOMERY_MODULUS: BFieldElement = BFieldElement::new(((1_u128 << 64) % BFieldElement::P as u128) as u64); -pub const POWER_MAP_EXPONENT: u64 = 7; -pub const NUM_ROUND_CONSTANTS: usize = tip5::STATE_SIZE; +const POWER_MAP_EXPONENT: u64 = 7; +const NUM_ROUND_CONSTANTS: usize = tip5::STATE_SIZE; pub const PERMUTATION_TRACE_LENGTH: usize = NUM_ROUNDS + 1; @@ -159,9 +158,11 @@ impl HashTable { .fold(constant(1), |accumulator, factor| accumulator * factor) } - /// The [`HashBaseTableColumn`] for the round constant corresponding to the given index. + /// The [main column][col] for the round constant corresponding to the given index. /// Valid indices are 0 through 15, corresponding to the 16 round constants - /// [`Constant0`] through [`Constant15`]. + /// `Constant0` through `Constant15`. + /// + /// [col]: crate::table_column::HashBaseTableColumn pub fn round_constant_column_by_index(index: usize) -> ::MainColumn { match index { 0 => ::MainColumn::Constant0, @@ -1329,8 +1330,8 @@ impl AIR for HashTable { /// 1. Processing the `hash` instruction. /// 1. Padding mode. /// -/// Changing the mode is only possible when the current [`RoundNumber`] is [`NUM_ROUNDS`]. -/// The mode evolves as +/// Changing the mode is only possible when the current [`RoundNumber`][round_no] +/// is [`NUM_ROUNDS`]. The mode evolves as /// [`ProgramHashing`][prog_hash] → [`Sponge`][sponge] → [`Hash`][hash] → [`Pad`][pad]. /// Once mode [`Pad`][pad] is reached, it is not possible to change the mode anymore. /// Skipping any or all of the modes [`Sponge`][sponge], [`Hash`][hash], or [`Pad`][pad] @@ -1344,6 +1345,7 @@ impl AIR for HashTable { /// The empty program is not valid since any valid [`Program`][program] must execute /// instruction `halt`. /// +/// [round_no]: crate::table_column::HashBaseTableColumn::RoundNumber /// [program]: isa::program::Program /// [prog_hash]: HashTableMode::ProgramHashing /// [sponge]: HashTableMode::Sponge diff --git a/triton-air/src/table_column.rs b/triton-air/src/table_column.rs index 7031da9c8..400934cb0 100644 --- a/triton-air/src/table_column.rs +++ b/triton-air/src/table_column.rs @@ -188,9 +188,9 @@ pub enum RamBaseTableColumn { /// Is [`INSTRUCTION_TYPE_READ`] for instruction `read_mem` and [`INSTRUCTION_TYPE_WRITE`] /// for instruction `write_mem`. For padding rows, this is set to [`PADDING_INDICATOR`]. /// - /// [`INSTRUCTION_TYPE_READ`]: crate::table::ram_table::INSTRUCTION_TYPE_READ - /// [`INSTRUCTION_TYPE_WRITE`]: crate::table::ram_table::INSTRUCTION_TYPE_WRITE - /// [`PADDING_INDICATOR`]: crate::table::ram_table::PADDING_INDICATOR + /// [`INSTRUCTION_TYPE_READ`]: crate::table::ram::INSTRUCTION_TYPE_READ + /// [`INSTRUCTION_TYPE_WRITE`]: crate::table::ram::INSTRUCTION_TYPE_WRITE + /// [`PADDING_INDICATOR`]: crate::table::ram::PADDING_INDICATOR InstructionType, RamPointer, RamValue, @@ -236,14 +236,14 @@ pub enum JumpStackExtTableColumn { pub enum HashBaseTableColumn { /// The indicator for the [`HashTableMode`][mode]. /// - /// [mode]: crate::table::hash_table::HashTableMode + /// [mode]: crate::table::hash::HashTableMode Mode, /// The current instruction. Only relevant for [`Mode`][mode] [`Sponge`][mode_sponge] /// in order to distinguish between the different Sponge instructions. /// /// [mode]: HashBaseTableColumn::Mode - /// [mode_sponge]: crate::table::hash_table::HashTableMode::Sponge + /// [mode_sponge]: crate::table::hash::HashTableMode::Sponge CI, /// The number of the current round in the permutation. The round number evolves as @@ -255,10 +255,10 @@ pub enum HashBaseTableColumn { /// /// [ci]: HashBaseTableColumn::CI /// [mode]: HashBaseTableColumn::Mode - /// [mode_prog_hash]: crate::table::hash_table::HashTableMode::ProgramHashing - /// [mode_sponge]: crate::table::hash_table::HashTableMode::Sponge - /// [mode_hash]: crate::table::hash_table::HashTableMode::Hash - /// [mode_pad]: crate::table::hash_table::HashTableMode::Pad + /// [mode_prog_hash]: crate::table::hash::HashTableMode::ProgramHashing + /// [mode_sponge]: crate::table::hash::HashTableMode::Sponge + /// [mode_hash]: crate::table::hash::HashTableMode::Hash + /// [mode_pad]: crate::table::hash::HashTableMode::Pad RoundNumber, State0HighestLkIn, diff --git a/triton-constraint-builder/src/codegen.rs b/triton-constraint-builder/src/codegen.rs index ff8d9edae..15b9dd00c 100644 --- a/triton-constraint-builder/src/codegen.rs +++ b/triton-constraint-builder/src/codegen.rs @@ -417,7 +417,7 @@ impl Codegen for TasmBackend { /// Emits a function that emits [Triton assembly][tasm] that evaluates Triton VM's AIR /// constraints over the [extension field][XFieldElement]. /// - /// [tasm]: triton_vm::prelude::triton_asm + /// [tasm]: isa::triton_asm fn constraint_evaluation_code(constraints: &Constraints) -> TokenStream { let doc_comment = Self::doc_comment_static_version(); diff --git a/triton-vm/src/aet.rs b/triton-vm/src/aet.rs index 38a2997cf..9bba2f14b 100644 --- a/triton-vm/src/aet.rs +++ b/triton-vm/src/aet.rs @@ -121,7 +121,7 @@ impl AlgebraicExecutionTrace { /// /// Guaranteed to be a power of two. /// - /// [pad]: master_table::MasterBaseTable::pad + /// [pad]: table::master_table::MasterBaseTable::pad pub fn padded_height(&self) -> usize { self.height().height.next_power_of_two() } @@ -129,7 +129,7 @@ impl AlgebraicExecutionTrace { /// The height of the [AET](AlgebraicExecutionTrace) before [padding][pad]. /// Corresponds to the height of the longest table. /// - /// [pad]: master_table::MasterBaseTable::pad + /// [pad]: table::master_table::MasterBaseTable::pad pub fn height(&self) -> TableHeight { TableId::iter() .map(|t| TableHeight::new(t, self.height_of_table(t))) diff --git a/triton-vm/src/challenges.rs b/triton-vm/src/challenges.rs index ca2e99c5f..1fd156181 100644 --- a/triton-vm/src/challenges.rs +++ b/triton-vm/src/challenges.rs @@ -1,7 +1,7 @@ //! Challenges are needed for the [cross-table arguments](CrossTableArg), _i.e._, -//! [Permutation Arguments](crate::cross_table_argument::PermArg), -//! [Evaluation Arguments](crate::cross_table_argument::EvalArg), and -//! [Lookup Arguments](crate::cross_table_argument::LookupArg), +//! [Permutation Arguments](air::cross_table_argument::PermArg), +//! [Evaluation Arguments](EvalArg), and +//! [Lookup Arguments](air::cross_table_argument::LookupArg), //! as well as for the RAM Table's Contiguity Argument. //! //! There are three types of challenges: @@ -50,14 +50,18 @@ impl Challenges { /// from publicly known values and other, sampled challenges. /// /// Concretely: - /// - The [`StandardInputTerminal`] is computed from Triton VM's public input and the sampled - /// indeterminate [`StandardInputIndeterminate`]. - /// - The [`StandardOutputTerminal`] is computed from Triton VM's public output and the sampled - /// indeterminate [`StandardOutputIndeterminate`]. - /// - The [`LookupTablePublicTerminal`] is computed from the publicly known and constant - /// lookup table and the sampled indeterminate [`LookupTablePublicIndeterminate`]. - /// - The [`CompressedProgramDigest`] is computed from the program to be executed and the - /// sampled indeterminate [`CompressProgramDigestIndeterminate`]. + /// - The [`StandardInputTerminal`][ChallengeId::StandardInputTerminal] is computed + /// from Triton VM's public input and the sampled indeterminate + /// [`StandardInputIndeterminate`][ChallengeId::StandardInputIndeterminate]. + /// - The [`StandardOutputTerminal`][ChallengeId::StandardOutputTerminal] is computed + /// from Triton VM's public output and the sampled indeterminate + /// [`StandardOutputIndeterminate`][ChallengeId::StandardOutputIndeterminate]. + /// - The [`LookupTablePublicTerminal`][ChallengeId::LookupTablePublicTerminal] is + /// computed from the publicly known and constant lookup table and the sampled indeterminate + /// [`LookupTablePublicIndeterminate`][ChallengeId::LookupTablePublicIndeterminate]. + /// - The [`CompressedProgramDigest`][ChallengeId::CompressedProgramDigest] is computed + /// from the program to be executed and the sampled indeterminate + /// [`CompressProgramDigestIndeterminate`][ChallengeId::CompressProgramDigestIndeterminate]. pub const SAMPLE_COUNT: usize = Self::COUNT - ChallengeId::NUM_DERIVED_CHALLENGES; pub fn new(mut challenges: Vec, claim: &Claim) -> Self { diff --git a/triton-vm/src/stark.rs b/triton-vm/src/stark.rs index 496ea79af..19bc91723 100644 --- a/triton-vm/src/stark.rs +++ b/triton-vm/src/stark.rs @@ -563,8 +563,8 @@ impl Stark { /// length of the execution trace and the FRI expansion factor, a security parameter. /// /// In principle, the FRI domain is also influenced by the AIR's degree - /// (see [`TARGET_DEGREE`]). However, by segmenting the quotient polynomial into - /// [`TARGET_DEGREE`]-many parts, that influence is mitigated. + /// (see [`air::TARGET_DEGREE`]). However, by segmenting the quotient polynomial into + /// `TARGET_DEGREE`-many parts, that influence is mitigated. pub fn derive_fri(&self, padded_height: usize) -> fri::SetupResult { let interpolant_degree = interpolant_degree(padded_height, self.num_trace_randomizers); let interpolant_codeword_length = interpolant_degree as usize + 1; diff --git a/triton-vm/src/table.rs b/triton-vm/src/table.rs index bc1ace5f6..146d87e0a 100644 --- a/triton-vm/src/table.rs +++ b/triton-vm/src/table.rs @@ -78,18 +78,14 @@ pub enum ConstraintType { Terminal, } -/// A single row of a [`MasterBaseTable`][table]. +/// A single row of a [`MasterBaseTable`]. /// /// Usually, the elements in the table are [`BFieldElement`]s. For out-of-domain rows, which is /// relevant for “Domain Extension to Eliminate Pretenders” (DEEP), the elements are /// [`XFieldElement`]s. -/// -/// [table]: master_table::MasterBaseTable pub type BaseRow = [T; MasterBaseTable::NUM_COLUMNS]; -/// A single row of a [`MasterExtensionTable`][table]. -/// -/// [table]: master_table::MasterExtTable +/// A single row of a [`MasterExtTable`]. pub type ExtensionRow = [XFieldElement; MasterExtTable::NUM_COLUMNS]; /// An element of the split-up quotient polynomial. diff --git a/triton-vm/src/table/master_table.rs b/triton-vm/src/table/master_table.rs index 2d3fd548b..c1ac9e2ef 100644 --- a/triton-vm/src/table/master_table.rs +++ b/triton-vm/src/table/master_table.rs @@ -122,7 +122,7 @@ use crate::table::*; /// [`MasterExtensionTable`][master_ext_table] but does induce a nonzero number of constraints /// and thus terms in the [quotient combination][all_quotients_combined]. /// -/// [cross_arg]: cross_table_argument::GrandCrossTableArg +/// [cross_arg]: air::cross_table_argument::GrandCrossTableArg /// [overwrite_cache]: crate::config::overwrite_lde_trace_caching_to /// [lde]: Self::low_degree_extend_all_columns /// [quot_table]: Self::quotient_domain_table @@ -1410,15 +1410,36 @@ mod tests { assert!(DEGREE_LOWERING_TARGETS.contains(&Some(air::TARGET_DEGREE))); let mut all_table_info = [ - ("program-table.md", table_widths::()), - ("processor-table.md", table_widths::()), - ("operational-stack-table.md", table_widths::()), - ("random-access-memory-table.md", table_widths::()), - ("jump-stack-table.md", table_widths::()), - ("hash-table.md", table_widths::()), - ("cascade-table.md", table_widths::()), - ("lookup-table.md", table_widths::()), - ("u32-table.md", table_widths::()), + ( + "[ProgramTable](program-table.md)", + table_widths::(), + ), + ( + "[ProcessorTable](processor-table.md)", + table_widths::(), + ), + ( + "[OpStackTable](operational-stack-table.md)", + table_widths::(), + ), + ( + "[RamTable](random-access-memory-table.md)", + table_widths::(), + ), + ( + "[JumpStackTable](jump-stack-table.md)", + table_widths::(), + ), + ("[HashTable](hash-table.md)", table_widths::()), + ( + "[CascadeTable](cascade-table.md)", + table_widths::(), + ), + ( + "[LookupTable](lookup-table.md)", + table_widths::(), + ), + ("[U32Table](u32-table.md)", table_widths::()), ] .map(|(description, (main_width, aux_width))| { ( From f1b494cb54e2a0b0d14b12733a04dfe8b9ca98e6 Mon Sep 17 00:00:00 2001 From: Jan Ferdinand Sauer Date: Mon, 9 Sep 2024 11:44:44 +0200 Subject: [PATCH 11/15] style: introduce local type alias changelog: ignore --- triton-air/src/table/hash.rs | 275 +++++++++++++++-------------------- 1 file changed, 121 insertions(+), 154 deletions(-) diff --git a/triton-air/src/table/hash.rs b/triton-air/src/table/hash.rs index 4f463a82e..06566c988 100644 --- a/triton-air/src/table/hash.rs +++ b/triton-air/src/table/hash.rs @@ -39,6 +39,9 @@ pub type PermutationTrace = [[BFieldElement; tip5::STATE_SIZE]; PERMUTATION_TRAC #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct HashTable; +type MainColumn = ::MainColumn; +type AuxColumn = ::AuxColumn; + impl HashTable { /// Get the MDS matrix's entry in row `row_idx` and column `col_idx`. const fn mds_matrix_entry(row_idx: usize, col_idx: usize) -> BFieldElement { @@ -163,24 +166,24 @@ impl HashTable { /// `Constant0` through `Constant15`. /// /// [col]: crate::table_column::HashBaseTableColumn - pub fn round_constant_column_by_index(index: usize) -> ::MainColumn { + pub fn round_constant_column_by_index(index: usize) -> MainColumn { match index { - 0 => ::MainColumn::Constant0, - 1 => ::MainColumn::Constant1, - 2 => ::MainColumn::Constant2, - 3 => ::MainColumn::Constant3, - 4 => ::MainColumn::Constant4, - 5 => ::MainColumn::Constant5, - 6 => ::MainColumn::Constant6, - 7 => ::MainColumn::Constant7, - 8 => ::MainColumn::Constant8, - 9 => ::MainColumn::Constant9, - 10 => ::MainColumn::Constant10, - 11 => ::MainColumn::Constant11, - 12 => ::MainColumn::Constant12, - 13 => ::MainColumn::Constant13, - 14 => ::MainColumn::Constant14, - 15 => ::MainColumn::Constant15, + 0 => MainColumn::Constant0, + 1 => MainColumn::Constant1, + 2 => MainColumn::Constant2, + 3 => MainColumn::Constant3, + 4 => MainColumn::Constant4, + 5 => MainColumn::Constant5, + 6 => MainColumn::Constant6, + 7 => MainColumn::Constant7, + 8 => MainColumn::Constant8, + 9 => MainColumn::Constant9, + 10 => MainColumn::Constant10, + 11 => MainColumn::Constant11, + 12 => MainColumn::Constant12, + 13 => MainColumn::Constant13, + 14 => MainColumn::Constant14, + 15 => MainColumn::Constant15, _ => panic!("invalid constant column index"), } } @@ -192,104 +195,68 @@ impl HashTable { /// States with indices 0 through 3 have to be assembled from the respective limbs; /// see [`Self::re_compose_states_0_through_3_before_lookup`] /// or [`Self::re_compose_16_bit_limbs`]. - fn state_column_by_index(index: usize) -> ::MainColumn { + fn state_column_by_index(index: usize) -> MainColumn { match index { - 4 => ::MainColumn::State4, - 5 => ::MainColumn::State5, - 6 => ::MainColumn::State6, - 7 => ::MainColumn::State7, - 8 => ::MainColumn::State8, - 9 => ::MainColumn::State9, - 10 => ::MainColumn::State10, - 11 => ::MainColumn::State11, - 12 => ::MainColumn::State12, - 13 => ::MainColumn::State13, - 14 => ::MainColumn::State14, - 15 => ::MainColumn::State15, + 4 => MainColumn::State4, + 5 => MainColumn::State5, + 6 => MainColumn::State6, + 7 => MainColumn::State7, + 8 => MainColumn::State8, + 9 => MainColumn::State9, + 10 => MainColumn::State10, + 11 => MainColumn::State11, + 12 => MainColumn::State12, + 13 => MainColumn::State13, + 14 => MainColumn::State14, + 15 => MainColumn::State15, _ => panic!("invalid state column index"), } } - fn indicate_column_index_in_base_row(column: ::MainColumn) -> SingleRowIndicator { + fn indicate_column_index_in_base_row(column: MainColumn) -> SingleRowIndicator { BaseRow(column.master_base_table_index()) } - fn indicate_column_index_in_current_base_row( - column: ::MainColumn, - ) -> DualRowIndicator { + fn indicate_column_index_in_current_base_row(column: MainColumn) -> DualRowIndicator { CurrentBaseRow(column.master_base_table_index()) } - fn indicate_column_index_in_next_base_row( - column: ::MainColumn, - ) -> DualRowIndicator { + fn indicate_column_index_in_next_base_row(column: MainColumn) -> DualRowIndicator { NextBaseRow(column.master_base_table_index()) } fn re_compose_states_0_through_3_before_lookup( circuit_builder: &ConstraintCircuitBuilder, - main_row_to_input_indicator: fn(::MainColumn) -> II, + main_row_to_input_indicator: fn(MainColumn) -> II, ) -> [ConstraintCircuitMonad; 4] { let input = |input_indicator: II| circuit_builder.input(input_indicator); let state_0 = Self::re_compose_16_bit_limbs( circuit_builder, - input(main_row_to_input_indicator( - ::MainColumn::State0HighestLkIn, - )), - input(main_row_to_input_indicator( - ::MainColumn::State0MidHighLkIn, - )), - input(main_row_to_input_indicator( - ::MainColumn::State0MidLowLkIn, - )), - input(main_row_to_input_indicator( - ::MainColumn::State0LowestLkIn, - )), + input(main_row_to_input_indicator(MainColumn::State0HighestLkIn)), + input(main_row_to_input_indicator(MainColumn::State0MidHighLkIn)), + input(main_row_to_input_indicator(MainColumn::State0MidLowLkIn)), + input(main_row_to_input_indicator(MainColumn::State0LowestLkIn)), ); let state_1 = Self::re_compose_16_bit_limbs( circuit_builder, - input(main_row_to_input_indicator( - ::MainColumn::State1HighestLkIn, - )), - input(main_row_to_input_indicator( - ::MainColumn::State1MidHighLkIn, - )), - input(main_row_to_input_indicator( - ::MainColumn::State1MidLowLkIn, - )), - input(main_row_to_input_indicator( - ::MainColumn::State1LowestLkIn, - )), + input(main_row_to_input_indicator(MainColumn::State1HighestLkIn)), + input(main_row_to_input_indicator(MainColumn::State1MidHighLkIn)), + input(main_row_to_input_indicator(MainColumn::State1MidLowLkIn)), + input(main_row_to_input_indicator(MainColumn::State1LowestLkIn)), ); let state_2 = Self::re_compose_16_bit_limbs( circuit_builder, - input(main_row_to_input_indicator( - ::MainColumn::State2HighestLkIn, - )), - input(main_row_to_input_indicator( - ::MainColumn::State2MidHighLkIn, - )), - input(main_row_to_input_indicator( - ::MainColumn::State2MidLowLkIn, - )), - input(main_row_to_input_indicator( - ::MainColumn::State2LowestLkIn, - )), + input(main_row_to_input_indicator(MainColumn::State2HighestLkIn)), + input(main_row_to_input_indicator(MainColumn::State2MidHighLkIn)), + input(main_row_to_input_indicator(MainColumn::State2MidLowLkIn)), + input(main_row_to_input_indicator(MainColumn::State2LowestLkIn)), ); let state_3 = Self::re_compose_16_bit_limbs( circuit_builder, - input(main_row_to_input_indicator( - ::MainColumn::State3HighestLkIn, - )), - input(main_row_to_input_indicator( - ::MainColumn::State3MidHighLkIn, - )), - input(main_row_to_input_indicator( - ::MainColumn::State3MidLowLkIn, - )), - input(main_row_to_input_indicator( - ::MainColumn::State3LowestLkIn, - )), + input(main_row_to_input_indicator(MainColumn::State3HighestLkIn)), + input(main_row_to_input_indicator(MainColumn::State3MidHighLkIn)), + input(main_row_to_input_indicator(MainColumn::State3MidLowLkIn)), + input(main_row_to_input_indicator(MainColumn::State3LowestLkIn)), ); [state_0, state_1, state_2, state_3] } @@ -302,55 +269,55 @@ impl HashTable { ) { let constant = |c: u64| circuit_builder.b_constant(c); let b_constant = |c| circuit_builder.b_constant(c); - let current_main_row = |column_idx: ::MainColumn| { + let current_main_row = |column_idx: MainColumn| { circuit_builder.input(CurrentBaseRow(column_idx.master_base_table_index())) }; - let next_main_row = |column_idx: ::MainColumn| { + let next_main_row = |column_idx: MainColumn| { circuit_builder.input(NextBaseRow(column_idx.master_base_table_index())) }; let state_0_after_lookup = Self::re_compose_16_bit_limbs( circuit_builder, - current_main_row(::MainColumn::State0HighestLkOut), - current_main_row(::MainColumn::State0MidHighLkOut), - current_main_row(::MainColumn::State0MidLowLkOut), - current_main_row(::MainColumn::State0LowestLkOut), + current_main_row(MainColumn::State0HighestLkOut), + current_main_row(MainColumn::State0MidHighLkOut), + current_main_row(MainColumn::State0MidLowLkOut), + current_main_row(MainColumn::State0LowestLkOut), ); let state_1_after_lookup = Self::re_compose_16_bit_limbs( circuit_builder, - current_main_row(::MainColumn::State1HighestLkOut), - current_main_row(::MainColumn::State1MidHighLkOut), - current_main_row(::MainColumn::State1MidLowLkOut), - current_main_row(::MainColumn::State1LowestLkOut), + current_main_row(MainColumn::State1HighestLkOut), + current_main_row(MainColumn::State1MidHighLkOut), + current_main_row(MainColumn::State1MidLowLkOut), + current_main_row(MainColumn::State1LowestLkOut), ); let state_2_after_lookup = Self::re_compose_16_bit_limbs( circuit_builder, - current_main_row(::MainColumn::State2HighestLkOut), - current_main_row(::MainColumn::State2MidHighLkOut), - current_main_row(::MainColumn::State2MidLowLkOut), - current_main_row(::MainColumn::State2LowestLkOut), + current_main_row(MainColumn::State2HighestLkOut), + current_main_row(MainColumn::State2MidHighLkOut), + current_main_row(MainColumn::State2MidLowLkOut), + current_main_row(MainColumn::State2LowestLkOut), ); let state_3_after_lookup = Self::re_compose_16_bit_limbs( circuit_builder, - current_main_row(::MainColumn::State3HighestLkOut), - current_main_row(::MainColumn::State3MidHighLkOut), - current_main_row(::MainColumn::State3MidLowLkOut), - current_main_row(::MainColumn::State3LowestLkOut), + current_main_row(MainColumn::State3HighestLkOut), + current_main_row(MainColumn::State3MidHighLkOut), + current_main_row(MainColumn::State3MidLowLkOut), + current_main_row(MainColumn::State3LowestLkOut), ); let state_part_before_power_map: [_; tip5::STATE_SIZE - tip5::NUM_SPLIT_AND_LOOKUP] = [ - ::MainColumn::State4, - ::MainColumn::State5, - ::MainColumn::State6, - ::MainColumn::State7, - ::MainColumn::State8, - ::MainColumn::State9, - ::MainColumn::State10, - ::MainColumn::State11, - ::MainColumn::State12, - ::MainColumn::State13, - ::MainColumn::State14, - ::MainColumn::State15, + MainColumn::State4, + MainColumn::State5, + MainColumn::State6, + MainColumn::State7, + MainColumn::State8, + MainColumn::State9, + MainColumn::State10, + MainColumn::State11, + MainColumn::State12, + MainColumn::State13, + MainColumn::State14, + MainColumn::State15, ] .map(current_main_row); @@ -392,22 +359,22 @@ impl HashTable { } let round_constants: [_; tip5::STATE_SIZE] = [ - ::MainColumn::Constant0, - ::MainColumn::Constant1, - ::MainColumn::Constant2, - ::MainColumn::Constant3, - ::MainColumn::Constant4, - ::MainColumn::Constant5, - ::MainColumn::Constant6, - ::MainColumn::Constant7, - ::MainColumn::Constant8, - ::MainColumn::Constant9, - ::MainColumn::Constant10, - ::MainColumn::Constant11, - ::MainColumn::Constant12, - ::MainColumn::Constant13, - ::MainColumn::Constant14, - ::MainColumn::Constant15, + MainColumn::Constant0, + MainColumn::Constant1, + MainColumn::Constant2, + MainColumn::Constant3, + MainColumn::Constant4, + MainColumn::Constant5, + MainColumn::Constant6, + MainColumn::Constant7, + MainColumn::Constant8, + MainColumn::Constant9, + MainColumn::Constant10, + MainColumn::Constant11, + MainColumn::Constant12, + MainColumn::Constant13, + MainColumn::Constant14, + MainColumn::Constant15, ] .map(current_main_row); @@ -427,21 +394,21 @@ impl HashTable { state_1_next, state_2_next, state_3_next, - next_main_row(::MainColumn::State4), - next_main_row(::MainColumn::State5), - next_main_row(::MainColumn::State6), - next_main_row(::MainColumn::State7), - next_main_row(::MainColumn::State8), - next_main_row(::MainColumn::State9), - next_main_row(::MainColumn::State10), - next_main_row(::MainColumn::State11), - next_main_row(::MainColumn::State12), - next_main_row(::MainColumn::State13), - next_main_row(::MainColumn::State14), - next_main_row(::MainColumn::State15), + next_main_row(MainColumn::State4), + next_main_row(MainColumn::State5), + next_main_row(MainColumn::State6), + next_main_row(MainColumn::State7), + next_main_row(MainColumn::State8), + next_main_row(MainColumn::State9), + next_main_row(MainColumn::State10), + next_main_row(MainColumn::State11), + next_main_row(MainColumn::State12), + next_main_row(MainColumn::State13), + next_main_row(MainColumn::State14), + next_main_row(MainColumn::State15), ]; - let round_number_next = next_main_row(::MainColumn::RoundNumber); + let round_number_next = next_main_row(MainColumn::RoundNumber); let hash_function_round_correctly_performs_update = state_after_round_constant_addition .into_iter() .zip_eq(state_next.clone()) @@ -457,20 +424,20 @@ impl HashTable { fn cascade_log_derivative_update_circuit( circuit_builder: &ConstraintCircuitBuilder, - look_in_column: ::MainColumn, - look_out_column: ::MainColumn, - cascade_log_derivative_column: ::AuxColumn, + look_in_column: MainColumn, + look_out_column: MainColumn, + cascade_log_derivative_column: AuxColumn, ) -> ConstraintCircuitMonad { let challenge = |c| circuit_builder.challenge(c); let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b()); let constant = |c: u32| circuit_builder.b_constant(c); - let next_main_row = |column_idx: ::MainColumn| { + let next_main_row = |column_idx: MainColumn| { circuit_builder.input(NextBaseRow(column_idx.master_base_table_index())) }; - let current_aux_row = |column_idx: ::AuxColumn| { + let current_aux_row = |column_idx: AuxColumn| { circuit_builder.input(CurrentExtRow(column_idx.master_ext_table_index())) }; - let next_aux_row = |column_idx: ::AuxColumn| { + let next_aux_row = |column_idx: AuxColumn| { circuit_builder.input(NextExtRow(column_idx.master_ext_table_index())) }; @@ -478,9 +445,9 @@ impl HashTable { let look_in_weight = challenge(ChallengeId::HashCascadeLookInWeight); let look_out_weight = challenge(ChallengeId::HashCascadeLookOutWeight); - let ci_next = next_main_row(::MainColumn::CI); - let mode_next = next_main_row(::MainColumn::Mode); - let round_number_next = next_main_row(::MainColumn::RoundNumber); + let ci_next = next_main_row(MainColumn::CI); + let mode_next = next_main_row(MainColumn::Mode); + let round_number_next = next_main_row(MainColumn::RoundNumber); let cascade_log_derivative = current_aux_row(cascade_log_derivative_column); let cascade_log_derivative_next = next_aux_row(cascade_log_derivative_column); From d88a76131bb70ba2ccace46197bb1b7dd884852b Mon Sep 17 00:00:00 2001 From: Jan Ferdinand Sauer Date: Mon, 9 Sep 2024 13:09:09 +0200 Subject: [PATCH 12/15] build: Remove Makefile The Makefile simplified building Triton VM despite the cyclic dependency in its build process. Now, the cyclic build dependency is broken. Cargo is the recommended build tool for Triton VM. --- .gitignore | 3 --- Makefile | 72 ------------------------------------------------------ 2 files changed, 75 deletions(-) delete mode 100644 Makefile diff --git a/.gitignore b/.gitignore index 3e914a14b..6b7b618cf 100644 --- a/.gitignore +++ b/.gitignore @@ -29,9 +29,6 @@ Cargo.lock # MSVC Windows builds of rustc generate these, which store debugging information *.pdb -# Added by Makefile configuration -/makefile-target - ### VisualStudioCode ### .vscode/* !.vscode/settings.json diff --git a/Makefile b/Makefile deleted file mode 100644 index 02f21d50f..000000000 --- a/Makefile +++ /dev/null @@ -1,72 +0,0 @@ -# Treat `cargo clippy` warnings as errors. -CLIPPY_ARGS = --all-targets -- -D warnings - -# Fail if `cargo fmt` changes anything. -FMT_ARGS = --all -- --check - -# Treat all warnings as errors -export RUSTFLAGS = -Dwarnings - -# Set another target dir than default to avoid builds from `make` -# to invalidate cache from barebones use of `cargo` commands. -# The cache is cleared when a new `RUSTFLAGS` value is encountered, -# so to prevent the two builds from interfering, we use two dirs. -export CARGO_TARGET_DIR=./makefile-target - -# By first building tests, and consequently building constraints, before -# running `fmt` and `clippy`, the auto-generated constraints are exposed -# to `fmt` and `clippy`. -default: build-constraints build-tests build-bench fmt-only clippy-only clean-constraints - -# Run `make all` when the constraints are already in place. -all: test build-bench fmt-only clippy-only - -# Alternative to `cargo build --all-targets` -build: build-constraints - cargo build --all-targets - -# Alternative to `cargo test --all-targets` -test: - cargo test --all-targets - -# Alternative to `cargo bench --all-targets` -bench: build-constraints - cargo bench --all-targets - -# Alternative to `cargo clippy ...` -clippy: build-constraints - cargo clippy $(CLIPPY_ARGS) - -# Alternative to `cargo fmt ...` -fmt-check: build-constraints - cargo fmt $(FMT_ARGS) - -# Alternative to `cargo clean` -clean: - cargo clean - make clean-constraints - -# Auxiliary targets -# -# Assume constraints are compiled. - -build-tests: - cargo test --all-targets --no-run - -build-bench: - cargo bench --all-targets --no-run - -build-constraints: - cargo run --bin constraint-evaluation-generator - -clean-constraints: - git restore --staged triton-vm/src/table/constraints.rs - git restore --staged triton-vm/src/table/degree_lowering_table.rs - git restore triton-vm/src/table/constraints.rs - git restore triton-vm/src/table/degree_lowering_table.rs - -fmt-only: - cargo fmt $(FMT_ARGS) - -clippy-only: - cargo clippy $(CLIPPY_ARGS) From 278244a725a5fa2ad2bd0dfb00d86f40e5cbf15c Mon Sep 17 00:00:00 2001 From: Jan Ferdinand Sauer Date: Mon, 9 Sep 2024 13:36:10 +0200 Subject: [PATCH 13/15] refactor: Add commonly used types to prelude changelog: ignore --- triton-vm/src/{air.rs => constraints.rs} | 5 +++-- triton-vm/src/error.rs | 5 +++++ triton-vm/src/lib.rs | 8 ++++---- triton-vm/src/memory_layout.rs | 4 ++-- triton-vm/src/prelude.rs | 13 +++++++++++-- triton-vm/src/table/master_table.rs | 4 ++-- 6 files changed, 27 insertions(+), 12 deletions(-) rename triton-vm/src/{air.rs => constraints.rs} (98%) diff --git a/triton-vm/src/air.rs b/triton-vm/src/constraints.rs similarity index 98% rename from triton-vm/src/air.rs rename to triton-vm/src/constraints.rs index 91371b0d6..e1acdd888 100644 --- a/triton-vm/src/air.rs +++ b/triton-vm/src/constraints.rs @@ -13,8 +13,6 @@ mod test { use twenty_first::prelude::x_field_element::EXTENSION_DEGREE; use twenty_first::prelude::*; - use crate::air::dynamic_air_constraint_evaluation_tasm; - use crate::air::static_air_constraint_evaluation_tasm; use crate::challenges::Challenges; use crate::memory_layout; use crate::memory_layout::DynamicTasmConstraintEvaluationMemoryLayout; @@ -26,6 +24,9 @@ mod test { use crate::table::NUM_AUX_COLUMNS; use crate::table::NUM_MAIN_COLUMNS; + use super::dynamic_air_constraint_evaluation_tasm; + use super::static_air_constraint_evaluation_tasm; + #[derive(Debug, Clone, test_strategy::Arbitrary)] struct ConstraintEvaluationPoint { #[strategy(vec(arb(), NUM_MAIN_COLUMNS))] diff --git a/triton-vm/src/error.rs b/triton-vm/src/error.rs index 50ce9306c..5f377ac80 100644 --- a/triton-vm/src/error.rs +++ b/triton-vm/src/error.rs @@ -1,4 +1,9 @@ pub use isa::error::InstructionError; +pub use isa::error::NumberOfWordsError; +pub use isa::error::OpStackElementError; +pub use isa::error::OpStackError; +pub use isa::error::ParseError; +pub use isa::error::ProgramDecodingError; use std::fmt; use std::fmt::Display; diff --git a/triton-vm/src/lib.rs b/triton-vm/src/lib.rs index 5c4e07d1a..33c316d9e 100644 --- a/triton-vm/src/lib.rs +++ b/triton-vm/src/lib.rs @@ -141,7 +141,7 @@ //! `halt` as its last instruction. Certain instructions, such as `assert`, `invert`, or the u32 //! instructions, can also cause the VM to crash. Upon crashing Triton VM, methods like //! [`run`](VM::run) and [`trace_execution`](VM::trace_execution) will return a -//! [`VMError`][vm_error]. This can be helpful for debugging. +//! [`VMError`]. This can be helpful for debugging. //! //! ``` //! # use triton_vm::*; @@ -152,11 +152,11 @@ //! // inspect the VM state //! eprintln!("{vm_error}"); //! ``` -//! -//! [vm_error]: error::VMError #![recursion_limit = "4096"] +pub use air; +pub use isa; pub use twenty_first; use isa::program::Program; @@ -165,10 +165,10 @@ use crate::error::ProvingError; use crate::prelude::*; pub mod aet; -pub mod air; pub mod arithmetic_domain; pub mod challenges; pub mod config; +pub mod constraints; pub mod error; pub mod example_programs; pub mod execution_trace_profiler; diff --git a/triton-vm/src/memory_layout.rs b/triton-vm/src/memory_layout.rs index 57f9bd2dc..0e816b0d6 100644 --- a/triton-vm/src/memory_layout.rs +++ b/triton-vm/src/memory_layout.rs @@ -11,7 +11,7 @@ use twenty_first::prelude::*; /// Memory layout guarantees for the [Triton assembly AIR constraint evaluator][tasm_air] /// with input lists at dynamically known memory locations. /// -/// [tasm_air]: crate::air::dynamic_air_constraint_evaluation_tasm +/// [tasm_air]: crate::constraints::dynamic_air_constraint_evaluation_tasm #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)] pub struct DynamicTasmConstraintEvaluationMemoryLayout { /// Pointer to a region of memory that is reserved for (a) pointers to {current, @@ -29,7 +29,7 @@ pub struct DynamicTasmConstraintEvaluationMemoryLayout { /// Memory layout guarantees for the [Triton assembly AIR constraint evaluator][tasm_air] /// with input lists at statically known memory locations. /// -/// [tasm_air]: crate::air::static_air_constraint_evaluation_tasm +/// [tasm_air]: crate::constraints::static_air_constraint_evaluation_tasm #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)] pub struct StaticTasmConstraintEvaluationMemoryLayout { /// Pointer to a region of memory that is reserved for constraint evaluation. diff --git a/triton-vm/src/prelude.rs b/triton-vm/src/prelude.rs index 7f4a583cf..1773c7f60 100644 --- a/triton-vm/src/prelude.rs +++ b/triton-vm/src/prelude.rs @@ -19,14 +19,23 @@ pub use twenty_first::prelude::Digest; pub use twenty_first::prelude::Tip5; pub use twenty_first::prelude::XFieldElement; -pub use isa as triton_isa; -pub use isa::error::InstructionError; +pub use isa; pub use isa::instruction::LabelledInstruction; pub use isa::program::Program; pub use isa::triton_asm; pub use isa::triton_instr; pub use isa::triton_program; +pub use air::table::TableId; +pub use air::AIR; + +pub use crate::error::InstructionError; +pub use crate::error::NumberOfWordsError; +pub use crate::error::OpStackElementError; +pub use crate::error::OpStackError; +pub use crate::error::ParseError; +pub use crate::error::ProgramDecodingError; +pub use crate::error::VMError; pub use crate::proof::Claim; pub use crate::proof::Proof; pub use crate::stark::Stark; diff --git a/triton-vm/src/table/master_table.rs b/triton-vm/src/table/master_table.rs index c1ac9e2ef..820a7ecf6 100644 --- a/triton-vm/src/table/master_table.rs +++ b/triton-vm/src/table/master_table.rs @@ -1206,9 +1206,9 @@ mod tests { use twenty_first::math::traits::FiniteField; use twenty_first::prelude::x_field_element::EXTENSION_DEGREE; - use crate::air::dynamic_air_constraint_evaluation_tasm; - use crate::air::static_air_constraint_evaluation_tasm; use crate::arithmetic_domain::ArithmeticDomain; + use crate::constraints::dynamic_air_constraint_evaluation_tasm; + use crate::constraints::static_air_constraint_evaluation_tasm; use crate::memory_layout::DynamicTasmConstraintEvaluationMemoryLayout; use crate::memory_layout::StaticTasmConstraintEvaluationMemoryLayout; use crate::shared_tests::ProgramAndInput; From 07fdf1599dc6cb6f95fed30cb44940e2c5fc29b8 Mon Sep 17 00:00:00 2001 From: Jan Ferdinand Sauer Date: Tue, 10 Sep 2024 11:37:34 +0200 Subject: [PATCH 14/15] =?UTF-8?q?chore:=20Use=20=E2=80=9Cmain=E2=80=9D=20&?= =?UTF-8?q?=20=E2=80=9Caux=20over=20=E2=80=9Cbase=E2=80=9D=20&=20=E2=80=9C?= =?UTF-8?q?ext=E2=80=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Over the years, the terminology of “main” and “auxiliary” tables has become dominant over our terminology of “base” and “extension” tables. changelog: ignore --- specification/src/arithmetization.md | 10 +- specification/src/cascade-table.md | 8 +- .../contiguity-of-memory-pointer-regions.md | 18 +- specification/src/evaluation-argument.md | 10 +- specification/src/hash-table.md | 8 +- specification/src/jump-stack-table.md | 6 +- specification/src/lookup-argument.md | 10 +- specification/src/lookup-table.md | 8 +- specification/src/operational-stack-table.md | 6 +- specification/src/permutation-argument.md | 12 +- specification/src/processor-table.md | 6 +- specification/src/program-table.md | 8 +- .../src/proof-of-memory-consistency.md | 2 +- .../src/random-access-memory-table.md | 12 +- specification/src/u32-table.md | 6 +- triton-air/src/challenge_id.rs | 6 +- triton-air/src/cross_table_argument.rs | 144 +- triton-air/src/lib.rs | 44 +- triton-air/src/table.rs | 52 +- triton-air/src/table/cascade.rs | 43 +- triton-air/src/table/hash.rs | 109 +- triton-air/src/table/jump_stack.rs | 85 +- triton-air/src/table/lookup.rs | 73 +- triton-air/src/table/op_stack.rs | 47 +- triton-air/src/table/processor.rs | 1671 ++++++++--------- triton-air/src/table/program.rs | 105 +- triton-air/src/table/ram.rs | 118 +- triton-air/src/table/u32.rs | 57 +- triton-air/src/table_column.rs | 302 +-- triton-constraint-builder/src/codegen.rs | 112 +- triton-constraint-builder/src/lib.rs | 35 +- .../src/substitutions.rs | 162 +- triton-constraint-circuit/src/lib.rs | 171 +- triton-vm/build.rs | 4 +- triton-vm/src/aet.rs | 18 +- triton-vm/src/constraints.rs | 78 +- triton-vm/src/error.rs | 8 +- triton-vm/src/lib.rs | 8 +- triton-vm/src/memory_layout.rs | 46 +- triton-vm/src/proof_item.rs | 20 +- triton-vm/src/proof_stream.rs | 36 +- triton-vm/src/shared_tests.rs | 8 +- triton-vm/src/stark.rs | 431 +++-- triton-vm/src/table.rs | 158 +- ...{extension_table.rs => auxiliary_table.rs} | 30 +- triton-vm/src/table/cascade.rs | 43 +- triton-vm/src/table/hash.rs | 236 ++- triton-vm/src/table/jump_stack.rs | 56 +- triton-vm/src/table/lookup.rs | 34 +- triton-vm/src/table/master_table.rs | 384 ++-- triton-vm/src/table/op_stack.rs | 72 +- triton-vm/src/table/processor.rs | 202 +- triton-vm/src/table/program.rs | 74 +- triton-vm/src/table/ram.rs | 87 +- triton-vm/src/table/u32.rs | 116 +- triton-vm/src/vm.rs | 90 +- 56 files changed, 2729 insertions(+), 2976 deletions(-) rename triton-vm/src/table/{extension_table.rs => auxiliary_table.rs} (85%) diff --git a/specification/src/arithmetization.md b/specification/src/arithmetization.md index 947f41346..5e905d7f7 100644 --- a/specification/src/arithmetization.md +++ b/specification/src/arithmetization.md @@ -28,17 +28,17 @@ See “[Arguments Using Public Information](arithmetization.md#arguments-using-p ![](img/aet-relations.png) -### Base Tables +### Main Tables The values of all registers, and consequently the elements on the stack, in memory, and so on, are elements of the _B-field_, _i.e._, $\mathbb{F}_p$ where $p$ is the Oxfoi prime, $2^{64}-2^{32}+1$. All values of columns corresponding to one such register are elements of the B-Field as well. -Together, these columns are referred to as table's _base_ columns, and make up the _base table_. +Together, these columns are referred to as table's _main_ columns, and make up the _main table_. -### Extension Tables +### Auxiliary Tables The entries of a table's columns corresponding to [Permutation](permutation-argument.md), [Evaluation](evaluation-argument.md), and [Lookup Arguments](lookup-argument.md) are elements from the _X-field_ $\mathbb{F}_{p^3}$. -These columns are referred to as a table's _extension_ columns, both because the entries are elements of the X-field and because the entries can only be computed using the base tables, through an _extension_ process. -Together, these columns are referred to as a table's _extension_ columns, and make up the _extension table_. +These columns are referred to as a table's _auxiliary_ columns, both because the entries are elements of the X-field and because the entries can only be computed using the main tables, through an _auxiliary_ process. +Together, these columns are referred to as a table's _auxiliary_ columns, and make up the _auxiliary table_. ### Padding diff --git a/specification/src/cascade-table.md b/specification/src/cascade-table.md index d2f8a4402..55d0828b3 100644 --- a/specification/src/cascade-table.md +++ b/specification/src/cascade-table.md @@ -6,9 +6,9 @@ The Cascade Table facilitates the translation of limb widths. For the actual lookup of the 8-bit limbs, the [Lookup Table](lookup-table.md) is used. For a more detailed explanation and in-depth analysis, see the [Tip5 paper](https://eprint.iacr.org/2023/107.pdf). -## Base Columns +## Main Columns -The Cascade Table has 6 base columns: +The Cascade Table has 6 main columns: | name | description | |:---------------------|:--------------------------------------------------------------| @@ -19,9 +19,9 @@ The Cascade Table has 6 base columns: | `LookOutLo` | The less significant bits of the lookup output. | | `LookupMultiplicity` | The number of times the value is looked up by the Hash Table. | -## Extension Columns +## Auxiliary Columns -The Cascade Table has 2 extension columns: +The Cascade Table has 2 auxiliary columns: - `HashTableServerLogDerivative`, the (running sum of the) logarithmic derivative for the Lookup Argument with the Hash Table. In every row, the sum accumulates `LookupMultiplicity / (🧺 - Combo)` where 🧺 is a verifier-supplied challenge diff --git a/specification/src/contiguity-of-memory-pointer-regions.md b/specification/src/contiguity-of-memory-pointer-regions.md index fd280b062..1cb4d689a 100644 --- a/specification/src/contiguity-of-memory-pointer-regions.md +++ b/specification/src/contiguity-of-memory-pointer-regions.md @@ -19,7 +19,7 @@ Analogously to the Op Stack Table, the Jump Stack's memory pointer `jsp` can onl ## Contiguity for RAM Table The *Contiguity Argument* for the RAM table establishes that all RAM pointer regions start with distinct values. -It is easy to ignore _consecutive_ duplicates in the list of all RAM pointers using one additional base column. +It is easy to ignore _consecutive_ duplicates in the list of all RAM pointers using one additional main column. This allows identification of the RAM pointer values at the regions' boundaries, $A$. The Contiguity Argument then shows that the list $A$ contains no duplicates. For this, it uses [Bézout's identity for univariate polynomials](https://en.wikipedia.org/wiki/Polynomial_greatest_common_divisor#B%C3%A9zout's_identity_and_extended_GCD_algorithm). @@ -36,11 +36,11 @@ This implies that all roots of $f_A(X)$ have multiplicity 1, which holds if and The following columns and constraints are needed for the Contiguity Argument: - - Base column `iord` and two deterministic transition constraints enable conditioning on a changed memory pointer. - - Base columns `bcpc0` and `bcpc1` and two deterministic transition constraints contain and constrain the symbolic Bézout coefficient polynomials' coefficients. - - Extension column `rpp` is a running product similar to that of a conditioned [permutation argument](permutation-argument.md). A randomized transition constraint verifies the correct accumulation of factors for updating this column. - - Extension column `fd` is the formal derivative of `rpp`. A randomized transition constraint verifies the correct application of the product rule of differentiation to update this column. - - Extension columns `bc0` and `bc1` build up the Bézout coefficient polynomials based on the corresponding base columns, `bcpc0` and `bcpc1`. + - Main column `iord` and two deterministic transition constraints enable conditioning on a changed memory pointer. + - Main columns `bcpc0` and `bcpc1` and two deterministic transition constraints contain and constrain the symbolic Bézout coefficient polynomials' coefficients. + - Auxiliary column `rpp` is a running product similar to that of a conditioned [permutation argument](permutation-argument.md). A randomized transition constraint verifies the correct accumulation of factors for updating this column. + - Auxiliary column `fd` is the formal derivative of `rpp`. A randomized transition constraint verifies the correct application of the product rule of differentiation to update this column. + - Auxiliary columns `bc0` and `bc1` build up the Bézout coefficient polynomials based on the corresponding main columns, `bcpc0` and `bcpc1`. Two randomized transition constraints enforce the correct build-up of the Bézout coefficient polynomials. - A terminal constraint takes the weighted sum of the running product and the formal derivative, where the weights are the Bézout coefficient polynomials, and equates it to one. This equation asserts the Bézout relation. @@ -57,7 +57,7 @@ Columns not needed for establishing memory consistency are not displayed. | $c$ | 0 | $k$ | $n$ | $(X - a)(X - b)(X - c)$ | $q(X)$ | $jX + k$ | $\ell X^2 + mX + n$ | | $c$ | - | $k$ | $n$ | $(X - a)(X - b)(X - c)$ | $q(X)$ | $jX + k$ | $\ell X^2 + mX + n$ | -The values contained in the extension columns are undetermined until the verifier's challenge $\alpha$ is known; before that happens it is worthwhile to present the polynomial expressions in $X$, anticipating the substitution $X \mapsto \alpha$. The constraints are articulated relative to `α`. +The values contained in the auxiliary columns are undetermined until the verifier's challenge $\alpha$ is known; before that happens it is worthwhile to present the polynomial expressions in $X$, anticipating the substitution $X \mapsto \alpha$. The constraints are articulated relative to `α`. The inverse of RAMP difference `iord` takes the inverse of the difference between the current and next `ramp` values if that difference is non-zero, and zero else. This constraint corresponds to two transition constraint polynomials: @@ -87,7 +87,7 @@ $$f_\mathsf{bc0}(X) \cdot f_{\mathsf{rp}}(X) + f_\mathsf{bc1}(X) \cdot f_{\maths The prover finds $f_\mathsf{bc0}(X)$ and $f_\mathsf{bc1}(X)$ as the minimal-degree Bézout coefficients as returned by the extended Euclidean algorithm. Concretely, the degree of $f_\mathsf{bc0}(X)$ is smaller than the degree of $f_\mathsf{fd}(X)$, and the degree of $f_\mathsf{bc1}(X)$ is smaller than the degree of $f_\mathsf{rp}(X)$. -The (scalar) coefficients of the Bézout coefficient polynomials are recorded in base columns `bcpc0` and `bcpc1`, respectively. +The (scalar) coefficients of the Bézout coefficient polynomials are recorded in main columns `bcpc0` and `bcpc1`, respectively. The transition constraints for these columns enforce that the value in one such column can only change if the memory pointer `ramp` changes. However, unlike the conditional update rule enforced by the transition constraints of `rp` and `fd`, the new value is unconstrained. Concretely, the two transition constraints are: @@ -96,7 +96,7 @@ Concretely, the two transition constraints are: - `(1 - (ramp' - ramp) ⋅ iord) ⋅ (bcpc1' - bcpc1)` Additionally, `bcpc0` must initially be zero, which is enforced by an initial constraint. -This upper-bounds the degrees of the Bézout coefficient polynomials, which are built from base columns `bcpc0` and `bcpc1`. +This upper-bounds the degrees of the Bézout coefficient polynomials, which are built from main columns `bcpc0` and `bcpc1`. Two transition constraints enforce the correct build-up of the Bézout coefficient polynomials: - `(1 - (ramp' - ramp) ⋅ iord) ⋅ (bc0' - bc0) + (ramp' - ramp) ⋅ (bc0' - α ⋅ bc0 - bcpc0')` diff --git a/specification/src/evaluation-argument.md b/specification/src/evaluation-argument.md index dcf3a76e8..70e02f4ae 100644 --- a/specification/src/evaluation-argument.md +++ b/specification/src/evaluation-argument.md @@ -16,13 +16,13 @@ By the [Schwartz–Zippel lemma](https://en.wikipedia.org/wiki/Schwartz%E2%80%93 In Triton VM, the Evaluation Argument is generally used to show that (parts of) some row appear in two tables in the same order. To establish this, the prover -- commits to the base column in question,[^2] +- commits to the main column in question,[^2] - samples a random challenge $\alpha$ through the Fiat-Shamir heuristic, -- computes the _running evaluation_ of $f_A(\alpha)$ and $f_B(\alpha)$ in the respective tables' extension column. +- computes the _running evaluation_ of $f_A(\alpha)$ and $f_B(\alpha)$ in the respective tables' auxiliary column. For example, in both Table A and B: -| base column | extension column: running evaluation | +| main column | auxiliary column: running evaluation | |------------:|:-----------------------------------------------------------| | 0 | $\alpha^1 + 0\alpha^0$ | | 1 | $\alpha^2 + 0\alpha^1 + 1\alpha^0$ | @@ -33,13 +33,13 @@ It is possible to establish a subset relation by skipping over certain elements The running evaluation must incorporate the same elements in both tables. Otherwise, the Evaluation Argument will fail. -Examples for subset Evaluation Arguments can be found between the [Hash Table](hash-table.md#extension-columns) and the [Processor Table](processor-table.md#extension-colums). +Examples for subset Evaluation Arguments can be found between the [Hash Table](hash-table.md#auxiliary-columns) and the [Processor Table](processor-table.md#auxiliary-colums). --- [^1]: This depends on the length $n$ of the lists $A$ and $B$ as well as the field size. For Triton VM, $n < 2^{32}$. -The polynomials $f_A(X)$ and $f_B(X)$ are evaluated over the extension field with $p^3 \approx 2^{192}$ elements. +The polynomials $f_A(X)$ and $f_B(X)$ are evaluated over the auxiliary field with $p^3 \approx 2^{192}$ elements. The false positive rate is therefore $n / |\mathbb{F}_{p^3}| \leqslant 2^{-160}$. [^2]: See “[Compressing Multiple Elements](table-linking.md#compressing-multiple-elements).” diff --git a/specification/src/hash-table.md b/specification/src/hash-table.md index 81285727a..b3ebbad61 100644 --- a/specification/src/hash-table.md +++ b/specification/src/hash-table.md @@ -63,9 +63,9 @@ For convenience, this document occasionally refers to those states as if they we This is an alias for $(2^{48}\cdot\texttt{state\_i\_highest\_lkin} + 2^{32}\cdot\texttt{state\_i\_mid\_high\_lkin} + 2^{16}\cdot\texttt{state\_i\_mid\_low\_lkin} + \texttt{state\_i\_lowest\_lkin})\cdot R^{-1}$. -## Base Columns +## Main Columns -The Hash Table has 67 base columns: +The Hash Table has 67 main columns: - The `Mode` indicator, as described above. It takes value @@ -84,9 +84,9 @@ This column is only relevant for mode `sponge`. - 4 columns `state_i_inv` establishing correct decomposition of `state_0_*_lkin` through `state_3_*_lkin` into 16-bit wide limbs. - 16 columns `constant_i`, which hold the round constant for the round indicated by `RoundNumber`, or 0 if no round with this round number exists. -## Extension Columns +## Auxiliary Columns -The Hash Table has 20 extension columns: +The Hash Table has 20 auxiliary columns: - `RunningEvaluationReceiveChunk` for the [Evaluation Argument](evaluation-argument.md) for copying chunks of size $\texttt{rate}$ from the [Program Table](program-table.md). Relevant for [program attestation](program-attestation.md). diff --git a/specification/src/jump-stack-table.md b/specification/src/jump-stack-table.md index 0825f7aa4..9a0a12734 100644 --- a/specification/src/jump-stack-table.md +++ b/specification/src/jump-stack-table.md @@ -2,7 +2,7 @@ The Jump Stack Memory contains the underflow from the Jump Stack. -## Base Columns +## Main Columns The Jump Stack Table consists of 5 columns: 1. the cycle counter `clk` @@ -103,9 +103,9 @@ Jump Stack Table: | 14 | `bar` | 2 | `0xB3` | `0xC0` | | 15 | `return` | 2 | `0xB3` | `0xC0` | -## Extension Columns +## Auxiliary Columns -The Jump Stack Table has 2 extension columns, `rppa` and `ClockJumpDifferenceLookupClientLogDerivative`. +The Jump Stack Table has 2 auxiliary columns, `rppa` and `ClockJumpDifferenceLookupClientLogDerivative`. 1. A Permutation Argument establishes that the rows of the Jump Stack Table match with the rows in the [Processor Table](processor-table.md). The running product for this argument is contained in the `rppa` column. diff --git a/specification/src/lookup-argument.md b/specification/src/lookup-argument.md index 57a229c6f..16a2f26ae 100644 --- a/specification/src/lookup-argument.md +++ b/specification/src/lookup-argument.md @@ -69,14 +69,14 @@ $$ \sum_{i=0}^\ell \frac{1}{X - a_i} = \sum_{i=0}^n \frac{m_i}{X - b_i} $$ -To compute the sums, the lists $A$ and $B$ are base columns in the two respective tables. -Additionally, the lookup multiplicity is recorded explicitly in a base column of the lookup table. +To compute the sums, the lists $A$ and $B$ are main columns in the two respective tables. +Additionally, the lookup multiplicity is recorded explicitly in a main column of the lookup table. ## Example In Table A: -| base column A | extension column A: logarithmic derivative | +| main column A | auxiliary column A: logarithmic derivative | |--------------:|:---------------------------------------------------------------------| | 0 | $\frac{1}{\alpha - 0}$ | | 2 | $\frac{1}{\alpha - 0} + \frac{1}{\alpha - 2}$ | @@ -86,7 +86,7 @@ In Table A: And in Table B: -| base column B | multiplicity | extension column B: logarithmic derivative | +| main column B | multiplicity | auxiliary column B: logarithmic derivative | |--------------:|-------------:|:---------------------------------------------------------------------| | 0 | 1 | $\frac{1}{\alpha - 0}$ | | 1 | 1 | $\frac{1}{\alpha - 0} + \frac{1}{\alpha - 1}$ | @@ -96,7 +96,7 @@ It is possible to establish a subset relation by skipping over certain elements The logarithmic derivative must incorporate the same elements with the same multiplicity in both tables. Otherwise, the Lookup Argument will fail. -An example for a Lookup Argument can be found between the [Program Table](program-table.md) and the [Processor Table](processor-table.md#extension-colums). +An example for a Lookup Argument can be found between the [Program Table](program-table.md) and the [Processor Table](processor-table.md#auxiliary-colums). --- diff --git a/specification/src/lookup-table.md b/specification/src/lookup-table.md index bc278de4c..cf96a7406 100644 --- a/specification/src/lookup-table.md +++ b/specification/src/lookup-table.md @@ -9,9 +9,9 @@ Correct creation of the Lookup Table is guaranteed through a public-facing [Eval after sampling some challenge $X$, the verifier computes the terminal of the Evaluation Argument over the list of all the expected lookup values with respect to challenge $X$. The equality of this verifier-supplied terminal against the similarly computed, in-table part of the Evaluation Argument is checked by the Lookup Table's terminal constraint. -## Base Columns +## Main Columns -The Lookup Table has 4 base columns: +The Lookup Table has 4 main columns: | name | description | |:---------------------|:--------------------------------------------| @@ -20,9 +20,9 @@ The Lookup Table has 4 base columns: | `LookOut` | The lookup output. | | `LookupMultiplicity` | The number of times the value is looked up. | -## Extension Columns +## Auxiliary Columns -The Lookup Table has 2 extension columns: +The Lookup Table has 2 auxiliary columns: - `CascadeTableServerLogDerivative`, the (running sum of the) logarithmic derivative for the Lookup Argument with the Cascade Table. In every row, accumulates the summand `LookupMultiplicity / Combo` where `Combo` is the verifier-weighted combination of `LookIn` and `LookOut`. diff --git a/specification/src/operational-stack-table.md b/specification/src/operational-stack-table.md index 0b4b2f5b9..43d0bd08b 100644 --- a/specification/src/operational-stack-table.md +++ b/specification/src/operational-stack-table.md @@ -12,7 +12,7 @@ The sole task of the Op Stack Table is to keep underflow memory immutable. To achieve this, any read or write accesses to the underflow memory are recorded in the Op Stack Table. Read and write accesses to op stack underflow memory are a side effect of shrinking or growing the op stack. -## Base Columns +## Main Columns The Op Stack Table consists of 4 columns: 1. the cycle counter `clk` @@ -100,9 +100,9 @@ Operational Stack Table: | 8 | 1 | 10 | 44 | -## Extension Columns +## Auxiliary Columns -The Op Stack Table has 2 extension columns, `rppa` and `ClockJumpDifferenceLookupClientLogDerivative`. +The Op Stack Table has 2 auxiliary columns, `rppa` and `ClockJumpDifferenceLookupClientLogDerivative`. 1. A Permutation Argument establishes that the rows of the Op Stack Table correspond to the rows of the [Processor Table](processor-table.md). The running product for this argument is contained in the `rppa` column. diff --git a/specification/src/permutation-argument.md b/specification/src/permutation-argument.md index 77c679edd..38b93c27c 100644 --- a/specification/src/permutation-argument.md +++ b/specification/src/permutation-argument.md @@ -16,13 +16,13 @@ By the [Schwartz–Zippel lemma](https://en.wikipedia.org/wiki/Schwartz%E2%80%93 In Triton VM, the Permutation Argument is generally applied to show that the rows of one table appear in some other table without enforcing the rows' order in relation to each other. To establish this, the prover -- commits to the base column in question,[^2] +- commits to the main column in question,[^2] - samples a random challenge $\alpha$ through the Fiat-Shamir heuristic, -- computes the _running product_ of $f_A(\alpha)$ and $f_B(\alpha)$ in the respective tables' extension column. +- computes the _running product_ of $f_A(\alpha)$ and $f_B(\alpha)$ in the respective tables' auxiliary column. For example, in Table A: -| base column A | extension column A: running product | +| main column A | auxiliary column A: running product | |--------------:|:---------------------------------------------------| | 0 | $(\alpha - 0)$ | | 1 | $(\alpha - 0)(\alpha - 1)$ | @@ -31,7 +31,7 @@ For example, in Table A: And in Table B: -| base column B | extension column B: running product | +| main column B | auxiliary column B: running product | |--------------:|:---------------------------------------------------| | 2 | $(\alpha - 2)$ | | 1 | $(\alpha - 2)(\alpha - 1)$ | @@ -42,13 +42,13 @@ It is possible to establish a subset relation by skipping over certain elements The running product must incorporate the same elements in both tables. Otherwise, the Permutation Argument will fail. -An example of a subset Permutation Argument can be found between the [U32 Table](u32-table.md#extension-columns) and the [Processor Table](processor-table.md#extension-colums). +An example of a subset Permutation Argument can be found between the [U32 Table](u32-table.md#auxiliary-columns) and the [Processor Table](processor-table.md#auxiliary-colums). --- [^1]: This depends on the length $n$ of the lists $A$ and $B$ as well as the field size. For Triton VM, $n < 2^{32}$. -The polynomials $f_A(X)$ and $f_B(X)$ are evaluated over the extension field with $p^3 \approx 2^{192}$ elements. +The polynomials $f_A(X)$ and $f_B(X)$ are evaluated over the auxiliary field with $p^3 \approx 2^{192}$ elements. The false positive rate is therefore $n / |\mathbb{F}_{p^3}| \leqslant 2^{-160}$. [^2]: See “[Compressing Multiple Elements](table-linking.md#compressing-multiple-elements).” diff --git a/specification/src/processor-table.md b/specification/src/processor-table.md index b5bac4d7b..e033d51a2 100644 --- a/specification/src/processor-table.md +++ b/specification/src/processor-table.md @@ -13,14 +13,14 @@ they can compare their own program digest to the program digest of the proof the This way, a recursive verifier can easily determine if they are actually recursing, or whether the proof they are checking was generated using an entirely different program. A more detailed explanation of the mechanics can be found on the page about [program attestation](program-attestation.md). -## Base Columns +## Main Columns The processor consists of all registers defined in the [Instruction Set Architecture](isa.md). Each register is assigned a column in the processor table. -## Extension Columns +## Auxiliary Columns -The Processor Table has the following extension columns, corresponding to [Evaluation Arguments](evaluation-argument.md), [Permutation Arguments](permutation-argument.md), and [Lookup Arguments](lookup-argument.md): +The Processor Table has the following auxiliary columns, corresponding to [Evaluation Arguments](evaluation-argument.md), [Permutation Arguments](permutation-argument.md), and [Lookup Arguments](lookup-argument.md): 1. `RunningEvaluationStandardInput` for the Evaluation Argument with the input symbols. 1. `RunningEvaluationStandardOutput` for the Evaluation Argument with the output symbols. diff --git a/specification/src/program-table.md b/specification/src/program-table.md index e3d924b23..d9b9dfc05 100644 --- a/specification/src/program-table.md +++ b/specification/src/program-table.md @@ -6,9 +6,9 @@ The [processor](processor-table.md) looks up instructions and arguments using it For [program attestation](program-attestation.md), the program is [padded](program-attestation.md#mechanics) and sent to the [Hash Table](hash-table.md) in chunks of size 10, which is the $\texttt{rate}$ of the [Tip5 hash function][tip5]. Program padding is one 1 followed by the minimal number of 0’s necessary to make the padded input length a multiple of the $\texttt{rate}$[^padding]. -## Base Columns +## Main Columns -The Program Table consists of 7 base columns. +The Program Table consists of 7 main columns. Those columns marked with an asterisk (\*) are only used for [program attestation](program-attestation.md). | Column | Description | @@ -21,10 +21,10 @@ Those columns marked with an asterisk (\*) are only used for [program attestatio | \*`IsHashInputPadding` | padding indicator for absorbing the program into the Sponge | | `IsTablePadding` | padding indicator for rows only required due to the dominating length of some other table | -## Extension Columns +## Auxiliary Columns A [Lookup Argument](lookup-argument.md) with the [Processor Table](processor-table.md) establishes that the processor has loaded the correct instruction (and its argument) from program memory. -To establish the program memory's side of the Lookup Argument, the Program Table has extension column `InstructionLookupServerLogDerivative`. +To establish the program memory's side of the Lookup Argument, the Program Table has auxiliary column `InstructionLookupServerLogDerivative`. For sending the padded program to the [Hash Table](hash-table.md), a combination of two [Evaluation Arguments](evaluation-argument.md) is used. The first, `PrepareChunkRunningEvaluation`, absorbs one chunk of $\texttt{rate}$ (_i.e._ 10) instructions at a time, after which it is reset and starts absorbing again. diff --git a/specification/src/proof-of-memory-consistency.md b/specification/src/proof-of-memory-consistency.md index e5124f8f5..84b8922f6 100644 --- a/specification/src/proof-of-memory-consistency.md +++ b/specification/src/proof-of-memory-consistency.md @@ -1,6 +1,6 @@ # Proof of Memory Consistency -Whenever the Processor Table reads a value "from" a memory-like table, this value appears nondeterministically and is unconstrained by the base table AIR constraints. However, there is a permutation argument that links the Processor Table to the memory-like table in question. *The construction satisfies memory consistency if it guarantees that whenever a memory cell is read, its value is consistent with the last time that cell was written.* +Whenever the Processor Table reads a value "from" a memory-like table, this value appears nondeterministically and is unconstrained by the main table AIR constraints. However, there is a permutation argument that links the Processor Table to the memory-like table in question. *The construction satisfies memory consistency if it guarantees that whenever a memory cell is read, its value is consistent with the last time that cell was written.* The above is too informal to provide a meaningful proof for. Let's put formal meanings on the proposition and premises, before reducing the former to the latter. diff --git a/specification/src/random-access-memory-table.md b/specification/src/random-access-memory-table.md index 7195968d6..444200051 100644 --- a/specification/src/random-access-memory-table.md +++ b/specification/src/random-access-memory-table.md @@ -12,9 +12,9 @@ If some RAM address is read from before it is written to, the corresponding valu This is one of interfaces for non-deterministic input in Triton VM. Consecutive reads from any such address always returns the same value (until overwritten via `write_mem`). -## Base Columns +## Main Columns -The RAM Table has 7 base columns: +The RAM Table has 7 main columns: 1. the cycle counter `clk`, 1. the executed `instruction_type` – 0 for “write”, 1 for “read”, 2 for padding rows, 1. RAM pointer `ram_pointer`, @@ -29,13 +29,13 @@ The function of `iord` is best explained in the context of sorting the RAM Table The Bézout coefficient polynomial coefficients `bcpc0` and `bcpc1` represent the coefficients of polynomials that are needed for the [contiguity argument](memory-consistency.md#contiguity-for-ram-table). This argument establishes that all regions of constant `ram_pointer` are contiguous. -## Extension Columns +## Auxiliary Columns -The RAM Table has 6 extension columns: +The RAM Table has 6 auxiliary columns: 1. `RunningProductOfRAMP`, accumulating next row's `ram_pointer` as a root whenever `ram_pointer` changes between two rows, 1. `FormalDerivative`, the (evaluated) formal derivative of `RunningProductOfRAMP`, -1. `BezoutCoefficient0`, the (evaluated) polynomial with base column `bcpc0` as coefficients, -1. `BezoutCoefficient1`, the (evaluated) polynomial with base column `bcpc1` as coefficients, +1. `BezoutCoefficient0`, the (evaluated) polynomial with main column `bcpc0` as coefficients, +1. `BezoutCoefficient1`, the (evaluated) polynomial with main column `bcpc1` as coefficients, 1. `RunningProductPermArg`, the [Permutation Argument](permutation-argument.md) with the [Processor Table](processor-table.md), and 1. `ClockJumpDifferenceLookupClientLogDerivative`, part of [memory consistency](clock-jump-differences-and-inner-sorting.md). diff --git a/specification/src/u32-table.md b/specification/src/u32-table.md index 7ee29b76b..c9507c8c1 100644 --- a/specification/src/u32-table.md +++ b/specification/src/u32-table.md @@ -12,7 +12,7 @@ The processor's current instruction `CI` is recorded within the section and dict Crucially, the rows of the U32 table are independent of the processor's clock. Hence, the result of the instruction can be transferred into the processor within one clock cycle. -## Base Columns +## Main Columns | name | description | |:---------------------|:------------------------------------------------------------------------------------------------| @@ -98,9 +98,9 @@ A new row with `CopyFlag = 1` can only be inserted if It is impossible to create a valid proof of correct execution of Triton VM if `Bits` is 33 in any row. -## Extension Columns +## Auxiliary Columns -The U32 Table has 1 extension column, `U32LookupServerLogDerivative`. +The U32 Table has 1 auxiliary column, `U32LookupServerLogDerivative`. It corresponds to the [Lookup Argument](lookup-argument.md) with the [Processor Table](processor-table.md), establishing that whenever the processor executes a u32 instruction, the following holds: - the processor's requested left-hand side is copied into `LHS`, diff --git a/triton-air/src/challenge_id.rs b/triton-air/src/challenge_id.rs index fd8975290..04e04ebee 100644 --- a/triton-air/src/challenge_id.rs +++ b/triton-air/src/challenge_id.rs @@ -95,7 +95,7 @@ pub enum ChallengeId { /// Used by the evaluation argument [`PrepareChunkEvalArg`][prep] and in the Hash Table. /// /// [rate]: twenty_first::prelude::tip5::RATE - /// [prep]: crate::table_column::ProgramExtTableColumn::PrepareChunkRunningEvaluation + /// [prep]: crate::table_column::ProgramAuxColumn::PrepareChunkRunningEvaluation ProgramAttestationPrepareChunkIndeterminate, /// The indeterminate for the bus over which the [`RATE`][rate]-sized chunks of instructions @@ -105,8 +105,8 @@ pub enum ChallengeId { /// [`ProgramAttestationPrepareChunkIndeterminate`][ind]. /// /// [rate]: twenty_first::prelude::tip5::RATE - /// [send]: crate::table_column::ProgramExtTableColumn::SendChunkRunningEvaluation - /// [recv]: crate::table_column::HashExtTableColumn::ReceiveChunkRunningEvaluation + /// [send]: crate::table_column::ProgramAuxColumn::SendChunkRunningEvaluation + /// [recv]: crate::table_column::HashAuxColumn::ReceiveChunkRunningEvaluation /// [ind]: ChallengeId::ProgramAttestationPrepareChunkIndeterminate ProgramAttestationSendChunkIndeterminate, diff --git a/triton-air/src/cross_table_argument.rs b/triton-air/src/cross_table_argument.rs index 575087b98..c3e17e8a6 100644 --- a/triton-air/src/cross_table_argument.rs +++ b/triton-air/src/cross_table_argument.rs @@ -5,21 +5,21 @@ use constraint_circuit::ConstraintCircuitBuilder; use constraint_circuit::ConstraintCircuitMonad; use constraint_circuit::DualRowIndicator; use constraint_circuit::SingleRowIndicator; -use constraint_circuit::SingleRowIndicator::ExtRow; +use constraint_circuit::SingleRowIndicator::Aux; use twenty_first::prelude::*; use crate::challenge_id::ChallengeId; -use crate::table_column::CascadeExtTableColumn; -use crate::table_column::HashExtTableColumn; -use crate::table_column::JumpStackExtTableColumn; -use crate::table_column::LookupExtTableColumn; -use crate::table_column::MasterExtTableColumn; -use crate::table_column::OpStackExtTableColumn; -use crate::table_column::ProcessorExtTableColumn; -use crate::table_column::ProgramExtTableColumn; -use crate::table_column::RamExtTableColumn; -use crate::table_column::U32ExtTableColumn; +use crate::table_column::CascadeAuxColumn; +use crate::table_column::HashAuxColumn; +use crate::table_column::JumpStackAuxColumn; +use crate::table_column::LookupAuxColumn; +use crate::table_column::MasterAuxColumn; +use crate::table_column::OpStackAuxColumn; +use crate::table_column::ProcessorAuxColumn; +use crate::table_column::ProgramAuxColumn; +use crate::table_column::RamAuxColumn; +use crate::table_column::U32AuxColumn; pub trait CrossTableArg { fn default_initial() -> XFieldElement @@ -131,81 +131,67 @@ impl GrandCrossTableArg { circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { let challenge = |c| circuit_builder.challenge(c); - let ext_row = |col_index| circuit_builder.input(ExtRow(col_index)); + let aux_row = |col_index| circuit_builder.input(Aux(col_index)); // Closures cannot take arguments of type `impl Trait`. Hence: some more helpers. \o/ - let program_ext_row = - |column: ProgramExtTableColumn| ext_row(column.master_ext_table_index()); - let processor_ext_row = - |column: ProcessorExtTableColumn| ext_row(column.master_ext_table_index()); - let op_stack_ext_row = - |column: OpStackExtTableColumn| ext_row(column.master_ext_table_index()); - let ram_ext_row = |column: RamExtTableColumn| ext_row(column.master_ext_table_index()); - let jump_stack_ext_row = - |column: JumpStackExtTableColumn| ext_row(column.master_ext_table_index()); - let hash_ext_row = |column: HashExtTableColumn| ext_row(column.master_ext_table_index()); - let cascade_ext_row = - |column: CascadeExtTableColumn| ext_row(column.master_ext_table_index()); - let lookup_ext_row = - |column: LookupExtTableColumn| ext_row(column.master_ext_table_index()); - let u32_ext_row = |column: U32ExtTableColumn| ext_row(column.master_ext_table_index()); + let program_aux_row = |column: ProgramAuxColumn| aux_row(column.master_aux_index()); + let processor_aux_row = |column: ProcessorAuxColumn| aux_row(column.master_aux_index()); + let op_stack_aux_row = |column: OpStackAuxColumn| aux_row(column.master_aux_index()); + let ram_aux_row = |column: RamAuxColumn| aux_row(column.master_aux_index()); + let j_stack_aux_row = |column: JumpStackAuxColumn| aux_row(column.master_aux_index()); + let hash_aux_row = |column: HashAuxColumn| aux_row(column.master_aux_index()); + let cascade_aux_row = |column: CascadeAuxColumn| aux_row(column.master_aux_index()); + let lookup_aux_row = |column: LookupAuxColumn| aux_row(column.master_aux_index()); + let u32_aux_row = |column: U32AuxColumn| aux_row(column.master_aux_index()); - let program_attestation = - program_ext_row(ProgramExtTableColumn::SendChunkRunningEvaluation) - - hash_ext_row(HashExtTableColumn::ReceiveChunkRunningEvaluation); + let program_attestation = program_aux_row(ProgramAuxColumn::SendChunkRunningEvaluation) + - hash_aux_row(HashAuxColumn::ReceiveChunkRunningEvaluation); let input_to_processor = challenge(ChallengeId::StandardInputTerminal) - - processor_ext_row(ProcessorExtTableColumn::InputTableEvalArg); - let processor_to_output = processor_ext_row(ProcessorExtTableColumn::OutputTableEvalArg) + - processor_aux_row(ProcessorAuxColumn::InputTableEvalArg); + let processor_to_output = processor_aux_row(ProcessorAuxColumn::OutputTableEvalArg) - challenge(ChallengeId::StandardOutputTerminal); let instruction_lookup = - processor_ext_row(ProcessorExtTableColumn::InstructionLookupClientLogDerivative) - - program_ext_row(ProgramExtTableColumn::InstructionLookupServerLogDerivative); - let processor_to_op_stack = processor_ext_row(ProcessorExtTableColumn::OpStackTablePermArg) - - op_stack_ext_row(OpStackExtTableColumn::RunningProductPermArg); - let processor_to_ram = processor_ext_row(ProcessorExtTableColumn::RamTablePermArg) - - ram_ext_row(RamExtTableColumn::RunningProductPermArg); - let processor_to_jump_stack = - processor_ext_row(ProcessorExtTableColumn::JumpStackTablePermArg) - - jump_stack_ext_row(JumpStackExtTableColumn::RunningProductPermArg); - let hash_input = processor_ext_row(ProcessorExtTableColumn::HashInputEvalArg) - - hash_ext_row(HashExtTableColumn::HashInputRunningEvaluation); - let hash_digest = hash_ext_row(HashExtTableColumn::HashDigestRunningEvaluation) - - processor_ext_row(ProcessorExtTableColumn::HashDigestEvalArg); - let sponge = processor_ext_row(ProcessorExtTableColumn::SpongeEvalArg) - - hash_ext_row(HashExtTableColumn::SpongeRunningEvaluation); - let hash_to_cascade = cascade_ext_row(CascadeExtTableColumn::HashTableServerLogDerivative) - - hash_ext_row(HashExtTableColumn::CascadeState0HighestClientLogDerivative) - - hash_ext_row(HashExtTableColumn::CascadeState0MidHighClientLogDerivative) - - hash_ext_row(HashExtTableColumn::CascadeState0MidLowClientLogDerivative) - - hash_ext_row(HashExtTableColumn::CascadeState0LowestClientLogDerivative) - - hash_ext_row(HashExtTableColumn::CascadeState1HighestClientLogDerivative) - - hash_ext_row(HashExtTableColumn::CascadeState1MidHighClientLogDerivative) - - hash_ext_row(HashExtTableColumn::CascadeState1MidLowClientLogDerivative) - - hash_ext_row(HashExtTableColumn::CascadeState1LowestClientLogDerivative) - - hash_ext_row(HashExtTableColumn::CascadeState2HighestClientLogDerivative) - - hash_ext_row(HashExtTableColumn::CascadeState2MidHighClientLogDerivative) - - hash_ext_row(HashExtTableColumn::CascadeState2MidLowClientLogDerivative) - - hash_ext_row(HashExtTableColumn::CascadeState2LowestClientLogDerivative) - - hash_ext_row(HashExtTableColumn::CascadeState3HighestClientLogDerivative) - - hash_ext_row(HashExtTableColumn::CascadeState3MidHighClientLogDerivative) - - hash_ext_row(HashExtTableColumn::CascadeState3MidLowClientLogDerivative) - - hash_ext_row(HashExtTableColumn::CascadeState3LowestClientLogDerivative); - let cascade_to_lookup = - cascade_ext_row(CascadeExtTableColumn::LookupTableClientLogDerivative) - - lookup_ext_row(LookupExtTableColumn::CascadeTableServerLogDerivative); - let processor_to_u32 = - processor_ext_row(ProcessorExtTableColumn::U32LookupClientLogDerivative) - - u32_ext_row(U32ExtTableColumn::LookupServerLogDerivative); + processor_aux_row(ProcessorAuxColumn::InstructionLookupClientLogDerivative) + - program_aux_row(ProgramAuxColumn::InstructionLookupServerLogDerivative); + let processor_to_op_stack = processor_aux_row(ProcessorAuxColumn::OpStackTablePermArg) + - op_stack_aux_row(OpStackAuxColumn::RunningProductPermArg); + let processor_to_ram = processor_aux_row(ProcessorAuxColumn::RamTablePermArg) + - ram_aux_row(RamAuxColumn::RunningProductPermArg); + let processor_to_jump_stack = processor_aux_row(ProcessorAuxColumn::JumpStackTablePermArg) + - j_stack_aux_row(JumpStackAuxColumn::RunningProductPermArg); + let hash_input = processor_aux_row(ProcessorAuxColumn::HashInputEvalArg) + - hash_aux_row(HashAuxColumn::HashInputRunningEvaluation); + let hash_digest = hash_aux_row(HashAuxColumn::HashDigestRunningEvaluation) + - processor_aux_row(ProcessorAuxColumn::HashDigestEvalArg); + let sponge = processor_aux_row(ProcessorAuxColumn::SpongeEvalArg) + - hash_aux_row(HashAuxColumn::SpongeRunningEvaluation); + let hash_to_cascade = cascade_aux_row(CascadeAuxColumn::HashTableServerLogDerivative) + - hash_aux_row(HashAuxColumn::CascadeState0HighestClientLogDerivative) + - hash_aux_row(HashAuxColumn::CascadeState0MidHighClientLogDerivative) + - hash_aux_row(HashAuxColumn::CascadeState0MidLowClientLogDerivative) + - hash_aux_row(HashAuxColumn::CascadeState0LowestClientLogDerivative) + - hash_aux_row(HashAuxColumn::CascadeState1HighestClientLogDerivative) + - hash_aux_row(HashAuxColumn::CascadeState1MidHighClientLogDerivative) + - hash_aux_row(HashAuxColumn::CascadeState1MidLowClientLogDerivative) + - hash_aux_row(HashAuxColumn::CascadeState1LowestClientLogDerivative) + - hash_aux_row(HashAuxColumn::CascadeState2HighestClientLogDerivative) + - hash_aux_row(HashAuxColumn::CascadeState2MidHighClientLogDerivative) + - hash_aux_row(HashAuxColumn::CascadeState2MidLowClientLogDerivative) + - hash_aux_row(HashAuxColumn::CascadeState2LowestClientLogDerivative) + - hash_aux_row(HashAuxColumn::CascadeState3HighestClientLogDerivative) + - hash_aux_row(HashAuxColumn::CascadeState3MidHighClientLogDerivative) + - hash_aux_row(HashAuxColumn::CascadeState3MidLowClientLogDerivative) + - hash_aux_row(HashAuxColumn::CascadeState3LowestClientLogDerivative); + let cascade_to_lookup = cascade_aux_row(CascadeAuxColumn::LookupTableClientLogDerivative) + - lookup_aux_row(LookupAuxColumn::CascadeTableServerLogDerivative); + let processor_to_u32 = processor_aux_row(ProcessorAuxColumn::U32LookupClientLogDerivative) + - u32_aux_row(U32AuxColumn::LookupServerLogDerivative); - // Introduce new variable names to increase readability. Potentially opinionated. - let processor_cjdld = ProcessorExtTableColumn::ClockJumpDifferenceLookupServerLogDerivative; - let op_stack_cjdld = OpStackExtTableColumn::ClockJumpDifferenceLookupClientLogDerivative; - let ram_cjdld = RamExtTableColumn::ClockJumpDifferenceLookupClientLogDerivative; - let j_stack_cjdld = JumpStackExtTableColumn::ClockJumpDifferenceLookupClientLogDerivative; - let clock_jump_difference_lookup = processor_ext_row(processor_cjdld) - - op_stack_ext_row(op_stack_cjdld) - - ram_ext_row(ram_cjdld) - - jump_stack_ext_row(j_stack_cjdld); + let clock_jump_difference_lookup = + processor_aux_row(ProcessorAuxColumn::ClockJumpDifferenceLookupServerLogDerivative) + - op_stack_aux_row(OpStackAuxColumn::ClockJumpDifferenceLookupClientLogDerivative) + - ram_aux_row(RamAuxColumn::ClockJumpDifferenceLookupClientLogDerivative) + - j_stack_aux_row(JumpStackAuxColumn::ClockJumpDifferenceLookupClientLogDerivative); vec![ program_attestation, diff --git a/triton-air/src/lib.rs b/triton-air/src/lib.rs index 8046f54c2..7f4c36973 100644 --- a/triton-air/src/lib.rs +++ b/triton-air/src/lib.rs @@ -4,8 +4,8 @@ use constraint_circuit::DualRowIndicator; use constraint_circuit::SingleRowIndicator; use strum::EnumCount; -use crate::table_column::MasterBaseTableColumn; -use crate::table_column::MasterExtTableColumn; +use crate::table_column::MasterAuxColumn; +use crate::table_column::MasterMainColumn; pub mod challenge_id; pub mod cross_table_argument; @@ -29,8 +29,8 @@ pub mod table_column; pub const TARGET_DEGREE: isize = 4; pub trait AIR { - type MainColumn: MasterBaseTableColumn + EnumCount; - type AuxColumn: MasterExtTableColumn + EnumCount; + type MainColumn: MasterMainColumn + EnumCount; + type AuxColumn: MasterAuxColumn + EnumCount; fn initial_constraints( circuit_builder: &ConstraintCircuitBuilder, @@ -62,24 +62,24 @@ mod tests { implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); diff --git a/triton-air/src/table.rs b/triton-air/src/table.rs index 7e17ad4c8..207d1f0d8 100644 --- a/triton-air/src/table.rs +++ b/triton-air/src/table.rs @@ -26,7 +26,7 @@ pub mod u32; /// The total number of main columns across all tables. /// The degree lowering columns are _not_ included. -pub const NUM_BASE_COLUMNS: usize = ::MainColumn::COUNT +pub const NUM_MAIN_COLUMNS: usize = ::MainColumn::COUNT + ::MainColumn::COUNT + ::MainColumn::COUNT + ::MainColumn::COUNT @@ -39,7 +39,7 @@ pub const NUM_BASE_COLUMNS: usize = ::MainColumn::COUNT /// The total number of auxiliary columns across all tables. /// The degree lowering columns as well as any randomizer polynomials are _not_ /// included. -pub const NUM_EXT_COLUMNS: usize = ::AuxColumn::COUNT +pub const NUM_AUX_COLUMNS: usize = ::AuxColumn::COUNT + ::AuxColumn::COUNT + ::AuxColumn::COUNT + ::AuxColumn::COUNT @@ -71,30 +71,30 @@ pub const LOOKUP_TABLE_END: usize = LOOKUP_TABLE_START + ::M pub const U32_TABLE_START: usize = LOOKUP_TABLE_END; pub const U32_TABLE_END: usize = U32_TABLE_START + ::MainColumn::COUNT; -pub const EXT_PROGRAM_TABLE_START: usize = 0; -pub const EXT_PROGRAM_TABLE_END: usize = - EXT_PROGRAM_TABLE_START + ::AuxColumn::COUNT; -pub const EXT_PROCESSOR_TABLE_START: usize = EXT_PROGRAM_TABLE_END; -pub const EXT_PROCESSOR_TABLE_END: usize = - EXT_PROCESSOR_TABLE_START + ::AuxColumn::COUNT; -pub const EXT_OP_STACK_TABLE_START: usize = EXT_PROCESSOR_TABLE_END; -pub const EXT_OP_STACK_TABLE_END: usize = - EXT_OP_STACK_TABLE_START + ::AuxColumn::COUNT; -pub const EXT_RAM_TABLE_START: usize = EXT_OP_STACK_TABLE_END; -pub const EXT_RAM_TABLE_END: usize = EXT_RAM_TABLE_START + ::AuxColumn::COUNT; -pub const EXT_JUMP_STACK_TABLE_START: usize = EXT_RAM_TABLE_END; -pub const EXT_JUMP_STACK_TABLE_END: usize = - EXT_JUMP_STACK_TABLE_START + ::AuxColumn::COUNT; -pub const EXT_HASH_TABLE_START: usize = EXT_JUMP_STACK_TABLE_END; -pub const EXT_HASH_TABLE_END: usize = EXT_HASH_TABLE_START + ::AuxColumn::COUNT; -pub const EXT_CASCADE_TABLE_START: usize = EXT_HASH_TABLE_END; -pub const EXT_CASCADE_TABLE_END: usize = - EXT_CASCADE_TABLE_START + ::AuxColumn::COUNT; -pub const EXT_LOOKUP_TABLE_START: usize = EXT_CASCADE_TABLE_END; -pub const EXT_LOOKUP_TABLE_END: usize = - EXT_LOOKUP_TABLE_START + ::AuxColumn::COUNT; -pub const EXT_U32_TABLE_START: usize = EXT_LOOKUP_TABLE_END; -pub const EXT_U32_TABLE_END: usize = EXT_U32_TABLE_START + ::AuxColumn::COUNT; +pub const AUX_PROGRAM_TABLE_START: usize = 0; +pub const AUX_PROGRAM_TABLE_END: usize = + AUX_PROGRAM_TABLE_START + ::AuxColumn::COUNT; +pub const AUX_PROCESSOR_TABLE_START: usize = AUX_PROGRAM_TABLE_END; +pub const AUX_PROCESSOR_TABLE_END: usize = + AUX_PROCESSOR_TABLE_START + ::AuxColumn::COUNT; +pub const AUX_OP_STACK_TABLE_START: usize = AUX_PROCESSOR_TABLE_END; +pub const AUX_OP_STACK_TABLE_END: usize = + AUX_OP_STACK_TABLE_START + ::AuxColumn::COUNT; +pub const AUX_RAM_TABLE_START: usize = AUX_OP_STACK_TABLE_END; +pub const AUX_RAM_TABLE_END: usize = AUX_RAM_TABLE_START + ::AuxColumn::COUNT; +pub const AUX_JUMP_STACK_TABLE_START: usize = AUX_RAM_TABLE_END; +pub const AUX_JUMP_STACK_TABLE_END: usize = + AUX_JUMP_STACK_TABLE_START + ::AuxColumn::COUNT; +pub const AUX_HASH_TABLE_START: usize = AUX_JUMP_STACK_TABLE_END; +pub const AUX_HASH_TABLE_END: usize = AUX_HASH_TABLE_START + ::AuxColumn::COUNT; +pub const AUX_CASCADE_TABLE_START: usize = AUX_HASH_TABLE_END; +pub const AUX_CASCADE_TABLE_END: usize = + AUX_CASCADE_TABLE_START + ::AuxColumn::COUNT; +pub const AUX_LOOKUP_TABLE_START: usize = AUX_CASCADE_TABLE_END; +pub const AUX_LOOKUP_TABLE_END: usize = + AUX_LOOKUP_TABLE_START + ::AuxColumn::COUNT; +pub const AUX_U32_TABLE_START: usize = AUX_LOOKUP_TABLE_END; +pub const AUX_U32_TABLE_END: usize = AUX_U32_TABLE_START + ::AuxColumn::COUNT; /// Uniquely determines one of Triton VM's tables. #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter, Arbitrary)] diff --git a/triton-air/src/table/cascade.rs b/triton-air/src/table/cascade.rs index e9beeb93b..d7be94b17 100644 --- a/triton-air/src/table/cascade.rs +++ b/triton-air/src/table/cascade.rs @@ -1,13 +1,13 @@ use constraint_circuit::ConstraintCircuitBuilder; use constraint_circuit::ConstraintCircuitMonad; use constraint_circuit::DualRowIndicator; -use constraint_circuit::DualRowIndicator::CurrentBaseRow; -use constraint_circuit::DualRowIndicator::CurrentExtRow; -use constraint_circuit::DualRowIndicator::NextBaseRow; -use constraint_circuit::DualRowIndicator::NextExtRow; +use constraint_circuit::DualRowIndicator::CurrentAux; +use constraint_circuit::DualRowIndicator::CurrentMain; +use constraint_circuit::DualRowIndicator::NextAux; +use constraint_circuit::DualRowIndicator::NextMain; use constraint_circuit::SingleRowIndicator; -use constraint_circuit::SingleRowIndicator::BaseRow; -use constraint_circuit::SingleRowIndicator::ExtRow; +use constraint_circuit::SingleRowIndicator::Aux; +use constraint_circuit::SingleRowIndicator::Main; use crate::challenge_id::ChallengeId; use crate::challenge_id::ChallengeId::CascadeLookupIndeterminate; @@ -18,26 +18,24 @@ use crate::challenge_id::ChallengeId::LookupTableInputWeight; use crate::challenge_id::ChallengeId::LookupTableOutputWeight; use crate::cross_table_argument::CrossTableArg; use crate::cross_table_argument::LookupArg; -use crate::table_column::MasterBaseTableColumn; -use crate::table_column::MasterExtTableColumn; +use crate::table_column::MasterAuxColumn; +use crate::table_column::MasterMainColumn; use crate::AIR; #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct CascadeTable; impl AIR for CascadeTable { - type MainColumn = crate::table_column::CascadeBaseTableColumn; - type AuxColumn = crate::table_column::CascadeExtTableColumn; + type MainColumn = crate::table_column::CascadeMainColumn; + type AuxColumn = crate::table_column::CascadeAuxColumn; fn initial_constraints( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let main_row = |col_id: Self::MainColumn| { - circuit_builder.input(BaseRow(col_id.master_base_table_index())) - }; - let aux_row = |col_id: Self::AuxColumn| { - circuit_builder.input(ExtRow(col_id.master_ext_table_index())) - }; + let main_row = + |col_id: Self::MainColumn| circuit_builder.input(Main(col_id.master_main_index())); + let aux_row = + |col_id: Self::AuxColumn| circuit_builder.input(Aux(col_id.master_aux_index())); let challenge = |challenge_id: ChallengeId| circuit_builder.challenge(challenge_id); let one = || circuit_builder.b_constant(1); @@ -105,9 +103,8 @@ impl AIR for CascadeTable { fn consistency_constraints( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let row = |col_id: Self::MainColumn| { - circuit_builder.input(BaseRow(col_id.master_base_table_index())) - }; + let row = + |col_id: Self::MainColumn| circuit_builder.input(Main(col_id.master_main_index())); let one = circuit_builder.b_constant(1); let is_padding = row(Self::MainColumn::IsPadding); @@ -123,16 +120,16 @@ impl AIR for CascadeTable { let constant = |c: u64| circuit_builder.b_constant(c); let curr_main_row = |column_idx: Self::MainColumn| { - circuit_builder.input(CurrentBaseRow(column_idx.master_base_table_index())) + circuit_builder.input(CurrentMain(column_idx.master_main_index())) }; let next_main_row = |column_idx: Self::MainColumn| { - circuit_builder.input(NextBaseRow(column_idx.master_base_table_index())) + circuit_builder.input(NextMain(column_idx.master_main_index())) }; let curr_aux_row = |column_idx: Self::AuxColumn| { - circuit_builder.input(CurrentExtRow(column_idx.master_ext_table_index())) + circuit_builder.input(CurrentAux(column_idx.master_aux_index())) }; let next_aux_row = |column_idx: Self::AuxColumn| { - circuit_builder.input(NextExtRow(column_idx.master_ext_table_index())) + circuit_builder.input(NextAux(column_idx.master_aux_index())) }; let one = constant(1); diff --git a/triton-air/src/table/hash.rs b/triton-air/src/table/hash.rs index 06566c988..f52456206 100644 --- a/triton-air/src/table/hash.rs +++ b/triton-air/src/table/hash.rs @@ -1,14 +1,14 @@ use constraint_circuit::ConstraintCircuitBuilder; use constraint_circuit::ConstraintCircuitMonad; use constraint_circuit::DualRowIndicator; -use constraint_circuit::DualRowIndicator::CurrentBaseRow; -use constraint_circuit::DualRowIndicator::CurrentExtRow; -use constraint_circuit::DualRowIndicator::NextBaseRow; -use constraint_circuit::DualRowIndicator::NextExtRow; +use constraint_circuit::DualRowIndicator::CurrentAux; +use constraint_circuit::DualRowIndicator::CurrentMain; +use constraint_circuit::DualRowIndicator::NextAux; +use constraint_circuit::DualRowIndicator::NextMain; use constraint_circuit::InputIndicator; use constraint_circuit::SingleRowIndicator; -use constraint_circuit::SingleRowIndicator::BaseRow; -use constraint_circuit::SingleRowIndicator::ExtRow; +use constraint_circuit::SingleRowIndicator::Aux; +use constraint_circuit::SingleRowIndicator::Main; use isa::instruction::Instruction; use itertools::Itertools; use strum::Display; @@ -22,8 +22,8 @@ use crate::challenge_id::ChallengeId; use crate::cross_table_argument::CrossTableArg; use crate::cross_table_argument::EvalArg; use crate::cross_table_argument::LookupArg; -use crate::table_column::MasterBaseTableColumn; -use crate::table_column::MasterExtTableColumn; +use crate::table_column::MasterAuxColumn; +use crate::table_column::MasterMainColumn; use crate::AIR; pub const MONTGOMERY_MODULUS: BFieldElement = @@ -165,7 +165,7 @@ impl HashTable { /// Valid indices are 0 through 15, corresponding to the 16 round constants /// `Constant0` through `Constant15`. /// - /// [col]: crate::table_column::HashBaseTableColumn + /// [col]: crate::table_column::HashMainColumn pub fn round_constant_column_by_index(index: usize) -> MainColumn { match index { 0 => MainColumn::Constant0, @@ -188,7 +188,7 @@ impl HashTable { } } - /// The [`HashBaseTableColumn`] for the state corresponding to the given index. + /// The [`HashMainColumn`] for the state corresponding to the given index. /// Valid indices are 4 through 15, corresponding to the 12 state columns /// [`State4`] through [`State15`]. /// @@ -213,16 +213,16 @@ impl HashTable { } } - fn indicate_column_index_in_base_row(column: MainColumn) -> SingleRowIndicator { - BaseRow(column.master_base_table_index()) + fn indicate_column_index_in_main_row(column: MainColumn) -> SingleRowIndicator { + Main(column.master_main_index()) } - fn indicate_column_index_in_current_base_row(column: MainColumn) -> DualRowIndicator { - CurrentBaseRow(column.master_base_table_index()) + fn indicate_column_index_in_current_main_row(column: MainColumn) -> DualRowIndicator { + CurrentMain(column.master_main_index()) } - fn indicate_column_index_in_next_base_row(column: MainColumn) -> DualRowIndicator { - NextBaseRow(column.master_base_table_index()) + fn indicate_column_index_in_next_main_row(column: MainColumn) -> DualRowIndicator { + NextMain(column.master_main_index()) } fn re_compose_states_0_through_3_before_lookup( @@ -270,10 +270,10 @@ impl HashTable { let constant = |c: u64| circuit_builder.b_constant(c); let b_constant = |c| circuit_builder.b_constant(c); let current_main_row = |column_idx: MainColumn| { - circuit_builder.input(CurrentBaseRow(column_idx.master_base_table_index())) + circuit_builder.input(CurrentMain(column_idx.master_main_index())) }; let next_main_row = |column_idx: MainColumn| { - circuit_builder.input(NextBaseRow(column_idx.master_base_table_index())) + circuit_builder.input(NextMain(column_idx.master_main_index())) }; let state_0_after_lookup = Self::re_compose_16_bit_limbs( @@ -387,7 +387,7 @@ impl HashTable { let [state_0_next, state_1_next, state_2_next, state_3_next] = Self::re_compose_states_0_through_3_before_lookup( circuit_builder, - Self::indicate_column_index_in_next_base_row, + Self::indicate_column_index_in_next_main_row, ); let state_next = [ state_0_next, @@ -432,14 +432,13 @@ impl HashTable { let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b()); let constant = |c: u32| circuit_builder.b_constant(c); let next_main_row = |column_idx: MainColumn| { - circuit_builder.input(NextBaseRow(column_idx.master_base_table_index())) + circuit_builder.input(NextMain(column_idx.master_main_index())) }; let current_aux_row = |column_idx: AuxColumn| { - circuit_builder.input(CurrentExtRow(column_idx.master_ext_table_index())) - }; - let next_aux_row = |column_idx: AuxColumn| { - circuit_builder.input(NextExtRow(column_idx.master_ext_table_index())) + circuit_builder.input(CurrentAux(column_idx.master_aux_index())) }; + let next_aux_row = + |column_idx: AuxColumn| circuit_builder.input(NextAux(column_idx.master_aux_index())); let cascade_indeterminate = challenge(ChallengeId::HashCascadeLookupIndeterminate); let look_in_weight = challenge(ChallengeId::HashCascadeLookInWeight); @@ -477,8 +476,8 @@ impl HashTable { } impl AIR for HashTable { - type MainColumn = crate::table_column::HashBaseTableColumn; - type AuxColumn = crate::table_column::HashExtTableColumn; + type MainColumn = crate::table_column::HashMainColumn; + type AuxColumn = crate::table_column::HashAuxColumn; fn initial_constraints( circuit_builder: &ConstraintCircuitBuilder, @@ -486,12 +485,10 @@ impl AIR for HashTable { let challenge = |c| circuit_builder.challenge(c); let constant = |c: u64| circuit_builder.b_constant(c); - let main_row = |column: Self::MainColumn| { - circuit_builder.input(BaseRow(column.master_base_table_index())) - }; - let aux_row = |column: Self::AuxColumn| { - circuit_builder.input(ExtRow(column.master_ext_table_index())) - }; + let main_row = + |column: Self::MainColumn| circuit_builder.input(Main(column.master_main_index())); + let aux_row = + |column: Self::AuxColumn| circuit_builder.input(Aux(column.master_aux_index())); let running_evaluation_initial = circuit_builder.x_constant(EvalArg::default_initial()); let lookup_arg_default_initial = circuit_builder.x_constant(LookupArg::default_initial()); @@ -515,7 +512,7 @@ impl AIR for HashTable { let [state_0, state_1, state_2, state_3] = Self::re_compose_states_0_through_3_before_lookup( circuit_builder, - Self::indicate_column_index_in_base_row, + Self::indicate_column_index_in_main_row, ); let state_rate_part: [_; tip5::RATE] = [ state_0, @@ -659,7 +656,7 @@ impl AIR for HashTable { let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b()); let constant = |c: u64| circuit_builder.b_constant(c); let main_row = |column_id: Self::MainColumn| { - circuit_builder.input(BaseRow(column_id.master_base_table_index())) + circuit_builder.input(Main(column_id.master_main_index())) }; let mode = main_row(Self::MainColumn::Mode); @@ -837,16 +834,16 @@ impl AIR for HashTable { let opcode_sponge_squeeze = opcode(Instruction::SpongeSqueeze); let current_main_row = |column_idx: Self::MainColumn| { - circuit_builder.input(CurrentBaseRow(column_idx.master_base_table_index())) + circuit_builder.input(CurrentMain(column_idx.master_main_index())) }; - let next_base_row = |column_idx: Self::MainColumn| { - circuit_builder.input(NextBaseRow(column_idx.master_base_table_index())) + let next_main_row = |column_idx: Self::MainColumn| { + circuit_builder.input(NextMain(column_idx.master_main_index())) }; - let current_ext_row = |column_idx: Self::AuxColumn| { - circuit_builder.input(CurrentExtRow(column_idx.master_ext_table_index())) + let current_aux_row = |column_idx: Self::AuxColumn| { + circuit_builder.input(CurrentAux(column_idx.master_aux_index())) }; - let next_ext_row = |column_idx: Self::AuxColumn| { - circuit_builder.input(NextExtRow(column_idx.master_ext_table_index())) + let next_aux_row = |column_idx: Self::AuxColumn| { + circuit_builder.input(NextAux(column_idx.master_aux_index())) }; let running_evaluation_initial = circuit_builder.x_constant(EvalArg::default_initial()); @@ -866,28 +863,28 @@ impl AIR for HashTable { let ci = current_main_row(Self::MainColumn::CI); let round_number = current_main_row(Self::MainColumn::RoundNumber); let running_evaluation_receive_chunk = - current_ext_row(Self::AuxColumn::ReceiveChunkRunningEvaluation); + current_aux_row(Self::AuxColumn::ReceiveChunkRunningEvaluation); let running_evaluation_hash_input = - current_ext_row(Self::AuxColumn::HashInputRunningEvaluation); + current_aux_row(Self::AuxColumn::HashInputRunningEvaluation); let running_evaluation_hash_digest = - current_ext_row(Self::AuxColumn::HashDigestRunningEvaluation); - let running_evaluation_sponge = current_ext_row(Self::AuxColumn::SpongeRunningEvaluation); + current_aux_row(Self::AuxColumn::HashDigestRunningEvaluation); + let running_evaluation_sponge = current_aux_row(Self::AuxColumn::SpongeRunningEvaluation); - let mode_next = next_base_row(Self::MainColumn::Mode); - let ci_next = next_base_row(Self::MainColumn::CI); - let round_number_next = next_base_row(Self::MainColumn::RoundNumber); + let mode_next = next_main_row(Self::MainColumn::Mode); + let ci_next = next_main_row(Self::MainColumn::CI); + let round_number_next = next_main_row(Self::MainColumn::RoundNumber); let running_evaluation_receive_chunk_next = - next_ext_row(Self::AuxColumn::ReceiveChunkRunningEvaluation); + next_aux_row(Self::AuxColumn::ReceiveChunkRunningEvaluation); let running_evaluation_hash_input_next = - next_ext_row(Self::AuxColumn::HashInputRunningEvaluation); + next_aux_row(Self::AuxColumn::HashInputRunningEvaluation); let running_evaluation_hash_digest_next = - next_ext_row(Self::AuxColumn::HashDigestRunningEvaluation); - let running_evaluation_sponge_next = next_ext_row(Self::AuxColumn::SpongeRunningEvaluation); + next_aux_row(Self::AuxColumn::HashDigestRunningEvaluation); + let running_evaluation_sponge_next = next_aux_row(Self::AuxColumn::SpongeRunningEvaluation); let [state_0, state_1, state_2, state_3] = Self::re_compose_states_0_through_3_before_lookup( circuit_builder, - Self::indicate_column_index_in_current_base_row, + Self::indicate_column_index_in_current_main_row, ); let state_current = [ @@ -1248,7 +1245,7 @@ impl AIR for HashTable { let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b()); let constant = |c: u64| circuit_builder.b_constant(c); let main_row = |column_idx: Self::MainColumn| { - circuit_builder.input(BaseRow(column_idx.master_base_table_index())) + circuit_builder.input(Main(column_idx.master_main_index())) }; let mode = main_row(Self::MainColumn::Mode); @@ -1263,7 +1260,7 @@ impl AIR for HashTable { let [state_0, state_1, state_2, state_3] = Self::re_compose_states_0_through_3_before_lookup( circuit_builder, - Self::indicate_column_index_in_base_row, + Self::indicate_column_index_in_main_row, ); let state_4 = main_row(Self::MainColumn::State4); let program_digest = [state_0, state_1, state_2, state_3, state_4]; @@ -1312,7 +1309,7 @@ impl AIR for HashTable { /// The empty program is not valid since any valid [`Program`][program] must execute /// instruction `halt`. /// -/// [round_no]: crate::table_column::HashBaseTableColumn::RoundNumber +/// [round_no]: crate::table_column::HashMainColumn::RoundNumber /// [program]: isa::program::Program /// [prog_hash]: HashTableMode::ProgramHashing /// [sponge]: HashTableMode::Sponge diff --git a/triton-air/src/table/jump_stack.rs b/triton-air/src/table/jump_stack.rs index cc2b3e23e..18624345a 100644 --- a/triton-air/src/table/jump_stack.rs +++ b/triton-air/src/table/jump_stack.rs @@ -1,13 +1,13 @@ use constraint_circuit::ConstraintCircuitBuilder; use constraint_circuit::ConstraintCircuitMonad; use constraint_circuit::DualRowIndicator; -use constraint_circuit::DualRowIndicator::CurrentBaseRow; -use constraint_circuit::DualRowIndicator::CurrentExtRow; -use constraint_circuit::DualRowIndicator::NextBaseRow; -use constraint_circuit::DualRowIndicator::NextExtRow; +use constraint_circuit::DualRowIndicator::CurrentAux; +use constraint_circuit::DualRowIndicator::CurrentMain; +use constraint_circuit::DualRowIndicator::NextAux; +use constraint_circuit::DualRowIndicator::NextMain; use constraint_circuit::SingleRowIndicator; -use constraint_circuit::SingleRowIndicator::BaseRow; -use constraint_circuit::SingleRowIndicator::ExtRow; +use constraint_circuit::SingleRowIndicator::Aux; +use constraint_circuit::SingleRowIndicator::Main; use isa::instruction::Instruction; use twenty_first::prelude::BFieldElement; @@ -20,35 +20,35 @@ use crate::challenge_id::ChallengeId::JumpStackJsoWeight; use crate::challenge_id::ChallengeId::JumpStackJspWeight; use crate::cross_table_argument::CrossTableArg; use crate::cross_table_argument::LookupArg; -use crate::table_column::JumpStackBaseTableColumn::CI; -use crate::table_column::JumpStackBaseTableColumn::CLK; -use crate::table_column::JumpStackBaseTableColumn::JSD; -use crate::table_column::JumpStackBaseTableColumn::JSO; -use crate::table_column::JumpStackBaseTableColumn::JSP; -use crate::table_column::JumpStackExtTableColumn::ClockJumpDifferenceLookupClientLogDerivative; -use crate::table_column::JumpStackExtTableColumn::RunningProductPermArg; -use crate::table_column::MasterBaseTableColumn; -use crate::table_column::MasterExtTableColumn; +use crate::table_column::JumpStackAuxColumn::ClockJumpDifferenceLookupClientLogDerivative; +use crate::table_column::JumpStackAuxColumn::RunningProductPermArg; +use crate::table_column::JumpStackMainColumn::CI; +use crate::table_column::JumpStackMainColumn::CLK; +use crate::table_column::JumpStackMainColumn::JSD; +use crate::table_column::JumpStackMainColumn::JSO; +use crate::table_column::JumpStackMainColumn::JSP; +use crate::table_column::MasterAuxColumn; +use crate::table_column::MasterMainColumn; use crate::AIR; #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct JumpStackTable; impl AIR for JumpStackTable { - type MainColumn = crate::table_column::JumpStackBaseTableColumn; - type AuxColumn = crate::table_column::JumpStackExtTableColumn; + type MainColumn = crate::table_column::JumpStackMainColumn; + type AuxColumn = crate::table_column::JumpStackAuxColumn; fn initial_constraints( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let clk = circuit_builder.input(BaseRow(CLK.master_base_table_index())); - let jsp = circuit_builder.input(BaseRow(JSP.master_base_table_index())); - let jso = circuit_builder.input(BaseRow(JSO.master_base_table_index())); - let jsd = circuit_builder.input(BaseRow(JSD.master_base_table_index())); - let ci = circuit_builder.input(BaseRow(CI.master_base_table_index())); - let rppa = circuit_builder.input(ExtRow(RunningProductPermArg.master_ext_table_index())); - let clock_jump_diff_log_derivative = circuit_builder.input(ExtRow( - ClockJumpDifferenceLookupClientLogDerivative.master_ext_table_index(), + let clk = circuit_builder.input(Main(CLK.master_main_index())); + let jsp = circuit_builder.input(Main(JSP.master_main_index())); + let jso = circuit_builder.input(Main(JSO.master_main_index())); + let jsd = circuit_builder.input(Main(JSD.master_main_index())); + let ci = circuit_builder.input(Main(CI.master_main_index())); + let rppa = circuit_builder.input(Aux(RunningProductPermArg.master_aux_index())); + let clock_jump_diff_log_derivative = circuit_builder.input(Aux( + ClockJumpDifferenceLookupClientLogDerivative.master_aux_index(), )); let processor_perm_indeterminate = circuit_builder.challenge(JumpStackIndeterminate); @@ -87,27 +87,24 @@ impl AIR for JumpStackTable { let recurse_or_return_opcode = circuit_builder.b_constant(Instruction::RecurseOrReturn.opcode_b()); - let clk = circuit_builder.input(CurrentBaseRow(CLK.master_base_table_index())); - let ci = circuit_builder.input(CurrentBaseRow(CI.master_base_table_index())); - let jsp = circuit_builder.input(CurrentBaseRow(JSP.master_base_table_index())); - let jso = circuit_builder.input(CurrentBaseRow(JSO.master_base_table_index())); - let jsd = circuit_builder.input(CurrentBaseRow(JSD.master_base_table_index())); - let rppa = circuit_builder.input(CurrentExtRow( - RunningProductPermArg.master_ext_table_index(), - )); - let clock_jump_diff_log_derivative = circuit_builder.input(CurrentExtRow( - ClockJumpDifferenceLookupClientLogDerivative.master_ext_table_index(), + let clk = circuit_builder.input(CurrentMain(CLK.master_main_index())); + let ci = circuit_builder.input(CurrentMain(CI.master_main_index())); + let jsp = circuit_builder.input(CurrentMain(JSP.master_main_index())); + let jso = circuit_builder.input(CurrentMain(JSO.master_main_index())); + let jsd = circuit_builder.input(CurrentMain(JSD.master_main_index())); + let rppa = circuit_builder.input(CurrentAux(RunningProductPermArg.master_aux_index())); + let clock_jump_diff_log_derivative = circuit_builder.input(CurrentAux( + ClockJumpDifferenceLookupClientLogDerivative.master_aux_index(), )); - let clk_next = circuit_builder.input(NextBaseRow(CLK.master_base_table_index())); - let ci_next = circuit_builder.input(NextBaseRow(CI.master_base_table_index())); - let jsp_next = circuit_builder.input(NextBaseRow(JSP.master_base_table_index())); - let jso_next = circuit_builder.input(NextBaseRow(JSO.master_base_table_index())); - let jsd_next = circuit_builder.input(NextBaseRow(JSD.master_base_table_index())); - let rppa_next = - circuit_builder.input(NextExtRow(RunningProductPermArg.master_ext_table_index())); - let clock_jump_diff_log_derivative_next = circuit_builder.input(NextExtRow( - ClockJumpDifferenceLookupClientLogDerivative.master_ext_table_index(), + let clk_next = circuit_builder.input(NextMain(CLK.master_main_index())); + let ci_next = circuit_builder.input(NextMain(CI.master_main_index())); + let jsp_next = circuit_builder.input(NextMain(JSP.master_main_index())); + let jso_next = circuit_builder.input(NextMain(JSO.master_main_index())); + let jsd_next = circuit_builder.input(NextMain(JSD.master_main_index())); + let rppa_next = circuit_builder.input(NextAux(RunningProductPermArg.master_aux_index())); + let clock_jump_diff_log_derivative_next = circuit_builder.input(NextAux( + ClockJumpDifferenceLookupClientLogDerivative.master_aux_index(), )); let jsp_inc_or_stays = diff --git a/triton-air/src/table/lookup.rs b/triton-air/src/table/lookup.rs index a296a4ede..7706b1f5d 100644 --- a/triton-air/src/table/lookup.rs +++ b/triton-air/src/table/lookup.rs @@ -1,13 +1,13 @@ use constraint_circuit::ConstraintCircuitBuilder; use constraint_circuit::ConstraintCircuitMonad; use constraint_circuit::DualRowIndicator; -use constraint_circuit::DualRowIndicator::CurrentBaseRow; -use constraint_circuit::DualRowIndicator::CurrentExtRow; -use constraint_circuit::DualRowIndicator::NextBaseRow; -use constraint_circuit::DualRowIndicator::NextExtRow; +use constraint_circuit::DualRowIndicator::CurrentAux; +use constraint_circuit::DualRowIndicator::CurrentMain; +use constraint_circuit::DualRowIndicator::NextAux; +use constraint_circuit::DualRowIndicator::NextMain; use constraint_circuit::SingleRowIndicator; -use constraint_circuit::SingleRowIndicator::BaseRow; -use constraint_circuit::SingleRowIndicator::ExtRow; +use constraint_circuit::SingleRowIndicator::Aux; +use constraint_circuit::SingleRowIndicator::Main; use crate::challenge_id::ChallengeId; use crate::challenge_id::ChallengeId::CascadeLookupIndeterminate; @@ -18,34 +18,32 @@ use crate::challenge_id::ChallengeId::LookupTablePublicTerminal; use crate::cross_table_argument::CrossTableArg; use crate::cross_table_argument::EvalArg; use crate::cross_table_argument::LookupArg; -use crate::table_column::LookupBaseTableColumn; -use crate::table_column::LookupBaseTableColumn::IsPadding; -use crate::table_column::LookupBaseTableColumn::LookIn; -use crate::table_column::LookupBaseTableColumn::LookOut; -use crate::table_column::LookupBaseTableColumn::LookupMultiplicity; -use crate::table_column::LookupExtTableColumn; -use crate::table_column::LookupExtTableColumn::CascadeTableServerLogDerivative; -use crate::table_column::LookupExtTableColumn::PublicEvaluationArgument; -use crate::table_column::MasterBaseTableColumn; -use crate::table_column::MasterExtTableColumn; +use crate::table_column::LookupAuxColumn; +use crate::table_column::LookupAuxColumn::CascadeTableServerLogDerivative; +use crate::table_column::LookupAuxColumn::PublicEvaluationArgument; +use crate::table_column::LookupMainColumn; +use crate::table_column::LookupMainColumn::IsPadding; +use crate::table_column::LookupMainColumn::LookIn; +use crate::table_column::LookupMainColumn::LookOut; +use crate::table_column::LookupMainColumn::LookupMultiplicity; +use crate::table_column::MasterAuxColumn; +use crate::table_column::MasterMainColumn; use crate::AIR; #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct LookupTable; impl AIR for LookupTable { - type MainColumn = LookupBaseTableColumn; - type AuxColumn = LookupExtTableColumn; + type MainColumn = LookupMainColumn; + type AuxColumn = LookupAuxColumn; fn initial_constraints( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let main_row = |col_id: Self::MainColumn| { - circuit_builder.input(BaseRow(col_id.master_base_table_index())) - }; - let aux_row = |col_id: Self::AuxColumn| { - circuit_builder.input(ExtRow(col_id.master_ext_table_index())) - }; + let main_row = + |col_id: Self::MainColumn| circuit_builder.input(Main(col_id.master_main_index())); + let aux_row = + |col_id: Self::AuxColumn| circuit_builder.input(Aux(col_id.master_aux_index())); let challenge = |challenge_id: ChallengeId| circuit_builder.challenge(challenge_id); let lookup_input = main_row(LookIn); @@ -85,9 +83,8 @@ impl AIR for LookupTable { circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { let constant = |c: u64| circuit_builder.b_constant(c); - let main_row = |col_id: Self::MainColumn| { - circuit_builder.input(BaseRow(col_id.master_base_table_index())) - }; + let main_row = + |col_id: Self::MainColumn| circuit_builder.input(Main(col_id.master_main_index())); let padding_is_0_or_1 = main_row(IsPadding) * (constant(1) - main_row(IsPadding)); @@ -100,17 +97,14 @@ impl AIR for LookupTable { let one = || circuit_builder.b_constant(1); let current_main_row = |col_id: Self::MainColumn| { - circuit_builder.input(CurrentBaseRow(col_id.master_base_table_index())) - }; - let next_main_row = |col_id: Self::MainColumn| { - circuit_builder.input(NextBaseRow(col_id.master_base_table_index())) - }; - let current_aux_row = |col_id: Self::AuxColumn| { - circuit_builder.input(CurrentExtRow(col_id.master_ext_table_index())) - }; - let next_aux_row = |col_id: Self::AuxColumn| { - circuit_builder.input(NextExtRow(col_id.master_ext_table_index())) + circuit_builder.input(CurrentMain(col_id.master_main_index())) }; + let next_main_row = + |col_id: Self::MainColumn| circuit_builder.input(NextMain(col_id.master_main_index())); + let current_aux_row = + |col_id: Self::AuxColumn| circuit_builder.input(CurrentAux(col_id.master_aux_index())); + let next_aux_row = + |col_id: Self::AuxColumn| circuit_builder.input(NextAux(col_id.master_aux_index())); let challenge = |challenge_id: ChallengeId| circuit_builder.challenge(challenge_id); let lookup_input = current_main_row(LookIn); @@ -177,9 +171,8 @@ impl AIR for LookupTable { circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { let challenge = |challenge_id: ChallengeId| circuit_builder.challenge(challenge_id); - let aux_row = |col_id: Self::AuxColumn| { - circuit_builder.input(ExtRow(col_id.master_ext_table_index())) - }; + let aux_row = + |col_id: Self::AuxColumn| circuit_builder.input(Aux(col_id.master_aux_index())); let narrow_table_terminal_matches_user_supplied_terminal = aux_row(PublicEvaluationArgument) - challenge(LookupTablePublicTerminal); diff --git a/triton-air/src/table/op_stack.rs b/triton-air/src/table/op_stack.rs index 6a822cc08..cc417af37 100644 --- a/triton-air/src/table/op_stack.rs +++ b/triton-air/src/table/op_stack.rs @@ -1,13 +1,13 @@ use constraint_circuit::ConstraintCircuitBuilder; use constraint_circuit::ConstraintCircuitMonad; use constraint_circuit::DualRowIndicator; -use constraint_circuit::DualRowIndicator::CurrentBaseRow; -use constraint_circuit::DualRowIndicator::CurrentExtRow; -use constraint_circuit::DualRowIndicator::NextBaseRow; -use constraint_circuit::DualRowIndicator::NextExtRow; +use constraint_circuit::DualRowIndicator::CurrentAux; +use constraint_circuit::DualRowIndicator::CurrentMain; +use constraint_circuit::DualRowIndicator::NextAux; +use constraint_circuit::DualRowIndicator::NextMain; use constraint_circuit::SingleRowIndicator; -use constraint_circuit::SingleRowIndicator::BaseRow; -use constraint_circuit::SingleRowIndicator::ExtRow; +use constraint_circuit::SingleRowIndicator::Aux; +use constraint_circuit::SingleRowIndicator::Main; use isa::op_stack::OpStackElement; use strum::EnumCount; use twenty_first::prelude::*; @@ -16,8 +16,8 @@ use crate::challenge_id::ChallengeId; use crate::cross_table_argument::CrossTableArg; use crate::cross_table_argument::LookupArg; use crate::cross_table_argument::PermArg; -use crate::table_column::MasterBaseTableColumn; -use crate::table_column::MasterExtTableColumn; +use crate::table_column::MasterAuxColumn; +use crate::table_column::MasterMainColumn; use crate::AIR; /// The value indicating a padding row in the op stack table. Stored in the @@ -28,8 +28,8 @@ pub const PADDING_VALUE: BFieldElement = BFieldElement::new(2); pub struct OpStackTable; impl AIR for OpStackTable { - type MainColumn = crate::table_column::OpStackBaseTableColumn; - type AuxColumn = crate::table_column::OpStackExtTableColumn; + type MainColumn = crate::table_column::OpStackMainColumn; + type AuxColumn = crate::table_column::OpStackAuxColumn; fn initial_constraints( circuit_builder: &ConstraintCircuitBuilder, @@ -37,12 +37,10 @@ impl AIR for OpStackTable { let challenge = |c| circuit_builder.challenge(c); let constant = |c| circuit_builder.b_constant(c); let x_constant = |c| circuit_builder.x_constant(c); - let main_row = |column: Self::MainColumn| { - circuit_builder.input(BaseRow(column.master_base_table_index())) - }; - let aux_row = |column: Self::AuxColumn| { - circuit_builder.input(ExtRow(column.master_ext_table_index())) - }; + let main_row = + |column: Self::MainColumn| circuit_builder.input(Main(column.master_main_index())); + let aux_row = + |column: Self::AuxColumn| circuit_builder.input(Aux(column.master_aux_index())); let initial_stack_length = u32::try_from(OpStackElement::COUNT).unwrap(); let initial_stack_length = constant(initial_stack_length.into()); @@ -97,17 +95,14 @@ impl AIR for OpStackTable { let constant = |c| circuit_builder.b_constant(c); let challenge = |c| circuit_builder.challenge(c); let current_main_row = |column: Self::MainColumn| { - circuit_builder.input(CurrentBaseRow(column.master_base_table_index())) - }; - let current_aux_row = |column: Self::AuxColumn| { - circuit_builder.input(CurrentExtRow(column.master_ext_table_index())) - }; - let next_main_row = |column: Self::MainColumn| { - circuit_builder.input(NextBaseRow(column.master_base_table_index())) - }; - let next_aux_row = |column: Self::AuxColumn| { - circuit_builder.input(NextExtRow(column.master_ext_table_index())) + circuit_builder.input(CurrentMain(column.master_main_index())) }; + let current_aux_row = + |column: Self::AuxColumn| circuit_builder.input(CurrentAux(column.master_aux_index())); + let next_main_row = + |column: Self::MainColumn| circuit_builder.input(NextMain(column.master_main_index())); + let next_aux_row = + |column: Self::AuxColumn| circuit_builder.input(NextAux(column.master_aux_index())); let one = constant(1_u32.into()); let padding_indicator = constant(PADDING_VALUE); diff --git a/triton-air/src/table/processor.rs b/triton-air/src/table/processor.rs index d319d72d5..d3fa080cc 100644 --- a/triton-air/src/table/processor.rs +++ b/triton-air/src/table/processor.rs @@ -4,14 +4,14 @@ use std::ops::Mul; use constraint_circuit::ConstraintCircuitBuilder; use constraint_circuit::ConstraintCircuitMonad; use constraint_circuit::DualRowIndicator; -use constraint_circuit::DualRowIndicator::CurrentBaseRow; -use constraint_circuit::DualRowIndicator::CurrentExtRow; -use constraint_circuit::DualRowIndicator::NextBaseRow; -use constraint_circuit::DualRowIndicator::NextExtRow; +use constraint_circuit::DualRowIndicator::CurrentAux; +use constraint_circuit::DualRowIndicator::CurrentMain; +use constraint_circuit::DualRowIndicator::NextAux; +use constraint_circuit::DualRowIndicator::NextMain; use constraint_circuit::InputIndicator; use constraint_circuit::SingleRowIndicator; -use constraint_circuit::SingleRowIndicator::BaseRow; -use constraint_circuit::SingleRowIndicator::ExtRow; +use constraint_circuit::SingleRowIndicator::Aux; +use constraint_circuit::SingleRowIndicator::Main; use isa::instruction::Instruction; use isa::instruction::InstructionBit; use isa::instruction::ALL_INSTRUCTIONS; @@ -80,60 +80,60 @@ use crate::cross_table_argument::EvalArg; use crate::cross_table_argument::LookupArg; use crate::cross_table_argument::PermArg; use crate::table; -use crate::table_column::MasterBaseTableColumn; -use crate::table_column::MasterExtTableColumn; -use crate::table_column::ProcessorBaseTableColumn; -use crate::table_column::ProcessorBaseTableColumn::ClockJumpDifferenceLookupMultiplicity; -use crate::table_column::ProcessorBaseTableColumn::IsPadding; -use crate::table_column::ProcessorBaseTableColumn::OpStackPointer; -use crate::table_column::ProcessorBaseTableColumn::CI; -use crate::table_column::ProcessorBaseTableColumn::CLK; -use crate::table_column::ProcessorBaseTableColumn::HV0; -use crate::table_column::ProcessorBaseTableColumn::HV1; -use crate::table_column::ProcessorBaseTableColumn::HV2; -use crate::table_column::ProcessorBaseTableColumn::HV3; -use crate::table_column::ProcessorBaseTableColumn::HV4; -use crate::table_column::ProcessorBaseTableColumn::HV5; -use crate::table_column::ProcessorBaseTableColumn::IB0; -use crate::table_column::ProcessorBaseTableColumn::IB1; -use crate::table_column::ProcessorBaseTableColumn::IB2; -use crate::table_column::ProcessorBaseTableColumn::IB3; -use crate::table_column::ProcessorBaseTableColumn::IB4; -use crate::table_column::ProcessorBaseTableColumn::IB5; -use crate::table_column::ProcessorBaseTableColumn::IB6; -use crate::table_column::ProcessorBaseTableColumn::IP; -use crate::table_column::ProcessorBaseTableColumn::JSD; -use crate::table_column::ProcessorBaseTableColumn::JSO; -use crate::table_column::ProcessorBaseTableColumn::JSP; -use crate::table_column::ProcessorBaseTableColumn::NIA; -use crate::table_column::ProcessorBaseTableColumn::ST0; -use crate::table_column::ProcessorBaseTableColumn::ST1; -use crate::table_column::ProcessorBaseTableColumn::ST10; -use crate::table_column::ProcessorBaseTableColumn::ST11; -use crate::table_column::ProcessorBaseTableColumn::ST12; -use crate::table_column::ProcessorBaseTableColumn::ST13; -use crate::table_column::ProcessorBaseTableColumn::ST14; -use crate::table_column::ProcessorBaseTableColumn::ST15; -use crate::table_column::ProcessorBaseTableColumn::ST2; -use crate::table_column::ProcessorBaseTableColumn::ST3; -use crate::table_column::ProcessorBaseTableColumn::ST4; -use crate::table_column::ProcessorBaseTableColumn::ST5; -use crate::table_column::ProcessorBaseTableColumn::ST6; -use crate::table_column::ProcessorBaseTableColumn::ST7; -use crate::table_column::ProcessorBaseTableColumn::ST8; -use crate::table_column::ProcessorBaseTableColumn::ST9; -use crate::table_column::ProcessorExtTableColumn; -use crate::table_column::ProcessorExtTableColumn::ClockJumpDifferenceLookupServerLogDerivative; -use crate::table_column::ProcessorExtTableColumn::HashDigestEvalArg; -use crate::table_column::ProcessorExtTableColumn::HashInputEvalArg; -use crate::table_column::ProcessorExtTableColumn::InputTableEvalArg; -use crate::table_column::ProcessorExtTableColumn::InstructionLookupClientLogDerivative; -use crate::table_column::ProcessorExtTableColumn::JumpStackTablePermArg; -use crate::table_column::ProcessorExtTableColumn::OpStackTablePermArg; -use crate::table_column::ProcessorExtTableColumn::OutputTableEvalArg; -use crate::table_column::ProcessorExtTableColumn::RamTablePermArg; -use crate::table_column::ProcessorExtTableColumn::SpongeEvalArg; -use crate::table_column::ProcessorExtTableColumn::U32LookupClientLogDerivative; +use crate::table_column::MasterAuxColumn; +use crate::table_column::MasterMainColumn; +use crate::table_column::ProcessorAuxColumn; +use crate::table_column::ProcessorAuxColumn::ClockJumpDifferenceLookupServerLogDerivative; +use crate::table_column::ProcessorAuxColumn::HashDigestEvalArg; +use crate::table_column::ProcessorAuxColumn::HashInputEvalArg; +use crate::table_column::ProcessorAuxColumn::InputTableEvalArg; +use crate::table_column::ProcessorAuxColumn::InstructionLookupClientLogDerivative; +use crate::table_column::ProcessorAuxColumn::JumpStackTablePermArg; +use crate::table_column::ProcessorAuxColumn::OpStackTablePermArg; +use crate::table_column::ProcessorAuxColumn::OutputTableEvalArg; +use crate::table_column::ProcessorAuxColumn::RamTablePermArg; +use crate::table_column::ProcessorAuxColumn::SpongeEvalArg; +use crate::table_column::ProcessorAuxColumn::U32LookupClientLogDerivative; +use crate::table_column::ProcessorMainColumn; +use crate::table_column::ProcessorMainColumn::ClockJumpDifferenceLookupMultiplicity; +use crate::table_column::ProcessorMainColumn::IsPadding; +use crate::table_column::ProcessorMainColumn::OpStackPointer; +use crate::table_column::ProcessorMainColumn::CI; +use crate::table_column::ProcessorMainColumn::CLK; +use crate::table_column::ProcessorMainColumn::HV0; +use crate::table_column::ProcessorMainColumn::HV1; +use crate::table_column::ProcessorMainColumn::HV2; +use crate::table_column::ProcessorMainColumn::HV3; +use crate::table_column::ProcessorMainColumn::HV4; +use crate::table_column::ProcessorMainColumn::HV5; +use crate::table_column::ProcessorMainColumn::IB0; +use crate::table_column::ProcessorMainColumn::IB1; +use crate::table_column::ProcessorMainColumn::IB2; +use crate::table_column::ProcessorMainColumn::IB3; +use crate::table_column::ProcessorMainColumn::IB4; +use crate::table_column::ProcessorMainColumn::IB5; +use crate::table_column::ProcessorMainColumn::IB6; +use crate::table_column::ProcessorMainColumn::IP; +use crate::table_column::ProcessorMainColumn::JSD; +use crate::table_column::ProcessorMainColumn::JSO; +use crate::table_column::ProcessorMainColumn::JSP; +use crate::table_column::ProcessorMainColumn::NIA; +use crate::table_column::ProcessorMainColumn::ST0; +use crate::table_column::ProcessorMainColumn::ST1; +use crate::table_column::ProcessorMainColumn::ST10; +use crate::table_column::ProcessorMainColumn::ST11; +use crate::table_column::ProcessorMainColumn::ST12; +use crate::table_column::ProcessorMainColumn::ST13; +use crate::table_column::ProcessorMainColumn::ST14; +use crate::table_column::ProcessorMainColumn::ST15; +use crate::table_column::ProcessorMainColumn::ST2; +use crate::table_column::ProcessorMainColumn::ST3; +use crate::table_column::ProcessorMainColumn::ST4; +use crate::table_column::ProcessorMainColumn::ST5; +use crate::table_column::ProcessorMainColumn::ST6; +use crate::table_column::ProcessorMainColumn::ST7; +use crate::table_column::ProcessorMainColumn::ST8; +use crate::table_column::ProcessorMainColumn::ST9; use crate::AIR; /// The number of helper variable registers @@ -146,7 +146,7 @@ impl ProcessorTable { /// # Panics /// /// Panics if the index is out of bounds. - pub fn op_stack_column_by_index(index: usize) -> ProcessorBaseTableColumn { + pub fn op_stack_column_by_index(index: usize) -> ProcessorMainColumn { assert!( index < OpStackElement::COUNT, "Op Stack column index must be in [0, 15], not {index}" @@ -175,8 +175,8 @@ impl ProcessorTable { } impl AIR for ProcessorTable { - type MainColumn = crate::table_column::ProcessorBaseTableColumn; - type AuxColumn = crate::table_column::ProcessorExtTableColumn; + type MainColumn = crate::table_column::ProcessorMainColumn; + type AuxColumn = crate::table_column::ProcessorAuxColumn; fn initial_constraints( circuit_builder: &ConstraintCircuitBuilder, @@ -184,39 +184,36 @@ impl AIR for ProcessorTable { let constant = |c: u32| circuit_builder.b_constant(c); let x_constant = |x| circuit_builder.x_constant(x); let challenge = |c| circuit_builder.challenge(c); - let base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(BaseRow(col.master_base_table_index())) - }; - let ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(ExtRow(col.master_ext_table_index())) - }; - - let clk_is_0 = base_row(CLK); - let ip_is_0 = base_row(IP); - let jsp_is_0 = base_row(JSP); - let jso_is_0 = base_row(JSO); - let jsd_is_0 = base_row(JSD); - let st0_is_0 = base_row(ST0); - let st1_is_0 = base_row(ST1); - let st2_is_0 = base_row(ST2); - let st3_is_0 = base_row(ST3); - let st4_is_0 = base_row(ST4); - let st5_is_0 = base_row(ST5); - let st6_is_0 = base_row(ST6); - let st7_is_0 = base_row(ST7); - let st8_is_0 = base_row(ST8); - let st9_is_0 = base_row(ST9); - let st10_is_0 = base_row(ST10); - let op_stack_pointer_is_16 = base_row(OpStackPointer) - constant(16); + let main_row = + |col: ProcessorMainColumn| circuit_builder.input(Main(col.master_main_index())); + let aux_row = |col: ProcessorAuxColumn| circuit_builder.input(Aux(col.master_aux_index())); + + let clk_is_0 = main_row(CLK); + let ip_is_0 = main_row(IP); + let jsp_is_0 = main_row(JSP); + let jso_is_0 = main_row(JSO); + let jsd_is_0 = main_row(JSD); + let st0_is_0 = main_row(ST0); + let st1_is_0 = main_row(ST1); + let st2_is_0 = main_row(ST2); + let st3_is_0 = main_row(ST3); + let st4_is_0 = main_row(ST4); + let st5_is_0 = main_row(ST5); + let st6_is_0 = main_row(ST6); + let st7_is_0 = main_row(ST7); + let st8_is_0 = main_row(ST8); + let st9_is_0 = main_row(ST9); + let st10_is_0 = main_row(ST10); + let op_stack_pointer_is_16 = main_row(OpStackPointer) - constant(16); // Compress the program digest using an Evaluation Argument. // Lowest index in the digest corresponds to lowest index on the stack. let program_digest: [_; Digest::LEN] = [ - base_row(ST11), - base_row(ST12), - base_row(ST13), - base_row(ST14), - base_row(ST15), + main_row(ST11), + main_row(ST12), + main_row(ST13), + main_row(ST14), + main_row(ST15), ]; let compressed_program_digest = program_digest.into_iter().fold( circuit_builder.x_constant(EvalArg::default_initial()), @@ -231,78 +228,78 @@ impl AIR for ProcessorTable { // standard input let running_evaluation_for_standard_input_is_initialized_correctly = - ext_row(InputTableEvalArg) - x_constant(EvalArg::default_initial()); + aux_row(InputTableEvalArg) - x_constant(EvalArg::default_initial()); // program table let instruction_lookup_indeterminate = challenge(InstructionLookupIndeterminate); let instruction_ci_weight = challenge(ProgramInstructionWeight); let instruction_nia_weight = challenge(ProgramNextInstructionWeight); let compressed_row_for_instruction_lookup = - instruction_ci_weight * base_row(CI) + instruction_nia_weight * base_row(NIA); + instruction_ci_weight * main_row(CI) + instruction_nia_weight * main_row(NIA); let instruction_lookup_log_derivative_is_initialized_correctly = - (ext_row(InstructionLookupClientLogDerivative) + (aux_row(InstructionLookupClientLogDerivative) - x_constant(LookupArg::default_initial())) * (instruction_lookup_indeterminate - compressed_row_for_instruction_lookup) - constant(1); // standard output let running_evaluation_for_standard_output_is_initialized_correctly = - ext_row(OutputTableEvalArg) - x_constant(EvalArg::default_initial()); + aux_row(OutputTableEvalArg) - x_constant(EvalArg::default_initial()); let running_product_for_op_stack_table_is_initialized_correctly = - ext_row(OpStackTablePermArg) - x_constant(PermArg::default_initial()); + aux_row(OpStackTablePermArg) - x_constant(PermArg::default_initial()); // ram table let running_product_for_ram_table_is_initialized_correctly = - ext_row(RamTablePermArg) - x_constant(PermArg::default_initial()); + aux_row(RamTablePermArg) - x_constant(PermArg::default_initial()); // jump-stack table let jump_stack_indeterminate = challenge(JumpStackIndeterminate); let jump_stack_ci_weight = challenge(JumpStackCiWeight); // note: `clk`, `jsp`, `jso`, and `jsd` are already constrained to be 0. - let compressed_row_for_jump_stack_table = jump_stack_ci_weight * base_row(CI); + let compressed_row_for_jump_stack_table = jump_stack_ci_weight * main_row(CI); let running_product_for_jump_stack_table_is_initialized_correctly = - ext_row(JumpStackTablePermArg) + aux_row(JumpStackTablePermArg) - x_constant(PermArg::default_initial()) * (jump_stack_indeterminate - compressed_row_for_jump_stack_table); // clock jump difference lookup argument // The clock jump difference logarithmic derivative accumulator starts // off having accumulated the contribution from the first row. - // Note that (challenge(ClockJumpDifferenceLookupIndeterminate) - base_row(CLK)) + // Note that (challenge(ClockJumpDifferenceLookupIndeterminate) - main_row(CLK)) // collapses to challenge(ClockJumpDifferenceLookupIndeterminate) - // because base_row(CLK) = 0 is already a constraint. + // because main_row(CLK) = 0 is already a constraint. let clock_jump_diff_lookup_log_derivative_is_initialized_correctly = - ext_row(ClockJumpDifferenceLookupServerLogDerivative) + aux_row(ClockJumpDifferenceLookupServerLogDerivative) * challenge(ClockJumpDifferenceLookupIndeterminate) - - base_row(ClockJumpDifferenceLookupMultiplicity); + - main_row(ClockJumpDifferenceLookupMultiplicity); // from processor to hash table - let hash_selector = base_row(CI) - constant(Instruction::Hash.opcode()); + let hash_selector = main_row(CI) - constant(Instruction::Hash.opcode()); let hash_deselector = instruction_deselector_single_row(circuit_builder, Instruction::Hash); let hash_input_indeterminate = challenge(HashInputIndeterminate); // the opStack is guaranteed to be initialized to 0 by virtue of other initial constraints let compressed_row = constant(0); - let running_evaluation_hash_input_has_absorbed_first_row = ext_row(HashInputEvalArg) + let running_evaluation_hash_input_has_absorbed_first_row = aux_row(HashInputEvalArg) - hash_input_indeterminate * x_constant(EvalArg::default_initial()) - compressed_row; let running_evaluation_hash_input_is_default_initial = - ext_row(HashInputEvalArg) - x_constant(EvalArg::default_initial()); + aux_row(HashInputEvalArg) - x_constant(EvalArg::default_initial()); let running_evaluation_hash_input_is_initialized_correctly = hash_selector * running_evaluation_hash_input_is_default_initial + hash_deselector * running_evaluation_hash_input_has_absorbed_first_row; // from hash table to processor let running_evaluation_hash_digest_is_initialized_correctly = - ext_row(HashDigestEvalArg) - x_constant(EvalArg::default_initial()); + aux_row(HashDigestEvalArg) - x_constant(EvalArg::default_initial()); // Hash Table – Sponge let running_evaluation_sponge_absorb_is_initialized_correctly = - ext_row(SpongeEvalArg) - x_constant(EvalArg::default_initial()); + aux_row(SpongeEvalArg) - x_constant(EvalArg::default_initial()); // u32 table let running_sum_log_derivative_for_u32_table_is_initialized_correctly = - ext_row(U32LookupClientLogDerivative) - x_constant(LookupArg::default_initial()); + aux_row(U32LookupClientLogDerivative) - x_constant(LookupArg::default_initial()); vec![ clk_is_0, @@ -341,36 +338,35 @@ impl AIR for ProcessorTable { circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { let constant = |c: u32| circuit_builder.b_constant(c); - let base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(BaseRow(col.master_base_table_index())) - }; + let main_row = + |col: ProcessorMainColumn| circuit_builder.input(Main(col.master_main_index())); // The composition of instruction bits ib0-ib7 corresponds the current instruction ci. - let ib_composition = base_row(IB0) - + constant(1 << 1) * base_row(IB1) - + constant(1 << 2) * base_row(IB2) - + constant(1 << 3) * base_row(IB3) - + constant(1 << 4) * base_row(IB4) - + constant(1 << 5) * base_row(IB5) - + constant(1 << 6) * base_row(IB6); - let ci_corresponds_to_ib0_thru_ib7 = base_row(CI) - ib_composition; - - let ib0_is_bit = base_row(IB0) * (base_row(IB0) - constant(1)); - let ib1_is_bit = base_row(IB1) * (base_row(IB1) - constant(1)); - let ib2_is_bit = base_row(IB2) * (base_row(IB2) - constant(1)); - let ib3_is_bit = base_row(IB3) * (base_row(IB3) - constant(1)); - let ib4_is_bit = base_row(IB4) * (base_row(IB4) - constant(1)); - let ib5_is_bit = base_row(IB5) * (base_row(IB5) - constant(1)); - let ib6_is_bit = base_row(IB6) * (base_row(IB6) - constant(1)); - let is_padding_is_bit = base_row(IsPadding) * (base_row(IsPadding) - constant(1)); + let ib_composition = main_row(IB0) + + constant(1 << 1) * main_row(IB1) + + constant(1 << 2) * main_row(IB2) + + constant(1 << 3) * main_row(IB3) + + constant(1 << 4) * main_row(IB4) + + constant(1 << 5) * main_row(IB5) + + constant(1 << 6) * main_row(IB6); + let ci_corresponds_to_ib0_thru_ib7 = main_row(CI) - ib_composition; + + let ib0_is_bit = main_row(IB0) * (main_row(IB0) - constant(1)); + let ib1_is_bit = main_row(IB1) * (main_row(IB1) - constant(1)); + let ib2_is_bit = main_row(IB2) * (main_row(IB2) - constant(1)); + let ib3_is_bit = main_row(IB3) * (main_row(IB3) - constant(1)); + let ib4_is_bit = main_row(IB4) * (main_row(IB4) - constant(1)); + let ib5_is_bit = main_row(IB5) * (main_row(IB5) - constant(1)); + let ib6_is_bit = main_row(IB6) * (main_row(IB6) - constant(1)); + let is_padding_is_bit = main_row(IsPadding) * (main_row(IsPadding) - constant(1)); // In padding rows, the clock jump difference lookup multiplicity is 0. The one row // exempt from this rule is the row wth CLK == 1: since the memory-like tables don't have // an “awareness” of padding rows, they keep looking up clock jump differences of // magnitude 1. - let clock_jump_diff_lookup_multiplicity_is_0_in_padding_rows = base_row(IsPadding) - * (base_row(CLK) - constant(1)) - * base_row(ClockJumpDifferenceLookupMultiplicity); + let clock_jump_diff_lookup_multiplicity_is_0_in_padding_rows = main_row(IsPadding) + * (main_row(CLK) - constant(1)) + * main_row(ClockJumpDifferenceLookupMultiplicity); vec![ ib0_is_bit, @@ -390,17 +386,15 @@ impl AIR for ProcessorTable { circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { let constant = |c: u64| circuit_builder.b_constant(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); // constraints common to all instructions - let clk_increases_by_1 = next_base_row(CLK) - curr_base_row(CLK) - constant(1); + let clk_increases_by_1 = next_main_row(CLK) - curr_main_row(CLK) - constant(1); let is_padding_is_0_or_does_not_change = - curr_base_row(IsPadding) * (next_base_row(IsPadding) - curr_base_row(IsPadding)); + curr_main_row(IsPadding) * (next_main_row(IsPadding) - curr_main_row(IsPadding)); let instruction_independent_constraints = vec![clk_increases_by_1, is_padding_is_0_or_does_not_change]; @@ -443,8 +437,7 @@ impl AIR for ProcessorTable { fn terminal_constraints( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let main_row = - |col: Self::MainColumn| circuit_builder.input(BaseRow(col.master_base_table_index())); + let main_row = |col: Self::MainColumn| circuit_builder.input(Main(col.master_main_index())); let constant = |c| circuit_builder.b_constant(c); let last_ci_is_halt = @@ -516,18 +509,16 @@ fn combine_transition_constraints_with_padding_constraints( instruction_transition_constraints: Vec>, ) -> Vec> { let constant = |c: u64| circuit_builder.b_constant(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); let padding_row_transition_constraints = [ vec![ - next_base_row(IP) - curr_base_row(IP), - next_base_row(CI) - curr_base_row(CI), - next_base_row(NIA) - curr_base_row(NIA), + next_main_row(IP) - curr_main_row(IP), + next_main_row(CI) - curr_main_row(CI), + next_main_row(NIA) - curr_main_row(NIA), ], instruction_group_keep_jump_stack(circuit_builder), instruction_group_keep_op_stack(circuit_builder), @@ -536,8 +527,8 @@ fn combine_transition_constraints_with_padding_constraints( ] .concat(); - let padding_row_deselector = constant(1) - next_base_row(IsPadding); - let padding_row_selector = next_base_row(IsPadding); + let padding_row_deselector = constant(1) - next_main_row(IsPadding); + let padding_row_selector = next_main_row(IsPadding); let max_number_of_constraints = max( instruction_transition_constraints.len(), @@ -565,20 +556,19 @@ fn instruction_group_decompose_arg( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { let constant = |c: u32| circuit_builder.b_constant(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); - let hv0_is_a_bit = curr_base_row(HV0) * (curr_base_row(HV0) - constant(1)); - let hv1_is_a_bit = curr_base_row(HV1) * (curr_base_row(HV1) - constant(1)); - let hv2_is_a_bit = curr_base_row(HV2) * (curr_base_row(HV2) - constant(1)); - let hv3_is_a_bit = curr_base_row(HV3) * (curr_base_row(HV3) - constant(1)); + let hv0_is_a_bit = curr_main_row(HV0) * (curr_main_row(HV0) - constant(1)); + let hv1_is_a_bit = curr_main_row(HV1) * (curr_main_row(HV1) - constant(1)); + let hv2_is_a_bit = curr_main_row(HV2) * (curr_main_row(HV2) - constant(1)); + let hv3_is_a_bit = curr_main_row(HV3) * (curr_main_row(HV3) - constant(1)); - let helper_variables_are_binary_decomposition_of_nia = curr_base_row(NIA) - - constant(8) * curr_base_row(HV3) - - constant(4) * curr_base_row(HV2) - - constant(2) * curr_base_row(HV1) - - curr_base_row(HV0); + let helper_variables_are_binary_decomposition_of_nia = curr_main_row(NIA) + - constant(8) * curr_main_row(HV3) + - constant(4) * curr_main_row(HV2) + - constant(2) * curr_main_row(HV1) + - curr_main_row(HV0); vec![ hv0_is_a_bit, @@ -594,14 +584,12 @@ fn instruction_group_decompose_arg( fn instruction_group_no_ram( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; + let curr_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(CurrentAux(col.master_aux_index())); + let next_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(NextAux(col.master_aux_index())); - vec![next_ext_row(RamTablePermArg) - curr_ext_row(RamTablePermArg)] + vec![next_aux_row(RamTablePermArg) - curr_aux_row(RamTablePermArg)] } fn instruction_group_no_io( @@ -621,12 +609,10 @@ fn instruction_group_op_stack_remains_except_top_n( ) -> Vec> { assert!(n <= NUM_OP_STACK_REGISTERS); - let curr_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); let stack = (0..OpStackElement::COUNT) .map(ProcessorTable::op_stack_column_by_index) @@ -667,15 +653,14 @@ fn instruction_group_keep_op_stack_height( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { let op_stack_pointer_curr = - circuit_builder.input(CurrentBaseRow(OpStackPointer.master_base_table_index())); - let op_stack_pointer_next = - circuit_builder.input(NextBaseRow(OpStackPointer.master_base_table_index())); + circuit_builder.input(CurrentMain(OpStackPointer.master_main_index())); + let op_stack_pointer_next = circuit_builder.input(NextMain(OpStackPointer.master_main_index())); let osp_remains_unchanged = op_stack_pointer_next - op_stack_pointer_curr; let op_stack_table_perm_arg_curr = - circuit_builder.input(CurrentExtRow(OpStackTablePermArg.master_ext_table_index())); + circuit_builder.input(CurrentAux(OpStackTablePermArg.master_aux_index())); let op_stack_table_perm_arg_next = - circuit_builder.input(NextExtRow(OpStackTablePermArg.master_ext_table_index())); + circuit_builder.input(NextAux(OpStackTablePermArg.master_aux_index())); let perm_arg_remains_unchanged = op_stack_table_perm_arg_next - op_stack_table_perm_arg_curr; vec![osp_remains_unchanged, perm_arg_remains_unchanged] @@ -685,29 +670,27 @@ fn instruction_group_grow_op_stack_and_top_two_elements_unconstrained( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { let constant = |c: u32| circuit_builder.b_constant(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); vec![ - next_base_row(ST2) - curr_base_row(ST1), - next_base_row(ST3) - curr_base_row(ST2), - next_base_row(ST4) - curr_base_row(ST3), - next_base_row(ST5) - curr_base_row(ST4), - next_base_row(ST6) - curr_base_row(ST5), - next_base_row(ST7) - curr_base_row(ST6), - next_base_row(ST8) - curr_base_row(ST7), - next_base_row(ST9) - curr_base_row(ST8), - next_base_row(ST10) - curr_base_row(ST9), - next_base_row(ST11) - curr_base_row(ST10), - next_base_row(ST12) - curr_base_row(ST11), - next_base_row(ST13) - curr_base_row(ST12), - next_base_row(ST14) - curr_base_row(ST13), - next_base_row(ST15) - curr_base_row(ST14), - next_base_row(OpStackPointer) - curr_base_row(OpStackPointer) - constant(1), + next_main_row(ST2) - curr_main_row(ST1), + next_main_row(ST3) - curr_main_row(ST2), + next_main_row(ST4) - curr_main_row(ST3), + next_main_row(ST5) - curr_main_row(ST4), + next_main_row(ST6) - curr_main_row(ST5), + next_main_row(ST7) - curr_main_row(ST6), + next_main_row(ST8) - curr_main_row(ST7), + next_main_row(ST9) - curr_main_row(ST8), + next_main_row(ST10) - curr_main_row(ST9), + next_main_row(ST11) - curr_main_row(ST10), + next_main_row(ST12) - curr_main_row(ST11), + next_main_row(ST13) - curr_main_row(ST12), + next_main_row(ST14) - curr_main_row(ST13), + next_main_row(ST15) - curr_main_row(ST14), + next_main_row(OpStackPointer) - curr_main_row(OpStackPointer) - constant(1), running_product_op_stack_accounts_for_growing_stack_by(circuit_builder, 1), ] } @@ -715,14 +698,12 @@ fn instruction_group_grow_op_stack_and_top_two_elements_unconstrained( fn instruction_group_grow_op_stack( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); - let specific_constraints = vec![next_base_row(ST1) - curr_base_row(ST0)]; + let specific_constraints = vec![next_main_row(ST1) - curr_main_row(ST0)]; let inherited_constraints = instruction_group_grow_op_stack_and_top_two_elements_unconstrained(circuit_builder); @@ -733,27 +714,25 @@ fn instruction_group_op_stack_shrinks_and_top_three_elements_unconstrained( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { let constant = |c: u32| circuit_builder.b_constant(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); vec![ - next_base_row(ST3) - curr_base_row(ST4), - next_base_row(ST4) - curr_base_row(ST5), - next_base_row(ST5) - curr_base_row(ST6), - next_base_row(ST6) - curr_base_row(ST7), - next_base_row(ST7) - curr_base_row(ST8), - next_base_row(ST8) - curr_base_row(ST9), - next_base_row(ST9) - curr_base_row(ST10), - next_base_row(ST10) - curr_base_row(ST11), - next_base_row(ST11) - curr_base_row(ST12), - next_base_row(ST12) - curr_base_row(ST13), - next_base_row(ST13) - curr_base_row(ST14), - next_base_row(ST14) - curr_base_row(ST15), - next_base_row(OpStackPointer) - curr_base_row(OpStackPointer) + constant(1), + next_main_row(ST3) - curr_main_row(ST4), + next_main_row(ST4) - curr_main_row(ST5), + next_main_row(ST5) - curr_main_row(ST6), + next_main_row(ST6) - curr_main_row(ST7), + next_main_row(ST7) - curr_main_row(ST8), + next_main_row(ST8) - curr_main_row(ST9), + next_main_row(ST9) - curr_main_row(ST10), + next_main_row(ST10) - curr_main_row(ST11), + next_main_row(ST11) - curr_main_row(ST12), + next_main_row(ST12) - curr_main_row(ST13), + next_main_row(ST13) - curr_main_row(ST14), + next_main_row(ST14) - curr_main_row(ST15), + next_main_row(OpStackPointer) - curr_main_row(OpStackPointer) + constant(1), running_product_op_stack_accounts_for_shrinking_stack_by(circuit_builder, 1), ] } @@ -761,16 +740,14 @@ fn instruction_group_op_stack_shrinks_and_top_three_elements_unconstrained( fn instruction_group_binop( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); let specific_constraints = vec![ - next_base_row(ST1) - curr_base_row(ST2), - next_base_row(ST2) - curr_base_row(ST3), + next_main_row(ST1) - curr_main_row(ST2), + next_main_row(ST2) - curr_main_row(ST3), ]; let inherited_constraints = instruction_group_op_stack_shrinks_and_top_three_elements_unconstrained(circuit_builder); @@ -781,14 +758,12 @@ fn instruction_group_binop( fn instruction_group_shrink_op_stack( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); - let specific_constraints = vec![next_base_row(ST0) - curr_base_row(ST1)]; + let specific_constraints = vec![next_main_row(ST0) - curr_main_row(ST1)]; let inherited_constraints = instruction_group_binop(circuit_builder); [specific_constraints, inherited_constraints].concat() @@ -797,16 +772,14 @@ fn instruction_group_shrink_op_stack( fn instruction_group_keep_jump_stack( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); - let jsp_does_not_change = next_base_row(JSP) - curr_base_row(JSP); - let jso_does_not_change = next_base_row(JSO) - curr_base_row(JSO); - let jsd_does_not_change = next_base_row(JSD) - curr_base_row(JSD); + let jsp_does_not_change = next_main_row(JSP) - curr_main_row(JSP); + let jso_does_not_change = next_main_row(JSO) - curr_main_row(JSO); + let jsd_does_not_change = next_main_row(JSD) - curr_main_row(JSD); vec![ jsp_does_not_change, @@ -820,14 +793,12 @@ fn instruction_group_step_1( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { let constant = |c: u32| circuit_builder.b_constant(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); - let instruction_pointer_increases_by_one = next_base_row(IP) - curr_base_row(IP) - constant(1); + let instruction_pointer_increases_by_one = next_main_row(IP) - curr_main_row(IP) - constant(1); [ instruction_group_keep_jump_stack(circuit_builder), vec![instruction_pointer_increases_by_one], @@ -840,14 +811,12 @@ fn instruction_group_step_2( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { let constant = |c: u32| circuit_builder.b_constant(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); - let instruction_pointer_increases_by_two = next_base_row(IP) - curr_base_row(IP) - constant(2); + let instruction_pointer_increases_by_two = next_main_row(IP) - curr_main_row(IP) - constant(2); [ instruction_group_keep_jump_stack(circuit_builder), vec![instruction_pointer_increases_by_two], @@ -888,18 +857,17 @@ fn instruction_deselector_current_row( circuit_builder: &ConstraintCircuitBuilder, instruction: Instruction, ) -> ConstraintCircuitMonad { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); let instruction_bit_polynomials = [ - curr_base_row(IB0), - curr_base_row(IB1), - curr_base_row(IB2), - curr_base_row(IB3), - curr_base_row(IB4), - curr_base_row(IB5), - curr_base_row(IB6), + curr_main_row(IB0), + curr_main_row(IB1), + curr_main_row(IB2), + curr_main_row(IB3), + curr_main_row(IB4), + curr_main_row(IB5), + curr_main_row(IB6), ]; instruction_deselector_common_functionality( @@ -915,18 +883,17 @@ fn instruction_deselector_next_row( circuit_builder: &ConstraintCircuitBuilder, instruction: Instruction, ) -> ConstraintCircuitMonad { - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); let instruction_bit_polynomials = [ - next_base_row(IB0), - next_base_row(IB1), - next_base_row(IB2), - next_base_row(IB3), - next_base_row(IB4), - next_base_row(IB5), - next_base_row(IB6), + next_main_row(IB0), + next_main_row(IB1), + next_main_row(IB2), + next_main_row(IB3), + next_main_row(IB4), + next_main_row(IB5), + next_main_row(IB6), ]; instruction_deselector_common_functionality( @@ -942,18 +909,16 @@ fn instruction_deselector_single_row( circuit_builder: &ConstraintCircuitBuilder, instruction: Instruction, ) -> ConstraintCircuitMonad { - let base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(BaseRow(col.master_base_table_index())) - }; + let main_row = |col: ProcessorMainColumn| circuit_builder.input(Main(col.master_main_index())); let instruction_bit_polynomials = [ - base_row(IB0), - base_row(IB1), - base_row(IB2), - base_row(IB3), - base_row(IB4), - base_row(IB5), - base_row(IB6), + main_row(IB0), + main_row(IB1), + main_row(IB2), + main_row(IB3), + main_row(IB4), + main_row(IB5), + main_row(IB6), ]; instruction_deselector_common_functionality( @@ -980,14 +945,12 @@ fn instruction_pop( fn instruction_push( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); - let specific_constraints = vec![next_base_row(ST0) - curr_base_row(NIA)]; + let specific_constraints = vec![next_main_row(ST0) - curr_main_row(NIA)]; [ specific_constraints, instruction_group_grow_op_stack(circuit_builder), @@ -1016,12 +979,10 @@ fn instruction_dup( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { let indicator_poly = |idx| indicator_polynomial(circuit_builder, idx); - let curr_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); let st_column = ProcessorTable::op_stack_column_by_index; let duplicate_element = |i| indicator_poly(i) * (next_row(ST0) - curr_row(st_column(i))); @@ -1041,12 +1002,10 @@ fn instruction_dup( fn instruction_swap( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let curr_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); let stack = (0..OpStackElement::COUNT) .map(ProcessorTable::op_stack_column_by_index) @@ -1102,24 +1061,22 @@ fn instruction_skiz( ) -> Vec> { let constant = |c: u32| circuit_builder.b_constant(c); let one = || constant(1); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); - let hv0_is_inverse_of_st0 = curr_base_row(HV0) * curr_base_row(ST0) - one(); - let hv0_is_inverse_of_st0_or_hv0_is_0 = hv0_is_inverse_of_st0.clone() * curr_base_row(HV0); - let hv0_is_inverse_of_st0_or_st0_is_0 = hv0_is_inverse_of_st0 * curr_base_row(ST0); + let hv0_is_inverse_of_st0 = curr_main_row(HV0) * curr_main_row(ST0) - one(); + let hv0_is_inverse_of_st0_or_hv0_is_0 = hv0_is_inverse_of_st0.clone() * curr_main_row(HV0); + let hv0_is_inverse_of_st0_or_st0_is_0 = hv0_is_inverse_of_st0 * curr_main_row(ST0); // The next instruction nia is decomposed into helper variables hv. - let nia_decomposes_to_hvs = curr_base_row(NIA) - - curr_base_row(HV1) - - constant(1 << 1) * curr_base_row(HV2) - - constant(1 << 3) * curr_base_row(HV3) - - constant(1 << 5) * curr_base_row(HV4) - - constant(1 << 7) * curr_base_row(HV5); + let nia_decomposes_to_hvs = curr_main_row(NIA) + - curr_main_row(HV1) + - constant(1 << 1) * curr_main_row(HV2) + - constant(1 << 3) * curr_main_row(HV3) + - constant(1 << 5) * curr_main_row(HV4) + - constant(1 << 7) * curr_main_row(HV5); // If `st0` is non-zero, register `ip` is incremented by 1. // If `st0` is 0 and `nia` takes no argument, register `ip` is incremented by 2. @@ -1131,13 +1088,13 @@ fn instruction_skiz( // (Register `st0` is 0 or `ip` is incremented by 1), and // (`st0` has a multiplicative inverse or `hv1` is 1 or `ip` is incremented by 2), and // (`st0` has a multiplicative inverse or `hv1` is 0 or `ip` is incremented by 3). - let ip_case_1 = (next_base_row(IP) - curr_base_row(IP) - constant(1)) * curr_base_row(ST0); - let ip_case_2 = (next_base_row(IP) - curr_base_row(IP) - constant(2)) - * (curr_base_row(ST0) * curr_base_row(HV0) - one()) - * (curr_base_row(HV1) - one()); - let ip_case_3 = (next_base_row(IP) - curr_base_row(IP) - constant(3)) - * (curr_base_row(ST0) * curr_base_row(HV0) - one()) - * curr_base_row(HV1); + let ip_case_1 = (next_main_row(IP) - curr_main_row(IP) - constant(1)) * curr_main_row(ST0); + let ip_case_2 = (next_main_row(IP) - curr_main_row(IP) - constant(2)) + * (curr_main_row(ST0) * curr_main_row(HV0) - one()) + * (curr_main_row(HV1) - one()); + let ip_case_3 = (next_main_row(IP) - curr_main_row(IP) - constant(3)) + * (curr_main_row(ST0) * curr_main_row(HV0) - one()) + * curr_main_row(HV1); let ip_incr_by_1_or_2_or_3 = ip_case_1 + ip_case_2 + ip_case_3; let specific_constraints = vec![ @@ -1161,17 +1118,16 @@ fn next_instruction_range_check_constraints_for_instruction_skiz( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { let constant = |c: u32| circuit_builder.b_constant(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); let is_0_or_1 = - |var: ProcessorBaseTableColumn| curr_base_row(var) * (curr_base_row(var) - constant(1)); - let is_0_or_1_or_2_or_3 = |var: ProcessorBaseTableColumn| { - curr_base_row(var) - * (curr_base_row(var) - constant(1)) - * (curr_base_row(var) - constant(2)) - * (curr_base_row(var) - constant(3)) + |var: ProcessorMainColumn| curr_main_row(var) * (curr_main_row(var) - constant(1)); + let is_0_or_1_or_2_or_3 = |var: ProcessorMainColumn| { + curr_main_row(var) + * (curr_main_row(var) - constant(1)) + * (curr_main_row(var) - constant(2)) + * (curr_main_row(var) - constant(3)) }; vec![ @@ -1187,24 +1143,22 @@ fn instruction_call( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { let constant = |c: u32| circuit_builder.b_constant(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); // The jump stack pointer jsp is incremented by 1. - let jsp_incr_1 = next_base_row(JSP) - curr_base_row(JSP) - constant(1); + let jsp_incr_1 = next_main_row(JSP) - curr_main_row(JSP) - constant(1); // The jump's origin jso is set to the current instruction pointer ip plus 2. - let jso_becomes_ip_plus_2 = next_base_row(JSO) - curr_base_row(IP) - constant(2); + let jso_becomes_ip_plus_2 = next_main_row(JSO) - curr_main_row(IP) - constant(2); // The jump's destination jsd is set to the instruction's argument. - let jsd_becomes_nia = next_base_row(JSD) - curr_base_row(NIA); + let jsd_becomes_nia = next_main_row(JSD) - curr_main_row(NIA); // The instruction pointer ip is set to the instruction's argument. - let ip_becomes_nia = next_base_row(IP) - curr_base_row(NIA); + let ip_becomes_nia = next_main_row(IP) - curr_main_row(NIA); let specific_constraints = vec![ jsp_incr_1, @@ -1225,15 +1179,13 @@ fn instruction_return( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { let constant = |c: u32| circuit_builder.b_constant(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); - let jsp_decrements_by_1 = next_base_row(JSP) - curr_base_row(JSP) + constant(1); - let ip_is_set_to_jso = next_base_row(IP) - curr_base_row(JSO); + let jsp_decrements_by_1 = next_main_row(JSP) - curr_main_row(JSP) + constant(1); + let ip_is_set_to_jso = next_main_row(IP) - curr_main_row(JSO); let specific_constraints = vec![jsp_decrements_by_1, ip_is_set_to_jso]; [ @@ -1248,15 +1200,13 @@ fn instruction_return( fn instruction_recurse( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); // The instruction pointer ip is set to the last jump's destination jsd. - let ip_becomes_jsd = next_base_row(IP) - curr_base_row(JSD); + let ip_becomes_jsd = next_main_row(IP) - curr_main_row(JSD); let specific_constraints = vec![ip_becomes_jsd]; [ specific_constraints, @@ -1272,12 +1222,10 @@ fn instruction_recurse_or_return( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { let one = || circuit_builder.b_constant(1); - let curr_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); // Zero if the ST5 equals ST6. One if they are not equal. let st5_eq_st6 = || curr_row(HV0) * (curr_row(ST6) - curr_row(ST5)); @@ -1319,12 +1267,11 @@ fn instruction_assert( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { let constant = |c: u32| circuit_builder.b_constant(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); // The current top of the stack st0 is 1. - let st_0_is_1 = curr_base_row(ST0) - constant(1); + let st_0_is_1 = curr_main_row(ST0) - constant(1); let specific_constraints = vec![st_0_is_1]; [ @@ -1340,15 +1287,13 @@ fn instruction_assert( fn instruction_halt( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); // The instruction executed in the following step is instruction halt. - let halt_is_followed_by_halt = next_base_row(CI) - curr_base_row(CI); + let halt_is_followed_by_halt = next_main_row(CI) - curr_main_row(CI); let specific_constraints = vec![halt_is_followed_by_halt]; [ @@ -1392,21 +1337,19 @@ fn instruction_hash( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { let constant = |c: u32| circuit_builder.b_constant(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); let op_stack_shrinks_by_5_and_top_5_unconstrained = vec![ - next_base_row(ST5) - curr_base_row(ST10), - next_base_row(ST6) - curr_base_row(ST11), - next_base_row(ST7) - curr_base_row(ST12), - next_base_row(ST8) - curr_base_row(ST13), - next_base_row(ST9) - curr_base_row(ST14), - next_base_row(ST10) - curr_base_row(ST15), - next_base_row(OpStackPointer) - curr_base_row(OpStackPointer) + constant(5), + next_main_row(ST5) - curr_main_row(ST10), + next_main_row(ST6) - curr_main_row(ST11), + next_main_row(ST7) - curr_main_row(ST12), + next_main_row(ST8) - curr_main_row(ST13), + next_main_row(ST9) - curr_main_row(ST14), + next_main_row(ST10) - curr_main_row(ST15), + next_main_row(OpStackPointer) - curr_main_row(OpStackPointer) + constant(5), running_product_op_stack_accounts_for_shrinking_stack_by(circuit_builder, 5), ]; @@ -1435,12 +1378,9 @@ fn instruction_merkle_step_mem( ) -> Vec> { let constant = |c: u32| circuit_builder.b_constant(c); let stack_weight = |i| circuit_builder.challenge(stack_weight_by_index(i)); - let curr = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next = |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); let ram_pointers = [0, 1, 2, 3, 4].map(|i| curr(ST7) + constant(i)); let ram_read_destinations = [HV0, HV1, HV2, HV3, HV4].map(curr); @@ -1470,12 +1410,9 @@ fn instruction_merkle_step_shared_constraints( ) -> Vec> { let constant = |c: u32| circuit_builder.b_constant(c); let one = || constant(1); - let curr = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next = |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); let hv5_is_0_or_1 = curr(HV5) * (curr(HV5) - one()); let new_st5_is_previous_st5_div_2 = constant(2) * next(ST5) + curr(HV5) - curr(ST5); @@ -1492,16 +1429,15 @@ fn instruction_merkle_step_shared_constraints( fn instruction_assert_vector( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); let specific_constraints = vec![ - curr_base_row(ST5) - curr_base_row(ST0), - curr_base_row(ST6) - curr_base_row(ST1), - curr_base_row(ST7) - curr_base_row(ST2), - curr_base_row(ST8) - curr_base_row(ST3), - curr_base_row(ST9) - curr_base_row(ST4), + curr_main_row(ST5) - curr_main_row(ST0), + curr_main_row(ST6) - curr_main_row(ST1), + curr_main_row(ST7) - curr_main_row(ST2), + curr_main_row(ST8) - curr_main_row(ST3), + curr_main_row(ST9) - curr_main_row(ST4), ]; [ specific_constraints, @@ -1540,16 +1476,14 @@ fn instruction_sponge_absorb( fn instruction_sponge_absorb_mem( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); let constant = |c| circuit_builder.b_constant(c); let increment_ram_pointer = - next_base_row(ST0) - curr_base_row(ST0) - constant(tip5::RATE as u32); + next_main_row(ST0) - curr_main_row(ST0) - constant(tip5::RATE as u32); [ vec![increment_ram_pointer], @@ -1575,14 +1509,12 @@ fn instruction_sponge_squeeze( fn instruction_add( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); - let specific_constraints = vec![next_base_row(ST0) - curr_base_row(ST0) - curr_base_row(ST1)]; + let specific_constraints = vec![next_main_row(ST0) - curr_main_row(ST0) - curr_main_row(ST1)]; [ specific_constraints, instruction_group_step_1(circuit_builder), @@ -1596,14 +1528,12 @@ fn instruction_add( fn instruction_addi( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); - let specific_constraints = vec![next_base_row(ST0) - curr_base_row(ST0) - curr_base_row(NIA)]; + let specific_constraints = vec![next_main_row(ST0) - curr_main_row(ST0) - curr_main_row(NIA)]; [ specific_constraints, instruction_group_step_2(circuit_builder), @@ -1617,14 +1547,12 @@ fn instruction_addi( fn instruction_mul( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); - let specific_constraints = vec![next_base_row(ST0) - curr_base_row(ST0) * curr_base_row(ST1)]; + let specific_constraints = vec![next_main_row(ST0) - curr_main_row(ST0) * curr_main_row(ST1)]; [ specific_constraints, instruction_group_step_1(circuit_builder), @@ -1640,14 +1568,12 @@ fn instruction_invert( ) -> Vec> { let constant = |c: u32| circuit_builder.b_constant(c); let one = || constant(1); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); - let specific_constraints = vec![next_base_row(ST0) * curr_base_row(ST0) - one()]; + let specific_constraints = vec![next_main_row(ST0) * curr_main_row(ST0) - one()]; [ specific_constraints, instruction_group_step_1(circuit_builder), @@ -1663,27 +1589,25 @@ fn instruction_eq( ) -> Vec> { let constant = |c: u32| circuit_builder.b_constant(c); let one = || constant(1); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); - let st0_eq_st1 = || one() - curr_base_row(HV0) * (curr_base_row(ST1) - curr_base_row(ST0)); + let st0_eq_st1 = || one() - curr_main_row(HV0) * (curr_main_row(ST1) - curr_main_row(ST0)); // Helper variable hv0 is the inverse-or-zero of the difference of the stack's two top-most // elements: `hv0·(1 - hv0·(st1 - st0))` - let hv0_is_inverse_of_diff_or_hv0_is_0 = curr_base_row(HV0) * st0_eq_st1(); + let hv0_is_inverse_of_diff_or_hv0_is_0 = curr_main_row(HV0) * st0_eq_st1(); // Helper variable hv0 is the inverse-or-zero of the difference of the stack's two // top-most elements: `(st1 - st0)·(1 - hv0·(st1 - st0))` let hv0_is_inverse_of_diff_or_diff_is_0 = - (curr_base_row(ST1) - curr_base_row(ST0)) * st0_eq_st1(); + (curr_main_row(ST1) - curr_main_row(ST0)) * st0_eq_st1(); // The new top of the stack is 1 if the difference between the stack's two top-most // elements is not invertible, 0 otherwise: `st0' - (1 - hv0·(st1 - st0))` - let st0_becomes_1_if_diff_is_not_invertible = next_base_row(ST0) - st0_eq_st1(); + let st0_becomes_1_if_diff_is_not_invertible = next_main_row(ST0) - st0_eq_st1(); let specific_constraints = vec![ hv0_is_inverse_of_diff_or_hv0_is_0, @@ -1705,17 +1629,15 @@ fn instruction_split( ) -> Vec> { let constant = |c: u64| circuit_builder.b_constant(c); let one = || constant(1); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); // The top of the stack is decomposed as 32-bit chunks into the stack's top-most elements: // st0 - (2^32·st0' + st1') = 0$ let st0_decomposes_to_two_32_bit_chunks = - curr_base_row(ST0) - (constant(1 << 32) * next_base_row(ST1) + next_base_row(ST0)); + curr_main_row(ST0) - (constant(1 << 32) * next_main_row(ST1) + next_main_row(ST0)); // Helper variable `hv0` = 0 if either // 1. `hv0` is the difference between (2^32 - 1) and the high 32 bits (`st0'`), or @@ -1724,9 +1646,9 @@ fn instruction_split( // st1'·(hv0·(st0' - (2^32 - 1)) - 1) // lo·(hv0·(hi - 0xffff_ffff)) - 1) let hv0_holds_inverse_of_chunk_difference_or_low_bits_are_0 = { - let hv0 = curr_base_row(HV0); - let hi = next_base_row(ST1); - let lo = next_base_row(ST0); + let hv0 = curr_main_row(HV0); + let hi = next_main_row(ST1); + let lo = next_main_row(ST0); let ffff_ffff = constant(0xffff_ffff); lo * (hv0 * (hi - ffff_ffff) - one()) @@ -1809,16 +1731,14 @@ fn instruction_pow( fn instruction_div_mod( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); // `n == d·q + r` means `st0 - st1·st1' - st0'` let numerator_is_quotient_times_denominator_plus_remainder = - curr_base_row(ST0) - curr_base_row(ST1) * next_base_row(ST1) - next_base_row(ST0); + curr_main_row(ST0) - curr_main_row(ST1) * next_main_row(ST1) - next_main_row(ST0); let specific_constraints = vec![numerator_is_quotient_times_denominator_plus_remainder]; [ @@ -1846,16 +1766,14 @@ fn instruction_pop_count( fn instruction_xx_add( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); - let st0_becomes_st0_plus_st3 = next_base_row(ST0) - curr_base_row(ST0) - curr_base_row(ST3); - let st1_becomes_st1_plus_st4 = next_base_row(ST1) - curr_base_row(ST1) - curr_base_row(ST4); - let st2_becomes_st2_plus_st5 = next_base_row(ST2) - curr_base_row(ST2) - curr_base_row(ST5); + let st0_becomes_st0_plus_st3 = next_main_row(ST0) - curr_main_row(ST0) - curr_main_row(ST3); + let st1_becomes_st1_plus_st4 = next_main_row(ST1) - curr_main_row(ST1) - curr_main_row(ST4); + let st2_becomes_st2_plus_st5 = next_main_row(ST2) - curr_main_row(ST2) - curr_main_row(ST5); let specific_constraints = vec![ st0_becomes_st0_plus_st3, st1_becomes_st1_plus_st4, @@ -1875,20 +1793,18 @@ fn instruction_xx_add( fn instruction_xx_mul( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); - let [x0, x1, x2, y0, y1, y2] = [ST0, ST1, ST2, ST3, ST4, ST5].map(curr_base_row); + let [x0, x1, x2, y0, y1, y2] = [ST0, ST1, ST2, ST3, ST4, ST5].map(curr_main_row); let [c0, c1, c2] = xx_product([x0, x1, x2], [y0, y1, y2]); let specific_constraints = vec![ - next_base_row(ST0) - c0, - next_base_row(ST1) - c1, - next_base_row(ST2) - c2, + next_main_row(ST0) - c0, + next_main_row(ST1) - c1, + next_main_row(ST2) - c2, ]; [ specific_constraints, @@ -1904,30 +1820,28 @@ fn instruction_xinv( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { let constant = |c: u64| circuit_builder.b_constant(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - let first_coefficient_of_product_of_element_and_inverse_is_1 = curr_base_row(ST0) - * next_base_row(ST0) - - curr_base_row(ST2) * next_base_row(ST1) - - curr_base_row(ST1) * next_base_row(ST2) + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); + + let first_coefficient_of_product_of_element_and_inverse_is_1 = curr_main_row(ST0) + * next_main_row(ST0) + - curr_main_row(ST2) * next_main_row(ST1) + - curr_main_row(ST1) * next_main_row(ST2) - constant(1); let second_coefficient_of_product_of_element_and_inverse_is_0 = - curr_base_row(ST1) * next_base_row(ST0) + curr_base_row(ST0) * next_base_row(ST1) - - curr_base_row(ST2) * next_base_row(ST2) - + curr_base_row(ST2) * next_base_row(ST1) - + curr_base_row(ST1) * next_base_row(ST2); + curr_main_row(ST1) * next_main_row(ST0) + curr_main_row(ST0) * next_main_row(ST1) + - curr_main_row(ST2) * next_main_row(ST2) + + curr_main_row(ST2) * next_main_row(ST1) + + curr_main_row(ST1) * next_main_row(ST2); - let third_coefficient_of_product_of_element_and_inverse_is_0 = curr_base_row(ST2) - * next_base_row(ST0) - + curr_base_row(ST1) * next_base_row(ST1) - + curr_base_row(ST0) * next_base_row(ST2) - + curr_base_row(ST2) * next_base_row(ST2); + let third_coefficient_of_product_of_element_and_inverse_is_0 = curr_main_row(ST2) + * next_main_row(ST0) + + curr_main_row(ST1) * next_main_row(ST1) + + curr_main_row(ST0) * next_main_row(ST2) + + curr_main_row(ST2) * next_main_row(ST2); let specific_constraints = vec![ first_coefficient_of_product_of_element_and_inverse_is_1, @@ -1947,20 +1861,18 @@ fn instruction_xinv( fn instruction_xb_mul( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); - let [x, y0, y1, y2] = [ST0, ST1, ST2, ST3].map(curr_base_row); + let [x, y0, y1, y2] = [ST0, ST1, ST2, ST3].map(curr_main_row); let [c0, c1, c2] = xb_product([y0, y1, y2], x); let specific_constraints = vec![ - next_base_row(ST0) - c0, - next_base_row(ST1) - c1, - next_base_row(ST2) - c2, + next_main_row(ST0) - c0, + next_main_row(ST1) - c1, + next_main_row(ST2) - c2, ]; [ specific_constraints, @@ -2033,20 +1945,17 @@ fn read_from_ram_to( ram_pointers: [ConstraintCircuitMonad; N], destinations: [ConstraintCircuitMonad; N], ) -> ConstraintCircuitMonad { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let curr_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(CurrentAux(col.master_aux_index())); + let next_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(NextAux(col.master_aux_index())); let challenge = |c: ChallengeId| circuit_builder.challenge(c); let constant = |bfe| circuit_builder.b_constant(bfe); let compress_row = |(ram_pointer, destination)| { - curr_base_row(CLK) * challenge(RamClkWeight) + curr_main_row(CLK) * challenge(RamClkWeight) + constant(table::ram::INSTRUCTION_TYPE_READ) * challenge(RamInstructionTypeWeight) + ram_pointer * challenge(RamPointerWeight) + destination * challenge(RamValueWeight) @@ -2059,7 +1968,7 @@ fn read_from_ram_to( .map(|compressed_row| challenge(RamIndeterminate) - compressed_row) .reduce(|l, r| l * r) .unwrap_or_else(|| constant(bfe!(1))); - curr_ext_row(RamTablePermArg) * factor - next_ext_row(RamTablePermArg) + curr_aux_row(RamTablePermArg) * factor - next_aux_row(RamTablePermArg) } fn xx_product( @@ -2088,17 +1997,15 @@ fn xb_product( fn update_dotstep_accumulator( circuit_builder: &ConstraintCircuitBuilder, - accumulator_indices: [ProcessorBaseTableColumn; EXTENSION_DEGREE], + accumulator_indices: [ProcessorMainColumn; EXTENSION_DEGREE], difference: [ConstraintCircuitMonad; EXTENSION_DEGREE], ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - let curr = accumulator_indices.map(curr_base_row); - let next = accumulator_indices.map(next_base_row); + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); + let curr = accumulator_indices.map(curr_main_row); + let next = accumulator_indices.map(next_main_row); izip!(curr, next, difference) .map(|(c, n, d)| n - c - d) .collect() @@ -2107,25 +2014,23 @@ fn update_dotstep_accumulator( fn instruction_xx_dot_step( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); let constant = |c| circuit_builder.b_constant(c); - let increment_ram_pointer_st0 = next_base_row(ST0) - curr_base_row(ST0) - constant(3); - let increment_ram_pointer_st1 = next_base_row(ST1) - curr_base_row(ST1) - constant(3); + let increment_ram_pointer_st0 = next_main_row(ST0) - curr_main_row(ST0) - constant(3); + let increment_ram_pointer_st1 = next_main_row(ST1) - curr_main_row(ST1) - constant(3); - let rhs_ptr0 = curr_base_row(ST0); + let rhs_ptr0 = curr_main_row(ST0); let rhs_ptr1 = rhs_ptr0.clone() + constant(1); let rhs_ptr2 = rhs_ptr0.clone() + constant(2); - let lhs_ptr0 = curr_base_row(ST1); + let lhs_ptr0 = curr_main_row(ST1); let lhs_ptr1 = lhs_ptr0.clone() + constant(1); let lhs_ptr2 = lhs_ptr0.clone() + constant(2); let ram_read_sources = [rhs_ptr0, rhs_ptr1, rhs_ptr2, lhs_ptr0, lhs_ptr1, lhs_ptr2]; - let ram_read_destinations = [HV0, HV1, HV2, HV3, HV4, HV5].map(curr_base_row); + let ram_read_destinations = [HV0, HV1, HV2, HV3, HV4, HV5].map(curr_main_row); let read_two_xfes_from_ram = read_from_ram_to(circuit_builder, ram_read_sources, ram_read_destinations); @@ -2135,7 +2040,7 @@ fn instruction_xx_dot_step( read_two_xfes_from_ram, ]; - let [hv0, hv1, hv2, hv3, hv4, hv5] = [HV0, HV1, HV2, HV3, HV4, HV5].map(curr_base_row); + let [hv0, hv1, hv2, hv3, hv4, hv5] = [HV0, HV1, HV2, HV3, HV4, HV5].map(curr_main_row); let hv_product = xx_product([hv0, hv1, hv2], [hv3, hv4, hv5]); [ @@ -2151,23 +2056,21 @@ fn instruction_xx_dot_step( fn instruction_xb_dot_step( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); let constant = |c| circuit_builder.b_constant(c); - let increment_ram_pointer_st0 = next_base_row(ST0) - curr_base_row(ST0) - constant(1); - let increment_ram_pointer_st1 = next_base_row(ST1) - curr_base_row(ST1) - constant(3); + let increment_ram_pointer_st0 = next_main_row(ST0) - curr_main_row(ST0) - constant(1); + let increment_ram_pointer_st1 = next_main_row(ST1) - curr_main_row(ST1) - constant(3); - let rhs_ptr0 = curr_base_row(ST0); - let lhs_ptr0 = curr_base_row(ST1); + let rhs_ptr0 = curr_main_row(ST0); + let lhs_ptr0 = curr_main_row(ST1); let lhs_ptr1 = lhs_ptr0.clone() + constant(1); let lhs_ptr2 = lhs_ptr0.clone() + constant(2); let ram_read_sources = [rhs_ptr0, lhs_ptr0, lhs_ptr1, lhs_ptr2]; - let ram_read_destinations = [HV0, HV1, HV2, HV3].map(curr_base_row); + let ram_read_destinations = [HV0, HV1, HV2, HV3].map(curr_main_row); let read_bfe_and_xfe_from_ram = read_from_ram_to(circuit_builder, ram_read_sources, ram_read_destinations); @@ -2177,7 +2080,7 @@ fn instruction_xb_dot_step( read_bfe_and_xfe_from_ram, ]; - let [hv0, hv1, hv2, hv3] = [HV0, HV1, HV2, HV3].map(curr_base_row); + let [hv0, hv1, hv2, hv3] = [HV0, HV1, HV2, HV3].map(curr_main_row); let hv_product = xb_product([hv1, hv2, hv3], hv0); [ @@ -2257,46 +2160,39 @@ fn log_derivative_accumulates_clk_next( circuit_builder: &ConstraintCircuitBuilder, ) -> ConstraintCircuitMonad { let challenge = |c: ChallengeId| circuit_builder.challenge(c); - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); + let curr_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(CurrentAux(col.master_aux_index())); + let next_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(NextAux(col.master_aux_index())); - (next_ext_row(ClockJumpDifferenceLookupServerLogDerivative) - - curr_ext_row(ClockJumpDifferenceLookupServerLogDerivative)) - * (challenge(ClockJumpDifferenceLookupIndeterminate) - next_base_row(CLK)) - - next_base_row(ClockJumpDifferenceLookupMultiplicity) + (next_aux_row(ClockJumpDifferenceLookupServerLogDerivative) + - curr_aux_row(ClockJumpDifferenceLookupServerLogDerivative)) + * (challenge(ClockJumpDifferenceLookupIndeterminate) - next_main_row(CLK)) + - next_main_row(ClockJumpDifferenceLookupMultiplicity) } fn running_evaluation_for_standard_input_remains_unchanged( circuit_builder: &ConstraintCircuitBuilder, ) -> ConstraintCircuitMonad { - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; + let curr_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(CurrentAux(col.master_aux_index())); + let next_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(NextAux(col.master_aux_index())); - next_ext_row(InputTableEvalArg) - curr_ext_row(InputTableEvalArg) + next_aux_row(InputTableEvalArg) - curr_aux_row(InputTableEvalArg) } fn running_evaluation_for_standard_output_remains_unchanged( circuit_builder: &ConstraintCircuitBuilder, ) -> ConstraintCircuitMonad { - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; + let curr_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(CurrentAux(col.master_aux_index())); + let next_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(NextAux(col.master_aux_index())); - next_ext_row(OutputTableEvalArg) - curr_ext_row(OutputTableEvalArg) + next_aux_row(OutputTableEvalArg) - curr_aux_row(OutputTableEvalArg) } fn grow_stack_by_n_and_read_n_symbols_from_input( @@ -2304,22 +2200,19 @@ fn grow_stack_by_n_and_read_n_symbols_from_input( n: usize, ) -> Vec> { let indeterminate = || circuit_builder.challenge(StandardInputIndeterminate); - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; - - let mut running_evaluation = curr_ext_row(InputTableEvalArg); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); + let curr_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(CurrentAux(col.master_aux_index())); + let next_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(NextAux(col.master_aux_index())); + + let mut running_evaluation = curr_aux_row(InputTableEvalArg); for i in (0..n).rev() { let stack_element = ProcessorTable::op_stack_column_by_index(i); - running_evaluation = indeterminate() * running_evaluation + next_base_row(stack_element); + running_evaluation = indeterminate() * running_evaluation + next_main_row(stack_element); } - let running_evaluation_update = next_ext_row(InputTableEvalArg) - running_evaluation; + let running_evaluation_update = next_aux_row(InputTableEvalArg) - running_evaluation; let conditional_running_evaluation_update = indicator_polynomial(circuit_builder, n) * running_evaluation_update; @@ -2333,22 +2226,19 @@ fn shrink_stack_by_n_and_write_n_symbols_to_output( n: usize, ) -> Vec> { let indeterminate = || circuit_builder.challenge(StandardOutputIndeterminate); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; - - let mut running_evaluation = curr_ext_row(OutputTableEvalArg); + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let curr_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(CurrentAux(col.master_aux_index())); + let next_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(NextAux(col.master_aux_index())); + + let mut running_evaluation = curr_aux_row(OutputTableEvalArg); for i in 0..n { let stack_element = ProcessorTable::op_stack_column_by_index(i); - running_evaluation = indeterminate() * running_evaluation + curr_base_row(stack_element); + running_evaluation = indeterminate() * running_evaluation + curr_main_row(stack_element); } - let running_evaluation_update = next_ext_row(OutputTableEvalArg) - running_evaluation; + let running_evaluation_update = next_aux_row(OutputTableEvalArg) - running_evaluation; let conditional_running_evaluation_update = indicator_polynomial(circuit_builder, n) * running_evaluation_update; @@ -2362,53 +2252,48 @@ fn log_derivative_for_instruction_lookup_updates_correctly( ) -> ConstraintCircuitMonad { let one = || circuit_builder.b_constant(1); let challenge = |c: ChallengeId| circuit_builder.challenge(c); - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; - - let compressed_row = challenge(ProgramAddressWeight) * next_base_row(IP) - + challenge(ProgramInstructionWeight) * next_base_row(CI) - + challenge(ProgramNextInstructionWeight) * next_base_row(NIA); - let log_derivative_updates = (next_ext_row(InstructionLookupClientLogDerivative) - - curr_ext_row(InstructionLookupClientLogDerivative)) + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); + let curr_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(CurrentAux(col.master_aux_index())); + let next_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(NextAux(col.master_aux_index())); + + let compressed_row = challenge(ProgramAddressWeight) * next_main_row(IP) + + challenge(ProgramInstructionWeight) * next_main_row(CI) + + challenge(ProgramNextInstructionWeight) * next_main_row(NIA); + let log_derivative_updates = (next_aux_row(InstructionLookupClientLogDerivative) + - curr_aux_row(InstructionLookupClientLogDerivative)) * (challenge(InstructionLookupIndeterminate) - compressed_row) - one(); - let log_derivative_remains = next_ext_row(InstructionLookupClientLogDerivative) - - curr_ext_row(InstructionLookupClientLogDerivative); + let log_derivative_remains = next_aux_row(InstructionLookupClientLogDerivative) + - curr_aux_row(InstructionLookupClientLogDerivative); - (one() - next_base_row(IsPadding)) * log_derivative_updates - + next_base_row(IsPadding) * log_derivative_remains + (one() - next_main_row(IsPadding)) * log_derivative_updates + + next_main_row(IsPadding) * log_derivative_remains } fn constraints_for_shrinking_stack_by_3_and_top_3_unconstrained( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { let constant = |c: u64| circuit_builder.b_constant(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); vec![ - next_base_row(ST3) - curr_base_row(ST6), - next_base_row(ST4) - curr_base_row(ST7), - next_base_row(ST5) - curr_base_row(ST8), - next_base_row(ST6) - curr_base_row(ST9), - next_base_row(ST7) - curr_base_row(ST10), - next_base_row(ST8) - curr_base_row(ST11), - next_base_row(ST9) - curr_base_row(ST12), - next_base_row(ST10) - curr_base_row(ST13), - next_base_row(ST11) - curr_base_row(ST14), - next_base_row(ST12) - curr_base_row(ST15), - next_base_row(OpStackPointer) - curr_base_row(OpStackPointer) + constant(3), + next_main_row(ST3) - curr_main_row(ST6), + next_main_row(ST4) - curr_main_row(ST7), + next_main_row(ST5) - curr_main_row(ST8), + next_main_row(ST6) - curr_main_row(ST9), + next_main_row(ST7) - curr_main_row(ST10), + next_main_row(ST8) - curr_main_row(ST11), + next_main_row(ST9) - curr_main_row(ST12), + next_main_row(ST10) - curr_main_row(ST13), + next_main_row(ST11) - curr_main_row(ST14), + next_main_row(ST12) - curr_main_row(ST15), + next_main_row(OpStackPointer) - curr_main_row(OpStackPointer) + constant(3), running_product_op_stack_accounts_for_shrinking_stack_by(circuit_builder, 3), ] } @@ -2503,12 +2388,10 @@ fn constraints_for_shrinking_stack_by( n: usize, ) -> Vec> { let constant = |c: usize| circuit_builder.b_constant(u64::try_from(c).unwrap()); - let curr_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); let stack = || (0..OpStackElement::COUNT).map(ProcessorTable::op_stack_column_by_index); let new_stack = stack().dropping_back(n).map(next_row).collect_vec(); @@ -2539,12 +2422,10 @@ fn constraints_for_growing_stack_by( n: usize, ) -> Vec> { let constant = |c: usize| circuit_builder.b_constant(u32::try_from(c).unwrap()); - let curr_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); let stack = || (0..OpStackElement::COUNT).map(ProcessorTable::op_stack_column_by_index); let new_stack = stack().skip(n).map(next_row).collect_vec(); @@ -2595,16 +2476,14 @@ fn running_product_op_stack_accounts_for_growing_stack_by( n: usize, ) -> ConstraintCircuitMonad { let constant = |c: u32| circuit_builder.b_constant(c); - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; + let curr_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(CurrentAux(col.master_aux_index())); + let next_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(NextAux(col.master_aux_index())); let single_grow_factor = |op_stack_pointer_offset| { single_factor_for_permutation_argument_with_op_stack_table( circuit_builder, - CurrentBaseRow, + CurrentMain, op_stack_pointer_offset, ) }; @@ -2614,7 +2493,7 @@ fn running_product_op_stack_accounts_for_growing_stack_by( factor = factor * single_grow_factor(op_stack_pointer_offset); } - next_ext_row(OpStackTablePermArg) - curr_ext_row(OpStackTablePermArg) * factor + next_aux_row(OpStackTablePermArg) - curr_aux_row(OpStackTablePermArg) * factor } fn running_product_op_stack_accounts_for_shrinking_stack_by( @@ -2622,16 +2501,14 @@ fn running_product_op_stack_accounts_for_shrinking_stack_by( n: usize, ) -> ConstraintCircuitMonad { let constant = |c: u32| circuit_builder.b_constant(c); - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; + let curr_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(CurrentAux(col.master_aux_index())); + let next_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(NextAux(col.master_aux_index())); let single_shrink_factor = |op_stack_pointer_offset| { single_factor_for_permutation_argument_with_op_stack_table( circuit_builder, - NextBaseRow, + NextMain, op_stack_pointer_offset, ) }; @@ -2641,7 +2518,7 @@ fn running_product_op_stack_accounts_for_shrinking_stack_by( factor = factor * single_shrink_factor(op_stack_pointer_offset); } - next_ext_row(OpStackTablePermArg) - curr_ext_row(OpStackTablePermArg) * factor + next_aux_row(OpStackTablePermArg) - curr_aux_row(OpStackTablePermArg) * factor } fn single_factor_for_permutation_argument_with_op_stack_table( @@ -2651,13 +2528,10 @@ fn single_factor_for_permutation_argument_with_op_stack_table( ) -> ConstraintCircuitMonad { let constant = |c: u32| circuit_builder.b_constant(c); let challenge = |c: ChallengeId| circuit_builder.challenge(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let row_with_shorter_stack = |col: ProcessorBaseTableColumn| { - circuit_builder.input(row_with_shorter_stack_indicator( - col.master_base_table_index(), - )) + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let row_with_shorter_stack = |col: ProcessorMainColumn| { + circuit_builder.input(row_with_shorter_stack_indicator(col.master_main_index())) }; let max_stack_element_index = OpStackElement::COUNT - 1; @@ -2669,8 +2543,8 @@ fn single_factor_for_permutation_argument_with_op_stack_table( let offset = constant(op_stack_pointer_offset as u32); let offset_op_stack_pointer = op_stack_pointer + offset; - let compressed_row = challenge(OpStackClkWeight) * curr_base_row(CLK) - + challenge(OpStackIb1Weight) * curr_base_row(IB1) + let compressed_row = challenge(OpStackClkWeight) * curr_main_row(CLK) + + challenge(OpStackIb1Weight) * curr_main_row(IB1) + challenge(OpStackPointerWeight) * offset_op_stack_pointer + challenge(OpStackFirstUnderflowElementWeight) * underflow_element; challenge(OpStackIndeterminate) - compressed_row @@ -2730,16 +2604,14 @@ fn shrink_stack_by_n_and_write_n_elements_to_ram( n: usize, ) -> Vec> { let constant = |c: usize| circuit_builder.b_constant(u32::try_from(c).unwrap()); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); let op_stack_pointer_shrinks_by_n = - next_base_row(OpStackPointer) - curr_base_row(OpStackPointer) + constant(n); - let ram_pointer_grows_by_n = next_base_row(ST0) - curr_base_row(ST0) - constant(n); + next_main_row(OpStackPointer) - curr_main_row(OpStackPointer) + constant(n); + let ram_pointer_grows_by_n = next_main_row(ST0) - curr_main_row(ST0) - constant(n); let mut constraints = vec![ op_stack_pointer_shrinks_by_n, @@ -2753,7 +2625,7 @@ fn shrink_stack_by_n_and_write_n_elements_to_ram( let curr_stack_element = ProcessorTable::op_stack_column_by_index(i); let next_stack_element = ProcessorTable::op_stack_column_by_index(i - n); let element_i_is_shifted_by_n = - next_base_row(next_stack_element) - curr_base_row(curr_stack_element); + next_main_row(next_stack_element) - curr_main_row(curr_stack_element); constraints.push(element_i_is_shifted_by_n); } constraints @@ -2764,16 +2636,14 @@ fn grow_stack_by_n_and_read_n_elements_from_ram( n: usize, ) -> Vec> { let constant = |c: usize| circuit_builder.b_constant(u64::try_from(c).unwrap()); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); let op_stack_pointer_grows_by_n = - next_base_row(OpStackPointer) - curr_base_row(OpStackPointer) - constant(n); - let ram_pointer_shrinks_by_n = next_base_row(ST0) - curr_base_row(ST0) + constant(n); + next_main_row(OpStackPointer) - curr_main_row(OpStackPointer) - constant(n); + let ram_pointer_shrinks_by_n = next_main_row(ST0) - curr_main_row(ST0) + constant(n); let mut constraints = vec![ op_stack_pointer_grows_by_n, @@ -2787,7 +2657,7 @@ fn grow_stack_by_n_and_read_n_elements_from_ram( let curr_stack_element = ProcessorTable::op_stack_column_by_index(i); let next_stack_element = ProcessorTable::op_stack_column_by_index(i + n); let element_i_is_shifted_by_n = - next_base_row(next_stack_element) - curr_base_row(curr_stack_element); + next_main_row(next_stack_element) - curr_main_row(curr_stack_element); constraints.push(element_i_is_shifted_by_n); } constraints @@ -2798,16 +2668,14 @@ fn running_product_ram_accounts_for_writing_n_elements( n: usize, ) -> ConstraintCircuitMonad { let constant = |c: u32| circuit_builder.b_constant(c); - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; + let curr_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(CurrentAux(col.master_aux_index())); + let next_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(NextAux(col.master_aux_index())); let single_write_factor = |ram_pointer_offset| { single_factor_for_permutation_argument_with_ram_table( circuit_builder, - CurrentBaseRow, + CurrentMain, table::ram::INSTRUCTION_TYPE_WRITE, ram_pointer_offset, ) @@ -2818,7 +2686,7 @@ fn running_product_ram_accounts_for_writing_n_elements( factor = factor * single_write_factor(ram_pointer_offset); } - next_ext_row(RamTablePermArg) - curr_ext_row(RamTablePermArg) * factor + next_aux_row(RamTablePermArg) - curr_aux_row(RamTablePermArg) * factor } fn running_product_ram_accounts_for_reading_n_elements( @@ -2826,16 +2694,14 @@ fn running_product_ram_accounts_for_reading_n_elements( n: usize, ) -> ConstraintCircuitMonad { let constant = |c: u32| circuit_builder.b_constant(c); - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; + let curr_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(CurrentAux(col.master_aux_index())); + let next_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(NextAux(col.master_aux_index())); let single_read_factor = |ram_pointer_offset| { single_factor_for_permutation_argument_with_ram_table( circuit_builder, - NextBaseRow, + NextMain, table::ram::INSTRUCTION_TYPE_READ, ram_pointer_offset, ) @@ -2846,7 +2712,7 @@ fn running_product_ram_accounts_for_reading_n_elements( factor = factor * single_read_factor(ram_pointer_offset); } - next_ext_row(RamTablePermArg) - curr_ext_row(RamTablePermArg) * factor + next_aux_row(RamTablePermArg) - curr_aux_row(RamTablePermArg) * factor } fn single_factor_for_permutation_argument_with_ram_table( @@ -2858,13 +2724,10 @@ fn single_factor_for_permutation_argument_with_ram_table( let constant = |c: u32| circuit_builder.b_constant(c); let b_constant = |c| circuit_builder.b_constant(c); let challenge = |c: ChallengeId| circuit_builder.challenge(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let row_with_longer_stack = |col: ProcessorBaseTableColumn| { - circuit_builder.input(row_with_longer_stack_indicator( - col.master_base_table_index(), - )) + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let row_with_longer_stack = |col: ProcessorMainColumn| { + circuit_builder.input(row_with_longer_stack_indicator(col.master_main_index())) }; let num_ram_pointers = 1; @@ -2882,7 +2745,7 @@ fn single_factor_for_permutation_argument_with_ram_table( let offset = constant(additional_offset + ram_pointer_offset as u32); let offset_ram_pointer = ram_pointer + offset; - let compressed_row = curr_base_row(CLK) * challenge(RamClkWeight) + let compressed_row = curr_main_row(CLK) * challenge(RamClkWeight) + b_constant(instruction_type) * challenge(RamInstructionTypeWeight) + offset_ram_pointer * challenge(RamPointerWeight) + ram_value * challenge(RamValueWeight); @@ -2893,24 +2756,21 @@ fn running_product_for_jump_stack_table_updates_correctly( circuit_builder: &ConstraintCircuitBuilder, ) -> ConstraintCircuitMonad { let challenge = |c: ChallengeId| circuit_builder.challenge(c); - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); + let curr_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(CurrentAux(col.master_aux_index())); + let next_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(NextAux(col.master_aux_index())); - let compressed_row = challenge(JumpStackClkWeight) * next_base_row(CLK) - + challenge(JumpStackCiWeight) * next_base_row(CI) - + challenge(JumpStackJspWeight) * next_base_row(JSP) - + challenge(JumpStackJsoWeight) * next_base_row(JSO) - + challenge(JumpStackJsdWeight) * next_base_row(JSD); + let compressed_row = challenge(JumpStackClkWeight) * next_main_row(CLK) + + challenge(JumpStackCiWeight) * next_main_row(CI) + + challenge(JumpStackJspWeight) * next_main_row(JSP) + + challenge(JumpStackJsoWeight) * next_main_row(JSO) + + challenge(JumpStackJsdWeight) * next_main_row(JSD); - next_ext_row(JumpStackTablePermArg) - - curr_ext_row(JumpStackTablePermArg) * (challenge(JumpStackIndeterminate) - compressed_row) + next_aux_row(JumpStackTablePermArg) + - curr_aux_row(JumpStackTablePermArg) * (challenge(JumpStackIndeterminate) - compressed_row) } /// Deal with instructions `hash` and `merkle_step`. The registers from which @@ -2930,24 +2790,21 @@ fn running_evaluation_hash_input_updates_correctly( let constant = |c: u32| circuit_builder.b_constant(c); let one = || constant(1); let challenge = |c: ChallengeId| circuit_builder.challenge(c); - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); + let curr_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(CurrentAux(col.master_aux_index())); + let next_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(NextAux(col.master_aux_index())); let hash_deselector = instruction_deselector_next_row(circuit_builder, Instruction::Hash); let merkle_step_deselector = instruction_deselector_next_row(circuit_builder, Instruction::MerkleStep); let merkle_step_mem_deselector = instruction_deselector_next_row(circuit_builder, Instruction::MerkleStepMem); - let hash_and_merkle_step_selector = (next_base_row(CI) - constant(Instruction::Hash.opcode())) - * (next_base_row(CI) - constant(Instruction::MerkleStep.opcode())) - * (next_base_row(CI) - constant(Instruction::MerkleStepMem.opcode())); + let hash_and_merkle_step_selector = (next_main_row(CI) - constant(Instruction::Hash.opcode())) + * (next_main_row(CI) - constant(Instruction::MerkleStep.opcode())) + * (next_main_row(CI) - constant(Instruction::MerkleStepMem.opcode())); let weights = [ StackWeight0, @@ -2964,7 +2821,7 @@ fn running_evaluation_hash_input_updates_correctly( .map(challenge); // hash - let state_for_hash = [ST0, ST1, ST2, ST3, ST4, ST5, ST6, ST7, ST8, ST9].map(next_base_row); + let state_for_hash = [ST0, ST1, ST2, ST3, ST4, ST5, ST6, ST7, ST8, ST9].map(next_main_row); let compressed_hash_row = weights .iter() .zip_eq(state_for_hash) @@ -2972,10 +2829,10 @@ fn running_evaluation_hash_input_updates_correctly( .sum(); // merkle step - let is_left_sibling = || next_base_row(HV5); - let is_right_sibling = || one() - next_base_row(HV5); + let is_left_sibling = || next_main_row(HV5); + let is_right_sibling = || one() - next_main_row(HV5); let merkle_step_state_element = - |l, r| is_right_sibling() * next_base_row(l) + is_left_sibling() * next_base_row(r); + |l, r| is_right_sibling() * next_main_row(l) + is_left_sibling() * next_main_row(r); let state_for_merkle_step = [ merkle_step_state_element(ST0, HV0), merkle_step_state_element(ST1, HV1), @@ -2995,12 +2852,12 @@ fn running_evaluation_hash_input_updates_correctly( .sum::>(); let running_evaluation_updates_with = |compressed_row| { - next_ext_row(HashInputEvalArg) - - challenge(HashInputIndeterminate) * curr_ext_row(HashInputEvalArg) + next_aux_row(HashInputEvalArg) + - challenge(HashInputIndeterminate) * curr_aux_row(HashInputEvalArg) - compressed_row }; let running_evaluation_remains = - next_ext_row(HashInputEvalArg) - curr_ext_row(HashInputEvalArg); + next_aux_row(HashInputEvalArg) - curr_aux_row(HashInputEvalArg); hash_and_merkle_step_selector * running_evaluation_remains + hash_deselector * running_evaluation_updates_with(compressed_hash_row) @@ -3014,27 +2871,23 @@ fn running_evaluation_hash_digest_updates_correctly( ) -> ConstraintCircuitMonad { let constant = |c: u32| circuit_builder.b_constant(c); let challenge = |c: ChallengeId| circuit_builder.challenge(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); + let curr_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(CurrentAux(col.master_aux_index())); + let next_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(NextAux(col.master_aux_index())); let hash_deselector = instruction_deselector_current_row(circuit_builder, Instruction::Hash); let merkle_step_deselector = instruction_deselector_current_row(circuit_builder, Instruction::MerkleStep); let merkle_step_mem_deselector = instruction_deselector_current_row(circuit_builder, Instruction::MerkleStepMem); - let hash_and_merkle_step_selector = (curr_base_row(CI) - constant(Instruction::Hash.opcode())) - * (curr_base_row(CI) - constant(Instruction::MerkleStep.opcode())) - * (curr_base_row(CI) - constant(Instruction::MerkleStepMem.opcode())); + let hash_and_merkle_step_selector = (curr_main_row(CI) - constant(Instruction::Hash.opcode())) + * (curr_main_row(CI) - constant(Instruction::MerkleStep.opcode())) + * (curr_main_row(CI) - constant(Instruction::MerkleStepMem.opcode())); let weights = [ StackWeight0, @@ -3044,18 +2897,18 @@ fn running_evaluation_hash_digest_updates_correctly( StackWeight4, ] .map(challenge); - let state = [ST0, ST1, ST2, ST3, ST4].map(next_base_row); + let state = [ST0, ST1, ST2, ST3, ST4].map(next_main_row); let compressed_row = weights .into_iter() .zip_eq(state) .map(|(weight, state)| weight * state) .sum(); - let running_evaluation_updates = next_ext_row(HashDigestEvalArg) - - challenge(HashDigestIndeterminate) * curr_ext_row(HashDigestEvalArg) + let running_evaluation_updates = next_aux_row(HashDigestEvalArg) + - challenge(HashDigestIndeterminate) * curr_aux_row(HashDigestEvalArg) - compressed_row; let running_evaluation_remains = - next_ext_row(HashDigestEvalArg) - curr_ext_row(HashDigestEvalArg); + next_aux_row(HashDigestEvalArg) - curr_aux_row(HashDigestEvalArg); hash_and_merkle_step_selector * running_evaluation_remains + (hash_deselector + merkle_step_deselector + merkle_step_mem_deselector) @@ -3067,18 +2920,14 @@ fn running_evaluation_sponge_updates_correctly( ) -> ConstraintCircuitMonad { let constant = |c: u32| circuit_builder.b_constant(c); let challenge = |c: ChallengeId| circuit_builder.challenge(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); + let curr_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(CurrentAux(col.master_aux_index())); + let next_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(NextAux(col.master_aux_index())); let sponge_init_deselector = instruction_deselector_current_row(circuit_builder, Instruction::SpongeInit); @@ -3089,11 +2938,11 @@ fn running_evaluation_sponge_updates_correctly( let sponge_squeeze_deselector = instruction_deselector_current_row(circuit_builder, Instruction::SpongeSqueeze); - let sponge_instruction_selector = (curr_base_row(CI) + let sponge_instruction_selector = (curr_main_row(CI) - constant(Instruction::SpongeInit.opcode())) - * (curr_base_row(CI) - constant(Instruction::SpongeAbsorb.opcode())) - * (curr_base_row(CI) - constant(Instruction::SpongeAbsorbMem.opcode())) - * (curr_base_row(CI) - constant(Instruction::SpongeSqueeze.opcode())); + * (curr_main_row(CI) - constant(Instruction::SpongeAbsorb.opcode())) + * (curr_main_row(CI) - constant(Instruction::SpongeAbsorbMem.opcode())) + * (curr_main_row(CI) - constant(Instruction::SpongeSqueeze.opcode())); let weighted_sum = |state| { let weights = [ @@ -3113,28 +2962,28 @@ fn running_evaluation_sponge_updates_correctly( }; let state = [ST0, ST1, ST2, ST3, ST4, ST5, ST6, ST7, ST8, ST9]; - let compressed_row_current = weighted_sum(state.map(curr_base_row)); - let compressed_row_next = weighted_sum(state.map(next_base_row)); + let compressed_row_current = weighted_sum(state.map(curr_main_row)); + let compressed_row_next = weighted_sum(state.map(next_main_row)); // Use domain-specific knowledge: the compressed row (i.e., random linear sum) // of the initial Sponge state is 0. - let running_evaluation_updates_for_sponge_init = next_ext_row(SpongeEvalArg) - - challenge(SpongeIndeterminate) * curr_ext_row(SpongeEvalArg) - - challenge(HashCIWeight) * curr_base_row(CI); + let running_evaluation_updates_for_sponge_init = next_aux_row(SpongeEvalArg) + - challenge(SpongeIndeterminate) * curr_aux_row(SpongeEvalArg) + - challenge(HashCIWeight) * curr_main_row(CI); let running_evaluation_updates_for_absorb = running_evaluation_updates_for_sponge_init.clone() - compressed_row_current; let running_evaluation_updates_for_squeeze = running_evaluation_updates_for_sponge_init.clone() - compressed_row_next; - let running_evaluation_remains = next_ext_row(SpongeEvalArg) - curr_ext_row(SpongeEvalArg); + let running_evaluation_remains = next_aux_row(SpongeEvalArg) - curr_aux_row(SpongeEvalArg); // `sponge_absorb_mem` - let stack_elements = [ST1, ST2, ST3, ST4].map(next_base_row); - let hv_elements = [HV0, HV1, HV2, HV3, HV4, HV5].map(curr_base_row); + let stack_elements = [ST1, ST2, ST3, ST4].map(next_main_row); + let hv_elements = [HV0, HV1, HV2, HV3, HV4, HV5].map(curr_main_row); let absorb_mem_elements = stack_elements.into_iter().chain(hv_elements); let absorb_mem_elements = absorb_mem_elements.collect_vec().try_into().unwrap(); let compressed_row_absorb_mem = weighted_sum(absorb_mem_elements); - let running_evaluation_updates_for_absorb_mem = next_ext_row(SpongeEvalArg) - - challenge(SpongeIndeterminate) * curr_ext_row(SpongeEvalArg) + let running_evaluation_updates_for_absorb_mem = next_aux_row(SpongeEvalArg) + - challenge(SpongeIndeterminate) * curr_aux_row(SpongeEvalArg) - challenge(HashCIWeight) * constant(Instruction::SpongeAbsorb.opcode()) - compressed_row_absorb_mem; @@ -3152,18 +3001,14 @@ fn log_derivative_with_u32_table_updates_correctly( let one = || constant(1); let two_inverse = circuit_builder.b_constant(bfe!(2).inverse()); let challenge = |c: ChallengeId| circuit_builder.challenge(c); - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; + let curr_main_row = + |col: ProcessorMainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: ProcessorMainColumn| circuit_builder.input(NextMain(col.master_main_index())); + let curr_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(CurrentAux(col.master_aux_index())); + let next_aux_row = + |col: ProcessorAuxColumn| circuit_builder.input(NextAux(col.master_aux_index())); let split_deselector = instruction_deselector_current_row(circuit_builder, Instruction::Split); let lt_deselector = instruction_deselector_current_row(circuit_builder, Instruction::Lt); @@ -3181,41 +3026,41 @@ fn log_derivative_with_u32_table_updates_correctly( let merkle_step_mem_deselector = instruction_deselector_current_row(circuit_builder, Instruction::MerkleStepMem); - let running_sum = curr_ext_row(U32LookupClientLogDerivative); - let running_sum_next = next_ext_row(U32LookupClientLogDerivative); + let running_sum = curr_aux_row(U32LookupClientLogDerivative); + let running_sum_next = next_aux_row(U32LookupClientLogDerivative); let split_factor = challenge(U32Indeterminate) - - challenge(U32LhsWeight) * next_base_row(ST0) - - challenge(U32RhsWeight) * next_base_row(ST1) - - challenge(U32CiWeight) * curr_base_row(CI); + - challenge(U32LhsWeight) * next_main_row(ST0) + - challenge(U32RhsWeight) * next_main_row(ST1) + - challenge(U32CiWeight) * curr_main_row(CI); let binop_factor = challenge(U32Indeterminate) - - challenge(U32LhsWeight) * curr_base_row(ST0) - - challenge(U32RhsWeight) * curr_base_row(ST1) - - challenge(U32CiWeight) * curr_base_row(CI) - - challenge(U32ResultWeight) * next_base_row(ST0); + - challenge(U32LhsWeight) * curr_main_row(ST0) + - challenge(U32RhsWeight) * curr_main_row(ST1) + - challenge(U32CiWeight) * curr_main_row(CI) + - challenge(U32ResultWeight) * next_main_row(ST0); let xor_factor = challenge(U32Indeterminate) - - challenge(U32LhsWeight) * curr_base_row(ST0) - - challenge(U32RhsWeight) * curr_base_row(ST1) + - challenge(U32LhsWeight) * curr_main_row(ST0) + - challenge(U32RhsWeight) * curr_main_row(ST1) - challenge(U32CiWeight) * constant(Instruction::And.opcode()) - challenge(U32ResultWeight) - * (curr_base_row(ST0) + curr_base_row(ST1) - next_base_row(ST0)) + * (curr_main_row(ST0) + curr_main_row(ST1) - next_main_row(ST0)) * two_inverse; let unop_factor = challenge(U32Indeterminate) - - challenge(U32LhsWeight) * curr_base_row(ST0) - - challenge(U32CiWeight) * curr_base_row(CI) - - challenge(U32ResultWeight) * next_base_row(ST0); + - challenge(U32LhsWeight) * curr_main_row(ST0) + - challenge(U32CiWeight) * curr_main_row(CI) + - challenge(U32ResultWeight) * next_main_row(ST0); let div_mod_factor_for_lt = challenge(U32Indeterminate) - - challenge(U32LhsWeight) * next_base_row(ST0) - - challenge(U32RhsWeight) * curr_base_row(ST1) + - challenge(U32LhsWeight) * next_main_row(ST0) + - challenge(U32RhsWeight) * curr_main_row(ST1) - challenge(U32CiWeight) * constant(Instruction::Lt.opcode()) - challenge(U32ResultWeight); let div_mod_factor_for_range_check = challenge(U32Indeterminate) - - challenge(U32LhsWeight) * curr_base_row(ST0) - - challenge(U32RhsWeight) * next_base_row(ST1) + - challenge(U32LhsWeight) * curr_main_row(ST0) + - challenge(U32RhsWeight) * next_main_row(ST1) - challenge(U32CiWeight) * constant(Instruction::Split.opcode()); let merkle_step_range_check_factor = challenge(U32Indeterminate) - - challenge(U32LhsWeight) * curr_base_row(ST5) - - challenge(U32RhsWeight) * next_base_row(ST5) + - challenge(U32LhsWeight) * curr_main_row(ST5) + - challenge(U32RhsWeight) * next_main_row(ST5) - challenge(U32CiWeight) * constant(Instruction::Split.opcode()); let running_sum_absorbs_split_factor = @@ -3246,7 +3091,7 @@ fn log_derivative_with_u32_table_updates_correctly( merkle_step_deselector * running_sum_absorbs_merkle_step_factor.clone(); let merkle_step_mem_summand = merkle_step_mem_deselector * running_sum_absorbs_merkle_step_factor; - let no_update_summand = (one() - curr_base_row(IB2)) * (running_sum_next - running_sum); + let no_update_summand = (one() - curr_main_row(IB2)) * (running_sum_next - running_sum); split_summand + lt_summand @@ -3317,12 +3162,12 @@ fn helper_variable( index: usize, ) -> ConstraintCircuitMonad { match index { - 0 => circuit_builder.input(CurrentBaseRow(HV0.master_base_table_index())), - 1 => circuit_builder.input(CurrentBaseRow(HV1.master_base_table_index())), - 2 => circuit_builder.input(CurrentBaseRow(HV2.master_base_table_index())), - 3 => circuit_builder.input(CurrentBaseRow(HV3.master_base_table_index())), - 4 => circuit_builder.input(CurrentBaseRow(HV4.master_base_table_index())), - 5 => circuit_builder.input(CurrentBaseRow(HV5.master_base_table_index())), + 0 => circuit_builder.input(CurrentMain(HV0.master_main_index())), + 1 => circuit_builder.input(CurrentMain(HV1.master_main_index())), + 2 => circuit_builder.input(CurrentMain(HV2.master_main_index())), + 3 => circuit_builder.input(CurrentMain(HV3.master_main_index())), + 4 => circuit_builder.input(CurrentMain(HV4.master_main_index())), + 5 => circuit_builder.input(CurrentMain(HV5.master_main_index())), i => unimplemented!("Helper variable index {i} out of bounds."), } } @@ -3336,8 +3181,8 @@ mod tests { use proptest_arbitrary_interop::arb; use test_strategy::proptest; - use crate::table::NUM_BASE_COLUMNS; - use crate::table::NUM_EXT_COLUMNS; + use crate::table::NUM_AUX_COLUMNS; + use crate::table::NUM_MAIN_COLUMNS; use super::*; @@ -3345,15 +3190,15 @@ mod tests { fn instruction_deselector_gives_0_for_all_other_instructions() { let circuit_builder = ConstraintCircuitBuilder::new(); - let mut master_base_table = Array2::zeros([2, NUM_BASE_COLUMNS]); - let master_ext_table = Array2::zeros([2, NUM_EXT_COLUMNS]); + let mut master_base_table = Array2::zeros([2, NUM_MAIN_COLUMNS]); + let master_ext_table = Array2::zeros([2, NUM_AUX_COLUMNS]); // For this test, dummy challenges suffice to evaluate the constraints. let dummy_challenges = (0..ChallengeId::COUNT) .map(|i| XFieldElement::from(i as u64)) .collect_vec(); for instruction in ALL_INSTRUCTIONS { - use ProcessorBaseTableColumn::*; + use ProcessorMainColumn::*; let deselector = instruction_deselector_current_row(&circuit_builder, instruction); println!("\n\nThe Deselector for instruction {instruction} is:\n{deselector}",); @@ -3364,13 +3209,13 @@ mod tests { .filter(|other_instruction| *other_instruction != instruction) { let mut curr_row = master_base_table.slice_mut(s![0, ..]); - curr_row[IB0.master_base_table_index()] = other_instruction.ib(InstructionBit::IB0); - curr_row[IB1.master_base_table_index()] = other_instruction.ib(InstructionBit::IB1); - curr_row[IB2.master_base_table_index()] = other_instruction.ib(InstructionBit::IB2); - curr_row[IB3.master_base_table_index()] = other_instruction.ib(InstructionBit::IB3); - curr_row[IB4.master_base_table_index()] = other_instruction.ib(InstructionBit::IB4); - curr_row[IB5.master_base_table_index()] = other_instruction.ib(InstructionBit::IB5); - curr_row[IB6.master_base_table_index()] = other_instruction.ib(InstructionBit::IB6); + curr_row[IB0.master_main_index()] = other_instruction.ib(InstructionBit::IB0); + curr_row[IB1.master_main_index()] = other_instruction.ib(InstructionBit::IB1); + curr_row[IB2.master_main_index()] = other_instruction.ib(InstructionBit::IB2); + curr_row[IB3.master_main_index()] = other_instruction.ib(InstructionBit::IB3); + curr_row[IB4.master_main_index()] = other_instruction.ib(InstructionBit::IB4); + curr_row[IB5.master_main_index()] = other_instruction.ib(InstructionBit::IB5); + curr_row[IB6.master_main_index()] = other_instruction.ib(InstructionBit::IB6); let result = deselector.clone().consume().evaluate( master_base_table.view(), master_ext_table.view(), @@ -3387,13 +3232,13 @@ mod tests { // Positive tests let mut curr_row = master_base_table.slice_mut(s![0, ..]); - curr_row[IB0.master_base_table_index()] = instruction.ib(InstructionBit::IB0); - curr_row[IB1.master_base_table_index()] = instruction.ib(InstructionBit::IB1); - curr_row[IB2.master_base_table_index()] = instruction.ib(InstructionBit::IB2); - curr_row[IB3.master_base_table_index()] = instruction.ib(InstructionBit::IB3); - curr_row[IB4.master_base_table_index()] = instruction.ib(InstructionBit::IB4); - curr_row[IB5.master_base_table_index()] = instruction.ib(InstructionBit::IB5); - curr_row[IB6.master_base_table_index()] = instruction.ib(InstructionBit::IB6); + curr_row[IB0.master_main_index()] = instruction.ib(InstructionBit::IB0); + curr_row[IB1.master_main_index()] = instruction.ib(InstructionBit::IB1); + curr_row[IB2.master_main_index()] = instruction.ib(InstructionBit::IB2); + curr_row[IB3.master_main_index()] = instruction.ib(InstructionBit::IB3); + curr_row[IB4.master_main_index()] = instruction.ib(InstructionBit::IB4); + curr_row[IB5.master_main_index()] = instruction.ib(InstructionBit::IB5); + curr_row[IB6.master_main_index()] = instruction.ib(InstructionBit::IB6); let result = deselector.consume().evaluate( master_base_table.view(), master_ext_table.view(), @@ -3483,13 +3328,13 @@ mod tests { #[strategy(0_usize..16)] indicator_poly_index: usize, #[strategy(0_u64..16)] query_index: u64, ) { - let mut base_table = Array2::ones([2, NUM_BASE_COLUMNS]); - let aux_table = Array2::ones([2, NUM_EXT_COLUMNS]); + let mut base_table = Array2::ones([2, NUM_MAIN_COLUMNS]); + let aux_table = Array2::ones([2, NUM_AUX_COLUMNS]); - base_table[[0, HV0.master_base_table_index()]] = bfe!(query_index % 2); - base_table[[0, HV1.master_base_table_index()]] = bfe!((query_index >> 1) % 2); - base_table[[0, HV2.master_base_table_index()]] = bfe!((query_index >> 2) % 2); - base_table[[0, HV3.master_base_table_index()]] = bfe!((query_index >> 3) % 2); + base_table[[0, HV0.master_main_index()]] = bfe!(query_index % 2); + base_table[[0, HV1.master_main_index()]] = bfe!((query_index >> 1) % 2); + base_table[[0, HV2.master_main_index()]] = bfe!((query_index >> 2) % 2); + base_table[[0, HV3.master_main_index()]] = bfe!((query_index >> 3) % 2); let builder = ConstraintCircuitBuilder::new(); let indicator_poly = indicator_polynomial(&builder, indicator_poly_index).consume(); @@ -3538,19 +3383,18 @@ mod tests { #[strategy(arb())] b: XFieldElement, ) { let circuit_builder = ConstraintCircuitBuilder::new(); - let main_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(BaseRow(col.master_base_table_index())) - }; + let main_row = + |col: ProcessorMainColumn| circuit_builder.input(Main(col.master_main_index())); let [x0, x1, x2, y0, y1, y2] = [ST0, ST1, ST2, ST3, ST4, ST5].map(main_row); - let mut base_table = Array2::zeros([1, NUM_BASE_COLUMNS]); - let ext_table = Array2::zeros([1, NUM_EXT_COLUMNS]); - base_table[[0, ST0.master_base_table_index()]] = a.coefficients[0]; - base_table[[0, ST1.master_base_table_index()]] = a.coefficients[1]; - base_table[[0, ST2.master_base_table_index()]] = a.coefficients[2]; - base_table[[0, ST3.master_base_table_index()]] = b.coefficients[0]; - base_table[[0, ST4.master_base_table_index()]] = b.coefficients[1]; - base_table[[0, ST5.master_base_table_index()]] = b.coefficients[2]; + let mut base_table = Array2::zeros([1, NUM_MAIN_COLUMNS]); + let ext_table = Array2::zeros([1, NUM_AUX_COLUMNS]); + base_table[[0, ST0.master_main_index()]] = a.coefficients[0]; + base_table[[0, ST1.master_main_index()]] = a.coefficients[1]; + base_table[[0, ST2.master_main_index()]] = a.coefficients[2]; + base_table[[0, ST3.master_main_index()]] = b.coefficients[0]; + base_table[[0, ST4.master_main_index()]] = b.coefficients[1]; + base_table[[0, ST5.master_main_index()]] = b.coefficients[2]; let [c0, c1, c2] = xx_product([x0, x1, x2], [y0, y1, y2]) .map(|c| c.consume()) @@ -3569,17 +3413,16 @@ mod tests { #[strategy(arb())] b: BFieldElement, ) { let circuit_builder = ConstraintCircuitBuilder::new(); - let base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(BaseRow(col.master_base_table_index())) - }; - let [x0, x1, x2, y] = [ST0, ST1, ST2, ST3].map(base_row); - - let mut base_table = Array2::zeros([1, NUM_BASE_COLUMNS]); - let ext_table = Array2::zeros([1, NUM_EXT_COLUMNS]); - base_table[[0, ST0.master_base_table_index()]] = a.coefficients[0]; - base_table[[0, ST1.master_base_table_index()]] = a.coefficients[1]; - base_table[[0, ST2.master_base_table_index()]] = a.coefficients[2]; - base_table[[0, ST3.master_base_table_index()]] = b; + let main_row = + |col: ProcessorMainColumn| circuit_builder.input(Main(col.master_main_index())); + let [x0, x1, x2, y] = [ST0, ST1, ST2, ST3].map(main_row); + + let mut base_table = Array2::zeros([1, NUM_MAIN_COLUMNS]); + let ext_table = Array2::zeros([1, NUM_AUX_COLUMNS]); + base_table[[0, ST0.master_main_index()]] = a.coefficients[0]; + base_table[[0, ST1.master_main_index()]] = a.coefficients[1]; + base_table[[0, ST2.master_main_index()]] = a.coefficients[2]; + base_table[[0, ST3.master_main_index()]] = b; let [c0, c1, c2] = xb_product([x0, x1, x2], y) .map(|c| c.consume()) diff --git a/triton-air/src/table/program.rs b/triton-air/src/table/program.rs index f4e67490b..d85544f5e 100644 --- a/triton-air/src/table/program.rs +++ b/triton-air/src/table/program.rs @@ -1,49 +1,47 @@ use constraint_circuit::ConstraintCircuitBuilder; use constraint_circuit::ConstraintCircuitMonad; use constraint_circuit::DualRowIndicator; -use constraint_circuit::DualRowIndicator::CurrentBaseRow; -use constraint_circuit::DualRowIndicator::CurrentExtRow; -use constraint_circuit::DualRowIndicator::NextBaseRow; -use constraint_circuit::DualRowIndicator::NextExtRow; +use constraint_circuit::DualRowIndicator::CurrentAux; +use constraint_circuit::DualRowIndicator::CurrentMain; +use constraint_circuit::DualRowIndicator::NextAux; +use constraint_circuit::DualRowIndicator::NextMain; use constraint_circuit::SingleRowIndicator; -use constraint_circuit::SingleRowIndicator::BaseRow; -use constraint_circuit::SingleRowIndicator::ExtRow; +use constraint_circuit::SingleRowIndicator::Aux; +use constraint_circuit::SingleRowIndicator::Main; use twenty_first::prelude::*; use crate::challenge_id::ChallengeId; use crate::cross_table_argument::CrossTableArg; use crate::cross_table_argument::EvalArg; use crate::cross_table_argument::LookupArg; -use crate::table_column::MasterBaseTableColumn; -use crate::table_column::MasterExtTableColumn; +use crate::table_column::MasterAuxColumn; +use crate::table_column::MasterMainColumn; use crate::AIR; #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct ProgramTable; impl AIR for ProgramTable { - type MainColumn = crate::table_column::ProgramBaseTableColumn; - type AuxColumn = crate::table_column::ProgramExtTableColumn; + type MainColumn = crate::table_column::ProgramMainColumn; + type AuxColumn = crate::table_column::ProgramAuxColumn; fn initial_constraints( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { let challenge = |c| circuit_builder.challenge(c); let x_constant = |xfe| circuit_builder.x_constant(xfe); - let base_row = - |col: Self::MainColumn| circuit_builder.input(BaseRow(col.master_base_table_index())); - let ext_row = - |col: Self::AuxColumn| circuit_builder.input(ExtRow(col.master_ext_table_index())); - - let address = base_row(Self::MainColumn::Address); - let instruction = base_row(Self::MainColumn::Instruction); - let index_in_chunk = base_row(Self::MainColumn::IndexInChunk); - let is_hash_input_padding = base_row(Self::MainColumn::IsHashInputPadding); + let main_row = |col: Self::MainColumn| circuit_builder.input(Main(col.master_main_index())); + let aux_row = |col: Self::AuxColumn| circuit_builder.input(Aux(col.master_aux_index())); + + let address = main_row(Self::MainColumn::Address); + let instruction = main_row(Self::MainColumn::Instruction); + let index_in_chunk = main_row(Self::MainColumn::IndexInChunk); + let is_hash_input_padding = main_row(Self::MainColumn::IsHashInputPadding); let instruction_lookup_log_derivative = - ext_row(Self::AuxColumn::InstructionLookupServerLogDerivative); + aux_row(Self::AuxColumn::InstructionLookupServerLogDerivative); let prepare_chunk_running_evaluation = - ext_row(Self::AuxColumn::PrepareChunkRunningEvaluation); - let send_chunk_running_evaluation = ext_row(Self::AuxColumn::SendChunkRunningEvaluation); + aux_row(Self::AuxColumn::PrepareChunkRunningEvaluation); + let send_chunk_running_evaluation = aux_row(Self::AuxColumn::SendChunkRunningEvaluation); let lookup_arg_initial = x_constant(LookupArg::default_initial()); let eval_arg_initial = x_constant(EvalArg::default_initial()); @@ -80,8 +78,7 @@ impl AIR for ProgramTable { circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { let constant = |c: u32| circuit_builder.b_constant(c); - let main_row = - |col: Self::MainColumn| circuit_builder.input(BaseRow(col.master_base_table_index())); + let main_row = |col: Self::MainColumn| circuit_builder.input(Main(col.master_main_index())); let one = constant(1); let max_index_in_chunk = constant((Tip5::RATE - 1).try_into().unwrap()); @@ -117,17 +114,14 @@ impl AIR for ProgramTable { let challenge = |c| circuit_builder.challenge(c); let constant = |c: u64| circuit_builder.b_constant(c); - let current_base_row = |col: Self::MainColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: Self::MainColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - let current_ext_row = |col: Self::AuxColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = - |col: Self::AuxColumn| circuit_builder.input(NextExtRow(col.master_ext_table_index())); + let current_main_row = + |col: Self::MainColumn| circuit_builder.input(CurrentMain(col.master_main_index())); + let next_main_row = + |col: Self::MainColumn| circuit_builder.input(NextMain(col.master_main_index())); + let current_aux_row = + |col: Self::AuxColumn| circuit_builder.input(CurrentAux(col.master_aux_index())); + let next_aux_row = + |col: Self::AuxColumn| circuit_builder.input(NextAux(col.master_aux_index())); let one = constant(1); let rate_minus_one = constant(u64::try_from(Tip5::RATE).unwrap() - 1); @@ -137,33 +131,33 @@ impl AIR for ProgramTable { let send_chunk_indeterminate = challenge(ChallengeId::ProgramAttestationSendChunkIndeterminate); - let address = current_base_row(Self::MainColumn::Address); - let instruction = current_base_row(Self::MainColumn::Instruction); - let lookup_multiplicity = current_base_row(Self::MainColumn::LookupMultiplicity); - let index_in_chunk = current_base_row(Self::MainColumn::IndexInChunk); + let address = current_main_row(Self::MainColumn::Address); + let instruction = current_main_row(Self::MainColumn::Instruction); + let lookup_multiplicity = current_main_row(Self::MainColumn::LookupMultiplicity); + let index_in_chunk = current_main_row(Self::MainColumn::IndexInChunk); let max_minus_index_in_chunk_inv = - current_base_row(Self::MainColumn::MaxMinusIndexInChunkInv); - let is_hash_input_padding = current_base_row(Self::MainColumn::IsHashInputPadding); - let is_table_padding = current_base_row(Self::MainColumn::IsTablePadding); - let log_derivative = current_ext_row(Self::AuxColumn::InstructionLookupServerLogDerivative); + current_main_row(Self::MainColumn::MaxMinusIndexInChunkInv); + let is_hash_input_padding = current_main_row(Self::MainColumn::IsHashInputPadding); + let is_table_padding = current_main_row(Self::MainColumn::IsTablePadding); + let log_derivative = current_aux_row(Self::AuxColumn::InstructionLookupServerLogDerivative); let prepare_chunk_running_evaluation = - current_ext_row(Self::AuxColumn::PrepareChunkRunningEvaluation); + current_aux_row(Self::AuxColumn::PrepareChunkRunningEvaluation); let send_chunk_running_evaluation = - current_ext_row(Self::AuxColumn::SendChunkRunningEvaluation); + current_aux_row(Self::AuxColumn::SendChunkRunningEvaluation); - let address_next = next_base_row(Self::MainColumn::Address); - let instruction_next = next_base_row(Self::MainColumn::Instruction); - let index_in_chunk_next = next_base_row(Self::MainColumn::IndexInChunk); + let address_next = next_main_row(Self::MainColumn::Address); + let instruction_next = next_main_row(Self::MainColumn::Instruction); + let index_in_chunk_next = next_main_row(Self::MainColumn::IndexInChunk); let max_minus_index_in_chunk_inv_next = - next_base_row(Self::MainColumn::MaxMinusIndexInChunkInv); - let is_hash_input_padding_next = next_base_row(Self::MainColumn::IsHashInputPadding); - let is_table_padding_next = next_base_row(Self::MainColumn::IsTablePadding); + next_main_row(Self::MainColumn::MaxMinusIndexInChunkInv); + let is_hash_input_padding_next = next_main_row(Self::MainColumn::IsHashInputPadding); + let is_table_padding_next = next_main_row(Self::MainColumn::IsTablePadding); let log_derivative_next = - next_ext_row(Self::AuxColumn::InstructionLookupServerLogDerivative); + next_aux_row(Self::AuxColumn::InstructionLookupServerLogDerivative); let prepare_chunk_running_evaluation_next = - next_ext_row(Self::AuxColumn::PrepareChunkRunningEvaluation); + next_aux_row(Self::AuxColumn::PrepareChunkRunningEvaluation); let send_chunk_running_evaluation_next = - next_ext_row(Self::AuxColumn::SendChunkRunningEvaluation); + next_aux_row(Self::AuxColumn::SendChunkRunningEvaluation); let address_increases_by_one = address_next - (address.clone() + one.clone()); let is_table_padding_is_0_or_remains_unchanged = @@ -258,8 +252,7 @@ impl AIR for ProgramTable { circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { let constant = |c: u64| circuit_builder.b_constant(c); - let main_row = - |col: Self::MainColumn| circuit_builder.input(BaseRow(col.master_base_table_index())); + let main_row = |col: Self::MainColumn| circuit_builder.input(Main(col.master_main_index())); let index_in_chunk = main_row(Self::MainColumn::IndexInChunk); let is_hash_input_padding = main_row(Self::MainColumn::IsHashInputPadding); diff --git a/triton-air/src/table/ram.rs b/triton-air/src/table/ram.rs index ba29141a3..e3a7e071f 100644 --- a/triton-air/src/table/ram.rs +++ b/triton-air/src/table/ram.rs @@ -1,21 +1,21 @@ use constraint_circuit::ConstraintCircuitBuilder; use constraint_circuit::ConstraintCircuitMonad; use constraint_circuit::DualRowIndicator; -use constraint_circuit::DualRowIndicator::CurrentBaseRow; -use constraint_circuit::DualRowIndicator::CurrentExtRow; -use constraint_circuit::DualRowIndicator::NextBaseRow; -use constraint_circuit::DualRowIndicator::NextExtRow; +use constraint_circuit::DualRowIndicator::CurrentAux; +use constraint_circuit::DualRowIndicator::CurrentMain; +use constraint_circuit::DualRowIndicator::NextAux; +use constraint_circuit::DualRowIndicator::NextMain; use constraint_circuit::SingleRowIndicator; -use constraint_circuit::SingleRowIndicator::BaseRow; -use constraint_circuit::SingleRowIndicator::ExtRow; +use constraint_circuit::SingleRowIndicator::Aux; +use constraint_circuit::SingleRowIndicator::Main; use twenty_first::prelude::*; use crate::challenge_id::ChallengeId; use crate::cross_table_argument::CrossTableArg; use crate::cross_table_argument::LookupArg; use crate::cross_table_argument::PermArg; -use crate::table_column::MasterBaseTableColumn; -use crate::table_column::MasterExtTableColumn; +use crate::table_column::MasterAuxColumn; +use crate::table_column::MasterMainColumn; use crate::AIR; pub const INSTRUCTION_TYPE_WRITE: BFieldElement = BFieldElement::new(0); @@ -26,8 +26,8 @@ pub const PADDING_INDICATOR: BFieldElement = BFieldElement::new(2); pub struct RamTable; impl AIR for RamTable { - type MainColumn = crate::table_column::RamBaseTableColumn; - type AuxColumn = crate::table_column::RamExtTableColumn; + type MainColumn = crate::table_column::RamMainColumn; + type AuxColumn = crate::table_column::RamAuxColumn; fn initial_constraints( circuit_builder: &ConstraintCircuitBuilder, @@ -35,12 +35,10 @@ impl AIR for RamTable { let challenge = |c| circuit_builder.challenge(c); let constant = |c| circuit_builder.b_constant(c); let x_constant = |c| circuit_builder.x_constant(c); - let main_row = |column: Self::MainColumn| { - circuit_builder.input(BaseRow(column.master_base_table_index())) - }; - let aux_row = |column: Self::AuxColumn| { - circuit_builder.input(ExtRow(column.master_ext_table_index())) - }; + let main_row = + |column: Self::MainColumn| circuit_builder.input(Main(column.master_main_index())); + let aux_row = + |column: Self::AuxColumn| circuit_builder.input(Aux(column.master_aux_index())); let first_row_is_padding_row = main_row(Self::MainColumn::InstructionType) - constant(PADDING_INDICATOR); @@ -108,54 +106,51 @@ impl AIR for RamTable { ) -> Vec> { let constant = |c| circuit_builder.b_constant(c); let challenge = |c| circuit_builder.challenge(c); - let curr_base_row = |column: Self::MainColumn| { - circuit_builder.input(CurrentBaseRow(column.master_base_table_index())) - }; - let curr_ext_row = |column: Self::AuxColumn| { - circuit_builder.input(CurrentExtRow(column.master_ext_table_index())) - }; - let next_base_row = |column: Self::MainColumn| { - circuit_builder.input(NextBaseRow(column.master_base_table_index())) - }; - let next_ext_row = |column: Self::AuxColumn| { - circuit_builder.input(NextExtRow(column.master_ext_table_index())) + let curr_main_row = |column: Self::MainColumn| { + circuit_builder.input(CurrentMain(column.master_main_index())) }; + let curr_aux_row = + |column: Self::AuxColumn| circuit_builder.input(CurrentAux(column.master_aux_index())); + let next_main_row = + |column: Self::MainColumn| circuit_builder.input(NextMain(column.master_main_index())); + let next_aux_row = + |column: Self::AuxColumn| circuit_builder.input(NextAux(column.master_aux_index())); let one = constant(1_u32.into()); let bezout_challenge = challenge(ChallengeId::RamTableBezoutRelationIndeterminate); - let clock = curr_base_row(Self::MainColumn::CLK); - let ram_pointer = curr_base_row(Self::MainColumn::RamPointer); - let ram_value = curr_base_row(Self::MainColumn::RamValue); - let instruction_type = curr_base_row(Self::MainColumn::InstructionType); + let clock = curr_main_row(Self::MainColumn::CLK); + let ram_pointer = curr_main_row(Self::MainColumn::RamPointer); + let ram_value = curr_main_row(Self::MainColumn::RamValue); + let instruction_type = curr_main_row(Self::MainColumn::InstructionType); let inverse_of_ram_pointer_difference = - curr_base_row(Self::MainColumn::InverseOfRampDifference); - let bcpc0 = curr_base_row(Self::MainColumn::BezoutCoefficientPolynomialCoefficient0); - let bcpc1 = curr_base_row(Self::MainColumn::BezoutCoefficientPolynomialCoefficient1); - - let running_product_ram_pointer = curr_ext_row(Self::AuxColumn::RunningProductOfRAMP); - let fd = curr_ext_row(Self::AuxColumn::FormalDerivative); - let bc0 = curr_ext_row(Self::AuxColumn::BezoutCoefficient0); - let bc1 = curr_ext_row(Self::AuxColumn::BezoutCoefficient1); - let rppa = curr_ext_row(Self::AuxColumn::RunningProductPermArg); + curr_main_row(Self::MainColumn::InverseOfRampDifference); + let bcpc0 = curr_main_row(Self::MainColumn::BezoutCoefficientPolynomialCoefficient0); + let bcpc1 = curr_main_row(Self::MainColumn::BezoutCoefficientPolynomialCoefficient1); + + let running_product_ram_pointer = curr_aux_row(Self::AuxColumn::RunningProductOfRAMP); + let fd = curr_aux_row(Self::AuxColumn::FormalDerivative); + let bc0 = curr_aux_row(Self::AuxColumn::BezoutCoefficient0); + let bc1 = curr_aux_row(Self::AuxColumn::BezoutCoefficient1); + let rppa = curr_aux_row(Self::AuxColumn::RunningProductPermArg); let clock_jump_diff_log_derivative = - curr_ext_row(Self::AuxColumn::ClockJumpDifferenceLookupClientLogDerivative); - - let clock_next = next_base_row(Self::MainColumn::CLK); - let ram_pointer_next = next_base_row(Self::MainColumn::RamPointer); - let ram_value_next = next_base_row(Self::MainColumn::RamValue); - let instruction_type_next = next_base_row(Self::MainColumn::InstructionType); - let bcpc0_next = next_base_row(Self::MainColumn::BezoutCoefficientPolynomialCoefficient0); - let bcpc1_next = next_base_row(Self::MainColumn::BezoutCoefficientPolynomialCoefficient1); - - let running_product_ram_pointer_next = next_ext_row(Self::AuxColumn::RunningProductOfRAMP); - let fd_next = next_ext_row(Self::AuxColumn::FormalDerivative); - let bc0_next = next_ext_row(Self::AuxColumn::BezoutCoefficient0); - let bc1_next = next_ext_row(Self::AuxColumn::BezoutCoefficient1); - let rppa_next = next_ext_row(Self::AuxColumn::RunningProductPermArg); + curr_aux_row(Self::AuxColumn::ClockJumpDifferenceLookupClientLogDerivative); + + let clock_next = next_main_row(Self::MainColumn::CLK); + let ram_pointer_next = next_main_row(Self::MainColumn::RamPointer); + let ram_value_next = next_main_row(Self::MainColumn::RamValue); + let instruction_type_next = next_main_row(Self::MainColumn::InstructionType); + let bcpc0_next = next_main_row(Self::MainColumn::BezoutCoefficientPolynomialCoefficient0); + let bcpc1_next = next_main_row(Self::MainColumn::BezoutCoefficientPolynomialCoefficient1); + + let running_product_ram_pointer_next = next_aux_row(Self::AuxColumn::RunningProductOfRAMP); + let fd_next = next_aux_row(Self::AuxColumn::FormalDerivative); + let bc0_next = next_aux_row(Self::AuxColumn::BezoutCoefficient0); + let bc1_next = next_aux_row(Self::AuxColumn::BezoutCoefficient1); + let rppa_next = next_aux_row(Self::AuxColumn::RunningProductPermArg); let clock_jump_diff_log_derivative_next = - next_ext_row(Self::AuxColumn::ClockJumpDifferenceLookupClientLogDerivative); + next_aux_row(Self::AuxColumn::ClockJumpDifferenceLookupClientLogDerivative); let next_row_is_padding_row = instruction_type_next.clone() - constant(PADDING_INDICATOR).clone(); @@ -260,14 +255,13 @@ impl AIR for RamTable { circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { let constant = |c: u32| circuit_builder.b_constant(c); - let ext_row = |column: Self::AuxColumn| { - circuit_builder.input(ExtRow(column.master_ext_table_index())) - }; + let aux_row = + |column: Self::AuxColumn| circuit_builder.input(Aux(column.master_aux_index())); - let bezout_relation_holds = ext_row(Self::AuxColumn::BezoutCoefficient0) - * ext_row(Self::AuxColumn::RunningProductOfRAMP) - + ext_row(Self::AuxColumn::BezoutCoefficient1) - * ext_row(Self::AuxColumn::FormalDerivative) + let bezout_relation_holds = aux_row(Self::AuxColumn::BezoutCoefficient0) + * aux_row(Self::AuxColumn::RunningProductOfRAMP) + + aux_row(Self::AuxColumn::BezoutCoefficient1) + * aux_row(Self::AuxColumn::FormalDerivative) - constant(1); vec![bezout_relation_holds] diff --git a/triton-air/src/table/u32.rs b/triton-air/src/table/u32.rs index ea3c37464..a86e8d53e 100644 --- a/triton-air/src/table/u32.rs +++ b/triton-air/src/table/u32.rs @@ -1,40 +1,38 @@ use constraint_circuit::ConstraintCircuitBuilder; use constraint_circuit::ConstraintCircuitMonad; use constraint_circuit::DualRowIndicator; -use constraint_circuit::DualRowIndicator::CurrentBaseRow; -use constraint_circuit::DualRowIndicator::CurrentExtRow; -use constraint_circuit::DualRowIndicator::NextBaseRow; -use constraint_circuit::DualRowIndicator::NextExtRow; +use constraint_circuit::DualRowIndicator::CurrentAux; +use constraint_circuit::DualRowIndicator::CurrentMain; +use constraint_circuit::DualRowIndicator::NextAux; +use constraint_circuit::DualRowIndicator::NextMain; use constraint_circuit::InputIndicator; use constraint_circuit::SingleRowIndicator; -use constraint_circuit::SingleRowIndicator::BaseRow; -use constraint_circuit::SingleRowIndicator::ExtRow; +use constraint_circuit::SingleRowIndicator::Aux; +use constraint_circuit::SingleRowIndicator::Main; use isa::instruction::Instruction; use std::ops::Mul; use crate::challenge_id::ChallengeId; use crate::cross_table_argument::CrossTableArg; use crate::cross_table_argument::LookupArg; -use crate::table_column::MasterBaseTableColumn; -use crate::table_column::MasterExtTableColumn; +use crate::table_column::MasterAuxColumn; +use crate::table_column::MasterMainColumn; use crate::AIR; #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct U32Table; impl AIR for U32Table { - type MainColumn = crate::table_column::U32BaseTableColumn; - type AuxColumn = crate::table_column::U32ExtTableColumn; + type MainColumn = crate::table_column::U32MainColumn; + type AuxColumn = crate::table_column::U32AuxColumn; fn initial_constraints( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let main_row = |column: Self::MainColumn| { - circuit_builder.input(BaseRow(column.master_base_table_index())) - }; - let aux_row = |column: Self::AuxColumn| { - circuit_builder.input(ExtRow(column.master_ext_table_index())) - }; + let main_row = + |column: Self::MainColumn| circuit_builder.input(Main(column.master_main_index())); + let aux_row = + |column: Self::AuxColumn| circuit_builder.input(Aux(column.master_aux_index())); let challenge = |c| circuit_builder.challenge(c); let one = circuit_builder.b_constant(1); @@ -70,9 +68,8 @@ impl AIR for U32Table { fn consistency_constraints( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let main_row = |column: Self::MainColumn| { - circuit_builder.input(BaseRow(column.master_base_table_index())) - }; + let main_row = + |column: Self::MainColumn| circuit_builder.input(Main(column.master_main_index())); let one = || circuit_builder.b_constant(1); let two = || circuit_builder.b_constant(2); @@ -160,17 +157,14 @@ impl AIR for U32Table { circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { let curr_main_row = |column: Self::MainColumn| { - circuit_builder.input(CurrentBaseRow(column.master_base_table_index())) - }; - let next_main_row = |column: Self::MainColumn| { - circuit_builder.input(NextBaseRow(column.master_base_table_index())) - }; - let curr_aux_row = |column: Self::AuxColumn| { - circuit_builder.input(CurrentExtRow(column.master_ext_table_index())) - }; - let next_aux_row = |column: Self::AuxColumn| { - circuit_builder.input(NextExtRow(column.master_ext_table_index())) + circuit_builder.input(CurrentMain(column.master_main_index())) }; + let next_main_row = + |column: Self::MainColumn| circuit_builder.input(NextMain(column.master_main_index())); + let curr_aux_row = + |column: Self::AuxColumn| circuit_builder.input(CurrentAux(column.master_aux_index())); + let next_aux_row = + |column: Self::AuxColumn| circuit_builder.input(NextAux(column.master_aux_index())); let challenge = |c| circuit_builder.challenge(c); let one = || circuit_builder.b_constant(1); let two = || circuit_builder.b_constant(2); @@ -359,9 +353,8 @@ impl AIR for U32Table { fn terminal_constraints( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let main_row = |column: Self::MainColumn| { - circuit_builder.input(BaseRow(column.master_base_table_index())) - }; + let main_row = + |column: Self::MainColumn| circuit_builder.input(Main(column.master_main_index())); let constant = |c| circuit_builder.b_constant(c); let ci = main_row(Self::MainColumn::CI); diff --git a/triton-air/src/table_column.rs b/triton-air/src/table_column.rs index 400934cb0..5900d3fa7 100644 --- a/triton-air/src/table_column.rs +++ b/triton-air/src/table_column.rs @@ -7,16 +7,16 @@ use strum::Display; use strum::EnumCount; use strum::EnumIter; +use crate::table::AUX_CASCADE_TABLE_START; +use crate::table::AUX_HASH_TABLE_START; +use crate::table::AUX_JUMP_STACK_TABLE_START; +use crate::table::AUX_LOOKUP_TABLE_START; +use crate::table::AUX_OP_STACK_TABLE_START; +use crate::table::AUX_PROCESSOR_TABLE_START; +use crate::table::AUX_PROGRAM_TABLE_START; +use crate::table::AUX_RAM_TABLE_START; +use crate::table::AUX_U32_TABLE_START; use crate::table::CASCADE_TABLE_START; -use crate::table::EXT_CASCADE_TABLE_START; -use crate::table::EXT_HASH_TABLE_START; -use crate::table::EXT_JUMP_STACK_TABLE_START; -use crate::table::EXT_LOOKUP_TABLE_START; -use crate::table::EXT_OP_STACK_TABLE_START; -use crate::table::EXT_PROCESSOR_TABLE_START; -use crate::table::EXT_PROGRAM_TABLE_START; -use crate::table::EXT_RAM_TABLE_START; -use crate::table::EXT_U32_TABLE_START; use crate::table::HASH_TABLE_START; use crate::table::JUMP_STACK_TABLE_START; use crate::table::LOOKUP_TABLE_START; @@ -28,7 +28,7 @@ use crate::table::U32_TABLE_START; #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] -pub enum ProgramBaseTableColumn { +pub enum ProgramMainColumn { /// An instruction's address. Address, @@ -43,14 +43,14 @@ pub enum ProgramBaseTableColumn { /// In other words: /// [`Address`] modulo [`Rate`]. /// - /// [`Address`]: ProgramBaseTableColumn::Address + /// [`Address`]: ProgramMainColumn::Address /// [`Rate`]: twenty_first::math::tip5::RATE IndexInChunk, /// The inverse-or-zero of [`Rate`] - 1 - [`IndexInChunk`]. /// Helper variable to guarantee [`IndexInChunk`]'s correct transition. /// - /// [`IndexInChunk`]: ProgramBaseTableColumn::IndexInChunk + /// [`IndexInChunk`]: ProgramMainColumn::IndexInChunk /// [`Rate`]: twenty_first::math::tip5::RATE MaxMinusIndexInChunkInv, @@ -63,16 +63,16 @@ pub enum ProgramBaseTableColumn { #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] -pub enum ProgramExtTableColumn { +pub enum ProgramAuxColumn { /// The server part of the instruction lookup. /// /// The counterpart to [`InstructionLookupClientLogDerivative`][client]. /// - /// [client]: ProcessorExtTableColumn::InstructionLookupClientLogDerivative + /// [client]: ProcessorAuxColumn::InstructionLookupClientLogDerivative InstructionLookupServerLogDerivative, /// An evaluation argument accumulating [`RATE`][rate] many instructions before - /// they are sent using [`SendChunkEvalArg`](ProgramExtTableColumn::SendChunkRunningEvaluation). + /// they are sent using [`SendChunkEvalArg`](ProgramAuxColumn::SendChunkRunningEvaluation). /// Resets to zero after each chunk. /// Relevant for program attestation. /// @@ -84,16 +84,16 @@ pub enum ProgramExtTableColumn { /// This bus is used for sending those chunks to the Hash Table. /// Relevant for program attestation. /// - /// The counterpart to [`RcvChunkEvalArg`](HashExtTableColumn::ReceiveChunkRunningEvaluation). + /// The counterpart to [`RcvChunkEvalArg`](HashAuxColumn::ReceiveChunkRunningEvaluation). /// /// [rate]: twenty_first::math::tip5::RATE - /// [prep]: ProgramExtTableColumn::PrepareChunkRunningEvaluation + /// [prep]: ProgramAuxColumn::PrepareChunkRunningEvaluation SendChunkRunningEvaluation, } #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] -pub enum ProcessorBaseTableColumn { +pub enum ProcessorMainColumn { CLK, IsPadding, IP, @@ -138,7 +138,7 @@ pub enum ProcessorBaseTableColumn { #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] -pub enum ProcessorExtTableColumn { +pub enum ProcessorAuxColumn { InputTableEvalArg, OutputTableEvalArg, InstructionLookupClientLogDerivative, @@ -164,7 +164,7 @@ pub enum ProcessorExtTableColumn { #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] -pub enum OpStackBaseTableColumn { +pub enum OpStackMainColumn { CLK, IB1ShrinkStack, StackPointer, @@ -173,7 +173,7 @@ pub enum OpStackBaseTableColumn { #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] -pub enum OpStackExtTableColumn { +pub enum OpStackAuxColumn { RunningProductPermArg, /// The (running sum of the) logarithmic derivative for the clock jump difference Lookup /// Argument with the Processor Table. @@ -182,7 +182,7 @@ pub enum OpStackExtTableColumn { #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] -pub enum RamBaseTableColumn { +pub enum RamMainColumn { CLK, /// Is [`INSTRUCTION_TYPE_READ`] for instruction `read_mem` and [`INSTRUCTION_TYPE_WRITE`] @@ -201,7 +201,7 @@ pub enum RamBaseTableColumn { #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] -pub enum RamExtTableColumn { +pub enum RamAuxColumn { RunningProductOfRAMP, FormalDerivative, BezoutCoefficient0, @@ -214,7 +214,7 @@ pub enum RamExtTableColumn { #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] -pub enum JumpStackBaseTableColumn { +pub enum JumpStackMainColumn { CLK, CI, JSP, @@ -224,7 +224,7 @@ pub enum JumpStackBaseTableColumn { #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] -pub enum JumpStackExtTableColumn { +pub enum JumpStackAuxColumn { RunningProductPermArg, /// The (running sum of the) logarithmic derivative for the clock jump difference Lookup /// Argument with the Processor Table. @@ -233,7 +233,7 @@ pub enum JumpStackExtTableColumn { #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] -pub enum HashBaseTableColumn { +pub enum HashMainColumn { /// The indicator for the [`HashTableMode`][mode]. /// /// [mode]: crate::table::hash::HashTableMode @@ -242,7 +242,7 @@ pub enum HashBaseTableColumn { /// The current instruction. Only relevant for [`Mode`][mode] [`Sponge`][mode_sponge] /// in order to distinguish between the different Sponge instructions. /// - /// [mode]: HashBaseTableColumn::Mode + /// [mode]: HashMainColumn::Mode /// [mode_sponge]: crate::table::hash::HashTableMode::Sponge CI, @@ -253,8 +253,8 @@ pub enum HashBaseTableColumn { /// `sponge_init`, as an exception to above rule, and /// - 0 → 0 in [`Mode`][mode] [`Pad`][mode_pad]. /// - /// [ci]: HashBaseTableColumn::CI - /// [mode]: HashBaseTableColumn::Mode + /// [ci]: HashMainColumn::CI + /// [mode]: HashMainColumn::Mode /// [mode_prog_hash]: crate::table::hash::HashTableMode::ProgramHashing /// [mode_sponge]: crate::table::hash::HashTableMode::Sponge /// [mode_hash]: crate::table::hash::HashTableMode::Hash @@ -331,12 +331,12 @@ pub enum HashBaseTableColumn { #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] -pub enum HashExtTableColumn { +pub enum HashAuxColumn { /// The evaluation argument corresponding to receiving instructions in chunks of size /// [`RATE`][rate]. The chunks are hashed in Sponge mode. /// This allows program attestation. /// - /// The counterpart to [`SendChunkEvalArg`](ProgramExtTableColumn::SendChunkRunningEvaluation). + /// The counterpart to [`SendChunkEvalArg`](ProgramAuxColumn::SendChunkRunningEvaluation). /// /// [rate]: twenty_first::math::tip5::RATE ReceiveChunkRunningEvaluation, @@ -369,7 +369,7 @@ pub enum HashExtTableColumn { #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] -pub enum CascadeBaseTableColumn { +pub enum CascadeMainColumn { /// Indicator for padding rows. IsPadding, @@ -391,7 +391,7 @@ pub enum CascadeBaseTableColumn { #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] -pub enum CascadeExtTableColumn { +pub enum CascadeAuxColumn { /// The (running sum of the) logarithmic derivative for the Lookup Argument with the Hash Table. /// In every row, the sum accumulates `LookupMultiplicity / (X - Combo)` where `X` is a /// verifier-supplied challenge and `Combo` is the weighted sum of @@ -411,7 +411,7 @@ pub enum CascadeExtTableColumn { #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] -pub enum LookupBaseTableColumn { +pub enum LookupMainColumn { /// Indicator for padding rows. IsPadding, @@ -427,7 +427,7 @@ pub enum LookupBaseTableColumn { #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] -pub enum LookupExtTableColumn { +pub enum LookupAuxColumn { /// The (running sum of the) logarithmic derivative for the Lookup Argument with the Cascade /// Table. In every row, accumulates the summand `LookupMultiplicity / Combo` where `Combo` is /// the verifier-weighted combination of `LookIn` and `LookOut`. @@ -440,7 +440,7 @@ pub enum LookupExtTableColumn { #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] -pub enum U32BaseTableColumn { +pub enum U32MainColumn { /// Marks the beginning of an independent section within the U32 table. CopyFlag, @@ -477,7 +477,7 @@ pub enum U32BaseTableColumn { #[repr(usize)] #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] -pub enum U32ExtTableColumn { +pub enum U32AuxColumn { /// The (running sum of the) logarithmic derivative for the Lookup Argument with the /// Processor Table. LookupServerLogDerivative, @@ -488,119 +488,119 @@ pub enum U32ExtTableColumn { /// - one to get the index of the column in the “local” base table, _i.e., not the master base /// table, and /// - one to get the index of the column in the master base table. -pub trait MasterBaseTableColumn { +pub trait MasterMainColumn { /// The index of the column in the “local” base table, _i.e., not the master base table. - fn base_table_index(&self) -> usize; + fn main_index(&self) -> usize; /// The index of the column in the master base table. - fn master_base_table_index(&self) -> usize; + fn master_main_index(&self) -> usize; } -impl MasterBaseTableColumn for ProgramBaseTableColumn { +impl MasterMainColumn for ProgramMainColumn { #[inline] - fn base_table_index(&self) -> usize { + fn main_index(&self) -> usize { (*self) as usize } #[inline] - fn master_base_table_index(&self) -> usize { - PROGRAM_TABLE_START + self.base_table_index() + fn master_main_index(&self) -> usize { + PROGRAM_TABLE_START + self.main_index() } } -impl MasterBaseTableColumn for ProcessorBaseTableColumn { +impl MasterMainColumn for ProcessorMainColumn { #[inline] - fn base_table_index(&self) -> usize { + fn main_index(&self) -> usize { (*self) as usize } #[inline] - fn master_base_table_index(&self) -> usize { - PROCESSOR_TABLE_START + self.base_table_index() + fn master_main_index(&self) -> usize { + PROCESSOR_TABLE_START + self.main_index() } } -impl MasterBaseTableColumn for OpStackBaseTableColumn { +impl MasterMainColumn for OpStackMainColumn { #[inline] - fn base_table_index(&self) -> usize { + fn main_index(&self) -> usize { (*self) as usize } #[inline] - fn master_base_table_index(&self) -> usize { - OP_STACK_TABLE_START + self.base_table_index() + fn master_main_index(&self) -> usize { + OP_STACK_TABLE_START + self.main_index() } } -impl MasterBaseTableColumn for RamBaseTableColumn { +impl MasterMainColumn for RamMainColumn { #[inline] - fn base_table_index(&self) -> usize { + fn main_index(&self) -> usize { (*self) as usize } #[inline] - fn master_base_table_index(&self) -> usize { - RAM_TABLE_START + self.base_table_index() + fn master_main_index(&self) -> usize { + RAM_TABLE_START + self.main_index() } } -impl MasterBaseTableColumn for JumpStackBaseTableColumn { +impl MasterMainColumn for JumpStackMainColumn { #[inline] - fn base_table_index(&self) -> usize { + fn main_index(&self) -> usize { (*self) as usize } #[inline] - fn master_base_table_index(&self) -> usize { - JUMP_STACK_TABLE_START + self.base_table_index() + fn master_main_index(&self) -> usize { + JUMP_STACK_TABLE_START + self.main_index() } } -impl MasterBaseTableColumn for HashBaseTableColumn { +impl MasterMainColumn for HashMainColumn { #[inline] - fn base_table_index(&self) -> usize { + fn main_index(&self) -> usize { (*self) as usize } #[inline] - fn master_base_table_index(&self) -> usize { - HASH_TABLE_START + self.base_table_index() + fn master_main_index(&self) -> usize { + HASH_TABLE_START + self.main_index() } } -impl MasterBaseTableColumn for CascadeBaseTableColumn { +impl MasterMainColumn for CascadeMainColumn { #[inline] - fn base_table_index(&self) -> usize { + fn main_index(&self) -> usize { (*self) as usize } #[inline] - fn master_base_table_index(&self) -> usize { - CASCADE_TABLE_START + self.base_table_index() + fn master_main_index(&self) -> usize { + CASCADE_TABLE_START + self.main_index() } } -impl MasterBaseTableColumn for LookupBaseTableColumn { +impl MasterMainColumn for LookupMainColumn { #[inline] - fn base_table_index(&self) -> usize { + fn main_index(&self) -> usize { (*self) as usize } #[inline] - fn master_base_table_index(&self) -> usize { - LOOKUP_TABLE_START + self.base_table_index() + fn master_main_index(&self) -> usize { + LOOKUP_TABLE_START + self.main_index() } } -impl MasterBaseTableColumn for U32BaseTableColumn { +impl MasterMainColumn for U32MainColumn { #[inline] - fn base_table_index(&self) -> usize { + fn main_index(&self) -> usize { (*self) as usize } #[inline] - fn master_base_table_index(&self) -> usize { - U32_TABLE_START + self.base_table_index() + fn master_main_index(&self) -> usize { + U32_TABLE_START + self.main_index() } } @@ -609,120 +609,120 @@ impl MasterBaseTableColumn for U32BaseTableColumn { /// - one to get the index of the column in the “local” extension table, _i.e._, not the master /// extension table, and /// - one to get the index of the column in the master extension table. -pub trait MasterExtTableColumn { +pub trait MasterAuxColumn { /// The index of the column in the “local” extension table, _i.e._, not the master extension /// table. - fn ext_table_index(&self) -> usize; + fn aux_index(&self) -> usize; /// The index of the column in the master extension table. - fn master_ext_table_index(&self) -> usize; + fn master_aux_index(&self) -> usize; } -impl MasterExtTableColumn for ProgramExtTableColumn { +impl MasterAuxColumn for ProgramAuxColumn { #[inline] - fn ext_table_index(&self) -> usize { + fn aux_index(&self) -> usize { (*self) as usize } #[inline] - fn master_ext_table_index(&self) -> usize { - EXT_PROGRAM_TABLE_START + self.ext_table_index() + fn master_aux_index(&self) -> usize { + AUX_PROGRAM_TABLE_START + self.aux_index() } } -impl MasterExtTableColumn for ProcessorExtTableColumn { +impl MasterAuxColumn for ProcessorAuxColumn { #[inline] - fn ext_table_index(&self) -> usize { + fn aux_index(&self) -> usize { (*self) as usize } #[inline] - fn master_ext_table_index(&self) -> usize { - EXT_PROCESSOR_TABLE_START + self.ext_table_index() + fn master_aux_index(&self) -> usize { + AUX_PROCESSOR_TABLE_START + self.aux_index() } } -impl MasterExtTableColumn for OpStackExtTableColumn { +impl MasterAuxColumn for OpStackAuxColumn { #[inline] - fn ext_table_index(&self) -> usize { + fn aux_index(&self) -> usize { (*self) as usize } #[inline] - fn master_ext_table_index(&self) -> usize { - EXT_OP_STACK_TABLE_START + self.ext_table_index() + fn master_aux_index(&self) -> usize { + AUX_OP_STACK_TABLE_START + self.aux_index() } } -impl MasterExtTableColumn for RamExtTableColumn { +impl MasterAuxColumn for RamAuxColumn { #[inline] - fn ext_table_index(&self) -> usize { + fn aux_index(&self) -> usize { (*self) as usize } #[inline] - fn master_ext_table_index(&self) -> usize { - EXT_RAM_TABLE_START + self.ext_table_index() + fn master_aux_index(&self) -> usize { + AUX_RAM_TABLE_START + self.aux_index() } } -impl MasterExtTableColumn for JumpStackExtTableColumn { +impl MasterAuxColumn for JumpStackAuxColumn { #[inline] - fn ext_table_index(&self) -> usize { + fn aux_index(&self) -> usize { (*self) as usize } #[inline] - fn master_ext_table_index(&self) -> usize { - EXT_JUMP_STACK_TABLE_START + self.ext_table_index() + fn master_aux_index(&self) -> usize { + AUX_JUMP_STACK_TABLE_START + self.aux_index() } } -impl MasterExtTableColumn for HashExtTableColumn { +impl MasterAuxColumn for HashAuxColumn { #[inline] - fn ext_table_index(&self) -> usize { + fn aux_index(&self) -> usize { (*self) as usize } #[inline] - fn master_ext_table_index(&self) -> usize { - EXT_HASH_TABLE_START + self.ext_table_index() + fn master_aux_index(&self) -> usize { + AUX_HASH_TABLE_START + self.aux_index() } } -impl MasterExtTableColumn for CascadeExtTableColumn { +impl MasterAuxColumn for CascadeAuxColumn { #[inline] - fn ext_table_index(&self) -> usize { + fn aux_index(&self) -> usize { (*self) as usize } #[inline] - fn master_ext_table_index(&self) -> usize { - EXT_CASCADE_TABLE_START + self.ext_table_index() + fn master_aux_index(&self) -> usize { + AUX_CASCADE_TABLE_START + self.aux_index() } } -impl MasterExtTableColumn for LookupExtTableColumn { +impl MasterAuxColumn for LookupAuxColumn { #[inline] - fn ext_table_index(&self) -> usize { + fn aux_index(&self) -> usize { (*self) as usize } #[inline] - fn master_ext_table_index(&self) -> usize { - EXT_LOOKUP_TABLE_START + self.ext_table_index() + fn master_aux_index(&self) -> usize { + AUX_LOOKUP_TABLE_START + self.aux_index() } } -impl MasterExtTableColumn for U32ExtTableColumn { +impl MasterAuxColumn for U32AuxColumn { #[inline] - fn ext_table_index(&self) -> usize { + fn aux_index(&self) -> usize { (*self) as usize } #[inline] - fn master_ext_table_index(&self) -> usize { - EXT_U32_TABLE_START + self.ext_table_index() + fn master_aux_index(&self) -> usize { + AUX_U32_TABLE_START + self.aux_index() } } @@ -735,40 +735,40 @@ mod tests { #[test] fn master_base_table_is_contiguous() { let mut expected_column_index = 0; - for column in ProgramBaseTableColumn::iter() { - assert_eq!(expected_column_index, column.master_base_table_index()); + for column in ProgramMainColumn::iter() { + assert_eq!(expected_column_index, column.master_main_index()); expected_column_index += 1; } - for column in ProcessorBaseTableColumn::iter() { - assert_eq!(expected_column_index, column.master_base_table_index()); + for column in ProcessorMainColumn::iter() { + assert_eq!(expected_column_index, column.master_main_index()); expected_column_index += 1; } - for column in OpStackBaseTableColumn::iter() { - assert_eq!(expected_column_index, column.master_base_table_index()); + for column in OpStackMainColumn::iter() { + assert_eq!(expected_column_index, column.master_main_index()); expected_column_index += 1; } - for column in RamBaseTableColumn::iter() { - assert_eq!(expected_column_index, column.master_base_table_index()); + for column in RamMainColumn::iter() { + assert_eq!(expected_column_index, column.master_main_index()); expected_column_index += 1; } - for column in JumpStackBaseTableColumn::iter() { - assert_eq!(expected_column_index, column.master_base_table_index()); + for column in JumpStackMainColumn::iter() { + assert_eq!(expected_column_index, column.master_main_index()); expected_column_index += 1; } - for column in HashBaseTableColumn::iter() { - assert_eq!(expected_column_index, column.master_base_table_index()); + for column in HashMainColumn::iter() { + assert_eq!(expected_column_index, column.master_main_index()); expected_column_index += 1; } - for column in CascadeBaseTableColumn::iter() { - assert_eq!(expected_column_index, column.master_base_table_index()); + for column in CascadeMainColumn::iter() { + assert_eq!(expected_column_index, column.master_main_index()); expected_column_index += 1; } - for column in LookupBaseTableColumn::iter() { - assert_eq!(expected_column_index, column.master_base_table_index()); + for column in LookupMainColumn::iter() { + assert_eq!(expected_column_index, column.master_main_index()); expected_column_index += 1; } - for column in U32BaseTableColumn::iter() { - assert_eq!(expected_column_index, column.master_base_table_index()); + for column in U32MainColumn::iter() { + assert_eq!(expected_column_index, column.master_main_index()); expected_column_index += 1; } } @@ -776,40 +776,40 @@ mod tests { #[test] fn master_ext_table_is_contiguous() { let mut expected_column_index = 0; - for column in ProgramExtTableColumn::iter() { - assert_eq!(expected_column_index, column.master_ext_table_index()); + for column in ProgramAuxColumn::iter() { + assert_eq!(expected_column_index, column.master_aux_index()); expected_column_index += 1; } - for column in ProcessorExtTableColumn::iter() { - assert_eq!(expected_column_index, column.master_ext_table_index()); + for column in ProcessorAuxColumn::iter() { + assert_eq!(expected_column_index, column.master_aux_index()); expected_column_index += 1; } - for column in OpStackExtTableColumn::iter() { - assert_eq!(expected_column_index, column.master_ext_table_index()); + for column in OpStackAuxColumn::iter() { + assert_eq!(expected_column_index, column.master_aux_index()); expected_column_index += 1; } - for column in RamExtTableColumn::iter() { - assert_eq!(expected_column_index, column.master_ext_table_index()); + for column in RamAuxColumn::iter() { + assert_eq!(expected_column_index, column.master_aux_index()); expected_column_index += 1; } - for column in JumpStackExtTableColumn::iter() { - assert_eq!(expected_column_index, column.master_ext_table_index()); + for column in JumpStackAuxColumn::iter() { + assert_eq!(expected_column_index, column.master_aux_index()); expected_column_index += 1; } - for column in HashExtTableColumn::iter() { - assert_eq!(expected_column_index, column.master_ext_table_index()); + for column in HashAuxColumn::iter() { + assert_eq!(expected_column_index, column.master_aux_index()); expected_column_index += 1; } - for column in CascadeExtTableColumn::iter() { - assert_eq!(expected_column_index, column.master_ext_table_index()); + for column in CascadeAuxColumn::iter() { + assert_eq!(expected_column_index, column.master_aux_index()); expected_column_index += 1; } - for column in LookupExtTableColumn::iter() { - assert_eq!(expected_column_index, column.master_ext_table_index()); + for column in LookupAuxColumn::iter() { + assert_eq!(expected_column_index, column.master_aux_index()); expected_column_index += 1; } - for column in U32ExtTableColumn::iter() { - assert_eq!(expected_column_index, column.master_ext_table_index()); + for column in U32AuxColumn::iter() { + assert_eq!(expected_column_index, column.master_aux_index()); expected_column_index += 1; } } diff --git a/triton-constraint-builder/src/codegen.rs b/triton-constraint-builder/src/codegen.rs index 15b9dd00c..675f39357 100644 --- a/triton-constraint-builder/src/codegen.rs +++ b/triton-constraint-builder/src/codegen.rs @@ -88,7 +88,7 @@ impl Codegen for RustBackend { ); let quotient_trait_impl = quote!( - impl MasterExtTable { + impl MasterAuxTable { pub const NUM_INITIAL_CONSTRAINTS: usize = #num_init_constraints; pub const NUM_CONSISTENCY_CONSTRAINTS: usize = #num_cons_constraints; pub const NUM_TRANSITION_CONSTRAINTS: usize = #num_tran_constraints; @@ -147,11 +147,11 @@ impl RustBackend { field: TokenStream, ) -> TokenStream { quote!( - impl Evaluable<#field> for MasterExtTable { + impl Evaluable<#field> for MasterAuxTable { #[allow(unused_variables)] fn evaluate_initial_constraints( - base_row: ArrayView1<#field>, - ext_row: ArrayView1, + main_row: ArrayView1<#field>, + aux_row: ArrayView1, challenges: &Challenges, ) -> Vec { #init_constraints @@ -159,8 +159,8 @@ impl RustBackend { #[allow(unused_variables)] fn evaluate_consistency_constraints( - base_row: ArrayView1<#field>, - ext_row: ArrayView1, + main_row: ArrayView1<#field>, + aux_row: ArrayView1, challenges: &Challenges, ) -> Vec { #cons_constraints @@ -168,10 +168,10 @@ impl RustBackend { #[allow(unused_variables)] fn evaluate_transition_constraints( - current_base_row: ArrayView1<#field>, - current_ext_row: ArrayView1, - next_base_row: ArrayView1<#field>, - next_ext_row: ArrayView1, + current_main_row: ArrayView1<#field>, + current_aux_row: ArrayView1, + next_main_row: ArrayView1<#field>, + next_aux_row: ArrayView1, challenges: &Challenges, ) -> Vec { #tran_constraints @@ -179,8 +179,8 @@ impl RustBackend { #[allow(unused_variables)] fn evaluate_terminal_constraints( - base_row: ArrayView1<#field>, - ext_row: ArrayView1, + main_row: ArrayView1<#field>, + aux_row: ArrayView1, challenges: &Challenges, ) -> Vec { #term_constraints @@ -205,16 +205,16 @@ impl RustBackend { let mut backend = Self::default(); let shared_declarations = backend.declare_shared_nodes(constraints); - let (base_constraints, ext_constraints): (Vec<_>, Vec<_>) = constraints + let (main_constraints, aux_constraints): (Vec<_>, Vec<_>) = constraints .iter() .partition(|constraint| constraint.evaluates_to_base_element()); // The order of the constraints' degrees must match the order of the constraints. // Hence, listing the degrees is only possible after the partition into base and extension // constraints is known. - let tokenized_degree_bounds = base_constraints + let tokenized_degree_bounds = main_constraints .iter() - .chain(&ext_constraints) + .chain(&aux_constraints) .map(|circuit| match circuit.degree() { d if d > 1 => quote!(interpolant_degree * #d - zerofier_degree), 1 => quote!(interpolant_degree - zerofier_degree), @@ -229,32 +229,32 @@ impl RustBackend { .map(|constraint| backend.evaluate_single_node(constraint)) .collect_vec() }; - let tokenized_base_constraints = tokenize_constraint_evaluation(&base_constraints); - let tokenized_ext_constraints = tokenize_constraint_evaluation(&ext_constraints); + let tokenized_main_constraints = tokenize_constraint_evaluation(&main_constraints); + let tokenized_aux_constraints = tokenize_constraint_evaluation(&aux_constraints); // If there are no base constraints, the type needs to be explicitly declared. - let tokenized_bfe_base_constraints = match base_constraints.is_empty() { - true => quote!(let base_constraints: [BFieldElement; 0] = []), - false => quote!(let base_constraints = [#(#tokenized_base_constraints),*]), + let tokenized_bfe_main_constraints = match main_constraints.is_empty() { + true => quote!(let main_constraints: [BFieldElement; 0] = []), + false => quote!(let main_constraints = [#(#tokenized_main_constraints),*]), }; let tokenized_bfe_constraints = quote!( #(#shared_declarations)* - #tokenized_bfe_base_constraints; - let ext_constraints = [#(#tokenized_ext_constraints),*]; - base_constraints + #tokenized_bfe_main_constraints; + let aux_constraints = [#(#tokenized_aux_constraints),*]; + main_constraints .into_iter() .map(|bfe| bfe.lift()) - .chain(ext_constraints) + .chain(aux_constraints) .collect() ); let tokenized_xfe_constraints = quote!( #(#shared_declarations)* - let base_constraints = [#(#tokenized_base_constraints),*]; - let ext_constraints = [#(#tokenized_ext_constraints),*]; - base_constraints + let main_constraints = [#(#tokenized_main_constraints),*]; + let aux_constraints = [#(#tokenized_aux_constraints),*]; + main_constraints .into_iter() - .chain(ext_constraints) + .chain(aux_constraints) .collect() ); @@ -455,10 +455,10 @@ impl Codegen for TasmBackend { mem_layout: StaticTasmConstraintEvaluationMemoryLayout, ) -> Vec { let free_mem_page_ptr = mem_layout.free_mem_page_ptr.value(); - let curr_base_row_ptr = mem_layout.curr_base_row_ptr.value(); - let curr_ext_row_ptr = mem_layout.curr_ext_row_ptr.value(); - let next_base_row_ptr = mem_layout.next_base_row_ptr.value(); - let next_ext_row_ptr = mem_layout.next_ext_row_ptr.value(); + let curr_main_row_ptr = mem_layout.curr_main_row_ptr.value(); + let curr_aux_row_ptr = mem_layout.curr_aux_row_ptr.value(); + let next_main_row_ptr = mem_layout.next_main_row_ptr.value(); + let next_aux_row_ptr = mem_layout.next_aux_row_ptr.value(); let challenges_ptr = mem_layout.challenges_ptr.value(); let raw_instructions = vec![ @@ -497,10 +497,10 @@ impl Codegen for TasmBackend { ) -> Vec { let num_pointer_pointers = 4; let free_mem_page_ptr = mem_layout.free_mem_page_ptr.value() + num_pointer_pointers; - let curr_base_row_ptr = mem_layout.free_mem_page_ptr.value(); - let curr_ext_row_ptr = mem_layout.free_mem_page_ptr.value() + 1; - let next_base_row_ptr = mem_layout.free_mem_page_ptr.value() + 2; - let next_ext_row_ptr = mem_layout.free_mem_page_ptr.value() + 3; + let curr_main_row_ptr = mem_layout.free_mem_page_ptr.value(); + let curr_aux_row_ptr = mem_layout.free_mem_page_ptr.value() + 1; + let next_main_row_ptr = mem_layout.free_mem_page_ptr.value() + 2; + let next_aux_row_ptr = mem_layout.free_mem_page_ptr.value() + 3; let challenges_ptr = mem_layout.challenges_ptr.value(); let raw_instructions = vec![ @@ -586,11 +586,11 @@ impl TasmBackend { [integral]: crate::memory_layout::IntegralMemoryLayout::is_integral [xfe]: twenty_first::prelude::XFieldElement - [total]: crate::table::master_table::MasterExtTable::NUM_CONSTRAINTS - [init]: crate::table::master_table::MasterExtTable::NUM_INITIAL_CONSTRAINTS - [cons]: crate::table::master_table::MasterExtTable::NUM_CONSISTENCY_CONSTRAINTS - [tran]: crate::table::master_table::MasterExtTable::NUM_TRANSITION_CONSTRAINTS - [term]: crate::table::master_table::MasterExtTable::NUM_TERMINAL_CONSTRAINTS + [total]: crate::table::master_table::MasterAuxTable::NUM_CONSTRAINTS + [init]: crate::table::master_table::MasterAuxTable::NUM_INITIAL_CONSTRAINTS + [cons]: crate::table::master_table::MasterAuxTable::NUM_CONSISTENCY_CONSTRAINTS + [tran]: crate::table::master_table::MasterAuxTable::NUM_TRANSITION_CONSTRAINTS + [term]: crate::table::master_table::MasterAuxTable::NUM_TERMINAL_CONSTRAINTS " } @@ -628,11 +628,11 @@ impl TasmBackend { [integral]: crate::memory_layout::IntegralMemoryLayout::is_integral [xfe]: twenty_first::prelude::XFieldElement - [total]: crate::table::master_table::MasterExtTable::NUM_CONSTRAINTS - [init]: crate::table::master_table::MasterExtTable::NUM_INITIAL_CONSTRAINTS - [cons]: crate::table::master_table::MasterExtTable::NUM_CONSISTENCY_CONSTRAINTS - [tran]: crate::table::master_table::MasterExtTable::NUM_TRANSITION_CONSTRAINTS - [term]: crate::table::master_table::MasterExtTable::NUM_TERMINAL_CONSTRAINTS + [total]: crate::table::master_table::MasterAuxTable::NUM_CONSTRAINTS + [init]: crate::table::master_table::MasterAuxTable::NUM_INITIAL_CONSTRAINTS + [cons]: crate::table::master_table::MasterAuxTable::NUM_CONSISTENCY_CONSTRAINTS + [tran]: crate::table::master_table::MasterAuxTable::NUM_TRANSITION_CONSTRAINTS + [term]: crate::table::master_table::MasterAuxTable::NUM_TERMINAL_CONSTRAINTS " } @@ -670,10 +670,10 @@ impl TasmBackend { let store_shared_nodes = self.store_all_shared_nodes(constraints); // to match the `RustBackend`, base constraints must be emitted first - let (base_constraints, ext_constraints): (Vec<_>, Vec<_>) = constraints + let (main_constraints, aux_constraints): (Vec<_>, Vec<_>) = constraints .iter() .partition(|constraint| constraint.evaluates_to_base_element()); - let sorted_constraints = base_constraints.into_iter().chain(ext_constraints); + let sorted_constraints = main_constraints.into_iter().chain(aux_constraints); let write_to_output = sorted_constraints .map(|c| self.write_evaluated_constraint_into_output_list(c)) .concat(); @@ -779,7 +779,7 @@ impl TasmBackend { } fn load_input(&self, input: II) -> Vec { - let list = match (input.is_current_row(), input.is_base_table_column()) { + let list = match (input.is_current_row(), input.is_main_table_column()) { (true, true) => IOList::CurrBaseRow, (true, false) => IOList::CurrExtRow, (false, true) => IOList::NextBaseRow, @@ -870,10 +870,10 @@ impl ToTokens for IOList { fn to_tokens(&self, tokens: &mut TokenStream) { match self { IOList::FreeMemPage => tokens.extend(quote!(free_mem_page_ptr)), - IOList::CurrBaseRow => tokens.extend(quote!(curr_base_row_ptr)), - IOList::CurrExtRow => tokens.extend(quote!(curr_ext_row_ptr)), - IOList::NextBaseRow => tokens.extend(quote!(next_base_row_ptr)), - IOList::NextExtRow => tokens.extend(quote!(next_ext_row_ptr)), + IOList::CurrBaseRow => tokens.extend(quote!(curr_main_row_ptr)), + IOList::CurrExtRow => tokens.extend(quote!(curr_aux_row_ptr)), + IOList::NextBaseRow => tokens.extend(quote!(next_main_row_ptr)), + IOList::NextExtRow => tokens.extend(quote!(next_aux_row_ptr)), IOList::Challenges => tokens.extend(quote!(challenges_ptr)), } } @@ -891,10 +891,10 @@ mod tests { let circuit_builder = ConstraintCircuitBuilder::new(); let challenge = |c: usize| circuit_builder.challenge(c); let constant = |c: u32| circuit_builder.x_constant(c); - let base_row = |i| circuit_builder.input(SingleRowIndicator::BaseRow(i)); - let ext_row = |i| circuit_builder.input(SingleRowIndicator::ExtRow(i)); + let main_row = |i| circuit_builder.input(SingleRowIndicator::Main(i)); + let aux_row = |i| circuit_builder.input(SingleRowIndicator::Aux(i)); - let constraint = base_row(0) * challenge(3) - ext_row(1) * constant(42); + let constraint = main_row(0) * challenge(3) - aux_row(1) * constant(42); Constraints { init: vec![constraint], diff --git a/triton-constraint-builder/src/lib.rs b/triton-constraint-builder/src/lib.rs index fefbfdd6a..d6bc6c565 100644 --- a/triton-constraint-builder/src/lib.rs +++ b/triton-constraint-builder/src/lib.rs @@ -145,31 +145,31 @@ impl Constraints { let (init_base_substitutions, init_ext_substitutions) = ConstraintCircuitMonad::lower_to_degree(&mut self.init, info); - info.num_base_cols += init_base_substitutions.len(); - info.num_ext_cols += init_ext_substitutions.len(); + info.num_main_cols += init_base_substitutions.len(); + info.num_aux_cols += init_ext_substitutions.len(); let (cons_base_substitutions, cons_ext_substitutions) = ConstraintCircuitMonad::lower_to_degree(&mut self.cons, info); - info.num_base_cols += cons_base_substitutions.len(); - info.num_ext_cols += cons_ext_substitutions.len(); + info.num_main_cols += cons_base_substitutions.len(); + info.num_aux_cols += cons_ext_substitutions.len(); let (tran_base_substitutions, tran_ext_substitutions) = ConstraintCircuitMonad::lower_to_degree(&mut self.tran, info); - info.num_base_cols += tran_base_substitutions.len(); - info.num_ext_cols += tran_ext_substitutions.len(); + info.num_main_cols += tran_base_substitutions.len(); + info.num_aux_cols += tran_ext_substitutions.len(); let (term_base_substitutions, term_ext_substitutions) = ConstraintCircuitMonad::lower_to_degree(&mut self.term, info); AllSubstitutions { - base: Substitutions { + main: Substitutions { lowering_info, init: init_base_substitutions, cons: cons_base_substitutions, tran: tran_base_substitutions, term: term_base_substitutions, }, - ext: Substitutions { + aux: Substitutions { lowering_info, init: init_ext_substitutions, cons: cons_ext_substitutions, @@ -182,7 +182,10 @@ impl Constraints { #[must_use] pub fn combine_with_substitution_induced_constraints( self, - AllSubstitutions { base, ext }: AllSubstitutions, + AllSubstitutions { + main: base, + aux: ext, + }: AllSubstitutions, ) -> Self { Self { init: [self.init, base.init, ext.init].concat(), @@ -239,8 +242,8 @@ mod tests { fn degree_lowering_info() -> DegreeLoweringInfo { DegreeLoweringInfo { target_degree: 4, - num_base_cols: 42, - num_ext_cols: 13, + num_main_cols: 42, + num_aux_cols: 13, } } @@ -301,7 +304,7 @@ mod tests { let circuit_builder = ConstraintCircuitBuilder::new(); let challenge = |c| circuit_builder.challenge(c); let constant = |c: u32| circuit_builder.b_constant(bfe!(c)); - let input = |i| circuit_builder.input(SingleRowIndicator::BaseRow(i)); + let input = |i| circuit_builder.input(SingleRowIndicator::Main(i)); let input_to_the_4th = |i| input(i) * input(i) * input(i) * input(i); vec![ @@ -316,10 +319,10 @@ mod tests { let challenge = |c| circuit_builder.challenge(c); let constant = |c: u32| circuit_builder.x_constant(c); - let curr_b_row = |col| circuit_builder.input(DualRowIndicator::CurrentBaseRow(col)); - let next_b_row = |col| circuit_builder.input(DualRowIndicator::NextBaseRow(col)); - let curr_x_row = |col| circuit_builder.input(DualRowIndicator::CurrentExtRow(col)); - let next_x_row = |col| circuit_builder.input(DualRowIndicator::NextExtRow(col)); + let curr_b_row = |col| circuit_builder.input(DualRowIndicator::CurrentMain(col)); + let next_b_row = |col| circuit_builder.input(DualRowIndicator::NextMain(col)); + let curr_x_row = |col| circuit_builder.input(DualRowIndicator::CurrentAux(col)); + let next_x_row = |col| circuit_builder.input(DualRowIndicator::NextAux(col)); vec![ curr_b_row(0) * next_x_row(1) - next_b_row(1) * curr_x_row(0), diff --git a/triton-constraint-builder/src/substitutions.rs b/triton-constraint-builder/src/substitutions.rs index dc2df5dc4..e50109fd8 100644 --- a/triton-constraint-builder/src/substitutions.rs +++ b/triton-constraint-builder/src/substitutions.rs @@ -15,8 +15,8 @@ use crate::codegen::RustBackend; #[derive(Debug, Clone, Eq, PartialEq, Hash)] pub struct AllSubstitutions { - pub base: Substitutions, - pub ext: Substitutions, + pub main: Substitutions, + pub aux: Substitutions, } #[derive(Debug, Clone, Eq, PartialEq, Hash)] @@ -32,30 +32,30 @@ impl AllSubstitutions { /// Generate code that evaluates all substitution rules in order. /// This includes generating the columns that are to be filled using the substitution rules. pub fn generate_degree_lowering_table_code(&self) -> TokenStream { - let num_new_base_cols = self.base.len(); - let num_new_ext_cols = self.ext.len(); + let num_new_main_cols = self.main.len(); + let num_new_aux_cols = self.aux.len(); // A zero-variant enum cannot be annotated with `repr(usize)`. - let base_repr_usize = match num_new_base_cols { + let main_repr_usize = match num_new_main_cols { 0 => quote!(), _ => quote!(#[repr(usize)]), }; - let ext_repr_usize = match num_new_ext_cols { + let aux_repr_usize = match num_new_aux_cols { 0 => quote!(), _ => quote!(#[repr(usize)]), }; - let base_columns = (0..num_new_base_cols) - .map(|i| format_ident!("DegreeLoweringBaseCol{i}")) + let main_columns = (0..num_new_main_cols) + .map(|i| format_ident!("DegreeLoweringMainCol{i}")) .map(|ident| quote!(#ident)) .collect_vec(); - let ext_columns = (0..num_new_ext_cols) - .map(|i| format_ident!("DegreeLoweringExtCol{i}")) + let aux_columns = (0..num_new_aux_cols) + .map(|i| format_ident!("DegreeLoweringAuxCol{i}")) .map(|ident| quote!(#ident)) .collect_vec(); - let fill_base_columns_code = self.base.generate_fill_base_columns_code(); - let fill_ext_columns_code = self.ext.generate_fill_ext_columns_code(); + let fill_main_columns_code = self.main.generate_fill_main_columns_code(); + let fill_aux_columns_code = self.aux.generate_fill_aux_columns_code(); quote!( use ndarray::Array1; @@ -69,43 +69,43 @@ impl AllSubstitutions { use strum::EnumIter; use twenty_first::prelude::BFieldElement; use twenty_first::prelude::XFieldElement; - use air::table_column::MasterBaseTableColumn; - use air::table_column::MasterExtTableColumn; + use air::table_column::MasterMainColumn; + use air::table_column::MasterAuxColumn; use crate::challenges::Challenges; use crate::table::master_table::MasterTable; - #base_repr_usize + #main_repr_usize #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] - pub enum DegreeLoweringBaseTableColumn { - #(#base_columns),* + pub enum DegreeLoweringMainColumn { + #(#main_columns),* } - impl MasterBaseTableColumn for DegreeLoweringBaseTableColumn { - fn base_table_index(&self) -> usize { + impl MasterMainColumn for DegreeLoweringMainColumn { + fn main_index(&self) -> usize { (*self) as usize } - fn master_base_table_index(&self) -> usize { + fn master_main_index(&self) -> usize { // hardcore domain-specific knowledge, and bad style - air::table::U32_TABLE_END + self.base_table_index() + air::table::U32_TABLE_END + self.main_index() } } - #ext_repr_usize + #aux_repr_usize #[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)] - pub enum DegreeLoweringExtTableColumn { - #(#ext_columns),* + pub enum DegreeLoweringAuxColumn { + #(#aux_columns),* } - impl MasterExtTableColumn for DegreeLoweringExtTableColumn { - fn ext_table_index(&self) -> usize { + impl MasterAuxColumn for DegreeLoweringAuxColumn { + fn aux_index(&self) -> usize { (*self) as usize } - fn master_ext_table_index(&self) -> usize { + fn master_aux_index(&self) -> usize { // hardcore domain-specific knowledge, and bad style - air::table::EXT_U32_TABLE_END + self.ext_table_index() + air::table::AUX_U32_TABLE_END + self.aux_index() } } @@ -113,8 +113,8 @@ impl AllSubstitutions { pub struct DegreeLoweringTable; impl DegreeLoweringTable { - #fill_base_columns_code - #fill_ext_columns_code + #fill_main_columns_code + #fill_aux_columns_code } ) } @@ -125,8 +125,8 @@ impl Substitutions { self.init.len() + self.cons.len() + self.tran.len() + self.term.len() } - fn generate_fill_base_columns_code(&self) -> TokenStream { - let derived_section_init_start = self.lowering_info.num_base_cols; + fn generate_fill_main_columns_code(&self) -> TokenStream { + let derived_section_init_start = self.lowering_info.num_main_cols; let derived_section_cons_start = derived_section_init_start + self.init.len(); let derived_section_tran_start = derived_section_cons_start + self.cons.len(); let derived_section_term_start = derived_section_tran_start + self.tran.len(); @@ -137,22 +137,22 @@ impl Substitutions { let term_substitutions = Self::several_substitution_rules_to_code(&self.term); let init_substitutions = - Self::base_single_row_substitutions(derived_section_init_start, &init_substitutions); + Self::main_single_row_substitutions(derived_section_init_start, &init_substitutions); let cons_substitutions = - Self::base_single_row_substitutions(derived_section_cons_start, &cons_substitutions); + Self::main_single_row_substitutions(derived_section_cons_start, &cons_substitutions); let tran_substitutions = - Self::base_dual_row_substitutions(derived_section_tran_start, &tran_substitutions); + Self::main_dual_row_substitutions(derived_section_tran_start, &tran_substitutions); let term_substitutions = - Self::base_single_row_substitutions(derived_section_term_start, &term_substitutions); + Self::main_single_row_substitutions(derived_section_term_start, &term_substitutions); quote!( #[allow(unused_variables)] - pub fn fill_derived_base_columns( - mut master_base_table: ArrayViewMut2 + pub fn fill_derived_main_columns( + mut master_main_table: ArrayViewMut2 ) { let num_expected_columns = - crate::table::master_table::MasterBaseTable::NUM_COLUMNS; - assert_eq!(num_expected_columns, master_base_table.ncols()); + crate::table::master_table::MasterMainTable::NUM_COLUMNS; + assert_eq!(num_expected_columns, master_main_table.ncols()); #init_substitutions #cons_substitutions #tran_substitutions @@ -161,8 +161,8 @@ impl Substitutions { ) } - fn generate_fill_ext_columns_code(&self) -> TokenStream { - let derived_section_init_start = self.lowering_info.num_ext_cols; + fn generate_fill_aux_columns_code(&self) -> TokenStream { + let derived_section_init_start = self.lowering_info.num_aux_cols; let derived_section_cons_start = derived_section_init_start + self.init.len(); let derived_section_tran_start = derived_section_cons_start + self.cons.len(); let derived_section_term_start = derived_section_tran_start + self.tran.len(); @@ -173,29 +173,29 @@ impl Substitutions { let term_substitutions = Self::several_substitution_rules_to_code(&self.term); let init_substitutions = - Self::ext_single_row_substitutions(derived_section_init_start, &init_substitutions); + Self::aux_single_row_substitutions(derived_section_init_start, &init_substitutions); let cons_substitutions = - Self::ext_single_row_substitutions(derived_section_cons_start, &cons_substitutions); + Self::aux_single_row_substitutions(derived_section_cons_start, &cons_substitutions); let tran_substitutions = - Self::ext_dual_row_substitutions(derived_section_tran_start, &tran_substitutions); + Self::aux_dual_row_substitutions(derived_section_tran_start, &tran_substitutions); let term_substitutions = - Self::ext_single_row_substitutions(derived_section_term_start, &term_substitutions); + Self::aux_single_row_substitutions(derived_section_term_start, &term_substitutions); quote!( #[allow(unused_variables)] #[allow(unused_mut)] - pub fn fill_derived_ext_columns( - master_base_table: ArrayView2, - mut master_ext_table: ArrayViewMut2, + pub fn fill_derived_aux_columns( + master_main_table: ArrayView2, + mut master_aux_table: ArrayViewMut2, challenges: &Challenges, ) { let num_expected_main_columns = - crate::table::master_table::MasterBaseTable::NUM_COLUMNS; + crate::table::master_table::MasterMainTable::NUM_COLUMNS; let num_expected_aux_columns = - crate::table::master_table::MasterExtTable::NUM_COLUMNS; - assert_eq!(num_expected_main_columns, master_base_table.ncols()); - assert_eq!(num_expected_aux_columns, master_ext_table.ncols()); - assert_eq!(master_base_table.nrows(), master_ext_table.nrows()); + crate::table::master_table::MasterAuxTable::NUM_COLUMNS; + assert_eq!(num_expected_main_columns, master_main_table.ncols()); + assert_eq!(num_expected_aux_columns, master_aux_table.ncols()); + assert_eq!(master_main_table.nrows(), master_aux_table.nrows()); #init_substitutions #cons_substitutions #tran_substitutions @@ -235,7 +235,7 @@ impl Substitutions { RustBackend::default().evaluate_single_node(&expr) } - fn base_single_row_substitutions( + fn main_single_row_substitutions( section_start_index: usize, substitutions: &[TokenStream], ) -> TokenStream { @@ -246,7 +246,7 @@ impl Substitutions { } quote!( let (original_part, mut current_section) = - master_base_table.multi_slice_mut( + master_main_table.multi_slice_mut( ( s![.., 0..#section_start_index], s![.., #section_start_index..#section_start_index+#num_substitutions], @@ -255,16 +255,16 @@ impl Substitutions { Zip::from(original_part.rows()) .and(current_section.rows_mut()) .par_for_each(|original_row, mut section_row| { - let mut base_row = original_row.to_owned(); + let mut main_row = original_row.to_owned(); #( section_row[#indices] = #substitutions; - base_row.push(Axis(0), section_row.slice(s![#indices])).unwrap(); + main_row.push(Axis(0), section_row.slice(s![#indices])).unwrap(); )* }); ) } - fn base_dual_row_substitutions( + fn main_dual_row_substitutions( section_start_index: usize, substitutions: &[TokenStream], ) -> TokenStream { @@ -274,9 +274,9 @@ impl Substitutions { return quote!(); } quote!( - let num_rows = master_base_table.nrows(); + let num_rows = master_main_table.nrows(); let (original_part, mut current_section) = - master_base_table.multi_slice_mut( + master_main_table.multi_slice_mut( ( s![.., 0..#section_start_index], s![.., #section_start_index..#section_start_index+#num_substitutions], @@ -287,19 +287,19 @@ impl Substitutions { .and(row_indices.view()) .par_for_each( |mut section_row, ¤t_row_index| { let next_row_index = current_row_index + 1; - let current_base_row_slice = original_part.slice(s![current_row_index..=current_row_index, ..]); - let next_base_row_slice = original_part.slice(s![next_row_index..=next_row_index, ..]); - let mut current_base_row = current_base_row_slice.row(0).to_owned(); - let next_base_row = next_base_row_slice.row(0); + let current_main_row_slice = original_part.slice(s![current_row_index..=current_row_index, ..]); + let next_main_row_slice = original_part.slice(s![next_row_index..=next_row_index, ..]); + let mut current_main_row = current_main_row_slice.row(0).to_owned(); + let next_main_row = next_main_row_slice.row(0); #( section_row[#indices] = #substitutions; - current_base_row.push(Axis(0), section_row.slice(s![#indices])).unwrap(); + current_main_row.push(Axis(0), section_row.slice(s![#indices])).unwrap(); )* }); ) } - fn ext_single_row_substitutions( + fn aux_single_row_substitutions( section_start_index: usize, substitutions: &[TokenStream], ) -> TokenStream { @@ -309,30 +309,30 @@ impl Substitutions { return quote!(); } quote!( - let (original_part, mut current_section) = master_ext_table.multi_slice_mut( + let (original_part, mut current_section) = master_aux_table.multi_slice_mut( ( s![.., 0..#section_start_index], s![.., #section_start_index..#section_start_index+#num_substitutions], ) ); - Zip::from(master_base_table.rows()) + Zip::from(master_main_table.rows()) .and(original_part.rows()) .and(current_section.rows_mut()) .par_for_each( - |base_table_row, original_row, mut section_row| { - let mut extension_row = original_row.to_owned(); + |main_table_row, original_row, mut section_row| { + let mut auxiliary_row = original_row.to_owned(); #( - let (original_row_extension_row, mut det_col) = + let (original_row_auxiliary_row, mut det_col) = section_row.multi_slice_mut((s![..#indices],s![#indices..=#indices])); det_col[0] = #substitutions; - extension_row.push(Axis(0), det_col.slice(s![0])).unwrap(); + auxiliary_row.push(Axis(0), det_col.slice(s![0])).unwrap(); )* } ); ) } - fn ext_dual_row_substitutions( + fn aux_dual_row_substitutions( section_start_index: usize, substitutions: &[TokenStream], ) -> TokenStream { @@ -342,8 +342,8 @@ impl Substitutions { return quote!(); } quote!( - let num_rows = master_base_table.nrows(); - let (original_part, mut current_section) = master_ext_table.multi_slice_mut( + let num_rows = master_main_table.nrows(); + let (original_part, mut current_section) = master_aux_table.multi_slice_mut( ( s![.., 0..#section_start_index], s![.., #section_start_index..#section_start_index+#num_substitutions], @@ -354,13 +354,13 @@ impl Substitutions { .and(row_indices.view()) .par_for_each(|mut section_row, ¤t_row_index| { let next_row_index = current_row_index + 1; - let current_base_row = master_base_table.row(current_row_index); - let next_base_row = master_base_table.row(next_row_index); - let mut current_ext_row = original_part.row(current_row_index).to_owned(); - let next_ext_row = original_part.row(next_row_index); + let current_main_row = master_main_table.row(current_row_index); + let next_main_row = master_main_table.row(next_row_index); + let mut current_aux_row = original_part.row(current_row_index).to_owned(); + let next_aux_row = original_part.row(next_row_index); #( section_row[#indices]= #substitutions; - current_ext_row.push(Axis(0), section_row.slice(s![#indices])).unwrap(); + current_aux_row.push(Axis(0), section_row.slice(s![#indices])).unwrap(); )* }); ) diff --git a/triton-constraint-circuit/src/lib.rs b/triton-constraint-circuit/src/lib.rs index a4e7f0a5c..428ca0fbc 100644 --- a/triton-constraint-circuit/src/lib.rs +++ b/triton-constraint-circuit/src/lib.rs @@ -40,11 +40,11 @@ pub struct DegreeLoweringInfo { /// The degree after degree lowering. Must be greater than 1. pub target_degree: isize, - /// The total number of base columns _before_ degree lowering has happened. - pub num_base_cols: usize, + /// The total number of main columns _before_ degree lowering has happened. + pub num_main_cols: usize, - /// The total number of extension columns _before_ degree lowering has happened. - pub num_ext_cols: usize, + /// The total number of auxiliary columns _before_ degree lowering has happened. + pub num_aux_cols: usize, } #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] @@ -96,7 +96,7 @@ impl BinOp { /// Having `Copy + Hash + Eq` helps to put `InputIndicator`s into containers. pub trait InputIndicator: Debug + Display + Copy + Hash + Eq + ToTokens { /// `true` iff `self` refers to a column in the base table. - fn is_base_table_column(&self) -> bool; + fn is_main_table_column(&self) -> bool; /// `true` iff `self` refers to the current row. fn is_current_row(&self) -> bool; @@ -105,7 +105,7 @@ pub trait InputIndicator: Debug + Display + Copy + Hash + Eq + ToTokens { fn column(&self) -> usize; fn base_table_input(index: usize) -> Self; - fn ext_table_input(index: usize) -> Self; + fn aux_table_input(index: usize) -> Self; fn evaluate( &self, @@ -118,15 +118,15 @@ pub trait InputIndicator: Debug + Display + Copy + Hash + Eq + ToTokens { /// execution trace. #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)] pub enum SingleRowIndicator { - BaseRow(usize), - ExtRow(usize), + Main(usize), + Aux(usize), } impl Display for SingleRowIndicator { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { let input_indicator: String = match self { - Self::BaseRow(i) => format!("base_row[{i}]"), - Self::ExtRow(i) => format!("ext_row[{i}]"), + Self::Main(i) => format!("main_row[{i}]"), + Self::Aux(i) => format!("aux_row[{i}]"), }; write!(f, "{input_indicator}") @@ -136,15 +136,15 @@ impl Display for SingleRowIndicator { impl ToTokens for SingleRowIndicator { fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { match self { - Self::BaseRow(i) => tokens.extend(quote!(base_row[#i])), - Self::ExtRow(i) => tokens.extend(quote!(ext_row[#i])), + Self::Main(i) => tokens.extend(quote!(main_row[#i])), + Self::Aux(i) => tokens.extend(quote!(aux_row[#i])), } } } impl InputIndicator for SingleRowIndicator { - fn is_base_table_column(&self) -> bool { - matches!(self, Self::BaseRow(_)) + fn is_main_table_column(&self) -> bool { + matches!(self, Self::Main(_)) } fn is_current_row(&self) -> bool { @@ -153,16 +153,16 @@ impl InputIndicator for SingleRowIndicator { fn column(&self) -> usize { match self { - Self::BaseRow(i) | Self::ExtRow(i) => *i, + Self::Main(i) | Self::Aux(i) => *i, } } fn base_table_input(index: usize) -> Self { - Self::BaseRow(index) + Self::Main(index) } - fn ext_table_input(index: usize) -> Self { - Self::ExtRow(index) + fn aux_table_input(index: usize) -> Self { + Self::Aux(index) } fn evaluate( @@ -171,8 +171,8 @@ impl InputIndicator for SingleRowIndicator { ext_table: ArrayView2, ) -> XFieldElement { match self { - Self::BaseRow(i) => base_table[[0, *i]].lift(), - Self::ExtRow(i) => ext_table[[0, *i]], + Self::Main(i) => base_table[[0, *i]].lift(), + Self::Aux(i) => ext_table[[0, *i]], } } } @@ -181,19 +181,19 @@ impl InputIndicator for SingleRowIndicator { /// next) of the execution trace. #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)] pub enum DualRowIndicator { - CurrentBaseRow(usize), - CurrentExtRow(usize), - NextBaseRow(usize), - NextExtRow(usize), + CurrentMain(usize), + CurrentAux(usize), + NextMain(usize), + NextAux(usize), } impl Display for DualRowIndicator { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { let input_indicator: String = match self { - Self::CurrentBaseRow(i) => format!("current_base_row[{i}]"), - Self::CurrentExtRow(i) => format!("current_ext_row[{i}]"), - Self::NextBaseRow(i) => format!("next_base_row[{i}]"), - Self::NextExtRow(i) => format!("next_ext_row[{i}]"), + Self::CurrentMain(i) => format!("current_main_row[{i}]"), + Self::CurrentAux(i) => format!("current_aux_row[{i}]"), + Self::NextMain(i) => format!("next_main_row[{i}]"), + Self::NextAux(i) => format!("next_aux_row[{i}]"), }; write!(f, "{input_indicator}") @@ -203,29 +203,26 @@ impl Display for DualRowIndicator { impl ToTokens for DualRowIndicator { fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { match self { - Self::CurrentBaseRow(i) => tokens.extend(quote!(current_base_row[#i])), - Self::CurrentExtRow(i) => tokens.extend(quote!(current_ext_row[#i])), - Self::NextBaseRow(i) => tokens.extend(quote!(next_base_row[#i])), - Self::NextExtRow(i) => tokens.extend(quote!(next_ext_row[#i])), + Self::CurrentMain(i) => tokens.extend(quote!(current_main_row[#i])), + Self::CurrentAux(i) => tokens.extend(quote!(current_aux_row[#i])), + Self::NextMain(i) => tokens.extend(quote!(next_main_row[#i])), + Self::NextAux(i) => tokens.extend(quote!(next_aux_row[#i])), } } } impl InputIndicator for DualRowIndicator { - fn is_base_table_column(&self) -> bool { - matches!(self, Self::CurrentBaseRow(_) | Self::NextBaseRow(_)) + fn is_main_table_column(&self) -> bool { + matches!(self, Self::CurrentMain(_) | Self::NextMain(_)) } fn is_current_row(&self) -> bool { - matches!(self, Self::CurrentBaseRow(_) | Self::CurrentExtRow(_)) + matches!(self, Self::CurrentMain(_) | Self::CurrentAux(_)) } fn column(&self) -> usize { match self { - Self::CurrentBaseRow(i) - | Self::NextBaseRow(i) - | Self::CurrentExtRow(i) - | Self::NextExtRow(i) => *i, + Self::CurrentMain(i) | Self::NextMain(i) | Self::CurrentAux(i) | Self::NextAux(i) => *i, } } @@ -233,11 +230,11 @@ impl InputIndicator for DualRowIndicator { // It seems that the choice between `CurrentBaseRow` and `NextBaseRow` is arbitrary: // any transition constraint polynomial is evaluated on both the current and the next row. // Hence, both rows are in scope. - Self::CurrentBaseRow(index) + Self::CurrentMain(index) } - fn ext_table_input(index: usize) -> Self { - Self::CurrentExtRow(index) + fn aux_table_input(index: usize) -> Self { + Self::CurrentAux(index) } fn evaluate( @@ -246,10 +243,10 @@ impl InputIndicator for DualRowIndicator { ext_table: ArrayView2, ) -> XFieldElement { match self { - Self::CurrentBaseRow(i) => base_table[[0, *i]].lift(), - Self::CurrentExtRow(i) => ext_table[[0, *i]], - Self::NextBaseRow(i) => base_table[[1, *i]].lift(), - Self::NextExtRow(i) => ext_table[[1, *i]], + Self::CurrentMain(i) => base_table[[0, *i]].lift(), + Self::CurrentAux(i) => ext_table[[0, *i]], + Self::NextMain(i) => base_table[[1, *i]].lift(), + Self::NextAux(i) => ext_table[[1, *i]], } } } @@ -501,14 +498,14 @@ impl ConstraintCircuit { } /// Recursively check whether this node is composed of only BFieldElements, i.e., only uses - /// 1. inputs from base rows, + /// 1. inputs from main rows, /// 2. constants from the B-field, and /// 3. binary operations on BFieldElements. pub fn evaluates_to_base_element(&self) -> bool { match &self.expression { CircuitExpression::BConstant(_) => true, CircuitExpression::XConstant(_) => false, - CircuitExpression::Input(indicator) => indicator.is_base_table_column(), + CircuitExpression::Input(indicator) => indicator.is_main_table_column(), CircuitExpression::Challenge(_) => false, CircuitExpression::BinaryOperation(_, lhs, rhs) => { lhs.borrow().evaluates_to_base_element() && rhs.borrow().evaluates_to_base_element() @@ -698,8 +695,8 @@ impl ConstraintCircuitMonad { /// The target degree must be greater than 1. /// /// The new constraints are returned as two vector of ConstraintCircuitMonads: - /// the first corresponds to base columns and constraints, - /// the second to extension columns and constraints. + /// the first corresponds to main columns and constraints, + /// the second to auxiliary columns and constraints. /// /// Each returned constraint is guaranteed to correspond to some /// `CircuitExpression::BinaryOperation(BinOp::Sub, lhs, rhs)` where @@ -722,11 +719,11 @@ impl ConstraintCircuitMonad { "Target degree must be greater than 1. Got {target_degree}." ); - let mut base_constraints = vec![]; - let mut ext_constraints = vec![]; + let mut main_constraints = vec![]; + let mut aux_constraints = vec![]; if multicircuit.is_empty() { - return (base_constraints, ext_constraints); + return (main_constraints, aux_constraints); } let builder = multicircuit[0].builder.clone(); @@ -736,13 +733,13 @@ impl ConstraintCircuitMonad { // Create a new variable. let chosen_node = builder.all_nodes.borrow()[&chosen_node_id].clone(); - let chosen_node_is_base_col = chosen_node.circuit.borrow().evaluates_to_base_element(); - let new_input_indicator = if chosen_node_is_base_col { - let new_base_col_idx = info.num_base_cols + base_constraints.len(); - II::base_table_input(new_base_col_idx) + let chosen_node_is_main_col = chosen_node.circuit.borrow().evaluates_to_base_element(); + let new_input_indicator = if chosen_node_is_main_col { + let new_main_col_idx = info.num_main_cols + main_constraints.len(); + II::base_table_input(new_main_col_idx) } else { - let new_ext_col_idx = info.num_ext_cols + ext_constraints.len(); - II::ext_table_input(new_ext_col_idx) + let new_aux_col_idx = info.num_aux_cols + aux_constraints.len(); + II::aux_table_input(new_aux_col_idx) }; let new_variable = builder.input(new_input_indicator); @@ -758,13 +755,13 @@ impl ConstraintCircuitMonad { // Create new constraint and put it into the appropriate return vector. let new_constraint = new_variable - chosen_node; - match chosen_node_is_base_col { - true => base_constraints.push(new_constraint), - false => ext_constraints.push(new_constraint), + match chosen_node_is_main_col { + true => main_constraints.push(new_constraint), + false => aux_constraints.push(new_constraint), } } - (base_constraints, ext_constraints) + (main_constraints, aux_constraints) } /// Heuristically pick a node from the given multicircuit that is to be substituted with a new @@ -1215,8 +1212,8 @@ mod tests { let builder = ConstraintCircuitBuilder::new(); assert_eq!("1", builder.b_constant(1).to_string()); assert_eq!( - "base_row[5] ", - builder.input(SingleRowIndicator::BaseRow(5)).to_string() + "main_row[5] ", + builder.input(SingleRowIndicator::Main(5)).to_string() ); assert_eq!("6", builder.challenge(6_usize).to_string()); @@ -1287,7 +1284,7 @@ mod tests { #[test] fn substitution_replaces_a_node_in_a_circuit() { let builder = ConstraintCircuitBuilder::new(); - let x = |i| builder.input(SingleRowIndicator::BaseRow(i)); + let x = |i| builder.input(SingleRowIndicator::Main(i)); let constant = |c: u32| builder.b_constant(c); let challenge = |i: usize| builder.challenge(i); @@ -1316,28 +1313,28 @@ mod tests { #[test] fn simple_degree_lowering() { let builder = ConstraintCircuitBuilder::new(); - let x = || builder.input(SingleRowIndicator::BaseRow(0)); + let x = || builder.input(SingleRowIndicator::Main(0)); let x_pow_3 = x() * x() * x(); let x_pow_5 = x() * x() * x() * x() * x(); let mut multicircuit = [x_pow_5, x_pow_3]; let degree_lowering_info = DegreeLoweringInfo { target_degree: 3, - num_base_cols: 1, - num_ext_cols: 0, + num_main_cols: 1, + num_aux_cols: 0, }; - let (new_base_constraints, new_ext_constraints) = + let (new_main_constraints, new_aux_constraints) = ConstraintCircuitMonad::lower_to_degree(&mut multicircuit, degree_lowering_info); - assert_eq!(1, new_base_constraints.len()); - assert!(new_ext_constraints.is_empty()); + assert_eq!(1, new_main_constraints.len()); + assert!(new_aux_constraints.is_empty()); } #[test] fn somewhat_simple_degree_lowering() { let builder = ConstraintCircuitBuilder::new(); - let x = |i| builder.input(SingleRowIndicator::BaseRow(i)); - let y = |i| builder.input(SingleRowIndicator::ExtRow(i)); + let x = |i| builder.input(SingleRowIndicator::Main(i)); + let y = |i| builder.input(SingleRowIndicator::Aux(i)); let b_con = |i: u64| builder.b_constant(i); let constraint_0 = x(0) * x(0) * (x(1) - x(2)) - x(0) * x(2) - b_con(42); @@ -1351,20 +1348,20 @@ mod tests { let degree_lowering_info = DegreeLoweringInfo { target_degree: 2, - num_base_cols: 3, - num_ext_cols: 2, + num_main_cols: 3, + num_aux_cols: 2, }; - let (new_base_constraints, new_ext_constraints) = + let (new_main_constraints, new_aux_constraints) = ConstraintCircuitMonad::lower_to_degree(&mut multicircuit, degree_lowering_info); - assert!(new_base_constraints.len() <= 3); - assert!(new_ext_constraints.len() <= 1); + assert!(new_main_constraints.len() <= 3); + assert!(new_aux_constraints.len() <= 1); } #[test] fn less_simple_degree_lowering() { let builder = ConstraintCircuitBuilder::new(); - let x = |i| builder.input(SingleRowIndicator::BaseRow(i)); + let x = |i| builder.input(SingleRowIndicator::Main(i)); let constraint_0 = (x(0) * x(1) * x(2)) * (x(3) * x(4)) * x(5); let constraint_1 = (x(6) * x(7)) * (x(3) * x(4)) * x(8); @@ -1373,21 +1370,21 @@ mod tests { let degree_lowering_info = DegreeLoweringInfo { target_degree: 3, - num_base_cols: 9, - num_ext_cols: 0, + num_main_cols: 9, + num_aux_cols: 0, }; - let (new_base_constraints, new_ext_constraints) = + let (new_main_constraints, new_aux_constraints) = ConstraintCircuitMonad::lower_to_degree(&mut multicircuit, degree_lowering_info); - assert!(new_base_constraints.len() <= 3); - assert!(new_ext_constraints.is_empty()); + assert!(new_main_constraints.len() <= 3); + assert!(new_aux_constraints.is_empty()); } #[test] fn all_nodes_in_multicircuit_are_identified_correctly() { let builder = ConstraintCircuitBuilder::new(); - let x = |i| builder.input(SingleRowIndicator::BaseRow(i)); + let x = |i| builder.input(SingleRowIndicator::Main(i)); let b_con = |i: u64| builder.b_constant(i); let sub_tree_0 = x(0) * x(1) * (x(2) - b_con(1)) * x(3) * x(4); @@ -1447,7 +1444,7 @@ mod tests { fn equivalent_nodes_are_detected_when_present() { let builder = ConstraintCircuitBuilder::new(); - let x = |i| builder.input(SingleRowIndicator::BaseRow(i)); + let x = |i| builder.input(SingleRowIndicator::Main(i)); let ch = |i: usize| builder.challenge(i); let u0 = x(0) + x(1); diff --git a/triton-vm/build.rs b/triton-vm/build.rs index 6c8fd6c74..935698897 100644 --- a/triton-vm/build.rs +++ b/triton-vm/build.rs @@ -12,8 +12,8 @@ fn main() { let mut constraints = Constraints::all(); let degree_lowering_info = constraint_circuit::DegreeLoweringInfo { target_degree: air::TARGET_DEGREE, - num_base_cols: air::table::NUM_BASE_COLUMNS, - num_ext_cols: air::table::NUM_EXT_COLUMNS, + num_main_cols: air::table::NUM_MAIN_COLUMNS, + num_aux_cols: air::table::NUM_AUX_COLUMNS, }; let substitutions = constraints.lower_to_target_degree_through_substitutions(degree_lowering_info); diff --git a/triton-vm/src/aet.rs b/triton-vm/src/aet.rs index 9bba2f14b..ae5dd8227 100644 --- a/triton-vm/src/aet.rs +++ b/triton-vm/src/aet.rs @@ -9,8 +9,8 @@ use air::table::op_stack; use air::table::processor; use air::table::ram; use air::table::TableId; -use air::table_column::HashBaseTableColumn::CI; -use air::table_column::MasterBaseTableColumn; +use air::table_column::HashMainColumn::CI; +use air::table_column::MasterMainColumn; use air::AIR; use arbitrary::Arbitrary; use isa::error::InstructionError; @@ -121,7 +121,7 @@ impl AlgebraicExecutionTrace { /// /// Guaranteed to be a power of two. /// - /// [pad]: table::master_table::MasterBaseTable::pad + /// [pad]: table::master_table::MasterMainTable::pad pub fn padded_height(&self) -> usize { self.height().height.next_power_of_two() } @@ -129,7 +129,7 @@ impl AlgebraicExecutionTrace { /// The height of the [AET](AlgebraicExecutionTrace) before [padding][pad]. /// Corresponds to the height of the longest table. /// - /// [pad]: table::master_table::MasterBaseTable::pad + /// [pad]: table::master_table::MasterMainTable::pad pub fn height(&self) -> TableHeight { TableId::iter() .map(|t| TableHeight::new(t, self.height_of_table(t))) @@ -191,7 +191,7 @@ impl AlgebraicExecutionTrace { .expect("shapes must be identical"); } - let instruction_column_index = CI.base_table_index(); + let instruction_column_index = CI.main_index(); let mut instruction_column = self.program_hash_trace.column_mut(instruction_column_index); instruction_column.fill(Instruction::Hash.opcode_b()); @@ -254,7 +254,7 @@ impl AlgebraicExecutionTrace { self.increase_lookup_multiplicities(trace); let mut hash_trace_addendum = table::hash::trace_to_table_rows(trace); hash_trace_addendum - .slice_mut(s![.., CI.base_table_index()]) + .slice_mut(s![.., CI.main_index()]) .fill(Instruction::Hash.opcode_b()); self.hash_trace .append(Axis(0), hash_trace_addendum.view()) @@ -265,7 +265,7 @@ impl AlgebraicExecutionTrace { let round_number = 0; let initial_state = Tip5::init().state; let mut hash_table_row = table::hash::trace_row_to_table_row(initial_state, round_number); - hash_table_row[CI.base_table_index()] = Instruction::SpongeInit.opcode_b(); + hash_table_row[CI.main_index()] = Instruction::SpongeInit.opcode_b(); self.sponge_trace.push_row(hash_table_row.view()).unwrap(); } @@ -277,7 +277,7 @@ impl AlgebraicExecutionTrace { self.increase_lookup_multiplicities(trace); let mut sponge_trace_addendum = table::hash::trace_to_table_rows(trace); sponge_trace_addendum - .slice_mut(s![.., CI.base_table_index()]) + .slice_mut(s![.., CI.main_index()]) .fill(instruction.opcode_b()); self.sponge_trace .append(Axis(0), sponge_trace_addendum.view()) @@ -333,7 +333,7 @@ impl AlgebraicExecutionTrace { } fn record_op_stack_entry(&mut self, op_stack_entry: OpStackTableEntry) { - let op_stack_table_row = op_stack_entry.to_base_table_row(); + let op_stack_table_row = op_stack_entry.to_main_table_row(); self.op_stack_underflow_trace .push_row(op_stack_table_row.view()) .unwrap(); diff --git a/triton-vm/src/constraints.rs b/triton-vm/src/constraints.rs index e1acdd888..1e2b4b82a 100644 --- a/triton-vm/src/constraints.rs +++ b/triton-vm/src/constraints.rs @@ -19,31 +19,31 @@ mod test { use crate::memory_layout::IntegralMemoryLayout; use crate::memory_layout::StaticTasmConstraintEvaluationMemoryLayout; use crate::prelude::*; - use crate::table::extension_table::Evaluable; - use crate::table::master_table::MasterExtTable; - use crate::table::NUM_AUX_COLUMNS; - use crate::table::NUM_MAIN_COLUMNS; + use crate::table::auxiliary_table::Evaluable; + use crate::table::master_table::MasterAuxTable; + use crate::table::master_table::MasterMainTable; + use crate::table::master_table::MasterTable; use super::dynamic_air_constraint_evaluation_tasm; use super::static_air_constraint_evaluation_tasm; #[derive(Debug, Clone, test_strategy::Arbitrary)] struct ConstraintEvaluationPoint { - #[strategy(vec(arb(), NUM_MAIN_COLUMNS))] + #[strategy(vec(arb(), MasterMainTable::NUM_COLUMNS))] #[map(Array1::from)] - curr_base_row: Array1, + curr_main_row: Array1, - #[strategy(vec(arb(), NUM_AUX_COLUMNS))] + #[strategy(vec(arb(), MasterAuxTable::NUM_COLUMNS))] #[map(Array1::from)] - curr_ext_row: Array1, + curr_aux_row: Array1, - #[strategy(vec(arb(), NUM_MAIN_COLUMNS))] + #[strategy(vec(arb(), MasterMainTable::NUM_COLUMNS))] #[map(Array1::from)] - next_base_row: Array1, + next_main_row: Array1, - #[strategy(vec(arb(), NUM_AUX_COLUMNS))] + #[strategy(vec(arb(), MasterAuxTable::NUM_COLUMNS))] #[map(Array1::from)] - next_ext_row: Array1, + next_aux_row: Array1, #[strategy(arb())] challenges: Challenges, @@ -55,26 +55,26 @@ mod test { impl ConstraintEvaluationPoint { fn evaluate_all_constraints_rust(&self) -> Vec { - let init = MasterExtTable::evaluate_initial_constraints( - self.curr_base_row.view(), - self.curr_ext_row.view(), + let init = MasterAuxTable::evaluate_initial_constraints( + self.curr_main_row.view(), + self.curr_aux_row.view(), &self.challenges, ); - let cons = MasterExtTable::evaluate_consistency_constraints( - self.curr_base_row.view(), - self.curr_ext_row.view(), + let cons = MasterAuxTable::evaluate_consistency_constraints( + self.curr_main_row.view(), + self.curr_aux_row.view(), &self.challenges, ); - let tran = MasterExtTable::evaluate_transition_constraints( - self.curr_base_row.view(), - self.curr_ext_row.view(), - self.next_base_row.view(), - self.next_ext_row.view(), + let tran = MasterAuxTable::evaluate_transition_constraints( + self.curr_main_row.view(), + self.curr_aux_row.view(), + self.next_main_row.view(), + self.next_aux_row.view(), &self.challenges, ); - let term = MasterExtTable::evaluate_terminal_constraints( - self.curr_base_row.view(), - self.curr_ext_row.view(), + let term = MasterAuxTable::evaluate_terminal_constraints( + self.curr_main_row.view(), + self.curr_aux_row.view(), &self.challenges, ); @@ -117,7 +117,7 @@ mod test { fn extract_constraint_evaluations(mut vm_state: VMState) -> Vec { assert!(vm_state.halting); let output_list_ptr = vm_state.op_stack.pop().unwrap().value(); - let num_quotients = MasterExtTable::NUM_CONSTRAINTS; + let num_quotients = MasterAuxTable::NUM_CONSTRAINTS; Self::read_xfe_list_at_address(vm_state.ram, output_list_ptr, num_quotients) } @@ -125,17 +125,17 @@ mod test { &self, program: &Program, ) -> VMState { - let curr_base_row_ptr = self.static_memory_layout.curr_base_row_ptr; - let curr_ext_row_ptr = self.static_memory_layout.curr_ext_row_ptr; - let next_base_row_ptr = self.static_memory_layout.next_base_row_ptr; - let next_ext_row_ptr = self.static_memory_layout.next_ext_row_ptr; + let curr_main_row_ptr = self.static_memory_layout.curr_main_row_ptr; + let curr_aux_row_ptr = self.static_memory_layout.curr_aux_row_ptr; + let next_main_row_ptr = self.static_memory_layout.next_main_row_ptr; + let next_aux_row_ptr = self.static_memory_layout.next_aux_row_ptr; let challenges_ptr = self.static_memory_layout.challenges_ptr; let mut ram = HashMap::default(); - Self::extend_ram_at_address(&mut ram, self.curr_base_row.to_vec(), curr_base_row_ptr); - Self::extend_ram_at_address(&mut ram, self.curr_ext_row.to_vec(), curr_ext_row_ptr); - Self::extend_ram_at_address(&mut ram, self.next_base_row.to_vec(), next_base_row_ptr); - Self::extend_ram_at_address(&mut ram, self.next_ext_row.to_vec(), next_ext_row_ptr); + Self::extend_ram_at_address(&mut ram, self.curr_main_row.to_vec(), curr_main_row_ptr); + Self::extend_ram_at_address(&mut ram, self.curr_aux_row.to_vec(), curr_aux_row_ptr); + Self::extend_ram_at_address(&mut ram, self.next_main_row.to_vec(), next_main_row_ptr); + Self::extend_ram_at_address(&mut ram, self.next_aux_row.to_vec(), next_aux_row_ptr); Self::extend_ram_at_address(&mut ram, self.challenges.challenges, challenges_ptr); let non_determinism = NonDeterminism::default().with_ram(ram); @@ -151,10 +151,10 @@ mod test { self.set_up_triton_vm_to_evaluate_constraints_in_tasm_static(program); let layout = self.static_memory_layout; - vm_state.op_stack.push(layout.curr_base_row_ptr); - vm_state.op_stack.push(layout.curr_ext_row_ptr); - vm_state.op_stack.push(layout.next_base_row_ptr); - vm_state.op_stack.push(layout.next_ext_row_ptr); + vm_state.op_stack.push(layout.curr_main_row_ptr); + vm_state.op_stack.push(layout.curr_aux_row_ptr); + vm_state.op_stack.push(layout.next_main_row_ptr); + vm_state.op_stack.push(layout.next_aux_row_ptr); vm_state } diff --git a/triton-vm/src/error.rs b/triton-vm/src/error.rs index 5f377ac80..2b7862563 100644 --- a/triton-vm/src/error.rs +++ b/triton-vm/src/error.rs @@ -173,7 +173,7 @@ pub enum VerificationError { BaseCodewordAuthenticationFailure, #[error("failed to verify authentication path for extension codeword")] - ExtensionCodewordAuthenticationFailure, + AuxiliaryCodewordAuthenticationFailure, #[error("failed to verify authentication path for combined quotient codeword")] QuotientCodewordAuthenticationFailure, @@ -190,10 +190,10 @@ pub enum VerificationError { #[error("the number of received quotient segment elements does not match the parameters")] IncorrectNumberOfQuotientSegmentElements, - #[error("the number of received base table rows does not match the parameters")] - IncorrectNumberOfBaseTableRows, + #[error("the number of received main table rows does not match the parameters")] + IncorrectNumberOfMainTableRows, - #[error("the number of received extension table rows does not match the parameters")] + #[error("the number of received auxiliary table rows does not match the parameters")] IncorrectNumberOfExtTableRows, #[error(transparent)] diff --git a/triton-vm/src/lib.rs b/triton-vm/src/lib.rs index 33c316d9e..48a8e3ee4 100644 --- a/triton-vm/src/lib.rs +++ b/triton-vm/src/lib.rs @@ -331,11 +331,11 @@ mod tests { // table things implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); implements_auto_traits::(); - implements_auto_traits::(); - implements_auto_traits::(); + implements_auto_traits::(); + implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); implements_auto_traits::(); diff --git a/triton-vm/src/memory_layout.rs b/triton-vm/src/memory_layout.rs index 0e816b0d6..7acfa3e42 100644 --- a/triton-vm/src/memory_layout.rs +++ b/triton-vm/src/memory_layout.rs @@ -1,13 +1,15 @@ pub use constraint_builder::codegen::MEM_PAGE_SIZE; use air::challenge_id::ChallengeId; -use air::table::NUM_BASE_COLUMNS; -use air::table::NUM_EXT_COLUMNS; use arbitrary::Arbitrary; use itertools::Itertools; use strum::EnumCount; use twenty_first::prelude::*; +use crate::table::master_table::MasterAuxTable; +use crate::table::master_table::MasterMainTable; +use crate::table::master_table::MasterTable; + /// Memory layout guarantees for the [Triton assembly AIR constraint evaluator][tasm_air] /// with input lists at dynamically known memory locations. /// @@ -36,17 +38,17 @@ pub struct StaticTasmConstraintEvaluationMemoryLayout { /// The size of the region must be at least [`MEM_PAGE_SIZE`] [`BFieldElement`]s. pub free_mem_page_ptr: BFieldElement, - /// Pointer to an array of [`XFieldElement`]s of length [`NUM_BASE_COLUMNS`]. - pub curr_base_row_ptr: BFieldElement, + /// Pointer to an array of [`XFieldElement`]s of length [`MasterMainTable::NUM_COLUMNS`]. + pub curr_main_row_ptr: BFieldElement, - /// Pointer to an array of [`XFieldElement`]s of length [`NUM_EXT_COLUMNS`]. - pub curr_ext_row_ptr: BFieldElement, + /// Pointer to an array of [`XFieldElement`]s of length [`MasterAuxTable::NUM_COLUMNS`]. + pub curr_aux_row_ptr: BFieldElement, - /// Pointer to an array of [`XFieldElement`]s of length [`NUM_BASE_COLUMNS`]. - pub next_base_row_ptr: BFieldElement, + /// Pointer to an array of [`XFieldElement`]s of length [`MasterMainTable::NUM_COLUMNS`]. + pub next_main_row_ptr: BFieldElement, - /// Pointer to an array of [`XFieldElement`]s of length [`NUM_EXT_COLUMNS`]. - pub next_ext_row_ptr: BFieldElement, + /// Pointer to an array of [`XFieldElement`]s of length [`MasterAuxTable::NUM_COLUMNS`]. + pub next_aux_row_ptr: BFieldElement, /// Pointer to an array of [`XFieldElement`]s of length [`NUM_CHALLENGES`][num_challenges]. /// @@ -77,10 +79,10 @@ impl IntegralMemoryLayout for StaticTasmConstraintEvaluationMemoryLayout { fn memory_regions(&self) -> Box<[MemoryRegion]> { let all_regions = [ MemoryRegion::new(self.free_mem_page_ptr, MEM_PAGE_SIZE), - MemoryRegion::new(self.curr_base_row_ptr, NUM_BASE_COLUMNS), - MemoryRegion::new(self.curr_ext_row_ptr, NUM_EXT_COLUMNS), - MemoryRegion::new(self.next_base_row_ptr, NUM_BASE_COLUMNS), - MemoryRegion::new(self.next_ext_row_ptr, NUM_EXT_COLUMNS), + MemoryRegion::new(self.curr_main_row_ptr, MasterMainTable::NUM_COLUMNS), + MemoryRegion::new(self.curr_aux_row_ptr, MasterAuxTable::NUM_COLUMNS), + MemoryRegion::new(self.next_main_row_ptr, MasterMainTable::NUM_COLUMNS), + MemoryRegion::new(self.next_aux_row_ptr, MasterAuxTable::NUM_COLUMNS), MemoryRegion::new(self.challenges_ptr, ChallengeId::COUNT), ]; Box::new(all_regions) @@ -143,10 +145,10 @@ mod tests { let mem_page = |i| bfe!(i * mem_page_size); StaticTasmConstraintEvaluationMemoryLayout { free_mem_page_ptr: mem_page(0), - curr_base_row_ptr: mem_page(1), - curr_ext_row_ptr: mem_page(2), - next_base_row_ptr: mem_page(3), - next_ext_row_ptr: mem_page(4), + curr_main_row_ptr: mem_page(1), + curr_aux_row_ptr: mem_page(2), + next_main_row_ptr: mem_page(3), + next_aux_row_ptr: mem_page(4), challenges_ptr: mem_page(5), } } @@ -196,10 +198,10 @@ mod tests { fn definitely_non_integral_memory_layout_is_detected_as_non_integral() { let layout = StaticTasmConstraintEvaluationMemoryLayout { free_mem_page_ptr: bfe!(0), - curr_base_row_ptr: bfe!(1), - curr_ext_row_ptr: bfe!(2), - next_base_row_ptr: bfe!(3), - next_ext_row_ptr: bfe!(4), + curr_main_row_ptr: bfe!(1), + curr_aux_row_ptr: bfe!(2), + next_main_row_ptr: bfe!(3), + next_aux_row_ptr: bfe!(4), challenges_ptr: bfe!(5), }; assert!(!layout.is_integral()); diff --git a/triton-vm/src/proof_item.rs b/triton-vm/src/proof_item.rs index 1ad3253cb..0597bacd2 100644 --- a/triton-vm/src/proof_item.rs +++ b/triton-vm/src/proof_item.rs @@ -8,8 +8,8 @@ use twenty_first::prelude::*; use crate::error::ProofStreamError; use crate::error::ProofStreamError::UnexpectedItem; use crate::fri::AuthenticationStructure; -use crate::table::BaseRow; -use crate::table::ExtensionRow; +use crate::table::AuxiliaryRow; +use crate::table::MainRow; use crate::table::QuotientSegments; /// A `FriResponse` is an `AuthenticationStructure` together with the values of the @@ -99,14 +99,14 @@ macro_rules! proof_items { proof_items!( MerkleRoot(Digest) => true, try_into_merkle_root, - OutOfDomainBaseRow(Box>) => true, try_into_out_of_domain_base_row, - OutOfDomainExtRow(Box) => true, try_into_out_of_domain_ext_row, + OutOfDomainMainRow(Box>) => true, try_into_out_of_domain_main_row, + OutOfDomainAuxRow(Box) => true, try_into_out_of_domain_aux_row, OutOfDomainQuotientSegments(QuotientSegments) => true, try_into_out_of_domain_quot_segments, // the following are implied by some Merkle root, thus not included in the Fiat-Shamir heuristic AuthenticationStructure(AuthenticationStructure) => false, try_into_authentication_structure, - MasterBaseTableRows(Vec>) => false, try_into_master_base_table_rows, - MasterExtTableRows(Vec) => false, try_into_master_ext_table_rows, + MasterMainTableRows(Vec>) => false, try_into_master_main_table_rows, + MasterAuxTableRows(Vec) => false, try_into_master_aux_table_rows, Log2PaddedHeight(u32) => false, try_into_log2_padded_height, QuotientSegmentsElements(Vec) => false, try_into_quot_segments_elements, FriCodeword(Vec) => false, try_into_fri_codeword, @@ -182,10 +182,10 @@ pub(crate) mod tests { let item = ProofItem::MerkleRoot(fake_root); assert!(let Err(UnexpectedItem{..}) = item.clone().try_into_authentication_structure()); assert!(let Err(UnexpectedItem{..}) = item.clone().try_into_fri_response()); - assert!(let Err(UnexpectedItem{..}) = item.clone().try_into_master_base_table_rows()); - assert!(let Err(UnexpectedItem{..}) = item.clone().try_into_master_ext_table_rows()); - assert!(let Err(UnexpectedItem{..}) = item.clone().try_into_out_of_domain_base_row()); - assert!(let Err(UnexpectedItem{..}) = item.clone().try_into_out_of_domain_ext_row()); + assert!(let Err(UnexpectedItem{..}) = item.clone().try_into_master_main_table_rows()); + assert!(let Err(UnexpectedItem{..}) = item.clone().try_into_master_aux_table_rows()); + assert!(let Err(UnexpectedItem{..}) = item.clone().try_into_out_of_domain_main_row()); + assert!(let Err(UnexpectedItem{..}) = item.clone().try_into_out_of_domain_aux_row()); assert!(let Err(UnexpectedItem{..}) = item.clone().try_into_out_of_domain_quot_segments()); assert!(let Err(UnexpectedItem{..}) = item.clone().try_into_log2_padded_height()); assert!(let Err(UnexpectedItem{..}) = item.clone().try_into_quot_segments_elements()); diff --git a/triton-vm/src/proof_stream.rs b/triton-vm/src/proof_stream.rs index f0c90e28f..6a1f76f29 100644 --- a/triton-vm/src/proof_stream.rs +++ b/triton-vm/src/proof_stream.rs @@ -129,18 +129,18 @@ mod tests { use crate::proof_item::FriResponse; use crate::proof_item::ProofItem; use crate::shared_tests::LeavedMerkleTreeTestData; - use crate::table::BaseRow; - use crate::table::ExtensionRow; + use crate::table::AuxiliaryRow; + use crate::table::MainRow; use crate::table::QuotientSegments; use super::*; #[proptest] fn serialize_proof_with_fiat_shamir( - #[strategy(vec(arb(), 2..100))] base_rows: Vec>, - #[strategy(vec(arb(), 2..100))] ext_rows: Vec, - #[strategy(arb())] ood_base_row: Box>, - #[strategy(arb())] ood_ext_row: Box, + #[strategy(vec(arb(), 2..100))] main_rows: Vec>, + #[strategy(vec(arb(), 2..100))] aux_rows: Vec, + #[strategy(arb())] ood_main_row: Box>, + #[strategy(arb())] ood_aux_row: Box, #[strategy(arb())] quot_elements: Vec, leaved_merkle_tree: LeavedMerkleTreeTestData, ) { @@ -155,13 +155,13 @@ mod tests { sponge_states.push_back(proof_stream.sponge.state); proof_stream.enqueue(ProofItem::AuthenticationStructure(auth_structure.clone())); sponge_states.push_back(proof_stream.sponge.state); - proof_stream.enqueue(ProofItem::MasterBaseTableRows(base_rows.clone())); + proof_stream.enqueue(ProofItem::MasterMainTableRows(main_rows.clone())); sponge_states.push_back(proof_stream.sponge.state); - proof_stream.enqueue(ProofItem::MasterExtTableRows(ext_rows.clone())); + proof_stream.enqueue(ProofItem::MasterAuxTableRows(aux_rows.clone())); sponge_states.push_back(proof_stream.sponge.state); - proof_stream.enqueue(ProofItem::OutOfDomainBaseRow(ood_base_row.clone())); + proof_stream.enqueue(ProofItem::OutOfDomainMainRow(ood_main_row.clone())); sponge_states.push_back(proof_stream.sponge.state); - proof_stream.enqueue(ProofItem::OutOfDomainExtRow(ood_ext_row.clone())); + proof_stream.enqueue(ProofItem::OutOfDomainAuxRow(ood_aux_row.clone())); sponge_states.push_back(proof_stream.sponge.state); proof_stream.enqueue(ProofItem::MerkleRoot(root)); sponge_states.push_back(proof_stream.sponge.state); @@ -181,20 +181,20 @@ mod tests { assert!(auth_structure == auth_structure_); assert!(sponge_states.pop_front() == Some(proof_stream.sponge.state)); - let_assert!(Ok(ProofItem::MasterBaseTableRows(base_rows_)) = proof_stream.dequeue()); - assert!(base_rows == base_rows_); + let_assert!(Ok(ProofItem::MasterMainTableRows(main_rows_)) = proof_stream.dequeue()); + assert!(main_rows == main_rows_); assert!(sponge_states.pop_front() == Some(proof_stream.sponge.state)); - let_assert!(Ok(ProofItem::MasterExtTableRows(ext_rows_)) = proof_stream.dequeue()); - assert!(ext_rows == ext_rows_); + let_assert!(Ok(ProofItem::MasterAuxTableRows(aux_rows_)) = proof_stream.dequeue()); + assert!(aux_rows == aux_rows_); assert!(sponge_states.pop_front() == Some(proof_stream.sponge.state)); - let_assert!(Ok(ProofItem::OutOfDomainBaseRow(ood_base_row_)) = proof_stream.dequeue()); - assert!(ood_base_row == ood_base_row_); + let_assert!(Ok(ProofItem::OutOfDomainMainRow(ood_main_row_)) = proof_stream.dequeue()); + assert!(ood_main_row == ood_main_row_); assert!(sponge_states.pop_front() == Some(proof_stream.sponge.state)); - let_assert!(Ok(ProofItem::OutOfDomainExtRow(ood_ext_row_)) = proof_stream.dequeue()); - assert!(ood_ext_row == ood_ext_row_); + let_assert!(Ok(ProofItem::OutOfDomainAuxRow(ood_aux_row_)) = proof_stream.dequeue()); + assert!(ood_aux_row == ood_aux_row_); assert!(sponge_states.pop_front() == Some(proof_stream.sponge.state)); let_assert!(Ok(ProofItem::MerkleRoot(root_)) = proof_stream.dequeue()); diff --git a/triton-vm/src/shared_tests.rs b/triton-vm/src/shared_tests.rs index 11715d2ac..6abb3e270 100644 --- a/triton-vm/src/shared_tests.rs +++ b/triton-vm/src/shared_tests.rs @@ -14,7 +14,7 @@ use crate::fri::AuthenticationStructure; use crate::prelude::*; use crate::profiler::profiler; use crate::proof_item::FriResponse; -use crate::table::master_table::MasterBaseTable; +use crate::table::master_table::MasterMainTable; pub(crate) const DEFAULT_LOG2_FRI_EXPANSION_FACTOR_FOR_TESTS: usize = 2; @@ -140,15 +140,15 @@ pub(crate) fn low_security_stark(log_expansion_factor: usize) -> Stark { Stark::new(security_level, log_expansion_factor) } -pub(crate) fn construct_master_base_table( +pub(crate) fn construct_master_main_table( stark: Stark, aet: &AlgebraicExecutionTrace, -) -> MasterBaseTable { +) -> MasterMainTable { let padded_height = aet.padded_height(); let fri = stark.derive_fri(padded_height).unwrap(); let max_degree = stark.derive_max_degree(padded_height); let quotient_domain = Stark::quotient_domain(fri.domain, max_degree).unwrap(); - MasterBaseTable::new( + MasterMainTable::new( aet, stark.num_trace_randomizers, quotient_domain, diff --git a/triton-vm/src/stark.rs b/triton-vm/src/stark.rs index 19bc91723..e54a584c7 100644 --- a/triton-vm/src/stark.rs +++ b/triton-vm/src/stark.rs @@ -28,12 +28,12 @@ use crate::proof::Claim; use crate::proof::Proof; use crate::proof_item::ProofItem; use crate::proof_stream::ProofStream; -use crate::table::extension_table::Evaluable; +use crate::table::auxiliary_table::Evaluable; use crate::table::master_table::all_quotients_combined; use crate::table::master_table::interpolant_degree; use crate::table::master_table::max_degree_with_origin; -use crate::table::master_table::MasterBaseTable; -use crate::table::master_table::MasterExtTable; +use crate::table::master_table::MasterAuxTable; +use crate::table::master_table::MasterMainTable; use crate::table::master_table::MasterTable; use crate::table::QuotientSegments; @@ -114,65 +114,65 @@ impl Stark { proof_stream.enqueue(ProofItem::Log2PaddedHeight(padded_height.ilog2())); profiler!(stop "derive additional parameters"); - profiler!(start "base tables"); + profiler!(start "main tables"); profiler!(start "create" ("gen")); - let mut master_base_table = - MasterBaseTable::new(aet, self.num_trace_randomizers, quotient_domain, fri.domain); + let mut master_main_table = + MasterMainTable::new(aet, self.num_trace_randomizers, quotient_domain, fri.domain); profiler!(stop "create"); profiler!(start "pad" ("gen")); - master_base_table.pad(); + master_main_table.pad(); profiler!(stop "pad"); profiler!(start "randomize trace" ("gen")); - master_base_table.randomize_trace(); + master_main_table.randomize_trace(); profiler!(stop "randomize trace"); profiler!(start "LDE" ("LDE")); - master_base_table.low_degree_extend_all_columns(); + master_main_table.low_degree_extend_all_columns(); profiler!(stop "LDE"); profiler!(start "Merkle tree" ("hash")); - let base_merkle_tree = master_base_table.merkle_tree(); + let main_merkle_tree = master_main_table.merkle_tree(); profiler!(stop "Merkle tree"); profiler!(start "Fiat-Shamir" ("hash")); - proof_stream.enqueue(ProofItem::MerkleRoot(base_merkle_tree.root())); + proof_stream.enqueue(ProofItem::MerkleRoot(main_merkle_tree.root())); let challenges = proof_stream.sample_scalars(Challenges::SAMPLE_COUNT); let challenges = Challenges::new(challenges, claim); profiler!(stop "Fiat-Shamir"); profiler!(start "extend" ("gen")); - let mut master_ext_table = master_base_table.extend(&challenges); + let mut master_aux_table = master_main_table.extend(&challenges); profiler!(stop "extend"); - profiler!(stop "base tables"); + profiler!(stop "main tables"); profiler!(start "ext tables"); profiler!(start "randomize trace" ("gen")); - master_ext_table.randomize_trace(); + master_aux_table.randomize_trace(); profiler!(stop "randomize trace"); profiler!(start "LDE" ("LDE")); - master_ext_table.low_degree_extend_all_columns(); + master_aux_table.low_degree_extend_all_columns(); profiler!(stop "LDE"); profiler!(start "Merkle tree" ("hash")); - let ext_merkle_tree = master_ext_table.merkle_tree(); + let aux_merkle_tree = master_aux_table.merkle_tree(); profiler!(stop "Merkle tree"); profiler!(start "Fiat-Shamir" ("hash")); - proof_stream.enqueue(ProofItem::MerkleRoot(ext_merkle_tree.root())); + proof_stream.enqueue(ProofItem::MerkleRoot(aux_merkle_tree.root())); // Get the weights with which to compress the many quotients into one. let quotient_combination_weights = - proof_stream.sample_scalars(MasterExtTable::NUM_CONSTRAINTS); + proof_stream.sample_scalars(MasterAuxTable::NUM_CONSTRAINTS); profiler!(stop "Fiat-Shamir"); profiler!(stop "ext tables"); let (fri_domain_quotient_segment_codewords, quotient_segment_polynomials) = Self::compute_quotient_segments( - &master_base_table, - &master_ext_table, + &master_main_table, + &master_aux_table, fri.domain, quotient_domain, &challenges, @@ -201,25 +201,25 @@ impl Stark { debug_assert_eq!(fri.domain.length, quot_merkle_tree.num_leafs()); profiler!(start "out-of-domain rows"); - let trace_domain_generator = master_base_table.trace_domain().generator; + let trace_domain_generator = master_main_table.trace_domain().generator; let out_of_domain_point_curr_row = proof_stream.sample_scalars(1)[0]; let out_of_domain_point_next_row = trace_domain_generator * out_of_domain_point_curr_row; - let ood_base_row = master_base_table.out_of_domain_row(out_of_domain_point_curr_row); - let ood_base_row = MasterBaseTable::try_to_base_row(ood_base_row)?; - proof_stream.enqueue(ProofItem::OutOfDomainBaseRow(Box::new(ood_base_row))); + let ood_main_row = master_main_table.out_of_domain_row(out_of_domain_point_curr_row); + let ood_main_row = MasterMainTable::try_to_main_row(ood_main_row)?; + proof_stream.enqueue(ProofItem::OutOfDomainMainRow(Box::new(ood_main_row))); - let ood_ext_row = master_ext_table.out_of_domain_row(out_of_domain_point_curr_row); - let ood_ext_row = MasterExtTable::try_to_ext_row(ood_ext_row)?; - proof_stream.enqueue(ProofItem::OutOfDomainExtRow(Box::new(ood_ext_row))); + let ood_aux_row = master_aux_table.out_of_domain_row(out_of_domain_point_curr_row); + let ood_aux_row = MasterAuxTable::try_to_aux_row(ood_aux_row)?; + proof_stream.enqueue(ProofItem::OutOfDomainAuxRow(Box::new(ood_aux_row))); - let ood_next_base_row = master_base_table.out_of_domain_row(out_of_domain_point_next_row); - let ood_next_base_row = MasterBaseTable::try_to_base_row(ood_next_base_row)?; - proof_stream.enqueue(ProofItem::OutOfDomainBaseRow(Box::new(ood_next_base_row))); + let ood_next_main_row = master_main_table.out_of_domain_row(out_of_domain_point_next_row); + let ood_next_main_row = MasterMainTable::try_to_main_row(ood_next_main_row)?; + proof_stream.enqueue(ProofItem::OutOfDomainMainRow(Box::new(ood_next_main_row))); - let ood_next_ext_row = master_ext_table.out_of_domain_row(out_of_domain_point_next_row); - let ood_next_ext_row = MasterExtTable::try_to_ext_row(ood_next_ext_row)?; - proof_stream.enqueue(ProofItem::OutOfDomainExtRow(Box::new(ood_next_ext_row))); + let ood_next_aux_row = master_aux_table.out_of_domain_row(out_of_domain_point_next_row); + let ood_next_aux_row = MasterAuxTable::try_to_aux_row(ood_next_aux_row)?; + proof_stream.enqueue(ProofItem::OutOfDomainAuxRow(Box::new(ood_next_aux_row))); let out_of_domain_point_curr_row_pow_num_segments = out_of_domain_point_curr_row.mod_pow_u32(NUM_QUOTIENT_SEGMENTS as u32); @@ -247,12 +247,12 @@ impl Stark { profiler!(start "linear combination"); profiler!(start "base" ("CC")); let base_combination_polynomial = - Self::random_linear_sum(master_base_table.interpolation_polynomials(), weights.main); + Self::random_linear_sum(master_main_table.interpolation_polynomials(), weights.main); profiler!(stop "base"); profiler!(start "ext" ("CC")); let ext_combination_polynomial = - Self::random_linear_sum(master_ext_table.interpolation_polynomials(), weights.aux); + Self::random_linear_sum(master_aux_table.interpolation_polynomials(), weights.aux); profiler!(stop "ext"); let base_and_ext_combination_polynomial = base_combination_polynomial + ext_combination_polynomial; @@ -357,35 +357,35 @@ impl Stark { profiler!(start "open trace leafs"); // Open leafs of zipped codewords at indicated positions let revealed_base_elems = - if let Some(fri_domain_table) = master_base_table.fri_domain_table() { + if let Some(fri_domain_table) = master_main_table.fri_domain_table() { Self::read_revealed_rows(fri_domain_table, &revealed_current_row_indices)? } else { - Self::recompute_revealed_rows::<{ MasterBaseTable::NUM_COLUMNS }, BFieldElement>( - &master_base_table.interpolation_polynomials(), + Self::recompute_revealed_rows::<{ MasterMainTable::NUM_COLUMNS }, BFieldElement>( + &master_main_table.interpolation_polynomials(), &revealed_current_row_indices, fri.domain, ) }; let base_authentication_structure = - base_merkle_tree.authentication_structure(&revealed_current_row_indices)?; - proof_stream.enqueue(ProofItem::MasterBaseTableRows(revealed_base_elems)); + main_merkle_tree.authentication_structure(&revealed_current_row_indices)?; + proof_stream.enqueue(ProofItem::MasterMainTableRows(revealed_base_elems)); proof_stream.enqueue(ProofItem::AuthenticationStructure( base_authentication_structure, )); - let revealed_ext_elems = if let Some(fri_domain_table) = master_ext_table.fri_domain_table() + let revealed_ext_elems = if let Some(fri_domain_table) = master_aux_table.fri_domain_table() { Self::read_revealed_rows(fri_domain_table, &revealed_current_row_indices)? } else { Self::recompute_revealed_rows( - &master_ext_table.interpolation_polynomials(), + &master_aux_table.interpolation_polynomials(), &revealed_current_row_indices, fri.domain, ) }; let ext_authentication_structure = - ext_merkle_tree.authentication_structure(&revealed_current_row_indices)?; - proof_stream.enqueue(ProofItem::MasterExtTableRows(revealed_ext_elems)); + aux_merkle_tree.authentication_structure(&revealed_current_row_indices)?; + proof_stream.enqueue(ProofItem::MasterAuxTableRows(revealed_ext_elems)); proof_stream.enqueue(ProofItem::AuthenticationStructure( ext_authentication_structure, )); @@ -412,8 +412,8 @@ impl Stark { } fn compute_quotient_segments( - master_base_table: &MasterBaseTable, - master_ext_table: &MasterExtTable, + master_main_table: &MasterMainTable, + master_aux_table: &MasterAuxTable, fri_domain: ArithmeticDomain, quotient_domain: ArithmeticDomain, challenges: &Challenges, @@ -423,11 +423,11 @@ impl Stark { profiler!(start "quotient calculation (just-in-time)"); let (fri_domain_quotient_segment_codewords, quotient_segment_polynomials) = Self::compute_quotient_segments_with_jit_lde( - master_base_table.interpolation_polynomials(), - master_ext_table.interpolation_polynomials(), - master_base_table.trace_domain(), - master_base_table.randomized_trace_domain(), - master_base_table.fri_domain(), + master_main_table.interpolation_polynomials(), + master_aux_table.interpolation_polynomials(), + master_main_table.trace_domain(), + master_main_table.randomized_trace_domain(), + master_main_table.fri_domain(), challenges, quotient_combination_weights, ); @@ -438,18 +438,18 @@ impl Stark { ) }; - let Some(base_quotient_domain_codewords) = master_base_table.quotient_domain_table() else { + let Some(main_quotient_domain_codewords) = master_main_table.quotient_domain_table() else { return calculate_quotients_with_just_in_time_low_degree_extension(); }; - let Some(ext_quotient_domain_codewords) = master_ext_table.quotient_domain_table() else { + let Some(aux_quotient_domain_codewords) = master_aux_table.quotient_domain_table() else { return calculate_quotients_with_just_in_time_low_degree_extension(); }; profiler!(start "quotient calculation (cached)" ("CC")); let quotient_codeword = all_quotients_combined( - base_quotient_domain_codewords, - ext_quotient_domain_codewords, - master_base_table.trace_domain(), + main_quotient_domain_codewords, + aux_quotient_domain_codewords, + master_main_table.trace_domain(), quotient_domain, challenges, quotient_combination_weights, @@ -738,10 +738,10 @@ impl Stark { let base_merkle_tree_root = proof_stream.dequeue()?.try_into_merkle_root()?; let extension_challenge_weights = proof_stream.sample_scalars(Challenges::SAMPLE_COUNT); let challenges = Challenges::new(extension_challenge_weights, claim); - let extension_tree_merkle_root = proof_stream.dequeue()?.try_into_merkle_root()?; + let auxiliary_tree_merkle_root = proof_stream.dequeue()?.try_into_merkle_root()?; // Sample weights for quotient codeword, which is a part of the combination codeword. // See corresponding part in the prover for a more detailed explanation. - let quot_codeword_weights = proof_stream.sample_scalars(MasterExtTable::NUM_CONSTRAINTS); + let quot_codeword_weights = proof_stream.sample_scalars(MasterAuxTable::NUM_CONSTRAINTS); let quot_codeword_weights = Array1::from(quot_codeword_weights); let quotient_codeword_merkle_root = proof_stream.dequeue()?.try_into_merkle_root()?; profiler!(stop "Fiat-Shamir 1"); @@ -753,48 +753,48 @@ impl Stark { let out_of_domain_point_curr_row_pow_num_segments = out_of_domain_point_curr_row.mod_pow_u32(NUM_QUOTIENT_SEGMENTS as u32); - let out_of_domain_curr_base_row = - proof_stream.dequeue()?.try_into_out_of_domain_base_row()?; - let out_of_domain_curr_ext_row = - proof_stream.dequeue()?.try_into_out_of_domain_ext_row()?; - let out_of_domain_next_base_row = - proof_stream.dequeue()?.try_into_out_of_domain_base_row()?; - let out_of_domain_next_ext_row = - proof_stream.dequeue()?.try_into_out_of_domain_ext_row()?; + let out_of_domain_curr_main_row = + proof_stream.dequeue()?.try_into_out_of_domain_main_row()?; + let out_of_domain_curr_aux_row = + proof_stream.dequeue()?.try_into_out_of_domain_aux_row()?; + let out_of_domain_next_main_row = + proof_stream.dequeue()?.try_into_out_of_domain_main_row()?; + let out_of_domain_next_aux_row = + proof_stream.dequeue()?.try_into_out_of_domain_aux_row()?; let out_of_domain_curr_row_quot_segments = proof_stream .dequeue()? .try_into_out_of_domain_quot_segments()?; - let out_of_domain_curr_base_row = Array1::from(out_of_domain_curr_base_row.to_vec()); - let out_of_domain_curr_ext_row = Array1::from(out_of_domain_curr_ext_row.to_vec()); - let out_of_domain_next_base_row = Array1::from(out_of_domain_next_base_row.to_vec()); - let out_of_domain_next_ext_row = Array1::from(out_of_domain_next_ext_row.to_vec()); + let out_of_domain_curr_main_row = Array1::from(out_of_domain_curr_main_row.to_vec()); + let out_of_domain_curr_aux_row = Array1::from(out_of_domain_curr_aux_row.to_vec()); + let out_of_domain_next_main_row = Array1::from(out_of_domain_next_main_row.to_vec()); + let out_of_domain_next_aux_row = Array1::from(out_of_domain_next_aux_row.to_vec()); let out_of_domain_curr_row_quot_segments = Array1::from(out_of_domain_curr_row_quot_segments.to_vec()); profiler!(stop "dequeue ood point and rows"); profiler!(start "out-of-domain quotient element"); profiler!(start "evaluate AIR" ("AIR")); - let evaluated_initial_constraints = MasterExtTable::evaluate_initial_constraints( - out_of_domain_curr_base_row.view(), - out_of_domain_curr_ext_row.view(), + let evaluated_initial_constraints = MasterAuxTable::evaluate_initial_constraints( + out_of_domain_curr_main_row.view(), + out_of_domain_curr_aux_row.view(), &challenges, ); - let evaluated_consistency_constraints = MasterExtTable::evaluate_consistency_constraints( - out_of_domain_curr_base_row.view(), - out_of_domain_curr_ext_row.view(), + let evaluated_consistency_constraints = MasterAuxTable::evaluate_consistency_constraints( + out_of_domain_curr_main_row.view(), + out_of_domain_curr_aux_row.view(), &challenges, ); - let evaluated_transition_constraints = MasterExtTable::evaluate_transition_constraints( - out_of_domain_curr_base_row.view(), - out_of_domain_curr_ext_row.view(), - out_of_domain_next_base_row.view(), - out_of_domain_next_ext_row.view(), + let evaluated_transition_constraints = MasterAuxTable::evaluate_transition_constraints( + out_of_domain_curr_main_row.view(), + out_of_domain_curr_aux_row.view(), + out_of_domain_next_main_row.view(), + out_of_domain_next_aux_row.view(), &challenges, ); - let evaluated_terminal_constraints = MasterExtTable::evaluate_terminal_constraints( - out_of_domain_curr_base_row.view(), - out_of_domain_curr_ext_row.view(), + let evaluated_terminal_constraints = MasterAuxTable::evaluate_terminal_constraints( + out_of_domain_curr_main_row.view(), + out_of_domain_curr_aux_row.view(), &challenges, ); profiler!(stop "evaluate AIR"); @@ -847,14 +847,14 @@ impl Stark { profiler!(stop "Fiat-Shamir 2"); profiler!(start "sum out-of-domain values" ("CC")); - let out_of_domain_curr_row_base_and_ext_value = Self::linearly_sum_base_and_ext_row( - out_of_domain_curr_base_row.view(), - out_of_domain_curr_ext_row.view(), + let out_of_domain_curr_row_base_and_ext_value = Self::linearly_sum_main_and_aux_row( + out_of_domain_curr_main_row.view(), + out_of_domain_curr_aux_row.view(), base_and_ext_codeword_weights.view(), ); - let out_of_domain_next_row_base_and_ext_value = Self::linearly_sum_base_and_ext_row( - out_of_domain_next_base_row.view(), - out_of_domain_next_ext_row.view(), + let out_of_domain_next_row_base_and_ext_value = Self::linearly_sum_main_and_aux_row( + out_of_domain_next_main_row.view(), + out_of_domain_next_aux_row.view(), base_and_ext_codeword_weights.view(), ); let out_of_domain_curr_row_quotient_segment_value = weights @@ -870,8 +870,8 @@ impl Stark { profiler!(stop "FRI"); profiler!(start "check leafs"); - profiler!(start "dequeue base elements"); - let base_table_rows = proof_stream.dequeue()?.try_into_master_base_table_rows()?; + profiler!(start "dequeue main elements"); + let base_table_rows = proof_stream.dequeue()?.try_into_master_main_table_rows()?; let base_authentication_structure = proof_stream .dequeue()? .try_into_authentication_structure()?; @@ -879,13 +879,13 @@ impl Stark { .par_iter() .map(|revealed_base_elem| Tip5::hash_varlen(revealed_base_elem)) .collect(); - profiler!(stop "dequeue base elements"); + profiler!(stop "dequeue main elements"); let index_leaves = |leaves| { let index_iter = revealed_current_row_indices.iter().copied(); index_iter.zip_eq(leaves).collect() }; - profiler!(start "Merkle verify (base tree)" ("hash")); + profiler!(start "Merkle verify (main tree)" ("hash")); let base_merkle_tree_inclusion_proof = MerkleTreeInclusionProof { tree_height: merkle_tree_height, indexed_leafs: index_leaves(leaf_digests_base), @@ -894,10 +894,10 @@ impl Stark { if !base_merkle_tree_inclusion_proof.verify(base_merkle_tree_root) { return Err(VerificationError::BaseCodewordAuthenticationFailure); } - profiler!(stop "Merkle verify (base tree)"); + profiler!(stop "Merkle verify (main tree)"); - profiler!(start "dequeue extension elements"); - let ext_table_rows = proof_stream.dequeue()?.try_into_master_ext_table_rows()?; + profiler!(start "dequeue auxiliary elements"); + let ext_table_rows = proof_stream.dequeue()?.try_into_master_aux_table_rows()?; let ext_authentication_structure = proof_stream .dequeue()? .try_into_authentication_structure()?; @@ -908,18 +908,18 @@ impl Stark { Tip5::hash_varlen(&b_values.collect_vec()) }) .collect::>(); - profiler!(stop "dequeue extension elements"); + profiler!(stop "dequeue auxiliary elements"); - profiler!(start "Merkle verify (extension tree)" ("hash")); + profiler!(start "Merkle verify (auxiliary tree)" ("hash")); let ext_merkle_tree_inclusion_proof = MerkleTreeInclusionProof { tree_height: merkle_tree_height, indexed_leafs: index_leaves(leaf_digests_ext), authentication_structure: ext_authentication_structure, }; - if !ext_merkle_tree_inclusion_proof.verify(extension_tree_merkle_root) { - return Err(VerificationError::ExtensionCodewordAuthenticationFailure); + if !ext_merkle_tree_inclusion_proof.verify(auxiliary_tree_merkle_root) { + return Err(VerificationError::AuxiliaryCodewordAuthenticationFailure); } - profiler!(stop "Merkle verify (extension tree)"); + profiler!(stop "Merkle verify (auxiliary tree)"); profiler!(start "dequeue quotient segments' elements"); let revealed_quotient_segments_elements = @@ -954,27 +954,27 @@ impl Stark { return Err(VerificationError::IncorrectNumberOfQuotientSegmentElements); }; if self.num_collinearity_checks != base_table_rows.len() { - return Err(VerificationError::IncorrectNumberOfBaseTableRows); + return Err(VerificationError::IncorrectNumberOfMainTableRows); }; if self.num_collinearity_checks != ext_table_rows.len() { return Err(VerificationError::IncorrectNumberOfExtTableRows); }; - for (row_idx, base_row, ext_row, quotient_segments_elements, fri_value) in izip!( + for (row_idx, main_row, aux_row, quotient_segments_elements, fri_value) in izip!( revealed_current_row_indices, base_table_rows, ext_table_rows, revealed_quotient_segments_elements, revealed_fri_values, ) { - let base_row = Array1::from(base_row.to_vec()); - let ext_row = Array1::from(ext_row.to_vec()); + let main_row = Array1::from(main_row.to_vec()); + let aux_row = Array1::from(aux_row.to_vec()); let current_fri_domain_value = fri.domain.domain_value(row_idx as u32); profiler!(start "base & ext elements" ("CC")); - let base_and_ext_curr_row_element = Self::linearly_sum_base_and_ext_row( - base_row.view(), - ext_row.view(), + let base_and_ext_curr_row_element = Self::linearly_sum_main_and_aux_row( + main_row.view(), + aux_row.view(), base_and_ext_codeword_weights.view(), ); let quotient_segments_curr_row_element = weights @@ -1028,9 +1028,9 @@ impl Stark { .collect() } - fn linearly_sum_base_and_ext_row( - base_row: ArrayView1, - ext_row: ArrayView1, + fn linearly_sum_main_and_aux_row( + main_row: ArrayView1, + aux_row: ArrayView1, weights: ArrayView1, ) -> XFieldElement where @@ -1038,8 +1038,8 @@ impl Stark { XFieldElement: Mul, { profiler!(start "collect"); - let mut row = base_row.map(|&element| element.into()); - row.append(Axis(0), ext_row).unwrap(); + let mut row = main_row.map(|&element| element.into()); + row.append(Axis(0), aux_row).unwrap(); profiler!(stop "collect"); profiler!(start "inner product"); // todo: Try to get rid of this clone. The alternative line @@ -1263,10 +1263,10 @@ impl<'a> Arbitrary<'a> for Stark { /// Fiat-Shamir-sampled challenges to compress a row into a single /// [extension field element][XFieldElement]. struct LinearCombinationWeights { - /// of length [`MasterBaseTable::NUM_COLUMNS`] + /// of length [`MasterMainTable::NUM_COLUMNS`] main: Array1, - /// of length [`MasterExtTable::NUM_COLUMNS`] + /// of length [`MasterAuxTable::NUM_COLUMNS`] aux: Array1, /// of length [`NUM_QUOTIENT_SEGMENTS`] @@ -1277,14 +1277,14 @@ struct LinearCombinationWeights { } impl LinearCombinationWeights { - const NUM: usize = MasterBaseTable::NUM_COLUMNS - + MasterExtTable::NUM_COLUMNS + const NUM: usize = MasterMainTable::NUM_COLUMNS + + MasterAuxTable::NUM_COLUMNS + NUM_QUOTIENT_SEGMENTS + NUM_DEEP_CODEWORD_COMPONENTS; fn sample(proof_stream: &mut ProofStream) -> Self { - const MAIN_END: usize = MasterBaseTable::NUM_COLUMNS; - const AUX_END: usize = MAIN_END + MasterExtTable::NUM_COLUMNS; + const MAIN_END: usize = MasterMainTable::NUM_COLUMNS; + const AUX_END: usize = MAIN_END + MasterAuxTable::NUM_COLUMNS; const QUOT_END: usize = AUX_END + NUM_QUOTIENT_SEGMENTS; let weights = proof_stream.sample_scalars(Self::NUM); @@ -1324,13 +1324,13 @@ pub(crate) mod tests { use air::table::ram::RamTable; use air::table::u32::U32Table; use air::table::TableId; - use air::table_column::MasterBaseTableColumn; - use air::table_column::MasterExtTableColumn; - use air::table_column::OpStackBaseTableColumn; - use air::table_column::ProcessorBaseTableColumn; - use air::table_column::ProcessorExtTableColumn::InputTableEvalArg; - use air::table_column::ProcessorExtTableColumn::OutputTableEvalArg; - use air::table_column::RamBaseTableColumn; + use air::table_column::MasterAuxColumn; + use air::table_column::MasterMainColumn; + use air::table_column::OpStackMainColumn; + use air::table_column::ProcessorAuxColumn::InputTableEvalArg; + use air::table_column::ProcessorAuxColumn::OutputTableEvalArg; + use air::table_column::ProcessorMainColumn; + use air::table_column::RamMainColumn; use air::AIR; use assert2::assert; use assert2::check; @@ -1354,9 +1354,9 @@ pub(crate) mod tests { use crate::error::InstructionError; use crate::example_programs::*; use crate::shared_tests::*; - use crate::table::extension_table; - use crate::table::extension_table::Evaluable; - use crate::table::master_table::MasterExtTable; + use crate::table::auxiliary_table; + use crate::table::auxiliary_table::Evaluable; + use crate::table::master_table::MasterAuxTable; use crate::triton_program; use crate::vm::tests::*; use crate::vm::NonDeterminism; @@ -1367,7 +1367,7 @@ pub(crate) mod tests { pub(crate) fn master_base_table_for_low_security_level( program_and_input: ProgramAndInput, - ) -> (Stark, Claim, MasterBaseTable) { + ) -> (Stark, Claim, MasterMainTable) { let ProgramAndInput { program, public_input, @@ -1380,14 +1380,14 @@ pub(crate) mod tests { let claim = Claim::about_program(&aet.program) .with_input(public_input.individual_tokens) .with_output(stdout); - let master_base_table = construct_master_base_table(stark, &aet); + let master_base_table = construct_master_main_table(stark, &aet); (stark, claim, master_base_table) } pub(crate) fn master_tables_for_low_security_level( program_and_input: ProgramAndInput, - ) -> (Stark, Claim, MasterBaseTable, MasterExtTable, Challenges) { + ) -> (Stark, Claim, MasterMainTable, MasterAuxTable, Challenges) { let (stark, claim, mut master_base_table) = master_base_table_for_low_security_level(program_and_input); @@ -1432,13 +1432,13 @@ pub(crate) mod tests { .into_iter() .take(40) { - let clk = row[ProcessorBaseTableColumn::CLK.base_table_index()].to_string(); - let st0 = row[ProcessorBaseTableColumn::ST0.base_table_index()].to_string(); - let st1 = row[ProcessorBaseTableColumn::ST1.base_table_index()].to_string(); - let st2 = row[ProcessorBaseTableColumn::ST2.base_table_index()].to_string(); - let st3 = row[ProcessorBaseTableColumn::ST3.base_table_index()].to_string(); - let st4 = row[ProcessorBaseTableColumn::ST4.base_table_index()].to_string(); - let st5 = row[ProcessorBaseTableColumn::ST5.base_table_index()].to_string(); + let clk = row[ProcessorMainColumn::CLK.main_index()].to_string(); + let st0 = row[ProcessorMainColumn::ST0.main_index()].to_string(); + let st1 = row[ProcessorMainColumn::ST1.main_index()].to_string(); + let st2 = row[ProcessorMainColumn::ST2.main_index()].to_string(); + let st3 = row[ProcessorMainColumn::ST3.main_index()].to_string(); + let st4 = row[ProcessorMainColumn::ST4.main_index()].to_string(); + let st5 = row[ProcessorMainColumn::ST5.main_index()].to_string(); let (ci, nia) = ci_and_nia_from_master_table_row(row); @@ -1459,20 +1459,18 @@ pub(crate) mod tests { .into_iter() .take(25) { - let clk = row[RamBaseTableColumn::CLK.base_table_index()].to_string(); - let ramp = row[RamBaseTableColumn::RamPointer.base_table_index()].to_string(); - let ramv = row[RamBaseTableColumn::RamValue.base_table_index()].to_string(); - let iord = - row[RamBaseTableColumn::InverseOfRampDifference.base_table_index()].to_string(); - - let instruction_type = - match row[RamBaseTableColumn::InstructionType.base_table_index()] { - ram::INSTRUCTION_TYPE_READ => "read", - ram::INSTRUCTION_TYPE_WRITE => "write", - ram::PADDING_INDICATOR => "pad", - _ => "-", - } - .to_string(); + let clk = row[RamMainColumn::CLK.main_index()].to_string(); + let ramp = row[RamMainColumn::RamPointer.main_index()].to_string(); + let ramv = row[RamMainColumn::RamValue.main_index()].to_string(); + let iord = row[RamMainColumn::InverseOfRampDifference.main_index()].to_string(); + + let instruction_type = match row[RamMainColumn::InstructionType.main_index()] { + ram::INSTRUCTION_TYPE_READ => "read", + ram::INSTRUCTION_TYPE_WRITE => "write", + ram::PADDING_INDICATOR => "pad", + _ => "-", + } + .to_string(); let interesting_cols = [clk, instruction_type, ramp, ramv, iord]; let interesting_cols = interesting_cols @@ -1508,19 +1506,19 @@ pub(crate) mod tests { .into_iter() .take(num_interesting_rows) { - let clk = row[ProcessorBaseTableColumn::CLK.base_table_index()].to_string(); - let st0 = row[ProcessorBaseTableColumn::ST0.base_table_index()].to_string(); - let st1 = row[ProcessorBaseTableColumn::ST1.base_table_index()].to_string(); - let st2 = row[ProcessorBaseTableColumn::ST2.base_table_index()].to_string(); - let st3 = row[ProcessorBaseTableColumn::ST3.base_table_index()].to_string(); - let st4 = row[ProcessorBaseTableColumn::ST4.base_table_index()].to_string(); - let st5 = row[ProcessorBaseTableColumn::ST5.base_table_index()].to_string(); - let st6 = row[ProcessorBaseTableColumn::ST6.base_table_index()].to_string(); - let st7 = row[ProcessorBaseTableColumn::ST7.base_table_index()].to_string(); - let st8 = row[ProcessorBaseTableColumn::ST8.base_table_index()].to_string(); - let st9 = row[ProcessorBaseTableColumn::ST9.base_table_index()].to_string(); - - let osp = row[ProcessorBaseTableColumn::OpStackPointer.base_table_index()]; + let clk = row[ProcessorMainColumn::CLK.main_index()].to_string(); + let st0 = row[ProcessorMainColumn::ST0.main_index()].to_string(); + let st1 = row[ProcessorMainColumn::ST1.main_index()].to_string(); + let st2 = row[ProcessorMainColumn::ST2.main_index()].to_string(); + let st3 = row[ProcessorMainColumn::ST3.main_index()].to_string(); + let st4 = row[ProcessorMainColumn::ST4.main_index()].to_string(); + let st5 = row[ProcessorMainColumn::ST5.main_index()].to_string(); + let st6 = row[ProcessorMainColumn::ST6.main_index()].to_string(); + let st7 = row[ProcessorMainColumn::ST7.main_index()].to_string(); + let st8 = row[ProcessorMainColumn::ST8.main_index()].to_string(); + let st9 = row[ProcessorMainColumn::ST9.main_index()].to_string(); + + let osp = row[ProcessorMainColumn::OpStackPointer.main_index()]; let osp = (osp.value() + fake_op_stack_size).saturating_sub(OpStackElement::COUNT as u64); @@ -1552,16 +1550,15 @@ pub(crate) mod tests { .into_iter() .take(num_interesting_rows) { - let clk = row[OpStackBaseTableColumn::CLK.base_table_index()].to_string(); - let ib1 = row[OpStackBaseTableColumn::IB1ShrinkStack.base_table_index()].to_string(); + let clk = row[OpStackMainColumn::CLK.main_index()].to_string(); + let ib1 = row[OpStackMainColumn::IB1ShrinkStack.main_index()].to_string(); - let osp = row[OpStackBaseTableColumn::StackPointer.base_table_index()]; + let osp = row[OpStackMainColumn::StackPointer.main_index()]; let osp = (osp.value() + fake_op_stack_size).saturating_sub(OpStackElement::COUNT as u64); let osp = osp.to_string(); - let value = - row[OpStackBaseTableColumn::FirstUnderflowElement.base_table_index()].to_string(); + let value = row[OpStackMainColumn::FirstUnderflowElement.main_index()].to_string(); let interesting_cols = [clk, ib1, osp, value]; let interesting_cols = interesting_cols @@ -1572,8 +1569,8 @@ pub(crate) mod tests { } fn ci_and_nia_from_master_table_row(row: ArrayView1) -> (String, String) { - let curr_instruction = row[ProcessorBaseTableColumn::CI.base_table_index()].value(); - let next_instruction_or_arg = row[ProcessorBaseTableColumn::NIA.base_table_index()].value(); + let curr_instruction = row[ProcessorMainColumn::CI.main_index()].value(); + let next_instruction_or_arg = row[ProcessorMainColumn::NIA.main_index()].value(); let curr_instruction = Instruction::try_from(curr_instruction).unwrap(); let nia = curr_instruction @@ -1589,7 +1586,7 @@ pub(crate) mod tests { let padded_height = 2; let num_trace_randomizers = 2; let interpolant_degree = interpolant_degree(padded_height, num_trace_randomizers); - for deg in extension_table::all_degrees_with_origin(interpolant_degree, padded_height) { + for deg in auxiliary_table::all_degrees_with_origin(interpolant_degree, padded_height) { println!("{deg}"); } } @@ -1606,7 +1603,7 @@ pub(crate) mod tests { let processor_table = master_ext_table.table(TableId::Processor); let processor_table_last_row = processor_table.slice(s![-1, ..]); - let ptie = processor_table_last_row[InputTableEvalArg.ext_table_index()]; + let ptie = processor_table_last_row[InputTableEvalArg.aux_index()]; let ine = EvalArg::compute_terminal( &claim.input, EvalArg::default_initial(), @@ -1614,7 +1611,7 @@ pub(crate) mod tests { ); check!(ptie == ine); - let ptoe = processor_table_last_row[OutputTableEvalArg.ext_table_index()]; + let ptoe = processor_table_last_row[OutputTableEvalArg.aux_index()]; let oute = EvalArg::compute_terminal( &claim.output, EvalArg::default_initial(), @@ -1626,16 +1623,16 @@ pub(crate) mod tests { #[test] fn constraint_polynomials_use_right_number_of_variables() { let challenges = Challenges::default(); - let base_row = Array1::::zeros(MasterBaseTable::NUM_COLUMNS); - let ext_row = Array1::zeros(MasterExtTable::NUM_COLUMNS); + let main_row = Array1::::zeros(MasterMainTable::NUM_COLUMNS); + let aux_row = Array1::zeros(MasterAuxTable::NUM_COLUMNS); - let br = base_row.view(); - let er = ext_row.view(); + let br = main_row.view(); + let er = aux_row.view(); - MasterExtTable::evaluate_initial_constraints(br, er, &challenges); - MasterExtTable::evaluate_consistency_constraints(br, er, &challenges); - MasterExtTable::evaluate_transition_constraints(br, er, br, er, &challenges); - MasterExtTable::evaluate_terminal_constraints(br, er, &challenges); + MasterAuxTable::evaluate_initial_constraints(br, er, &challenges); + MasterAuxTable::evaluate_consistency_constraints(br, er, &challenges); + MasterAuxTable::evaluate_transition_constraints(br, er, br, er, &challenges); + MasterAuxTable::evaluate_terminal_constraints(br, er, &challenges); } #[test] @@ -1736,8 +1733,8 @@ pub(crate) mod tests { #[test] fn number_of_quotient_degree_bounds_match_number_of_constraints() { - let base_row = Array1::::zeros(MasterBaseTable::NUM_COLUMNS); - let ext_row = Array1::zeros(MasterExtTable::NUM_COLUMNS); + let main_row = Array1::::zeros(MasterMainTable::NUM_COLUMNS); + let aux_row = Array1::zeros(MasterAuxTable::NUM_COLUMNS); let ch = Challenges::default(); let padded_height = 2; let num_trace_randomizers = 2; @@ -1746,28 +1743,28 @@ pub(crate) mod tests { // Shorten some names for better formatting. This is just a test. let ph = padded_height; let id = interpolant_degree; - let br = base_row.view(); - let er = ext_row.view(); + let br = main_row.view(); + let er = aux_row.view(); - let num_init_quots = MasterExtTable::NUM_INITIAL_CONSTRAINTS; - let num_cons_quots = MasterExtTable::NUM_CONSISTENCY_CONSTRAINTS; - let num_tran_quots = MasterExtTable::NUM_TRANSITION_CONSTRAINTS; - let num_term_quots = MasterExtTable::NUM_TERMINAL_CONSTRAINTS; + let num_init_quots = MasterAuxTable::NUM_INITIAL_CONSTRAINTS; + let num_cons_quots = MasterAuxTable::NUM_CONSISTENCY_CONSTRAINTS; + let num_tran_quots = MasterAuxTable::NUM_TRANSITION_CONSTRAINTS; + let num_term_quots = MasterAuxTable::NUM_TERMINAL_CONSTRAINTS; - let eval_init_consts = MasterExtTable::evaluate_initial_constraints(br, er, &ch); - let eval_cons_consts = MasterExtTable::evaluate_consistency_constraints(br, er, &ch); - let eval_tran_consts = MasterExtTable::evaluate_transition_constraints(br, er, br, er, &ch); - let eval_term_consts = MasterExtTable::evaluate_terminal_constraints(br, er, &ch); + let eval_init_consts = MasterAuxTable::evaluate_initial_constraints(br, er, &ch); + let eval_cons_consts = MasterAuxTable::evaluate_consistency_constraints(br, er, &ch); + let eval_tran_consts = MasterAuxTable::evaluate_transition_constraints(br, er, br, er, &ch); + let eval_term_consts = MasterAuxTable::evaluate_terminal_constraints(br, er, &ch); assert!(num_init_quots == eval_init_consts.len()); assert!(num_cons_quots == eval_cons_consts.len()); assert!(num_tran_quots == eval_tran_consts.len()); assert!(num_term_quots == eval_term_consts.len()); - assert!(num_init_quots == MasterExtTable::initial_quotient_degree_bounds(id).len()); - assert!(num_cons_quots == MasterExtTable::consistency_quotient_degree_bounds(id, ph).len()); - assert!(num_tran_quots == MasterExtTable::transition_quotient_degree_bounds(id, ph).len()); - assert!(num_term_quots == MasterExtTable::terminal_quotient_degree_bounds(id).len()); + assert!(num_init_quots == MasterAuxTable::initial_quotient_degree_bounds(id).len()); + assert!(num_cons_quots == MasterAuxTable::consistency_quotient_degree_bounds(id, ph).len()); + assert!(num_tran_quots == MasterAuxTable::transition_quotient_degree_bounds(id, ph).len()); + assert!(num_term_quots == MasterAuxTable::terminal_quotient_degree_bounds(id).len()); } #[test] @@ -2223,7 +2220,7 @@ pub(crate) mod tests { let master_base_trace_table = master_base_table.trace_table(); let master_ext_trace_table = master_ext_table.trace_table(); - let evaluated_initial_constraints = MasterExtTable::evaluate_initial_constraints( + let evaluated_initial_constraints = MasterAuxTable::evaluate_initial_constraints( master_base_trace_table.row(0), master_ext_trace_table.row(0), &challenges, @@ -2239,7 +2236,7 @@ pub(crate) mod tests { for row_idx in 0..master_base_trace_table.nrows() { let evaluated_consistency_constraints = - MasterExtTable::evaluate_consistency_constraints( + MasterAuxTable::evaluate_consistency_constraints( master_base_trace_table.row(row_idx), master_ext_trace_table.row(row_idx), &challenges, @@ -2256,7 +2253,7 @@ pub(crate) mod tests { for curr_row_idx in 0..master_base_trace_table.nrows() - 1 { let next_row_idx = curr_row_idx + 1; - let evaluated_transition_constraints = MasterExtTable::evaluate_transition_constraints( + let evaluated_transition_constraints = MasterAuxTable::evaluate_transition_constraints( master_base_trace_table.row(curr_row_idx), master_ext_trace_table.row(curr_row_idx), master_base_trace_table.row(next_row_idx), @@ -2273,7 +2270,7 @@ pub(crate) mod tests { } } - let evaluated_terminal_constraints = MasterExtTable::evaluate_terminal_constraints( + let evaluated_terminal_constraints = MasterAuxTable::evaluate_terminal_constraints( master_base_trace_table.row(master_base_trace_table.nrows() - 1), master_ext_trace_table.row(master_ext_trace_table.nrows() - 1), &challenges, @@ -2470,11 +2467,11 @@ pub(crate) mod tests { #[filter(!#offset.is_zero())] offset: BFieldElement, #[strategy(arb())] main_polynomials: [Polynomial; - MasterBaseTable::NUM_COLUMNS], + MasterMainTable::NUM_COLUMNS], #[strategy(arb())] aux_polynomials: [Polynomial; - MasterExtTable::NUM_COLUMNS], + MasterAuxTable::NUM_COLUMNS], #[strategy(arb())] challenges: Challenges, - #[strategy(arb())] quotient_weights: [XFieldElement; MasterExtTable::NUM_CONSTRAINTS], + #[strategy(arb())] quotient_weights: [XFieldElement; MasterAuxTable::NUM_CONSTRAINTS], ) { // set up let main_polynomials = Array1::from_vec(main_polynomials.to_vec()); @@ -2529,15 +2526,15 @@ pub(crate) mod tests { quotient_weights: &[XFieldElement], ) -> (Array2, Array1>) { let mut base_quotient_domain_codewords = - Array2::::zeros([quotient_domain.length, MasterBaseTable::NUM_COLUMNS]); + Array2::::zeros([quotient_domain.length, MasterMainTable::NUM_COLUMNS]); Zip::from(base_quotient_domain_codewords.axis_iter_mut(Axis(1))) .and(main_polynomials.axis_iter(Axis(0))) .for_each(|codeword, polynomial| { Array1::from_vec(quotient_domain.evaluate(&polynomial[()])).move_into(codeword); }); - let mut ext_quotient_domain_codewords = - Array2::::zeros([quotient_domain.length, MasterExtTable::NUM_COLUMNS]); - Zip::from(ext_quotient_domain_codewords.axis_iter_mut(Axis(1))) + let mut aux_quotient_domain_codewords = + Array2::::zeros([quotient_domain.length, MasterAuxTable::NUM_COLUMNS]); + Zip::from(aux_quotient_domain_codewords.axis_iter_mut(Axis(1))) .and(aux_polynomials.axis_iter(Axis(0))) .for_each(|codeword, polynomial| { Array1::from_vec(quotient_domain.evaluate(&polynomial[()])).move_into(codeword); @@ -2545,7 +2542,7 @@ pub(crate) mod tests { let quotient_codeword = all_quotients_combined( base_quotient_domain_codewords.view(), - ext_quotient_domain_codewords.view(), + aux_quotient_domain_codewords.view(), trace_domain, quotient_domain, challenges, @@ -2621,12 +2618,12 @@ pub(crate) mod tests { ) { let weights = LinearCombinationWeights::sample(&mut proof_stream); - prop_assert_eq!(MasterBaseTable::NUM_COLUMNS, weights.main.len()); - prop_assert_eq!(MasterExtTable::NUM_COLUMNS, weights.aux.len()); + prop_assert_eq!(MasterMainTable::NUM_COLUMNS, weights.main.len()); + prop_assert_eq!(MasterAuxTable::NUM_COLUMNS, weights.aux.len()); prop_assert_eq!(NUM_QUOTIENT_SEGMENTS, weights.quot_segments.len()); prop_assert_eq!(NUM_DEEP_CODEWORD_COMPONENTS, weights.deep.len()); prop_assert_eq!( - MasterBaseTable::NUM_COLUMNS + MasterExtTable::NUM_COLUMNS, + MasterMainTable::NUM_COLUMNS + MasterAuxTable::NUM_COLUMNS, weights.base_and_ext().len() ); } @@ -2781,7 +2778,7 @@ pub(crate) mod tests { let (aet, _) = VM::trace_execution(&program, public_input, non_determinism).unwrap(); let opcodes_of_all_executed_instructions = aet .processor_trace - .column(ProcessorBaseTableColumn::CI.base_table_index()) + .column(ProcessorMainColumn::CI.main_index()) .iter() .copied() .collect::>(); diff --git a/triton-vm/src/table.rs b/triton-vm/src/table.rs index 146d87e0a..cf6de3928 100644 --- a/triton-vm/src/table.rs +++ b/triton-vm/src/table.rs @@ -12,14 +12,13 @@ use twenty_first::prelude::*; use crate::aet::AlgebraicExecutionTrace; use crate::challenges::Challenges; pub use crate::stark::NUM_QUOTIENT_SEGMENTS; -use crate::table::master_table::MasterBaseTable; -use crate::table::master_table::MasterExtTable; +use crate::table::master_table::MasterAuxTable; +use crate::table::master_table::MasterMainTable; use crate::table::master_table::MasterTable; -pub mod degree_lowering; - +pub mod auxiliary_table; pub mod cascade; -pub mod extension_table; +pub mod degree_lowering; pub mod hash; pub mod jump_stack; pub mod lookup; @@ -30,17 +29,6 @@ pub mod program; pub mod ram; pub mod u32; -/// The total number of main columns across all tables. -/// The degree lowering columns _are_ included. -pub const NUM_MAIN_COLUMNS: usize = - air::table::NUM_BASE_COLUMNS + degree_lowering::DegreeLoweringBaseTableColumn::COUNT; - -/// The total number of auxiliary columns across all tables. -/// The degree lowering columns _are_ included, -/// randomizer polynomials are _not_ included. -pub const NUM_AUX_COLUMNS: usize = - air::table::NUM_EXT_COLUMNS + degree_lowering::DegreeLoweringExtTableColumn::COUNT; - trait TraceTable: AIR { // a nicer design is in order type FillParam; @@ -55,8 +43,8 @@ trait TraceTable: AIR { fn pad(main_table: ArrayViewMut2, table_length: usize); fn extend( - base_table: ArrayView2, - ext_table: ArrayViewMut2, + main_table: ArrayView2, + aux_table: ArrayViewMut2, challenges: &Challenges, ); } @@ -78,15 +66,15 @@ pub enum ConstraintType { Terminal, } -/// A single row of a [`MasterBaseTable`]. +/// A single row of a [`MasterMainTable`]. /// /// Usually, the elements in the table are [`BFieldElement`]s. For out-of-domain rows, which is /// relevant for “Domain Extension to Eliminate Pretenders” (DEEP), the elements are /// [`XFieldElement`]s. -pub type BaseRow = [T; MasterBaseTable::NUM_COLUMNS]; +pub type MainRow = [T; MasterMainTable::NUM_COLUMNS]; -/// A single row of a [`MasterExtTable`]. -pub type ExtensionRow = [XFieldElement; MasterExtTable::NUM_COLUMNS]; +/// A single row of a [`MasterAuxTable`]. +pub type AuxiliaryRow = [XFieldElement; MasterAuxTable::NUM_COLUMNS]; /// An element of the split-up quotient polynomial. /// @@ -104,16 +92,16 @@ mod tests { use air::table::program::ProgramTable; use air::table::ram::RamTable; use air::table::u32::U32Table; + use air::table::AUX_CASCADE_TABLE_END; + use air::table::AUX_HASH_TABLE_END; + use air::table::AUX_JUMP_STACK_TABLE_END; + use air::table::AUX_LOOKUP_TABLE_END; + use air::table::AUX_OP_STACK_TABLE_END; + use air::table::AUX_PROCESSOR_TABLE_END; + use air::table::AUX_PROGRAM_TABLE_END; + use air::table::AUX_RAM_TABLE_END; + use air::table::AUX_U32_TABLE_END; use air::table::CASCADE_TABLE_END; - use air::table::EXT_CASCADE_TABLE_END; - use air::table::EXT_HASH_TABLE_END; - use air::table::EXT_JUMP_STACK_TABLE_END; - use air::table::EXT_LOOKUP_TABLE_END; - use air::table::EXT_OP_STACK_TABLE_END; - use air::table::EXT_PROCESSOR_TABLE_END; - use air::table::EXT_PROGRAM_TABLE_END; - use air::table::EXT_RAM_TABLE_END; - use air::table::EXT_U32_TABLE_END; use air::table::HASH_TABLE_END; use air::table::JUMP_STACK_TABLE_END; use air::table::LOOKUP_TABLE_END; @@ -161,16 +149,16 @@ mod tests { let challenges = &challenges.challenges; let num_rows = 2; - let base_shape = [num_rows, NUM_MAIN_COLUMNS]; - let ext_shape = [num_rows, NUM_AUX_COLUMNS]; - let base_rows = Array2::from_shape_simple_fn(base_shape, || rng.gen::()); - let ext_rows = Array2::from_shape_simple_fn(ext_shape, || rng.gen::()); - let base_rows = base_rows.view(); - let ext_rows = ext_rows.view(); + let base_shape = [num_rows, MasterMainTable::NUM_COLUMNS]; + let aux_shape = [num_rows, MasterAuxTable::NUM_COLUMNS]; + let main_rows = Array2::from_shape_simple_fn(base_shape, || rng.gen::()); + let aux_rows = Array2::from_shape_simple_fn(aux_shape, || rng.gen::()); + let main_rows = main_rows.view(); + let aux_rows = aux_rows.view(); let mut values = HashMap::new(); for c in constraints { - evaluate_assert_unique(c, challenges, base_rows, ext_rows, &mut values); + evaluate_assert_unique(c, challenges, main_rows, aux_rows, &mut values); } let circuit_degree = constraints.iter().map(|c| c.degree()).max().unwrap_or(-1); @@ -188,19 +176,19 @@ mod tests { fn evaluate_assert_unique( constraint: &ConstraintCircuit, challenges: &[XFieldElement], - base_rows: ArrayView2, - ext_rows: ArrayView2, + main_rows: ArrayView2, + aux_rows: ArrayView2, values: &mut HashMap)>, ) -> XFieldElement { let value = match &constraint.expression { CircuitExpression::BinaryOperation(binop, lhs, rhs) => { let lhs = lhs.borrow(); let rhs = rhs.borrow(); - let lhs = evaluate_assert_unique(&lhs, challenges, base_rows, ext_rows, values); - let rhs = evaluate_assert_unique(&rhs, challenges, base_rows, ext_rows, values); + let lhs = evaluate_assert_unique(&lhs, challenges, main_rows, aux_rows, values); + let rhs = evaluate_assert_unique(&rhs, challenges, main_rows, aux_rows, values); binop.operation(lhs, rhs) } - _ => constraint.evaluate(base_rows, ext_rows, challenges), + _ => constraint.evaluate(main_rows, aux_rows, challenges), }; let own_id = constraint.id.to_owned(); @@ -281,21 +269,21 @@ mod tests { println!(" {circuit}"); } - let (new_base_constraints, new_ext_constraints) = + let (new_main_constraints, new_aux_constraints) = ConstraintCircuitMonad::lower_to_degree(multicircuit, info); assert_eq!(num_constraints, multicircuit.len()); let target_deg = info.target_degree; assert!(ConstraintCircuitMonad::multicircuit_degree(multicircuit) <= target_deg); - assert!(ConstraintCircuitMonad::multicircuit_degree(&new_base_constraints) <= target_deg); - assert!(ConstraintCircuitMonad::multicircuit_degree(&new_ext_constraints) <= target_deg); + assert!(ConstraintCircuitMonad::multicircuit_degree(&new_main_constraints) <= target_deg); + assert!(ConstraintCircuitMonad::multicircuit_degree(&new_aux_constraints) <= target_deg); // Check that the new constraints are simple substitutions. let mut substitution_rules = vec![]; for (constraint_type, constraints) in [ - ("base", &new_base_constraints), - ("ext", &new_ext_constraints), + ("base", &new_main_constraints), + ("ext", &new_aux_constraints), ] { for (i, constraint) in constraints.iter().enumerate() { let expression = constraint.circuit.borrow().expression.clone(); @@ -320,20 +308,16 @@ mod tests { let challenges = &challenges.challenges; let num_rows = 2; - let num_new_base_constraints = new_base_constraints.len(); - let num_new_ext_constraints = new_ext_constraints.len(); - let num_base_cols = NUM_MAIN_COLUMNS + num_new_base_constraints; - let num_ext_cols = NUM_AUX_COLUMNS + num_new_ext_constraints; - let base_shape = [num_rows, num_base_cols]; - let ext_shape = [num_rows, num_ext_cols]; - let base_rows = Array2::from_shape_simple_fn(base_shape, || rng.gen::()); - let ext_rows = Array2::from_shape_simple_fn(ext_shape, || rng.gen::()); - let base_rows = base_rows.view(); - let ext_rows = ext_rows.view(); + let main_shape = [num_rows, MasterMainTable::NUM_COLUMNS]; + let aux_shape = [num_rows, MasterAuxTable::NUM_COLUMNS]; + let main_rows = Array2::from_shape_simple_fn(main_shape, || rng.gen::()); + let aux_rows = Array2::from_shape_simple_fn(aux_shape, || rng.gen::()); + let main_rows = main_rows.view(); + let aux_rows = aux_rows.view(); let evaluated_substitution_rules = substitution_rules .iter() - .map(|c| c.evaluate(base_rows, ext_rows, challenges)); + .map(|c| c.evaluate(main_rows, aux_rows, challenges)); let mut values_to_index = HashMap::new(); for (idx, value) in evaluated_substitution_rules.enumerate() { @@ -350,21 +334,23 @@ mod tests { println!(" {circuit}"); } println!("new base constraints:"); - for constraint in &new_base_constraints { + for constraint in &new_main_constraints { println!(" {constraint}"); } println!("new ext constraints:"); - for constraint in &new_ext_constraints { + for constraint in &new_aux_constraints { println!(" {constraint}"); } + let num_new_main_constraints = new_main_constraints.len(); + let num_new_aux_constraints = new_aux_constraints.len(); println!( "Started with {num_constraints} constraints. \ - Derived {num_new_base_constraints} new base, \ - {num_new_ext_constraints} new extension constraints." + Derived {num_new_main_constraints} new main, \ + {num_new_aux_constraints} new auxiliary constraints." ); - (new_base_constraints, new_ext_constraints) + (new_main_constraints, new_aux_constraints) } /// Panics if the given substitution rule uses variables with an index greater than (or equal) @@ -382,8 +368,8 @@ mod tests { assert_substitution_rule_uses_legal_variables(new_var, &rhs); } CircuitExpression::Input(old_var) => { - let new_var_is_base = new_var.is_base_table_column(); - let old_var_is_base = old_var.is_base_table_column(); + let new_var_is_base = new_var.is_main_table_column(); + let old_var_is_base = old_var.is_main_table_column(); let legal_substitute = match (new_var_is_base, old_var_is_base) { (true, false) => false, (false, true) => true, @@ -401,8 +387,8 @@ mod tests { ($table:ident ($base_end:ident, $ext_end:ident)) => {{ let degree_lowering_info = DegreeLoweringInfo { target_degree: air::TARGET_DEGREE, - num_base_cols: $base_end, - num_ext_cols: $ext_end, + num_main_cols: $base_end, + num_aux_cols: $ext_end, }; let circuit_builder = ConstraintCircuitBuilder::new(); let mut init = $table::initial_constraints(&circuit_builder); @@ -422,18 +408,18 @@ mod tests { }}; } - assert_degree_lowering!(ProgramTable(PROGRAM_TABLE_END, EXT_PROGRAM_TABLE_END)); - assert_degree_lowering!(ProcessorTable(PROCESSOR_TABLE_END, EXT_PROCESSOR_TABLE_END)); - assert_degree_lowering!(OpStackTable(OP_STACK_TABLE_END, EXT_OP_STACK_TABLE_END)); - assert_degree_lowering!(RamTable(RAM_TABLE_END, EXT_RAM_TABLE_END)); + assert_degree_lowering!(ProgramTable(PROGRAM_TABLE_END, AUX_PROGRAM_TABLE_END)); + assert_degree_lowering!(ProcessorTable(PROCESSOR_TABLE_END, AUX_PROCESSOR_TABLE_END)); + assert_degree_lowering!(OpStackTable(OP_STACK_TABLE_END, AUX_OP_STACK_TABLE_END)); + assert_degree_lowering!(RamTable(RAM_TABLE_END, AUX_RAM_TABLE_END)); assert_degree_lowering!(JumpStackTable( JUMP_STACK_TABLE_END, - EXT_JUMP_STACK_TABLE_END + AUX_JUMP_STACK_TABLE_END )); - assert_degree_lowering!(HashTable(HASH_TABLE_END, EXT_HASH_TABLE_END)); - assert_degree_lowering!(CascadeTable(CASCADE_TABLE_END, EXT_CASCADE_TABLE_END)); - assert_degree_lowering!(LookupTable(LOOKUP_TABLE_END, EXT_LOOKUP_TABLE_END)); - assert_degree_lowering!(U32Table(U32_TABLE_END, EXT_U32_TABLE_END)); + assert_degree_lowering!(HashTable(HASH_TABLE_END, AUX_HASH_TABLE_END)); + assert_degree_lowering!(CascadeTable(CASCADE_TABLE_END, AUX_CASCADE_TABLE_END)); + assert_degree_lowering!(LookupTable(LOOKUP_TABLE_END, AUX_LOOKUP_TABLE_END)); + assert_degree_lowering!(U32Table(U32_TABLE_END, AUX_U32_TABLE_END)); } /// Fills the derived columns of the degree-lowering table using randomly generated rows and @@ -444,30 +430,32 @@ mod tests { #[ignore = "(probably) requires normalization of circuit expressions"] fn substitution_rules_are_unique() { let challenges = Challenges::default(); - let mut base_table_rows = Array2::from_shape_fn((2, NUM_MAIN_COLUMNS), |_| random()); - let mut ext_table_rows = Array2::from_shape_fn((2, NUM_AUX_COLUMNS), |_| random()); + let mut base_table_rows = + Array2::from_shape_fn((2, MasterMainTable::NUM_COLUMNS), |_| random()); + let mut ext_table_rows = + Array2::from_shape_fn((2, MasterAuxTable::NUM_COLUMNS), |_| random()); - DegreeLoweringTable::fill_derived_base_columns(base_table_rows.view_mut()); - DegreeLoweringTable::fill_derived_ext_columns( + DegreeLoweringTable::fill_derived_main_columns(base_table_rows.view_mut()); + DegreeLoweringTable::fill_derived_aux_columns( base_table_rows.view(), ext_table_rows.view_mut(), &challenges, ); let mut encountered_values = HashMap::new(); - for col_idx in 0..NUM_MAIN_COLUMNS { + for col_idx in 0..MasterMainTable::NUM_COLUMNS { let val = base_table_rows[(0, col_idx)].lift(); let other_entry = encountered_values.insert(val, col_idx); if let Some(other_idx) = other_entry { - panic!("Duplicate value {val} in derived base column {other_idx} and {col_idx}."); + panic!("Duplicate value {val} in derived main column {other_idx} and {col_idx}."); } } println!("Now comparing extension columns…"); - for col_idx in 0..NUM_AUX_COLUMNS { + for col_idx in 0..MasterAuxTable::NUM_COLUMNS { let val = ext_table_rows[(0, col_idx)]; let other_entry = encountered_values.insert(val, col_idx); if let Some(other_idx) = other_entry { - panic!("Duplicate value {val} in derived ext column {other_idx} and {col_idx}."); + panic!("Duplicate value {val} in derived aux column {other_idx} and {col_idx}."); } } } diff --git a/triton-vm/src/table/extension_table.rs b/triton-vm/src/table/auxiliary_table.rs similarity index 85% rename from triton-vm/src/table/extension_table.rs rename to triton-vm/src/table/auxiliary_table.rs index 519189c7a..253a6f8bd 100644 --- a/triton-vm/src/table/extension_table.rs +++ b/triton-vm/src/table/auxiliary_table.rs @@ -8,7 +8,7 @@ use twenty_first::math::traits::FiniteField; use twenty_first::prelude::*; use crate::challenges::Challenges; -use crate::table::master_table::MasterExtTable; +use crate::table::master_table::MasterAuxTable; use crate::table::ConstraintType; include!(concat!(env!("OUT_DIR"), "/evaluate_constraints.rs")); @@ -16,28 +16,28 @@ include!(concat!(env!("OUT_DIR"), "/evaluate_constraints.rs")); // The implementations of these functions are generated in `build.rs`. pub trait Evaluable { fn evaluate_initial_constraints( - base_row: ArrayView1, - ext_row: ArrayView1, + main_row: ArrayView1, + aux_row: ArrayView1, challenges: &Challenges, ) -> Vec; fn evaluate_consistency_constraints( - base_row: ArrayView1, - ext_row: ArrayView1, + main_row: ArrayView1, + aux_row: ArrayView1, challenges: &Challenges, ) -> Vec; fn evaluate_transition_constraints( - current_base_row: ArrayView1, - current_ext_row: ArrayView1, - next_base_row: ArrayView1, - next_ext_row: ArrayView1, + current_main_row: ArrayView1, + current_aux_row: ArrayView1, + next_main_row: ArrayView1, + next_aux_row: ArrayView1, challenges: &Challenges, ) -> Vec; fn evaluate_terminal_constraints( - base_row: ArrayView1, - ext_row: ArrayView1, + main_row: ArrayView1, + aux_row: ArrayView1, challenges: &Challenges, ) -> Vec; } @@ -77,7 +77,7 @@ pub(crate) fn all_degrees_with_origin( padded_height: usize, ) -> Vec { let initial_degrees_with_origin = - MasterExtTable::initial_quotient_degree_bounds(interpolant_degree) + MasterAuxTable::initial_quotient_degree_bounds(interpolant_degree) .into_iter() .enumerate() .map(|(origin_index, degree)| DegreeWithOrigin { @@ -91,7 +91,7 @@ pub(crate) fn all_degrees_with_origin( .collect_vec(); let consistency_degrees_with_origin = - MasterExtTable::consistency_quotient_degree_bounds(interpolant_degree, padded_height) + MasterAuxTable::consistency_quotient_degree_bounds(interpolant_degree, padded_height) .into_iter() .enumerate() .map(|(origin_index, degree)| DegreeWithOrigin { @@ -105,7 +105,7 @@ pub(crate) fn all_degrees_with_origin( .collect(); let transition_degrees_with_origin = - MasterExtTable::transition_quotient_degree_bounds(interpolant_degree, padded_height) + MasterAuxTable::transition_quotient_degree_bounds(interpolant_degree, padded_height) .into_iter() .enumerate() .map(|(origin_index, degree)| DegreeWithOrigin { @@ -119,7 +119,7 @@ pub(crate) fn all_degrees_with_origin( .collect(); let terminal_degrees_with_origin = - MasterExtTable::terminal_quotient_degree_bounds(interpolant_degree) + MasterAuxTable::terminal_quotient_degree_bounds(interpolant_degree) .into_iter() .enumerate() .map(|(origin_index, degree)| DegreeWithOrigin { diff --git a/triton-vm/src/table/cascade.rs b/triton-vm/src/table/cascade.rs index 434d94294..246c5e487 100644 --- a/triton-vm/src/table/cascade.rs +++ b/triton-vm/src/table/cascade.rs @@ -2,8 +2,8 @@ use air::challenge_id::ChallengeId; use air::cross_table_argument::CrossTableArg; use air::cross_table_argument::LookupArg; use air::table::cascade::CascadeTable; -use air::table_column::MasterBaseTableColumn; -use air::table_column::MasterExtTableColumn; +use air::table_column::MasterAuxColumn; +use air::table_column::MasterMainColumn; use air::AIR; use ndarray::s; use ndarray::ArrayView2; @@ -45,11 +45,11 @@ impl TraceTable for CascadeTable { let to_look_up_hi = ((to_look_up >> 8) & 0xff) as u8; let mut row = main_table.row_mut(row_idx); - row[MainColumn::LookInLo.base_table_index()] = bfe!(to_look_up_lo); - row[MainColumn::LookInHi.base_table_index()] = bfe!(to_look_up_hi); - row[MainColumn::LookOutLo.base_table_index()] = lookup_8_bit_limb(to_look_up_lo); - row[MainColumn::LookOutHi.base_table_index()] = lookup_8_bit_limb(to_look_up_hi); - row[MainColumn::LookupMultiplicity.base_table_index()] = bfe!(multiplicity); + row[MainColumn::LookInLo.main_index()] = bfe!(to_look_up_lo); + row[MainColumn::LookInHi.main_index()] = bfe!(to_look_up_hi); + row[MainColumn::LookOutLo.main_index()] = lookup_8_bit_limb(to_look_up_lo); + row[MainColumn::LookOutHi.main_index()] = lookup_8_bit_limb(to_look_up_hi); + row[MainColumn::LookupMultiplicity.main_index()] = bfe!(multiplicity); } } @@ -57,7 +57,7 @@ impl TraceTable for CascadeTable { main_table .slice_mut(s![ cascade_table_length.., - MainColumn::IsPadding.base_table_index() + MainColumn::IsPadding.main_index() ]) .fill(BFieldElement::ONE); } @@ -86,35 +86,34 @@ impl TraceTable for CascadeTable { let lookup_output_weight = challenges[ChallengeId::LookupTableOutputWeight]; for row_idx in 0..main_table.nrows() { - let base_row = main_table.row(row_idx); - let is_padding = base_row[MainColumn::IsPadding.base_table_index()].is_one(); + let main_row = main_table.row(row_idx); + let is_padding = main_row[MainColumn::IsPadding.main_index()].is_one(); if !is_padding { - let look_in = two_pow_8 * base_row[MainColumn::LookInHi.base_table_index()] - + base_row[MainColumn::LookInLo.base_table_index()]; - let look_out = two_pow_8 * base_row[MainColumn::LookOutHi.base_table_index()] - + base_row[MainColumn::LookOutLo.base_table_index()]; + let look_in = two_pow_8 * main_row[MainColumn::LookInHi.main_index()] + + main_row[MainColumn::LookInLo.main_index()]; + let look_out = two_pow_8 * main_row[MainColumn::LookOutHi.main_index()] + + main_row[MainColumn::LookOutLo.main_index()]; let compressed_row_hash = hash_input_weight * look_in + hash_output_weight * look_out; - let lookup_multiplicity = - base_row[MainColumn::LookupMultiplicity.base_table_index()]; + let lookup_multiplicity = main_row[MainColumn::LookupMultiplicity.main_index()]; hash_table_log_derivative += (hash_indeterminate - compressed_row_hash).inverse() * lookup_multiplicity; let compressed_row_lo = lookup_input_weight - * base_row[MainColumn::LookInLo.base_table_index()] - + lookup_output_weight * base_row[MainColumn::LookOutLo.base_table_index()]; + * main_row[MainColumn::LookInLo.main_index()] + + lookup_output_weight * main_row[MainColumn::LookOutLo.main_index()]; let compressed_row_hi = lookup_input_weight - * base_row[MainColumn::LookInHi.base_table_index()] - + lookup_output_weight * base_row[MainColumn::LookOutHi.base_table_index()]; + * main_row[MainColumn::LookInHi.main_index()] + + lookup_output_weight * main_row[MainColumn::LookOutHi.main_index()]; lookup_table_log_derivative += (lookup_indeterminate - compressed_row_lo).inverse(); lookup_table_log_derivative += (lookup_indeterminate - compressed_row_hi).inverse(); } let mut extension_row = aux_table.row_mut(row_idx); - extension_row[AuxColumn::HashTableServerLogDerivative.ext_table_index()] = + extension_row[AuxColumn::HashTableServerLogDerivative.aux_index()] = hash_table_log_derivative; - extension_row[AuxColumn::LookupTableClientLogDerivative.ext_table_index()] = + extension_row[AuxColumn::LookupTableClientLogDerivative.aux_index()] = lookup_table_log_derivative; } profiler!(stop "cascade table"); diff --git a/triton-vm/src/table/hash.rs b/triton-vm/src/table/hash.rs index 14272035c..ac8b57ae2 100644 --- a/triton-vm/src/table/hash.rs +++ b/triton-vm/src/table/hash.rs @@ -6,10 +6,10 @@ use air::table::hash::HashTable; use air::table::hash::HashTableMode; use air::table::hash::PermutationTrace; use air::table::hash::MONTGOMERY_MODULUS; -use air::table_column::HashBaseTableColumn::*; -use air::table_column::HashExtTableColumn::*; -use air::table_column::MasterBaseTableColumn; -use air::table_column::MasterExtTableColumn; +use air::table_column::HashAuxColumn::*; +use air::table_column::HashMainColumn::*; +use air::table_column::MasterAuxColumn; +use air::table_column::MasterMainColumn; use air::AIR; use isa::instruction::Instruction; use itertools::Itertools; @@ -69,7 +69,7 @@ fn fill_row_with_round_number( mut row: Array1, round_number: usize, ) -> Array1 { - row[RoundNumber.base_table_index()] = bfe!(round_number as u64); + row[RoundNumber.main_index()] = bfe!(round_number as u64); row } @@ -89,16 +89,16 @@ fn fill_split_state_element_0_of_row_using_trace_row( ) -> Array1 { let limbs = base_field_element_into_16_bit_limbs(trace_row[0]); let look_in_split = limbs.map(|limb| bfe!(limb)); - row[State0LowestLkIn.base_table_index()] = look_in_split[0]; - row[State0MidLowLkIn.base_table_index()] = look_in_split[1]; - row[State0MidHighLkIn.base_table_index()] = look_in_split[2]; - row[State0HighestLkIn.base_table_index()] = look_in_split[3]; + row[State0LowestLkIn.main_index()] = look_in_split[0]; + row[State0MidLowLkIn.main_index()] = look_in_split[1]; + row[State0MidHighLkIn.main_index()] = look_in_split[2]; + row[State0HighestLkIn.main_index()] = look_in_split[3]; let look_out_split = limbs.map(crate::table::cascade::lookup_16_bit_limb); - row[State0LowestLkOut.base_table_index()] = look_out_split[0]; - row[State0MidLowLkOut.base_table_index()] = look_out_split[1]; - row[State0MidHighLkOut.base_table_index()] = look_out_split[2]; - row[State0HighestLkOut.base_table_index()] = look_out_split[3]; + row[State0LowestLkOut.main_index()] = look_out_split[0]; + row[State0MidLowLkOut.main_index()] = look_out_split[1]; + row[State0MidHighLkOut.main_index()] = look_out_split[2]; + row[State0HighestLkOut.main_index()] = look_out_split[3]; row } @@ -109,16 +109,16 @@ fn fill_split_state_element_1_of_row_using_trace_row( ) -> Array1 { let limbs = base_field_element_into_16_bit_limbs(trace_row[1]); let look_in_split = limbs.map(|limb| bfe!(limb)); - row[State1LowestLkIn.base_table_index()] = look_in_split[0]; - row[State1MidLowLkIn.base_table_index()] = look_in_split[1]; - row[State1MidHighLkIn.base_table_index()] = look_in_split[2]; - row[State1HighestLkIn.base_table_index()] = look_in_split[3]; + row[State1LowestLkIn.main_index()] = look_in_split[0]; + row[State1MidLowLkIn.main_index()] = look_in_split[1]; + row[State1MidHighLkIn.main_index()] = look_in_split[2]; + row[State1HighestLkIn.main_index()] = look_in_split[3]; let look_out_split = limbs.map(crate::table::cascade::lookup_16_bit_limb); - row[State1LowestLkOut.base_table_index()] = look_out_split[0]; - row[State1MidLowLkOut.base_table_index()] = look_out_split[1]; - row[State1MidHighLkOut.base_table_index()] = look_out_split[2]; - row[State1HighestLkOut.base_table_index()] = look_out_split[3]; + row[State1LowestLkOut.main_index()] = look_out_split[0]; + row[State1MidLowLkOut.main_index()] = look_out_split[1]; + row[State1MidHighLkOut.main_index()] = look_out_split[2]; + row[State1HighestLkOut.main_index()] = look_out_split[3]; row } @@ -129,16 +129,16 @@ fn fill_split_state_element_2_of_row_using_trace_row( ) -> Array1 { let limbs = base_field_element_into_16_bit_limbs(trace_row[2]); let look_in_split = limbs.map(|limb| bfe!(limb)); - row[State2LowestLkIn.base_table_index()] = look_in_split[0]; - row[State2MidLowLkIn.base_table_index()] = look_in_split[1]; - row[State2MidHighLkIn.base_table_index()] = look_in_split[2]; - row[State2HighestLkIn.base_table_index()] = look_in_split[3]; + row[State2LowestLkIn.main_index()] = look_in_split[0]; + row[State2MidLowLkIn.main_index()] = look_in_split[1]; + row[State2MidHighLkIn.main_index()] = look_in_split[2]; + row[State2HighestLkIn.main_index()] = look_in_split[3]; let look_out_split = limbs.map(crate::table::cascade::lookup_16_bit_limb); - row[State2LowestLkOut.base_table_index()] = look_out_split[0]; - row[State2MidLowLkOut.base_table_index()] = look_out_split[1]; - row[State2MidHighLkOut.base_table_index()] = look_out_split[2]; - row[State2HighestLkOut.base_table_index()] = look_out_split[3]; + row[State2LowestLkOut.main_index()] = look_out_split[0]; + row[State2MidLowLkOut.main_index()] = look_out_split[1]; + row[State2MidHighLkOut.main_index()] = look_out_split[2]; + row[State2HighestLkOut.main_index()] = look_out_split[3]; row } @@ -149,16 +149,16 @@ fn fill_split_state_element_3_of_row_using_trace_row( ) -> Array1 { let limbs = base_field_element_into_16_bit_limbs(trace_row[3]); let look_in_split = limbs.map(|limb| bfe!(limb)); - row[State3LowestLkIn.base_table_index()] = look_in_split[0]; - row[State3MidLowLkIn.base_table_index()] = look_in_split[1]; - row[State3MidHighLkIn.base_table_index()] = look_in_split[2]; - row[State3HighestLkIn.base_table_index()] = look_in_split[3]; + row[State3LowestLkIn.main_index()] = look_in_split[0]; + row[State3MidLowLkIn.main_index()] = look_in_split[1]; + row[State3MidHighLkIn.main_index()] = look_in_split[2]; + row[State3HighestLkIn.main_index()] = look_in_split[3]; let look_out_split = limbs.map(crate::table::cascade::lookup_16_bit_limb); - row[State3LowestLkOut.base_table_index()] = look_out_split[0]; - row[State3MidLowLkOut.base_table_index()] = look_out_split[1]; - row[State3MidHighLkOut.base_table_index()] = look_out_split[2]; - row[State3HighestLkOut.base_table_index()] = look_out_split[3]; + row[State3LowestLkOut.main_index()] = look_out_split[0]; + row[State3MidLowLkOut.main_index()] = look_out_split[1]; + row[State3MidHighLkOut.main_index()] = look_out_split[2]; + row[State3HighestLkOut.main_index()] = look_out_split[3]; row } @@ -167,18 +167,18 @@ fn fill_row_with_unsplit_state_elements_using_trace_row( mut row: Array1, trace_row: [BFieldElement; STATE_SIZE], ) -> Array1 { - row[State4.base_table_index()] = trace_row[4]; - row[State5.base_table_index()] = trace_row[5]; - row[State6.base_table_index()] = trace_row[6]; - row[State7.base_table_index()] = trace_row[7]; - row[State8.base_table_index()] = trace_row[8]; - row[State9.base_table_index()] = trace_row[9]; - row[State10.base_table_index()] = trace_row[10]; - row[State11.base_table_index()] = trace_row[11]; - row[State12.base_table_index()] = trace_row[12]; - row[State13.base_table_index()] = trace_row[13]; - row[State14.base_table_index()] = trace_row[14]; - row[State15.base_table_index()] = trace_row[15]; + row[State4.main_index()] = trace_row[4]; + row[State5.main_index()] = trace_row[5]; + row[State6.main_index()] = trace_row[6]; + row[State7.main_index()] = trace_row[7]; + row[State8.main_index()] = trace_row[8]; + row[State9.main_index()] = trace_row[9]; + row[State10.main_index()] = trace_row[10]; + row[State11.main_index()] = trace_row[11]; + row[State12.main_index()] = trace_row[12]; + row[State13.main_index()] = trace_row[13]; + row[State14.main_index()] = trace_row[14]; + row[State15.main_index()] = trace_row[15]; row } @@ -186,10 +186,10 @@ fn fill_row_with_state_inverses_using_trace_row( mut row: Array1, trace_row: [BFieldElement; STATE_SIZE], ) -> Array1 { - row[State0Inv.base_table_index()] = inverse_or_zero_of_highest_2_limbs(trace_row[0]); - row[State1Inv.base_table_index()] = inverse_or_zero_of_highest_2_limbs(trace_row[1]); - row[State2Inv.base_table_index()] = inverse_or_zero_of_highest_2_limbs(trace_row[2]); - row[State3Inv.base_table_index()] = inverse_or_zero_of_highest_2_limbs(trace_row[3]); + row[State0Inv.main_index()] = inverse_or_zero_of_highest_2_limbs(trace_row[0]); + row[State1Inv.main_index()] = inverse_or_zero_of_highest_2_limbs(trace_row[1]); + row[State2Inv.main_index()] = inverse_or_zero_of_highest_2_limbs(trace_row[2]); + row[State3Inv.main_index()] = inverse_or_zero_of_highest_2_limbs(trace_row[3]); row } @@ -212,22 +212,22 @@ fn fill_row_with_round_constants_for_round( ) -> Array1 { let round_constants = HashTable::tip5_round_constants_by_round_number(round_number); let [r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13, r14, r15] = round_constants; - row[Constant0.base_table_index()] = r0; - row[Constant1.base_table_index()] = r1; - row[Constant2.base_table_index()] = r2; - row[Constant3.base_table_index()] = r3; - row[Constant4.base_table_index()] = r4; - row[Constant5.base_table_index()] = r5; - row[Constant6.base_table_index()] = r6; - row[Constant7.base_table_index()] = r7; - row[Constant8.base_table_index()] = r8; - row[Constant9.base_table_index()] = r9; - row[Constant10.base_table_index()] = r10; - row[Constant11.base_table_index()] = r11; - row[Constant12.base_table_index()] = r12; - row[Constant13.base_table_index()] = r13; - row[Constant14.base_table_index()] = r14; - row[Constant15.base_table_index()] = r15; + row[Constant0.main_index()] = r0; + row[Constant1.main_index()] = r1; + row[Constant2.main_index()] = r2; + row[Constant3.main_index()] = r3; + row[Constant4.main_index()] = r4; + row[Constant5.main_index()] = r5; + row[Constant6.main_index()] = r6; + row[Constant7.main_index()] = r7; + row[Constant8.main_index()] = r8; + row[Constant9.main_index()] = r9; + row[Constant10.main_index()] = r10; + row[Constant11.main_index()] = r11; + row[Constant12.main_index()] = r12; + row[Constant13.main_index()] = r13; + row[Constant14.main_index()] = r14; + row[Constant15.main_index()] = r15; row } @@ -253,7 +253,7 @@ impl TraceTable for HashTable { sponge_part.assign(&aet.sponge_trace); hash_part.assign(&aet.hash_trace); - let mode_column_idx = Mode.base_table_index(); + let mode_column_idx = Mode.main_index(); let mut program_hash_mode_column = program_hash_part.column_mut(mode_column_idx); let mut sponge_mode_column = sponge_part.column_mut(mode_column_idx); let mut hash_mode_column = hash_part.column_mut(mode_column_idx); @@ -266,7 +266,7 @@ impl TraceTable for HashTable { fn pad(mut main_table: ArrayViewMut2, table_length: usize) { let inverse_of_high_limbs = inverse_or_zero_of_highest_2_limbs(bfe!(0)); for column_id in [State0Inv, State1Inv, State2Inv, State3Inv] { - let column_index = column_id.base_table_index(); + let column_index = column_id.main_index(); let slice_info = s![table_length.., column_index]; let mut column = main_table.slice_mut(slice_info); column.fill(inverse_of_high_limbs); @@ -276,18 +276,18 @@ impl TraceTable for HashTable { for (round_constant_idx, &round_constant) in round_constants.iter().enumerate() { let round_constant_column = HashTable::round_constant_column_by_index(round_constant_idx); - let round_constant_column_idx = round_constant_column.base_table_index(); + let round_constant_column_idx = round_constant_column.main_index(); let slice_info = s![table_length.., round_constant_column_idx]; let mut column = main_table.slice_mut(slice_info); column.fill(round_constant); } - let mode_column_index = Mode.base_table_index(); + let mode_column_index = Mode.main_index(); let mode_column_slice_info = s![table_length.., mode_column_index]; let mut mode_column = main_table.slice_mut(mode_column_slice_info); mode_column.fill(HashTableMode::Pad.into()); - let instruction_column_index = CI.base_table_index(); + let instruction_column_index = CI.main_index(); let instruction_column_slice_info = s![table_length.., instruction_column_index]; let mut instruction_column = main_table.slice_mut(instruction_column_slice_info); instruction_column.fill(Instruction::Hash.opcode_b()); @@ -342,10 +342,10 @@ impl TraceTable for HashTable { mid_high: Self::MainColumn, mid_low: Self::MainColumn, lowest: Self::MainColumn| { - (row[highest.base_table_index()] * two_pow_48 - + row[mid_high.base_table_index()] * two_pow_32 - + row[mid_low.base_table_index()] * two_pow_16 - + row[lowest.base_table_index()]) + (row[highest.main_index()] * two_pow_48 + + row[mid_high.main_index()] * two_pow_32 + + row[mid_low.main_index()] * two_pow_16 + + row[lowest.main_index()]) * montgomery_modulus_inverse }; @@ -383,12 +383,12 @@ impl TraceTable for HashTable { state_1, state_2, state_3, - row[State4.base_table_index()], - row[State5.base_table_index()], - row[State6.base_table_index()], - row[State7.base_table_index()], - row[State8.base_table_index()], - row[State9.base_table_index()], + row[State4.main_index()], + row[State5.main_index()], + row[State6.main_index()], + row[State7.main_index()], + row[State8.main_index()], + row[State9.main_index()], ] }; @@ -409,25 +409,25 @@ impl TraceTable for HashTable { lk_in_col: Self::MainColumn, lk_out_col: Self::MainColumn| { let compressed_elements = cascade_indeterminate - - cascade_look_in_weight * row[lk_in_col.base_table_index()] - - cascade_look_out_weight * row[lk_out_col.base_table_index()]; + - cascade_look_in_weight * row[lk_in_col.main_index()] + - cascade_look_out_weight * row[lk_out_col.main_index()]; compressed_elements.inverse() }; for row_idx in 0..main_table.nrows() { let row = main_table.row(row_idx); - let mode = row[Mode.base_table_index()]; + let mode = row[Mode.main_index()]; let in_program_hashing_mode = mode == HashTableMode::ProgramHashing.into(); let in_sponge_mode = mode == HashTableMode::Sponge.into(); let in_hash_mode = mode == HashTableMode::Hash.into(); let in_pad_mode = mode == HashTableMode::Pad.into(); - let round_number = row[RoundNumber.base_table_index()]; + let round_number = row[RoundNumber.main_index()]; let in_round_0 = round_number.is_zero(); let in_last_round = round_number == (NUM_ROUNDS as u64).into(); - let current_instruction = row[CI.base_table_index()]; + let current_instruction = row[CI.main_index()]; let current_instruction_is_sponge_init = current_instruction == Instruction::SpongeInit.opcode_b(); @@ -506,44 +506,42 @@ impl TraceTable for HashTable { } let mut extension_row = aux_table.row_mut(row_idx); - extension_row[ReceiveChunkRunningEvaluation.ext_table_index()] = + extension_row[ReceiveChunkRunningEvaluation.aux_index()] = receive_chunk_running_evaluation; - extension_row[HashInputRunningEvaluation.ext_table_index()] = - hash_input_running_evaluation; - extension_row[HashDigestRunningEvaluation.ext_table_index()] = - hash_digest_running_evaluation; - extension_row[SpongeRunningEvaluation.ext_table_index()] = sponge_running_evaluation; - extension_row[CascadeState0HighestClientLogDerivative.ext_table_index()] = + extension_row[HashInputRunningEvaluation.aux_index()] = hash_input_running_evaluation; + extension_row[HashDigestRunningEvaluation.aux_index()] = hash_digest_running_evaluation; + extension_row[SpongeRunningEvaluation.aux_index()] = sponge_running_evaluation; + extension_row[CascadeState0HighestClientLogDerivative.aux_index()] = cascade_state_0_highest_log_derivative; - extension_row[CascadeState0MidHighClientLogDerivative.ext_table_index()] = + extension_row[CascadeState0MidHighClientLogDerivative.aux_index()] = cascade_state_0_mid_high_log_derivative; - extension_row[CascadeState0MidLowClientLogDerivative.ext_table_index()] = + extension_row[CascadeState0MidLowClientLogDerivative.aux_index()] = cascade_state_0_mid_low_log_derivative; - extension_row[CascadeState0LowestClientLogDerivative.ext_table_index()] = + extension_row[CascadeState0LowestClientLogDerivative.aux_index()] = cascade_state_0_lowest_log_derivative; - extension_row[CascadeState1HighestClientLogDerivative.ext_table_index()] = + extension_row[CascadeState1HighestClientLogDerivative.aux_index()] = cascade_state_1_highest_log_derivative; - extension_row[CascadeState1MidHighClientLogDerivative.ext_table_index()] = + extension_row[CascadeState1MidHighClientLogDerivative.aux_index()] = cascade_state_1_mid_high_log_derivative; - extension_row[CascadeState1MidLowClientLogDerivative.ext_table_index()] = + extension_row[CascadeState1MidLowClientLogDerivative.aux_index()] = cascade_state_1_mid_low_log_derivative; - extension_row[CascadeState1LowestClientLogDerivative.ext_table_index()] = + extension_row[CascadeState1LowestClientLogDerivative.aux_index()] = cascade_state_1_lowest_log_derivative; - extension_row[CascadeState2HighestClientLogDerivative.ext_table_index()] = + extension_row[CascadeState2HighestClientLogDerivative.aux_index()] = cascade_state_2_highest_log_derivative; - extension_row[CascadeState2MidHighClientLogDerivative.ext_table_index()] = + extension_row[CascadeState2MidHighClientLogDerivative.aux_index()] = cascade_state_2_mid_high_log_derivative; - extension_row[CascadeState2MidLowClientLogDerivative.ext_table_index()] = + extension_row[CascadeState2MidLowClientLogDerivative.aux_index()] = cascade_state_2_mid_low_log_derivative; - extension_row[CascadeState2LowestClientLogDerivative.ext_table_index()] = + extension_row[CascadeState2LowestClientLogDerivative.aux_index()] = cascade_state_2_lowest_log_derivative; - extension_row[CascadeState3HighestClientLogDerivative.ext_table_index()] = + extension_row[CascadeState3HighestClientLogDerivative.aux_index()] = cascade_state_3_highest_log_derivative; - extension_row[CascadeState3MidHighClientLogDerivative.ext_table_index()] = + extension_row[CascadeState3MidHighClientLogDerivative.aux_index()] = cascade_state_3_mid_high_log_derivative; - extension_row[CascadeState3MidLowClientLogDerivative.ext_table_index()] = + extension_row[CascadeState3MidLowClientLogDerivative.aux_index()] = cascade_state_3_mid_low_log_derivative; - extension_row[CascadeState3LowestClientLogDerivative.ext_table_index()] = + extension_row[CascadeState3LowestClientLogDerivative.aux_index()] = cascade_state_3_lowest_log_derivative; } profiler!(stop "hash table"); @@ -553,7 +551,7 @@ impl TraceTable for HashTable { #[cfg(test)] pub(crate) mod tests { use air::table::TableId; - use air::table_column::HashBaseTableColumn; + use air::table_column::HashMainColumn; use air::AIR; use constraint_circuit::ConstraintCircuitBuilder; use std::collections::HashMap; @@ -600,15 +598,15 @@ pub(crate) mod tests { dbg!(aet.height_of_table(TableId::OpStack)); dbg!(aet.height_of_table(TableId::Cascade)); - let (_, _, master_base_table, master_ext_table, challenges) = + let (_, _, master_main_table, master_aux_table, challenges) = master_tables_for_low_security_level(ProgramAndInput::new(program)); let challenges = &challenges.challenges; - let master_base_trace_table = master_base_table.trace_table(); - let master_ext_trace_table = master_ext_table.trace_table(); + let master_main_trace_table = master_main_table.trace_table(); + let master_aux_trace_table = master_aux_table.trace_table(); - let last_row = master_base_trace_table.slice(s![-1.., ..]); - let last_opcode = last_row[[0, HashBaseTableColumn::CI.master_base_table_index()]]; + let last_row = master_main_trace_table.slice(s![-1.., ..]); + let last_opcode = last_row[[0, HashMainColumn::CI.master_main_index()]]; let last_instruction: Instruction = last_opcode.value().try_into().unwrap(); assert_eq!(Instruction::SpongeInit, last_instruction); @@ -619,8 +617,8 @@ pub(crate) mod tests { .enumerate() { let evaluated_constraint = constraint.evaluate( - master_base_trace_table.slice(s![-1.., ..]), - master_ext_trace_table.slice(s![-1.., ..]), + master_main_trace_table.slice(s![-1.., ..]), + master_aux_trace_table.slice(s![-1.., ..]), challenges, ); assert_eq!( diff --git a/triton-vm/src/table/jump_stack.rs b/triton-vm/src/table/jump_stack.rs index 1f2301d44..8fb59eb20 100644 --- a/triton-vm/src/table/jump_stack.rs +++ b/triton-vm/src/table/jump_stack.rs @@ -34,12 +34,11 @@ fn extension_column_running_product_permutation_argument( let mut running_product = PermArg::default_initial(); let mut extension_column = Vec::with_capacity(main_table.nrows()); for row in main_table.rows() { - let compressed_row = row[MainColumn::CLK.base_table_index()] - * challenges[JumpStackClkWeight] - + row[MainColumn::CI.base_table_index()] * challenges[JumpStackCiWeight] - + row[MainColumn::JSP.base_table_index()] * challenges[JumpStackJspWeight] - + row[MainColumn::JSO.base_table_index()] * challenges[JumpStackJsoWeight] - + row[MainColumn::JSD.base_table_index()] * challenges[JumpStackJsdWeight]; + let compressed_row = row[MainColumn::CLK.main_index()] * challenges[JumpStackClkWeight] + + row[MainColumn::CI.main_index()] * challenges[JumpStackCiWeight] + + row[MainColumn::JSP.main_index()] * challenges[JumpStackJspWeight] + + row[MainColumn::JSO.main_index()] * challenges[JumpStackJsoWeight] + + row[MainColumn::JSD.main_index()] * challenges[JumpStackJsdWeight]; running_product *= challenges[JumpStackIndeterminate] - compressed_row; extension_column.push(running_product); } @@ -67,11 +66,9 @@ fn extension_column_clock_jump_diff_lookup_log_derivative( let mut extension_column = Vec::with_capacity(main_table.nrows()); extension_column.push(cjd_lookup_log_derivative); for (previous_row, current_row) in main_table.rows().into_iter().tuple_windows() { - if previous_row[MainColumn::JSP.base_table_index()] - == current_row[MainColumn::JSP.base_table_index()] - { - let previous_clock = previous_row[MainColumn::CLK.base_table_index()]; - let current_clock = current_row[MainColumn::CLK.base_table_index()]; + if previous_row[MainColumn::JSP.main_index()] == current_row[MainColumn::JSP.main_index()] { + let previous_clock = previous_row[MainColumn::CLK.main_index()]; + let current_clock = current_row[MainColumn::CLK.main_index()]; let clock_jump_difference = current_clock - previous_clock; let &mut inverse = inverses_dictionary .entry(clock_jump_difference) @@ -97,11 +94,11 @@ impl TraceTable for JumpStackTable { // rows, which are sorted by CLK. let mut pre_processed_jump_stack_table: Vec> = vec![]; for processor_row in aet.processor_trace.rows() { - let clk = processor_row[ProcessorBaseTableColumn::CLK.base_table_index()]; - let ci = processor_row[ProcessorBaseTableColumn::CI.base_table_index()]; - let jsp = processor_row[ProcessorBaseTableColumn::JSP.base_table_index()]; - let jso = processor_row[ProcessorBaseTableColumn::JSO.base_table_index()]; - let jsd = processor_row[ProcessorBaseTableColumn::JSD.base_table_index()]; + let clk = processor_row[ProcessorMainColumn::CLK.main_index()]; + let ci = processor_row[ProcessorMainColumn::CI.main_index()]; + let jsp = processor_row[ProcessorMainColumn::JSP.main_index()]; + let jso = processor_row[ProcessorMainColumn::JSO.main_index()]; + let jsd = processor_row[ProcessorMainColumn::JSD.main_index()]; // The (honest) prover can only grow the Jump Stack's size by at most 1 per execution // step. Hence, the following (a) works, and (b) sorts. let jsp_val = jsp.value() as usize; @@ -119,11 +116,11 @@ impl TraceTable for JumpStackTable { { let jsp = bfe!(jsp_val as u64); for (clk, ci, jso, jsd) in rows_with_this_jsp { - jump_stack_table[(jump_stack_table_row, MainColumn::CLK.base_table_index())] = clk; - jump_stack_table[(jump_stack_table_row, MainColumn::CI.base_table_index())] = ci; - jump_stack_table[(jump_stack_table_row, MainColumn::JSP.base_table_index())] = jsp; - jump_stack_table[(jump_stack_table_row, MainColumn::JSO.base_table_index())] = jso; - jump_stack_table[(jump_stack_table_row, MainColumn::JSD.base_table_index())] = jsd; + jump_stack_table[(jump_stack_table_row, MainColumn::CLK.main_index())] = clk; + jump_stack_table[(jump_stack_table_row, MainColumn::CI.main_index())] = ci; + jump_stack_table[(jump_stack_table_row, MainColumn::JSP.main_index())] = jsp; + jump_stack_table[(jump_stack_table_row, MainColumn::JSO.main_index())] = jso; + jump_stack_table[(jump_stack_table_row, MainColumn::JSD.main_index())] = jsd; jump_stack_table_row += 1; } } @@ -135,11 +132,9 @@ impl TraceTable for JumpStackTable { for row_idx in 0..aet.processor_trace.nrows() - 1 { let curr_row = jump_stack_table.row(row_idx); let next_row = jump_stack_table.row(row_idx + 1); - let clk_diff = next_row[MainColumn::CLK.base_table_index()] - - curr_row[MainColumn::CLK.base_table_index()]; - if curr_row[MainColumn::JSP.base_table_index()] - == next_row[MainColumn::JSP.base_table_index()] - { + let clk_diff = + next_row[MainColumn::CLK.main_index()] - curr_row[MainColumn::CLK.main_index()]; + if curr_row[MainColumn::JSP.main_index()] == next_row[MainColumn::JSP.main_index()] { clock_jump_differences.push(clk_diff); } } @@ -158,7 +153,7 @@ impl TraceTable for JumpStackTable { .into_iter() .enumerate() .find(|(_, row)| { - row[MainColumn::CLK.base_table_index()].value() as usize == max_clk_before_padding + row[MainColumn::CLK.main_index()].value() as usize == max_clk_before_padding }) .map(|(idx, _)| idx) .expect("Jump Stack Table must contain row with clock cycle equal to max cycle."); @@ -199,8 +194,7 @@ impl TraceTable for JumpStackTable { // CLK keeps increasing by 1 also in the padding section. let new_clk_values = Array1::from_iter((table_len..padded_height).map(|clk| bfe!(clk as u64))); - new_clk_values - .move_into(padding_section.slice_mut(s![.., MainColumn::CLK.base_table_index()])); + new_clk_values.move_into(padding_section.slice_mut(s![.., MainColumn::CLK.main_index()])); } fn extend( @@ -214,8 +208,8 @@ impl TraceTable for JumpStackTable { assert_eq!(main_table.nrows(), aux_table.nrows()); // use strum::IntoEnumIterator; - let extension_column_indices = JumpStackExtTableColumn::iter() - .map(|column| column.ext_table_index()) + let extension_column_indices = JumpStackAuxColumn::iter() + .map(|column| column.aux_index()) .collect_vec(); let extension_column_slices = horizontal_multi_slice_mut( aux_table.view_mut(), diff --git a/triton-vm/src/table/lookup.rs b/triton-vm/src/table/lookup.rs index ed5295245..228b26636 100644 --- a/triton-vm/src/table/lookup.rs +++ b/triton-vm/src/table/lookup.rs @@ -3,8 +3,8 @@ use air::cross_table_argument::CrossTableArg; use air::cross_table_argument::EvalArg; use air::cross_table_argument::LookupArg; use air::table::lookup::LookupTable; -use air::table_column::MasterBaseTableColumn; -use air::table_column::MasterExtTableColumn; +use air::table_column::MasterAuxColumn; +use air::table_column::MasterMainColumn; use air::AIR; use itertools::Itertools; use ndarray::prelude::*; @@ -36,15 +36,15 @@ fn extension_column_cascade_running_sum_log_derivative( let mut cascade_table_running_sum_log_derivative = LookupArg::default_initial(); let mut extension_column = Vec::with_capacity(base_table.nrows()); for row in base_table.rows() { - if row[MainColumn::IsPadding.base_table_index()].is_one() { + if row[MainColumn::IsPadding.main_index()].is_one() { break; } - let lookup_input = row[MainColumn::LookIn.base_table_index()]; - let lookup_output = row[MainColumn::LookOut.base_table_index()]; + let lookup_input = row[MainColumn::LookIn.main_index()]; + let lookup_output = row[MainColumn::LookOut.main_index()]; let compressed_row = lookup_input * look_in_weight + lookup_output * look_out_weight; - let lookup_multiplicity = row[MainColumn::LookupMultiplicity.base_table_index()]; + let lookup_multiplicity = row[MainColumn::LookupMultiplicity.main_index()]; cascade_table_running_sum_log_derivative += (indeterminate - compressed_row).inverse() * lookup_multiplicity; @@ -63,13 +63,13 @@ fn extension_column_public_running_evaluation( let mut running_evaluation = EvalArg::default_initial(); let mut extension_column = Vec::with_capacity(base_table.nrows()); for row in base_table.rows() { - if row[MainColumn::IsPadding.base_table_index()].is_one() { + if row[MainColumn::IsPadding.main_index()].is_one() { break; } running_evaluation = running_evaluation * challenges[ChallengeId::LookupTablePublicIndeterminate] - + row[MainColumn::LookOut.base_table_index()]; + + row[MainColumn::LookOut.main_index()]; extension_column.push(running_evaluation); } @@ -88,18 +88,14 @@ impl TraceTable for LookupTable { // Lookup Table input let lookup_input = Array1::from_iter((0..LOOKUP_TABLE_LEN).map(|i| bfe!(i as u64))); - let lookup_input_column = main_table.slice_mut(s![ - ..LOOKUP_TABLE_LEN, - MainColumn::LookIn.base_table_index() - ]); + let lookup_input_column = + main_table.slice_mut(s![..LOOKUP_TABLE_LEN, MainColumn::LookIn.main_index()]); lookup_input.move_into(lookup_input_column); // Lookup Table output let lookup_output = Array1::from_iter(tip5::LOOKUP_TABLE.map(BFieldElement::from)); - let lookup_output_column = main_table.slice_mut(s![ - ..LOOKUP_TABLE_LEN, - MainColumn::LookOut.base_table_index() - ]); + let lookup_output_column = + main_table.slice_mut(s![..LOOKUP_TABLE_LEN, MainColumn::LookOut.main_index()]); lookup_output.move_into(lookup_output_column); // Lookup Table multiplicities @@ -109,14 +105,14 @@ impl TraceTable for LookupTable { ); let lookup_multiplicities_column = main_table.slice_mut(s![ ..LOOKUP_TABLE_LEN, - MainColumn::LookupMultiplicity.base_table_index() + MainColumn::LookupMultiplicity.main_index() ]); lookup_multiplicities.move_into(lookup_multiplicities_column); } fn pad(mut lookup_table: ArrayViewMut2, table_length: usize) { lookup_table - .slice_mut(s![table_length.., MainColumn::IsPadding.base_table_index()]) + .slice_mut(s![table_length.., MainColumn::IsPadding.main_index()]) .fill(BFieldElement::ONE); } @@ -131,7 +127,7 @@ impl TraceTable for LookupTable { assert_eq!(main_table.nrows(), aux_table.nrows()); let extension_column_indices = AuxColumn::iter() - .map(|column| column.ext_table_index()) + .map(|column| column.aux_index()) .collect_vec(); let extension_column_slices = horizontal_multi_slice_mut( aux_table.view_mut(), diff --git a/triton-vm/src/table/master_table.rs b/triton-vm/src/table/master_table.rs index 820a7ecf6..8c5f8dd4e 100644 --- a/triton-vm/src/table/master_table.rs +++ b/triton-vm/src/table/master_table.rs @@ -11,26 +11,26 @@ use air::table::program::ProgramTable; use air::table::ram::RamTable; use air::table::u32::U32Table; use air::table::TableId; +use air::table::AUX_CASCADE_TABLE_END; +use air::table::AUX_CASCADE_TABLE_START; +use air::table::AUX_HASH_TABLE_END; +use air::table::AUX_HASH_TABLE_START; +use air::table::AUX_JUMP_STACK_TABLE_END; +use air::table::AUX_JUMP_STACK_TABLE_START; +use air::table::AUX_LOOKUP_TABLE_END; +use air::table::AUX_LOOKUP_TABLE_START; +use air::table::AUX_OP_STACK_TABLE_END; +use air::table::AUX_OP_STACK_TABLE_START; +use air::table::AUX_PROCESSOR_TABLE_END; +use air::table::AUX_PROCESSOR_TABLE_START; +use air::table::AUX_PROGRAM_TABLE_END; +use air::table::AUX_PROGRAM_TABLE_START; +use air::table::AUX_RAM_TABLE_END; +use air::table::AUX_RAM_TABLE_START; +use air::table::AUX_U32_TABLE_END; +use air::table::AUX_U32_TABLE_START; use air::table::CASCADE_TABLE_END; use air::table::CASCADE_TABLE_START; -use air::table::EXT_CASCADE_TABLE_END; -use air::table::EXT_CASCADE_TABLE_START; -use air::table::EXT_HASH_TABLE_END; -use air::table::EXT_HASH_TABLE_START; -use air::table::EXT_JUMP_STACK_TABLE_END; -use air::table::EXT_JUMP_STACK_TABLE_START; -use air::table::EXT_LOOKUP_TABLE_END; -use air::table::EXT_LOOKUP_TABLE_START; -use air::table::EXT_OP_STACK_TABLE_END; -use air::table::EXT_OP_STACK_TABLE_START; -use air::table::EXT_PROCESSOR_TABLE_END; -use air::table::EXT_PROCESSOR_TABLE_START; -use air::table::EXT_PROGRAM_TABLE_END; -use air::table::EXT_PROGRAM_TABLE_START; -use air::table::EXT_RAM_TABLE_END; -use air::table::EXT_RAM_TABLE_START; -use air::table::EXT_U32_TABLE_END; -use air::table::EXT_U32_TABLE_START; use air::table::HASH_TABLE_END; use air::table::HASH_TABLE_START; use air::table::JUMP_STACK_TABLE_END; @@ -49,7 +49,7 @@ use air::table::U32_TABLE_END; use air::table::U32_TABLE_START; use air::table_column::*; use itertools::Itertools; -use master_table::extension_table::Evaluable; +use master_table::auxiliary_table::Evaluable; use ndarray::parallel::prelude::*; use ndarray::prelude::*; use ndarray::s; @@ -57,6 +57,7 @@ use ndarray::Array2; use ndarray::ArrayView2; use ndarray::ArrayViewMut2; use ndarray::Zip; +use num_traits::ConstZero; use num_traits::One; use num_traits::Zero; use rand::distributions::Standard; @@ -78,9 +79,9 @@ use crate::ndarray_helper::horizontal_multi_slice_mut; use crate::ndarray_helper::partial_sums; use crate::profiler::profiler; use crate::stark::NUM_RANDOMIZER_POLYNOMIALS; +use crate::table::auxiliary_table::all_degrees_with_origin; +use crate::table::auxiliary_table::DegreeWithOrigin; use crate::table::degree_lowering::DegreeLoweringTable; -use crate::table::extension_table::all_degrees_with_origin; -use crate::table::extension_table::DegreeWithOrigin; use crate::table::processor::ClkJumpDiffs; use crate::table::*; @@ -94,49 +95,50 @@ use crate::table::*; /// completely separate from each other. Only the [cross-table arguments][cross_arg] link all tables /// together. /// -/// Conceptually, there are two Master Tables: the [`MasterBaseTable`] ("main"), the Master +/// Conceptually, there are two Master Tables: the [`MasterMainTable`] ("main"), the Master /// Extension Table ("auxiliary"). The lifecycle of the Master Tables is /// as follows: -/// 1. The [`MasterBaseTable`] is instantiated and filled using the Algebraic Execution Trace. -/// 2. The [`MasterBaseTable`] is padded using logic from the individual tables. -/// 3. The still-empty entries in the [`MasterBaseTable`] are filled with random elements. This +/// 1. The [`MasterMainTable`] is instantiated and filled using the Algebraic Execution Trace. +/// 2. The [`MasterMainTable`] is padded using logic from the individual tables. +/// 3. The still-empty entries in the [`MasterMainTable`] are filled with random elements. This /// step is also known as “trace randomization.” -/// 4. If there is enough RAM, then each column of the [`MasterBaseTable`] is low-degree extended. -/// The results are stored on the [`MasterBaseTable`] for quick access later. +/// 4. If there is enough RAM, then each column of the [`MasterMainTable`] is low-degree extended. +/// The results are stored on the [`MasterMainTable`] for quick access later. /// If there is not enough RAM, then the low-degree extensions of the trace columns will be /// computed and sometimes recomputed just-in-time, and the memory freed afterward. /// The caching behavior [can be forced][overwrite_cache]. -/// 5. The [`MasterBaseTable`] is used to derive the [`MasterExtensionTable`][master_ext_table] +/// 5. The [`MasterMainTable`] is used to derive the [`MasterAuxiliaryTable`][master_aux_table] /// using logic from the individual tables. -/// 6. The [`MasterExtensionTable`][master_ext_table] is trace-randomized. -/// 7. Each column of the [`MasterExtensionTable`][master_ext_table] is [low-degree extended][lde]. -/// The effects are the same as for the [`MasterBaseTable`]. -/// 8. Using the [`MasterBaseTable`] and the [`MasterExtensionTable`][master_ext_table], the +/// 6. The [`MasterAuxiliaryTable`][master_aux_table] is trace-randomized. +/// 7. Each column of the [`MasterAuxiliaryTable`][master_aux_table] is [low-degree extended][lde]. +/// The effects are the same as for the [`MasterMainTable`]. +/// 8. Using the [`MasterMainTable`] and the [`MasterAuxiliaryTable`][master_aux_table], the /// [quotient codeword][master_quot_table] is derived using the AIR. Each individual table /// defines that part of the AIR that is relevant to it. /// /// The following points are of note: -/// - The [`MasterExtensionTable`][master_ext_table]'s rightmost columns are the randomizer +/// - The [`MasterMainColumns are the randomizer /// codewords. These are necessary for zero-knowledge. -/// - The cross-table argument has zero width for the [`MasterBaseTable`] and -/// [`MasterExtensionTable`][master_ext_table] but does induce a nonzero number of constraints +/// - The cross-table argument has zero width for the [`MasterMainTable`] and +/// [`MasterAuxiliaryTable`][master_aux_table] but does induce a nonzero number of constraints /// and thus terms in the [quotient combination][all_quotients_combined]. /// /// [cross_arg]: air::cross_table_argument::GrandCrossTableArg /// [overwrite_cache]: crate::config::overwrite_lde_trace_caching_to /// [lde]: Self::low_degree_extend_all_columns /// [quot_table]: Self::quotient_domain_table -/// [master_ext_table]: MasterExtTable +/// [master_aux_table]: MasterAuxTable /// [master_quot_table]: all_quotients_combined -pub trait MasterTable: Sync +pub trait MasterTable: Sync where - FF: FiniteField + Standard: Distribution, +{ + type Field: FiniteField + MulAssign + From + BFieldCodec - + Mul, - Standard: Distribution, -{ + + Mul; + const NUM_COLUMNS: usize; fn trace_domain(&self) -> ArithmeticDomain; @@ -161,15 +163,15 @@ where } /// Presents underlying trace data, excluding trace randomizers and randomizer polynomials. - fn trace_table(&self) -> ArrayView2; + fn trace_table(&self) -> ArrayView2; /// Mutably presents underlying trace data, excluding trace randomizers and randomizer /// polynomials. - fn trace_table_mut(&mut self) -> ArrayViewMut2; + fn trace_table_mut(&mut self) -> ArrayViewMut2; - fn randomized_trace_table(&self) -> ArrayView2; + fn randomized_trace_table(&self) -> ArrayView2; - fn randomized_trace_table_mut(&mut self) -> ArrayViewMut2; + fn randomized_trace_table_mut(&mut self) -> ArrayViewMut2; /// The quotient-domain view of the cached low-degree-extended table, if /// 1. the table has been [low-degree extended][lde], and @@ -181,7 +183,7 @@ where // pointer to an array that must live somewhere and cannot live on the stack. // From the trait implementation we cannot access the implementing object's // fields. - fn quotient_domain_table(&self) -> Option>; + fn quotient_domain_table(&self) -> Option>; /// Set all rows _not_ part of the actual (padded) trace to random values. fn randomize_trace(&mut self) { @@ -189,7 +191,7 @@ where (1..unit_distance).for_each(|offset| { self.randomized_trace_table_mut() .slice_mut(s![offset..; unit_distance, ..]) - .par_mapv_inplace(|_| random::()) + .par_mapv_inplace(|_| random::()) }); } @@ -219,7 +221,7 @@ where }); profiler!(stop "interpolation"); - let mut extended_trace = Vec::::with_capacity(0); + let mut extended_trace = Vec::::with_capacity(0); let num_elements = num_rows * Self::NUM_COLUMNS; let should_cache = match crate::config::cache_lde_trace() { Some(CacheDecision::NoCache) => false, @@ -236,7 +238,7 @@ where extended_trace .spare_capacity_mut() .par_iter_mut() - .for_each(|e| *e = MaybeUninit::new(FF::zero())); + .for_each(|e| *e = MaybeUninit::new(Self::Field::ZERO)); unsafe { // Speed up initialization through parallelization. @@ -269,29 +271,32 @@ where /// Not intended for direct use, but through [`Self::low_degree_extend_all_columns`]. #[doc(hidden)] - fn memoize_low_degree_extended_table(&mut self, low_degree_extended_columns: Array2); + fn memoize_low_degree_extended_table( + &mut self, + low_degree_extended_columns: Array2, + ); /// Return the cached low-degree-extended table, if any. - fn low_degree_extended_table(&self) -> Option>; + fn low_degree_extended_table(&self) -> Option>; /// Return the FRI domain view of the cached low-degree-extended table, if any. /// /// This method cannot be implemented generically on the trait because it returns a pointer to /// an array and that array has to live somewhere; it cannot live on stack and from the trait /// implementation we cannot access the implementing object's fields. - fn fri_domain_table(&self) -> Option>; + fn fri_domain_table(&self) -> Option>; /// Memoize the polynomials interpolating the columns. /// Not intended for direct use, but through [`Self::low_degree_extend_all_columns`]. #[doc(hidden)] fn memoize_interpolation_polynomials( &mut self, - interpolation_polynomials: Array1>, + interpolation_polynomials: Array1>, ); /// Requires having called /// [`low_degree_extend_all_columns`](Self::low_degree_extend_all_columns) first. - fn interpolation_polynomials(&self) -> ArrayView1>; + fn interpolation_polynomials(&self) -> ArrayView1>; /// Get one row of the table at an arbitrary index. Notably, the index does not have to be in /// any of the domains. In other words, can be used to compute out-of-domain rows. Requires @@ -322,7 +327,7 @@ where } } - fn hash_one_row(row: ArrayView1) -> Digest { + fn hash_one_row(row: ArrayView1) -> Digest { Tip5::hash_varlen(&row.iter().flat_map(|e| e.encode()).collect_vec()) } @@ -421,7 +426,7 @@ impl SpongeWithPendingAbsorb { /// See [`MasterTable`]. #[derive(Debug, Clone)] -pub struct MasterBaseTable { +pub struct MasterMainTable { pub num_trace_randomizers: usize, program_table_len: usize, @@ -444,7 +449,7 @@ pub struct MasterBaseTable { /// See [`MasterTable`]. #[derive(Debug, Clone)] -pub struct MasterExtTable { +pub struct MasterAuxTable { pub num_trace_randomizers: usize, trace_domain: ArithmeticDomain, @@ -457,8 +462,10 @@ pub struct MasterExtTable { interpolation_polynomials: Option>>, } -impl MasterTable for MasterBaseTable { - const NUM_COLUMNS: usize = NUM_MAIN_COLUMNS; +impl MasterTable for MasterMainTable { + type Field = BFieldElement; + const NUM_COLUMNS: usize = + air::table::NUM_MAIN_COLUMNS + degree_lowering::DegreeLoweringMainColumn::COUNT; fn trace_domain(&self) -> ArithmeticDomain { self.trace_domain @@ -563,8 +570,11 @@ impl MasterTable for MasterBaseTable { } } -impl MasterTable for MasterExtTable { - const NUM_COLUMNS: usize = NUM_AUX_COLUMNS + NUM_RANDOMIZER_POLYNOMIALS; +impl MasterTable for MasterAuxTable { + type Field = XFieldElement; + const NUM_COLUMNS: usize = air::table::NUM_AUX_COLUMNS + + degree_lowering::DegreeLoweringAuxColumn::COUNT + + NUM_RANDOMIZER_POLYNOMIALS; fn trace_domain(&self) -> ArithmeticDomain { self.trace_domain @@ -663,7 +673,7 @@ impl MasterTable for MasterExtTable { type PadFunction = fn(ArrayViewMut2, usize); type ExtendFunction = fn(ArrayView2, ArrayViewMut2, &Challenges); -impl MasterBaseTable { +impl MasterMainTable { pub fn new( aet: &AlgebraicExecutionTrace, num_trace_randomizers: usize, @@ -744,15 +754,15 @@ impl MasterBaseTable { let base_tables: [_; TableId::COUNT] = horizontal_multi_slice_mut( master_table_without_randomizers, &partial_sums(&[ - ProgramBaseTableColumn::COUNT, - ProcessorBaseTableColumn::COUNT, - OpStackBaseTableColumn::COUNT, - RamBaseTableColumn::COUNT, - JumpStackBaseTableColumn::COUNT, - HashBaseTableColumn::COUNT, - CascadeBaseTableColumn::COUNT, - LookupBaseTableColumn::COUNT, - U32BaseTableColumn::COUNT, + ProgramMainColumn::COUNT, + ProcessorMainColumn::COUNT, + OpStackMainColumn::COUNT, + RamMainColumn::COUNT, + JumpStackMainColumn::COUNT, + HashMainColumn::COUNT, + CascadeMainColumn::COUNT, + LookupMainColumn::COUNT, + U32MainColumn::COUNT, ]), ) .try_into() @@ -780,7 +790,7 @@ impl MasterBaseTable { profiler!(stop "pad original tables"); profiler!(start "fill degree-lowering table"); - DegreeLoweringTable::fill_derived_base_columns(self.trace_table_mut()); + DegreeLoweringTable::fill_derived_main_columns(self.trace_table_mut()); profiler!(stop "fill degree-lowering table"); } @@ -801,24 +811,24 @@ impl MasterBaseTable { ] } - /// Create a `MasterExtTable` from a `MasterBaseTable` by `.extend()`ing each individual base + /// Create a `MasterAuxTable` from a `MasterMainTable` by `.extend()`ing each individual base /// table. The `.extend()` for each table is specific to that table, but always involves /// adding some number of columns. - pub fn extend(&self, challenges: &Challenges) -> MasterExtTable { + pub fn extend(&self, challenges: &Challenges) -> MasterAuxTable { // randomizer polynomials let num_rows = self.randomized_trace_table().nrows(); profiler!(start "initialize master table"); - let num_aux_columns = MasterExtTable::NUM_COLUMNS; + let num_aux_columns = MasterAuxTable::NUM_COLUMNS; let mut randomized_trace_extension_table = fast_zeros_column_major(num_rows, num_aux_columns); - let randomizers_start = MasterExtTable::NUM_COLUMNS - NUM_RANDOMIZER_POLYNOMIALS; + let randomizers_start = MasterAuxTable::NUM_COLUMNS - NUM_RANDOMIZER_POLYNOMIALS; randomized_trace_extension_table .slice_mut(s![.., randomizers_start..]) .par_mapv_inplace(|_| random::()); profiler!(stop "initialize master table"); - let mut master_ext_table = MasterExtTable { + let mut master_ext_table = MasterAuxTable { num_trace_randomizers: self.num_trace_randomizers, trace_domain: self.trace_domain(), randomized_trace_domain: self.randomized_trace_domain(), @@ -837,15 +847,15 @@ impl MasterBaseTable { let extension_tables: [_; TableId::COUNT] = horizontal_multi_slice_mut( master_ext_table_without_randomizers, &partial_sums(&[ - ProgramExtTableColumn::COUNT, - ProcessorExtTableColumn::COUNT, - OpStackExtTableColumn::COUNT, - RamExtTableColumn::COUNT, - JumpStackExtTableColumn::COUNT, - HashExtTableColumn::COUNT, - CascadeExtTableColumn::COUNT, - LookupExtTableColumn::COUNT, - U32ExtTableColumn::COUNT, + ProgramAuxColumn::COUNT, + ProcessorAuxColumn::COUNT, + OpStackAuxColumn::COUNT, + RamAuxColumn::COUNT, + JumpStackAuxColumn::COUNT, + HashAuxColumn::COUNT, + CascadeAuxColumn::COUNT, + LookupAuxColumn::COUNT, + U32AuxColumn::COUNT, ]), ) .try_into() @@ -863,7 +873,7 @@ impl MasterBaseTable { profiler!(stop "all tables"); profiler!(start "fill degree lowering table"); - DegreeLoweringTable::fill_derived_ext_columns( + DegreeLoweringTable::fill_derived_aux_columns( self.trace_table(), master_ext_table.trace_table_mut(), challenges, @@ -932,9 +942,9 @@ impl MasterBaseTable { .slice_mut(s![..; unit_distance, column_indices]) } - pub(crate) fn try_to_base_row( + pub(crate) fn try_to_main_row( row: Array1, - ) -> Result, ProvingError> { + ) -> Result, ProvingError> { let err = || ProvingError::TableRowConversionError { expected_len: Self::NUM_COLUMNS, actual_len: row.len(), @@ -943,19 +953,19 @@ impl MasterBaseTable { } } -impl MasterExtTable { +impl MasterAuxTable { fn column_indices_for_table(id: TableId) -> Range { use TableId::*; match id { - Program => EXT_PROGRAM_TABLE_START..EXT_PROGRAM_TABLE_END, - Processor => EXT_PROCESSOR_TABLE_START..EXT_PROCESSOR_TABLE_END, - OpStack => EXT_OP_STACK_TABLE_START..EXT_OP_STACK_TABLE_END, - Ram => EXT_RAM_TABLE_START..EXT_RAM_TABLE_END, - JumpStack => EXT_JUMP_STACK_TABLE_START..EXT_JUMP_STACK_TABLE_END, - Hash => EXT_HASH_TABLE_START..EXT_HASH_TABLE_END, - Cascade => EXT_CASCADE_TABLE_START..EXT_CASCADE_TABLE_END, - Lookup => EXT_LOOKUP_TABLE_START..EXT_LOOKUP_TABLE_END, - U32 => EXT_U32_TABLE_START..EXT_U32_TABLE_END, + Program => AUX_PROGRAM_TABLE_START..AUX_PROGRAM_TABLE_END, + Processor => AUX_PROCESSOR_TABLE_START..AUX_PROCESSOR_TABLE_END, + OpStack => AUX_OP_STACK_TABLE_START..AUX_OP_STACK_TABLE_END, + Ram => AUX_RAM_TABLE_START..AUX_RAM_TABLE_END, + JumpStack => AUX_JUMP_STACK_TABLE_START..AUX_JUMP_STACK_TABLE_END, + Hash => AUX_HASH_TABLE_START..AUX_HASH_TABLE_END, + Cascade => AUX_CASCADE_TABLE_START..AUX_CASCADE_TABLE_END, + Lookup => AUX_LOOKUP_TABLE_START..AUX_LOOKUP_TABLE_END, + U32 => AUX_U32_TABLE_START..AUX_U32_TABLE_END, } } @@ -975,7 +985,7 @@ impl MasterExtTable { .slice_mut(s![..; unit_distance, column_indices]) } - pub(crate) fn try_to_ext_row(row: Array1) -> Result { + pub(crate) fn try_to_aux_row(row: Array1) -> Result { let err = || ProvingError::TableRowConversionError { expected_len: Self::NUM_COLUMNS, actual_len: row.len(), @@ -1082,11 +1092,11 @@ pub fn all_quotients_combined( quotient_domain.length, quotient_domain_master_ext_table.nrows() ); - assert_eq!(MasterExtTable::NUM_CONSTRAINTS, quotient_weights.len()); + assert_eq!(MasterAuxTable::NUM_CONSTRAINTS, quotient_weights.len()); - let init_section_end = MasterExtTable::NUM_INITIAL_CONSTRAINTS; - let cons_section_end = init_section_end + MasterExtTable::NUM_CONSISTENCY_CONSTRAINTS; - let tran_section_end = cons_section_end + MasterExtTable::NUM_TRANSITION_CONSTRAINTS; + let init_section_end = MasterAuxTable::NUM_INITIAL_CONSTRAINTS; + let cons_section_end = init_section_end + MasterAuxTable::NUM_CONSISTENCY_CONSTRAINTS; + let tran_section_end = cons_section_end + MasterAuxTable::NUM_TRANSITION_CONSTRAINTS; profiler!(start "zerofier inverse"); let initial_zerofier_inverse = initial_quotient_zerofier_inverse(quotient_domain); @@ -1114,7 +1124,7 @@ pub fn all_quotients_combined( let next_row_main = quotient_domain_master_base_table.row(next_row_index); let next_row_aux = quotient_domain_master_ext_table.row(next_row_index); - let initial_constraint_values = MasterExtTable::evaluate_initial_constraints( + let initial_constraint_values = MasterAuxTable::evaluate_initial_constraints( current_row_main, current_row_aux, challenges, @@ -1125,7 +1135,7 @@ pub fn all_quotients_combined( ); let mut quotient_value = initial_inner_product * initial_zerofier_inverse[row_index]; - let consistency_constraint_values = MasterExtTable::evaluate_consistency_constraints( + let consistency_constraint_values = MasterAuxTable::evaluate_consistency_constraints( current_row_main, current_row_aux, challenges, @@ -1136,7 +1146,7 @@ pub fn all_quotients_combined( ); quotient_value += consistency_inner_product * consistency_zerofier_inverse[row_index]; - let transition_constraint_values = MasterExtTable::evaluate_transition_constraints( + let transition_constraint_values = MasterAuxTable::evaluate_transition_constraints( current_row_main, current_row_aux, next_row_main, @@ -1149,7 +1159,7 @@ pub fn all_quotients_combined( ); quotient_value += transition_inner_product * transition_zerofier_inverse[row_index]; - let terminal_constraint_values = MasterExtTable::evaluate_terminal_constraints( + let terminal_constraint_values = MasterAuxTable::evaluate_terminal_constraints( current_row_main, current_row_aux, challenges, @@ -1213,8 +1223,8 @@ mod tests { use crate::memory_layout::StaticTasmConstraintEvaluationMemoryLayout; use crate::shared_tests::ProgramAndInput; use crate::stark::tests::*; - use crate::table::degree_lowering::DegreeLoweringBaseTableColumn; - use crate::table::degree_lowering::DegreeLoweringExtTableColumn; + use crate::table::degree_lowering::DegreeLoweringAuxColumn; + use crate::table::degree_lowering::DegreeLoweringMainColumn; use crate::table::*; use crate::triton_program; @@ -1461,8 +1471,8 @@ mod tests { let degree_lowering_info = DegreeLoweringInfo { target_degree, - num_base_cols: 0, - num_ext_cols: 0, + num_main_cols: 0, + num_aux_cols: 0, }; // generic closures are not possible; define two variants :( @@ -1568,12 +1578,12 @@ mod tests { pub consistency_constraints: Vec>, pub transition_constraints: Vec>, pub terminal_constraints: Vec>, - pub last_base_column_index: usize, - pub last_ext_column_index: usize, + pub last_main_column_index: usize, + pub last_aux_column_index: usize, } macro_rules! constraint_overview_rows { - ($($table:ident ends at $base_end:ident and $ext_end: ident. + ($($table:ident ends at $main_end:ident and $aux_end: ident. Spec: [$spec_name:literal]($spec_file:literal)),* $(,)?) => {{ let single_row_builder = || ConstraintCircuitBuilder::new(); let dual_row_builder = || ConstraintCircuitBuilder::new(); @@ -1586,8 +1596,8 @@ mod tests { consistency_constraints: $table::consistency_constraints(&single_row_builder()), transition_constraints: $table::transition_constraints(&dual_row_builder()), terminal_constraints: $table::terminal_constraints(&single_row_builder()), - last_base_column_index: $base_end, - last_ext_column_index: $ext_end, + last_main_column_index: $main_end, + last_aux_column_index: $aux_end, }; rows.push(row); )* @@ -1635,23 +1645,23 @@ mod tests { let mut total_max_degree = 0; let mut tables = constraint_overview_rows!( - ProgramTable ends at PROGRAM_TABLE_END and EXT_PROGRAM_TABLE_END. + ProgramTable ends at PROGRAM_TABLE_END and AUX_PROGRAM_TABLE_END. Spec: ["ProgramTable"]("program-table.md"), - ProcessorTable ends at PROCESSOR_TABLE_END and EXT_PROCESSOR_TABLE_END. + ProcessorTable ends at PROCESSOR_TABLE_END and AUX_PROCESSOR_TABLE_END. Spec: ["ProcessorTable"]("processor-table.md"), - OpStackTable ends at OP_STACK_TABLE_END and EXT_OP_STACK_TABLE_END. + OpStackTable ends at OP_STACK_TABLE_END and AUX_OP_STACK_TABLE_END. Spec: ["OpStackTable"]("operational-stack-table.md"), - RamTable ends at RAM_TABLE_END and EXT_RAM_TABLE_END. + RamTable ends at RAM_TABLE_END and AUX_RAM_TABLE_END. Spec: ["RamTable"]("random-access-memory-table.md"), - JumpStackTable ends at JUMP_STACK_TABLE_END and EXT_JUMP_STACK_TABLE_END. + JumpStackTable ends at JUMP_STACK_TABLE_END and AUX_JUMP_STACK_TABLE_END. Spec: ["JumpStackTable"]("jump-stack-table.md"), - HashTable ends at HASH_TABLE_END and EXT_HASH_TABLE_END. + HashTable ends at HASH_TABLE_END and AUX_HASH_TABLE_END. Spec: ["HashTable"]("hash-table.md"), - CascadeTable ends at CASCADE_TABLE_END and EXT_CASCADE_TABLE_END. + CascadeTable ends at CASCADE_TABLE_END and AUX_CASCADE_TABLE_END. Spec: ["CascadeTable"]("cascade-table.md"), - LookupTable ends at LOOKUP_TABLE_END and EXT_LOOKUP_TABLE_END. + LookupTable ends at LOOKUP_TABLE_END and AUX_LOOKUP_TABLE_END. Spec: ["LookupTable"]("lookup-table.md"), - U32Table ends at U32_TABLE_END and EXT_U32_TABLE_END. + U32Table ends at U32_TABLE_END and AUX_U32_TABLE_END. Spec: ["U32Table"]("u32-table.md"), GrandCrossTableArg ends at ZERO and ZERO. Spec: ["Grand Cross-Table Argument"]("table-linking.md"), @@ -1669,8 +1679,8 @@ mod tests { if let Some(target_degree) = target_degree { let info = DegreeLoweringInfo { target_degree, - num_base_cols: table.last_base_column_index, - num_ext_cols: table.last_ext_column_index, + num_main_cols: table.last_main_column_index, + num_aux_cols: table.last_aux_column_index, }; let (new_base_init, new_ext_init) = ConstraintCircuitMonad::lower_to_degree( &mut table.initial_constraints, @@ -1780,10 +1790,10 @@ mod tests { fn generate_tasm_air_evaluation_cost_overview() -> SpecSnippet { let dummy_static_layout = StaticTasmConstraintEvaluationMemoryLayout { free_mem_page_ptr: BFieldElement::ZERO, - curr_base_row_ptr: BFieldElement::ZERO, - curr_ext_row_ptr: BFieldElement::ZERO, - next_base_row_ptr: BFieldElement::ZERO, - next_ext_row_ptr: BFieldElement::ZERO, + curr_main_row_ptr: BFieldElement::ZERO, + curr_aux_row_ptr: BFieldElement::ZERO, + next_main_row_ptr: BFieldElement::ZERO, + next_aux_row_ptr: BFieldElement::ZERO, challenges_ptr: BFieldElement::ZERO, }; let dummy_dynamic_layout = DynamicTasmConstraintEvaluationMemoryLayout { @@ -1929,14 +1939,14 @@ mod tests { macro_rules! print_columns { (base $table:ident for $name:literal) => {{ for column in $table::iter() { - let idx = column.master_base_table_index(); + let idx = column.master_main_index(); let name = $name; println!("| {idx:>3} | {name:<11} | {column}"); } }}; (ext $table:ident for $name:literal) => {{ for column in $table::iter() { - let idx = column.master_ext_table_index(); + let idx = column.master_aux_index(); let name = $name; println!("| {idx:>3} | {name:<11} | {column}"); } @@ -1944,32 +1954,32 @@ mod tests { } println!(); - println!("| idx | table | base column"); + println!("| idx | table | main column"); println!("|----:|:------------|:-----------"); - print_columns!(base ProgramBaseTableColumn for "program"); - print_columns!(base ProcessorBaseTableColumn for "processor"); - print_columns!(base OpStackBaseTableColumn for "op stack"); - print_columns!(base RamBaseTableColumn for "ram"); - print_columns!(base JumpStackBaseTableColumn for "jump stack"); - print_columns!(base HashBaseTableColumn for "hash"); - print_columns!(base CascadeBaseTableColumn for "cascade"); - print_columns!(base LookupBaseTableColumn for "lookup"); - print_columns!(base U32BaseTableColumn for "u32"); - print_columns!(base DegreeLoweringBaseTableColumn for "degree low."); + print_columns!(base ProgramMainColumn for "program"); + print_columns!(base ProcessorMainColumn for "processor"); + print_columns!(base OpStackMainColumn for "op stack"); + print_columns!(base RamMainColumn for "ram"); + print_columns!(base JumpStackMainColumn for "jump stack"); + print_columns!(base HashMainColumn for "hash"); + print_columns!(base CascadeMainColumn for "cascade"); + print_columns!(base LookupMainColumn for "lookup"); + print_columns!(base U32MainColumn for "u32"); + print_columns!(base DegreeLoweringMainColumn for "degree low."); println!(); - println!("| idx | table | extension column"); + println!("| idx | table | auxiliary column"); println!("|----:|:------------|:----------------"); - print_columns!(ext ProgramExtTableColumn for "program"); - print_columns!(ext ProcessorExtTableColumn for "processor"); - print_columns!(ext OpStackExtTableColumn for "op stack"); - print_columns!(ext RamExtTableColumn for "ram"); - print_columns!(ext JumpStackExtTableColumn for "jump stack"); - print_columns!(ext HashExtTableColumn for "hash"); - print_columns!(ext CascadeExtTableColumn for "cascade"); - print_columns!(ext LookupExtTableColumn for "lookup"); - print_columns!(ext U32ExtTableColumn for "u32"); - print_columns!(ext DegreeLoweringExtTableColumn for "degree low."); + print_columns!(ext ProgramAuxColumn for "program"); + print_columns!(ext ProcessorAuxColumn for "processor"); + print_columns!(ext OpStackAuxColumn for "op stack"); + print_columns!(ext RamAuxColumn for "ram"); + print_columns!(ext JumpStackAuxColumn for "jump stack"); + print_columns!(ext HashAuxColumn for "hash"); + print_columns!(ext CascadeAuxColumn for "cascade"); + print_columns!(ext LookupAuxColumn for "lookup"); + print_columns!(ext U32AuxColumn for "u32"); + print_columns!(ext DegreeLoweringAuxColumn for "degree low."); } #[test] @@ -1980,9 +1990,9 @@ mod tests { let fri_domain = ArithmeticDomain::of_length(1 << 11).unwrap(); let randomized_trace_table = - Array2::zeros((randomized_trace_domain.length, NUM_AUX_COLUMNS)); + Array2::zeros((randomized_trace_domain.length, MasterAuxTable::NUM_COLUMNS)); - let mut master_table = MasterExtTable { + let mut master_table = MasterAuxTable { num_trace_randomizers: 16, trace_domain, randomized_trace_domain, @@ -1994,23 +2004,23 @@ mod tests { }; let num_rows = trace_domain.length; - Array2::from_elem((num_rows, ProgramExtTableColumn::COUNT), 1.into()) + Array2::from_elem((num_rows, ProgramAuxColumn::COUNT), 1.into()) .move_into(&mut master_table.table_mut(TableId::Program)); - Array2::from_elem((num_rows, ProcessorExtTableColumn::COUNT), 2.into()) + Array2::from_elem((num_rows, ProcessorAuxColumn::COUNT), 2.into()) .move_into(&mut master_table.table_mut(TableId::Processor)); - Array2::from_elem((num_rows, OpStackExtTableColumn::COUNT), 3.into()) + Array2::from_elem((num_rows, OpStackAuxColumn::COUNT), 3.into()) .move_into(&mut master_table.table_mut(TableId::OpStack)); - Array2::from_elem((num_rows, RamExtTableColumn::COUNT), 4.into()) + Array2::from_elem((num_rows, RamAuxColumn::COUNT), 4.into()) .move_into(&mut master_table.table_mut(TableId::Ram)); - Array2::from_elem((num_rows, JumpStackExtTableColumn::COUNT), 5.into()) + Array2::from_elem((num_rows, JumpStackAuxColumn::COUNT), 5.into()) .move_into(&mut master_table.table_mut(TableId::JumpStack)); - Array2::from_elem((num_rows, HashExtTableColumn::COUNT), 6.into()) + Array2::from_elem((num_rows, HashAuxColumn::COUNT), 6.into()) .move_into(&mut master_table.table_mut(TableId::Hash)); - Array2::from_elem((num_rows, CascadeExtTableColumn::COUNT), 7.into()) + Array2::from_elem((num_rows, CascadeAuxColumn::COUNT), 7.into()) .move_into(&mut master_table.table_mut(TableId::Cascade)); - Array2::from_elem((num_rows, LookupExtTableColumn::COUNT), 8.into()) + Array2::from_elem((num_rows, LookupAuxColumn::COUNT), 8.into()) .move_into(&mut master_table.table_mut(TableId::Lookup)); - Array2::from_elem((num_rows, U32ExtTableColumn::COUNT), 9.into()) + Array2::from_elem((num_rows, U32AuxColumn::COUNT), 9.into()) .move_into(&mut master_table.table_mut(TableId::U32)); let trace_domain_element = |column| { @@ -2024,29 +2034,29 @@ mod tests { xfe.unlift().unwrap().value() }; - assert_eq!(1, trace_domain_element(EXT_PROGRAM_TABLE_START)); - assert_eq!(2, trace_domain_element(EXT_PROCESSOR_TABLE_START)); - assert_eq!(3, trace_domain_element(EXT_OP_STACK_TABLE_START)); - assert_eq!(4, trace_domain_element(EXT_RAM_TABLE_START)); - assert_eq!(5, trace_domain_element(EXT_JUMP_STACK_TABLE_START)); - assert_eq!(6, trace_domain_element(EXT_HASH_TABLE_START)); - assert_eq!(7, trace_domain_element(EXT_CASCADE_TABLE_START)); - assert_eq!(8, trace_domain_element(EXT_LOOKUP_TABLE_START)); - assert_eq!(9, trace_domain_element(EXT_U32_TABLE_START)); - - assert_eq!(0, not_trace_domain_element(EXT_PROGRAM_TABLE_START)); - assert_eq!(0, not_trace_domain_element(EXT_PROCESSOR_TABLE_START)); - assert_eq!(0, not_trace_domain_element(EXT_OP_STACK_TABLE_START)); - assert_eq!(0, not_trace_domain_element(EXT_RAM_TABLE_START)); - assert_eq!(0, not_trace_domain_element(EXT_JUMP_STACK_TABLE_START)); - assert_eq!(0, not_trace_domain_element(EXT_HASH_TABLE_START)); - assert_eq!(0, not_trace_domain_element(EXT_CASCADE_TABLE_START)); - assert_eq!(0, not_trace_domain_element(EXT_LOOKUP_TABLE_START)); - assert_eq!(0, not_trace_domain_element(EXT_U32_TABLE_START)); + assert_eq!(1, trace_domain_element(AUX_PROGRAM_TABLE_START)); + assert_eq!(2, trace_domain_element(AUX_PROCESSOR_TABLE_START)); + assert_eq!(3, trace_domain_element(AUX_OP_STACK_TABLE_START)); + assert_eq!(4, trace_domain_element(AUX_RAM_TABLE_START)); + assert_eq!(5, trace_domain_element(AUX_JUMP_STACK_TABLE_START)); + assert_eq!(6, trace_domain_element(AUX_HASH_TABLE_START)); + assert_eq!(7, trace_domain_element(AUX_CASCADE_TABLE_START)); + assert_eq!(8, trace_domain_element(AUX_LOOKUP_TABLE_START)); + assert_eq!(9, trace_domain_element(AUX_U32_TABLE_START)); + + assert_eq!(0, not_trace_domain_element(AUX_PROGRAM_TABLE_START)); + assert_eq!(0, not_trace_domain_element(AUX_PROCESSOR_TABLE_START)); + assert_eq!(0, not_trace_domain_element(AUX_OP_STACK_TABLE_START)); + assert_eq!(0, not_trace_domain_element(AUX_RAM_TABLE_START)); + assert_eq!(0, not_trace_domain_element(AUX_JUMP_STACK_TABLE_START)); + assert_eq!(0, not_trace_domain_element(AUX_HASH_TABLE_START)); + assert_eq!(0, not_trace_domain_element(AUX_CASCADE_TABLE_START)); + assert_eq!(0, not_trace_domain_element(AUX_LOOKUP_TABLE_START)); + assert_eq!(0, not_trace_domain_element(AUX_U32_TABLE_START)); } #[proptest] - fn test_sponge_with_pending_absorb( + fn sponge_with_pending_absorb_is_equivalent_to_usual_sponge( #[strategy(arb())] elements: Vec, #[strategy(0_usize..=#elements.len())] substring_index: usize, ) { diff --git a/triton-vm/src/table/op_stack.rs b/triton-vm/src/table/op_stack.rs index f6341002f..21df5efb1 100644 --- a/triton-vm/src/table/op_stack.rs +++ b/triton-vm/src/table/op_stack.rs @@ -82,7 +82,7 @@ impl OpStackTableEntry { op_stack_table_entries } - pub fn to_base_table_row(self) -> Array1 { + pub fn to_main_table_row(self) -> Array1 { let shrink_stack_indicator = if self.shrinks_stack() { bfe!(1) } else { @@ -90,10 +90,10 @@ impl OpStackTableEntry { }; let mut row = Array1::zeros(::MainColumn::COUNT); - row[MainColumn::CLK.base_table_index()] = self.clk.into(); - row[MainColumn::IB1ShrinkStack.base_table_index()] = shrink_stack_indicator; - row[MainColumn::StackPointer.base_table_index()] = self.op_stack_pointer; - row[MainColumn::FirstUnderflowElement.base_table_index()] = self.underflow_io.payload(); + row[MainColumn::CLK.main_index()] = self.clk.into(); + row[MainColumn::IB1ShrinkStack.main_index()] = shrink_stack_indicator; + row[MainColumn::StackPointer.main_index()] = self.op_stack_pointer; + row[MainColumn::FirstUnderflowElement.main_index()] = self.underflow_io.payload(); row } } @@ -107,13 +107,11 @@ fn extension_column_running_product_permutation_argument( let mut running_product = PermArg::default_initial(); let mut extension_column = Vec::with_capacity(base_table.nrows()); for row in base_table.rows() { - if row[MainColumn::IB1ShrinkStack.base_table_index()] != PADDING_VALUE { - let compressed_row = row[MainColumn::CLK.base_table_index()] - * challenges[OpStackClkWeight] - + row[MainColumn::IB1ShrinkStack.base_table_index()] * challenges[OpStackIb1Weight] - + row[MainColumn::StackPointer.base_table_index()] - * challenges[OpStackPointerWeight] - + row[MainColumn::FirstUnderflowElement.base_table_index()] + if row[MainColumn::IB1ShrinkStack.main_index()] != PADDING_VALUE { + let compressed_row = row[MainColumn::CLK.main_index()] * challenges[OpStackClkWeight] + + row[MainColumn::IB1ShrinkStack.main_index()] * challenges[OpStackIb1Weight] + + row[MainColumn::StackPointer.main_index()] * challenges[OpStackPointerWeight] + + row[MainColumn::FirstUnderflowElement.main_index()] * challenges[OpStackFirstUnderflowElementWeight]; running_product *= perm_arg_indeterminate - compressed_row; } @@ -144,15 +142,15 @@ fn extension_column_clock_jump_diff_lookup_log_derivative( let mut extension_column = Vec::with_capacity(base_table.nrows()); extension_column.push(cjd_lookup_log_derivative); for (previous_row, current_row) in base_table.rows().into_iter().tuple_windows() { - if current_row[MainColumn::IB1ShrinkStack.base_table_index()] == PADDING_VALUE { + if current_row[MainColumn::IB1ShrinkStack.main_index()] == PADDING_VALUE { break; }; - let previous_stack_pointer = previous_row[MainColumn::StackPointer.base_table_index()]; - let current_stack_pointer = current_row[MainColumn::StackPointer.base_table_index()]; + let previous_stack_pointer = previous_row[MainColumn::StackPointer.main_index()]; + let current_stack_pointer = current_row[MainColumn::StackPointer.main_index()]; if previous_stack_pointer == current_stack_pointer { - let previous_clock = previous_row[MainColumn::CLK.base_table_index()]; - let current_clock = current_row[MainColumn::CLK.base_table_index()]; + let previous_clock = previous_row[MainColumn::CLK.main_index()]; + let current_clock = current_row[MainColumn::CLK.main_index()]; let clock_jump_difference = current_clock - previous_clock; let &mut inverse = inverses_dictionary .entry(clock_jump_difference) @@ -192,10 +190,10 @@ impl TraceTable for OpStackTable { fn pad(mut op_stack_table: ArrayViewMut2, op_stack_table_len: usize) { let last_row_index = op_stack_table_len.saturating_sub(1); let mut padding_row = op_stack_table.row(last_row_index).to_owned(); - padding_row[MainColumn::IB1ShrinkStack.base_table_index()] = PADDING_VALUE; + padding_row[MainColumn::IB1ShrinkStack.main_index()] = PADDING_VALUE; if op_stack_table_len == 0 { let first_stack_pointer = u32::try_from(OpStackElement::COUNT).unwrap().into(); - padding_row[MainColumn::StackPointer.base_table_index()] = first_stack_pointer; + padding_row[MainColumn::StackPointer.main_index()] = first_stack_pointer; } let mut padding_section = op_stack_table.slice_mut(s![op_stack_table_len.., ..]); @@ -215,8 +213,8 @@ impl TraceTable for OpStackTable { assert_eq!(AuxColumn::COUNT, aux_table.ncols()); assert_eq!(main_table.nrows(), aux_table.nrows()); - let extension_column_indices = OpStackExtTableColumn::iter() - .map(|column| column.ext_table_index()) + let extension_column_indices = OpStackAuxColumn::iter() + .map(|column| column.aux_index()) .collect_vec(); let extension_column_slices = horizontal_multi_slice_mut( aux_table.view_mut(), @@ -239,12 +237,12 @@ impl TraceTable for OpStackTable { } fn compare_rows(row_0: ArrayView1, row_1: ArrayView1) -> Ordering { - let stack_pointer_0 = row_0[MainColumn::StackPointer.base_table_index()].value(); - let stack_pointer_1 = row_1[MainColumn::StackPointer.base_table_index()].value(); + let stack_pointer_0 = row_0[MainColumn::StackPointer.main_index()].value(); + let stack_pointer_1 = row_1[MainColumn::StackPointer.main_index()].value(); let compare_stack_pointers = stack_pointer_0.cmp(&stack_pointer_1); - let clk_0 = row_0[MainColumn::CLK.base_table_index()].value(); - let clk_1 = row_1[MainColumn::CLK.base_table_index()].value(); + let clk_0 = row_0[MainColumn::CLK.main_index()].value(); + let clk_1 = row_1[MainColumn::CLK.main_index()].value(); let compare_clocks = clk_0.cmp(&clk_1); compare_stack_pointers.then(compare_clocks) @@ -255,11 +253,11 @@ fn clock_jump_differences(op_stack_table: ArrayView2) -> Vec::MainColumn::COUNT; let mut row_0 = Array1::zeros(BASE_WIDTH); - row_0[MainColumn::StackPointer.base_table_index()] = stack_pointer_0.into(); - row_0[MainColumn::CLK.base_table_index()] = clk.into(); + row_0[MainColumn::StackPointer.main_index()] = stack_pointer_0.into(); + row_0[MainColumn::CLK.main_index()] = clk.into(); let mut row_1 = Array1::zeros(BASE_WIDTH); - row_1[MainColumn::StackPointer.base_table_index()] = stack_pointer_1.into(); - row_1[MainColumn::CLK.base_table_index()] = clk.into(); + row_1[MainColumn::StackPointer.main_index()] = stack_pointer_1.into(); + row_1[MainColumn::CLK.main_index()] = clk.into(); let stack_pointer_comparison = stack_pointer_0.cmp(&stack_pointer_1); let row_comparison = compare_rows(row_0.view(), row_1.view()); @@ -379,12 +377,12 @@ pub(crate) mod tests { const BASE_WIDTH: usize = ::MainColumn::COUNT; let mut row_0 = Array1::zeros(BASE_WIDTH); - row_0[MainColumn::StackPointer.base_table_index()] = stack_pointer.into(); - row_0[MainColumn::CLK.base_table_index()] = clk_0.into(); + row_0[MainColumn::StackPointer.main_index()] = stack_pointer.into(); + row_0[MainColumn::CLK.main_index()] = clk_0.into(); let mut row_1 = Array1::zeros(BASE_WIDTH); - row_1[MainColumn::StackPointer.base_table_index()] = stack_pointer.into(); - row_1[MainColumn::CLK.base_table_index()] = clk_1.into(); + row_1[MainColumn::StackPointer.main_index()] = stack_pointer.into(); + row_1[MainColumn::CLK.main_index()] = clk_1.into(); let clk_comparison = clk_0.cmp(&clk_1); let row_comparison = compare_rows(row_0.view(), row_1.view()); diff --git a/triton-vm/src/table/processor.rs b/triton-vm/src/table/processor.rs index eb699b05f..5736642b1 100644 --- a/triton-vm/src/table/processor.rs +++ b/triton-vm/src/table/processor.rs @@ -2,7 +2,7 @@ use air::challenge_id::ChallengeId; use air::cross_table_argument::*; use air::table::processor::ProcessorTable; use air::table::ram; -use air::table_column::ProcessorBaseTableColumn::*; +use air::table_column::ProcessorMainColumn::*; use air::table_column::*; use isa::instruction::AnInstruction::*; use isa::instruction::Instruction; @@ -57,15 +57,15 @@ impl TraceTable for ProcessorTable { let mut processor_table = main_table.slice_mut(s![0..num_rows, ..]); processor_table.assign(&aet.processor_trace); processor_table - .column_mut(ClockJumpDifferenceLookupMultiplicity.base_table_index()) + .column_mut(ClockJumpDifferenceLookupMultiplicity.main_index()) .assign(&clk_jump_diff_multiplicities); } fn pad(mut main_table: ArrayViewMut2, table_len: usize) { assert!(table_len > 0, "Processor Table must have at least one row."); let mut padding_template = main_table.row(table_len - 1).to_owned(); - padding_template[IsPadding.base_table_index()] = bfe!(1); - padding_template[ClockJumpDifferenceLookupMultiplicity.base_table_index()] = bfe!(0); + padding_template[IsPadding.main_index()] = bfe!(1); + padding_template[ClockJumpDifferenceLookupMultiplicity.main_index()] = bfe!(0); main_table .slice_mut(s![table_len.., ..]) .axis_iter_mut(Axis(0)) @@ -74,7 +74,7 @@ impl TraceTable for ProcessorTable { let clk_range = table_len..main_table.nrows(); let clk_col = Array1::from_iter(clk_range.map(|a| bfe!(a as u64))); - clk_col.move_into(main_table.slice_mut(s![table_len.., CLK.base_table_index()])); + clk_col.move_into(main_table.slice_mut(s![table_len.., CLK.main_index()])); // The Jump Stack Table does not have a padding indicator. Hence, clock jump differences are // being looked up in its padding sections. The clock jump differences in that section are @@ -84,7 +84,7 @@ impl TraceTable for ProcessorTable { let num_padding_rows = bfe!(num_padding_rows as u64); let mut row_1 = main_table.row_mut(1); - row_1[ClockJumpDifferenceLookupMultiplicity.base_table_index()] += num_padding_rows; + row_1[ClockJumpDifferenceLookupMultiplicity.main_index()] += num_padding_rows; } fn extend( @@ -97,8 +97,8 @@ impl TraceTable for ProcessorTable { assert_eq!(Self::AuxColumn::COUNT, aux_table.ncols()); assert_eq!(main_table.nrows(), aux_table.nrows()); - let all_column_indices = ProcessorExtTableColumn::iter() - .map(|column| column.ext_table_index()) + let all_column_indices = ProcessorAuxColumn::iter() + .map(|column| column.aux_index()) .collect_vec(); let all_column_slices = horizontal_multi_slice_mut( aux_table.view_mut(), @@ -140,7 +140,7 @@ fn extension_column_input_table_eval_argument( if let Some(Instruction::ReadIo(st)) = instruction_from_row(previous_row) { for i in (0..st.num_words()).rev() { let input_symbol_column = ProcessorTable::op_stack_column_by_index(i); - let input_symbol = current_row[input_symbol_column.base_table_index()]; + let input_symbol = current_row[input_symbol_column.main_index()]; input_table_running_evaluation = input_table_running_evaluation * challenges[ChallengeId::StandardInputIndeterminate] + input_symbol; @@ -162,7 +162,7 @@ fn extension_column_output_table_eval_argument( if let Some(Instruction::WriteIo(st)) = instruction_from_row(previous_row) { for i in 0..st.num_words() { let output_symbol_column = ProcessorTable::op_stack_column_by_index(i); - let output_symbol = previous_row[output_symbol_column.base_table_index()]; + let output_symbol = previous_row[output_symbol_column.main_index()]; output_table_running_evaluation = output_table_running_evaluation * challenges[ChallengeId::StandardOutputIndeterminate] + output_symbol; @@ -180,14 +180,13 @@ fn extension_column_instruction_lookup_argument( // collect all to-be-inverted elements for batch inversion let mut to_invert = vec![]; for row in base_table.rows() { - if row[IsPadding.base_table_index()].is_one() { + if row[IsPadding.main_index()].is_one() { break; // padding marks the end of the trace } - let compressed_row = row[IP.base_table_index()] - * challenges[ChallengeId::ProgramAddressWeight] - + row[CI.base_table_index()] * challenges[ChallengeId::ProgramInstructionWeight] - + row[NIA.base_table_index()] * challenges[ChallengeId::ProgramNextInstructionWeight]; + let compressed_row = row[IP.main_index()] * challenges[ChallengeId::ProgramAddressWeight] + + row[CI.main_index()] * challenges[ChallengeId::ProgramInstructionWeight] + + row[NIA.main_index()] * challenges[ChallengeId::ProgramNextInstructionWeight]; to_invert.push(challenges[ChallengeId::InstructionLookupIndeterminate] - compressed_row); } @@ -242,12 +241,11 @@ fn extension_column_jump_stack_table_perm_argument( let mut jump_stack_running_product = PermArg::default_initial(); let mut extension_column = Vec::with_capacity(base_table.nrows()); for row in base_table.rows() { - let compressed_row = row[CLK.base_table_index()] - * challenges[ChallengeId::JumpStackClkWeight] - + row[CI.base_table_index()] * challenges[ChallengeId::JumpStackCiWeight] - + row[JSP.base_table_index()] * challenges[ChallengeId::JumpStackJspWeight] - + row[JSO.base_table_index()] * challenges[ChallengeId::JumpStackJsoWeight] - + row[JSD.base_table_index()] * challenges[ChallengeId::JumpStackJsdWeight]; + let compressed_row = row[CLK.main_index()] * challenges[ChallengeId::JumpStackClkWeight] + + row[CI.main_index()] * challenges[ChallengeId::JumpStackCiWeight] + + row[JSP.main_index()] * challenges[ChallengeId::JumpStackJspWeight] + + row[JSO.main_index()] * challenges[ChallengeId::JumpStackJsoWeight] + + row[JSD.main_index()] * challenges[ChallengeId::JumpStackJsdWeight]; jump_stack_running_product *= challenges[ChallengeId::JumpStackIndeterminate] - compressed_row; extension_column.push(jump_stack_running_product); @@ -269,12 +267,12 @@ fn extension_column_hash_input_eval_argument( let mut hash_input_running_evaluation = EvalArg::default_initial(); let mut extension_column = Vec::with_capacity(base_table.nrows()); for row in base_table.rows() { - let current_instruction = row[CI.base_table_index()]; + let current_instruction = row[CI.main_index()]; if current_instruction == Instruction::Hash.opcode_b() || current_instruction == Instruction::MerkleStep.opcode_b() || current_instruction == Instruction::MerkleStepMem.opcode_b() { - let is_left_sibling = row[ST5.base_table_index()].value() % 2 == 0; + let is_left_sibling = row[ST5.main_index()].value() % 2 == 0; let hash_input = match instruction_from_row(row) { Some(MerkleStep | MerkleStepMem) if is_left_sibling => merkle_step_left_sibling, Some(MerkleStep | MerkleStepMem) => merkle_step_right_sibling, @@ -282,7 +280,7 @@ fn extension_column_hash_input_eval_argument( _ => unreachable!(), }; let compressed_row = hash_input - .map(|st| row[st.base_table_index()]) + .map(|st| row[st.main_index()]) .into_iter() .zip_eq(hash_state_weights.iter()) .map(|(st, &weight)| weight * st) @@ -305,13 +303,13 @@ fn extension_column_hash_digest_eval_argument( let mut extension_column = Vec::with_capacity(base_table.nrows()); extension_column.push(hash_digest_running_evaluation); for (previous_row, current_row) in base_table.rows().into_iter().tuple_windows() { - let previous_ci = previous_row[CI.base_table_index()]; + let previous_ci = previous_row[CI.main_index()]; if previous_ci == Instruction::Hash.opcode_b() || previous_ci == Instruction::MerkleStep.opcode_b() || previous_ci == Instruction::MerkleStepMem.opcode_b() { let compressed_row = [ST0, ST1, ST2, ST3, ST4] - .map(|st| current_row[st.base_table_index()]) + .map(|st| current_row[st.main_index()]) .into_iter() .zip_eq(&challenges[ChallengeId::StackWeight0..=ChallengeId::StackWeight4]) .map(|(st, &weight)| weight * st) @@ -337,14 +335,14 @@ fn extension_column_sponge_eval_argument( let mut extension_column = Vec::with_capacity(base_table.nrows()); extension_column.push(sponge_running_evaluation); for (previous_row, current_row) in base_table.rows().into_iter().tuple_windows() { - let previous_ci = previous_row[CI.base_table_index()]; + let previous_ci = previous_row[CI.main_index()]; if previous_ci == Instruction::SpongeInit.opcode_b() { sponge_running_evaluation = sponge_running_evaluation * challenges[ChallengeId::SpongeIndeterminate] + challenges[ChallengeId::HashCIWeight] * Instruction::SpongeInit.opcode_b(); } else if previous_ci == Instruction::SpongeAbsorb.opcode_b() { let compressed_row = st0_through_st9 - .map(|st| previous_row[st.base_table_index()]) + .map(|st| previous_row[st.main_index()]) .into_iter() .zip_eq(hash_state_weights.iter()) .map(|(st, &weight)| weight * st) @@ -357,9 +355,9 @@ fn extension_column_sponge_eval_argument( let stack_elements = [ST1, ST2, ST3, ST4]; let helper_variables = [HV0, HV1, HV2, HV3, HV4, HV5]; let compressed_row = stack_elements - .map(|st| current_row[st.base_table_index()]) + .map(|st| current_row[st.main_index()]) .into_iter() - .chain(helper_variables.map(|hv| previous_row[hv.base_table_index()])) + .chain(helper_variables.map(|hv| previous_row[hv.main_index()])) .zip_eq(hash_state_weights.iter()) .map(|(element, &weight)| weight * element) .sum::(); @@ -369,7 +367,7 @@ fn extension_column_sponge_eval_argument( + compressed_row; } else if previous_ci == Instruction::SpongeSqueeze.opcode_b() { let compressed_row = st0_through_st9 - .map(|st| current_row[st.base_table_index()]) + .map(|st| current_row[st.main_index()]) .into_iter() .zip_eq(hash_state_weights.iter()) .map(|(st, &weight)| weight * st) @@ -391,31 +389,31 @@ fn extension_column_for_u32_lookup_argument( // collect elements to be inverted for more performant batch inversion let mut to_invert = vec![]; for (previous_row, current_row) in base_table.rows().into_iter().tuple_windows() { - let previous_ci = previous_row[CI.base_table_index()]; + let previous_ci = previous_row[CI.main_index()]; if previous_ci == Instruction::Split.opcode_b() { - let compressed_row = current_row[ST0.base_table_index()] + let compressed_row = current_row[ST0.main_index()] * challenges[ChallengeId::U32LhsWeight] - + current_row[ST1.base_table_index()] * challenges[ChallengeId::U32RhsWeight] - + previous_row[CI.base_table_index()] * challenges[ChallengeId::U32CiWeight]; + + current_row[ST1.main_index()] * challenges[ChallengeId::U32RhsWeight] + + previous_row[CI.main_index()] * challenges[ChallengeId::U32CiWeight]; to_invert.push(challenges[ChallengeId::U32Indeterminate] - compressed_row); } else if previous_ci == Instruction::Lt.opcode_b() || previous_ci == Instruction::And.opcode_b() || previous_ci == Instruction::Pow.opcode_b() { - let compressed_row = previous_row[ST0.base_table_index()] + let compressed_row = previous_row[ST0.main_index()] * challenges[ChallengeId::U32LhsWeight] - + previous_row[ST1.base_table_index()] * challenges[ChallengeId::U32RhsWeight] - + previous_row[CI.base_table_index()] * challenges[ChallengeId::U32CiWeight] - + current_row[ST0.base_table_index()] * challenges[ChallengeId::U32ResultWeight]; + + previous_row[ST1.main_index()] * challenges[ChallengeId::U32RhsWeight] + + previous_row[CI.main_index()] * challenges[ChallengeId::U32CiWeight] + + current_row[ST0.main_index()] * challenges[ChallengeId::U32ResultWeight]; to_invert.push(challenges[ChallengeId::U32Indeterminate] - compressed_row); } else if previous_ci == Instruction::Xor.opcode_b() { // Triton VM uses the following equality to compute the results of both the // `and` and `xor` instruction using the u32 coprocessor's `and` capability: // a ^ b = a + b - 2 · (a & b) // <=> a & b = (a + b - a ^ b) / 2 - let st0_prev = previous_row[ST0.base_table_index()]; - let st1_prev = previous_row[ST1.base_table_index()]; - let st0 = current_row[ST0.base_table_index()]; + let st0_prev = previous_row[ST0.main_index()]; + let st1_prev = previous_row[ST1.main_index()]; + let st0 = current_row[ST0.main_index()]; let from_xor_in_processor_to_and_in_u32_coprocessor = (st0_prev + st1_prev - st0) / bfe!(2); let compressed_row = st0_prev * challenges[ChallengeId::U32LhsWeight] @@ -427,20 +425,20 @@ fn extension_column_for_u32_lookup_argument( } else if previous_ci == Instruction::Log2Floor.opcode_b() || previous_ci == Instruction::PopCount.opcode_b() { - let compressed_row = previous_row[ST0.base_table_index()] + let compressed_row = previous_row[ST0.main_index()] * challenges[ChallengeId::U32LhsWeight] - + previous_row[CI.base_table_index()] * challenges[ChallengeId::U32CiWeight] - + current_row[ST0.base_table_index()] * challenges[ChallengeId::U32ResultWeight]; + + previous_row[CI.main_index()] * challenges[ChallengeId::U32CiWeight] + + current_row[ST0.main_index()] * challenges[ChallengeId::U32ResultWeight]; to_invert.push(challenges[ChallengeId::U32Indeterminate] - compressed_row); } else if previous_ci == Instruction::DivMod.opcode_b() { - let compressed_row_for_lt_check = current_row[ST0.base_table_index()] + let compressed_row_for_lt_check = current_row[ST0.main_index()] * challenges[ChallengeId::U32LhsWeight] - + previous_row[ST1.base_table_index()] * challenges[ChallengeId::U32RhsWeight] + + previous_row[ST1.main_index()] * challenges[ChallengeId::U32RhsWeight] + Instruction::Lt.opcode_b() * challenges[ChallengeId::U32CiWeight] + bfe!(1) * challenges[ChallengeId::U32ResultWeight]; - let compressed_row_for_range_check = previous_row[ST0.base_table_index()] + let compressed_row_for_range_check = previous_row[ST0.main_index()] * challenges[ChallengeId::U32LhsWeight] - + current_row[ST1.base_table_index()] * challenges[ChallengeId::U32RhsWeight] + + current_row[ST1.main_index()] * challenges[ChallengeId::U32RhsWeight] + Instruction::Split.opcode_b() * challenges[ChallengeId::U32CiWeight]; to_invert.push(challenges[ChallengeId::U32Indeterminate] - compressed_row_for_lt_check); to_invert @@ -448,9 +446,9 @@ fn extension_column_for_u32_lookup_argument( } else if previous_ci == Instruction::MerkleStep.opcode_b() || previous_ci == Instruction::MerkleStepMem.opcode_b() { - let compressed_row = previous_row[ST5.base_table_index()] + let compressed_row = previous_row[ST5.main_index()] * challenges[ChallengeId::U32LhsWeight] - + current_row[ST5.base_table_index()] * challenges[ChallengeId::U32RhsWeight] + + current_row[ST5.main_index()] * challenges[ChallengeId::U32RhsWeight] + Instruction::Split.opcode_b() * challenges[ChallengeId::U32CiWeight]; to_invert.push(challenges[ChallengeId::U32Indeterminate] - compressed_row); } @@ -462,7 +460,7 @@ fn extension_column_for_u32_lookup_argument( let mut extension_column = Vec::with_capacity(base_table.nrows()); extension_column.push(u32_table_running_sum_log_derivative); for (previous_row, _) in base_table.rows().into_iter().tuple_windows() { - let previous_ci = previous_row[CI.base_table_index()]; + let previous_ci = previous_row[CI.main_index()]; if Instruction::try_from(previous_ci) .unwrap() .is_u32_instruction() @@ -488,9 +486,9 @@ fn extension_column_for_clock_jump_difference_lookup_argument( // collect inverses to batch invert let mut to_invert = vec![]; for row in base_table.rows() { - let lookup_multiplicity = row[ClockJumpDifferenceLookupMultiplicity.base_table_index()]; + let lookup_multiplicity = row[ClockJumpDifferenceLookupMultiplicity.main_index()]; if !lookup_multiplicity.is_zero() { - let clk = row[CLK.base_table_index()]; + let clk = row[CLK.main_index()]; to_invert.push(challenges[ChallengeId::ClockJumpDifferenceLookupIndeterminate] - clk); } } @@ -500,7 +498,7 @@ fn extension_column_for_clock_jump_difference_lookup_argument( let mut cjd_lookup_log_derivative = LookupArg::default_initial(); let mut extension_column = Vec::with_capacity(base_table.nrows()); for row in base_table.rows() { - let lookup_multiplicity = row[ClockJumpDifferenceLookupMultiplicity.base_table_index()]; + let lookup_multiplicity = row[ClockJumpDifferenceLookupMultiplicity.main_index()]; if !lookup_multiplicity.is_zero() { cjd_lookup_log_derivative += inverses.next().unwrap() * lookup_multiplicity; } @@ -517,7 +515,7 @@ fn factor_for_op_stack_table_running_product( ) -> XFieldElement { let default_factor = xfe!(1); - let is_padding_row = current_row[IsPadding.base_table_index()].is_one(); + let is_padding_row = current_row[IsPadding.main_index()].is_one(); if is_padding_row { return default_factor; } @@ -541,14 +539,14 @@ fn factor_for_op_stack_table_running_product( let max_stack_element_index = OpStackElement::COUNT - 1; let stack_element_index = max_stack_element_index - op_stack_pointer_offset; let stack_element_column = ProcessorTable::op_stack_column_by_index(stack_element_index); - let underflow_element = row_with_shorter_stack[stack_element_column.base_table_index()]; + let underflow_element = row_with_shorter_stack[stack_element_column.main_index()]; - let op_stack_pointer = row_with_shorter_stack[OpStackPointer.base_table_index()]; + let op_stack_pointer = row_with_shorter_stack[OpStackPointer.main_index()]; let offset = bfe!(op_stack_pointer_offset as u64); let offset_op_stack_pointer = op_stack_pointer + offset; - let clk = previous_row[CLK.base_table_index()]; - let ib1_shrink_stack = previous_row[IB1.base_table_index()]; + let clk = previous_row[CLK.main_index()]; + let ib1_shrink_stack = previous_row[IB1.main_index()]; let compressed_row = clk * challenges[ChallengeId::OpStackClkWeight] + ib1_shrink_stack * challenges[ChallengeId::OpStackIb1Weight] + offset_op_stack_pointer * challenges[ChallengeId::OpStackPointerWeight] @@ -563,14 +561,14 @@ fn factor_for_ram_table_running_product( current_row: ArrayView1, challenges: &Challenges, ) -> Option { - let is_padding_row = current_row[IsPadding.base_table_index()].is_one(); + let is_padding_row = current_row[IsPadding.main_index()].is_one(); if is_padding_row { return None; } let instruction = instruction_from_row(previous_row)?; - let clk = previous_row[CLK.base_table_index()]; + let clk = previous_row[CLK.main_index()]; let instruction_type = match instruction { ReadMem(_) => ram::INSTRUCTION_TYPE_READ, WriteMem(_) => ram::INSTRUCTION_TYPE_WRITE, @@ -597,50 +595,50 @@ fn factor_for_ram_table_running_product( for ram_pointer_offset in 0..op_stack_delta { let ram_value_index = ram_pointer_offset + num_ram_pointers; let ram_value_column = ProcessorTable::op_stack_column_by_index(ram_value_index); - let ram_value = row_with_longer_stack[ram_value_column.base_table_index()]; + let ram_value = row_with_longer_stack[ram_value_column.main_index()]; let offset_ram_pointer = offset_ram_pointer(instruction, row_with_longer_stack, ram_pointer_offset); accesses.push((offset_ram_pointer, ram_value)); } } SpongeAbsorbMem => { - let mem_pointer = previous_row[ST0.base_table_index()]; - accesses.push((mem_pointer + bfe!(0), current_row[ST1.base_table_index()])); - accesses.push((mem_pointer + bfe!(1), current_row[ST2.base_table_index()])); - accesses.push((mem_pointer + bfe!(2), current_row[ST3.base_table_index()])); - accesses.push((mem_pointer + bfe!(3), current_row[ST4.base_table_index()])); - accesses.push((mem_pointer + bfe!(4), previous_row[HV0.base_table_index()])); - accesses.push((mem_pointer + bfe!(5), previous_row[HV1.base_table_index()])); - accesses.push((mem_pointer + bfe!(6), previous_row[HV2.base_table_index()])); - accesses.push((mem_pointer + bfe!(7), previous_row[HV3.base_table_index()])); - accesses.push((mem_pointer + bfe!(8), previous_row[HV4.base_table_index()])); - accesses.push((mem_pointer + bfe!(9), previous_row[HV5.base_table_index()])); + let mem_pointer = previous_row[ST0.main_index()]; + accesses.push((mem_pointer + bfe!(0), current_row[ST1.main_index()])); + accesses.push((mem_pointer + bfe!(1), current_row[ST2.main_index()])); + accesses.push((mem_pointer + bfe!(2), current_row[ST3.main_index()])); + accesses.push((mem_pointer + bfe!(3), current_row[ST4.main_index()])); + accesses.push((mem_pointer + bfe!(4), previous_row[HV0.main_index()])); + accesses.push((mem_pointer + bfe!(5), previous_row[HV1.main_index()])); + accesses.push((mem_pointer + bfe!(6), previous_row[HV2.main_index()])); + accesses.push((mem_pointer + bfe!(7), previous_row[HV3.main_index()])); + accesses.push((mem_pointer + bfe!(8), previous_row[HV4.main_index()])); + accesses.push((mem_pointer + bfe!(9), previous_row[HV5.main_index()])); } MerkleStepMem => { - let mem_pointer = previous_row[ST7.base_table_index()]; - accesses.push((mem_pointer + bfe!(0), previous_row[HV0.base_table_index()])); - accesses.push((mem_pointer + bfe!(1), previous_row[HV1.base_table_index()])); - accesses.push((mem_pointer + bfe!(2), previous_row[HV2.base_table_index()])); - accesses.push((mem_pointer + bfe!(3), previous_row[HV3.base_table_index()])); - accesses.push((mem_pointer + bfe!(4), previous_row[HV4.base_table_index()])); + let mem_pointer = previous_row[ST7.main_index()]; + accesses.push((mem_pointer + bfe!(0), previous_row[HV0.main_index()])); + accesses.push((mem_pointer + bfe!(1), previous_row[HV1.main_index()])); + accesses.push((mem_pointer + bfe!(2), previous_row[HV2.main_index()])); + accesses.push((mem_pointer + bfe!(3), previous_row[HV3.main_index()])); + accesses.push((mem_pointer + bfe!(4), previous_row[HV4.main_index()])); } XxDotStep => { - let rhs_pointer = previous_row[ST0.base_table_index()]; - let lhs_pointer = previous_row[ST1.base_table_index()]; - accesses.push((rhs_pointer + bfe!(0), previous_row[HV0.base_table_index()])); - accesses.push((rhs_pointer + bfe!(1), previous_row[HV1.base_table_index()])); - accesses.push((rhs_pointer + bfe!(2), previous_row[HV2.base_table_index()])); - accesses.push((lhs_pointer + bfe!(0), previous_row[HV3.base_table_index()])); - accesses.push((lhs_pointer + bfe!(1), previous_row[HV4.base_table_index()])); - accesses.push((lhs_pointer + bfe!(2), previous_row[HV5.base_table_index()])); + let rhs_pointer = previous_row[ST0.main_index()]; + let lhs_pointer = previous_row[ST1.main_index()]; + accesses.push((rhs_pointer + bfe!(0), previous_row[HV0.main_index()])); + accesses.push((rhs_pointer + bfe!(1), previous_row[HV1.main_index()])); + accesses.push((rhs_pointer + bfe!(2), previous_row[HV2.main_index()])); + accesses.push((lhs_pointer + bfe!(0), previous_row[HV3.main_index()])); + accesses.push((lhs_pointer + bfe!(1), previous_row[HV4.main_index()])); + accesses.push((lhs_pointer + bfe!(2), previous_row[HV5.main_index()])); } XbDotStep => { - let rhs_pointer = previous_row[ST0.base_table_index()]; - let lhs_pointer = previous_row[ST1.base_table_index()]; - accesses.push((rhs_pointer + bfe!(0), previous_row[HV0.base_table_index()])); - accesses.push((lhs_pointer + bfe!(0), previous_row[HV1.base_table_index()])); - accesses.push((lhs_pointer + bfe!(1), previous_row[HV2.base_table_index()])); - accesses.push((lhs_pointer + bfe!(2), previous_row[HV3.base_table_index()])); + let rhs_pointer = previous_row[ST0.main_index()]; + let lhs_pointer = previous_row[ST1.main_index()]; + accesses.push((rhs_pointer + bfe!(0), previous_row[HV0.main_index()])); + accesses.push((lhs_pointer + bfe!(0), previous_row[HV1.main_index()])); + accesses.push((lhs_pointer + bfe!(1), previous_row[HV2.main_index()])); + accesses.push((lhs_pointer + bfe!(2), previous_row[HV3.main_index()])); } _ => unreachable!(), }; @@ -662,7 +660,7 @@ fn offset_ram_pointer( row_with_longer_stack: ArrayView1, ram_pointer_offset: usize, ) -> BFieldElement { - let ram_pointer = row_with_longer_stack[ST0.base_table_index()]; + let ram_pointer = row_with_longer_stack[ST0.main_index()]; let offset = bfe!(ram_pointer_offset as u64); match instruction { @@ -675,11 +673,11 @@ fn offset_ram_pointer( } fn instruction_from_row(row: ArrayView1) -> Option { - let opcode = row[CI.base_table_index()]; + let opcode = row[CI.main_index()]; let instruction = Instruction::try_from(opcode).ok()?; if instruction.arg().is_some() { - let arg = row[NIA.base_table_index()]; + let arg = row[NIA.main_index()]; return instruction.change_arg(arg).ok(); } @@ -733,8 +731,8 @@ pub(crate) mod tests { #[derive(Debug, Clone)] struct TestRowsDebugInfo { pub instruction: Instruction, - pub debug_cols_curr_row: Vec, - pub debug_cols_next_row: Vec, + pub debug_cols_curr_row: Vec, + pub debug_cols_next_row: Vec, } fn test_row_from_program(program: Program, row_num: usize) -> TestRows { @@ -777,16 +775,16 @@ pub(crate) mod tests { println!("Testing all constraints of {instruction} for test case {case_idx}…"); for &c in &debug_info.debug_cols_curr_row { - print!("{c} = {}, ", curr_row[c.master_base_table_index()]); + print!("{c} = {}, ", curr_row[c.master_main_index()]); } println!(); for &c in &debug_info.debug_cols_next_row { - print!("{c}' = {}, ", next_row[c.master_base_table_index()]); + print!("{c}' = {}, ", next_row[c.master_main_index()]); } println!(); assert!( - instruction.opcode_b() == curr_row[CI.master_base_table_index()], + instruction.opcode_b() == curr_row[CI.master_main_index()], "The test is trying to check the wrong transition constraint polynomials." ); diff --git a/triton-vm/src/table/program.rs b/triton-vm/src/table/program.rs index 78d334bf8..46b2d399d 100644 --- a/triton-vm/src/table/program.rs +++ b/triton-vm/src/table/program.rs @@ -6,8 +6,8 @@ use air::cross_table_argument::EvalArg; use air::cross_table_argument::LookupArg; use air::table::program::ProgramTable; use air::table::TableId; -use air::table_column::ProgramBaseTableColumn::*; -use air::table_column::ProgramExtTableColumn::*; +use air::table_column::ProgramAuxColumn::*; +use air::table_column::ProgramMainColumn::*; use air::table_column::*; use ndarray::s; use ndarray::Array1; @@ -61,12 +61,12 @@ impl TraceTable for ProgramTable { }; let mut current_row = program_table.row_mut(row_idx); - current_row[Address.base_table_index()] = address; - current_row[Instruction.base_table_index()] = instruction; - current_row[LookupMultiplicity.base_table_index()] = lookup_multiplicity; - current_row[IndexInChunk.base_table_index()] = index_in_chunk; - current_row[MaxMinusIndexInChunkInv.base_table_index()] = max_minus_index_in_chunk_inv; - current_row[IsHashInputPadding.base_table_index()] = is_hash_input_padding; + current_row[Address.main_index()] = address; + current_row[Instruction.main_index()] = instruction; + current_row[LookupMultiplicity.main_index()] = lookup_multiplicity; + current_row[IndexInChunk.main_index()] = index_in_chunk; + current_row[MaxMinusIndexInChunkInv.main_index()] = max_minus_index_in_chunk_inv; + current_row[IsHashInputPadding.main_index()] = is_hash_input_padding; } } @@ -74,7 +74,7 @@ impl TraceTable for ProgramTable { let addresses = (program_len..program_table.nrows()).map(|a| bfe!(u64::try_from(a).unwrap())); let addresses = Array1::from_iter(addresses); - let address_column = program_table.slice_mut(s![program_len.., Address.base_table_index()]); + let address_column = program_table.slice_mut(s![program_len.., Address.main_index()]); addresses.move_into(address_column); let indices_in_chunks = (program_len..program_table.nrows()) @@ -82,7 +82,7 @@ impl TraceTable for ProgramTable { .map(|ac| bfe!(u64::try_from(ac).unwrap())); let indices_in_chunks = Array1::from_iter(indices_in_chunks); let index_in_chunk_column = - program_table.slice_mut(s![program_len.., IndexInChunk.base_table_index()]); + program_table.slice_mut(s![program_len.., IndexInChunk.main_index()]); indices_in_chunks.move_into(index_in_chunk_column); let max_minus_indices_in_chunks_inverses = (program_len..program_table.nrows()) @@ -91,17 +91,15 @@ impl TraceTable for ProgramTable { .map(|bfe| bfe.inverse_or_zero()); let max_minus_indices_in_chunks_inverses = Array1::from_iter(max_minus_indices_in_chunks_inverses); - let max_minus_index_in_chunk_inv_column = program_table.slice_mut(s![ - program_len.., - MaxMinusIndexInChunkInv.base_table_index() - ]); + let max_minus_index_in_chunk_inv_column = + program_table.slice_mut(s![program_len.., MaxMinusIndexInChunkInv.main_index()]); max_minus_indices_in_chunks_inverses.move_into(max_minus_index_in_chunk_inv_column); program_table - .slice_mut(s![program_len.., IsHashInputPadding.base_table_index()]) + .slice_mut(s![program_len.., IsHashInputPadding.main_index()]) .fill(BFieldElement::one()); program_table - .slice_mut(s![program_len.., IsTablePadding.base_table_index()]) + .slice_mut(s![program_len.., IsTablePadding.main_index()]) .fill(BFieldElement::one()); } @@ -139,7 +137,7 @@ impl TraceTable for ProgramTable { // The logarithmic derivative's final value, allowing for a meaningful cross-table // argument, is recorded in the first padding row. This row is guaranteed to exist // due to the hash-input padding mechanics. - extension_row[InstructionLookupServerLogDerivative.ext_table_index()] = + extension_row[InstructionLookupServerLogDerivative.aux_index()] = instruction_lookup_log_derivative; instruction_lookup_log_derivative = update_instruction_lookup_log_derivative( @@ -160,35 +158,33 @@ impl TraceTable for ProgramTable { prepare_chunk_running_evaluation, ); - extension_row[PrepareChunkRunningEvaluation.ext_table_index()] = + extension_row[PrepareChunkRunningEvaluation.aux_index()] = prepare_chunk_running_evaluation; - extension_row[SendChunkRunningEvaluation.ext_table_index()] = - send_chunk_running_evaluation; + extension_row[SendChunkRunningEvaluation.aux_index()] = send_chunk_running_evaluation; } // special treatment for the last row - let base_rows_iter = main_table.rows().into_iter(); - let ext_rows_iter = aux_table.rows_mut().into_iter(); - let last_base_row = base_rows_iter.last().unwrap(); - let mut last_ext_row = ext_rows_iter.last().unwrap(); + let main_rows_iter = main_table.rows().into_iter(); + let aux_rows_iter = aux_table.rows_mut().into_iter(); + let last_main_row = main_rows_iter.last().unwrap(); + let mut last_aux_row = aux_rows_iter.last().unwrap(); prepare_chunk_running_evaluation = update_prepare_chunk_running_evaluation( - last_base_row, + last_main_row, challenges, prepare_chunk_running_evaluation, ); send_chunk_running_evaluation = update_send_chunk_running_evaluation( - last_base_row, + last_main_row, challenges, send_chunk_running_evaluation, prepare_chunk_running_evaluation, ); - last_ext_row[InstructionLookupServerLogDerivative.ext_table_index()] = + last_aux_row[InstructionLookupServerLogDerivative.aux_index()] = instruction_lookup_log_derivative; - last_ext_row[PrepareChunkRunningEvaluation.ext_table_index()] = - prepare_chunk_running_evaluation; - last_ext_row[SendChunkRunningEvaluation.ext_table_index()] = send_chunk_running_evaluation; + last_aux_row[PrepareChunkRunningEvaluation.aux_index()] = prepare_chunk_running_evaluation; + last_aux_row[SendChunkRunningEvaluation.aux_index()] = send_chunk_running_evaluation; profiler!(stop "program table"); } @@ -200,7 +196,7 @@ fn update_instruction_lookup_log_derivative( next_row: ArrayView1, instruction_lookup_log_derivative: XFieldElement, ) -> XFieldElement { - if row[IsHashInputPadding.base_table_index()].is_one() { + if row[IsHashInputPadding.main_index()].is_one() { return instruction_lookup_log_derivative; } instruction_lookup_log_derivative @@ -212,11 +208,11 @@ fn instruction_lookup_log_derivative_summand( next_row: ArrayView1, challenges: &Challenges, ) -> XFieldElement { - let compressed_row = row[Address.base_table_index()] * challenges[ProgramAddressWeight] - + row[Instruction.base_table_index()] * challenges[ProgramInstructionWeight] - + next_row[Instruction.base_table_index()] * challenges[ProgramNextInstructionWeight]; + let compressed_row = row[Address.main_index()] * challenges[ProgramAddressWeight] + + row[Instruction.main_index()] * challenges[ProgramInstructionWeight] + + next_row[Instruction.main_index()] * challenges[ProgramNextInstructionWeight]; (challenges[InstructionLookupIndeterminate] - compressed_row).inverse() - * row[LookupMultiplicity.base_table_index()] + * row[LookupMultiplicity.main_index()] } fn update_prepare_chunk_running_evaluation( @@ -224,7 +220,7 @@ fn update_prepare_chunk_running_evaluation( challenges: &Challenges, prepare_chunk_running_evaluation: XFieldElement, ) -> XFieldElement { - let running_evaluation_resets = row[IndexInChunk.base_table_index()].is_zero(); + let running_evaluation_resets = row[IndexInChunk.main_index()].is_zero(); let prepare_chunk_running_evaluation = if running_evaluation_resets { EvalArg::default_initial() } else { @@ -232,7 +228,7 @@ fn update_prepare_chunk_running_evaluation( }; prepare_chunk_running_evaluation * challenges[ProgramAttestationPrepareChunkIndeterminate] - + row[Instruction.base_table_index()] + + row[Instruction.main_index()] } fn update_send_chunk_running_evaluation( @@ -241,8 +237,8 @@ fn update_send_chunk_running_evaluation( send_chunk_running_evaluation: XFieldElement, prepare_chunk_running_evaluation: XFieldElement, ) -> XFieldElement { - let index_in_chunk = row[IndexInChunk.base_table_index()]; - let is_table_padding_row = row[IsTablePadding.base_table_index()].is_one(); + let index_in_chunk = row[IndexInChunk.main_index()]; + let is_table_padding_row = row[IsTablePadding.main_index()].is_one(); let max_index_in_chunk = Tip5::RATE as u64 - 1; let running_evaluation_needs_update = !is_table_padding_row && index_in_chunk.value() == max_index_in_chunk; diff --git a/triton-vm/src/table/ram.rs b/triton-vm/src/table/ram.rs index 567d20a51..d5807a24e 100644 --- a/triton-vm/src/table/ram.rs +++ b/triton-vm/src/table/ram.rs @@ -5,7 +5,7 @@ use air::cross_table_argument::*; use air::table::ram::RamTable; use air::table::ram::PADDING_INDICATOR; use air::table::TableId; -use air::table_column::RamBaseTableColumn::*; +use air::table_column::RamMainColumn::*; use air::table_column::*; use air::AIR; use arbitrary::Arbitrary; @@ -46,10 +46,10 @@ impl RamTableCall { }; let mut row = Array1::zeros(::MainColumn::COUNT); - row[CLK.base_table_index()] = self.clk.into(); - row[InstructionType.base_table_index()] = instruction_type; - row[RamPointer.base_table_index()] = self.ram_pointer; - row[RamValue.base_table_index()] = self.ram_value; + row[CLK.main_index()] = self.clk.into(); + row[InstructionType.main_index()] = instruction_type; + row[RamPointer.main_index()] = self.ram_pointer; + row[RamValue.main_index()] = self.ram_value; row } } @@ -72,7 +72,7 @@ impl TraceTable for RamTable { ram_table.row_mut(row_index).assign(&row); } - let all_ram_pointers = ram_table.column(RamPointer.base_table_index()); + let all_ram_pointers = ram_table.column(RamPointer.main_index()); let unique_ram_pointers = all_ram_pointers.iter().unique().copied().collect_vec(); let (bezout_0, bezout_1) = bezout_coefficient_polynomials_coefficients(&unique_ram_pointers); @@ -83,10 +83,9 @@ impl TraceTable for RamTable { fn pad(mut main_table: ArrayViewMut2, table_len: usize) { let last_row_index = table_len.saturating_sub(1); let mut padding_row = main_table.row(last_row_index).to_owned(); - padding_row[InstructionType.base_table_index()] = PADDING_INDICATOR; + padding_row[InstructionType.main_index()] = PADDING_INDICATOR; if table_len == 0 { - padding_row[BezoutCoefficientPolynomialCoefficient1.base_table_index()] = - BFieldElement::ONE; + padding_row[BezoutCoefficientPolynomialCoefficient1.main_index()] = BFieldElement::ONE; } let mut padding_section = main_table.slice_mut(s![table_len.., ..]); @@ -106,11 +105,11 @@ impl TraceTable for RamTable { assert_eq!(Self::AuxColumn::COUNT, ext_table.ncols()); assert_eq!(base_table.nrows(), ext_table.nrows()); - let extension_column_indices = RamExtTableColumn::iter() + let extension_column_indices = RamAuxColumn::iter() // RunningProductOfRAMP + FormalDerivative are constitute one // slice and are populated by the same function - .filter(|column| *column != RamExtTableColumn::FormalDerivative) - .map(|column| column.ext_table_index()) + .filter(|column| *column != RamAuxColumn::FormalDerivative) + .map(|column| column.aux_index()) .collect_vec(); let extension_column_slices = horizontal_multi_slice_mut( ext_table.view_mut(), @@ -135,12 +134,12 @@ impl TraceTable for RamTable { } fn compare_rows(row_0: ArrayView1, row_1: ArrayView1) -> Ordering { - let ram_pointer_0 = row_0[RamPointer.base_table_index()].value(); - let ram_pointer_1 = row_1[RamPointer.base_table_index()].value(); + let ram_pointer_0 = row_0[RamPointer.main_index()].value(); + let ram_pointer_1 = row_1[RamPointer.main_index()].value(); let compare_ram_pointers = ram_pointer_0.cmp(&ram_pointer_1); - let clk_0 = row_0[CLK.base_table_index()].value(); - let clk_1 = row_1[CLK.base_table_index()].value(); + let clk_0 = row_0[CLK.main_index()].value(); + let clk_1 = row_1[CLK.main_index()].value(); let compare_clocks = clk_0.cmp(&clk_1); compare_ram_pointers.then(compare_clocks) @@ -209,19 +208,16 @@ fn make_ram_table_consistent( let mut current_bcpc_0 = bezout_coefficient_polynomial_coefficients_0.pop().unwrap(); let mut current_bcpc_1 = bezout_coefficient_polynomial_coefficients_1.pop().unwrap(); - ram_table.row_mut(0)[BezoutCoefficientPolynomialCoefficient0.base_table_index()] = - current_bcpc_0; - ram_table.row_mut(0)[BezoutCoefficientPolynomialCoefficient1.base_table_index()] = - current_bcpc_1; + ram_table.row_mut(0)[BezoutCoefficientPolynomialCoefficient0.main_index()] = current_bcpc_0; + ram_table.row_mut(0)[BezoutCoefficientPolynomialCoefficient1.main_index()] = current_bcpc_1; let mut clock_jump_differences = vec![]; for row_idx in 0..ram_table.nrows() - 1 { let (mut curr_row, mut next_row) = ram_table.multi_slice_mut((s![row_idx, ..], s![row_idx + 1, ..])); - let ramp_diff = - next_row[RamPointer.base_table_index()] - curr_row[RamPointer.base_table_index()]; - let clk_diff = next_row[CLK.base_table_index()] - curr_row[CLK.base_table_index()]; + let ramp_diff = next_row[RamPointer.main_index()] - curr_row[RamPointer.main_index()]; + let clk_diff = next_row[CLK.main_index()] - curr_row[CLK.main_index()]; if ramp_diff.is_zero() { clock_jump_differences.push(clk_diff); @@ -230,9 +226,9 @@ fn make_ram_table_consistent( current_bcpc_1 = bezout_coefficient_polynomial_coefficients_1.pop().unwrap(); } - curr_row[InverseOfRampDifference.base_table_index()] = ramp_diff.inverse_or_zero(); - next_row[BezoutCoefficientPolynomialCoefficient0.base_table_index()] = current_bcpc_0; - next_row[BezoutCoefficientPolynomialCoefficient1.base_table_index()] = current_bcpc_1; + curr_row[InverseOfRampDifference.main_index()] = ramp_diff.inverse_or_zero(); + next_row[BezoutCoefficientPolynomialCoefficient0.main_index()] = current_bcpc_0; + next_row[BezoutCoefficientPolynomialCoefficient1.main_index()] = current_bcpc_1; } assert_eq!(0, bezout_coefficient_polynomial_coefficients_0.len()); @@ -248,19 +244,19 @@ fn extension_column_running_product_of_ramp_and_formal_derivative( let mut extension_columns = Vec::with_capacity(2 * base_table.nrows()); let mut running_product_ram_pointer = - bezout_indeterminate - base_table.row(0)[RamPointer.base_table_index()]; + bezout_indeterminate - base_table.row(0)[RamPointer.main_index()]; let mut formal_derivative = xfe!(1); extension_columns.push(running_product_ram_pointer); extension_columns.push(formal_derivative); for (previous_row, current_row) in base_table.rows().into_iter().tuple_windows() { - let instruction_type = current_row[InstructionType.base_table_index()]; + let instruction_type = current_row[InstructionType.main_index()]; let is_no_padding_row = instruction_type != PADDING_INDICATOR; if is_no_padding_row { - let current_ram_pointer = current_row[RamPointer.base_table_index()]; - let previous_ram_pointer = previous_row[RamPointer.base_table_index()]; + let current_ram_pointer = current_row[RamPointer.main_index()]; + let previous_ram_pointer = previous_row[RamPointer.main_index()]; if previous_ram_pointer != current_ram_pointer { formal_derivative = (bezout_indeterminate - current_ram_pointer) * formal_derivative @@ -301,25 +297,24 @@ fn extension_column_bezout_coefficient_1( fn extension_column_bezout_coefficient( base_table: ArrayView2, challenges: &Challenges, - bezout_cefficient_column: RamBaseTableColumn, + bezout_cefficient_column: RamMainColumn, ) -> Array2 { let bezout_indeterminate = challenges[RamTableBezoutRelationIndeterminate]; - let mut bezout_coefficient = - base_table.row(0)[bezout_cefficient_column.base_table_index()].lift(); + let mut bezout_coefficient = base_table.row(0)[bezout_cefficient_column.main_index()].lift(); let mut extension_column = Vec::with_capacity(base_table.nrows()); extension_column.push(bezout_coefficient); for (previous_row, current_row) in base_table.rows().into_iter().tuple_windows() { - if current_row[InstructionType.base_table_index()] == PADDING_INDICATOR { + if current_row[InstructionType.main_index()] == PADDING_INDICATOR { break; // padding marks the end of the trace } - let previous_ram_pointer = previous_row[RamPointer.base_table_index()]; - let current_ram_pointer = current_row[RamPointer.base_table_index()]; + let previous_ram_pointer = previous_row[RamPointer.main_index()]; + let current_ram_pointer = current_row[RamPointer.main_index()]; if previous_ram_pointer != current_ram_pointer { bezout_coefficient *= bezout_indeterminate; - bezout_coefficient += current_row[bezout_cefficient_column.base_table_index()]; + bezout_coefficient += current_row[bezout_cefficient_column.main_index()]; } extension_column.push(bezout_coefficient); } @@ -336,14 +331,14 @@ fn extension_column_running_product_perm_arg( let mut running_product_for_perm_arg = PermArg::default_initial(); let mut extension_column = Vec::with_capacity(base_table.nrows()); for row in base_table.rows() { - let instruction_type = row[InstructionType.base_table_index()]; + let instruction_type = row[InstructionType.main_index()]; if instruction_type == PADDING_INDICATOR { break; // padding marks the end of the trace } - let clk = row[CLK.base_table_index()]; - let current_ram_pointer = row[RamPointer.base_table_index()]; - let ram_value = row[RamValue.base_table_index()]; + let clk = row[CLK.main_index()]; + let current_ram_pointer = row[RamPointer.main_index()]; + let ram_value = row[RamValue.main_index()]; let compressed_row = clk * challenges[RamClkWeight] + instruction_type * challenges[RamInstructionTypeWeight] + current_ram_pointer * challenges[RamPointerWeight] @@ -368,15 +363,15 @@ fn extension_column_clock_jump_difference_lookup_log_derivative( extension_column.push(cjd_lookup_log_derivative); for (previous_row, current_row) in base_table.rows().into_iter().tuple_windows() { - if current_row[InstructionType.base_table_index()] == PADDING_INDICATOR { + if current_row[InstructionType.main_index()] == PADDING_INDICATOR { break; // padding marks the end of the trace } - let previous_ram_pointer = previous_row[RamPointer.base_table_index()]; - let current_ram_pointer = current_row[RamPointer.base_table_index()]; + let previous_ram_pointer = previous_row[RamPointer.main_index()]; + let current_ram_pointer = current_row[RamPointer.main_index()]; if previous_ram_pointer == current_ram_pointer { - let previous_clock = previous_row[CLK.base_table_index()]; - let current_clock = current_row[CLK.base_table_index()]; + let previous_clock = previous_row[CLK.main_index()]; + let current_clock = current_row[CLK.main_index()]; let clock_jump_difference = current_clock - previous_clock; let log_derivative_summand = (indeterminate - clock_jump_difference).inverse(); cjd_lookup_log_derivative += log_derivative_summand; diff --git a/triton-vm/src/table/u32.rs b/triton-vm/src/table/u32.rs index e6a92caf5..85f4bb2b7 100644 --- a/triton-vm/src/table/u32.rs +++ b/triton-vm/src/table/u32.rs @@ -4,8 +4,8 @@ use air::challenge_id::ChallengeId::*; use air::cross_table_argument::CrossTableArg; use air::cross_table_argument::LookupArg; use air::table::u32::U32Table; -use air::table_column::MasterBaseTableColumn; -use air::table_column::MasterExtTableColumn; +use air::table_column::MasterAuxColumn; +use air::table_column::MasterMainColumn; use arbitrary::Arbitrary; use isa::instruction::Instruction; use ndarray::parallel::prelude::*; @@ -72,14 +72,13 @@ impl TraceTable for U32Table { let mut next_section_start = 0; for (&u32_table_entry, &multiplicity) in &aet.u32_entries { let mut first_row = Array2::zeros([1, MainColumn::COUNT]); - first_row[[0, MainColumn::CopyFlag.base_table_index()]] = bfe!(1); - first_row[[0, MainColumn::Bits.base_table_index()]] = bfe!(0); - first_row[[0, MainColumn::BitsMinus33Inv.base_table_index()]] = bfe!(-33).inverse(); - first_row[[0, MainColumn::CI.base_table_index()]] = - u32_table_entry.instruction.opcode_b(); - first_row[[0, MainColumn::LHS.base_table_index()]] = u32_table_entry.left_operand; - first_row[[0, MainColumn::RHS.base_table_index()]] = u32_table_entry.right_operand; - first_row[[0, MainColumn::LookupMultiplicity.base_table_index()]] = multiplicity.into(); + first_row[[0, MainColumn::CopyFlag.main_index()]] = bfe!(1); + first_row[[0, MainColumn::Bits.main_index()]] = bfe!(0); + first_row[[0, MainColumn::BitsMinus33Inv.main_index()]] = bfe!(-33).inverse(); + first_row[[0, MainColumn::CI.main_index()]] = u32_table_entry.instruction.opcode_b(); + first_row[[0, MainColumn::LHS.main_index()]] = u32_table_entry.left_operand; + first_row[[0, MainColumn::RHS.main_index()]] = u32_table_entry.right_operand; + first_row[[0, MainColumn::LookupMultiplicity.main_index()]] = multiplicity.into(); let u32_section = u32_section_next_row(first_row); let next_section_end = next_section_start + u32_section.nrows(); @@ -92,25 +91,23 @@ impl TraceTable for U32Table { fn pad(mut main_table: ArrayViewMut2, table_len: usize) { let mut padding_row = Array1::zeros([MainColumn::COUNT]); - padding_row[[MainColumn::CI.base_table_index()]] = Instruction::Split.opcode_b(); - padding_row[[MainColumn::BitsMinus33Inv.base_table_index()]] = bfe!(-33).inverse(); + padding_row[[MainColumn::CI.main_index()]] = Instruction::Split.opcode_b(); + padding_row[[MainColumn::BitsMinus33Inv.main_index()]] = bfe!(-33).inverse(); if table_len > 0 { let last_row = main_table.row(table_len - 1); - padding_row[[MainColumn::CI.base_table_index()]] = - last_row[MainColumn::CI.base_table_index()]; - padding_row[[MainColumn::LHS.base_table_index()]] = - last_row[MainColumn::LHS.base_table_index()]; - padding_row[[MainColumn::LhsInv.base_table_index()]] = - last_row[MainColumn::LhsInv.base_table_index()]; - padding_row[[MainColumn::Result.base_table_index()]] = - last_row[MainColumn::Result.base_table_index()]; + padding_row[[MainColumn::CI.main_index()]] = last_row[MainColumn::CI.main_index()]; + padding_row[[MainColumn::LHS.main_index()]] = last_row[MainColumn::LHS.main_index()]; + padding_row[[MainColumn::LhsInv.main_index()]] = + last_row[MainColumn::LhsInv.main_index()]; + padding_row[[MainColumn::Result.main_index()]] = + last_row[MainColumn::Result.main_index()]; // In the edge case that the last non-padding row comes from executing instruction // `lt` on operands 0 and 0, the `Result` column is 0. For the padding section, // where the `CopyFlag` is always 0, the `Result` needs to be set to 2 instead. - if padding_row[[MainColumn::CI.base_table_index()]] == Instruction::Lt.opcode_b() { - padding_row[[MainColumn::Result.base_table_index()]] = bfe!(2); + if padding_row[[MainColumn::CI.main_index()]] == Instruction::Lt.opcode_b() { + padding_row[[MainColumn::Result.main_index()]] = bfe!(2); } } @@ -140,19 +137,18 @@ impl TraceTable for U32Table { let mut running_sum_log_derivative = LookupArg::default_initial(); for row_idx in 0..base_table.nrows() { let current_row = base_table.row(row_idx); - if current_row[MainColumn::CopyFlag.base_table_index()].is_one() { - let lookup_multiplicity = - current_row[MainColumn::LookupMultiplicity.base_table_index()]; - let compressed_row = ci_weight * current_row[MainColumn::CI.base_table_index()] - + lhs_weight * current_row[MainColumn::LHS.base_table_index()] - + rhs_weight * current_row[MainColumn::RHS.base_table_index()] - + result_weight * current_row[MainColumn::Result.base_table_index()]; + if current_row[MainColumn::CopyFlag.main_index()].is_one() { + let lookup_multiplicity = current_row[MainColumn::LookupMultiplicity.main_index()]; + let compressed_row = ci_weight * current_row[MainColumn::CI.main_index()] + + lhs_weight * current_row[MainColumn::LHS.main_index()] + + rhs_weight * current_row[MainColumn::RHS.main_index()] + + result_weight * current_row[MainColumn::Result.main_index()]; running_sum_log_derivative += lookup_multiplicity * (lookup_indeterminate - compressed_row).inverse(); } let mut extension_row = ext_table.row_mut(row_idx); - extension_row[AuxColumn::LookupServerLogDerivative.ext_table_index()] = + extension_row[AuxColumn::LookupServerLogDerivative.aux_index()] = running_sum_log_derivative; } profiler!(stop "u32 table"); @@ -161,17 +157,17 @@ impl TraceTable for U32Table { fn u32_section_next_row(mut section: Array2) -> Array2 { let row_idx = section.nrows() - 1; - let current_instruction: Instruction = section[[row_idx, MainColumn::CI.base_table_index()]] + let current_instruction: Instruction = section[[row_idx, MainColumn::CI.main_index()]] .value() .try_into() .expect("Unknown instruction"); // Is the last row in this section reached? - if (section[[row_idx, MainColumn::LHS.base_table_index()]].is_zero() + if (section[[row_idx, MainColumn::LHS.main_index()]].is_zero() || current_instruction == Instruction::Pow) - && section[[row_idx, MainColumn::RHS.base_table_index()]].is_zero() + && section[[row_idx, MainColumn::RHS.main_index()]].is_zero() { - section[[row_idx, MainColumn::Result.base_table_index()]] = match current_instruction { + section[[row_idx, MainColumn::Result.main_index()]] = match current_instruction { Instruction::Split => bfe!(0), Instruction::Lt => bfe!(2), Instruction::And => bfe!(0), @@ -183,52 +179,50 @@ fn u32_section_next_row(mut section: Array2) -> Array2 section[[row_idx, MainColumn::LHS.base_table_index()]], - false => (section[[row_idx, MainColumn::LHS.base_table_index()]] - lhs_lsb) / bfe!(2), + next_row[MainColumn::CopyFlag.main_index()] = bfe!(0); + next_row[MainColumn::Bits.main_index()] += bfe!(1); + next_row[MainColumn::BitsMinus33Inv.main_index()] = + (next_row[MainColumn::Bits.main_index()] - bfe!(33)).inverse(); + next_row[MainColumn::LHS.main_index()] = match current_instruction == Instruction::Pow { + true => section[[row_idx, MainColumn::LHS.main_index()]], + false => (section[[row_idx, MainColumn::LHS.main_index()]] - lhs_lsb) / bfe!(2), }; - next_row[MainColumn::RHS.base_table_index()] = - (section[[row_idx, MainColumn::RHS.base_table_index()]] - rhs_lsb) / bfe!(2); - next_row[MainColumn::LookupMultiplicity.base_table_index()] = bfe!(0); + next_row[MainColumn::RHS.main_index()] = + (section[[row_idx, MainColumn::RHS.main_index()]] - rhs_lsb) / bfe!(2); + next_row[MainColumn::LookupMultiplicity.main_index()] = bfe!(0); section.push_row(next_row.view()).unwrap(); section = u32_section_next_row(section); let (mut row, next_row) = section.multi_slice_mut((s![row_idx, ..], s![row_idx + 1, ..])); - row[MainColumn::LhsInv.base_table_index()] = - row[MainColumn::LHS.base_table_index()].inverse_or_zero(); - row[MainColumn::RhsInv.base_table_index()] = - row[MainColumn::RHS.base_table_index()].inverse_or_zero(); + row[MainColumn::LhsInv.main_index()] = row[MainColumn::LHS.main_index()].inverse_or_zero(); + row[MainColumn::RhsInv.main_index()] = row[MainColumn::RHS.main_index()].inverse_or_zero(); - let next_row_result = next_row[MainColumn::Result.base_table_index()]; - row[MainColumn::Result.base_table_index()] = match current_instruction { + let next_row_result = next_row[MainColumn::Result.main_index()]; + row[MainColumn::Result.main_index()] = match current_instruction { Instruction::Split => next_row_result, Instruction::Lt => { match ( next_row_result.value(), lhs_lsb.value(), rhs_lsb.value(), - row[MainColumn::CopyFlag.base_table_index()].value(), + row[MainColumn::CopyFlag.main_index()].value(), ) { (0 | 1, _, _, _) => next_row_result, // result already known (2, 0, 1, _) => bfe!(1), // LHS < RHS @@ -240,18 +234,18 @@ fn u32_section_next_row(mut section: Array2) -> Array2 bfe!(2) * next_row_result + lhs_lsb * rhs_lsb, Instruction::Log2Floor => { - if row[MainColumn::LHS.base_table_index()].is_zero() { + if row[MainColumn::LHS.main_index()].is_zero() { bfe!(-1) - } else if !next_row[MainColumn::LHS.base_table_index()].is_zero() { + } else if !next_row[MainColumn::LHS.main_index()].is_zero() { next_row_result } else { // LHS != 0 && LHS' == 0 - row[MainColumn::Bits.base_table_index()] + row[MainColumn::Bits.main_index()] } } Instruction::Pow => match rhs_lsb.is_zero() { true => next_row_result * next_row_result, - false => next_row_result * next_row_result * row[MainColumn::LHS.base_table_index()], + false => next_row_result * next_row_result * row[MainColumn::LHS.main_index()], }, Instruction::PopCount => next_row_result + lhs_lsb, _ => panic!("Must be u32 instruction, not {current_instruction}."), diff --git a/triton-vm/src/vm.rs b/triton-vm/src/vm.rs index da5162ac0..975e96848 100644 --- a/triton-vm/src/vm.rs +++ b/triton-vm/src/vm.rs @@ -1049,49 +1049,49 @@ impl VMState { pub fn to_processor_row(&self) -> Array1 { use isa::instruction::InstructionBit; - use ProcessorBaseTableColumn::*; + use ProcessorMainColumn::*; let mut processor_row = Array1::zeros(::MainColumn::COUNT); let current_instruction = self.current_instruction().unwrap_or(Nop); let helper_variables = self.derive_helper_variables(); - processor_row[CLK.base_table_index()] = u64::from(self.cycle_count).into(); - processor_row[IP.base_table_index()] = (self.instruction_pointer as u32).into(); - processor_row[CI.base_table_index()] = current_instruction.opcode_b(); - processor_row[NIA.base_table_index()] = self.next_instruction_or_argument(); - processor_row[IB0.base_table_index()] = current_instruction.ib(InstructionBit::IB0); - processor_row[IB1.base_table_index()] = current_instruction.ib(InstructionBit::IB1); - processor_row[IB2.base_table_index()] = current_instruction.ib(InstructionBit::IB2); - processor_row[IB3.base_table_index()] = current_instruction.ib(InstructionBit::IB3); - processor_row[IB4.base_table_index()] = current_instruction.ib(InstructionBit::IB4); - processor_row[IB5.base_table_index()] = current_instruction.ib(InstructionBit::IB5); - processor_row[IB6.base_table_index()] = current_instruction.ib(InstructionBit::IB6); - processor_row[JSP.base_table_index()] = self.jump_stack_pointer(); - processor_row[JSO.base_table_index()] = self.jump_stack_origin(); - processor_row[JSD.base_table_index()] = self.jump_stack_destination(); - processor_row[ST0.base_table_index()] = self.op_stack[OpStackElement::ST0]; - processor_row[ST1.base_table_index()] = self.op_stack[OpStackElement::ST1]; - processor_row[ST2.base_table_index()] = self.op_stack[OpStackElement::ST2]; - processor_row[ST3.base_table_index()] = self.op_stack[OpStackElement::ST3]; - processor_row[ST4.base_table_index()] = self.op_stack[OpStackElement::ST4]; - processor_row[ST5.base_table_index()] = self.op_stack[OpStackElement::ST5]; - processor_row[ST6.base_table_index()] = self.op_stack[OpStackElement::ST6]; - processor_row[ST7.base_table_index()] = self.op_stack[OpStackElement::ST7]; - processor_row[ST8.base_table_index()] = self.op_stack[OpStackElement::ST8]; - processor_row[ST9.base_table_index()] = self.op_stack[OpStackElement::ST9]; - processor_row[ST10.base_table_index()] = self.op_stack[OpStackElement::ST10]; - processor_row[ST11.base_table_index()] = self.op_stack[OpStackElement::ST11]; - processor_row[ST12.base_table_index()] = self.op_stack[OpStackElement::ST12]; - processor_row[ST13.base_table_index()] = self.op_stack[OpStackElement::ST13]; - processor_row[ST14.base_table_index()] = self.op_stack[OpStackElement::ST14]; - processor_row[ST15.base_table_index()] = self.op_stack[OpStackElement::ST15]; - processor_row[OpStackPointer.base_table_index()] = self.op_stack.pointer(); - processor_row[HV0.base_table_index()] = helper_variables[0]; - processor_row[HV1.base_table_index()] = helper_variables[1]; - processor_row[HV2.base_table_index()] = helper_variables[2]; - processor_row[HV3.base_table_index()] = helper_variables[3]; - processor_row[HV4.base_table_index()] = helper_variables[4]; - processor_row[HV5.base_table_index()] = helper_variables[5]; + processor_row[CLK.main_index()] = u64::from(self.cycle_count).into(); + processor_row[IP.main_index()] = (self.instruction_pointer as u32).into(); + processor_row[CI.main_index()] = current_instruction.opcode_b(); + processor_row[NIA.main_index()] = self.next_instruction_or_argument(); + processor_row[IB0.main_index()] = current_instruction.ib(InstructionBit::IB0); + processor_row[IB1.main_index()] = current_instruction.ib(InstructionBit::IB1); + processor_row[IB2.main_index()] = current_instruction.ib(InstructionBit::IB2); + processor_row[IB3.main_index()] = current_instruction.ib(InstructionBit::IB3); + processor_row[IB4.main_index()] = current_instruction.ib(InstructionBit::IB4); + processor_row[IB5.main_index()] = current_instruction.ib(InstructionBit::IB5); + processor_row[IB6.main_index()] = current_instruction.ib(InstructionBit::IB6); + processor_row[JSP.main_index()] = self.jump_stack_pointer(); + processor_row[JSO.main_index()] = self.jump_stack_origin(); + processor_row[JSD.main_index()] = self.jump_stack_destination(); + processor_row[ST0.main_index()] = self.op_stack[OpStackElement::ST0]; + processor_row[ST1.main_index()] = self.op_stack[OpStackElement::ST1]; + processor_row[ST2.main_index()] = self.op_stack[OpStackElement::ST2]; + processor_row[ST3.main_index()] = self.op_stack[OpStackElement::ST3]; + processor_row[ST4.main_index()] = self.op_stack[OpStackElement::ST4]; + processor_row[ST5.main_index()] = self.op_stack[OpStackElement::ST5]; + processor_row[ST6.main_index()] = self.op_stack[OpStackElement::ST6]; + processor_row[ST7.main_index()] = self.op_stack[OpStackElement::ST7]; + processor_row[ST8.main_index()] = self.op_stack[OpStackElement::ST8]; + processor_row[ST9.main_index()] = self.op_stack[OpStackElement::ST9]; + processor_row[ST10.main_index()] = self.op_stack[OpStackElement::ST10]; + processor_row[ST11.main_index()] = self.op_stack[OpStackElement::ST11]; + processor_row[ST12.main_index()] = self.op_stack[OpStackElement::ST12]; + processor_row[ST13.main_index()] = self.op_stack[OpStackElement::ST13]; + processor_row[ST14.main_index()] = self.op_stack[OpStackElement::ST14]; + processor_row[ST15.main_index()] = self.op_stack[OpStackElement::ST15]; + processor_row[OpStackPointer.main_index()] = self.op_stack.pointer(); + processor_row[HV0.main_index()] = helper_variables[0]; + processor_row[HV1.main_index()] = helper_variables[1]; + processor_row[HV2.main_index()] = helper_variables[2]; + processor_row[HV3.main_index()] = helper_variables[3]; + processor_row[HV4.main_index()] = helper_variables[4]; + processor_row[HV5.main_index()] = helper_variables[5]; processor_row } @@ -1180,7 +1180,7 @@ impl VMState { impl Display for VMState { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - use ProcessorBaseTableColumn as ProcCol; + use ProcessorMainColumn as ProcCol; let Ok(instruction) = self.current_instruction() else { return write!(f, "END-OF-FILE"); @@ -1197,8 +1197,8 @@ impl Display for VMState { let row = self.to_processor_row(); - let register = |reg: ProcessorBaseTableColumn| { - let reg_string = format!("{}", row[reg.base_table_index()]); + let register = |reg: ProcessorMainColumn| { + let reg_string = format!("{}", row[reg.main_index()]); format!("{reg_string:>register_width$}") }; let multi_register = |regs: [_; 4]| regs.map(register).join(" | "); @@ -1219,7 +1219,7 @@ impl Display for VMState { let jso = register(ProcCol::JSO); let jsd = register(ProcCol::JSD); let osp = register(ProcCol::OpStackPointer); - let clk = row[ProcCol::CLK.base_table_index()].to_string(); + let clk = row[ProcCol::CLK.main_index()].to_string(); let clk = clk.trim_start_matches('0'); let first_line = format!("ip: {ip} ╷ ci: {ci} ╷ nia: {nia} │ {clk: >clk_width$}"); @@ -1264,7 +1264,7 @@ impl Display for VMState { ProcCol::IB1, ProcCol::IB0, ] - .map(|reg| row[reg.base_table_index()]) + .map(|reg| row[reg.main_index()]) .map(|bfe| format!("{bfe:>2}")) .join(" | "); print_row(f, format!("ib6-0: [ {ib_registers} ]",))?; @@ -2752,10 +2752,10 @@ pub(crate) mod tests { let_assert!(Ok((aet, _)) = VM::trace_execution(&program, [].into(), [].into())); let_assert!(Some(last_processor_row) = aet.processor_trace.rows().into_iter().last()); - let clk_count = last_processor_row[ProcessorBaseTableColumn::CLK.base_table_index()]; + let clk_count = last_processor_row[ProcessorMainColumn::CLK.main_index()]; assert!(BFieldElement::ZERO == clk_count); - let last_instruction = last_processor_row[ProcessorBaseTableColumn::CI.base_table_index()]; + let last_instruction = last_processor_row[ProcessorMainColumn::CI.main_index()]; assert!(Instruction::Halt.opcode_b() == last_instruction); println!("{last_processor_row}"); From 7926a12350968039b86394aa1ece515c6ed6af28 Mon Sep 17 00:00:00 2001 From: Jan Ferdinand Sauer Date: Tue, 10 Sep 2024 15:06:39 +0200 Subject: [PATCH 15/15] docs: Describe purpose of individual crates changelog: ignore --- triton-air/Cargo.toml | 2 +- triton-air/README.md | 4 ++++ triton-constraint-builder/Cargo.toml | 2 +- triton-constraint-builder/README.md | 5 +++++ triton-constraint-circuit/Cargo.toml | 2 +- triton-constraint-circuit/README.md | 4 ++++ triton-isa/Cargo.toml | 2 +- triton-isa/README.md | 4 ++++ 8 files changed, 21 insertions(+), 4 deletions(-) create mode 100644 triton-air/README.md create mode 100644 triton-constraint-builder/README.md create mode 100644 triton-constraint-circuit/README.md create mode 100644 triton-isa/README.md diff --git a/triton-air/Cargo.toml b/triton-air/Cargo.toml index 1ba8fb87d..9d5eca449 100644 --- a/triton-air/Cargo.toml +++ b/triton-air/Cargo.toml @@ -4,6 +4,7 @@ description = """ The Arithmetic Intermediate Representation (AIR) for Triton VM. """ +readme = "README.md" version.workspace = true edition.workspace = true authors.workspace = true @@ -11,7 +12,6 @@ license.workspace = true homepage.workspace = true documentation.workspace = true repository.workspace = true -readme.workspace = true [dependencies] arbitrary.workspace = true diff --git a/triton-air/README.md b/triton-air/README.md new file mode 100644 index 000000000..7c1e89469 --- /dev/null +++ b/triton-air/README.md @@ -0,0 +1,4 @@ +# Triton VM AIR + +This crate is part of the [Triton VM](https://triton-vm.org) ecosystem. It contains the definition +of the AIR constraints, which are part of the STARK proving system. diff --git a/triton-constraint-builder/Cargo.toml b/triton-constraint-builder/Cargo.toml index bc07adc4d..5e0641090 100644 --- a/triton-constraint-builder/Cargo.toml +++ b/triton-constraint-builder/Cargo.toml @@ -4,6 +4,7 @@ description = """ Emits efficient code from Triton VM's AIR. """ +readme = "README.md" version.workspace = true edition.workspace = true authors.workspace = true @@ -11,7 +12,6 @@ license.workspace = true homepage.workspace = true documentation.workspace = true repository.workspace = true -readme.workspace = true [dependencies] air.workspace = true diff --git a/triton-constraint-builder/README.md b/triton-constraint-builder/README.md new file mode 100644 index 000000000..d43a38513 --- /dev/null +++ b/triton-constraint-builder/README.md @@ -0,0 +1,5 @@ +# Constraint Circuit Builder + +This crate is part of the [Triton VM](https://triton-vm.org) ecosystem. It contains the code +generator emitting efficient Rust code for Triton VM's AIR constraints, which are part of the STARK +proving system. diff --git a/triton-constraint-circuit/Cargo.toml b/triton-constraint-circuit/Cargo.toml index 1ef2d0506..0885c72b4 100644 --- a/triton-constraint-circuit/Cargo.toml +++ b/triton-constraint-circuit/Cargo.toml @@ -4,6 +4,7 @@ description = """ AIR constraints build helper for Triton VM. """ +readme = "README.md" version.workspace = true edition.workspace = true authors.workspace = true @@ -11,7 +12,6 @@ license.workspace = true homepage.workspace = true documentation.workspace = true repository.workspace = true -readme.workspace = true [dependencies] arbitrary.workspace = true diff --git a/triton-constraint-circuit/README.md b/triton-constraint-circuit/README.md new file mode 100644 index 000000000..90db91982 --- /dev/null +++ b/triton-constraint-circuit/README.md @@ -0,0 +1,4 @@ +# Constraint Circuit + +This crate is part of the [Triton VM](https://triton-vm.org) ecosystem. It contains logic that helps +building efficient AIR constraints, which are part of the STARK proving system. diff --git a/triton-isa/Cargo.toml b/triton-isa/Cargo.toml index d697d3175..adf6fffc6 100644 --- a/triton-isa/Cargo.toml +++ b/triton-isa/Cargo.toml @@ -3,6 +3,7 @@ name = "triton-isa" description = """ The instruction set architecture for Triton VM. """ +readme = "README.md" version.workspace = true edition.workspace = true @@ -11,7 +12,6 @@ license.workspace = true homepage.workspace = true documentation.workspace = true repository.workspace = true -readme.workspace = true [dependencies] arbitrary.workspace = true diff --git a/triton-isa/README.md b/triton-isa/README.md new file mode 100644 index 000000000..65fb42a66 --- /dev/null +++ b/triton-isa/README.md @@ -0,0 +1,4 @@ +# Triton VM's Instruction Set Architecture + +This crate is part of the [Triton VM](https://triton-vm.org) ecosystem. It contains the basic +definitions for Triton VM's instructions as well as program parsing functionality.