diff --git a/triton-vm/src/table/constraint_circuit.rs b/triton-vm/src/table/constraint_circuit.rs index 2e2271aae..8c0986bb9 100644 --- a/triton-vm/src/table/constraint_circuit.rs +++ b/triton-vm/src/table/constraint_circuit.rs @@ -8,21 +8,30 @@ use std::cell::RefCell; use std::cmp; -use std::collections::*; +use std::collections::HashMap; +use std::collections::HashSet; +use std::fmt::Debug; +use std::fmt::Display; +use std::fmt::Formatter; use std::fmt::Result as FmtResult; -use std::fmt::*; use std::hash::Hash; use std::hash::Hasher; use std::iter::Sum; -use std::ops::*; +use std::ops::Add; +use std::ops::Mul; +use std::ops::Neg; +use std::ops::Sub; use std::rc::Rc; +use arbitrary::Arbitrary; +use arbitrary::Unstructured; use itertools::Itertools; use ndarray::ArrayView2; use num_traits::One; use num_traits::Zero; use quote::quote; use quote::ToTokens; +use strum::IntoEnumIterator; use twenty_first::prelude::*; use twenty_first::shared_math::mpolynomial::Degree; @@ -567,6 +576,8 @@ fn binop( lhs: ConstraintCircuitMonad, rhs: ConstraintCircuitMonad, ) -> ConstraintCircuitMonad { + assert_eq!(lhs.builder, rhs.builder); + // all `BinOp`s are commutative – try both orders of the operands let new_node = binop_new_node(binop, &rhs, &lhs); if let Some(node) = lhs.builder.all_nodes.borrow().get(&new_node) { @@ -725,7 +736,7 @@ impl ConstraintCircuitMonad { /// Reduce size of multitree by simplifying constant expressions such as `1 * MPol(_,_)` pub fn constant_folding(circuits: &mut [ConstraintCircuitMonad]) { - for circuit in circuits.iter_mut() { + for circuit in circuits { let mut mutated = true; while mutated { let (mutated_inner, maybe_new_root) = circuit.constant_fold_inner(); @@ -914,7 +925,7 @@ impl ConstraintCircuitMonad { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Eq, PartialEq)] /// Helper struct to construct new leaf nodes in the circuit multitree. Ensures that each newly /// created node gets a unique ID. pub struct ConstraintCircuitBuilder { @@ -1010,13 +1021,14 @@ impl ConstraintCircuitBuilder { } *self.id_counter.borrow_mut() += 1; - self.all_nodes.borrow_mut().insert(new_node.clone()); + let was_inserted = self.all_nodes.borrow_mut().insert(new_node.clone()); + assert!(was_inserted, "Leaf-created value must be new… {new_node}"); new_node } /// Substitute all nodes with ID `old_id` with the given `new` node. pub fn substitute(&self, old_id: usize, new: &Rc>>) { - for node in self.all_nodes.borrow().clone() { + for node in self.all_nodes.borrow().iter() { if node.circuit.borrow().id == old_id { continue; } @@ -1036,6 +1048,71 @@ impl ConstraintCircuitBuilder { } } +impl<'a, II: InputIndicator + Arbitrary<'a>> Arbitrary<'a> for ConstraintCircuitMonad { + fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result { + let builder = ConstraintCircuitBuilder::new(); + let mut random_circuit = random_circuit_leaf(&builder, u)?; + + let num_nodes_in_circuit = u.arbitrary_len::()?; + for _ in 0..num_nodes_in_circuit { + let leaf = random_circuit_leaf(&builder, u)?; + match u.int_in_range(0..=5)? { + 0 => random_circuit = random_circuit * leaf, + 1 => random_circuit = random_circuit + leaf, + 2 => random_circuit = random_circuit - leaf, + 3 => random_circuit = leaf * random_circuit, + 4 => random_circuit = leaf + random_circuit, + 5 => random_circuit = leaf - random_circuit, + _ => unreachable!(), + } + } + + Ok(random_circuit) + } +} + +fn random_circuit_leaf<'a, II: InputIndicator + Arbitrary<'a>>( + builder: &ConstraintCircuitBuilder, + u: &mut Unstructured<'a>, +) -> arbitrary::Result> { + let challenge_ids = ChallengeId::iter().collect_vec(); + let leaf = match u.int_in_range(0..=5)? { + 0 => builder.input(u.arbitrary()?), + 1 => builder.challenge(*u.choose(&challenge_ids)?), + 2 => builder.b_constant(u.arbitrary::()?), + 3 => builder.x_constant(u.arbitrary::()?), + 4 => builder.one(), + 5 => builder.zero(), + _ => unreachable!(), + }; + Ok(leaf) +} + +impl<'a> Arbitrary<'a> for SingleRowIndicator { + fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result { + let col_idx = u.arbitrary()?; + let indicator = match u.arbitrary()? { + true => Self::BaseRow(col_idx), + false => Self::ExtRow(col_idx), + }; + Ok(indicator) + } +} + +impl<'a> Arbitrary<'a> for DualRowIndicator { + fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result { + let col_idx = u.arbitrary()?; + let indicator = match u.int_in_range(0..=3)? { + 0 => Self::CurrentBaseRow(col_idx), + 1 => Self::CurrentExtRow(col_idx), + 2 => Self::NextBaseRow(col_idx), + 3 => Self::NextExtRow(col_idx), + _ => unreachable!(), + }; + Ok(indicator) + } +} + #[cfg(test)] mod tests { use std::collections::hash_map::DefaultHasher; @@ -1043,13 +1120,13 @@ mod tests { use itertools::Itertools; use ndarray::Array2; + use proptest::prelude::*; + use proptest_arbitrary_interop::arb; use rand::random; use rand::rngs::StdRng; - use rand::thread_rng; use rand::Rng; use rand::SeedableRng; - use strum::EnumCount; - use strum::IntoEnumIterator; + use test_strategy::proptest; use crate::table::cascade_table::ExtCascadeTable; use crate::table::challenges::Challenges; @@ -1068,290 +1145,129 @@ mod tests { use super::*; - fn random_circuit() -> ConstraintCircuitMonad { - let mut rng = thread_rng(); - let num_base_columns = rng.gen_range(1..120); - let num_ext_columns = rng.gen_range(1..40); - let circuit_builder = ConstraintCircuitBuilder::new(); - let initial_input = DualRowIndicator::NextBaseRow(rng.gen_range(0..num_base_columns)); - let mut random_circuit = circuit_builder.input(initial_input); + /// Circuit monads are put into hash sets. Hence, it is important that `Eq` and `Hash` + /// agree whether two nodes are equal: k1 == k2 => h(k1) == h(k2) + #[proptest] + fn equality_and_hash_agree( + #[strategy(arb())] circuit: ConstraintCircuitMonad, + ) { + let hash0 = hash_circuit(&circuit); + let other_circuit = circuit.clone() + circuit.builder.zero(); + let hash1 = hash_circuit(&other_circuit); + prop_assert_eq!(circuit == other_circuit, hash0 == hash1); + } + + /// The hash of a node may not depend on `ref_count`, `counter`, `id_counter_ref`, or + /// `all_nodes`, since `all_nodes` contains the digest of all nodes in the multi tree. + /// For more details, see [`HashSet`]. + #[proptest] + fn multi_circuit_hash_is_unchanged_by_meta_data( + #[strategy(arb())] circuit: ConstraintCircuitMonad, + new_ref_count: usize, + new_id_counter: usize, + ) { + let original_digest = hash_circuit(&circuit); - let num_nodes_in_circuit = rng.gen_range(50..300); - for _ in 0..num_nodes_in_circuit { - let node = random_circuit_node(&circuit_builder, num_base_columns, num_ext_columns); - match rng.gen_range(0..3) { - 0 => random_circuit = random_circuit * node, - 1 => random_circuit = random_circuit + node, - 2 => random_circuit = random_circuit - node, - _ => unreachable!(), - } - } - random_circuit - } - - fn random_circuit_node( - circuit_builder: &ConstraintCircuitBuilder, - num_base_columns: usize, - num_ext_columns: usize, - ) -> ConstraintCircuitMonad { - let mut rng = thread_rng(); - let base_col_index = rng.gen_range(0..num_base_columns); - let ext_col_index = rng.gen_range(0..num_ext_columns); - match rng.gen_range(0..39) { - 0..=4 => circuit_builder.input(DualRowIndicator::CurrentBaseRow(base_col_index)), - 5..=9 => circuit_builder.input(DualRowIndicator::NextBaseRow(base_col_index)), - 10..=14 => circuit_builder.input(DualRowIndicator::CurrentExtRow(ext_col_index)), - 15..=19 => circuit_builder.input(DualRowIndicator::NextExtRow(ext_col_index)), - 20..=24 => circuit_builder.b_constant(rng.gen::()), - 25..=29 => circuit_builder.x_constant(rng.gen::()), - 30..=34 => circuit_builder.challenge(random_challenge_id()), - 35 => circuit_builder.b_constant(0), - 36 => circuit_builder.x_constant(0), - 37 => circuit_builder.b_constant(1), - 38 => circuit_builder.x_constant(1), - _ => unreachable!(), - } - } + circuit.circuit.borrow_mut().ref_count = new_ref_count; + prop_assert_eq!(original_digest, hash_circuit(&circuit)); - fn random_challenge_id() -> ChallengeId { - let random_index = thread_rng().gen_range(0..ChallengeId::COUNT); - let all_challenge_ids = ChallengeId::iter().collect_vec(); - all_challenge_ids[random_index] + circuit.builder.id_counter.replace(new_id_counter); + prop_assert_eq!(original_digest, hash_circuit(&circuit)); } - // Make a deep copy of a Multicircuit and return it as a ConstraintCircuitMonad - fn deep_copy_inner( - val: &ConstraintCircuit, - builder: &mut ConstraintCircuitBuilder, - ) -> ConstraintCircuitMonad { - match &val.expression { - BinaryOperation(op, lhs, rhs) => { - let lhs_ref = deep_copy_inner(&lhs.borrow(), builder); - let rhs_ref = deep_copy_inner(&rhs.borrow(), builder); - binop(*op, lhs_ref, rhs_ref) - } - XConstant(xfe) => builder.x_constant(*xfe), - BConstant(bfe) => builder.b_constant(*bfe), - Input(input_index) => builder.input(*input_index), - Challenge(challenge_id) => builder.challenge(*challenge_id), - } + fn hash_circuit(circuit: &ConstraintCircuitMonad) -> u64 { + let mut hasher = DefaultHasher::new(); + circuit.hash(&mut hasher); + hasher.finish() } - fn deep_copy(val: &ConstraintCircuit) -> ConstraintCircuitMonad { - let mut builder = ConstraintCircuitBuilder::new(); - deep_copy_inner(val, &mut builder) + #[proptest] + fn constant_folding_can_deal_with_multiplication_by_one( + #[strategy(arb())] c: ConstraintCircuitMonad, + ) { + let one = || c.builder.one(); + check_constant_folding_property(c.clone(), c.clone() * one())?; + check_constant_folding_property(c.clone(), one() * c.clone())?; + check_constant_folding_property(c.clone(), one() * c.clone() * one())?; } - #[test] - fn equality_and_hash_agree() { - // The Multicircuits are put into a hash set. Hence, it is important that `Eq` and `Hash` - // agree whether two nodes are equal: k1 == k2 => h(k1) == h(k2) - for _ in 0..100 { - let circuit = random_circuit(); - let mut hasher0 = DefaultHasher::new(); - circuit.hash(&mut hasher0); - let hash0 = hasher0.finish(); - assert_eq!(circuit, circuit); - - let zero = circuit.builder.x_constant(0); - let same_circuit = circuit.clone() + zero; - let mut hasher1 = DefaultHasher::new(); - same_circuit.hash(&mut hasher1); - let hash1 = hasher1.finish(); - let eq_eq = circuit == same_circuit; - let hash_eq = hash0 == hash1; - - assert_eq!(eq_eq, hash_eq); - } + #[proptest] + fn constant_folding_can_deal_with_adding_zero( + #[strategy(arb())] c: ConstraintCircuitMonad, + ) { + let zero = || c.builder.zero(); + check_constant_folding_property(c.clone(), c.clone() + zero())?; + check_constant_folding_property(c.clone(), zero() + c.clone())?; + check_constant_folding_property(c.clone(), zero() + c.clone() + zero())?; } - #[test] - fn multi_circuit_hash_is_unchanged_by_meta_data() { - // From https://doc.rust-lang.org/std/collections/struct.HashSet.html - // "It is a logic error for a key to be modified in such a way that the key’s hash, as - // determined by the Hash trait, or its equality, as determined by the Eq trait, changes - // while it is in the map. This is normally only possible through Cell, RefCell, global - // state, I/O, or unsafe code. The behavior resulting from such a logic error is not - // specified, but will be encapsulated to the HashSet that observed the logic error and not - // result in undefined behavior. This could include panics, incorrect results, aborts, - // memory leaks, and non-termination." - // This means that the hash of a node may not depend on: `ref_count`, `counter`, - // `id_counter_ref`, or `all_nodes`. The reason for this constraint is that `all_nodes` - // contains the digest of all nodes in the multi tree. - let circuit = random_circuit(); - let mut hasher0 = DefaultHasher::new(); - circuit.hash(&mut hasher0); - let digest_prior = hasher0.finish(); - - // Increase ref counter and verify digest is unchanged - circuit.circuit.borrow_mut().ref_count += 1; - let mut hasher1 = DefaultHasher::new(); - circuit.hash(&mut hasher1); - let digest_after = hasher1.finish(); - assert_eq!( - digest_prior, digest_after, - "Digest must be unchanged by traversal" - ); - - // id counter and verify digest is unchanged - let _dummy = circuit.clone() + circuit.clone(); - let mut hasher2 = DefaultHasher::new(); - circuit.hash(&mut hasher2); - let digest_after2 = hasher2.finish(); - assert_eq!( - digest_prior, digest_after2, - "Digest must be unchanged by Id counter increase" - ); + #[proptest] + fn constant_folding_can_deal_with_subtracting_zero( + #[strategy(arb())] c: ConstraintCircuitMonad, + ) { + check_constant_folding_property(c.clone(), c.clone() - c.builder.zero())?; } - #[test] - fn circuit_equality_check_and_constant_folding() { - let circuit_builder: ConstraintCircuitBuilder = - ConstraintCircuitBuilder::new(); - let var_0 = circuit_builder.input(DualRowIndicator::CurrentBaseRow(0)); - let var_4 = circuit_builder.input(DualRowIndicator::NextBaseRow(4)); - let four = circuit_builder.x_constant(4); - let one = circuit_builder.x_constant(1); - let zero = circuit_builder.x_constant(0); - - assert_ne!(var_0, var_4); - assert_ne!(var_0, four); - assert_ne!(one, four); - assert_ne!(one, zero); - assert_ne!(zero, one); - - // Verify that constant folding can handle a = a * 1 - let var_0_copy_0 = deep_copy(&var_0.circuit.borrow()); - let var_0_mul_one_0 = var_0_copy_0.clone() * one.clone(); - assert_ne!(var_0_copy_0, var_0_mul_one_0); - let mut circuits = [var_0_copy_0, var_0_mul_one_0]; - ConstraintCircuitMonad::constant_folding(&mut circuits); - assert_eq!(circuits[0], circuits[1]); - - // Verify that constant folding can handle a = 1 * a - let var_0_copy_1 = deep_copy(&var_0.circuit.borrow()); - let var_0_one_mul_1 = one.clone() * var_0_copy_1.clone(); - assert_ne!(var_0_copy_1, var_0_one_mul_1); - let mut circuits = [var_0_copy_1, var_0_one_mul_1]; - ConstraintCircuitMonad::constant_folding(&mut circuits); - assert_eq!(circuits[0], circuits[1]); + /// Terribly confusing, super rare bug that's extremely difficult to reproduce or pin down: + /// 1. apply constant folding + /// 1. introduce a new redundant circuit, + /// 1. apply constant folding again. + /// + /// As a workaround, only _one_ redundant circuit is produced below. + /// + /// If you, dear reader, feel like diving into a rabbit hole of confusion and frustration, + /// try checking the constant-folding property of all 4 possible combinations in the same test. + #[proptest] + fn constant_folding_can_deal_with_adding_effectively_zero_term( + #[strategy(arb())] c: ConstraintCircuitMonad, + #[strategy(0_usize..4)] test_case: usize, + ) { + let zero = || c.builder.zero(); + let redundant_circuit = match test_case { + 0 => c.clone() + (c.clone() * zero()), + 1 => c.clone() + (zero() * c.clone()), + 2 => (c.clone() * zero()) + c.clone(), + 3 => (zero() * c.clone()) + c.clone(), + _ => unreachable!(), + }; - // Verify that constant folding can handle a = 1 * a * 1 - let var_0_copy_2 = deep_copy(&var_0.circuit.borrow()); - let var_0_one_mul_2 = one.clone() * var_0_copy_2.clone() * one; - assert_ne!(var_0_copy_2, var_0_one_mul_2); - let mut circuits = [var_0_copy_2, var_0_one_mul_2]; - ConstraintCircuitMonad::constant_folding(&mut circuits); - assert_eq!(circuits[0], circuits[1]); + check_constant_folding_property(c, redundant_circuit)?; + } - // Verify that constant folding handles a + 0 = a - let var_0_copy_3 = deep_copy(&var_0.circuit.borrow()); - let var_0_plus_zero_3 = var_0_copy_3.clone() + zero.clone(); - assert_ne!(var_0_copy_3, var_0_plus_zero_3); - let mut circuits = [var_0_copy_3, var_0_plus_zero_3]; + fn check_constant_folding_property( + circuit: ConstraintCircuitMonad, + circuit_with_redundancy: ConstraintCircuitMonad, + ) -> Result<(), TestCaseError> { + prop_assert_ne!(&circuit, &circuit_with_redundancy); + let mut circuits = [circuit, circuit_with_redundancy]; ConstraintCircuitMonad::constant_folding(&mut circuits); - assert_eq!(circuits[0], circuits[1]); + let [circuit_0, circuit_1] = circuits; + prop_assert_eq!(circuit_0, circuit_1); + Ok(()) + } - // Verify that constant folding handles a + (a * 0) = a - let var_0_copy_4 = deep_copy(&var_0.circuit.borrow()); - let var_0_plus_zero_4 = var_0_copy_4.clone() + var_0_copy_4.clone() * zero.clone(); - assert_ne!(var_0_copy_4, var_0_plus_zero_4); - let mut circuits = [var_0_copy_4, var_0_plus_zero_4]; - ConstraintCircuitMonad::constant_folding(&mut circuits); - assert_eq!(circuits[0], circuits[1]); + #[proptest] + fn constant_folding_does_not_replace_0_minus_circuit_with_the_circuit( + #[strategy(arb())] circuit: ConstraintCircuitMonad, + ) { + let zero_minus_circuit = circuit.builder.zero() - circuit.clone(); + prop_assert_ne!(&circuit, &zero_minus_circuit); - // Verify that constant folding does not equate `0 - a` with `a` - let var_0_copy_5 = deep_copy(&var_0.circuit.borrow()); - let zero_minus_var_0 = zero - var_0_copy_5.clone(); - assert_ne!(var_0_copy_5, zero_minus_var_0); - let mut circuits = [var_0_copy_5, zero_minus_var_0]; + let mut circuits = [circuit, zero_minus_circuit]; ConstraintCircuitMonad::constant_folding(&mut circuits); - assert_ne!(circuits[0], circuits[1]); - } - - #[test] - fn constant_folding_pbt() { - for _ in 0..200 { - let circuit = random_circuit(); - let one = circuit.builder.x_constant(1); - let zero = circuit.builder.x_constant(0); - - // Verify that constant folding can handle a = a * 1 - let copy_0 = deep_copy(&circuit.circuit.borrow()); - let copy_0_alt = copy_0.clone() * one.clone(); - assert_ne!(copy_0, copy_0_alt); - let mut circuits = [copy_0.clone(), copy_0_alt.clone()]; - ConstraintCircuitMonad::constant_folding(&mut circuits); - assert_eq!(circuits[0], circuits[1]); - - // Verify that constant folding can handle a = 1 * a - let copy_1 = deep_copy(&circuit.circuit.borrow()); - let copy_1_alt = one.clone() * copy_1.clone(); - assert_ne!(copy_1, copy_1_alt); - let mut circuits = [copy_1, copy_1_alt]; - ConstraintCircuitMonad::constant_folding(&mut circuits); - assert_eq!(circuits[0], circuits[1]); - - // Verify that constant folding can handle a = 1 * a * 1 - let copy_2 = deep_copy(&circuit.circuit.borrow()); - let copy_2_alt = one.clone() * copy_2.clone() * one.clone(); - assert_ne!(copy_2, copy_2_alt); - let mut circuits = [copy_2, copy_2_alt]; - ConstraintCircuitMonad::constant_folding(&mut circuits); - assert_eq!(circuits[0], circuits[1]); - - // Verify that constant folding handles a + 0 = a - let copy_3 = deep_copy(&circuit.circuit.borrow()); - let copy_3_alt = copy_3.clone() + zero.clone(); - assert_ne!(copy_3, copy_3_alt); - let mut circuits = [copy_3, copy_3_alt]; - ConstraintCircuitMonad::constant_folding(&mut circuits); - assert_eq!(circuits[0], circuits[1]); - - // Verify that constant folding handles a + (a * 0) = a - let copy_4 = deep_copy(&circuit.circuit.borrow()); - let copy_4_alt = copy_4.clone() + copy_4.clone() * zero.clone(); - assert_ne!(copy_4, copy_4_alt); - let mut circuits = [copy_4, copy_4_alt]; - ConstraintCircuitMonad::constant_folding(&mut circuits); - assert_eq!(circuits[0], circuits[1]); - - // Verify that constant folding handles a + (0 * a) = a - let copy_5 = deep_copy(&circuit.circuit.borrow()); - let copy_5_alt = copy_5.clone() + copy_5.clone() * zero.clone(); - assert_ne!(copy_5, copy_5_alt); - let mut circuits = [copy_5, copy_5_alt]; - ConstraintCircuitMonad::constant_folding(&mut circuits); - assert_eq!(circuits[0], circuits[1]); - - // Verify that constant folding does not equate `0 - a` with `a` - // But only if `a != 0` - let copy_6 = deep_copy(&circuit.circuit.borrow()); - let zero_minus_copy_6 = zero.clone() - copy_6.clone(); - assert_ne!(copy_6, zero_minus_copy_6); - let mut circuits = [copy_6, zero_minus_copy_6]; - ConstraintCircuitMonad::constant_folding(&mut circuits); - let copy_6_is_zero = circuits[0].circuit.borrow().is_zero(); - let copy_6_expr = circuits[0].circuit.borrow().expression.clone(); - let zero_minus_copy_6_expr = circuits[1].circuit.borrow().expression.clone(); - - // An X field and a B field leaf will never be equal - let copy_6_and_zero_minus_copy_6_have_same_constant_type = matches!( - (copy_6_expr, zero_minus_copy_6_expr), - (BConstant(_), BConstant(_)) | (XConstant(_), XConstant(_)) - ); - match copy_6_is_zero && copy_6_and_zero_minus_copy_6_have_same_constant_type { - true => assert_eq!(circuits[0], circuits[1]), - false => assert_ne!(circuits[0], circuits[1]), + let circuit = circuits[0].circuit.borrow().expression.clone(); + let zero_minus_circuit = circuits[1].circuit.borrow().expression.clone(); + + match (circuit, zero_minus_circuit) { + // Takes care of special case `circuit == 0`, where `0 - c == c`. Also gives stronger + // guarantees if folding reduces circuit to constant, but that's just bonus. + (BConstant(l), BConstant(r)) => prop_assert_eq!(l, -r), + (XConstant(l), XConstant(r)) => prop_assert_eq!(l, -r), + (BConstant(_), XConstant(_)) | (XConstant(_), BConstant(_)) => { + let reason = "`circuit` and `0 - circuit` must be of same type"; + return Err(TestCaseError::fail(reason)); } - - // Verify that constant folding handles a - 0 = a - let copy_7 = deep_copy(&circuit.circuit.borrow()); - let copy_7_alt = copy_7.clone() - zero.clone(); - assert_ne!(copy_7, copy_7_alt); - let mut circuits = [copy_7, copy_7_alt]; - ConstraintCircuitMonad::constant_folding(&mut circuits); - assert_eq!(circuits[0], circuits[1]); + _ => prop_assert_ne!(&circuits[0], &circuits[1]), } }