From eed044d83ccf5dce767d2dd09f0f1184b0ca4b8e Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Wed, 29 May 2024 16:49:13 -0400 Subject: [PATCH 01/46] feat: sorted_limbs chip checking each limb less than limb_bits bits --- chips/src/lib.rs | 1 + chips/src/sorted_limbs/air.rs | 55 ++++++++++++++++++ chips/src/sorted_limbs/chip.rs | 22 +++++++ chips/src/sorted_limbs/columns.rs | 48 ++++++++++++++++ chips/src/sorted_limbs/mod.rs | 95 +++++++++++++++++++++++++++++++ chips/src/sorted_limbs/trace.rs | 46 +++++++++++++++ chips/tests/integration_test.rs | 41 ++++++++++++- 7 files changed, 307 insertions(+), 1 deletion(-) create mode 100644 chips/src/sorted_limbs/air.rs create mode 100644 chips/src/sorted_limbs/chip.rs create mode 100644 chips/src/sorted_limbs/columns.rs create mode 100644 chips/src/sorted_limbs/mod.rs create mode 100644 chips/src/sorted_limbs/trace.rs diff --git a/chips/src/lib.rs b/chips/src/lib.rs index 0c5c4cfc04..64c11e54d2 100644 --- a/chips/src/lib.rs +++ b/chips/src/lib.rs @@ -1,6 +1,7 @@ /// Chip to range check a value has less than a fixed number of bits pub mod range; pub mod range_gate; +pub mod sorted_limbs; pub mod xor_bits; pub mod xor_limbs; pub mod xor_lookup; diff --git a/chips/src/sorted_limbs/air.rs b/chips/src/sorted_limbs/air.rs new file mode 100644 index 0000000000..9953366f46 --- /dev/null +++ b/chips/src/sorted_limbs/air.rs @@ -0,0 +1,55 @@ +use std::borrow::Borrow; + +use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir}; +use p3_field::{AbstractField, Field}; +use p3_matrix::Matrix; + +use super::columns::SortedLimbsCols; +use super::SortedLimbsChip; + +impl BaseAir for SortedLimbsChip { + fn width(&self) -> usize { + SortedLimbsCols::::get_width(self.limb_bits(), self.decomp(), self.key_vec_len()) + } +} + +impl Air for SortedLimbsChip +where + AB: AirBuilder, + AB::Var: Clone, +{ + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let _pis = builder.public_values(); + + let (local, _next) = (main.row_slice(0), main.row_slice(1)); + let local: &[AB::Var] = (*local).borrow(); + + let sort_limbs_cols = SortedLimbsCols::::from_slice( + local, + self.limb_bits(), + self.decomp(), + self.key_vec_len(), + ); + let num_limbs = (self.limb_bits() + self.decomp() - 1) / self.decomp(); + let key_len = self.key_vec_len(); + + for i in 0..key_len { + let mut key_from_limbs: AB::Expr = AB::Expr::zero(); + // constrain that the decomposition is correct + for j in 0..num_limbs { + key_from_limbs += sort_limbs_cols.keys_decomp[i][j] + * AB::Expr::from_canonical_u64(1 << (j * self.decomp())); + } + + // constrain that the shifted last sublimb is shifted correctly + let shifted_val = sort_limbs_cols.keys_decomp[i][num_limbs - 1] + * AB::Expr::from_canonical_u64( + 1 << (self.decomp() - (self.limb_bits() % self.decomp())), + ); + + builder.assert_eq(sort_limbs_cols.keys_decomp[i][num_limbs], shifted_val); + builder.assert_eq(key_from_limbs, sort_limbs_cols.key[i]); + } + } +} diff --git a/chips/src/sorted_limbs/chip.rs b/chips/src/sorted_limbs/chip.rs new file mode 100644 index 0000000000..20ae923190 --- /dev/null +++ b/chips/src/sorted_limbs/chip.rs @@ -0,0 +1,22 @@ +use super::columns::SortedLimbsCols; +use afs_stark_backend::interaction::{Chip, Interaction}; +use p3_field::PrimeField64; + +use super::SortedLimbsChip; + +impl Chip for SortedLimbsChip { + fn sends(&self) -> Vec> { + let num_cols = + SortedLimbsCols::::get_width(self.limb_bits(), self.decomp(), self.key_vec_len()); + let all_cols = (0..num_cols).collect::>(); + + let cols_numbered = SortedLimbsCols::::cols_numbered( + &all_cols, + self.limb_bits(), + self.decomp(), + self.key_vec_len(), + ); + + self.sends_custom(cols_numbered) + } +} diff --git a/chips/src/sorted_limbs/columns.rs b/chips/src/sorted_limbs/columns.rs new file mode 100644 index 0000000000..afce8f7637 --- /dev/null +++ b/chips/src/sorted_limbs/columns.rs @@ -0,0 +1,48 @@ +use afs_derive::AlignedBorrow; + +#[derive(Default, AlignedBorrow)] +pub struct SortedLimbsCols { + pub key: Vec, + pub keys_decomp: Vec>, +} + +impl SortedLimbsCols { + pub fn from_slice(slc: &[T], limb_bits: usize, decomp: usize, key_vec_len: usize) -> Self { + // num_limbs is the number of sublimbs per limb, not including the shifted last sublimb + let num_limbs = (limb_bits + decomp - 1) / decomp; + + let key = slc[..key_vec_len].to_vec(); + // num_limbs + 1 to account for the shifted last sublimb + let keys_decomp = slc[key_vec_len..] + .chunks(num_limbs + 1) + .map(|chunk| chunk.to_vec()) + .collect(); + + Self { key, keys_decomp } + } + + pub fn get_width(limb_bits: usize, decomp: usize, key_vec_len: usize) -> usize { + // there are (limb_bits + decomp - 1) / decomp sublimbs per limb, we add 1 to + // account for the sublimb itself, and another 1 to account for the shifted + // last sublimb + key_vec_len * ((limb_bits + decomp - 1) / decomp + 2) + } + + pub fn cols_numbered( + cols: &[usize], + limb_bits: usize, + decomp: usize, + key_vec_len: usize, + ) -> SortedLimbsCols { + // num_limbs is the number of sublimbs per limb, not including the shifted last sublimb + let num_limbs = (limb_bits + decomp - 1) / decomp; + let key = cols[..key_vec_len].to_vec(); + // num_limbs + 1 to account for the shifted last sublimb + let keys_decomp = cols[key_vec_len..] + .chunks(num_limbs + 1) + .map(|chunk| chunk.to_vec()) + .collect(); + + SortedLimbsCols { key, keys_decomp } + } +} diff --git a/chips/src/sorted_limbs/mod.rs b/chips/src/sorted_limbs/mod.rs new file mode 100644 index 0000000000..f3c5e2a78c --- /dev/null +++ b/chips/src/sorted_limbs/mod.rs @@ -0,0 +1,95 @@ +use crate::range_gate::RangeCheckerGateChip; + +use afs_stark_backend::interaction::Interaction; +use columns::SortedLimbsCols; +use p3_air::VirtualPairCol; +use p3_field::PrimeField64; + +pub mod air; +pub mod chip; +pub mod columns; +pub mod trace; + +/** + * This Chip constrains that consecutive rows are sorted lexicographically. + * + * Each row consists of a key decomposed into limbs, and the chip constrains + * each limb has at most limb_bits bits, where limb_bits is at most 31. It + * does this by interacting with a RangeCheckerGateChip. Because the range checker + * gate can take MAX up to 2^20, we further decompose each limb into sublimbs + * of size decomp bits. + */ +#[derive(Default)] +pub struct SortedLimbsChip { + bus_index: usize, + limb_bits: usize, + decomp: usize, + key_vec_len: usize, + keys: Vec>, + + pub range_checker_gate: RangeCheckerGateChip, +} + +impl SortedLimbsChip { + pub fn new( + bus_index: usize, + limb_bits: usize, + decomp: usize, + key_vec_len: usize, + keys: Vec>, + ) -> Self { + Self { + bus_index, + limb_bits, + decomp, + key_vec_len, + keys, + range_checker_gate: RangeCheckerGateChip::::new(bus_index), + } + } + + pub fn bus_index(&self) -> usize { + self.bus_index + } + + pub fn limb_bits(&self) -> usize { + self.limb_bits + } + + pub fn decomp(&self) -> usize { + self.decomp + } + + pub fn key_vec_len(&self) -> usize { + self.key_vec_len + } + + pub fn keys(&self) -> Vec> { + self.keys.clone() + } + + pub fn sends_custom( + &self, + cols: SortedLimbsCols, + ) -> Vec> { + // num_limbs is the number of sublimbs per limb of key, not including the + // shifted last sublimb + let num_limbs = (self.limb_bits() + self.decomp() - 1) / self.decomp(); + let num_keys = self.key_vec_len(); + + let mut interactions = vec![]; + + for i in 0..num_keys { + // add 1 to account for the shifted last sublimb + for j in 0..(num_limbs + 1) { + interactions.push(Interaction { + fields: vec![VirtualPairCol::single_main(cols.keys_decomp[i][j])], + count: VirtualPairCol::constant(F::one()), + argument_index: self.bus_index(), + }); + } + } + + interactions + } +} diff --git a/chips/src/sorted_limbs/trace.rs b/chips/src/sorted_limbs/trace.rs new file mode 100644 index 0000000000..aee1a05a2b --- /dev/null +++ b/chips/src/sorted_limbs/trace.rs @@ -0,0 +1,46 @@ +use p3_field::PrimeField64; +use p3_matrix::dense::RowMajorMatrix; + +use super::{columns::SortedLimbsCols, SortedLimbsChip}; + +impl SortedLimbsChip { + pub fn generate_trace(&self) -> RowMajorMatrix { + let num_xor_cols: usize = + SortedLimbsCols::::get_width(self.limb_bits(), self.decomp(), self.key_vec_len()); + + let keys = self.keys(); + + let num_limbs = (self.limb_bits() + self.decomp() - 1) / self.decomp(); + + let rows = keys + .iter() + .map(|x| { + let mut row = vec![]; + for &item in x.iter() { + row.push(F::from_canonical_u32(item)); + } + for &val in x.iter() { + // decompose each limb into sublimbs of size self.decomp() bits + for j in 0..num_limbs { + let bits = (val >> (j * self.decomp())) & ((1 << self.decomp()) - 1); + row.push(F::from_canonical_u32(bits)); + self.range_checker_gate.add_count(bits) + } + // the last sublimb should be of size self.limb_bits() % self.decomp() bits, + // so we need to shift it to constrain this + let bits = + (val >> ((num_limbs - 1) * self.decomp())) & ((1 << self.decomp()) - 1); + self.range_checker_gate + .add_count(bits << (self.decomp() - (self.limb_bits() % self.decomp()))); + + row.push(F::from_canonical_u32( + bits << (self.decomp() - (self.limb_bits() % self.decomp())), + )); + } + row + }) + .collect::>(); + + RowMajorMatrix::new(rows.concat(), num_xor_cols) + } +} diff --git a/chips/tests/integration_test.rs b/chips/tests/integration_test.rs index 818e8c049d..7740dc0d8a 100644 --- a/chips/tests/integration_test.rs +++ b/chips/tests/integration_test.rs @@ -1,6 +1,6 @@ use std::{iter, sync::Arc}; -use afs_chips::{range, range_gate, xor_bits, xor_limbs}; +use afs_chips::{range, range_gate, sorted_limbs, xor_bits, xor_limbs}; use afs_stark_backend::prover::USE_DEBUG_BUILDER; use afs_stark_backend::rap::AnyRap; use afs_stark_backend::verifier::VerificationError; @@ -73,6 +73,45 @@ fn test_list_range_checker() { run_simple_test_no_pis(all_chips, all_traces).expect("Verification failed"); } +#[test] +fn test_sorted_limbs_chip() { + let mut rng = create_seeded_rng(); + + use sorted_limbs::SortedLimbsChip; + + const BUS_INDEX: usize = 0; + const LIMB_BITS: usize = 20; + const DECOMP: usize = 8; + const KEY_VEC_LEN: usize = 4; + + const LOG_REQUESTS: usize = 2; + + const MAX: u32 = 1 << 8; + const MAX_LIMB: u32 = 1 << LIMB_BITS; + const REQUESTS: usize = 1 << LOG_REQUESTS; + + let requests = (0..REQUESTS) + .map(|_| { + (0..KEY_VEC_LEN) + .map(|_| rng.gen::() % MAX_LIMB) + .collect::>() + }) + .collect::>>(); + + let sorted_limbs_chip = + SortedLimbsChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); + + let sorted_limbs_chip_trace: DenseMatrix = sorted_limbs_chip.generate_trace(); + let sorted_limbs_range_chip_trace: DenseMatrix = + sorted_limbs_chip.range_checker_gate.generate_trace(); + + run_simple_test_no_pis( + vec![&sorted_limbs_chip, &sorted_limbs_chip.range_checker_gate], + vec![sorted_limbs_chip_trace, sorted_limbs_range_chip_trace], + ) + .expect("Verification failed"); +} + #[test] fn test_xor_bits_chip() { let mut rng = create_seeded_rng(); From 023c03446e05e34b9f84a5065508dd5bc3f8334e Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Thu, 30 May 2024 12:30:34 -0400 Subject: [PATCH 02/46] feat: completed sorted_limbs chip with tests --- chips/src/sorted_limbs/air.rs | 30 +++-- chips/src/sorted_limbs/columns.rs | 159 +++++++++++++++++++++++-- chips/src/sorted_limbs/mod.rs | 117 ++++++++++++++++++- chips/src/sorted_limbs/tests/mod.rs | 173 ++++++++++++++++++++++++++++ chips/src/sorted_limbs/trace.rs | 110 ++++++++++++++++-- chips/tests/integration_test.rs | 41 +------ 6 files changed, 559 insertions(+), 71 deletions(-) create mode 100644 chips/src/sorted_limbs/tests/mod.rs diff --git a/chips/src/sorted_limbs/air.rs b/chips/src/sorted_limbs/air.rs index 9953366f46..5caccf889e 100644 --- a/chips/src/sorted_limbs/air.rs +++ b/chips/src/sorted_limbs/air.rs @@ -22,34 +22,46 @@ where let main = builder.main(); let _pis = builder.public_values(); - let (local, _next) = (main.row_slice(0), main.row_slice(1)); + let (local, next) = (main.row_slice(0), main.row_slice(1)); let local: &[AB::Var] = (*local).borrow(); + let next: &[AB::Var] = (*next).borrow(); - let sort_limbs_cols = SortedLimbsCols::::from_slice( + let local_cols = SortedLimbsCols::::from_slice( local, self.limb_bits(), self.decomp(), self.key_vec_len(), ); + let next_cols = SortedLimbsCols::::from_slice( + next, + self.limb_bits(), + self.decomp(), + self.key_vec_len(), + ); + let num_limbs = (self.limb_bits() + self.decomp() - 1) / self.decomp(); let key_len = self.key_vec_len(); + // to range check the last sublimb of the decomposed limb, we need to shift it to make sure it is in + // the correct range + let last_limb_shift = (self.decomp() - (self.limb_bits() % self.decomp())) % self.decomp(); + for i in 0..key_len { let mut key_from_limbs: AB::Expr = AB::Expr::zero(); // constrain that the decomposition is correct for j in 0..num_limbs { - key_from_limbs += sort_limbs_cols.keys_decomp[i][j] + key_from_limbs += local_cols.keys_decomp[i][j] * AB::Expr::from_canonical_u64(1 << (j * self.decomp())); } // constrain that the shifted last sublimb is shifted correctly - let shifted_val = sort_limbs_cols.keys_decomp[i][num_limbs - 1] - * AB::Expr::from_canonical_u64( - 1 << (self.decomp() - (self.limb_bits() % self.decomp())), - ); + let shifted_val = local_cols.keys_decomp[i][num_limbs - 1] + * AB::Expr::from_canonical_u64(1 << last_limb_shift); - builder.assert_eq(sort_limbs_cols.keys_decomp[i][num_limbs], shifted_val); - builder.assert_eq(key_from_limbs, sort_limbs_cols.key[i]); + builder.assert_eq(local_cols.keys_decomp[i][num_limbs], shifted_val); + builder.assert_eq(key_from_limbs, local_cols.key[i]); } + + self.is_less_than(builder, local_cols, next_cols); } } diff --git a/chips/src/sorted_limbs/columns.rs b/chips/src/sorted_limbs/columns.rs index afce8f7637..52c025cdf7 100644 --- a/chips/src/sorted_limbs/columns.rs +++ b/chips/src/sorted_limbs/columns.rs @@ -4,28 +4,116 @@ use afs_derive::AlignedBorrow; pub struct SortedLimbsCols { pub key: Vec, pub keys_decomp: Vec>, + pub intermed_sum: Vec, + pub lower_bits: Vec, + pub upper_bit: Vec, + pub lower_bits_decomp: Vec>, + pub diff: Vec, + pub is_zero: Vec, + pub inverses: Vec, } impl SortedLimbsCols { pub fn from_slice(slc: &[T], limb_bits: usize, decomp: usize, key_vec_len: usize) -> Self { // num_limbs is the number of sublimbs per limb, not including the shifted last sublimb let num_limbs = (limb_bits + decomp - 1) / decomp; + let mut cur_start_idx = 0; + let mut cur_end_idx = key_vec_len; - let key = slc[..key_vec_len].to_vec(); - // num_limbs + 1 to account for the shifted last sublimb - let keys_decomp = slc[key_vec_len..] + // the first key_vec_len elements are the key itself + let key = slc[cur_start_idx..cur_end_idx].to_vec(); + cur_start_idx = cur_end_idx; + cur_end_idx += key_vec_len * (num_limbs + 1); + + // the next key_vec_len * (num_limbs + 1) elements are the decomposed keys (with each having + // an extra shifted last sublimb) + let keys_decomp = slc[cur_start_idx..cur_end_idx] .chunks(num_limbs + 1) .map(|chunk| chunk.to_vec()) .collect(); + cur_start_idx = cur_end_idx; + cur_end_idx += key_vec_len; + + // the next key_vec_len elements are the values of 2^num_limbs + b - a - 1 where a and b are limbs + // on consecutive rows and b is the row after a + let intermed_sum = slc[cur_start_idx..cur_end_idx].to_vec(); + cur_start_idx = cur_end_idx; + cur_end_idx += key_vec_len; + + // the next key_vec_len elements are the values of the lower num_limbs bits of the intermediate sum + let lower_bits = slc[cur_start_idx..cur_end_idx].to_vec(); + cur_start_idx = cur_end_idx; + cur_end_idx += key_vec_len; + + // the next key_vec_len elements are the values of the upper bit of the intermediate sum; note that + // b > a <=> upper_bit = 1 + let upper_bit = slc[cur_start_idx..cur_end_idx].to_vec(); + cur_start_idx = cur_end_idx; + cur_end_idx += key_vec_len * (num_limbs + 1); - Self { key, keys_decomp } + // the next key_vec_len * (num_limbs + 1) elements are the decomposed limbs of the lower bits of the + // intermediate sum + let lower_bits_decomp = slc[cur_start_idx..cur_end_idx] + .chunks(num_limbs + 1) + .map(|chunk| chunk.to_vec()) + .collect(); + cur_start_idx = cur_end_idx; + cur_end_idx += key_vec_len; + + // the next key_vec_len elements are the difference between consecutive limbs of rows + let diff = slc[cur_start_idx..cur_end_idx].to_vec(); + cur_start_idx = cur_end_idx; + cur_end_idx += key_vec_len; + + // the next key_vec_len elements are the indicator whether the difference is zero; if difference is + // zero then the two limbs must be equal + let is_zero = slc[cur_start_idx..cur_end_idx].to_vec(); + cur_start_idx = cur_end_idx; + cur_end_idx += key_vec_len; + + // the next key_vec_len elements contain the inverses of the corresponding sum of diff and is_zero; + // note that this sum will always be nonzero so the inverse will exist + let inverses = slc[cur_start_idx..cur_end_idx].to_vec(); + + Self { + key, + keys_decomp, + intermed_sum, + lower_bits, + upper_bit, + lower_bits_decomp, + diff, + is_zero, + inverses, + } } pub fn get_width(limb_bits: usize, decomp: usize, key_vec_len: usize) -> usize { // there are (limb_bits + decomp - 1) / decomp sublimbs per limb, we add 1 to // account for the sublimb itself, and another 1 to account for the shifted // last sublimb - key_vec_len * ((limb_bits + decomp - 1) / decomp + 2) + let mut width = 0; + // for the key itself + width += key_vec_len; + // for the decomposed keys + let num_limbs = (limb_bits + decomp - 1) / decomp; + width += key_vec_len * (num_limbs + 1); + // for the 2^limb_bits + b - a values + width += key_vec_len; + // for the lower_bits + width += key_vec_len; + // for the upper_bit + width += key_vec_len; + // for the decomposed lower_bits + width += key_vec_len * (num_limbs + 1); + // for the difference between consecutive rows + width += key_vec_len; + // for the indicator whether difference is zero + width += key_vec_len; + // for the y such that y * (i + x) = 1 + width += key_vec_len; + + width } pub fn cols_numbered( @@ -36,13 +124,66 @@ impl SortedLimbsCols { ) -> SortedLimbsCols { // num_limbs is the number of sublimbs per limb, not including the shifted last sublimb let num_limbs = (limb_bits + decomp - 1) / decomp; - let key = cols[..key_vec_len].to_vec(); - // num_limbs + 1 to account for the shifted last sublimb - let keys_decomp = cols[key_vec_len..] + let mut cur_start_idx = 0; + let mut cur_end_idx = key_vec_len; + + // the first key_vec_len elements are the key itself + let key = cols[cur_start_idx..cur_end_idx].to_vec(); + cur_start_idx = cur_end_idx; + cur_end_idx += key_vec_len * (num_limbs + 1); + + // the next key_vec_len * (num_limbs + 1) elements are the decomposed keys + let keys_decomp = cols[cur_start_idx..cur_end_idx] + .chunks(num_limbs + 1) + .map(|chunk| chunk.to_vec()) + .collect(); + cur_start_idx = cur_end_idx; + cur_end_idx += key_vec_len; + + // the next key_vec_len elements are the intermediate sum + let intermed_sum = cols[cur_start_idx..cur_end_idx].to_vec(); + cur_start_idx = cur_end_idx; + cur_end_idx += key_vec_len; + + // the next key_vec_len elements are the lower_bits + let lower_bits = cols[cur_start_idx..cur_end_idx].to_vec(); + cur_start_idx = cur_end_idx; + cur_end_idx += key_vec_len; + + // the next key_vec_len elements are the upper_bit + let upper_bit = cols[cur_start_idx..cur_end_idx].to_vec(); + cur_start_idx = cur_end_idx; + cur_end_idx += key_vec_len * (num_limbs + 1); + + // the next key_vec_len * (num_limbs + 1) elements are the decomposed lower_bits + let lower_bits_decomp = cols[cur_start_idx..cur_end_idx] .chunks(num_limbs + 1) .map(|chunk| chunk.to_vec()) .collect(); - SortedLimbsCols { key, keys_decomp } + // the next key_vec_len elements are the difference between consecutive rows + let diff = cols[cur_start_idx..cur_end_idx].to_vec(); + cur_start_idx = cur_end_idx; + cur_end_idx += key_vec_len; + + // the next key_vec_len elements are the indicator whether difference is zero + let is_zero = cols[cur_start_idx..cur_end_idx].to_vec(); + cur_start_idx = cur_end_idx; + cur_end_idx += key_vec_len; + + // the next key_vec_len elements are the inverses + let inverses = cols[cur_start_idx..cur_end_idx].to_vec(); + + SortedLimbsCols { + key, + keys_decomp, + intermed_sum, + lower_bits, + upper_bit, + lower_bits_decomp, + diff, + is_zero, + inverses, + } } } diff --git a/chips/src/sorted_limbs/mod.rs b/chips/src/sorted_limbs/mod.rs index f3c5e2a78c..0145ec5464 100644 --- a/chips/src/sorted_limbs/mod.rs +++ b/chips/src/sorted_limbs/mod.rs @@ -2,8 +2,11 @@ use crate::range_gate::RangeCheckerGateChip; use afs_stark_backend::interaction::Interaction; use columns::SortedLimbsCols; -use p3_air::VirtualPairCol; -use p3_field::PrimeField64; +use p3_air::{AirBuilder, VirtualPairCol}; +use p3_field::{AbstractField, PrimeField64}; + +#[cfg(test)] +pub mod tests; pub mod air; pub mod chip; @@ -79,6 +82,7 @@ impl SortedLimbsChip { let mut interactions = vec![]; + // we will range check the decomposed limbs of the key for i in 0..num_keys { // add 1 to account for the shifted last sublimb for j in 0..(num_limbs + 1) { @@ -90,6 +94,115 @@ impl SortedLimbsChip { } } + // we also range check the limbs of the lower_bits so that we know each element + // of lower_bits has at most limb_bits bits + for i in 0..num_keys { + for j in 0..(num_limbs + 1) { + interactions.push(Interaction { + fields: vec![VirtualPairCol::single_main(cols.lower_bits_decomp[i][j])], + count: VirtualPairCol::constant(F::one()), + argument_index: self.bus_index(), + }); + } + } + interactions } + + // sub-chip with constraints to check whether one key is less than the next (row-wise) + pub fn is_less_than( + &self, + builder: &mut AB, + local_cols: SortedLimbsCols, + next_cols: SortedLimbsCols, + ) where + AB::Var: Clone, + { + // num_limbs is the number of sublimbs per limb, not including the shifted last sublimb + let num_limbs = (self.limb_bits() + self.decomp() - 1) / self.decomp(); + + let intermed_sum = local_cols.intermed_sum; + let lower_bits = local_cols.lower_bits; + let upper_bit = local_cols.upper_bit; + let lower_bits_decomp = local_cols.lower_bits_decomp; + + // we want to check these constraints for each row except the last one + let mut when_transition = builder.when_transition(); + + // to range check the last sublimb of the decomposed limb, we need to shift it to make sure it is in + // the correct range + let last_limb_shift = (self.decomp() - (self.limb_bits() % self.decomp())) % self.decomp(); + + for (i, (key_local, key_next)) in + local_cols.key.iter().zip(next_cols.key.iter()).enumerate() + { + // this is the desired intermediate value (i.e. 2^limb_bits + b - a - 1) + let intermed_val = *key_next - *key_local + + AB::Expr::from_canonical_u64(1 << self.limb_bits()) + - AB::Expr::one(); + + // constrain that the intermed val (2^limb_bits + key_next - key_local) is correct + when_transition.assert_eq(intermed_sum[i], intermed_val); + + // constrain that lower_bits[i] + upper_bit[i] * 2^limb_bits is the correct intermediate sum + let check_val = + lower_bits[i] + upper_bit[i] * AB::Expr::from_canonical_u64(1 << self.limb_bits()); + when_transition.assert_eq(intermed_sum[i], check_val); + + // constrain that diff is the difference between the two elements of consecutive rows + let diff = *key_next - *key_local; + //when_transition.assert_zero(local_cols.diff[i]); + when_transition.assert_eq(diff, local_cols.diff[i]); + } + + for i in 0..self.key_vec_len() { + let mut lower_bits_from_decomp: AB::Expr = AB::Expr::zero(); + // constrain that the decomposition of each lower_bits element is correct + for j in 0..num_limbs { + lower_bits_from_decomp += lower_bits_decomp[i][j] + * AB::Expr::from_canonical_u64(1 << (j * self.decomp())); + } + + // constrain that the shifted last limb is shifted correctly + let shifted_val = lower_bits_decomp[i][num_limbs - 1] + * AB::Expr::from_canonical_u64(1 << last_limb_shift); + + when_transition.assert_eq(lower_bits_decomp[i][num_limbs], shifted_val); + when_transition.assert_eq(lower_bits_from_decomp, lower_bits[i]); + } + + for upper_bit_value in &upper_bit { + // constrain that each element in upper_bit is a boolean + let is_bool = *upper_bit_value * (AB::Expr::one() - *upper_bit_value); + when_transition.assert_zero(is_bool); + } + + for i in 0..self.key_vec_len() { + let diff = local_cols.diff[i]; + let is_equal = local_cols.is_zero[i]; + let inverse = local_cols.inverses[i]; + + // check that diff * is_equal = 0 + when_transition.assert_zero(diff * is_equal); + // check that is_equal is boolean + when_transition.assert_zero(is_equal * (AB::Expr::one() - is_equal)); + // check that inverse * (diff + is_equal) = 1 + when_transition.assert_one(inverse * (diff + is_equal)); + } + + // to check whether one row is less than another, we can use the indicators to generate a boolean + // expression; the idea is that, starting at the most significant limb, a row is less than the next + // if all the limbs more significant are equal and the current limb is less than the corresponding + // limb in the next row + let mut check_less_than: AB::Expr = AB::Expr::zero(); + + for (i, &upper_bit_value) in upper_bit.iter().enumerate() { + let mut curr_expr: AB::Expr = upper_bit_value.into(); + for &is_zero_value in &local_cols.is_zero[i + 1..] { + curr_expr *= is_zero_value.into(); + } + check_less_than += curr_expr; + } + when_transition.assert_one(check_less_than); + } } diff --git a/chips/src/sorted_limbs/tests/mod.rs b/chips/src/sorted_limbs/tests/mod.rs new file mode 100644 index 0000000000..668ca6bf70 --- /dev/null +++ b/chips/src/sorted_limbs/tests/mod.rs @@ -0,0 +1,173 @@ +use super::super::sorted_limbs; + +use afs_stark_backend::prover::USE_DEBUG_BUILDER; +use afs_stark_backend::verifier::VerificationError; +use afs_test_utils::config::baby_bear_poseidon2::run_simple_test_no_pis; +use p3_baby_bear::BabyBear; +use p3_matrix::dense::DenseMatrix; + +/** + * Testing strategy for the sorted limbs chip: + * partition on limb_bits: + * limb_bits < 20 + * limb_bits >= 20 + * partition on key_vec_len: + * key_vec_len < 4 + * key_vec_len >= 4 + * partition on decomp: + * limb_bits % decomp == 0 + * limb_bits % decomp != 0 + * partition on number of rows: + * number of rows < 4 + * number of rows >= 4 + * partition on size of each limb: + * each limb has at most limb_bits bits + * at least one limb has more than limb_bits bits + * partition on row order: + * rows are sorted lexicographically + * rows are not sorted lexicographically + */ + +// covers limb_bits < 20, key_vec_len < 4, limb_bits % decomp == 0, number of rows < 4, each limb has at +// most limb_bits bits, rows are sorted lexicographically +#[test] +fn test_sorted_limbs_chip_small_positive() { + use sorted_limbs::SortedLimbsChip; + + const BUS_INDEX: usize = 0; + const LIMB_BITS: usize = 16; + const DECOMP: usize = 8; + const KEY_VEC_LEN: usize = 2; + + const MAX: u32 = 1 << DECOMP; + + let requests = vec![vec![7784, 35423], vec![17558, 44832]]; + + let sorted_limbs_chip = + SortedLimbsChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); + + let sorted_limbs_chip_trace: DenseMatrix = sorted_limbs_chip.generate_trace(); + let sorted_limbs_range_chip_trace: DenseMatrix = + sorted_limbs_chip.range_checker_gate.generate_trace(); + + run_simple_test_no_pis( + vec![&sorted_limbs_chip, &sorted_limbs_chip.range_checker_gate], + vec![sorted_limbs_chip_trace, sorted_limbs_range_chip_trace], + ) + .expect("Verification failed"); +} + +// covers limb_bits >= 20, key_vec_len >= 4, limb_bits % decomp != 0, number of rows >= 4, each limb has at +// most limb_bits bits, rows are sorted lexicographically +#[test] +fn test_sorted_limbs_chip_large_positive() { + use sorted_limbs::SortedLimbsChip; + + const BUS_INDEX: usize = 0; + const LIMB_BITS: usize = 30; + const DECOMP: usize = 8; + const KEY_VEC_LEN: usize = 4; + + const MAX: u32 = 1 << DECOMP; + + let requests = vec![ + vec![35867, 318434, 12786, 44832], + vec![704210, 369315, 42421, 487111], + vec![370183, 37202, 729789, 783571], + vec![875005, 767547, 196209, 887921], + ]; + + let sorted_limbs_chip = + SortedLimbsChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); + + let sorted_limbs_chip_trace: DenseMatrix = sorted_limbs_chip.generate_trace(); + let sorted_limbs_range_chip_trace: DenseMatrix = + sorted_limbs_chip.range_checker_gate.generate_trace(); + + run_simple_test_no_pis( + vec![&sorted_limbs_chip, &sorted_limbs_chip.range_checker_gate], + vec![sorted_limbs_chip_trace, sorted_limbs_range_chip_trace], + ) + .expect("Verification failed"); +} + +// covers limb_bits >= 20, key_vec_len >= 4, limb_bits % decomp != 0, number of rows >= 4, at least one limb +// has more than limb_bits bits, rows are sorted lexicographically +#[test] +fn test_sorted_limbs_chip_largelimb_negative() { + use sorted_limbs::SortedLimbsChip; + + const BUS_INDEX: usize = 0; + const LIMB_BITS: usize = 10; + const DECOMP: usize = 8; + const KEY_VEC_LEN: usize = 4; + + const MAX: u32 = 1 << DECOMP; + + // the first and second rows are not in sorted order + let requests = vec![ + vec![223, 448, 15, 587], + vec![883, 168, 772, 673], + vec![57, 386, 1025, 694], + vec![128, 767, 196, 953], + ]; + + let sorted_limbs_chip = + SortedLimbsChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); + + let sorted_limbs_chip_trace: DenseMatrix = sorted_limbs_chip.generate_trace(); + let sorted_limbs_range_chip_trace: DenseMatrix = + sorted_limbs_chip.range_checker_gate.generate_trace(); + + let result = run_simple_test_no_pis( + vec![&sorted_limbs_chip, &sorted_limbs_chip.range_checker_gate], + vec![sorted_limbs_chip_trace, sorted_limbs_range_chip_trace], + ); + + assert_eq!( + result, + Err(VerificationError::NonZeroCumulativeSum), + "Expected verification to fail, but it passed" + ); +} + +// covers limb_bits >= 20, key_vec_len >= 4, limb_bits % decomp != 0, number of rows >= 4, each limb has at +// most limb_bits bits, rows are not sorted lexicographically +#[test] +fn test_sorted_limbs_chip_unsorted_negative() { + use sorted_limbs::SortedLimbsChip; + + const BUS_INDEX: usize = 0; + const LIMB_BITS: usize = 30; + const DECOMP: usize = 8; + const KEY_VEC_LEN: usize = 4; + + const MAX: u32 = 1 << DECOMP; + + // the first and second rows are not in sorted order + let requests = vec![ + vec![704210, 369315, 42421, 44832], + vec![35867, 318434, 12786, 44832], + vec![370183, 37202, 729789, 783571], + vec![875005, 767547, 196209, 887921], + ]; + + let sorted_limbs_chip = + SortedLimbsChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); + + let sorted_limbs_chip_trace: DenseMatrix = sorted_limbs_chip.generate_trace(); + let sorted_limbs_range_chip_trace: DenseMatrix = + sorted_limbs_chip.range_checker_gate.generate_trace(); + + USE_DEBUG_BUILDER.with(|debug| { + *debug.lock().unwrap() = false; + }); + assert_eq!( + run_simple_test_no_pis( + vec![&sorted_limbs_chip, &sorted_limbs_chip.range_checker_gate], + vec![sorted_limbs_chip_trace, sorted_limbs_range_chip_trace], + ), + Err(VerificationError::OodEvaluationMismatch), + "Expected verification to fail, but it passed" + ); +} diff --git a/chips/src/sorted_limbs/trace.rs b/chips/src/sorted_limbs/trace.rs index aee1a05a2b..7178f12c61 100644 --- a/chips/src/sorted_limbs/trace.rs +++ b/chips/src/sorted_limbs/trace.rs @@ -5,22 +5,29 @@ use super::{columns::SortedLimbsCols, SortedLimbsChip}; impl SortedLimbsChip { pub fn generate_trace(&self) -> RowMajorMatrix { - let num_xor_cols: usize = + let num_cols: usize = SortedLimbsCols::::get_width(self.limb_bits(), self.decomp(), self.key_vec_len()); let keys = self.keys(); let num_limbs = (self.limb_bits() + self.decomp() - 1) / self.decomp(); + // to range check the last sublimb of the decomposed limb, we need to shift it to make sure it is in + // the correct range + let last_limb_shift = (self.decomp() - (self.limb_bits() % self.decomp())) % self.decomp(); + let rows = keys .iter() - .map(|x| { + .enumerate() + .map(|(i, key)| { + // put the key itself into the trace let mut row = vec![]; - for &item in x.iter() { + for &item in key.iter() { row.push(F::from_canonical_u32(item)); } - for &val in x.iter() { - // decompose each limb into sublimbs of size self.decomp() bits + + // decompose each limb into sublimbs of size self.decomp() bits + for &val in key.iter() { for j in 0..num_limbs { let bits = (val >> (j * self.decomp())) & ((1 << self.decomp()) - 1); row.push(F::from_canonical_u32(bits)); @@ -30,17 +37,98 @@ impl SortedLimbsChip { // so we need to shift it to constrain this let bits = (val >> ((num_limbs - 1) * self.decomp())) & ((1 << self.decomp()) - 1); - self.range_checker_gate - .add_count(bits << (self.decomp() - (self.limb_bits() % self.decomp()))); + if (bits << last_limb_shift) < MAX { + self.range_checker_gate.add_count(bits << last_limb_shift); + } + row.push(F::from_canonical_u32(bits << last_limb_shift)); + } + + // this will contain 2^limb_bits + b - a + let mut checks: Vec = vec![]; + // the lower limb_bits bits of the corresponding check value + let mut lower_bits: Vec = vec![]; + let mut lower_bits_u32: Vec = vec![]; + // the (n + 1)st bits of the corresponding check value, will be 1 if a < b + let mut upper_bit: Vec = vec![]; + + // contains the difference between consecutive rows + let mut diff: Vec = vec![]; + // contains indicator whether difference is zero + let mut is_zero: Vec = vec![]; + // contains y such that y * (i + x) = 1 + let mut inverses: Vec = vec![]; + + // we compute the indicators, which only matter if the row is not the last + if i + 1 < keys.len() { + let next_key = &keys[i + 1]; + for (j, &val) in key.iter().enumerate() { + let next_val = next_key[j]; + // compute 2^limb_bits + next_val - val - 1 + let check_less_than = (1 << self.limb_bits()) + next_val - val - 1; + checks.push(F::from_canonical_u32(check_less_than)); + // the lower limb_bits bits of the check value + lower_bits.push(F::from_canonical_u32( + check_less_than & ((1 << self.limb_bits()) - 1), + )); + // we also need the u32 value to compute the decomposition later + lower_bits_u32.push(check_less_than & ((1 << self.limb_bits()) - 1)); + // the (n + 1)st bit of the check value, will be 1 if a < b + upper_bit.push(F::from_canonical_u32(check_less_than >> self.limb_bits())); + + // the difference between the two limbs + let curr_diff = + F::from_canonical_u32(next_val) - F::from_canonical_u32(val); + diff.push(curr_diff); + + // compute the equal indicator and inverses + if next_val == val { + is_zero.push(F::one()); + inverses.push((curr_diff + F::one()).inverse()); + } else { + is_zero.push(F::zero()); + inverses.push(curr_diff.inverse()); + } + } + } else { + for _ in 0..self.key_vec_len() { + checks.push(F::zero()); + lower_bits.push(F::zero()); + lower_bits_u32.push(0); + upper_bit.push(F::zero()); + diff.push(F::zero()); + is_zero.push(F::zero()); + inverses.push(F::zero()); + } + } + + row.extend(checks); + row.extend(lower_bits); + row.extend(upper_bit); - row.push(F::from_canonical_u32( - bits << (self.decomp() - (self.limb_bits() % self.decomp())), - )); + // decompose each element of lower_bits so we can range check that the element + // has at most limb_bits bits + for val in lower_bits_u32 { + for j in 0..num_limbs { + let bits = (val >> (j * self.decomp())) & ((1 << self.decomp()) - 1); + row.push(F::from_canonical_u32(bits)); + self.range_checker_gate.add_count(bits); + } + let bits = + (val >> ((num_limbs - 1) * self.decomp())) & ((1 << self.decomp()) - 1); + if (bits << last_limb_shift) < MAX { + self.range_checker_gate.add_count(bits << last_limb_shift); + } + row.push(F::from_canonical_u32(bits << last_limb_shift)); } + + row.extend(diff); + row.extend(is_zero); + row.extend(inverses); + row }) .collect::>(); - RowMajorMatrix::new(rows.concat(), num_xor_cols) + RowMajorMatrix::new(rows.concat(), num_cols) } } diff --git a/chips/tests/integration_test.rs b/chips/tests/integration_test.rs index 7740dc0d8a..818e8c049d 100644 --- a/chips/tests/integration_test.rs +++ b/chips/tests/integration_test.rs @@ -1,6 +1,6 @@ use std::{iter, sync::Arc}; -use afs_chips::{range, range_gate, sorted_limbs, xor_bits, xor_limbs}; +use afs_chips::{range, range_gate, xor_bits, xor_limbs}; use afs_stark_backend::prover::USE_DEBUG_BUILDER; use afs_stark_backend::rap::AnyRap; use afs_stark_backend::verifier::VerificationError; @@ -73,45 +73,6 @@ fn test_list_range_checker() { run_simple_test_no_pis(all_chips, all_traces).expect("Verification failed"); } -#[test] -fn test_sorted_limbs_chip() { - let mut rng = create_seeded_rng(); - - use sorted_limbs::SortedLimbsChip; - - const BUS_INDEX: usize = 0; - const LIMB_BITS: usize = 20; - const DECOMP: usize = 8; - const KEY_VEC_LEN: usize = 4; - - const LOG_REQUESTS: usize = 2; - - const MAX: u32 = 1 << 8; - const MAX_LIMB: u32 = 1 << LIMB_BITS; - const REQUESTS: usize = 1 << LOG_REQUESTS; - - let requests = (0..REQUESTS) - .map(|_| { - (0..KEY_VEC_LEN) - .map(|_| rng.gen::() % MAX_LIMB) - .collect::>() - }) - .collect::>>(); - - let sorted_limbs_chip = - SortedLimbsChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); - - let sorted_limbs_chip_trace: DenseMatrix = sorted_limbs_chip.generate_trace(); - let sorted_limbs_range_chip_trace: DenseMatrix = - sorted_limbs_chip.range_checker_gate.generate_trace(); - - run_simple_test_no_pis( - vec![&sorted_limbs_chip, &sorted_limbs_chip.range_checker_gate], - vec![sorted_limbs_chip_trace, sorted_limbs_range_chip_trace], - ) - .expect("Verification failed"); -} - #[test] fn test_xor_bits_chip() { let mut rng = create_seeded_rng(); From a5c098234472ecf307ffa1ad7c85b9949ffe218a Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Fri, 31 May 2024 14:38:59 -0400 Subject: [PATCH 03/46] feat: SortedLimbsChip with LessThan subchip --- chips/src/less_than/air.rs | 139 +++++++++++++++++++++ chips/src/less_than/chip.rs | 52 ++++++++ chips/src/less_than/columns.rs | 183 ++++++++++++++++++++++++++++ chips/src/less_than/mod.rs | 65 ++++++++++ chips/src/less_than/trace.rs | 116 ++++++++++++++++++ chips/src/lib.rs | 2 + chips/src/sorted_limbs/air.rs | 41 +++++-- chips/src/sorted_limbs/chip.rs | 27 +++- chips/src/sorted_limbs/columns.rs | 29 +++-- chips/src/sorted_limbs/mod.rs | 151 +++-------------------- chips/src/sorted_limbs/tests/mod.rs | 44 +++++-- chips/src/sorted_limbs/trace.rs | 150 ++++++----------------- chips/src/sub_chip.rs | 34 ++++++ chips/src/xor_bits/air.rs | 52 +++++++- chips/src/xor_bits/chip.rs | 25 +++- chips/src/xor_bits/mod.rs | 59 --------- chips/src/xor_bits/trace.rs | 50 +++++--- 17 files changed, 855 insertions(+), 364 deletions(-) create mode 100644 chips/src/less_than/air.rs create mode 100644 chips/src/less_than/chip.rs create mode 100644 chips/src/less_than/columns.rs create mode 100644 chips/src/less_than/mod.rs create mode 100644 chips/src/less_than/trace.rs create mode 100644 chips/src/sub_chip.rs diff --git a/chips/src/less_than/air.rs b/chips/src/less_than/air.rs new file mode 100644 index 0000000000..97ba13d267 --- /dev/null +++ b/chips/src/less_than/air.rs @@ -0,0 +1,139 @@ +use std::borrow::Borrow; + +use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir}; +use p3_field::{AbstractField, Field}; +use p3_matrix::Matrix; + +use crate::sub_chip::{AirConfig, SubAir}; + +use super::{columns::LessThanCols, LessThanChip}; + +impl BaseAir for LessThanChip { + fn width(&self) -> usize { + LessThanCols::::get_width(self.limb_bits(), self.decomp(), self.key_vec_len()) + } +} + +impl Air for LessThanChip +where + AB: AirBuilder, + AB::Var: Clone, +{ + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let _pis = builder.public_values(); + + let (local, _next) = (main.row_slice(0), main.row_slice(1)); + let local: &[AB::Var] = (*local).borrow(); + + let local_cols = LessThanCols::::from_slice( + local, + self.limb_bits(), + self.decomp(), + self.key_vec_len(), + ); + + SubAir::eval(self, builder, vec![local_cols]); + } +} + +impl AirConfig for LessThanChip { + type Cols = LessThanCols; +} + +// sub-chip with constraints to check whether one key is less than the next (row-wise) +impl SubAir for LessThanChip { + type ColsPassed = Vec>; + + fn eval(&self, builder: &mut AB, cols: Self::ColsPassed) { + let local_cols = &cols[0]; + let next_cols = &cols[1]; + + // num_limbs is the number of sublimbs per limb, not including the shifted last sublimb + let num_limbs = (self.limb_bits() + self.decomp() - 1) / self.decomp(); + + let intermed_sum = local_cols.intermed_sum.clone(); + let lower_bits = local_cols.lower_bits.clone(); + let upper_bit = local_cols.upper_bit.clone(); + let lower_bits_decomp = local_cols.lower_bits_decomp.clone(); + + // we want to check these constraints for each row except the last one + let mut when_transition = builder.when_transition(); + + // to range check the last sublimb of the decomposed limb, we need to shift it to make sure it is in + // the correct range + let last_limb_shift = (self.decomp() - (self.limb_bits() % self.decomp())) % self.decomp(); + + for (i, (key_local, key_next)) in + local_cols.key.iter().zip(next_cols.key.iter()).enumerate() + { + // this is the desired intermediate value (i.e. 2^limb_bits + b - a - 1) + let intermed_val = *key_next - *key_local + + AB::Expr::from_canonical_u64(1 << self.limb_bits()) + - AB::Expr::one(); + + // constrain that the intermed val (2^limb_bits + key_next - key_local) is correct + when_transition.assert_eq(intermed_sum[i], intermed_val); + + // constrain that lower_bits[i] + upper_bit[i] * 2^limb_bits is the correct intermediate sum + let check_val = + lower_bits[i] + upper_bit[i] * AB::Expr::from_canonical_u64(1 << self.limb_bits()); + when_transition.assert_eq(intermed_sum[i], check_val); + + // constrain that diff is the difference between the two elements of consecutive rows + let diff = *key_next - *key_local; + //when_transition.assert_zero(local_cols.diff[i]); + when_transition.assert_eq(diff, local_cols.diff[i]); + } + + for i in 0..self.key_vec_len() { + let mut lower_bits_from_decomp: AB::Expr = AB::Expr::zero(); + // constrain that the decomposition of each lower_bits element is correct + for j in 0..num_limbs { + lower_bits_from_decomp += lower_bits_decomp[i][j] + * AB::Expr::from_canonical_u64(1 << (j * self.decomp())); + } + + // constrain that the shifted last limb is shifted correctly + let shifted_val = lower_bits_decomp[i][num_limbs - 1] + * AB::Expr::from_canonical_u64(1 << last_limb_shift); + + when_transition.assert_eq(lower_bits_decomp[i][num_limbs], shifted_val); + when_transition.assert_eq(lower_bits_from_decomp, lower_bits[i]); + } + + for upper_bit_value in &upper_bit { + // constrain that each element in upper_bit is a boolean + let is_bool = *upper_bit_value * (AB::Expr::one() - *upper_bit_value); + when_transition.assert_zero(is_bool); + } + + for i in 0..self.key_vec_len() { + let diff = local_cols.diff[i]; + let is_equal = local_cols.is_zero[i]; + let inverse = local_cols.inverses[i]; + + // check that diff * is_equal = 0 + when_transition.assert_zero(diff * is_equal); + // check that is_equal is boolean + when_transition.assert_zero(is_equal * (AB::Expr::one() - is_equal)); + // check that inverse * (diff + is_equal) = 1 + when_transition.assert_one(inverse * (diff + is_equal)); + } + + // to check whether one row is less than another, we can use the indicators to generate a boolean + // expression; the idea is that, starting at the most significant limb, a row is less than the next + // if all the limbs more significant are equal and the current limb is less than the corresponding + // limb in the next row + let mut check_less_than: AB::Expr = AB::Expr::zero(); + + for (i, &upper_bit_value) in upper_bit.iter().enumerate() { + let mut curr_expr: AB::Expr = upper_bit_value.into(); + for &is_zero_value in &local_cols.is_zero[i + 1..] { + curr_expr *= is_zero_value.into(); + } + check_less_than += curr_expr; + } + when_transition.assert_one(check_less_than); + } +} diff --git a/chips/src/less_than/chip.rs b/chips/src/less_than/chip.rs new file mode 100644 index 0000000000..df76a04840 --- /dev/null +++ b/chips/src/less_than/chip.rs @@ -0,0 +1,52 @@ +use crate::sub_chip::SubAirWithInteractions; + +use super::columns::LessThanCols; +use afs_stark_backend::interaction::{Chip, Interaction}; +use p3_air::VirtualPairCol; +use p3_field::PrimeField64; + +use super::LessThanChip; + +impl Chip for LessThanChip { + fn sends(&self) -> Vec> { + let num_cols = + LessThanCols::::get_width(self.limb_bits(), self.decomp(), self.key_vec_len()); + let all_cols = (0..num_cols).collect::>(); + + let cols_numbered = LessThanCols::::cols_numbered( + &all_cols, + self.limb_bits(), + self.decomp(), + self.key_vec_len(), + ); + + SubAirWithInteractions::sends(self, cols_numbered) + } +} + +impl SubAirWithInteractions for LessThanChip { + fn sends(&self, col_indices: LessThanCols) -> Vec> { + // num_limbs is the number of sublimbs per limb of key, not including the + // shifted last sublimb + let num_limbs = (self.limb_bits() + self.decomp() - 1) / self.decomp(); + let num_keys = self.key_vec_len(); + + let mut interactions = vec![]; + + // we range check the limbs of the lower_bits so that we know each element + // of lower_bits has at most limb_bits bits + for i in 0..num_keys { + for j in 0..(num_limbs + 1) { + interactions.push(Interaction { + fields: vec![VirtualPairCol::single_main( + col_indices.lower_bits_decomp[i][j], + )], + count: VirtualPairCol::constant(F::one()), + argument_index: self.bus_index(), + }); + } + } + + interactions + } +} diff --git a/chips/src/less_than/columns.rs b/chips/src/less_than/columns.rs new file mode 100644 index 0000000000..b47d3ef2e1 --- /dev/null +++ b/chips/src/less_than/columns.rs @@ -0,0 +1,183 @@ +use afs_derive::AlignedBorrow; + +#[derive(Default, AlignedBorrow)] +pub struct LessThanCols { + pub key: Vec, + pub intermed_sum: Vec, + pub lower_bits: Vec, + pub upper_bit: Vec, + pub lower_bits_decomp: Vec>, + pub diff: Vec, + pub is_zero: Vec, + pub inverses: Vec, +} + +impl LessThanCols { + pub fn from_slice(slc: &[T], limb_bits: usize, decomp: usize, key_vec_len: usize) -> Self { + // num_limbs is the number of sublimbs per limb, not including the shifted last sublimb + let num_limbs = (limb_bits + decomp - 1) / decomp; + let mut cur_start_idx = 0; + let mut cur_end_idx = key_vec_len; + + // the first key_vec_len elements are the key itself + let key = slc[cur_start_idx..cur_end_idx].to_vec(); + cur_start_idx = cur_end_idx; + cur_end_idx += key_vec_len; + + // the next key_vec_len elements are the values of 2^num_limbs + b - a - 1 where a and b are limbs + // on consecutive rows and b is the row after a + let intermed_sum = slc[cur_start_idx..cur_end_idx].to_vec(); + cur_start_idx = cur_end_idx; + cur_end_idx += key_vec_len; + + // the next key_vec_len elements are the values of the lower num_limbs bits of the intermediate sum + let lower_bits = slc[cur_start_idx..cur_end_idx].to_vec(); + cur_start_idx = cur_end_idx; + cur_end_idx += key_vec_len; + + // the next key_vec_len elements are the values of the upper bit of the intermediate sum; note that + // b > a <=> upper_bit = 1 + let upper_bit = slc[cur_start_idx..cur_end_idx].to_vec(); + cur_start_idx = cur_end_idx; + cur_end_idx += key_vec_len * (num_limbs + 1); + + // the next key_vec_len * (num_limbs + 1) elements are the decomposed limbs of the lower bits of the + // intermediate sum + let lower_bits_decomp = slc[cur_start_idx..cur_end_idx] + .chunks(num_limbs + 1) + .map(|chunk| chunk.to_vec()) + .collect(); + cur_start_idx = cur_end_idx; + cur_end_idx += key_vec_len; + + // the next key_vec_len elements are the difference between consecutive limbs of rows + let diff = slc[cur_start_idx..cur_end_idx].to_vec(); + cur_start_idx = cur_end_idx; + cur_end_idx += key_vec_len; + + // the next key_vec_len elements are the indicator whether the difference is zero; if difference is + // zero then the two limbs must be equal + let is_zero = slc[cur_start_idx..cur_end_idx].to_vec(); + cur_start_idx = cur_end_idx; + cur_end_idx += key_vec_len; + + // the next key_vec_len elements contain the inverses of the corresponding sum of diff and is_zero; + // note that this sum will always be nonzero so the inverse will exist + let inverses = slc[cur_start_idx..cur_end_idx].to_vec(); + + Self { + key, + intermed_sum, + lower_bits, + upper_bit, + lower_bits_decomp, + diff, + is_zero, + inverses, + } + } + + pub fn flatten(&self) -> Vec { + let mut flattened = vec![]; + flattened.extend_from_slice(&self.key); + flattened.extend_from_slice(&self.intermed_sum); + flattened.extend_from_slice(&self.lower_bits); + flattened.extend_from_slice(&self.upper_bit); + for decomp_vec in &self.lower_bits_decomp { + flattened.extend_from_slice(decomp_vec); + } + flattened.extend_from_slice(&self.diff); + flattened.extend_from_slice(&self.is_zero); + flattened.extend_from_slice(&self.inverses); + + flattened + } + + pub fn get_width(limb_bits: usize, decomp: usize, key_vec_len: usize) -> usize { + // there are (limb_bits + decomp - 1) / decomp sublimbs per limb, we add 1 to + // account for the sublimb itself, and another 1 to account for the shifted + // last sublimb + let mut width = 0; + // for the key itself + width += key_vec_len; + // for the 2^limb_bits + b - a values + width += key_vec_len; + // for the lower_bits + width += key_vec_len; + // for the upper_bit + width += key_vec_len; + // for the decomposed lower_bits + let num_limbs = (limb_bits + decomp - 1) / decomp; + width += key_vec_len * (num_limbs + 1); + // for the difference between consecutive rows + width += key_vec_len; + // for the indicator whether difference is zero + width += key_vec_len; + // for the y such that y * (i + x) = 1 + width += key_vec_len; + + width + } + + pub fn cols_numbered( + cols: &[usize], + limb_bits: usize, + decomp: usize, + key_vec_len: usize, + ) -> LessThanCols { + // num_limbs is the number of sublimbs per limb, not including the shifted last sublimb + let num_limbs = (limb_bits + decomp - 1) / decomp; + let mut cur_start_idx = 0; + let mut cur_end_idx = key_vec_len; + + // the first key_vec_len elements are the key itself + let key = cols[cur_start_idx..cur_end_idx].to_vec(); + cur_start_idx = cur_end_idx; + cur_end_idx += key_vec_len; + + // the next key_vec_len elements are the intermediate sum + let intermed_sum = cols[cur_start_idx..cur_end_idx].to_vec(); + cur_start_idx = cur_end_idx; + cur_end_idx += key_vec_len; + + // the next key_vec_len elements are the lower_bits + let lower_bits = cols[cur_start_idx..cur_end_idx].to_vec(); + cur_start_idx = cur_end_idx; + cur_end_idx += key_vec_len; + + // the next key_vec_len elements are the upper_bit + let upper_bit = cols[cur_start_idx..cur_end_idx].to_vec(); + cur_start_idx = cur_end_idx; + cur_end_idx += key_vec_len * (num_limbs + 1); + + // the next key_vec_len * (num_limbs + 1) elements are the decomposed lower_bits + let lower_bits_decomp = cols[cur_start_idx..cur_end_idx] + .chunks(num_limbs + 1) + .map(|chunk| chunk.to_vec()) + .collect(); + + // the next key_vec_len elements are the difference between consecutive rows + let diff = cols[cur_start_idx..cur_end_idx].to_vec(); + cur_start_idx = cur_end_idx; + cur_end_idx += key_vec_len; + + // the next key_vec_len elements are the indicator whether difference is zero + let is_zero = cols[cur_start_idx..cur_end_idx].to_vec(); + cur_start_idx = cur_end_idx; + cur_end_idx += key_vec_len; + + // the next key_vec_len elements are the inverses + let inverses = cols[cur_start_idx..cur_end_idx].to_vec(); + + LessThanCols { + key, + intermed_sum, + lower_bits, + upper_bit, + lower_bits_decomp, + diff, + is_zero, + inverses, + } + } +} diff --git a/chips/src/less_than/mod.rs b/chips/src/less_than/mod.rs new file mode 100644 index 0000000000..e6f6005cbb --- /dev/null +++ b/chips/src/less_than/mod.rs @@ -0,0 +1,65 @@ +use crate::range_gate::RangeCheckerGateChip; + +pub mod air; +pub mod chip; +pub mod columns; +pub mod trace; + +/** + * This Chip constrains that consecutive rows are sorted lexicographically. + * + * Each row consists of a key decomposed into limbs, and the chip constrains + * each limb has at most limb_bits bits, where limb_bits is at most 31. It + * does this by interacting with a RangeCheckerGateChip. Because the range checker + * gate can take MAX up to 2^20, we further decompose each limb into sublimbs + * of size decomp bits. + */ +#[derive(Default)] +pub struct LessThanChip { + bus_index: usize, + limb_bits: usize, + decomp: usize, + key_vec_len: usize, + keys: Vec>, + + pub range_checker_gate: RangeCheckerGateChip, +} + +impl LessThanChip { + pub fn new( + bus_index: usize, + limb_bits: usize, + decomp: usize, + key_vec_len: usize, + keys: Vec>, + ) -> Self { + Self { + bus_index, + limb_bits, + decomp, + key_vec_len, + keys, + range_checker_gate: RangeCheckerGateChip::::new(bus_index), + } + } + + pub fn bus_index(&self) -> usize { + self.bus_index + } + + pub fn limb_bits(&self) -> usize { + self.limb_bits + } + + pub fn decomp(&self) -> usize { + self.decomp + } + + pub fn key_vec_len(&self) -> usize { + self.key_vec_len + } + + pub fn keys(&self) -> Vec> { + self.keys.clone() + } +} diff --git a/chips/src/less_than/trace.rs b/chips/src/less_than/trace.rs new file mode 100644 index 0000000000..d4229f476d --- /dev/null +++ b/chips/src/less_than/trace.rs @@ -0,0 +1,116 @@ +use p3_field::PrimeField64; +use p3_matrix::dense::RowMajorMatrix; + +use crate::sub_chip::LocalTraceInstructions; + +use super::{columns::LessThanCols, LessThanChip}; + +impl LessThanChip { + pub fn generate_trace(&self) -> RowMajorMatrix { + let num_cols: usize = + LessThanCols::::get_width(self.limb_bits(), self.decomp(), self.key_vec_len()); + + let mut rows: Vec = vec![]; + for i in 0..self.key_vec_len() { + let key = self.keys[i].clone(); + let next_key: Vec = if i == self.key_vec_len() - 1 { + vec![0; self.key_vec_len()] + } else { + self.keys[i + 1].clone() + }; + let row = self.generate_trace_row((key, next_key)).flatten(); + rows.extend_from_slice(&row); + } + + RowMajorMatrix::new(rows, num_cols) + } +} + +impl LocalTraceInstructions for LessThanChip { + type LocalInput = (Vec, Vec); + + fn generate_trace_row(&self, consecutive_keys: (Vec, Vec)) -> Self::Cols { + let (key, next_key) = consecutive_keys; + let num_limbs = (self.limb_bits() + self.decomp() - 1) / self.decomp(); + let last_limb_shift = (self.decomp() - (self.limb_bits() % self.decomp())) % self.decomp(); + + // this will contain 2^limb_bits + b - a + let mut intermed_sum: Vec = vec![]; + // the lower limb_bits bits of the corresponding check value + let mut lower_bits: Vec = vec![]; + let mut lower_bits_u32: Vec = vec![]; + // the (n + 1)st bits of the corresponding check value, will be 1 if a < b + let mut upper_bit: Vec = vec![]; + + // contains the difference between consecutive rows + let mut diff: Vec = vec![]; + // contains indicator whether difference is zero + let mut is_zero: Vec = vec![]; + // contains y such that y * (i + x) = 1 + let mut inverses: Vec = vec![]; + + // we compute the indicators, which only matter if the row is not the last + for (j, &val) in key.iter().enumerate() { + let next_val = next_key[j]; + // compute 2^limb_bits + next_val - val - 1 + let check_less_than = (1 << self.limb_bits()) + next_val - val - 1; + intermed_sum.push(F::from_canonical_u32(check_less_than)); + // the lower limb_bits bits of the check value + lower_bits.push(F::from_canonical_u32( + check_less_than & ((1 << self.limb_bits()) - 1), + )); + // we also need the u32 value to compute the decomposition later + lower_bits_u32.push(check_less_than & ((1 << self.limb_bits()) - 1)); + // the (n + 1)st bit of the check value, will be 1 if a < b + upper_bit.push(F::from_canonical_u32(check_less_than >> self.limb_bits())); + + // the difference between the two limbs + let curr_diff = F::from_canonical_u32(next_val) - F::from_canonical_u32(val); + diff.push(curr_diff); + + // compute the equal indicator and inverses + if next_val == val { + is_zero.push(F::one()); + inverses.push((curr_diff + F::one()).inverse()); + } else { + is_zero.push(F::zero()); + inverses.push(curr_diff.inverse()); + } + } + + let mut lower_bits_decomp: Vec> = vec![]; + + // decompose each element of lower_bits so we can range check that the element + // has at most limb_bits bits + for i in 0..lower_bits_u32.len() { + let val = lower_bits_u32[i]; + if i != lower_bits_u32.len() { + let mut curr_decomp: Vec = vec![]; + for j in 0..num_limbs { + let bits = (val >> (j * self.decomp())) & ((1 << self.decomp()) - 1); + curr_decomp.push(F::from_canonical_u32(bits)); + self.range_checker_gate.add_count(bits); + } + let bits = (val >> ((num_limbs - 1) * self.decomp())) & ((1 << self.decomp()) - 1); + if (bits << last_limb_shift) < MAX { + self.range_checker_gate.add_count(bits << last_limb_shift); + } + curr_decomp.push(F::from_canonical_u32(bits << last_limb_shift)); + lower_bits_decomp.push(curr_decomp); + } else { + lower_bits_decomp.push(vec![F::zero(); num_limbs + 1]); + } + } + + LessThanCols { + key: key.into_iter().map(F::from_canonical_u32).collect(), + intermed_sum, + lower_bits, + upper_bit, + lower_bits_decomp, + diff, + is_zero, + inverses, + } + } +} diff --git a/chips/src/lib.rs b/chips/src/lib.rs index a24ba486fb..2b6b185133 100644 --- a/chips/src/lib.rs +++ b/chips/src/lib.rs @@ -1,9 +1,11 @@ +pub mod less_than; pub mod page_controller; pub mod page_read; /// Chip to range check a value has less than a fixed number of bits pub mod range; pub mod range_gate; pub mod sorted_limbs; +pub mod sub_chip; pub mod xor_bits; pub mod xor_limbs; pub mod xor_lookup; diff --git a/chips/src/sorted_limbs/air.rs b/chips/src/sorted_limbs/air.rs index 5caccf889e..665454cb6a 100644 --- a/chips/src/sorted_limbs/air.rs +++ b/chips/src/sorted_limbs/air.rs @@ -1,9 +1,12 @@ use std::borrow::Borrow; -use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir}; +use p3_air::{Air, AirBuilder, BaseAir}; use p3_field::{AbstractField, Field}; use p3_matrix::Matrix; +use crate::less_than::columns::LessThanCols; +use crate::sub_chip::SubAir; + use super::columns::SortedLimbsCols; use super::SortedLimbsChip; @@ -13,18 +16,16 @@ impl BaseAir for SortedLimbsChip { } } -impl Air for SortedLimbsChip +impl Air for SortedLimbsChip where AB: AirBuilder, AB::Var: Clone, { fn eval(&self, builder: &mut AB) { let main = builder.main(); - let _pis = builder.public_values(); let (local, next) = (main.row_slice(0), main.row_slice(1)); let local: &[AB::Var] = (*local).borrow(); - let next: &[AB::Var] = (*next).borrow(); let local_cols = SortedLimbsCols::::from_slice( local, @@ -32,12 +33,6 @@ where self.decomp(), self.key_vec_len(), ); - let next_cols = SortedLimbsCols::::from_slice( - next, - self.limb_bits(), - self.decomp(), - self.key_vec_len(), - ); let num_limbs = (self.limb_bits() + self.decomp() - 1) / self.decomp(); let key_len = self.key_vec_len(); @@ -59,9 +54,31 @@ where * AB::Expr::from_canonical_u64(1 << last_limb_shift); builder.assert_eq(local_cols.keys_decomp[i][num_limbs], shifted_val); - builder.assert_eq(key_from_limbs, local_cols.key[i]); + builder.assert_eq(key_from_limbs, local_cols.less_than_cols.key[i]); } - self.is_less_than(builder, local_cols, next_cols); + // generate LessThanCols struct for current row and next row + let mut local_slice: Vec = local[0..self.key_vec_len()].to_vec(); + local_slice.extend_from_slice(&local[(self.key_vec_len() * (num_limbs + 2))..]); + + let mut next_slice: Vec = next[0..self.key_vec_len()].to_vec(); + next_slice.extend_from_slice(&next[(self.key_vec_len() * (num_limbs + 2))..]); + + let local_cols = LessThanCols::::from_slice( + &local_slice, + self.limb_bits(), + self.decomp(), + self.key_vec_len(), + ); + + let next_cols = LessThanCols::::from_slice( + &next_slice, + self.limb_bits(), + self.decomp(), + self.key_vec_len(), + ); + + // constrain the current row is less than the next row + SubAir::eval(&self.less_than_chip, builder, vec![local_cols, next_cols]); } } diff --git a/chips/src/sorted_limbs/chip.rs b/chips/src/sorted_limbs/chip.rs index 20ae923190..d970b70c2c 100644 --- a/chips/src/sorted_limbs/chip.rs +++ b/chips/src/sorted_limbs/chip.rs @@ -1,5 +1,8 @@ +use crate::sub_chip::SubAirWithInteractions; + use super::columns::SortedLimbsCols; use afs_stark_backend::interaction::{Chip, Interaction}; +use p3_air::VirtualPairCol; use p3_field::PrimeField64; use super::SortedLimbsChip; @@ -17,6 +20,28 @@ impl Chip for SortedLimbsChip { self.key_vec_len(), ); - self.sends_custom(cols_numbered) + let mut interactions: Vec> = vec![]; + + let num_limbs = (self.limb_bits() + self.decomp() - 1) / self.decomp(); + let num_keys = self.key_vec_len(); + + // we will range check the decomposed limbs of the key + for i in 0..num_keys { + // add 1 to account for the shifted last sublimb + for j in 0..(num_limbs + 1) { + interactions.push(Interaction { + fields: vec![VirtualPairCol::single_main(cols_numbered.keys_decomp[i][j])], + count: VirtualPairCol::constant(F::one()), + argument_index: self.bus_index(), + }); + } + } + + // append the interactions from the subchip + let mut less_than_interactions: Vec> = + SubAirWithInteractions::::sends(&self.less_than_chip, cols_numbered.less_than_cols); + interactions.append(&mut less_than_interactions); + + interactions } } diff --git a/chips/src/sorted_limbs/columns.rs b/chips/src/sorted_limbs/columns.rs index 52c025cdf7..7ba71313ee 100644 --- a/chips/src/sorted_limbs/columns.rs +++ b/chips/src/sorted_limbs/columns.rs @@ -1,16 +1,13 @@ use afs_derive::AlignedBorrow; +use crate::less_than::columns::LessThanCols; + +// Since SortedLimbsChip contains a LessThanChip subchip, a subset of the columns are those of the +// LessThanChip #[derive(Default, AlignedBorrow)] pub struct SortedLimbsCols { - pub key: Vec, pub keys_decomp: Vec>, - pub intermed_sum: Vec, - pub lower_bits: Vec, - pub upper_bit: Vec, - pub lower_bits_decomp: Vec>, - pub diff: Vec, - pub is_zero: Vec, - pub inverses: Vec, + pub less_than_cols: LessThanCols, } impl SortedLimbsCols { @@ -75,9 +72,8 @@ impl SortedLimbsCols { // note that this sum will always be nonzero so the inverse will exist let inverses = slc[cur_start_idx..cur_end_idx].to_vec(); - Self { + let less_than_cols = LessThanCols { key, - keys_decomp, intermed_sum, lower_bits, upper_bit, @@ -85,6 +81,11 @@ impl SortedLimbsCols { diff, is_zero, inverses, + }; + + Self { + keys_decomp, + less_than_cols, } } @@ -174,9 +175,8 @@ impl SortedLimbsCols { // the next key_vec_len elements are the inverses let inverses = cols[cur_start_idx..cur_end_idx].to_vec(); - SortedLimbsCols { + let less_than_cols = LessThanCols { key, - keys_decomp, intermed_sum, lower_bits, upper_bit, @@ -184,6 +184,11 @@ impl SortedLimbsCols { diff, is_zero, inverses, + }; + + SortedLimbsCols { + keys_decomp, + less_than_cols, } } } diff --git a/chips/src/sorted_limbs/mod.rs b/chips/src/sorted_limbs/mod.rs index 0145ec5464..4c205865f3 100644 --- a/chips/src/sorted_limbs/mod.rs +++ b/chips/src/sorted_limbs/mod.rs @@ -1,9 +1,9 @@ -use crate::range_gate::RangeCheckerGateChip; +use crate::less_than::LessThanChip; use afs_stark_backend::interaction::Interaction; use columns::SortedLimbsCols; -use p3_air::{AirBuilder, VirtualPairCol}; -use p3_field::{AbstractField, PrimeField64}; +use p3_air::VirtualPairCol; +use p3_field::PrimeField64; #[cfg(test)] pub mod tests; @@ -21,16 +21,13 @@ pub mod trace; * does this by interacting with a RangeCheckerGateChip. Because the range checker * gate can take MAX up to 2^20, we further decompose each limb into sublimbs * of size decomp bits. + * + * The SortedLimbsChip contains a LessThanChip subchip, which is used to constrain + * that the rows are sorted lexicographically. */ #[derive(Default)] pub struct SortedLimbsChip { - bus_index: usize, - limb_bits: usize, - decomp: usize, - key_vec_len: usize, - keys: Vec>, - - pub range_checker_gate: RangeCheckerGateChip, + less_than_chip: LessThanChip, } impl SortedLimbsChip { @@ -42,38 +39,39 @@ impl SortedLimbsChip { keys: Vec>, ) -> Self { Self { - bus_index, - limb_bits, - decomp, - key_vec_len, - keys, - range_checker_gate: RangeCheckerGateChip::::new(bus_index), + less_than_chip: LessThanChip::::new( + bus_index, + limb_bits, + decomp, + key_vec_len, + keys, + ), } } pub fn bus_index(&self) -> usize { - self.bus_index + self.less_than_chip.bus_index() } pub fn limb_bits(&self) -> usize { - self.limb_bits + self.less_than_chip.limb_bits() } pub fn decomp(&self) -> usize { - self.decomp + self.less_than_chip.decomp() } pub fn key_vec_len(&self) -> usize { - self.key_vec_len + self.less_than_chip.key_vec_len() } pub fn keys(&self) -> Vec> { - self.keys.clone() + self.less_than_chip.keys().clone() } pub fn sends_custom( &self, - cols: SortedLimbsCols, + cols: &SortedLimbsCols, ) -> Vec> { // num_limbs is the number of sublimbs per limb of key, not including the // shifted last sublimb @@ -94,115 +92,6 @@ impl SortedLimbsChip { } } - // we also range check the limbs of the lower_bits so that we know each element - // of lower_bits has at most limb_bits bits - for i in 0..num_keys { - for j in 0..(num_limbs + 1) { - interactions.push(Interaction { - fields: vec![VirtualPairCol::single_main(cols.lower_bits_decomp[i][j])], - count: VirtualPairCol::constant(F::one()), - argument_index: self.bus_index(), - }); - } - } - interactions } - - // sub-chip with constraints to check whether one key is less than the next (row-wise) - pub fn is_less_than( - &self, - builder: &mut AB, - local_cols: SortedLimbsCols, - next_cols: SortedLimbsCols, - ) where - AB::Var: Clone, - { - // num_limbs is the number of sublimbs per limb, not including the shifted last sublimb - let num_limbs = (self.limb_bits() + self.decomp() - 1) / self.decomp(); - - let intermed_sum = local_cols.intermed_sum; - let lower_bits = local_cols.lower_bits; - let upper_bit = local_cols.upper_bit; - let lower_bits_decomp = local_cols.lower_bits_decomp; - - // we want to check these constraints for each row except the last one - let mut when_transition = builder.when_transition(); - - // to range check the last sublimb of the decomposed limb, we need to shift it to make sure it is in - // the correct range - let last_limb_shift = (self.decomp() - (self.limb_bits() % self.decomp())) % self.decomp(); - - for (i, (key_local, key_next)) in - local_cols.key.iter().zip(next_cols.key.iter()).enumerate() - { - // this is the desired intermediate value (i.e. 2^limb_bits + b - a - 1) - let intermed_val = *key_next - *key_local - + AB::Expr::from_canonical_u64(1 << self.limb_bits()) - - AB::Expr::one(); - - // constrain that the intermed val (2^limb_bits + key_next - key_local) is correct - when_transition.assert_eq(intermed_sum[i], intermed_val); - - // constrain that lower_bits[i] + upper_bit[i] * 2^limb_bits is the correct intermediate sum - let check_val = - lower_bits[i] + upper_bit[i] * AB::Expr::from_canonical_u64(1 << self.limb_bits()); - when_transition.assert_eq(intermed_sum[i], check_val); - - // constrain that diff is the difference between the two elements of consecutive rows - let diff = *key_next - *key_local; - //when_transition.assert_zero(local_cols.diff[i]); - when_transition.assert_eq(diff, local_cols.diff[i]); - } - - for i in 0..self.key_vec_len() { - let mut lower_bits_from_decomp: AB::Expr = AB::Expr::zero(); - // constrain that the decomposition of each lower_bits element is correct - for j in 0..num_limbs { - lower_bits_from_decomp += lower_bits_decomp[i][j] - * AB::Expr::from_canonical_u64(1 << (j * self.decomp())); - } - - // constrain that the shifted last limb is shifted correctly - let shifted_val = lower_bits_decomp[i][num_limbs - 1] - * AB::Expr::from_canonical_u64(1 << last_limb_shift); - - when_transition.assert_eq(lower_bits_decomp[i][num_limbs], shifted_val); - when_transition.assert_eq(lower_bits_from_decomp, lower_bits[i]); - } - - for upper_bit_value in &upper_bit { - // constrain that each element in upper_bit is a boolean - let is_bool = *upper_bit_value * (AB::Expr::one() - *upper_bit_value); - when_transition.assert_zero(is_bool); - } - - for i in 0..self.key_vec_len() { - let diff = local_cols.diff[i]; - let is_equal = local_cols.is_zero[i]; - let inverse = local_cols.inverses[i]; - - // check that diff * is_equal = 0 - when_transition.assert_zero(diff * is_equal); - // check that is_equal is boolean - when_transition.assert_zero(is_equal * (AB::Expr::one() - is_equal)); - // check that inverse * (diff + is_equal) = 1 - when_transition.assert_one(inverse * (diff + is_equal)); - } - - // to check whether one row is less than another, we can use the indicators to generate a boolean - // expression; the idea is that, starting at the most significant limb, a row is less than the next - // if all the limbs more significant are equal and the current limb is less than the corresponding - // limb in the next row - let mut check_less_than: AB::Expr = AB::Expr::zero(); - - for (i, &upper_bit_value) in upper_bit.iter().enumerate() { - let mut curr_expr: AB::Expr = upper_bit_value.into(); - for &is_zero_value in &local_cols.is_zero[i + 1..] { - curr_expr *= is_zero_value.into(); - } - check_less_than += curr_expr; - } - when_transition.assert_one(check_less_than); - } } diff --git a/chips/src/sorted_limbs/tests/mod.rs b/chips/src/sorted_limbs/tests/mod.rs index 668ca6bf70..526e81af9d 100644 --- a/chips/src/sorted_limbs/tests/mod.rs +++ b/chips/src/sorted_limbs/tests/mod.rs @@ -47,11 +47,16 @@ fn test_sorted_limbs_chip_small_positive() { SortedLimbsChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); let sorted_limbs_chip_trace: DenseMatrix = sorted_limbs_chip.generate_trace(); - let sorted_limbs_range_chip_trace: DenseMatrix = - sorted_limbs_chip.range_checker_gate.generate_trace(); + let sorted_limbs_range_chip_trace: DenseMatrix = sorted_limbs_chip + .less_than_chip + .range_checker_gate + .generate_trace(); run_simple_test_no_pis( - vec![&sorted_limbs_chip, &sorted_limbs_chip.range_checker_gate], + vec![ + &sorted_limbs_chip, + &sorted_limbs_chip.less_than_chip.range_checker_gate, + ], vec![sorted_limbs_chip_trace, sorted_limbs_range_chip_trace], ) .expect("Verification failed"); @@ -81,11 +86,16 @@ fn test_sorted_limbs_chip_large_positive() { SortedLimbsChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); let sorted_limbs_chip_trace: DenseMatrix = sorted_limbs_chip.generate_trace(); - let sorted_limbs_range_chip_trace: DenseMatrix = - sorted_limbs_chip.range_checker_gate.generate_trace(); + let sorted_limbs_range_chip_trace: DenseMatrix = sorted_limbs_chip + .less_than_chip + .range_checker_gate + .generate_trace(); run_simple_test_no_pis( - vec![&sorted_limbs_chip, &sorted_limbs_chip.range_checker_gate], + vec![ + &sorted_limbs_chip, + &sorted_limbs_chip.less_than_chip.range_checker_gate, + ], vec![sorted_limbs_chip_trace, sorted_limbs_range_chip_trace], ) .expect("Verification failed"); @@ -116,11 +126,16 @@ fn test_sorted_limbs_chip_largelimb_negative() { SortedLimbsChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); let sorted_limbs_chip_trace: DenseMatrix = sorted_limbs_chip.generate_trace(); - let sorted_limbs_range_chip_trace: DenseMatrix = - sorted_limbs_chip.range_checker_gate.generate_trace(); + let sorted_limbs_range_chip_trace: DenseMatrix = sorted_limbs_chip + .less_than_chip + .range_checker_gate + .generate_trace(); let result = run_simple_test_no_pis( - vec![&sorted_limbs_chip, &sorted_limbs_chip.range_checker_gate], + vec![ + &sorted_limbs_chip, + &sorted_limbs_chip.less_than_chip.range_checker_gate, + ], vec![sorted_limbs_chip_trace, sorted_limbs_range_chip_trace], ); @@ -156,15 +171,20 @@ fn test_sorted_limbs_chip_unsorted_negative() { SortedLimbsChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); let sorted_limbs_chip_trace: DenseMatrix = sorted_limbs_chip.generate_trace(); - let sorted_limbs_range_chip_trace: DenseMatrix = - sorted_limbs_chip.range_checker_gate.generate_trace(); + let sorted_limbs_range_chip_trace: DenseMatrix = sorted_limbs_chip + .less_than_chip + .range_checker_gate + .generate_trace(); USE_DEBUG_BUILDER.with(|debug| { *debug.lock().unwrap() = false; }); assert_eq!( run_simple_test_no_pis( - vec![&sorted_limbs_chip, &sorted_limbs_chip.range_checker_gate], + vec![ + &sorted_limbs_chip, + &sorted_limbs_chip.less_than_chip.range_checker_gate, + ], vec![sorted_limbs_chip_trace, sorted_limbs_range_chip_trace], ), Err(VerificationError::OodEvaluationMismatch), diff --git a/chips/src/sorted_limbs/trace.rs b/chips/src/sorted_limbs/trace.rs index 7178f12c61..7f4fafd214 100644 --- a/chips/src/sorted_limbs/trace.rs +++ b/chips/src/sorted_limbs/trace.rs @@ -1,6 +1,8 @@ use p3_field::PrimeField64; use p3_matrix::dense::RowMajorMatrix; +use crate::sub_chip::LocalTraceInstructions; + use super::{columns::SortedLimbsCols, SortedLimbsChip}; impl SortedLimbsChip { @@ -8,127 +10,53 @@ impl SortedLimbsChip { let num_cols: usize = SortedLimbsCols::::get_width(self.limb_bits(), self.decomp(), self.key_vec_len()); - let keys = self.keys(); - let num_limbs = (self.limb_bits() + self.decomp() - 1) / self.decomp(); // to range check the last sublimb of the decomposed limb, we need to shift it to make sure it is in // the correct range let last_limb_shift = (self.decomp() - (self.limb_bits() % self.decomp())) % self.decomp(); - let rows = keys - .iter() - .enumerate() - .map(|(i, key)| { - // put the key itself into the trace - let mut row = vec![]; - for &item in key.iter() { - row.push(F::from_canonical_u32(item)); - } - - // decompose each limb into sublimbs of size self.decomp() bits - for &val in key.iter() { - for j in 0..num_limbs { - let bits = (val >> (j * self.decomp())) & ((1 << self.decomp()) - 1); - row.push(F::from_canonical_u32(bits)); - self.range_checker_gate.add_count(bits) - } - // the last sublimb should be of size self.limb_bits() % self.decomp() bits, - // so we need to shift it to constrain this - let bits = - (val >> ((num_limbs - 1) * self.decomp())) & ((1 << self.decomp()) - 1); - if (bits << last_limb_shift) < MAX { - self.range_checker_gate.add_count(bits << last_limb_shift); - } - row.push(F::from_canonical_u32(bits << last_limb_shift)); - } - - // this will contain 2^limb_bits + b - a - let mut checks: Vec = vec![]; - // the lower limb_bits bits of the corresponding check value - let mut lower_bits: Vec = vec![]; - let mut lower_bits_u32: Vec = vec![]; - // the (n + 1)st bits of the corresponding check value, will be 1 if a < b - let mut upper_bit: Vec = vec![]; - - // contains the difference between consecutive rows - let mut diff: Vec = vec![]; - // contains indicator whether difference is zero - let mut is_zero: Vec = vec![]; - // contains y such that y * (i + x) = 1 - let mut inverses: Vec = vec![]; - - // we compute the indicators, which only matter if the row is not the last - if i + 1 < keys.len() { - let next_key = &keys[i + 1]; - for (j, &val) in key.iter().enumerate() { - let next_val = next_key[j]; - // compute 2^limb_bits + next_val - val - 1 - let check_less_than = (1 << self.limb_bits()) + next_val - val - 1; - checks.push(F::from_canonical_u32(check_less_than)); - // the lower limb_bits bits of the check value - lower_bits.push(F::from_canonical_u32( - check_less_than & ((1 << self.limb_bits()) - 1), - )); - // we also need the u32 value to compute the decomposition later - lower_bits_u32.push(check_less_than & ((1 << self.limb_bits()) - 1)); - // the (n + 1)st bit of the check value, will be 1 if a < b - upper_bit.push(F::from_canonical_u32(check_less_than >> self.limb_bits())); - - // the difference between the two limbs - let curr_diff = - F::from_canonical_u32(next_val) - F::from_canonical_u32(val); - diff.push(curr_diff); - - // compute the equal indicator and inverses - if next_val == val { - is_zero.push(F::one()); - inverses.push((curr_diff + F::one()).inverse()); - } else { - is_zero.push(F::zero()); - inverses.push(curr_diff.inverse()); - } - } - } else { - for _ in 0..self.key_vec_len() { - checks.push(F::zero()); - lower_bits.push(F::zero()); - lower_bits_u32.push(0); - upper_bit.push(F::zero()); - diff.push(F::zero()); - is_zero.push(F::zero()); - inverses.push(F::zero()); - } + let mut rows: Vec = vec![]; + for i in 0..self.key_vec_len() { + let key = self.keys()[i].clone(); + let next_key: Vec = if i == self.key_vec_len() - 1 { + vec![0; self.key_vec_len()] + } else { + self.keys()[i + 1].clone() + }; + + let less_than_trace = LocalTraceInstructions::generate_trace_row( + &self.less_than_chip, + (key.clone(), next_key.clone()), + ) + .flatten(); + + let mut key_decomp_trace: Vec = vec![]; + // decompose each limb into sublimbs of size self.decomp() bits + for &val in key.iter() { + for i in 0..num_limbs { + let bits = (val >> (i * self.decomp())) & ((1 << self.decomp()) - 1); + key_decomp_trace.push(F::from_canonical_u32(bits)); + self.less_than_chip.range_checker_gate.add_count(bits); } - - row.extend(checks); - row.extend(lower_bits); - row.extend(upper_bit); - - // decompose each element of lower_bits so we can range check that the element - // has at most limb_bits bits - for val in lower_bits_u32 { - for j in 0..num_limbs { - let bits = (val >> (j * self.decomp())) & ((1 << self.decomp()) - 1); - row.push(F::from_canonical_u32(bits)); - self.range_checker_gate.add_count(bits); - } - let bits = - (val >> ((num_limbs - 1) * self.decomp())) & ((1 << self.decomp()) - 1); - if (bits << last_limb_shift) < MAX { - self.range_checker_gate.add_count(bits << last_limb_shift); - } - row.push(F::from_canonical_u32(bits << last_limb_shift)); + // the last sublimb should be of size self.limb_bits() % self.decomp() bits, + // so we need to shift it to constrain this + let bits = (val >> ((num_limbs - 1) * self.decomp())) & ((1 << self.decomp()) - 1); + if (bits << last_limb_shift) < MAX { + self.less_than_chip + .range_checker_gate + .add_count(bits << last_limb_shift); } + key_decomp_trace.push(F::from_canonical_u32(bits << last_limb_shift)); + } - row.extend(diff); - row.extend(is_zero); - row.extend(inverses); + let mut row: Vec = less_than_trace[0..self.key_vec_len()].to_vec(); + row.extend_from_slice(&key_decomp_trace); + row.extend_from_slice(&less_than_trace[self.key_vec_len()..]); - row - }) - .collect::>(); + rows.extend_from_slice(&row); + } - RowMajorMatrix::new(rows.concat(), num_cols) + RowMajorMatrix::new(rows, num_cols) } } diff --git a/chips/src/sub_chip.rs b/chips/src/sub_chip.rs new file mode 100644 index 0000000000..701f88e61b --- /dev/null +++ b/chips/src/sub_chip.rs @@ -0,0 +1,34 @@ +use afs_stark_backend::interaction::Interaction; +use p3_air::AirBuilder; +use p3_field::Field; + +pub trait AirConfig { + /// Column struct over generic type + type Cols; +} + +/// Trait with associated types intended to allow re-use of constraint logic +/// inside other AIRs. +pub trait SubAir: AirConfig { + type ColsPassed; + + fn eval(&self, builder: &mut AB, cols: Self::ColsPassed); +} + +pub trait LocalTraceInstructions: AirConfig { + /// Logical inputs needed to generate a single row of the trace. + type LocalInput; + + fn generate_trace_row(&self, local_input: Self::LocalInput) -> Self::Cols; +} + +pub trait SubAirWithInteractions: AirConfig { + fn sends(&self, col_indices: Self::Cols) -> Vec> { + let _ = col_indices; + vec![] + } + fn receives(&self, col_indices: Self::Cols) -> Vec> { + let _ = col_indices; + vec![] + } +} diff --git a/chips/src/xor_bits/air.rs b/chips/src/xor_bits/air.rs index 04bbcf3952..c26d7ebbdc 100644 --- a/chips/src/xor_bits/air.rs +++ b/chips/src/xor_bits/air.rs @@ -1,11 +1,12 @@ use std::borrow::Borrow; use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir}; -use p3_field::Field; +use p3_field::{AbstractField, Field}; use p3_matrix::Matrix; -use super::columns::XorCols; -use super::XorBitsChip; +use crate::sub_chip::{AirConfig, SubAir}; + +use super::{columns::XorCols, XorBitsChip}; impl BaseAir for XorBitsChip { fn width(&self) -> usize { @@ -26,6 +27,49 @@ where let xor_cols = XorCols::::from_slice(local); - self.impose_constraints(builder, xor_cols); + SubAir::eval(self, builder, xor_cols); + } +} + +impl AirConfig for XorBitsChip { + type Cols = XorCols; +} + +/// Imposes AIR constraints within each row of the trace +/// Constrains x, y, z to be equal to their bit representation in x_bits, y_bits, z_bits. +/// For each x_bit[i], y_bit[i], and z_bit[i], constraints x_bit[i] + y_bit[i] - 2 * x_bit[i] * y_bit[i] == z_bit[i], +/// which is equivalent to ensuring that x_bit[i] ^ y_bit[i] == z_bit[i]. +/// Overall, this ensures that x^y == z. +impl SubAir for XorBitsChip { + type ColsPassed = XorCols; + + fn eval(&self, builder: &mut AB, cols: Self::ColsPassed) { + let xor_cols = &cols; + + let mut x_from_bits: AB::Expr = AB::Expr::zero(); + for i in 0..N { + x_from_bits += xor_cols.x_bits[i] * AB::Expr::from_canonical_u64(1 << i); + } + builder.assert_eq(x_from_bits, xor_cols.io.x); + + let mut y_from_bits: AB::Expr = AB::Expr::zero(); + for i in 0..N { + y_from_bits += xor_cols.y_bits[i] * AB::Expr::from_canonical_u64(1 << i); + } + builder.assert_eq(y_from_bits, xor_cols.io.y); + + let mut z_from_bits: AB::Expr = AB::Expr::zero(); + for i in 0..N { + z_from_bits += xor_cols.z_bits[i] * AB::Expr::from_canonical_u64(1 << i); + } + builder.assert_eq(z_from_bits, xor_cols.io.z); + + for i in 0..N { + builder.assert_eq( + xor_cols.x_bits[i] + xor_cols.y_bits[i] + - AB::Expr::two() * xor_cols.x_bits[i] * xor_cols.y_bits[i], + xor_cols.z_bits[i], + ); + } } } diff --git a/chips/src/xor_bits/chip.rs b/chips/src/xor_bits/chip.rs index c3e8b0b537..34a4a5a408 100644 --- a/chips/src/xor_bits/chip.rs +++ b/chips/src/xor_bits/chip.rs @@ -1,15 +1,32 @@ use afs_stark_backend::interaction::{Chip, Interaction}; -use p3_field::PrimeField64; +use p3_air::VirtualPairCol; +use p3_field::{Field, PrimeField64}; + +use crate::sub_chip::SubAirWithInteractions; use super::{columns::XorCols, XorBitsChip}; impl Chip for XorBitsChip { fn receives(&self) -> Vec> { let num_cols = XorCols::::get_width(); - let all_cols = (0..num_cols).collect::>(); + let indices = (0..num_cols).collect::>(); + let col_indices = XorCols::::from_slice(&indices); - let cols_to_receive = XorCols::::cols_to_receive(&all_cols); + SubAirWithInteractions::receives(self, col_indices) + } +} - vec![self.receives_custom(cols_to_receive)] +impl SubAirWithInteractions for XorBitsChip { + fn receives(&self, col_indices: XorCols) -> Vec> { + let io_indices = col_indices.io; + vec![Interaction { + fields: vec![ + VirtualPairCol::single_main(io_indices.x), + VirtualPairCol::single_main(io_indices.y), + VirtualPairCol::single_main(io_indices.z), + ], + count: VirtualPairCol::constant(F::one()), + argument_index: self.bus_index(), + }] } } diff --git a/chips/src/xor_bits/mod.rs b/chips/src/xor_bits/mod.rs index 86b0a2c976..5750460cf4 100644 --- a/chips/src/xor_bits/mod.rs +++ b/chips/src/xor_bits/mod.rs @@ -1,13 +1,5 @@ -use afs_stark_backend::interaction::Interaction; -use columns::XorCols; -use p3_air::AirBuilder; -use p3_air::VirtualPairCol; -use p3_field::AbstractField; -use p3_field::PrimeField64; use parking_lot::Mutex; -use self::columns::XorIOCols; - pub mod air; pub mod chip; pub mod columns; @@ -43,55 +35,4 @@ impl XorBitsChip { pairs_locked.push((a, b)); self.calc_xor(a, b) } - - /// Imposes AIR constraints within each row of the trace - /// Constraints x, y, z to be equal to their bit representation in x_bits, y_bits, z_bits. - /// For each x_bit[i], y_bit[i], and z_bit[i], constraints x_bit[i] + y_bit[i] - 2 * x_bit[i] * y_bit[i] == z_bit[i], - /// which is equivalent to ensuring that x_bit[i] ^ y_bit[i] == z_bit[i]. - /// Overall, this ensures that x^y == z. - pub fn impose_constraints( - &self, - builder: &mut AB, - xor_cols: XorCols, - ) where - AB::Var: Clone, - { - let mut x_from_bits: AB::Expr = AB::Expr::zero(); - for i in 0..N { - x_from_bits += xor_cols.x_bits[i] * AB::Expr::from_canonical_u64(1 << i); - } - builder.assert_eq(x_from_bits, xor_cols.io.x); - - let mut y_from_bits: AB::Expr = AB::Expr::zero(); - for i in 0..N { - y_from_bits += xor_cols.y_bits[i] * AB::Expr::from_canonical_u64(1 << i); - } - builder.assert_eq(y_from_bits, xor_cols.io.y); - - let mut z_from_bits: AB::Expr = AB::Expr::zero(); - for i in 0..N { - z_from_bits += xor_cols.z_bits[i] * AB::Expr::from_canonical_u64(1 << i); - } - builder.assert_eq(z_from_bits, xor_cols.io.z); - - for i in 0..N { - builder.assert_eq( - xor_cols.x_bits[i] + xor_cols.y_bits[i] - - AB::Expr::two() * xor_cols.x_bits[i] * xor_cols.y_bits[i], - xor_cols.z_bits[i], - ); - } - } - - pub fn receives_custom(&self, cols: XorIOCols) -> Interaction { - Interaction { - fields: vec![ - VirtualPairCol::single_main(cols.x), - VirtualPairCol::single_main(cols.y), - VirtualPairCol::single_main(cols.z), - ], - count: VirtualPairCol::constant(F::one()), - argument_index: self.bus_index(), - } - } } diff --git a/chips/src/xor_bits/trace.rs b/chips/src/xor_bits/trace.rs index 9453a00373..6a0246db33 100644 --- a/chips/src/xor_bits/trace.rs +++ b/chips/src/xor_bits/trace.rs @@ -1,7 +1,12 @@ -use p3_field::PrimeField64; +use p3_field::{AbstractField, PrimeField64}; use p3_matrix::dense::RowMajorMatrix; -use super::{columns::XorCols, XorBitsChip}; +use crate::sub_chip::LocalTraceInstructions; + +use super::{ + columns::{XorCols, XorIOCols}, + XorBitsChip, +}; impl XorBitsChip { pub fn generate_trace(&self) -> RowMajorMatrix { @@ -12,23 +17,32 @@ impl XorBitsChip { let rows = pairs_locked .iter() - .map(|(x, y)| { - let z = self.calc_xor(*x, *y); - - let mut row = vec![ - F::from_canonical_u32(*x), - F::from_canonical_u32(*y), - F::from_canonical_u32(z), - ]; - - row.extend((0..N).map(|i| (x >> i) & 1).map(F::from_canonical_u32)); - row.extend((0..N).map(|i| (y >> i) & 1).map(F::from_canonical_u32)); - row.extend((0..N).map(|i| (z >> i) & 1).map(F::from_canonical_u32)); + .flat_map(|(x, y)| self.generate_trace_row((*x, *y)).flatten()) + .collect(); - row - }) - .collect::>(); + RowMajorMatrix::new(rows, num_xor_cols) + } +} - RowMajorMatrix::new(rows.concat(), num_xor_cols) +impl LocalTraceInstructions for XorBitsChip { + /// The input is (x, y) to be XOR-ed. + type LocalInput = (u32, u32); + + fn generate_trace_row(&self, (x, y): (u32, u32)) -> Self::Cols { + let z = self.calc_xor(x, y); + let [x_bits, y_bits, z_bits] = [x, y, z].map(|x| { + (0..N) + .map(|i| (x >> i) & 1) + .map(F::from_canonical_u32) + .collect() + }); + let [x, y, z] = [x, y, z].map(F::from_canonical_u32); + + XorCols { + io: XorIOCols { x, y, z }, + x_bits, + y_bits, + z_bits, + } } } From 06ada4b61fd36f3a6ab4f3c6e0ade359627ccbbe Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Fri, 31 May 2024 15:06:01 -0400 Subject: [PATCH 04/46] feat: less_than subchip refactored --- chips/Cargo.toml | 1 + chips/src/less_than/air.rs | 11 ++++---- chips/src/sorted_limbs/air.rs | 7 +++++- chips/src/sub_chip.rs | 12 ++++++--- chips/src/xor_bits/air.rs | 47 ++++++++++++++--------------------- chips/src/xor_bits/columns.rs | 25 ++++++++++++------- chips/src/xor_bits/trace.rs | 10 +++++--- 7 files changed, 62 insertions(+), 51 deletions(-) diff --git a/chips/Cargo.toml b/chips/Cargo.toml index a5a2a2504c..f16a1c0afe 100644 --- a/chips/Cargo.toml +++ b/chips/Cargo.toml @@ -20,6 +20,7 @@ afs-derive = { path = "../derive" } afs-test-utils = { path = "../test-utils" } parking_lot = "0.12.2" tracing = "0.1.40" +itertools = "0.13.0" [dev-dependencies] p3-uni-stark = { workspace = true } diff --git a/chips/src/less_than/air.rs b/chips/src/less_than/air.rs index 97ba13d267..e86b8b1947 100644 --- a/chips/src/less_than/air.rs +++ b/chips/src/less_than/air.rs @@ -33,7 +33,7 @@ where self.key_vec_len(), ); - SubAir::eval(self, builder, vec![local_cols]); + SubAir::eval(self, builder, (), vec![local_cols]); } } @@ -43,11 +43,12 @@ impl AirConfig for LessThanChip { // sub-chip with constraints to check whether one key is less than the next (row-wise) impl SubAir for LessThanChip { - type ColsPassed = Vec>; + type IoView = (); + type AuxView = Vec>; - fn eval(&self, builder: &mut AB, cols: Self::ColsPassed) { - let local_cols = &cols[0]; - let next_cols = &cols[1]; + fn eval(&self, builder: &mut AB, _io: Self::IoView, aux: Self::AuxView) { + let local_cols = &aux[0]; + let next_cols = &aux[1]; // num_limbs is the number of sublimbs per limb, not including the shifted last sublimb let num_limbs = (self.limb_bits() + self.decomp() - 1) / self.decomp(); diff --git a/chips/src/sorted_limbs/air.rs b/chips/src/sorted_limbs/air.rs index 665454cb6a..a7c8f33bda 100644 --- a/chips/src/sorted_limbs/air.rs +++ b/chips/src/sorted_limbs/air.rs @@ -79,6 +79,11 @@ where ); // constrain the current row is less than the next row - SubAir::eval(&self.less_than_chip, builder, vec![local_cols, next_cols]); + SubAir::eval( + &self.less_than_chip, + builder, + (), + vec![local_cols, next_cols], + ); } } diff --git a/chips/src/sub_chip.rs b/chips/src/sub_chip.rs index 701f88e61b..40c6bd6f33 100644 --- a/chips/src/sub_chip.rs +++ b/chips/src/sub_chip.rs @@ -9,12 +9,18 @@ pub trait AirConfig { /// Trait with associated types intended to allow re-use of constraint logic /// inside other AIRs. -pub trait SubAir: AirConfig { - type ColsPassed; +pub trait SubAir { + /// View of the parts of matrix relevant for IO. + /// Typically this is either 'local' IO columns or 'local' and 'next' IO columns. + type IoView; + /// View of auxiliary parts of matrix necessary for constraint evaluation. + /// Typically this is either a subset of 'local' columns or subset of 'local' and 'next' columns. + type AuxView; - fn eval(&self, builder: &mut AB, cols: Self::ColsPassed); + fn eval(&self, builder: &mut AB, io: Self::IoView, aux: Self::AuxView); } +// This is a helper for simple trace row generation. Not every AIR will need this. pub trait LocalTraceInstructions: AirConfig { /// Logical inputs needed to generate a single row of the trace. type LocalInput; diff --git a/chips/src/xor_bits/air.rs b/chips/src/xor_bits/air.rs index c26d7ebbdc..e10f652e83 100644 --- a/chips/src/xor_bits/air.rs +++ b/chips/src/xor_bits/air.rs @@ -1,12 +1,16 @@ -use std::borrow::Borrow; +use std::{borrow::Borrow, iter::zip}; +use itertools::Itertools; use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir}; use p3_field::{AbstractField, Field}; use p3_matrix::Matrix; use crate::sub_chip::{AirConfig, SubAir}; -use super::{columns::XorCols, XorBitsChip}; +use super::{ + columns::{XorBitCols, XorCols, XorIOCols}, + XorBitsChip, +}; impl BaseAir for XorBitsChip { fn width(&self) -> usize { @@ -27,7 +31,7 @@ where let xor_cols = XorCols::::from_slice(local); - SubAir::eval(self, builder, xor_cols); + SubAir::eval(self, builder, xor_cols.io, xor_cols.bits); } } @@ -41,35 +45,20 @@ impl AirConfig for XorBitsChip { /// which is equivalent to ensuring that x_bit[i] ^ y_bit[i] == z_bit[i]. /// Overall, this ensures that x^y == z. impl SubAir for XorBitsChip { - type ColsPassed = XorCols; + type IoView = XorIOCols; + type AuxView = XorBitCols; - fn eval(&self, builder: &mut AB, cols: Self::ColsPassed) { - let xor_cols = &cols; - - let mut x_from_bits: AB::Expr = AB::Expr::zero(); - for i in 0..N { - x_from_bits += xor_cols.x_bits[i] * AB::Expr::from_canonical_u64(1 << i); - } - builder.assert_eq(x_from_bits, xor_cols.io.x); - - let mut y_from_bits: AB::Expr = AB::Expr::zero(); - for i in 0..N { - y_from_bits += xor_cols.y_bits[i] * AB::Expr::from_canonical_u64(1 << i); - } - builder.assert_eq(y_from_bits, xor_cols.io.y); - - let mut z_from_bits: AB::Expr = AB::Expr::zero(); - for i in 0..N { - z_from_bits += xor_cols.z_bits[i] * AB::Expr::from_canonical_u64(1 << i); + fn eval(&self, builder: &mut AB, io: Self::IoView, bits: Self::AuxView) { + for (x, bit_decomp) in zip([io.x, io.y, io.z], [&bits.x, &bits.y, &bits.z]) { + let mut from_bits = AB::Expr::zero(); + for (i, &bit) in bit_decomp.iter().enumerate() { + from_bits += bit * AB::Expr::from_canonical_u32(1 << i); + } + builder.assert_eq(from_bits, x); } - builder.assert_eq(z_from_bits, xor_cols.io.z); - for i in 0..N { - builder.assert_eq( - xor_cols.x_bits[i] + xor_cols.y_bits[i] - - AB::Expr::two() * xor_cols.x_bits[i] * xor_cols.y_bits[i], - xor_cols.z_bits[i], - ); + for ((x, y), z) in bits.x.into_iter().zip_eq(bits.y).zip_eq(bits.z) { + builder.assert_eq(x + y - AB::Expr::two() * x * y, z); } } } diff --git a/chips/src/xor_bits/columns.rs b/chips/src/xor_bits/columns.rs index a6c8179aac..baff6118bd 100644 --- a/chips/src/xor_bits/columns.rs +++ b/chips/src/xor_bits/columns.rs @@ -7,11 +7,16 @@ pub struct XorIOCols { pub z: T, } +/// Bit decompositions +pub struct XorBitCols { + pub x: Vec, + pub y: Vec, + pub z: Vec, +} + pub struct XorCols { pub io: XorIOCols, - pub x_bits: Vec, - pub y_bits: Vec, - pub z_bits: Vec, + pub bits: XorBitCols, } impl XorCols { @@ -26,9 +31,11 @@ impl XorCols { Self { io: XorIOCols { x, y, z }, - x_bits, - y_bits, - z_bits, + bits: XorBitCols { + x: x_bits, + y: y_bits, + z: z_bits, + }, } } @@ -37,9 +44,9 @@ impl XorCols { flattened.extend_from_slice(&[self.io.x.clone(), self.io.y.clone(), self.io.z.clone()]); - flattened.extend_from_slice(&self.x_bits); - flattened.extend_from_slice(&self.y_bits); - flattened.extend_from_slice(&self.z_bits); + flattened.extend_from_slice(&self.bits.x); + flattened.extend_from_slice(&self.bits.y); + flattened.extend_from_slice(&self.bits.z); flattened } diff --git a/chips/src/xor_bits/trace.rs b/chips/src/xor_bits/trace.rs index 6a0246db33..cbf1567850 100644 --- a/chips/src/xor_bits/trace.rs +++ b/chips/src/xor_bits/trace.rs @@ -4,7 +4,7 @@ use p3_matrix::dense::RowMajorMatrix; use crate::sub_chip::LocalTraceInstructions; use super::{ - columns::{XorCols, XorIOCols}, + columns::{XorBitCols, XorCols, XorIOCols}, XorBitsChip, }; @@ -40,9 +40,11 @@ impl LocalTraceInstructions for XorBitsChip XorCols { io: XorIOCols { x, y, z }, - x_bits, - y_bits, - z_bits, + bits: XorBitCols { + x: x_bits, + y: y_bits, + z: z_bits, + }, } } } From b349571dfe0437ed7c18f1031aed70fdd41758cf Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Fri, 31 May 2024 15:50:38 -0400 Subject: [PATCH 05/46] feat: rename SortedLimbsChip to AssertSortedChip and write LessThanChip tests --- .../{sorted_limbs => assert_sorted}/air.rs | 18 +-- .../{sorted_limbs => assert_sorted}/chip.rs | 10 +- .../columns.rs | 26 ++-- .../{sorted_limbs => assert_sorted}/mod.rs | 10 +- .../tests/mod.rs | 20 +-- .../{sorted_limbs => assert_sorted}/trace.rs | 6 +- chips/src/less_than/air.rs | 57 +++++--- chips/src/less_than/chip.rs | 2 +- chips/src/less_than/columns.rs | 42 ++++-- chips/src/less_than/mod.rs | 9 +- chips/src/less_than/tests/mod.rs | 133 ++++++++++++++++++ chips/src/less_than/trace.rs | 13 +- chips/src/lib.rs | 2 +- 13 files changed, 259 insertions(+), 89 deletions(-) rename chips/src/{sorted_limbs => assert_sorted}/air.rs (84%) rename chips/src/{sorted_limbs => assert_sorted}/chip.rs (81%) rename chips/src/{sorted_limbs => assert_sorted}/columns.rs (91%) rename chips/src/{sorted_limbs => assert_sorted}/mod.rs (90%) rename chips/src/{sorted_limbs => assert_sorted}/tests/mod.rs (89%) rename chips/src/{sorted_limbs => assert_sorted}/trace.rs (92%) create mode 100644 chips/src/less_than/tests/mod.rs diff --git a/chips/src/sorted_limbs/air.rs b/chips/src/assert_sorted/air.rs similarity index 84% rename from chips/src/sorted_limbs/air.rs rename to chips/src/assert_sorted/air.rs index a7c8f33bda..865761d331 100644 --- a/chips/src/sorted_limbs/air.rs +++ b/chips/src/assert_sorted/air.rs @@ -7,16 +7,16 @@ use p3_matrix::Matrix; use crate::less_than::columns::LessThanCols; use crate::sub_chip::SubAir; -use super::columns::SortedLimbsCols; -use super::SortedLimbsChip; +use super::columns::AssertedSortedCols; +use super::AssertedSortedChip; -impl BaseAir for SortedLimbsChip { +impl BaseAir for AssertedSortedChip { fn width(&self) -> usize { - SortedLimbsCols::::get_width(self.limb_bits(), self.decomp(), self.key_vec_len()) + AssertedSortedCols::::get_width(self.limb_bits(), self.decomp(), self.key_vec_len()) } } -impl Air for SortedLimbsChip +impl Air for AssertedSortedChip where AB: AirBuilder, AB::Var: Clone, @@ -27,7 +27,7 @@ where let (local, next) = (main.row_slice(0), main.row_slice(1)); let local: &[AB::Var] = (*local).borrow(); - let local_cols = SortedLimbsCols::::from_slice( + let local_cols = AssertedSortedCols::::from_slice( local, self.limb_bits(), self.decomp(), @@ -54,7 +54,7 @@ where * AB::Expr::from_canonical_u64(1 << last_limb_shift); builder.assert_eq(local_cols.keys_decomp[i][num_limbs], shifted_val); - builder.assert_eq(key_from_limbs, local_cols.less_than_cols.key[i]); + builder.assert_eq(key_from_limbs, local_cols.less_than_cols.io.key[i]); } // generate LessThanCols struct for current row and next row @@ -82,8 +82,8 @@ where SubAir::eval( &self.less_than_chip, builder, - (), - vec![local_cols, next_cols], + vec![local_cols.io, next_cols.io], + local_cols.aux, ); } } diff --git a/chips/src/sorted_limbs/chip.rs b/chips/src/assert_sorted/chip.rs similarity index 81% rename from chips/src/sorted_limbs/chip.rs rename to chips/src/assert_sorted/chip.rs index d970b70c2c..531b4f1a7b 100644 --- a/chips/src/sorted_limbs/chip.rs +++ b/chips/src/assert_sorted/chip.rs @@ -1,19 +1,19 @@ use crate::sub_chip::SubAirWithInteractions; -use super::columns::SortedLimbsCols; +use super::columns::AssertedSortedCols; use afs_stark_backend::interaction::{Chip, Interaction}; use p3_air::VirtualPairCol; use p3_field::PrimeField64; -use super::SortedLimbsChip; +use super::AssertedSortedChip; -impl Chip for SortedLimbsChip { +impl Chip for AssertedSortedChip { fn sends(&self) -> Vec> { let num_cols = - SortedLimbsCols::::get_width(self.limb_bits(), self.decomp(), self.key_vec_len()); + AssertedSortedCols::::get_width(self.limb_bits(), self.decomp(), self.key_vec_len()); let all_cols = (0..num_cols).collect::>(); - let cols_numbered = SortedLimbsCols::::cols_numbered( + let cols_numbered = AssertedSortedCols::::cols_numbered( &all_cols, self.limb_bits(), self.decomp(), diff --git a/chips/src/sorted_limbs/columns.rs b/chips/src/assert_sorted/columns.rs similarity index 91% rename from chips/src/sorted_limbs/columns.rs rename to chips/src/assert_sorted/columns.rs index 7ba71313ee..a5535dea0b 100644 --- a/chips/src/sorted_limbs/columns.rs +++ b/chips/src/assert_sorted/columns.rs @@ -1,16 +1,16 @@ use afs_derive::AlignedBorrow; -use crate::less_than::columns::LessThanCols; +use crate::less_than::columns::{LessThanAuxCols, LessThanCols, LessThanIOCols}; -// Since SortedLimbsChip contains a LessThanChip subchip, a subset of the columns are those of the +// Since AssertedSortedChip contains a LessThanChip subchip, a subset of the columns are those of the // LessThanChip -#[derive(Default, AlignedBorrow)] -pub struct SortedLimbsCols { +#[derive(AlignedBorrow)] +pub struct AssertedSortedCols { pub keys_decomp: Vec>, pub less_than_cols: LessThanCols, } -impl SortedLimbsCols { +impl AssertedSortedCols { pub fn from_slice(slc: &[T], limb_bits: usize, decomp: usize, key_vec_len: usize) -> Self { // num_limbs is the number of sublimbs per limb, not including the shifted last sublimb let num_limbs = (limb_bits + decomp - 1) / decomp; @@ -72,8 +72,8 @@ impl SortedLimbsCols { // note that this sum will always be nonzero so the inverse will exist let inverses = slc[cur_start_idx..cur_end_idx].to_vec(); - let less_than_cols = LessThanCols { - key, + let io = LessThanIOCols { key }; + let aux = LessThanAuxCols { intermed_sum, lower_bits, upper_bit, @@ -83,6 +83,8 @@ impl SortedLimbsCols { inverses, }; + let less_than_cols = LessThanCols { io, aux }; + Self { keys_decomp, less_than_cols, @@ -122,7 +124,7 @@ impl SortedLimbsCols { limb_bits: usize, decomp: usize, key_vec_len: usize, - ) -> SortedLimbsCols { + ) -> AssertedSortedCols { // num_limbs is the number of sublimbs per limb, not including the shifted last sublimb let num_limbs = (limb_bits + decomp - 1) / decomp; let mut cur_start_idx = 0; @@ -175,8 +177,8 @@ impl SortedLimbsCols { // the next key_vec_len elements are the inverses let inverses = cols[cur_start_idx..cur_end_idx].to_vec(); - let less_than_cols = LessThanCols { - key, + let io = LessThanIOCols { key }; + let aux = LessThanAuxCols { intermed_sum, lower_bits, upper_bit, @@ -186,7 +188,9 @@ impl SortedLimbsCols { inverses, }; - SortedLimbsCols { + let less_than_cols = LessThanCols { io, aux }; + + AssertedSortedCols { keys_decomp, less_than_cols, } diff --git a/chips/src/sorted_limbs/mod.rs b/chips/src/assert_sorted/mod.rs similarity index 90% rename from chips/src/sorted_limbs/mod.rs rename to chips/src/assert_sorted/mod.rs index 4c205865f3..2dd75b0f16 100644 --- a/chips/src/sorted_limbs/mod.rs +++ b/chips/src/assert_sorted/mod.rs @@ -1,7 +1,7 @@ use crate::less_than::LessThanChip; use afs_stark_backend::interaction::Interaction; -use columns::SortedLimbsCols; +use columns::AssertedSortedCols; use p3_air::VirtualPairCol; use p3_field::PrimeField64; @@ -22,15 +22,15 @@ pub mod trace; * gate can take MAX up to 2^20, we further decompose each limb into sublimbs * of size decomp bits. * - * The SortedLimbsChip contains a LessThanChip subchip, which is used to constrain + * The AssertedSortedChip contains a LessThanChip subchip, which is used to constrain * that the rows are sorted lexicographically. */ #[derive(Default)] -pub struct SortedLimbsChip { +pub struct AssertedSortedChip { less_than_chip: LessThanChip, } -impl SortedLimbsChip { +impl AssertedSortedChip { pub fn new( bus_index: usize, limb_bits: usize, @@ -71,7 +71,7 @@ impl SortedLimbsChip { pub fn sends_custom( &self, - cols: &SortedLimbsCols, + cols: &AssertedSortedCols, ) -> Vec> { // num_limbs is the number of sublimbs per limb of key, not including the // shifted last sublimb diff --git a/chips/src/sorted_limbs/tests/mod.rs b/chips/src/assert_sorted/tests/mod.rs similarity index 89% rename from chips/src/sorted_limbs/tests/mod.rs rename to chips/src/assert_sorted/tests/mod.rs index 526e81af9d..1799a29eec 100644 --- a/chips/src/sorted_limbs/tests/mod.rs +++ b/chips/src/assert_sorted/tests/mod.rs @@ -1,4 +1,4 @@ -use super::super::sorted_limbs; +use super::super::assert_sorted; use afs_stark_backend::prover::USE_DEBUG_BUILDER; use afs_stark_backend::verifier::VerificationError; @@ -7,7 +7,7 @@ use p3_baby_bear::BabyBear; use p3_matrix::dense::DenseMatrix; /** - * Testing strategy for the sorted limbs chip: + * Testing strategy for the assert sorted chip: * partition on limb_bits: * limb_bits < 20 * limb_bits >= 20 @@ -32,7 +32,7 @@ use p3_matrix::dense::DenseMatrix; // most limb_bits bits, rows are sorted lexicographically #[test] fn test_sorted_limbs_chip_small_positive() { - use sorted_limbs::SortedLimbsChip; + use assert_sorted::AssertedSortedChip; const BUS_INDEX: usize = 0; const LIMB_BITS: usize = 16; @@ -44,7 +44,7 @@ fn test_sorted_limbs_chip_small_positive() { let requests = vec![vec![7784, 35423], vec![17558, 44832]]; let sorted_limbs_chip = - SortedLimbsChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); + AssertedSortedChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); let sorted_limbs_chip_trace: DenseMatrix = sorted_limbs_chip.generate_trace(); let sorted_limbs_range_chip_trace: DenseMatrix = sorted_limbs_chip @@ -66,7 +66,7 @@ fn test_sorted_limbs_chip_small_positive() { // most limb_bits bits, rows are sorted lexicographically #[test] fn test_sorted_limbs_chip_large_positive() { - use sorted_limbs::SortedLimbsChip; + use assert_sorted::AssertedSortedChip; const BUS_INDEX: usize = 0; const LIMB_BITS: usize = 30; @@ -83,7 +83,7 @@ fn test_sorted_limbs_chip_large_positive() { ]; let sorted_limbs_chip = - SortedLimbsChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); + AssertedSortedChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); let sorted_limbs_chip_trace: DenseMatrix = sorted_limbs_chip.generate_trace(); let sorted_limbs_range_chip_trace: DenseMatrix = sorted_limbs_chip @@ -105,7 +105,7 @@ fn test_sorted_limbs_chip_large_positive() { // has more than limb_bits bits, rows are sorted lexicographically #[test] fn test_sorted_limbs_chip_largelimb_negative() { - use sorted_limbs::SortedLimbsChip; + use assert_sorted::AssertedSortedChip; const BUS_INDEX: usize = 0; const LIMB_BITS: usize = 10; @@ -123,7 +123,7 @@ fn test_sorted_limbs_chip_largelimb_negative() { ]; let sorted_limbs_chip = - SortedLimbsChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); + AssertedSortedChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); let sorted_limbs_chip_trace: DenseMatrix = sorted_limbs_chip.generate_trace(); let sorted_limbs_range_chip_trace: DenseMatrix = sorted_limbs_chip @@ -150,7 +150,7 @@ fn test_sorted_limbs_chip_largelimb_negative() { // most limb_bits bits, rows are not sorted lexicographically #[test] fn test_sorted_limbs_chip_unsorted_negative() { - use sorted_limbs::SortedLimbsChip; + use assert_sorted::AssertedSortedChip; const BUS_INDEX: usize = 0; const LIMB_BITS: usize = 30; @@ -168,7 +168,7 @@ fn test_sorted_limbs_chip_unsorted_negative() { ]; let sorted_limbs_chip = - SortedLimbsChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); + AssertedSortedChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); let sorted_limbs_chip_trace: DenseMatrix = sorted_limbs_chip.generate_trace(); let sorted_limbs_range_chip_trace: DenseMatrix = sorted_limbs_chip diff --git a/chips/src/sorted_limbs/trace.rs b/chips/src/assert_sorted/trace.rs similarity index 92% rename from chips/src/sorted_limbs/trace.rs rename to chips/src/assert_sorted/trace.rs index 7f4fafd214..78a835d808 100644 --- a/chips/src/sorted_limbs/trace.rs +++ b/chips/src/assert_sorted/trace.rs @@ -3,12 +3,12 @@ use p3_matrix::dense::RowMajorMatrix; use crate::sub_chip::LocalTraceInstructions; -use super::{columns::SortedLimbsCols, SortedLimbsChip}; +use super::{columns::AssertedSortedCols, AssertedSortedChip}; -impl SortedLimbsChip { +impl AssertedSortedChip { pub fn generate_trace(&self) -> RowMajorMatrix { let num_cols: usize = - SortedLimbsCols::::get_width(self.limb_bits(), self.decomp(), self.key_vec_len()); + AssertedSortedCols::::get_width(self.limb_bits(), self.decomp(), self.key_vec_len()); let num_limbs = (self.limb_bits() + self.decomp() - 1) / self.decomp(); diff --git a/chips/src/less_than/air.rs b/chips/src/less_than/air.rs index e86b8b1947..135b95bab3 100644 --- a/chips/src/less_than/air.rs +++ b/chips/src/less_than/air.rs @@ -6,7 +6,10 @@ use p3_matrix::Matrix; use crate::sub_chip::{AirConfig, SubAir}; -use super::{columns::LessThanCols, LessThanChip}; +use super::{ + columns::{LessThanAuxCols, LessThanCols, LessThanIOCols}, + LessThanChip, +}; impl BaseAir for LessThanChip { fn width(&self) -> usize { @@ -23,8 +26,9 @@ where let main = builder.main(); let _pis = builder.public_values(); - let (local, _next) = (main.row_slice(0), main.row_slice(1)); + let (local, next) = (main.row_slice(0), main.row_slice(1)); let local: &[AB::Var] = (*local).borrow(); + let next: &[AB::Var] = (*next).borrow(); let local_cols = LessThanCols::::from_slice( local, @@ -33,7 +37,19 @@ where self.key_vec_len(), ); - SubAir::eval(self, builder, (), vec![local_cols]); + let next_cols = LessThanCols::::from_slice( + next, + self.limb_bits(), + self.decomp(), + self.key_vec_len(), + ); + + SubAir::eval( + self, + builder, + vec![local_cols.io, next_cols.io], + local_cols.aux, + ); } } @@ -43,20 +59,22 @@ impl AirConfig for LessThanChip { // sub-chip with constraints to check whether one key is less than the next (row-wise) impl SubAir for LessThanChip { - type IoView = (); - type AuxView = Vec>; + type IoView = Vec>; + type AuxView = LessThanAuxCols; + + fn eval(&self, builder: &mut AB, io: Self::IoView, aux: Self::AuxView) { + let local_key = io[0].key.clone(); + let next_key = io[1].key.clone(); - fn eval(&self, builder: &mut AB, _io: Self::IoView, aux: Self::AuxView) { - let local_cols = &aux[0]; - let next_cols = &aux[1]; + let local_aux = &aux; // num_limbs is the number of sublimbs per limb, not including the shifted last sublimb let num_limbs = (self.limb_bits() + self.decomp() - 1) / self.decomp(); - let intermed_sum = local_cols.intermed_sum.clone(); - let lower_bits = local_cols.lower_bits.clone(); - let upper_bit = local_cols.upper_bit.clone(); - let lower_bits_decomp = local_cols.lower_bits_decomp.clone(); + let intermed_sum = local_aux.intermed_sum.clone(); + let lower_bits = local_aux.lower_bits.clone(); + let upper_bit = local_aux.upper_bit.clone(); + let lower_bits_decomp = local_aux.lower_bits_decomp.clone(); // we want to check these constraints for each row except the last one let mut when_transition = builder.when_transition(); @@ -65,9 +83,7 @@ impl SubAir for LessThanChip { // the correct range let last_limb_shift = (self.decomp() - (self.limb_bits() % self.decomp())) % self.decomp(); - for (i, (key_local, key_next)) in - local_cols.key.iter().zip(next_cols.key.iter()).enumerate() - { + for (i, (key_local, key_next)) in local_key.iter().zip(next_key.iter()).enumerate() { // this is the desired intermediate value (i.e. 2^limb_bits + b - a - 1) let intermed_val = *key_next - *key_local + AB::Expr::from_canonical_u64(1 << self.limb_bits()) @@ -83,8 +99,7 @@ impl SubAir for LessThanChip { // constrain that diff is the difference between the two elements of consecutive rows let diff = *key_next - *key_local; - //when_transition.assert_zero(local_cols.diff[i]); - when_transition.assert_eq(diff, local_cols.diff[i]); + when_transition.assert_eq(diff, local_aux.diff[i]); } for i in 0..self.key_vec_len() { @@ -110,9 +125,9 @@ impl SubAir for LessThanChip { } for i in 0..self.key_vec_len() { - let diff = local_cols.diff[i]; - let is_equal = local_cols.is_zero[i]; - let inverse = local_cols.inverses[i]; + let diff = local_aux.diff[i]; + let is_equal = local_aux.is_zero[i]; + let inverse = local_aux.inverses[i]; // check that diff * is_equal = 0 when_transition.assert_zero(diff * is_equal); @@ -130,7 +145,7 @@ impl SubAir for LessThanChip { for (i, &upper_bit_value) in upper_bit.iter().enumerate() { let mut curr_expr: AB::Expr = upper_bit_value.into(); - for &is_zero_value in &local_cols.is_zero[i + 1..] { + for &is_zero_value in &local_aux.is_zero[i + 1..] { curr_expr *= is_zero_value.into(); } check_less_than += curr_expr; diff --git a/chips/src/less_than/chip.rs b/chips/src/less_than/chip.rs index df76a04840..44f09f9880 100644 --- a/chips/src/less_than/chip.rs +++ b/chips/src/less_than/chip.rs @@ -39,7 +39,7 @@ impl SubAirWithInteractions for LessThanChip for j in 0..(num_limbs + 1) { interactions.push(Interaction { fields: vec![VirtualPairCol::single_main( - col_indices.lower_bits_decomp[i][j], + col_indices.aux.lower_bits_decomp[i][j], )], count: VirtualPairCol::constant(F::one()), argument_index: self.bus_index(), diff --git a/chips/src/less_than/columns.rs b/chips/src/less_than/columns.rs index b47d3ef2e1..5e39b9880e 100644 --- a/chips/src/less_than/columns.rs +++ b/chips/src/less_than/columns.rs @@ -1,8 +1,11 @@ use afs_derive::AlignedBorrow; #[derive(Default, AlignedBorrow)] -pub struct LessThanCols { +pub struct LessThanIOCols { pub key: Vec, +} + +pub struct LessThanAuxCols { pub intermed_sum: Vec, pub lower_bits: Vec, pub upper_bit: Vec, @@ -12,6 +15,11 @@ pub struct LessThanCols { pub inverses: Vec, } +pub struct LessThanCols { + pub io: LessThanIOCols, + pub aux: LessThanAuxCols, +} + impl LessThanCols { pub fn from_slice(slc: &[T], limb_bits: usize, decomp: usize, key_vec_len: usize) -> Self { // num_limbs is the number of sublimbs per limb, not including the shifted last sublimb @@ -65,8 +73,8 @@ impl LessThanCols { // note that this sum will always be nonzero so the inverse will exist let inverses = slc[cur_start_idx..cur_end_idx].to_vec(); - Self { - key, + let io = LessThanIOCols { key }; + let aux = LessThanAuxCols { intermed_sum, lower_bits, upper_bit, @@ -74,21 +82,23 @@ impl LessThanCols { diff, is_zero, inverses, - } + }; + + Self { io, aux } } pub fn flatten(&self) -> Vec { let mut flattened = vec![]; - flattened.extend_from_slice(&self.key); - flattened.extend_from_slice(&self.intermed_sum); - flattened.extend_from_slice(&self.lower_bits); - flattened.extend_from_slice(&self.upper_bit); - for decomp_vec in &self.lower_bits_decomp { + flattened.extend_from_slice(&self.io.key); + flattened.extend_from_slice(&self.aux.intermed_sum); + flattened.extend_from_slice(&self.aux.lower_bits); + flattened.extend_from_slice(&self.aux.upper_bit); + for decomp_vec in &self.aux.lower_bits_decomp { flattened.extend_from_slice(decomp_vec); } - flattened.extend_from_slice(&self.diff); - flattened.extend_from_slice(&self.is_zero); - flattened.extend_from_slice(&self.inverses); + flattened.extend_from_slice(&self.aux.diff); + flattened.extend_from_slice(&self.aux.is_zero); + flattened.extend_from_slice(&self.aux.inverses); flattened } @@ -169,8 +179,8 @@ impl LessThanCols { // the next key_vec_len elements are the inverses let inverses = cols[cur_start_idx..cur_end_idx].to_vec(); - LessThanCols { - key, + let io = LessThanIOCols { key }; + let aux = LessThanAuxCols { intermed_sum, lower_bits, upper_bit, @@ -178,6 +188,8 @@ impl LessThanCols { diff, is_zero, inverses, - } + }; + + LessThanCols { io, aux } } } diff --git a/chips/src/less_than/mod.rs b/chips/src/less_than/mod.rs index e6f6005cbb..9e688ae3e5 100644 --- a/chips/src/less_than/mod.rs +++ b/chips/src/less_than/mod.rs @@ -1,5 +1,8 @@ use crate::range_gate::RangeCheckerGateChip; +#[cfg(test)] +pub mod tests; + pub mod air; pub mod chip; pub mod columns; @@ -8,11 +11,7 @@ pub mod trace; /** * This Chip constrains that consecutive rows are sorted lexicographically. * - * Each row consists of a key decomposed into limbs, and the chip constrains - * each limb has at most limb_bits bits, where limb_bits is at most 31. It - * does this by interacting with a RangeCheckerGateChip. Because the range checker - * gate can take MAX up to 2^20, we further decompose each limb into sublimbs - * of size decomp bits. + * Each row consists of a key decomposed into limbs with at most limb_bits bits */ #[derive(Default)] pub struct LessThanChip { diff --git a/chips/src/less_than/tests/mod.rs b/chips/src/less_than/tests/mod.rs new file mode 100644 index 0000000000..b2779e9c83 --- /dev/null +++ b/chips/src/less_than/tests/mod.rs @@ -0,0 +1,133 @@ +use super::super::less_than; + +use afs_stark_backend::prover::USE_DEBUG_BUILDER; +use afs_stark_backend::verifier::VerificationError; +use afs_test_utils::config::baby_bear_poseidon2::run_simple_test_no_pis; +use p3_baby_bear::BabyBear; +use p3_matrix::dense::DenseMatrix; + +/** + * Testing strategy for the less than chip: + * partition on limb_bits: + * limb_bits < 20 + * limb_bits >= 20 + * partition on key_vec_len: + * key_vec_len < 4 + * key_vec_len >= 4 + * partition on decomp: + * limb_bits % decomp == 0 + * limb_bits % decomp != 0 + * partition on number of rows: + * number of rows < 4 + * number of rows >= 4 + * partition on size of each limb: + * each limb has at most limb_bits bits + * at least one limb has more than limb_bits bits + * partition on row order: + * rows are sorted lexicographically + * rows are not sorted lexicographically + */ + +// covers limb_bits < 20, key_vec_len < 4, limb_bits % decomp == 0, number of rows < 4, each limb has at +// most limb_bits bits, rows are sorted lexicographically +#[test] +fn test_less_than_chip_small_positive() { + use less_than::LessThanChip; + + const BUS_INDEX: usize = 0; + const LIMB_BITS: usize = 16; + const DECOMP: usize = 8; + const KEY_VEC_LEN: usize = 2; + + const MAX: u32 = 1 << DECOMP; + + let requests = vec![vec![7784, 35423], vec![17558, 44832]]; + + let less_than_chip = + LessThanChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); + + let less_than_chip_trace: DenseMatrix = less_than_chip.generate_trace(); + let less_than_range_chip_trace: DenseMatrix = + less_than_chip.range_checker_gate.generate_trace(); + + run_simple_test_no_pis( + vec![&less_than_chip, &less_than_chip.range_checker_gate], + vec![less_than_chip_trace, less_than_range_chip_trace], + ) + .expect("Verification failed"); +} + +// covers limb_bits >= 20, key_vec_len >= 4, limb_bits % decomp != 0, number of rows >= 4, each limb has at +// most limb_bits bits, rows are sorted lexicographically +#[test] +fn test_less_than_chip_large_positive() { + use less_than::LessThanChip; + + const BUS_INDEX: usize = 0; + const LIMB_BITS: usize = 30; + const DECOMP: usize = 8; + const KEY_VEC_LEN: usize = 4; + + const MAX: u32 = 1 << DECOMP; + + let requests = vec![ + vec![35867, 318434, 12786, 44832], + vec![704210, 369315, 42421, 487111], + vec![370183, 37202, 729789, 783571], + vec![875005, 767547, 196209, 887921], + ]; + + let less_than_chip = + LessThanChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); + + let less_than_chip_trace: DenseMatrix = less_than_chip.generate_trace(); + let less_than_range_chip_trace: DenseMatrix = + less_than_chip.range_checker_gate.generate_trace(); + + run_simple_test_no_pis( + vec![&less_than_chip, &less_than_chip.range_checker_gate], + vec![less_than_chip_trace, less_than_range_chip_trace], + ) + .expect("Verification failed"); +} + +// covers limb_bits >= 20, key_vec_len >= 4, limb_bits % decomp != 0, number of rows >= 4, each limb has at +// most limb_bits bits, rows are not sorted lexicographically +#[test] +fn test_less_than_chip_unsorted_negative() { + use less_than::LessThanChip; + + const BUS_INDEX: usize = 0; + const LIMB_BITS: usize = 30; + const DECOMP: usize = 8; + const KEY_VEC_LEN: usize = 4; + + const MAX: u32 = 1 << DECOMP; + + // the first and second rows are not in sorted order + let requests = vec![ + vec![704210, 369315, 42421, 44832], + vec![35867, 318434, 12786, 44832], + vec![370183, 37202, 729789, 783571], + vec![875005, 767547, 196209, 887921], + ]; + + let less_than_chip = + LessThanChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); + + let less_than_chip_trace: DenseMatrix = less_than_chip.generate_trace(); + let less_than_range_chip_trace: DenseMatrix = + less_than_chip.range_checker_gate.generate_trace(); + + USE_DEBUG_BUILDER.with(|debug| { + *debug.lock().unwrap() = false; + }); + assert_eq!( + run_simple_test_no_pis( + vec![&less_than_chip, &less_than_chip.range_checker_gate,], + vec![less_than_chip_trace, less_than_range_chip_trace], + ), + Err(VerificationError::OodEvaluationMismatch), + "Expected verification to fail, but it passed" + ); +} diff --git a/chips/src/less_than/trace.rs b/chips/src/less_than/trace.rs index d4229f476d..9bf4350084 100644 --- a/chips/src/less_than/trace.rs +++ b/chips/src/less_than/trace.rs @@ -3,7 +3,10 @@ use p3_matrix::dense::RowMajorMatrix; use crate::sub_chip::LocalTraceInstructions; -use super::{columns::LessThanCols, LessThanChip}; +use super::{ + columns::{LessThanAuxCols, LessThanCols, LessThanIOCols}, + LessThanChip, +}; impl LessThanChip { pub fn generate_trace(&self) -> RowMajorMatrix { @@ -102,8 +105,10 @@ impl LocalTraceInstructions for LessThanChip } } - LessThanCols { + let io = LessThanIOCols { key: key.into_iter().map(F::from_canonical_u32).collect(), + }; + let aux = LessThanAuxCols { intermed_sum, lower_bits, upper_bit, @@ -111,6 +116,8 @@ impl LocalTraceInstructions for LessThanChip diff, is_zero, inverses, - } + }; + + LessThanCols { io, aux } } } diff --git a/chips/src/lib.rs b/chips/src/lib.rs index 59472509bd..55773c2b15 100644 --- a/chips/src/lib.rs +++ b/chips/src/lib.rs @@ -1,3 +1,4 @@ +pub mod assert_sorted; pub mod keccak_permute; pub mod less_than; pub mod merkle_proof; @@ -6,7 +7,6 @@ pub mod page_read; /// Chip to range check a value has less than a fixed number of bits pub mod range; pub mod range_gate; -pub mod sorted_limbs; pub mod sub_chip; pub mod utils; pub mod xor_bits; From dbada8353a377ff82354ceaf1d05dfd09b708547 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Fri, 31 May 2024 17:36:36 -0400 Subject: [PATCH 06/46] chore: change name of assert sorted chip --- chips/src/assert_sorted/air.rs | 12 ++++++------ chips/src/assert_sorted/chip.rs | 10 +++++----- chips/src/assert_sorted/columns.rs | 10 +++++----- chips/src/assert_sorted/mod.rs | 10 +++++----- chips/src/assert_sorted/tests/mod.rs | 16 ++++++++-------- chips/src/assert_sorted/trace.rs | 6 +++--- chips/src/less_than/tests/mod.rs | 3 --- 7 files changed, 32 insertions(+), 35 deletions(-) diff --git a/chips/src/assert_sorted/air.rs b/chips/src/assert_sorted/air.rs index 865761d331..a7db1dd053 100644 --- a/chips/src/assert_sorted/air.rs +++ b/chips/src/assert_sorted/air.rs @@ -7,16 +7,16 @@ use p3_matrix::Matrix; use crate::less_than::columns::LessThanCols; use crate::sub_chip::SubAir; -use super::columns::AssertedSortedCols; -use super::AssertedSortedChip; +use super::columns::AssertSortedCols; +use super::AssertSortedChip; -impl BaseAir for AssertedSortedChip { +impl BaseAir for AssertSortedChip { fn width(&self) -> usize { - AssertedSortedCols::::get_width(self.limb_bits(), self.decomp(), self.key_vec_len()) + AssertSortedCols::::get_width(self.limb_bits(), self.decomp(), self.key_vec_len()) } } -impl Air for AssertedSortedChip +impl Air for AssertSortedChip where AB: AirBuilder, AB::Var: Clone, @@ -27,7 +27,7 @@ where let (local, next) = (main.row_slice(0), main.row_slice(1)); let local: &[AB::Var] = (*local).borrow(); - let local_cols = AssertedSortedCols::::from_slice( + let local_cols = AssertSortedCols::::from_slice( local, self.limb_bits(), self.decomp(), diff --git a/chips/src/assert_sorted/chip.rs b/chips/src/assert_sorted/chip.rs index 531b4f1a7b..f5ee4a3ea6 100644 --- a/chips/src/assert_sorted/chip.rs +++ b/chips/src/assert_sorted/chip.rs @@ -1,19 +1,19 @@ use crate::sub_chip::SubAirWithInteractions; -use super::columns::AssertedSortedCols; +use super::columns::AssertSortedCols; use afs_stark_backend::interaction::{Chip, Interaction}; use p3_air::VirtualPairCol; use p3_field::PrimeField64; -use super::AssertedSortedChip; +use super::AssertSortedChip; -impl Chip for AssertedSortedChip { +impl Chip for AssertSortedChip { fn sends(&self) -> Vec> { let num_cols = - AssertedSortedCols::::get_width(self.limb_bits(), self.decomp(), self.key_vec_len()); + AssertSortedCols::::get_width(self.limb_bits(), self.decomp(), self.key_vec_len()); let all_cols = (0..num_cols).collect::>(); - let cols_numbered = AssertedSortedCols::::cols_numbered( + let cols_numbered = AssertSortedCols::::cols_numbered( &all_cols, self.limb_bits(), self.decomp(), diff --git a/chips/src/assert_sorted/columns.rs b/chips/src/assert_sorted/columns.rs index a5535dea0b..d4933a97f3 100644 --- a/chips/src/assert_sorted/columns.rs +++ b/chips/src/assert_sorted/columns.rs @@ -2,15 +2,15 @@ use afs_derive::AlignedBorrow; use crate::less_than::columns::{LessThanAuxCols, LessThanCols, LessThanIOCols}; -// Since AssertedSortedChip contains a LessThanChip subchip, a subset of the columns are those of the +// Since AssertSortedChip contains a LessThanChip subchip, a subset of the columns are those of the // LessThanChip #[derive(AlignedBorrow)] -pub struct AssertedSortedCols { +pub struct AssertSortedCols { pub keys_decomp: Vec>, pub less_than_cols: LessThanCols, } -impl AssertedSortedCols { +impl AssertSortedCols { pub fn from_slice(slc: &[T], limb_bits: usize, decomp: usize, key_vec_len: usize) -> Self { // num_limbs is the number of sublimbs per limb, not including the shifted last sublimb let num_limbs = (limb_bits + decomp - 1) / decomp; @@ -124,7 +124,7 @@ impl AssertedSortedCols { limb_bits: usize, decomp: usize, key_vec_len: usize, - ) -> AssertedSortedCols { + ) -> AssertSortedCols { // num_limbs is the number of sublimbs per limb, not including the shifted last sublimb let num_limbs = (limb_bits + decomp - 1) / decomp; let mut cur_start_idx = 0; @@ -190,7 +190,7 @@ impl AssertedSortedCols { let less_than_cols = LessThanCols { io, aux }; - AssertedSortedCols { + AssertSortedCols { keys_decomp, less_than_cols, } diff --git a/chips/src/assert_sorted/mod.rs b/chips/src/assert_sorted/mod.rs index 2dd75b0f16..00f48897d1 100644 --- a/chips/src/assert_sorted/mod.rs +++ b/chips/src/assert_sorted/mod.rs @@ -1,7 +1,7 @@ use crate::less_than::LessThanChip; use afs_stark_backend::interaction::Interaction; -use columns::AssertedSortedCols; +use columns::AssertSortedCols; use p3_air::VirtualPairCol; use p3_field::PrimeField64; @@ -22,15 +22,15 @@ pub mod trace; * gate can take MAX up to 2^20, we further decompose each limb into sublimbs * of size decomp bits. * - * The AssertedSortedChip contains a LessThanChip subchip, which is used to constrain + * The AssertSortedChip contains a LessThanChip subchip, which is used to constrain * that the rows are sorted lexicographically. */ #[derive(Default)] -pub struct AssertedSortedChip { +pub struct AssertSortedChip { less_than_chip: LessThanChip, } -impl AssertedSortedChip { +impl AssertSortedChip { pub fn new( bus_index: usize, limb_bits: usize, @@ -71,7 +71,7 @@ impl AssertedSortedChip { pub fn sends_custom( &self, - cols: &AssertedSortedCols, + cols: &AssertSortedCols, ) -> Vec> { // num_limbs is the number of sublimbs per limb of key, not including the // shifted last sublimb diff --git a/chips/src/assert_sorted/tests/mod.rs b/chips/src/assert_sorted/tests/mod.rs index 1799a29eec..db1897c0e5 100644 --- a/chips/src/assert_sorted/tests/mod.rs +++ b/chips/src/assert_sorted/tests/mod.rs @@ -32,7 +32,7 @@ use p3_matrix::dense::DenseMatrix; // most limb_bits bits, rows are sorted lexicographically #[test] fn test_sorted_limbs_chip_small_positive() { - use assert_sorted::AssertedSortedChip; + use assert_sorted::AssertSortedChip; const BUS_INDEX: usize = 0; const LIMB_BITS: usize = 16; @@ -44,7 +44,7 @@ fn test_sorted_limbs_chip_small_positive() { let requests = vec![vec![7784, 35423], vec![17558, 44832]]; let sorted_limbs_chip = - AssertedSortedChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); + AssertSortedChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); let sorted_limbs_chip_trace: DenseMatrix = sorted_limbs_chip.generate_trace(); let sorted_limbs_range_chip_trace: DenseMatrix = sorted_limbs_chip @@ -66,7 +66,7 @@ fn test_sorted_limbs_chip_small_positive() { // most limb_bits bits, rows are sorted lexicographically #[test] fn test_sorted_limbs_chip_large_positive() { - use assert_sorted::AssertedSortedChip; + use assert_sorted::AssertSortedChip; const BUS_INDEX: usize = 0; const LIMB_BITS: usize = 30; @@ -83,7 +83,7 @@ fn test_sorted_limbs_chip_large_positive() { ]; let sorted_limbs_chip = - AssertedSortedChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); + AssertSortedChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); let sorted_limbs_chip_trace: DenseMatrix = sorted_limbs_chip.generate_trace(); let sorted_limbs_range_chip_trace: DenseMatrix = sorted_limbs_chip @@ -105,7 +105,7 @@ fn test_sorted_limbs_chip_large_positive() { // has more than limb_bits bits, rows are sorted lexicographically #[test] fn test_sorted_limbs_chip_largelimb_negative() { - use assert_sorted::AssertedSortedChip; + use assert_sorted::AssertSortedChip; const BUS_INDEX: usize = 0; const LIMB_BITS: usize = 10; @@ -123,7 +123,7 @@ fn test_sorted_limbs_chip_largelimb_negative() { ]; let sorted_limbs_chip = - AssertedSortedChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); + AssertSortedChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); let sorted_limbs_chip_trace: DenseMatrix = sorted_limbs_chip.generate_trace(); let sorted_limbs_range_chip_trace: DenseMatrix = sorted_limbs_chip @@ -150,7 +150,7 @@ fn test_sorted_limbs_chip_largelimb_negative() { // most limb_bits bits, rows are not sorted lexicographically #[test] fn test_sorted_limbs_chip_unsorted_negative() { - use assert_sorted::AssertedSortedChip; + use assert_sorted::AssertSortedChip; const BUS_INDEX: usize = 0; const LIMB_BITS: usize = 30; @@ -168,7 +168,7 @@ fn test_sorted_limbs_chip_unsorted_negative() { ]; let sorted_limbs_chip = - AssertedSortedChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); + AssertSortedChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); let sorted_limbs_chip_trace: DenseMatrix = sorted_limbs_chip.generate_trace(); let sorted_limbs_range_chip_trace: DenseMatrix = sorted_limbs_chip diff --git a/chips/src/assert_sorted/trace.rs b/chips/src/assert_sorted/trace.rs index 78a835d808..c28790d1a2 100644 --- a/chips/src/assert_sorted/trace.rs +++ b/chips/src/assert_sorted/trace.rs @@ -3,12 +3,12 @@ use p3_matrix::dense::RowMajorMatrix; use crate::sub_chip::LocalTraceInstructions; -use super::{columns::AssertedSortedCols, AssertedSortedChip}; +use super::{columns::AssertSortedCols, AssertSortedChip}; -impl AssertedSortedChip { +impl AssertSortedChip { pub fn generate_trace(&self) -> RowMajorMatrix { let num_cols: usize = - AssertedSortedCols::::get_width(self.limb_bits(), self.decomp(), self.key_vec_len()); + AssertSortedCols::::get_width(self.limb_bits(), self.decomp(), self.key_vec_len()); let num_limbs = (self.limb_bits() + self.decomp() - 1) / self.decomp(); diff --git a/chips/src/less_than/tests/mod.rs b/chips/src/less_than/tests/mod.rs index b2779e9c83..726c904815 100644 --- a/chips/src/less_than/tests/mod.rs +++ b/chips/src/less_than/tests/mod.rs @@ -20,9 +20,6 @@ use p3_matrix::dense::DenseMatrix; * partition on number of rows: * number of rows < 4 * number of rows >= 4 - * partition on size of each limb: - * each limb has at most limb_bits bits - * at least one limb has more than limb_bits bits * partition on row order: * rows are sorted lexicographically * rows are not sorted lexicographically From b62172af6a60fb77b8f3c634b1cc0819b3a05ae5 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Fri, 31 May 2024 17:39:05 -0400 Subject: [PATCH 07/46] chore: fix names in tests for AssertSortedChip --- chips/src/assert_sorted/tests/mod.rs | 56 ++++++++++++++-------------- chips/src/less_than/tests/mod.rs | 9 ++--- 2 files changed, 31 insertions(+), 34 deletions(-) diff --git a/chips/src/assert_sorted/tests/mod.rs b/chips/src/assert_sorted/tests/mod.rs index db1897c0e5..978b18cad9 100644 --- a/chips/src/assert_sorted/tests/mod.rs +++ b/chips/src/assert_sorted/tests/mod.rs @@ -31,7 +31,7 @@ use p3_matrix::dense::DenseMatrix; // covers limb_bits < 20, key_vec_len < 4, limb_bits % decomp == 0, number of rows < 4, each limb has at // most limb_bits bits, rows are sorted lexicographically #[test] -fn test_sorted_limbs_chip_small_positive() { +fn test_assert_sorted_chip_small_positive() { use assert_sorted::AssertSortedChip; const BUS_INDEX: usize = 0; @@ -43,21 +43,21 @@ fn test_sorted_limbs_chip_small_positive() { let requests = vec![vec![7784, 35423], vec![17558, 44832]]; - let sorted_limbs_chip = + let assert_sorted_chip = AssertSortedChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); - let sorted_limbs_chip_trace: DenseMatrix = sorted_limbs_chip.generate_trace(); - let sorted_limbs_range_chip_trace: DenseMatrix = sorted_limbs_chip + let assert_sorted_chip_trace: DenseMatrix = assert_sorted_chip.generate_trace(); + let assert_sorted_range_chip_trace: DenseMatrix = assert_sorted_chip .less_than_chip .range_checker_gate .generate_trace(); run_simple_test_no_pis( vec![ - &sorted_limbs_chip, - &sorted_limbs_chip.less_than_chip.range_checker_gate, + &assert_sorted_chip, + &assert_sorted_chip.less_than_chip.range_checker_gate, ], - vec![sorted_limbs_chip_trace, sorted_limbs_range_chip_trace], + vec![assert_sorted_chip_trace, assert_sorted_range_chip_trace], ) .expect("Verification failed"); } @@ -65,7 +65,7 @@ fn test_sorted_limbs_chip_small_positive() { // covers limb_bits >= 20, key_vec_len >= 4, limb_bits % decomp != 0, number of rows >= 4, each limb has at // most limb_bits bits, rows are sorted lexicographically #[test] -fn test_sorted_limbs_chip_large_positive() { +fn test_assert_sorted_chip_large_positive() { use assert_sorted::AssertSortedChip; const BUS_INDEX: usize = 0; @@ -82,21 +82,21 @@ fn test_sorted_limbs_chip_large_positive() { vec![875005, 767547, 196209, 887921], ]; - let sorted_limbs_chip = + let assert_sorted_chip = AssertSortedChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); - let sorted_limbs_chip_trace: DenseMatrix = sorted_limbs_chip.generate_trace(); - let sorted_limbs_range_chip_trace: DenseMatrix = sorted_limbs_chip + let assert_sorted_chip_trace: DenseMatrix = assert_sorted_chip.generate_trace(); + let assert_sorted_range_chip_trace: DenseMatrix = assert_sorted_chip .less_than_chip .range_checker_gate .generate_trace(); run_simple_test_no_pis( vec![ - &sorted_limbs_chip, - &sorted_limbs_chip.less_than_chip.range_checker_gate, + &assert_sorted_chip, + &assert_sorted_chip.less_than_chip.range_checker_gate, ], - vec![sorted_limbs_chip_trace, sorted_limbs_range_chip_trace], + vec![assert_sorted_chip_trace, assert_sorted_range_chip_trace], ) .expect("Verification failed"); } @@ -104,7 +104,7 @@ fn test_sorted_limbs_chip_large_positive() { // covers limb_bits >= 20, key_vec_len >= 4, limb_bits % decomp != 0, number of rows >= 4, at least one limb // has more than limb_bits bits, rows are sorted lexicographically #[test] -fn test_sorted_limbs_chip_largelimb_negative() { +fn test_assert_sorted_chip_largelimb_negative() { use assert_sorted::AssertSortedChip; const BUS_INDEX: usize = 0; @@ -122,21 +122,21 @@ fn test_sorted_limbs_chip_largelimb_negative() { vec![128, 767, 196, 953], ]; - let sorted_limbs_chip = + let assert_sorted_chip = AssertSortedChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); - let sorted_limbs_chip_trace: DenseMatrix = sorted_limbs_chip.generate_trace(); - let sorted_limbs_range_chip_trace: DenseMatrix = sorted_limbs_chip + let assert_sorted_chip_trace: DenseMatrix = assert_sorted_chip.generate_trace(); + let assert_sorted_range_chip_trace: DenseMatrix = assert_sorted_chip .less_than_chip .range_checker_gate .generate_trace(); let result = run_simple_test_no_pis( vec![ - &sorted_limbs_chip, - &sorted_limbs_chip.less_than_chip.range_checker_gate, + &assert_sorted_chip, + &assert_sorted_chip.less_than_chip.range_checker_gate, ], - vec![sorted_limbs_chip_trace, sorted_limbs_range_chip_trace], + vec![assert_sorted_chip_trace, assert_sorted_range_chip_trace], ); assert_eq!( @@ -149,7 +149,7 @@ fn test_sorted_limbs_chip_largelimb_negative() { // covers limb_bits >= 20, key_vec_len >= 4, limb_bits % decomp != 0, number of rows >= 4, each limb has at // most limb_bits bits, rows are not sorted lexicographically #[test] -fn test_sorted_limbs_chip_unsorted_negative() { +fn test_assert_sorted_chip_unsorted_negative() { use assert_sorted::AssertSortedChip; const BUS_INDEX: usize = 0; @@ -167,11 +167,11 @@ fn test_sorted_limbs_chip_unsorted_negative() { vec![875005, 767547, 196209, 887921], ]; - let sorted_limbs_chip = + let assert_sorted_chip = AssertSortedChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); - let sorted_limbs_chip_trace: DenseMatrix = sorted_limbs_chip.generate_trace(); - let sorted_limbs_range_chip_trace: DenseMatrix = sorted_limbs_chip + let assert_sorted_chip_trace: DenseMatrix = assert_sorted_chip.generate_trace(); + let assert_sorted_range_chip_trace: DenseMatrix = assert_sorted_chip .less_than_chip .range_checker_gate .generate_trace(); @@ -182,10 +182,10 @@ fn test_sorted_limbs_chip_unsorted_negative() { assert_eq!( run_simple_test_no_pis( vec![ - &sorted_limbs_chip, - &sorted_limbs_chip.less_than_chip.range_checker_gate, + &assert_sorted_chip, + &assert_sorted_chip.less_than_chip.range_checker_gate, ], - vec![sorted_limbs_chip_trace, sorted_limbs_range_chip_trace], + vec![assert_sorted_chip_trace, assert_sorted_range_chip_trace], ), Err(VerificationError::OodEvaluationMismatch), "Expected verification to fail, but it passed" diff --git a/chips/src/less_than/tests/mod.rs b/chips/src/less_than/tests/mod.rs index 726c904815..6ec8cb501e 100644 --- a/chips/src/less_than/tests/mod.rs +++ b/chips/src/less_than/tests/mod.rs @@ -25,8 +25,7 @@ use p3_matrix::dense::DenseMatrix; * rows are not sorted lexicographically */ -// covers limb_bits < 20, key_vec_len < 4, limb_bits % decomp == 0, number of rows < 4, each limb has at -// most limb_bits bits, rows are sorted lexicographically +// covers limb_bits < 20, key_vec_len < 4, limb_bits % decomp == 0, number of rows < 4, rows are sorted lexicographically #[test] fn test_less_than_chip_small_positive() { use less_than::LessThanChip; @@ -54,8 +53,7 @@ fn test_less_than_chip_small_positive() { .expect("Verification failed"); } -// covers limb_bits >= 20, key_vec_len >= 4, limb_bits % decomp != 0, number of rows >= 4, each limb has at -// most limb_bits bits, rows are sorted lexicographically +// covers limb_bits >= 20, key_vec_len >= 4, limb_bits % decomp != 0, number of rows >= 4, rows are sorted lexicographically #[test] fn test_less_than_chip_large_positive() { use less_than::LessThanChip; @@ -88,8 +86,7 @@ fn test_less_than_chip_large_positive() { .expect("Verification failed"); } -// covers limb_bits >= 20, key_vec_len >= 4, limb_bits % decomp != 0, number of rows >= 4, each limb has at -// most limb_bits bits, rows are not sorted lexicographically +// covers limb_bits >= 20, key_vec_len >= 4, limb_bits % decomp != 0, number of rows >= 4, rows are not sorted lexicographically #[test] fn test_less_than_chip_unsorted_negative() { use less_than::LessThanChip; From 451f4c462cb4b549a344ea836b14f9d69e56506e Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Mon, 3 Jun 2024 13:48:14 -0400 Subject: [PATCH 08/46] chore: address comments --- Cargo.lock | 43 +++++++++++++++ chips/Cargo.toml | 1 + chips/src/assert_sorted/air.rs | 49 ++++++++++------- chips/src/assert_sorted/chip.rs | 23 ++++---- chips/src/assert_sorted/columns.rs | 86 ------------------------------ chips/src/assert_sorted/mod.rs | 38 ++++++------- chips/src/assert_sorted/trace.rs | 36 ++++++++----- chips/src/less_than/air.rs | 73 +++++++++++-------------- chips/src/less_than/chip.rs | 21 ++++---- chips/src/less_than/columns.rs | 75 -------------------------- chips/src/less_than/mod.rs | 50 ++++++++--------- chips/src/less_than/trace.rs | 42 ++++++++------- 12 files changed, 220 insertions(+), 317 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4f584ef820..17ff15ad1b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9,6 +9,7 @@ dependencies = [ "afs-derive", "afs-stark-backend", "afs-test-utils", + "getset", "itertools 0.13.0", "p3-air", "p3-baby-bear", @@ -239,6 +240,18 @@ dependencies = [ "wasi", ] +[[package]] +name = "getset" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e45727250e75cc04ff2846a66397da8ef2b3db8e40e0cef4df67950a07621eb9" +dependencies = [ + "proc-macro-error", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "indenter" version = "0.3.3" @@ -636,6 +649,30 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn 1.0.109", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + [[package]] name = "proc-macro2" version = "1.0.82" @@ -962,6 +999,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" diff --git a/chips/Cargo.toml b/chips/Cargo.toml index c1a85e9604..8e7da3b73d 100644 --- a/chips/Cargo.toml +++ b/chips/Cargo.toml @@ -25,6 +25,7 @@ afs-test-utils = { path = "../test-utils" } parking_lot = "0.12.2" tracing = "0.1.40" itertools = "0.13.0" +getset = "0.1.2" [dev-dependencies] p3-uni-stark = { workspace = true } diff --git a/chips/src/assert_sorted/air.rs b/chips/src/assert_sorted/air.rs index a7db1dd053..a94514ba30 100644 --- a/chips/src/assert_sorted/air.rs +++ b/chips/src/assert_sorted/air.rs @@ -12,7 +12,11 @@ use super::AssertSortedChip; impl BaseAir for AssertSortedChip { fn width(&self) -> usize { - AssertSortedCols::::get_width(self.limb_bits(), self.decomp(), self.key_vec_len()) + AssertSortedCols::::get_width( + *self.less_than_chip.air.limb_bits(), + *self.less_than_chip.air.decomp(), + *self.less_than_chip.air.key_vec_len(), + ) } } @@ -29,24 +33,28 @@ where let local_cols = AssertSortedCols::::from_slice( local, - self.limb_bits(), - self.decomp(), - self.key_vec_len(), + *self.less_than_chip.air.limb_bits(), + *self.less_than_chip.air.decomp(), + *self.less_than_chip.air.key_vec_len(), ); - let num_limbs = (self.limb_bits() + self.decomp() - 1) / self.decomp(); - let key_len = self.key_vec_len(); + let num_limbs = (*self.less_than_chip.air.limb_bits() + *self.less_than_chip.air.decomp() + - 1) + / *self.less_than_chip.air.decomp(); + let key_len = *self.less_than_chip.air.key_vec_len(); // to range check the last sublimb of the decomposed limb, we need to shift it to make sure it is in // the correct range - let last_limb_shift = (self.decomp() - (self.limb_bits() % self.decomp())) % self.decomp(); + let last_limb_shift = (*self.less_than_chip.air.decomp() + - (*self.less_than_chip.air.limb_bits() % *self.less_than_chip.air.decomp())) + % *self.less_than_chip.air.decomp(); for i in 0..key_len { let mut key_from_limbs: AB::Expr = AB::Expr::zero(); // constrain that the decomposition is correct for j in 0..num_limbs { key_from_limbs += local_cols.keys_decomp[i][j] - * AB::Expr::from_canonical_u64(1 << (j * self.decomp())); + * AB::Expr::from_canonical_u64(1 << (j * self.less_than_chip.air.decomp())); } // constrain that the shifted last sublimb is shifted correctly @@ -58,31 +66,32 @@ where } // generate LessThanCols struct for current row and next row - let mut local_slice: Vec = local[0..self.key_vec_len()].to_vec(); - local_slice.extend_from_slice(&local[(self.key_vec_len() * (num_limbs + 2))..]); + let mut local_slice: Vec = local[0..key_len].to_vec(); + local_slice.extend_from_slice(&local[key_len * (num_limbs + 2)..]); - let mut next_slice: Vec = next[0..self.key_vec_len()].to_vec(); - next_slice.extend_from_slice(&next[(self.key_vec_len() * (num_limbs + 2))..]); + let mut next_slice: Vec = next[0..key_len].to_vec(); + next_slice + .extend_from_slice(&next[(self.less_than_chip.air.key_vec_len() * (num_limbs + 2))..]); let local_cols = LessThanCols::::from_slice( &local_slice, - self.limb_bits(), - self.decomp(), - self.key_vec_len(), + *self.less_than_chip.air.limb_bits(), + *self.less_than_chip.air.decomp(), + *self.less_than_chip.air.key_vec_len(), ); let next_cols = LessThanCols::::from_slice( &next_slice, - self.limb_bits(), - self.decomp(), - self.key_vec_len(), + *self.less_than_chip.air.limb_bits(), + *self.less_than_chip.air.decomp(), + *self.less_than_chip.air.key_vec_len(), ); // constrain the current row is less than the next row SubAir::eval( - &self.less_than_chip, + &self.less_than_chip.air, builder, - vec![local_cols.io, next_cols.io], + [local_cols.io, next_cols.io], local_cols.aux, ); } diff --git a/chips/src/assert_sorted/chip.rs b/chips/src/assert_sorted/chip.rs index f5ee4a3ea6..331fe632ba 100644 --- a/chips/src/assert_sorted/chip.rs +++ b/chips/src/assert_sorted/chip.rs @@ -9,21 +9,26 @@ use super::AssertSortedChip; impl Chip for AssertSortedChip { fn sends(&self) -> Vec> { - let num_cols = - AssertSortedCols::::get_width(self.limb_bits(), self.decomp(), self.key_vec_len()); + let num_cols = AssertSortedCols::::get_width( + *self.less_than_chip.air.limb_bits(), + *self.less_than_chip.air.decomp(), + *self.less_than_chip.air.key_vec_len(), + ); let all_cols = (0..num_cols).collect::>(); - let cols_numbered = AssertSortedCols::::cols_numbered( + let cols_numbered = AssertSortedCols::::from_slice( &all_cols, - self.limb_bits(), - self.decomp(), - self.key_vec_len(), + *self.less_than_chip.air.limb_bits(), + *self.less_than_chip.air.decomp(), + *self.less_than_chip.air.key_vec_len(), ); let mut interactions: Vec> = vec![]; - let num_limbs = (self.limb_bits() + self.decomp() - 1) / self.decomp(); - let num_keys = self.key_vec_len(); + let num_limbs = (*self.less_than_chip.air.limb_bits() + *self.less_than_chip.air.decomp() + - 1) + / *self.less_than_chip.air.decomp(); + let num_keys = *self.less_than_chip.air.key_vec_len(); // we will range check the decomposed limbs of the key for i in 0..num_keys { @@ -32,7 +37,7 @@ impl Chip for AssertSortedChip { interactions.push(Interaction { fields: vec![VirtualPairCol::single_main(cols_numbered.keys_decomp[i][j])], count: VirtualPairCol::constant(F::one()), - argument_index: self.bus_index(), + argument_index: *self.less_than_chip.bus_index(), }); } } diff --git a/chips/src/assert_sorted/columns.rs b/chips/src/assert_sorted/columns.rs index d4933a97f3..b9b59e6701 100644 --- a/chips/src/assert_sorted/columns.rs +++ b/chips/src/assert_sorted/columns.rs @@ -31,12 +31,6 @@ impl AssertSortedCols { cur_start_idx = cur_end_idx; cur_end_idx += key_vec_len; - // the next key_vec_len elements are the values of 2^num_limbs + b - a - 1 where a and b are limbs - // on consecutive rows and b is the row after a - let intermed_sum = slc[cur_start_idx..cur_end_idx].to_vec(); - cur_start_idx = cur_end_idx; - cur_end_idx += key_vec_len; - // the next key_vec_len elements are the values of the lower num_limbs bits of the intermediate sum let lower_bits = slc[cur_start_idx..cur_end_idx].to_vec(); cur_start_idx = cur_end_idx; @@ -74,7 +68,6 @@ impl AssertSortedCols { let io = LessThanIOCols { key }; let aux = LessThanAuxCols { - intermed_sum, lower_bits, upper_bit, lower_bits_decomp, @@ -101,8 +94,6 @@ impl AssertSortedCols { // for the decomposed keys let num_limbs = (limb_bits + decomp - 1) / decomp; width += key_vec_len * (num_limbs + 1); - // for the 2^limb_bits + b - a values - width += key_vec_len; // for the lower_bits width += key_vec_len; // for the upper_bit @@ -118,81 +109,4 @@ impl AssertSortedCols { width } - - pub fn cols_numbered( - cols: &[usize], - limb_bits: usize, - decomp: usize, - key_vec_len: usize, - ) -> AssertSortedCols { - // num_limbs is the number of sublimbs per limb, not including the shifted last sublimb - let num_limbs = (limb_bits + decomp - 1) / decomp; - let mut cur_start_idx = 0; - let mut cur_end_idx = key_vec_len; - - // the first key_vec_len elements are the key itself - let key = cols[cur_start_idx..cur_end_idx].to_vec(); - cur_start_idx = cur_end_idx; - cur_end_idx += key_vec_len * (num_limbs + 1); - - // the next key_vec_len * (num_limbs + 1) elements are the decomposed keys - let keys_decomp = cols[cur_start_idx..cur_end_idx] - .chunks(num_limbs + 1) - .map(|chunk| chunk.to_vec()) - .collect(); - cur_start_idx = cur_end_idx; - cur_end_idx += key_vec_len; - - // the next key_vec_len elements are the intermediate sum - let intermed_sum = cols[cur_start_idx..cur_end_idx].to_vec(); - cur_start_idx = cur_end_idx; - cur_end_idx += key_vec_len; - - // the next key_vec_len elements are the lower_bits - let lower_bits = cols[cur_start_idx..cur_end_idx].to_vec(); - cur_start_idx = cur_end_idx; - cur_end_idx += key_vec_len; - - // the next key_vec_len elements are the upper_bit - let upper_bit = cols[cur_start_idx..cur_end_idx].to_vec(); - cur_start_idx = cur_end_idx; - cur_end_idx += key_vec_len * (num_limbs + 1); - - // the next key_vec_len * (num_limbs + 1) elements are the decomposed lower_bits - let lower_bits_decomp = cols[cur_start_idx..cur_end_idx] - .chunks(num_limbs + 1) - .map(|chunk| chunk.to_vec()) - .collect(); - - // the next key_vec_len elements are the difference between consecutive rows - let diff = cols[cur_start_idx..cur_end_idx].to_vec(); - cur_start_idx = cur_end_idx; - cur_end_idx += key_vec_len; - - // the next key_vec_len elements are the indicator whether difference is zero - let is_zero = cols[cur_start_idx..cur_end_idx].to_vec(); - cur_start_idx = cur_end_idx; - cur_end_idx += key_vec_len; - - // the next key_vec_len elements are the inverses - let inverses = cols[cur_start_idx..cur_end_idx].to_vec(); - - let io = LessThanIOCols { key }; - let aux = LessThanAuxCols { - intermed_sum, - lower_bits, - upper_bit, - lower_bits_decomp, - diff, - is_zero, - inverses, - }; - - let less_than_cols = LessThanCols { io, aux }; - - AssertSortedCols { - keys_decomp, - less_than_cols, - } - } } diff --git a/chips/src/assert_sorted/mod.rs b/chips/src/assert_sorted/mod.rs index 00f48897d1..8f571f4dfd 100644 --- a/chips/src/assert_sorted/mod.rs +++ b/chips/src/assert_sorted/mod.rs @@ -49,25 +49,25 @@ impl AssertSortedChip { } } - pub fn bus_index(&self) -> usize { - self.less_than_chip.bus_index() - } + // pub fn bus_index(&self) -> &usize { + // self.less_than_chip.bus_index() + // } - pub fn limb_bits(&self) -> usize { - self.less_than_chip.limb_bits() - } + // pub fn limb_bits(&self) -> &usize { + // self.less_than_chip.air.limb_bits() + // } - pub fn decomp(&self) -> usize { - self.less_than_chip.decomp() - } + // pub fn decomp(&self) -> &usize { + // self.less_than_chip.air.decomp() + // } - pub fn key_vec_len(&self) -> usize { - self.less_than_chip.key_vec_len() - } + // pub fn key_vec_len(&self) -> &usize { + // self.less_than_chip.air.key_vec_len() + // } - pub fn keys(&self) -> Vec> { - self.less_than_chip.keys().clone() - } + // pub fn keys(&self) -> &Vec> { + // self.less_than_chip.air.keys() + // } pub fn sends_custom( &self, @@ -75,8 +75,10 @@ impl AssertSortedChip { ) -> Vec> { // num_limbs is the number of sublimbs per limb of key, not including the // shifted last sublimb - let num_limbs = (self.limb_bits() + self.decomp() - 1) / self.decomp(); - let num_keys = self.key_vec_len(); + let num_limbs = (*self.less_than_chip.air.limb_bits() + *self.less_than_chip.air.decomp() + - 1) + / *self.less_than_chip.air.decomp(); + let num_keys = *self.less_than_chip.air.key_vec_len(); let mut interactions = vec![]; @@ -87,7 +89,7 @@ impl AssertSortedChip { interactions.push(Interaction { fields: vec![VirtualPairCol::single_main(cols.keys_decomp[i][j])], count: VirtualPairCol::constant(F::one()), - argument_index: self.bus_index(), + argument_index: *self.less_than_chip.bus_index(), }); } } diff --git a/chips/src/assert_sorted/trace.rs b/chips/src/assert_sorted/trace.rs index c28790d1a2..1b88635544 100644 --- a/chips/src/assert_sorted/trace.rs +++ b/chips/src/assert_sorted/trace.rs @@ -7,22 +7,29 @@ use super::{columns::AssertSortedCols, AssertSortedChip}; impl AssertSortedChip { pub fn generate_trace(&self) -> RowMajorMatrix { - let num_cols: usize = - AssertSortedCols::::get_width(self.limb_bits(), self.decomp(), self.key_vec_len()); + let num_cols: usize = AssertSortedCols::::get_width( + *self.less_than_chip.air.limb_bits(), + *self.less_than_chip.air.decomp(), + *self.less_than_chip.air.key_vec_len(), + ); - let num_limbs = (self.limb_bits() + self.decomp() - 1) / self.decomp(); + let num_limbs = (*self.less_than_chip.air.limb_bits() + *self.less_than_chip.air.decomp() + - 1) + / *self.less_than_chip.air.decomp(); // to range check the last sublimb of the decomposed limb, we need to shift it to make sure it is in // the correct range - let last_limb_shift = (self.decomp() - (self.limb_bits() % self.decomp())) % self.decomp(); + let last_limb_shift = (self.less_than_chip.air.decomp() + - (self.less_than_chip.air.limb_bits() % self.less_than_chip.air.decomp())) + % self.less_than_chip.air.decomp(); let mut rows: Vec = vec![]; - for i in 0..self.key_vec_len() { - let key = self.keys()[i].clone(); - let next_key: Vec = if i == self.key_vec_len() - 1 { - vec![0; self.key_vec_len()] + for i in 0..*self.less_than_chip.air.key_vec_len() { + let key = self.less_than_chip.air.keys()[i].clone(); + let next_key: Vec = if i == *self.less_than_chip.air.key_vec_len() - 1 { + vec![0; *self.less_than_chip.air.key_vec_len()] } else { - self.keys()[i + 1].clone() + self.less_than_chip.air.keys()[i + 1].clone() }; let less_than_trace = LocalTraceInstructions::generate_trace_row( @@ -35,13 +42,15 @@ impl AssertSortedChip { // decompose each limb into sublimbs of size self.decomp() bits for &val in key.iter() { for i in 0..num_limbs { - let bits = (val >> (i * self.decomp())) & ((1 << self.decomp()) - 1); + let bits = (val >> (i * self.less_than_chip.air.decomp())) + & ((1 << self.less_than_chip.air.decomp()) - 1); key_decomp_trace.push(F::from_canonical_u32(bits)); self.less_than_chip.range_checker_gate.add_count(bits); } // the last sublimb should be of size self.limb_bits() % self.decomp() bits, // so we need to shift it to constrain this - let bits = (val >> ((num_limbs - 1) * self.decomp())) & ((1 << self.decomp()) - 1); + let bits = (val >> ((num_limbs - 1) * self.less_than_chip.air.decomp())) + & ((1 << self.less_than_chip.air.decomp()) - 1); if (bits << last_limb_shift) < MAX { self.less_than_chip .range_checker_gate @@ -50,9 +59,10 @@ impl AssertSortedChip { key_decomp_trace.push(F::from_canonical_u32(bits << last_limb_shift)); } - let mut row: Vec = less_than_trace[0..self.key_vec_len()].to_vec(); + let mut row: Vec = + less_than_trace[0..*self.less_than_chip.air.key_vec_len()].to_vec(); row.extend_from_slice(&key_decomp_trace); - row.extend_from_slice(&less_than_trace[self.key_vec_len()..]); + row.extend_from_slice(&less_than_trace[*self.less_than_chip.air.key_vec_len()..]); rows.extend_from_slice(&row); } diff --git a/chips/src/less_than/air.rs b/chips/src/less_than/air.rs index 135b95bab3..181257772f 100644 --- a/chips/src/less_than/air.rs +++ b/chips/src/less_than/air.rs @@ -1,6 +1,6 @@ use std::borrow::Borrow; -use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir}; +use p3_air::{Air, AirBuilder, BaseAir}; use p3_field::{AbstractField, Field}; use p3_matrix::Matrix; @@ -8,60 +8,55 @@ use crate::sub_chip::{AirConfig, SubAir}; use super::{ columns::{LessThanAuxCols, LessThanCols, LessThanIOCols}, - LessThanChip, + LessThanAir, LessThanChip, }; +impl AirConfig for LessThanChip { + type Cols = LessThanCols; +} + impl BaseAir for LessThanChip { fn width(&self) -> usize { - LessThanCols::::get_width(self.limb_bits(), self.decomp(), self.key_vec_len()) + LessThanCols::::get_width( + *self.air.limb_bits(), + *self.air.decomp(), + *self.air.key_vec_len(), + ) } } -impl Air for LessThanChip -where - AB: AirBuilder, - AB::Var: Clone, -{ +impl Air for LessThanChip { fn eval(&self, builder: &mut AB) { let main = builder.main(); - let _pis = builder.public_values(); let (local, next) = (main.row_slice(0), main.row_slice(1)); let local: &[AB::Var] = (*local).borrow(); let next: &[AB::Var] = (*next).borrow(); - let local_cols = LessThanCols::::from_slice( - local, - self.limb_bits(), - self.decomp(), - self.key_vec_len(), - ); - - let next_cols = LessThanCols::::from_slice( - next, - self.limb_bits(), - self.decomp(), - self.key_vec_len(), - ); + let [local_cols, next_cols] = [local, next].map(|view| { + LessThanCols::::from_slice( + view, + *self.air.limb_bits(), + *self.air.decomp(), + *self.air.key_vec_len(), + ) + }); SubAir::eval( - self, + &self.air, builder, - vec![local_cols.io, next_cols.io], + [local_cols.io, next_cols.io], local_cols.aux, ); } } -impl AirConfig for LessThanChip { - type Cols = LessThanCols; -} - // sub-chip with constraints to check whether one key is less than the next (row-wise) -impl SubAir for LessThanChip { - type IoView = Vec>; +impl SubAir for LessThanAir { + type IoView = [LessThanIOCols; 2]; type AuxView = LessThanAuxCols; + // constrain that local_key < next_key lexicographically fn eval(&self, builder: &mut AB, io: Self::IoView, aux: Self::AuxView) { let local_key = io[0].key.clone(); let next_key = io[1].key.clone(); @@ -71,10 +66,9 @@ impl SubAir for LessThanChip { // num_limbs is the number of sublimbs per limb, not including the shifted last sublimb let num_limbs = (self.limb_bits() + self.decomp() - 1) / self.decomp(); - let intermed_sum = local_aux.intermed_sum.clone(); - let lower_bits = local_aux.lower_bits.clone(); - let upper_bit = local_aux.upper_bit.clone(); - let lower_bits_decomp = local_aux.lower_bits_decomp.clone(); + let lower_bits = &local_aux.lower_bits; + let upper_bit = &local_aux.upper_bit; + let lower_bits_decomp = &local_aux.lower_bits_decomp; // we want to check these constraints for each row except the last one let mut when_transition = builder.when_transition(); @@ -89,20 +83,17 @@ impl SubAir for LessThanChip { + AB::Expr::from_canonical_u64(1 << self.limb_bits()) - AB::Expr::one(); - // constrain that the intermed val (2^limb_bits + key_next - key_local) is correct - when_transition.assert_eq(intermed_sum[i], intermed_val); - // constrain that lower_bits[i] + upper_bit[i] * 2^limb_bits is the correct intermediate sum let check_val = lower_bits[i] + upper_bit[i] * AB::Expr::from_canonical_u64(1 << self.limb_bits()); - when_transition.assert_eq(intermed_sum[i], check_val); + when_transition.assert_eq(intermed_val, check_val); // constrain that diff is the difference between the two elements of consecutive rows let diff = *key_next - *key_local; when_transition.assert_eq(diff, local_aux.diff[i]); } - for i in 0..self.key_vec_len() { + for i in 0..*self.key_vec_len() { let mut lower_bits_from_decomp: AB::Expr = AB::Expr::zero(); // constrain that the decomposition of each lower_bits element is correct for j in 0..num_limbs { @@ -118,13 +109,13 @@ impl SubAir for LessThanChip { when_transition.assert_eq(lower_bits_from_decomp, lower_bits[i]); } - for upper_bit_value in &upper_bit { + for upper_bit_value in upper_bit { // constrain that each element in upper_bit is a boolean let is_bool = *upper_bit_value * (AB::Expr::one() - *upper_bit_value); when_transition.assert_zero(is_bool); } - for i in 0..self.key_vec_len() { + for i in 0..*self.key_vec_len() { let diff = local_aux.diff[i]; let is_equal = local_aux.is_zero[i]; let inverse = local_aux.inverses[i]; diff --git a/chips/src/less_than/chip.rs b/chips/src/less_than/chip.rs index 44f09f9880..82752886ae 100644 --- a/chips/src/less_than/chip.rs +++ b/chips/src/less_than/chip.rs @@ -9,15 +9,18 @@ use super::LessThanChip; impl Chip for LessThanChip { fn sends(&self) -> Vec> { - let num_cols = - LessThanCols::::get_width(self.limb_bits(), self.decomp(), self.key_vec_len()); + let num_cols = LessThanCols::::get_width( + *self.air.limb_bits(), + *self.air.decomp(), + *self.air.key_vec_len(), + ); let all_cols = (0..num_cols).collect::>(); - let cols_numbered = LessThanCols::::cols_numbered( + let cols_numbered = LessThanCols::::from_slice( &all_cols, - self.limb_bits(), - self.decomp(), - self.key_vec_len(), + *self.air.limb_bits(), + *self.air.decomp(), + *self.air.key_vec_len(), ); SubAirWithInteractions::sends(self, cols_numbered) @@ -28,8 +31,8 @@ impl SubAirWithInteractions for LessThanChip fn sends(&self, col_indices: LessThanCols) -> Vec> { // num_limbs is the number of sublimbs per limb of key, not including the // shifted last sublimb - let num_limbs = (self.limb_bits() + self.decomp() - 1) / self.decomp(); - let num_keys = self.key_vec_len(); + let num_limbs = (*self.air.limb_bits() + *self.air.decomp() - 1) / *self.air.decomp(); + let num_keys = *self.air.key_vec_len(); let mut interactions = vec![]; @@ -42,7 +45,7 @@ impl SubAirWithInteractions for LessThanChip col_indices.aux.lower_bits_decomp[i][j], )], count: VirtualPairCol::constant(F::one()), - argument_index: self.bus_index(), + argument_index: *self.bus_index(), }); } } diff --git a/chips/src/less_than/columns.rs b/chips/src/less_than/columns.rs index 5e39b9880e..5bd687eb2a 100644 --- a/chips/src/less_than/columns.rs +++ b/chips/src/less_than/columns.rs @@ -6,7 +6,6 @@ pub struct LessThanIOCols { } pub struct LessThanAuxCols { - pub intermed_sum: Vec, pub lower_bits: Vec, pub upper_bit: Vec, pub lower_bits_decomp: Vec>, @@ -32,12 +31,6 @@ impl LessThanCols { cur_start_idx = cur_end_idx; cur_end_idx += key_vec_len; - // the next key_vec_len elements are the values of 2^num_limbs + b - a - 1 where a and b are limbs - // on consecutive rows and b is the row after a - let intermed_sum = slc[cur_start_idx..cur_end_idx].to_vec(); - cur_start_idx = cur_end_idx; - cur_end_idx += key_vec_len; - // the next key_vec_len elements are the values of the lower num_limbs bits of the intermediate sum let lower_bits = slc[cur_start_idx..cur_end_idx].to_vec(); cur_start_idx = cur_end_idx; @@ -75,7 +68,6 @@ impl LessThanCols { let io = LessThanIOCols { key }; let aux = LessThanAuxCols { - intermed_sum, lower_bits, upper_bit, lower_bits_decomp, @@ -90,7 +82,6 @@ impl LessThanCols { pub fn flatten(&self) -> Vec { let mut flattened = vec![]; flattened.extend_from_slice(&self.io.key); - flattened.extend_from_slice(&self.aux.intermed_sum); flattened.extend_from_slice(&self.aux.lower_bits); flattened.extend_from_slice(&self.aux.upper_bit); for decomp_vec in &self.aux.lower_bits_decomp { @@ -110,8 +101,6 @@ impl LessThanCols { let mut width = 0; // for the key itself width += key_vec_len; - // for the 2^limb_bits + b - a values - width += key_vec_len; // for the lower_bits width += key_vec_len; // for the upper_bit @@ -128,68 +117,4 @@ impl LessThanCols { width } - - pub fn cols_numbered( - cols: &[usize], - limb_bits: usize, - decomp: usize, - key_vec_len: usize, - ) -> LessThanCols { - // num_limbs is the number of sublimbs per limb, not including the shifted last sublimb - let num_limbs = (limb_bits + decomp - 1) / decomp; - let mut cur_start_idx = 0; - let mut cur_end_idx = key_vec_len; - - // the first key_vec_len elements are the key itself - let key = cols[cur_start_idx..cur_end_idx].to_vec(); - cur_start_idx = cur_end_idx; - cur_end_idx += key_vec_len; - - // the next key_vec_len elements are the intermediate sum - let intermed_sum = cols[cur_start_idx..cur_end_idx].to_vec(); - cur_start_idx = cur_end_idx; - cur_end_idx += key_vec_len; - - // the next key_vec_len elements are the lower_bits - let lower_bits = cols[cur_start_idx..cur_end_idx].to_vec(); - cur_start_idx = cur_end_idx; - cur_end_idx += key_vec_len; - - // the next key_vec_len elements are the upper_bit - let upper_bit = cols[cur_start_idx..cur_end_idx].to_vec(); - cur_start_idx = cur_end_idx; - cur_end_idx += key_vec_len * (num_limbs + 1); - - // the next key_vec_len * (num_limbs + 1) elements are the decomposed lower_bits - let lower_bits_decomp = cols[cur_start_idx..cur_end_idx] - .chunks(num_limbs + 1) - .map(|chunk| chunk.to_vec()) - .collect(); - - // the next key_vec_len elements are the difference between consecutive rows - let diff = cols[cur_start_idx..cur_end_idx].to_vec(); - cur_start_idx = cur_end_idx; - cur_end_idx += key_vec_len; - - // the next key_vec_len elements are the indicator whether difference is zero - let is_zero = cols[cur_start_idx..cur_end_idx].to_vec(); - cur_start_idx = cur_end_idx; - cur_end_idx += key_vec_len; - - // the next key_vec_len elements are the inverses - let inverses = cols[cur_start_idx..cur_end_idx].to_vec(); - - let io = LessThanIOCols { key }; - let aux = LessThanAuxCols { - intermed_sum, - lower_bits, - upper_bit, - lower_bits_decomp, - diff, - is_zero, - inverses, - }; - - LessThanCols { io, aux } - } } diff --git a/chips/src/less_than/mod.rs b/chips/src/less_than/mod.rs index 9e688ae3e5..596ed65818 100644 --- a/chips/src/less_than/mod.rs +++ b/chips/src/less_than/mod.rs @@ -1,4 +1,5 @@ use crate::range_gate::RangeCheckerGateChip; +use getset::Getters; #[cfg(test)] pub mod tests; @@ -8,18 +9,29 @@ pub mod chip; pub mod columns; pub mod trace; +#[derive(Default, Getters)] +pub struct LessThanAir { + #[getset(get = "pub")] + limb_bits: usize, + #[getset(get = "pub")] + decomp: usize, + #[getset(get = "pub")] + key_vec_len: usize, + #[getset(get = "pub")] + keys: Vec>, +} + /** * This Chip constrains that consecutive rows are sorted lexicographically. * * Each row consists of a key decomposed into limbs with at most limb_bits bits */ -#[derive(Default)] +#[derive(Default, Getters)] pub struct LessThanChip { + pub air: LessThanAir, + + #[getset(get = "pub")] bus_index: usize, - limb_bits: usize, - decomp: usize, - key_vec_len: usize, - keys: Vec>, pub range_checker_gate: RangeCheckerGateChip, } @@ -32,33 +44,17 @@ impl LessThanChip { key_vec_len: usize, keys: Vec>, ) -> Self { - Self { - bus_index, + let air = LessThanAir { limb_bits, decomp, key_vec_len, keys, + }; + + Self { + air, + bus_index, range_checker_gate: RangeCheckerGateChip::::new(bus_index), } } - - pub fn bus_index(&self) -> usize { - self.bus_index - } - - pub fn limb_bits(&self) -> usize { - self.limb_bits - } - - pub fn decomp(&self) -> usize { - self.decomp - } - - pub fn key_vec_len(&self) -> usize { - self.key_vec_len - } - - pub fn keys(&self) -> Vec> { - self.keys.clone() - } } diff --git a/chips/src/less_than/trace.rs b/chips/src/less_than/trace.rs index 9bf4350084..d36dd87229 100644 --- a/chips/src/less_than/trace.rs +++ b/chips/src/less_than/trace.rs @@ -10,16 +10,19 @@ use super::{ impl LessThanChip { pub fn generate_trace(&self) -> RowMajorMatrix { - let num_cols: usize = - LessThanCols::::get_width(self.limb_bits(), self.decomp(), self.key_vec_len()); + let num_cols: usize = LessThanCols::::get_width( + *self.air.limb_bits(), + *self.air.decomp(), + *self.air.key_vec_len(), + ); let mut rows: Vec = vec![]; - for i in 0..self.key_vec_len() { - let key = self.keys[i].clone(); - let next_key: Vec = if i == self.key_vec_len() - 1 { - vec![0; self.key_vec_len()] + for i in 0..*self.air.key_vec_len() { + let key = self.air.keys[i].clone(); + let next_key: Vec = if i == *self.air.key_vec_len() - 1 { + vec![0; *self.air.key_vec_len()] } else { - self.keys[i + 1].clone() + self.air.keys[i + 1].clone() }; let row = self.generate_trace_row((key, next_key)).flatten(); rows.extend_from_slice(&row); @@ -34,11 +37,10 @@ impl LocalTraceInstructions for LessThanChip fn generate_trace_row(&self, consecutive_keys: (Vec, Vec)) -> Self::Cols { let (key, next_key) = consecutive_keys; - let num_limbs = (self.limb_bits() + self.decomp() - 1) / self.decomp(); - let last_limb_shift = (self.decomp() - (self.limb_bits() % self.decomp())) % self.decomp(); + let num_limbs = (self.air.limb_bits() + self.air.decomp() - 1) / self.air.decomp(); + let last_limb_shift = + (self.air.decomp() - (self.air.limb_bits() % self.air.decomp())) % self.air.decomp(); - // this will contain 2^limb_bits + b - a - let mut intermed_sum: Vec = vec![]; // the lower limb_bits bits of the corresponding check value let mut lower_bits: Vec = vec![]; let mut lower_bits_u32: Vec = vec![]; @@ -56,16 +58,18 @@ impl LocalTraceInstructions for LessThanChip for (j, &val) in key.iter().enumerate() { let next_val = next_key[j]; // compute 2^limb_bits + next_val - val - 1 - let check_less_than = (1 << self.limb_bits()) + next_val - val - 1; - intermed_sum.push(F::from_canonical_u32(check_less_than)); + let check_less_than = (1 << self.air.limb_bits()) + next_val - val - 1; + // the lower limb_bits bits of the check value lower_bits.push(F::from_canonical_u32( - check_less_than & ((1 << self.limb_bits()) - 1), + check_less_than & ((1 << self.air.limb_bits()) - 1), )); // we also need the u32 value to compute the decomposition later - lower_bits_u32.push(check_less_than & ((1 << self.limb_bits()) - 1)); + lower_bits_u32.push(check_less_than & ((1 << self.air.limb_bits()) - 1)); // the (n + 1)st bit of the check value, will be 1 if a < b - upper_bit.push(F::from_canonical_u32(check_less_than >> self.limb_bits())); + upper_bit.push(F::from_canonical_u32( + check_less_than >> self.air.limb_bits(), + )); // the difference between the two limbs let curr_diff = F::from_canonical_u32(next_val) - F::from_canonical_u32(val); @@ -90,11 +94,12 @@ impl LocalTraceInstructions for LessThanChip if i != lower_bits_u32.len() { let mut curr_decomp: Vec = vec![]; for j in 0..num_limbs { - let bits = (val >> (j * self.decomp())) & ((1 << self.decomp()) - 1); + let bits = (val >> (j * self.air.decomp())) & ((1 << self.air.decomp()) - 1); curr_decomp.push(F::from_canonical_u32(bits)); self.range_checker_gate.add_count(bits); } - let bits = (val >> ((num_limbs - 1) * self.decomp())) & ((1 << self.decomp()) - 1); + let bits = + (val >> ((num_limbs - 1) * self.air.decomp())) & ((1 << self.air.decomp()) - 1); if (bits << last_limb_shift) < MAX { self.range_checker_gate.add_count(bits << last_limb_shift); } @@ -109,7 +114,6 @@ impl LocalTraceInstructions for LessThanChip key: key.into_iter().map(F::from_canonical_u32).collect(), }; let aux = LessThanAuxCols { - intermed_sum, lower_bits, upper_bit, lower_bits_decomp, From d77fd080cefb4e250d209076405217a53a3e2ba0 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Mon, 3 Jun 2024 13:50:58 -0400 Subject: [PATCH 09/46] chore: cleanup --- chips/src/assert_sorted/air.rs | 6 +----- chips/src/assert_sorted/mod.rs | 20 -------------------- 2 files changed, 1 insertion(+), 25 deletions(-) diff --git a/chips/src/assert_sorted/air.rs b/chips/src/assert_sorted/air.rs index a94514ba30..75a788c518 100644 --- a/chips/src/assert_sorted/air.rs +++ b/chips/src/assert_sorted/air.rs @@ -20,11 +20,7 @@ impl BaseAir for AssertSortedChip { } } -impl Air for AssertSortedChip -where - AB: AirBuilder, - AB::Var: Clone, -{ +impl Air for AssertSortedChip { fn eval(&self, builder: &mut AB) { let main = builder.main(); diff --git a/chips/src/assert_sorted/mod.rs b/chips/src/assert_sorted/mod.rs index 8f571f4dfd..5253ba85c8 100644 --- a/chips/src/assert_sorted/mod.rs +++ b/chips/src/assert_sorted/mod.rs @@ -49,26 +49,6 @@ impl AssertSortedChip { } } - // pub fn bus_index(&self) -> &usize { - // self.less_than_chip.bus_index() - // } - - // pub fn limb_bits(&self) -> &usize { - // self.less_than_chip.air.limb_bits() - // } - - // pub fn decomp(&self) -> &usize { - // self.less_than_chip.air.decomp() - // } - - // pub fn key_vec_len(&self) -> &usize { - // self.less_than_chip.air.key_vec_len() - // } - - // pub fn keys(&self) -> &Vec> { - // self.less_than_chip.air.keys() - // } - pub fn sends_custom( &self, cols: &AssertSortedCols, From 3c8131824152aff083fe34a196838a03f10c6de6 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Mon, 3 Jun 2024 14:07:02 -0400 Subject: [PATCH 10/46] chore: change MAX from generic to instance field for LessThanChip and AssertSortedChip --- chips/src/assert_sorted/air.rs | 4 +-- chips/src/assert_sorted/chip.rs | 2 +- chips/src/assert_sorted/mod.rs | 16 +++++++---- chips/src/assert_sorted/tests/mod.rs | 40 ++++++++++++++++++++++------ chips/src/assert_sorted/trace.rs | 4 +-- chips/src/less_than/air.rs | 8 +++--- chips/src/less_than/chip.rs | 4 +-- chips/src/less_than/mod.rs | 16 ++++++----- chips/src/less_than/tests/mod.rs | 30 ++++++++++++++++----- chips/src/less_than/trace.rs | 6 ++--- 10 files changed, 91 insertions(+), 39 deletions(-) diff --git a/chips/src/assert_sorted/air.rs b/chips/src/assert_sorted/air.rs index 75a788c518..2b976b0322 100644 --- a/chips/src/assert_sorted/air.rs +++ b/chips/src/assert_sorted/air.rs @@ -10,7 +10,7 @@ use crate::sub_chip::SubAir; use super::columns::AssertSortedCols; use super::AssertSortedChip; -impl BaseAir for AssertSortedChip { +impl BaseAir for AssertSortedChip { fn width(&self) -> usize { AssertSortedCols::::get_width( *self.less_than_chip.air.limb_bits(), @@ -20,7 +20,7 @@ impl BaseAir for AssertSortedChip { } } -impl Air for AssertSortedChip { +impl Air for AssertSortedChip { fn eval(&self, builder: &mut AB) { let main = builder.main(); diff --git a/chips/src/assert_sorted/chip.rs b/chips/src/assert_sorted/chip.rs index 331fe632ba..5551445cb5 100644 --- a/chips/src/assert_sorted/chip.rs +++ b/chips/src/assert_sorted/chip.rs @@ -7,7 +7,7 @@ use p3_field::PrimeField64; use super::AssertSortedChip; -impl Chip for AssertSortedChip { +impl Chip for AssertSortedChip { fn sends(&self) -> Vec> { let num_cols = AssertSortedCols::::get_width( *self.less_than_chip.air.limb_bits(), diff --git a/chips/src/assert_sorted/mod.rs b/chips/src/assert_sorted/mod.rs index 5253ba85c8..098600fa86 100644 --- a/chips/src/assert_sorted/mod.rs +++ b/chips/src/assert_sorted/mod.rs @@ -1,4 +1,5 @@ use crate::less_than::LessThanChip; +use getset::Getters; use afs_stark_backend::interaction::Interaction; use columns::AssertSortedCols; @@ -25,22 +26,27 @@ pub mod trace; * The AssertSortedChip contains a LessThanChip subchip, which is used to constrain * that the rows are sorted lexicographically. */ -#[derive(Default)] -pub struct AssertSortedChip { - less_than_chip: LessThanChip, +#[derive(Default, Getters)] +pub struct AssertSortedChip { + #[getset(get = "pub")] + range_max: u32, + less_than_chip: LessThanChip, } -impl AssertSortedChip { +impl AssertSortedChip { pub fn new( bus_index: usize, + range_max: u32, limb_bits: usize, decomp: usize, key_vec_len: usize, keys: Vec>, ) -> Self { Self { - less_than_chip: LessThanChip::::new( + range_max, + less_than_chip: LessThanChip::new( bus_index, + range_max, limb_bits, decomp, key_vec_len, diff --git a/chips/src/assert_sorted/tests/mod.rs b/chips/src/assert_sorted/tests/mod.rs index 978b18cad9..d0394e2fda 100644 --- a/chips/src/assert_sorted/tests/mod.rs +++ b/chips/src/assert_sorted/tests/mod.rs @@ -43,8 +43,14 @@ fn test_assert_sorted_chip_small_positive() { let requests = vec![vec![7784, 35423], vec![17558, 44832]]; - let assert_sorted_chip = - AssertSortedChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); + let assert_sorted_chip = AssertSortedChip::new( + BUS_INDEX, + MAX, + LIMB_BITS, + DECOMP, + KEY_VEC_LEN, + requests.clone(), + ); let assert_sorted_chip_trace: DenseMatrix = assert_sorted_chip.generate_trace(); let assert_sorted_range_chip_trace: DenseMatrix = assert_sorted_chip @@ -82,8 +88,14 @@ fn test_assert_sorted_chip_large_positive() { vec![875005, 767547, 196209, 887921], ]; - let assert_sorted_chip = - AssertSortedChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); + let assert_sorted_chip = AssertSortedChip::new( + BUS_INDEX, + MAX, + LIMB_BITS, + DECOMP, + KEY_VEC_LEN, + requests.clone(), + ); let assert_sorted_chip_trace: DenseMatrix = assert_sorted_chip.generate_trace(); let assert_sorted_range_chip_trace: DenseMatrix = assert_sorted_chip @@ -122,8 +134,14 @@ fn test_assert_sorted_chip_largelimb_negative() { vec![128, 767, 196, 953], ]; - let assert_sorted_chip = - AssertSortedChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); + let assert_sorted_chip = AssertSortedChip::new( + BUS_INDEX, + MAX, + LIMB_BITS, + DECOMP, + KEY_VEC_LEN, + requests.clone(), + ); let assert_sorted_chip_trace: DenseMatrix = assert_sorted_chip.generate_trace(); let assert_sorted_range_chip_trace: DenseMatrix = assert_sorted_chip @@ -167,8 +185,14 @@ fn test_assert_sorted_chip_unsorted_negative() { vec![875005, 767547, 196209, 887921], ]; - let assert_sorted_chip = - AssertSortedChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); + let assert_sorted_chip = AssertSortedChip::new( + BUS_INDEX, + MAX, + LIMB_BITS, + DECOMP, + KEY_VEC_LEN, + requests.clone(), + ); let assert_sorted_chip_trace: DenseMatrix = assert_sorted_chip.generate_trace(); let assert_sorted_range_chip_trace: DenseMatrix = assert_sorted_chip diff --git a/chips/src/assert_sorted/trace.rs b/chips/src/assert_sorted/trace.rs index 1b88635544..f5c8de8f6e 100644 --- a/chips/src/assert_sorted/trace.rs +++ b/chips/src/assert_sorted/trace.rs @@ -5,7 +5,7 @@ use crate::sub_chip::LocalTraceInstructions; use super::{columns::AssertSortedCols, AssertSortedChip}; -impl AssertSortedChip { +impl AssertSortedChip { pub fn generate_trace(&self) -> RowMajorMatrix { let num_cols: usize = AssertSortedCols::::get_width( *self.less_than_chip.air.limb_bits(), @@ -51,7 +51,7 @@ impl AssertSortedChip { // so we need to shift it to constrain this let bits = (val >> ((num_limbs - 1) * self.less_than_chip.air.decomp())) & ((1 << self.less_than_chip.air.decomp()) - 1); - if (bits << last_limb_shift) < MAX { + if (bits << last_limb_shift) < *self.range_max() { self.less_than_chip .range_checker_gate .add_count(bits << last_limb_shift); diff --git a/chips/src/less_than/air.rs b/chips/src/less_than/air.rs index 181257772f..1623f684b8 100644 --- a/chips/src/less_than/air.rs +++ b/chips/src/less_than/air.rs @@ -11,11 +11,11 @@ use super::{ LessThanAir, LessThanChip, }; -impl AirConfig for LessThanChip { +impl AirConfig for LessThanChip { type Cols = LessThanCols; } -impl BaseAir for LessThanChip { +impl BaseAir for LessThanChip { fn width(&self) -> usize { LessThanCols::::get_width( *self.air.limb_bits(), @@ -25,7 +25,7 @@ impl BaseAir for LessThanChip { } } -impl Air for LessThanChip { +impl Air for LessThanChip { fn eval(&self, builder: &mut AB) { let main = builder.main(); @@ -52,7 +52,7 @@ impl Air for LessThanChip { } // sub-chip with constraints to check whether one key is less than the next (row-wise) -impl SubAir for LessThanAir { +impl SubAir for LessThanAir { type IoView = [LessThanIOCols; 2]; type AuxView = LessThanAuxCols; diff --git a/chips/src/less_than/chip.rs b/chips/src/less_than/chip.rs index 82752886ae..726f352e3c 100644 --- a/chips/src/less_than/chip.rs +++ b/chips/src/less_than/chip.rs @@ -7,7 +7,7 @@ use p3_field::PrimeField64; use super::LessThanChip; -impl Chip for LessThanChip { +impl Chip for LessThanChip { fn sends(&self) -> Vec> { let num_cols = LessThanCols::::get_width( *self.air.limb_bits(), @@ -27,7 +27,7 @@ impl Chip for LessThanChip { } } -impl SubAirWithInteractions for LessThanChip { +impl SubAirWithInteractions for LessThanChip { fn sends(&self, col_indices: LessThanCols) -> Vec> { // num_limbs is the number of sublimbs per limb of key, not including the // shifted last sublimb diff --git a/chips/src/less_than/mod.rs b/chips/src/less_than/mod.rs index 596ed65818..28643381a3 100644 --- a/chips/src/less_than/mod.rs +++ b/chips/src/less_than/mod.rs @@ -10,7 +10,9 @@ pub mod columns; pub mod trace; #[derive(Default, Getters)] -pub struct LessThanAir { +pub struct LessThanAir { + #[getset(get = "pub")] + range_max: u32, #[getset(get = "pub")] limb_bits: usize, #[getset(get = "pub")] @@ -27,24 +29,26 @@ pub struct LessThanAir { * Each row consists of a key decomposed into limbs with at most limb_bits bits */ #[derive(Default, Getters)] -pub struct LessThanChip { - pub air: LessThanAir, +pub struct LessThanChip { + pub air: LessThanAir, #[getset(get = "pub")] bus_index: usize, - pub range_checker_gate: RangeCheckerGateChip, + pub range_checker_gate: RangeCheckerGateChip, } -impl LessThanChip { +impl LessThanChip { pub fn new( bus_index: usize, + range_max: u32, limb_bits: usize, decomp: usize, key_vec_len: usize, keys: Vec>, ) -> Self { let air = LessThanAir { + range_max, limb_bits, decomp, key_vec_len, @@ -54,7 +58,7 @@ impl LessThanChip { Self { air, bus_index, - range_checker_gate: RangeCheckerGateChip::::new(bus_index), + range_checker_gate: RangeCheckerGateChip::new(bus_index, range_max), } } } diff --git a/chips/src/less_than/tests/mod.rs b/chips/src/less_than/tests/mod.rs index 6ec8cb501e..d9253e6470 100644 --- a/chips/src/less_than/tests/mod.rs +++ b/chips/src/less_than/tests/mod.rs @@ -39,8 +39,14 @@ fn test_less_than_chip_small_positive() { let requests = vec![vec![7784, 35423], vec![17558, 44832]]; - let less_than_chip = - LessThanChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); + let less_than_chip = LessThanChip::new( + BUS_INDEX, + MAX, + LIMB_BITS, + DECOMP, + KEY_VEC_LEN, + requests.clone(), + ); let less_than_chip_trace: DenseMatrix = less_than_chip.generate_trace(); let less_than_range_chip_trace: DenseMatrix = @@ -72,8 +78,14 @@ fn test_less_than_chip_large_positive() { vec![875005, 767547, 196209, 887921], ]; - let less_than_chip = - LessThanChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); + let less_than_chip = LessThanChip::new( + BUS_INDEX, + MAX, + LIMB_BITS, + DECOMP, + KEY_VEC_LEN, + requests.clone(), + ); let less_than_chip_trace: DenseMatrix = less_than_chip.generate_trace(); let less_than_range_chip_trace: DenseMatrix = @@ -106,8 +118,14 @@ fn test_less_than_chip_unsorted_negative() { vec![875005, 767547, 196209, 887921], ]; - let less_than_chip = - LessThanChip::::new(BUS_INDEX, LIMB_BITS, DECOMP, KEY_VEC_LEN, requests.clone()); + let less_than_chip = LessThanChip::new( + BUS_INDEX, + MAX, + LIMB_BITS, + DECOMP, + KEY_VEC_LEN, + requests.clone(), + ); let less_than_chip_trace: DenseMatrix = less_than_chip.generate_trace(); let less_than_range_chip_trace: DenseMatrix = diff --git a/chips/src/less_than/trace.rs b/chips/src/less_than/trace.rs index d36dd87229..00c454f7dd 100644 --- a/chips/src/less_than/trace.rs +++ b/chips/src/less_than/trace.rs @@ -8,7 +8,7 @@ use super::{ LessThanChip, }; -impl LessThanChip { +impl LessThanChip { pub fn generate_trace(&self) -> RowMajorMatrix { let num_cols: usize = LessThanCols::::get_width( *self.air.limb_bits(), @@ -32,7 +32,7 @@ impl LessThanChip { } } -impl LocalTraceInstructions for LessThanChip { +impl LocalTraceInstructions for LessThanChip { type LocalInput = (Vec, Vec); fn generate_trace_row(&self, consecutive_keys: (Vec, Vec)) -> Self::Cols { @@ -100,7 +100,7 @@ impl LocalTraceInstructions for LessThanChip } let bits = (val >> ((num_limbs - 1) * self.air.decomp())) & ((1 << self.air.decomp()) - 1); - if (bits << last_limb_shift) < MAX { + if (bits << last_limb_shift) < *self.air.range_max() { self.range_checker_gate.add_count(bits << last_limb_shift); } curr_decomp.push(F::from_canonical_u32(bits << last_limb_shift)); From 296b66fbd627b97102ddab614befe8ae33cd8921 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Tue, 4 Jun 2024 10:41:50 -0400 Subject: [PATCH 11/46] feat: IsLessThanChip to compare two numbers --- chips/src/assert_sorted/tests/mod.rs | 9 +-- chips/src/is_less_than/air.rs | 95 ++++++++++++++++++++++++++++ chips/src/is_less_than/chip.rs | 46 ++++++++++++++ chips/src/is_less_than/columns.rs | 81 ++++++++++++++++++++++++ chips/src/is_less_than/mod.rs | 57 +++++++++++++++++ chips/src/is_less_than/tests/mod.rs | 88 ++++++++++++++++++++++++++ chips/src/is_less_than/trace.rs | 71 +++++++++++++++++++++ chips/src/less_than/tests/mod.rs | 8 +-- chips/src/lib.rs | 1 + 9 files changed, 441 insertions(+), 15 deletions(-) create mode 100644 chips/src/is_less_than/air.rs create mode 100644 chips/src/is_less_than/chip.rs create mode 100644 chips/src/is_less_than/columns.rs create mode 100644 chips/src/is_less_than/mod.rs create mode 100644 chips/src/is_less_than/tests/mod.rs create mode 100644 chips/src/is_less_than/trace.rs diff --git a/chips/src/assert_sorted/tests/mod.rs b/chips/src/assert_sorted/tests/mod.rs index d0394e2fda..03524522bd 100644 --- a/chips/src/assert_sorted/tests/mod.rs +++ b/chips/src/assert_sorted/tests/mod.rs @@ -3,6 +3,7 @@ use super::super::assert_sorted; use afs_stark_backend::prover::USE_DEBUG_BUILDER; use afs_stark_backend::verifier::VerificationError; use afs_test_utils::config::baby_bear_poseidon2::run_simple_test_no_pis; +use assert_sorted::AssertSortedChip; use p3_baby_bear::BabyBear; use p3_matrix::dense::DenseMatrix; @@ -32,8 +33,6 @@ use p3_matrix::dense::DenseMatrix; // most limb_bits bits, rows are sorted lexicographically #[test] fn test_assert_sorted_chip_small_positive() { - use assert_sorted::AssertSortedChip; - const BUS_INDEX: usize = 0; const LIMB_BITS: usize = 16; const DECOMP: usize = 8; @@ -72,8 +71,6 @@ fn test_assert_sorted_chip_small_positive() { // most limb_bits bits, rows are sorted lexicographically #[test] fn test_assert_sorted_chip_large_positive() { - use assert_sorted::AssertSortedChip; - const BUS_INDEX: usize = 0; const LIMB_BITS: usize = 30; const DECOMP: usize = 8; @@ -117,8 +114,6 @@ fn test_assert_sorted_chip_large_positive() { // has more than limb_bits bits, rows are sorted lexicographically #[test] fn test_assert_sorted_chip_largelimb_negative() { - use assert_sorted::AssertSortedChip; - const BUS_INDEX: usize = 0; const LIMB_BITS: usize = 10; const DECOMP: usize = 8; @@ -168,8 +163,6 @@ fn test_assert_sorted_chip_largelimb_negative() { // most limb_bits bits, rows are not sorted lexicographically #[test] fn test_assert_sorted_chip_unsorted_negative() { - use assert_sorted::AssertSortedChip; - const BUS_INDEX: usize = 0; const LIMB_BITS: usize = 30; const DECOMP: usize = 8; diff --git a/chips/src/is_less_than/air.rs b/chips/src/is_less_than/air.rs new file mode 100644 index 0000000000..4ed126b286 --- /dev/null +++ b/chips/src/is_less_than/air.rs @@ -0,0 +1,95 @@ +use std::borrow::Borrow; + +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_field::{AbstractField, Field}; +use p3_matrix::Matrix; + +use crate::sub_chip::{AirConfig, SubAir}; + +use super::{ + columns::{IsLessThanAuxCols, IsLessThanCols, IsLessThanIOCols}, + IsLessThanAir, IsLessThanChip, +}; + +impl AirConfig for IsLessThanChip { + type Cols = IsLessThanCols; +} + +impl BaseAir for IsLessThanChip { + fn width(&self) -> usize { + IsLessThanCols::::get_width(*self.air.limb_bits(), *self.air.decomp()) + } +} + +impl Air for IsLessThanChip { + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + + let local = main.row_slice(0); + let local: &[AB::Var] = (*local).borrow(); + + let local_cols = + IsLessThanCols::::from_slice(local, *self.air.limb_bits(), *self.air.decomp()); + + SubAir::eval(&self.air, builder, local_cols.io, local_cols.aux); + } +} + +// sub-chip with constraints to check whether one key is less than the next (row-wise) +impl SubAir for IsLessThanAir { + type IoView = IsLessThanIOCols; + type AuxView = IsLessThanAuxCols; + + // constrain that local_key < next_key lexicographically + fn eval(&self, builder: &mut AB, io: Self::IoView, aux: Self::AuxView) { + let x = io.x; + let y = io.y; + let less_than = io.less_than; + + let local_aux = &aux; + + // num_limbs is the number of sublimbs per limb, not including the shifted last sublimb + let num_limbs = (self.limb_bits() + self.decomp() - 1) / self.decomp(); + + let lower_bits = local_aux.lower_bits; + let upper_bit = local_aux.upper_bit; + let lower_bits_decomp = local_aux.lower_bits_decomp.clone(); + + // to range check the last sublimb of the decomposed limb, we need to shift it to make sure it is in + // the correct range + let last_limb_shift = (self.decomp() - (self.limb_bits() % self.decomp())) % self.decomp(); + + // this is the desired intermediate value (i.e. 2^limb_bits + y - x - 1) + let intermed_val = + y - x + AB::Expr::from_canonical_u64(1 << self.limb_bits()) - AB::Expr::one(); + + // constrain that the lower bits + upper bit * 2^limb_bits is the correct intermediate sum + let check_val = + lower_bits + upper_bit * AB::Expr::from_canonical_u64(1 << self.limb_bits()); + + builder.assert_eq(intermed_val, check_val); + + // constrain that the decomposition of lower_bits is correct + let lower_bits_from_decomp = lower_bits_decomp + .iter() + .enumerate() + .take(num_limbs) + .fold(AB::Expr::zero(), |acc, (i, &val)| { + acc + val * AB::Expr::from_canonical_u64(1 << (i * self.decomp())) + }); + + builder.assert_eq(lower_bits_from_decomp, lower_bits); + + let shifted_val = + lower_bits_decomp[num_limbs - 1] * AB::Expr::from_canonical_u64(1 << last_limb_shift); + + // constrain that the shifted last limb is shifted correctly + builder.assert_eq(lower_bits_decomp[num_limbs], shifted_val); + + // constrain that upper_bit is a boolean + let is_bool = upper_bit * (AB::Expr::one() - upper_bit); + builder.assert_zero(is_bool); + + builder.assert_eq(less_than, upper_bit); + } +} diff --git a/chips/src/is_less_than/chip.rs b/chips/src/is_less_than/chip.rs new file mode 100644 index 0000000000..56d8d0a01f --- /dev/null +++ b/chips/src/is_less_than/chip.rs @@ -0,0 +1,46 @@ +use crate::sub_chip::SubAirWithInteractions; + +use super::columns::IsLessThanCols; +use afs_stark_backend::interaction::{Chip, Interaction}; +use p3_air::VirtualPairCol; +use p3_field::PrimeField64; + +use super::IsLessThanChip; + +impl Chip for IsLessThanChip { + fn sends(&self) -> Vec> { + let num_cols = IsLessThanCols::::get_width(*self.air.limb_bits(), *self.air.decomp()); + let all_cols = (0..num_cols).collect::>(); + + let cols_numbered = IsLessThanCols::::from_slice( + &all_cols, + *self.air.limb_bits(), + *self.air.decomp(), + ); + + SubAirWithInteractions::sends(self, cols_numbered) + } +} + +impl SubAirWithInteractions for IsLessThanChip { + fn sends(&self, col_indices: IsLessThanCols) -> Vec> { + // num_limbs is the number of limbs, not including the last shifted limb + let num_limbs = (*self.air.limb_bits() + *self.air.decomp() - 1) / *self.air.decomp(); + + let mut interactions = vec![]; + + // we range check the limbs of the lower_bits so that we know each element + // of lower_bits has at most limb_bits bits + for i in 0..(num_limbs + 1) { + interactions.push(Interaction { + fields: vec![VirtualPairCol::single_main( + col_indices.aux.lower_bits_decomp[i], + )], + count: VirtualPairCol::constant(F::one()), + argument_index: *self.bus_index(), + }); + } + + interactions + } +} diff --git a/chips/src/is_less_than/columns.rs b/chips/src/is_less_than/columns.rs new file mode 100644 index 0000000000..c0363ae9c1 --- /dev/null +++ b/chips/src/is_less_than/columns.rs @@ -0,0 +1,81 @@ +use afs_derive::AlignedBorrow; + +#[derive(Default, AlignedBorrow)] +pub struct IsLessThanIOCols { + pub x: T, + pub y: T, + pub less_than: T, +} + +pub struct IsLessThanAuxCols { + pub lower_bits: T, + pub upper_bit: T, + pub lower_bits_decomp: Vec, +} + +pub struct IsLessThanCols { + pub io: IsLessThanIOCols, + pub aux: IsLessThanAuxCols, +} + +impl IsLessThanCols { + pub fn from_slice(slc: &[T], limb_bits: usize, decomp: usize) -> Self { + // num_limbs is the number of limbs, not including the last shifted limb + let num_limbs = (limb_bits + decomp - 1) / decomp; + + // the first and second elements are x and y, respectively + let x = slc[0].clone(); + let y = slc[1].clone(); + // the third element is the less_than indicator + let less_than = slc[2].clone(); + + // the next element is the value of the lower num_limbs bits of the intermediate sum + let lower_bits = slc[3].clone(); + + // the next element is the value of the upper bit of the intermediate sum; note that + // y > x <=> upper_bit = 1 + let upper_bit = slc[4].clone(); + + // the next num_limbs + 1 elements are the decomposed limbs of the lower bits of the + // intermediate sum + let lower_bits_decomp = slc[5..5 + num_limbs + 1].to_vec(); + + let io = IsLessThanIOCols { x, y, less_than }; + let aux = IsLessThanAuxCols { + lower_bits, + upper_bit, + lower_bits_decomp, + }; + + Self { io, aux } + } + + pub fn flatten(&self) -> Vec { + let mut flattened = vec![ + self.io.x.clone(), + self.io.y.clone(), + self.io.less_than.clone(), + self.aux.lower_bits.clone(), + self.aux.upper_bit.clone(), + ]; + flattened.extend(self.aux.lower_bits_decomp.iter().cloned()); + flattened + } + + pub fn get_width(limb_bits: usize, decomp: usize) -> usize { + let mut width = 0; + // for the x and y + width += 2; + // for the less_than indicator + width += 1; + // for the lower_bits + width += 1; + // for the upper_bit + width += 1; + // for the decomposed lower_bits + let num_limbs = (limb_bits + decomp - 1) / decomp; + width += num_limbs + 1; + + width + } +} diff --git a/chips/src/is_less_than/mod.rs b/chips/src/is_less_than/mod.rs new file mode 100644 index 0000000000..bdd79e8852 --- /dev/null +++ b/chips/src/is_less_than/mod.rs @@ -0,0 +1,57 @@ +use crate::range_gate::RangeCheckerGateChip; +use getset::Getters; + +#[cfg(test)] +pub mod tests; + +pub mod air; +pub mod chip; +pub mod columns; +pub mod trace; + +#[derive(Default, Getters)] +pub struct IsLessThanAir { + #[getset(get = "pub")] + range_max: u32, + #[getset(get = "pub")] + limb_bits: usize, + #[getset(get = "pub")] + decomp: usize, +} + +/** + * This chip computes whether one number is less than another. + */ +#[derive(Default, Getters)] +pub struct IsLessThanChip { + pub air: IsLessThanAir, + + #[getset(get = "pub")] + bus_index: usize, + + pub range_checker_gate: RangeCheckerGateChip, +} + +impl IsLessThanChip { + pub fn new(bus_index: usize, range_max: u32, limb_bits: usize, decomp: usize) -> Self { + let air = IsLessThanAir { + range_max, + limb_bits, + decomp, + }; + + Self { + air, + bus_index, + range_checker_gate: RangeCheckerGateChip::new(bus_index, range_max), + } + } + + fn calc_less_than(&self, x: u32, y: u32) -> u32 { + if x < y { + 1 + } else { + 0 + } + } +} diff --git a/chips/src/is_less_than/tests/mod.rs b/chips/src/is_less_than/tests/mod.rs new file mode 100644 index 0000000000..607943967e --- /dev/null +++ b/chips/src/is_less_than/tests/mod.rs @@ -0,0 +1,88 @@ +use super::super::is_less_than::IsLessThanChip; + +use afs_stark_backend::prover::USE_DEBUG_BUILDER; +use afs_stark_backend::verifier::VerificationError; +use afs_test_utils::config::baby_bear_poseidon2::run_simple_test_no_pis; +use p3_baby_bear::BabyBear; +use p3_field::AbstractField; +use p3_matrix::dense::DenseMatrix; + +#[test] +fn test_is_less_than_chip_lt() { + const BUS_INDEX: usize = 0; + const LIMB_BITS: usize = 16; + const DECOMP: usize = 8; + const MAX: u32 = 1 << DECOMP; + + let chip = IsLessThanChip::new(BUS_INDEX, MAX, LIMB_BITS, DECOMP); + let trace = chip.generate_trace(14321, 26883); + let range_trace: DenseMatrix = chip.range_checker_gate.generate_trace(); + + run_simple_test_no_pis( + vec![&chip, &chip.range_checker_gate], + vec![trace, range_trace], + ) + .expect("Verification failed"); +} + +#[test] +fn test_is_less_than_chip_gt() { + const BUS_INDEX: usize = 0; + const LIMB_BITS: usize = 16; + const DECOMP: usize = 8; + const MAX: u32 = 1 << DECOMP; + + let chip = IsLessThanChip::new(BUS_INDEX, MAX, LIMB_BITS, DECOMP); + let trace = chip.generate_trace(1, 0); + let range_trace: DenseMatrix = chip.range_checker_gate.generate_trace(); + + run_simple_test_no_pis( + vec![&chip, &chip.range_checker_gate], + vec![trace, range_trace], + ) + .expect("Verification failed"); +} + +#[test] +fn test_is_less_than_chip_eq() { + const BUS_INDEX: usize = 0; + const LIMB_BITS: usize = 16; + const DECOMP: usize = 8; + const MAX: u32 = 1 << DECOMP; + + let chip = IsLessThanChip::new(BUS_INDEX, MAX, LIMB_BITS, DECOMP); + let trace = chip.generate_trace(773, 773); + let range_trace: DenseMatrix = chip.range_checker_gate.generate_trace(); + + run_simple_test_no_pis( + vec![&chip, &chip.range_checker_gate], + vec![trace, range_trace], + ) + .expect("Verification failed"); +} + +#[test] +fn test_is_less_than_negative() { + const BUS_INDEX: usize = 0; + const LIMB_BITS: usize = 16; + const DECOMP: usize = 8; + const MAX: u32 = 1 << DECOMP; + + let chip = IsLessThanChip::new(BUS_INDEX, MAX, LIMB_BITS, DECOMP); + let mut trace = chip.generate_trace(446, 553); + let range_trace = chip.range_checker_gate.generate_trace(); + + trace.values[2] = AbstractField::from_canonical_u64(0); + + USE_DEBUG_BUILDER.with(|debug| { + *debug.lock().unwrap() = false; + }); + assert_eq!( + run_simple_test_no_pis( + vec![&chip, &chip.range_checker_gate], + vec![trace, range_trace], + ), + Err(VerificationError::OodEvaluationMismatch), + "Expected verification to fail, but it passed" + ); +} diff --git a/chips/src/is_less_than/trace.rs b/chips/src/is_less_than/trace.rs new file mode 100644 index 0000000000..df48b13de3 --- /dev/null +++ b/chips/src/is_less_than/trace.rs @@ -0,0 +1,71 @@ +use p3_field::PrimeField64; +use p3_matrix::dense::RowMajorMatrix; + +use crate::sub_chip::LocalTraceInstructions; + +use super::{ + columns::{IsLessThanAuxCols, IsLessThanCols, IsLessThanIOCols}, + IsLessThanChip, +}; + +impl IsLessThanChip { + pub fn generate_trace(&self, x: u32, y: u32) -> RowMajorMatrix { + let num_cols: usize = + IsLessThanCols::::get_width(*self.air.limb_bits(), *self.air.decomp()); + + let row = self.generate_trace_row((x, y)).flatten(); + + RowMajorMatrix::new(row, num_cols) + } +} + +impl LocalTraceInstructions for IsLessThanChip { + type LocalInput = (u32, u32); + + fn generate_trace_row(&self, input: (u32, u32)) -> Self::Cols { + let (x, y) = input; + let less_than = self.calc_less_than(x, y); + + // num_limbs is the number of limbs, not including the last shifted limb + let num_limbs = (self.air.limb_bits() + self.air.decomp() - 1) / self.air.decomp(); + // to range check the last limb of the decomposed lower_bits, we need to shift it to make sure it is in + // the correct range + let last_limb_shift = + (self.air.decomp() - (self.air.limb_bits() % self.air.decomp())) % self.air.decomp(); + + // obtain the lower_bits and upper_bit + let check_less_than = (1 << self.air.limb_bits()) + y - x - 1; + let lower_bits = F::from_canonical_u32(check_less_than & ((1 << self.air.limb_bits()) - 1)); + let lower_bits_u32 = check_less_than & ((1 << self.air.limb_bits()) - 1); + let upper_bit = F::from_canonical_u32(check_less_than >> self.air.limb_bits()); + + // decompose lower_bits into limbs and range check + let mut lower_bits_decomp: Vec = vec![]; + for i in 0..num_limbs { + let bits = (lower_bits_u32 >> (i * self.air.decomp())) & ((1 << self.air.decomp()) - 1); + lower_bits_decomp.push(F::from_canonical_u32(bits)); + self.range_checker_gate.add_count(bits); + } + + // shift the last limb and range check + let bits = (lower_bits_u32 >> ((num_limbs - 1) * self.air.decomp())) + & ((1 << self.air.decomp()) - 1); + if (bits << last_limb_shift) < *self.air.range_max() { + self.range_checker_gate.add_count(bits << last_limb_shift); + } + lower_bits_decomp.push(F::from_canonical_u32(bits << last_limb_shift)); + + let io = IsLessThanIOCols { + x: F::from_canonical_u32(x), + y: F::from_canonical_u32(y), + less_than: F::from_canonical_u32(less_than), + }; + let aux = IsLessThanAuxCols { + lower_bits, + upper_bit, + lower_bits_decomp, + }; + + IsLessThanCols { io, aux } + } +} diff --git a/chips/src/less_than/tests/mod.rs b/chips/src/less_than/tests/mod.rs index d9253e6470..181c713333 100644 --- a/chips/src/less_than/tests/mod.rs +++ b/chips/src/less_than/tests/mod.rs @@ -1,4 +1,4 @@ -use super::super::less_than; +use super::super::less_than::LessThanChip; use afs_stark_backend::prover::USE_DEBUG_BUILDER; use afs_stark_backend::verifier::VerificationError; @@ -28,8 +28,6 @@ use p3_matrix::dense::DenseMatrix; // covers limb_bits < 20, key_vec_len < 4, limb_bits % decomp == 0, number of rows < 4, rows are sorted lexicographically #[test] fn test_less_than_chip_small_positive() { - use less_than::LessThanChip; - const BUS_INDEX: usize = 0; const LIMB_BITS: usize = 16; const DECOMP: usize = 8; @@ -62,8 +60,6 @@ fn test_less_than_chip_small_positive() { // covers limb_bits >= 20, key_vec_len >= 4, limb_bits % decomp != 0, number of rows >= 4, rows are sorted lexicographically #[test] fn test_less_than_chip_large_positive() { - use less_than::LessThanChip; - const BUS_INDEX: usize = 0; const LIMB_BITS: usize = 30; const DECOMP: usize = 8; @@ -101,8 +97,6 @@ fn test_less_than_chip_large_positive() { // covers limb_bits >= 20, key_vec_len >= 4, limb_bits % decomp != 0, number of rows >= 4, rows are not sorted lexicographically #[test] fn test_less_than_chip_unsorted_negative() { - use less_than::LessThanChip; - const BUS_INDEX: usize = 0; const LIMB_BITS: usize = 30; const DECOMP: usize = 8; diff --git a/chips/src/lib.rs b/chips/src/lib.rs index 630a92e314..ceed5fe52c 100644 --- a/chips/src/lib.rs +++ b/chips/src/lib.rs @@ -1,4 +1,5 @@ pub mod assert_sorted; +pub mod is_less_than; pub mod keccak_permute; pub mod less_than; pub mod merkle_proof; From 3610e0d913c3d00cc237203a72163e97aec4443b Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Tue, 4 Jun 2024 15:51:00 -0400 Subject: [PATCH 12/46] feat: IsLessThanTuple subchip for different limb_bits --- chips/src/is_less_than/air.rs | 19 ++- chips/src/is_less_than/columns.rs | 11 +- chips/src/is_less_than/trace.rs | 4 +- chips/src/is_less_than_tuple/air.rs | 127 +++++++++++++++++ chips/src/is_less_than_tuple/chip.rs | 9 ++ chips/src/is_less_than_tuple/columns.rs | 158 ++++++++++++++++++++++ chips/src/is_less_than_tuple/mod.rs | 66 +++++++++ chips/src/is_less_than_tuple/tests/mod.rs | 73 ++++++++++ chips/src/is_less_than_tuple/trace.rs | 100 ++++++++++++++ chips/src/lib.rs | 8 +- 10 files changed, 547 insertions(+), 28 deletions(-) create mode 100644 chips/src/is_less_than_tuple/air.rs create mode 100644 chips/src/is_less_than_tuple/chip.rs create mode 100644 chips/src/is_less_than_tuple/columns.rs create mode 100644 chips/src/is_less_than_tuple/mod.rs create mode 100644 chips/src/is_less_than_tuple/tests/mod.rs create mode 100644 chips/src/is_less_than_tuple/trace.rs diff --git a/chips/src/is_less_than/air.rs b/chips/src/is_less_than/air.rs index 4ed126b286..f593ace38f 100644 --- a/chips/src/is_less_than/air.rs +++ b/chips/src/is_less_than/air.rs @@ -35,12 +35,12 @@ impl Air for IsLessThanChip { } } -// sub-chip with constraints to check whether one key is less than the next (row-wise) +// sub-chip with constraints to check whether one number is less than another impl SubAir for IsLessThanAir { type IoView = IsLessThanIOCols; type AuxView = IsLessThanAuxCols; - // constrain that local_key < next_key lexicographically + // constrain that the result of x < y is given by less_than fn eval(&self, builder: &mut AB, io: Self::IoView, aux: Self::AuxView) { let x = io.x; let y = io.y; @@ -48,14 +48,13 @@ impl SubAir for IsLessThanAir { let local_aux = &aux; - // num_limbs is the number of sublimbs per limb, not including the shifted last sublimb + // num_limbs is the number of limbs, not including the last shifted limb let num_limbs = (self.limb_bits() + self.decomp() - 1) / self.decomp(); let lower_bits = local_aux.lower_bits; - let upper_bit = local_aux.upper_bit; let lower_bits_decomp = local_aux.lower_bits_decomp.clone(); - // to range check the last sublimb of the decomposed limb, we need to shift it to make sure it is in + // to range check the last limb of the decomposed lower_bits, we need to shift it to make sure it is in // the correct range let last_limb_shift = (self.decomp() - (self.limb_bits() % self.decomp())) % self.decomp(); @@ -63,9 +62,9 @@ impl SubAir for IsLessThanAir { let intermed_val = y - x + AB::Expr::from_canonical_u64(1 << self.limb_bits()) - AB::Expr::one(); - // constrain that the lower bits + upper bit * 2^limb_bits is the correct intermediate sum + // constrain that the lower_bits + less_than * 2^limb_bits is the correct intermediate sum let check_val = - lower_bits + upper_bit * AB::Expr::from_canonical_u64(1 << self.limb_bits()); + lower_bits + less_than * AB::Expr::from_canonical_u64(1 << self.limb_bits()); builder.assert_eq(intermed_val, check_val); @@ -86,10 +85,8 @@ impl SubAir for IsLessThanAir { // constrain that the shifted last limb is shifted correctly builder.assert_eq(lower_bits_decomp[num_limbs], shifted_val); - // constrain that upper_bit is a boolean - let is_bool = upper_bit * (AB::Expr::one() - upper_bit); + // constrain that less_than is a boolean + let is_bool = less_than * (AB::Expr::one() - less_than); builder.assert_zero(is_bool); - - builder.assert_eq(less_than, upper_bit); } } diff --git a/chips/src/is_less_than/columns.rs b/chips/src/is_less_than/columns.rs index c0363ae9c1..89c5d6438c 100644 --- a/chips/src/is_less_than/columns.rs +++ b/chips/src/is_less_than/columns.rs @@ -9,7 +9,6 @@ pub struct IsLessThanIOCols { pub struct IsLessThanAuxCols { pub lower_bits: T, - pub upper_bit: T, pub lower_bits_decomp: Vec, } @@ -32,18 +31,13 @@ impl IsLessThanCols { // the next element is the value of the lower num_limbs bits of the intermediate sum let lower_bits = slc[3].clone(); - // the next element is the value of the upper bit of the intermediate sum; note that - // y > x <=> upper_bit = 1 - let upper_bit = slc[4].clone(); - // the next num_limbs + 1 elements are the decomposed limbs of the lower bits of the // intermediate sum - let lower_bits_decomp = slc[5..5 + num_limbs + 1].to_vec(); + let lower_bits_decomp = slc[4..4 + num_limbs + 1].to_vec(); let io = IsLessThanIOCols { x, y, less_than }; let aux = IsLessThanAuxCols { lower_bits, - upper_bit, lower_bits_decomp, }; @@ -56,7 +50,6 @@ impl IsLessThanCols { self.io.y.clone(), self.io.less_than.clone(), self.aux.lower_bits.clone(), - self.aux.upper_bit.clone(), ]; flattened.extend(self.aux.lower_bits_decomp.iter().cloned()); flattened @@ -70,8 +63,6 @@ impl IsLessThanCols { width += 1; // for the lower_bits width += 1; - // for the upper_bit - width += 1; // for the decomposed lower_bits let num_limbs = (limb_bits + decomp - 1) / decomp; width += num_limbs + 1; diff --git a/chips/src/is_less_than/trace.rs b/chips/src/is_less_than/trace.rs index df48b13de3..3409f25bcb 100644 --- a/chips/src/is_less_than/trace.rs +++ b/chips/src/is_less_than/trace.rs @@ -33,11 +33,10 @@ impl LocalTraceInstructions for IsLessThanChip { let last_limb_shift = (self.air.decomp() - (self.air.limb_bits() % self.air.decomp())) % self.air.decomp(); - // obtain the lower_bits and upper_bit + // obtain the lower_bits let check_less_than = (1 << self.air.limb_bits()) + y - x - 1; let lower_bits = F::from_canonical_u32(check_less_than & ((1 << self.air.limb_bits()) - 1)); let lower_bits_u32 = check_less_than & ((1 << self.air.limb_bits()) - 1); - let upper_bit = F::from_canonical_u32(check_less_than >> self.air.limb_bits()); // decompose lower_bits into limbs and range check let mut lower_bits_decomp: Vec = vec![]; @@ -62,7 +61,6 @@ impl LocalTraceInstructions for IsLessThanChip { }; let aux = IsLessThanAuxCols { lower_bits, - upper_bit, lower_bits_decomp, }; diff --git a/chips/src/is_less_than_tuple/air.rs b/chips/src/is_less_than_tuple/air.rs new file mode 100644 index 0000000000..0d56e3f6cc --- /dev/null +++ b/chips/src/is_less_than_tuple/air.rs @@ -0,0 +1,127 @@ +use std::borrow::Borrow; + +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_field::{AbstractField, Field}; +use p3_matrix::Matrix; + +use crate::{ + is_less_than::{columns::IsLessThanCols, IsLessThanChip}, + sub_chip::{AirConfig, SubAir}, +}; + +use super::{ + columns::{IsLessThanTupleAuxCols, IsLessThanTupleCols, IsLessThanTupleIOCols}, + IsLessThanTupleAir, IsLessThanTupleChip, +}; + +impl AirConfig for IsLessThanTupleChip { + type Cols = IsLessThanTupleCols; +} + +impl BaseAir for IsLessThanTupleChip { + fn width(&self) -> usize { + IsLessThanTupleCols::::get_width( + self.air.limb_bits().clone(), + *self.air.decomp(), + *self.air.tuple_len(), + ) + } +} + +impl Air for IsLessThanTupleChip { + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + + let local = main.row_slice(0); + let local: &[AB::Var] = (*local).borrow(); + + let local_cols = IsLessThanTupleCols::::from_slice( + local, + self.air.limb_bits().clone(), + *self.air.decomp(), + *self.air.tuple_len(), + ); + + SubAir::eval(&self.air, builder, local_cols.io, local_cols.aux); + } +} + +// sub-chip with constraints to check whether one tuple is less than the another +impl SubAir for IsLessThanTupleAir { + type IoView = IsLessThanTupleIOCols; + type AuxView = IsLessThanTupleAuxCols; + + // constrain that x < y lexicographically + fn eval(&self, builder: &mut AB, io: Self::IoView, aux: Self::AuxView) { + let x = io.x.clone(); + let y = io.y.clone(); + + for i in 0..x.len() { + let x_val = x[i]; + let y_val = y[i]; + + let is_less_than_chip_dummy = IsLessThanChip::new( + *self.bus_index(), + *self.range_max(), + self.limb_bits()[i], + *self.decomp(), + ); + + // here we constrain that less_than[i] indicates whether x[i] < y[i] using the IsLessThan subchip + let mut is_less_than_slice = vec![x_val, y_val]; + is_less_than_slice.push(aux.less_than[i]); + is_less_than_slice.push(aux.lower_bits[i]); + is_less_than_slice.extend_from_slice(&aux.lower_bits_decomp[i]); + + let is_less_than_cols = IsLessThanCols::::from_slice( + &is_less_than_slice, + self.limb_bits()[i], + *self.decomp(), + ); + + SubAir::eval( + &is_less_than_chip_dummy.air, + builder, + is_less_than_cols.io, + is_less_than_cols.aux, + ); + } + + for i in 0..x.len() { + // constrain that diff is the difference between the two elements of consecutive rows + let diff = y[i] - x[i]; + builder.assert_eq(diff, aux.diff[i]); + } + + // together, these constrain that is_equal is the indicator for whether diff == 0, i.e. x[i] = y[i] + for i in 0..*self.tuple_len() { + let diff = aux.diff[i]; + let is_equal = aux.is_zero[i]; + let inverse = aux.inverses[i]; + + // check that diff * is_equal = 0 + builder.assert_zero(diff * is_equal); + // check that is_equal is boolean + builder.assert_zero(is_equal * (AB::Expr::one() - is_equal)); + // check that inverse * (diff + is_equal) = 1 + builder.assert_one(inverse * (diff + is_equal)); + } + + // to check whether one row is less than another, we can use the indicators to generate a boolean + // expression; the idea is that, starting at the most significant limb, a row is less than the next + // if all the limbs more significant are equal and the current limb is less than the corresponding + // limb in the next row + let mut check_less_than: AB::Expr = AB::Expr::zero(); + let less_than = aux.less_than.clone(); + + for (i, &less_than_value) in less_than.iter().enumerate() { + let mut curr_expr: AB::Expr = less_than_value.into(); + for &is_zero_value in &aux.is_zero[i + 1..] { + curr_expr *= is_zero_value.into(); + } + check_less_than += curr_expr; + } + + builder.assert_eq(io.tuple_less_than, check_less_than); + } +} diff --git a/chips/src/is_less_than_tuple/chip.rs b/chips/src/is_less_than_tuple/chip.rs new file mode 100644 index 0000000000..cf79a648f9 --- /dev/null +++ b/chips/src/is_less_than_tuple/chip.rs @@ -0,0 +1,9 @@ +use crate::sub_chip::SubAirWithInteractions; +use afs_stark_backend::interaction::Chip; +use p3_field::PrimeField64; + +use super::IsLessThanTupleChip; + +impl Chip for IsLessThanTupleChip {} + +impl SubAirWithInteractions for IsLessThanTupleChip {} diff --git a/chips/src/is_less_than_tuple/columns.rs b/chips/src/is_less_than_tuple/columns.rs new file mode 100644 index 0000000000..7078ad62f6 --- /dev/null +++ b/chips/src/is_less_than_tuple/columns.rs @@ -0,0 +1,158 @@ +use afs_derive::AlignedBorrow; + +#[derive(Default, AlignedBorrow)] +pub struct IsLessThanTupleIOCols { + pub x: Vec, + pub y: Vec, + pub tuple_less_than: T, +} + +pub struct IsLessThanTupleAuxCols { + pub less_than: Vec, + pub lower_bits: Vec, + pub lower_bits_decomp: Vec>, + pub diff: Vec, + pub is_zero: Vec, + pub inverses: Vec, +} + +pub struct IsLessThanTupleCols { + pub io: IsLessThanTupleIOCols, + pub aux: IsLessThanTupleAuxCols, +} + +impl IsLessThanTupleCols { + pub fn from_slice(slc: &[T], limb_bits: Vec, decomp: usize, tuple_len: usize) -> Self { + let mut x: Vec = vec![]; + let mut y: Vec = vec![]; + + let mut less_than: Vec = vec![]; + let mut lower_bits: Vec = vec![]; + let mut lower_bits_decomp: Vec> = vec![]; + let mut diff: Vec = vec![]; + let mut is_zero: Vec = vec![]; + let mut inverses: Vec = vec![]; + + let mut curr_start_idx = 0; + let mut curr_end_idx = tuple_len; + + // get the actual tuples, which are x and y + x.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); + + curr_start_idx = curr_end_idx; + curr_end_idx += tuple_len; + + y.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); + + curr_start_idx = curr_end_idx; + curr_end_idx += 1; + + // get the indicator for whether x < y, lexicographically + let tuple_less_than = slc[curr_start_idx].clone(); + + curr_start_idx = curr_end_idx; + curr_end_idx += tuple_len; + + // get the indicators for whether x[i] < y[i] for all indices + less_than.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); + + curr_start_idx = curr_end_idx; + curr_end_idx += tuple_len; + + // get the lower bits for each 2^limb_bits[i] + y[i] - x[i] - 1 + lower_bits.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); + + // get the lower bits decompositions + for &limb_bit in limb_bits.iter() { + let num_limbs = (limb_bit + decomp - 1) / decomp; + curr_start_idx = curr_end_idx; + curr_end_idx += num_limbs + 1; + + let mut lower_bits_curr: Vec = vec![]; + + for j in 0..(num_limbs + 1) { + lower_bits_curr.push(slc[curr_start_idx + j].clone()); + } + + lower_bits_decomp.push(lower_bits_curr); + } + + curr_start_idx = curr_end_idx; + curr_end_idx += tuple_len; + + // get the differences y[i] - x[i] + diff.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); + + curr_start_idx = curr_end_idx; + curr_end_idx += tuple_len; + + // get whether y[i] - x[i] == 0 + is_zero.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); + + curr_start_idx = curr_end_idx; + curr_end_idx += tuple_len; + + // get the inverses k such that k * (diff[i] + is_zero[i]) = 1 + inverses.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); + + IsLessThanTupleCols { + io: IsLessThanTupleIOCols { + x, + y, + tuple_less_than, + }, + aux: IsLessThanTupleAuxCols { + less_than, + lower_bits, + lower_bits_decomp, + diff, + is_zero, + inverses, + }, + } + } + + pub fn flatten(&self) -> Vec { + let mut flattened = vec![]; + flattened.extend_from_slice(&self.io.x); + flattened.extend_from_slice(&self.io.y); + flattened.push(self.io.tuple_less_than.clone()); + flattened.extend_from_slice(&self.aux.less_than); + flattened.extend_from_slice(&self.aux.lower_bits); + for i in 0..self.aux.lower_bits_decomp.len() { + flattened.extend_from_slice(&self.aux.lower_bits_decomp[i]); + } + flattened.extend_from_slice(&self.aux.diff); + flattened.extend_from_slice(&self.aux.is_zero); + flattened.extend_from_slice(&self.aux.inverses); + + flattened + } + + pub fn get_width(limb_bits: Vec, decomp: usize, tuple_len: usize) -> usize { + let mut width = 0; + // for the x and y tuples + width += 2 * tuple_len; + // for the tuple less than indicator + width += 1; + // for the less than indicator + width += tuple_len; + // for the lower bits + width += tuple_len; + + // for the lower bits decomposition + for &limb_bit in limb_bits.iter() { + let num_limbs = (limb_bit + decomp - 1) / decomp; + width += num_limbs + 1; + } + + // for the difference between consecutive rows + width += tuple_len; + // for the indicator whether difference is zero + width += tuple_len; + // for the inverses k such that k * (diff[i] + is_zero[i]) = 1 + width += tuple_len; + + width + } +} diff --git a/chips/src/is_less_than_tuple/mod.rs b/chips/src/is_less_than_tuple/mod.rs new file mode 100644 index 0000000000..3f1c889f7a --- /dev/null +++ b/chips/src/is_less_than_tuple/mod.rs @@ -0,0 +1,66 @@ +use getset::Getters; + +use crate::is_less_than::IsLessThanChip; + +#[cfg(test)] +pub mod tests; + +pub mod air; +pub mod chip; +pub mod columns; +pub mod trace; + +#[derive(Default, Getters)] +pub struct IsLessThanTupleAir { + #[getset(get = "pub")] + bus_index: usize, + #[getset(get = "pub")] + range_max: u32, + #[getset(get = "pub")] + limb_bits: Vec, + #[getset(get = "pub")] + decomp: usize, + #[getset(get = "pub")] + tuple_len: usize, +} + +/** + * This Chip constrains that consecutive rows are sorted lexicographically. + * + * Each row consists of a key decomposed into limbs with at most limb_bits bits + */ +#[derive(Default, Getters)] +pub struct IsLessThanTupleChip { + pub air: IsLessThanTupleAir, + + pub is_less_than_chips: Vec, +} + +impl IsLessThanTupleChip { + pub fn new( + bus_index: usize, + range_max: u32, + limb_bits: Vec, + decomp: usize, + tuple_len: usize, + ) -> Self { + let air = IsLessThanTupleAir { + bus_index, + range_max, + limb_bits: limb_bits.clone(), + decomp, + tuple_len, + }; + + // create less_than_chips which will be used to compare individual tuple elements + let is_less_than_chips = limb_bits + .iter() + .map(|&limb_bit| IsLessThanChip::new(bus_index, range_max, limb_bit, decomp)) + .collect::>(); + + Self { + air, + is_less_than_chips, + } + } +} diff --git a/chips/src/is_less_than_tuple/tests/mod.rs b/chips/src/is_less_than_tuple/tests/mod.rs new file mode 100644 index 0000000000..20fa58f146 --- /dev/null +++ b/chips/src/is_less_than_tuple/tests/mod.rs @@ -0,0 +1,73 @@ +use super::super::is_less_than_tuple::IsLessThanTupleChip; + +use afs_stark_backend::prover::USE_DEBUG_BUILDER; +use afs_stark_backend::verifier::VerificationError; +use afs_test_utils::config::baby_bear_poseidon2::run_simple_test_no_pis; +use p3_field::AbstractField; + +#[test] +fn test_is_less_than_tuple_chip_lt() { + let bus_index: usize = 0; + let limb_bits: Vec = vec![16, 8]; + let decomp: usize = 8; + let range_max: u32 = 1 << decomp; + let tuple_len: usize = 2; + + let chip = IsLessThanTupleChip::new(bus_index, range_max, limb_bits, decomp, tuple_len); + let trace = chip.generate_trace(vec![14321, 123], vec![26678, 233]); + + run_simple_test_no_pis(vec![&chip], vec![trace]).expect("Verification failed"); +} + +#[test] +fn test_is_less_than_tuple_chip_gt() { + let bus_index: usize = 0; + let limb_bits: Vec = vec![16, 8]; + let decomp: usize = 8; + let range_max: u32 = 1 << decomp; + let tuple_len: usize = 2; + + let chip = IsLessThanTupleChip::new(bus_index, range_max, limb_bits, decomp, tuple_len); + let trace = chip.generate_trace(vec![14321, 244], vec![26678, 233]); + + println!("{:?}", trace); + + run_simple_test_no_pis(vec![&chip], vec![trace]).expect("Verification failed"); +} + +#[test] +fn test_is_less_than_tuple_chip_eq() { + let bus_index: usize = 0; + let limb_bits: Vec = vec![16, 8]; + let decomp: usize = 8; + let range_max: u32 = 1 << decomp; + let tuple_len: usize = 2; + + let chip = IsLessThanTupleChip::new(bus_index, range_max, limb_bits, decomp, tuple_len); + let trace = chip.generate_trace(vec![14321, 244], vec![14321, 244]); + + run_simple_test_no_pis(vec![&chip], vec![trace]).expect("Verification failed"); +} + +#[test] +fn test_is_less_than_tuple_chip_negative() { + let bus_index: usize = 0; + let limb_bits: Vec = vec![16, 8]; + let decomp: usize = 8; + let range_max: u32 = 1 << decomp; + let tuple_len: usize = 2; + + let chip = IsLessThanTupleChip::new(bus_index, range_max, limb_bits, decomp, tuple_len); + let mut trace = chip.generate_trace(vec![14321, 123], vec![26678, 233]); + + trace.values[2] = AbstractField::from_canonical_u64(0); + + USE_DEBUG_BUILDER.with(|debug| { + *debug.lock().unwrap() = false; + }); + assert_eq!( + run_simple_test_no_pis(vec![&chip], vec![trace],), + Err(VerificationError::OodEvaluationMismatch), + "Expected verification to fail, but it passed" + ); +} diff --git a/chips/src/is_less_than_tuple/trace.rs b/chips/src/is_less_than_tuple/trace.rs new file mode 100644 index 0000000000..6c594f5756 --- /dev/null +++ b/chips/src/is_less_than_tuple/trace.rs @@ -0,0 +1,100 @@ +use p3_field::PrimeField64; +use p3_matrix::dense::RowMajorMatrix; + +use crate::sub_chip::LocalTraceInstructions; + +use super::{ + columns::{IsLessThanTupleAuxCols, IsLessThanTupleCols, IsLessThanTupleIOCols}, + IsLessThanTupleChip, +}; + +impl IsLessThanTupleChip { + pub fn generate_trace(&self, x: Vec, y: Vec) -> RowMajorMatrix { + let num_cols: usize = IsLessThanTupleCols::::get_width( + self.air.limb_bits().clone(), + *self.air.decomp(), + *self.air.tuple_len(), + ); + + let row: Vec = self.generate_trace_row((x, y)).flatten(); + + RowMajorMatrix::new(row, num_cols) + } +} + +impl LocalTraceInstructions for IsLessThanTupleChip { + type LocalInput = (Vec, Vec); + + fn generate_trace_row(&self, input: Self::LocalInput) -> Self::Cols { + let (x, y) = input; + + let mut less_than: Vec = vec![]; + let mut lower_bits: Vec = vec![]; + let mut lower_bits_decomp: Vec> = vec![]; + + let mut valid = true; + let mut tuple_less_than = F::zero(); + + // use subchip to generate relevant columns + for i in 0..x.len() { + let is_less_than_chip = &self.is_less_than_chips[i]; + + let curr_less_than_row = + LocalTraceInstructions::generate_trace_row(is_less_than_chip, (x[i], y[i])) + .flatten(); + less_than.push(curr_less_than_row[2]); + lower_bits.push(curr_less_than_row[3]); + lower_bits_decomp.push(curr_less_than_row[4..].to_vec()); + } + + // compute whether the x < y + for i in (0..x.len()).rev() { + if x[i] < y[i] && valid { + tuple_less_than = F::one(); + } else if x[i] > y[i] && valid { + valid = false; + } + } + + // contains the difference between consecutive rows + let mut diff: Vec = vec![]; + // contains indicator whether difference is zero + let mut is_zero: Vec = vec![]; + // contains y such that y * (i + x) = 1 + let mut inverses: Vec = vec![]; + + // we compute the indicators, which only matter if the row is not the last + for (i, &val) in x.iter().enumerate() { + let next_val = y[i]; + + // the difference between the two limbs + let curr_diff = F::from_canonical_u32(next_val) - F::from_canonical_u32(val); + diff.push(curr_diff); + + // compute the equal indicator and inverses + if next_val == val { + is_zero.push(F::one()); + inverses.push((curr_diff + F::one()).inverse()); + } else { + is_zero.push(F::zero()); + inverses.push(curr_diff.inverse()); + } + } + + let io = IsLessThanTupleIOCols { + x: x.into_iter().map(F::from_canonical_u32).collect(), + y: y.into_iter().map(F::from_canonical_u32).collect(), + tuple_less_than, + }; + let aux = IsLessThanTupleAuxCols { + less_than, + lower_bits, + lower_bits_decomp, + diff, + is_zero, + inverses, + }; + + IsLessThanTupleCols { io, aux } + } +} diff --git a/chips/src/lib.rs b/chips/src/lib.rs index 4b154312f1..e9cd8fe456 100644 --- a/chips/src/lib.rs +++ b/chips/src/lib.rs @@ -1,5 +1,9 @@ pub mod assert_sorted; +pub mod is_equal; +pub mod is_equal_vec; pub mod is_less_than; +pub mod is_less_than_tuple; +pub mod is_zero; pub mod keccak_permute; pub mod less_than; pub mod merkle_proof; @@ -13,7 +17,3 @@ mod utils; pub mod xor_bits; pub mod xor_limbs; pub mod xor_lookup; - -pub mod is_equal; -pub mod is_equal_vec; -pub mod is_zero; From 50d54a49b7e90a8e6e22048ffef8174316a9292c Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Tue, 4 Jun 2024 18:31:19 -0400 Subject: [PATCH 13/46] feat: IsLessThanTupleChip subchip in AssertSortedChip --- chips/src/assert_sorted/air.rs | 99 +++++---- chips/src/assert_sorted/chip.rs | 32 ++- chips/src/assert_sorted/columns.rs | 150 ++++++++------ chips/src/assert_sorted/mod.rs | 49 +++-- chips/src/assert_sorted/tests/mod.rs | 237 ++++++++++++++-------- chips/src/assert_sorted/trace.rs | 58 +++--- chips/src/is_less_than_tuple/chip.rs | 56 ++++- chips/src/is_less_than_tuple/tests/mod.rs | 67 +++++- chips/src/less_than/air.rs | 146 ------------- chips/src/less_than/chip.rs | 55 ----- chips/src/less_than/columns.rs | 120 ----------- chips/src/less_than/mod.rs | 64 ------ chips/src/less_than/tests/mod.rs | 139 ------------- chips/src/less_than/trace.rs | 127 ------------ chips/src/lib.rs | 1 - 15 files changed, 488 insertions(+), 912 deletions(-) delete mode 100644 chips/src/less_than/air.rs delete mode 100644 chips/src/less_than/chip.rs delete mode 100644 chips/src/less_than/columns.rs delete mode 100644 chips/src/less_than/mod.rs delete mode 100644 chips/src/less_than/tests/mod.rs delete mode 100644 chips/src/less_than/trace.rs diff --git a/chips/src/assert_sorted/air.rs b/chips/src/assert_sorted/air.rs index 2b976b0322..dc47aabd62 100644 --- a/chips/src/assert_sorted/air.rs +++ b/chips/src/assert_sorted/air.rs @@ -4,7 +4,7 @@ use p3_air::{Air, AirBuilder, BaseAir}; use p3_field::{AbstractField, Field}; use p3_matrix::Matrix; -use crate::less_than::columns::LessThanCols; +use crate::is_less_than_tuple::columns::IsLessThanTupleCols; use crate::sub_chip::SubAir; use super::columns::AssertSortedCols; @@ -13,9 +13,9 @@ use super::AssertSortedChip; impl BaseAir for AssertSortedChip { fn width(&self) -> usize { AssertSortedCols::::get_width( - *self.less_than_chip.air.limb_bits(), - *self.less_than_chip.air.decomp(), - *self.less_than_chip.air.key_vec_len(), + self.air.limb_bits().clone(), + *self.air.decomp(), + *self.air.key_vec_len(), ) } } @@ -24,33 +24,41 @@ impl Air for AssertSortedChip { fn eval(&self, builder: &mut AB) { let main = builder.main(); + // get the current row and the next row let (local, next) = (main.row_slice(0), main.row_slice(1)); let local: &[AB::Var] = (*local).borrow(); + let next: &[AB::Var] = (*next).borrow(); let local_cols = AssertSortedCols::::from_slice( local, - *self.less_than_chip.air.limb_bits(), - *self.less_than_chip.air.decomp(), - *self.less_than_chip.air.key_vec_len(), + self.air.limb_bits().clone(), + *self.air.decomp(), + *self.air.key_vec_len(), ); - let num_limbs = (*self.less_than_chip.air.limb_bits() + *self.less_than_chip.air.decomp() - - 1) - / *self.less_than_chip.air.decomp(); - let key_len = *self.less_than_chip.air.key_vec_len(); + let next_cols = AssertSortedCols::::from_slice( + next, + self.air.limb_bits().clone(), + *self.air.decomp(), + *self.air.key_vec_len(), + ); - // to range check the last sublimb of the decomposed limb, we need to shift it to make sure it is in - // the correct range - let last_limb_shift = (*self.less_than_chip.air.decomp() - - (*self.less_than_chip.air.limb_bits() % *self.less_than_chip.air.decomp())) - % *self.less_than_chip.air.decomp(); + let key_len = *self.air.key_vec_len(); for i in 0..key_len { let mut key_from_limbs: AB::Expr = AB::Expr::zero(); + + // num_limbs is the number of sublimbs the current limb should be decomposed into + let num_limbs = (self.air.limb_bits()[i] + self.air.decomp() - 1) / self.air.decomp(); + // to range check the last sublimb, we need to shift it + let last_limb_shift = (self.air.decomp() + - (self.air.limb_bits()[i] % self.air.decomp())) + % self.air.decomp(); + // constrain that the decomposition is correct for j in 0..num_limbs { key_from_limbs += local_cols.keys_decomp[i][j] - * AB::Expr::from_canonical_u64(1 << (j * self.less_than_chip.air.decomp())); + * AB::Expr::from_canonical_u64(1 << (j * self.air.decomp())); } // constrain that the shifted last sublimb is shifted correctly @@ -58,36 +66,49 @@ impl Air for AssertSortedChip { * AB::Expr::from_canonical_u64(1 << last_limb_shift); builder.assert_eq(local_cols.keys_decomp[i][num_limbs], shifted_val); - builder.assert_eq(key_from_limbs, local_cols.less_than_cols.io.key[i]); + builder.assert_eq(key_from_limbs, local_cols.is_less_than_tuple_cols.io.x[i]); + + // constrain that the keys are consistent across rows + builder.when_transition().assert_eq( + local_cols.is_less_than_tuple_cols.io.y[i], + next_cols.is_less_than_tuple_cols.io.x[i], + ); } - // generate LessThanCols struct for current row and next row - let mut local_slice: Vec = local[0..key_len].to_vec(); - local_slice.extend_from_slice(&local[key_len * (num_limbs + 2)..]); + // constrain that the current key is less than the next + builder + .when_transition() + .assert_one(local_cols.is_less_than_tuple_cols.io.tuple_less_than); - let mut next_slice: Vec = next[0..key_len].to_vec(); - next_slice - .extend_from_slice(&next[(self.less_than_chip.air.key_vec_len() * (num_limbs + 2))..]); + // generate IsLessThanTupleCols struct for current row and next row + let mut curr_start_idx = 0; + let mut curr_end_idx = 2 * key_len; + // get the current key and next key + let mut local_slice: Vec = local[curr_start_idx..curr_end_idx].to_vec(); - let local_cols = LessThanCols::::from_slice( - &local_slice, - *self.less_than_chip.air.limb_bits(), - *self.less_than_chip.air.decomp(), - *self.less_than_chip.air.key_vec_len(), - ); + // skip the key decomposition + for i in 0..key_len { + let num_limbs = (self.air.limb_bits()[i] + self.air.decomp() - 1) / self.air.decomp(); + curr_end_idx += num_limbs + 1; + } + + // get the rest of the columns + curr_start_idx = curr_end_idx; - let next_cols = LessThanCols::::from_slice( - &next_slice, - *self.less_than_chip.air.limb_bits(), - *self.less_than_chip.air.decomp(), - *self.less_than_chip.air.key_vec_len(), + local_slice.extend_from_slice(&local[curr_start_idx..]); + + let local_cols = IsLessThanTupleCols::::from_slice( + &local_slice, + self.air.limb_bits().clone(), + *self.air.decomp(), + *self.air.key_vec_len(), ); - // constrain the current row is less than the next row + // constrain the indicator that we used to check whether the current key < next key is correct SubAir::eval( - &self.less_than_chip.air, - builder, - [local_cols.io, next_cols.io], + &self.is_less_than_tuple_chip.air, + &mut builder.when_transition(), + local_cols.io, local_cols.aux, ); } diff --git a/chips/src/assert_sorted/chip.rs b/chips/src/assert_sorted/chip.rs index 5551445cb5..b767d0b14b 100644 --- a/chips/src/assert_sorted/chip.rs +++ b/chips/src/assert_sorted/chip.rs @@ -10,42 +10,40 @@ use super::AssertSortedChip; impl Chip for AssertSortedChip { fn sends(&self) -> Vec> { let num_cols = AssertSortedCols::::get_width( - *self.less_than_chip.air.limb_bits(), - *self.less_than_chip.air.decomp(), - *self.less_than_chip.air.key_vec_len(), + self.air.limb_bits().clone(), + *self.air.decomp(), + *self.air.key_vec_len(), ); let all_cols = (0..num_cols).collect::>(); let cols_numbered = AssertSortedCols::::from_slice( &all_cols, - *self.less_than_chip.air.limb_bits(), - *self.less_than_chip.air.decomp(), - *self.less_than_chip.air.key_vec_len(), + self.air.limb_bits().clone(), + *self.air.decomp(), + *self.air.key_vec_len(), ); let mut interactions: Vec> = vec![]; - let num_limbs = (*self.less_than_chip.air.limb_bits() + *self.less_than_chip.air.decomp() - - 1) - / *self.less_than_chip.air.decomp(); - let num_keys = *self.less_than_chip.air.key_vec_len(); - // we will range check the decomposed limbs of the key - for i in 0..num_keys { + for i in 0..*self.air.key_vec_len() { + let num_limbs = (self.air.limb_bits()[i] + *self.air.decomp() - 1) / *self.air.decomp(); // add 1 to account for the shifted last sublimb for j in 0..(num_limbs + 1) { interactions.push(Interaction { fields: vec![VirtualPairCol::single_main(cols_numbered.keys_decomp[i][j])], count: VirtualPairCol::constant(F::one()), - argument_index: *self.less_than_chip.bus_index(), + argument_index: self.range_checker_gate.bus_index(), }); } } - // append the interactions from the subchip - let mut less_than_interactions: Vec> = - SubAirWithInteractions::::sends(&self.less_than_chip, cols_numbered.less_than_cols); - interactions.append(&mut less_than_interactions); + let subchip_interactions = SubAirWithInteractions::::sends( + &self.is_less_than_tuple_chip, + cols_numbered.is_less_than_tuple_cols, + ); + + interactions.extend(subchip_interactions); interactions } diff --git a/chips/src/assert_sorted/columns.rs b/chips/src/assert_sorted/columns.rs index b9b59e6701..92b44812c4 100644 --- a/chips/src/assert_sorted/columns.rs +++ b/chips/src/assert_sorted/columns.rs @@ -1,105 +1,139 @@ use afs_derive::AlignedBorrow; -use crate::less_than::columns::{LessThanAuxCols, LessThanCols, LessThanIOCols}; +use crate::is_less_than_tuple::columns::{ + IsLessThanTupleAuxCols, IsLessThanTupleCols, IsLessThanTupleIOCols, +}; // Since AssertSortedChip contains a LessThanChip subchip, a subset of the columns are those of the // LessThanChip #[derive(AlignedBorrow)] pub struct AssertSortedCols { pub keys_decomp: Vec>, - pub less_than_cols: LessThanCols, + pub is_less_than_tuple_cols: IsLessThanTupleCols, } impl AssertSortedCols { - pub fn from_slice(slc: &[T], limb_bits: usize, decomp: usize, key_vec_len: usize) -> Self { + pub fn from_slice(slc: &[T], limb_bits: Vec, decomp: usize, key_vec_len: usize) -> Self { // num_limbs is the number of sublimbs per limb, not including the shifted last sublimb - let num_limbs = (limb_bits + decomp - 1) / decomp; - let mut cur_start_idx = 0; - let mut cur_end_idx = key_vec_len; + let mut curr_start_idx = 0; + let mut curr_end_idx = key_vec_len; // the first key_vec_len elements are the key itself - let key = slc[cur_start_idx..cur_end_idx].to_vec(); - cur_start_idx = cur_end_idx; - cur_end_idx += key_vec_len * (num_limbs + 1); - - // the next key_vec_len * (num_limbs + 1) elements are the decomposed keys (with each having - // an extra shifted last sublimb) - let keys_decomp = slc[cur_start_idx..cur_end_idx] - .chunks(num_limbs + 1) - .map(|chunk| chunk.to_vec()) - .collect(); - cur_start_idx = cur_end_idx; - cur_end_idx += key_vec_len; - - // the next key_vec_len elements are the values of the lower num_limbs bits of the intermediate sum - let lower_bits = slc[cur_start_idx..cur_end_idx].to_vec(); - cur_start_idx = cur_end_idx; - cur_end_idx += key_vec_len; - - // the next key_vec_len elements are the values of the upper bit of the intermediate sum; note that - // b > a <=> upper_bit = 1 - let upper_bit = slc[cur_start_idx..cur_end_idx].to_vec(); - cur_start_idx = cur_end_idx; - cur_end_idx += key_vec_len * (num_limbs + 1); - - // the next key_vec_len * (num_limbs + 1) elements are the decomposed limbs of the lower bits of the - // intermediate sum - let lower_bits_decomp = slc[cur_start_idx..cur_end_idx] - .chunks(num_limbs + 1) - .map(|chunk| chunk.to_vec()) - .collect(); - cur_start_idx = cur_end_idx; - cur_end_idx += key_vec_len; + let x = slc[curr_start_idx..curr_end_idx].to_vec(); + curr_start_idx = curr_end_idx; + curr_end_idx += key_vec_len; + + // the next key_vec_len elements are the next key (the following row) + let y = slc[curr_start_idx..curr_end_idx].to_vec(); + + // the next elements are the decomposed key (with each having an extra shifted last sublimb) + let mut keys_decomp: Vec> = vec![]; + + for curr_limb_bits in limb_bits.iter() { + let num_limbs = (curr_limb_bits + decomp - 1) / decomp; + + curr_start_idx = curr_end_idx; + curr_end_idx += num_limbs + 1; + + keys_decomp.push(slc[curr_start_idx..curr_end_idx].to_vec()); + } + + curr_start_idx = curr_end_idx; + curr_end_idx += 1; + + // the next element is the indicator for whether the key is less than the next key + let tuple_less_than = slc[curr_start_idx].clone(); + curr_start_idx = curr_end_idx; + curr_end_idx += key_vec_len; + + // the next key_vec_len elements are the indicators for the individual tuple element less thans + let less_than = slc[curr_start_idx..curr_end_idx].to_vec(); + curr_start_idx = curr_end_idx; + curr_end_idx += key_vec_len; + + // the next key_vec_len elements are the values of the lower bits of each intermediate sum + // (i.e. 2^limb_bits[i] + y[i] - x[i] - 1) + let lower_bits = slc[curr_start_idx..curr_end_idx].to_vec(); + + // the next elements are the decomposed lower bits + let mut lower_bits_decomp: Vec> = vec![]; + for curr_limb_bits in limb_bits.iter() { + let num_limbs = (curr_limb_bits + decomp - 1) / decomp; + + curr_start_idx = curr_end_idx; + curr_end_idx += num_limbs + 1; + + lower_bits_decomp.push(slc[curr_start_idx..curr_end_idx].to_vec()); + } + curr_start_idx = curr_end_idx; + curr_end_idx += key_vec_len; // the next key_vec_len elements are the difference between consecutive limbs of rows - let diff = slc[cur_start_idx..cur_end_idx].to_vec(); - cur_start_idx = cur_end_idx; - cur_end_idx += key_vec_len; + let diff = slc[curr_start_idx..curr_end_idx].to_vec(); + curr_start_idx = curr_end_idx; + curr_end_idx += key_vec_len; // the next key_vec_len elements are the indicator whether the difference is zero; if difference is // zero then the two limbs must be equal - let is_zero = slc[cur_start_idx..cur_end_idx].to_vec(); - cur_start_idx = cur_end_idx; - cur_end_idx += key_vec_len; + let is_zero = slc[curr_start_idx..curr_end_idx].to_vec(); + curr_start_idx = curr_end_idx; + curr_end_idx += key_vec_len; // the next key_vec_len elements contain the inverses of the corresponding sum of diff and is_zero; // note that this sum will always be nonzero so the inverse will exist - let inverses = slc[cur_start_idx..cur_end_idx].to_vec(); + let inverses = slc[curr_start_idx..curr_end_idx].to_vec(); - let io = LessThanIOCols { key }; - let aux = LessThanAuxCols { + let io = IsLessThanTupleIOCols { + x, + y, + tuple_less_than, + }; + let aux = IsLessThanTupleAuxCols { + less_than, lower_bits, - upper_bit, lower_bits_decomp, diff, is_zero, inverses, }; - let less_than_cols = LessThanCols { io, aux }; + let is_less_than_tuple_cols = IsLessThanTupleCols { io, aux }; Self { keys_decomp, - less_than_cols, + is_less_than_tuple_cols, } } - pub fn get_width(limb_bits: usize, decomp: usize, key_vec_len: usize) -> usize { + pub fn get_width(limb_bits: Vec, decomp: usize, key_vec_len: usize) -> usize { // there are (limb_bits + decomp - 1) / decomp sublimbs per limb, we add 1 to // account for the sublimb itself, and another 1 to account for the shifted // last sublimb let mut width = 0; - // for the key itself + // for the x and y keys + width += 2 * key_vec_len; + + // for the decomposed key + for &limb_bit in limb_bits.iter() { + let num_limbs = (limb_bit + decomp - 1) / decomp; + width += num_limbs + 1; + } + + // for the tuple less than indicator + width += 1; + + // for the less_than indicators width += key_vec_len; - // for the decomposed keys - let num_limbs = (limb_bits + decomp - 1) / decomp; - width += key_vec_len * (num_limbs + 1); + // for the lower_bits width += key_vec_len; - // for the upper_bit - width += key_vec_len; + // for the decomposed lower_bits - width += key_vec_len * (num_limbs + 1); + for &limb_bit in limb_bits.iter() { + let num_limbs = (limb_bit + decomp - 1) / decomp; + width += num_limbs + 1; + } + // for the difference between consecutive rows width += key_vec_len; // for the indicator whether difference is zero diff --git a/chips/src/assert_sorted/mod.rs b/chips/src/assert_sorted/mod.rs index 098600fa86..798491b4f7 100644 --- a/chips/src/assert_sorted/mod.rs +++ b/chips/src/assert_sorted/mod.rs @@ -1,4 +1,4 @@ -use crate::less_than::LessThanChip; +use crate::{is_less_than_tuple::IsLessThanTupleChip, range_gate::RangeCheckerGateChip}; use getset::Getters; use afs_stark_backend::interaction::Interaction; @@ -14,6 +14,22 @@ pub mod chip; pub mod columns; pub mod trace; +#[derive(Default, Getters)] +pub struct AssertedSortedAir { + #[getset(get = "pub")] + bus_index: usize, + #[getset(get = "pub")] + range_max: u32, + #[getset(get = "pub")] + limb_bits: Vec, + #[getset(get = "pub")] + decomp: usize, + #[getset(get = "pub")] + key_vec_len: usize, + #[getset(get = "pub")] + keys: Vec>, +} + /** * This Chip constrains that consecutive rows are sorted lexicographically. * @@ -26,32 +42,39 @@ pub mod trace; * The AssertSortedChip contains a LessThanChip subchip, which is used to constrain * that the rows are sorted lexicographically. */ -#[derive(Default, Getters)] +#[derive(Default)] pub struct AssertSortedChip { - #[getset(get = "pub")] - range_max: u32, - less_than_chip: LessThanChip, + air: AssertedSortedAir, + is_less_than_tuple_chip: IsLessThanTupleChip, + range_checker_gate: RangeCheckerGateChip, } impl AssertSortedChip { pub fn new( bus_index: usize, range_max: u32, - limb_bits: usize, + limb_bits: Vec, decomp: usize, key_vec_len: usize, keys: Vec>, ) -> Self { Self { - range_max, - less_than_chip: LessThanChip::new( + air: AssertedSortedAir { bus_index, range_max, - limb_bits, + limb_bits: limb_bits.clone(), decomp, key_vec_len, keys, + }, + is_less_than_tuple_chip: IsLessThanTupleChip::new( + bus_index, + range_max, + limb_bits, + decomp, + key_vec_len, ), + range_checker_gate: RangeCheckerGateChip::new(bus_index, range_max), } } @@ -61,21 +84,19 @@ impl AssertSortedChip { ) -> Vec> { // num_limbs is the number of sublimbs per limb of key, not including the // shifted last sublimb - let num_limbs = (*self.less_than_chip.air.limb_bits() + *self.less_than_chip.air.decomp() - - 1) - / *self.less_than_chip.air.decomp(); - let num_keys = *self.less_than_chip.air.key_vec_len(); + let num_keys = *self.air.key_vec_len(); let mut interactions = vec![]; // we will range check the decomposed limbs of the key for i in 0..num_keys { + let num_limbs = (self.air.limb_bits()[i] + *self.air.decomp() - 1) / *self.air.decomp(); // add 1 to account for the shifted last sublimb for j in 0..(num_limbs + 1) { interactions.push(Interaction { fields: vec![VirtualPairCol::single_main(cols.keys_decomp[i][j])], count: VirtualPairCol::constant(F::one()), - argument_index: *self.less_than_chip.bus_index(), + argument_index: *self.air.bus_index(), }); } } diff --git a/chips/src/assert_sorted/tests/mod.rs b/chips/src/assert_sorted/tests/mod.rs index 03524522bd..aeb029a431 100644 --- a/chips/src/assert_sorted/tests/mod.rs +++ b/chips/src/assert_sorted/tests/mod.rs @@ -1,6 +1,7 @@ use super::super::assert_sorted; use afs_stark_backend::prover::USE_DEBUG_BUILDER; +use afs_stark_backend::rap::AnyRap; use afs_stark_backend::verifier::VerificationError; use afs_test_utils::config::baby_bear_poseidon2::run_simple_test_no_pis; use assert_sorted::AssertSortedChip; @@ -33,50 +34,64 @@ use p3_matrix::dense::DenseMatrix; // most limb_bits bits, rows are sorted lexicographically #[test] fn test_assert_sorted_chip_small_positive() { - const BUS_INDEX: usize = 0; - const LIMB_BITS: usize = 16; - const DECOMP: usize = 8; - const KEY_VEC_LEN: usize = 2; + let bus_index: usize = 0; + let limb_bits: Vec = vec![16, 16]; + let decomp: usize = 8; + let key_vec_len: usize = 2; - const MAX: u32 = 1 << DECOMP; + let range_max: u32 = 1 << decomp; let requests = vec![vec![7784, 35423], vec![17558, 44832]]; let assert_sorted_chip = AssertSortedChip::new( - BUS_INDEX, - MAX, - LIMB_BITS, - DECOMP, - KEY_VEC_LEN, + bus_index, + range_max, + limb_bits, + decomp, + key_vec_len, requests.clone(), ); let assert_sorted_chip_trace: DenseMatrix = assert_sorted_chip.generate_trace(); - let assert_sorted_range_chip_trace: DenseMatrix = assert_sorted_chip - .less_than_chip - .range_checker_gate - .generate_trace(); - - run_simple_test_no_pis( - vec![ - &assert_sorted_chip, - &assert_sorted_chip.less_than_chip.range_checker_gate, - ], - vec![assert_sorted_chip_trace, assert_sorted_range_chip_trace], - ) - .expect("Verification failed"); + let assert_sorted_range_chip_trace: DenseMatrix = + assert_sorted_chip.range_checker_gate.generate_trace(); + + let mut chips: Vec<&dyn AnyRap<_>> = + vec![&assert_sorted_chip, &assert_sorted_chip.range_checker_gate]; + + for is_less_than_chip in assert_sorted_chip + .is_less_than_tuple_chip + .is_less_than_chips + .iter() + { + chips.push(&is_less_than_chip.range_checker_gate); + } + + let mut traces = vec![assert_sorted_chip_trace, assert_sorted_range_chip_trace]; + + for is_less_than_chip in assert_sorted_chip + .is_less_than_tuple_chip + .is_less_than_chips + .iter() + { + let range_trace: DenseMatrix = + is_less_than_chip.range_checker_gate.generate_trace(); + traces.push(range_trace); + } + + run_simple_test_no_pis(chips, traces).expect("Verification failed"); } // covers limb_bits >= 20, key_vec_len >= 4, limb_bits % decomp != 0, number of rows >= 4, each limb has at // most limb_bits bits, rows are sorted lexicographically #[test] fn test_assert_sorted_chip_large_positive() { - const BUS_INDEX: usize = 0; - const LIMB_BITS: usize = 30; - const DECOMP: usize = 8; - const KEY_VEC_LEN: usize = 4; + let bus_index: usize = 0; + let limb_bits: Vec = vec![30, 30, 30, 30]; + let decomp: usize = 8; + let key_vec_len: usize = 4; - const MAX: u32 = 1 << DECOMP; + let range_max: u32 = 1 << decomp; let requests = vec![ vec![35867, 318434, 12786, 44832], @@ -86,40 +101,54 @@ fn test_assert_sorted_chip_large_positive() { ]; let assert_sorted_chip = AssertSortedChip::new( - BUS_INDEX, - MAX, - LIMB_BITS, - DECOMP, - KEY_VEC_LEN, + bus_index, + range_max, + limb_bits, + decomp, + key_vec_len, requests.clone(), ); let assert_sorted_chip_trace: DenseMatrix = assert_sorted_chip.generate_trace(); - let assert_sorted_range_chip_trace: DenseMatrix = assert_sorted_chip - .less_than_chip - .range_checker_gate - .generate_trace(); - - run_simple_test_no_pis( - vec![ - &assert_sorted_chip, - &assert_sorted_chip.less_than_chip.range_checker_gate, - ], - vec![assert_sorted_chip_trace, assert_sorted_range_chip_trace], - ) - .expect("Verification failed"); + let assert_sorted_range_chip_trace: DenseMatrix = + assert_sorted_chip.range_checker_gate.generate_trace(); + + let mut chips: Vec<&dyn AnyRap<_>> = + vec![&assert_sorted_chip, &assert_sorted_chip.range_checker_gate]; + + for is_less_than_chip in assert_sorted_chip + .is_less_than_tuple_chip + .is_less_than_chips + .iter() + { + chips.push(&is_less_than_chip.range_checker_gate); + } + + let mut traces = vec![assert_sorted_chip_trace, assert_sorted_range_chip_trace]; + + for is_less_than_chip in assert_sorted_chip + .is_less_than_tuple_chip + .is_less_than_chips + .iter() + { + let range_trace: DenseMatrix = + is_less_than_chip.range_checker_gate.generate_trace(); + traces.push(range_trace); + } + + run_simple_test_no_pis(chips, traces).expect("Verification failed"); } // covers limb_bits >= 20, key_vec_len >= 4, limb_bits % decomp != 0, number of rows >= 4, at least one limb // has more than limb_bits bits, rows are sorted lexicographically #[test] fn test_assert_sorted_chip_largelimb_negative() { - const BUS_INDEX: usize = 0; - const LIMB_BITS: usize = 10; - const DECOMP: usize = 8; - const KEY_VEC_LEN: usize = 4; + let bus_index: usize = 0; + let limb_bits: Vec = vec![10, 10, 10, 10]; + let decomp: usize = 8; + let key_vec_len: usize = 4; - const MAX: u32 = 1 << DECOMP; + let range_max: u32 = 1 << decomp; // the first and second rows are not in sorted order let requests = vec![ @@ -130,27 +159,42 @@ fn test_assert_sorted_chip_largelimb_negative() { ]; let assert_sorted_chip = AssertSortedChip::new( - BUS_INDEX, - MAX, - LIMB_BITS, - DECOMP, - KEY_VEC_LEN, + bus_index, + range_max, + limb_bits, + decomp, + key_vec_len, requests.clone(), ); let assert_sorted_chip_trace: DenseMatrix = assert_sorted_chip.generate_trace(); - let assert_sorted_range_chip_trace: DenseMatrix = assert_sorted_chip - .less_than_chip - .range_checker_gate - .generate_trace(); - - let result = run_simple_test_no_pis( - vec![ - &assert_sorted_chip, - &assert_sorted_chip.less_than_chip.range_checker_gate, - ], - vec![assert_sorted_chip_trace, assert_sorted_range_chip_trace], - ); + let assert_sorted_range_chip_trace: DenseMatrix = + assert_sorted_chip.range_checker_gate.generate_trace(); + + let mut chips: Vec<&dyn AnyRap<_>> = + vec![&assert_sorted_chip, &assert_sorted_chip.range_checker_gate]; + + for is_less_than_chip in assert_sorted_chip + .is_less_than_tuple_chip + .is_less_than_chips + .iter() + { + chips.push(&is_less_than_chip.range_checker_gate); + } + + let mut traces = vec![assert_sorted_chip_trace, assert_sorted_range_chip_trace]; + + for is_less_than_chip in assert_sorted_chip + .is_less_than_tuple_chip + .is_less_than_chips + .iter() + { + let range_trace: DenseMatrix = + is_less_than_chip.range_checker_gate.generate_trace(); + traces.push(range_trace); + } + + let result = run_simple_test_no_pis(chips, traces); assert_eq!( result, @@ -163,12 +207,12 @@ fn test_assert_sorted_chip_largelimb_negative() { // most limb_bits bits, rows are not sorted lexicographically #[test] fn test_assert_sorted_chip_unsorted_negative() { - const BUS_INDEX: usize = 0; - const LIMB_BITS: usize = 30; - const DECOMP: usize = 8; - const KEY_VEC_LEN: usize = 4; + let bus_index: usize = 0; + let limb_bits: Vec = vec![30, 30, 30, 30]; + let decomp: usize = 8; + let key_vec_len: usize = 4; - const MAX: u32 = 1 << DECOMP; + let range_max: u32 = 1 << decomp; // the first and second rows are not in sorted order let requests = vec![ @@ -179,31 +223,46 @@ fn test_assert_sorted_chip_unsorted_negative() { ]; let assert_sorted_chip = AssertSortedChip::new( - BUS_INDEX, - MAX, - LIMB_BITS, - DECOMP, - KEY_VEC_LEN, + bus_index, + range_max, + limb_bits, + decomp, + key_vec_len, requests.clone(), ); let assert_sorted_chip_trace: DenseMatrix = assert_sorted_chip.generate_trace(); - let assert_sorted_range_chip_trace: DenseMatrix = assert_sorted_chip - .less_than_chip - .range_checker_gate - .generate_trace(); + let assert_sorted_range_chip_trace: DenseMatrix = + assert_sorted_chip.range_checker_gate.generate_trace(); + + let mut chips: Vec<&dyn AnyRap<_>> = + vec![&assert_sorted_chip, &assert_sorted_chip.range_checker_gate]; + + for is_less_than_chip in assert_sorted_chip + .is_less_than_tuple_chip + .is_less_than_chips + .iter() + { + chips.push(&is_less_than_chip.range_checker_gate); + } + + let mut traces = vec![assert_sorted_chip_trace, assert_sorted_range_chip_trace]; + + for is_less_than_chip in assert_sorted_chip + .is_less_than_tuple_chip + .is_less_than_chips + .iter() + { + let range_trace: DenseMatrix = + is_less_than_chip.range_checker_gate.generate_trace(); + traces.push(range_trace); + } USE_DEBUG_BUILDER.with(|debug| { *debug.lock().unwrap() = false; }); assert_eq!( - run_simple_test_no_pis( - vec![ - &assert_sorted_chip, - &assert_sorted_chip.less_than_chip.range_checker_gate, - ], - vec![assert_sorted_chip_trace, assert_sorted_range_chip_trace], - ), + run_simple_test_no_pis(chips, traces), Err(VerificationError::OodEvaluationMismatch), "Expected verification to fail, but it passed" ); diff --git a/chips/src/assert_sorted/trace.rs b/chips/src/assert_sorted/trace.rs index f5c8de8f6e..4465e45577 100644 --- a/chips/src/assert_sorted/trace.rs +++ b/chips/src/assert_sorted/trace.rs @@ -8,61 +8,53 @@ use super::{columns::AssertSortedCols, AssertSortedChip}; impl AssertSortedChip { pub fn generate_trace(&self) -> RowMajorMatrix { let num_cols: usize = AssertSortedCols::::get_width( - *self.less_than_chip.air.limb_bits(), - *self.less_than_chip.air.decomp(), - *self.less_than_chip.air.key_vec_len(), + self.air.limb_bits().clone(), + *self.air.decomp(), + *self.air.key_vec_len(), ); - let num_limbs = (*self.less_than_chip.air.limb_bits() + *self.less_than_chip.air.decomp() - - 1) - / *self.less_than_chip.air.decomp(); - - // to range check the last sublimb of the decomposed limb, we need to shift it to make sure it is in - // the correct range - let last_limb_shift = (self.less_than_chip.air.decomp() - - (self.less_than_chip.air.limb_bits() % self.less_than_chip.air.decomp())) - % self.less_than_chip.air.decomp(); - let mut rows: Vec = vec![]; - for i in 0..*self.less_than_chip.air.key_vec_len() { - let key = self.less_than_chip.air.keys()[i].clone(); - let next_key: Vec = if i == *self.less_than_chip.air.key_vec_len() - 1 { - vec![0; *self.less_than_chip.air.key_vec_len()] + for i in 0..*self.air.key_vec_len() { + let key = self.air.keys()[i].clone(); + let next_key: Vec = if i == *self.air.key_vec_len() - 1 { + vec![0; *self.air.key_vec_len()] } else { - self.less_than_chip.air.keys()[i + 1].clone() + self.air.keys()[i + 1].clone() }; - let less_than_trace = LocalTraceInstructions::generate_trace_row( - &self.less_than_chip, + let is_less_than_tuple_trace = LocalTraceInstructions::generate_trace_row( + &self.is_less_than_tuple_chip, (key.clone(), next_key.clone()), ) .flatten(); let mut key_decomp_trace: Vec = vec![]; // decompose each limb into sublimbs of size self.decomp() bits - for &val in key.iter() { + for (i, &val) in key.iter().enumerate() { + let num_limbs = + (self.air.limb_bits()[i] + self.air.decomp() - 1) / self.air.decomp(); + let last_limb_shift = (self.air.decomp() + - (self.air.limb_bits()[i] % self.air.decomp())) + % self.air.decomp(); + for i in 0..num_limbs { - let bits = (val >> (i * self.less_than_chip.air.decomp())) - & ((1 << self.less_than_chip.air.decomp()) - 1); + let bits = (val >> (i * self.air.decomp())) & ((1 << self.air.decomp()) - 1); key_decomp_trace.push(F::from_canonical_u32(bits)); - self.less_than_chip.range_checker_gate.add_count(bits); + self.range_checker_gate.add_count(bits); } // the last sublimb should be of size self.limb_bits() % self.decomp() bits, // so we need to shift it to constrain this - let bits = (val >> ((num_limbs - 1) * self.less_than_chip.air.decomp())) - & ((1 << self.less_than_chip.air.decomp()) - 1); - if (bits << last_limb_shift) < *self.range_max() { - self.less_than_chip - .range_checker_gate - .add_count(bits << last_limb_shift); + let bits = + (val >> ((num_limbs - 1) * self.air.decomp())) & ((1 << self.air.decomp()) - 1); + if (bits << last_limb_shift) < *self.air.range_max() { + self.range_checker_gate.add_count(bits << last_limb_shift); } key_decomp_trace.push(F::from_canonical_u32(bits << last_limb_shift)); } - let mut row: Vec = - less_than_trace[0..*self.less_than_chip.air.key_vec_len()].to_vec(); + let mut row: Vec = is_less_than_tuple_trace[0..2 * *self.air.key_vec_len()].to_vec(); row.extend_from_slice(&key_decomp_trace); - row.extend_from_slice(&less_than_trace[*self.less_than_chip.air.key_vec_len()..]); + row.extend_from_slice(&is_less_than_tuple_trace[2 * *self.air.key_vec_len()..]); rows.extend_from_slice(&row); } diff --git a/chips/src/is_less_than_tuple/chip.rs b/chips/src/is_less_than_tuple/chip.rs index cf79a648f9..5eb6d30f1c 100644 --- a/chips/src/is_less_than_tuple/chip.rs +++ b/chips/src/is_less_than_tuple/chip.rs @@ -1,9 +1,55 @@ -use crate::sub_chip::SubAirWithInteractions; -use afs_stark_backend::interaction::Chip; +use crate::{is_less_than::columns::IsLessThanCols, sub_chip::SubAirWithInteractions}; +use afs_stark_backend::interaction::{Chip, Interaction}; use p3_field::PrimeField64; -use super::IsLessThanTupleChip; +use super::{columns::IsLessThanTupleCols, IsLessThanTupleChip}; -impl Chip for IsLessThanTupleChip {} +impl Chip for IsLessThanTupleChip { + fn sends(&self) -> Vec> { + let num_cols = IsLessThanTupleCols::::get_width( + self.air.limb_bits().clone(), + *self.air.decomp(), + *self.air.tuple_len(), + ); + let all_cols = (0..num_cols).collect::>(); -impl SubAirWithInteractions for IsLessThanTupleChip {} + let cols_numbered = IsLessThanTupleCols::::from_slice( + &all_cols, + self.air.limb_bits().clone(), + *self.air.decomp(), + *self.air.tuple_len(), + ); + + SubAirWithInteractions::sends(self, cols_numbered) + } +} + +impl SubAirWithInteractions for IsLessThanTupleChip { + fn sends(&self, col_indices: IsLessThanTupleCols) -> Vec> { + // num_limbs is the number of limbs, not including the last shifted limb + let mut interactions = vec![]; + + for i in 0..*self.air.tuple_len() { + let mut is_less_than_cols = vec![ + col_indices.io.x[i], + col_indices.io.y[i], + col_indices.aux.less_than[i], + col_indices.aux.lower_bits[i], + ]; + + is_less_than_cols.extend_from_slice(&col_indices.aux.lower_bits_decomp[i]); + + let is_less_than_cols = IsLessThanCols::::from_slice( + &is_less_than_cols, + self.air.limb_bits().clone()[i], + *self.air.decomp(), + ); + + let curr_interactions = + SubAirWithInteractions::::sends(&self.is_less_than_chips[i], is_less_than_cols); + interactions.extend(curr_interactions); + } + + interactions + } +} diff --git a/chips/src/is_less_than_tuple/tests/mod.rs b/chips/src/is_less_than_tuple/tests/mod.rs index 20fa58f146..fa80af1761 100644 --- a/chips/src/is_less_than_tuple/tests/mod.rs +++ b/chips/src/is_less_than_tuple/tests/mod.rs @@ -1,9 +1,12 @@ use super::super::is_less_than_tuple::IsLessThanTupleChip; use afs_stark_backend::prover::USE_DEBUG_BUILDER; +use afs_stark_backend::rap::AnyRap; use afs_stark_backend::verifier::VerificationError; use afs_test_utils::config::baby_bear_poseidon2::run_simple_test_no_pis; +use p3_baby_bear::BabyBear; use p3_field::AbstractField; +use p3_matrix::dense::DenseMatrix; #[test] fn test_is_less_than_tuple_chip_lt() { @@ -16,7 +19,21 @@ fn test_is_less_than_tuple_chip_lt() { let chip = IsLessThanTupleChip::new(bus_index, range_max, limb_bits, decomp, tuple_len); let trace = chip.generate_trace(vec![14321, 123], vec![26678, 233]); - run_simple_test_no_pis(vec![&chip], vec![trace]).expect("Verification failed"); + let mut chips: Vec<&dyn AnyRap<_>> = vec![&chip]; + + for is_less_than_chip in chip.is_less_than_chips.iter() { + chips.push(&is_less_than_chip.range_checker_gate); + } + + let mut traces = vec![trace]; + + for is_less_than_chip in chip.is_less_than_chips.iter() { + let range_trace: DenseMatrix = + is_less_than_chip.range_checker_gate.generate_trace(); + traces.push(range_trace); + } + + run_simple_test_no_pis(chips, traces).expect("Verification failed"); } #[test] @@ -30,9 +47,21 @@ fn test_is_less_than_tuple_chip_gt() { let chip = IsLessThanTupleChip::new(bus_index, range_max, limb_bits, decomp, tuple_len); let trace = chip.generate_trace(vec![14321, 244], vec![26678, 233]); - println!("{:?}", trace); + let mut chips: Vec<&dyn AnyRap<_>> = vec![&chip]; - run_simple_test_no_pis(vec![&chip], vec![trace]).expect("Verification failed"); + for is_less_than_chip in chip.is_less_than_chips.iter() { + chips.push(&is_less_than_chip.range_checker_gate); + } + + let mut traces = vec![trace]; + + for is_less_than_chip in chip.is_less_than_chips.iter() { + let range_trace: DenseMatrix = + is_less_than_chip.range_checker_gate.generate_trace(); + traces.push(range_trace); + } + + run_simple_test_no_pis(chips, traces).expect("Verification failed"); } #[test] @@ -46,7 +75,21 @@ fn test_is_less_than_tuple_chip_eq() { let chip = IsLessThanTupleChip::new(bus_index, range_max, limb_bits, decomp, tuple_len); let trace = chip.generate_trace(vec![14321, 244], vec![14321, 244]); - run_simple_test_no_pis(vec![&chip], vec![trace]).expect("Verification failed"); + let mut chips: Vec<&dyn AnyRap<_>> = vec![&chip]; + + for is_less_than_chip in chip.is_less_than_chips.iter() { + chips.push(&is_less_than_chip.range_checker_gate); + } + + let mut traces = vec![trace]; + + for is_less_than_chip in chip.is_less_than_chips.iter() { + let range_trace: DenseMatrix = + is_less_than_chip.range_checker_gate.generate_trace(); + traces.push(range_trace); + } + + run_simple_test_no_pis(chips, traces).expect("Verification failed"); } #[test] @@ -62,11 +105,25 @@ fn test_is_less_than_tuple_chip_negative() { trace.values[2] = AbstractField::from_canonical_u64(0); + let mut chips: Vec<&dyn AnyRap<_>> = vec![&chip]; + + for is_less_than_chip in chip.is_less_than_chips.iter() { + chips.push(&is_less_than_chip.range_checker_gate); + } + + let mut traces = vec![trace]; + + for is_less_than_chip in chip.is_less_than_chips.iter() { + let range_trace: DenseMatrix = + is_less_than_chip.range_checker_gate.generate_trace(); + traces.push(range_trace); + } + USE_DEBUG_BUILDER.with(|debug| { *debug.lock().unwrap() = false; }); assert_eq!( - run_simple_test_no_pis(vec![&chip], vec![trace],), + run_simple_test_no_pis(chips, traces), Err(VerificationError::OodEvaluationMismatch), "Expected verification to fail, but it passed" ); diff --git a/chips/src/less_than/air.rs b/chips/src/less_than/air.rs deleted file mode 100644 index 1623f684b8..0000000000 --- a/chips/src/less_than/air.rs +++ /dev/null @@ -1,146 +0,0 @@ -use std::borrow::Borrow; - -use p3_air::{Air, AirBuilder, BaseAir}; -use p3_field::{AbstractField, Field}; -use p3_matrix::Matrix; - -use crate::sub_chip::{AirConfig, SubAir}; - -use super::{ - columns::{LessThanAuxCols, LessThanCols, LessThanIOCols}, - LessThanAir, LessThanChip, -}; - -impl AirConfig for LessThanChip { - type Cols = LessThanCols; -} - -impl BaseAir for LessThanChip { - fn width(&self) -> usize { - LessThanCols::::get_width( - *self.air.limb_bits(), - *self.air.decomp(), - *self.air.key_vec_len(), - ) - } -} - -impl Air for LessThanChip { - fn eval(&self, builder: &mut AB) { - let main = builder.main(); - - let (local, next) = (main.row_slice(0), main.row_slice(1)); - let local: &[AB::Var] = (*local).borrow(); - let next: &[AB::Var] = (*next).borrow(); - - let [local_cols, next_cols] = [local, next].map(|view| { - LessThanCols::::from_slice( - view, - *self.air.limb_bits(), - *self.air.decomp(), - *self.air.key_vec_len(), - ) - }); - - SubAir::eval( - &self.air, - builder, - [local_cols.io, next_cols.io], - local_cols.aux, - ); - } -} - -// sub-chip with constraints to check whether one key is less than the next (row-wise) -impl SubAir for LessThanAir { - type IoView = [LessThanIOCols; 2]; - type AuxView = LessThanAuxCols; - - // constrain that local_key < next_key lexicographically - fn eval(&self, builder: &mut AB, io: Self::IoView, aux: Self::AuxView) { - let local_key = io[0].key.clone(); - let next_key = io[1].key.clone(); - - let local_aux = &aux; - - // num_limbs is the number of sublimbs per limb, not including the shifted last sublimb - let num_limbs = (self.limb_bits() + self.decomp() - 1) / self.decomp(); - - let lower_bits = &local_aux.lower_bits; - let upper_bit = &local_aux.upper_bit; - let lower_bits_decomp = &local_aux.lower_bits_decomp; - - // we want to check these constraints for each row except the last one - let mut when_transition = builder.when_transition(); - - // to range check the last sublimb of the decomposed limb, we need to shift it to make sure it is in - // the correct range - let last_limb_shift = (self.decomp() - (self.limb_bits() % self.decomp())) % self.decomp(); - - for (i, (key_local, key_next)) in local_key.iter().zip(next_key.iter()).enumerate() { - // this is the desired intermediate value (i.e. 2^limb_bits + b - a - 1) - let intermed_val = *key_next - *key_local - + AB::Expr::from_canonical_u64(1 << self.limb_bits()) - - AB::Expr::one(); - - // constrain that lower_bits[i] + upper_bit[i] * 2^limb_bits is the correct intermediate sum - let check_val = - lower_bits[i] + upper_bit[i] * AB::Expr::from_canonical_u64(1 << self.limb_bits()); - when_transition.assert_eq(intermed_val, check_val); - - // constrain that diff is the difference between the two elements of consecutive rows - let diff = *key_next - *key_local; - when_transition.assert_eq(diff, local_aux.diff[i]); - } - - for i in 0..*self.key_vec_len() { - let mut lower_bits_from_decomp: AB::Expr = AB::Expr::zero(); - // constrain that the decomposition of each lower_bits element is correct - for j in 0..num_limbs { - lower_bits_from_decomp += lower_bits_decomp[i][j] - * AB::Expr::from_canonical_u64(1 << (j * self.decomp())); - } - - // constrain that the shifted last limb is shifted correctly - let shifted_val = lower_bits_decomp[i][num_limbs - 1] - * AB::Expr::from_canonical_u64(1 << last_limb_shift); - - when_transition.assert_eq(lower_bits_decomp[i][num_limbs], shifted_val); - when_transition.assert_eq(lower_bits_from_decomp, lower_bits[i]); - } - - for upper_bit_value in upper_bit { - // constrain that each element in upper_bit is a boolean - let is_bool = *upper_bit_value * (AB::Expr::one() - *upper_bit_value); - when_transition.assert_zero(is_bool); - } - - for i in 0..*self.key_vec_len() { - let diff = local_aux.diff[i]; - let is_equal = local_aux.is_zero[i]; - let inverse = local_aux.inverses[i]; - - // check that diff * is_equal = 0 - when_transition.assert_zero(diff * is_equal); - // check that is_equal is boolean - when_transition.assert_zero(is_equal * (AB::Expr::one() - is_equal)); - // check that inverse * (diff + is_equal) = 1 - when_transition.assert_one(inverse * (diff + is_equal)); - } - - // to check whether one row is less than another, we can use the indicators to generate a boolean - // expression; the idea is that, starting at the most significant limb, a row is less than the next - // if all the limbs more significant are equal and the current limb is less than the corresponding - // limb in the next row - let mut check_less_than: AB::Expr = AB::Expr::zero(); - - for (i, &upper_bit_value) in upper_bit.iter().enumerate() { - let mut curr_expr: AB::Expr = upper_bit_value.into(); - for &is_zero_value in &local_aux.is_zero[i + 1..] { - curr_expr *= is_zero_value.into(); - } - check_less_than += curr_expr; - } - when_transition.assert_one(check_less_than); - } -} diff --git a/chips/src/less_than/chip.rs b/chips/src/less_than/chip.rs deleted file mode 100644 index 726f352e3c..0000000000 --- a/chips/src/less_than/chip.rs +++ /dev/null @@ -1,55 +0,0 @@ -use crate::sub_chip::SubAirWithInteractions; - -use super::columns::LessThanCols; -use afs_stark_backend::interaction::{Chip, Interaction}; -use p3_air::VirtualPairCol; -use p3_field::PrimeField64; - -use super::LessThanChip; - -impl Chip for LessThanChip { - fn sends(&self) -> Vec> { - let num_cols = LessThanCols::::get_width( - *self.air.limb_bits(), - *self.air.decomp(), - *self.air.key_vec_len(), - ); - let all_cols = (0..num_cols).collect::>(); - - let cols_numbered = LessThanCols::::from_slice( - &all_cols, - *self.air.limb_bits(), - *self.air.decomp(), - *self.air.key_vec_len(), - ); - - SubAirWithInteractions::sends(self, cols_numbered) - } -} - -impl SubAirWithInteractions for LessThanChip { - fn sends(&self, col_indices: LessThanCols) -> Vec> { - // num_limbs is the number of sublimbs per limb of key, not including the - // shifted last sublimb - let num_limbs = (*self.air.limb_bits() + *self.air.decomp() - 1) / *self.air.decomp(); - let num_keys = *self.air.key_vec_len(); - - let mut interactions = vec![]; - - // we range check the limbs of the lower_bits so that we know each element - // of lower_bits has at most limb_bits bits - for i in 0..num_keys { - for j in 0..(num_limbs + 1) { - interactions.push(Interaction { - fields: vec![VirtualPairCol::single_main( - col_indices.aux.lower_bits_decomp[i][j], - )], - count: VirtualPairCol::constant(F::one()), - argument_index: *self.bus_index(), - }); - } - } - - interactions - } -} diff --git a/chips/src/less_than/columns.rs b/chips/src/less_than/columns.rs deleted file mode 100644 index 5bd687eb2a..0000000000 --- a/chips/src/less_than/columns.rs +++ /dev/null @@ -1,120 +0,0 @@ -use afs_derive::AlignedBorrow; - -#[derive(Default, AlignedBorrow)] -pub struct LessThanIOCols { - pub key: Vec, -} - -pub struct LessThanAuxCols { - pub lower_bits: Vec, - pub upper_bit: Vec, - pub lower_bits_decomp: Vec>, - pub diff: Vec, - pub is_zero: Vec, - pub inverses: Vec, -} - -pub struct LessThanCols { - pub io: LessThanIOCols, - pub aux: LessThanAuxCols, -} - -impl LessThanCols { - pub fn from_slice(slc: &[T], limb_bits: usize, decomp: usize, key_vec_len: usize) -> Self { - // num_limbs is the number of sublimbs per limb, not including the shifted last sublimb - let num_limbs = (limb_bits + decomp - 1) / decomp; - let mut cur_start_idx = 0; - let mut cur_end_idx = key_vec_len; - - // the first key_vec_len elements are the key itself - let key = slc[cur_start_idx..cur_end_idx].to_vec(); - cur_start_idx = cur_end_idx; - cur_end_idx += key_vec_len; - - // the next key_vec_len elements are the values of the lower num_limbs bits of the intermediate sum - let lower_bits = slc[cur_start_idx..cur_end_idx].to_vec(); - cur_start_idx = cur_end_idx; - cur_end_idx += key_vec_len; - - // the next key_vec_len elements are the values of the upper bit of the intermediate sum; note that - // b > a <=> upper_bit = 1 - let upper_bit = slc[cur_start_idx..cur_end_idx].to_vec(); - cur_start_idx = cur_end_idx; - cur_end_idx += key_vec_len * (num_limbs + 1); - - // the next key_vec_len * (num_limbs + 1) elements are the decomposed limbs of the lower bits of the - // intermediate sum - let lower_bits_decomp = slc[cur_start_idx..cur_end_idx] - .chunks(num_limbs + 1) - .map(|chunk| chunk.to_vec()) - .collect(); - cur_start_idx = cur_end_idx; - cur_end_idx += key_vec_len; - - // the next key_vec_len elements are the difference between consecutive limbs of rows - let diff = slc[cur_start_idx..cur_end_idx].to_vec(); - cur_start_idx = cur_end_idx; - cur_end_idx += key_vec_len; - - // the next key_vec_len elements are the indicator whether the difference is zero; if difference is - // zero then the two limbs must be equal - let is_zero = slc[cur_start_idx..cur_end_idx].to_vec(); - cur_start_idx = cur_end_idx; - cur_end_idx += key_vec_len; - - // the next key_vec_len elements contain the inverses of the corresponding sum of diff and is_zero; - // note that this sum will always be nonzero so the inverse will exist - let inverses = slc[cur_start_idx..cur_end_idx].to_vec(); - - let io = LessThanIOCols { key }; - let aux = LessThanAuxCols { - lower_bits, - upper_bit, - lower_bits_decomp, - diff, - is_zero, - inverses, - }; - - Self { io, aux } - } - - pub fn flatten(&self) -> Vec { - let mut flattened = vec![]; - flattened.extend_from_slice(&self.io.key); - flattened.extend_from_slice(&self.aux.lower_bits); - flattened.extend_from_slice(&self.aux.upper_bit); - for decomp_vec in &self.aux.lower_bits_decomp { - flattened.extend_from_slice(decomp_vec); - } - flattened.extend_from_slice(&self.aux.diff); - flattened.extend_from_slice(&self.aux.is_zero); - flattened.extend_from_slice(&self.aux.inverses); - - flattened - } - - pub fn get_width(limb_bits: usize, decomp: usize, key_vec_len: usize) -> usize { - // there are (limb_bits + decomp - 1) / decomp sublimbs per limb, we add 1 to - // account for the sublimb itself, and another 1 to account for the shifted - // last sublimb - let mut width = 0; - // for the key itself - width += key_vec_len; - // for the lower_bits - width += key_vec_len; - // for the upper_bit - width += key_vec_len; - // for the decomposed lower_bits - let num_limbs = (limb_bits + decomp - 1) / decomp; - width += key_vec_len * (num_limbs + 1); - // for the difference between consecutive rows - width += key_vec_len; - // for the indicator whether difference is zero - width += key_vec_len; - // for the y such that y * (i + x) = 1 - width += key_vec_len; - - width - } -} diff --git a/chips/src/less_than/mod.rs b/chips/src/less_than/mod.rs deleted file mode 100644 index 28643381a3..0000000000 --- a/chips/src/less_than/mod.rs +++ /dev/null @@ -1,64 +0,0 @@ -use crate::range_gate::RangeCheckerGateChip; -use getset::Getters; - -#[cfg(test)] -pub mod tests; - -pub mod air; -pub mod chip; -pub mod columns; -pub mod trace; - -#[derive(Default, Getters)] -pub struct LessThanAir { - #[getset(get = "pub")] - range_max: u32, - #[getset(get = "pub")] - limb_bits: usize, - #[getset(get = "pub")] - decomp: usize, - #[getset(get = "pub")] - key_vec_len: usize, - #[getset(get = "pub")] - keys: Vec>, -} - -/** - * This Chip constrains that consecutive rows are sorted lexicographically. - * - * Each row consists of a key decomposed into limbs with at most limb_bits bits - */ -#[derive(Default, Getters)] -pub struct LessThanChip { - pub air: LessThanAir, - - #[getset(get = "pub")] - bus_index: usize, - - pub range_checker_gate: RangeCheckerGateChip, -} - -impl LessThanChip { - pub fn new( - bus_index: usize, - range_max: u32, - limb_bits: usize, - decomp: usize, - key_vec_len: usize, - keys: Vec>, - ) -> Self { - let air = LessThanAir { - range_max, - limb_bits, - decomp, - key_vec_len, - keys, - }; - - Self { - air, - bus_index, - range_checker_gate: RangeCheckerGateChip::new(bus_index, range_max), - } - } -} diff --git a/chips/src/less_than/tests/mod.rs b/chips/src/less_than/tests/mod.rs deleted file mode 100644 index 181c713333..0000000000 --- a/chips/src/less_than/tests/mod.rs +++ /dev/null @@ -1,139 +0,0 @@ -use super::super::less_than::LessThanChip; - -use afs_stark_backend::prover::USE_DEBUG_BUILDER; -use afs_stark_backend::verifier::VerificationError; -use afs_test_utils::config::baby_bear_poseidon2::run_simple_test_no_pis; -use p3_baby_bear::BabyBear; -use p3_matrix::dense::DenseMatrix; - -/** - * Testing strategy for the less than chip: - * partition on limb_bits: - * limb_bits < 20 - * limb_bits >= 20 - * partition on key_vec_len: - * key_vec_len < 4 - * key_vec_len >= 4 - * partition on decomp: - * limb_bits % decomp == 0 - * limb_bits % decomp != 0 - * partition on number of rows: - * number of rows < 4 - * number of rows >= 4 - * partition on row order: - * rows are sorted lexicographically - * rows are not sorted lexicographically - */ - -// covers limb_bits < 20, key_vec_len < 4, limb_bits % decomp == 0, number of rows < 4, rows are sorted lexicographically -#[test] -fn test_less_than_chip_small_positive() { - const BUS_INDEX: usize = 0; - const LIMB_BITS: usize = 16; - const DECOMP: usize = 8; - const KEY_VEC_LEN: usize = 2; - - const MAX: u32 = 1 << DECOMP; - - let requests = vec![vec![7784, 35423], vec![17558, 44832]]; - - let less_than_chip = LessThanChip::new( - BUS_INDEX, - MAX, - LIMB_BITS, - DECOMP, - KEY_VEC_LEN, - requests.clone(), - ); - - let less_than_chip_trace: DenseMatrix = less_than_chip.generate_trace(); - let less_than_range_chip_trace: DenseMatrix = - less_than_chip.range_checker_gate.generate_trace(); - - run_simple_test_no_pis( - vec![&less_than_chip, &less_than_chip.range_checker_gate], - vec![less_than_chip_trace, less_than_range_chip_trace], - ) - .expect("Verification failed"); -} - -// covers limb_bits >= 20, key_vec_len >= 4, limb_bits % decomp != 0, number of rows >= 4, rows are sorted lexicographically -#[test] -fn test_less_than_chip_large_positive() { - const BUS_INDEX: usize = 0; - const LIMB_BITS: usize = 30; - const DECOMP: usize = 8; - const KEY_VEC_LEN: usize = 4; - - const MAX: u32 = 1 << DECOMP; - - let requests = vec![ - vec![35867, 318434, 12786, 44832], - vec![704210, 369315, 42421, 487111], - vec![370183, 37202, 729789, 783571], - vec![875005, 767547, 196209, 887921], - ]; - - let less_than_chip = LessThanChip::new( - BUS_INDEX, - MAX, - LIMB_BITS, - DECOMP, - KEY_VEC_LEN, - requests.clone(), - ); - - let less_than_chip_trace: DenseMatrix = less_than_chip.generate_trace(); - let less_than_range_chip_trace: DenseMatrix = - less_than_chip.range_checker_gate.generate_trace(); - - run_simple_test_no_pis( - vec![&less_than_chip, &less_than_chip.range_checker_gate], - vec![less_than_chip_trace, less_than_range_chip_trace], - ) - .expect("Verification failed"); -} - -// covers limb_bits >= 20, key_vec_len >= 4, limb_bits % decomp != 0, number of rows >= 4, rows are not sorted lexicographically -#[test] -fn test_less_than_chip_unsorted_negative() { - const BUS_INDEX: usize = 0; - const LIMB_BITS: usize = 30; - const DECOMP: usize = 8; - const KEY_VEC_LEN: usize = 4; - - const MAX: u32 = 1 << DECOMP; - - // the first and second rows are not in sorted order - let requests = vec![ - vec![704210, 369315, 42421, 44832], - vec![35867, 318434, 12786, 44832], - vec![370183, 37202, 729789, 783571], - vec![875005, 767547, 196209, 887921], - ]; - - let less_than_chip = LessThanChip::new( - BUS_INDEX, - MAX, - LIMB_BITS, - DECOMP, - KEY_VEC_LEN, - requests.clone(), - ); - - let less_than_chip_trace: DenseMatrix = less_than_chip.generate_trace(); - let less_than_range_chip_trace: DenseMatrix = - less_than_chip.range_checker_gate.generate_trace(); - - USE_DEBUG_BUILDER.with(|debug| { - *debug.lock().unwrap() = false; - }); - assert_eq!( - run_simple_test_no_pis( - vec![&less_than_chip, &less_than_chip.range_checker_gate,], - vec![less_than_chip_trace, less_than_range_chip_trace], - ), - Err(VerificationError::OodEvaluationMismatch), - "Expected verification to fail, but it passed" - ); -} diff --git a/chips/src/less_than/trace.rs b/chips/src/less_than/trace.rs deleted file mode 100644 index 00c454f7dd..0000000000 --- a/chips/src/less_than/trace.rs +++ /dev/null @@ -1,127 +0,0 @@ -use p3_field::PrimeField64; -use p3_matrix::dense::RowMajorMatrix; - -use crate::sub_chip::LocalTraceInstructions; - -use super::{ - columns::{LessThanAuxCols, LessThanCols, LessThanIOCols}, - LessThanChip, -}; - -impl LessThanChip { - pub fn generate_trace(&self) -> RowMajorMatrix { - let num_cols: usize = LessThanCols::::get_width( - *self.air.limb_bits(), - *self.air.decomp(), - *self.air.key_vec_len(), - ); - - let mut rows: Vec = vec![]; - for i in 0..*self.air.key_vec_len() { - let key = self.air.keys[i].clone(); - let next_key: Vec = if i == *self.air.key_vec_len() - 1 { - vec![0; *self.air.key_vec_len()] - } else { - self.air.keys[i + 1].clone() - }; - let row = self.generate_trace_row((key, next_key)).flatten(); - rows.extend_from_slice(&row); - } - - RowMajorMatrix::new(rows, num_cols) - } -} - -impl LocalTraceInstructions for LessThanChip { - type LocalInput = (Vec, Vec); - - fn generate_trace_row(&self, consecutive_keys: (Vec, Vec)) -> Self::Cols { - let (key, next_key) = consecutive_keys; - let num_limbs = (self.air.limb_bits() + self.air.decomp() - 1) / self.air.decomp(); - let last_limb_shift = - (self.air.decomp() - (self.air.limb_bits() % self.air.decomp())) % self.air.decomp(); - - // the lower limb_bits bits of the corresponding check value - let mut lower_bits: Vec = vec![]; - let mut lower_bits_u32: Vec = vec![]; - // the (n + 1)st bits of the corresponding check value, will be 1 if a < b - let mut upper_bit: Vec = vec![]; - - // contains the difference between consecutive rows - let mut diff: Vec = vec![]; - // contains indicator whether difference is zero - let mut is_zero: Vec = vec![]; - // contains y such that y * (i + x) = 1 - let mut inverses: Vec = vec![]; - - // we compute the indicators, which only matter if the row is not the last - for (j, &val) in key.iter().enumerate() { - let next_val = next_key[j]; - // compute 2^limb_bits + next_val - val - 1 - let check_less_than = (1 << self.air.limb_bits()) + next_val - val - 1; - - // the lower limb_bits bits of the check value - lower_bits.push(F::from_canonical_u32( - check_less_than & ((1 << self.air.limb_bits()) - 1), - )); - // we also need the u32 value to compute the decomposition later - lower_bits_u32.push(check_less_than & ((1 << self.air.limb_bits()) - 1)); - // the (n + 1)st bit of the check value, will be 1 if a < b - upper_bit.push(F::from_canonical_u32( - check_less_than >> self.air.limb_bits(), - )); - - // the difference between the two limbs - let curr_diff = F::from_canonical_u32(next_val) - F::from_canonical_u32(val); - diff.push(curr_diff); - - // compute the equal indicator and inverses - if next_val == val { - is_zero.push(F::one()); - inverses.push((curr_diff + F::one()).inverse()); - } else { - is_zero.push(F::zero()); - inverses.push(curr_diff.inverse()); - } - } - - let mut lower_bits_decomp: Vec> = vec![]; - - // decompose each element of lower_bits so we can range check that the element - // has at most limb_bits bits - for i in 0..lower_bits_u32.len() { - let val = lower_bits_u32[i]; - if i != lower_bits_u32.len() { - let mut curr_decomp: Vec = vec![]; - for j in 0..num_limbs { - let bits = (val >> (j * self.air.decomp())) & ((1 << self.air.decomp()) - 1); - curr_decomp.push(F::from_canonical_u32(bits)); - self.range_checker_gate.add_count(bits); - } - let bits = - (val >> ((num_limbs - 1) * self.air.decomp())) & ((1 << self.air.decomp()) - 1); - if (bits << last_limb_shift) < *self.air.range_max() { - self.range_checker_gate.add_count(bits << last_limb_shift); - } - curr_decomp.push(F::from_canonical_u32(bits << last_limb_shift)); - lower_bits_decomp.push(curr_decomp); - } else { - lower_bits_decomp.push(vec![F::zero(); num_limbs + 1]); - } - } - - let io = LessThanIOCols { - key: key.into_iter().map(F::from_canonical_u32).collect(), - }; - let aux = LessThanAuxCols { - lower_bits, - upper_bit, - lower_bits_decomp, - diff, - is_zero, - inverses, - }; - - LessThanCols { io, aux } - } -} diff --git a/chips/src/lib.rs b/chips/src/lib.rs index e9cd8fe456..a1b44eceb8 100644 --- a/chips/src/lib.rs +++ b/chips/src/lib.rs @@ -5,7 +5,6 @@ pub mod is_less_than; pub mod is_less_than_tuple; pub mod is_zero; pub mod keccak_permute; -pub mod less_than; pub mod merkle_proof; pub mod page_controller; pub mod page_read; From d0ae2cfb430e10a8788c2534fc53cabcaa5557bd Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Wed, 5 Jun 2024 12:41:08 -0400 Subject: [PATCH 14/46] chore: address comments first pass --- Cargo.lock | 34 +++++ chips/Cargo.toml | 1 + chips/src/assert_sorted/chip.rs | 2 +- chips/src/assert_sorted/mod.rs | 9 +- chips/src/assert_sorted/tests/mod.rs | 153 +++++++--------------- chips/src/assert_sorted/trace.rs | 4 +- chips/src/is_less_than/air.rs | 32 ++--- chips/src/is_less_than/chip.rs | 15 +-- chips/src/is_less_than/columns.rs | 22 ++-- chips/src/is_less_than/mod.rs | 44 +++++-- chips/src/is_less_than/tests/mod.rs | 36 +++-- chips/src/is_less_than/trace.rs | 26 ++-- chips/src/is_less_than_tuple/air.rs | 15 ++- chips/src/is_less_than_tuple/chip.rs | 8 +- chips/src/is_less_than_tuple/mod.rs | 41 +++--- chips/src/is_less_than_tuple/tests/mod.rs | 102 +++++---------- chips/src/is_less_than_tuple/trace.rs | 14 +- 17 files changed, 272 insertions(+), 286 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 17ff15ad1b..eefecc346f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -30,6 +30,7 @@ dependencies = [ "p3-util", "parking_lot", "rand", + "test-case", "tracing", "tracing-forest", "tracing-subscriber", @@ -874,6 +875,39 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "test-case" +version = "3.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb2550dd13afcd286853192af8601920d959b14c401fcece38071d53bf0768a8" +dependencies = [ + "test-case-macros", +] + +[[package]] +name = "test-case-core" +version = "3.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adcb7fd841cd518e279be3d5a3eb0636409487998a4aff22f3de87b81e88384f" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "syn 2.0.63", +] + +[[package]] +name = "test-case-macros" +version = "3.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c89e72a01ed4c579669add59014b9a524d609c0c88c6a585ce37485879f6ffb" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.63", + "test-case-core", +] + [[package]] name = "thiserror" version = "1.0.60" diff --git a/chips/Cargo.toml b/chips/Cargo.toml index 8e7da3b73d..ab9d43674e 100644 --- a/chips/Cargo.toml +++ b/chips/Cargo.toml @@ -40,6 +40,7 @@ p3-poseidon2 = { workspace = true } p3-symmetric = { workspace = true } tracing-subscriber = { version = "0.3.17", features = ["std", "env-filter"] } tracing-forest = { version = "0.1.6", features = ["ansi", "smallvec"] } +test-case = "3.3.1" [features] default = ["test-traits"] diff --git a/chips/src/assert_sorted/chip.rs b/chips/src/assert_sorted/chip.rs index b767d0b14b..15ef662e84 100644 --- a/chips/src/assert_sorted/chip.rs +++ b/chips/src/assert_sorted/chip.rs @@ -33,7 +33,7 @@ impl Chip for AssertSortedChip { interactions.push(Interaction { fields: vec![VirtualPairCol::single_main(cols_numbered.keys_decomp[i][j])], count: VirtualPairCol::constant(F::one()), - argument_index: self.range_checker_gate.bus_index(), + argument_index: self.range_checker.bus_index(), }); } } diff --git a/chips/src/assert_sorted/mod.rs b/chips/src/assert_sorted/mod.rs index 798491b4f7..a0ad52d973 100644 --- a/chips/src/assert_sorted/mod.rs +++ b/chips/src/assert_sorted/mod.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use crate::{is_less_than_tuple::IsLessThanTupleChip, range_gate::RangeCheckerGateChip}; use getset::Getters; @@ -46,7 +48,7 @@ pub struct AssertedSortedAir { pub struct AssertSortedChip { air: AssertedSortedAir, is_less_than_tuple_chip: IsLessThanTupleChip, - range_checker_gate: RangeCheckerGateChip, + range_checker: Arc, } impl AssertSortedChip { @@ -57,6 +59,7 @@ impl AssertSortedChip { decomp: usize, key_vec_len: usize, keys: Vec>, + range_checker: Arc, ) -> Self { Self { air: AssertedSortedAir { @@ -72,9 +75,9 @@ impl AssertSortedChip { range_max, limb_bits, decomp, - key_vec_len, + range_checker.clone(), ), - range_checker_gate: RangeCheckerGateChip::new(bus_index, range_max), + range_checker, } } diff --git a/chips/src/assert_sorted/tests/mod.rs b/chips/src/assert_sorted/tests/mod.rs index aeb029a431..8c260f30f1 100644 --- a/chips/src/assert_sorted/tests/mod.rs +++ b/chips/src/assert_sorted/tests/mod.rs @@ -1,7 +1,10 @@ +use std::sync::Arc; + +use crate::range_gate::RangeCheckerGateChip; + use super::super::assert_sorted; use afs_stark_backend::prover::USE_DEBUG_BUILDER; -use afs_stark_backend::rap::AnyRap; use afs_stark_backend::verifier::VerificationError; use afs_test_utils::config::baby_bear_poseidon2::run_simple_test_no_pis; use assert_sorted::AssertSortedChip; @@ -43,6 +46,8 @@ fn test_assert_sorted_chip_small_positive() { let requests = vec![vec![7784, 35423], vec![17558, 44832]]; + let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); + let assert_sorted_chip = AssertSortedChip::new( bus_index, range_max, @@ -50,36 +55,18 @@ fn test_assert_sorted_chip_small_positive() { decomp, key_vec_len, requests.clone(), + range_checker.clone(), ); + let range_checker_chip = assert_sorted_chip.range_checker.as_ref(); let assert_sorted_chip_trace: DenseMatrix = assert_sorted_chip.generate_trace(); - let assert_sorted_range_chip_trace: DenseMatrix = - assert_sorted_chip.range_checker_gate.generate_trace(); - - let mut chips: Vec<&dyn AnyRap<_>> = - vec![&assert_sorted_chip, &assert_sorted_chip.range_checker_gate]; - - for is_less_than_chip in assert_sorted_chip - .is_less_than_tuple_chip - .is_less_than_chips - .iter() - { - chips.push(&is_less_than_chip.range_checker_gate); - } - - let mut traces = vec![assert_sorted_chip_trace, assert_sorted_range_chip_trace]; - - for is_less_than_chip in assert_sorted_chip - .is_less_than_tuple_chip - .is_less_than_chips - .iter() - { - let range_trace: DenseMatrix = - is_less_than_chip.range_checker_gate.generate_trace(); - traces.push(range_trace); - } - - run_simple_test_no_pis(chips, traces).expect("Verification failed"); + let range_checker_trace = assert_sorted_chip.range_checker.generate_trace(); + + run_simple_test_no_pis( + vec![&assert_sorted_chip, range_checker_chip], + vec![assert_sorted_chip_trace, range_checker_trace], + ) + .expect("Verification failed"); } // covers limb_bits >= 20, key_vec_len >= 4, limb_bits % decomp != 0, number of rows >= 4, each limb has at @@ -100,6 +87,8 @@ fn test_assert_sorted_chip_large_positive() { vec![875005, 767547, 196209, 887921], ]; + let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); + let assert_sorted_chip = AssertSortedChip::new( bus_index, range_max, @@ -107,36 +96,18 @@ fn test_assert_sorted_chip_large_positive() { decomp, key_vec_len, requests.clone(), + range_checker.clone(), ); + let range_checker_chip = assert_sorted_chip.range_checker.as_ref(); let assert_sorted_chip_trace: DenseMatrix = assert_sorted_chip.generate_trace(); - let assert_sorted_range_chip_trace: DenseMatrix = - assert_sorted_chip.range_checker_gate.generate_trace(); - - let mut chips: Vec<&dyn AnyRap<_>> = - vec![&assert_sorted_chip, &assert_sorted_chip.range_checker_gate]; - - for is_less_than_chip in assert_sorted_chip - .is_less_than_tuple_chip - .is_less_than_chips - .iter() - { - chips.push(&is_less_than_chip.range_checker_gate); - } - - let mut traces = vec![assert_sorted_chip_trace, assert_sorted_range_chip_trace]; - - for is_less_than_chip in assert_sorted_chip - .is_less_than_tuple_chip - .is_less_than_chips - .iter() - { - let range_trace: DenseMatrix = - is_less_than_chip.range_checker_gate.generate_trace(); - traces.push(range_trace); - } - - run_simple_test_no_pis(chips, traces).expect("Verification failed"); + let range_checker_trace = assert_sorted_chip.range_checker.generate_trace(); + + run_simple_test_no_pis( + vec![&assert_sorted_chip, range_checker_chip], + vec![assert_sorted_chip_trace, range_checker_trace], + ) + .expect("Verification failed"); } // covers limb_bits >= 20, key_vec_len >= 4, limb_bits % decomp != 0, number of rows >= 4, at least one limb @@ -158,6 +129,8 @@ fn test_assert_sorted_chip_largelimb_negative() { vec![128, 767, 196, 953], ]; + let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); + let assert_sorted_chip = AssertSortedChip::new( bus_index, range_max, @@ -165,36 +138,17 @@ fn test_assert_sorted_chip_largelimb_negative() { decomp, key_vec_len, requests.clone(), + range_checker.clone(), ); + let range_checker_chip = assert_sorted_chip.range_checker.as_ref(); let assert_sorted_chip_trace: DenseMatrix = assert_sorted_chip.generate_trace(); - let assert_sorted_range_chip_trace: DenseMatrix = - assert_sorted_chip.range_checker_gate.generate_trace(); - - let mut chips: Vec<&dyn AnyRap<_>> = - vec![&assert_sorted_chip, &assert_sorted_chip.range_checker_gate]; - - for is_less_than_chip in assert_sorted_chip - .is_less_than_tuple_chip - .is_less_than_chips - .iter() - { - chips.push(&is_less_than_chip.range_checker_gate); - } - - let mut traces = vec![assert_sorted_chip_trace, assert_sorted_range_chip_trace]; - - for is_less_than_chip in assert_sorted_chip - .is_less_than_tuple_chip - .is_less_than_chips - .iter() - { - let range_trace: DenseMatrix = - is_less_than_chip.range_checker_gate.generate_trace(); - traces.push(range_trace); - } - - let result = run_simple_test_no_pis(chips, traces); + let range_checker_trace = assert_sorted_chip.range_checker.generate_trace(); + + let result = run_simple_test_no_pis( + vec![&assert_sorted_chip, range_checker_chip], + vec![assert_sorted_chip_trace, range_checker_trace], + ); assert_eq!( result, @@ -222,6 +176,8 @@ fn test_assert_sorted_chip_unsorted_negative() { vec![875005, 767547, 196209, 887921], ]; + let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); + let assert_sorted_chip = AssertSortedChip::new( bus_index, range_max, @@ -229,40 +185,21 @@ fn test_assert_sorted_chip_unsorted_negative() { decomp, key_vec_len, requests.clone(), + range_checker.clone(), ); + let range_checker_chip = assert_sorted_chip.range_checker.as_ref(); let assert_sorted_chip_trace: DenseMatrix = assert_sorted_chip.generate_trace(); - let assert_sorted_range_chip_trace: DenseMatrix = - assert_sorted_chip.range_checker_gate.generate_trace(); - - let mut chips: Vec<&dyn AnyRap<_>> = - vec![&assert_sorted_chip, &assert_sorted_chip.range_checker_gate]; - - for is_less_than_chip in assert_sorted_chip - .is_less_than_tuple_chip - .is_less_than_chips - .iter() - { - chips.push(&is_less_than_chip.range_checker_gate); - } - - let mut traces = vec![assert_sorted_chip_trace, assert_sorted_range_chip_trace]; - - for is_less_than_chip in assert_sorted_chip - .is_less_than_tuple_chip - .is_less_than_chips - .iter() - { - let range_trace: DenseMatrix = - is_less_than_chip.range_checker_gate.generate_trace(); - traces.push(range_trace); - } + let range_checker_trace = assert_sorted_chip.range_checker.generate_trace(); USE_DEBUG_BUILDER.with(|debug| { *debug.lock().unwrap() = false; }); assert_eq!( - run_simple_test_no_pis(chips, traces), + run_simple_test_no_pis( + vec![&assert_sorted_chip, range_checker_chip], + vec![assert_sorted_chip_trace, range_checker_trace], + ), Err(VerificationError::OodEvaluationMismatch), "Expected verification to fail, but it passed" ); diff --git a/chips/src/assert_sorted/trace.rs b/chips/src/assert_sorted/trace.rs index 4465e45577..ea145bdf8b 100644 --- a/chips/src/assert_sorted/trace.rs +++ b/chips/src/assert_sorted/trace.rs @@ -40,14 +40,14 @@ impl AssertSortedChip { for i in 0..num_limbs { let bits = (val >> (i * self.air.decomp())) & ((1 << self.air.decomp()) - 1); key_decomp_trace.push(F::from_canonical_u32(bits)); - self.range_checker_gate.add_count(bits); + self.range_checker.add_count(bits); } // the last sublimb should be of size self.limb_bits() % self.decomp() bits, // so we need to shift it to constrain this let bits = (val >> ((num_limbs - 1) * self.air.decomp())) & ((1 << self.air.decomp()) - 1); if (bits << last_limb_shift) < *self.air.range_max() { - self.range_checker_gate.add_count(bits << last_limb_shift); + self.range_checker.add_count(bits << last_limb_shift); } key_decomp_trace.push(F::from_canonical_u32(bits << last_limb_shift)); } diff --git a/chips/src/is_less_than/air.rs b/chips/src/is_less_than/air.rs index f593ace38f..050be6535d 100644 --- a/chips/src/is_less_than/air.rs +++ b/chips/src/is_less_than/air.rs @@ -15,6 +15,10 @@ impl AirConfig for IsLessThanChip { type Cols = IsLessThanCols; } +impl AirConfig for IsLessThanAir { + type Cols = IsLessThanCols; +} + impl BaseAir for IsLessThanChip { fn width(&self) -> usize { IsLessThanCols::::get_width(*self.air.limb_bits(), *self.air.decomp()) @@ -41,6 +45,7 @@ impl SubAir for IsLessThanAir { type AuxView = IsLessThanAuxCols; // constrain that the result of x < y is given by less_than + // warning: send for range check must be included for the constraints to be sound fn eval(&self, builder: &mut AB, io: Self::IoView, aux: Self::AuxView) { let x = io.x; let y = io.y; @@ -48,11 +53,8 @@ impl SubAir for IsLessThanAir { let local_aux = &aux; - // num_limbs is the number of limbs, not including the last shifted limb - let num_limbs = (self.limb_bits() + self.decomp() - 1) / self.decomp(); - - let lower_bits = local_aux.lower_bits; - let lower_bits_decomp = local_aux.lower_bits_decomp.clone(); + let lower = local_aux.lower; + let lower_decomp = local_aux.lower_decomp.clone(); // to range check the last limb of the decomposed lower_bits, we need to shift it to make sure it is in // the correct range @@ -63,30 +65,30 @@ impl SubAir for IsLessThanAir { y - x + AB::Expr::from_canonical_u64(1 << self.limb_bits()) - AB::Expr::one(); // constrain that the lower_bits + less_than * 2^limb_bits is the correct intermediate sum - let check_val = - lower_bits + less_than * AB::Expr::from_canonical_u64(1 << self.limb_bits()); + let check_val = lower + less_than * AB::Expr::from_canonical_u64(1 << self.limb_bits()); builder.assert_eq(intermed_val, check_val); // constrain that the decomposition of lower_bits is correct - let lower_bits_from_decomp = lower_bits_decomp + // each limb will be range checked + let lower_from_decomp = lower_decomp .iter() .enumerate() - .take(num_limbs) + .take(*self.num_limbs()) .fold(AB::Expr::zero(), |acc, (i, &val)| { acc + val * AB::Expr::from_canonical_u64(1 << (i * self.decomp())) }); - builder.assert_eq(lower_bits_from_decomp, lower_bits); + builder.assert_eq(lower_from_decomp, lower); - let shifted_val = - lower_bits_decomp[num_limbs - 1] * AB::Expr::from_canonical_u64(1 << last_limb_shift); + let shifted_val = lower_decomp[*self.num_limbs() - 1] + * AB::Expr::from_canonical_u64(1 << last_limb_shift); // constrain that the shifted last limb is shifted correctly - builder.assert_eq(lower_bits_decomp[num_limbs], shifted_val); + // this shifted last limb will also be range checked + builder.assert_eq(lower_decomp[*self.num_limbs()], shifted_val); // constrain that less_than is a boolean - let is_bool = less_than * (AB::Expr::one() - less_than); - builder.assert_zero(is_bool); + builder.assert_bool(less_than); } } diff --git a/chips/src/is_less_than/chip.rs b/chips/src/is_less_than/chip.rs index 56d8d0a01f..37229487c6 100644 --- a/chips/src/is_less_than/chip.rs +++ b/chips/src/is_less_than/chip.rs @@ -1,6 +1,6 @@ use crate::sub_chip::SubAirWithInteractions; -use super::columns::IsLessThanCols; +use super::{columns::IsLessThanCols, IsLessThanAir}; use afs_stark_backend::interaction::{Chip, Interaction}; use p3_air::VirtualPairCol; use p3_field::PrimeField64; @@ -18,24 +18,19 @@ impl Chip for IsLessThanChip { *self.air.decomp(), ); - SubAirWithInteractions::sends(self, cols_numbered) + SubAirWithInteractions::sends(&self.air, cols_numbered) } } -impl SubAirWithInteractions for IsLessThanChip { +impl SubAirWithInteractions for IsLessThanAir { fn sends(&self, col_indices: IsLessThanCols) -> Vec> { - // num_limbs is the number of limbs, not including the last shifted limb - let num_limbs = (*self.air.limb_bits() + *self.air.decomp() - 1) / *self.air.decomp(); - let mut interactions = vec![]; // we range check the limbs of the lower_bits so that we know each element // of lower_bits has at most limb_bits bits - for i in 0..(num_limbs + 1) { + for i in 0..(*self.num_limbs() + 1) { interactions.push(Interaction { - fields: vec![VirtualPairCol::single_main( - col_indices.aux.lower_bits_decomp[i], - )], + fields: vec![VirtualPairCol::single_main(col_indices.aux.lower_decomp[i])], count: VirtualPairCol::constant(F::one()), argument_index: *self.bus_index(), }); diff --git a/chips/src/is_less_than/columns.rs b/chips/src/is_less_than/columns.rs index 89c5d6438c..eb375d3ac4 100644 --- a/chips/src/is_less_than/columns.rs +++ b/chips/src/is_less_than/columns.rs @@ -8,8 +8,10 @@ pub struct IsLessThanIOCols { } pub struct IsLessThanAuxCols { - pub lower_bits: T, - pub lower_bits_decomp: Vec, + pub lower: T, + // lower_decomp consists of lower decomposed into limbs of size decomp where we also shift + // the final limb and store it as the last element of lower decomp so we can range check + pub lower_decomp: Vec, } pub struct IsLessThanCols { @@ -29,16 +31,16 @@ impl IsLessThanCols { let less_than = slc[2].clone(); // the next element is the value of the lower num_limbs bits of the intermediate sum - let lower_bits = slc[3].clone(); + let lower = slc[3].clone(); // the next num_limbs + 1 elements are the decomposed limbs of the lower bits of the // intermediate sum - let lower_bits_decomp = slc[4..4 + num_limbs + 1].to_vec(); + let lower_decomp = slc[4..4 + num_limbs + 1].to_vec(); let io = IsLessThanIOCols { x, y, less_than }; let aux = IsLessThanAuxCols { - lower_bits, - lower_bits_decomp, + lower, + lower_decomp, }; Self { io, aux } @@ -49,9 +51,9 @@ impl IsLessThanCols { self.io.x.clone(), self.io.y.clone(), self.io.less_than.clone(), - self.aux.lower_bits.clone(), + self.aux.lower.clone(), ]; - flattened.extend(self.aux.lower_bits_decomp.iter().cloned()); + flattened.extend(self.aux.lower_decomp.iter().cloned()); flattened } @@ -61,9 +63,9 @@ impl IsLessThanCols { width += 2; // for the less_than indicator width += 1; - // for the lower_bits + // for the lower width += 1; - // for the decomposed lower_bits + // for the decomposed lower let num_limbs = (limb_bits + decomp - 1) / decomp; width += num_limbs + 1; diff --git a/chips/src/is_less_than/mod.rs b/chips/src/is_less_than/mod.rs index bdd79e8852..feb44f052c 100644 --- a/chips/src/is_less_than/mod.rs +++ b/chips/src/is_less_than/mod.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use crate::range_gate::RangeCheckerGateChip; use getset::Getters; @@ -11,12 +13,33 @@ pub mod trace; #[derive(Default, Getters)] pub struct IsLessThanAir { + // The bus index + #[getset(get = "pub")] + bus_index: usize, + // The maximum range for the range checker #[getset(get = "pub")] range_max: u32, + // The maximum number of bits for the numbers to compare #[getset(get = "pub")] limb_bits: usize, + // The number of bits to decompose each number into, for less than checking #[getset(get = "pub")] decomp: usize, + // num_limbs is the number of limbs we decompose each input into, not including the last shifted limb + #[getset(get)] + num_limbs: usize, +} + +impl IsLessThanAir { + pub fn new(bus_index: usize, range_max: u32, limb_bits: usize, decomp: usize) -> Self { + Self { + bus_index, + range_max, + limb_bits, + decomp, + num_limbs: (limb_bits + decomp - 1) / decomp, + } + } } /** @@ -26,25 +49,26 @@ pub struct IsLessThanAir { pub struct IsLessThanChip { pub air: IsLessThanAir, - #[getset(get = "pub")] - bus_index: usize, - - pub range_checker_gate: RangeCheckerGateChip, + pub range_checker: Arc, } impl IsLessThanChip { - pub fn new(bus_index: usize, range_max: u32, limb_bits: usize, decomp: usize) -> Self { + pub fn new( + bus_index: usize, + range_max: u32, + limb_bits: usize, + decomp: usize, + range_checker: Arc, + ) -> Self { let air = IsLessThanAir { + bus_index, range_max, limb_bits, decomp, + num_limbs: (limb_bits + decomp - 1) / decomp, }; - Self { - air, - bus_index, - range_checker_gate: RangeCheckerGateChip::new(bus_index, range_max), - } + Self { air, range_checker } } fn calc_less_than(&self, x: u32, y: u32) -> u32 { diff --git a/chips/src/is_less_than/tests/mod.rs b/chips/src/is_less_than/tests/mod.rs index 607943967e..7c0582fce3 100644 --- a/chips/src/is_less_than/tests/mod.rs +++ b/chips/src/is_less_than/tests/mod.rs @@ -1,3 +1,7 @@ +use std::sync::Arc; + +use crate::range_gate::RangeCheckerGateChip; + use super::super::is_less_than::IsLessThanChip; use afs_stark_backend::prover::USE_DEBUG_BUILDER; @@ -14,12 +18,14 @@ fn test_is_less_than_chip_lt() { const DECOMP: usize = 8; const MAX: u32 = 1 << DECOMP; - let chip = IsLessThanChip::new(BUS_INDEX, MAX, LIMB_BITS, DECOMP); + let range_checker = Arc::new(RangeCheckerGateChip::new(BUS_INDEX, MAX)); + + let chip = IsLessThanChip::new(BUS_INDEX, MAX, LIMB_BITS, DECOMP, range_checker); let trace = chip.generate_trace(14321, 26883); - let range_trace: DenseMatrix = chip.range_checker_gate.generate_trace(); + let range_trace: DenseMatrix = chip.range_checker.generate_trace(); run_simple_test_no_pis( - vec![&chip, &chip.range_checker_gate], + vec![&chip, chip.range_checker.as_ref()], vec![trace, range_trace], ) .expect("Verification failed"); @@ -32,12 +38,14 @@ fn test_is_less_than_chip_gt() { const DECOMP: usize = 8; const MAX: u32 = 1 << DECOMP; - let chip = IsLessThanChip::new(BUS_INDEX, MAX, LIMB_BITS, DECOMP); + let range_checker = Arc::new(RangeCheckerGateChip::new(BUS_INDEX, MAX)); + + let chip = IsLessThanChip::new(BUS_INDEX, MAX, LIMB_BITS, DECOMP, range_checker); let trace = chip.generate_trace(1, 0); - let range_trace: DenseMatrix = chip.range_checker_gate.generate_trace(); + let range_trace: DenseMatrix = chip.range_checker.generate_trace(); run_simple_test_no_pis( - vec![&chip, &chip.range_checker_gate], + vec![&chip, chip.range_checker.as_ref()], vec![trace, range_trace], ) .expect("Verification failed"); @@ -50,12 +58,14 @@ fn test_is_less_than_chip_eq() { const DECOMP: usize = 8; const MAX: u32 = 1 << DECOMP; - let chip = IsLessThanChip::new(BUS_INDEX, MAX, LIMB_BITS, DECOMP); + let range_checker = Arc::new(RangeCheckerGateChip::new(BUS_INDEX, MAX)); + + let chip = IsLessThanChip::new(BUS_INDEX, MAX, LIMB_BITS, DECOMP, range_checker); let trace = chip.generate_trace(773, 773); - let range_trace: DenseMatrix = chip.range_checker_gate.generate_trace(); + let range_trace: DenseMatrix = chip.range_checker.generate_trace(); run_simple_test_no_pis( - vec![&chip, &chip.range_checker_gate], + vec![&chip, chip.range_checker.as_ref()], vec![trace, range_trace], ) .expect("Verification failed"); @@ -68,9 +78,11 @@ fn test_is_less_than_negative() { const DECOMP: usize = 8; const MAX: u32 = 1 << DECOMP; - let chip = IsLessThanChip::new(BUS_INDEX, MAX, LIMB_BITS, DECOMP); + let range_checker = Arc::new(RangeCheckerGateChip::new(BUS_INDEX, MAX)); + + let chip = IsLessThanChip::new(BUS_INDEX, MAX, LIMB_BITS, DECOMP, range_checker); let mut trace = chip.generate_trace(446, 553); - let range_trace = chip.range_checker_gate.generate_trace(); + let range_trace = chip.range_checker.generate_trace(); trace.values[2] = AbstractField::from_canonical_u64(0); @@ -79,7 +91,7 @@ fn test_is_less_than_negative() { }); assert_eq!( run_simple_test_no_pis( - vec![&chip, &chip.range_checker_gate], + vec![&chip, chip.range_checker.as_ref()], vec![trace, range_trace], ), Err(VerificationError::OodEvaluationMismatch), diff --git a/chips/src/is_less_than/trace.rs b/chips/src/is_less_than/trace.rs index 3409f25bcb..0086d70d01 100644 --- a/chips/src/is_less_than/trace.rs +++ b/chips/src/is_less_than/trace.rs @@ -26,8 +26,6 @@ impl LocalTraceInstructions for IsLessThanChip { let (x, y) = input; let less_than = self.calc_less_than(x, y); - // num_limbs is the number of limbs, not including the last shifted limb - let num_limbs = (self.air.limb_bits() + self.air.decomp() - 1) / self.air.decomp(); // to range check the last limb of the decomposed lower_bits, we need to shift it to make sure it is in // the correct range let last_limb_shift = @@ -35,24 +33,24 @@ impl LocalTraceInstructions for IsLessThanChip { // obtain the lower_bits let check_less_than = (1 << self.air.limb_bits()) + y - x - 1; - let lower_bits = F::from_canonical_u32(check_less_than & ((1 << self.air.limb_bits()) - 1)); - let lower_bits_u32 = check_less_than & ((1 << self.air.limb_bits()) - 1); + let lower = F::from_canonical_u32(check_less_than & ((1 << self.air.limb_bits()) - 1)); + let lower_u32 = check_less_than & ((1 << self.air.limb_bits()) - 1); // decompose lower_bits into limbs and range check - let mut lower_bits_decomp: Vec = vec![]; - for i in 0..num_limbs { - let bits = (lower_bits_u32 >> (i * self.air.decomp())) & ((1 << self.air.decomp()) - 1); - lower_bits_decomp.push(F::from_canonical_u32(bits)); - self.range_checker_gate.add_count(bits); + let mut lower_decomp: Vec = vec![]; + for i in 0..*self.air.num_limbs() { + let bits = (lower_u32 >> (i * self.air.decomp())) & ((1 << self.air.decomp()) - 1); + lower_decomp.push(F::from_canonical_u32(bits)); + self.range_checker.add_count(bits); } // shift the last limb and range check - let bits = (lower_bits_u32 >> ((num_limbs - 1) * self.air.decomp())) + let bits = (lower_u32 >> ((self.air.num_limbs() - 1) * self.air.decomp())) & ((1 << self.air.decomp()) - 1); if (bits << last_limb_shift) < *self.air.range_max() { - self.range_checker_gate.add_count(bits << last_limb_shift); + self.range_checker.add_count(bits << last_limb_shift); } - lower_bits_decomp.push(F::from_canonical_u32(bits << last_limb_shift)); + lower_decomp.push(F::from_canonical_u32(bits << last_limb_shift)); let io = IsLessThanIOCols { x: F::from_canonical_u32(x), @@ -60,8 +58,8 @@ impl LocalTraceInstructions for IsLessThanChip { less_than: F::from_canonical_u32(less_than), }; let aux = IsLessThanAuxCols { - lower_bits, - lower_bits_decomp, + lower, + lower_decomp, }; IsLessThanCols { io, aux } diff --git a/chips/src/is_less_than_tuple/air.rs b/chips/src/is_less_than_tuple/air.rs index 0d56e3f6cc..55b16b6487 100644 --- a/chips/src/is_less_than_tuple/air.rs +++ b/chips/src/is_less_than_tuple/air.rs @@ -1,4 +1,4 @@ -use std::borrow::Borrow; +use std::{borrow::Borrow, sync::Arc}; use p3_air::{Air, AirBuilder, BaseAir}; use p3_field::{AbstractField, Field}; @@ -6,6 +6,7 @@ use p3_matrix::Matrix; use crate::{ is_less_than::{columns::IsLessThanCols, IsLessThanChip}, + range_gate::RangeCheckerGateChip, sub_chip::{AirConfig, SubAir}, }; @@ -23,7 +24,7 @@ impl BaseAir for IsLessThanTupleChip { IsLessThanTupleCols::::get_width( self.air.limb_bits().clone(), *self.air.decomp(), - *self.air.tuple_len(), + self.air.tuple_len(), ) } } @@ -39,7 +40,7 @@ impl Air for IsLessThanTupleChip { local, self.air.limb_bits().clone(), *self.air.decomp(), - *self.air.tuple_len(), + self.air.tuple_len(), ); SubAir::eval(&self.air, builder, local_cols.io, local_cols.aux); @@ -60,11 +61,17 @@ impl SubAir for IsLessThanTupleAir { let x_val = x[i]; let y_val = y[i]; + let range_checker_dummy = Arc::new(RangeCheckerGateChip::new( + *self.bus_index(), + *self.range_max(), + )); + let is_less_than_chip_dummy = IsLessThanChip::new( *self.bus_index(), *self.range_max(), self.limb_bits()[i], *self.decomp(), + range_checker_dummy, ); // here we constrain that less_than[i] indicates whether x[i] < y[i] using the IsLessThan subchip @@ -94,7 +101,7 @@ impl SubAir for IsLessThanTupleAir { } // together, these constrain that is_equal is the indicator for whether diff == 0, i.e. x[i] = y[i] - for i in 0..*self.tuple_len() { + for i in 0..self.tuple_len() { let diff = aux.diff[i]; let is_equal = aux.is_zero[i]; let inverse = aux.inverses[i]; diff --git a/chips/src/is_less_than_tuple/chip.rs b/chips/src/is_less_than_tuple/chip.rs index 5eb6d30f1c..9cbd68a4d6 100644 --- a/chips/src/is_less_than_tuple/chip.rs +++ b/chips/src/is_less_than_tuple/chip.rs @@ -9,7 +9,7 @@ impl Chip for IsLessThanTupleChip { let num_cols = IsLessThanTupleCols::::get_width( self.air.limb_bits().clone(), *self.air.decomp(), - *self.air.tuple_len(), + self.air.tuple_len(), ); let all_cols = (0..num_cols).collect::>(); @@ -17,7 +17,7 @@ impl Chip for IsLessThanTupleChip { &all_cols, self.air.limb_bits().clone(), *self.air.decomp(), - *self.air.tuple_len(), + self.air.tuple_len(), ); SubAirWithInteractions::sends(self, cols_numbered) @@ -29,7 +29,7 @@ impl SubAirWithInteractions for IsLessThanTupleChip { // num_limbs is the number of limbs, not including the last shifted limb let mut interactions = vec![]; - for i in 0..*self.air.tuple_len() { + for i in 0..self.air.tuple_len() { let mut is_less_than_cols = vec![ col_indices.io.x[i], col_indices.io.y[i], @@ -46,7 +46,7 @@ impl SubAirWithInteractions for IsLessThanTupleChip { ); let curr_interactions = - SubAirWithInteractions::::sends(&self.is_less_than_chips[i], is_less_than_cols); + SubAirWithInteractions::::sends(&self.air.is_lt_airs[i], is_less_than_cols); interactions.extend(curr_interactions); } diff --git a/chips/src/is_less_than_tuple/mod.rs b/chips/src/is_less_than_tuple/mod.rs index 3f1c889f7a..cbc15f243c 100644 --- a/chips/src/is_less_than_tuple/mod.rs +++ b/chips/src/is_less_than_tuple/mod.rs @@ -1,6 +1,8 @@ +use std::sync::Arc; + use getset::Getters; -use crate::is_less_than::IsLessThanChip; +use crate::{is_less_than::IsLessThanAir, range_gate::RangeCheckerGateChip}; #[cfg(test)] pub mod tests; @@ -17,11 +19,19 @@ pub struct IsLessThanTupleAir { #[getset(get = "pub")] range_max: u32, #[getset(get = "pub")] - limb_bits: Vec, - #[getset(get = "pub")] decomp: usize, #[getset(get = "pub")] - tuple_len: usize, + is_lt_airs: Vec, +} + +impl IsLessThanTupleAir { + pub fn tuple_len(&self) -> usize { + self.is_lt_airs.len() + } + + pub fn limb_bits(&self) -> Vec { + self.is_lt_airs.iter().map(|air| *air.limb_bits()).collect() + } } /** @@ -33,7 +43,7 @@ pub struct IsLessThanTupleAir { pub struct IsLessThanTupleChip { pub air: IsLessThanTupleAir, - pub is_less_than_chips: Vec, + pub range_checker: Arc, } impl IsLessThanTupleChip { @@ -42,25 +52,20 @@ impl IsLessThanTupleChip { range_max: u32, limb_bits: Vec, decomp: usize, - tuple_len: usize, + range_checker: Arc, ) -> Self { + let is_lt_airs = limb_bits + .iter() + .map(|&limb_bit| IsLessThanAir::new(bus_index, range_max, limb_bit, decomp)) + .collect::>(); + let air = IsLessThanTupleAir { bus_index, range_max, - limb_bits: limb_bits.clone(), decomp, - tuple_len, + is_lt_airs, }; - // create less_than_chips which will be used to compare individual tuple elements - let is_less_than_chips = limb_bits - .iter() - .map(|&limb_bit| IsLessThanChip::new(bus_index, range_max, limb_bit, decomp)) - .collect::>(); - - Self { - air, - is_less_than_chips, - } + Self { air, range_checker } } } diff --git a/chips/src/is_less_than_tuple/tests/mod.rs b/chips/src/is_less_than_tuple/tests/mod.rs index fa80af1761..623ef8245b 100644 --- a/chips/src/is_less_than_tuple/tests/mod.rs +++ b/chips/src/is_less_than_tuple/tests/mod.rs @@ -1,12 +1,13 @@ +use std::sync::Arc; + +use crate::range_gate::RangeCheckerGateChip; + use super::super::is_less_than_tuple::IsLessThanTupleChip; use afs_stark_backend::prover::USE_DEBUG_BUILDER; -use afs_stark_backend::rap::AnyRap; use afs_stark_backend::verifier::VerificationError; use afs_test_utils::config::baby_bear_poseidon2::run_simple_test_no_pis; -use p3_baby_bear::BabyBear; use p3_field::AbstractField; -use p3_matrix::dense::DenseMatrix; #[test] fn test_is_less_than_tuple_chip_lt() { @@ -14,26 +15,16 @@ fn test_is_less_than_tuple_chip_lt() { let limb_bits: Vec = vec![16, 8]; let decomp: usize = 8; let range_max: u32 = 1 << decomp; - let tuple_len: usize = 2; - - let chip = IsLessThanTupleChip::new(bus_index, range_max, limb_bits, decomp, tuple_len); - let trace = chip.generate_trace(vec![14321, 123], vec![26678, 233]); - - let mut chips: Vec<&dyn AnyRap<_>> = vec![&chip]; - - for is_less_than_chip in chip.is_less_than_chips.iter() { - chips.push(&is_less_than_chip.range_checker_gate); - } - let mut traces = vec![trace]; + let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); - for is_less_than_chip in chip.is_less_than_chips.iter() { - let range_trace: DenseMatrix = - is_less_than_chip.range_checker_gate.generate_trace(); - traces.push(range_trace); - } + let chip = IsLessThanTupleChip::new(bus_index, range_max, limb_bits, decomp, range_checker); + let range_checker = chip.range_checker.as_ref(); + let trace = chip.generate_trace(vec![14321, 123], vec![26678, 233]); + let range_checker_trace = range_checker.generate_trace(); - run_simple_test_no_pis(chips, traces).expect("Verification failed"); + run_simple_test_no_pis(vec![&chip, range_checker], vec![trace, range_checker_trace]) + .expect("Verification failed"); } #[test] @@ -42,26 +33,16 @@ fn test_is_less_than_tuple_chip_gt() { let limb_bits: Vec = vec![16, 8]; let decomp: usize = 8; let range_max: u32 = 1 << decomp; - let tuple_len: usize = 2; - - let chip = IsLessThanTupleChip::new(bus_index, range_max, limb_bits, decomp, tuple_len); - let trace = chip.generate_trace(vec![14321, 244], vec![26678, 233]); - - let mut chips: Vec<&dyn AnyRap<_>> = vec![&chip]; - - for is_less_than_chip in chip.is_less_than_chips.iter() { - chips.push(&is_less_than_chip.range_checker_gate); - } - let mut traces = vec![trace]; + let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); - for is_less_than_chip in chip.is_less_than_chips.iter() { - let range_trace: DenseMatrix = - is_less_than_chip.range_checker_gate.generate_trace(); - traces.push(range_trace); - } + let chip = IsLessThanTupleChip::new(bus_index, range_max, limb_bits, decomp, range_checker); + let range_checker = chip.range_checker.as_ref(); + let trace = chip.generate_trace(vec![14321, 244], vec![26678, 233]); + let range_checker_trace = range_checker.generate_trace(); - run_simple_test_no_pis(chips, traces).expect("Verification failed"); + run_simple_test_no_pis(vec![&chip, range_checker], vec![trace, range_checker_trace]) + .expect("Verification failed"); } #[test] @@ -70,26 +51,16 @@ fn test_is_less_than_tuple_chip_eq() { let limb_bits: Vec = vec![16, 8]; let decomp: usize = 8; let range_max: u32 = 1 << decomp; - let tuple_len: usize = 2; - - let chip = IsLessThanTupleChip::new(bus_index, range_max, limb_bits, decomp, tuple_len); - let trace = chip.generate_trace(vec![14321, 244], vec![14321, 244]); - let mut chips: Vec<&dyn AnyRap<_>> = vec![&chip]; + let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); - for is_less_than_chip in chip.is_less_than_chips.iter() { - chips.push(&is_less_than_chip.range_checker_gate); - } - - let mut traces = vec![trace]; - - for is_less_than_chip in chip.is_less_than_chips.iter() { - let range_trace: DenseMatrix = - is_less_than_chip.range_checker_gate.generate_trace(); - traces.push(range_trace); - } + let chip = IsLessThanTupleChip::new(bus_index, range_max, limb_bits, decomp, range_checker); + let range_checker = chip.range_checker.as_ref(); + let trace = chip.generate_trace(vec![14321, 244], vec![14321, 244]); + let range_checker_trace = range_checker.generate_trace(); - run_simple_test_no_pis(chips, traces).expect("Verification failed"); + run_simple_test_no_pis(vec![&chip, range_checker], vec![trace, range_checker_trace]) + .expect("Verification failed"); } #[test] @@ -98,32 +69,21 @@ fn test_is_less_than_tuple_chip_negative() { let limb_bits: Vec = vec![16, 8]; let decomp: usize = 8; let range_max: u32 = 1 << decomp; - let tuple_len: usize = 2; - let chip = IsLessThanTupleChip::new(bus_index, range_max, limb_bits, decomp, tuple_len); + let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); + + let chip = IsLessThanTupleChip::new(bus_index, range_max, limb_bits, decomp, range_checker); + let range_checker = chip.range_checker.as_ref(); let mut trace = chip.generate_trace(vec![14321, 123], vec![26678, 233]); + let range_checker_trace = range_checker.generate_trace(); trace.values[2] = AbstractField::from_canonical_u64(0); - let mut chips: Vec<&dyn AnyRap<_>> = vec![&chip]; - - for is_less_than_chip in chip.is_less_than_chips.iter() { - chips.push(&is_less_than_chip.range_checker_gate); - } - - let mut traces = vec![trace]; - - for is_less_than_chip in chip.is_less_than_chips.iter() { - let range_trace: DenseMatrix = - is_less_than_chip.range_checker_gate.generate_trace(); - traces.push(range_trace); - } - USE_DEBUG_BUILDER.with(|debug| { *debug.lock().unwrap() = false; }); assert_eq!( - run_simple_test_no_pis(chips, traces), + run_simple_test_no_pis(vec![&chip, range_checker], vec![trace, range_checker_trace]), Err(VerificationError::OodEvaluationMismatch), "Expected verification to fail, but it passed" ); diff --git a/chips/src/is_less_than_tuple/trace.rs b/chips/src/is_less_than_tuple/trace.rs index 6c594f5756..5e8030b89f 100644 --- a/chips/src/is_less_than_tuple/trace.rs +++ b/chips/src/is_less_than_tuple/trace.rs @@ -1,7 +1,7 @@ use p3_field::PrimeField64; use p3_matrix::dense::RowMajorMatrix; -use crate::sub_chip::LocalTraceInstructions; +use crate::{is_less_than::IsLessThanChip, sub_chip::LocalTraceInstructions}; use super::{ columns::{IsLessThanTupleAuxCols, IsLessThanTupleCols, IsLessThanTupleIOCols}, @@ -13,7 +13,7 @@ impl IsLessThanTupleChip { let num_cols: usize = IsLessThanTupleCols::::get_width( self.air.limb_bits().clone(), *self.air.decomp(), - *self.air.tuple_len(), + self.air.tuple_len(), ); let row: Vec = self.generate_trace_row((x, y)).flatten(); @@ -37,10 +37,16 @@ impl LocalTraceInstructions for IsLessThanTupleChip { // use subchip to generate relevant columns for i in 0..x.len() { - let is_less_than_chip = &self.is_less_than_chips[i]; + let is_less_than_chip = IsLessThanChip::new( + *self.air.bus_index(), + *self.air.range_max(), + self.air.limb_bits()[i], + *self.air.decomp(), + self.range_checker.clone(), + ); let curr_less_than_row = - LocalTraceInstructions::generate_trace_row(is_less_than_chip, (x[i], y[i])) + LocalTraceInstructions::generate_trace_row(&is_less_than_chip, (x[i], y[i])) .flatten(); less_than.push(curr_less_than_row[2]); lower_bits.push(curr_less_than_row[3]); From 6ca339f22ff57bd1f5a9d9e6b9c2017934951397 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Wed, 5 Jun 2024 14:51:42 -0400 Subject: [PATCH 15/46] chore: refactor AssertSorted, IsEqual, IsLessThan, and IsLessThanTuple chips --- chips/src/assert_sorted/chip.rs | 2 +- chips/src/assert_sorted/columns.rs | 50 +++++++++------ chips/src/assert_sorted/trace.rs | 4 +- chips/src/is_equal/air.rs | 14 +++-- chips/src/is_equal/columns.rs | 21 +++++-- chips/src/is_equal/trace.rs | 2 +- chips/src/is_less_than/air.rs | 4 -- chips/src/is_less_than/mod.rs | 8 --- chips/src/is_less_than/tests/mod.rs | 48 +++++++-------- chips/src/is_less_than/trace.rs | 44 +++++++------ chips/src/is_less_than_tuple/air.rs | 68 ++++++++++---------- chips/src/is_less_than_tuple/chip.rs | 18 +++--- chips/src/is_less_than_tuple/columns.rs | 79 ++++++++++++++---------- chips/src/is_less_than_tuple/trace.rs | 82 ++++++++++++++++--------- 14 files changed, 251 insertions(+), 193 deletions(-) diff --git a/chips/src/assert_sorted/chip.rs b/chips/src/assert_sorted/chip.rs index 15ef662e84..ac6f4d9f12 100644 --- a/chips/src/assert_sorted/chip.rs +++ b/chips/src/assert_sorted/chip.rs @@ -39,7 +39,7 @@ impl Chip for AssertSortedChip { } let subchip_interactions = SubAirWithInteractions::::sends( - &self.is_less_than_tuple_chip, + &self.is_less_than_tuple_chip.air, cols_numbered.is_less_than_tuple_cols, ); diff --git a/chips/src/assert_sorted/columns.rs b/chips/src/assert_sorted/columns.rs index 92b44812c4..261609a361 100644 --- a/chips/src/assert_sorted/columns.rs +++ b/chips/src/assert_sorted/columns.rs @@ -1,7 +1,11 @@ use afs_derive::AlignedBorrow; -use crate::is_less_than_tuple::columns::{ - IsLessThanTupleAuxCols, IsLessThanTupleCols, IsLessThanTupleIOCols, +use crate::{ + is_equal::columns::IsEqualAuxCols, + is_less_than::columns::IsLessThanAuxCols, + is_less_than_tuple::columns::{ + IsLessThanTupleAuxCols, IsLessThanTupleCols, IsLessThanTupleIOCols, + }, }; // Since AssertSortedChip contains a LessThanChip subchip, a subset of the columns are those of the @@ -53,29 +57,24 @@ impl AssertSortedCols { // the next key_vec_len elements are the values of the lower bits of each intermediate sum // (i.e. 2^limb_bits[i] + y[i] - x[i] - 1) - let lower_bits = slc[curr_start_idx..curr_end_idx].to_vec(); + let lower_vec = slc[curr_start_idx..curr_end_idx].to_vec(); // the next elements are the decomposed lower bits - let mut lower_bits_decomp: Vec> = vec![]; + let mut lower_decomp_vec: Vec> = vec![]; for curr_limb_bits in limb_bits.iter() { let num_limbs = (curr_limb_bits + decomp - 1) / decomp; curr_start_idx = curr_end_idx; curr_end_idx += num_limbs + 1; - lower_bits_decomp.push(slc[curr_start_idx..curr_end_idx].to_vec()); + lower_decomp_vec.push(slc[curr_start_idx..curr_end_idx].to_vec()); } curr_start_idx = curr_end_idx; curr_end_idx += key_vec_len; - // the next key_vec_len elements are the difference between consecutive limbs of rows - let diff = slc[curr_start_idx..curr_end_idx].to_vec(); - curr_start_idx = curr_end_idx; - curr_end_idx += key_vec_len; - // the next key_vec_len elements are the indicator whether the difference is zero; if difference is // zero then the two limbs must be equal - let is_zero = slc[curr_start_idx..curr_end_idx].to_vec(); + let is_equal = slc[curr_start_idx..curr_end_idx].to_vec(); curr_start_idx = curr_end_idx; curr_end_idx += key_vec_len; @@ -83,6 +82,21 @@ impl AssertSortedCols { // note that this sum will always be nonzero so the inverse will exist let inverses = slc[curr_start_idx..curr_end_idx].to_vec(); + let mut less_than_cols: Vec> = vec![]; + for i in 0..key_vec_len { + let less_than_col = IsLessThanAuxCols { + lower: lower_vec[i].clone(), + lower_decomp: lower_decomp_vec[i].clone(), + }; + less_than_cols.push(less_than_col); + } + + let mut is_equal_cols: Vec> = vec![]; + for inv in inverses.iter() { + let is_equal_col = IsEqualAuxCols { inv: inv.clone() }; + is_equal_cols.push(is_equal_col); + } + let io = IsLessThanTupleIOCols { x, y, @@ -90,11 +104,9 @@ impl AssertSortedCols { }; let aux = IsLessThanTupleAuxCols { less_than, - lower_bits, - lower_bits_decomp, - diff, - is_zero, - inverses, + less_than_cols, + is_equal, + is_equal_cols, }; let is_less_than_tuple_cols = IsLessThanTupleCols { io, aux }; @@ -125,17 +137,15 @@ impl AssertSortedCols { // for the less_than indicators width += key_vec_len; - // for the lower_bits + // for the lowers width += key_vec_len; - // for the decomposed lower_bits + // for the decomposed lowers for &limb_bit in limb_bits.iter() { let num_limbs = (limb_bit + decomp - 1) / decomp; width += num_limbs + 1; } - // for the difference between consecutive rows - width += key_vec_len; // for the indicator whether difference is zero width += key_vec_len; // for the y such that y * (i + x) = 1 diff --git a/chips/src/assert_sorted/trace.rs b/chips/src/assert_sorted/trace.rs index ea145bdf8b..4032a597f3 100644 --- a/chips/src/assert_sorted/trace.rs +++ b/chips/src/assert_sorted/trace.rs @@ -23,8 +23,8 @@ impl AssertSortedChip { }; let is_less_than_tuple_trace = LocalTraceInstructions::generate_trace_row( - &self.is_less_than_tuple_chip, - (key.clone(), next_key.clone()), + &self.is_less_than_tuple_chip.air, + (key.clone(), next_key.clone(), self.range_checker.clone()), ) .flatten(); diff --git a/chips/src/is_equal/air.rs b/chips/src/is_equal/air.rs index b6cca4ed1a..7f2d6834e5 100644 --- a/chips/src/is_equal/air.rs +++ b/chips/src/is_equal/air.rs @@ -1,6 +1,6 @@ use std::borrow::Borrow; -use super::columns::{IsEqualCols, IsEqualIOCols, NUM_COLS}; +use super::columns::{IsEqualAuxCols, IsEqualCols, IsEqualIOCols, NUM_COLS}; use super::IsEqualChip; use crate::sub_chip::{AirConfig, SubAir}; use afs_stark_backend::interaction::Chip; @@ -20,9 +20,11 @@ impl Air for IsEqualChip { let main = builder.main(); let local = main.row_slice(0); - let is_equal_cols: &IsEqualCols<_> = (*local).borrow(); + let is_equal_cols: &[AB::Var] = (*local).borrow(); - SubAir::::eval(self, builder, is_equal_cols.io, is_equal_cols.inv); + let is_equal_cols = IsEqualCols::from_slice(is_equal_cols); + + SubAir::::eval(self, builder, is_equal_cols.io, is_equal_cols.aux); } } @@ -35,10 +37,10 @@ impl Chip for IsEqualChip {} impl SubAir for IsEqualChip { type IoView = IsEqualIOCols; - type AuxView = AB::Var; + type AuxView = IsEqualAuxCols; - fn eval(&self, builder: &mut AB, io: Self::IoView, inv: Self::AuxView) { - builder.assert_eq((io.x - io.y) * inv + io.is_equal, AB::F::one()); + fn eval(&self, builder: &mut AB, io: Self::IoView, aux: Self::AuxView) { + builder.assert_eq((io.x - io.y) * aux.inv + io.is_equal, AB::F::one()); builder.assert_eq((io.x - io.y) * io.is_equal, AB::F::zero()); } } diff --git a/chips/src/is_equal/columns.rs b/chips/src/is_equal/columns.rs index 9a29ce9f3a..d9caeff6cc 100644 --- a/chips/src/is_equal/columns.rs +++ b/chips/src/is_equal/columns.rs @@ -4,9 +4,9 @@ pub const NUM_COLS: usize = 4; #[repr(C)] #[derive(AlignedBorrow)] -pub struct IsEqualCols { - pub io: IsEqualIOCols, - pub inv: F, +pub struct IsEqualCols { + pub io: IsEqualIOCols, + pub aux: IsEqualAuxCols, } #[derive(Clone, Copy)] @@ -16,14 +16,27 @@ pub struct IsEqualIOCols { pub is_equal: T, } +pub struct IsEqualAuxCols { + pub inv: T, +} + impl IsEqualCols { pub const fn new(x: T, y: T, is_equal: T, inv: T) -> IsEqualCols { IsEqualCols { io: IsEqualIOCols { x, y, is_equal }, - inv, + aux: IsEqualAuxCols { inv }, } } + pub fn from_slice(slc: &[T]) -> IsEqualCols { + let x = slc[0].clone(); + let y = slc[1].clone(); + let is_equal = slc[2].clone(); + let inv = slc[3].clone(); + + IsEqualCols::new(x, y, is_equal, inv) + } + pub fn get_width() -> usize { NUM_COLS } diff --git a/chips/src/is_equal/trace.rs b/chips/src/is_equal/trace.rs index 61fc79e1ce..179cce3bd4 100644 --- a/chips/src/is_equal/trace.rs +++ b/chips/src/is_equal/trace.rs @@ -17,7 +17,7 @@ impl IsEqualChip { is_equal_cols.io.x, is_equal_cols.io.y, is_equal_cols.io.is_equal, - is_equal_cols.inv, + is_equal_cols.aux.inv, ] }) .collect::>(); diff --git a/chips/src/is_less_than/air.rs b/chips/src/is_less_than/air.rs index 050be6535d..bca87a3b04 100644 --- a/chips/src/is_less_than/air.rs +++ b/chips/src/is_less_than/air.rs @@ -11,10 +11,6 @@ use super::{ IsLessThanAir, IsLessThanChip, }; -impl AirConfig for IsLessThanChip { - type Cols = IsLessThanCols; -} - impl AirConfig for IsLessThanAir { type Cols = IsLessThanCols; } diff --git a/chips/src/is_less_than/mod.rs b/chips/src/is_less_than/mod.rs index feb44f052c..f3e15eca5a 100644 --- a/chips/src/is_less_than/mod.rs +++ b/chips/src/is_less_than/mod.rs @@ -70,12 +70,4 @@ impl IsLessThanChip { Self { air, range_checker } } - - fn calc_less_than(&self, x: u32, y: u32) -> u32 { - if x < y { - 1 - } else { - 0 - } - } } diff --git a/chips/src/is_less_than/tests/mod.rs b/chips/src/is_less_than/tests/mod.rs index 7c0582fce3..054f53b9d7 100644 --- a/chips/src/is_less_than/tests/mod.rs +++ b/chips/src/is_less_than/tests/mod.rs @@ -13,14 +13,14 @@ use p3_matrix::dense::DenseMatrix; #[test] fn test_is_less_than_chip_lt() { - const BUS_INDEX: usize = 0; - const LIMB_BITS: usize = 16; - const DECOMP: usize = 8; - const MAX: u32 = 1 << DECOMP; + let bus_index: usize = 0; + let limb_bits: usize = 16; + let decomp: usize = 8; + let range_max: u32 = 1 << decomp; - let range_checker = Arc::new(RangeCheckerGateChip::new(BUS_INDEX, MAX)); + let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); - let chip = IsLessThanChip::new(BUS_INDEX, MAX, LIMB_BITS, DECOMP, range_checker); + let chip = IsLessThanChip::new(bus_index, range_max, limb_bits, decomp, range_checker); let trace = chip.generate_trace(14321, 26883); let range_trace: DenseMatrix = chip.range_checker.generate_trace(); @@ -33,14 +33,14 @@ fn test_is_less_than_chip_lt() { #[test] fn test_is_less_than_chip_gt() { - const BUS_INDEX: usize = 0; - const LIMB_BITS: usize = 16; - const DECOMP: usize = 8; - const MAX: u32 = 1 << DECOMP; + let bus_index: usize = 0; + let limb_bits: usize = 16; + let decomp: usize = 8; + let range_max: u32 = 1 << decomp; - let range_checker = Arc::new(RangeCheckerGateChip::new(BUS_INDEX, MAX)); + let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); - let chip = IsLessThanChip::new(BUS_INDEX, MAX, LIMB_BITS, DECOMP, range_checker); + let chip = IsLessThanChip::new(bus_index, range_max, limb_bits, decomp, range_checker); let trace = chip.generate_trace(1, 0); let range_trace: DenseMatrix = chip.range_checker.generate_trace(); @@ -53,14 +53,14 @@ fn test_is_less_than_chip_gt() { #[test] fn test_is_less_than_chip_eq() { - const BUS_INDEX: usize = 0; - const LIMB_BITS: usize = 16; - const DECOMP: usize = 8; - const MAX: u32 = 1 << DECOMP; + let bus_index: usize = 0; + let limb_bits: usize = 16; + let decomp: usize = 8; + let range_max: u32 = 1 << decomp; - let range_checker = Arc::new(RangeCheckerGateChip::new(BUS_INDEX, MAX)); + let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); - let chip = IsLessThanChip::new(BUS_INDEX, MAX, LIMB_BITS, DECOMP, range_checker); + let chip = IsLessThanChip::new(bus_index, range_max, limb_bits, decomp, range_checker); let trace = chip.generate_trace(773, 773); let range_trace: DenseMatrix = chip.range_checker.generate_trace(); @@ -73,14 +73,14 @@ fn test_is_less_than_chip_eq() { #[test] fn test_is_less_than_negative() { - const BUS_INDEX: usize = 0; - const LIMB_BITS: usize = 16; - const DECOMP: usize = 8; - const MAX: u32 = 1 << DECOMP; + let bus_index: usize = 0; + let limb_bits: usize = 16; + let decomp: usize = 8; + let range_max: u32 = 1 << decomp; - let range_checker = Arc::new(RangeCheckerGateChip::new(BUS_INDEX, MAX)); + let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); - let chip = IsLessThanChip::new(BUS_INDEX, MAX, LIMB_BITS, DECOMP, range_checker); + let chip = IsLessThanChip::new(bus_index, range_max, limb_bits, decomp, range_checker); let mut trace = chip.generate_trace(446, 553); let range_trace = chip.range_checker.generate_trace(); diff --git a/chips/src/is_less_than/trace.rs b/chips/src/is_less_than/trace.rs index 0086d70d01..680764b3ab 100644 --- a/chips/src/is_less_than/trace.rs +++ b/chips/src/is_less_than/trace.rs @@ -1,11 +1,13 @@ +use std::sync::Arc; + use p3_field::PrimeField64; use p3_matrix::dense::RowMajorMatrix; -use crate::sub_chip::LocalTraceInstructions; +use crate::{range_gate::RangeCheckerGateChip, sub_chip::LocalTraceInstructions}; use super::{ columns::{IsLessThanAuxCols, IsLessThanCols, IsLessThanIOCols}, - IsLessThanChip, + IsLessThanAir, IsLessThanChip, }; impl IsLessThanChip { @@ -13,42 +15,44 @@ impl IsLessThanChip { let num_cols: usize = IsLessThanCols::::get_width(*self.air.limb_bits(), *self.air.decomp()); - let row = self.generate_trace_row((x, y)).flatten(); + let row = self + .air + .generate_trace_row((x, y, self.range_checker.clone())) + .flatten(); RowMajorMatrix::new(row, num_cols) } } -impl LocalTraceInstructions for IsLessThanChip { - type LocalInput = (u32, u32); +impl LocalTraceInstructions for IsLessThanAir { + type LocalInput = (u32, u32, Arc); - fn generate_trace_row(&self, input: (u32, u32)) -> Self::Cols { - let (x, y) = input; - let less_than = self.calc_less_than(x, y); + fn generate_trace_row(&self, input: (u32, u32, Arc)) -> Self::Cols { + let (x, y, range_checker) = input; + let less_than = if x < y { 1 } else { 0 }; // to range check the last limb of the decomposed lower_bits, we need to shift it to make sure it is in // the correct range - let last_limb_shift = - (self.air.decomp() - (self.air.limb_bits() % self.air.decomp())) % self.air.decomp(); + let last_limb_shift = (self.decomp() - (self.limb_bits() % self.decomp())) % self.decomp(); // obtain the lower_bits - let check_less_than = (1 << self.air.limb_bits()) + y - x - 1; - let lower = F::from_canonical_u32(check_less_than & ((1 << self.air.limb_bits()) - 1)); - let lower_u32 = check_less_than & ((1 << self.air.limb_bits()) - 1); + let check_less_than = (1 << self.limb_bits()) + y - x - 1; + let lower = F::from_canonical_u32(check_less_than & ((1 << self.limb_bits()) - 1)); + let lower_u32 = check_less_than & ((1 << self.limb_bits()) - 1); // decompose lower_bits into limbs and range check let mut lower_decomp: Vec = vec![]; - for i in 0..*self.air.num_limbs() { - let bits = (lower_u32 >> (i * self.air.decomp())) & ((1 << self.air.decomp()) - 1); + for i in 0..*self.num_limbs() { + let bits = (lower_u32 >> (i * self.decomp())) & ((1 << self.decomp()) - 1); lower_decomp.push(F::from_canonical_u32(bits)); - self.range_checker.add_count(bits); + range_checker.add_count(bits); } // shift the last limb and range check - let bits = (lower_u32 >> ((self.air.num_limbs() - 1) * self.air.decomp())) - & ((1 << self.air.decomp()) - 1); - if (bits << last_limb_shift) < *self.air.range_max() { - self.range_checker.add_count(bits << last_limb_shift); + let bits = + (lower_u32 >> ((self.num_limbs() - 1) * self.decomp())) & ((1 << self.decomp()) - 1); + if (bits << last_limb_shift) < *self.range_max() { + range_checker.add_count(bits << last_limb_shift); } lower_decomp.push(F::from_canonical_u32(bits << last_limb_shift)); diff --git a/chips/src/is_less_than_tuple/air.rs b/chips/src/is_less_than_tuple/air.rs index 55b16b6487..fbf650bc33 100644 --- a/chips/src/is_less_than_tuple/air.rs +++ b/chips/src/is_less_than_tuple/air.rs @@ -5,7 +5,14 @@ use p3_field::{AbstractField, Field}; use p3_matrix::Matrix; use crate::{ - is_less_than::{columns::IsLessThanCols, IsLessThanChip}, + is_equal::{ + columns::{IsEqualAuxCols, IsEqualCols, IsEqualIOCols}, + IsEqualChip, + }, + is_less_than::{ + columns::{IsLessThanAuxCols, IsLessThanCols, IsLessThanIOCols}, + IsLessThanChip, + }, range_gate::RangeCheckerGateChip, sub_chip::{AirConfig, SubAir}, }; @@ -15,7 +22,7 @@ use super::{ IsLessThanTupleAir, IsLessThanTupleChip, }; -impl AirConfig for IsLessThanTupleChip { +impl AirConfig for IsLessThanTupleAir { type Cols = IsLessThanTupleCols; } @@ -75,16 +82,17 @@ impl SubAir for IsLessThanTupleAir { ); // here we constrain that less_than[i] indicates whether x[i] < y[i] using the IsLessThan subchip - let mut is_less_than_slice = vec![x_val, y_val]; - is_less_than_slice.push(aux.less_than[i]); - is_less_than_slice.push(aux.lower_bits[i]); - is_less_than_slice.extend_from_slice(&aux.lower_bits_decomp[i]); - - let is_less_than_cols = IsLessThanCols::::from_slice( - &is_less_than_slice, - self.limb_bits()[i], - *self.decomp(), - ); + let is_less_than_cols = IsLessThanCols { + io: IsLessThanIOCols { + x: x_val, + y: y_val, + less_than: aux.less_than[i], + }, + aux: IsLessThanAuxCols { + lower: aux.less_than_cols[i].lower, + lower_decomp: aux.less_than_cols[i].lower_decomp.clone(), + }, + }; SubAir::eval( &is_less_than_chip_dummy.air, @@ -94,24 +102,22 @@ impl SubAir for IsLessThanTupleAir { ); } - for i in 0..x.len() { - // constrain that diff is the difference between the two elements of consecutive rows - let diff = y[i] - x[i]; - builder.assert_eq(diff, aux.diff[i]); - } - // together, these constrain that is_equal is the indicator for whether diff == 0, i.e. x[i] = y[i] - for i in 0..self.tuple_len() { - let diff = aux.diff[i]; - let is_equal = aux.is_zero[i]; - let inverse = aux.inverses[i]; - - // check that diff * is_equal = 0 - builder.assert_zero(diff * is_equal); - // check that is_equal is boolean - builder.assert_zero(is_equal * (AB::Expr::one() - is_equal)); - // check that inverse * (diff + is_equal) = 1 - builder.assert_one(inverse * (diff + is_equal)); + for i in 0..x.len() { + let is_equal = aux.is_equal[i]; + let inv = aux.is_equal_cols[i].inv; + + let is_equal_chip = IsEqualChip {}; + let is_equal_cols = IsEqualCols { + io: IsEqualIOCols { + x: x[i], + y: y[i], + is_equal, + }, + aux: IsEqualAuxCols { inv }, + }; + + SubAir::eval(&is_equal_chip, builder, is_equal_cols.io, is_equal_cols.aux); } // to check whether one row is less than another, we can use the indicators to generate a boolean @@ -123,8 +129,8 @@ impl SubAir for IsLessThanTupleAir { for (i, &less_than_value) in less_than.iter().enumerate() { let mut curr_expr: AB::Expr = less_than_value.into(); - for &is_zero_value in &aux.is_zero[i + 1..] { - curr_expr *= is_zero_value.into(); + for &is_equal_value in &aux.is_equal[i + 1..] { + curr_expr *= is_equal_value.into(); } check_less_than += curr_expr; } diff --git a/chips/src/is_less_than_tuple/chip.rs b/chips/src/is_less_than_tuple/chip.rs index 9cbd68a4d6..05c57789bd 100644 --- a/chips/src/is_less_than_tuple/chip.rs +++ b/chips/src/is_less_than_tuple/chip.rs @@ -2,7 +2,7 @@ use crate::{is_less_than::columns::IsLessThanCols, sub_chip::SubAirWithInteracti use afs_stark_backend::interaction::{Chip, Interaction}; use p3_field::PrimeField64; -use super::{columns::IsLessThanTupleCols, IsLessThanTupleChip}; +use super::{columns::IsLessThanTupleCols, IsLessThanTupleAir, IsLessThanTupleChip}; impl Chip for IsLessThanTupleChip { fn sends(&self) -> Vec> { @@ -20,33 +20,33 @@ impl Chip for IsLessThanTupleChip { self.air.tuple_len(), ); - SubAirWithInteractions::sends(self, cols_numbered) + SubAirWithInteractions::sends(&self.air, cols_numbered) } } -impl SubAirWithInteractions for IsLessThanTupleChip { +impl SubAirWithInteractions for IsLessThanTupleAir { fn sends(&self, col_indices: IsLessThanTupleCols) -> Vec> { // num_limbs is the number of limbs, not including the last shifted limb let mut interactions = vec![]; - for i in 0..self.air.tuple_len() { + for i in 0..self.tuple_len() { let mut is_less_than_cols = vec![ col_indices.io.x[i], col_indices.io.y[i], col_indices.aux.less_than[i], - col_indices.aux.lower_bits[i], + col_indices.aux.less_than_cols[i].lower, ]; - is_less_than_cols.extend_from_slice(&col_indices.aux.lower_bits_decomp[i]); + is_less_than_cols.extend_from_slice(&col_indices.aux.less_than_cols[i].lower_decomp); let is_less_than_cols = IsLessThanCols::::from_slice( &is_less_than_cols, - self.air.limb_bits().clone()[i], - *self.air.decomp(), + self.limb_bits().clone()[i], + *self.decomp(), ); let curr_interactions = - SubAirWithInteractions::::sends(&self.air.is_lt_airs[i], is_less_than_cols); + SubAirWithInteractions::::sends(&self.is_lt_airs[i], is_less_than_cols); interactions.extend(curr_interactions); } diff --git a/chips/src/is_less_than_tuple/columns.rs b/chips/src/is_less_than_tuple/columns.rs index 7078ad62f6..0ffaeae2b9 100644 --- a/chips/src/is_less_than_tuple/columns.rs +++ b/chips/src/is_less_than_tuple/columns.rs @@ -1,5 +1,7 @@ use afs_derive::AlignedBorrow; +use crate::{is_equal::columns::IsEqualAuxCols, is_less_than::columns::IsLessThanAuxCols}; + #[derive(Default, AlignedBorrow)] pub struct IsLessThanTupleIOCols { pub x: Vec, @@ -9,11 +11,9 @@ pub struct IsLessThanTupleIOCols { pub struct IsLessThanTupleAuxCols { pub less_than: Vec, - pub lower_bits: Vec, - pub lower_bits_decomp: Vec>, - pub diff: Vec, - pub is_zero: Vec, - pub inverses: Vec, + pub less_than_cols: Vec>, + pub is_equal: Vec, + pub is_equal_cols: Vec>, } pub struct IsLessThanTupleCols { @@ -26,12 +26,14 @@ impl IsLessThanTupleCols { let mut x: Vec = vec![]; let mut y: Vec = vec![]; + let mut lower_vec: Vec = vec![]; + let mut lower_decomp_vec: Vec> = vec![]; + let mut less_than_cols: Vec> = vec![]; + let mut less_than: Vec = vec![]; - let mut lower_bits: Vec = vec![]; - let mut lower_bits_decomp: Vec> = vec![]; - let mut diff: Vec = vec![]; - let mut is_zero: Vec = vec![]; + let mut is_equal: Vec = vec![]; let mut inverses: Vec = vec![]; + let mut is_equal_cols: Vec> = vec![]; let mut curr_start_idx = 0; let mut curr_end_idx = tuple_len; @@ -60,7 +62,7 @@ impl IsLessThanTupleCols { curr_end_idx += tuple_len; // get the lower bits for each 2^limb_bits[i] + y[i] - x[i] - 1 - lower_bits.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); + lower_vec.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); // get the lower bits decompositions for &limb_bit in limb_bits.iter() { @@ -74,20 +76,14 @@ impl IsLessThanTupleCols { lower_bits_curr.push(slc[curr_start_idx + j].clone()); } - lower_bits_decomp.push(lower_bits_curr); + lower_decomp_vec.push(lower_bits_curr); } curr_start_idx = curr_end_idx; curr_end_idx += tuple_len; - // get the differences y[i] - x[i] - diff.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); - - curr_start_idx = curr_end_idx; - curr_end_idx += tuple_len; - // get whether y[i] - x[i] == 0 - is_zero.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); + is_equal.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); curr_start_idx = curr_end_idx; curr_end_idx += tuple_len; @@ -95,6 +91,20 @@ impl IsLessThanTupleCols { // get the inverses k such that k * (diff[i] + is_zero[i]) = 1 inverses.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); + for i in 0..tuple_len { + let less_than_col = IsLessThanAuxCols { + lower: lower_vec[i].clone(), + lower_decomp: lower_decomp_vec[i].clone(), + }; + + less_than_cols.push(less_than_col); + } + + for inv in inverses.iter() { + let is_equal_col = IsEqualAuxCols { inv: inv.clone() }; + is_equal_cols.push(is_equal_col); + } + IsLessThanTupleCols { io: IsLessThanTupleIOCols { x, @@ -103,11 +113,9 @@ impl IsLessThanTupleCols { }, aux: IsLessThanTupleAuxCols { less_than, - lower_bits, - lower_bits_decomp, - diff, - is_zero, - inverses, + less_than_cols, + is_equal, + is_equal_cols, }, } } @@ -118,13 +126,20 @@ impl IsLessThanTupleCols { flattened.extend_from_slice(&self.io.y); flattened.push(self.io.tuple_less_than.clone()); flattened.extend_from_slice(&self.aux.less_than); - flattened.extend_from_slice(&self.aux.lower_bits); - for i in 0..self.aux.lower_bits_decomp.len() { - flattened.extend_from_slice(&self.aux.lower_bits_decomp[i]); + + for i in 0..self.aux.less_than_cols.len() { + flattened.push(self.aux.less_than_cols[i].lower.clone()); + } + + for i in 0..self.aux.less_than_cols.len() { + flattened.extend_from_slice(&self.aux.less_than_cols[i].lower_decomp); + } + + flattened.extend_from_slice(&self.aux.is_equal); + + for i in 0..self.aux.is_equal_cols.len() { + flattened.push(self.aux.is_equal_cols[i].inv.clone()); } - flattened.extend_from_slice(&self.aux.diff); - flattened.extend_from_slice(&self.aux.is_zero); - flattened.extend_from_slice(&self.aux.inverses); flattened } @@ -137,17 +152,15 @@ impl IsLessThanTupleCols { width += 1; // for the less than indicator width += tuple_len; - // for the lower bits + // for the lowers width += tuple_len; - // for the lower bits decomposition + // for the decomposed lowers for &limb_bit in limb_bits.iter() { let num_limbs = (limb_bit + decomp - 1) / decomp; width += num_limbs + 1; } - // for the difference between consecutive rows - width += tuple_len; // for the indicator whether difference is zero width += tuple_len; // for the inverses k such that k * (diff[i] + is_zero[i]) = 1 diff --git a/chips/src/is_less_than_tuple/trace.rs b/chips/src/is_less_than_tuple/trace.rs index 5e8030b89f..dfccba21dc 100644 --- a/chips/src/is_less_than_tuple/trace.rs +++ b/chips/src/is_less_than_tuple/trace.rs @@ -1,11 +1,18 @@ +use std::sync::Arc; + use p3_field::PrimeField64; use p3_matrix::dense::RowMajorMatrix; -use crate::{is_less_than::IsLessThanChip, sub_chip::LocalTraceInstructions}; +use crate::{ + is_equal::columns::IsEqualAuxCols, + is_less_than::{columns::IsLessThanAuxCols, IsLessThanChip}, + range_gate::RangeCheckerGateChip, + sub_chip::LocalTraceInstructions, +}; use super::{ columns::{IsLessThanTupleAuxCols, IsLessThanTupleCols, IsLessThanTupleIOCols}, - IsLessThanTupleChip, + IsLessThanTupleAir, IsLessThanTupleChip, }; impl IsLessThanTupleChip { @@ -16,21 +23,24 @@ impl IsLessThanTupleChip { self.air.tuple_len(), ); - let row: Vec = self.generate_trace_row((x, y)).flatten(); + let row: Vec = self + .air + .generate_trace_row((x, y, self.range_checker.clone())) + .flatten(); RowMajorMatrix::new(row, num_cols) } } -impl LocalTraceInstructions for IsLessThanTupleChip { - type LocalInput = (Vec, Vec); +impl LocalTraceInstructions for IsLessThanTupleAir { + type LocalInput = (Vec, Vec, Arc); fn generate_trace_row(&self, input: Self::LocalInput) -> Self::Cols { - let (x, y) = input; + let (x, y, range_checker) = input; let mut less_than: Vec = vec![]; - let mut lower_bits: Vec = vec![]; - let mut lower_bits_decomp: Vec> = vec![]; + let mut lower_vec: Vec = vec![]; + let mut lower_decomp_vec: Vec> = vec![]; let mut valid = true; let mut tuple_less_than = F::zero(); @@ -38,19 +48,21 @@ impl LocalTraceInstructions for IsLessThanTupleChip { // use subchip to generate relevant columns for i in 0..x.len() { let is_less_than_chip = IsLessThanChip::new( - *self.air.bus_index(), - *self.air.range_max(), - self.air.limb_bits()[i], - *self.air.decomp(), - self.range_checker.clone(), + *self.bus_index(), + *self.range_max(), + self.limb_bits()[i], + *self.decomp(), + range_checker.clone(), ); - let curr_less_than_row = - LocalTraceInstructions::generate_trace_row(&is_less_than_chip, (x[i], y[i])) - .flatten(); + let curr_less_than_row = LocalTraceInstructions::generate_trace_row( + &is_less_than_chip.air, + (x[i], y[i], range_checker.clone()), + ) + .flatten(); less_than.push(curr_less_than_row[2]); - lower_bits.push(curr_less_than_row[3]); - lower_bits_decomp.push(curr_less_than_row[4..].to_vec()); + lower_vec.push(curr_less_than_row[3]); + lower_decomp_vec.push(curr_less_than_row[4..].to_vec()); } // compute whether the x < y @@ -62,10 +74,8 @@ impl LocalTraceInstructions for IsLessThanTupleChip { } } - // contains the difference between consecutive rows - let mut diff: Vec = vec![]; // contains indicator whether difference is zero - let mut is_zero: Vec = vec![]; + let mut is_equal: Vec = vec![]; // contains y such that y * (i + x) = 1 let mut inverses: Vec = vec![]; @@ -74,19 +84,33 @@ impl LocalTraceInstructions for IsLessThanTupleChip { let next_val = y[i]; // the difference between the two limbs - let curr_diff = F::from_canonical_u32(next_val) - F::from_canonical_u32(val); - diff.push(curr_diff); + let curr_diff = F::from_canonical_u32(val) - F::from_canonical_u32(next_val); // compute the equal indicator and inverses if next_val == val { - is_zero.push(F::one()); + is_equal.push(F::one()); inverses.push((curr_diff + F::one()).inverse()); } else { - is_zero.push(F::zero()); + is_equal.push(F::zero()); inverses.push(curr_diff.inverse()); } } + let mut less_than_cols: Vec> = vec![]; + for i in 0..x.len() { + let less_than_col = IsLessThanAuxCols { + lower: lower_vec[i], + lower_decomp: lower_decomp_vec[i].clone(), + }; + less_than_cols.push(less_than_col); + } + + let mut is_equal_cols: Vec> = vec![]; + for inverse in &inverses { + let is_equal_col = IsEqualAuxCols { inv: *inverse }; + is_equal_cols.push(is_equal_col); + } + let io = IsLessThanTupleIOCols { x: x.into_iter().map(F::from_canonical_u32).collect(), y: y.into_iter().map(F::from_canonical_u32).collect(), @@ -94,11 +118,9 @@ impl LocalTraceInstructions for IsLessThanTupleChip { }; let aux = IsLessThanTupleAuxCols { less_than, - lower_bits, - lower_bits_decomp, - diff, - is_zero, - inverses, + less_than_cols, + is_equal, + is_equal_cols, }; IsLessThanTupleCols { io, aux } From aef744ab7813a8a97c968cc18cdbdc4950e5fd54 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Wed, 5 Jun 2024 16:30:13 -0400 Subject: [PATCH 16/46] chore: address comments --- chips/src/assert_sorted/mod.rs | 6 ++++++ chips/src/is_less_than/mod.rs | 4 ++-- chips/src/is_less_than_tuple/air.rs | 11 ++++------ chips/src/is_less_than_tuple/chip.rs | 30 ++++++++++++++-------------- chips/src/is_less_than_tuple/mod.rs | 4 ++++ 5 files changed, 31 insertions(+), 24 deletions(-) diff --git a/chips/src/assert_sorted/mod.rs b/chips/src/assert_sorted/mod.rs index a0ad52d973..e1232779a0 100644 --- a/chips/src/assert_sorted/mod.rs +++ b/chips/src/assert_sorted/mod.rs @@ -18,16 +18,22 @@ pub mod trace; #[derive(Default, Getters)] pub struct AssertedSortedAir { + // The bus index for sends to range chip #[getset(get = "pub")] bus_index: usize, + // The maximum range for the range checker #[getset(get = "pub")] range_max: u32, + // The limb_bits for each element of the keys #[getset(get = "pub")] limb_bits: Vec, + // The number of bits to decompose each number into, for less than checking #[getset(get = "pub")] decomp: usize, + // The number of elements in a key #[getset(get = "pub")] key_vec_len: usize, + // The keys to check for sortedness #[getset(get = "pub")] keys: Vec>, } diff --git a/chips/src/is_less_than/mod.rs b/chips/src/is_less_than/mod.rs index f3e15eca5a..06ef0c089c 100644 --- a/chips/src/is_less_than/mod.rs +++ b/chips/src/is_less_than/mod.rs @@ -11,9 +11,9 @@ pub mod chip; pub mod columns; pub mod trace; -#[derive(Default, Getters)] +#[derive(Default, Clone, Getters)] pub struct IsLessThanAir { - // The bus index + // The bus index for sends to range chip #[getset(get = "pub")] bus_index: usize, // The maximum range for the range checker diff --git a/chips/src/is_less_than_tuple/air.rs b/chips/src/is_less_than_tuple/air.rs index fbf650bc33..ebe64d9dfb 100644 --- a/chips/src/is_less_than_tuple/air.rs +++ b/chips/src/is_less_than_tuple/air.rs @@ -73,13 +73,10 @@ impl SubAir for IsLessThanTupleAir { *self.range_max(), )); - let is_less_than_chip_dummy = IsLessThanChip::new( - *self.bus_index(), - *self.range_max(), - self.limb_bits()[i], - *self.decomp(), - range_checker_dummy, - ); + let is_less_than_chip_dummy = IsLessThanChip { + air: self.is_lt_airs[i].clone(), + range_checker: range_checker_dummy, + }; // here we constrain that less_than[i] indicates whether x[i] < y[i] using the IsLessThan subchip let is_less_than_cols = IsLessThanCols { diff --git a/chips/src/is_less_than_tuple/chip.rs b/chips/src/is_less_than_tuple/chip.rs index 05c57789bd..797947b9b4 100644 --- a/chips/src/is_less_than_tuple/chip.rs +++ b/chips/src/is_less_than_tuple/chip.rs @@ -1,4 +1,7 @@ -use crate::{is_less_than::columns::IsLessThanCols, sub_chip::SubAirWithInteractions}; +use crate::{ + is_less_than::columns::{IsLessThanAuxCols, IsLessThanCols, IsLessThanIOCols}, + sub_chip::SubAirWithInteractions, +}; use afs_stark_backend::interaction::{Chip, Interaction}; use p3_field::PrimeField64; @@ -30,20 +33,17 @@ impl SubAirWithInteractions for IsLessThanTupleAir { let mut interactions = vec![]; for i in 0..self.tuple_len() { - let mut is_less_than_cols = vec![ - col_indices.io.x[i], - col_indices.io.y[i], - col_indices.aux.less_than[i], - col_indices.aux.less_than_cols[i].lower, - ]; - - is_less_than_cols.extend_from_slice(&col_indices.aux.less_than_cols[i].lower_decomp); - - let is_less_than_cols = IsLessThanCols::::from_slice( - &is_less_than_cols, - self.limb_bits().clone()[i], - *self.decomp(), - ); + let is_less_than_cols = IsLessThanCols { + io: IsLessThanIOCols { + x: col_indices.io.x[i], + y: col_indices.io.y[i], + less_than: col_indices.aux.less_than[i], + }, + aux: IsLessThanAuxCols { + lower: col_indices.aux.less_than_cols[i].lower, + lower_decomp: col_indices.aux.less_than_cols[i].lower_decomp.clone(), + }, + }; let curr_interactions = SubAirWithInteractions::::sends(&self.is_lt_airs[i], is_less_than_cols); diff --git a/chips/src/is_less_than_tuple/mod.rs b/chips/src/is_less_than_tuple/mod.rs index cbc15f243c..16cf867995 100644 --- a/chips/src/is_less_than_tuple/mod.rs +++ b/chips/src/is_less_than_tuple/mod.rs @@ -14,12 +14,16 @@ pub mod trace; #[derive(Default, Getters)] pub struct IsLessThanTupleAir { + // The bus index for sends to range chip #[getset(get = "pub")] bus_index: usize, + // The maximum range for the range checker #[getset(get = "pub")] range_max: u32, + // The number of bits to decompose each number into, for less than checking #[getset(get = "pub")] decomp: usize, + // IsLessThanAirs for each tuple element #[getset(get = "pub")] is_lt_airs: Vec, } From 38d5decd282516ad80ca3ce22618f030021f3376 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Wed, 5 Jun 2024 19:13:46 -0400 Subject: [PATCH 17/46] chore: eliminate high dim poly from IsLessThanTupleChip --- chips/src/assert_sorted/columns.rs | 30 ++++++++--- chips/src/is_less_than/air.rs | 12 ++--- chips/src/is_less_than/chip.rs | 15 ++---- chips/src/is_less_than/tests/mod.rs | 8 +-- chips/src/is_less_than_tuple/air.rs | 61 +++++++++++---------- chips/src/is_less_than_tuple/chip.rs | 22 ++++---- chips/src/is_less_than_tuple/columns.rs | 66 +++++++++++++++-------- chips/src/is_less_than_tuple/tests/mod.rs | 36 +++++++++---- chips/src/is_less_than_tuple/trace.rs | 40 +++++++++++--- 9 files changed, 189 insertions(+), 101 deletions(-) diff --git a/chips/src/assert_sorted/columns.rs b/chips/src/assert_sorted/columns.rs index 261609a361..3fd5e32c99 100644 --- a/chips/src/assert_sorted/columns.rs +++ b/chips/src/assert_sorted/columns.rs @@ -82,21 +82,34 @@ impl AssertSortedCols { // note that this sum will always be nonzero so the inverse will exist let inverses = slc[curr_start_idx..curr_end_idx].to_vec(); - let mut less_than_cols: Vec> = vec![]; + curr_start_idx = curr_end_idx; + curr_end_idx += key_vec_len; + + let mut less_than_aux: Vec> = vec![]; for i in 0..key_vec_len { let less_than_col = IsLessThanAuxCols { lower: lower_vec[i].clone(), lower_decomp: lower_decomp_vec[i].clone(), }; - less_than_cols.push(less_than_col); + less_than_aux.push(less_than_col); } - let mut is_equal_cols: Vec> = vec![]; + let mut is_equal_aux: Vec> = vec![]; for inv in inverses.iter() { let is_equal_col = IsEqualAuxCols { inv: inv.clone() }; - is_equal_cols.push(is_equal_col); + is_equal_aux.push(is_equal_col); } + let mut is_equal_cumulative: Vec = vec![]; + let mut less_than_cumulative: Vec = vec![]; + + is_equal_cumulative.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); + + curr_start_idx = curr_end_idx; + curr_end_idx += key_vec_len; + + less_than_cumulative.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); + let io = IsLessThanTupleIOCols { x, y, @@ -104,9 +117,11 @@ impl AssertSortedCols { }; let aux = IsLessThanTupleAuxCols { less_than, - less_than_cols, + less_than_aux, is_equal, - is_equal_cols, + is_equal_aux, + is_equal_cumulative, + less_than_cumulative, }; let is_less_than_tuple_cols = IsLessThanTupleCols { io, aux }; @@ -151,6 +166,9 @@ impl AssertSortedCols { // for the y such that y * (i + x) = 1 width += key_vec_len; + // for the cumulative is_equal and less_than + width += 2 * key_vec_len; + width } } diff --git a/chips/src/is_less_than/air.rs b/chips/src/is_less_than/air.rs index bca87a3b04..8c86995ba7 100644 --- a/chips/src/is_less_than/air.rs +++ b/chips/src/is_less_than/air.rs @@ -8,20 +8,20 @@ use crate::sub_chip::{AirConfig, SubAir}; use super::{ columns::{IsLessThanAuxCols, IsLessThanCols, IsLessThanIOCols}, - IsLessThanAir, IsLessThanChip, + IsLessThanAir, }; impl AirConfig for IsLessThanAir { type Cols = IsLessThanCols; } -impl BaseAir for IsLessThanChip { +impl BaseAir for IsLessThanAir { fn width(&self) -> usize { - IsLessThanCols::::get_width(*self.air.limb_bits(), *self.air.decomp()) + IsLessThanCols::::get_width(*self.limb_bits(), *self.decomp()) } } -impl Air for IsLessThanChip { +impl Air for IsLessThanAir { fn eval(&self, builder: &mut AB) { let main = builder.main(); @@ -29,9 +29,9 @@ impl Air for IsLessThanChip { let local: &[AB::Var] = (*local).borrow(); let local_cols = - IsLessThanCols::::from_slice(local, *self.air.limb_bits(), *self.air.decomp()); + IsLessThanCols::::from_slice(local, *self.limb_bits(), *self.decomp()); - SubAir::eval(&self.air, builder, local_cols.io, local_cols.aux); + SubAir::eval(self, builder, local_cols.io, local_cols.aux); } } diff --git a/chips/src/is_less_than/chip.rs b/chips/src/is_less_than/chip.rs index 37229487c6..ffe08f478b 100644 --- a/chips/src/is_less_than/chip.rs +++ b/chips/src/is_less_than/chip.rs @@ -5,20 +5,15 @@ use afs_stark_backend::interaction::{Chip, Interaction}; use p3_air::VirtualPairCol; use p3_field::PrimeField64; -use super::IsLessThanChip; - -impl Chip for IsLessThanChip { +impl Chip for IsLessThanAir { fn sends(&self) -> Vec> { - let num_cols = IsLessThanCols::::get_width(*self.air.limb_bits(), *self.air.decomp()); + let num_cols = IsLessThanCols::::get_width(*self.limb_bits(), *self.decomp()); let all_cols = (0..num_cols).collect::>(); - let cols_numbered = IsLessThanCols::::from_slice( - &all_cols, - *self.air.limb_bits(), - *self.air.decomp(), - ); + let cols_numbered = + IsLessThanCols::::from_slice(&all_cols, *self.limb_bits(), *self.decomp()); - SubAirWithInteractions::sends(&self.air, cols_numbered) + SubAirWithInteractions::sends(self, cols_numbered) } } diff --git a/chips/src/is_less_than/tests/mod.rs b/chips/src/is_less_than/tests/mod.rs index 054f53b9d7..f86a2b734a 100644 --- a/chips/src/is_less_than/tests/mod.rs +++ b/chips/src/is_less_than/tests/mod.rs @@ -25,7 +25,7 @@ fn test_is_less_than_chip_lt() { let range_trace: DenseMatrix = chip.range_checker.generate_trace(); run_simple_test_no_pis( - vec![&chip, chip.range_checker.as_ref()], + vec![&chip.air, chip.range_checker.as_ref()], vec![trace, range_trace], ) .expect("Verification failed"); @@ -45,7 +45,7 @@ fn test_is_less_than_chip_gt() { let range_trace: DenseMatrix = chip.range_checker.generate_trace(); run_simple_test_no_pis( - vec![&chip, chip.range_checker.as_ref()], + vec![&chip.air, chip.range_checker.as_ref()], vec![trace, range_trace], ) .expect("Verification failed"); @@ -65,7 +65,7 @@ fn test_is_less_than_chip_eq() { let range_trace: DenseMatrix = chip.range_checker.generate_trace(); run_simple_test_no_pis( - vec![&chip, chip.range_checker.as_ref()], + vec![&chip.air, chip.range_checker.as_ref()], vec![trace, range_trace], ) .expect("Verification failed"); @@ -91,7 +91,7 @@ fn test_is_less_than_negative() { }); assert_eq!( run_simple_test_no_pis( - vec![&chip, chip.range_checker.as_ref()], + vec![&chip.air, chip.range_checker.as_ref()], vec![trace, range_trace], ), Err(VerificationError::OodEvaluationMismatch), diff --git a/chips/src/is_less_than_tuple/air.rs b/chips/src/is_less_than_tuple/air.rs index ebe64d9dfb..8129b73dcb 100644 --- a/chips/src/is_less_than_tuple/air.rs +++ b/chips/src/is_less_than_tuple/air.rs @@ -1,7 +1,7 @@ use std::{borrow::Borrow, sync::Arc}; use p3_air::{Air, AirBuilder, BaseAir}; -use p3_field::{AbstractField, Field}; +use p3_field::Field; use p3_matrix::Matrix; use crate::{ @@ -19,24 +19,24 @@ use crate::{ use super::{ columns::{IsLessThanTupleAuxCols, IsLessThanTupleCols, IsLessThanTupleIOCols}, - IsLessThanTupleAir, IsLessThanTupleChip, + IsLessThanTupleAir, }; impl AirConfig for IsLessThanTupleAir { type Cols = IsLessThanTupleCols; } -impl BaseAir for IsLessThanTupleChip { +impl BaseAir for IsLessThanTupleAir { fn width(&self) -> usize { IsLessThanTupleCols::::get_width( - self.air.limb_bits().clone(), - *self.air.decomp(), - self.air.tuple_len(), + self.limb_bits().clone(), + *self.decomp(), + self.tuple_len(), ) } } -impl Air for IsLessThanTupleChip { +impl Air for IsLessThanTupleAir { fn eval(&self, builder: &mut AB) { let main = builder.main(); @@ -45,12 +45,12 @@ impl Air for IsLessThanTupleChip { let local_cols = IsLessThanTupleCols::::from_slice( local, - self.air.limb_bits().clone(), - *self.air.decomp(), - self.air.tuple_len(), + self.limb_bits().clone(), + *self.decomp(), + self.tuple_len(), ); - SubAir::eval(&self.air, builder, local_cols.io, local_cols.aux); + SubAir::eval(self, builder, local_cols.io, local_cols.aux); } } @@ -86,8 +86,8 @@ impl SubAir for IsLessThanTupleAir { less_than: aux.less_than[i], }, aux: IsLessThanAuxCols { - lower: aux.less_than_cols[i].lower, - lower_decomp: aux.less_than_cols[i].lower_decomp.clone(), + lower: aux.less_than_aux[i].lower, + lower_decomp: aux.less_than_aux[i].lower_decomp.clone(), }, }; @@ -102,7 +102,7 @@ impl SubAir for IsLessThanTupleAir { // together, these constrain that is_equal is the indicator for whether diff == 0, i.e. x[i] = y[i] for i in 0..x.len() { let is_equal = aux.is_equal[i]; - let inv = aux.is_equal_cols[i].inv; + let inv = aux.is_equal_aux[i].inv; let is_equal_chip = IsEqualChip {}; let is_equal_cols = IsEqualCols { @@ -117,19 +117,26 @@ impl SubAir for IsLessThanTupleAir { SubAir::eval(&is_equal_chip, builder, is_equal_cols.io, is_equal_cols.aux); } - // to check whether one row is less than another, we can use the indicators to generate a boolean - // expression; the idea is that, starting at the most significant limb, a row is less than the next - // if all the limbs more significant are equal and the current limb is less than the corresponding - // limb in the next row - let mut check_less_than: AB::Expr = AB::Expr::zero(); - let less_than = aux.less_than.clone(); - - for (i, &less_than_value) in less_than.iter().enumerate() { - let mut curr_expr: AB::Expr = less_than_value.into(); - for &is_equal_value in &aux.is_equal[i + 1..] { - curr_expr *= is_equal_value.into(); - } - check_less_than += curr_expr; + let is_equal_cumulative = aux.is_equal_cumulative.clone(); + let less_than_cumulative = aux.less_than_cumulative.clone(); + + builder.assert_eq(is_equal_cumulative[0], aux.is_equal[0]); + builder.assert_eq(less_than_cumulative[0], aux.less_than[0]); + for i in 1..x.len() { + builder.assert_eq( + is_equal_cumulative[i], + is_equal_cumulative[i - 1] * aux.is_equal[i], + ); + builder.assert_eq( + less_than_cumulative[i], + less_than_cumulative[i - 1] + aux.less_than[i] * is_equal_cumulative[i - 1], + ); + } + + let mut check_less_than: AB::Expr = less_than_cumulative[0].into(); + + for i in 1..x.len() { + check_less_than += less_than_cumulative[i] * is_equal_cumulative[i - 1]; } builder.assert_eq(io.tuple_less_than, check_less_than); diff --git a/chips/src/is_less_than_tuple/chip.rs b/chips/src/is_less_than_tuple/chip.rs index 797947b9b4..180f43051d 100644 --- a/chips/src/is_less_than_tuple/chip.rs +++ b/chips/src/is_less_than_tuple/chip.rs @@ -5,25 +5,25 @@ use crate::{ use afs_stark_backend::interaction::{Chip, Interaction}; use p3_field::PrimeField64; -use super::{columns::IsLessThanTupleCols, IsLessThanTupleAir, IsLessThanTupleChip}; +use super::{columns::IsLessThanTupleCols, IsLessThanTupleAir}; -impl Chip for IsLessThanTupleChip { +impl Chip for IsLessThanTupleAir { fn sends(&self) -> Vec> { let num_cols = IsLessThanTupleCols::::get_width( - self.air.limb_bits().clone(), - *self.air.decomp(), - self.air.tuple_len(), + self.limb_bits().clone(), + *self.decomp(), + self.tuple_len(), ); let all_cols = (0..num_cols).collect::>(); let cols_numbered = IsLessThanTupleCols::::from_slice( &all_cols, - self.air.limb_bits().clone(), - *self.air.decomp(), - self.air.tuple_len(), + self.limb_bits().clone(), + *self.decomp(), + self.tuple_len(), ); - SubAirWithInteractions::sends(&self.air, cols_numbered) + SubAirWithInteractions::sends(self, cols_numbered) } } @@ -40,8 +40,8 @@ impl SubAirWithInteractions for IsLessThanTupleAir { less_than: col_indices.aux.less_than[i], }, aux: IsLessThanAuxCols { - lower: col_indices.aux.less_than_cols[i].lower, - lower_decomp: col_indices.aux.less_than_cols[i].lower_decomp.clone(), + lower: col_indices.aux.less_than_aux[i].lower, + lower_decomp: col_indices.aux.less_than_aux[i].lower_decomp.clone(), }, }; diff --git a/chips/src/is_less_than_tuple/columns.rs b/chips/src/is_less_than_tuple/columns.rs index 0ffaeae2b9..c63ba717b3 100644 --- a/chips/src/is_less_than_tuple/columns.rs +++ b/chips/src/is_less_than_tuple/columns.rs @@ -11,9 +11,12 @@ pub struct IsLessThanTupleIOCols { pub struct IsLessThanTupleAuxCols { pub less_than: Vec, - pub less_than_cols: Vec>, + pub less_than_aux: Vec>, pub is_equal: Vec, - pub is_equal_cols: Vec>, + pub is_equal_aux: Vec>, + + pub is_equal_cumulative: Vec, + pub less_than_cumulative: Vec, } pub struct IsLessThanTupleCols { @@ -28,12 +31,15 @@ impl IsLessThanTupleCols { let mut lower_vec: Vec = vec![]; let mut lower_decomp_vec: Vec> = vec![]; - let mut less_than_cols: Vec> = vec![]; + let mut less_than_aux: Vec> = vec![]; let mut less_than: Vec = vec![]; let mut is_equal: Vec = vec![]; let mut inverses: Vec = vec![]; - let mut is_equal_cols: Vec> = vec![]; + let mut is_equal_aux: Vec> = vec![]; + + let mut is_equal_cumulative: Vec = vec![]; + let mut less_than_cumulative: Vec = vec![]; let mut curr_start_idx = 0; let mut curr_end_idx = tuple_len; @@ -79,6 +85,15 @@ impl IsLessThanTupleCols { lower_decomp_vec.push(lower_bits_curr); } + for i in 0..tuple_len { + let less_than_col = IsLessThanAuxCols { + lower: lower_vec[i].clone(), + lower_decomp: lower_decomp_vec[i].clone(), + }; + + less_than_aux.push(less_than_col); + } + curr_start_idx = curr_end_idx; curr_end_idx += tuple_len; @@ -91,20 +106,21 @@ impl IsLessThanTupleCols { // get the inverses k such that k * (diff[i] + is_zero[i]) = 1 inverses.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); - for i in 0..tuple_len { - let less_than_col = IsLessThanAuxCols { - lower: lower_vec[i].clone(), - lower_decomp: lower_decomp_vec[i].clone(), - }; - - less_than_cols.push(less_than_col); - } + curr_start_idx = curr_end_idx; + curr_end_idx += tuple_len; for inv in inverses.iter() { let is_equal_col = IsEqualAuxCols { inv: inv.clone() }; - is_equal_cols.push(is_equal_col); + is_equal_aux.push(is_equal_col); } + is_equal_cumulative.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); + + curr_start_idx = curr_end_idx; + curr_end_idx += tuple_len; + + less_than_cumulative.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); + IsLessThanTupleCols { io: IsLessThanTupleIOCols { x, @@ -113,9 +129,11 @@ impl IsLessThanTupleCols { }, aux: IsLessThanTupleAuxCols { less_than, - less_than_cols, + less_than_aux, is_equal, - is_equal_cols, + is_equal_aux, + is_equal_cumulative, + less_than_cumulative, }, } } @@ -127,20 +145,23 @@ impl IsLessThanTupleCols { flattened.push(self.io.tuple_less_than.clone()); flattened.extend_from_slice(&self.aux.less_than); - for i in 0..self.aux.less_than_cols.len() { - flattened.push(self.aux.less_than_cols[i].lower.clone()); + for i in 0..self.aux.less_than_aux.len() { + flattened.push(self.aux.less_than_aux[i].lower.clone()); } - for i in 0..self.aux.less_than_cols.len() { - flattened.extend_from_slice(&self.aux.less_than_cols[i].lower_decomp); + for i in 0..self.aux.less_than_aux.len() { + flattened.extend_from_slice(&self.aux.less_than_aux[i].lower_decomp); } flattened.extend_from_slice(&self.aux.is_equal); - for i in 0..self.aux.is_equal_cols.len() { - flattened.push(self.aux.is_equal_cols[i].inv.clone()); + for i in 0..self.aux.is_equal_aux.len() { + flattened.push(self.aux.is_equal_aux[i].inv.clone()); } + flattened.extend_from_slice(&self.aux.is_equal_cumulative); + flattened.extend_from_slice(&self.aux.less_than_cumulative); + flattened } @@ -166,6 +187,9 @@ impl IsLessThanTupleCols { // for the inverses k such that k * (diff[i] + is_zero[i]) = 1 width += tuple_len; + // for the cumulative is_equal and less_than + width += 2 * tuple_len; + width } } diff --git a/chips/src/is_less_than_tuple/tests/mod.rs b/chips/src/is_less_than_tuple/tests/mod.rs index 623ef8245b..26b01d62f7 100644 --- a/chips/src/is_less_than_tuple/tests/mod.rs +++ b/chips/src/is_less_than_tuple/tests/mod.rs @@ -23,14 +23,19 @@ fn test_is_less_than_tuple_chip_lt() { let trace = chip.generate_trace(vec![14321, 123], vec![26678, 233]); let range_checker_trace = range_checker.generate_trace(); - run_simple_test_no_pis(vec![&chip, range_checker], vec![trace, range_checker_trace]) - .expect("Verification failed"); + println!("trace: {:?}", trace); + + run_simple_test_no_pis( + vec![&chip.air, range_checker], + vec![trace, range_checker_trace], + ) + .expect("Verification failed"); } #[test] fn test_is_less_than_tuple_chip_gt() { let bus_index: usize = 0; - let limb_bits: Vec = vec![16, 8]; + let limb_bits: Vec = vec![8, 16]; let decomp: usize = 8; let range_max: u32 = 1 << decomp; @@ -38,11 +43,16 @@ fn test_is_less_than_tuple_chip_gt() { let chip = IsLessThanTupleChip::new(bus_index, range_max, limb_bits, decomp, range_checker); let range_checker = chip.range_checker.as_ref(); - let trace = chip.generate_trace(vec![14321, 244], vec![26678, 233]); + let trace = chip.generate_trace(vec![244, 14321], vec![233, 26678]); let range_checker_trace = range_checker.generate_trace(); - run_simple_test_no_pis(vec![&chip, range_checker], vec![trace, range_checker_trace]) - .expect("Verification failed"); + println!("trace: {:?}", trace); + + run_simple_test_no_pis( + vec![&chip.air, range_checker], + vec![trace, range_checker_trace], + ) + .expect("Verification failed"); } #[test] @@ -59,8 +69,13 @@ fn test_is_less_than_tuple_chip_eq() { let trace = chip.generate_trace(vec![14321, 244], vec![14321, 244]); let range_checker_trace = range_checker.generate_trace(); - run_simple_test_no_pis(vec![&chip, range_checker], vec![trace, range_checker_trace]) - .expect("Verification failed"); + println!("trace: {:?}", trace); + + run_simple_test_no_pis( + vec![&chip.air, range_checker], + vec![trace, range_checker_trace], + ) + .expect("Verification failed"); } #[test] @@ -83,7 +98,10 @@ fn test_is_less_than_tuple_chip_negative() { *debug.lock().unwrap() = false; }); assert_eq!( - run_simple_test_no_pis(vec![&chip, range_checker], vec![trace, range_checker_trace]), + run_simple_test_no_pis( + vec![&chip.air, range_checker], + vec![trace, range_checker_trace] + ), Err(VerificationError::OodEvaluationMismatch), "Expected verification to fail, but it passed" ); diff --git a/chips/src/is_less_than_tuple/trace.rs b/chips/src/is_less_than_tuple/trace.rs index dfccba21dc..286fdf588f 100644 --- a/chips/src/is_less_than_tuple/trace.rs +++ b/chips/src/is_less_than_tuple/trace.rs @@ -65,13 +65,37 @@ impl LocalTraceInstructions for IsLessThanTupleAir { lower_decomp_vec.push(curr_less_than_row[4..].to_vec()); } + let mut less_than_cumulative: Vec = vec![]; + + let mut transition_index = 0; + while transition_index < x.len() && x[transition_index] == y[transition_index] { + transition_index += 1; + } + + let is_equal_cumulative = std::iter::repeat(F::one()) + .take(transition_index) + .chain(std::iter::repeat(F::zero()).take(x.len() - transition_index)) + .collect::>(); + // compute whether the x < y - for i in (0..x.len()).rev() { + for i in 0..x.len() { + let mut less_than_curr = if i > 0 { + less_than_cumulative[i - 1] + } else { + F::zero() + }; + + if x[i] < y[i] && (i == 0 || is_equal_cumulative[i - 1] == F::one()) { + less_than_curr = F::one(); + } + if x[i] < y[i] && valid { tuple_less_than = F::one(); } else if x[i] > y[i] && valid { valid = false; } + + less_than_cumulative.push(less_than_curr); } // contains indicator whether difference is zero @@ -96,19 +120,19 @@ impl LocalTraceInstructions for IsLessThanTupleAir { } } - let mut less_than_cols: Vec> = vec![]; + let mut less_than_aux: Vec> = vec![]; for i in 0..x.len() { let less_than_col = IsLessThanAuxCols { lower: lower_vec[i], lower_decomp: lower_decomp_vec[i].clone(), }; - less_than_cols.push(less_than_col); + less_than_aux.push(less_than_col); } - let mut is_equal_cols: Vec> = vec![]; + let mut is_equal_aux: Vec> = vec![]; for inverse in &inverses { let is_equal_col = IsEqualAuxCols { inv: *inverse }; - is_equal_cols.push(is_equal_col); + is_equal_aux.push(is_equal_col); } let io = IsLessThanTupleIOCols { @@ -118,9 +142,11 @@ impl LocalTraceInstructions for IsLessThanTupleAir { }; let aux = IsLessThanTupleAuxCols { less_than, - less_than_cols, + less_than_aux, is_equal, - is_equal_cols, + is_equal_aux, + is_equal_cumulative, + less_than_cumulative, }; IsLessThanTupleCols { io, aux } From d76a94dab65c15d31021ab735b7b47881096fa29 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Wed, 5 Jun 2024 19:33:49 -0400 Subject: [PATCH 18/46] chore: fix tests --- chips/src/assert_sorted/tests/mod.rs | 24 ++++++------- chips/src/is_less_than/tests/mod.rs | 44 ++--------------------- chips/src/is_less_than/trace.rs | 17 +++++---- chips/src/is_less_than_tuple/air.rs | 8 +---- chips/src/is_less_than_tuple/tests/mod.rs | 14 +++----- chips/src/is_less_than_tuple/trace.rs | 21 +++++++---- 6 files changed, 45 insertions(+), 83 deletions(-) diff --git a/chips/src/assert_sorted/tests/mod.rs b/chips/src/assert_sorted/tests/mod.rs index 8c260f30f1..418c59f0dc 100644 --- a/chips/src/assert_sorted/tests/mod.rs +++ b/chips/src/assert_sorted/tests/mod.rs @@ -81,10 +81,10 @@ fn test_assert_sorted_chip_large_positive() { let range_max: u32 = 1 << decomp; let requests = vec![ - vec![35867, 318434, 12786, 44832], - vec![704210, 369315, 42421, 487111], - vec![370183, 37202, 729789, 783571], - vec![875005, 767547, 196209, 887921], + vec![44832, 12786, 318434, 35867], + vec![487111, 42421, 369315, 704210], + vec![783571, 729789, 37202, 370183], + vec![887921, 196209, 767547, 875005], ]; let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); @@ -123,10 +123,10 @@ fn test_assert_sorted_chip_largelimb_negative() { // the first and second rows are not in sorted order let requests = vec![ - vec![223, 448, 15, 587], - vec![883, 168, 772, 673], - vec![57, 386, 1025, 694], - vec![128, 767, 196, 953], + vec![587, 15, 448, 223], + vec![673, 772, 168, 883], + vec![694, 1025, 386, 57], + vec![953, 196, 767, 128], ]; let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); @@ -170,10 +170,10 @@ fn test_assert_sorted_chip_unsorted_negative() { // the first and second rows are not in sorted order let requests = vec![ - vec![704210, 369315, 42421, 44832], - vec![35867, 318434, 12786, 44832], - vec![370183, 37202, 729789, 783571], - vec![875005, 767547, 196209, 887921], + vec![44832, 42421, 369315, 704210], + vec![44832, 12786, 318434, 35867], + vec![783571, 729789, 37202, 370183], + vec![887921, 196209, 767547, 875005], ]; let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); diff --git a/chips/src/is_less_than/tests/mod.rs b/chips/src/is_less_than/tests/mod.rs index f86a2b734a..2c123a1dbe 100644 --- a/chips/src/is_less_than/tests/mod.rs +++ b/chips/src/is_less_than/tests/mod.rs @@ -21,47 +21,7 @@ fn test_is_less_than_chip_lt() { let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); let chip = IsLessThanChip::new(bus_index, range_max, limb_bits, decomp, range_checker); - let trace = chip.generate_trace(14321, 26883); - let range_trace: DenseMatrix = chip.range_checker.generate_trace(); - - run_simple_test_no_pis( - vec![&chip.air, chip.range_checker.as_ref()], - vec![trace, range_trace], - ) - .expect("Verification failed"); -} - -#[test] -fn test_is_less_than_chip_gt() { - let bus_index: usize = 0; - let limb_bits: usize = 16; - let decomp: usize = 8; - let range_max: u32 = 1 << decomp; - - let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); - - let chip = IsLessThanChip::new(bus_index, range_max, limb_bits, decomp, range_checker); - let trace = chip.generate_trace(1, 0); - let range_trace: DenseMatrix = chip.range_checker.generate_trace(); - - run_simple_test_no_pis( - vec![&chip.air, chip.range_checker.as_ref()], - vec![trace, range_trace], - ) - .expect("Verification failed"); -} - -#[test] -fn test_is_less_than_chip_eq() { - let bus_index: usize = 0; - let limb_bits: usize = 16; - let decomp: usize = 8; - let range_max: u32 = 1 << decomp; - - let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); - - let chip = IsLessThanChip::new(bus_index, range_max, limb_bits, decomp, range_checker); - let trace = chip.generate_trace(773, 773); + let trace = chip.generate_trace(vec![14321, 1, 773, 337], vec![26883, 0, 773, 456]); let range_trace: DenseMatrix = chip.range_checker.generate_trace(); run_simple_test_no_pis( @@ -81,7 +41,7 @@ fn test_is_less_than_negative() { let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); let chip = IsLessThanChip::new(bus_index, range_max, limb_bits, decomp, range_checker); - let mut trace = chip.generate_trace(446, 553); + let mut trace = chip.generate_trace(vec![446], vec![553]); let range_trace = chip.range_checker.generate_trace(); trace.values[2] = AbstractField::from_canonical_u64(0); diff --git a/chips/src/is_less_than/trace.rs b/chips/src/is_less_than/trace.rs index 680764b3ab..e18ed23468 100644 --- a/chips/src/is_less_than/trace.rs +++ b/chips/src/is_less_than/trace.rs @@ -11,16 +11,21 @@ use super::{ }; impl IsLessThanChip { - pub fn generate_trace(&self, x: u32, y: u32) -> RowMajorMatrix { + pub fn generate_trace(&self, x: Vec, y: Vec) -> RowMajorMatrix { let num_cols: usize = IsLessThanCols::::get_width(*self.air.limb_bits(), *self.air.decomp()); - let row = self - .air - .generate_trace_row((x, y, self.range_checker.clone())) - .flatten(); + let mut rows = vec![]; - RowMajorMatrix::new(row, num_cols) + for i in 0..x.len() { + let row: Vec = self + .air + .generate_trace_row((x[i], y[i], self.range_checker.clone())) + .flatten(); + rows.extend(row); + } + + RowMajorMatrix::new(rows, num_cols) } } diff --git a/chips/src/is_less_than_tuple/air.rs b/chips/src/is_less_than_tuple/air.rs index 8129b73dcb..9d1d0c0b09 100644 --- a/chips/src/is_less_than_tuple/air.rs +++ b/chips/src/is_less_than_tuple/air.rs @@ -133,12 +133,6 @@ impl SubAir for IsLessThanTupleAir { ); } - let mut check_less_than: AB::Expr = less_than_cumulative[0].into(); - - for i in 1..x.len() { - check_less_than += less_than_cumulative[i] * is_equal_cumulative[i - 1]; - } - - builder.assert_eq(io.tuple_less_than, check_less_than); + builder.assert_eq(io.tuple_less_than, less_than_cumulative[x.len() - 1]); } } diff --git a/chips/src/is_less_than_tuple/tests/mod.rs b/chips/src/is_less_than_tuple/tests/mod.rs index 26b01d62f7..2c45598497 100644 --- a/chips/src/is_less_than_tuple/tests/mod.rs +++ b/chips/src/is_less_than_tuple/tests/mod.rs @@ -20,11 +20,9 @@ fn test_is_less_than_tuple_chip_lt() { let chip = IsLessThanTupleChip::new(bus_index, range_max, limb_bits, decomp, range_checker); let range_checker = chip.range_checker.as_ref(); - let trace = chip.generate_trace(vec![14321, 123], vec![26678, 233]); + let trace = chip.generate_trace(vec![vec![14321, 123]], vec![vec![26678, 233]]); let range_checker_trace = range_checker.generate_trace(); - println!("trace: {:?}", trace); - run_simple_test_no_pis( vec![&chip.air, range_checker], vec![trace, range_checker_trace], @@ -43,11 +41,9 @@ fn test_is_less_than_tuple_chip_gt() { let chip = IsLessThanTupleChip::new(bus_index, range_max, limb_bits, decomp, range_checker); let range_checker = chip.range_checker.as_ref(); - let trace = chip.generate_trace(vec![244, 14321], vec![233, 26678]); + let trace = chip.generate_trace(vec![vec![244, 14321]], vec![vec![233, 26678]]); let range_checker_trace = range_checker.generate_trace(); - println!("trace: {:?}", trace); - run_simple_test_no_pis( vec![&chip.air, range_checker], vec![trace, range_checker_trace], @@ -66,11 +62,9 @@ fn test_is_less_than_tuple_chip_eq() { let chip = IsLessThanTupleChip::new(bus_index, range_max, limb_bits, decomp, range_checker); let range_checker = chip.range_checker.as_ref(); - let trace = chip.generate_trace(vec![14321, 244], vec![14321, 244]); + let trace = chip.generate_trace(vec![vec![14321, 244]], vec![vec![14321, 244]]); let range_checker_trace = range_checker.generate_trace(); - println!("trace: {:?}", trace); - run_simple_test_no_pis( vec![&chip.air, range_checker], vec![trace, range_checker_trace], @@ -89,7 +83,7 @@ fn test_is_less_than_tuple_chip_negative() { let chip = IsLessThanTupleChip::new(bus_index, range_max, limb_bits, decomp, range_checker); let range_checker = chip.range_checker.as_ref(); - let mut trace = chip.generate_trace(vec![14321, 123], vec![26678, 233]); + let mut trace = chip.generate_trace(vec![vec![14321, 123]], vec![vec![26678, 233]]); let range_checker_trace = range_checker.generate_trace(); trace.values[2] = AbstractField::from_canonical_u64(0); diff --git a/chips/src/is_less_than_tuple/trace.rs b/chips/src/is_less_than_tuple/trace.rs index 286fdf588f..f453e576da 100644 --- a/chips/src/is_less_than_tuple/trace.rs +++ b/chips/src/is_less_than_tuple/trace.rs @@ -16,19 +16,28 @@ use super::{ }; impl IsLessThanTupleChip { - pub fn generate_trace(&self, x: Vec, y: Vec) -> RowMajorMatrix { + pub fn generate_trace( + &self, + x: Vec>, + y: Vec>, + ) -> RowMajorMatrix { let num_cols: usize = IsLessThanTupleCols::::get_width( self.air.limb_bits().clone(), *self.air.decomp(), self.air.tuple_len(), ); - let row: Vec = self - .air - .generate_trace_row((x, y, self.range_checker.clone())) - .flatten(); + let mut rows: Vec = vec![]; + + for i in 0..x.len() { + let row: Vec = self + .air + .generate_trace_row((x[i].clone(), y[i].clone(), self.range_checker.clone())) + .flatten(); + rows.extend(row); + } - RowMajorMatrix::new(row, num_cols) + RowMajorMatrix::new(rows, num_cols) } } From eca7ef3cd7c98587a5655a2ab72a851fb51b83a3 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Thu, 6 Jun 2024 11:07:53 -0400 Subject: [PATCH 19/46] chore: address comments for AssertSortedChip --- chips/src/assert_sorted/air.rs | 99 ++++++++-------------------- chips/src/assert_sorted/chip.rs | 46 +++++++------ chips/src/assert_sorted/columns.rs | 66 +++++-------------- chips/src/assert_sorted/mod.rs | 71 +++----------------- chips/src/assert_sorted/tests/mod.rs | 59 +---------------- chips/src/assert_sorted/trace.rs | 46 ++++--------- chips/src/is_less_than_tuple/air.rs | 20 +----- chips/src/is_less_than_tuple/mod.rs | 14 ++++ 8 files changed, 106 insertions(+), 315 deletions(-) diff --git a/chips/src/assert_sorted/air.rs b/chips/src/assert_sorted/air.rs index dc47aabd62..c4c4ce3a32 100644 --- a/chips/src/assert_sorted/air.rs +++ b/chips/src/assert_sorted/air.rs @@ -1,26 +1,26 @@ use std::borrow::Borrow; use p3_air::{Air, AirBuilder, BaseAir}; -use p3_field::{AbstractField, Field}; +use p3_field::Field; use p3_matrix::Matrix; -use crate::is_less_than_tuple::columns::IsLessThanTupleCols; +use crate::is_less_than_tuple::columns::{IsLessThanTupleCols, IsLessThanTupleIOCols}; use crate::sub_chip::SubAir; use super::columns::AssertSortedCols; -use super::AssertSortedChip; +use super::AssertSortedAir; -impl BaseAir for AssertSortedChip { +impl BaseAir for AssertSortedAir { fn width(&self) -> usize { AssertSortedCols::::get_width( - self.air.limb_bits().clone(), - *self.air.decomp(), - *self.air.key_vec_len(), + self.is_less_than_tuple_air().limb_bits().clone(), + *self.is_less_than_tuple_air().decomp(), + self.is_less_than_tuple_air().tuple_len(), ) } } -impl Air for AssertSortedChip { +impl Air for AssertSortedAir { fn eval(&self, builder: &mut AB) { let main = builder.main(); @@ -31,85 +31,38 @@ impl Air for AssertSortedChip { let local_cols = AssertSortedCols::::from_slice( local, - self.air.limb_bits().clone(), - *self.air.decomp(), - *self.air.key_vec_len(), + self.is_less_than_tuple_air().limb_bits().clone(), + *self.is_less_than_tuple_air().decomp(), + self.is_less_than_tuple_air().tuple_len(), ); let next_cols = AssertSortedCols::::from_slice( next, - self.air.limb_bits().clone(), - *self.air.decomp(), - *self.air.key_vec_len(), + self.is_less_than_tuple_air().limb_bits().clone(), + *self.is_less_than_tuple_air().decomp(), + self.is_less_than_tuple_air().tuple_len(), ); - let key_len = *self.air.key_vec_len(); - - for i in 0..key_len { - let mut key_from_limbs: AB::Expr = AB::Expr::zero(); - - // num_limbs is the number of sublimbs the current limb should be decomposed into - let num_limbs = (self.air.limb_bits()[i] + self.air.decomp() - 1) / self.air.decomp(); - // to range check the last sublimb, we need to shift it - let last_limb_shift = (self.air.decomp() - - (self.air.limb_bits()[i] % self.air.decomp())) - % self.air.decomp(); - - // constrain that the decomposition is correct - for j in 0..num_limbs { - key_from_limbs += local_cols.keys_decomp[i][j] - * AB::Expr::from_canonical_u64(1 << (j * self.air.decomp())); - } - - // constrain that the shifted last sublimb is shifted correctly - let shifted_val = local_cols.keys_decomp[i][num_limbs - 1] - * AB::Expr::from_canonical_u64(1 << last_limb_shift); - - builder.assert_eq(local_cols.keys_decomp[i][num_limbs], shifted_val); - builder.assert_eq(key_from_limbs, local_cols.is_less_than_tuple_cols.io.x[i]); - - // constrain that the keys are consistent across rows - builder.when_transition().assert_eq( - local_cols.is_less_than_tuple_cols.io.y[i], - next_cols.is_less_than_tuple_cols.io.x[i], - ); - } - // constrain that the current key is less than the next builder .when_transition() - .assert_one(local_cols.is_less_than_tuple_cols.io.tuple_less_than); - - // generate IsLessThanTupleCols struct for current row and next row - let mut curr_start_idx = 0; - let mut curr_end_idx = 2 * key_len; - // get the current key and next key - let mut local_slice: Vec = local[curr_start_idx..curr_end_idx].to_vec(); - - // skip the key decomposition - for i in 0..key_len { - let num_limbs = (self.air.limb_bits()[i] + self.air.decomp() - 1) / self.air.decomp(); - curr_end_idx += num_limbs + 1; - } + .assert_one(local_cols.less_than_next_key); - // get the rest of the columns - curr_start_idx = curr_end_idx; - - local_slice.extend_from_slice(&local[curr_start_idx..]); - - let local_cols = IsLessThanTupleCols::::from_slice( - &local_slice, - self.air.limb_bits().clone(), - *self.air.decomp(), - *self.air.key_vec_len(), - ); + let is_less_than_tuple_cols = IsLessThanTupleCols { + io: IsLessThanTupleIOCols { + x: local_cols.key, + y: next_cols.key, + tuple_less_than: local_cols.less_than_next_key, + }, + aux: local_cols.is_less_than_tuple_aux, + }; // constrain the indicator that we used to check whether the current key < next key is correct SubAir::eval( - &self.is_less_than_tuple_chip.air, + self.is_less_than_tuple_air(), &mut builder.when_transition(), - local_cols.io, - local_cols.aux, + is_less_than_tuple_cols.io, + is_less_than_tuple_cols.aux, ); } } diff --git a/chips/src/assert_sorted/chip.rs b/chips/src/assert_sorted/chip.rs index ac6f4d9f12..7057c15506 100644 --- a/chips/src/assert_sorted/chip.rs +++ b/chips/src/assert_sorted/chip.rs @@ -1,46 +1,44 @@ -use crate::sub_chip::SubAirWithInteractions; +use crate::{ + is_less_than_tuple::columns::{IsLessThanTupleCols, IsLessThanTupleIOCols}, + sub_chip::SubAirWithInteractions, +}; use super::columns::AssertSortedCols; use afs_stark_backend::interaction::{Chip, Interaction}; -use p3_air::VirtualPairCol; use p3_field::PrimeField64; -use super::AssertSortedChip; +use super::AssertSortedAir; -impl Chip for AssertSortedChip { +impl Chip for AssertSortedAir { fn sends(&self) -> Vec> { let num_cols = AssertSortedCols::::get_width( - self.air.limb_bits().clone(), - *self.air.decomp(), - *self.air.key_vec_len(), + self.is_less_than_tuple_air().limb_bits().clone(), + *self.is_less_than_tuple_air().decomp(), + self.is_less_than_tuple_air().tuple_len(), ); let all_cols = (0..num_cols).collect::>(); let cols_numbered = AssertSortedCols::::from_slice( &all_cols, - self.air.limb_bits().clone(), - *self.air.decomp(), - *self.air.key_vec_len(), + self.is_less_than_tuple_air().limb_bits().clone(), + *self.is_less_than_tuple_air().decomp(), + self.is_less_than_tuple_air().tuple_len(), ); let mut interactions: Vec> = vec![]; - // we will range check the decomposed limbs of the key - for i in 0..*self.air.key_vec_len() { - let num_limbs = (self.air.limb_bits()[i] + *self.air.decomp() - 1) / *self.air.decomp(); - // add 1 to account for the shifted last sublimb - for j in 0..(num_limbs + 1) { - interactions.push(Interaction { - fields: vec![VirtualPairCol::single_main(cols_numbered.keys_decomp[i][j])], - count: VirtualPairCol::constant(F::one()), - argument_index: self.range_checker.bus_index(), - }); - } - } + let is_less_than_tuple_cols = IsLessThanTupleCols { + io: IsLessThanTupleIOCols { + x: cols_numbered.key.clone(), + y: cols_numbered.key.clone(), + tuple_less_than: cols_numbered.less_than_next_key, + }, + aux: cols_numbered.is_less_than_tuple_aux, + }; let subchip_interactions = SubAirWithInteractions::::sends( - &self.is_less_than_tuple_chip.air, - cols_numbered.is_less_than_tuple_cols, + self.is_less_than_tuple_air(), + is_less_than_tuple_cols, ); interactions.extend(subchip_interactions); diff --git a/chips/src/assert_sorted/columns.rs b/chips/src/assert_sorted/columns.rs index 3fd5e32c99..297221635f 100644 --- a/chips/src/assert_sorted/columns.rs +++ b/chips/src/assert_sorted/columns.rs @@ -1,52 +1,32 @@ use afs_derive::AlignedBorrow; use crate::{ - is_equal::columns::IsEqualAuxCols, - is_less_than::columns::IsLessThanAuxCols, - is_less_than_tuple::columns::{ - IsLessThanTupleAuxCols, IsLessThanTupleCols, IsLessThanTupleIOCols, - }, + is_equal::columns::IsEqualAuxCols, is_less_than::columns::IsLessThanAuxCols, + is_less_than_tuple::columns::IsLessThanTupleAuxCols, }; // Since AssertSortedChip contains a LessThanChip subchip, a subset of the columns are those of the // LessThanChip #[derive(AlignedBorrow)] pub struct AssertSortedCols { - pub keys_decomp: Vec>, - pub is_less_than_tuple_cols: IsLessThanTupleCols, + pub key: Vec, + pub less_than_next_key: T, + pub is_less_than_tuple_aux: IsLessThanTupleAuxCols, } impl AssertSortedCols { pub fn from_slice(slc: &[T], limb_bits: Vec, decomp: usize, key_vec_len: usize) -> Self { - // num_limbs is the number of sublimbs per limb, not including the shifted last sublimb let mut curr_start_idx = 0; let mut curr_end_idx = key_vec_len; // the first key_vec_len elements are the key itself - let x = slc[curr_start_idx..curr_end_idx].to_vec(); - curr_start_idx = curr_end_idx; - curr_end_idx += key_vec_len; - - // the next key_vec_len elements are the next key (the following row) - let y = slc[curr_start_idx..curr_end_idx].to_vec(); - - // the next elements are the decomposed key (with each having an extra shifted last sublimb) - let mut keys_decomp: Vec> = vec![]; - - for curr_limb_bits in limb_bits.iter() { - let num_limbs = (curr_limb_bits + decomp - 1) / decomp; - - curr_start_idx = curr_end_idx; - curr_end_idx += num_limbs + 1; - - keys_decomp.push(slc[curr_start_idx..curr_end_idx].to_vec()); - } + let key = slc[curr_start_idx..curr_end_idx].to_vec(); curr_start_idx = curr_end_idx; curr_end_idx += 1; // the next element is the indicator for whether the key is less than the next key - let tuple_less_than = slc[curr_start_idx].clone(); + let less_than_next_key = slc[curr_start_idx].clone(); curr_start_idx = curr_end_idx; curr_end_idx += key_vec_len; @@ -110,12 +90,7 @@ impl AssertSortedCols { less_than_cumulative.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); - let io = IsLessThanTupleIOCols { - x, - y, - tuple_less_than, - }; - let aux = IsLessThanTupleAuxCols { + let is_less_than_tuple_aux = IsLessThanTupleAuxCols { less_than, less_than_aux, is_equal, @@ -124,11 +99,10 @@ impl AssertSortedCols { less_than_cumulative, }; - let is_less_than_tuple_cols = IsLessThanTupleCols { io, aux }; - Self { - keys_decomp, - is_less_than_tuple_cols, + key, + less_than_next_key, + is_less_than_tuple_aux, } } @@ -137,16 +111,11 @@ impl AssertSortedCols { // account for the sublimb itself, and another 1 to account for the shifted // last sublimb let mut width = 0; - // for the x and y keys - width += 2 * key_vec_len; - - // for the decomposed key - for &limb_bit in limb_bits.iter() { - let num_limbs = (limb_bit + decomp - 1) / decomp; - width += num_limbs + 1; - } + + // for the key itself + width += key_vec_len; - // for the tuple less than indicator + // for the less than next key indicator width += 1; // for the less_than indicators @@ -161,9 +130,10 @@ impl AssertSortedCols { width += num_limbs + 1; } - // for the indicator whether difference is zero + // for the is_equal indicators width += key_vec_len; - // for the y such that y * (i + x) = 1 + + // for the inverses width += key_vec_len; // for the cumulative is_equal and less_than diff --git a/chips/src/assert_sorted/mod.rs b/chips/src/assert_sorted/mod.rs index e1232779a0..2cd2ab5eff 100644 --- a/chips/src/assert_sorted/mod.rs +++ b/chips/src/assert_sorted/mod.rs @@ -1,13 +1,8 @@ use std::sync::Arc; -use crate::{is_less_than_tuple::IsLessThanTupleChip, range_gate::RangeCheckerGateChip}; +use crate::{is_less_than_tuple::IsLessThanTupleAir, range_gate::RangeCheckerGateChip}; use getset::Getters; -use afs_stark_backend::interaction::Interaction; -use columns::AssertSortedCols; -use p3_air::VirtualPairCol; -use p3_field::PrimeField64; - #[cfg(test)] pub mod tests; @@ -17,22 +12,9 @@ pub mod columns; pub mod trace; #[derive(Default, Getters)] -pub struct AssertedSortedAir { - // The bus index for sends to range chip - #[getset(get = "pub")] - bus_index: usize, - // The maximum range for the range checker - #[getset(get = "pub")] - range_max: u32, - // The limb_bits for each element of the keys +pub struct AssertSortedAir { #[getset(get = "pub")] - limb_bits: Vec, - // The number of bits to decompose each number into, for less than checking - #[getset(get = "pub")] - decomp: usize, - // The number of elements in a key - #[getset(get = "pub")] - key_vec_len: usize, + is_less_than_tuple_air: IsLessThanTupleAir, // The keys to check for sortedness #[getset(get = "pub")] keys: Vec>, @@ -52,8 +34,7 @@ pub struct AssertedSortedAir { */ #[derive(Default)] pub struct AssertSortedChip { - air: AssertedSortedAir, - is_less_than_tuple_chip: IsLessThanTupleChip, + air: AssertSortedAir, range_checker: Arc, } @@ -63,53 +44,17 @@ impl AssertSortedChip { range_max: u32, limb_bits: Vec, decomp: usize, - key_vec_len: usize, keys: Vec>, range_checker: Arc, ) -> Self { Self { - air: AssertedSortedAir { - bus_index, - range_max, - limb_bits: limb_bits.clone(), - decomp, - key_vec_len, + air: AssertSortedAir { + is_less_than_tuple_air: IsLessThanTupleAir::new( + bus_index, range_max, limb_bits, decomp, + ), keys, }, - is_less_than_tuple_chip: IsLessThanTupleChip::new( - bus_index, - range_max, - limb_bits, - decomp, - range_checker.clone(), - ), range_checker, } } - - pub fn sends_custom( - &self, - cols: &AssertSortedCols, - ) -> Vec> { - // num_limbs is the number of sublimbs per limb of key, not including the - // shifted last sublimb - let num_keys = *self.air.key_vec_len(); - - let mut interactions = vec![]; - - // we will range check the decomposed limbs of the key - for i in 0..num_keys { - let num_limbs = (self.air.limb_bits()[i] + *self.air.decomp() - 1) / *self.air.decomp(); - // add 1 to account for the shifted last sublimb - for j in 0..(num_limbs + 1) { - interactions.push(Interaction { - fields: vec![VirtualPairCol::single_main(cols.keys_decomp[i][j])], - count: VirtualPairCol::constant(F::one()), - argument_index: *self.air.bus_index(), - }); - } - } - - interactions - } } diff --git a/chips/src/assert_sorted/tests/mod.rs b/chips/src/assert_sorted/tests/mod.rs index 418c59f0dc..f0df26fbdb 100644 --- a/chips/src/assert_sorted/tests/mod.rs +++ b/chips/src/assert_sorted/tests/mod.rs @@ -40,7 +40,6 @@ fn test_assert_sorted_chip_small_positive() { let bus_index: usize = 0; let limb_bits: Vec = vec![16, 16]; let decomp: usize = 8; - let key_vec_len: usize = 2; let range_max: u32 = 1 << decomp; @@ -53,7 +52,6 @@ fn test_assert_sorted_chip_small_positive() { range_max, limb_bits, decomp, - key_vec_len, requests.clone(), range_checker.clone(), ); @@ -63,7 +61,7 @@ fn test_assert_sorted_chip_small_positive() { let range_checker_trace = assert_sorted_chip.range_checker.generate_trace(); run_simple_test_no_pis( - vec![&assert_sorted_chip, range_checker_chip], + vec![&assert_sorted_chip.air, range_checker_chip], vec![assert_sorted_chip_trace, range_checker_trace], ) .expect("Verification failed"); @@ -76,7 +74,6 @@ fn test_assert_sorted_chip_large_positive() { let bus_index: usize = 0; let limb_bits: Vec = vec![30, 30, 30, 30]; let decomp: usize = 8; - let key_vec_len: usize = 4; let range_max: u32 = 1 << decomp; @@ -94,7 +91,6 @@ fn test_assert_sorted_chip_large_positive() { range_max, limb_bits, decomp, - key_vec_len, requests.clone(), range_checker.clone(), ); @@ -104,59 +100,12 @@ fn test_assert_sorted_chip_large_positive() { let range_checker_trace = assert_sorted_chip.range_checker.generate_trace(); run_simple_test_no_pis( - vec![&assert_sorted_chip, range_checker_chip], + vec![&assert_sorted_chip.air, range_checker_chip], vec![assert_sorted_chip_trace, range_checker_trace], ) .expect("Verification failed"); } -// covers limb_bits >= 20, key_vec_len >= 4, limb_bits % decomp != 0, number of rows >= 4, at least one limb -// has more than limb_bits bits, rows are sorted lexicographically -#[test] -fn test_assert_sorted_chip_largelimb_negative() { - let bus_index: usize = 0; - let limb_bits: Vec = vec![10, 10, 10, 10]; - let decomp: usize = 8; - let key_vec_len: usize = 4; - - let range_max: u32 = 1 << decomp; - - // the first and second rows are not in sorted order - let requests = vec![ - vec![587, 15, 448, 223], - vec![673, 772, 168, 883], - vec![694, 1025, 386, 57], - vec![953, 196, 767, 128], - ]; - - let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); - - let assert_sorted_chip = AssertSortedChip::new( - bus_index, - range_max, - limb_bits, - decomp, - key_vec_len, - requests.clone(), - range_checker.clone(), - ); - let range_checker_chip = assert_sorted_chip.range_checker.as_ref(); - - let assert_sorted_chip_trace: DenseMatrix = assert_sorted_chip.generate_trace(); - let range_checker_trace = assert_sorted_chip.range_checker.generate_trace(); - - let result = run_simple_test_no_pis( - vec![&assert_sorted_chip, range_checker_chip], - vec![assert_sorted_chip_trace, range_checker_trace], - ); - - assert_eq!( - result, - Err(VerificationError::NonZeroCumulativeSum), - "Expected verification to fail, but it passed" - ); -} - // covers limb_bits >= 20, key_vec_len >= 4, limb_bits % decomp != 0, number of rows >= 4, each limb has at // most limb_bits bits, rows are not sorted lexicographically #[test] @@ -164,7 +113,6 @@ fn test_assert_sorted_chip_unsorted_negative() { let bus_index: usize = 0; let limb_bits: Vec = vec![30, 30, 30, 30]; let decomp: usize = 8; - let key_vec_len: usize = 4; let range_max: u32 = 1 << decomp; @@ -183,7 +131,6 @@ fn test_assert_sorted_chip_unsorted_negative() { range_max, limb_bits, decomp, - key_vec_len, requests.clone(), range_checker.clone(), ); @@ -197,7 +144,7 @@ fn test_assert_sorted_chip_unsorted_negative() { }); assert_eq!( run_simple_test_no_pis( - vec![&assert_sorted_chip, range_checker_chip], + vec![&assert_sorted_chip.air, range_checker_chip], vec![assert_sorted_chip_trace, range_checker_trace], ), Err(VerificationError::OodEvaluationMismatch), diff --git a/chips/src/assert_sorted/trace.rs b/chips/src/assert_sorted/trace.rs index 4032a597f3..5a3980a3e3 100644 --- a/chips/src/assert_sorted/trace.rs +++ b/chips/src/assert_sorted/trace.rs @@ -8,53 +8,31 @@ use super::{columns::AssertSortedCols, AssertSortedChip}; impl AssertSortedChip { pub fn generate_trace(&self) -> RowMajorMatrix { let num_cols: usize = AssertSortedCols::::get_width( - self.air.limb_bits().clone(), - *self.air.decomp(), - *self.air.key_vec_len(), + self.air.is_less_than_tuple_air().limb_bits().clone(), + *self.air.is_less_than_tuple_air().decomp(), + self.air.is_less_than_tuple_air().tuple_len(), ); let mut rows: Vec = vec![]; - for i in 0..*self.air.key_vec_len() { + for i in 0..self.air.is_less_than_tuple_air().tuple_len() { let key = self.air.keys()[i].clone(); - let next_key: Vec = if i == *self.air.key_vec_len() - 1 { - vec![0; *self.air.key_vec_len()] + let next_key: Vec = if i == self.air.is_less_than_tuple_air().tuple_len() - 1 { + vec![0; self.air.is_less_than_tuple_air().tuple_len()] } else { self.air.keys()[i + 1].clone() }; let is_less_than_tuple_trace = LocalTraceInstructions::generate_trace_row( - &self.is_less_than_tuple_chip.air, + self.air.is_less_than_tuple_air(), (key.clone(), next_key.clone(), self.range_checker.clone()), ) .flatten(); - let mut key_decomp_trace: Vec = vec![]; - // decompose each limb into sublimbs of size self.decomp() bits - for (i, &val) in key.iter().enumerate() { - let num_limbs = - (self.air.limb_bits()[i] + self.air.decomp() - 1) / self.air.decomp(); - let last_limb_shift = (self.air.decomp() - - (self.air.limb_bits()[i] % self.air.decomp())) - % self.air.decomp(); - - for i in 0..num_limbs { - let bits = (val >> (i * self.air.decomp())) & ((1 << self.air.decomp()) - 1); - key_decomp_trace.push(F::from_canonical_u32(bits)); - self.range_checker.add_count(bits); - } - // the last sublimb should be of size self.limb_bits() % self.decomp() bits, - // so we need to shift it to constrain this - let bits = - (val >> ((num_limbs - 1) * self.air.decomp())) & ((1 << self.air.decomp()) - 1); - if (bits << last_limb_shift) < *self.air.range_max() { - self.range_checker.add_count(bits << last_limb_shift); - } - key_decomp_trace.push(F::from_canonical_u32(bits << last_limb_shift)); - } - - let mut row: Vec = is_less_than_tuple_trace[0..2 * *self.air.key_vec_len()].to_vec(); - row.extend_from_slice(&key_decomp_trace); - row.extend_from_slice(&is_less_than_tuple_trace[2 * *self.air.key_vec_len()..]); + let mut row: Vec = + is_less_than_tuple_trace[0..self.air.is_less_than_tuple_air().tuple_len()].to_vec(); + row.extend_from_slice( + &is_less_than_tuple_trace[2 * self.air.is_less_than_tuple_air().tuple_len()..], + ); rows.extend_from_slice(&row); } diff --git a/chips/src/is_less_than_tuple/air.rs b/chips/src/is_less_than_tuple/air.rs index 9d1d0c0b09..f2ab610fef 100644 --- a/chips/src/is_less_than_tuple/air.rs +++ b/chips/src/is_less_than_tuple/air.rs @@ -1,4 +1,4 @@ -use std::{borrow::Borrow, sync::Arc}; +use std::borrow::Borrow; use p3_air::{Air, AirBuilder, BaseAir}; use p3_field::Field; @@ -9,11 +9,7 @@ use crate::{ columns::{IsEqualAuxCols, IsEqualCols, IsEqualIOCols}, IsEqualChip, }, - is_less_than::{ - columns::{IsLessThanAuxCols, IsLessThanCols, IsLessThanIOCols}, - IsLessThanChip, - }, - range_gate::RangeCheckerGateChip, + is_less_than::columns::{IsLessThanAuxCols, IsLessThanCols, IsLessThanIOCols}, sub_chip::{AirConfig, SubAir}, }; @@ -68,16 +64,6 @@ impl SubAir for IsLessThanTupleAir { let x_val = x[i]; let y_val = y[i]; - let range_checker_dummy = Arc::new(RangeCheckerGateChip::new( - *self.bus_index(), - *self.range_max(), - )); - - let is_less_than_chip_dummy = IsLessThanChip { - air: self.is_lt_airs[i].clone(), - range_checker: range_checker_dummy, - }; - // here we constrain that less_than[i] indicates whether x[i] < y[i] using the IsLessThan subchip let is_less_than_cols = IsLessThanCols { io: IsLessThanIOCols { @@ -92,7 +78,7 @@ impl SubAir for IsLessThanTupleAir { }; SubAir::eval( - &is_less_than_chip_dummy.air, + &self.is_lt_airs[i].clone(), builder, is_less_than_cols.io, is_less_than_cols.aux, diff --git a/chips/src/is_less_than_tuple/mod.rs b/chips/src/is_less_than_tuple/mod.rs index 16cf867995..847be4a110 100644 --- a/chips/src/is_less_than_tuple/mod.rs +++ b/chips/src/is_less_than_tuple/mod.rs @@ -29,6 +29,20 @@ pub struct IsLessThanTupleAir { } impl IsLessThanTupleAir { + pub fn new(bus_index: usize, range_max: u32, limb_bits: Vec, decomp: usize) -> Self { + let is_lt_airs = limb_bits + .iter() + .map(|&limb_bit| IsLessThanAir::new(bus_index, range_max, limb_bit, decomp)) + .collect::>(); + + Self { + bus_index, + range_max, + decomp, + is_lt_airs, + } + } + pub fn tuple_len(&self) -> usize { self.is_lt_airs.len() } From 4890798419cdbea1624a9abe49436f39b60da163 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Thu, 6 Jun 2024 11:31:22 -0400 Subject: [PATCH 20/46] chore: cleanup AssertSorted --- chips/src/assert_sorted/chip.rs | 7 ++----- chips/src/assert_sorted/columns.rs | 27 +++++++++++++-------------- chips/src/assert_sorted/mod.rs | 14 ++++++-------- chips/src/assert_sorted/tests/mod.rs | 19 +++++++++---------- chips/src/assert_sorted/trace.rs | 6 ++++-- 5 files changed, 34 insertions(+), 39 deletions(-) diff --git a/chips/src/assert_sorted/chip.rs b/chips/src/assert_sorted/chip.rs index 7057c15506..a93827f83d 100644 --- a/chips/src/assert_sorted/chip.rs +++ b/chips/src/assert_sorted/chip.rs @@ -25,8 +25,7 @@ impl Chip for AssertSortedAir { self.is_less_than_tuple_air().tuple_len(), ); - let mut interactions: Vec> = vec![]; - + // here, y doesn't matter since we are only range checking the decompositions of x let is_less_than_tuple_cols = IsLessThanTupleCols { io: IsLessThanTupleIOCols { x: cols_numbered.key.clone(), @@ -41,8 +40,6 @@ impl Chip for AssertSortedAir { is_less_than_tuple_cols, ); - interactions.extend(subchip_interactions); - - interactions + subchip_interactions } } diff --git a/chips/src/assert_sorted/columns.rs b/chips/src/assert_sorted/columns.rs index 297221635f..2d9e4c2b11 100644 --- a/chips/src/assert_sorted/columns.rs +++ b/chips/src/assert_sorted/columns.rs @@ -65,6 +65,19 @@ impl AssertSortedCols { curr_start_idx = curr_end_idx; curr_end_idx += key_vec_len; + // the next key_vec_len elements contain the cumulative is_equal indicators; is_equal_cumulative[i] + // indicates whether the first i elements of the key are equal to those of the next key + let is_equal_cumulative = slc[curr_start_idx..curr_end_idx].to_vec(); + + curr_start_idx = curr_end_idx; + curr_end_idx += key_vec_len; + + // the next key_vec_len elements contain the cumulative less than indicators; less_than_cumulative[i] + // indicates whether the first i elements of the key are lexicographically less than those of the next + // key + let less_than_cumulative = slc[curr_start_idx..curr_end_idx].to_vec(); + + // now, we construct the IsLessThanTupleAuxCols let mut less_than_aux: Vec> = vec![]; for i in 0..key_vec_len { let less_than_col = IsLessThanAuxCols { @@ -80,16 +93,6 @@ impl AssertSortedCols { is_equal_aux.push(is_equal_col); } - let mut is_equal_cumulative: Vec = vec![]; - let mut less_than_cumulative: Vec = vec![]; - - is_equal_cumulative.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); - - curr_start_idx = curr_end_idx; - curr_end_idx += key_vec_len; - - less_than_cumulative.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); - let is_less_than_tuple_aux = IsLessThanTupleAuxCols { less_than, less_than_aux, @@ -107,11 +110,7 @@ impl AssertSortedCols { } pub fn get_width(limb_bits: Vec, decomp: usize, key_vec_len: usize) -> usize { - // there are (limb_bits + decomp - 1) / decomp sublimbs per limb, we add 1 to - // account for the sublimb itself, and another 1 to account for the shifted - // last sublimb let mut width = 0; - // for the key itself width += key_vec_len; diff --git a/chips/src/assert_sorted/mod.rs b/chips/src/assert_sorted/mod.rs index 2cd2ab5eff..893ac319de 100644 --- a/chips/src/assert_sorted/mod.rs +++ b/chips/src/assert_sorted/mod.rs @@ -21,16 +21,14 @@ pub struct AssertSortedAir { } /** - * This Chip constrains that consecutive rows are sorted lexicographically. + * This chip constrains that consecutive rows are sorted lexicographically. * - * Each row consists of a key decomposed into limbs, and the chip constrains - * each limb has at most limb_bits bits, where limb_bits is at most 31. It - * does this by interacting with a RangeCheckerGateChip. Because the range checker - * gate can take MAX up to 2^20, we further decompose each limb into sublimbs - * of size decomp bits. + * Each row consists of a key decomposed into limbs. Each limb has its own max number of + * bits, given by the limb_bits array. The chip assumes that each limb is within its + * given max limb_bits. * - * The AssertSortedChip contains a LessThanChip subchip, which is used to constrain - * that the rows are sorted lexicographically. + * The AssertSortedChip uses the IsLessThanTupleChip as a subchip to check that the rows + * are sorted lexicographically. */ #[derive(Default)] pub struct AssertSortedChip { diff --git a/chips/src/assert_sorted/tests/mod.rs b/chips/src/assert_sorted/tests/mod.rs index f0df26fbdb..2a9b4a6923 100644 --- a/chips/src/assert_sorted/tests/mod.rs +++ b/chips/src/assert_sorted/tests/mod.rs @@ -25,16 +25,12 @@ use p3_matrix::dense::DenseMatrix; * partition on number of rows: * number of rows < 4 * number of rows >= 4 - * partition on size of each limb: - * each limb has at most limb_bits bits - * at least one limb has more than limb_bits bits * partition on row order: * rows are sorted lexicographically * rows are not sorted lexicographically */ -// covers limb_bits < 20, key_vec_len < 4, limb_bits % decomp == 0, number of rows < 4, each limb has at -// most limb_bits bits, rows are sorted lexicographically +// covers limb_bits < 20, key_vec_len < 4, limb_bits % decomp == 0, number of rows < 4, rows are sorted lexicographically #[test] fn test_assert_sorted_chip_small_positive() { let bus_index: usize = 0; @@ -43,7 +39,12 @@ fn test_assert_sorted_chip_small_positive() { let range_max: u32 = 1 << decomp; - let requests = vec![vec![7784, 35423], vec![17558, 44832]]; + let requests = vec![ + vec![7784, 35423], + vec![17558, 44832], + vec![22843, 12786], + vec![32886, 24834], + ]; let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); @@ -67,8 +68,7 @@ fn test_assert_sorted_chip_small_positive() { .expect("Verification failed"); } -// covers limb_bits >= 20, key_vec_len >= 4, limb_bits % decomp != 0, number of rows >= 4, each limb has at -// most limb_bits bits, rows are sorted lexicographically +// covers limb_bits >= 20, key_vec_len >= 4, limb_bits % decomp != 0, number of rows >= 4, rows are sorted lexicographically #[test] fn test_assert_sorted_chip_large_positive() { let bus_index: usize = 0; @@ -106,8 +106,7 @@ fn test_assert_sorted_chip_large_positive() { .expect("Verification failed"); } -// covers limb_bits >= 20, key_vec_len >= 4, limb_bits % decomp != 0, number of rows >= 4, each limb has at -// most limb_bits bits, rows are not sorted lexicographically +// covers limb_bits >= 20, key_vec_len >= 4, limb_bits % decomp != 0, number of rows >= 4, rows are not sorted lexicographically #[test] fn test_assert_sorted_chip_unsorted_negative() { let bus_index: usize = 0; diff --git a/chips/src/assert_sorted/trace.rs b/chips/src/assert_sorted/trace.rs index 5a3980a3e3..6ac56c9ce1 100644 --- a/chips/src/assert_sorted/trace.rs +++ b/chips/src/assert_sorted/trace.rs @@ -14,9 +14,9 @@ impl AssertSortedChip { ); let mut rows: Vec = vec![]; - for i in 0..self.air.is_less_than_tuple_air().tuple_len() { + for i in 0..self.air.keys().len() { let key = self.air.keys()[i].clone(); - let next_key: Vec = if i == self.air.is_less_than_tuple_air().tuple_len() - 1 { + let next_key: Vec = if i == self.air.keys().len() - 1 { vec![0; self.air.is_less_than_tuple_air().tuple_len()] } else { self.air.keys()[i + 1].clone() @@ -28,8 +28,10 @@ impl AssertSortedChip { ) .flatten(); + // the current key let mut row: Vec = is_less_than_tuple_trace[0..self.air.is_less_than_tuple_air().tuple_len()].to_vec(); + // the less than indicator and the auxiliary columns row.extend_from_slice( &is_less_than_tuple_trace[2 * self.air.is_less_than_tuple_air().tuple_len()..], ); From a086260d8320f5a6542b213866066ad704f9b198 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Thu, 6 Jun 2024 12:12:45 -0400 Subject: [PATCH 21/46] chore: cleanup --- chips/src/is_less_than/air.rs | 2 + chips/src/is_less_than/mod.rs | 5 +- chips/src/is_less_than/tests/mod.rs | 4 +- chips/src/is_less_than/trace.rs | 7 +-- chips/src/is_less_than_tuple/air.rs | 13 +++-- chips/src/is_less_than_tuple/chip.rs | 4 +- chips/src/is_less_than_tuple/columns.rs | 64 ++++++++++------------- chips/src/is_less_than_tuple/mod.rs | 24 +++++---- chips/src/is_less_than_tuple/tests/mod.rs | 50 +++--------------- chips/src/is_less_than_tuple/trace.rs | 16 +++--- 10 files changed, 82 insertions(+), 107 deletions(-) diff --git a/chips/src/is_less_than/air.rs b/chips/src/is_less_than/air.rs index 8c86995ba7..00ad829cdf 100644 --- a/chips/src/is_less_than/air.rs +++ b/chips/src/is_less_than/air.rs @@ -61,6 +61,8 @@ impl SubAir for IsLessThanAir { y - x + AB::Expr::from_canonical_u64(1 << self.limb_bits()) - AB::Expr::one(); // constrain that the lower_bits + less_than * 2^limb_bits is the correct intermediate sum + // note that the intermediate value will be >= 2^limb_bits if and only if x < y, and check_val will therefore be + // the correct value if and only if less_than is the indicator for whether x < y let check_val = lower + less_than * AB::Expr::from_canonical_u64(1 << self.limb_bits()); builder.assert_eq(intermed_val, check_val); diff --git a/chips/src/is_less_than/mod.rs b/chips/src/is_less_than/mod.rs index 06ef0c089c..9feb959626 100644 --- a/chips/src/is_less_than/mod.rs +++ b/chips/src/is_less_than/mod.rs @@ -43,7 +43,10 @@ impl IsLessThanAir { } /** - * This chip computes whether one number is less than another. + * This chip checks whether one number is less than another. The two numbers have a max number of bits, + * given by limb_bits. The chip assumes that the two numbers are within limb_bits bits. The chip compares + * the numbers by decomposing them into limbs of size decomp bits, and interacts with a RangeCheckerGateChip + * to range check the decompositions. */ #[derive(Default, Getters)] pub struct IsLessThanChip { diff --git a/chips/src/is_less_than/tests/mod.rs b/chips/src/is_less_than/tests/mod.rs index 2c123a1dbe..7b2b7ae18f 100644 --- a/chips/src/is_less_than/tests/mod.rs +++ b/chips/src/is_less_than/tests/mod.rs @@ -21,7 +21,7 @@ fn test_is_less_than_chip_lt() { let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); let chip = IsLessThanChip::new(bus_index, range_max, limb_bits, decomp, range_checker); - let trace = chip.generate_trace(vec![14321, 1, 773, 337], vec![26883, 0, 773, 456]); + let trace = chip.generate_trace(vec![(14321, 26883), (1, 0), (773, 773), (337, 456)]); let range_trace: DenseMatrix = chip.range_checker.generate_trace(); run_simple_test_no_pis( @@ -41,7 +41,7 @@ fn test_is_less_than_negative() { let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); let chip = IsLessThanChip::new(bus_index, range_max, limb_bits, decomp, range_checker); - let mut trace = chip.generate_trace(vec![446], vec![553]); + let mut trace = chip.generate_trace(vec![(446, 553)]); let range_trace = chip.range_checker.generate_trace(); trace.values[2] = AbstractField::from_canonical_u64(0); diff --git a/chips/src/is_less_than/trace.rs b/chips/src/is_less_than/trace.rs index e18ed23468..e761080e1a 100644 --- a/chips/src/is_less_than/trace.rs +++ b/chips/src/is_less_than/trace.rs @@ -11,16 +11,17 @@ use super::{ }; impl IsLessThanChip { - pub fn generate_trace(&self, x: Vec, y: Vec) -> RowMajorMatrix { + pub fn generate_trace(&self, pairs: Vec<(u32, u32)>) -> RowMajorMatrix { let num_cols: usize = IsLessThanCols::::get_width(*self.air.limb_bits(), *self.air.decomp()); let mut rows = vec![]; - for i in 0..x.len() { + // generate a row for each pair of numbers to compare + for (x, y) in pairs { let row: Vec = self .air - .generate_trace_row((x[i], y[i], self.range_checker.clone())) + .generate_trace_row((x, y, self.range_checker.clone())) .flatten(); rows.extend(row); } diff --git a/chips/src/is_less_than_tuple/air.rs b/chips/src/is_less_than_tuple/air.rs index f2ab610fef..b2c4e96d06 100644 --- a/chips/src/is_less_than_tuple/air.rs +++ b/chips/src/is_less_than_tuple/air.rs @@ -60,11 +60,11 @@ impl SubAir for IsLessThanTupleAir { let x = io.x.clone(); let y = io.y.clone(); + // here we constrain that less_than[i] indicates whether x[i] < y[i] using the IsLessThan subchip for each i for i in 0..x.len() { let x_val = x[i]; let y_val = y[i]; - // here we constrain that less_than[i] indicates whether x[i] < y[i] using the IsLessThan subchip let is_less_than_cols = IsLessThanCols { io: IsLessThanIOCols { x: x_val, @@ -78,14 +78,14 @@ impl SubAir for IsLessThanTupleAir { }; SubAir::eval( - &self.is_lt_airs[i].clone(), + &self.is_less_than_airs[i].clone(), builder, is_less_than_cols.io, is_less_than_cols.aux, ); } - // together, these constrain that is_equal is the indicator for whether diff == 0, i.e. x[i] = y[i] + // here, we constrain that is_equal is the indicator for whether diff == 0, i.e. x[i] = y[i] for i in 0..x.len() { let is_equal = aux.is_equal[i]; let inv = aux.is_equal_aux[i].inv; @@ -103,22 +103,29 @@ impl SubAir for IsLessThanTupleAir { SubAir::eval(&is_equal_chip, builder, is_equal_cols.io, is_equal_cols.aux); } + // here, we constrain that is_equal_cumulative and less_than_cumulative are the correct values let is_equal_cumulative = aux.is_equal_cumulative.clone(); let less_than_cumulative = aux.less_than_cumulative.clone(); builder.assert_eq(is_equal_cumulative[0], aux.is_equal[0]); builder.assert_eq(less_than_cumulative[0], aux.less_than[0]); + for i in 1..x.len() { + // this constrains that is_equal_cumulative[i] indicates whether the first i elements of x and y are equal builder.assert_eq( is_equal_cumulative[i], is_equal_cumulative[i - 1] * aux.is_equal[i], ); + // this constrains that less_than_cumulative[i] indicates whether the first i elements of x are less than + // the first i elements of y, lexicographically + // note that less_than_cumulative[i - 1] and is_equal_cumulative[i - 1] are never both 1 builder.assert_eq( less_than_cumulative[i], less_than_cumulative[i - 1] + aux.less_than[i] * is_equal_cumulative[i - 1], ); } + // constrain that the tuple_less_than does indicate whether x < y, lexicographically builder.assert_eq(io.tuple_less_than, less_than_cumulative[x.len() - 1]); } } diff --git a/chips/src/is_less_than_tuple/chip.rs b/chips/src/is_less_than_tuple/chip.rs index 180f43051d..a8ba4ec098 100644 --- a/chips/src/is_less_than_tuple/chip.rs +++ b/chips/src/is_less_than_tuple/chip.rs @@ -29,9 +29,9 @@ impl Chip for IsLessThanTupleAir { impl SubAirWithInteractions for IsLessThanTupleAir { fn sends(&self, col_indices: IsLessThanTupleCols) -> Vec> { - // num_limbs is the number of limbs, not including the last shifted limb let mut interactions = vec![]; + // we need to get the interactions from the IsLessThan subchip for i in 0..self.tuple_len() { let is_less_than_cols = IsLessThanCols { io: IsLessThanIOCols { @@ -46,7 +46,7 @@ impl SubAirWithInteractions for IsLessThanTupleAir { }; let curr_interactions = - SubAirWithInteractions::::sends(&self.is_lt_airs[i], is_less_than_cols); + SubAirWithInteractions::::sends(&self.is_less_than_airs[i], is_less_than_cols); interactions.extend(curr_interactions); } diff --git a/chips/src/is_less_than_tuple/columns.rs b/chips/src/is_less_than_tuple/columns.rs index c63ba717b3..4dd75e55ed 100644 --- a/chips/src/is_less_than_tuple/columns.rs +++ b/chips/src/is_less_than_tuple/columns.rs @@ -26,31 +26,16 @@ pub struct IsLessThanTupleCols { impl IsLessThanTupleCols { pub fn from_slice(slc: &[T], limb_bits: Vec, decomp: usize, tuple_len: usize) -> Self { - let mut x: Vec = vec![]; - let mut y: Vec = vec![]; - - let mut lower_vec: Vec = vec![]; - let mut lower_decomp_vec: Vec> = vec![]; - let mut less_than_aux: Vec> = vec![]; - - let mut less_than: Vec = vec![]; - let mut is_equal: Vec = vec![]; - let mut inverses: Vec = vec![]; - let mut is_equal_aux: Vec> = vec![]; - - let mut is_equal_cumulative: Vec = vec![]; - let mut less_than_cumulative: Vec = vec![]; - let mut curr_start_idx = 0; let mut curr_end_idx = tuple_len; // get the actual tuples, which are x and y - x.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); + let x = slc[curr_start_idx..curr_end_idx].to_vec(); curr_start_idx = curr_end_idx; curr_end_idx += tuple_len; - y.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); + let y = slc[curr_start_idx..curr_end_idx].to_vec(); curr_start_idx = curr_end_idx; curr_end_idx += 1; @@ -62,15 +47,17 @@ impl IsLessThanTupleCols { curr_end_idx += tuple_len; // get the indicators for whether x[i] < y[i] for all indices - less_than.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); + let less_than = slc[curr_start_idx..curr_end_idx].to_vec(); curr_start_idx = curr_end_idx; curr_end_idx += tuple_len; // get the lower bits for each 2^limb_bits[i] + y[i] - x[i] - 1 - lower_vec.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); + let lower_vec = slc[curr_start_idx..curr_end_idx].to_vec(); // get the lower bits decompositions + let mut lower_decomp_vec: Vec> = vec![]; + for &limb_bit in limb_bits.iter() { let num_limbs = (limb_bit + decomp - 1) / decomp; curr_start_idx = curr_end_idx; @@ -85,41 +72,44 @@ impl IsLessThanTupleCols { lower_decomp_vec.push(lower_bits_curr); } - for i in 0..tuple_len { - let less_than_col = IsLessThanAuxCols { - lower: lower_vec[i].clone(), - lower_decomp: lower_decomp_vec[i].clone(), - }; - - less_than_aux.push(less_than_col); - } - curr_start_idx = curr_end_idx; curr_end_idx += tuple_len; // get whether y[i] - x[i] == 0 - is_equal.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); + let is_equal = slc[curr_start_idx..curr_end_idx].to_vec(); curr_start_idx = curr_end_idx; curr_end_idx += tuple_len; // get the inverses k such that k * (diff[i] + is_zero[i]) = 1 - inverses.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); + let inverses = slc[curr_start_idx..curr_end_idx].to_vec(); curr_start_idx = curr_end_idx; curr_end_idx += tuple_len; - for inv in inverses.iter() { - let is_equal_col = IsEqualAuxCols { inv: inv.clone() }; - is_equal_aux.push(is_equal_col); - } - - is_equal_cumulative.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); + let is_equal_cumulative = slc[curr_start_idx..curr_end_idx].to_vec(); curr_start_idx = curr_end_idx; curr_end_idx += tuple_len; - less_than_cumulative.extend_from_slice(&slc[curr_start_idx..curr_end_idx]); + let less_than_cumulative = slc[curr_start_idx..curr_end_idx].to_vec(); + + // generate the less_than_aux and is_equal_aux columns + let mut less_than_aux: Vec> = vec![]; + for i in 0..tuple_len { + let less_than_col = IsLessThanAuxCols { + lower: lower_vec[i].clone(), + lower_decomp: lower_decomp_vec[i].clone(), + }; + + less_than_aux.push(less_than_col); + } + + let mut is_equal_aux: Vec> = vec![]; + for inv in inverses.iter() { + let is_equal_col = IsEqualAuxCols { inv: inv.clone() }; + is_equal_aux.push(is_equal_col); + } IsLessThanTupleCols { io: IsLessThanTupleIOCols { diff --git a/chips/src/is_less_than_tuple/mod.rs b/chips/src/is_less_than_tuple/mod.rs index 847be4a110..aa51e6d645 100644 --- a/chips/src/is_less_than_tuple/mod.rs +++ b/chips/src/is_less_than_tuple/mod.rs @@ -25,12 +25,12 @@ pub struct IsLessThanTupleAir { decomp: usize, // IsLessThanAirs for each tuple element #[getset(get = "pub")] - is_lt_airs: Vec, + is_less_than_airs: Vec, } impl IsLessThanTupleAir { pub fn new(bus_index: usize, range_max: u32, limb_bits: Vec, decomp: usize) -> Self { - let is_lt_airs = limb_bits + let is_less_than_airs = limb_bits .iter() .map(|&limb_bit| IsLessThanAir::new(bus_index, range_max, limb_bit, decomp)) .collect::>(); @@ -39,23 +39,29 @@ impl IsLessThanTupleAir { bus_index, range_max, decomp, - is_lt_airs, + is_less_than_airs, } } pub fn tuple_len(&self) -> usize { - self.is_lt_airs.len() + self.is_less_than_airs.len() } pub fn limb_bits(&self) -> Vec { - self.is_lt_airs.iter().map(|air| *air.limb_bits()).collect() + self.is_less_than_airs + .iter() + .map(|air| *air.limb_bits()) + .collect() } } /** - * This Chip constrains that consecutive rows are sorted lexicographically. + * This chip computes whether one tuple is lexicographically less than another. Each element of the + * tuple has its own max number of bits, given by the limb_bits array. The chip assumes that each limb + * is within its given max limb_bits. * - * Each row consists of a key decomposed into limbs with at most limb_bits bits + * The IsLessThanTupleChip uses the IsLessThanChip as a subchip to check whether individual tuple elements + * are less than each other. */ #[derive(Default, Getters)] pub struct IsLessThanTupleChip { @@ -72,7 +78,7 @@ impl IsLessThanTupleChip { decomp: usize, range_checker: Arc, ) -> Self { - let is_lt_airs = limb_bits + let is_less_than_airs = limb_bits .iter() .map(|&limb_bit| IsLessThanAir::new(bus_index, range_max, limb_bit, decomp)) .collect::>(); @@ -81,7 +87,7 @@ impl IsLessThanTupleChip { bus_index, range_max, decomp, - is_lt_airs, + is_less_than_airs, }; Self { air, range_checker } diff --git a/chips/src/is_less_than_tuple/tests/mod.rs b/chips/src/is_less_than_tuple/tests/mod.rs index 2c45598497..8fee17feb8 100644 --- a/chips/src/is_less_than_tuple/tests/mod.rs +++ b/chips/src/is_less_than_tuple/tests/mod.rs @@ -20,49 +20,13 @@ fn test_is_less_than_tuple_chip_lt() { let chip = IsLessThanTupleChip::new(bus_index, range_max, limb_bits, decomp, range_checker); let range_checker = chip.range_checker.as_ref(); - let trace = chip.generate_trace(vec![vec![14321, 123]], vec![vec![26678, 233]]); - let range_checker_trace = range_checker.generate_trace(); - - run_simple_test_no_pis( - vec![&chip.air, range_checker], - vec![trace, range_checker_trace], - ) - .expect("Verification failed"); -} - -#[test] -fn test_is_less_than_tuple_chip_gt() { - let bus_index: usize = 0; - let limb_bits: Vec = vec![8, 16]; - let decomp: usize = 8; - let range_max: u32 = 1 << decomp; - - let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); - - let chip = IsLessThanTupleChip::new(bus_index, range_max, limb_bits, decomp, range_checker); - let range_checker = chip.range_checker.as_ref(); - let trace = chip.generate_trace(vec![vec![244, 14321]], vec![vec![233, 26678]]); - let range_checker_trace = range_checker.generate_trace(); - - run_simple_test_no_pis( - vec![&chip.air, range_checker], - vec![trace, range_checker_trace], - ) - .expect("Verification failed"); -} - -#[test] -fn test_is_less_than_tuple_chip_eq() { - let bus_index: usize = 0; - let limb_bits: Vec = vec![16, 8]; - let decomp: usize = 8; - let range_max: u32 = 1 << decomp; - let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); - - let chip = IsLessThanTupleChip::new(bus_index, range_max, limb_bits, decomp, range_checker); - let range_checker = chip.range_checker.as_ref(); - let trace = chip.generate_trace(vec![vec![14321, 244]], vec![vec![14321, 244]]); + let trace = chip.generate_trace(vec![ + (vec![14321, 123], vec![26678, 233]), + (vec![26678, 244], vec![14321, 233]), + (vec![14321, 244], vec![14321, 244]), + (vec![26678, 233], vec![14321, 244]), + ]); let range_checker_trace = range_checker.generate_trace(); run_simple_test_no_pis( @@ -83,7 +47,7 @@ fn test_is_less_than_tuple_chip_negative() { let chip = IsLessThanTupleChip::new(bus_index, range_max, limb_bits, decomp, range_checker); let range_checker = chip.range_checker.as_ref(); - let mut trace = chip.generate_trace(vec![vec![14321, 123]], vec![vec![26678, 233]]); + let mut trace = chip.generate_trace(vec![(vec![14321, 123], vec![26678, 233])]); let range_checker_trace = range_checker.generate_trace(); trace.values[2] = AbstractField::from_canonical_u64(0); diff --git a/chips/src/is_less_than_tuple/trace.rs b/chips/src/is_less_than_tuple/trace.rs index f453e576da..f3cdea1835 100644 --- a/chips/src/is_less_than_tuple/trace.rs +++ b/chips/src/is_less_than_tuple/trace.rs @@ -18,8 +18,7 @@ use super::{ impl IsLessThanTupleChip { pub fn generate_trace( &self, - x: Vec>, - y: Vec>, + tuple_pairs: Vec<(Vec, Vec)>, ) -> RowMajorMatrix { let num_cols: usize = IsLessThanTupleCols::::get_width( self.air.limb_bits().clone(), @@ -29,10 +28,11 @@ impl IsLessThanTupleChip { let mut rows: Vec = vec![]; - for i in 0..x.len() { + // for each tuple pair, generate the trace row + for (x, y) in tuple_pairs { let row: Vec = self .air - .generate_trace_row((x[i].clone(), y[i].clone(), self.range_checker.clone())) + .generate_trace_row((x.clone(), y.clone(), self.range_checker.clone())) .flatten(); rows.extend(row); } @@ -74,8 +74,7 @@ impl LocalTraceInstructions for IsLessThanTupleAir { lower_decomp_vec.push(curr_less_than_row[4..].to_vec()); } - let mut less_than_cumulative: Vec = vec![]; - + // compute is_equal_cumulative let mut transition_index = 0; while transition_index < x.len() && x[transition_index] == y[transition_index] { transition_index += 1; @@ -86,7 +85,9 @@ impl LocalTraceInstructions for IsLessThanTupleAir { .chain(std::iter::repeat(F::zero()).take(x.len() - transition_index)) .collect::>(); - // compute whether the x < y + let mut less_than_cumulative: Vec = vec![]; + + // compute less_than_cumulative for i in 0..x.len() { let mut less_than_curr = if i > 0 { less_than_cumulative[i - 1] @@ -129,6 +130,7 @@ impl LocalTraceInstructions for IsLessThanTupleAir { } } + // compute less_than_aux and is_equal_aux let mut less_than_aux: Vec> = vec![]; for i in 0..x.len() { let less_than_col = IsLessThanAuxCols { From 6896610837b241c3572978fbeeaf23431ad1b239 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Thu, 6 Jun 2024 12:27:42 -0400 Subject: [PATCH 22/46] chore: include roundtrip flatten and from_slice tests --- chips/src/is_less_than/tests/mod.rs | 17 +++++++++++++++++ chips/src/is_less_than_tuple/tests/mod.rs | 23 +++++++++++++++++++++-- 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/chips/src/is_less_than/tests/mod.rs b/chips/src/is_less_than/tests/mod.rs index 7b2b7ae18f..8c6131ff40 100644 --- a/chips/src/is_less_than/tests/mod.rs +++ b/chips/src/is_less_than/tests/mod.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use crate::range_gate::RangeCheckerGateChip; use super::super::is_less_than::IsLessThanChip; +use super::columns::IsLessThanCols; use afs_stark_backend::prover::USE_DEBUG_BUILDER; use afs_stark_backend::verifier::VerificationError; @@ -11,6 +12,22 @@ use p3_baby_bear::BabyBear; use p3_field::AbstractField; use p3_matrix::dense::DenseMatrix; +#[test] +fn test_flatten_fromslice_roundtrip() { + let limb_bits = 16; + let decomp = 8; + + let num_cols = IsLessThanCols::::get_width(limb_bits, decomp); + let all_cols = (0..num_cols).collect::>(); + + let cols_numbered = IsLessThanCols::::from_slice(&all_cols, limb_bits, decomp); + let flattened = cols_numbered.flatten(); + + for (i, col) in flattened.iter().enumerate() { + assert_eq!(*col, all_cols[i]); + } +} + #[test] fn test_is_less_than_chip_lt() { let bus_index: usize = 0; diff --git a/chips/src/is_less_than_tuple/tests/mod.rs b/chips/src/is_less_than_tuple/tests/mod.rs index 8fee17feb8..3ffb8a3467 100644 --- a/chips/src/is_less_than_tuple/tests/mod.rs +++ b/chips/src/is_less_than_tuple/tests/mod.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use crate::is_less_than_tuple::columns::IsLessThanTupleCols; use crate::range_gate::RangeCheckerGateChip; use super::super::is_less_than_tuple::IsLessThanTupleChip; @@ -10,10 +11,28 @@ use afs_test_utils::config::baby_bear_poseidon2::run_simple_test_no_pis; use p3_field::AbstractField; #[test] -fn test_is_less_than_tuple_chip_lt() { +fn test_flatten_fromslice_roundtrip() { + let limb_bits = vec![16, 8, 20, 20]; + let decomp = 8; + let tuple_len = 4; + + let num_cols = IsLessThanTupleCols::::get_width(limb_bits.clone(), decomp, tuple_len); + let all_cols = (0..num_cols).collect::>(); + + let cols_numbered = + IsLessThanTupleCols::::from_slice(&all_cols, limb_bits.clone(), decomp, tuple_len); + let flattened = cols_numbered.flatten(); + + for (i, col) in flattened.iter().enumerate() { + assert_eq!(*col, all_cols[i]); + } +} + +#[test] +fn test_is_less_than_tuple_chip() { let bus_index: usize = 0; let limb_bits: Vec = vec![16, 8]; - let decomp: usize = 8; + let decomp: usize = 6; let range_max: u32 = 1 << decomp; let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); From 7c23b034d934e41921285fbc7c180a616eb50c96 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Thu, 6 Jun 2024 12:41:28 -0400 Subject: [PATCH 23/46] feat: flatten and from_slice for IO and Aux columns --- chips/src/assert_sorted/columns.rs | 83 ++-------------- chips/src/is_less_than/air.rs | 3 +- chips/src/is_less_than/chip.rs | 3 +- chips/src/is_less_than/columns.rs | 63 ++++++------ chips/src/is_less_than/tests/mod.rs | 4 +- chips/src/is_less_than_tuple/columns.rs | 113 ++++++++++++---------- chips/src/is_less_than_tuple/tests/mod.rs | 2 + 7 files changed, 110 insertions(+), 161 deletions(-) diff --git a/chips/src/assert_sorted/columns.rs b/chips/src/assert_sorted/columns.rs index 2d9e4c2b11..7ff1f202a6 100644 --- a/chips/src/assert_sorted/columns.rs +++ b/chips/src/assert_sorted/columns.rs @@ -1,9 +1,6 @@ use afs_derive::AlignedBorrow; -use crate::{ - is_equal::columns::IsEqualAuxCols, is_less_than::columns::IsLessThanAuxCols, - is_less_than_tuple::columns::IsLessThanTupleAuxCols, -}; +use crate::is_less_than_tuple::columns::IsLessThanTupleAuxCols; // Since AssertSortedChip contains a LessThanChip subchip, a subset of the columns are those of the // LessThanChip @@ -28,79 +25,13 @@ impl AssertSortedCols { // the next element is the indicator for whether the key is less than the next key let less_than_next_key = slc[curr_start_idx].clone(); curr_start_idx = curr_end_idx; - curr_end_idx += key_vec_len; - // the next key_vec_len elements are the indicators for the individual tuple element less thans - let less_than = slc[curr_start_idx..curr_end_idx].to_vec(); - curr_start_idx = curr_end_idx; - curr_end_idx += key_vec_len; - - // the next key_vec_len elements are the values of the lower bits of each intermediate sum - // (i.e. 2^limb_bits[i] + y[i] - x[i] - 1) - let lower_vec = slc[curr_start_idx..curr_end_idx].to_vec(); - - // the next elements are the decomposed lower bits - let mut lower_decomp_vec: Vec> = vec![]; - for curr_limb_bits in limb_bits.iter() { - let num_limbs = (curr_limb_bits + decomp - 1) / decomp; - - curr_start_idx = curr_end_idx; - curr_end_idx += num_limbs + 1; - - lower_decomp_vec.push(slc[curr_start_idx..curr_end_idx].to_vec()); - } - curr_start_idx = curr_end_idx; - curr_end_idx += key_vec_len; - - // the next key_vec_len elements are the indicator whether the difference is zero; if difference is - // zero then the two limbs must be equal - let is_equal = slc[curr_start_idx..curr_end_idx].to_vec(); - curr_start_idx = curr_end_idx; - curr_end_idx += key_vec_len; - - // the next key_vec_len elements contain the inverses of the corresponding sum of diff and is_zero; - // note that this sum will always be nonzero so the inverse will exist - let inverses = slc[curr_start_idx..curr_end_idx].to_vec(); - - curr_start_idx = curr_end_idx; - curr_end_idx += key_vec_len; - - // the next key_vec_len elements contain the cumulative is_equal indicators; is_equal_cumulative[i] - // indicates whether the first i elements of the key are equal to those of the next key - let is_equal_cumulative = slc[curr_start_idx..curr_end_idx].to_vec(); - - curr_start_idx = curr_end_idx; - curr_end_idx += key_vec_len; - - // the next key_vec_len elements contain the cumulative less than indicators; less_than_cumulative[i] - // indicates whether the first i elements of the key are lexicographically less than those of the next - // key - let less_than_cumulative = slc[curr_start_idx..curr_end_idx].to_vec(); - - // now, we construct the IsLessThanTupleAuxCols - let mut less_than_aux: Vec> = vec![]; - for i in 0..key_vec_len { - let less_than_col = IsLessThanAuxCols { - lower: lower_vec[i].clone(), - lower_decomp: lower_decomp_vec[i].clone(), - }; - less_than_aux.push(less_than_col); - } - - let mut is_equal_aux: Vec> = vec![]; - for inv in inverses.iter() { - let is_equal_col = IsEqualAuxCols { inv: inv.clone() }; - is_equal_aux.push(is_equal_col); - } - - let is_less_than_tuple_aux = IsLessThanTupleAuxCols { - less_than, - less_than_aux, - is_equal, - is_equal_aux, - is_equal_cumulative, - less_than_cumulative, - }; + let is_less_than_tuple_aux = IsLessThanTupleAuxCols::from_slice( + &slc[curr_start_idx..], + limb_bits, + decomp, + key_vec_len, + ); Self { key, diff --git a/chips/src/is_less_than/air.rs b/chips/src/is_less_than/air.rs index 00ad829cdf..dce3c52291 100644 --- a/chips/src/is_less_than/air.rs +++ b/chips/src/is_less_than/air.rs @@ -28,8 +28,7 @@ impl Air for IsLessThanAir { let local = main.row_slice(0); let local: &[AB::Var] = (*local).borrow(); - let local_cols = - IsLessThanCols::::from_slice(local, *self.limb_bits(), *self.decomp()); + let local_cols = IsLessThanCols::::from_slice(local); SubAir::eval(self, builder, local_cols.io, local_cols.aux); } diff --git a/chips/src/is_less_than/chip.rs b/chips/src/is_less_than/chip.rs index ffe08f478b..d1d4211d1f 100644 --- a/chips/src/is_less_than/chip.rs +++ b/chips/src/is_less_than/chip.rs @@ -10,8 +10,7 @@ impl Chip for IsLessThanAir { let num_cols = IsLessThanCols::::get_width(*self.limb_bits(), *self.decomp()); let all_cols = (0..num_cols).collect::>(); - let cols_numbered = - IsLessThanCols::::from_slice(&all_cols, *self.limb_bits(), *self.decomp()); + let cols_numbered = IsLessThanCols::::from_slice(&all_cols); SubAirWithInteractions::sends(self, cols_numbered) } diff --git a/chips/src/is_less_than/columns.rs b/chips/src/is_less_than/columns.rs index eb375d3ac4..e33a187e56 100644 --- a/chips/src/is_less_than/columns.rs +++ b/chips/src/is_less_than/columns.rs @@ -7,6 +7,20 @@ pub struct IsLessThanIOCols { pub less_than: T, } +impl IsLessThanIOCols { + pub fn from_slice(slc: &[T]) -> Self { + Self { + x: slc[0].clone(), + y: slc[1].clone(), + less_than: slc[2].clone(), + } + } + + pub fn flatten(&self) -> Vec { + vec![self.x.clone(), self.y.clone(), self.less_than.clone()] + } +} + pub struct IsLessThanAuxCols { pub lower: T, // lower_decomp consists of lower decomposed into limbs of size decomp where we also shift @@ -14,46 +28,37 @@ pub struct IsLessThanAuxCols { pub lower_decomp: Vec, } +impl IsLessThanAuxCols { + pub fn from_slice(slc: &[T]) -> Self { + Self { + lower: slc[0].clone(), + lower_decomp: slc[1..].to_vec(), + } + } + + pub fn flatten(&self) -> Vec { + let mut flattened = vec![self.lower.clone()]; + flattened.extend(self.lower_decomp.iter().cloned()); + flattened + } +} + pub struct IsLessThanCols { pub io: IsLessThanIOCols, pub aux: IsLessThanAuxCols, } impl IsLessThanCols { - pub fn from_slice(slc: &[T], limb_bits: usize, decomp: usize) -> Self { - // num_limbs is the number of limbs, not including the last shifted limb - let num_limbs = (limb_bits + decomp - 1) / decomp; - - // the first and second elements are x and y, respectively - let x = slc[0].clone(); - let y = slc[1].clone(); - // the third element is the less_than indicator - let less_than = slc[2].clone(); - - // the next element is the value of the lower num_limbs bits of the intermediate sum - let lower = slc[3].clone(); - - // the next num_limbs + 1 elements are the decomposed limbs of the lower bits of the - // intermediate sum - let lower_decomp = slc[4..4 + num_limbs + 1].to_vec(); - - let io = IsLessThanIOCols { x, y, less_than }; - let aux = IsLessThanAuxCols { - lower, - lower_decomp, - }; + pub fn from_slice(slc: &[T]) -> Self { + let io = IsLessThanIOCols::from_slice(&slc[..3]); + let aux = IsLessThanAuxCols::from_slice(&slc[3..]); Self { io, aux } } pub fn flatten(&self) -> Vec { - let mut flattened = vec![ - self.io.x.clone(), - self.io.y.clone(), - self.io.less_than.clone(), - self.aux.lower.clone(), - ]; - flattened.extend(self.aux.lower_decomp.iter().cloned()); + let mut flattened = self.io.flatten(); + flattened.extend(self.aux.flatten()); flattened } diff --git a/chips/src/is_less_than/tests/mod.rs b/chips/src/is_less_than/tests/mod.rs index 8c6131ff40..e13f3172f8 100644 --- a/chips/src/is_less_than/tests/mod.rs +++ b/chips/src/is_less_than/tests/mod.rs @@ -20,12 +20,14 @@ fn test_flatten_fromslice_roundtrip() { let num_cols = IsLessThanCols::::get_width(limb_bits, decomp); let all_cols = (0..num_cols).collect::>(); - let cols_numbered = IsLessThanCols::::from_slice(&all_cols, limb_bits, decomp); + let cols_numbered = IsLessThanCols::::from_slice(&all_cols); let flattened = cols_numbered.flatten(); for (i, col) in flattened.iter().enumerate() { assert_eq!(*col, all_cols[i]); } + + assert_eq!(num_cols, flattened.len()); } #[test] diff --git a/chips/src/is_less_than_tuple/columns.rs b/chips/src/is_less_than_tuple/columns.rs index 4dd75e55ed..341d6fdb72 100644 --- a/chips/src/is_less_than_tuple/columns.rs +++ b/chips/src/is_less_than_tuple/columns.rs @@ -9,6 +9,24 @@ pub struct IsLessThanTupleIOCols { pub tuple_less_than: T, } +impl IsLessThanTupleIOCols { + pub fn from_slice(slc: &[T], tuple_len: usize) -> Self { + Self { + x: slc[0..tuple_len].to_vec(), + y: slc[tuple_len..2 * tuple_len].to_vec(), + tuple_less_than: slc[2 * tuple_len].clone(), + } + } + + pub fn flatten(&self) -> Vec { + let mut flattened = vec![]; + flattened.extend_from_slice(&self.x); + flattened.extend_from_slice(&self.y); + flattened.push(self.tuple_less_than.clone()); + flattened + } +} + pub struct IsLessThanTupleAuxCols { pub less_than: Vec, pub less_than_aux: Vec>, @@ -19,34 +37,11 @@ pub struct IsLessThanTupleAuxCols { pub less_than_cumulative: Vec, } -pub struct IsLessThanTupleCols { - pub io: IsLessThanTupleIOCols, - pub aux: IsLessThanTupleAuxCols, -} - -impl IsLessThanTupleCols { +impl IsLessThanTupleAuxCols { pub fn from_slice(slc: &[T], limb_bits: Vec, decomp: usize, tuple_len: usize) -> Self { let mut curr_start_idx = 0; let mut curr_end_idx = tuple_len; - // get the actual tuples, which are x and y - let x = slc[curr_start_idx..curr_end_idx].to_vec(); - - curr_start_idx = curr_end_idx; - curr_end_idx += tuple_len; - - let y = slc[curr_start_idx..curr_end_idx].to_vec(); - - curr_start_idx = curr_end_idx; - curr_end_idx += 1; - - // get the indicator for whether x < y, lexicographically - let tuple_less_than = slc[curr_start_idx].clone(); - - curr_start_idx = curr_end_idx; - curr_end_idx += tuple_len; - - // get the indicators for whether x[i] < y[i] for all indices let less_than = slc[curr_start_idx..curr_end_idx].to_vec(); curr_start_idx = curr_end_idx; @@ -111,47 +106,63 @@ impl IsLessThanTupleCols { is_equal_aux.push(is_equal_col); } - IsLessThanTupleCols { - io: IsLessThanTupleIOCols { - x, - y, - tuple_less_than, - }, - aux: IsLessThanTupleAuxCols { - less_than, - less_than_aux, - is_equal, - is_equal_aux, - is_equal_cumulative, - less_than_cumulative, - }, + Self { + less_than, + less_than_aux, + is_equal, + is_equal_aux, + is_equal_cumulative, + less_than_cumulative, } } pub fn flatten(&self) -> Vec { let mut flattened = vec![]; - flattened.extend_from_slice(&self.io.x); - flattened.extend_from_slice(&self.io.y); - flattened.push(self.io.tuple_less_than.clone()); - flattened.extend_from_slice(&self.aux.less_than); - for i in 0..self.aux.less_than_aux.len() { - flattened.push(self.aux.less_than_aux[i].lower.clone()); + flattened.extend_from_slice(&self.less_than); + + for i in 0..self.less_than_aux.len() { + flattened.push(self.less_than_aux[i].lower.clone()); } - for i in 0..self.aux.less_than_aux.len() { - flattened.extend_from_slice(&self.aux.less_than_aux[i].lower_decomp); + for i in 0..self.less_than_aux.len() { + flattened.extend_from_slice(&self.less_than_aux[i].lower_decomp); } - flattened.extend_from_slice(&self.aux.is_equal); + flattened.extend_from_slice(&self.is_equal); - for i in 0..self.aux.is_equal_aux.len() { - flattened.push(self.aux.is_equal_aux[i].inv.clone()); + for i in 0..self.is_equal_aux.len() { + flattened.push(self.is_equal_aux[i].inv.clone()); } - flattened.extend_from_slice(&self.aux.is_equal_cumulative); - flattened.extend_from_slice(&self.aux.less_than_cumulative); + flattened.extend_from_slice(&self.is_equal_cumulative); + flattened.extend_from_slice(&self.less_than_cumulative); + + flattened + } +} + +pub struct IsLessThanTupleCols { + pub io: IsLessThanTupleIOCols, + pub aux: IsLessThanTupleAuxCols, +} + +impl IsLessThanTupleCols { + pub fn from_slice(slc: &[T], limb_bits: Vec, decomp: usize, tuple_len: usize) -> Self { + let io = IsLessThanTupleIOCols::from_slice(&slc[..2 * tuple_len + 1], tuple_len); + let aux = IsLessThanTupleAuxCols::from_slice( + &slc[2 * tuple_len + 1..], + limb_bits, + decomp, + tuple_len, + ); + + Self { io, aux } + } + pub fn flatten(&self) -> Vec { + let mut flattened = self.io.flatten(); + flattened.extend(self.aux.flatten()); flattened } diff --git a/chips/src/is_less_than_tuple/tests/mod.rs b/chips/src/is_less_than_tuple/tests/mod.rs index 3ffb8a3467..949481563f 100644 --- a/chips/src/is_less_than_tuple/tests/mod.rs +++ b/chips/src/is_less_than_tuple/tests/mod.rs @@ -26,6 +26,8 @@ fn test_flatten_fromslice_roundtrip() { for (i, col) in flattened.iter().enumerate() { assert_eq!(*col, all_cols[i]); } + + assert_eq!(num_cols, flattened.len()); } #[test] From 1bf59eb6c18233531218f7a4e3d708790e596dfa Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Thu, 6 Jun 2024 16:39:29 -0400 Subject: [PATCH 24/46] create files --- chips/src/lib.rs | 1 + chips/src/predicate/air.rs | 0 chips/src/predicate/chip.rs | 1 + chips/src/predicate/columns.rs | 7 +++++++ chips/src/predicate/mod.rs | 18 ++++++++++++++++++ chips/src/predicate/trace.rs | 1 + 6 files changed, 28 insertions(+) create mode 100644 chips/src/predicate/air.rs create mode 100644 chips/src/predicate/chip.rs create mode 100644 chips/src/predicate/columns.rs create mode 100644 chips/src/predicate/mod.rs create mode 100644 chips/src/predicate/trace.rs diff --git a/chips/src/lib.rs b/chips/src/lib.rs index 7f315c20c3..2a47d58265 100644 --- a/chips/src/lib.rs +++ b/chips/src/lib.rs @@ -2,6 +2,7 @@ pub mod keccak_permute; pub mod merkle_proof; pub mod page_controller; pub mod page_read; +pub mod predicate; pub mod range; pub mod range_gate; pub mod sub_chip; diff --git a/chips/src/predicate/air.rs b/chips/src/predicate/air.rs new file mode 100644 index 0000000000..e69de29bb2 diff --git a/chips/src/predicate/chip.rs b/chips/src/predicate/chip.rs new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/chips/src/predicate/chip.rs @@ -0,0 +1 @@ + diff --git a/chips/src/predicate/columns.rs b/chips/src/predicate/columns.rs new file mode 100644 index 0000000000..8e6c98278c --- /dev/null +++ b/chips/src/predicate/columns.rs @@ -0,0 +1,7 @@ +use super::Comp; + +pub struct PredicateIOCols { + pub x: T, + pub y: T, + pub cmp: Comp, +} diff --git a/chips/src/predicate/mod.rs b/chips/src/predicate/mod.rs new file mode 100644 index 0000000000..25f4966c83 --- /dev/null +++ b/chips/src/predicate/mod.rs @@ -0,0 +1,18 @@ +pub mod air; +pub mod chip; +pub mod columns; +pub mod trace; + +pub enum Comp { + Lt, + Lte, + Eq, + Gte, + Gt, +} + +pub struct PredicateAir {} + +pub struct PredicateChip { + pub air: PredicateAir, +} diff --git a/chips/src/predicate/trace.rs b/chips/src/predicate/trace.rs new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/chips/src/predicate/trace.rs @@ -0,0 +1 @@ + From 5b713a884a7337b52c8168f4ee5b130dca8bb6e4 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Fri, 7 Jun 2024 11:12:30 -0400 Subject: [PATCH 25/46] chore: begin PageIndexScanChip --- chips/src/is_less_than/columns.rs | 28 +++++---- chips/src/is_less_than_tuple/columns.rs | 52 +++++++++-------- chips/src/lib.rs | 2 +- chips/src/page_index_scan/air.rs | 57 +++++++++++++++++++ .../{predicate => page_index_scan}/chip.rs | 0 chips/src/page_index_scan/columns.rs | 49 ++++++++++++++++ chips/src/page_index_scan/mod.rs | 30 ++++++++++ .../{predicate => page_index_scan}/trace.rs | 0 chips/src/predicate/air.rs | 0 chips/src/predicate/columns.rs | 7 --- chips/src/predicate/mod.rs | 18 ------ 11 files changed, 180 insertions(+), 63 deletions(-) create mode 100644 chips/src/page_index_scan/air.rs rename chips/src/{predicate => page_index_scan}/chip.rs (100%) create mode 100644 chips/src/page_index_scan/columns.rs create mode 100644 chips/src/page_index_scan/mod.rs rename chips/src/{predicate => page_index_scan}/trace.rs (100%) delete mode 100644 chips/src/predicate/air.rs delete mode 100644 chips/src/predicate/columns.rs delete mode 100644 chips/src/predicate/mod.rs diff --git a/chips/src/is_less_than/columns.rs b/chips/src/is_less_than/columns.rs index e33a187e56..d04fb9c24b 100644 --- a/chips/src/is_less_than/columns.rs +++ b/chips/src/is_less_than/columns.rs @@ -19,6 +19,10 @@ impl IsLessThanIOCols { pub fn flatten(&self) -> Vec { vec![self.x.clone(), self.y.clone(), self.less_than.clone()] } + + pub fn get_width() -> usize { + 3 + } } pub struct IsLessThanAuxCols { @@ -41,6 +45,17 @@ impl IsLessThanAuxCols { flattened.extend(self.lower_decomp.iter().cloned()); flattened } + + pub fn get_width(limb_bits: usize, decomp: usize) -> usize { + let mut width = 0; + // for the lower + width += 1; + // for the decomposed lower + let num_limbs = (limb_bits + decomp - 1) / decomp; + width += num_limbs + 1; + + width + } } pub struct IsLessThanCols { @@ -63,17 +78,6 @@ impl IsLessThanCols { } pub fn get_width(limb_bits: usize, decomp: usize) -> usize { - let mut width = 0; - // for the x and y - width += 2; - // for the less_than indicator - width += 1; - // for the lower - width += 1; - // for the decomposed lower - let num_limbs = (limb_bits + decomp - 1) / decomp; - width += num_limbs + 1; - - width + IsLessThanIOCols::::get_width() + IsLessThanAuxCols::::get_width(limb_bits, decomp) } } diff --git a/chips/src/is_less_than_tuple/columns.rs b/chips/src/is_less_than_tuple/columns.rs index 341d6fdb72..8e2fe897b9 100644 --- a/chips/src/is_less_than_tuple/columns.rs +++ b/chips/src/is_less_than_tuple/columns.rs @@ -25,6 +25,10 @@ impl IsLessThanTupleIOCols { flattened.push(self.tuple_less_than.clone()); flattened } + + pub fn get_width(tuple_len: usize) -> usize { + tuple_len + tuple_len + 1 + } } pub struct IsLessThanTupleAuxCols { @@ -140,6 +144,27 @@ impl IsLessThanTupleAuxCols { flattened } + + pub fn get_width(limb_bits: Vec, decomp: usize, tuple_len: usize) -> usize { + let mut width = 0; + // for the less than indicator + width += tuple_len; + // for the lowers + width += tuple_len; + // for the decomposed lowers + for &limb_bit in limb_bits.iter() { + let num_limbs = (limb_bit + decomp - 1) / decomp; + width += num_limbs + 1; + } + // for the indicator whether difference is zero + width += tuple_len; + // for the inverses k such that k * (diff[i] + is_zero[i]) = 1 + width += tuple_len; + // for the cumulative is_equal and less_than + width += 2 * tuple_len; + + width + } } pub struct IsLessThanTupleCols { @@ -167,30 +192,7 @@ impl IsLessThanTupleCols { } pub fn get_width(limb_bits: Vec, decomp: usize, tuple_len: usize) -> usize { - let mut width = 0; - // for the x and y tuples - width += 2 * tuple_len; - // for the tuple less than indicator - width += 1; - // for the less than indicator - width += tuple_len; - // for the lowers - width += tuple_len; - - // for the decomposed lowers - for &limb_bit in limb_bits.iter() { - let num_limbs = (limb_bit + decomp - 1) / decomp; - width += num_limbs + 1; - } - - // for the indicator whether difference is zero - width += tuple_len; - // for the inverses k such that k * (diff[i] + is_zero[i]) = 1 - width += tuple_len; - - // for the cumulative is_equal and less_than - width += 2 * tuple_len; - - width + IsLessThanTupleIOCols::::get_width(tuple_len) + + IsLessThanTupleAuxCols::::get_width(limb_bits, decomp, tuple_len) } } diff --git a/chips/src/lib.rs b/chips/src/lib.rs index 1db7cf52f5..bc62810702 100644 --- a/chips/src/lib.rs +++ b/chips/src/lib.rs @@ -7,8 +7,8 @@ pub mod is_zero; pub mod keccak_permute; pub mod merkle_proof; pub mod page_controller; +pub mod page_index_scan; pub mod page_read; -pub mod predicate; pub mod range; pub mod range_gate; pub mod sub_chip; diff --git a/chips/src/page_index_scan/air.rs b/chips/src/page_index_scan/air.rs new file mode 100644 index 0000000000..bd0500cc55 --- /dev/null +++ b/chips/src/page_index_scan/air.rs @@ -0,0 +1,57 @@ +use std::borrow::Borrow; + +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_field::Field; +use p3_matrix::Matrix; + +use crate::{ + is_less_than_tuple::columns::{IsLessThanTupleCols, IsLessThanTupleIOCols}, + sub_chip::SubAir, +}; + +use super::{columns::PageIndexScanCols, PageIndexScanAir}; + +impl BaseAir for PageIndexScanAir { + fn width(&self) -> usize { + PageIndexScanCols::::get_width( + self.idx_len, + self.data_len, + self.limb_bits.clone(), + self.decomp, + ) + } +} + +impl Air for PageIndexScanAir { + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + + let local = main.row_slice(0); + let local: &[AB::Var] = (*local).borrow(); + + let local_cols = PageIndexScanCols::::from_slice( + local, + self.idx_len, + self.data_len, + self.decomp, + self.limb_bits.clone(), + ); + + let is_less_than_tuple_cols = IsLessThanTupleCols { + io: IsLessThanTupleIOCols { + x: local_cols.idx, + y: local_cols.x, + tuple_less_than: local_cols.satisfies_pred, + }, + aux: local_cols.is_less_than_tuple_aux, + }; + + // constrain the indicator that we used to check wheter key < x is correct + SubAir::eval( + &self.is_less_than_tuple_air, + &mut builder.when_transition(), + is_less_than_tuple_cols.io, + is_less_than_tuple_cols.aux, + ); + } +} diff --git a/chips/src/predicate/chip.rs b/chips/src/page_index_scan/chip.rs similarity index 100% rename from chips/src/predicate/chip.rs rename to chips/src/page_index_scan/chip.rs diff --git a/chips/src/page_index_scan/columns.rs b/chips/src/page_index_scan/columns.rs new file mode 100644 index 0000000000..dd7b648004 --- /dev/null +++ b/chips/src/page_index_scan/columns.rs @@ -0,0 +1,49 @@ +use crate::is_less_than_tuple::columns::IsLessThanTupleAuxCols; + +pub struct PageIndexScanCols { + pub is_alloc: T, + pub idx: Vec, + pub data: Vec, + + pub x: Vec, + + pub satisfies_pred: T, + pub is_less_than_tuple_aux: IsLessThanTupleAuxCols, +} + +impl PageIndexScanCols { + pub fn from_slice( + slc: &[T], + idx_len: usize, + data_len: usize, + decomp: usize, + limb_bits: Vec, + ) -> Self { + Self { + is_alloc: slc[0].clone(), + idx: slc[1..idx_len + 1].to_vec(), + data: slc[idx_len + 1..idx_len + data_len + 1].to_vec(), + x: slc[idx_len + data_len + 1..2 * idx_len + data_len + 1].to_vec(), + satisfies_pred: slc[2 * idx_len + data_len + 1].clone(), + is_less_than_tuple_aux: IsLessThanTupleAuxCols::from_slice( + &slc[2 * idx_len + data_len + 2..], + limb_bits, + decomp, + idx_len, + ), + } + } + + pub fn get_width( + idx_len: usize, + data_len: usize, + limb_bits: Vec, + decomp: usize, + ) -> usize { + 1 + idx_len + + data_len + + idx_len + + 1 + + IsLessThanTupleAuxCols::::get_width(limb_bits, decomp, idx_len) + } +} diff --git a/chips/src/page_index_scan/mod.rs b/chips/src/page_index_scan/mod.rs new file mode 100644 index 0000000000..0e2433d228 --- /dev/null +++ b/chips/src/page_index_scan/mod.rs @@ -0,0 +1,30 @@ +use crate::is_less_than_tuple::IsLessThanTupleAir; + +pub mod air; +pub mod chip; +pub mod columns; +pub mod trace; + +// pub enum Comp { +// Lt, +// Lte, +// Eq, +// Gte, +// Gt, +// } + +pub struct PageIndexScanAir { + pub bus_index: usize, + pub idx_len: usize, + pub data_len: usize, + + pub limb_bits: Vec, + pub decomp: usize, + + is_less_than_tuple_air: IsLessThanTupleAir, + // pub cmp: Comp, +} + +pub struct PageIndexScanChip { + pub air: PageIndexScanAir, +} diff --git a/chips/src/predicate/trace.rs b/chips/src/page_index_scan/trace.rs similarity index 100% rename from chips/src/predicate/trace.rs rename to chips/src/page_index_scan/trace.rs diff --git a/chips/src/predicate/air.rs b/chips/src/predicate/air.rs deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/chips/src/predicate/columns.rs b/chips/src/predicate/columns.rs deleted file mode 100644 index 8e6c98278c..0000000000 --- a/chips/src/predicate/columns.rs +++ /dev/null @@ -1,7 +0,0 @@ -use super::Comp; - -pub struct PredicateIOCols { - pub x: T, - pub y: T, - pub cmp: Comp, -} diff --git a/chips/src/predicate/mod.rs b/chips/src/predicate/mod.rs deleted file mode 100644 index 25f4966c83..0000000000 --- a/chips/src/predicate/mod.rs +++ /dev/null @@ -1,18 +0,0 @@ -pub mod air; -pub mod chip; -pub mod columns; -pub mod trace; - -pub enum Comp { - Lt, - Lte, - Eq, - Gte, - Gt, -} - -pub struct PredicateAir {} - -pub struct PredicateChip { - pub air: PredicateAir, -} From daad3de9a26afe6719e7d513c66cdde87cc0548d Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Fri, 7 Jun 2024 18:00:19 -0400 Subject: [PATCH 26/46] feat: prototype --- chips/src/lib.rs | 2 +- chips/src/page_index_scan/chip.rs | 1 - chips/src/page_index_scan/mod.rs | 30 --------- chips/src/page_index_scan/trace.rs | 1 - chips/src/single_page_index_scan/mod.rs | 5 ++ .../page_index_scan/air.rs | 20 ++++-- .../page_index_scan/chip.rs | 65 +++++++++++++++++++ .../page_index_scan/columns.rs | 5 +- .../page_index_scan/mod.rs | 65 +++++++++++++++++++ .../page_index_scan/trace.rs | 58 +++++++++++++++++ .../page_index_scan_verify/air.rs | 31 +++++++++ .../page_index_scan_verify/chip.rs | 35 ++++++++++ .../page_index_scan_verify/columns.rs | 19 ++++++ .../page_index_scan_verify/mod.rs | 42 ++++++++++++ .../page_index_scan_verify/trace.rs | 33 ++++++++++ chips/src/single_page_index_scan/tests.rs | 61 +++++++++++++++++ 16 files changed, 434 insertions(+), 39 deletions(-) delete mode 100644 chips/src/page_index_scan/chip.rs delete mode 100644 chips/src/page_index_scan/mod.rs delete mode 100644 chips/src/page_index_scan/trace.rs create mode 100644 chips/src/single_page_index_scan/mod.rs rename chips/src/{ => single_page_index_scan}/page_index_scan/air.rs (73%) create mode 100644 chips/src/single_page_index_scan/page_index_scan/chip.rs rename chips/src/{ => single_page_index_scan}/page_index_scan/columns.rs (89%) create mode 100644 chips/src/single_page_index_scan/page_index_scan/mod.rs create mode 100644 chips/src/single_page_index_scan/page_index_scan/trace.rs create mode 100644 chips/src/single_page_index_scan/page_index_scan_verify/air.rs create mode 100644 chips/src/single_page_index_scan/page_index_scan_verify/chip.rs create mode 100644 chips/src/single_page_index_scan/page_index_scan_verify/columns.rs create mode 100644 chips/src/single_page_index_scan/page_index_scan_verify/mod.rs create mode 100644 chips/src/single_page_index_scan/page_index_scan_verify/trace.rs create mode 100644 chips/src/single_page_index_scan/tests.rs diff --git a/chips/src/lib.rs b/chips/src/lib.rs index bc62810702..e664196e3a 100644 --- a/chips/src/lib.rs +++ b/chips/src/lib.rs @@ -7,10 +7,10 @@ pub mod is_zero; pub mod keccak_permute; pub mod merkle_proof; pub mod page_controller; -pub mod page_index_scan; pub mod page_read; pub mod range; pub mod range_gate; +pub mod single_page_index_scan; pub mod sub_chip; pub mod sum; mod utils; diff --git a/chips/src/page_index_scan/chip.rs b/chips/src/page_index_scan/chip.rs deleted file mode 100644 index 8b13789179..0000000000 --- a/chips/src/page_index_scan/chip.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/chips/src/page_index_scan/mod.rs b/chips/src/page_index_scan/mod.rs deleted file mode 100644 index 0e2433d228..0000000000 --- a/chips/src/page_index_scan/mod.rs +++ /dev/null @@ -1,30 +0,0 @@ -use crate::is_less_than_tuple::IsLessThanTupleAir; - -pub mod air; -pub mod chip; -pub mod columns; -pub mod trace; - -// pub enum Comp { -// Lt, -// Lte, -// Eq, -// Gte, -// Gt, -// } - -pub struct PageIndexScanAir { - pub bus_index: usize, - pub idx_len: usize, - pub data_len: usize, - - pub limb_bits: Vec, - pub decomp: usize, - - is_less_than_tuple_air: IsLessThanTupleAir, - // pub cmp: Comp, -} - -pub struct PageIndexScanChip { - pub air: PageIndexScanAir, -} diff --git a/chips/src/page_index_scan/trace.rs b/chips/src/page_index_scan/trace.rs deleted file mode 100644 index 8b13789179..0000000000 --- a/chips/src/page_index_scan/trace.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/chips/src/single_page_index_scan/mod.rs b/chips/src/single_page_index_scan/mod.rs new file mode 100644 index 0000000000..742f37163b --- /dev/null +++ b/chips/src/single_page_index_scan/mod.rs @@ -0,0 +1,5 @@ +pub mod page_index_scan; +pub mod page_index_scan_verify; + +#[cfg(test)] +pub mod tests; diff --git a/chips/src/page_index_scan/air.rs b/chips/src/single_page_index_scan/page_index_scan/air.rs similarity index 73% rename from chips/src/page_index_scan/air.rs rename to chips/src/single_page_index_scan/page_index_scan/air.rs index bd0500cc55..82d5859d7e 100644 --- a/chips/src/page_index_scan/air.rs +++ b/chips/src/single_page_index_scan/page_index_scan/air.rs @@ -6,18 +6,22 @@ use p3_matrix::Matrix; use crate::{ is_less_than_tuple::columns::{IsLessThanTupleCols, IsLessThanTupleIOCols}, - sub_chip::SubAir, + sub_chip::{AirConfig, SubAir}, }; use super::{columns::PageIndexScanCols, PageIndexScanAir}; +impl AirConfig for PageIndexScanAir { + type Cols = PageIndexScanCols; +} + impl BaseAir for PageIndexScanAir { fn width(&self) -> usize { PageIndexScanCols::::get_width( self.idx_len, self.data_len, - self.limb_bits.clone(), - self.decomp, + self.is_less_than_tuple_air.limb_bits().clone(), + *self.is_less_than_tuple_air.decomp(), ) } } @@ -33,8 +37,8 @@ impl Air for PageIndexScanAir { local, self.idx_len, self.data_len, - self.decomp, - self.limb_bits.clone(), + *self.is_less_than_tuple_air.decomp(), + self.is_less_than_tuple_air.limb_bits().clone(), ); let is_less_than_tuple_cols = IsLessThanTupleCols { @@ -46,6 +50,12 @@ impl Air for PageIndexScanAir { aux: local_cols.is_less_than_tuple_aux, }; + builder.assert_eq( + local_cols.is_alloc * local_cols.satisfies_pred, + local_cols.send_row, + ); + builder.assert_bool(local_cols.send_row); + // constrain the indicator that we used to check wheter key < x is correct SubAir::eval( &self.is_less_than_tuple_air, diff --git a/chips/src/single_page_index_scan/page_index_scan/chip.rs b/chips/src/single_page_index_scan/page_index_scan/chip.rs new file mode 100644 index 0000000000..d2ac1c3721 --- /dev/null +++ b/chips/src/single_page_index_scan/page_index_scan/chip.rs @@ -0,0 +1,65 @@ +use crate::{ + is_less_than_tuple::columns::{IsLessThanTupleCols, IsLessThanTupleIOCols}, + sub_chip::SubAirWithInteractions, +}; + +use super::columns::PageIndexScanCols; +use afs_stark_backend::interaction::{Chip, Interaction}; +use p3_air::VirtualPairCol; +use p3_field::PrimeField64; + +use super::PageIndexScanAir; + +impl Chip for PageIndexScanAir { + fn sends(&self) -> Vec> { + let num_cols = PageIndexScanCols::::get_width( + *self.idx_len(), + *self.data_len(), + self.is_less_than_tuple_air().limb_bits(), + *self.is_less_than_tuple_air().decomp(), + ); + let all_cols = (0..num_cols).collect::>(); + + let cols_numbered = PageIndexScanCols::::from_slice( + &all_cols, + *self.idx_len(), + *self.data_len(), + *self.is_less_than_tuple_air().decomp(), + self.is_less_than_tuple_air().limb_bits(), + ); + + let is_less_than_tuple_cols = IsLessThanTupleCols { + io: IsLessThanTupleIOCols { + x: cols_numbered.idx.clone(), + y: cols_numbered.x.clone(), + tuple_less_than: cols_numbered.satisfies_pred, + }, + aux: cols_numbered.is_less_than_tuple_aux, + }; + + let mut cols = vec![]; + cols.push(cols_numbered.is_alloc); + cols.extend(cols_numbered.idx); + cols.extend(cols_numbered.data); + + let virtual_cols = cols + .iter() + .map(|col| VirtualPairCol::single_main(*col)) + .collect::>(); + + let mut interactions = vec![Interaction { + fields: virtual_cols, + count: VirtualPairCol::single_main(cols_numbered.send_row), + argument_index: *self.bus_index(), + }]; + + let mut subchip_interactions = SubAirWithInteractions::::sends( + self.is_less_than_tuple_air(), + is_less_than_tuple_cols, + ); + + interactions.append(&mut subchip_interactions); + + interactions + } +} diff --git a/chips/src/page_index_scan/columns.rs b/chips/src/single_page_index_scan/page_index_scan/columns.rs similarity index 89% rename from chips/src/page_index_scan/columns.rs rename to chips/src/single_page_index_scan/page_index_scan/columns.rs index dd7b648004..eb3422ec21 100644 --- a/chips/src/page_index_scan/columns.rs +++ b/chips/src/single_page_index_scan/page_index_scan/columns.rs @@ -8,6 +8,7 @@ pub struct PageIndexScanCols { pub x: Vec, pub satisfies_pred: T, + pub send_row: T, pub is_less_than_tuple_aux: IsLessThanTupleAuxCols, } @@ -25,8 +26,9 @@ impl PageIndexScanCols { data: slc[idx_len + 1..idx_len + data_len + 1].to_vec(), x: slc[idx_len + data_len + 1..2 * idx_len + data_len + 1].to_vec(), satisfies_pred: slc[2 * idx_len + data_len + 1].clone(), + send_row: slc[2 * idx_len + data_len + 2].clone(), is_less_than_tuple_aux: IsLessThanTupleAuxCols::from_slice( - &slc[2 * idx_len + data_len + 2..], + &slc[2 * idx_len + data_len + 3..], limb_bits, decomp, idx_len, @@ -44,6 +46,7 @@ impl PageIndexScanCols { + data_len + idx_len + 1 + + 1 + IsLessThanTupleAuxCols::::get_width(limb_bits, decomp, idx_len) } } diff --git a/chips/src/single_page_index_scan/page_index_scan/mod.rs b/chips/src/single_page_index_scan/page_index_scan/mod.rs new file mode 100644 index 0000000000..bcbeeab0c9 --- /dev/null +++ b/chips/src/single_page_index_scan/page_index_scan/mod.rs @@ -0,0 +1,65 @@ +use std::sync::Arc; + +use getset::Getters; + +use crate::{is_less_than_tuple::IsLessThanTupleAir, range_gate::RangeCheckerGateChip}; + +pub mod air; +pub mod chip; +pub mod columns; +pub mod trace; + +#[derive(Default)] +pub enum Comp { + #[default] + Lt, + Lte, + Eq, + Gte, + Gt, +} + +#[derive(Default, Getters)] +pub struct PageIndexScanAir { + #[getset(get = "pub")] + pub bus_index: usize, + #[getset(get = "pub")] + pub idx_len: usize, + #[getset(get = "pub")] + pub data_len: usize, + + #[getset(get = "pub")] + is_less_than_tuple_air: IsLessThanTupleAir, +} + +pub struct PageIndexScanChip { + pub air: PageIndexScanAir, + pub range_checker: Arc, +} + +impl PageIndexScanChip { + pub fn new( + bus_index: usize, + idx_len: usize, + data_len: usize, + range_max: u32, + limb_bits: Vec, + decomp: usize, + range_checker: Arc, + ) -> Self { + Self { + air: PageIndexScanAir { + bus_index, + idx_len, + data_len, + is_less_than_tuple_air: IsLessThanTupleAir::new( + bus_index, + range_max, + limb_bits.clone(), + decomp, + ), + }, + range_checker, + } + } +} diff --git a/chips/src/single_page_index_scan/page_index_scan/trace.rs b/chips/src/single_page_index_scan/page_index_scan/trace.rs new file mode 100644 index 0000000000..068633cd7a --- /dev/null +++ b/chips/src/single_page_index_scan/page_index_scan/trace.rs @@ -0,0 +1,58 @@ +use p3_field::PrimeField64; +use p3_matrix::dense::RowMajorMatrix; + +use crate::sub_chip::LocalTraceInstructions; + +use super::{columns::PageIndexScanCols, PageIndexScanChip}; + +impl PageIndexScanChip { + pub fn generate_trace( + &self, + page: Vec>, + x: Vec, + ) -> RowMajorMatrix { + let num_cols: usize = PageIndexScanCols::::get_width( + self.air.idx_len, + self.air.data_len, + self.air.is_less_than_tuple_air.limb_bits(), + *self.air.is_less_than_tuple_air.decomp(), + ); + + let mut rows: Vec = vec![]; + + for page_row in &page { + let mut row: Vec = vec![]; + + let is_alloc = F::from_canonical_u32(page_row[0]); + row.push(is_alloc); + + let idx = page_row[1..1 + self.air.idx_len].to_vec(); + let idx_trace: Vec = idx.iter().map(|x| F::from_canonical_u32(*x)).collect(); + row.extend(idx_trace); + + let data = + page_row[1 + self.air.idx_len..1 + self.air.idx_len + self.air.data_len].to_vec(); + let data_trace: Vec = data.iter().map(|x| F::from_canonical_u32(*x)).collect(); + row.extend(data_trace); + + let x_trace: Vec = x.iter().map(|x| F::from_canonical_u32(*x)).collect(); + row.extend(x_trace); + + let is_less_than_tuple_trace: Vec = LocalTraceInstructions::generate_trace_row( + &self.air.is_less_than_tuple_air, + (idx.clone(), x.clone(), self.range_checker.clone()), + ) + .flatten(); + + row.push(is_less_than_tuple_trace[2 * self.air.idx_len]); + let send_row = is_less_than_tuple_trace[2 * self.air.idx_len] * is_alloc; + row.push(send_row); + + row.extend_from_slice(&is_less_than_tuple_trace[2 * self.air.idx_len + 1..]); + + rows.extend_from_slice(&row); + } + + RowMajorMatrix::new(rows, num_cols) + } +} diff --git a/chips/src/single_page_index_scan/page_index_scan_verify/air.rs b/chips/src/single_page_index_scan/page_index_scan_verify/air.rs new file mode 100644 index 0000000000..2174705fb9 --- /dev/null +++ b/chips/src/single_page_index_scan/page_index_scan_verify/air.rs @@ -0,0 +1,31 @@ +use std::borrow::Borrow; + +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_field::Field; +use p3_matrix::Matrix; + +use crate::sub_chip::AirConfig; + +use super::{columns::PageIndexScanVerifyCols, PageIndexScanVerifyAir}; + +impl AirConfig for PageIndexScanVerifyAir { + type Cols = PageIndexScanVerifyCols; +} + +impl BaseAir for PageIndexScanVerifyAir { + fn width(&self) -> usize { + PageIndexScanVerifyCols::::get_width(self.idx_len, self.data_len) + } +} + +impl Air for PageIndexScanVerifyAir { + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + + let local = main.row_slice(0); + let local: &[AB::Var] = (*local).borrow(); + + let _local_cols = + PageIndexScanVerifyCols::::from_slice(local, self.idx_len, self.data_len); + } +} diff --git a/chips/src/single_page_index_scan/page_index_scan_verify/chip.rs b/chips/src/single_page_index_scan/page_index_scan_verify/chip.rs new file mode 100644 index 0000000000..f7d3f9269d --- /dev/null +++ b/chips/src/single_page_index_scan/page_index_scan_verify/chip.rs @@ -0,0 +1,35 @@ +use super::columns::PageIndexScanVerifyCols; +use afs_stark_backend::interaction::{Chip, Interaction}; +use p3_air::VirtualPairCol; +use p3_field::PrimeField64; + +use super::PageIndexScanVerifyAir; + +impl Chip for PageIndexScanVerifyAir { + fn receives(&self) -> Vec> { + let num_cols = PageIndexScanVerifyCols::::get_width(*self.idx_len(), *self.data_len()); + let all_cols = (0..num_cols).collect::>(); + + let cols_numbered = PageIndexScanVerifyCols::::from_slice( + &all_cols, + *self.idx_len(), + *self.data_len(), + ); + + let mut cols = vec![]; + cols.push(cols_numbered.is_alloc); + cols.extend(cols_numbered.idx); + cols.extend(cols_numbered.data); + + let virtual_cols = cols + .iter() + .map(|col| VirtualPairCol::single_main(*col)) + .collect::>(); + + vec![Interaction { + fields: virtual_cols, + count: VirtualPairCol::single_main(cols_numbered.is_alloc), + argument_index: *self.bus_index(), + }] + } +} diff --git a/chips/src/single_page_index_scan/page_index_scan_verify/columns.rs b/chips/src/single_page_index_scan/page_index_scan_verify/columns.rs new file mode 100644 index 0000000000..82946e3352 --- /dev/null +++ b/chips/src/single_page_index_scan/page_index_scan_verify/columns.rs @@ -0,0 +1,19 @@ +pub struct PageIndexScanVerifyCols { + pub is_alloc: T, + pub idx: Vec, + pub data: Vec, +} + +impl PageIndexScanVerifyCols { + pub fn from_slice(slc: &[T], idx_len: usize, data_len: usize) -> Self { + Self { + is_alloc: slc[0].clone(), + idx: slc[1..idx_len + 1].to_vec(), + data: slc[idx_len + 1..idx_len + data_len + 1].to_vec(), + } + } + + pub fn get_width(idx_len: usize, data_len: usize) -> usize { + 1 + idx_len + data_len + } +} diff --git a/chips/src/single_page_index_scan/page_index_scan_verify/mod.rs b/chips/src/single_page_index_scan/page_index_scan_verify/mod.rs new file mode 100644 index 0000000000..77476c2e61 --- /dev/null +++ b/chips/src/single_page_index_scan/page_index_scan_verify/mod.rs @@ -0,0 +1,42 @@ +use getset::Getters; + +pub mod air; +pub mod chip; +pub mod columns; +pub mod trace; + +#[derive(Default)] +pub enum Comp { + #[default] + Lt, + Lte, + Eq, + Gte, + Gt, +} + +#[derive(Default, Getters)] +pub struct PageIndexScanVerifyAir { + #[getset(get = "pub")] + pub bus_index: usize, + #[getset(get = "pub")] + pub idx_len: usize, + #[getset(get = "pub")] + pub data_len: usize, +} + +pub struct PageIndexScanVerifyChip { + pub air: PageIndexScanVerifyAir, +} + +impl PageIndexScanVerifyChip { + pub fn new(bus_index: usize, idx_len: usize, data_len: usize) -> Self { + Self { + air: PageIndexScanVerifyAir { + bus_index, + idx_len, + data_len, + }, + } + } +} diff --git a/chips/src/single_page_index_scan/page_index_scan_verify/trace.rs b/chips/src/single_page_index_scan/page_index_scan_verify/trace.rs new file mode 100644 index 0000000000..beb60a81cb --- /dev/null +++ b/chips/src/single_page_index_scan/page_index_scan_verify/trace.rs @@ -0,0 +1,33 @@ +use p3_field::PrimeField64; +use p3_matrix::dense::RowMajorMatrix; + +use super::{columns::PageIndexScanVerifyCols, PageIndexScanVerifyChip}; + +impl PageIndexScanVerifyChip { + pub fn generate_trace(&self, page: Vec>) -> RowMajorMatrix { + let num_cols: usize = + PageIndexScanVerifyCols::::get_width(self.air.idx_len, self.air.data_len); + + let mut rows: Vec = vec![]; + + for page_row in &page { + let mut row: Vec = vec![]; + + let is_alloc = F::from_canonical_u32(page_row[0]); + row.push(is_alloc); + + let idx = page_row[1..1 + self.air.idx_len].to_vec(); + let idx_trace: Vec = idx.iter().map(|x| F::from_canonical_u32(*x)).collect(); + row.extend(idx_trace); + + let data = + page_row[1 + self.air.idx_len..1 + self.air.idx_len + self.air.data_len].to_vec(); + let data_trace: Vec = data.iter().map(|x| F::from_canonical_u32(*x)).collect(); + row.extend(data_trace); + + rows.extend_from_slice(&row); + } + + RowMajorMatrix::new(rows, num_cols) + } +} diff --git a/chips/src/single_page_index_scan/tests.rs b/chips/src/single_page_index_scan/tests.rs new file mode 100644 index 0000000000..7df114f07d --- /dev/null +++ b/chips/src/single_page_index_scan/tests.rs @@ -0,0 +1,61 @@ +use std::sync::Arc; + +use afs_test_utils::config::baby_bear_poseidon2::run_simple_test_no_pis; + +use super::{page_index_scan::PageIndexScanChip, page_index_scan_verify::PageIndexScanVerifyChip}; +use crate::range_gate::RangeCheckerGateChip; + +#[test] +fn test_single_page_index_scan() { + let bus_index: usize = 0; + let idx_len: usize = 2; + let data_len: usize = 3; + let decomp: usize = 8; + let limb_bits: Vec = vec![16, 16]; + let range_max: u32 = 1 << decomp; + + let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); + + let page_index_scan_chip = PageIndexScanChip::new( + bus_index, + idx_len, + data_len, + range_max, + limb_bits, + decomp, + range_checker, + ); + let page_index_scan_verify_chip = PageIndexScanVerifyChip::new(bus_index, idx_len, data_len); + let range_checker = page_index_scan_chip.range_checker.as_ref(); + + let page: Vec> = vec![ + vec![1, 443, 376, 22278, 13998, 58327], + vec![1, 2883, 7769, 51171, 3989, 12770], + ]; + + let page_indexed: Vec> = vec![ + vec![1, 443, 376, 22278, 13998, 58327], + vec![0, 0, 0, 0, 0, 0], + ]; + + let x: Vec = vec![2177, 5880]; + + let page_index_scan_chip_trace = page_index_scan_chip.generate_trace(page.clone(), x); + let page_index_scan_verify_chip_trace = + page_index_scan_verify_chip.generate_trace(page_indexed.clone()); + let range_checker_trace = page_index_scan_chip.range_checker.generate_trace(); + + run_simple_test_no_pis( + vec![ + &page_index_scan_chip.air, + &page_index_scan_verify_chip.air, + range_checker, + ], + vec![ + page_index_scan_chip_trace, + page_index_scan_verify_chip_trace, + range_checker_trace, + ], + ) + .expect("Verification failed"); +} From 0a8882c001908346b80d2cdd1e7669f76e0cdd32 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Mon, 10 Jun 2024 10:42:51 -0400 Subject: [PATCH 27/46] feat: rename Chip to AirBridge --- .../page_index_scan/{chip.rs => bridge.rs} | 12 +++++------- .../single_page_index_scan/page_index_scan/mod.rs | 2 +- .../page_index_scan_verify/{chip.rs => bridge.rs} | 4 ++-- .../page_index_scan_verify/mod.rs | 2 +- chips/src/single_page_index_scan/tests.rs | 2 +- 5 files changed, 10 insertions(+), 12 deletions(-) rename chips/src/single_page_index_scan/page_index_scan/{chip.rs => bridge.rs} (84%) rename chips/src/single_page_index_scan/page_index_scan_verify/{chip.rs => bridge.rs} (88%) diff --git a/chips/src/single_page_index_scan/page_index_scan/chip.rs b/chips/src/single_page_index_scan/page_index_scan/bridge.rs similarity index 84% rename from chips/src/single_page_index_scan/page_index_scan/chip.rs rename to chips/src/single_page_index_scan/page_index_scan/bridge.rs index d2ac1c3721..cb9da5396d 100644 --- a/chips/src/single_page_index_scan/page_index_scan/chip.rs +++ b/chips/src/single_page_index_scan/page_index_scan/bridge.rs @@ -1,16 +1,16 @@ use crate::{ is_less_than_tuple::columns::{IsLessThanTupleCols, IsLessThanTupleIOCols}, - sub_chip::SubAirWithInteractions, + sub_chip::SubAirBridge, }; use super::columns::PageIndexScanCols; -use afs_stark_backend::interaction::{Chip, Interaction}; +use afs_stark_backend::interaction::{AirBridge, Interaction}; use p3_air::VirtualPairCol; use p3_field::PrimeField64; use super::PageIndexScanAir; -impl Chip for PageIndexScanAir { +impl AirBridge for PageIndexScanAir { fn sends(&self) -> Vec> { let num_cols = PageIndexScanCols::::get_width( *self.idx_len(), @@ -53,10 +53,8 @@ impl Chip for PageIndexScanAir { argument_index: *self.bus_index(), }]; - let mut subchip_interactions = SubAirWithInteractions::::sends( - self.is_less_than_tuple_air(), - is_less_than_tuple_cols, - ); + let mut subchip_interactions = + SubAirBridge::::sends(self.is_less_than_tuple_air(), is_less_than_tuple_cols); interactions.append(&mut subchip_interactions); diff --git a/chips/src/single_page_index_scan/page_index_scan/mod.rs b/chips/src/single_page_index_scan/page_index_scan/mod.rs index bcbeeab0c9..02507cb2e7 100644 --- a/chips/src/single_page_index_scan/page_index_scan/mod.rs +++ b/chips/src/single_page_index_scan/page_index_scan/mod.rs @@ -5,7 +5,7 @@ use getset::Getters; use crate::{is_less_than_tuple::IsLessThanTupleAir, range_gate::RangeCheckerGateChip}; pub mod air; -pub mod chip; +pub mod bridge; pub mod columns; pub mod trace; diff --git a/chips/src/single_page_index_scan/page_index_scan_verify/chip.rs b/chips/src/single_page_index_scan/page_index_scan_verify/bridge.rs similarity index 88% rename from chips/src/single_page_index_scan/page_index_scan_verify/chip.rs rename to chips/src/single_page_index_scan/page_index_scan_verify/bridge.rs index f7d3f9269d..48000e6d85 100644 --- a/chips/src/single_page_index_scan/page_index_scan_verify/chip.rs +++ b/chips/src/single_page_index_scan/page_index_scan_verify/bridge.rs @@ -1,11 +1,11 @@ use super::columns::PageIndexScanVerifyCols; -use afs_stark_backend::interaction::{Chip, Interaction}; +use afs_stark_backend::interaction::{AirBridge, Interaction}; use p3_air::VirtualPairCol; use p3_field::PrimeField64; use super::PageIndexScanVerifyAir; -impl Chip for PageIndexScanVerifyAir { +impl AirBridge for PageIndexScanVerifyAir { fn receives(&self) -> Vec> { let num_cols = PageIndexScanVerifyCols::::get_width(*self.idx_len(), *self.data_len()); let all_cols = (0..num_cols).collect::>(); diff --git a/chips/src/single_page_index_scan/page_index_scan_verify/mod.rs b/chips/src/single_page_index_scan/page_index_scan_verify/mod.rs index 77476c2e61..df42af28d9 100644 --- a/chips/src/single_page_index_scan/page_index_scan_verify/mod.rs +++ b/chips/src/single_page_index_scan/page_index_scan_verify/mod.rs @@ -1,7 +1,7 @@ use getset::Getters; pub mod air; -pub mod chip; +pub mod bridge; pub mod columns; pub mod trace; diff --git a/chips/src/single_page_index_scan/tests.rs b/chips/src/single_page_index_scan/tests.rs index 7df114f07d..a24049ae78 100644 --- a/chips/src/single_page_index_scan/tests.rs +++ b/chips/src/single_page_index_scan/tests.rs @@ -49,7 +49,7 @@ fn test_single_page_index_scan() { vec![ &page_index_scan_chip.air, &page_index_scan_verify_chip.air, - range_checker, + &range_checker.air, ], vec![ page_index_scan_chip_trace, From 24a8771f234e2df04e995d704cdd03b87b498987 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Mon, 10 Jun 2024 12:40:25 -0400 Subject: [PATCH 28/46] feat: single page index scan chip for less than predicate --- chips/src/single_page_index_scan/mod.rs | 4 +- .../air.rs | 18 +- .../bridge.rs | 30 +-- .../columns.rs | 6 +- .../mod.rs | 15 +- .../trace.rs | 8 +- .../page_index_scan_output/air.rs | 86 +++++++++ .../page_index_scan_output/bridge.rs | 77 ++++++++ .../page_index_scan_output/columns.rs | 45 +++++ .../page_index_scan_output/mod.rs | 59 ++++++ .../page_index_scan_output/trace.rs | 58 ++++++ .../page_index_scan_verify/air.rs | 31 ---- .../page_index_scan_verify/bridge.rs | 35 ---- .../page_index_scan_verify/columns.rs | 19 -- .../page_index_scan_verify/mod.rs | 42 ----- .../page_index_scan_verify/trace.rs | 33 ---- chips/src/single_page_index_scan/tests.rs | 173 ++++++++++++++++-- 17 files changed, 525 insertions(+), 214 deletions(-) rename chips/src/single_page_index_scan/{page_index_scan => page_index_scan_input}/air.rs (75%) rename chips/src/single_page_index_scan/{page_index_scan => page_index_scan_input}/bridge.rs (64%) rename chips/src/single_page_index_scan/{page_index_scan => page_index_scan_input}/columns.rs (94%) rename chips/src/single_page_index_scan/{page_index_scan => page_index_scan_input}/mod.rs (80%) rename chips/src/single_page_index_scan/{page_index_scan => page_index_scan_input}/trace.rs (88%) create mode 100644 chips/src/single_page_index_scan/page_index_scan_output/air.rs create mode 100644 chips/src/single_page_index_scan/page_index_scan_output/bridge.rs create mode 100644 chips/src/single_page_index_scan/page_index_scan_output/columns.rs create mode 100644 chips/src/single_page_index_scan/page_index_scan_output/mod.rs create mode 100644 chips/src/single_page_index_scan/page_index_scan_output/trace.rs delete mode 100644 chips/src/single_page_index_scan/page_index_scan_verify/air.rs delete mode 100644 chips/src/single_page_index_scan/page_index_scan_verify/bridge.rs delete mode 100644 chips/src/single_page_index_scan/page_index_scan_verify/columns.rs delete mode 100644 chips/src/single_page_index_scan/page_index_scan_verify/mod.rs delete mode 100644 chips/src/single_page_index_scan/page_index_scan_verify/trace.rs diff --git a/chips/src/single_page_index_scan/mod.rs b/chips/src/single_page_index_scan/mod.rs index 742f37163b..6c3cba6c48 100644 --- a/chips/src/single_page_index_scan/mod.rs +++ b/chips/src/single_page_index_scan/mod.rs @@ -1,5 +1,5 @@ -pub mod page_index_scan; -pub mod page_index_scan_verify; +pub mod page_index_scan_input; +pub mod page_index_scan_output; #[cfg(test)] pub mod tests; diff --git a/chips/src/single_page_index_scan/page_index_scan/air.rs b/chips/src/single_page_index_scan/page_index_scan_input/air.rs similarity index 75% rename from chips/src/single_page_index_scan/page_index_scan/air.rs rename to chips/src/single_page_index_scan/page_index_scan_input/air.rs index 82d5859d7e..822ed026d4 100644 --- a/chips/src/single_page_index_scan/page_index_scan/air.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/air.rs @@ -9,36 +9,36 @@ use crate::{ sub_chip::{AirConfig, SubAir}, }; -use super::{columns::PageIndexScanCols, PageIndexScanAir}; +use super::{columns::PageIndexScanInputCols, PageIndexScanInputAir}; -impl AirConfig for PageIndexScanAir { - type Cols = PageIndexScanCols; +impl AirConfig for PageIndexScanInputAir { + type Cols = PageIndexScanInputCols; } -impl BaseAir for PageIndexScanAir { +impl BaseAir for PageIndexScanInputAir { fn width(&self) -> usize { - PageIndexScanCols::::get_width( + PageIndexScanInputCols::::get_width( self.idx_len, self.data_len, self.is_less_than_tuple_air.limb_bits().clone(), - *self.is_less_than_tuple_air.decomp(), + self.is_less_than_tuple_air.decomp(), ) } } -impl Air for PageIndexScanAir { +impl Air for PageIndexScanInputAir { fn eval(&self, builder: &mut AB) { let main = builder.main(); let local = main.row_slice(0); let local: &[AB::Var] = (*local).borrow(); - let local_cols = PageIndexScanCols::::from_slice( + let local_cols = PageIndexScanInputCols::::from_slice( local, self.idx_len, self.data_len, - *self.is_less_than_tuple_air.decomp(), self.is_less_than_tuple_air.limb_bits().clone(), + self.is_less_than_tuple_air.decomp(), ); let is_less_than_tuple_cols = IsLessThanTupleCols { diff --git a/chips/src/single_page_index_scan/page_index_scan/bridge.rs b/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs similarity index 64% rename from chips/src/single_page_index_scan/page_index_scan/bridge.rs rename to chips/src/single_page_index_scan/page_index_scan_input/bridge.rs index cb9da5396d..1f9f0ceab0 100644 --- a/chips/src/single_page_index_scan/page_index_scan/bridge.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs @@ -3,29 +3,29 @@ use crate::{ sub_chip::SubAirBridge, }; -use super::columns::PageIndexScanCols; +use super::columns::PageIndexScanInputCols; use afs_stark_backend::interaction::{AirBridge, Interaction}; use p3_air::VirtualPairCol; use p3_field::PrimeField64; -use super::PageIndexScanAir; +use super::PageIndexScanInputAir; -impl AirBridge for PageIndexScanAir { +impl AirBridge for PageIndexScanInputAir { fn sends(&self) -> Vec> { - let num_cols = PageIndexScanCols::::get_width( - *self.idx_len(), - *self.data_len(), - self.is_less_than_tuple_air().limb_bits(), - *self.is_less_than_tuple_air().decomp(), + let num_cols = PageIndexScanInputCols::::get_width( + self.idx_len, + self.data_len, + self.is_less_than_tuple_air.limb_bits(), + self.is_less_than_tuple_air.decomp(), ); let all_cols = (0..num_cols).collect::>(); - let cols_numbered = PageIndexScanCols::::from_slice( + let cols_numbered = PageIndexScanInputCols::::from_slice( &all_cols, - *self.idx_len(), - *self.data_len(), - *self.is_less_than_tuple_air().decomp(), - self.is_less_than_tuple_air().limb_bits(), + self.idx_len, + self.data_len, + self.is_less_than_tuple_air.limb_bits(), + self.is_less_than_tuple_air.decomp(), ); let is_less_than_tuple_cols = IsLessThanTupleCols { @@ -50,11 +50,11 @@ impl AirBridge for PageIndexScanAir { let mut interactions = vec![Interaction { fields: virtual_cols, count: VirtualPairCol::single_main(cols_numbered.send_row), - argument_index: *self.bus_index(), + argument_index: self.bus_index, }]; let mut subchip_interactions = - SubAirBridge::::sends(self.is_less_than_tuple_air(), is_less_than_tuple_cols); + SubAirBridge::::sends(&self.is_less_than_tuple_air, is_less_than_tuple_cols); interactions.append(&mut subchip_interactions); diff --git a/chips/src/single_page_index_scan/page_index_scan/columns.rs b/chips/src/single_page_index_scan/page_index_scan_input/columns.rs similarity index 94% rename from chips/src/single_page_index_scan/page_index_scan/columns.rs rename to chips/src/single_page_index_scan/page_index_scan_input/columns.rs index eb3422ec21..431be45e23 100644 --- a/chips/src/single_page_index_scan/page_index_scan/columns.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/columns.rs @@ -1,6 +1,6 @@ use crate::is_less_than_tuple::columns::IsLessThanTupleAuxCols; -pub struct PageIndexScanCols { +pub struct PageIndexScanInputCols { pub is_alloc: T, pub idx: Vec, pub data: Vec, @@ -12,13 +12,13 @@ pub struct PageIndexScanCols { pub is_less_than_tuple_aux: IsLessThanTupleAuxCols, } -impl PageIndexScanCols { +impl PageIndexScanInputCols { pub fn from_slice( slc: &[T], idx_len: usize, data_len: usize, - decomp: usize, limb_bits: Vec, + decomp: usize, ) -> Self { Self { is_alloc: slc[0].clone(), diff --git a/chips/src/single_page_index_scan/page_index_scan/mod.rs b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs similarity index 80% rename from chips/src/single_page_index_scan/page_index_scan/mod.rs rename to chips/src/single_page_index_scan/page_index_scan_input/mod.rs index 02507cb2e7..b086e2cee7 100644 --- a/chips/src/single_page_index_scan/page_index_scan/mod.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs @@ -20,24 +20,21 @@ pub enum Comp { } #[derive(Default, Getters)] -pub struct PageIndexScanAir { - #[getset(get = "pub")] +pub struct PageIndexScanInputAir { pub bus_index: usize, - #[getset(get = "pub")] pub idx_len: usize, - #[getset(get = "pub")] pub data_len: usize, - #[getset(get = "pub")] + #[getset(skip)] is_less_than_tuple_air: IsLessThanTupleAir, } -pub struct PageIndexScanChip { - pub air: PageIndexScanAir, +pub struct PageIndexScanInputChip { + pub air: PageIndexScanInputAir, pub range_checker: Arc, } -impl PageIndexScanChip { +impl PageIndexScanInputChip { pub fn new( bus_index: usize, idx_len: usize, @@ -48,7 +45,7 @@ impl PageIndexScanChip { range_checker: Arc, ) -> Self { Self { - air: PageIndexScanAir { + air: PageIndexScanInputAir { bus_index, idx_len, data_len, diff --git a/chips/src/single_page_index_scan/page_index_scan/trace.rs b/chips/src/single_page_index_scan/page_index_scan_input/trace.rs similarity index 88% rename from chips/src/single_page_index_scan/page_index_scan/trace.rs rename to chips/src/single_page_index_scan/page_index_scan_input/trace.rs index 068633cd7a..66626f1544 100644 --- a/chips/src/single_page_index_scan/page_index_scan/trace.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/trace.rs @@ -3,19 +3,19 @@ use p3_matrix::dense::RowMajorMatrix; use crate::sub_chip::LocalTraceInstructions; -use super::{columns::PageIndexScanCols, PageIndexScanChip}; +use super::{columns::PageIndexScanInputCols, PageIndexScanInputChip}; -impl PageIndexScanChip { +impl PageIndexScanInputChip { pub fn generate_trace( &self, page: Vec>, x: Vec, ) -> RowMajorMatrix { - let num_cols: usize = PageIndexScanCols::::get_width( + let num_cols: usize = PageIndexScanInputCols::::get_width( self.air.idx_len, self.air.data_len, self.air.is_less_than_tuple_air.limb_bits(), - *self.air.is_less_than_tuple_air.decomp(), + self.air.is_less_than_tuple_air.decomp(), ); let mut rows: Vec = vec![]; diff --git a/chips/src/single_page_index_scan/page_index_scan_output/air.rs b/chips/src/single_page_index_scan/page_index_scan_output/air.rs new file mode 100644 index 0000000000..47acea6b82 --- /dev/null +++ b/chips/src/single_page_index_scan/page_index_scan_output/air.rs @@ -0,0 +1,86 @@ +use std::borrow::Borrow; + +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_field::{AbstractField, Field}; +use p3_matrix::Matrix; + +use crate::{ + is_less_than_tuple::columns::{IsLessThanTupleCols, IsLessThanTupleIOCols}, + sub_chip::{AirConfig, SubAir}, +}; + +use super::{columns::PageIndexScanOutputCols, PageIndexScanOutputAir}; + +impl AirConfig for PageIndexScanOutputAir { + type Cols = PageIndexScanOutputCols; +} + +impl BaseAir for PageIndexScanOutputAir { + fn width(&self) -> usize { + PageIndexScanOutputCols::::get_width( + self.idx_len, + self.data_len, + self.is_less_than_tuple_air().limb_bits().clone(), + self.is_less_than_tuple_air().decomp(), + ) + } +} + +impl Air for PageIndexScanOutputAir { + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + + // get the current row and the next row + let (local, next) = (main.row_slice(0), main.row_slice(1)); + let local: &[AB::Var] = (*local).borrow(); + let next: &[AB::Var] = (*next).borrow(); + + let local_cols = PageIndexScanOutputCols::::from_slice( + local, + self.idx_len, + self.data_len, + self.is_less_than_tuple_air().limb_bits().clone(), + self.is_less_than_tuple_air().decomp(), + ); + let next_cols = PageIndexScanOutputCols::::from_slice( + next, + self.idx_len, + self.data_len, + self.is_less_than_tuple_air().limb_bits().clone(), + self.is_less_than_tuple_air().decomp(), + ); + + // check that is_alloc is a bool + builder.when_transition().assert_bool(local_cols.is_alloc); + // if local_cols.is_alloc is 1, then next_cols.is_alloc can be 0 or 1 + builder + .when_transition() + .assert_bool(local_cols.is_alloc * next_cols.is_alloc); + // if local_cols.is_alloc is 0, then next_cols.is_alloc must be 0 + builder + .when_transition() + .assert_zero((AB::Expr::one() - local_cols.is_alloc) * next_cols.is_alloc); + + let is_less_than_tuple_cols = IsLessThanTupleCols { + io: IsLessThanTupleIOCols { + x: local_cols.idx, + y: next_cols.idx, + tuple_less_than: local_cols.less_than_next_idx, + }, + aux: local_cols.is_less_than_tuple_aux, + }; + + // constrain the indicator that we used to check whether the current key < next key is correct + SubAir::eval( + self.is_less_than_tuple_air(), + &mut builder.when_transition(), + is_less_than_tuple_cols.io, + is_less_than_tuple_cols.aux, + ); + + // if the next row exists, then the current row index must be less than the next row index + builder + .when_transition() + .assert_zero(next_cols.is_alloc * (AB::Expr::one() - local_cols.less_than_next_idx)); + } +} diff --git a/chips/src/single_page_index_scan/page_index_scan_output/bridge.rs b/chips/src/single_page_index_scan/page_index_scan_output/bridge.rs new file mode 100644 index 0000000000..66c65be145 --- /dev/null +++ b/chips/src/single_page_index_scan/page_index_scan_output/bridge.rs @@ -0,0 +1,77 @@ +use crate::{ + is_less_than_tuple::columns::{IsLessThanTupleCols, IsLessThanTupleIOCols}, + sub_chip::SubAirBridge, +}; + +use super::columns::PageIndexScanOutputCols; +use afs_stark_backend::interaction::{AirBridge, Interaction}; +use p3_air::VirtualPairCol; +use p3_field::PrimeField64; + +use super::PageIndexScanOutputAir; + +impl AirBridge for PageIndexScanOutputAir { + fn receives(&self) -> Vec> { + let num_cols = PageIndexScanOutputCols::::get_width( + self.idx_len, + self.data_len, + self.is_less_than_tuple_air().limb_bits().clone(), + self.is_less_than_tuple_air().decomp(), + ); + let all_cols = (0..num_cols).collect::>(); + + let cols_numbered = PageIndexScanOutputCols::::from_slice( + &all_cols, + self.idx_len, + self.data_len, + self.is_less_than_tuple_air().limb_bits().clone(), + self.is_less_than_tuple_air().decomp(), + ); + + let mut cols = vec![]; + cols.push(cols_numbered.is_alloc); + cols.extend(cols_numbered.idx.clone()); + cols.extend(cols_numbered.data); + + let virtual_cols = cols + .iter() + .map(|col| VirtualPairCol::single_main(*col)) + .collect::>(); + + vec![Interaction { + fields: virtual_cols, + count: VirtualPairCol::single_main(cols_numbered.is_alloc), + argument_index: self.bus_index, + }] + } + + fn sends(&self) -> Vec> { + let num_cols = PageIndexScanOutputCols::::get_width( + self.idx_len, + self.data_len, + self.is_less_than_tuple_air().limb_bits().clone(), + self.is_less_than_tuple_air().decomp(), + ); + let all_cols = (0..num_cols).collect::>(); + + let cols_numbered = PageIndexScanOutputCols::::from_slice( + &all_cols, + self.idx_len, + self.data_len, + self.is_less_than_tuple_air().limb_bits().clone(), + self.is_less_than_tuple_air().decomp(), + ); + + // range check the decompositions of x within aux columns; here the io doesn't matter + let is_less_than_tuple_cols = IsLessThanTupleCols { + io: IsLessThanTupleIOCols { + x: cols_numbered.idx.clone(), + y: cols_numbered.idx.clone(), + tuple_less_than: cols_numbered.less_than_next_idx, + }, + aux: cols_numbered.is_less_than_tuple_aux, + }; + + SubAirBridge::::sends(&self.is_less_than_tuple_air, is_less_than_tuple_cols) + } +} diff --git a/chips/src/single_page_index_scan/page_index_scan_output/columns.rs b/chips/src/single_page_index_scan/page_index_scan_output/columns.rs new file mode 100644 index 0000000000..4742764b0a --- /dev/null +++ b/chips/src/single_page_index_scan/page_index_scan_output/columns.rs @@ -0,0 +1,45 @@ +use crate::is_less_than_tuple::columns::IsLessThanTupleAuxCols; + +pub struct PageIndexScanOutputCols { + pub is_alloc: T, + pub idx: Vec, + pub data: Vec, + + pub less_than_next_idx: T, + pub is_less_than_tuple_aux: IsLessThanTupleAuxCols, +} + +impl PageIndexScanOutputCols { + pub fn from_slice( + slc: &[T], + idx_len: usize, + data_len: usize, + limb_bits: Vec, + decomp: usize, + ) -> Self { + Self { + is_alloc: slc[0].clone(), + idx: slc[1..idx_len + 1].to_vec(), + data: slc[idx_len + 1..idx_len + data_len + 1].to_vec(), + less_than_next_idx: slc[idx_len + data_len + 1].clone(), + is_less_than_tuple_aux: IsLessThanTupleAuxCols::from_slice( + &slc[idx_len + data_len + 2..], + limb_bits, + decomp, + idx_len, + ), + } + } + + pub fn get_width( + idx_len: usize, + data_len: usize, + limb_bits: Vec, + decomp: usize, + ) -> usize { + 1 + idx_len + + data_len + + 1 + + IsLessThanTupleAuxCols::::get_width(limb_bits, decomp, idx_len) + } +} diff --git a/chips/src/single_page_index_scan/page_index_scan_output/mod.rs b/chips/src/single_page_index_scan/page_index_scan_output/mod.rs new file mode 100644 index 0000000000..7a38288e98 --- /dev/null +++ b/chips/src/single_page_index_scan/page_index_scan_output/mod.rs @@ -0,0 +1,59 @@ +use std::sync::Arc; + +use getset::Getters; + +use crate::{is_less_than_tuple::IsLessThanTupleAir, range_gate::RangeCheckerGateChip}; + +pub mod air; +pub mod bridge; +pub mod columns; +pub mod trace; + +#[derive(Default)] +pub enum Comp { + #[default] + Lt, + Lte, + Eq, + Gte, + Gt, +} + +#[derive(Default, Getters)] +pub struct PageIndexScanOutputAir { + pub bus_index: usize, + pub idx_len: usize, + pub data_len: usize, + + #[getset(get = "pub")] + is_less_than_tuple_air: IsLessThanTupleAir, +} + +pub struct PageIndexScanOutputChip { + pub air: PageIndexScanOutputAir, + pub range_checker: Arc, +} + +impl PageIndexScanOutputChip { + pub fn new( + bus_index: usize, + idx_len: usize, + data_len: usize, + range_max: u32, + limb_bits: Vec, + decomp: usize, + range_checker: Arc, + ) -> Self { + Self { + air: PageIndexScanOutputAir { + bus_index, + idx_len, + data_len, + is_less_than_tuple_air: IsLessThanTupleAir::new( + bus_index, range_max, limb_bits, decomp, + ), + }, + range_checker, + } + } +} diff --git a/chips/src/single_page_index_scan/page_index_scan_output/trace.rs b/chips/src/single_page_index_scan/page_index_scan_output/trace.rs new file mode 100644 index 0000000000..5c41364d1b --- /dev/null +++ b/chips/src/single_page_index_scan/page_index_scan_output/trace.rs @@ -0,0 +1,58 @@ +use p3_field::PrimeField64; +use p3_matrix::dense::RowMajorMatrix; + +use crate::sub_chip::LocalTraceInstructions; + +use super::{columns::PageIndexScanOutputCols, PageIndexScanOutputChip}; + +impl PageIndexScanOutputChip { + pub fn generate_trace(&self, page: Vec>) -> RowMajorMatrix { + let num_cols: usize = PageIndexScanOutputCols::::get_width( + self.air.idx_len, + self.air.data_len, + self.air.is_less_than_tuple_air().limb_bits().clone(), + self.air.is_less_than_tuple_air().decomp(), + ); + + let mut rows: Vec = vec![]; + + for i in 0..page.len() { + let page_row = page[i].clone(); + let next_page: Vec = if i == page.len() - 1 { + vec![0; 1 + self.air.idx_len + self.air.data_len] + } else { + page[i + 1].clone() + }; + + let mut row: Vec = vec![]; + + let is_alloc = F::from_canonical_u32(page_row[0]); + row.push(is_alloc); + + let idx = page_row[1..1 + self.air.idx_len].to_vec(); + let idx_trace: Vec = idx.iter().map(|x| F::from_canonical_u32(*x)).collect(); + row.extend(idx_trace); + + let data = + page_row[1 + self.air.idx_len..1 + self.air.idx_len + self.air.data_len].to_vec(); + let data_trace: Vec = data.iter().map(|x| F::from_canonical_u32(*x)).collect(); + row.extend(data_trace); + + let is_less_than_tuple_trace = LocalTraceInstructions::generate_trace_row( + self.air.is_less_than_tuple_air(), + ( + page_row[1..1 + self.air.idx_len].to_vec(), + next_page[1..1 + self.air.idx_len].to_vec(), + self.range_checker.clone(), + ), + ) + .flatten(); + + row.extend_from_slice(&is_less_than_tuple_trace[2 * self.air.idx_len..]); + + rows.extend_from_slice(&row); + } + + RowMajorMatrix::new(rows, num_cols) + } +} diff --git a/chips/src/single_page_index_scan/page_index_scan_verify/air.rs b/chips/src/single_page_index_scan/page_index_scan_verify/air.rs deleted file mode 100644 index 2174705fb9..0000000000 --- a/chips/src/single_page_index_scan/page_index_scan_verify/air.rs +++ /dev/null @@ -1,31 +0,0 @@ -use std::borrow::Borrow; - -use p3_air::{Air, AirBuilder, BaseAir}; -use p3_field::Field; -use p3_matrix::Matrix; - -use crate::sub_chip::AirConfig; - -use super::{columns::PageIndexScanVerifyCols, PageIndexScanVerifyAir}; - -impl AirConfig for PageIndexScanVerifyAir { - type Cols = PageIndexScanVerifyCols; -} - -impl BaseAir for PageIndexScanVerifyAir { - fn width(&self) -> usize { - PageIndexScanVerifyCols::::get_width(self.idx_len, self.data_len) - } -} - -impl Air for PageIndexScanVerifyAir { - fn eval(&self, builder: &mut AB) { - let main = builder.main(); - - let local = main.row_slice(0); - let local: &[AB::Var] = (*local).borrow(); - - let _local_cols = - PageIndexScanVerifyCols::::from_slice(local, self.idx_len, self.data_len); - } -} diff --git a/chips/src/single_page_index_scan/page_index_scan_verify/bridge.rs b/chips/src/single_page_index_scan/page_index_scan_verify/bridge.rs deleted file mode 100644 index 48000e6d85..0000000000 --- a/chips/src/single_page_index_scan/page_index_scan_verify/bridge.rs +++ /dev/null @@ -1,35 +0,0 @@ -use super::columns::PageIndexScanVerifyCols; -use afs_stark_backend::interaction::{AirBridge, Interaction}; -use p3_air::VirtualPairCol; -use p3_field::PrimeField64; - -use super::PageIndexScanVerifyAir; - -impl AirBridge for PageIndexScanVerifyAir { - fn receives(&self) -> Vec> { - let num_cols = PageIndexScanVerifyCols::::get_width(*self.idx_len(), *self.data_len()); - let all_cols = (0..num_cols).collect::>(); - - let cols_numbered = PageIndexScanVerifyCols::::from_slice( - &all_cols, - *self.idx_len(), - *self.data_len(), - ); - - let mut cols = vec![]; - cols.push(cols_numbered.is_alloc); - cols.extend(cols_numbered.idx); - cols.extend(cols_numbered.data); - - let virtual_cols = cols - .iter() - .map(|col| VirtualPairCol::single_main(*col)) - .collect::>(); - - vec![Interaction { - fields: virtual_cols, - count: VirtualPairCol::single_main(cols_numbered.is_alloc), - argument_index: *self.bus_index(), - }] - } -} diff --git a/chips/src/single_page_index_scan/page_index_scan_verify/columns.rs b/chips/src/single_page_index_scan/page_index_scan_verify/columns.rs deleted file mode 100644 index 82946e3352..0000000000 --- a/chips/src/single_page_index_scan/page_index_scan_verify/columns.rs +++ /dev/null @@ -1,19 +0,0 @@ -pub struct PageIndexScanVerifyCols { - pub is_alloc: T, - pub idx: Vec, - pub data: Vec, -} - -impl PageIndexScanVerifyCols { - pub fn from_slice(slc: &[T], idx_len: usize, data_len: usize) -> Self { - Self { - is_alloc: slc[0].clone(), - idx: slc[1..idx_len + 1].to_vec(), - data: slc[idx_len + 1..idx_len + data_len + 1].to_vec(), - } - } - - pub fn get_width(idx_len: usize, data_len: usize) -> usize { - 1 + idx_len + data_len - } -} diff --git a/chips/src/single_page_index_scan/page_index_scan_verify/mod.rs b/chips/src/single_page_index_scan/page_index_scan_verify/mod.rs deleted file mode 100644 index df42af28d9..0000000000 --- a/chips/src/single_page_index_scan/page_index_scan_verify/mod.rs +++ /dev/null @@ -1,42 +0,0 @@ -use getset::Getters; - -pub mod air; -pub mod bridge; -pub mod columns; -pub mod trace; - -#[derive(Default)] -pub enum Comp { - #[default] - Lt, - Lte, - Eq, - Gte, - Gt, -} - -#[derive(Default, Getters)] -pub struct PageIndexScanVerifyAir { - #[getset(get = "pub")] - pub bus_index: usize, - #[getset(get = "pub")] - pub idx_len: usize, - #[getset(get = "pub")] - pub data_len: usize, -} - -pub struct PageIndexScanVerifyChip { - pub air: PageIndexScanVerifyAir, -} - -impl PageIndexScanVerifyChip { - pub fn new(bus_index: usize, idx_len: usize, data_len: usize) -> Self { - Self { - air: PageIndexScanVerifyAir { - bus_index, - idx_len, - data_len, - }, - } - } -} diff --git a/chips/src/single_page_index_scan/page_index_scan_verify/trace.rs b/chips/src/single_page_index_scan/page_index_scan_verify/trace.rs deleted file mode 100644 index beb60a81cb..0000000000 --- a/chips/src/single_page_index_scan/page_index_scan_verify/trace.rs +++ /dev/null @@ -1,33 +0,0 @@ -use p3_field::PrimeField64; -use p3_matrix::dense::RowMajorMatrix; - -use super::{columns::PageIndexScanVerifyCols, PageIndexScanVerifyChip}; - -impl PageIndexScanVerifyChip { - pub fn generate_trace(&self, page: Vec>) -> RowMajorMatrix { - let num_cols: usize = - PageIndexScanVerifyCols::::get_width(self.air.idx_len, self.air.data_len); - - let mut rows: Vec = vec![]; - - for page_row in &page { - let mut row: Vec = vec![]; - - let is_alloc = F::from_canonical_u32(page_row[0]); - row.push(is_alloc); - - let idx = page_row[1..1 + self.air.idx_len].to_vec(); - let idx_trace: Vec = idx.iter().map(|x| F::from_canonical_u32(*x)).collect(); - row.extend(idx_trace); - - let data = - page_row[1 + self.air.idx_len..1 + self.air.idx_len + self.air.data_len].to_vec(); - let data_trace: Vec = data.iter().map(|x| F::from_canonical_u32(*x)).collect(); - row.extend(data_trace); - - rows.extend_from_slice(&row); - } - - RowMajorMatrix::new(rows, num_cols) - } -} diff --git a/chips/src/single_page_index_scan/tests.rs b/chips/src/single_page_index_scan/tests.rs index a24049ae78..a6b7f358ca 100644 --- a/chips/src/single_page_index_scan/tests.rs +++ b/chips/src/single_page_index_scan/tests.rs @@ -1,8 +1,11 @@ use std::sync::Arc; +use afs_stark_backend::{prover::USE_DEBUG_BUILDER, verifier::VerificationError}; use afs_test_utils::config::baby_bear_poseidon2::run_simple_test_no_pis; -use super::{page_index_scan::PageIndexScanChip, page_index_scan_verify::PageIndexScanVerifyChip}; +use super::{ + page_index_scan_input::PageIndexScanInputChip, page_index_scan_output::PageIndexScanOutputChip, +}; use crate::range_gate::RangeCheckerGateChip; #[test] @@ -16,17 +19,25 @@ fn test_single_page_index_scan() { let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); - let page_index_scan_chip = PageIndexScanChip::new( + let page_index_scan_input_chip = PageIndexScanInputChip::new( bus_index, idx_len, data_len, range_max, - limb_bits, + limb_bits.clone(), decomp, - range_checker, + range_checker.clone(), ); - let page_index_scan_verify_chip = PageIndexScanVerifyChip::new(bus_index, idx_len, data_len); - let range_checker = page_index_scan_chip.range_checker.as_ref(); + let page_index_scan_output_chip = PageIndexScanOutputChip::new( + bus_index, + idx_len, + data_len, + range_max, + limb_bits.clone(), + decomp, + range_checker.clone(), + ); + let range_checker_chip = range_checker.as_ref(); let page: Vec> = vec![ vec![1, 443, 376, 22278, 13998, 58327], @@ -40,16 +51,16 @@ fn test_single_page_index_scan() { let x: Vec = vec![2177, 5880]; - let page_index_scan_chip_trace = page_index_scan_chip.generate_trace(page.clone(), x); + let page_index_scan_chip_trace = page_index_scan_input_chip.generate_trace(page.clone(), x); let page_index_scan_verify_chip_trace = - page_index_scan_verify_chip.generate_trace(page_indexed.clone()); - let range_checker_trace = page_index_scan_chip.range_checker.generate_trace(); + page_index_scan_output_chip.generate_trace(page_indexed.clone()); + let range_checker_trace = range_checker_chip.generate_trace(); run_simple_test_no_pis( vec![ - &page_index_scan_chip.air, - &page_index_scan_verify_chip.air, - &range_checker.air, + &page_index_scan_input_chip.air, + &page_index_scan_output_chip.air, + &range_checker_chip.air, ], vec![ page_index_scan_chip_trace, @@ -59,3 +70,141 @@ fn test_single_page_index_scan() { ) .expect("Verification failed"); } + +#[test] +fn test_single_page_index_scan_wrong_order() { + let bus_index: usize = 0; + let idx_len: usize = 2; + let data_len: usize = 3; + let decomp: usize = 8; + let limb_bits: Vec = vec![16, 16]; + let range_max: u32 = 1 << decomp; + + let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); + + let page_index_scan_input_chip = PageIndexScanInputChip::new( + bus_index, + idx_len, + data_len, + range_max, + limb_bits.clone(), + decomp, + range_checker.clone(), + ); + let page_index_scan_output_chip = PageIndexScanOutputChip::new( + bus_index, + idx_len, + data_len, + range_max, + limb_bits.clone(), + decomp, + range_checker.clone(), + ); + let range_checker_chip = range_checker.as_ref(); + + let page: Vec> = vec![ + vec![1, 443, 376, 22278, 13998, 58327], + vec![1, 2883, 7769, 51171, 3989, 12770], + ]; + + let page_indexed: Vec> = vec![ + vec![0, 0, 0, 0, 0, 0], + vec![1, 443, 376, 22278, 13998, 58327], + ]; + + let x: Vec = vec![2177, 5880]; + + let page_index_scan_chip_trace = page_index_scan_input_chip.generate_trace(page.clone(), x); + let page_index_scan_verify_chip_trace = + page_index_scan_output_chip.generate_trace(page_indexed.clone()); + let range_checker_trace = range_checker_chip.generate_trace(); + + USE_DEBUG_BUILDER.with(|debug| { + *debug.lock().unwrap() = false; + }); + assert_eq!( + run_simple_test_no_pis( + vec![ + &page_index_scan_input_chip.air, + &page_index_scan_output_chip.air, + &range_checker_chip.air, + ], + vec![ + page_index_scan_chip_trace, + page_index_scan_verify_chip_trace, + range_checker_trace, + ], + ), + Err(VerificationError::OodEvaluationMismatch), + "Expected verification to fail, but it passed" + ); +} + +#[test] +fn test_single_page_index_scan_unsorted() { + let bus_index: usize = 0; + let idx_len: usize = 2; + let data_len: usize = 3; + let decomp: usize = 8; + let limb_bits: Vec = vec![16, 16]; + let range_max: u32 = 1 << decomp; + + let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); + + let page_index_scan_input_chip = PageIndexScanInputChip::new( + bus_index, + idx_len, + data_len, + range_max, + limb_bits.clone(), + decomp, + range_checker.clone(), + ); + let page_index_scan_output_chip = PageIndexScanOutputChip::new( + bus_index, + idx_len, + data_len, + range_max, + limb_bits.clone(), + decomp, + range_checker.clone(), + ); + let range_checker_chip = range_checker.as_ref(); + + let page: Vec> = vec![ + vec![1, 443, 376, 22278, 13998, 58327], + vec![1, 2883, 7769, 51171, 3989, 12770], + ]; + + let page_indexed: Vec> = vec![ + vec![1, 2883, 7769, 51171, 3989, 12770], + vec![1, 443, 376, 22278, 13998, 58327], + ]; + + let x: Vec = vec![2883, 7770]; + + let page_index_scan_chip_trace = page_index_scan_input_chip.generate_trace(page.clone(), x); + let page_index_scan_verify_chip_trace = + page_index_scan_output_chip.generate_trace(page_indexed.clone()); + let range_checker_trace = range_checker_chip.generate_trace(); + + USE_DEBUG_BUILDER.with(|debug| { + *debug.lock().unwrap() = false; + }); + assert_eq!( + run_simple_test_no_pis( + vec![ + &page_index_scan_input_chip.air, + &page_index_scan_output_chip.air, + &range_checker_chip.air, + ], + vec![ + page_index_scan_chip_trace, + page_index_scan_verify_chip_trace, + range_checker_trace, + ], + ), + Err(VerificationError::OodEvaluationMismatch), + "Expected verification to fail, but it passed" + ); +} From d250e99f8b13247a79baee28d07c528690d0dba6 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Mon, 10 Jun 2024 18:21:49 -0400 Subject: [PATCH 29/46] feat: partitioned main --- chips/src/range_gate/mod.rs | 16 +- chips/src/range_gate/trace.rs | 4 +- chips/src/single_page_index_scan/mod.rs | 1 + .../page_controller/mod.rs | 173 +++++++ .../page_index_scan_input/air.rs | 23 +- .../page_index_scan_input/columns.rs | 8 +- .../page_index_scan_input/mod.rs | 28 +- .../page_index_scan_input/trace.rs | 112 ++-- .../page_index_scan_output/air.rs | 32 +- .../page_index_scan_output/columns.rs | 8 +- .../page_index_scan_output/mod.rs | 28 +- .../page_index_scan_output/trace.rs | 55 +- chips/src/single_page_index_scan/tests.rs | 484 ++++++++++++------ 13 files changed, 729 insertions(+), 243 deletions(-) create mode 100644 chips/src/single_page_index_scan/page_controller/mod.rs diff --git a/chips/src/range_gate/mod.rs b/chips/src/range_gate/mod.rs index a59c575175..42d4185fb9 100644 --- a/chips/src/range_gate/mod.rs +++ b/chips/src/range_gate/mod.rs @@ -8,7 +8,7 @@ pub mod trace; #[derive(Default)] pub struct RangeCheckerGateAir { bus_index: usize, - _range_max: u32, + range_max: u32, } /// This chip gets requests to verify that a number is in the range @@ -30,12 +30,24 @@ impl RangeCheckerGateChip { Self { air: RangeCheckerGateAir { bus_index, - _range_max: range_max, + range_max, }, count, } } + pub fn bus_index(&self) -> usize { + self.air.bus_index + } + + pub fn range_max(&self) -> u32 { + self.air.range_max + } + + pub fn air_width(&self) -> usize { + 2 + } + pub fn add_count(&self, val: u32) { let val_atomic = &self.count[val as usize]; val_atomic.fetch_add(1, std::sync::atomic::Ordering::Relaxed); diff --git a/chips/src/range_gate/trace.rs b/chips/src/range_gate/trace.rs index 94c8b77694..52f7e95c1d 100644 --- a/chips/src/range_gate/trace.rs +++ b/chips/src/range_gate/trace.rs @@ -1,10 +1,10 @@ -use p3_field::PrimeField64; +use p3_field::PrimeField; use p3_matrix::dense::RowMajorMatrix; use super::{columns::NUM_RANGE_GATE_COLS, RangeCheckerGateChip}; impl RangeCheckerGateChip { - pub fn generate_trace(&self) -> RowMajorMatrix { + pub fn generate_trace(&self) -> RowMajorMatrix { let rows = self .count .iter() diff --git a/chips/src/single_page_index_scan/mod.rs b/chips/src/single_page_index_scan/mod.rs index 6c3cba6c48..62d219af4c 100644 --- a/chips/src/single_page_index_scan/mod.rs +++ b/chips/src/single_page_index_scan/mod.rs @@ -1,3 +1,4 @@ +pub mod page_controller; pub mod page_index_scan_input; pub mod page_index_scan_output; diff --git a/chips/src/single_page_index_scan/page_controller/mod.rs b/chips/src/single_page_index_scan/page_controller/mod.rs new file mode 100644 index 0000000000..024df7c6bd --- /dev/null +++ b/chips/src/single_page_index_scan/page_controller/mod.rs @@ -0,0 +1,173 @@ +use std::sync::Arc; + +use afs_stark_backend::{ + config::Com, + prover::trace::{ProverTraceData, TraceCommitter}, +}; +use p3_field::{AbstractField, PrimeField, PrimeField64}; +use p3_matrix::dense::DenseMatrix; +use p3_matrix::Matrix; +use p3_uni_stark::{StarkGenericConfig, Val}; + +use crate::range_gate::RangeCheckerGateChip; + +use super::{ + page_index_scan_input::PageIndexScanInputChip, page_index_scan_output::PageIndexScanOutputChip, +}; + +pub struct PageController +where + Val: AbstractField + PrimeField64, +{ + pub input_chip: PageIndexScanInputChip, + pub output_chip: PageIndexScanOutputChip, + + input_chip_trace: Option>>, + input_chip_aux_trace: Option>>, + output_chip_trace: Option>>, + output_chip_aux_trace: Option>>, + + input_commitment: Option>, + + pub range_checker: Arc, +} + +impl PageController +where + Val: AbstractField + PrimeField64, +{ + pub fn new( + bus_index: usize, + idx_len: usize, + data_len: usize, + range_max: u32, + idx_limb_bits: Vec, + idx_decomp: usize, + ) -> Self { + let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, 1 << idx_decomp)); + Self { + input_chip: PageIndexScanInputChip::new( + bus_index, + idx_len, + data_len, + range_max, + idx_limb_bits.clone(), + idx_decomp, + range_checker.clone(), + ), + output_chip: PageIndexScanOutputChip::new( + bus_index, + idx_len, + data_len, + range_max, + idx_limb_bits.clone(), + idx_decomp, + range_checker.clone(), + ), + input_chip_trace: None, + input_chip_aux_trace: None, + output_chip_trace: None, + output_chip_aux_trace: None, + input_commitment: None, + range_checker, + } + } + + pub fn input_chip_trace(&self) -> DenseMatrix> { + self.input_chip_trace.clone().unwrap() + } + + pub fn input_chip_aux_trace(&self) -> DenseMatrix> { + self.input_chip_aux_trace.clone().unwrap() + } + + pub fn output_chip_trace(&self) -> DenseMatrix> { + self.output_chip_trace.clone().unwrap() + } + + pub fn output_chip_aux_trace(&self) -> DenseMatrix> { + self.output_chip_aux_trace.clone().unwrap() + } + + pub fn range_checker_trace(&self) -> DenseMatrix> + where + Val: PrimeField, + { + self.range_checker.generate_trace() + } + + pub fn update_range_checker(&mut self, idx_decomp: usize) { + self.range_checker = Arc::new(RangeCheckerGateChip::new( + self.range_checker.bus_index(), + 1 << idx_decomp, + )); + } + + #[allow(clippy::too_many_arguments)] + pub fn load_page( + &mut self, + page: Vec>, + x: Vec, + idx_len: usize, + data_len: usize, + idx_limb_bits: Vec, + idx_decomp: usize, + trace_committer: &mut TraceCommitter, + ) -> (Vec>>, Vec>) + where + Val: PrimeField, + { + // idx_decomp can't change between different pages since range_checker depends on it + assert!(1 << idx_decomp == self.range_checker.range_max()); + + assert!(!page.is_empty()); + + let bus_index = self.input_chip.air.bus_index; + + self.input_chip = PageIndexScanInputChip::new( + bus_index, + idx_len, + data_len, + self.range_checker.range_max(), + idx_limb_bits.clone(), + idx_decomp, + self.range_checker.clone(), + ); + self.input_chip_trace = Some(self.input_chip.gen_page_trace::(page.clone())); + self.input_chip_aux_trace = + Some(self.input_chip.gen_aux_trace::(page.clone(), x.clone())); + + self.output_chip = PageIndexScanOutputChip::new( + bus_index, + idx_len, + data_len, + self.range_checker.range_max(), + idx_limb_bits.clone(), + idx_decomp, + self.range_checker.clone(), + ); + + let page_result = self.input_chip.gen_output(page.clone(), x.clone()); + + println!("page_result: {:?}", page_result); + + self.output_chip_trace = Some(self.output_chip.gen_page_trace::(page_result.clone())); + self.output_chip_aux_trace = + Some(self.output_chip.gen_aux_trace::(page_result.clone())); + + let prover_data = + vec![trace_committer.commit(vec![self.input_chip_trace.clone().unwrap()])]; + + self.input_commitment = Some(prover_data[0].commit.clone()); + + tracing::debug!( + "heights of all traces: {} {} {} {}", + self.input_chip_trace.as_ref().unwrap().height(), + self.input_chip_aux_trace.as_ref().unwrap().height(), + self.output_chip_trace.as_ref().unwrap().height(), + self.output_chip_aux_trace.as_ref().unwrap().height() + ); + + (vec![self.input_chip_trace.clone().unwrap()], prover_data) + } +} diff --git a/chips/src/single_page_index_scan/page_index_scan_input/air.rs b/chips/src/single_page_index_scan/page_index_scan_input/air.rs index 822ed026d4..ed74d79124 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/air.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/air.rs @@ -1,6 +1,5 @@ -use std::borrow::Borrow; - -use p3_air::{Air, AirBuilder, BaseAir}; +use afs_stark_backend::air_builders::PartitionedAirBuilder; +use p3_air::{Air, BaseAir}; use p3_field::Field; use p3_matrix::Matrix; @@ -26,12 +25,22 @@ impl BaseAir for PageIndexScanInputAir { } } -impl Air for PageIndexScanInputAir { +impl Air for PageIndexScanInputAir +where + AB::M: Clone, +{ fn eval(&self, builder: &mut AB) { - let main = builder.main(); + let page_main = &builder.partitioned_main()[0].clone(); + let aux_main = &builder.partitioned_main()[1].clone(); - let local = main.row_slice(0); - let local: &[AB::Var] = (*local).borrow(); + let local_page = page_main.row_slice(0); + let local_aux = aux_main.row_slice(0); + let local_vec = local_page + .iter() + .chain(local_aux.iter()) + .cloned() + .collect::>(); + let local = local_vec.as_slice(); let local_cols = PageIndexScanInputCols::::from_slice( local, diff --git a/chips/src/single_page_index_scan/page_index_scan_input/columns.rs b/chips/src/single_page_index_scan/page_index_scan_input/columns.rs index 431be45e23..e81d849f62 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/columns.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/columns.rs @@ -17,7 +17,7 @@ impl PageIndexScanInputCols { slc: &[T], idx_len: usize, data_len: usize, - limb_bits: Vec, + idx_limb_bits: Vec, decomp: usize, ) -> Self { Self { @@ -29,7 +29,7 @@ impl PageIndexScanInputCols { send_row: slc[2 * idx_len + data_len + 2].clone(), is_less_than_tuple_aux: IsLessThanTupleAuxCols::from_slice( &slc[2 * idx_len + data_len + 3..], - limb_bits, + idx_limb_bits, decomp, idx_len, ), @@ -39,7 +39,7 @@ impl PageIndexScanInputCols { pub fn get_width( idx_len: usize, data_len: usize, - limb_bits: Vec, + idx_limb_bits: Vec, decomp: usize, ) -> usize { 1 + idx_len @@ -47,6 +47,6 @@ impl PageIndexScanInputCols { + idx_len + 1 + 1 - + IsLessThanTupleAuxCols::::get_width(limb_bits, decomp, idx_len) + + IsLessThanTupleAuxCols::::get_width(idx_limb_bits, decomp, idx_len) } } diff --git a/chips/src/single_page_index_scan/page_index_scan_input/mod.rs b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs index b086e2cee7..525963f81b 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/mod.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs @@ -2,7 +2,10 @@ use std::sync::Arc; use getset::Getters; -use crate::{is_less_than_tuple::IsLessThanTupleAir, range_gate::RangeCheckerGateChip}; +use crate::{ + is_less_than_tuple::{columns::IsLessThanTupleAuxCols, IsLessThanTupleAir}, + range_gate::RangeCheckerGateChip, +}; pub mod air; pub mod bridge; @@ -40,7 +43,7 @@ impl PageIndexScanInputChip { idx_len: usize, data_len: usize, range_max: u32, - limb_bits: Vec, + idx_limb_bits: Vec, decomp: usize, range_checker: Arc, ) -> Self { @@ -52,11 +55,30 @@ impl PageIndexScanInputChip { is_less_than_tuple_air: IsLessThanTupleAir::new( bus_index, range_max, - limb_bits.clone(), + idx_limb_bits.clone(), decomp, ), }, range_checker, } } + + pub fn page_width(&self) -> usize { + 1 + self.air.idx_len + self.air.data_len + } + + pub fn aux_width(&self) -> usize { + self.air.idx_len + + 1 + + 1 + + IsLessThanTupleAuxCols::::get_width( + self.air.is_less_than_tuple_air.limb_bits(), + self.air.is_less_than_tuple_air.decomp(), + self.air.idx_len, + ) + } + + pub fn air_width(&self) -> usize { + self.page_width() + self.aux_width() + } } diff --git a/chips/src/single_page_index_scan/page_index_scan_input/trace.rs b/chips/src/single_page_index_scan/page_index_scan_input/trace.rs index 66626f1544..5ad88f8ee7 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/trace.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/trace.rs @@ -1,48 +1,59 @@ -use p3_field::PrimeField64; +use p3_field::{AbstractField, PrimeField64}; use p3_matrix::dense::RowMajorMatrix; +use p3_uni_stark::{StarkGenericConfig, Val}; use crate::sub_chip::LocalTraceInstructions; -use super::{columns::PageIndexScanInputCols, PageIndexScanInputChip}; +use super::PageIndexScanInputChip; impl PageIndexScanInputChip { - pub fn generate_trace( + pub fn gen_page_trace( &self, page: Vec>, - x: Vec, - ) -> RowMajorMatrix { - let num_cols: usize = PageIndexScanInputCols::::get_width( - self.air.idx_len, - self.air.data_len, - self.air.is_less_than_tuple_air.limb_bits(), - self.air.is_less_than_tuple_air.decomp(), - ); + ) -> RowMajorMatrix> + where + Val: AbstractField, + { + RowMajorMatrix::new( + page.into_iter() + .flat_map(|row| { + row.into_iter() + .map(Val::::from_wrapped_u32) + .collect::>>() + }) + .collect(), + self.page_width(), + ) + } - let mut rows: Vec = vec![]; + pub fn gen_aux_trace( + &self, + page: Vec>, + x: Vec, + ) -> RowMajorMatrix> + where + Val: AbstractField + PrimeField64, + { + let mut rows: Vec> = vec![]; for page_row in &page { - let mut row: Vec = vec![]; - - let is_alloc = F::from_canonical_u32(page_row[0]); - row.push(is_alloc); + let mut row: Vec> = vec![]; + let is_alloc = Val::::from_canonical_u32(page_row[0]); let idx = page_row[1..1 + self.air.idx_len].to_vec(); - let idx_trace: Vec = idx.iter().map(|x| F::from_canonical_u32(*x)).collect(); - row.extend(idx_trace); - - let data = - page_row[1 + self.air.idx_len..1 + self.air.idx_len + self.air.data_len].to_vec(); - let data_trace: Vec = data.iter().map(|x| F::from_canonical_u32(*x)).collect(); - row.extend(data_trace); - let x_trace: Vec = x.iter().map(|x| F::from_canonical_u32(*x)).collect(); + let x_trace: Vec> = x + .iter() + .map(|x| Val::::from_canonical_u32(*x)) + .collect(); row.extend(x_trace); - let is_less_than_tuple_trace: Vec = LocalTraceInstructions::generate_trace_row( - &self.air.is_less_than_tuple_air, - (idx.clone(), x.clone(), self.range_checker.clone()), - ) - .flatten(); + let is_less_than_tuple_trace: Vec> = + LocalTraceInstructions::generate_trace_row( + &self.air.is_less_than_tuple_air, + (idx.clone(), x.clone(), self.range_checker.clone()), + ) + .flatten(); row.push(is_less_than_tuple_trace[2 * self.air.idx_len]); let send_row = is_less_than_tuple_trace[2 * self.air.idx_len] * is_alloc; @@ -53,6 +64,47 @@ impl PageIndexScanInputChip { rows.extend_from_slice(&row); } - RowMajorMatrix::new(rows, num_cols) + RowMajorMatrix::new(rows, self.aux_width()) + } + + pub fn gen_output(&self, page: Vec>, x: Vec) -> Vec> { + let mut output: Vec> = vec![]; + + for page_row in &page { + let is_alloc = page_row[0]; + let idx = page_row[1..1 + self.air.idx_len].to_vec(); + let data = page_row[1 + self.air.idx_len..].to_vec(); + + let mut less_than = false; + for (&idx_val, &x_val) in idx.iter().zip(x.iter()) { + use std::cmp::Ordering; + match idx_val.cmp(&x_val) { + Ordering::Less => { + less_than = true; + break; + } + Ordering::Greater => { + break; + } + Ordering::Equal => {} + } + } + + if less_than { + output.push( + vec![is_alloc] + .into_iter() + .chain(idx.iter().cloned()) + .chain(data.iter().cloned()) + .collect(), + ); + } + } + + let num_remaining = page.len() - output.len(); + + output.extend((0..num_remaining).map(|_| vec![0; self.page_width()])); + + output } } diff --git a/chips/src/single_page_index_scan/page_index_scan_output/air.rs b/chips/src/single_page_index_scan/page_index_scan_output/air.rs index 47acea6b82..543401914d 100644 --- a/chips/src/single_page_index_scan/page_index_scan_output/air.rs +++ b/chips/src/single_page_index_scan/page_index_scan_output/air.rs @@ -1,5 +1,6 @@ use std::borrow::Borrow; +use afs_stark_backend::air_builders::PartitionedAirBuilder; use p3_air::{Air, AirBuilder, BaseAir}; use p3_field::{AbstractField, Field}; use p3_matrix::Matrix; @@ -26,14 +27,35 @@ impl BaseAir for PageIndexScanOutputAir { } } -impl Air for PageIndexScanOutputAir { +impl Air for PageIndexScanOutputAir +where + AB::M: Clone, +{ fn eval(&self, builder: &mut AB) { - let main = builder.main(); + let page_main = &builder.partitioned_main()[0].clone(); + let aux_main = &builder.partitioned_main()[1].clone(); // get the current row and the next row - let (local, next) = (main.row_slice(0), main.row_slice(1)); - let local: &[AB::Var] = (*local).borrow(); - let next: &[AB::Var] = (*next).borrow(); + let (local_page, next_page) = (page_main.row_slice(0), page_main.row_slice(1)); + let local_page: &[AB::Var] = (*local_page).borrow(); + let next_page: &[AB::Var] = (*next_page).borrow(); + + let (local_aux, next_aux) = (aux_main.row_slice(0), aux_main.row_slice(1)); + let local_aux: &[AB::Var] = (*local_aux).borrow(); + let next_aux: &[AB::Var] = (*next_aux).borrow(); + + let local_vec = local_page + .iter() + .chain(local_aux.iter()) + .cloned() + .collect::>(); + let local = local_vec.as_slice(); + let next_vec = next_page + .iter() + .chain(next_aux.iter()) + .cloned() + .collect::>(); + let next = next_vec.as_slice(); let local_cols = PageIndexScanOutputCols::::from_slice( local, diff --git a/chips/src/single_page_index_scan/page_index_scan_output/columns.rs b/chips/src/single_page_index_scan/page_index_scan_output/columns.rs index 4742764b0a..317f5d41b0 100644 --- a/chips/src/single_page_index_scan/page_index_scan_output/columns.rs +++ b/chips/src/single_page_index_scan/page_index_scan_output/columns.rs @@ -14,7 +14,7 @@ impl PageIndexScanOutputCols { slc: &[T], idx_len: usize, data_len: usize, - limb_bits: Vec, + idx_limb_bits: Vec, decomp: usize, ) -> Self { Self { @@ -24,7 +24,7 @@ impl PageIndexScanOutputCols { less_than_next_idx: slc[idx_len + data_len + 1].clone(), is_less_than_tuple_aux: IsLessThanTupleAuxCols::from_slice( &slc[idx_len + data_len + 2..], - limb_bits, + idx_limb_bits, decomp, idx_len, ), @@ -34,12 +34,12 @@ impl PageIndexScanOutputCols { pub fn get_width( idx_len: usize, data_len: usize, - limb_bits: Vec, + idx_limb_bits: Vec, decomp: usize, ) -> usize { 1 + idx_len + data_len + 1 - + IsLessThanTupleAuxCols::::get_width(limb_bits, decomp, idx_len) + + IsLessThanTupleAuxCols::::get_width(idx_limb_bits, decomp, idx_len) } } diff --git a/chips/src/single_page_index_scan/page_index_scan_output/mod.rs b/chips/src/single_page_index_scan/page_index_scan_output/mod.rs index 7a38288e98..59215b7ca6 100644 --- a/chips/src/single_page_index_scan/page_index_scan_output/mod.rs +++ b/chips/src/single_page_index_scan/page_index_scan_output/mod.rs @@ -2,7 +2,10 @@ use std::sync::Arc; use getset::Getters; -use crate::{is_less_than_tuple::IsLessThanTupleAir, range_gate::RangeCheckerGateChip}; +use crate::{ + is_less_than_tuple::{columns::IsLessThanTupleAuxCols, IsLessThanTupleAir}, + range_gate::RangeCheckerGateChip, +}; pub mod air; pub mod bridge; @@ -40,7 +43,7 @@ impl PageIndexScanOutputChip { idx_len: usize, data_len: usize, range_max: u32, - limb_bits: Vec, + idx_limb_bits: Vec, decomp: usize, range_checker: Arc, ) -> Self { @@ -50,10 +53,29 @@ impl PageIndexScanOutputChip { idx_len, data_len, is_less_than_tuple_air: IsLessThanTupleAir::new( - bus_index, range_max, limb_bits, decomp, + bus_index, + range_max, + idx_limb_bits, + decomp, ), }, range_checker, } } + + pub fn page_width(&self) -> usize { + 1 + self.air.idx_len + self.air.data_len + } + + pub fn aux_width(&self) -> usize { + 1 + IsLessThanTupleAuxCols::::get_width( + self.air.is_less_than_tuple_air().limb_bits(), + self.air.is_less_than_tuple_air().decomp(), + self.air.idx_len, + ) + } + + pub fn air_width(&self) -> usize { + self.page_width() + self.aux_width() + } } diff --git a/chips/src/single_page_index_scan/page_index_scan_output/trace.rs b/chips/src/single_page_index_scan/page_index_scan_output/trace.rs index 5c41364d1b..7457ae204a 100644 --- a/chips/src/single_page_index_scan/page_index_scan_output/trace.rs +++ b/chips/src/single_page_index_scan/page_index_scan_output/trace.rs @@ -1,20 +1,39 @@ -use p3_field::PrimeField64; +use p3_field::{AbstractField, PrimeField64}; use p3_matrix::dense::RowMajorMatrix; +use p3_uni_stark::{StarkGenericConfig, Val}; use crate::sub_chip::LocalTraceInstructions; -use super::{columns::PageIndexScanOutputCols, PageIndexScanOutputChip}; +use super::PageIndexScanOutputChip; impl PageIndexScanOutputChip { - pub fn generate_trace(&self, page: Vec>) -> RowMajorMatrix { - let num_cols: usize = PageIndexScanOutputCols::::get_width( - self.air.idx_len, - self.air.data_len, - self.air.is_less_than_tuple_air().limb_bits().clone(), - self.air.is_less_than_tuple_air().decomp(), - ); + pub fn gen_page_trace( + &self, + page: Vec>, + ) -> RowMajorMatrix> + where + Val: AbstractField, + { + RowMajorMatrix::new( + page.into_iter() + .flat_map(|row| { + row.into_iter() + .map(Val::::from_wrapped_u32) + .collect::>>() + }) + .collect(), + self.page_width(), + ) + } - let mut rows: Vec = vec![]; + pub fn gen_aux_trace( + &self, + page: Vec>, + ) -> RowMajorMatrix> + where + Val: AbstractField + PrimeField64, + { + let mut rows: Vec> = vec![]; for i in 0..page.len() { let page_row = page[i].clone(); @@ -24,19 +43,7 @@ impl PageIndexScanOutputChip { page[i + 1].clone() }; - let mut row: Vec = vec![]; - - let is_alloc = F::from_canonical_u32(page_row[0]); - row.push(is_alloc); - - let idx = page_row[1..1 + self.air.idx_len].to_vec(); - let idx_trace: Vec = idx.iter().map(|x| F::from_canonical_u32(*x)).collect(); - row.extend(idx_trace); - - let data = - page_row[1 + self.air.idx_len..1 + self.air.idx_len + self.air.data_len].to_vec(); - let data_trace: Vec = data.iter().map(|x| F::from_canonical_u32(*x)).collect(); - row.extend(data_trace); + let mut row: Vec> = vec![]; let is_less_than_tuple_trace = LocalTraceInstructions::generate_trace_row( self.air.is_less_than_tuple_air(), @@ -53,6 +60,6 @@ impl PageIndexScanOutputChip { rows.extend_from_slice(&row); } - RowMajorMatrix::new(rows, num_cols) + RowMajorMatrix::new(rows, self.aux_width()) } } diff --git a/chips/src/single_page_index_scan/tests.rs b/chips/src/single_page_index_scan/tests.rs index a6b7f358ca..8052944b2a 100644 --- a/chips/src/single_page_index_scan/tests.rs +++ b/chips/src/single_page_index_scan/tests.rs @@ -1,78 +1,98 @@ -use std::sync::Arc; - -use afs_stark_backend::{prover::USE_DEBUG_BUILDER, verifier::VerificationError}; -use afs_test_utils::config::baby_bear_poseidon2::run_simple_test_no_pis; - -use super::{ - page_index_scan_input::PageIndexScanInputChip, page_index_scan_output::PageIndexScanOutputChip, +use afs_stark_backend::{ + keygen::{types::MultiStarkPartialProvingKey, MultiStarkKeygenBuilder}, + prover::{trace::TraceCommitmentBuilder, MultiTraceStarkProver}, + verifier::VerificationError, +}; +use afs_test_utils::{ + config::{ + self, + baby_bear_poseidon2::{BabyBearPoseidon2Config, BabyBearPoseidon2Engine}, + }, + engine::StarkEngine, }; -use crate::range_gate::RangeCheckerGateChip; -#[test] -fn test_single_page_index_scan() { - let bus_index: usize = 0; - let idx_len: usize = 2; - let data_len: usize = 3; - let decomp: usize = 8; - let limb_bits: Vec = vec![16, 16]; - let range_max: u32 = 1 << decomp; +use super::page_controller::PageController; - let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); +#[allow(clippy::too_many_arguments)] +fn index_scan_test( + engine: &BabyBearPoseidon2Engine, + page: Vec>, + x: Vec, + idx_len: usize, + data_len: usize, + idx_limb_bits: Vec, + idx_decomp: usize, + page_controller: &mut PageController, + trace_builder: &mut TraceCommitmentBuilder, + partial_pk: &MultiStarkPartialProvingKey, +) -> Result<(), VerificationError> { + let page_height = page.len(); + assert!(page_height > 0); - let page_index_scan_input_chip = PageIndexScanInputChip::new( - bus_index, - idx_len, - data_len, - range_max, - limb_bits.clone(), - decomp, - range_checker.clone(), - ); - let page_index_scan_output_chip = PageIndexScanOutputChip::new( - bus_index, + let (page_traces, mut prover_data) = page_controller.load_page( + page.clone(), + x, idx_len, data_len, - range_max, - limb_bits.clone(), - decomp, - range_checker.clone(), + idx_limb_bits, + idx_decomp, + &mut trace_builder.committer, ); - let range_checker_chip = range_checker.as_ref(); - let page: Vec> = vec![ - vec![1, 443, 376, 22278, 13998, 58327], - vec![1, 2883, 7769, 51171, 3989, 12770], - ]; + let input_chip_aux_trace = page_controller.input_chip_aux_trace(); + let output_chip_trace = page_controller.output_chip_trace(); + let output_chip_aux_trace = page_controller.output_chip_aux_trace(); + let range_checker_trace = page_controller.range_checker_trace(); - let page_indexed: Vec> = vec![ - vec![1, 443, 376, 22278, 13998, 58327], - vec![0, 0, 0, 0, 0, 0], - ]; + // Clearing the range_checker counts + page_controller.update_range_checker(idx_decomp); - let x: Vec = vec![2177, 5880]; + trace_builder.clear(); + + trace_builder.load_cached_trace(page_traces[0].clone(), prover_data.remove(0)); + trace_builder.load_trace(input_chip_aux_trace); + trace_builder.load_trace(output_chip_trace); + trace_builder.load_trace(output_chip_aux_trace); + trace_builder.load_trace(range_checker_trace); - let page_index_scan_chip_trace = page_index_scan_input_chip.generate_trace(page.clone(), x); - let page_index_scan_verify_chip_trace = - page_index_scan_output_chip.generate_trace(page_indexed.clone()); - let range_checker_trace = range_checker_chip.generate_trace(); + trace_builder.commit_current(); - run_simple_test_no_pis( + let partial_vk = partial_pk.partial_vk(); + + let main_trace_data = trace_builder.view( + &partial_vk, vec![ - &page_index_scan_input_chip.air, - &page_index_scan_output_chip.air, - &range_checker_chip.air, + &page_controller.input_chip.air, + &page_controller.output_chip.air, + &page_controller.range_checker.air, ], + ); + + let pis = vec![vec![]; partial_vk.per_air.len()]; + + let prover = engine.prover(); + let verifier = engine.verifier(); + + let mut challenger = engine.new_challenger(); + let proof = prover.prove(&mut challenger, partial_pk, main_trace_data, &pis); + + let mut challenger = engine.new_challenger(); + + verifier.verify( + &mut challenger, + partial_vk, vec![ - page_index_scan_chip_trace, - page_index_scan_verify_chip_trace, - range_checker_trace, + &page_controller.input_chip.air, + &page_controller.output_chip.air, + &page_controller.range_checker.air, ], + proof, + &pis, ) - .expect("Verification failed"); } #[test] -fn test_single_page_index_scan_wrong_order() { +fn test_single_page_index_scan() { let bus_index: usize = 0; let idx_len: usize = 2; let data_len: usize = 3; @@ -80,131 +100,277 @@ fn test_single_page_index_scan_wrong_order() { let limb_bits: Vec = vec![16, 16]; let range_max: u32 = 1 << decomp; - let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); + let log_page_height = 1; + let page_height = 1 << log_page_height; + let page_width = 1 + idx_len + data_len; - let page_index_scan_input_chip = PageIndexScanInputChip::new( - bus_index, - idx_len, - data_len, - range_max, - limb_bits.clone(), - decomp, - range_checker.clone(), - ); - let page_index_scan_output_chip = PageIndexScanOutputChip::new( + let mut page_controller: PageController = PageController::new( bus_index, idx_len, data_len, range_max, limb_bits.clone(), decomp, - range_checker.clone(), ); - let range_checker_chip = range_checker.as_ref(); - let page: Vec> = vec![ - vec![1, 443, 376, 22278, 13998, 58327], - vec![1, 2883, 7769, 51171, 3989, 12770], - ]; + let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); - let page_indexed: Vec> = vec![ - vec![0, 0, 0, 0, 0, 0], - vec![1, 443, 376, 22278, 13998, 58327], - ]; + let mut keygen_builder = MultiStarkKeygenBuilder::new(&engine.config); - let x: Vec = vec![2177, 5880]; + let input_page_ptr = keygen_builder.add_cached_main_matrix(page_width); + let input_page_aux_ptr = keygen_builder.add_main_matrix(page_controller.input_chip.aux_width()); + let output_page_ptr = keygen_builder.add_main_matrix(page_width); + let output_page_aux_ptr = + keygen_builder.add_main_matrix(page_controller.output_chip.aux_width()); + let range_checker_ptr = + keygen_builder.add_main_matrix(page_controller.range_checker.air_width()); - let page_index_scan_chip_trace = page_index_scan_input_chip.generate_trace(page.clone(), x); - let page_index_scan_verify_chip_trace = - page_index_scan_output_chip.generate_trace(page_indexed.clone()); - let range_checker_trace = range_checker_chip.generate_trace(); - - USE_DEBUG_BUILDER.with(|debug| { - *debug.lock().unwrap() = false; - }); - assert_eq!( - run_simple_test_no_pis( - vec![ - &page_index_scan_input_chip.air, - &page_index_scan_output_chip.air, - &range_checker_chip.air, - ], - vec![ - page_index_scan_chip_trace, - page_index_scan_verify_chip_trace, - range_checker_trace, - ], - ), - Err(VerificationError::OodEvaluationMismatch), - "Expected verification to fail, but it passed" + keygen_builder.add_partitioned_air( + &page_controller.input_chip.air, + page_height, + 0, + vec![input_page_ptr, input_page_aux_ptr], ); -} - -#[test] -fn test_single_page_index_scan_unsorted() { - let bus_index: usize = 0; - let idx_len: usize = 2; - let data_len: usize = 3; - let decomp: usize = 8; - let limb_bits: Vec = vec![16, 16]; - let range_max: u32 = 1 << decomp; - - let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); + dbg!(output_page_aux_ptr); - let page_index_scan_input_chip = PageIndexScanInputChip::new( - bus_index, - idx_len, - data_len, - range_max, - limb_bits.clone(), - decomp, - range_checker.clone(), + keygen_builder.add_partitioned_air( + &page_controller.output_chip.air, + page_height, + 0, + vec![output_page_ptr, output_page_aux_ptr], ); - let page_index_scan_output_chip = PageIndexScanOutputChip::new( - bus_index, - idx_len, - data_len, - range_max, - limb_bits.clone(), - decomp, - range_checker.clone(), + + keygen_builder.add_partitioned_air( + &page_controller.range_checker.air, + 1 << decomp, + 0, + vec![range_checker_ptr], ); - let range_checker_chip = range_checker.as_ref(); + + let partial_pk = keygen_builder.generate_partial_pk(); + + let prover = MultiTraceStarkProver::new(&engine.config); + let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); let page: Vec> = vec![ vec![1, 443, 376, 22278, 13998, 58327], vec![1, 2883, 7769, 51171, 3989, 12770], ]; - let page_indexed: Vec> = vec![ - vec![1, 2883, 7769, 51171, 3989, 12770], - vec![1, 443, 376, 22278, 13998, 58327], - ]; + let x: Vec = vec![2177, 5880]; - let x: Vec = vec![2883, 7770]; - - let page_index_scan_chip_trace = page_index_scan_input_chip.generate_trace(page.clone(), x); - let page_index_scan_verify_chip_trace = - page_index_scan_output_chip.generate_trace(page_indexed.clone()); - let range_checker_trace = range_checker_chip.generate_trace(); - - USE_DEBUG_BUILDER.with(|debug| { - *debug.lock().unwrap() = false; - }); - assert_eq!( - run_simple_test_no_pis( - vec![ - &page_index_scan_input_chip.air, - &page_index_scan_output_chip.air, - &range_checker_chip.air, - ], - vec![ - page_index_scan_chip_trace, - page_index_scan_verify_chip_trace, - range_checker_trace, - ], - ), - Err(VerificationError::OodEvaluationMismatch), - "Expected verification to fail, but it passed" - ); + index_scan_test( + &engine, + page, + x, + idx_len, + data_len, + limb_bits, + decomp, + &mut page_controller, + &mut trace_builder, + &partial_pk, + ) + .expect("Verification failed"); } + +// #[test] +// fn test_single_page_index_scan() { +// let bus_index: usize = 0; +// let idx_len: usize = 2; +// let data_len: usize = 3; +// let decomp: usize = 8; +// let limb_bits: Vec = vec![16, 16]; +// let range_max: u32 = 1 << decomp; + +// let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); + +// let page_index_scan_input_chip = PageIndexScanInputChip::new( +// bus_index, +// idx_len, +// data_len, +// range_max, +// limb_bits.clone(), +// decomp, +// range_checker.clone(), +// ); +// let page_index_scan_output_chip = PageIndexScanOutputChip::new( +// bus_index, +// idx_len, +// data_len, +// range_max, +// limb_bits.clone(), +// decomp, +// range_checker.clone(), +// ); +// let range_checker_chip = range_checker.as_ref(); + +// let page: Vec> = vec![ +// vec![1, 443, 376, 22278, 13998, 58327], +// vec![1, 2883, 7769, 51171, 3989, 12770], +// ]; + +// let page_indexed: Vec> = vec![ +// vec![1, 443, 376, 22278, 13998, 58327], +// vec![0, 0, 0, 0, 0, 0], +// ]; + +// let x: Vec = vec![2177, 5880]; + +// let page_index_scan_chip_trace = page_index_scan_input_chip.generate_trace(page.clone(), x); +// let page_index_scan_verify_chip_trace = +// page_index_scan_output_chip.generate_trace(page_indexed.clone()); +// let range_checker_trace = range_checker_chip.generate_trace(); + +// run_simple_test_no_pis( +// vec![ +// &page_index_scan_input_chip.air, +// &page_index_scan_output_chip.air, +// &range_checker_chip.air, +// ], +// vec![ +// page_index_scan_chip_trace, +// page_index_scan_verify_chip_trace, +// range_checker_trace, +// ], +// ) +// .expect("Verification failed"); +// } + +// #[test] +// fn test_single_page_index_scan_wrong_order() { +// let bus_index: usize = 0; +// let idx_len: usize = 2; +// let data_len: usize = 3; +// let decomp: usize = 8; +// let limb_bits: Vec = vec![16, 16]; +// let range_max: u32 = 1 << decomp; + +// let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); + +// let page_index_scan_input_chip = PageIndexScanInputChip::new( +// bus_index, +// idx_len, +// data_len, +// range_max, +// limb_bits.clone(), +// decomp, +// range_checker.clone(), +// ); +// let page_index_scan_output_chip = PageIndexScanOutputChip::new( +// bus_index, +// idx_len, +// data_len, +// range_max, +// limb_bits.clone(), +// decomp, +// range_checker.clone(), +// ); +// let range_checker_chip = range_checker.as_ref(); + +// let page: Vec> = vec![ +// vec![1, 443, 376, 22278, 13998, 58327], +// vec![1, 2883, 7769, 51171, 3989, 12770], +// ]; + +// let page_indexed: Vec> = vec![ +// vec![0, 0, 0, 0, 0, 0], +// vec![1, 443, 376, 22278, 13998, 58327], +// ]; + +// let x: Vec = vec![2177, 5880]; + +// let page_index_scan_chip_trace = page_index_scan_input_chip.generate_trace(page.clone(), x); +// let page_index_scan_verify_chip_trace = +// page_index_scan_output_chip.generate_trace(page_indexed.clone()); +// let range_checker_trace = range_checker_chip.generate_trace(); + +// USE_DEBUG_BUILDER.with(|debug| { +// *debug.lock().unwrap() = false; +// }); +// assert_eq!( +// run_simple_test_no_pis( +// vec![ +// &page_index_scan_input_chip.air, +// &page_index_scan_output_chip.air, +// &range_checker_chip.air, +// ], +// vec![ +// page_index_scan_chip_trace, +// page_index_scan_verify_chip_trace, +// range_checker_trace, +// ], +// ), +// Err(VerificationError::OodEvaluationMismatch), +// "Expected verification to fail, but it passed" +// ); +// } + +// #[test] +// fn test_single_page_index_scan_unsorted() { +// let bus_index: usize = 0; +// let idx_len: usize = 2; +// let data_len: usize = 3; +// let decomp: usize = 8; +// let limb_bits: Vec = vec![16, 16]; +// let range_max: u32 = 1 << decomp; + +// let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); + +// let page_index_scan_input_chip = PageIndexScanInputChip::new( +// bus_index, +// idx_len, +// data_len, +// range_max, +// limb_bits.clone(), +// decomp, +// range_checker.clone(), +// ); +// let page_index_scan_output_chip = PageIndexScanOutputChip::new( +// bus_index, +// idx_len, +// data_len, +// range_max, +// limb_bits.clone(), +// decomp, +// range_checker.clone(), +// ); +// let range_checker_chip = range_checker.as_ref(); + +// let page: Vec> = vec![ +// vec![1, 443, 376, 22278, 13998, 58327], +// vec![1, 2883, 7769, 51171, 3989, 12770], +// ]; + +// let page_indexed: Vec> = vec![ +// vec![1, 2883, 7769, 51171, 3989, 12770], +// vec![1, 443, 376, 22278, 13998, 58327], +// ]; + +// let x: Vec = vec![2883, 7770]; + +// let page_index_scan_chip_trace = page_index_scan_input_chip.generate_trace(page.clone(), x); +// let page_index_scan_verify_chip_trace = +// page_index_scan_output_chip.generate_trace(page_indexed.clone()); +// let range_checker_trace = range_checker_chip.generate_trace(); + +// USE_DEBUG_BUILDER.with(|debug| { +// *debug.lock().unwrap() = false; +// }); +// assert_eq!( +// run_simple_test_no_pis( +// vec![ +// &page_index_scan_input_chip.air, +// &page_index_scan_output_chip.air, +// &range_checker_chip.air, +// ], +// vec![ +// page_index_scan_chip_trace, +// page_index_scan_verify_chip_trace, +// range_checker_trace, +// ], +// ), +// Err(VerificationError::OodEvaluationMismatch), +// "Expected verification to fail, but it passed" +// ); +// } From b7c479ed87c376d49815e64ac012e477c9f1513d Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Mon, 10 Jun 2024 19:19:13 -0400 Subject: [PATCH 30/46] chore: cleanup index scan for less than predicate --- .../page_controller/mod.rs | 96 +++-- chips/src/single_page_index_scan/tests.rs | 400 +++++++++--------- 2 files changed, 267 insertions(+), 229 deletions(-) diff --git a/chips/src/single_page_index_scan/page_controller/mod.rs b/chips/src/single_page_index_scan/page_controller/mod.rs index 024df7c6bd..33b5af095c 100644 --- a/chips/src/single_page_index_scan/page_controller/mod.rs +++ b/chips/src/single_page_index_scan/page_controller/mod.rs @@ -6,7 +6,6 @@ use afs_stark_backend::{ }; use p3_field::{AbstractField, PrimeField, PrimeField64}; use p3_matrix::dense::DenseMatrix; -use p3_matrix::Matrix; use p3_uni_stark::{StarkGenericConfig, Val}; use crate::range_gate::RangeCheckerGateChip; @@ -28,6 +27,7 @@ where output_chip_aux_trace: Option>>, input_commitment: Option>, + output_commitment: Option>, pub range_checker: Arc, } @@ -69,6 +69,7 @@ where output_chip_trace: None, output_chip_aux_trace: None, input_commitment: None, + output_commitment: None, range_checker, } } @@ -103,10 +104,58 @@ where )); } + pub fn gen_output( + &self, + page: Vec>, + x: Vec, + idx_len: usize, + page_width: usize, + ) -> Vec> { + let mut output: Vec> = vec![]; + + for page_row in &page { + let is_alloc = page_row[0]; + let idx = page_row[1..1 + idx_len].to_vec(); + let data = page_row[1 + idx_len..].to_vec(); + + let mut less_than = false; + for (&idx_val, &x_val) in idx.iter().zip(x.iter()) { + use std::cmp::Ordering; + match idx_val.cmp(&x_val) { + Ordering::Less => { + less_than = true; + break; + } + Ordering::Greater => { + break; + } + Ordering::Equal => {} + } + } + + if less_than { + output.push( + vec![is_alloc] + .into_iter() + .chain(idx.iter().cloned()) + .chain(data.iter().cloned()) + .collect(), + ); + } + } + + let num_remaining = page.len() - output.len(); + + output.extend((0..num_remaining).map(|_| vec![0; page_width])); + + output + } + #[allow(clippy::too_many_arguments)] pub fn load_page( &mut self, - page: Vec>, + page_input: Vec>, + page_output: Vec>, x: Vec, idx_len: usize, data_len: usize, @@ -120,7 +169,7 @@ where // idx_decomp can't change between different pages since range_checker depends on it assert!(1 << idx_decomp == self.range_checker.range_max()); - assert!(!page.is_empty()); + assert!(!page_input.is_empty()); let bus_index = self.input_chip.air.bus_index; @@ -133,9 +182,11 @@ where idx_decomp, self.range_checker.clone(), ); - self.input_chip_trace = Some(self.input_chip.gen_page_trace::(page.clone())); - self.input_chip_aux_trace = - Some(self.input_chip.gen_aux_trace::(page.clone(), x.clone())); + self.input_chip_trace = Some(self.input_chip.gen_page_trace::(page_input.clone())); + self.input_chip_aux_trace = Some( + self.input_chip + .gen_aux_trace::(page_input.clone(), x.clone()), + ); self.output_chip = PageIndexScanOutputChip::new( bus_index, @@ -147,27 +198,24 @@ where self.range_checker.clone(), ); - let page_result = self.input_chip.gen_output(page.clone(), x.clone()); - - println!("page_result: {:?}", page_result); - - self.output_chip_trace = Some(self.output_chip.gen_page_trace::(page_result.clone())); + self.output_chip_trace = Some(self.output_chip.gen_page_trace::(page_output.clone())); self.output_chip_aux_trace = - Some(self.output_chip.gen_aux_trace::(page_result.clone())); + Some(self.output_chip.gen_aux_trace::(page_output.clone())); - let prover_data = - vec![trace_committer.commit(vec![self.input_chip_trace.clone().unwrap()])]; + let prover_data = vec![ + trace_committer.commit(vec![self.input_chip_trace.clone().unwrap()]), + trace_committer.commit(vec![self.output_chip_trace.clone().unwrap()]), + ]; self.input_commitment = Some(prover_data[0].commit.clone()); - - tracing::debug!( - "heights of all traces: {} {} {} {}", - self.input_chip_trace.as_ref().unwrap().height(), - self.input_chip_aux_trace.as_ref().unwrap().height(), - self.output_chip_trace.as_ref().unwrap().height(), - self.output_chip_aux_trace.as_ref().unwrap().height() - ); - - (vec![self.input_chip_trace.clone().unwrap()], prover_data) + self.output_commitment = Some(prover_data[1].commit.clone()); + + ( + vec![ + self.input_chip_trace.clone().unwrap(), + self.output_chip_trace.clone().unwrap(), + ], + prover_data, + ) } } diff --git a/chips/src/single_page_index_scan/tests.rs b/chips/src/single_page_index_scan/tests.rs index 8052944b2a..88ce17dad7 100644 --- a/chips/src/single_page_index_scan/tests.rs +++ b/chips/src/single_page_index_scan/tests.rs @@ -1,6 +1,6 @@ use afs_stark_backend::{ keygen::{types::MultiStarkPartialProvingKey, MultiStarkKeygenBuilder}, - prover::{trace::TraceCommitmentBuilder, MultiTraceStarkProver}, + prover::{trace::TraceCommitmentBuilder, MultiTraceStarkProver, USE_DEBUG_BUILDER}, verifier::VerificationError, }; use afs_test_utils::{ @@ -17,6 +17,7 @@ use super::page_controller::PageController; fn index_scan_test( engine: &BabyBearPoseidon2Engine, page: Vec>, + page_output: Vec>, x: Vec, idx_len: usize, data_len: usize, @@ -31,6 +32,7 @@ fn index_scan_test( let (page_traces, mut prover_data) = page_controller.load_page( page.clone(), + page_output.clone(), x, idx_len, data_len, @@ -40,7 +42,6 @@ fn index_scan_test( ); let input_chip_aux_trace = page_controller.input_chip_aux_trace(); - let output_chip_trace = page_controller.output_chip_trace(); let output_chip_aux_trace = page_controller.output_chip_aux_trace(); let range_checker_trace = page_controller.range_checker_trace(); @@ -50,8 +51,8 @@ fn index_scan_test( trace_builder.clear(); trace_builder.load_cached_trace(page_traces[0].clone(), prover_data.remove(0)); + trace_builder.load_cached_trace(page_traces[1].clone(), prover_data.remove(0)); trace_builder.load_trace(input_chip_aux_trace); - trace_builder.load_trace(output_chip_trace); trace_builder.load_trace(output_chip_aux_trace); trace_builder.load_trace(range_checker_trace); @@ -118,8 +119,8 @@ fn test_single_page_index_scan() { let mut keygen_builder = MultiStarkKeygenBuilder::new(&engine.config); let input_page_ptr = keygen_builder.add_cached_main_matrix(page_width); + let output_page_ptr = keygen_builder.add_cached_main_matrix(page_width); let input_page_aux_ptr = keygen_builder.add_main_matrix(page_controller.input_chip.aux_width()); - let output_page_ptr = keygen_builder.add_main_matrix(page_width); let output_page_aux_ptr = keygen_builder.add_main_matrix(page_controller.output_chip.aux_width()); let range_checker_ptr = @@ -131,7 +132,6 @@ fn test_single_page_index_scan() { 0, vec![input_page_ptr, input_page_aux_ptr], ); - dbg!(output_page_aux_ptr); keygen_builder.add_partitioned_air( &page_controller.output_chip.air, @@ -159,9 +159,12 @@ fn test_single_page_index_scan() { let x: Vec = vec![2177, 5880]; + let page_output = page_controller.gen_output(page.clone(), x.clone(), idx_len, page_width); + index_scan_test( &engine, page, + page_output, x, idx_len, data_len, @@ -174,203 +177,190 @@ fn test_single_page_index_scan() { .expect("Verification failed"); } -// #[test] -// fn test_single_page_index_scan() { -// let bus_index: usize = 0; -// let idx_len: usize = 2; -// let data_len: usize = 3; -// let decomp: usize = 8; -// let limb_bits: Vec = vec![16, 16]; -// let range_max: u32 = 1 << decomp; - -// let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); - -// let page_index_scan_input_chip = PageIndexScanInputChip::new( -// bus_index, -// idx_len, -// data_len, -// range_max, -// limb_bits.clone(), -// decomp, -// range_checker.clone(), -// ); -// let page_index_scan_output_chip = PageIndexScanOutputChip::new( -// bus_index, -// idx_len, -// data_len, -// range_max, -// limb_bits.clone(), -// decomp, -// range_checker.clone(), -// ); -// let range_checker_chip = range_checker.as_ref(); - -// let page: Vec> = vec![ -// vec![1, 443, 376, 22278, 13998, 58327], -// vec![1, 2883, 7769, 51171, 3989, 12770], -// ]; - -// let page_indexed: Vec> = vec![ -// vec![1, 443, 376, 22278, 13998, 58327], -// vec![0, 0, 0, 0, 0, 0], -// ]; - -// let x: Vec = vec![2177, 5880]; - -// let page_index_scan_chip_trace = page_index_scan_input_chip.generate_trace(page.clone(), x); -// let page_index_scan_verify_chip_trace = -// page_index_scan_output_chip.generate_trace(page_indexed.clone()); -// let range_checker_trace = range_checker_chip.generate_trace(); - -// run_simple_test_no_pis( -// vec![ -// &page_index_scan_input_chip.air, -// &page_index_scan_output_chip.air, -// &range_checker_chip.air, -// ], -// vec![ -// page_index_scan_chip_trace, -// page_index_scan_verify_chip_trace, -// range_checker_trace, -// ], -// ) -// .expect("Verification failed"); -// } - -// #[test] -// fn test_single_page_index_scan_wrong_order() { -// let bus_index: usize = 0; -// let idx_len: usize = 2; -// let data_len: usize = 3; -// let decomp: usize = 8; -// let limb_bits: Vec = vec![16, 16]; -// let range_max: u32 = 1 << decomp; - -// let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); - -// let page_index_scan_input_chip = PageIndexScanInputChip::new( -// bus_index, -// idx_len, -// data_len, -// range_max, -// limb_bits.clone(), -// decomp, -// range_checker.clone(), -// ); -// let page_index_scan_output_chip = PageIndexScanOutputChip::new( -// bus_index, -// idx_len, -// data_len, -// range_max, -// limb_bits.clone(), -// decomp, -// range_checker.clone(), -// ); -// let range_checker_chip = range_checker.as_ref(); - -// let page: Vec> = vec![ -// vec![1, 443, 376, 22278, 13998, 58327], -// vec![1, 2883, 7769, 51171, 3989, 12770], -// ]; - -// let page_indexed: Vec> = vec![ -// vec![0, 0, 0, 0, 0, 0], -// vec![1, 443, 376, 22278, 13998, 58327], -// ]; - -// let x: Vec = vec![2177, 5880]; - -// let page_index_scan_chip_trace = page_index_scan_input_chip.generate_trace(page.clone(), x); -// let page_index_scan_verify_chip_trace = -// page_index_scan_output_chip.generate_trace(page_indexed.clone()); -// let range_checker_trace = range_checker_chip.generate_trace(); - -// USE_DEBUG_BUILDER.with(|debug| { -// *debug.lock().unwrap() = false; -// }); -// assert_eq!( -// run_simple_test_no_pis( -// vec![ -// &page_index_scan_input_chip.air, -// &page_index_scan_output_chip.air, -// &range_checker_chip.air, -// ], -// vec![ -// page_index_scan_chip_trace, -// page_index_scan_verify_chip_trace, -// range_checker_trace, -// ], -// ), -// Err(VerificationError::OodEvaluationMismatch), -// "Expected verification to fail, but it passed" -// ); -// } - -// #[test] -// fn test_single_page_index_scan_unsorted() { -// let bus_index: usize = 0; -// let idx_len: usize = 2; -// let data_len: usize = 3; -// let decomp: usize = 8; -// let limb_bits: Vec = vec![16, 16]; -// let range_max: u32 = 1 << decomp; - -// let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); - -// let page_index_scan_input_chip = PageIndexScanInputChip::new( -// bus_index, -// idx_len, -// data_len, -// range_max, -// limb_bits.clone(), -// decomp, -// range_checker.clone(), -// ); -// let page_index_scan_output_chip = PageIndexScanOutputChip::new( -// bus_index, -// idx_len, -// data_len, -// range_max, -// limb_bits.clone(), -// decomp, -// range_checker.clone(), -// ); -// let range_checker_chip = range_checker.as_ref(); - -// let page: Vec> = vec![ -// vec![1, 443, 376, 22278, 13998, 58327], -// vec![1, 2883, 7769, 51171, 3989, 12770], -// ]; - -// let page_indexed: Vec> = vec![ -// vec![1, 2883, 7769, 51171, 3989, 12770], -// vec![1, 443, 376, 22278, 13998, 58327], -// ]; - -// let x: Vec = vec![2883, 7770]; - -// let page_index_scan_chip_trace = page_index_scan_input_chip.generate_trace(page.clone(), x); -// let page_index_scan_verify_chip_trace = -// page_index_scan_output_chip.generate_trace(page_indexed.clone()); -// let range_checker_trace = range_checker_chip.generate_trace(); - -// USE_DEBUG_BUILDER.with(|debug| { -// *debug.lock().unwrap() = false; -// }); -// assert_eq!( -// run_simple_test_no_pis( -// vec![ -// &page_index_scan_input_chip.air, -// &page_index_scan_output_chip.air, -// &range_checker_chip.air, -// ], -// vec![ -// page_index_scan_chip_trace, -// page_index_scan_verify_chip_trace, -// range_checker_trace, -// ], -// ), -// Err(VerificationError::OodEvaluationMismatch), -// "Expected verification to fail, but it passed" -// ); -// } +#[test] +fn test_single_page_index_scan_wrong_order() { + let bus_index: usize = 0; + let idx_len: usize = 2; + let data_len: usize = 3; + let decomp: usize = 8; + let limb_bits: Vec = vec![16, 16]; + let range_max: u32 = 1 << decomp; + + let log_page_height = 1; + let page_height = 1 << log_page_height; + let page_width = 1 + idx_len + data_len; + + let mut page_controller: PageController = PageController::new( + bus_index, + idx_len, + data_len, + range_max, + limb_bits.clone(), + decomp, + ); + + let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); + + let mut keygen_builder = MultiStarkKeygenBuilder::new(&engine.config); + + let input_page_ptr = keygen_builder.add_cached_main_matrix(page_width); + let output_page_ptr = keygen_builder.add_cached_main_matrix(page_width); + let input_page_aux_ptr = keygen_builder.add_main_matrix(page_controller.input_chip.aux_width()); + let output_page_aux_ptr = + keygen_builder.add_main_matrix(page_controller.output_chip.aux_width()); + let range_checker_ptr = + keygen_builder.add_main_matrix(page_controller.range_checker.air_width()); + + keygen_builder.add_partitioned_air( + &page_controller.input_chip.air, + page_height, + 0, + vec![input_page_ptr, input_page_aux_ptr], + ); + + keygen_builder.add_partitioned_air( + &page_controller.output_chip.air, + page_height, + 0, + vec![output_page_ptr, output_page_aux_ptr], + ); + + keygen_builder.add_partitioned_air( + &page_controller.range_checker.air, + 1 << decomp, + 0, + vec![range_checker_ptr], + ); + + let partial_pk = keygen_builder.generate_partial_pk(); + + let prover = MultiTraceStarkProver::new(&engine.config); + let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); + + let page: Vec> = vec![ + vec![1, 443, 376, 22278, 13998, 58327], + vec![1, 2883, 7769, 51171, 3989, 12770], + ]; + + let x: Vec = vec![2177, 5880]; + + let page_output = vec![ + vec![0, 0, 0, 0, 0, 0], + vec![1, 443, 376, 22278, 13998, 58327], + ]; + + USE_DEBUG_BUILDER.with(|debug| { + *debug.lock().unwrap() = false; + }); + assert_eq!( + index_scan_test( + &engine, + page, + page_output, + x, + idx_len, + data_len, + limb_bits, + decomp, + &mut page_controller, + &mut trace_builder, + &partial_pk, + ), + Err(VerificationError::OodEvaluationMismatch), + "Expected verification to fail, but it passed" + ); +} + +#[test] +fn test_single_page_index_scan_unsorted() { + let bus_index: usize = 0; + let idx_len: usize = 2; + let data_len: usize = 3; + let decomp: usize = 8; + let limb_bits: Vec = vec![16, 16]; + let range_max: u32 = 1 << decomp; + + let log_page_height = 1; + let page_height = 1 << log_page_height; + let page_width = 1 + idx_len + data_len; + + let mut page_controller: PageController = PageController::new( + bus_index, + idx_len, + data_len, + range_max, + limb_bits.clone(), + decomp, + ); + + let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); + + let mut keygen_builder = MultiStarkKeygenBuilder::new(&engine.config); + + let input_page_ptr = keygen_builder.add_cached_main_matrix(page_width); + let output_page_ptr = keygen_builder.add_cached_main_matrix(page_width); + let input_page_aux_ptr = keygen_builder.add_main_matrix(page_controller.input_chip.aux_width()); + let output_page_aux_ptr = + keygen_builder.add_main_matrix(page_controller.output_chip.aux_width()); + let range_checker_ptr = + keygen_builder.add_main_matrix(page_controller.range_checker.air_width()); + + keygen_builder.add_partitioned_air( + &page_controller.input_chip.air, + page_height, + 0, + vec![input_page_ptr, input_page_aux_ptr], + ); + + keygen_builder.add_partitioned_air( + &page_controller.output_chip.air, + page_height, + 0, + vec![output_page_ptr, output_page_aux_ptr], + ); + + keygen_builder.add_partitioned_air( + &page_controller.range_checker.air, + 1 << decomp, + 0, + vec![range_checker_ptr], + ); + + let partial_pk = keygen_builder.generate_partial_pk(); + + let prover = MultiTraceStarkProver::new(&engine.config); + let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); + + let page: Vec> = vec![ + vec![1, 2883, 7769, 51171, 3989, 12770], + vec![1, 443, 376, 22278, 13998, 58327], + ]; + + let x: Vec = vec![2177, 5880]; + + let page_output = vec![ + vec![0, 0, 0, 0, 0, 0], + vec![1, 443, 376, 22278, 13998, 58327], + ]; + + USE_DEBUG_BUILDER.with(|debug| { + *debug.lock().unwrap() = false; + }); + assert_eq!( + index_scan_test( + &engine, + page, + page_output, + x, + idx_len, + data_len, + limb_bits, + decomp, + &mut page_controller, + &mut trace_builder, + &partial_pk, + ), + Err(VerificationError::OodEvaluationMismatch), + "Expected verification to fail, but it passed" + ); +} From 04a31dd3974c17e9c7c20b65d107afff79adea92 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Tue, 11 Jun 2024 11:30:02 -0400 Subject: [PATCH 31/46] feat: x as public value --- .../page_index_scan_input/air.rs | 16 +++++-- .../page_index_scan_input/bridge.rs | 2 + .../page_index_scan_input/mod.rs | 8 ++++ .../page_index_scan_input/trace.rs | 43 +------------------ .../page_index_scan_output/air.rs | 2 +- .../page_index_scan_output/bridge.rs | 2 + .../page_index_scan_output/mod.rs | 8 ++++ .../page_index_scan_output/trace.rs | 2 + chips/src/single_page_index_scan/tests.rs | 16 ++++--- 9 files changed, 49 insertions(+), 50 deletions(-) diff --git a/chips/src/single_page_index_scan/page_index_scan_input/air.rs b/chips/src/single_page_index_scan/page_index_scan_input/air.rs index ed74d79124..dd47124cb6 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/air.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/air.rs @@ -1,5 +1,5 @@ use afs_stark_backend::air_builders::PartitionedAirBuilder; -use p3_air::{Air, BaseAir}; +use p3_air::{Air, AirBuilderWithPublicValues, BaseAir}; use p3_field::Field; use p3_matrix::Matrix; @@ -25,7 +25,7 @@ impl BaseAir for PageIndexScanInputAir { } } -impl Air for PageIndexScanInputAir +impl Air for PageIndexScanInputAir where AB::M: Clone, { @@ -33,6 +33,10 @@ where let page_main = &builder.partitioned_main()[0].clone(); let aux_main = &builder.partitioned_main()[1].clone(); + // get the public value x + let pis = builder.public_values(); + let x = pis[..self.idx_len].to_vec(); + let local_page = page_main.row_slice(0); let local_aux = aux_main.row_slice(0); let local_vec = local_page @@ -53,12 +57,18 @@ where let is_less_than_tuple_cols = IsLessThanTupleCols { io: IsLessThanTupleIOCols { x: local_cols.idx, - y: local_cols.x, + y: local_cols.x.clone(), tuple_less_than: local_cols.satisfies_pred, }, aux: local_cols.is_less_than_tuple_aux, }; + // constrain that the public value x is the same as the column x + for (&local_x, &pub_x) in local_cols.x.iter().zip(x.iter()) { + builder.assert_eq(local_x, pub_x); + } + + // constrain that we send the row iff the row is allocated and satisfies the predicate builder.assert_eq( local_cols.is_alloc * local_cols.satisfies_pred, local_cols.send_row, diff --git a/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs b/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs index 1f9f0ceab0..2e3374311b 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs @@ -37,6 +37,7 @@ impl AirBridge for PageIndexScanInputAir { aux: cols_numbered.is_less_than_tuple_aux, }; + // construct the row to send let mut cols = vec![]; cols.push(cols_numbered.is_alloc); cols.extend(cols_numbered.idx); @@ -47,6 +48,7 @@ impl AirBridge for PageIndexScanInputAir { .map(|col| VirtualPairCol::single_main(*col)) .collect::>(); + // sends with count given by send_row indicator let mut interactions = vec![Interaction { fields: virtual_cols, count: VirtualPairCol::single_main(cols_numbered.send_row), diff --git a/chips/src/single_page_index_scan/page_index_scan_input/mod.rs b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs index 525963f81b..1108273936 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/mod.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs @@ -24,14 +24,22 @@ pub enum Comp { #[derive(Default, Getters)] pub struct PageIndexScanInputAir { + /// The bus index pub bus_index: usize, + /// The length of each index in the page table pub idx_len: usize, + /// The length of each data entry in the page table pub data_len: usize, #[getset(skip)] is_less_than_tuple_air: IsLessThanTupleAir, } +/// Given a fixed predicate of the form index OP x, where OP is one of {<, <=, =, >=, >} +/// and x is a private input, the PageIndexScanInputChip implements a chip such that the chip: +/// +/// 1. Has public value x +/// 2. Sends all rows of the page that match the predicate index OP x where x is the public value pub struct PageIndexScanInputChip { pub air: PageIndexScanInputAir, pub range_checker: Arc, diff --git a/chips/src/single_page_index_scan/page_index_scan_input/trace.rs b/chips/src/single_page_index_scan/page_index_scan_input/trace.rs index 5ad88f8ee7..e5d8f2ac5a 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/trace.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/trace.rs @@ -7,6 +7,7 @@ use crate::sub_chip::LocalTraceInstructions; use super::PageIndexScanInputChip; impl PageIndexScanInputChip { + /// Generate the trace for the page table pub fn gen_page_trace( &self, page: Vec>, @@ -26,6 +27,7 @@ impl PageIndexScanInputChip { ) } + /// Generate the trace for the auxiliary columns pub fn gen_aux_trace( &self, page: Vec>, @@ -66,45 +68,4 @@ impl PageIndexScanInputChip { RowMajorMatrix::new(rows, self.aux_width()) } - - pub fn gen_output(&self, page: Vec>, x: Vec) -> Vec> { - let mut output: Vec> = vec![]; - - for page_row in &page { - let is_alloc = page_row[0]; - let idx = page_row[1..1 + self.air.idx_len].to_vec(); - let data = page_row[1 + self.air.idx_len..].to_vec(); - - let mut less_than = false; - for (&idx_val, &x_val) in idx.iter().zip(x.iter()) { - use std::cmp::Ordering; - match idx_val.cmp(&x_val) { - Ordering::Less => { - less_than = true; - break; - } - Ordering::Greater => { - break; - } - Ordering::Equal => {} - } - } - - if less_than { - output.push( - vec![is_alloc] - .into_iter() - .chain(idx.iter().cloned()) - .chain(data.iter().cloned()) - .collect(), - ); - } - } - - let num_remaining = page.len() - output.len(); - - output.extend((0..num_remaining).map(|_| vec![0; self.page_width()])); - - output - } } diff --git a/chips/src/single_page_index_scan/page_index_scan_output/air.rs b/chips/src/single_page_index_scan/page_index_scan_output/air.rs index 543401914d..ab0390122e 100644 --- a/chips/src/single_page_index_scan/page_index_scan_output/air.rs +++ b/chips/src/single_page_index_scan/page_index_scan_output/air.rs @@ -100,7 +100,7 @@ where is_less_than_tuple_cols.aux, ); - // if the next row exists, then the current row index must be less than the next row index + // if the next row is allocated, then the current row index must be less than the next row index builder .when_transition() .assert_zero(next_cols.is_alloc * (AB::Expr::one() - local_cols.less_than_next_idx)); diff --git a/chips/src/single_page_index_scan/page_index_scan_output/bridge.rs b/chips/src/single_page_index_scan/page_index_scan_output/bridge.rs index 66c65be145..8ede9a08b4 100644 --- a/chips/src/single_page_index_scan/page_index_scan_output/bridge.rs +++ b/chips/src/single_page_index_scan/page_index_scan_output/bridge.rs @@ -11,6 +11,7 @@ use p3_field::PrimeField64; use super::PageIndexScanOutputAir; impl AirBridge for PageIndexScanOutputAir { + // we receive the rows that satisfy the predicate fn receives(&self) -> Vec> { let num_cols = PageIndexScanOutputCols::::get_width( self.idx_len, @@ -45,6 +46,7 @@ impl AirBridge for PageIndexScanOutputAir { }] } + // we send range checks that are from the IsLessThanTuple subchip fn sends(&self) -> Vec> { let num_cols = PageIndexScanOutputCols::::get_width( self.idx_len, diff --git a/chips/src/single_page_index_scan/page_index_scan_output/mod.rs b/chips/src/single_page_index_scan/page_index_scan_output/mod.rs index 59215b7ca6..df68aa6177 100644 --- a/chips/src/single_page_index_scan/page_index_scan_output/mod.rs +++ b/chips/src/single_page_index_scan/page_index_scan_output/mod.rs @@ -24,14 +24,22 @@ pub enum Comp { #[derive(Default, Getters)] pub struct PageIndexScanOutputAir { + /// The bus index for sends to range chip pub bus_index: usize, + /// The length of each index in the page table pub idx_len: usize, + /// The length of each data entry in the page table pub data_len: usize, #[getset(get = "pub")] is_less_than_tuple_air: IsLessThanTupleAir, } +/// This chip receives rows from the PageIndexScanInputChip and constrains that: +/// +/// 1. All allocated rows are before unallocated rows +/// 2. The allocated rows are sorted in ascending order by index +/// 3. The allocated rows of the new page are exactly the result of the index scan (via interactions) pub struct PageIndexScanOutputChip { pub air: PageIndexScanOutputAir, pub range_checker: Arc, diff --git a/chips/src/single_page_index_scan/page_index_scan_output/trace.rs b/chips/src/single_page_index_scan/page_index_scan_output/trace.rs index 7457ae204a..9ad2d1d6d4 100644 --- a/chips/src/single_page_index_scan/page_index_scan_output/trace.rs +++ b/chips/src/single_page_index_scan/page_index_scan_output/trace.rs @@ -7,6 +7,7 @@ use crate::sub_chip::LocalTraceInstructions; use super::PageIndexScanOutputChip; impl PageIndexScanOutputChip { + /// Generate the trace for the page table pub fn gen_page_trace( &self, page: Vec>, @@ -26,6 +27,7 @@ impl PageIndexScanOutputChip { ) } + /// Generate the trace for the auxiliary columns pub fn gen_aux_trace( &self, page: Vec>, diff --git a/chips/src/single_page_index_scan/tests.rs b/chips/src/single_page_index_scan/tests.rs index 88ce17dad7..854105cc7c 100644 --- a/chips/src/single_page_index_scan/tests.rs +++ b/chips/src/single_page_index_scan/tests.rs @@ -10,6 +10,8 @@ use afs_test_utils::{ }, engine::StarkEngine, }; +use p3_baby_bear::BabyBear; +use p3_field::AbstractField; use super::page_controller::PageController; @@ -33,7 +35,7 @@ fn index_scan_test( let (page_traces, mut prover_data) = page_controller.load_page( page.clone(), page_output.clone(), - x, + x.clone(), idx_len, data_len, idx_limb_bits, @@ -69,7 +71,11 @@ fn index_scan_test( ], ); - let pis = vec![vec![]; partial_vk.per_air.len()]; + let pis = vec![ + x.iter().map(|x| BabyBear::from_canonical_u32(*x)).collect(), + vec![], + vec![], + ]; let prover = engine.prover(); let verifier = engine.verifier(); @@ -129,7 +135,7 @@ fn test_single_page_index_scan() { keygen_builder.add_partitioned_air( &page_controller.input_chip.air, page_height, - 0, + idx_len, vec![input_page_ptr, input_page_aux_ptr], ); @@ -214,7 +220,7 @@ fn test_single_page_index_scan_wrong_order() { keygen_builder.add_partitioned_air( &page_controller.input_chip.air, page_height, - 0, + idx_len, vec![input_page_ptr, input_page_aux_ptr], ); @@ -308,7 +314,7 @@ fn test_single_page_index_scan_unsorted() { keygen_builder.add_partitioned_air( &page_controller.input_chip.air, page_height, - 0, + idx_len, vec![input_page_ptr, input_page_aux_ptr], ); From 567923bbc5ea917ace163c11645618516dd5ae05 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Tue, 11 Jun 2024 12:35:10 -0400 Subject: [PATCH 32/46] feat: page index scan with comparator enum --- .../page_controller/mod.rs | 15 +- .../page_index_scan_input/air.rs | 131 +++++++++++------- .../page_index_scan_input/bridge.rs | 110 +++++++++------ .../page_index_scan_input/columns.rs | 66 +++++---- .../page_index_scan_input/mod.rs | 65 +++++---- .../page_index_scan_input/trace.rs | 44 +++--- chips/src/single_page_index_scan/tests.rs | 11 +- 7 files changed, 271 insertions(+), 171 deletions(-) diff --git a/chips/src/single_page_index_scan/page_controller/mod.rs b/chips/src/single_page_index_scan/page_controller/mod.rs index 33b5af095c..b050a3113e 100644 --- a/chips/src/single_page_index_scan/page_controller/mod.rs +++ b/chips/src/single_page_index_scan/page_controller/mod.rs @@ -8,10 +8,14 @@ use p3_field::{AbstractField, PrimeField, PrimeField64}; use p3_matrix::dense::DenseMatrix; use p3_uni_stark::{StarkGenericConfig, Val}; -use crate::range_gate::RangeCheckerGateChip; +use crate::{ + range_gate::RangeCheckerGateChip, + single_page_index_scan::page_index_scan_input::PageIndexScanInputAir, +}; use super::{ - page_index_scan_input::PageIndexScanInputChip, page_index_scan_output::PageIndexScanOutputChip, + page_index_scan_input::{Comp, PageIndexScanInputChip}, + page_index_scan_output::PageIndexScanOutputChip, }; pub struct PageController @@ -43,6 +47,7 @@ where range_max: u32, idx_limb_bits: Vec, idx_decomp: usize, + cmp: Comp, ) -> Self { let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, 1 << idx_decomp)); Self { @@ -54,6 +59,7 @@ where idx_limb_bits.clone(), idx_decomp, range_checker.clone(), + cmp, ), output_chip: PageIndexScanOutputChip::new( bus_index, @@ -171,7 +177,9 @@ where assert!(!page_input.is_empty()); - let bus_index = self.input_chip.air.bus_index; + let bus_index = match self.input_chip.air { + PageIndexScanInputAir::Lt { bus_index, .. } => bus_index, + }; self.input_chip = PageIndexScanInputChip::new( bus_index, @@ -181,6 +189,7 @@ where idx_limb_bits.clone(), idx_decomp, self.range_checker.clone(), + self.input_chip.cmp.clone(), ); self.input_chip_trace = Some(self.input_chip.gen_page_trace::(page_input.clone())); self.input_chip_aux_trace = Some( diff --git a/chips/src/single_page_index_scan/page_index_scan_input/air.rs b/chips/src/single_page_index_scan/page_index_scan_input/air.rs index dd47124cb6..006d879bb2 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/air.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/air.rs @@ -8,7 +8,7 @@ use crate::{ sub_chip::{AirConfig, SubAir}, }; -use super::{columns::PageIndexScanInputCols, PageIndexScanInputAir}; +use super::{columns::PageIndexScanInputCols, Comp, PageIndexScanInputAir}; impl AirConfig for PageIndexScanInputAir { type Cols = PageIndexScanInputCols; @@ -16,12 +16,20 @@ impl AirConfig for PageIndexScanInputAir { impl BaseAir for PageIndexScanInputAir { fn width(&self) -> usize { - PageIndexScanInputCols::::get_width( - self.idx_len, - self.data_len, - self.is_less_than_tuple_air.limb_bits().clone(), - self.is_less_than_tuple_air.decomp(), - ) + match &self { + PageIndexScanInputAir::Lt { + idx_len, + data_len, + is_less_than_tuple_air, + .. + } => PageIndexScanInputCols::::get_width( + *idx_len, + *data_len, + is_less_than_tuple_air.limb_bits().clone(), + is_less_than_tuple_air.decomp(), + Comp::Lt, + ), + } } } @@ -30,57 +38,76 @@ where AB::M: Clone, { fn eval(&self, builder: &mut AB) { - let page_main = &builder.partitioned_main()[0].clone(); - let aux_main = &builder.partitioned_main()[1].clone(); + match &self { + PageIndexScanInputAir::Lt { + idx_len, + data_len, + is_less_than_tuple_air, + .. + } => { + let page_main = &builder.partitioned_main()[0].clone(); + let aux_main = &builder.partitioned_main()[1].clone(); - // get the public value x - let pis = builder.public_values(); - let x = pis[..self.idx_len].to_vec(); + // get the public value x + let pis = builder.public_values(); + let public_x = pis[..*idx_len].to_vec(); - let local_page = page_main.row_slice(0); - let local_aux = aux_main.row_slice(0); - let local_vec = local_page - .iter() - .chain(local_aux.iter()) - .cloned() - .collect::>(); - let local = local_vec.as_slice(); + let local_page = page_main.row_slice(0); + let local_aux = aux_main.row_slice(0); + let local_vec = local_page + .iter() + .chain(local_aux.iter()) + .cloned() + .collect::>(); + let local = local_vec.as_slice(); - let local_cols = PageIndexScanInputCols::::from_slice( - local, - self.idx_len, - self.data_len, - self.is_less_than_tuple_air.limb_bits().clone(), - self.is_less_than_tuple_air.decomp(), - ); + let local_cols = PageIndexScanInputCols::::from_slice( + local, + *idx_len, + *data_len, + is_less_than_tuple_air.limb_bits().clone(), + is_less_than_tuple_air.decomp(), + Comp::Lt, + ); - let is_less_than_tuple_cols = IsLessThanTupleCols { - io: IsLessThanTupleIOCols { - x: local_cols.idx, - y: local_cols.x.clone(), - tuple_less_than: local_cols.satisfies_pred, - }, - aux: local_cols.is_less_than_tuple_aux, - }; + match local_cols { + PageIndexScanInputCols::Lt { + is_alloc, + idx, + x, + satisfies_pred, + send_row, + is_less_than_tuple_aux, + .. + } => { + let is_less_than_tuple_cols = IsLessThanTupleCols { + io: IsLessThanTupleIOCols { + x: idx, + y: x.clone(), + tuple_less_than: satisfies_pred, + }, + aux: is_less_than_tuple_aux, + }; - // constrain that the public value x is the same as the column x - for (&local_x, &pub_x) in local_cols.x.iter().zip(x.iter()) { - builder.assert_eq(local_x, pub_x); - } + // constrain that the public value x is the same as the column x + for (&local_x, &pub_x) in x.iter().zip(public_x.iter()) { + builder.assert_eq(local_x, pub_x); + } - // constrain that we send the row iff the row is allocated and satisfies the predicate - builder.assert_eq( - local_cols.is_alloc * local_cols.satisfies_pred, - local_cols.send_row, - ); - builder.assert_bool(local_cols.send_row); + // constrain that we send the row iff the row is allocated and satisfies the predicate + builder.assert_eq(is_alloc * satisfies_pred, send_row); + builder.assert_bool(send_row); - // constrain the indicator that we used to check wheter key < x is correct - SubAir::eval( - &self.is_less_than_tuple_air, - &mut builder.when_transition(), - is_less_than_tuple_cols.io, - is_less_than_tuple_cols.aux, - ); + // constrain the indicator that we used to check wheter key < x is correct + SubAir::eval( + is_less_than_tuple_air, + &mut builder.when_transition(), + is_less_than_tuple_cols.io, + is_less_than_tuple_cols.aux, + ); + } + } + } + } } } diff --git a/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs b/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs index 2e3374311b..e4966010da 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs @@ -3,7 +3,7 @@ use crate::{ sub_chip::SubAirBridge, }; -use super::columns::PageIndexScanInputCols; +use super::{columns::PageIndexScanInputCols, Comp}; use afs_stark_backend::interaction::{AirBridge, Interaction}; use p3_air::VirtualPairCol; use p3_field::PrimeField64; @@ -12,54 +12,80 @@ use super::PageIndexScanInputAir; impl AirBridge for PageIndexScanInputAir { fn sends(&self) -> Vec> { - let num_cols = PageIndexScanInputCols::::get_width( - self.idx_len, - self.data_len, - self.is_less_than_tuple_air.limb_bits(), - self.is_less_than_tuple_air.decomp(), - ); - let all_cols = (0..num_cols).collect::>(); + match &self { + PageIndexScanInputAir::Lt { + bus_index, + idx_len, + data_len, + is_less_than_tuple_air, + .. + } => { + let num_cols = PageIndexScanInputCols::::get_width( + *idx_len, + *data_len, + is_less_than_tuple_air.limb_bits(), + is_less_than_tuple_air.decomp(), + Comp::Lt, + ); + let all_cols = (0..num_cols).collect::>(); - let cols_numbered = PageIndexScanInputCols::::from_slice( - &all_cols, - self.idx_len, - self.data_len, - self.is_less_than_tuple_air.limb_bits(), - self.is_less_than_tuple_air.decomp(), - ); + let cols_numbered = PageIndexScanInputCols::::from_slice( + &all_cols, + *idx_len, + *data_len, + is_less_than_tuple_air.limb_bits(), + is_less_than_tuple_air.decomp(), + Comp::Lt, + ); - let is_less_than_tuple_cols = IsLessThanTupleCols { - io: IsLessThanTupleIOCols { - x: cols_numbered.idx.clone(), - y: cols_numbered.x.clone(), - tuple_less_than: cols_numbered.satisfies_pred, - }, - aux: cols_numbered.is_less_than_tuple_aux, - }; + let mut interactions: Vec> = vec![]; - // construct the row to send - let mut cols = vec![]; - cols.push(cols_numbered.is_alloc); - cols.extend(cols_numbered.idx); - cols.extend(cols_numbered.data); + match cols_numbered { + PageIndexScanInputCols::Lt { + is_alloc, + idx, + data, + send_row, + is_less_than_tuple_aux, + .. + } => { + let is_less_than_tuple_cols = IsLessThanTupleCols { + io: IsLessThanTupleIOCols { + x: idx.clone(), + y: data.clone(), + tuple_less_than: send_row, + }, + aux: is_less_than_tuple_aux, + }; - let virtual_cols = cols - .iter() - .map(|col| VirtualPairCol::single_main(*col)) - .collect::>(); + // construct the row to send + let mut cols = vec![]; + cols.push(is_alloc); + cols.extend(idx); + cols.extend(data); - // sends with count given by send_row indicator - let mut interactions = vec![Interaction { - fields: virtual_cols, - count: VirtualPairCol::single_main(cols_numbered.send_row), - argument_index: self.bus_index, - }]; + let virtual_cols = cols + .iter() + .map(|col| VirtualPairCol::single_main(*col)) + .collect::>(); - let mut subchip_interactions = - SubAirBridge::::sends(&self.is_less_than_tuple_air, is_less_than_tuple_cols); + interactions.push(Interaction { + fields: virtual_cols, + count: VirtualPairCol::single_main(send_row), + argument_index: *bus_index, + }); - interactions.append(&mut subchip_interactions); + let mut subchip_interactions = SubAirBridge::::sends( + is_less_than_tuple_air, + is_less_than_tuple_cols, + ); - interactions + interactions.append(&mut subchip_interactions); + } + } + + interactions + } + } } } diff --git a/chips/src/single_page_index_scan/page_index_scan_input/columns.rs b/chips/src/single_page_index_scan/page_index_scan_input/columns.rs index e81d849f62..cbef9fe323 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/columns.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/columns.rs @@ -1,15 +1,17 @@ use crate::is_less_than_tuple::columns::IsLessThanTupleAuxCols; -pub struct PageIndexScanInputCols { - pub is_alloc: T, - pub idx: Vec, - pub data: Vec, +use super::Comp; - pub x: Vec, - - pub satisfies_pred: T, - pub send_row: T, - pub is_less_than_tuple_aux: IsLessThanTupleAuxCols, +pub enum PageIndexScanInputCols { + Lt { + is_alloc: T, + idx: Vec, + data: Vec, + x: Vec, + satisfies_pred: T, + send_row: T, + is_less_than_tuple_aux: IsLessThanTupleAuxCols, + }, } impl PageIndexScanInputCols { @@ -19,20 +21,23 @@ impl PageIndexScanInputCols { data_len: usize, idx_limb_bits: Vec, decomp: usize, + cmp: Comp, ) -> Self { - Self { - is_alloc: slc[0].clone(), - idx: slc[1..idx_len + 1].to_vec(), - data: slc[idx_len + 1..idx_len + data_len + 1].to_vec(), - x: slc[idx_len + data_len + 1..2 * idx_len + data_len + 1].to_vec(), - satisfies_pred: slc[2 * idx_len + data_len + 1].clone(), - send_row: slc[2 * idx_len + data_len + 2].clone(), - is_less_than_tuple_aux: IsLessThanTupleAuxCols::from_slice( - &slc[2 * idx_len + data_len + 3..], - idx_limb_bits, - decomp, - idx_len, - ), + match cmp { + Comp::Lt => Self::Lt { + is_alloc: slc[0].clone(), + idx: slc[1..idx_len + 1].to_vec(), + data: slc[idx_len + 1..idx_len + data_len + 1].to_vec(), + x: slc[idx_len + data_len + 1..2 * idx_len + data_len + 1].to_vec(), + satisfies_pred: slc[2 * idx_len + data_len + 1].clone(), + send_row: slc[2 * idx_len + data_len + 2].clone(), + is_less_than_tuple_aux: IsLessThanTupleAuxCols::from_slice( + &slc[2 * idx_len + data_len + 3..], + idx_limb_bits, + decomp, + idx_len, + ), + }, } } @@ -41,12 +46,17 @@ impl PageIndexScanInputCols { data_len: usize, idx_limb_bits: Vec, decomp: usize, + cmp: Comp, ) -> usize { - 1 + idx_len - + data_len - + idx_len - + 1 - + 1 - + IsLessThanTupleAuxCols::::get_width(idx_limb_bits, decomp, idx_len) + match cmp { + Comp::Lt => { + 1 + idx_len + + data_len + + idx_len + + 1 + + 1 + + IsLessThanTupleAuxCols::::get_width(idx_limb_bits, decomp, idx_len) + } + } } } diff --git a/chips/src/single_page_index_scan/page_index_scan_input/mod.rs b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs index 1108273936..42654825bb 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/mod.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs @@ -1,7 +1,5 @@ use std::sync::Arc; -use getset::Getters; - use crate::{ is_less_than_tuple::{columns::IsLessThanTupleAuxCols, IsLessThanTupleAir}, range_gate::RangeCheckerGateChip, @@ -12,27 +10,23 @@ pub mod bridge; pub mod columns; pub mod trace; -#[derive(Default)] +#[derive(Default, Clone)] pub enum Comp { #[default] Lt, - Lte, - Eq, - Gte, - Gt, } -#[derive(Default, Getters)] -pub struct PageIndexScanInputAir { - /// The bus index - pub bus_index: usize, - /// The length of each index in the page table - pub idx_len: usize, - /// The length of each data entry in the page table - pub data_len: usize, +pub enum PageIndexScanInputAir { + Lt { + /// The bus index + bus_index: usize, + /// The length of each index in the page table + idx_len: usize, + /// The length of each data entry in the page table + data_len: usize, - #[getset(skip)] - is_less_than_tuple_air: IsLessThanTupleAir, + is_less_than_tuple_air: IsLessThanTupleAir, + }, } /// Given a fixed predicate of the form index OP x, where OP is one of {<, <=, =, >=, >} @@ -43,9 +37,11 @@ pub struct PageIndexScanInputAir { pub struct PageIndexScanInputChip { pub air: PageIndexScanInputAir, pub range_checker: Arc, + pub cmp: Comp, } impl PageIndexScanInputChip { + #[allow(clippy::too_many_arguments)] pub fn new( bus_index: usize, idx_len: usize, @@ -54,9 +50,10 @@ impl PageIndexScanInputChip { idx_limb_bits: Vec, decomp: usize, range_checker: Arc, + cmp: Comp, ) -> Self { Self { - air: PageIndexScanInputAir { + air: PageIndexScanInputAir::Lt { bus_index, idx_len, data_len, @@ -68,22 +65,36 @@ impl PageIndexScanInputChip { ), }, range_checker, + cmp, } } pub fn page_width(&self) -> usize { - 1 + self.air.idx_len + self.air.data_len + match &self.air { + PageIndexScanInputAir::Lt { + idx_len, data_len, .. + } => 1 + idx_len + data_len, + } } pub fn aux_width(&self) -> usize { - self.air.idx_len - + 1 - + 1 - + IsLessThanTupleAuxCols::::get_width( - self.air.is_less_than_tuple_air.limb_bits(), - self.air.is_less_than_tuple_air.decomp(), - self.air.idx_len, - ) + match &self.air { + PageIndexScanInputAir::Lt { + bus_index: _, + idx_len, + data_len: _, + is_less_than_tuple_air, + } => { + idx_len + + 1 + + 1 + + IsLessThanTupleAuxCols::::get_width( + is_less_than_tuple_air.limb_bits(), + is_less_than_tuple_air.decomp(), + *idx_len, + ) + } + } } pub fn air_width(&self) -> usize { diff --git a/chips/src/single_page_index_scan/page_index_scan_input/trace.rs b/chips/src/single_page_index_scan/page_index_scan_input/trace.rs index e5d8f2ac5a..d450ded017 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/trace.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/trace.rs @@ -4,7 +4,7 @@ use p3_uni_stark::{StarkGenericConfig, Val}; use crate::sub_chip::LocalTraceInstructions; -use super::PageIndexScanInputChip; +use super::{PageIndexScanInputAir, PageIndexScanInputChip}; impl PageIndexScanInputChip { /// Generate the trace for the page table @@ -41,27 +41,35 @@ impl PageIndexScanInputChip { for page_row in &page { let mut row: Vec> = vec![]; - let is_alloc = Val::::from_canonical_u32(page_row[0]); - let idx = page_row[1..1 + self.air.idx_len].to_vec(); + match &self.air { + PageIndexScanInputAir::Lt { + idx_len, + is_less_than_tuple_air, + .. + } => { + let is_alloc = Val::::from_canonical_u32(page_row[0]); + let idx = page_row[1..1 + *idx_len].to_vec(); - let x_trace: Vec> = x - .iter() - .map(|x| Val::::from_canonical_u32(*x)) - .collect(); - row.extend(x_trace); + let x_trace: Vec> = x + .iter() + .map(|x| Val::::from_canonical_u32(*x)) + .collect(); + row.extend(x_trace); - let is_less_than_tuple_trace: Vec> = - LocalTraceInstructions::generate_trace_row( - &self.air.is_less_than_tuple_air, - (idx.clone(), x.clone(), self.range_checker.clone()), - ) - .flatten(); + let is_less_than_tuple_trace: Vec> = + LocalTraceInstructions::generate_trace_row( + is_less_than_tuple_air, + (idx.clone(), x.clone(), self.range_checker.clone()), + ) + .flatten(); - row.push(is_less_than_tuple_trace[2 * self.air.idx_len]); - let send_row = is_less_than_tuple_trace[2 * self.air.idx_len] * is_alloc; - row.push(send_row); + row.push(is_less_than_tuple_trace[2 * *idx_len]); + let send_row = is_less_than_tuple_trace[2 * *idx_len] * is_alloc; + row.push(send_row); - row.extend_from_slice(&is_less_than_tuple_trace[2 * self.air.idx_len + 1..]); + row.extend_from_slice(&is_less_than_tuple_trace[2 * *idx_len + 1..]); + } + } rows.extend_from_slice(&row); } diff --git a/chips/src/single_page_index_scan/tests.rs b/chips/src/single_page_index_scan/tests.rs index 854105cc7c..1989ce73ae 100644 --- a/chips/src/single_page_index_scan/tests.rs +++ b/chips/src/single_page_index_scan/tests.rs @@ -13,7 +13,7 @@ use afs_test_utils::{ use p3_baby_bear::BabyBear; use p3_field::AbstractField; -use super::page_controller::PageController; +use super::{page_controller::PageController, page_index_scan_input::Comp}; #[allow(clippy::too_many_arguments)] fn index_scan_test( @@ -111,6 +111,8 @@ fn test_single_page_index_scan() { let page_height = 1 << log_page_height; let page_width = 1 + idx_len + data_len; + let cmp = Comp::Lt; + let mut page_controller: PageController = PageController::new( bus_index, idx_len, @@ -118,6 +120,7 @@ fn test_single_page_index_scan() { range_max, limb_bits.clone(), decomp, + cmp, ); let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); @@ -196,6 +199,8 @@ fn test_single_page_index_scan_wrong_order() { let page_height = 1 << log_page_height; let page_width = 1 + idx_len + data_len; + let cmp = Comp::Lt; + let mut page_controller: PageController = PageController::new( bus_index, idx_len, @@ -203,6 +208,7 @@ fn test_single_page_index_scan_wrong_order() { range_max, limb_bits.clone(), decomp, + cmp, ); let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); @@ -290,6 +296,8 @@ fn test_single_page_index_scan_unsorted() { let page_height = 1 << log_page_height; let page_width = 1 + idx_len + data_len; + let cmp = Comp::Lt; + let mut page_controller: PageController = PageController::new( bus_index, idx_len, @@ -297,6 +305,7 @@ fn test_single_page_index_scan_unsorted() { range_max, limb_bits.clone(), decomp, + cmp, ); let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); From cffc568e5b7c37bfe917168e1a603346213e21bb Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Tue, 11 Jun 2024 13:45:23 -0400 Subject: [PATCH 33/46] feat: page index scan for greater than predicate --- .../page_controller/mod.rs | 70 ++++++++++---- .../page_index_scan_input/air.rs | 93 ++++++++++++++++++ .../page_index_scan_input/bridge.rs | 84 ++++++++++++++++ .../page_index_scan_input/columns.rs | 32 +++++++ .../page_index_scan_input/mod.rs | 70 +++++++++++--- .../page_index_scan_input/trace.rs | 28 ++++++ chips/src/single_page_index_scan/tests.rs | 95 ++++++++++++++++++- 7 files changed, 436 insertions(+), 36 deletions(-) diff --git a/chips/src/single_page_index_scan/page_controller/mod.rs b/chips/src/single_page_index_scan/page_controller/mod.rs index b050a3113e..16f15d1af4 100644 --- a/chips/src/single_page_index_scan/page_controller/mod.rs +++ b/chips/src/single_page_index_scan/page_controller/mod.rs @@ -116,6 +116,7 @@ where x: Vec, idx_len: usize, page_width: usize, + cmp: Comp, ) -> Vec> { let mut output: Vec> = vec![]; @@ -124,29 +125,57 @@ where let idx = page_row[1..1 + idx_len].to_vec(); let data = page_row[1 + idx_len..].to_vec(); - let mut less_than = false; - for (&idx_val, &x_val) in idx.iter().zip(x.iter()) { - use std::cmp::Ordering; - match idx_val.cmp(&x_val) { - Ordering::Less => { - less_than = true; - break; + match cmp { + Comp::Lt => { + let mut less_than = false; + for (&idx_val, &x_val) in idx.iter().zip(x.iter()) { + use std::cmp::Ordering; + match idx_val.cmp(&x_val) { + Ordering::Less => { + less_than = true; + break; + } + Ordering::Greater => { + break; + } + Ordering::Equal => {} + } } - Ordering::Greater => { - break; + if less_than { + output.push( + vec![is_alloc] + .into_iter() + .chain(idx.iter().cloned()) + .chain(data.iter().cloned()) + .collect(), + ); + } + } + Comp::Gt => { + let mut greater_than = false; + for (&idx_val, &x_val) in idx.iter().zip(x.iter()) { + use std::cmp::Ordering; + match idx_val.cmp(&x_val) { + Ordering::Greater => { + greater_than = true; + break; + } + Ordering::Less => { + break; + } + Ordering::Equal => {} + } + } + if greater_than { + output.push( + vec![is_alloc] + .into_iter() + .chain(idx.iter().cloned()) + .chain(data.iter().cloned()) + .collect(), + ); } - Ordering::Equal => {} } - } - - if less_than { - output.push( - vec![is_alloc] - .into_iter() - .chain(idx.iter().cloned()) - .chain(data.iter().cloned()) - .collect(), - ); } } @@ -179,6 +208,7 @@ where let bus_index = match self.input_chip.air { PageIndexScanInputAir::Lt { bus_index, .. } => bus_index, + PageIndexScanInputAir::Gt { bus_index, .. } => bus_index, }; self.input_chip = PageIndexScanInputChip::new( diff --git a/chips/src/single_page_index_scan/page_index_scan_input/air.rs b/chips/src/single_page_index_scan/page_index_scan_input/air.rs index 006d879bb2..5369ba60f5 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/air.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/air.rs @@ -29,6 +29,18 @@ impl BaseAir for PageIndexScanInputAir { is_less_than_tuple_air.decomp(), Comp::Lt, ), + PageIndexScanInputAir::Gt { + idx_len, + data_len, + is_less_than_tuple_air, + .. + } => PageIndexScanInputCols::::get_width( + *idx_len, + *data_len, + is_less_than_tuple_air.limb_bits().clone(), + is_less_than_tuple_air.decomp(), + Comp::Gt, + ), } } } @@ -80,6 +92,7 @@ where is_less_than_tuple_aux, .. } => { + // here, we are checking if idx < x let is_less_than_tuple_cols = IsLessThanTupleCols { io: IsLessThanTupleIOCols { x: idx, @@ -98,6 +111,86 @@ where builder.assert_eq(is_alloc * satisfies_pred, send_row); builder.assert_bool(send_row); + // constrain the indicator that we used to check wheter key < x is correct + SubAir::eval( + is_less_than_tuple_air, + &mut builder.when_transition(), + is_less_than_tuple_cols.io, + is_less_than_tuple_cols.aux, + ); + } + PageIndexScanInputCols::Gt { .. } => { + panic!( + "expected PageIndexScanInputCols::Lt, got PageIndexScanInputCols::Gt" + ); + } + } + } + PageIndexScanInputAir::Gt { + idx_len, + data_len, + is_less_than_tuple_air, + .. + } => { + let page_main = &builder.partitioned_main()[0].clone(); + let aux_main = &builder.partitioned_main()[1].clone(); + + // get the public value x + let pis = builder.public_values(); + let public_x = pis[..*idx_len].to_vec(); + + let local_page = page_main.row_slice(0); + let local_aux = aux_main.row_slice(0); + let local_vec = local_page + .iter() + .chain(local_aux.iter()) + .cloned() + .collect::>(); + let local = local_vec.as_slice(); + + let local_cols = PageIndexScanInputCols::::from_slice( + local, + *idx_len, + *data_len, + is_less_than_tuple_air.limb_bits().clone(), + is_less_than_tuple_air.decomp(), + Comp::Gt, + ); + + match local_cols { + PageIndexScanInputCols::Lt { .. } => { + panic!( + "expected PageIndexScanInputCols::Lt, got PageIndexScanInputCols::Gt" + ); + } + PageIndexScanInputCols::Gt { + is_alloc, + idx, + x, + satisfies_pred, + send_row, + is_less_than_tuple_aux, + .. + } => { + // here, we are checking if idx > x + let is_less_than_tuple_cols = IsLessThanTupleCols { + io: IsLessThanTupleIOCols { + x: x.clone(), + y: idx, + tuple_less_than: satisfies_pred, + }, + aux: is_less_than_tuple_aux, + }; + + // constrain that the public value x is the same as the column x + for (&local_x, &pub_x) in x.iter().zip(public_x.iter()) { + builder.assert_eq(local_x, pub_x); + } + + // constrain that we send the row iff the row is allocated and satisfies the predicate + builder.assert_eq(is_alloc * satisfies_pred, send_row); + builder.assert_bool(send_row); + // constrain the indicator that we used to check wheter key < x is correct SubAir::eval( is_less_than_tuple_air, diff --git a/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs b/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs index e4966010da..2e8da8055e 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs @@ -80,6 +80,90 @@ impl AirBridge for PageIndexScanInputAir { is_less_than_tuple_cols, ); + interactions.append(&mut subchip_interactions); + } + PageIndexScanInputCols::Gt { .. } => { + panic!( + "expected PageIndexScanInputCols::Lt, got PageIndexScanInputCols::Gt" + ); + } + } + + interactions + } + + PageIndexScanInputAir::Gt { + bus_index, + idx_len, + data_len, + is_less_than_tuple_air, + .. + } => { + let num_cols = PageIndexScanInputCols::::get_width( + *idx_len, + *data_len, + is_less_than_tuple_air.limb_bits(), + is_less_than_tuple_air.decomp(), + Comp::Gt, + ); + let all_cols = (0..num_cols).collect::>(); + + let cols_numbered = PageIndexScanInputCols::::from_slice( + &all_cols, + *idx_len, + *data_len, + is_less_than_tuple_air.limb_bits(), + is_less_than_tuple_air.decomp(), + Comp::Gt, + ); + + let mut interactions: Vec> = vec![]; + + match cols_numbered { + PageIndexScanInputCols::Lt { .. } => { + panic!( + "expected PageIndexScanInputCols::Lt, got PageIndexScanInputCols::Gt" + ); + } + PageIndexScanInputCols::Gt { + is_alloc, + idx, + data, + send_row, + is_less_than_tuple_aux, + .. + } => { + let is_less_than_tuple_cols = IsLessThanTupleCols { + io: IsLessThanTupleIOCols { + x: idx.clone(), + y: data.clone(), + tuple_less_than: send_row, + }, + aux: is_less_than_tuple_aux, + }; + + // construct the row to send + let mut cols = vec![]; + cols.push(is_alloc); + cols.extend(idx); + cols.extend(data); + + let virtual_cols = cols + .iter() + .map(|col| VirtualPairCol::single_main(*col)) + .collect::>(); + + interactions.push(Interaction { + fields: virtual_cols, + count: VirtualPairCol::single_main(send_row), + argument_index: *bus_index, + }); + + let mut subchip_interactions = SubAirBridge::::sends( + is_less_than_tuple_air, + is_less_than_tuple_cols, + ); + interactions.append(&mut subchip_interactions); } } diff --git a/chips/src/single_page_index_scan/page_index_scan_input/columns.rs b/chips/src/single_page_index_scan/page_index_scan_input/columns.rs index cbef9fe323..228748cc42 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/columns.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/columns.rs @@ -12,6 +12,16 @@ pub enum PageIndexScanInputCols { send_row: T, is_less_than_tuple_aux: IsLessThanTupleAuxCols, }, + + Gt { + is_alloc: T, + idx: Vec, + data: Vec, + x: Vec, + satisfies_pred: T, + send_row: T, + is_less_than_tuple_aux: IsLessThanTupleAuxCols, + }, } impl PageIndexScanInputCols { @@ -38,6 +48,20 @@ impl PageIndexScanInputCols { idx_len, ), }, + Comp::Gt => Self::Gt { + is_alloc: slc[0].clone(), + idx: slc[1..idx_len + 1].to_vec(), + data: slc[idx_len + 1..idx_len + data_len + 1].to_vec(), + x: slc[idx_len + data_len + 1..2 * idx_len + data_len + 1].to_vec(), + satisfies_pred: slc[2 * idx_len + data_len + 1].clone(), + send_row: slc[2 * idx_len + data_len + 2].clone(), + is_less_than_tuple_aux: IsLessThanTupleAuxCols::from_slice( + &slc[2 * idx_len + data_len + 3..], + idx_limb_bits, + decomp, + idx_len, + ), + }, } } @@ -57,6 +81,14 @@ impl PageIndexScanInputCols { + 1 + IsLessThanTupleAuxCols::::get_width(idx_limb_bits, decomp, idx_len) } + Comp::Gt => { + 1 + idx_len + + data_len + + idx_len + + 1 + + 1 + + IsLessThanTupleAuxCols::::get_width(idx_limb_bits, decomp, idx_len) + } } } } diff --git a/chips/src/single_page_index_scan/page_index_scan_input/mod.rs b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs index 42654825bb..b6072d3bed 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/mod.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs @@ -14,6 +14,7 @@ pub mod trace; pub enum Comp { #[default] Lt, + Gt, } pub enum PageIndexScanInputAir { @@ -25,6 +26,16 @@ pub enum PageIndexScanInputAir { /// The length of each data entry in the page table data_len: usize, + is_less_than_tuple_air: IsLessThanTupleAir, + }, + Gt { + /// The bus index + bus_index: usize, + /// The length of each index in the page table + idx_len: usize, + /// The length of each data entry in the page table + data_len: usize, + is_less_than_tuple_air: IsLessThanTupleAir, }, } @@ -52,20 +63,37 @@ impl PageIndexScanInputChip { range_checker: Arc, cmp: Comp, ) -> Self { - Self { - air: PageIndexScanInputAir::Lt { - bus_index, - idx_len, - data_len, - is_less_than_tuple_air: IsLessThanTupleAir::new( + match cmp { + Comp::Lt => Self { + air: PageIndexScanInputAir::Lt { bus_index, - range_max, - idx_limb_bits.clone(), - decomp, - ), + idx_len, + data_len, + is_less_than_tuple_air: IsLessThanTupleAir::new( + bus_index, + range_max, + idx_limb_bits.clone(), + decomp, + ), + }, + range_checker, + cmp, + }, + Comp::Gt => Self { + air: PageIndexScanInputAir::Gt { + bus_index, + idx_len, + data_len, + is_less_than_tuple_air: IsLessThanTupleAir::new( + bus_index, + range_max, + idx_limb_bits.clone(), + decomp, + ), + }, + range_checker, + cmp, }, - range_checker, - cmp, } } @@ -74,6 +102,9 @@ impl PageIndexScanInputChip { PageIndexScanInputAir::Lt { idx_len, data_len, .. } => 1 + idx_len + data_len, + PageIndexScanInputAir::Gt { + idx_len, data_len, .. + } => 1 + idx_len + data_len, } } @@ -94,6 +125,21 @@ impl PageIndexScanInputChip { *idx_len, ) } + PageIndexScanInputAir::Gt { + bus_index: _, + idx_len, + data_len: _, + is_less_than_tuple_air, + } => { + idx_len + + 1 + + 1 + + IsLessThanTupleAuxCols::::get_width( + is_less_than_tuple_air.limb_bits(), + is_less_than_tuple_air.decomp(), + *idx_len, + ) + } } } diff --git a/chips/src/single_page_index_scan/page_index_scan_input/trace.rs b/chips/src/single_page_index_scan/page_index_scan_input/trace.rs index d450ded017..fd8aec2dfb 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/trace.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/trace.rs @@ -67,6 +67,34 @@ impl PageIndexScanInputChip { let send_row = is_less_than_tuple_trace[2 * *idx_len] * is_alloc; row.push(send_row); + row.extend_from_slice(&is_less_than_tuple_trace[2 * *idx_len + 1..]); + } + PageIndexScanInputAir::Gt { + idx_len, + is_less_than_tuple_air, + .. + } => { + let is_alloc = Val::::from_canonical_u32(page_row[0]); + let idx = page_row[1..1 + *idx_len].to_vec(); + + let x_trace: Vec> = x + .iter() + .map(|x| Val::::from_canonical_u32(*x)) + .collect(); + row.extend(x_trace); + + // we want to check if idx > x + let is_less_than_tuple_trace: Vec> = + LocalTraceInstructions::generate_trace_row( + is_less_than_tuple_air, + (x.clone(), idx.clone(), self.range_checker.clone()), + ) + .flatten(); + + row.push(is_less_than_tuple_trace[2 * *idx_len]); + let send_row = is_less_than_tuple_trace[2 * *idx_len] * is_alloc; + row.push(send_row); + row.extend_from_slice(&is_less_than_tuple_trace[2 * *idx_len + 1..]); } } diff --git a/chips/src/single_page_index_scan/tests.rs b/chips/src/single_page_index_scan/tests.rs index 1989ce73ae..1bc9a56187 100644 --- a/chips/src/single_page_index_scan/tests.rs +++ b/chips/src/single_page_index_scan/tests.rs @@ -99,7 +99,7 @@ fn index_scan_test( } #[test] -fn test_single_page_index_scan() { +fn test_single_page_index_scan_lt() { let bus_index: usize = 0; let idx_len: usize = 2; let data_len: usize = 3; @@ -111,7 +111,92 @@ fn test_single_page_index_scan() { let page_height = 1 << log_page_height; let page_width = 1 + idx_len + data_len; - let cmp = Comp::Lt; + let mut page_controller: PageController = PageController::new( + bus_index, + idx_len, + data_len, + range_max, + limb_bits.clone(), + decomp, + Comp::Lt, + ); + + let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); + + let mut keygen_builder = MultiStarkKeygenBuilder::new(&engine.config); + + let input_page_ptr = keygen_builder.add_cached_main_matrix(page_width); + let output_page_ptr = keygen_builder.add_cached_main_matrix(page_width); + let input_page_aux_ptr = keygen_builder.add_main_matrix(page_controller.input_chip.aux_width()); + let output_page_aux_ptr = + keygen_builder.add_main_matrix(page_controller.output_chip.aux_width()); + let range_checker_ptr = + keygen_builder.add_main_matrix(page_controller.range_checker.air_width()); + + keygen_builder.add_partitioned_air( + &page_controller.input_chip.air, + page_height, + idx_len, + vec![input_page_ptr, input_page_aux_ptr], + ); + + keygen_builder.add_partitioned_air( + &page_controller.output_chip.air, + page_height, + 0, + vec![output_page_ptr, output_page_aux_ptr], + ); + + keygen_builder.add_partitioned_air( + &page_controller.range_checker.air, + 1 << decomp, + 0, + vec![range_checker_ptr], + ); + + let partial_pk = keygen_builder.generate_partial_pk(); + + let prover = MultiTraceStarkProver::new(&engine.config); + let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); + + let page: Vec> = vec![ + vec![1, 443, 376, 22278, 13998, 58327], + vec![1, 2883, 7769, 51171, 3989, 12770], + ]; + + let x: Vec = vec![2177, 5880]; + + let page_output = + page_controller.gen_output(page.clone(), x.clone(), idx_len, page_width, Comp::Lt); + + index_scan_test( + &engine, + page, + page_output, + x, + idx_len, + data_len, + limb_bits, + decomp, + &mut page_controller, + &mut trace_builder, + &partial_pk, + ) + .expect("Verification failed"); +} + +#[test] +fn test_single_page_index_scan_gt() { + let bus_index: usize = 0; + let idx_len: usize = 2; + let data_len: usize = 3; + let decomp: usize = 8; + let limb_bits: Vec = vec![16, 16]; + let range_max: u32 = 1 << decomp; + + let log_page_height = 1; + let page_height = 1 << log_page_height; + let page_width = 1 + idx_len + data_len; let mut page_controller: PageController = PageController::new( bus_index, @@ -120,7 +205,7 @@ fn test_single_page_index_scan() { range_max, limb_bits.clone(), decomp, - cmp, + Comp::Gt, ); let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); @@ -168,7 +253,9 @@ fn test_single_page_index_scan() { let x: Vec = vec![2177, 5880]; - let page_output = page_controller.gen_output(page.clone(), x.clone(), idx_len, page_width); + let page_output = + page_controller.gen_output(page.clone(), x.clone(), idx_len, page_width, Comp::Gt); + println!("page_output: {:?}", page_output); index_scan_test( &engine, From 62fe907db9670a94e4a11b31845603e14a683212 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Tue, 11 Jun 2024 15:21:35 -0400 Subject: [PATCH 34/46] feat: page index scan for equal to predicates --- chips/src/is_equal_vec/mod.rs | 4 + .../page_controller/mod.rs | 19 ++++ .../page_index_scan_input/air.rs | 97 ++++++++++++++++++- .../page_index_scan_input/bridge.rs | 95 ++++++++++++++++-- .../page_index_scan_input/columns.rs | 27 +++++- .../page_index_scan_input/mod.rs | 31 ++++++ .../page_index_scan_input/trace.rs | 36 +++++++ chips/src/single_page_index_scan/tests.rs | 88 ++++++++++++++++- 8 files changed, 386 insertions(+), 11 deletions(-) diff --git a/chips/src/is_equal_vec/mod.rs b/chips/src/is_equal_vec/mod.rs index 9b2f123754..e6b2850f4f 100644 --- a/chips/src/is_equal_vec/mod.rs +++ b/chips/src/is_equal_vec/mod.rs @@ -11,6 +11,10 @@ pub struct IsEqualVecAir { } impl IsEqualVecAir { + pub fn new(vec_len: usize) -> Self { + Self { vec_len } + } + pub fn request(&self, x: &[F], y: &[F]) -> bool { x == y } diff --git a/chips/src/single_page_index_scan/page_controller/mod.rs b/chips/src/single_page_index_scan/page_controller/mod.rs index 16f15d1af4..9cbc854de0 100644 --- a/chips/src/single_page_index_scan/page_controller/mod.rs +++ b/chips/src/single_page_index_scan/page_controller/mod.rs @@ -151,6 +151,24 @@ where ); } } + Comp::Eq => { + let mut eq = true; + for (&idx_val, &x_val) in idx.iter().zip(x.iter()) { + if idx_val != x_val { + eq = false; + break; + } + } + if eq { + output.push( + vec![is_alloc] + .into_iter() + .chain(idx.iter().cloned()) + .chain(data.iter().cloned()) + .collect(), + ); + } + } Comp::Gt => { let mut greater_than = false; for (&idx_val, &x_val) in idx.iter().zip(x.iter()) { @@ -208,6 +226,7 @@ where let bus_index = match self.input_chip.air { PageIndexScanInputAir::Lt { bus_index, .. } => bus_index, + PageIndexScanInputAir::Eq { bus_index, .. } => bus_index, PageIndexScanInputAir::Gt { bus_index, .. } => bus_index, }; diff --git a/chips/src/single_page_index_scan/page_index_scan_input/air.rs b/chips/src/single_page_index_scan/page_index_scan_input/air.rs index 5369ba60f5..abf28e9627 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/air.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/air.rs @@ -4,6 +4,7 @@ use p3_field::Field; use p3_matrix::Matrix; use crate::{ + is_equal_vec::columns::{IsEqualVecCols, IsEqualVecIOCols}, is_less_than_tuple::columns::{IsLessThanTupleCols, IsLessThanTupleIOCols}, sub_chip::{AirConfig, SubAir}, }; @@ -29,6 +30,10 @@ impl BaseAir for PageIndexScanInputAir { is_less_than_tuple_air.decomp(), Comp::Lt, ), + // there is no idx_limb_bits or decomp, so we supply an empty vec and 0, respectively + PageIndexScanInputAir::Eq { + idx_len, data_len, .. + } => PageIndexScanInputCols::::get_width(*idx_len, *data_len, vec![], 0, Comp::Eq), PageIndexScanInputAir::Gt { idx_len, data_len, @@ -119,6 +124,11 @@ where is_less_than_tuple_cols.aux, ); } + PageIndexScanInputCols::Eq { .. } => { + panic!( + "expected PageIndexScanInputCols::Lt, got PageIndexScanInputCols::Eq" + ); + } PageIndexScanInputCols::Gt { .. } => { panic!( "expected PageIndexScanInputCols::Lt, got PageIndexScanInputCols::Gt" @@ -126,6 +136,86 @@ where } } } + PageIndexScanInputAir::Eq { + idx_len, + data_len, + is_equal_vec_air, + .. + } => { + let page_main = &builder.partitioned_main()[0].clone(); + let aux_main = &builder.partitioned_main()[1].clone(); + + // get the public value x + let pis = builder.public_values(); + let public_x = pis[..*idx_len].to_vec(); + + let local_page = page_main.row_slice(0); + let local_aux = aux_main.row_slice(0); + let local_vec = local_page + .iter() + .chain(local_aux.iter()) + .cloned() + .collect::>(); + let local = local_vec.as_slice(); + + let local_cols = PageIndexScanInputCols::::from_slice( + local, + *idx_len, + *data_len, + vec![], + 0, + Comp::Eq, + ); + + match local_cols { + PageIndexScanInputCols::Lt { .. } => { + panic!( + "expected PageIndexScanInputCols::Gt, got PageIndexScanInputCols::Lt" + ); + } + PageIndexScanInputCols::Eq { + is_alloc, + idx, + x, + satisfies_pred, + send_row, + is_equal_vec_aux, + .. + } => { + // here, we are checking if idx = x + let is_equal_vec_cols = IsEqualVecCols { + io: IsEqualVecIOCols { + x: idx, + y: x.clone(), + prod: satisfies_pred, + }, + aux: is_equal_vec_aux, + }; + + // constrain that the public value x is the same as the column x + for (&local_x, &pub_x) in x.iter().zip(public_x.iter()) { + builder.assert_eq(local_x, pub_x); + } + + // constrain that we send the row iff the row is allocated and satisfies the predicate + builder.assert_eq(is_alloc * satisfies_pred, send_row); + builder.assert_bool(send_row); + + // constrain the indicator that we used to check wheter key = x is correct + SubAir::eval( + is_equal_vec_air, + builder, + is_equal_vec_cols.io, + is_equal_vec_cols.aux, + ); + } + PageIndexScanInputCols::Gt { .. } => { + panic!( + "expected PageIndexScanInputCols::Eq, got PageIndexScanInputCols::Gt" + ); + } + } + } PageIndexScanInputAir::Gt { idx_len, data_len, @@ -160,7 +250,12 @@ where match local_cols { PageIndexScanInputCols::Lt { .. } => { panic!( - "expected PageIndexScanInputCols::Lt, got PageIndexScanInputCols::Gt" + "expected PageIndexScanInputCols::Gt, got PageIndexScanInputCols::Lt" + ); + } + PageIndexScanInputCols::Eq { .. } => { + panic!( + "expected PageIndexScanInputCols::Gt, got PageIndexScanInputCols::Eq" ); } PageIndexScanInputCols::Gt { diff --git a/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs b/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs index 2e8da8055e..8b7e9a886b 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs @@ -18,7 +18,6 @@ impl AirBridge for PageIndexScanInputAir { idx_len, data_len, is_less_than_tuple_air, - .. } => { let num_cols = PageIndexScanInputCols::::get_width( *idx_len, @@ -45,6 +44,8 @@ impl AirBridge for PageIndexScanInputAir { is_alloc, idx, data, + x, + satisfies_pred, send_row, is_less_than_tuple_aux, .. @@ -52,8 +53,8 @@ impl AirBridge for PageIndexScanInputAir { let is_less_than_tuple_cols = IsLessThanTupleCols { io: IsLessThanTupleIOCols { x: idx.clone(), - y: data.clone(), - tuple_less_than: send_row, + y: x.clone(), + tuple_less_than: satisfies_pred, }, aux: is_less_than_tuple_aux, }; @@ -82,6 +83,11 @@ impl AirBridge for PageIndexScanInputAir { interactions.append(&mut subchip_interactions); } + PageIndexScanInputCols::Eq { .. } => { + panic!( + "expected PageIndexScanInputCols::Lt, got PageIndexScanInputCols::Eq" + ); + } PageIndexScanInputCols::Gt { .. } => { panic!( "expected PageIndexScanInputCols::Lt, got PageIndexScanInputCols::Gt" @@ -97,7 +103,6 @@ impl AirBridge for PageIndexScanInputAir { idx_len, data_len, is_less_than_tuple_air, - .. } => { let num_cols = PageIndexScanInputCols::::get_width( *idx_len, @@ -122,22 +127,29 @@ impl AirBridge for PageIndexScanInputAir { match cols_numbered { PageIndexScanInputCols::Lt { .. } => { panic!( - "expected PageIndexScanInputCols::Lt, got PageIndexScanInputCols::Gt" + "expected PageIndexScanInputCols::Gt, got PageIndexScanInputCols::Lt" + ); + } + PageIndexScanInputCols::Eq { .. } => { + panic!( + "expected PageIndexScanInputCols::Gt, got PageIndexScanInputCols::Eq" ); } PageIndexScanInputCols::Gt { is_alloc, idx, data, + x, send_row, + satisfies_pred, is_less_than_tuple_aux, .. } => { let is_less_than_tuple_cols = IsLessThanTupleCols { io: IsLessThanTupleIOCols { - x: idx.clone(), - y: data.clone(), - tuple_less_than: send_row, + x: x.clone(), + y: idx.clone(), + tuple_less_than: satisfies_pred, }, aux: is_less_than_tuple_aux, }; @@ -168,6 +180,73 @@ impl AirBridge for PageIndexScanInputAir { } } + interactions + } + PageIndexScanInputAir::Eq { + bus_index, + idx_len, + data_len, + .. + } => { + // There is no limb_bits or decomp for IsEqualVec, so we can just pass in an empty vec and 0, respectively + let num_cols = PageIndexScanInputCols::::get_width( + *idx_len, + *data_len, + vec![], + 0, + Comp::Eq, + ); + + let all_cols = (0..num_cols).collect::>(); + + let cols_numbered = PageIndexScanInputCols::::from_slice( + &all_cols, + *idx_len, + *data_len, + vec![], + 0, + Comp::Eq, + ); + + let mut interactions: Vec> = vec![]; + + match cols_numbered { + PageIndexScanInputCols::Lt { .. } => { + panic!( + "expected PageIndexScanInputCols::Eq, got PageIndexScanInputCols::Lt" + ); + } + PageIndexScanInputCols::Eq { + is_alloc, + idx, + data, + send_row, + .. + } => { + // construct the row to send + let mut cols = vec![]; + cols.push(is_alloc); + cols.extend(idx); + cols.extend(data); + + let virtual_cols = cols + .iter() + .map(|col| VirtualPairCol::single_main(*col)) + .collect::>(); + + interactions.push(Interaction { + fields: virtual_cols, + count: VirtualPairCol::single_main(send_row), + argument_index: *bus_index, + }); + } + PageIndexScanInputCols::Gt { .. } => { + panic!( + "expected PageIndexScanInputCols::Eq, got PageIndexScanInputCols::Gt" + ); + } + } + interactions } } diff --git a/chips/src/single_page_index_scan/page_index_scan_input/columns.rs b/chips/src/single_page_index_scan/page_index_scan_input/columns.rs index 228748cc42..d4b77c8644 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/columns.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/columns.rs @@ -1,4 +1,6 @@ -use crate::is_less_than_tuple::columns::IsLessThanTupleAuxCols; +use crate::{ + is_equal_vec::columns::IsEqualVecAuxCols, is_less_than_tuple::columns::IsLessThanTupleAuxCols, +}; use super::Comp; @@ -13,6 +15,16 @@ pub enum PageIndexScanInputCols { is_less_than_tuple_aux: IsLessThanTupleAuxCols, }, + Eq { + is_alloc: T, + idx: Vec, + data: Vec, + x: Vec, + satisfies_pred: T, + send_row: T, + is_equal_vec_aux: IsEqualVecAuxCols, + }, + Gt { is_alloc: T, idx: Vec, @@ -48,6 +60,18 @@ impl PageIndexScanInputCols { idx_len, ), }, + Comp::Eq => Self::Eq { + is_alloc: slc[0].clone(), + idx: slc[1..idx_len + 1].to_vec(), + data: slc[idx_len + 1..idx_len + data_len + 1].to_vec(), + x: slc[idx_len + data_len + 1..2 * idx_len + data_len + 1].to_vec(), + satisfies_pred: slc[2 * idx_len + data_len + 1].clone(), + send_row: slc[2 * idx_len + data_len + 2].clone(), + is_equal_vec_aux: IsEqualVecAuxCols { + prods: slc[2 * idx_len + data_len + 3..3 * idx_len + data_len + 3].to_vec(), + invs: slc[3 * idx_len + data_len + 3..].to_vec(), + }, + }, Comp::Gt => Self::Gt { is_alloc: slc[0].clone(), idx: slc[1..idx_len + 1].to_vec(), @@ -81,6 +105,7 @@ impl PageIndexScanInputCols { + 1 + IsLessThanTupleAuxCols::::get_width(idx_limb_bits, decomp, idx_len) } + Comp::Eq => 1 + idx_len + data_len + idx_len + 1 + 1 + 2 * idx_len, Comp::Gt => { 1 + idx_len + data_len diff --git a/chips/src/single_page_index_scan/page_index_scan_input/mod.rs b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs index b6072d3bed..509af70dc6 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/mod.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use crate::{ + is_equal_vec::IsEqualVecAir, is_less_than_tuple::{columns::IsLessThanTupleAuxCols, IsLessThanTupleAir}, range_gate::RangeCheckerGateChip, }; @@ -14,6 +15,7 @@ pub mod trace; pub enum Comp { #[default] Lt, + Eq, Gt, } @@ -28,6 +30,16 @@ pub enum PageIndexScanInputAir { is_less_than_tuple_air: IsLessThanTupleAir, }, + Eq { + /// The bus index + bus_index: usize, + /// The length of each index in the page table + idx_len: usize, + /// The length of each data entry in the page table + data_len: usize, + + is_equal_vec_air: IsEqualVecAir, + }, Gt { /// The bus index bus_index: usize, @@ -79,6 +91,16 @@ impl PageIndexScanInputChip { range_checker, cmp, }, + Comp::Eq => Self { + air: PageIndexScanInputAir::Eq { + bus_index, + idx_len, + data_len, + is_equal_vec_air: IsEqualVecAir::new(idx_len), + }, + range_checker, + cmp, + }, Comp::Gt => Self { air: PageIndexScanInputAir::Gt { bus_index, @@ -102,6 +124,9 @@ impl PageIndexScanInputChip { PageIndexScanInputAir::Lt { idx_len, data_len, .. } => 1 + idx_len + data_len, + PageIndexScanInputAir::Eq { + idx_len, data_len, .. + } => 1 + idx_len + data_len, PageIndexScanInputAir::Gt { idx_len, data_len, .. } => 1 + idx_len + data_len, @@ -125,6 +150,12 @@ impl PageIndexScanInputChip { *idx_len, ) } + PageIndexScanInputAir::Eq { + bus_index: _, + idx_len, + data_len: _, + is_equal_vec_air: _, + } => idx_len + 1 + 1 + 2 * idx_len, PageIndexScanInputAir::Gt { bus_index: _, idx_len, diff --git a/chips/src/single_page_index_scan/page_index_scan_input/trace.rs b/chips/src/single_page_index_scan/page_index_scan_input/trace.rs index fd8aec2dfb..41e9be4474 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/trace.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/trace.rs @@ -69,6 +69,42 @@ impl PageIndexScanInputChip { row.extend_from_slice(&is_less_than_tuple_trace[2 * *idx_len + 1..]); } + PageIndexScanInputAir::Eq { + idx_len, + is_equal_vec_air, + .. + } => { + let is_alloc = Val::::from_canonical_u32(page_row[0]); + let idx = page_row[1..1 + *idx_len].to_vec(); + + let x_trace: Vec> = x + .iter() + .map(|x| Val::::from_canonical_u32(*x)) + .collect(); + row.extend(x_trace); + + let is_equal_vec_trace: Vec> = + LocalTraceInstructions::generate_trace_row( + is_equal_vec_air, + ( + idx.clone() + .into_iter() + .map(Val::::from_canonical_u32) + .collect(), + x.clone() + .into_iter() + .map(Val::::from_canonical_u32) + .collect(), + ), + ) + .flatten(); + + row.push(is_equal_vec_trace[3 * *idx_len - 1]); + let send_row = is_equal_vec_trace[3 * *idx_len - 1] * is_alloc; + row.push(send_row); + + row.extend_from_slice(&is_equal_vec_trace[2 * *idx_len..]); + } PageIndexScanInputAir::Gt { idx_len, is_less_than_tuple_air, diff --git a/chips/src/single_page_index_scan/tests.rs b/chips/src/single_page_index_scan/tests.rs index 1bc9a56187..07cae30155 100644 --- a/chips/src/single_page_index_scan/tests.rs +++ b/chips/src/single_page_index_scan/tests.rs @@ -255,7 +255,93 @@ fn test_single_page_index_scan_gt() { let page_output = page_controller.gen_output(page.clone(), x.clone(), idx_len, page_width, Comp::Gt); - println!("page_output: {:?}", page_output); + + index_scan_test( + &engine, + page, + page_output, + x, + idx_len, + data_len, + limb_bits, + decomp, + &mut page_controller, + &mut trace_builder, + &partial_pk, + ) + .expect("Verification failed"); +} + +#[test] +fn test_single_page_index_scan_eq() { + let bus_index: usize = 0; + let idx_len: usize = 2; + let data_len: usize = 3; + let decomp: usize = 8; + let limb_bits: Vec = vec![16, 16]; + let range_max: u32 = 1 << decomp; + + let log_page_height = 1; + let page_height = 1 << log_page_height; + let page_width = 1 + idx_len + data_len; + + let mut page_controller: PageController = PageController::new( + bus_index, + idx_len, + data_len, + range_max, + limb_bits.clone(), + decomp, + Comp::Eq, + ); + + let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); + + let mut keygen_builder = MultiStarkKeygenBuilder::new(&engine.config); + + let input_page_ptr = keygen_builder.add_cached_main_matrix(page_width); + let output_page_ptr = keygen_builder.add_cached_main_matrix(page_width); + let input_page_aux_ptr = keygen_builder.add_main_matrix(page_controller.input_chip.aux_width()); + let output_page_aux_ptr = + keygen_builder.add_main_matrix(page_controller.output_chip.aux_width()); + let range_checker_ptr = + keygen_builder.add_main_matrix(page_controller.range_checker.air_width()); + + keygen_builder.add_partitioned_air( + &page_controller.input_chip.air, + page_height, + idx_len, + vec![input_page_ptr, input_page_aux_ptr], + ); + + keygen_builder.add_partitioned_air( + &page_controller.output_chip.air, + page_height, + 0, + vec![output_page_ptr, output_page_aux_ptr], + ); + + keygen_builder.add_partitioned_air( + &page_controller.range_checker.air, + 1 << decomp, + 0, + vec![range_checker_ptr], + ); + + let partial_pk = keygen_builder.generate_partial_pk(); + + let prover = MultiTraceStarkProver::new(&engine.config); + let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); + + let page: Vec> = vec![ + vec![1, 443, 376, 22278, 13998, 58327], + vec![1, 2883, 7769, 51171, 3989, 12770], + ]; + + let x: Vec = vec![2883, 7769]; + + let page_output = + page_controller.gen_output(page.clone(), x.clone(), idx_len, page_width, Comp::Eq); index_scan_test( &engine, From e65102498db477dfd2e924b2dfd507cafd4074dc Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Tue, 11 Jun 2024 16:41:23 -0400 Subject: [PATCH 35/46] feat: page index scan for less than or equal to predicate --- .../page_controller/mod.rs | 35 +++++ .../page_index_scan_input/air.rs | 140 +++++++++++++++++- .../page_index_scan_input/bridge.rs | 139 ++++++++++++++--- .../page_index_scan_input/columns.rs | 52 +++++++ .../page_index_scan_input/mod.rs | 59 ++++++-- .../page_index_scan_input/trace.rs | 48 ++++++ chips/src/single_page_index_scan/tests.rs | 97 +++++++++++- 7 files changed, 535 insertions(+), 35 deletions(-) diff --git a/chips/src/single_page_index_scan/page_controller/mod.rs b/chips/src/single_page_index_scan/page_controller/mod.rs index 9cbc854de0..79077f5c12 100644 --- a/chips/src/single_page_index_scan/page_controller/mod.rs +++ b/chips/src/single_page_index_scan/page_controller/mod.rs @@ -151,6 +151,40 @@ where ); } } + Comp::Lte => { + let mut less_than = false; + for (&idx_val, &x_val) in idx.iter().zip(x.iter()) { + use std::cmp::Ordering; + match idx_val.cmp(&x_val) { + Ordering::Less => { + less_than = true; + break; + } + Ordering::Greater => { + break; + } + Ordering::Equal => {} + } + } + + let mut eq = true; + for (&idx_val, &x_val) in idx.iter().zip(x.iter()) { + if idx_val != x_val { + eq = false; + break; + } + } + + if less_than || eq { + output.push( + vec![is_alloc] + .into_iter() + .chain(idx.iter().cloned()) + .chain(data.iter().cloned()) + .collect(), + ); + } + } Comp::Eq => { let mut eq = true; for (&idx_val, &x_val) in idx.iter().zip(x.iter()) { @@ -226,6 +260,7 @@ where let bus_index = match self.input_chip.air { PageIndexScanInputAir::Lt { bus_index, .. } => bus_index, + PageIndexScanInputAir::Lte { bus_index, .. } => bus_index, PageIndexScanInputAir::Eq { bus_index, .. } => bus_index, PageIndexScanInputAir::Gt { bus_index, .. } => bus_index, }; diff --git a/chips/src/single_page_index_scan/page_index_scan_input/air.rs b/chips/src/single_page_index_scan/page_index_scan_input/air.rs index abf28e9627..0b5b050420 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/air.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/air.rs @@ -30,6 +30,18 @@ impl BaseAir for PageIndexScanInputAir { is_less_than_tuple_air.decomp(), Comp::Lt, ), + PageIndexScanInputAir::Lte { + idx_len, + data_len, + is_less_than_tuple_air, + .. + } => PageIndexScanInputCols::::get_width( + *idx_len, + *data_len, + is_less_than_tuple_air.limb_bits().clone(), + is_less_than_tuple_air.decomp(), + Comp::Lte, + ), // there is no idx_limb_bits or decomp, so we supply an empty vec and 0, respectively PageIndexScanInputAir::Eq { idx_len, data_len, .. @@ -124,6 +136,11 @@ where is_less_than_tuple_cols.aux, ); } + PageIndexScanInputCols::Lte { .. } => { + panic!( + "expected PageIndexScanInputCols::Lt, got PageIndexScanInputCols::Lte" + ); + } PageIndexScanInputCols::Eq { .. } => { panic!( "expected PageIndexScanInputCols::Lt, got PageIndexScanInputCols::Eq" @@ -136,6 +153,117 @@ where } } } + PageIndexScanInputAir::Lte { + idx_len, + data_len, + is_less_than_tuple_air, + is_equal_vec_air, + .. + } => { + let page_main = &builder.partitioned_main()[0].clone(); + let aux_main = &builder.partitioned_main()[1].clone(); + + // get the public value x + let pis = builder.public_values(); + let public_x = pis[..*idx_len].to_vec(); + + let local_page = page_main.row_slice(0); + let local_aux = aux_main.row_slice(0); + let local_vec = local_page + .iter() + .chain(local_aux.iter()) + .cloned() + .collect::>(); + let local = local_vec.as_slice(); + + let local_cols = PageIndexScanInputCols::::from_slice( + local, + *idx_len, + *data_len, + is_less_than_tuple_air.limb_bits().clone(), + is_less_than_tuple_air.decomp(), + Comp::Lte, + ); + + match local_cols { + PageIndexScanInputCols::Lt { .. } => { + panic!( + "expected PageIndexScanInputCols::Lte, got PageIndexScanInputCols::Lt" + ); + } + PageIndexScanInputCols::Lte { + is_alloc, + idx, + x, + less_than_x, + eq_to_x, + satisfies_pred, + send_row, + is_less_than_tuple_aux, + is_equal_vec_aux, + .. + } => { + // here, we are checking if idx <= x + let is_less_than_tuple_cols = IsLessThanTupleCols { + io: IsLessThanTupleIOCols { + x: idx.clone(), + y: x.clone(), + tuple_less_than: less_than_x, + }, + aux: is_less_than_tuple_aux, + }; + + // constrain the indicator that we used to check wheter key < x is correct + SubAir::eval( + is_less_than_tuple_air, + &mut builder.when_transition(), + is_less_than_tuple_cols.io, + is_less_than_tuple_cols.aux, + ); + + // here, we are checking if idx = x + let is_equal_vec_cols = IsEqualVecCols { + io: IsEqualVecIOCols { + x: idx.clone(), + y: x.clone(), + prod: eq_to_x, + }, + aux: is_equal_vec_aux, + }; + + // constrain the indicator that we used to check wheter key = x is correct + SubAir::eval( + is_equal_vec_air, + builder, + is_equal_vec_cols.io, + is_equal_vec_cols.aux, + ); + + // constrain that it satisfies predicate if either less than or equal, and that satisfies is bool + builder.assert_eq(less_than_x + eq_to_x, satisfies_pred); + builder.assert_bool(satisfies_pred); + + // constrain that the public value x is the same as the column x + for (&local_x, &pub_x) in x.iter().zip(public_x.iter()) { + builder.assert_eq(local_x, pub_x); + } + + // constrain that we send the row iff the row is allocated and satisfies the predicate + builder.assert_eq(is_alloc * satisfies_pred, send_row); + builder.assert_bool(send_row); + } + PageIndexScanInputCols::Eq { .. } => { + panic!( + "expected PageIndexScanInputCols::Lte, got PageIndexScanInputCols::Eq" + ); + } + PageIndexScanInputCols::Gt { .. } => { + panic!( + "expected PageIndexScanInputCols::Lte, got PageIndexScanInputCols::Gt" + ); + } + } + } PageIndexScanInputAir::Eq { idx_len, data_len, @@ -170,7 +298,12 @@ where match local_cols { PageIndexScanInputCols::Lt { .. } => { panic!( - "expected PageIndexScanInputCols::Gt, got PageIndexScanInputCols::Lt" + "expected PageIndexScanInputCols::Eq, got PageIndexScanInputCols::Lt" + ); + } + PageIndexScanInputCols::Lte { .. } => { + panic!( + "expected PageIndexScanInputCols::Eq, got PageIndexScanInputCols::Lte" ); } PageIndexScanInputCols::Eq { @@ -253,6 +386,11 @@ where "expected PageIndexScanInputCols::Gt, got PageIndexScanInputCols::Lt" ); } + PageIndexScanInputCols::Lte { .. } => { + panic!( + "expected PageIndexScanInputCols::Gt, got PageIndexScanInputCols::Lte" + ); + } PageIndexScanInputCols::Eq { .. } => { panic!( "expected PageIndexScanInputCols::Gt, got PageIndexScanInputCols::Eq" diff --git a/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs b/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs index 8b7e9a886b..852f9335d4 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs @@ -12,6 +12,8 @@ use super::PageIndexScanInputAir; impl AirBridge for PageIndexScanInputAir { fn sends(&self) -> Vec> { + let mut interactions: Vec> = vec![]; + match &self { PageIndexScanInputAir::Lt { bus_index, @@ -37,8 +39,6 @@ impl AirBridge for PageIndexScanInputAir { Comp::Lt, ); - let mut interactions: Vec> = vec![]; - match cols_numbered { PageIndexScanInputCols::Lt { is_alloc, @@ -83,6 +83,11 @@ impl AirBridge for PageIndexScanInputAir { interactions.append(&mut subchip_interactions); } + PageIndexScanInputCols::Lte { .. } => { + panic!( + "expected PageIndexScanInputCols::Lt, got PageIndexScanInputCols::Lte" + ); + } PageIndexScanInputCols::Eq { .. } => { panic!( "expected PageIndexScanInputCols::Lt, got PageIndexScanInputCols::Eq" @@ -97,20 +102,21 @@ impl AirBridge for PageIndexScanInputAir { interactions } - - PageIndexScanInputAir::Gt { + PageIndexScanInputAir::Lte { bus_index, idx_len, data_len, is_less_than_tuple_air, + .. } => { let num_cols = PageIndexScanInputCols::::get_width( *idx_len, *data_len, is_less_than_tuple_air.limb_bits(), is_less_than_tuple_air.decomp(), - Comp::Gt, + Comp::Lte, ); + let all_cols = (0..num_cols).collect::>(); let cols_numbered = PageIndexScanInputCols::::from_slice( @@ -119,36 +125,29 @@ impl AirBridge for PageIndexScanInputAir { *data_len, is_less_than_tuple_air.limb_bits(), is_less_than_tuple_air.decomp(), - Comp::Gt, + Comp::Lte, ); - let mut interactions: Vec> = vec![]; - match cols_numbered { PageIndexScanInputCols::Lt { .. } => { panic!( - "expected PageIndexScanInputCols::Gt, got PageIndexScanInputCols::Lt" + "expected PageIndexScanInputCols::Lte, got PageIndexScanInputCols::Lt" ); } - PageIndexScanInputCols::Eq { .. } => { - panic!( - "expected PageIndexScanInputCols::Gt, got PageIndexScanInputCols::Eq" - ); - } - PageIndexScanInputCols::Gt { + PageIndexScanInputCols::Lte { is_alloc, idx, data, x, - send_row, satisfies_pred, + send_row, is_less_than_tuple_aux, .. } => { let is_less_than_tuple_cols = IsLessThanTupleCols { io: IsLessThanTupleIOCols { - x: x.clone(), - y: idx.clone(), + x: idx.clone(), + y: x.clone(), tuple_less_than: satisfies_pred, }, aux: is_less_than_tuple_aux, @@ -178,6 +177,16 @@ impl AirBridge for PageIndexScanInputAir { interactions.append(&mut subchip_interactions); } + PageIndexScanInputCols::Eq { .. } => { + panic!( + "expected PageIndexScanInputCols::Lte, got PageIndexScanInputCols::Eq" + ); + } + PageIndexScanInputCols::Gt { .. } => { + panic!( + "expected PageIndexScanInputCols::Lte, got PageIndexScanInputCols::Gt" + ); + } } interactions @@ -208,14 +217,17 @@ impl AirBridge for PageIndexScanInputAir { Comp::Eq, ); - let mut interactions: Vec> = vec![]; - match cols_numbered { PageIndexScanInputCols::Lt { .. } => { panic!( "expected PageIndexScanInputCols::Eq, got PageIndexScanInputCols::Lt" ); } + PageIndexScanInputCols::Lte { .. } => { + panic!( + "expected PageIndexScanInputCols::Eq, got PageIndexScanInputCols::Lte" + ); + } PageIndexScanInputCols::Eq { is_alloc, idx, @@ -247,6 +259,93 @@ impl AirBridge for PageIndexScanInputAir { } } + interactions + } + PageIndexScanInputAir::Gt { + bus_index, + idx_len, + data_len, + is_less_than_tuple_air, + } => { + let num_cols = PageIndexScanInputCols::::get_width( + *idx_len, + *data_len, + is_less_than_tuple_air.limb_bits(), + is_less_than_tuple_air.decomp(), + Comp::Gt, + ); + let all_cols = (0..num_cols).collect::>(); + + let cols_numbered = PageIndexScanInputCols::::from_slice( + &all_cols, + *idx_len, + *data_len, + is_less_than_tuple_air.limb_bits(), + is_less_than_tuple_air.decomp(), + Comp::Gt, + ); + + match cols_numbered { + PageIndexScanInputCols::Lt { .. } => { + panic!( + "expected PageIndexScanInputCols::Gt, got PageIndexScanInputCols::Lt" + ); + } + PageIndexScanInputCols::Lte { .. } => { + panic!( + "expected PageIndexScanInputCols::Gt, got PageIndexScanInputCols::Lte" + ); + } + PageIndexScanInputCols::Eq { .. } => { + panic!( + "expected PageIndexScanInputCols::Gt, got PageIndexScanInputCols::Eq" + ); + } + PageIndexScanInputCols::Gt { + is_alloc, + idx, + data, + x, + send_row, + satisfies_pred, + is_less_than_tuple_aux, + .. + } => { + let is_less_than_tuple_cols = IsLessThanTupleCols { + io: IsLessThanTupleIOCols { + x: x.clone(), + y: idx.clone(), + tuple_less_than: satisfies_pred, + }, + aux: is_less_than_tuple_aux, + }; + + // construct the row to send + let mut cols = vec![]; + cols.push(is_alloc); + cols.extend(idx); + cols.extend(data); + + let virtual_cols = cols + .iter() + .map(|col| VirtualPairCol::single_main(*col)) + .collect::>(); + + interactions.push(Interaction { + fields: virtual_cols, + count: VirtualPairCol::single_main(send_row), + argument_index: *bus_index, + }); + + let mut subchip_interactions = SubAirBridge::::sends( + is_less_than_tuple_air, + is_less_than_tuple_cols, + ); + + interactions.append(&mut subchip_interactions); + } + } + interactions } } diff --git a/chips/src/single_page_index_scan/page_index_scan_input/columns.rs b/chips/src/single_page_index_scan/page_index_scan_input/columns.rs index d4b77c8644..91a62d8348 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/columns.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/columns.rs @@ -15,6 +15,19 @@ pub enum PageIndexScanInputCols { is_less_than_tuple_aux: IsLessThanTupleAuxCols, }, + Lte { + is_alloc: T, + idx: Vec, + data: Vec, + x: Vec, + less_than_x: T, + eq_to_x: T, + satisfies_pred: T, + send_row: T, + is_less_than_tuple_aux: IsLessThanTupleAuxCols, + is_equal_vec_aux: IsEqualVecAuxCols, + }, + Eq { is_alloc: T, idx: Vec, @@ -60,6 +73,34 @@ impl PageIndexScanInputCols { idx_len, ), }, + Comp::Lte => { + let less_than_tuple_aux_width = + IsLessThanTupleAuxCols::::get_width(idx_limb_bits.clone(), decomp, idx_len); + Self::Lte { + is_alloc: slc[0].clone(), + idx: slc[1..idx_len + 1].to_vec(), + data: slc[idx_len + 1..idx_len + data_len + 1].to_vec(), + x: slc[idx_len + data_len + 1..2 * idx_len + data_len + 1].to_vec(), + less_than_x: slc[2 * idx_len + data_len + 1].clone(), + eq_to_x: slc[2 * idx_len + data_len + 2].clone(), + satisfies_pred: slc[2 * idx_len + data_len + 3].clone(), + send_row: slc[2 * idx_len + data_len + 4].clone(), + is_less_than_tuple_aux: IsLessThanTupleAuxCols::from_slice( + &slc[2 * idx_len + data_len + 5 + ..2 * idx_len + data_len + 5 + less_than_tuple_aux_width], + idx_limb_bits, + decomp, + idx_len, + ), + is_equal_vec_aux: IsEqualVecAuxCols { + prods: slc[2 * idx_len + data_len + 5 + less_than_tuple_aux_width + ..3 * idx_len + data_len + 5 + less_than_tuple_aux_width] + .to_vec(), + invs: slc[3 * idx_len + data_len + 5 + less_than_tuple_aux_width..] + .to_vec(), + }, + } + } Comp::Eq => Self::Eq { is_alloc: slc[0].clone(), idx: slc[1..idx_len + 1].to_vec(), @@ -105,6 +146,17 @@ impl PageIndexScanInputCols { + 1 + IsLessThanTupleAuxCols::::get_width(idx_limb_bits, decomp, idx_len) } + Comp::Lte => { + 1 + idx_len + + data_len + + idx_len + + 1 + + 1 + + 1 + + 1 + + IsLessThanTupleAuxCols::::get_width(idx_limb_bits, decomp, idx_len) + + 2 * idx_len + } Comp::Eq => 1 + idx_len + data_len + idx_len + 1 + 1 + 2 * idx_len, Comp::Gt => { 1 + idx_len diff --git a/chips/src/single_page_index_scan/page_index_scan_input/mod.rs b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs index 509af70dc6..4e3a421ad9 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/mod.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs @@ -15,6 +15,7 @@ pub mod trace; pub enum Comp { #[default] Lt, + Lte, Eq, Gt, } @@ -30,6 +31,17 @@ pub enum PageIndexScanInputAir { is_less_than_tuple_air: IsLessThanTupleAir, }, + Lte { + /// The bus index + bus_index: usize, + /// The length of each index in the page table + idx_len: usize, + /// The length of each data entry in the page table + data_len: usize, + + is_less_than_tuple_air: IsLessThanTupleAir, + is_equal_vec_air: IsEqualVecAir, + }, Eq { /// The bus index bus_index: usize, @@ -91,6 +103,22 @@ impl PageIndexScanInputChip { range_checker, cmp, }, + Comp::Lte => Self { + air: PageIndexScanInputAir::Lte { + bus_index, + idx_len, + data_len, + is_less_than_tuple_air: IsLessThanTupleAir::new( + bus_index, + range_max, + idx_limb_bits.clone(), + decomp, + ), + is_equal_vec_air: IsEqualVecAir::new(idx_len), + }, + range_checker, + cmp, + }, Comp::Eq => Self { air: PageIndexScanInputAir::Eq { bus_index, @@ -124,6 +152,9 @@ impl PageIndexScanInputChip { PageIndexScanInputAir::Lt { idx_len, data_len, .. } => 1 + idx_len + data_len, + PageIndexScanInputAir::Lte { + idx_len, data_len, .. + } => 1 + idx_len + data_len, PageIndexScanInputAir::Eq { idx_len, data_len, .. } => 1 + idx_len + data_len, @@ -136,10 +167,9 @@ impl PageIndexScanInputChip { pub fn aux_width(&self) -> usize { match &self.air { PageIndexScanInputAir::Lt { - bus_index: _, idx_len, - data_len: _, is_less_than_tuple_air, + .. } => { idx_len + 1 @@ -150,17 +180,28 @@ impl PageIndexScanInputChip { *idx_len, ) } - PageIndexScanInputAir::Eq { - bus_index: _, + PageIndexScanInputAir::Lte { idx_len, - data_len: _, - is_equal_vec_air: _, - } => idx_len + 1 + 1 + 2 * idx_len, + is_less_than_tuple_air, + .. + } => { + idx_len + + 1 + + 1 + + 1 + + 1 + + IsLessThanTupleAuxCols::::get_width( + is_less_than_tuple_air.limb_bits(), + is_less_than_tuple_air.decomp(), + *idx_len, + ) + + 2 * idx_len + } + PageIndexScanInputAir::Eq { idx_len, .. } => idx_len + 1 + 1 + 2 * idx_len, PageIndexScanInputAir::Gt { - bus_index: _, idx_len, - data_len: _, is_less_than_tuple_air, + .. } => { idx_len + 1 diff --git a/chips/src/single_page_index_scan/page_index_scan_input/trace.rs b/chips/src/single_page_index_scan/page_index_scan_input/trace.rs index 41e9be4474..ee095d230a 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/trace.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/trace.rs @@ -69,6 +69,54 @@ impl PageIndexScanInputChip { row.extend_from_slice(&is_less_than_tuple_trace[2 * *idx_len + 1..]); } + PageIndexScanInputAir::Lte { + idx_len, + is_less_than_tuple_air, + is_equal_vec_air, + .. + } => { + let is_alloc = Val::::from_canonical_u32(page_row[0]); + let idx = page_row[1..1 + *idx_len].to_vec(); + + let x_trace: Vec> = x + .iter() + .map(|x| Val::::from_canonical_u32(*x)) + .collect(); + row.extend(x_trace); + + let is_less_than_tuple_trace: Vec> = + LocalTraceInstructions::generate_trace_row( + is_less_than_tuple_air, + (idx.clone(), x.clone(), self.range_checker.clone()), + ) + .flatten(); + + let is_equal_vec_trace: Vec> = + LocalTraceInstructions::generate_trace_row( + is_equal_vec_air, + ( + idx.clone() + .into_iter() + .map(Val::::from_canonical_u32) + .collect(), + x.clone() + .into_iter() + .map(Val::::from_canonical_u32) + .collect(), + ), + ) + .flatten(); + + row.push(is_less_than_tuple_trace[2 * *idx_len]); + row.push(is_equal_vec_trace[3 * *idx_len - 1]); + let satisfies_pred = is_less_than_tuple_trace[2 * *idx_len] + + is_equal_vec_trace[3 * *idx_len - 1]; + row.push(satisfies_pred); + row.push(satisfies_pred * is_alloc); + + row.extend_from_slice(&is_less_than_tuple_trace[2 * *idx_len + 1..]); + row.extend_from_slice(&is_equal_vec_trace[2 * *idx_len..]); + } PageIndexScanInputAir::Eq { idx_len, is_equal_vec_air, diff --git a/chips/src/single_page_index_scan/tests.rs b/chips/src/single_page_index_scan/tests.rs index 07cae30155..c3d1ae505d 100644 --- a/chips/src/single_page_index_scan/tests.rs +++ b/chips/src/single_page_index_scan/tests.rs @@ -186,7 +186,7 @@ fn test_single_page_index_scan_lt() { } #[test] -fn test_single_page_index_scan_gt() { +fn test_single_page_index_scan_lte() { let bus_index: usize = 0; let idx_len: usize = 2; let data_len: usize = 3; @@ -205,7 +205,7 @@ fn test_single_page_index_scan_gt() { range_max, limb_bits.clone(), decomp, - Comp::Gt, + Comp::Lte, ); let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); @@ -248,13 +248,13 @@ fn test_single_page_index_scan_gt() { let page: Vec> = vec![ vec![1, 443, 376, 22278, 13998, 58327], - vec![1, 2883, 7769, 51171, 3989, 12770], + vec![1, 2177, 5880, 51171, 3989, 12770], ]; let x: Vec = vec![2177, 5880]; let page_output = - page_controller.gen_output(page.clone(), x.clone(), idx_len, page_width, Comp::Gt); + page_controller.gen_output(page.clone(), x.clone(), idx_len, page_width, Comp::Lte); index_scan_test( &engine, @@ -338,7 +338,7 @@ fn test_single_page_index_scan_eq() { vec![1, 2883, 7769, 51171, 3989, 12770], ]; - let x: Vec = vec![2883, 7769]; + let x: Vec = vec![443, 376]; let page_output = page_controller.gen_output(page.clone(), x.clone(), idx_len, page_width, Comp::Eq); @@ -359,6 +359,93 @@ fn test_single_page_index_scan_eq() { .expect("Verification failed"); } +#[test] +fn test_single_page_index_scan_gt() { + let bus_index: usize = 0; + let idx_len: usize = 2; + let data_len: usize = 3; + let decomp: usize = 8; + let limb_bits: Vec = vec![16, 16]; + let range_max: u32 = 1 << decomp; + + let log_page_height = 1; + let page_height = 1 << log_page_height; + let page_width = 1 + idx_len + data_len; + + let mut page_controller: PageController = PageController::new( + bus_index, + idx_len, + data_len, + range_max, + limb_bits.clone(), + decomp, + Comp::Gt, + ); + + let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); + + let mut keygen_builder = MultiStarkKeygenBuilder::new(&engine.config); + + let input_page_ptr = keygen_builder.add_cached_main_matrix(page_width); + let output_page_ptr = keygen_builder.add_cached_main_matrix(page_width); + let input_page_aux_ptr = keygen_builder.add_main_matrix(page_controller.input_chip.aux_width()); + let output_page_aux_ptr = + keygen_builder.add_main_matrix(page_controller.output_chip.aux_width()); + let range_checker_ptr = + keygen_builder.add_main_matrix(page_controller.range_checker.air_width()); + + keygen_builder.add_partitioned_air( + &page_controller.input_chip.air, + page_height, + idx_len, + vec![input_page_ptr, input_page_aux_ptr], + ); + + keygen_builder.add_partitioned_air( + &page_controller.output_chip.air, + page_height, + 0, + vec![output_page_ptr, output_page_aux_ptr], + ); + + keygen_builder.add_partitioned_air( + &page_controller.range_checker.air, + 1 << decomp, + 0, + vec![range_checker_ptr], + ); + + let partial_pk = keygen_builder.generate_partial_pk(); + + let prover = MultiTraceStarkProver::new(&engine.config); + let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); + + let page: Vec> = vec![ + vec![1, 443, 376, 22278, 13998, 58327], + vec![1, 2883, 7769, 51171, 3989, 12770], + ]; + + let x: Vec = vec![2177, 5880]; + + let page_output = + page_controller.gen_output(page.clone(), x.clone(), idx_len, page_width, Comp::Gt); + + index_scan_test( + &engine, + page, + page_output, + x, + idx_len, + data_len, + limb_bits, + decomp, + &mut page_controller, + &mut trace_builder, + &partial_pk, + ) + .expect("Verification failed"); +} + #[test] fn test_single_page_index_scan_wrong_order() { let bus_index: usize = 0; From bf78f3424ffa338a9198a0ae4b75b3495306cf92 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Tue, 11 Jun 2024 17:09:43 -0400 Subject: [PATCH 36/46] feat: page index scan for greater than or equal to predicate --- .../page_controller/mod.rs | 35 +++++ .../page_index_scan_input/air.rs | 148 ++++++++++++++++++ .../page_index_scan_input/bridge.rs | 114 ++++++++++++++ .../page_index_scan_input/columns.rs | 53 +++++-- .../page_index_scan_input/mod.rs | 48 ++++++ .../page_index_scan_input/trace.rs | 48 ++++++ .../page_index_scan_output/mod.rs | 10 -- chips/src/single_page_index_scan/tests.rs | 89 ++++++++++- 8 files changed, 524 insertions(+), 21 deletions(-) diff --git a/chips/src/single_page_index_scan/page_controller/mod.rs b/chips/src/single_page_index_scan/page_controller/mod.rs index 79077f5c12..c88088d81e 100644 --- a/chips/src/single_page_index_scan/page_controller/mod.rs +++ b/chips/src/single_page_index_scan/page_controller/mod.rs @@ -203,6 +203,40 @@ where ); } } + Comp::Gte => { + let mut greater_than = false; + for (&idx_val, &x_val) in idx.iter().zip(x.iter()) { + use std::cmp::Ordering; + match idx_val.cmp(&x_val) { + Ordering::Greater => { + greater_than = true; + break; + } + Ordering::Less => { + break; + } + Ordering::Equal => {} + } + } + + let mut eq = true; + for (&idx_val, &x_val) in idx.iter().zip(x.iter()) { + if idx_val != x_val { + eq = false; + break; + } + } + + if greater_than || eq { + output.push( + vec![is_alloc] + .into_iter() + .chain(idx.iter().cloned()) + .chain(data.iter().cloned()) + .collect(), + ); + } + } Comp::Gt => { let mut greater_than = false; for (&idx_val, &x_val) in idx.iter().zip(x.iter()) { @@ -262,6 +296,7 @@ where PageIndexScanInputAir::Lt { bus_index, .. } => bus_index, PageIndexScanInputAir::Lte { bus_index, .. } => bus_index, PageIndexScanInputAir::Eq { bus_index, .. } => bus_index, + PageIndexScanInputAir::Gte { bus_index, .. } => bus_index, PageIndexScanInputAir::Gt { bus_index, .. } => bus_index, }; diff --git a/chips/src/single_page_index_scan/page_index_scan_input/air.rs b/chips/src/single_page_index_scan/page_index_scan_input/air.rs index 0b5b050420..f81145406f 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/air.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/air.rs @@ -46,6 +46,18 @@ impl BaseAir for PageIndexScanInputAir { PageIndexScanInputAir::Eq { idx_len, data_len, .. } => PageIndexScanInputCols::::get_width(*idx_len, *data_len, vec![], 0, Comp::Eq), + PageIndexScanInputAir::Gte { + idx_len, + data_len, + is_less_than_tuple_air, + .. + } => PageIndexScanInputCols::::get_width( + *idx_len, + *data_len, + is_less_than_tuple_air.limb_bits().clone(), + is_less_than_tuple_air.decomp(), + Comp::Gte, + ), PageIndexScanInputAir::Gt { idx_len, data_len, @@ -146,6 +158,11 @@ where "expected PageIndexScanInputCols::Lt, got PageIndexScanInputCols::Eq" ); } + PageIndexScanInputCols::Gte { .. } => { + panic!( + "expected PageIndexScanInputCols::Lt, got PageIndexScanInputCols::Gte" + ); + } PageIndexScanInputCols::Gt { .. } => { panic!( "expected PageIndexScanInputCols::Lt, got PageIndexScanInputCols::Gt" @@ -257,6 +274,11 @@ where "expected PageIndexScanInputCols::Lte, got PageIndexScanInputCols::Eq" ); } + PageIndexScanInputCols::Gte { .. } => { + panic!( + "expected PageIndexScanInputCols::Lte, got PageIndexScanInputCols::Gte" + ); + } PageIndexScanInputCols::Gt { .. } => { panic!( "expected PageIndexScanInputCols::Lte, got PageIndexScanInputCols::Gt" @@ -342,6 +364,11 @@ where is_equal_vec_cols.aux, ); } + PageIndexScanInputCols::Gte { .. } => { + panic!( + "expected PageIndexScanInputCols::Eq, got PageIndexScanInputCols::Gte" + ); + } PageIndexScanInputCols::Gt { .. } => { panic!( "expected PageIndexScanInputCols::Eq, got PageIndexScanInputCols::Gt" @@ -349,6 +376,122 @@ where } } } + PageIndexScanInputAir::Gte { + idx_len, + data_len, + is_less_than_tuple_air, + is_equal_vec_air, + .. + } => { + let page_main = &builder.partitioned_main()[0].clone(); + let aux_main = &builder.partitioned_main()[1].clone(); + + // get the public value x + let pis = builder.public_values(); + let public_x = pis[..*idx_len].to_vec(); + + let local_page = page_main.row_slice(0); + let local_aux = aux_main.row_slice(0); + let local_vec = local_page + .iter() + .chain(local_aux.iter()) + .cloned() + .collect::>(); + let local = local_vec.as_slice(); + + let local_cols = PageIndexScanInputCols::::from_slice( + local, + *idx_len, + *data_len, + is_less_than_tuple_air.limb_bits().clone(), + is_less_than_tuple_air.decomp(), + Comp::Gte, + ); + + match local_cols { + PageIndexScanInputCols::Lt { .. } => { + panic!( + "expected PageIndexScanInputCols::Gte, got PageIndexScanInputCols::Lt" + ); + } + PageIndexScanInputCols::Lte { .. } => { + panic!( + "expected PageIndexScanInputCols::Gte, got PageIndexScanInputCols::Lte" + ); + } + PageIndexScanInputCols::Eq { .. } => { + panic!( + "expected PageIndexScanInputCols::Gte, got PageIndexScanInputCols::Eq" + ); + } + PageIndexScanInputCols::Gte { + is_alloc, + idx, + x, + greater_than_x, + eq_to_x, + satisfies_pred, + send_row, + is_less_than_tuple_aux, + is_equal_vec_aux, + .. + } => { + // here, we are checking if idx <= x + let is_less_than_tuple_cols = IsLessThanTupleCols { + io: IsLessThanTupleIOCols { + x: x.clone(), + y: idx.clone(), + tuple_less_than: greater_than_x, + }, + aux: is_less_than_tuple_aux, + }; + + // constrain the indicator that we used to check wheter key < x is correct + SubAir::eval( + is_less_than_tuple_air, + &mut builder.when_transition(), + is_less_than_tuple_cols.io, + is_less_than_tuple_cols.aux, + ); + + // here, we are checking if idx = x + let is_equal_vec_cols = IsEqualVecCols { + io: IsEqualVecIOCols { + x: idx.clone(), + y: x.clone(), + prod: eq_to_x, + }, + aux: is_equal_vec_aux, + }; + + // constrain the indicator that we used to check wheter key = x is correct + SubAir::eval( + is_equal_vec_air, + builder, + is_equal_vec_cols.io, + is_equal_vec_cols.aux, + ); + + // constrain that it satisfies predicate if either less than or equal, and that satisfies is bool + builder.assert_eq(greater_than_x + eq_to_x, satisfies_pred); + builder.assert_bool(satisfies_pred); + + // constrain that the public value x is the same as the column x + for (&local_x, &pub_x) in x.iter().zip(public_x.iter()) { + builder.assert_eq(local_x, pub_x); + } + + // constrain that we send the row iff the row is allocated and satisfies the predicate + builder.assert_eq(is_alloc * satisfies_pred, send_row); + builder.assert_bool(send_row); + } + PageIndexScanInputCols::Gt { .. } => { + panic!( + "expected PageIndexScanInputCols::Gte, got PageIndexScanInputCols::Gt" + ); + } + } + } PageIndexScanInputAir::Gt { idx_len, data_len, @@ -396,6 +539,11 @@ where "expected PageIndexScanInputCols::Gt, got PageIndexScanInputCols::Eq" ); } + PageIndexScanInputCols::Gte { .. } => { + panic!( + "expected PageIndexScanInputCols::Gt, got PageIndexScanInputCols::Gte" + ); + } PageIndexScanInputCols::Gt { is_alloc, idx, diff --git a/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs b/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs index 852f9335d4..7d91976829 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs @@ -93,6 +93,11 @@ impl AirBridge for PageIndexScanInputAir { "expected PageIndexScanInputCols::Lt, got PageIndexScanInputCols::Eq" ); } + PageIndexScanInputCols::Gte { .. } => { + panic!( + "expected PageIndexScanInputCols::Lt, got PageIndexScanInputCols::Gte" + ); + } PageIndexScanInputCols::Gt { .. } => { panic!( "expected PageIndexScanInputCols::Lt, got PageIndexScanInputCols::Gt" @@ -182,6 +187,11 @@ impl AirBridge for PageIndexScanInputAir { "expected PageIndexScanInputCols::Lte, got PageIndexScanInputCols::Eq" ); } + PageIndexScanInputCols::Gte { .. } => { + panic!( + "expected PageIndexScanInputCols::Lte, got PageIndexScanInputCols::Gte" + ); + } PageIndexScanInputCols::Gt { .. } => { panic!( "expected PageIndexScanInputCols::Lte, got PageIndexScanInputCols::Gt" @@ -252,6 +262,11 @@ impl AirBridge for PageIndexScanInputAir { argument_index: *bus_index, }); } + PageIndexScanInputCols::Gte { .. } => { + panic!( + "expected PageIndexScanInputCols::Eq, got PageIndexScanInputCols::Gte" + ); + } PageIndexScanInputCols::Gt { .. } => { panic!( "expected PageIndexScanInputCols::Eq, got PageIndexScanInputCols::Gt" @@ -261,6 +276,100 @@ impl AirBridge for PageIndexScanInputAir { interactions } + PageIndexScanInputAir::Gte { + bus_index, + idx_len, + data_len, + is_less_than_tuple_air, + .. + } => { + let num_cols = PageIndexScanInputCols::::get_width( + *idx_len, + *data_len, + is_less_than_tuple_air.limb_bits(), + is_less_than_tuple_air.decomp(), + Comp::Gte, + ); + + let all_cols = (0..num_cols).collect::>(); + + let cols_numbered = PageIndexScanInputCols::::from_slice( + &all_cols, + *idx_len, + *data_len, + is_less_than_tuple_air.limb_bits(), + is_less_than_tuple_air.decomp(), + Comp::Gte, + ); + + match cols_numbered { + PageIndexScanInputCols::Lt { .. } => { + panic!( + "expected PageIndexScanInputCols::Gte, got PageIndexScanInputCols::Lt" + ); + } + PageIndexScanInputCols::Lte { .. } => { + panic!( + "expected PageIndexScanInputCols::Gte, got PageIndexScanInputCols::Lte" + ); + } + PageIndexScanInputCols::Eq { .. } => { + panic!( + "expected PageIndexScanInputCols::Gte, got PageIndexScanInputCols::Eq" + ); + } + PageIndexScanInputCols::Gte { + is_alloc, + idx, + data, + x, + satisfies_pred, + send_row, + is_less_than_tuple_aux, + .. + } => { + let is_less_than_tuple_cols = IsLessThanTupleCols { + io: IsLessThanTupleIOCols { + x: x.clone(), + y: idx.clone(), + tuple_less_than: satisfies_pred, + }, + aux: is_less_than_tuple_aux, + }; + + // construct the row to send + let mut cols = vec![]; + cols.push(is_alloc); + cols.extend(idx); + cols.extend(data); + + let virtual_cols = cols + .iter() + .map(|col| VirtualPairCol::single_main(*col)) + .collect::>(); + + interactions.push(Interaction { + fields: virtual_cols, + count: VirtualPairCol::single_main(send_row), + argument_index: *bus_index, + }); + + let mut subchip_interactions = SubAirBridge::::sends( + is_less_than_tuple_air, + is_less_than_tuple_cols, + ); + + interactions.append(&mut subchip_interactions); + } + PageIndexScanInputCols::Gt { .. } => { + panic!( + "expected PageIndexScanInputCols::Gte, got PageIndexScanInputCols::Gt" + ); + } + } + + interactions + } PageIndexScanInputAir::Gt { bus_index, idx_len, @@ -301,6 +410,11 @@ impl AirBridge for PageIndexScanInputAir { "expected PageIndexScanInputCols::Gt, got PageIndexScanInputCols::Eq" ); } + PageIndexScanInputCols::Gte { .. } => { + panic!( + "expected PageIndexScanInputCols::Gt, got PageIndexScanInputCols::Gte" + ); + } PageIndexScanInputCols::Gt { is_alloc, idx, diff --git a/chips/src/single_page_index_scan/page_index_scan_input/columns.rs b/chips/src/single_page_index_scan/page_index_scan_input/columns.rs index 91a62d8348..3f0dd696a1 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/columns.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/columns.rs @@ -38,6 +38,19 @@ pub enum PageIndexScanInputCols { is_equal_vec_aux: IsEqualVecAuxCols, }, + Gte { + is_alloc: T, + idx: Vec, + data: Vec, + x: Vec, + greater_than_x: T, + eq_to_x: T, + satisfies_pred: T, + send_row: T, + is_less_than_tuple_aux: IsLessThanTupleAuxCols, + is_equal_vec_aux: IsEqualVecAuxCols, + }, + Gt { is_alloc: T, idx: Vec, @@ -113,6 +126,34 @@ impl PageIndexScanInputCols { invs: slc[3 * idx_len + data_len + 3..].to_vec(), }, }, + Comp::Gte => { + let less_than_tuple_aux_width = + IsLessThanTupleAuxCols::::get_width(idx_limb_bits.clone(), decomp, idx_len); + Self::Gte { + is_alloc: slc[0].clone(), + idx: slc[1..idx_len + 1].to_vec(), + data: slc[idx_len + 1..idx_len + data_len + 1].to_vec(), + x: slc[idx_len + data_len + 1..2 * idx_len + data_len + 1].to_vec(), + greater_than_x: slc[2 * idx_len + data_len + 1].clone(), + eq_to_x: slc[2 * idx_len + data_len + 2].clone(), + satisfies_pred: slc[2 * idx_len + data_len + 3].clone(), + send_row: slc[2 * idx_len + data_len + 4].clone(), + is_less_than_tuple_aux: IsLessThanTupleAuxCols::from_slice( + &slc[2 * idx_len + data_len + 5 + ..2 * idx_len + data_len + 5 + less_than_tuple_aux_width], + idx_limb_bits, + decomp, + idx_len, + ), + is_equal_vec_aux: IsEqualVecAuxCols { + prods: slc[2 * idx_len + data_len + 5 + less_than_tuple_aux_width + ..3 * idx_len + data_len + 5 + less_than_tuple_aux_width] + .to_vec(), + invs: slc[3 * idx_len + data_len + 5 + less_than_tuple_aux_width..] + .to_vec(), + }, + } + } Comp::Gt => Self::Gt { is_alloc: slc[0].clone(), idx: slc[1..idx_len + 1].to_vec(), @@ -138,7 +179,7 @@ impl PageIndexScanInputCols { cmp: Comp, ) -> usize { match cmp { - Comp::Lt => { + Comp::Lt | Comp::Gt => { 1 + idx_len + data_len + idx_len @@ -146,7 +187,7 @@ impl PageIndexScanInputCols { + 1 + IsLessThanTupleAuxCols::::get_width(idx_limb_bits, decomp, idx_len) } - Comp::Lte => { + Comp::Lte | Comp::Gte => { 1 + idx_len + data_len + idx_len @@ -158,14 +199,6 @@ impl PageIndexScanInputCols { + 2 * idx_len } Comp::Eq => 1 + idx_len + data_len + idx_len + 1 + 1 + 2 * idx_len, - Comp::Gt => { - 1 + idx_len - + data_len - + idx_len - + 1 - + 1 - + IsLessThanTupleAuxCols::::get_width(idx_limb_bits, decomp, idx_len) - } } } } diff --git a/chips/src/single_page_index_scan/page_index_scan_input/mod.rs b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs index 4e3a421ad9..a9b55084bf 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/mod.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs @@ -17,6 +17,7 @@ pub enum Comp { Lt, Lte, Eq, + Gte, Gt, } @@ -52,6 +53,17 @@ pub enum PageIndexScanInputAir { is_equal_vec_air: IsEqualVecAir, }, + Gte { + /// The bus index + bus_index: usize, + /// The length of each index in the page table + idx_len: usize, + /// The length of each data entry in the page table + data_len: usize, + + is_less_than_tuple_air: IsLessThanTupleAir, + is_equal_vec_air: IsEqualVecAir, + }, Gt { /// The bus index bus_index: usize, @@ -129,6 +141,22 @@ impl PageIndexScanInputChip { range_checker, cmp, }, + Comp::Gte => Self { + air: PageIndexScanInputAir::Gte { + bus_index, + idx_len, + data_len, + is_less_than_tuple_air: IsLessThanTupleAir::new( + bus_index, + range_max, + idx_limb_bits.clone(), + decomp, + ), + is_equal_vec_air: IsEqualVecAir::new(idx_len), + }, + range_checker, + cmp, + }, Comp::Gt => Self { air: PageIndexScanInputAir::Gt { bus_index, @@ -158,6 +186,9 @@ impl PageIndexScanInputChip { PageIndexScanInputAir::Eq { idx_len, data_len, .. } => 1 + idx_len + data_len, + PageIndexScanInputAir::Gte { + idx_len, data_len, .. + } => 1 + idx_len + data_len, PageIndexScanInputAir::Gt { idx_len, data_len, .. } => 1 + idx_len + data_len, @@ -198,6 +229,23 @@ impl PageIndexScanInputChip { + 2 * idx_len } PageIndexScanInputAir::Eq { idx_len, .. } => idx_len + 1 + 1 + 2 * idx_len, + PageIndexScanInputAir::Gte { + idx_len, + is_less_than_tuple_air, + .. + } => { + idx_len + + 1 + + 1 + + 1 + + 1 + + IsLessThanTupleAuxCols::::get_width( + is_less_than_tuple_air.limb_bits(), + is_less_than_tuple_air.decomp(), + *idx_len, + ) + + 2 * idx_len + } PageIndexScanInputAir::Gt { idx_len, is_less_than_tuple_air, diff --git a/chips/src/single_page_index_scan/page_index_scan_input/trace.rs b/chips/src/single_page_index_scan/page_index_scan_input/trace.rs index ee095d230a..01f4c054d0 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/trace.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/trace.rs @@ -153,6 +153,54 @@ impl PageIndexScanInputChip { row.extend_from_slice(&is_equal_vec_trace[2 * *idx_len..]); } + PageIndexScanInputAir::Gte { + idx_len, + is_less_than_tuple_air, + is_equal_vec_air, + .. + } => { + let is_alloc = Val::::from_canonical_u32(page_row[0]); + let idx = page_row[1..1 + *idx_len].to_vec(); + + let x_trace: Vec> = x + .iter() + .map(|x| Val::::from_canonical_u32(*x)) + .collect(); + row.extend(x_trace); + + let is_less_than_tuple_trace: Vec> = + LocalTraceInstructions::generate_trace_row( + is_less_than_tuple_air, + (x.clone(), idx.clone(), self.range_checker.clone()), + ) + .flatten(); + + let is_equal_vec_trace: Vec> = + LocalTraceInstructions::generate_trace_row( + is_equal_vec_air, + ( + idx.clone() + .into_iter() + .map(Val::::from_canonical_u32) + .collect(), + x.clone() + .into_iter() + .map(Val::::from_canonical_u32) + .collect(), + ), + ) + .flatten(); + + row.push(is_less_than_tuple_trace[2 * *idx_len]); + row.push(is_equal_vec_trace[3 * *idx_len - 1]); + let satisfies_pred = is_less_than_tuple_trace[2 * *idx_len] + + is_equal_vec_trace[3 * *idx_len - 1]; + row.push(satisfies_pred); + row.push(satisfies_pred * is_alloc); + + row.extend_from_slice(&is_less_than_tuple_trace[2 * *idx_len + 1..]); + row.extend_from_slice(&is_equal_vec_trace[2 * *idx_len..]); + } PageIndexScanInputAir::Gt { idx_len, is_less_than_tuple_air, diff --git a/chips/src/single_page_index_scan/page_index_scan_output/mod.rs b/chips/src/single_page_index_scan/page_index_scan_output/mod.rs index df68aa6177..0986ac6ab8 100644 --- a/chips/src/single_page_index_scan/page_index_scan_output/mod.rs +++ b/chips/src/single_page_index_scan/page_index_scan_output/mod.rs @@ -12,16 +12,6 @@ pub mod bridge; pub mod columns; pub mod trace; -#[derive(Default)] -pub enum Comp { - #[default] - Lt, - Lte, - Eq, - Gte, - Gt, -} - #[derive(Default, Getters)] pub struct PageIndexScanOutputAir { /// The bus index for sends to range chip diff --git a/chips/src/single_page_index_scan/tests.rs b/chips/src/single_page_index_scan/tests.rs index c3d1ae505d..b3e46e694c 100644 --- a/chips/src/single_page_index_scan/tests.rs +++ b/chips/src/single_page_index_scan/tests.rs @@ -359,6 +359,93 @@ fn test_single_page_index_scan_eq() { .expect("Verification failed"); } +#[test] +fn test_single_page_index_scan_gte() { + let bus_index: usize = 0; + let idx_len: usize = 2; + let data_len: usize = 3; + let decomp: usize = 8; + let limb_bits: Vec = vec![16, 16]; + let range_max: u32 = 1 << decomp; + + let log_page_height = 1; + let page_height = 1 << log_page_height; + let page_width = 1 + idx_len + data_len; + + let mut page_controller: PageController = PageController::new( + bus_index, + idx_len, + data_len, + range_max, + limb_bits.clone(), + decomp, + Comp::Gte, + ); + + let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); + + let mut keygen_builder = MultiStarkKeygenBuilder::new(&engine.config); + + let input_page_ptr = keygen_builder.add_cached_main_matrix(page_width); + let output_page_ptr = keygen_builder.add_cached_main_matrix(page_width); + let input_page_aux_ptr = keygen_builder.add_main_matrix(page_controller.input_chip.aux_width()); + let output_page_aux_ptr = + keygen_builder.add_main_matrix(page_controller.output_chip.aux_width()); + let range_checker_ptr = + keygen_builder.add_main_matrix(page_controller.range_checker.air_width()); + + keygen_builder.add_partitioned_air( + &page_controller.input_chip.air, + page_height, + idx_len, + vec![input_page_ptr, input_page_aux_ptr], + ); + + keygen_builder.add_partitioned_air( + &page_controller.output_chip.air, + page_height, + 0, + vec![output_page_ptr, output_page_aux_ptr], + ); + + keygen_builder.add_partitioned_air( + &page_controller.range_checker.air, + 1 << decomp, + 0, + vec![range_checker_ptr], + ); + + let partial_pk = keygen_builder.generate_partial_pk(); + + let prover = MultiTraceStarkProver::new(&engine.config); + let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); + + let page: Vec> = vec![ + vec![1, 2177, 5880, 22278, 13998, 58327], + vec![1, 2883, 7769, 51171, 3989, 12770], + ]; + + let x: Vec = vec![2177, 5880]; + + let page_output = + page_controller.gen_output(page.clone(), x.clone(), idx_len, page_width, Comp::Gte); + + index_scan_test( + &engine, + page, + page_output, + x, + idx_len, + data_len, + limb_bits, + decomp, + &mut page_controller, + &mut trace_builder, + &partial_pk, + ) + .expect("Verification failed"); +} + #[test] fn test_single_page_index_scan_gt() { let bus_index: usize = 0; @@ -421,7 +508,7 @@ fn test_single_page_index_scan_gt() { let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); let page: Vec> = vec![ - vec![1, 443, 376, 22278, 13998, 58327], + vec![1, 2203, 376, 22278, 13998, 58327], vec![1, 2883, 7769, 51171, 3989, 12770], ]; From a13547d6b306d039011c3a7862e49a6c6c2bb1c3 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Tue, 11 Jun 2024 17:41:27 -0400 Subject: [PATCH 37/46] chore: cleanup some branches --- .../page_index_scan_input/air.rs | 34 ++++--------- .../page_index_scan_input/mod.rs | 51 ++++++------------- 2 files changed, 26 insertions(+), 59 deletions(-) diff --git a/chips/src/single_page_index_scan/page_index_scan_input/air.rs b/chips/src/single_page_index_scan/page_index_scan_input/air.rs index f81145406f..09ac9873fa 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/air.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/air.rs @@ -23,14 +23,8 @@ impl BaseAir for PageIndexScanInputAir { data_len, is_less_than_tuple_air, .. - } => PageIndexScanInputCols::::get_width( - *idx_len, - *data_len, - is_less_than_tuple_air.limb_bits().clone(), - is_less_than_tuple_air.decomp(), - Comp::Lt, - ), - PageIndexScanInputAir::Lte { + } + | PageIndexScanInputAir::Gt { idx_len, data_len, is_less_than_tuple_air, @@ -40,25 +34,15 @@ impl BaseAir for PageIndexScanInputAir { *data_len, is_less_than_tuple_air.limb_bits().clone(), is_less_than_tuple_air.decomp(), - Comp::Lte, + Comp::Lt, ), - // there is no idx_limb_bits or decomp, so we supply an empty vec and 0, respectively - PageIndexScanInputAir::Eq { - idx_len, data_len, .. - } => PageIndexScanInputCols::::get_width(*idx_len, *data_len, vec![], 0, Comp::Eq), - PageIndexScanInputAir::Gte { + PageIndexScanInputAir::Lte { idx_len, data_len, is_less_than_tuple_air, .. - } => PageIndexScanInputCols::::get_width( - *idx_len, - *data_len, - is_less_than_tuple_air.limb_bits().clone(), - is_less_than_tuple_air.decomp(), - Comp::Gte, - ), - PageIndexScanInputAir::Gt { + } + | PageIndexScanInputAir::Gte { idx_len, data_len, is_less_than_tuple_air, @@ -68,8 +52,12 @@ impl BaseAir for PageIndexScanInputAir { *data_len, is_less_than_tuple_air.limb_bits().clone(), is_less_than_tuple_air.decomp(), - Comp::Gt, + Comp::Lte, ), + // there is no idx_limb_bits or decomp, so we supply an empty vec and 0, respectively + PageIndexScanInputAir::Eq { + idx_len, data_len, .. + } => PageIndexScanInputCols::::get_width(*idx_len, *data_len, vec![], 0, Comp::Eq), } } } diff --git a/chips/src/single_page_index_scan/page_index_scan_input/mod.rs b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs index a9b55084bf..e5ef562cca 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/mod.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs @@ -179,17 +179,17 @@ impl PageIndexScanInputChip { match &self.air { PageIndexScanInputAir::Lt { idx_len, data_len, .. - } => 1 + idx_len + data_len, - PageIndexScanInputAir::Lte { + } + | PageIndexScanInputAir::Lte { idx_len, data_len, .. - } => 1 + idx_len + data_len, - PageIndexScanInputAir::Eq { + } + | PageIndexScanInputAir::Eq { idx_len, data_len, .. - } => 1 + idx_len + data_len, - PageIndexScanInputAir::Gte { + } + | PageIndexScanInputAir::Gte { idx_len, data_len, .. - } => 1 + idx_len + data_len, - PageIndexScanInputAir::Gt { + } + | PageIndexScanInputAir::Gt { idx_len, data_len, .. } => 1 + idx_len + data_len, } @@ -201,24 +201,13 @@ impl PageIndexScanInputChip { idx_len, is_less_than_tuple_air, .. - } => { - idx_len - + 1 - + 1 - + IsLessThanTupleAuxCols::::get_width( - is_less_than_tuple_air.limb_bits(), - is_less_than_tuple_air.decomp(), - *idx_len, - ) } - PageIndexScanInputAir::Lte { + | PageIndexScanInputAir::Gt { idx_len, is_less_than_tuple_air, .. } => { idx_len - + 1 - + 1 + 1 + 1 + IsLessThanTupleAuxCols::::get_width( @@ -226,32 +215,20 @@ impl PageIndexScanInputChip { is_less_than_tuple_air.decomp(), *idx_len, ) - + 2 * idx_len } - PageIndexScanInputAir::Eq { idx_len, .. } => idx_len + 1 + 1 + 2 * idx_len, - PageIndexScanInputAir::Gte { + PageIndexScanInputAir::Lte { idx_len, is_less_than_tuple_air, .. - } => { - idx_len - + 1 - + 1 - + 1 - + 1 - + IsLessThanTupleAuxCols::::get_width( - is_less_than_tuple_air.limb_bits(), - is_less_than_tuple_air.decomp(), - *idx_len, - ) - + 2 * idx_len } - PageIndexScanInputAir::Gt { + | PageIndexScanInputAir::Gte { idx_len, is_less_than_tuple_air, .. } => { idx_len + + 1 + + 1 + 1 + 1 + IsLessThanTupleAuxCols::::get_width( @@ -259,7 +236,9 @@ impl PageIndexScanInputChip { is_less_than_tuple_air.decomp(), *idx_len, ) + + 2 * idx_len } + PageIndexScanInputAir::Eq { idx_len, .. } => idx_len + 1 + 1 + 2 * idx_len, } } From 50e150be2ebb999ec5f4126cfaa0c3990dae4952 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Wed, 12 Jun 2024 15:37:10 -0400 Subject: [PATCH 38/46] chore: refactor code to reduce repetition --- chips/src/is_equal_vec/columns.rs | 12 + .../page_controller/mod.rs | 13 +- .../page_index_scan_input/air.rs | 719 +++++++----------- .../page_index_scan_input/bridge.rs | 585 ++++---------- .../page_index_scan_input/columns.rs | 201 +++-- .../page_index_scan_input/mod.rs | 240 +++--- .../page_index_scan_input/trace.rs | 154 ++-- 7 files changed, 652 insertions(+), 1272 deletions(-) diff --git a/chips/src/is_equal_vec/columns.rs b/chips/src/is_equal_vec/columns.rs index 314a2272f6..9875d44317 100644 --- a/chips/src/is_equal_vec/columns.rs +++ b/chips/src/is_equal_vec/columns.rs @@ -11,6 +11,18 @@ pub struct IsEqualVecAuxCols { pub invs: Vec, } +impl IsEqualVecAuxCols { + pub fn flatten(&self) -> Vec { + self.prods.iter().chain(self.invs.iter()).cloned().collect() + } + + pub fn from_slice(slc: &[T], vec_len: usize) -> Self { + let prods = slc[0..vec_len].to_vec(); + let invs = slc[vec_len..2 * vec_len].to_vec(); + Self { prods, invs } + } +} + #[derive(Default)] pub struct IsEqualVecCols { pub io: IsEqualVecIOCols, diff --git a/chips/src/single_page_index_scan/page_controller/mod.rs b/chips/src/single_page_index_scan/page_controller/mod.rs index c88088d81e..9679b0584d 100644 --- a/chips/src/single_page_index_scan/page_controller/mod.rs +++ b/chips/src/single_page_index_scan/page_controller/mod.rs @@ -8,10 +8,7 @@ use p3_field::{AbstractField, PrimeField, PrimeField64}; use p3_matrix::dense::DenseMatrix; use p3_uni_stark::{StarkGenericConfig, Val}; -use crate::{ - range_gate::RangeCheckerGateChip, - single_page_index_scan::page_index_scan_input::PageIndexScanInputAir, -}; +use crate::range_gate::RangeCheckerGateChip; use super::{ page_index_scan_input::{Comp, PageIndexScanInputChip}, @@ -292,13 +289,7 @@ where assert!(!page_input.is_empty()); - let bus_index = match self.input_chip.air { - PageIndexScanInputAir::Lt { bus_index, .. } => bus_index, - PageIndexScanInputAir::Lte { bus_index, .. } => bus_index, - PageIndexScanInputAir::Eq { bus_index, .. } => bus_index, - PageIndexScanInputAir::Gte { bus_index, .. } => bus_index, - PageIndexScanInputAir::Gt { bus_index, .. } => bus_index, - }; + let bus_index = self.input_chip.air.bus_index; self.input_chip = PageIndexScanInputChip::new( bus_index, diff --git a/chips/src/single_page_index_scan/page_index_scan_input/air.rs b/chips/src/single_page_index_scan/page_index_scan_input/air.rs index 09ac9873fa..50c963e494 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/air.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/air.rs @@ -4,12 +4,21 @@ use p3_field::Field; use p3_matrix::Matrix; use crate::{ - is_equal_vec::columns::{IsEqualVecCols, IsEqualVecIOCols}, - is_less_than_tuple::columns::{IsLessThanTupleCols, IsLessThanTupleIOCols}, + is_equal_vec::columns::{IsEqualVecAuxCols, IsEqualVecCols, IsEqualVecIOCols}, + is_less_than_tuple::columns::{ + IsLessThanTupleAuxCols, IsLessThanTupleCols, IsLessThanTupleIOCols, + }, sub_chip::{AirConfig, SubAir}, }; -use super::{columns::PageIndexScanInputCols, Comp, PageIndexScanInputAir}; +use super::{ + columns::{ + EqCompAuxCols, NonStrictCompAuxCols, PageIndexScanInputAuxCols, PageIndexScanInputCols, + StrictCompAuxCols, + }, + Comp, EqCompAir, NonStrictCompAir, PageIndexScanInputAir, PageIndexScanInputAirVariants, + StrictCompAir, +}; impl AirConfig for PageIndexScanInputAir { type Cols = PageIndexScanInputCols; @@ -17,47 +26,44 @@ impl AirConfig for PageIndexScanInputAir { impl BaseAir for PageIndexScanInputAir { fn width(&self) -> usize { - match &self { - PageIndexScanInputAir::Lt { - idx_len, - data_len, + match &self.subair { + PageIndexScanInputAirVariants::Lt(StrictCompAir { is_less_than_tuple_air, .. - } - | PageIndexScanInputAir::Gt { - idx_len, - data_len, + }) + | PageIndexScanInputAirVariants::Gt(StrictCompAir { is_less_than_tuple_air, .. - } => PageIndexScanInputCols::::get_width( - *idx_len, - *data_len, + }) => PageIndexScanInputCols::::get_width( + self.idx_len, + self.data_len, is_less_than_tuple_air.limb_bits().clone(), is_less_than_tuple_air.decomp(), Comp::Lt, ), - PageIndexScanInputAir::Lte { - idx_len, - data_len, + PageIndexScanInputAirVariants::Lte(NonStrictCompAir { is_less_than_tuple_air, .. - } - | PageIndexScanInputAir::Gte { - idx_len, - data_len, + }) + | PageIndexScanInputAirVariants::Gte(NonStrictCompAir { is_less_than_tuple_air, .. - } => PageIndexScanInputCols::::get_width( - *idx_len, - *data_len, + }) => PageIndexScanInputCols::::get_width( + self.idx_len, + self.data_len, is_less_than_tuple_air.limb_bits().clone(), is_less_than_tuple_air.decomp(), Comp::Lte, ), - // there is no idx_limb_bits or decomp, so we supply an empty vec and 0, respectively - PageIndexScanInputAir::Eq { - idx_len, data_len, .. - } => PageIndexScanInputCols::::get_width(*idx_len, *data_len, vec![], 0, Comp::Eq), + PageIndexScanInputAirVariants::Eq(EqCompAir { .. }) => { + PageIndexScanInputCols::::get_width( + self.idx_len, + self.data_len, + vec![], + 0, + Comp::Eq, + ) + } } } } @@ -67,158 +73,162 @@ where AB::M: Clone, { fn eval(&self, builder: &mut AB) { - match &self { - PageIndexScanInputAir::Lt { - idx_len, - data_len, + let page_main = &builder.partitioned_main()[0].clone(); + let aux_main = &builder.partitioned_main()[1].clone(); + + // get the public value x + let pis = builder.public_values(); + let public_x = pis[..self.idx_len].to_vec(); + + let local_page = page_main.row_slice(0); + let local_aux = aux_main.row_slice(0); + let local_vec = local_page + .iter() + .chain(local_aux.iter()) + .cloned() + .collect::>(); + let local = local_vec.as_slice(); + + let (idx_limb_bits, decomp) = match &self.subair { + PageIndexScanInputAirVariants::Lt(StrictCompAir { is_less_than_tuple_air, .. - } => { - let page_main = &builder.partitioned_main()[0].clone(); - let aux_main = &builder.partitioned_main()[1].clone(); - - // get the public value x - let pis = builder.public_values(); - let public_x = pis[..*idx_len].to_vec(); - - let local_page = page_main.row_slice(0); - let local_aux = aux_main.row_slice(0); - let local_vec = local_page - .iter() - .chain(local_aux.iter()) - .cloned() - .collect::>(); - let local = local_vec.as_slice(); - - let local_cols = PageIndexScanInputCols::::from_slice( - local, - *idx_len, - *data_len, - is_less_than_tuple_air.limb_bits().clone(), - is_less_than_tuple_air.decomp(), - Comp::Lt, + }) + | PageIndexScanInputAirVariants::Gt(StrictCompAir { + is_less_than_tuple_air, + .. + }) + | PageIndexScanInputAirVariants::Lte(NonStrictCompAir { + is_less_than_tuple_air, + .. + }) + | PageIndexScanInputAirVariants::Gte(NonStrictCompAir { + is_less_than_tuple_air, + .. + }) => ( + is_less_than_tuple_air.limb_bits(), + is_less_than_tuple_air.decomp(), + ), + PageIndexScanInputAirVariants::Eq(EqCompAir { .. }) => (vec![], 0), + }; + + let cmp = match &self.subair { + PageIndexScanInputAirVariants::Lt(..) => Comp::Lt, + PageIndexScanInputAirVariants::Gt(..) => Comp::Gt, + PageIndexScanInputAirVariants::Lte(..) => Comp::Lte, + PageIndexScanInputAirVariants::Gte(..) => Comp::Gte, + PageIndexScanInputAirVariants::Eq(..) => Comp::Eq, + }; + + let local_cols = PageIndexScanInputCols::::from_slice( + local, + self.idx_len, + self.data_len, + idx_limb_bits.clone(), + decomp, + cmp, + ); + + // constrain that the public value x is the same as the column x + for (&local_x, &pub_x) in local_cols.x.iter().zip(public_x.iter()) { + builder.assert_eq(local_x, pub_x); + } + // constrain that we send the row iff the row is allocated and satisfies the predicate + builder.assert_eq( + local_cols.page_cols.is_alloc * local_cols.satisfies_pred, + local_cols.send_row, + ); + // constrain that satisfies_pred and send_row are boolean indicators + builder.assert_bool(local_cols.satisfies_pred); + builder.assert_bool(local_cols.send_row); + + let is_less_than_tuple_aux_flattened = match &local_cols.aux_cols { + PageIndexScanInputAuxCols::Lt(StrictCompAuxCols { + is_less_than_tuple_aux, + .. + }) + | PageIndexScanInputAuxCols::Gt(StrictCompAuxCols { + is_less_than_tuple_aux, + .. + }) + | PageIndexScanInputAuxCols::Lte(NonStrictCompAuxCols { + is_less_than_tuple_aux, + .. + }) + | PageIndexScanInputAuxCols::Gte(NonStrictCompAuxCols { + is_less_than_tuple_aux, + .. + }) => is_less_than_tuple_aux.flatten(), + PageIndexScanInputAuxCols::Eq(EqCompAuxCols { .. }) => vec![], + }; + + let is_equal_vec_aux_flattened = match &local_cols.aux_cols { + PageIndexScanInputAuxCols::Eq(EqCompAuxCols { + is_equal_vec_aux, .. + }) + | PageIndexScanInputAuxCols::Lte(NonStrictCompAuxCols { + is_equal_vec_aux, .. + }) + | PageIndexScanInputAuxCols::Gte(NonStrictCompAuxCols { + is_equal_vec_aux, .. + }) => is_equal_vec_aux.flatten(), + _ => vec![], + }; + + match &self.subair { + PageIndexScanInputAirVariants::Lt(StrictCompAir { + is_less_than_tuple_air, + .. + }) => { + // here, we are checking if idx < x + let is_less_than_tuple_cols = IsLessThanTupleCols { + io: IsLessThanTupleIOCols { + x: local_cols.page_cols.idx.clone(), + y: local_cols.x.clone(), + tuple_less_than: local_cols.satisfies_pred, + }, + aux: IsLessThanTupleAuxCols::from_slice( + &is_less_than_tuple_aux_flattened, + idx_limb_bits.clone(), + decomp, + self.idx_len, + ), + }; + + // constrain the indicator that we used to check whether key < x is correct + SubAir::eval( + is_less_than_tuple_air, + &mut builder.when_transition(), + is_less_than_tuple_cols.io, + is_less_than_tuple_cols.aux, ); - - match local_cols { - PageIndexScanInputCols::Lt { - is_alloc, - idx, - x, - satisfies_pred, - send_row, - is_less_than_tuple_aux, - .. - } => { - // here, we are checking if idx < x - let is_less_than_tuple_cols = IsLessThanTupleCols { - io: IsLessThanTupleIOCols { - x: idx, - y: x.clone(), - tuple_less_than: satisfies_pred, - }, - aux: is_less_than_tuple_aux, - }; - - // constrain that the public value x is the same as the column x - for (&local_x, &pub_x) in x.iter().zip(public_x.iter()) { - builder.assert_eq(local_x, pub_x); - } - - // constrain that we send the row iff the row is allocated and satisfies the predicate - builder.assert_eq(is_alloc * satisfies_pred, send_row); - builder.assert_bool(send_row); - - // constrain the indicator that we used to check wheter key < x is correct - SubAir::eval( - is_less_than_tuple_air, - &mut builder.when_transition(), - is_less_than_tuple_cols.io, - is_less_than_tuple_cols.aux, - ); - } - PageIndexScanInputCols::Lte { .. } => { - panic!( - "expected PageIndexScanInputCols::Lt, got PageIndexScanInputCols::Lte" - ); - } - PageIndexScanInputCols::Eq { .. } => { - panic!( - "expected PageIndexScanInputCols::Lt, got PageIndexScanInputCols::Eq" - ); - } - PageIndexScanInputCols::Gte { .. } => { - panic!( - "expected PageIndexScanInputCols::Lt, got PageIndexScanInputCols::Gte" - ); - } - PageIndexScanInputCols::Gt { .. } => { - panic!( - "expected PageIndexScanInputCols::Lt, got PageIndexScanInputCols::Gt" - ); - } - } } - PageIndexScanInputAir::Lte { - idx_len, - data_len, + PageIndexScanInputAirVariants::Lte(NonStrictCompAir { is_less_than_tuple_air, is_equal_vec_air, - .. - } => { - let page_main = &builder.partitioned_main()[0].clone(); - let aux_main = &builder.partitioned_main()[1].clone(); - - // get the public value x - let pis = builder.public_values(); - let public_x = pis[..*idx_len].to_vec(); - - let local_page = page_main.row_slice(0); - let local_aux = aux_main.row_slice(0); - let local_vec = local_page - .iter() - .chain(local_aux.iter()) - .cloned() - .collect::>(); - let local = local_vec.as_slice(); - - let local_cols = PageIndexScanInputCols::::from_slice( - local, - *idx_len, - *data_len, - is_less_than_tuple_air.limb_bits().clone(), - is_less_than_tuple_air.decomp(), - Comp::Lte, - ); - - match local_cols { - PageIndexScanInputCols::Lt { .. } => { - panic!( - "expected PageIndexScanInputCols::Lte, got PageIndexScanInputCols::Lt" - ); - } - PageIndexScanInputCols::Lte { - is_alloc, - idx, - x, - less_than_x, - eq_to_x, - satisfies_pred, - send_row, - is_less_than_tuple_aux, - is_equal_vec_aux, + }) => { + match &local_cols.aux_cols { + PageIndexScanInputAuxCols::Lte(NonStrictCompAuxCols { + satisfies_strict, + satisfies_eq, .. - } => { - // here, we are checking if idx <= x + }) => { + // here, we are checking if idx < x let is_less_than_tuple_cols = IsLessThanTupleCols { io: IsLessThanTupleIOCols { - x: idx.clone(), - y: x.clone(), - tuple_less_than: less_than_x, + x: local_cols.page_cols.idx.clone(), + y: local_cols.x.clone(), + tuple_less_than: *satisfies_strict, }, - aux: is_less_than_tuple_aux, + aux: IsLessThanTupleAuxCols::from_slice( + &is_less_than_tuple_aux_flattened, + idx_limb_bits, + decomp, + self.idx_len, + ), }; - // constrain the indicator that we used to check wheter key < x is correct + // constrain the indicator that we used to check whether idx < x is correct SubAir::eval( is_less_than_tuple_air, &mut builder.when_transition(), @@ -229,14 +239,17 @@ where // here, we are checking if idx = x let is_equal_vec_cols = IsEqualVecCols { io: IsEqualVecIOCols { - x: idx.clone(), - y: x.clone(), - prod: eq_to_x, + x: local_cols.page_cols.idx.clone(), + y: local_cols.x.clone(), + prod: *satisfies_eq, }, - aux: is_equal_vec_aux, + aux: IsEqualVecAuxCols::from_slice( + &is_equal_vec_aux_flattened, + self.idx_len, + ), }; - // constrain the indicator that we used to check wheter key = x is correct + // constrain the indicator that we used to check whether idx = x is correct SubAir::eval( is_equal_vec_air, builder, @@ -244,197 +257,60 @@ where is_equal_vec_cols.aux, ); - // constrain that it satisfies predicate if either less than or equal, and that satisfies is bool - builder.assert_eq(less_than_x + eq_to_x, satisfies_pred); - builder.assert_bool(satisfies_pred); - - // constrain that the public value x is the same as the column x - for (&local_x, &pub_x) in x.iter().zip(public_x.iter()) { - builder.assert_eq(local_x, pub_x); - } - - // constrain that we send the row iff the row is allocated and satisfies the predicate - builder.assert_eq(is_alloc * satisfies_pred, send_row); - builder.assert_bool(send_row); - } - PageIndexScanInputCols::Eq { .. } => { - panic!( - "expected PageIndexScanInputCols::Lte, got PageIndexScanInputCols::Eq" - ); - } - PageIndexScanInputCols::Gte { .. } => { - panic!( - "expected PageIndexScanInputCols::Lte, got PageIndexScanInputCols::Gte" - ); - } - PageIndexScanInputCols::Gt { .. } => { - panic!( - "expected PageIndexScanInputCols::Lte, got PageIndexScanInputCols::Gt" + // constrain that satisfies_pred indicates whether idx <= x + builder.assert_eq( + *satisfies_strict + *satisfies_eq, + local_cols.satisfies_pred, ); } + _ => panic!("Unexpected aux cols"), } } - PageIndexScanInputAir::Eq { - idx_len, - data_len, - is_equal_vec_air, - .. - } => { - let page_main = &builder.partitioned_main()[0].clone(); - let aux_main = &builder.partitioned_main()[1].clone(); - - // get the public value x - let pis = builder.public_values(); - let public_x = pis[..*idx_len].to_vec(); - - let local_page = page_main.row_slice(0); - let local_aux = aux_main.row_slice(0); - let local_vec = local_page - .iter() - .chain(local_aux.iter()) - .cloned() - .collect::>(); - let local = local_vec.as_slice(); - - let local_cols = PageIndexScanInputCols::::from_slice( - local, - *idx_len, - *data_len, - vec![], - 0, - Comp::Eq, + PageIndexScanInputAirVariants::Eq(EqCompAir { is_equal_vec_air }) => { + // here, we are checking if idx = x + let is_equal_vec_cols = IsEqualVecCols { + io: IsEqualVecIOCols { + x: local_cols.page_cols.idx.clone(), + y: local_cols.x.clone(), + prod: local_cols.satisfies_pred, + }, + aux: IsEqualVecAuxCols::from_slice(&is_equal_vec_aux_flattened, self.idx_len), + }; + + // constrain the indicator that we used to check whether idx = x is correct + SubAir::eval( + is_equal_vec_air, + builder, + is_equal_vec_cols.io, + is_equal_vec_cols.aux, ); - - match local_cols { - PageIndexScanInputCols::Lt { .. } => { - panic!( - "expected PageIndexScanInputCols::Eq, got PageIndexScanInputCols::Lt" - ); - } - PageIndexScanInputCols::Lte { .. } => { - panic!( - "expected PageIndexScanInputCols::Eq, got PageIndexScanInputCols::Lte" - ); - } - PageIndexScanInputCols::Eq { - is_alloc, - idx, - x, - satisfies_pred, - send_row, - is_equal_vec_aux, - .. - } => { - // here, we are checking if idx = x - let is_equal_vec_cols = IsEqualVecCols { - io: IsEqualVecIOCols { - x: idx, - y: x.clone(), - prod: satisfies_pred, - }, - aux: is_equal_vec_aux, - }; - - // constrain that the public value x is the same as the column x - for (&local_x, &pub_x) in x.iter().zip(public_x.iter()) { - builder.assert_eq(local_x, pub_x); - } - - // constrain that we send the row iff the row is allocated and satisfies the predicate - builder.assert_eq(is_alloc * satisfies_pred, send_row); - builder.assert_bool(send_row); - - // constrain the indicator that we used to check wheter key = x is correct - SubAir::eval( - is_equal_vec_air, - builder, - is_equal_vec_cols.io, - is_equal_vec_cols.aux, - ); - } - PageIndexScanInputCols::Gte { .. } => { - panic!( - "expected PageIndexScanInputCols::Eq, got PageIndexScanInputCols::Gte" - ); - } - PageIndexScanInputCols::Gt { .. } => { - panic!( - "expected PageIndexScanInputCols::Eq, got PageIndexScanInputCols::Gt" - ); - } - } } - PageIndexScanInputAir::Gte { - idx_len, - data_len, + PageIndexScanInputAirVariants::Gte(NonStrictCompAir { is_less_than_tuple_air, is_equal_vec_air, - .. - } => { - let page_main = &builder.partitioned_main()[0].clone(); - let aux_main = &builder.partitioned_main()[1].clone(); - - // get the public value x - let pis = builder.public_values(); - let public_x = pis[..*idx_len].to_vec(); - - let local_page = page_main.row_slice(0); - let local_aux = aux_main.row_slice(0); - let local_vec = local_page - .iter() - .chain(local_aux.iter()) - .cloned() - .collect::>(); - let local = local_vec.as_slice(); - - let local_cols = PageIndexScanInputCols::::from_slice( - local, - *idx_len, - *data_len, - is_less_than_tuple_air.limb_bits().clone(), - is_less_than_tuple_air.decomp(), - Comp::Gte, - ); - - match local_cols { - PageIndexScanInputCols::Lt { .. } => { - panic!( - "expected PageIndexScanInputCols::Gte, got PageIndexScanInputCols::Lt" - ); - } - PageIndexScanInputCols::Lte { .. } => { - panic!( - "expected PageIndexScanInputCols::Gte, got PageIndexScanInputCols::Lte" - ); - } - PageIndexScanInputCols::Eq { .. } => { - panic!( - "expected PageIndexScanInputCols::Gte, got PageIndexScanInputCols::Eq" - ); - } - PageIndexScanInputCols::Gte { - is_alloc, - idx, - x, - greater_than_x, - eq_to_x, - satisfies_pred, - send_row, - is_less_than_tuple_aux, - is_equal_vec_aux, + }) => { + match &local_cols.aux_cols { + PageIndexScanInputAuxCols::Gte(NonStrictCompAuxCols { + satisfies_strict, + satisfies_eq, .. - } => { - // here, we are checking if idx <= x + }) => { + // here, we are checking if idx > x let is_less_than_tuple_cols = IsLessThanTupleCols { io: IsLessThanTupleIOCols { - x: x.clone(), - y: idx.clone(), - tuple_less_than: greater_than_x, + x: local_cols.x.clone(), + y: local_cols.page_cols.idx.clone(), + tuple_less_than: *satisfies_strict, }, - aux: is_less_than_tuple_aux, + aux: IsLessThanTupleAuxCols::from_slice( + &is_less_than_tuple_aux_flattened, + idx_limb_bits, + decomp, + self.idx_len, + ), }; - // constrain the indicator that we used to check wheter key < x is correct + // constrain the indicator that we used to check whether idx > x is correct SubAir::eval( is_less_than_tuple_air, &mut builder.when_transition(), @@ -445,14 +321,17 @@ where // here, we are checking if idx = x let is_equal_vec_cols = IsEqualVecCols { io: IsEqualVecIOCols { - x: idx.clone(), - y: x.clone(), - prod: eq_to_x, + x: local_cols.page_cols.idx.clone(), + y: local_cols.x.clone(), + prod: *satisfies_eq, }, - aux: is_equal_vec_aux, + aux: IsEqualVecAuxCols::from_slice( + &is_equal_vec_aux_flattened, + self.idx_len, + ), }; - // constrain the indicator that we used to check wheter key = x is correct + // constrain the indicator that we used to check whether idx = x is correct SubAir::eval( is_equal_vec_air, builder, @@ -460,115 +339,41 @@ where is_equal_vec_cols.aux, ); - // constrain that it satisfies predicate if either less than or equal, and that satisfies is bool - builder.assert_eq(greater_than_x + eq_to_x, satisfies_pred); - builder.assert_bool(satisfies_pred); - - // constrain that the public value x is the same as the column x - for (&local_x, &pub_x) in x.iter().zip(public_x.iter()) { - builder.assert_eq(local_x, pub_x); - } - - // constrain that we send the row iff the row is allocated and satisfies the predicate - builder.assert_eq(is_alloc * satisfies_pred, send_row); - builder.assert_bool(send_row); - } - PageIndexScanInputCols::Gt { .. } => { - panic!( - "expected PageIndexScanInputCols::Gte, got PageIndexScanInputCols::Gt" + builder.assert_eq( + *satisfies_strict + *satisfies_eq, + local_cols.satisfies_pred, ); + builder.assert_bool(local_cols.satisfies_pred); } + _ => panic!("Unexpected aux cols"), } } - PageIndexScanInputAir::Gt { - idx_len, - data_len, + PageIndexScanInputAirVariants::Gt(StrictCompAir { is_less_than_tuple_air, .. - } => { - let page_main = &builder.partitioned_main()[0].clone(); - let aux_main = &builder.partitioned_main()[1].clone(); - - // get the public value x - let pis = builder.public_values(); - let public_x = pis[..*idx_len].to_vec(); - - let local_page = page_main.row_slice(0); - let local_aux = aux_main.row_slice(0); - let local_vec = local_page - .iter() - .chain(local_aux.iter()) - .cloned() - .collect::>(); - let local = local_vec.as_slice(); - - let local_cols = PageIndexScanInputCols::::from_slice( - local, - *idx_len, - *data_len, - is_less_than_tuple_air.limb_bits().clone(), - is_less_than_tuple_air.decomp(), - Comp::Gt, + }) => { + // here, we are checking if idx > x + let is_less_than_tuple_cols = IsLessThanTupleCols { + io: IsLessThanTupleIOCols { + x: local_cols.x.clone(), + y: local_cols.page_cols.idx.clone(), + tuple_less_than: local_cols.satisfies_pred, + }, + aux: IsLessThanTupleAuxCols::from_slice( + &is_less_than_tuple_aux_flattened, + idx_limb_bits, + decomp, + self.idx_len, + ), + }; + + // constrain the indicator that we used to check whether idx > x is correct + SubAir::eval( + is_less_than_tuple_air, + &mut builder.when_transition(), + is_less_than_tuple_cols.io, + is_less_than_tuple_cols.aux, ); - - match local_cols { - PageIndexScanInputCols::Lt { .. } => { - panic!( - "expected PageIndexScanInputCols::Gt, got PageIndexScanInputCols::Lt" - ); - } - PageIndexScanInputCols::Lte { .. } => { - panic!( - "expected PageIndexScanInputCols::Gt, got PageIndexScanInputCols::Lte" - ); - } - PageIndexScanInputCols::Eq { .. } => { - panic!( - "expected PageIndexScanInputCols::Gt, got PageIndexScanInputCols::Eq" - ); - } - PageIndexScanInputCols::Gte { .. } => { - panic!( - "expected PageIndexScanInputCols::Gt, got PageIndexScanInputCols::Gte" - ); - } - PageIndexScanInputCols::Gt { - is_alloc, - idx, - x, - satisfies_pred, - send_row, - is_less_than_tuple_aux, - .. - } => { - // here, we are checking if idx > x - let is_less_than_tuple_cols = IsLessThanTupleCols { - io: IsLessThanTupleIOCols { - x: x.clone(), - y: idx, - tuple_less_than: satisfies_pred, - }, - aux: is_less_than_tuple_aux, - }; - - // constrain that the public value x is the same as the column x - for (&local_x, &pub_x) in x.iter().zip(public_x.iter()) { - builder.assert_eq(local_x, pub_x); - } - - // constrain that we send the row iff the row is allocated and satisfies the predicate - builder.assert_eq(is_alloc * satisfies_pred, send_row); - builder.assert_bool(send_row); - - // constrain the indicator that we used to check wheter key < x is correct - SubAir::eval( - is_less_than_tuple_air, - &mut builder.when_transition(), - is_less_than_tuple_cols.io, - is_less_than_tuple_cols.aux, - ); - } - } } } } diff --git a/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs b/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs index 7d91976829..f1e87a02df 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs @@ -1,9 +1,17 @@ use crate::{ - is_less_than_tuple::columns::{IsLessThanTupleCols, IsLessThanTupleIOCols}, + is_less_than_tuple::columns::{ + IsLessThanTupleAuxCols, IsLessThanTupleCols, IsLessThanTupleIOCols, + }, sub_chip::SubAirBridge, }; -use super::{columns::PageIndexScanInputCols, Comp}; +use super::{ + columns::{ + EqCompAuxCols, NonStrictCompAuxCols, PageIndexScanInputAuxCols, PageIndexScanInputCols, + StrictCompAuxCols, + }, + Comp, EqCompAir, NonStrictCompAir, PageIndexScanInputAirVariants, StrictCompAir, +}; use afs_stark_backend::interaction::{AirBridge, Interaction}; use p3_air::VirtualPairCol; use p3_field::PrimeField64; @@ -14,454 +22,151 @@ impl AirBridge for PageIndexScanInputAir { fn sends(&self) -> Vec> { let mut interactions: Vec> = vec![]; - match &self { - PageIndexScanInputAir::Lt { - bus_index, - idx_len, - data_len, + let (idx_limb_bits, decomp) = match &self.subair { + PageIndexScanInputAirVariants::Lt(StrictCompAir { is_less_than_tuple_air, - } => { - let num_cols = PageIndexScanInputCols::::get_width( - *idx_len, - *data_len, - is_less_than_tuple_air.limb_bits(), - is_less_than_tuple_air.decomp(), - Comp::Lt, - ); - let all_cols = (0..num_cols).collect::>(); - - let cols_numbered = PageIndexScanInputCols::::from_slice( - &all_cols, - *idx_len, - *data_len, - is_less_than_tuple_air.limb_bits(), - is_less_than_tuple_air.decomp(), - Comp::Lt, - ); - - match cols_numbered { - PageIndexScanInputCols::Lt { - is_alloc, - idx, - data, - x, - satisfies_pred, - send_row, - is_less_than_tuple_aux, - .. - } => { - let is_less_than_tuple_cols = IsLessThanTupleCols { - io: IsLessThanTupleIOCols { - x: idx.clone(), - y: x.clone(), - tuple_less_than: satisfies_pred, - }, - aux: is_less_than_tuple_aux, - }; - - // construct the row to send - let mut cols = vec![]; - cols.push(is_alloc); - cols.extend(idx); - cols.extend(data); - - let virtual_cols = cols - .iter() - .map(|col| VirtualPairCol::single_main(*col)) - .collect::>(); - - interactions.push(Interaction { - fields: virtual_cols, - count: VirtualPairCol::single_main(send_row), - argument_index: *bus_index, - }); - - let mut subchip_interactions = SubAirBridge::::sends( - is_less_than_tuple_air, - is_less_than_tuple_cols, - ); - - interactions.append(&mut subchip_interactions); - } - PageIndexScanInputCols::Lte { .. } => { - panic!( - "expected PageIndexScanInputCols::Lt, got PageIndexScanInputCols::Lte" - ); - } - PageIndexScanInputCols::Eq { .. } => { - panic!( - "expected PageIndexScanInputCols::Lt, got PageIndexScanInputCols::Eq" - ); - } - PageIndexScanInputCols::Gte { .. } => { - panic!( - "expected PageIndexScanInputCols::Lt, got PageIndexScanInputCols::Gte" - ); - } - PageIndexScanInputCols::Gt { .. } => { - panic!( - "expected PageIndexScanInputCols::Lt, got PageIndexScanInputCols::Gt" - ); - } - } - - interactions - } - PageIndexScanInputAir::Lte { - bus_index, - idx_len, - data_len, + .. + }) + | PageIndexScanInputAirVariants::Gt(StrictCompAir { is_less_than_tuple_air, .. - } => { - let num_cols = PageIndexScanInputCols::::get_width( - *idx_len, - *data_len, - is_less_than_tuple_air.limb_bits(), - is_less_than_tuple_air.decomp(), - Comp::Lte, - ); - - let all_cols = (0..num_cols).collect::>(); - - let cols_numbered = PageIndexScanInputCols::::from_slice( - &all_cols, - *idx_len, - *data_len, - is_less_than_tuple_air.limb_bits(), - is_less_than_tuple_air.decomp(), - Comp::Lte, - ); - - match cols_numbered { - PageIndexScanInputCols::Lt { .. } => { - panic!( - "expected PageIndexScanInputCols::Lte, got PageIndexScanInputCols::Lt" - ); - } - PageIndexScanInputCols::Lte { - is_alloc, - idx, - data, - x, - satisfies_pred, - send_row, - is_less_than_tuple_aux, - .. - } => { - let is_less_than_tuple_cols = IsLessThanTupleCols { - io: IsLessThanTupleIOCols { - x: idx.clone(), - y: x.clone(), - tuple_less_than: satisfies_pred, - }, - aux: is_less_than_tuple_aux, - }; - - // construct the row to send - let mut cols = vec![]; - cols.push(is_alloc); - cols.extend(idx); - cols.extend(data); - - let virtual_cols = cols - .iter() - .map(|col| VirtualPairCol::single_main(*col)) - .collect::>(); - - interactions.push(Interaction { - fields: virtual_cols, - count: VirtualPairCol::single_main(send_row), - argument_index: *bus_index, - }); - - let mut subchip_interactions = SubAirBridge::::sends( - is_less_than_tuple_air, - is_less_than_tuple_cols, - ); - - interactions.append(&mut subchip_interactions); - } - PageIndexScanInputCols::Eq { .. } => { - panic!( - "expected PageIndexScanInputCols::Lte, got PageIndexScanInputCols::Eq" - ); - } - PageIndexScanInputCols::Gte { .. } => { - panic!( - "expected PageIndexScanInputCols::Lte, got PageIndexScanInputCols::Gte" - ); - } - PageIndexScanInputCols::Gt { .. } => { - panic!( - "expected PageIndexScanInputCols::Lte, got PageIndexScanInputCols::Gt" - ); - } - } - - interactions - } - PageIndexScanInputAir::Eq { - bus_index, - idx_len, - data_len, + }) + | PageIndexScanInputAirVariants::Lte(NonStrictCompAir { + is_less_than_tuple_air, .. - } => { - // There is no limb_bits or decomp for IsEqualVec, so we can just pass in an empty vec and 0, respectively - let num_cols = PageIndexScanInputCols::::get_width( - *idx_len, - *data_len, - vec![], - 0, - Comp::Eq, - ); - - let all_cols = (0..num_cols).collect::>(); - - let cols_numbered = PageIndexScanInputCols::::from_slice( - &all_cols, - *idx_len, - *data_len, - vec![], - 0, - Comp::Eq, - ); - - match cols_numbered { - PageIndexScanInputCols::Lt { .. } => { - panic!( - "expected PageIndexScanInputCols::Eq, got PageIndexScanInputCols::Lt" - ); - } - PageIndexScanInputCols::Lte { .. } => { - panic!( - "expected PageIndexScanInputCols::Eq, got PageIndexScanInputCols::Lte" - ); - } - PageIndexScanInputCols::Eq { - is_alloc, - idx, - data, - send_row, - .. - } => { - // construct the row to send - let mut cols = vec![]; - cols.push(is_alloc); - cols.extend(idx); - cols.extend(data); - - let virtual_cols = cols - .iter() - .map(|col| VirtualPairCol::single_main(*col)) - .collect::>(); - - interactions.push(Interaction { - fields: virtual_cols, - count: VirtualPairCol::single_main(send_row), - argument_index: *bus_index, - }); - } - PageIndexScanInputCols::Gte { .. } => { - panic!( - "expected PageIndexScanInputCols::Eq, got PageIndexScanInputCols::Gte" - ); - } - PageIndexScanInputCols::Gt { .. } => { - panic!( - "expected PageIndexScanInputCols::Eq, got PageIndexScanInputCols::Gt" - ); - } - } - - interactions - } - PageIndexScanInputAir::Gte { - bus_index, - idx_len, - data_len, + }) + | PageIndexScanInputAirVariants::Gte(NonStrictCompAir { is_less_than_tuple_air, .. - } => { - let num_cols = PageIndexScanInputCols::::get_width( - *idx_len, - *data_len, - is_less_than_tuple_air.limb_bits(), - is_less_than_tuple_air.decomp(), - Comp::Gte, - ); - - let all_cols = (0..num_cols).collect::>(); - - let cols_numbered = PageIndexScanInputCols::::from_slice( - &all_cols, - *idx_len, - *data_len, - is_less_than_tuple_air.limb_bits(), - is_less_than_tuple_air.decomp(), - Comp::Gte, - ); - - match cols_numbered { - PageIndexScanInputCols::Lt { .. } => { - panic!( - "expected PageIndexScanInputCols::Gte, got PageIndexScanInputCols::Lt" - ); - } - PageIndexScanInputCols::Lte { .. } => { - panic!( - "expected PageIndexScanInputCols::Gte, got PageIndexScanInputCols::Lte" - ); - } - PageIndexScanInputCols::Eq { .. } => { - panic!( - "expected PageIndexScanInputCols::Gte, got PageIndexScanInputCols::Eq" - ); - } - PageIndexScanInputCols::Gte { - is_alloc, - idx, - data, - x, - satisfies_pred, - send_row, - is_less_than_tuple_aux, - .. - } => { - let is_less_than_tuple_cols = IsLessThanTupleCols { - io: IsLessThanTupleIOCols { - x: x.clone(), - y: idx.clone(), - tuple_less_than: satisfies_pred, - }, - aux: is_less_than_tuple_aux, - }; - - // construct the row to send - let mut cols = vec![]; - cols.push(is_alloc); - cols.extend(idx); - cols.extend(data); - - let virtual_cols = cols - .iter() - .map(|col| VirtualPairCol::single_main(*col)) - .collect::>(); - - interactions.push(Interaction { - fields: virtual_cols, - count: VirtualPairCol::single_main(send_row), - argument_index: *bus_index, - }); - - let mut subchip_interactions = SubAirBridge::::sends( - is_less_than_tuple_air, - is_less_than_tuple_cols, - ); - - interactions.append(&mut subchip_interactions); - } - PageIndexScanInputCols::Gt { .. } => { - panic!( - "expected PageIndexScanInputCols::Gte, got PageIndexScanInputCols::Gt" - ); - } - } + }) => ( + is_less_than_tuple_air.limb_bits(), + is_less_than_tuple_air.decomp(), + ), + PageIndexScanInputAirVariants::Eq(EqCompAir { .. }) => (vec![], 0), + }; + + let cmp = match &self.subair { + PageIndexScanInputAirVariants::Lt(..) => Comp::Lt, + PageIndexScanInputAirVariants::Gt(..) => Comp::Gt, + PageIndexScanInputAirVariants::Lte(..) => Comp::Lte, + PageIndexScanInputAirVariants::Gte(..) => Comp::Gte, + PageIndexScanInputAirVariants::Eq(..) => Comp::Eq, + }; + + let num_cols = PageIndexScanInputCols::::get_width( + self.idx_len, + self.data_len, + idx_limb_bits.clone(), + decomp, + cmp.clone(), + ); + let all_cols = (0..num_cols).collect::>(); + + let cols_numbered = PageIndexScanInputCols::::from_slice( + &all_cols, + self.idx_len, + self.data_len, + idx_limb_bits.clone(), + decomp, + cmp.clone(), + ); + + // construct the row to send + let mut cols = vec![]; + cols.push(cols_numbered.page_cols.is_alloc); + cols.extend(cols_numbered.page_cols.idx.clone()); + cols.extend(cols_numbered.page_cols.data); + + let virtual_cols = cols + .iter() + .map(|col| VirtualPairCol::single_main(*col)) + .collect::>(); + + interactions.push(Interaction { + fields: virtual_cols, + count: VirtualPairCol::single_main(cols_numbered.send_row), + argument_index: self.bus_index, + }); + + let (is_less_than_tuple_aux_flattened, strict_comp_ind) = match cols_numbered.aux_cols { + PageIndexScanInputAuxCols::Lt(StrictCompAuxCols { + is_less_than_tuple_aux, + .. + }) + | PageIndexScanInputAuxCols::Gt(StrictCompAuxCols { + is_less_than_tuple_aux, + .. + }) => ( + is_less_than_tuple_aux.flatten(), + cols_numbered.satisfies_pred, + ), + PageIndexScanInputAuxCols::Lte(NonStrictCompAuxCols { + satisfies_strict, + is_less_than_tuple_aux, + .. + }) + | PageIndexScanInputAuxCols::Gte(NonStrictCompAuxCols { + satisfies_strict, + is_less_than_tuple_aux, + .. + }) => (is_less_than_tuple_aux.flatten(), satisfies_strict), + PageIndexScanInputAuxCols::Eq(EqCompAuxCols { .. }) => (vec![], 0), + }; - interactions + let mut subchip_interactions = match &self.subair { + PageIndexScanInputAirVariants::Lt(StrictCompAir { + is_less_than_tuple_air, + .. + }) + | PageIndexScanInputAirVariants::Lte(NonStrictCompAir { + is_less_than_tuple_air, + .. + }) => { + let is_less_than_tuple_cols = IsLessThanTupleCols { + io: IsLessThanTupleIOCols { + x: cols_numbered.page_cols.idx.clone(), + y: cols_numbered.x.clone(), + tuple_less_than: strict_comp_ind, + }, + aux: IsLessThanTupleAuxCols::from_slice( + &is_less_than_tuple_aux_flattened, + idx_limb_bits, + decomp, + self.idx_len, + ), + }; + + SubAirBridge::::sends(is_less_than_tuple_air, is_less_than_tuple_cols) } - PageIndexScanInputAir::Gt { - bus_index, - idx_len, - data_len, + PageIndexScanInputAirVariants::Gt(StrictCompAir { is_less_than_tuple_air, - } => { - let num_cols = PageIndexScanInputCols::::get_width( - *idx_len, - *data_len, - is_less_than_tuple_air.limb_bits(), - is_less_than_tuple_air.decomp(), - Comp::Gt, - ); - let all_cols = (0..num_cols).collect::>(); - - let cols_numbered = PageIndexScanInputCols::::from_slice( - &all_cols, - *idx_len, - *data_len, - is_less_than_tuple_air.limb_bits(), - is_less_than_tuple_air.decomp(), - Comp::Gt, - ); - - match cols_numbered { - PageIndexScanInputCols::Lt { .. } => { - panic!( - "expected PageIndexScanInputCols::Gt, got PageIndexScanInputCols::Lt" - ); - } - PageIndexScanInputCols::Lte { .. } => { - panic!( - "expected PageIndexScanInputCols::Gt, got PageIndexScanInputCols::Lte" - ); - } - PageIndexScanInputCols::Eq { .. } => { - panic!( - "expected PageIndexScanInputCols::Gt, got PageIndexScanInputCols::Eq" - ); - } - PageIndexScanInputCols::Gte { .. } => { - panic!( - "expected PageIndexScanInputCols::Gt, got PageIndexScanInputCols::Gte" - ); - } - PageIndexScanInputCols::Gt { - is_alloc, - idx, - data, - x, - send_row, - satisfies_pred, - is_less_than_tuple_aux, - .. - } => { - let is_less_than_tuple_cols = IsLessThanTupleCols { - io: IsLessThanTupleIOCols { - x: x.clone(), - y: idx.clone(), - tuple_less_than: satisfies_pred, - }, - aux: is_less_than_tuple_aux, - }; - - // construct the row to send - let mut cols = vec![]; - cols.push(is_alloc); - cols.extend(idx); - cols.extend(data); - - let virtual_cols = cols - .iter() - .map(|col| VirtualPairCol::single_main(*col)) - .collect::>(); - - interactions.push(Interaction { - fields: virtual_cols, - count: VirtualPairCol::single_main(send_row), - argument_index: *bus_index, - }); - - let mut subchip_interactions = SubAirBridge::::sends( - is_less_than_tuple_air, - is_less_than_tuple_cols, - ); + .. + }) + | PageIndexScanInputAirVariants::Gte(NonStrictCompAir { + is_less_than_tuple_air, + .. + }) => { + let is_less_than_tuple_cols = IsLessThanTupleCols { + io: IsLessThanTupleIOCols { + x: cols_numbered.x.clone(), + y: cols_numbered.page_cols.idx.clone(), + tuple_less_than: strict_comp_ind, + }, + aux: IsLessThanTupleAuxCols::from_slice( + &is_less_than_tuple_aux_flattened, + idx_limb_bits, + decomp, + self.idx_len, + ), + }; + + SubAirBridge::::sends(is_less_than_tuple_air, is_less_than_tuple_cols) + } + PageIndexScanInputAirVariants::Eq(EqCompAir { .. }) => vec![], + }; - interactions.append(&mut subchip_interactions); - } - } + interactions.append(&mut subchip_interactions); - interactions - } - } + interactions } } diff --git a/chips/src/single_page_index_scan/page_index_scan_input/columns.rs b/chips/src/single_page_index_scan/page_index_scan_input/columns.rs index 3f0dd696a1..2f532ee6a0 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/columns.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/columns.rs @@ -4,62 +4,51 @@ use crate::{ use super::Comp; -pub enum PageIndexScanInputCols { - Lt { - is_alloc: T, - idx: Vec, - data: Vec, - x: Vec, - satisfies_pred: T, - send_row: T, - is_less_than_tuple_aux: IsLessThanTupleAuxCols, - }, +pub struct PageCols { + pub is_alloc: T, // indicates if row is allocated + pub idx: Vec, + pub data: Vec, +} - Lte { - is_alloc: T, - idx: Vec, - data: Vec, - x: Vec, - less_than_x: T, - eq_to_x: T, - satisfies_pred: T, - send_row: T, - is_less_than_tuple_aux: IsLessThanTupleAuxCols, - is_equal_vec_aux: IsEqualVecAuxCols, - }, +impl PageCols { + pub fn from_slice(cols: &[T], idx_len: usize, data_len: usize) -> PageCols { + PageCols { + is_alloc: cols[0].clone(), + idx: cols[1..idx_len + 1].to_vec(), + data: cols[idx_len + 1..idx_len + data_len + 1].to_vec(), + } + } +} - Eq { - is_alloc: T, - idx: Vec, - data: Vec, - x: Vec, - satisfies_pred: T, - send_row: T, - is_equal_vec_aux: IsEqualVecAuxCols, - }, +pub struct StrictCompAuxCols { + pub is_less_than_tuple_aux: IsLessThanTupleAuxCols, +} - Gte { - is_alloc: T, - idx: Vec, - data: Vec, - x: Vec, - greater_than_x: T, - eq_to_x: T, - satisfies_pred: T, - send_row: T, - is_less_than_tuple_aux: IsLessThanTupleAuxCols, - is_equal_vec_aux: IsEqualVecAuxCols, - }, +pub struct NonStrictCompAuxCols { + pub satisfies_strict: T, + pub satisfies_eq: T, + pub is_less_than_tuple_aux: IsLessThanTupleAuxCols, + pub is_equal_vec_aux: IsEqualVecAuxCols, +} - Gt { - is_alloc: T, - idx: Vec, - data: Vec, - x: Vec, - satisfies_pred: T, - send_row: T, - is_less_than_tuple_aux: IsLessThanTupleAuxCols, - }, +pub struct EqCompAuxCols { + pub is_equal_vec_aux: IsEqualVecAuxCols, +} + +pub enum PageIndexScanInputAuxCols { + Lt(StrictCompAuxCols), + Lte(NonStrictCompAuxCols), + Eq(EqCompAuxCols), + Gte(NonStrictCompAuxCols), + Gt(StrictCompAuxCols), +} + +pub struct PageIndexScanInputCols { + pub page_cols: PageCols, + pub x: Vec, + pub satisfies_pred: T, + pub send_row: T, + pub aux_cols: PageIndexScanInputAuxCols, } impl PageIndexScanInputCols { @@ -71,33 +60,31 @@ impl PageIndexScanInputCols { decomp: usize, cmp: Comp, ) -> Self { - match cmp { - Comp::Lt => Self::Lt { - is_alloc: slc[0].clone(), - idx: slc[1..idx_len + 1].to_vec(), - data: slc[idx_len + 1..idx_len + data_len + 1].to_vec(), - x: slc[idx_len + data_len + 1..2 * idx_len + data_len + 1].to_vec(), - satisfies_pred: slc[2 * idx_len + data_len + 1].clone(), - send_row: slc[2 * idx_len + data_len + 2].clone(), + let page_cols = PageCols { + is_alloc: slc[0].clone(), + idx: slc[1..idx_len + 1].to_vec(), + data: slc[idx_len + 1..idx_len + data_len + 1].to_vec(), + }; + + let x = slc[idx_len + data_len + 1..2 * idx_len + data_len + 1].to_vec(); + let satisfies_pred = slc[2 * idx_len + data_len + 1].clone(); + let send_row = slc[2 * idx_len + data_len + 2].clone(); + + let aux_cols = match cmp { + Comp::Lt => PageIndexScanInputAuxCols::Lt(StrictCompAuxCols { is_less_than_tuple_aux: IsLessThanTupleAuxCols::from_slice( &slc[2 * idx_len + data_len + 3..], idx_limb_bits, decomp, idx_len, ), - }, + }), Comp::Lte => { let less_than_tuple_aux_width = IsLessThanTupleAuxCols::::get_width(idx_limb_bits.clone(), decomp, idx_len); - Self::Lte { - is_alloc: slc[0].clone(), - idx: slc[1..idx_len + 1].to_vec(), - data: slc[idx_len + 1..idx_len + data_len + 1].to_vec(), - x: slc[idx_len + data_len + 1..2 * idx_len + data_len + 1].to_vec(), - less_than_x: slc[2 * idx_len + data_len + 1].clone(), - eq_to_x: slc[2 * idx_len + data_len + 2].clone(), - satisfies_pred: slc[2 * idx_len + data_len + 3].clone(), - send_row: slc[2 * idx_len + data_len + 4].clone(), + PageIndexScanInputAuxCols::Lte(NonStrictCompAuxCols { + satisfies_strict: slc[2 * idx_len + data_len + 3].clone(), + satisfies_eq: slc[2 * idx_len + data_len + 4].clone(), is_less_than_tuple_aux: IsLessThanTupleAuxCols::from_slice( &slc[2 * idx_len + data_len + 5 ..2 * idx_len + data_len + 5 + less_than_tuple_aux_width], @@ -105,39 +92,24 @@ impl PageIndexScanInputCols { decomp, idx_len, ), - is_equal_vec_aux: IsEqualVecAuxCols { - prods: slc[2 * idx_len + data_len + 5 + less_than_tuple_aux_width - ..3 * idx_len + data_len + 5 + less_than_tuple_aux_width] - .to_vec(), - invs: slc[3 * idx_len + data_len + 5 + less_than_tuple_aux_width..] - .to_vec(), - }, - } + is_equal_vec_aux: IsEqualVecAuxCols::from_slice( + &slc[2 * idx_len + data_len + 5 + less_than_tuple_aux_width..], + idx_len, + ), + }) } - Comp::Eq => Self::Eq { - is_alloc: slc[0].clone(), - idx: slc[1..idx_len + 1].to_vec(), - data: slc[idx_len + 1..idx_len + data_len + 1].to_vec(), - x: slc[idx_len + data_len + 1..2 * idx_len + data_len + 1].to_vec(), - satisfies_pred: slc[2 * idx_len + data_len + 1].clone(), - send_row: slc[2 * idx_len + data_len + 2].clone(), - is_equal_vec_aux: IsEqualVecAuxCols { - prods: slc[2 * idx_len + data_len + 3..3 * idx_len + data_len + 3].to_vec(), - invs: slc[3 * idx_len + data_len + 3..].to_vec(), - }, - }, + Comp::Eq => PageIndexScanInputAuxCols::Eq(EqCompAuxCols { + is_equal_vec_aux: IsEqualVecAuxCols::from_slice( + &slc[2 * idx_len + data_len + 3..], + idx_len, + ), + }), Comp::Gte => { let less_than_tuple_aux_width = IsLessThanTupleAuxCols::::get_width(idx_limb_bits.clone(), decomp, idx_len); - Self::Gte { - is_alloc: slc[0].clone(), - idx: slc[1..idx_len + 1].to_vec(), - data: slc[idx_len + 1..idx_len + data_len + 1].to_vec(), - x: slc[idx_len + data_len + 1..2 * idx_len + data_len + 1].to_vec(), - greater_than_x: slc[2 * idx_len + data_len + 1].clone(), - eq_to_x: slc[2 * idx_len + data_len + 2].clone(), - satisfies_pred: slc[2 * idx_len + data_len + 3].clone(), - send_row: slc[2 * idx_len + data_len + 4].clone(), + PageIndexScanInputAuxCols::Gte(NonStrictCompAuxCols { + satisfies_strict: slc[2 * idx_len + data_len + 3].clone(), + satisfies_eq: slc[2 * idx_len + data_len + 4].clone(), is_less_than_tuple_aux: IsLessThanTupleAuxCols::from_slice( &slc[2 * idx_len + data_len + 5 ..2 * idx_len + data_len + 5 + less_than_tuple_aux_width], @@ -145,29 +117,28 @@ impl PageIndexScanInputCols { decomp, idx_len, ), - is_equal_vec_aux: IsEqualVecAuxCols { - prods: slc[2 * idx_len + data_len + 5 + less_than_tuple_aux_width - ..3 * idx_len + data_len + 5 + less_than_tuple_aux_width] - .to_vec(), - invs: slc[3 * idx_len + data_len + 5 + less_than_tuple_aux_width..] - .to_vec(), - }, - } + is_equal_vec_aux: IsEqualVecAuxCols::from_slice( + &slc[2 * idx_len + data_len + 5 + less_than_tuple_aux_width..], + idx_len, + ), + }) } - Comp::Gt => Self::Gt { - is_alloc: slc[0].clone(), - idx: slc[1..idx_len + 1].to_vec(), - data: slc[idx_len + 1..idx_len + data_len + 1].to_vec(), - x: slc[idx_len + data_len + 1..2 * idx_len + data_len + 1].to_vec(), - satisfies_pred: slc[2 * idx_len + data_len + 1].clone(), - send_row: slc[2 * idx_len + data_len + 2].clone(), + Comp::Gt => PageIndexScanInputAuxCols::Gt(StrictCompAuxCols { is_less_than_tuple_aux: IsLessThanTupleAuxCols::from_slice( &slc[2 * idx_len + data_len + 3..], idx_limb_bits, decomp, idx_len, ), - }, + }), + }; + + Self { + page_cols, + x, + satisfies_pred, + send_row, + aux_cols, } } diff --git a/chips/src/single_page_index_scan/page_index_scan_input/mod.rs b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs index e5ef562cca..8d3b3d5463 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/mod.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs @@ -21,59 +21,33 @@ pub enum Comp { Gt, } -pub enum PageIndexScanInputAir { - Lt { - /// The bus index - bus_index: usize, - /// The length of each index in the page table - idx_len: usize, - /// The length of each data entry in the page table - data_len: usize, +pub struct StrictCompAir { + is_less_than_tuple_air: IsLessThanTupleAir, +} - is_less_than_tuple_air: IsLessThanTupleAir, - }, - Lte { - /// The bus index - bus_index: usize, - /// The length of each index in the page table - idx_len: usize, - /// The length of each data entry in the page table - data_len: usize, +pub struct NonStrictCompAir { + is_less_than_tuple_air: IsLessThanTupleAir, + is_equal_vec_air: IsEqualVecAir, +} - is_less_than_tuple_air: IsLessThanTupleAir, - is_equal_vec_air: IsEqualVecAir, - }, - Eq { - /// The bus index - bus_index: usize, - /// The length of each index in the page table - idx_len: usize, - /// The length of each data entry in the page table - data_len: usize, +pub struct EqCompAir { + is_equal_vec_air: IsEqualVecAir, +} - is_equal_vec_air: IsEqualVecAir, - }, - Gte { - /// The bus index - bus_index: usize, - /// The length of each index in the page table - idx_len: usize, - /// The length of each data entry in the page table - data_len: usize, +pub enum PageIndexScanInputAirVariants { + Lt(StrictCompAir), + Lte(NonStrictCompAir), + Eq(EqCompAir), + Gte(NonStrictCompAir), + Gt(StrictCompAir), +} - is_less_than_tuple_air: IsLessThanTupleAir, - is_equal_vec_air: IsEqualVecAir, - }, - Gt { - /// The bus index - bus_index: usize, - /// The length of each index in the page table - idx_len: usize, - /// The length of each data entry in the page table - data_len: usize, +pub struct PageIndexScanInputAir { + pub bus_index: usize, + pub idx_len: usize, + pub data_len: usize, - is_less_than_tuple_air: IsLessThanTupleAir, - }, + subair: PageIndexScanInputAirVariants, } /// Given a fixed predicate of the form index OP x, where OP is one of {<, <=, =, >=, >} @@ -99,134 +73,92 @@ impl PageIndexScanInputChip { range_checker: Arc, cmp: Comp, ) -> Self { - match cmp { - Comp::Lt => Self { - air: PageIndexScanInputAir::Lt { + let subair = match cmp { + Comp::Lt => PageIndexScanInputAirVariants::Lt(StrictCompAir { + is_less_than_tuple_air: IsLessThanTupleAir::new( bus_index, - idx_len, - data_len, - is_less_than_tuple_air: IsLessThanTupleAir::new( - bus_index, - range_max, - idx_limb_bits.clone(), - decomp, - ), - }, - range_checker, - cmp, - }, - Comp::Lte => Self { - air: PageIndexScanInputAir::Lte { + range_max, + idx_limb_bits.clone(), + decomp, + ), + }), + Comp::Lte => PageIndexScanInputAirVariants::Lte(NonStrictCompAir { + is_less_than_tuple_air: IsLessThanTupleAir::new( bus_index, - idx_len, - data_len, - is_less_than_tuple_air: IsLessThanTupleAir::new( - bus_index, - range_max, - idx_limb_bits.clone(), - decomp, - ), - is_equal_vec_air: IsEqualVecAir::new(idx_len), - }, - range_checker, - cmp, - }, - Comp::Eq => Self { - air: PageIndexScanInputAir::Eq { + range_max, + idx_limb_bits.clone(), + decomp, + ), + is_equal_vec_air: IsEqualVecAir::new(idx_len), + }), + Comp::Eq => PageIndexScanInputAirVariants::Eq(EqCompAir { + is_equal_vec_air: IsEqualVecAir::new(idx_len), + }), + Comp::Gte => PageIndexScanInputAirVariants::Gte(NonStrictCompAir { + is_less_than_tuple_air: IsLessThanTupleAir::new( bus_index, - idx_len, - data_len, - is_equal_vec_air: IsEqualVecAir::new(idx_len), - }, - range_checker, - cmp, - }, - Comp::Gte => Self { - air: PageIndexScanInputAir::Gte { + range_max, + idx_limb_bits.clone(), + decomp, + ), + is_equal_vec_air: IsEqualVecAir::new(idx_len), + }), + Comp::Gt => PageIndexScanInputAirVariants::Gt(StrictCompAir { + is_less_than_tuple_air: IsLessThanTupleAir::new( bus_index, - idx_len, - data_len, - is_less_than_tuple_air: IsLessThanTupleAir::new( - bus_index, - range_max, - idx_limb_bits.clone(), - decomp, - ), - is_equal_vec_air: IsEqualVecAir::new(idx_len), - }, - range_checker, - cmp, - }, - Comp::Gt => Self { - air: PageIndexScanInputAir::Gt { - bus_index, - idx_len, - data_len, - is_less_than_tuple_air: IsLessThanTupleAir::new( - bus_index, - range_max, - idx_limb_bits.clone(), - decomp, - ), - }, - range_checker, - cmp, - }, + range_max, + idx_limb_bits.clone(), + decomp, + ), + }), + }; + + let air = PageIndexScanInputAir { + bus_index, + idx_len, + data_len, + subair, + }; + + Self { + air, + range_checker, + cmp, } } pub fn page_width(&self) -> usize { - match &self.air { - PageIndexScanInputAir::Lt { - idx_len, data_len, .. - } - | PageIndexScanInputAir::Lte { - idx_len, data_len, .. - } - | PageIndexScanInputAir::Eq { - idx_len, data_len, .. - } - | PageIndexScanInputAir::Gte { - idx_len, data_len, .. - } - | PageIndexScanInputAir::Gt { - idx_len, data_len, .. - } => 1 + idx_len + data_len, - } + 1 + self.air.idx_len + self.air.data_len } pub fn aux_width(&self) -> usize { - match &self.air { - PageIndexScanInputAir::Lt { - idx_len, + match &self.air.subair { + PageIndexScanInputAirVariants::Lt(StrictCompAir { is_less_than_tuple_air, .. - } - | PageIndexScanInputAir::Gt { - idx_len, + }) + | PageIndexScanInputAirVariants::Gt(StrictCompAir { is_less_than_tuple_air, .. - } => { - idx_len + }) => { + self.air.idx_len + 1 + 1 + IsLessThanTupleAuxCols::::get_width( is_less_than_tuple_air.limb_bits(), is_less_than_tuple_air.decomp(), - *idx_len, + self.air.idx_len, ) } - PageIndexScanInputAir::Lte { - idx_len, + PageIndexScanInputAirVariants::Lte(NonStrictCompAir { is_less_than_tuple_air, .. - } - | PageIndexScanInputAir::Gte { - idx_len, + }) + | PageIndexScanInputAirVariants::Gte(NonStrictCompAir { is_less_than_tuple_air, .. - } => { - idx_len + }) => { + self.air.idx_len + 1 + 1 + 1 @@ -234,11 +166,13 @@ impl PageIndexScanInputChip { + IsLessThanTupleAuxCols::::get_width( is_less_than_tuple_air.limb_bits(), is_less_than_tuple_air.decomp(), - *idx_len, + self.air.idx_len, ) - + 2 * idx_len + + 2 * self.air.idx_len + } + PageIndexScanInputAirVariants::Eq(EqCompAir { .. }) => { + self.air.idx_len + 1 + 1 + 2 * self.air.idx_len } - PageIndexScanInputAir::Eq { idx_len, .. } => idx_len + 1 + 1 + 2 * idx_len, } } diff --git a/chips/src/single_page_index_scan/page_index_scan_input/trace.rs b/chips/src/single_page_index_scan/page_index_scan_input/trace.rs index 01f4c054d0..42d74179db 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/trace.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/trace.rs @@ -4,7 +4,10 @@ use p3_uni_stark::{StarkGenericConfig, Val}; use crate::sub_chip::LocalTraceInstructions; -use super::{PageIndexScanInputAir, PageIndexScanInputChip}; +use super::{ + EqCompAir, NonStrictCompAir, PageIndexScanInputAirVariants, PageIndexScanInputChip, + StrictCompAir, +}; impl PageIndexScanInputChip { /// Generate the trace for the page table @@ -41,21 +44,20 @@ impl PageIndexScanInputChip { for page_row in &page { let mut row: Vec> = vec![]; - match &self.air { - PageIndexScanInputAir::Lt { - idx_len, - is_less_than_tuple_air, - .. - } => { - let is_alloc = Val::::from_canonical_u32(page_row[0]); - let idx = page_row[1..1 + *idx_len].to_vec(); + let is_alloc = Val::::from_canonical_u32(page_row[0]); + let idx = page_row[1..1 + self.air.idx_len].to_vec(); - let x_trace: Vec> = x - .iter() - .map(|x| Val::::from_canonical_u32(*x)) - .collect(); - row.extend(x_trace); + let x_trace: Vec> = x + .iter() + .map(|x| Val::::from_canonical_u32(*x)) + .collect(); + row.extend(x_trace); + match &self.air.subair { + PageIndexScanInputAirVariants::Lt(StrictCompAir { + is_less_than_tuple_air, + .. + }) => { let is_less_than_tuple_trace: Vec> = LocalTraceInstructions::generate_trace_row( is_less_than_tuple_air, @@ -63,27 +65,17 @@ impl PageIndexScanInputChip { ) .flatten(); - row.push(is_less_than_tuple_trace[2 * *idx_len]); - let send_row = is_less_than_tuple_trace[2 * *idx_len] * is_alloc; + row.push(is_less_than_tuple_trace[2 * self.air.idx_len]); + let send_row = is_less_than_tuple_trace[2 * self.air.idx_len] * is_alloc; row.push(send_row); - row.extend_from_slice(&is_less_than_tuple_trace[2 * *idx_len + 1..]); + row.extend_from_slice(&is_less_than_tuple_trace[2 * self.air.idx_len + 1..]); } - PageIndexScanInputAir::Lte { - idx_len, + PageIndexScanInputAirVariants::Lte(NonStrictCompAir { is_less_than_tuple_air, is_equal_vec_air, .. - } => { - let is_alloc = Val::::from_canonical_u32(page_row[0]); - let idx = page_row[1..1 + *idx_len].to_vec(); - - let x_trace: Vec> = x - .iter() - .map(|x| Val::::from_canonical_u32(*x)) - .collect(); - row.extend(x_trace); - + }) => { let is_less_than_tuple_trace: Vec> = LocalTraceInstructions::generate_trace_row( is_less_than_tuple_air, @@ -107,30 +99,20 @@ impl PageIndexScanInputChip { ) .flatten(); - row.push(is_less_than_tuple_trace[2 * *idx_len]); - row.push(is_equal_vec_trace[3 * *idx_len - 1]); - let satisfies_pred = is_less_than_tuple_trace[2 * *idx_len] - + is_equal_vec_trace[3 * *idx_len - 1]; + let satisfies_pred = is_less_than_tuple_trace[2 * self.air.idx_len] + + is_equal_vec_trace[3 * self.air.idx_len - 1]; row.push(satisfies_pred); row.push(satisfies_pred * is_alloc); - row.extend_from_slice(&is_less_than_tuple_trace[2 * *idx_len + 1..]); - row.extend_from_slice(&is_equal_vec_trace[2 * *idx_len..]); - } - PageIndexScanInputAir::Eq { - idx_len, - is_equal_vec_air, - .. - } => { - let is_alloc = Val::::from_canonical_u32(page_row[0]); - let idx = page_row[1..1 + *idx_len].to_vec(); - - let x_trace: Vec> = x - .iter() - .map(|x| Val::::from_canonical_u32(*x)) - .collect(); - row.extend(x_trace); + row.push(is_less_than_tuple_trace[2 * self.air.idx_len]); + row.push(is_equal_vec_trace[3 * self.air.idx_len - 1]); + row.extend_from_slice(&is_less_than_tuple_trace[2 * self.air.idx_len + 1..]); + row.extend_from_slice(&is_equal_vec_trace[2 * self.air.idx_len..]); + } + PageIndexScanInputAirVariants::Eq(EqCompAir { + is_equal_vec_air, .. + }) => { let is_equal_vec_trace: Vec> = LocalTraceInstructions::generate_trace_row( is_equal_vec_air, @@ -147,27 +129,34 @@ impl PageIndexScanInputChip { ) .flatten(); - row.push(is_equal_vec_trace[3 * *idx_len - 1]); - let send_row = is_equal_vec_trace[3 * *idx_len - 1] * is_alloc; + row.push(is_equal_vec_trace[3 * self.air.idx_len - 1]); + let send_row = is_equal_vec_trace[3 * self.air.idx_len - 1] * is_alloc; row.push(send_row); - row.extend_from_slice(&is_equal_vec_trace[2 * *idx_len..]); + row.extend_from_slice(&is_equal_vec_trace[2 * self.air.idx_len..]); } - PageIndexScanInputAir::Gte { - idx_len, + PageIndexScanInputAirVariants::Gt(StrictCompAir { is_less_than_tuple_air, - is_equal_vec_air, .. - } => { - let is_alloc = Val::::from_canonical_u32(page_row[0]); - let idx = page_row[1..1 + *idx_len].to_vec(); + }) => { + let is_less_than_tuple_trace: Vec> = + LocalTraceInstructions::generate_trace_row( + is_less_than_tuple_air, + (x.clone(), idx.clone(), self.range_checker.clone()), + ) + .flatten(); - let x_trace: Vec> = x - .iter() - .map(|x| Val::::from_canonical_u32(*x)) - .collect(); - row.extend(x_trace); + row.push(is_less_than_tuple_trace[2 * self.air.idx_len]); + let send_row = is_less_than_tuple_trace[2 * self.air.idx_len] * is_alloc; + row.push(send_row); + row.extend_from_slice(&is_less_than_tuple_trace[2 * self.air.idx_len + 1..]); + } + PageIndexScanInputAirVariants::Gte(NonStrictCompAir { + is_less_than_tuple_air, + is_equal_vec_air, + .. + }) => { let is_less_than_tuple_trace: Vec> = LocalTraceInstructions::generate_trace_row( is_less_than_tuple_air, @@ -191,43 +180,16 @@ impl PageIndexScanInputChip { ) .flatten(); - row.push(is_less_than_tuple_trace[2 * *idx_len]); - row.push(is_equal_vec_trace[3 * *idx_len - 1]); - let satisfies_pred = is_less_than_tuple_trace[2 * *idx_len] - + is_equal_vec_trace[3 * *idx_len - 1]; + let satisfies_pred = is_less_than_tuple_trace[2 * self.air.idx_len] + + is_equal_vec_trace[3 * self.air.idx_len - 1]; row.push(satisfies_pred); row.push(satisfies_pred * is_alloc); - row.extend_from_slice(&is_less_than_tuple_trace[2 * *idx_len + 1..]); - row.extend_from_slice(&is_equal_vec_trace[2 * *idx_len..]); - } - PageIndexScanInputAir::Gt { - idx_len, - is_less_than_tuple_air, - .. - } => { - let is_alloc = Val::::from_canonical_u32(page_row[0]); - let idx = page_row[1..1 + *idx_len].to_vec(); - - let x_trace: Vec> = x - .iter() - .map(|x| Val::::from_canonical_u32(*x)) - .collect(); - row.extend(x_trace); - - // we want to check if idx > x - let is_less_than_tuple_trace: Vec> = - LocalTraceInstructions::generate_trace_row( - is_less_than_tuple_air, - (x.clone(), idx.clone(), self.range_checker.clone()), - ) - .flatten(); - - row.push(is_less_than_tuple_trace[2 * *idx_len]); - let send_row = is_less_than_tuple_trace[2 * *idx_len] * is_alloc; - row.push(send_row); + row.push(is_less_than_tuple_trace[2 * self.air.idx_len]); + row.push(is_equal_vec_trace[3 * self.air.idx_len - 1]); - row.extend_from_slice(&is_less_than_tuple_trace[2 * *idx_len + 1..]); + row.extend_from_slice(&is_less_than_tuple_trace[2 * self.air.idx_len + 1..]); + row.extend_from_slice(&is_equal_vec_trace[2 * self.air.idx_len..]); } } From 99d1f9847aa7208d5fadfcb80d1cfe0b4afbb76f Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Wed, 12 Jun 2024 16:06:05 -0400 Subject: [PATCH 39/46] chore: cleanup PageIndexScanInputChip --- chips/src/is_equal_vec/columns.rs | 23 ++++++++++ .../page_index_scan_input/air.rs | 35 +++++++++------- .../page_index_scan_input/bridge.rs | 16 ++++--- .../page_index_scan_input/columns.rs | 21 ++++++---- .../page_index_scan_input/mod.rs | 18 ++++---- .../page_index_scan_input/trace.rs | 42 +++++++++++-------- 6 files changed, 99 insertions(+), 56 deletions(-) diff --git a/chips/src/is_equal_vec/columns.rs b/chips/src/is_equal_vec/columns.rs index 9875d44317..98dc16fab9 100644 --- a/chips/src/is_equal_vec/columns.rs +++ b/chips/src/is_equal_vec/columns.rs @@ -5,6 +5,25 @@ pub struct IsEqualVecIOCols { pub prod: T, } +impl IsEqualVecIOCols { + pub fn flatten(&self) -> Vec { + let mut res: Vec = self.x.iter().chain(self.y.iter()).cloned().collect(); + res.push(self.prod.clone()); + res + } + + pub fn from_slice(slc: &[T], vec_len: usize) -> Self { + let x = slc[0..vec_len].to_vec(); + let y = slc[vec_len..2 * vec_len].to_vec(); + let prod = slc[2 * vec_len].clone(); + Self { x, y, prod } + } + + pub fn get_width(vec_len: usize) -> usize { + vec_len + vec_len + 1 + } +} + #[derive(Default)] pub struct IsEqualVecAuxCols { pub prods: Vec, @@ -21,6 +40,10 @@ impl IsEqualVecAuxCols { let invs = slc[vec_len..2 * vec_len].to_vec(); Self { prods, invs } } + + pub fn get_width(vec_len: usize) -> usize { + 2 * vec_len + } } #[derive(Default)] diff --git a/chips/src/single_page_index_scan/page_index_scan_input/air.rs b/chips/src/single_page_index_scan/page_index_scan_input/air.rs index 50c963e494..8b55aafd86 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/air.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/air.rs @@ -26,7 +26,7 @@ impl AirConfig for PageIndexScanInputAir { impl BaseAir for PageIndexScanInputAir { fn width(&self) -> usize { - match &self.subair { + match &self.variant_air { PageIndexScanInputAirVariants::Lt(StrictCompAir { is_less_than_tuple_air, .. @@ -56,6 +56,7 @@ impl BaseAir for PageIndexScanInputAir { Comp::Lte, ), PageIndexScanInputAirVariants::Eq(EqCompAir { .. }) => { + // since get_width doesn't use idx_limb_bits and decomp for when comparator is =, we can pass in dummy values PageIndexScanInputCols::::get_width( self.idx_len, self.data_len, @@ -89,7 +90,8 @@ where .collect::>(); let local = local_vec.as_slice(); - let (idx_limb_bits, decomp) = match &self.subair { + // when comparator is = we can use dummy values for idx_limb_bits and decomp + let (idx_limb_bits, decomp) = match &self.variant_air { PageIndexScanInputAirVariants::Lt(StrictCompAir { is_less_than_tuple_air, .. @@ -112,7 +114,7 @@ where PageIndexScanInputAirVariants::Eq(EqCompAir { .. }) => (vec![], 0), }; - let cmp = match &self.subair { + let cmp = match &self.variant_air { PageIndexScanInputAirVariants::Lt(..) => Comp::Lt, PageIndexScanInputAirVariants::Gt(..) => Comp::Gt, PageIndexScanInputAirVariants::Lte(..) => Comp::Lte, @@ -142,6 +144,7 @@ where builder.assert_bool(local_cols.satisfies_pred); builder.assert_bool(local_cols.send_row); + // generate flattened aux columns for IsLessThanTuple and IsEqualVec let is_less_than_tuple_aux_flattened = match &local_cols.aux_cols { PageIndexScanInputAuxCols::Lt(StrictCompAuxCols { is_less_than_tuple_aux, @@ -175,7 +178,7 @@ where _ => vec![], }; - match &self.subair { + match &self.variant_air { PageIndexScanInputAirVariants::Lt(StrictCompAir { is_less_than_tuple_air, .. @@ -195,7 +198,7 @@ where ), }; - // constrain the indicator that we used to check whether key < x is correct + // constrain the indicator that we used to check whether idx < x is correct SubAir::eval( is_less_than_tuple_air, &mut builder.when_transition(), @@ -209,8 +212,8 @@ where }) => { match &local_cols.aux_cols { PageIndexScanInputAuxCols::Lte(NonStrictCompAuxCols { - satisfies_strict, - satisfies_eq, + satisfies_strict_comp, + satisfies_eq_comp, .. }) => { // here, we are checking if idx < x @@ -218,7 +221,7 @@ where io: IsLessThanTupleIOCols { x: local_cols.page_cols.idx.clone(), y: local_cols.x.clone(), - tuple_less_than: *satisfies_strict, + tuple_less_than: *satisfies_strict_comp, }, aux: IsLessThanTupleAuxCols::from_slice( &is_less_than_tuple_aux_flattened, @@ -241,7 +244,7 @@ where io: IsEqualVecIOCols { x: local_cols.page_cols.idx.clone(), y: local_cols.x.clone(), - prod: *satisfies_eq, + prod: *satisfies_eq_comp, }, aux: IsEqualVecAuxCols::from_slice( &is_equal_vec_aux_flattened, @@ -259,7 +262,7 @@ where // constrain that satisfies_pred indicates whether idx <= x builder.assert_eq( - *satisfies_strict + *satisfies_eq, + *satisfies_strict_comp + *satisfies_eq_comp, local_cols.satisfies_pred, ); } @@ -291,8 +294,8 @@ where }) => { match &local_cols.aux_cols { PageIndexScanInputAuxCols::Gte(NonStrictCompAuxCols { - satisfies_strict, - satisfies_eq, + satisfies_strict_comp, + satisfies_eq_comp, .. }) => { // here, we are checking if idx > x @@ -300,7 +303,7 @@ where io: IsLessThanTupleIOCols { x: local_cols.x.clone(), y: local_cols.page_cols.idx.clone(), - tuple_less_than: *satisfies_strict, + tuple_less_than: *satisfies_strict_comp, }, aux: IsLessThanTupleAuxCols::from_slice( &is_less_than_tuple_aux_flattened, @@ -323,7 +326,7 @@ where io: IsEqualVecIOCols { x: local_cols.page_cols.idx.clone(), y: local_cols.x.clone(), - prod: *satisfies_eq, + prod: *satisfies_eq_comp, }, aux: IsEqualVecAuxCols::from_slice( &is_equal_vec_aux_flattened, @@ -339,11 +342,11 @@ where is_equal_vec_cols.aux, ); + // constrain that satisfies_pred indicates whether idx >= x builder.assert_eq( - *satisfies_strict + *satisfies_eq, + *satisfies_strict_comp + *satisfies_eq_comp, local_cols.satisfies_pred, ); - builder.assert_bool(local_cols.satisfies_pred); } _ => panic!("Unexpected aux cols"), } diff --git a/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs b/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs index f1e87a02df..8ef5718a25 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs @@ -22,7 +22,8 @@ impl AirBridge for PageIndexScanInputAir { fn sends(&self) -> Vec> { let mut interactions: Vec> = vec![]; - let (idx_limb_bits, decomp) = match &self.subair { + // when comparator is = we can use dummy values for idx_limb_bits and decomp + let (idx_limb_bits, decomp) = match &self.variant_air { PageIndexScanInputAirVariants::Lt(StrictCompAir { is_less_than_tuple_air, .. @@ -45,7 +46,7 @@ impl AirBridge for PageIndexScanInputAir { PageIndexScanInputAirVariants::Eq(EqCompAir { .. }) => (vec![], 0), }; - let cmp = match &self.subair { + let cmp = match &self.variant_air { PageIndexScanInputAirVariants::Lt(..) => Comp::Lt, PageIndexScanInputAirVariants::Gt(..) => Comp::Gt, PageIndexScanInputAirVariants::Lte(..) => Comp::Lte, @@ -88,6 +89,8 @@ impl AirBridge for PageIndexScanInputAir { argument_index: self.bus_index, }); + // here, we generate the flattened aux columns for IsLessThanTuple, and get the indicator associated with the strict comparison + // when the comparator is =, we can just generate dummy values let (is_less_than_tuple_aux_flattened, strict_comp_ind) = match cols_numbered.aux_cols { PageIndexScanInputAuxCols::Lt(StrictCompAuxCols { is_less_than_tuple_aux, @@ -101,19 +104,20 @@ impl AirBridge for PageIndexScanInputAir { cols_numbered.satisfies_pred, ), PageIndexScanInputAuxCols::Lte(NonStrictCompAuxCols { - satisfies_strict, + satisfies_strict_comp, is_less_than_tuple_aux, .. }) | PageIndexScanInputAuxCols::Gte(NonStrictCompAuxCols { - satisfies_strict, + satisfies_strict_comp, is_less_than_tuple_aux, .. - }) => (is_less_than_tuple_aux.flatten(), satisfies_strict), + }) => (is_less_than_tuple_aux.flatten(), satisfies_strict_comp), PageIndexScanInputAuxCols::Eq(EqCompAuxCols { .. }) => (vec![], 0), }; - let mut subchip_interactions = match &self.subair { + // get interactions from IsLessThanTuple subchip + let mut subchip_interactions = match &self.variant_air { PageIndexScanInputAirVariants::Lt(StrictCompAir { is_less_than_tuple_air, .. diff --git a/chips/src/single_page_index_scan/page_index_scan_input/columns.rs b/chips/src/single_page_index_scan/page_index_scan_input/columns.rs index 2f532ee6a0..61785b7977 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/columns.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/columns.rs @@ -1,5 +1,6 @@ use crate::{ - is_equal_vec::columns::IsEqualVecAuxCols, is_less_than_tuple::columns::IsLessThanTupleAuxCols, + is_equal_vec::columns::{IsEqualVecAuxCols, IsEqualVecIOCols}, + is_less_than_tuple::columns::IsLessThanTupleAuxCols, }; use super::Comp; @@ -25,8 +26,8 @@ pub struct StrictCompAuxCols { } pub struct NonStrictCompAuxCols { - pub satisfies_strict: T, - pub satisfies_eq: T, + pub satisfies_strict_comp: T, + pub satisfies_eq_comp: T, pub is_less_than_tuple_aux: IsLessThanTupleAuxCols, pub is_equal_vec_aux: IsEqualVecAuxCols, } @@ -83,8 +84,8 @@ impl PageIndexScanInputCols { let less_than_tuple_aux_width = IsLessThanTupleAuxCols::::get_width(idx_limb_bits.clone(), decomp, idx_len); PageIndexScanInputAuxCols::Lte(NonStrictCompAuxCols { - satisfies_strict: slc[2 * idx_len + data_len + 3].clone(), - satisfies_eq: slc[2 * idx_len + data_len + 4].clone(), + satisfies_strict_comp: slc[2 * idx_len + data_len + 3].clone(), + satisfies_eq_comp: slc[2 * idx_len + data_len + 4].clone(), is_less_than_tuple_aux: IsLessThanTupleAuxCols::from_slice( &slc[2 * idx_len + data_len + 5 ..2 * idx_len + data_len + 5 + less_than_tuple_aux_width], @@ -108,8 +109,8 @@ impl PageIndexScanInputCols { let less_than_tuple_aux_width = IsLessThanTupleAuxCols::::get_width(idx_limb_bits.clone(), decomp, idx_len); PageIndexScanInputAuxCols::Gte(NonStrictCompAuxCols { - satisfies_strict: slc[2 * idx_len + data_len + 3].clone(), - satisfies_eq: slc[2 * idx_len + data_len + 4].clone(), + satisfies_strict_comp: slc[2 * idx_len + data_len + 3].clone(), + satisfies_eq_comp: slc[2 * idx_len + data_len + 4].clone(), is_less_than_tuple_aux: IsLessThanTupleAuxCols::from_slice( &slc[2 * idx_len + data_len + 5 ..2 * idx_len + data_len + 5 + less_than_tuple_aux_width], @@ -167,9 +168,11 @@ impl PageIndexScanInputCols { + 1 + 1 + IsLessThanTupleAuxCols::::get_width(idx_limb_bits, decomp, idx_len) - + 2 * idx_len + + IsEqualVecIOCols::::get_width(idx_len) + } + Comp::Eq => { + 1 + idx_len + data_len + idx_len + 1 + 1 + IsEqualVecIOCols::::get_width(idx_len) } - Comp::Eq => 1 + idx_len + data_len + idx_len + 1 + 1 + 2 * idx_len, } } } diff --git a/chips/src/single_page_index_scan/page_index_scan_input/mod.rs b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs index 8d3b3d5463..d4f3e3a681 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/mod.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use crate::{ - is_equal_vec::IsEqualVecAir, + is_equal_vec::{columns::IsEqualVecIOCols, IsEqualVecAir}, is_less_than_tuple::{columns::IsLessThanTupleAuxCols, IsLessThanTupleAir}, range_gate::RangeCheckerGateChip, }; @@ -47,7 +47,7 @@ pub struct PageIndexScanInputAir { pub idx_len: usize, pub data_len: usize, - subair: PageIndexScanInputAirVariants, + variant_air: PageIndexScanInputAirVariants, } /// Given a fixed predicate of the form index OP x, where OP is one of {<, <=, =, >=, >} @@ -73,7 +73,7 @@ impl PageIndexScanInputChip { range_checker: Arc, cmp: Comp, ) -> Self { - let subair = match cmp { + let variant_air = match cmp { Comp::Lt => PageIndexScanInputAirVariants::Lt(StrictCompAir { is_less_than_tuple_air: IsLessThanTupleAir::new( bus_index, @@ -117,7 +117,7 @@ impl PageIndexScanInputChip { bus_index, idx_len, data_len, - subair, + variant_air, }; Self { @@ -132,7 +132,7 @@ impl PageIndexScanInputChip { } pub fn aux_width(&self) -> usize { - match &self.air.subair { + match &self.air.variant_air { PageIndexScanInputAirVariants::Lt(StrictCompAir { is_less_than_tuple_air, .. @@ -141,6 +141,7 @@ impl PageIndexScanInputChip { is_less_than_tuple_air, .. }) => { + // x, satisfies_pred, send_row, is_less_than_tuple_aux_cols self.air.idx_len + 1 + 1 @@ -158,6 +159,8 @@ impl PageIndexScanInputChip { is_less_than_tuple_air, .. }) => { + // x, satisfies_pred, send_row, satisfies_strict_comp, satisfies_eq_comp, + // is_less_than_tuple_aux_cols, is_equal_vec_aux_cols self.air.idx_len + 1 + 1 @@ -168,10 +171,11 @@ impl PageIndexScanInputChip { is_less_than_tuple_air.decomp(), self.air.idx_len, ) - + 2 * self.air.idx_len + + IsEqualVecIOCols::::get_width(self.air.idx_len) } PageIndexScanInputAirVariants::Eq(EqCompAir { .. }) => { - self.air.idx_len + 1 + 1 + 2 * self.air.idx_len + // x, satisfies_pred, send_row, is_equal_vec_aux_cols + self.air.idx_len + 1 + 1 + IsEqualVecIOCols::::get_width(self.air.idx_len) } } } diff --git a/chips/src/single_page_index_scan/page_index_scan_input/trace.rs b/chips/src/single_page_index_scan/page_index_scan_input/trace.rs index 42d74179db..3817ab7e96 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/trace.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/trace.rs @@ -47,13 +47,14 @@ impl PageIndexScanInputChip { let is_alloc = Val::::from_canonical_u32(page_row[0]); let idx = page_row[1..1 + self.air.idx_len].to_vec(); + // first, get the values for x let x_trace: Vec> = x .iter() .map(|x| Val::::from_canonical_u32(*x)) .collect(); row.extend(x_trace); - match &self.air.subair { + match &self.air.variant_air { PageIndexScanInputAirVariants::Lt(StrictCompAir { is_less_than_tuple_air, .. @@ -65,6 +66,7 @@ impl PageIndexScanInputChip { ) .flatten(); + // satisfies_pred, send_row, is_less_than_tuple_aux_cols row.push(is_less_than_tuple_trace[2 * self.air.idx_len]); let send_row = is_less_than_tuple_trace[2 * self.air.idx_len] * is_alloc; row.push(send_row); @@ -99,6 +101,7 @@ impl PageIndexScanInputChip { ) .flatten(); + // satisfies_pred, send_row, satisfies_strict_comp, satisfies_eq_comp, is_less_than_tuple_aux_cols, is_equal_vec_aux_cols let satisfies_pred = is_less_than_tuple_trace[2 * self.air.idx_len] + is_equal_vec_trace[3 * self.air.idx_len - 1]; row.push(satisfies_pred); @@ -129,29 +132,13 @@ impl PageIndexScanInputChip { ) .flatten(); + // satisfies_pred, send_row, is_equal_vec_aux_cols row.push(is_equal_vec_trace[3 * self.air.idx_len - 1]); let send_row = is_equal_vec_trace[3 * self.air.idx_len - 1] * is_alloc; row.push(send_row); row.extend_from_slice(&is_equal_vec_trace[2 * self.air.idx_len..]); } - PageIndexScanInputAirVariants::Gt(StrictCompAir { - is_less_than_tuple_air, - .. - }) => { - let is_less_than_tuple_trace: Vec> = - LocalTraceInstructions::generate_trace_row( - is_less_than_tuple_air, - (x.clone(), idx.clone(), self.range_checker.clone()), - ) - .flatten(); - - row.push(is_less_than_tuple_trace[2 * self.air.idx_len]); - let send_row = is_less_than_tuple_trace[2 * self.air.idx_len] * is_alloc; - row.push(send_row); - - row.extend_from_slice(&is_less_than_tuple_trace[2 * self.air.idx_len + 1..]); - } PageIndexScanInputAirVariants::Gte(NonStrictCompAir { is_less_than_tuple_air, is_equal_vec_air, @@ -180,6 +167,7 @@ impl PageIndexScanInputChip { ) .flatten(); + // satisfies_pred, send_row, satisfies_strict_comp, satisfies_eq_comp, is_less_than_tuple_aux_cols, is_equal_vec_aux_cols let satisfies_pred = is_less_than_tuple_trace[2 * self.air.idx_len] + is_equal_vec_trace[3 * self.air.idx_len - 1]; row.push(satisfies_pred); @@ -191,6 +179,24 @@ impl PageIndexScanInputChip { row.extend_from_slice(&is_less_than_tuple_trace[2 * self.air.idx_len + 1..]); row.extend_from_slice(&is_equal_vec_trace[2 * self.air.idx_len..]); } + PageIndexScanInputAirVariants::Gt(StrictCompAir { + is_less_than_tuple_air, + .. + }) => { + let is_less_than_tuple_trace: Vec> = + LocalTraceInstructions::generate_trace_row( + is_less_than_tuple_air, + (x.clone(), idx.clone(), self.range_checker.clone()), + ) + .flatten(); + + // satisfies_pred, send_row, is_less_than_tuple_aux_cols + row.push(is_less_than_tuple_trace[2 * self.air.idx_len]); + let send_row = is_less_than_tuple_trace[2 * self.air.idx_len] * is_alloc; + row.push(send_row); + + row.extend_from_slice(&is_less_than_tuple_trace[2 * self.air.idx_len + 1..]); + } } rows.extend_from_slice(&row); From aeb206ab50339e4fde8ca559f0e206a596088d01 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Wed, 12 Jun 2024 16:10:45 -0400 Subject: [PATCH 40/46] chore: fix test --- chips/src/is_equal_vec/columns.rs | 2 +- .../page_index_scan_input/columns.rs | 12 ++++++++---- .../page_index_scan_input/mod.rs | 6 +++--- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/chips/src/is_equal_vec/columns.rs b/chips/src/is_equal_vec/columns.rs index 98dc16fab9..662304ba8d 100644 --- a/chips/src/is_equal_vec/columns.rs +++ b/chips/src/is_equal_vec/columns.rs @@ -42,7 +42,7 @@ impl IsEqualVecAuxCols { } pub fn get_width(vec_len: usize) -> usize { - 2 * vec_len + vec_len + vec_len } } diff --git a/chips/src/single_page_index_scan/page_index_scan_input/columns.rs b/chips/src/single_page_index_scan/page_index_scan_input/columns.rs index 61785b7977..27fa587014 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/columns.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/columns.rs @@ -1,6 +1,5 @@ use crate::{ - is_equal_vec::columns::{IsEqualVecAuxCols, IsEqualVecIOCols}, - is_less_than_tuple::columns::IsLessThanTupleAuxCols, + is_equal_vec::columns::IsEqualVecAuxCols, is_less_than_tuple::columns::IsLessThanTupleAuxCols, }; use super::Comp; @@ -168,10 +167,15 @@ impl PageIndexScanInputCols { + 1 + 1 + IsLessThanTupleAuxCols::::get_width(idx_limb_bits, decomp, idx_len) - + IsEqualVecIOCols::::get_width(idx_len) + + IsEqualVecAuxCols::::get_width(idx_len) } Comp::Eq => { - 1 + idx_len + data_len + idx_len + 1 + 1 + IsEqualVecIOCols::::get_width(idx_len) + 1 + idx_len + + data_len + + idx_len + + 1 + + 1 + + IsEqualVecAuxCols::::get_width(idx_len) } } } diff --git a/chips/src/single_page_index_scan/page_index_scan_input/mod.rs b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs index d4f3e3a681..512a7fb624 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/mod.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use crate::{ - is_equal_vec::{columns::IsEqualVecIOCols, IsEqualVecAir}, + is_equal_vec::{columns::IsEqualVecAuxCols, IsEqualVecAir}, is_less_than_tuple::{columns::IsLessThanTupleAuxCols, IsLessThanTupleAir}, range_gate::RangeCheckerGateChip, }; @@ -171,11 +171,11 @@ impl PageIndexScanInputChip { is_less_than_tuple_air.decomp(), self.air.idx_len, ) - + IsEqualVecIOCols::::get_width(self.air.idx_len) + + IsEqualVecAuxCols::::get_width(self.air.idx_len) } PageIndexScanInputAirVariants::Eq(EqCompAir { .. }) => { // x, satisfies_pred, send_row, is_equal_vec_aux_cols - self.air.idx_len + 1 + 1 + IsEqualVecIOCols::::get_width(self.air.idx_len) + self.air.idx_len + 1 + 1 + IsEqualVecAuxCols::::get_width(self.air.idx_len) } } } From 9d72619361d6da16038a30824ddbf02935823c80 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Mon, 17 Jun 2024 10:30:59 -0400 Subject: [PATCH 41/46] merge IsEqualVec columns --- chips/src/is_equal_vec/columns.rs | 41 ++++++++++++++----------------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/chips/src/is_equal_vec/columns.rs b/chips/src/is_equal_vec/columns.rs index 3617407305..662304ba8d 100644 --- a/chips/src/is_equal_vec/columns.rs +++ b/chips/src/is_equal_vec/columns.rs @@ -6,24 +6,17 @@ pub struct IsEqualVecIOCols { } impl IsEqualVecIOCols { - pub fn new(x: Vec, y: Vec, prod: T) -> Self { - Self { x, y, prod } - } - pub fn flatten(&self) -> Vec { let mut res: Vec = self.x.iter().chain(self.y.iter()).cloned().collect(); res.push(self.prod.clone()); res } - // Note that the slice this function takes is of an unusual - // slc should be a whole row of the trace pub fn from_slice(slc: &[T], vec_len: usize) -> Self { - Self { - x: slc[0..vec_len].to_vec(), - y: slc[vec_len..2 * vec_len].to_vec(), - prod: slc[3 * vec_len - 1].clone(), - } + let x = slc[0..vec_len].to_vec(); + let y = slc[vec_len..2 * vec_len].to_vec(); + let prod = slc[2 * vec_len].clone(); + Self { x, y, prod } } pub fn get_width(vec_len: usize) -> usize { @@ -31,26 +24,21 @@ impl IsEqualVecIOCols { } } -#[derive(Default, Debug, Clone)] +#[derive(Default)] pub struct IsEqualVecAuxCols { pub prods: Vec, pub invs: Vec, } impl IsEqualVecAuxCols { - pub fn new(prods: Vec, invs: Vec) -> Self { - Self { prods, invs } - } - pub fn flatten(&self) -> Vec { self.prods.iter().chain(self.invs.iter()).cloned().collect() } pub fn from_slice(slc: &[T], vec_len: usize) -> Self { - Self { - prods: slc[0..vec_len].to_vec(), - invs: slc[vec_len..2 * vec_len].to_vec(), - } + let prods = slc[0..vec_len].to_vec(); + let invs = slc[vec_len..2 * vec_len].to_vec(); + Self { prods, invs } } pub fn get_width(vec_len: usize) -> usize { @@ -77,9 +65,15 @@ impl IsEqualVecCols { } pub fn from_slice(slc: &[T], vec_len: usize) -> Self { + let x = slc[0..vec_len].to_vec(); + let y = slc[vec_len..2 * vec_len].to_vec(); + let prod = slc[3 * vec_len - 1].clone(); + let prods = slc[2 * vec_len..3 * vec_len].to_vec(); + let invs = slc[3 * vec_len..4 * vec_len].to_vec(); + Self { - io: IsEqualVecIOCols::from_slice(slc, vec_len), - aux: IsEqualVecAuxCols::from_slice(&slc[2 * vec_len..], vec_len), + io: IsEqualVecIOCols { x, y, prod }, + aux: IsEqualVecAuxCols { prods, invs }, } } @@ -88,7 +82,8 @@ impl IsEqualVecCols { .x .iter() .chain(self.io.y.iter()) - .chain(self.aux.flatten().iter()) + .chain(self.aux.prods.iter()) + .chain(self.aux.invs.iter()) .cloned() .collect() } From 778979392ec413c539ed9cf5fc99a15fbc038551 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Mon, 17 Jun 2024 10:37:08 -0400 Subject: [PATCH 42/46] remove range_max from PageIndexScan chips --- chips/src/is_equal_vec/columns.rs | 2 +- chips/src/single_page_index_scan/page_controller/mod.rs | 6 +----- .../single_page_index_scan/page_index_scan_input/mod.rs | 5 ----- .../single_page_index_scan/page_index_scan_output/mod.rs | 8 +------- chips/src/single_page_index_scan/tests.rs | 2 ++ 5 files changed, 5 insertions(+), 18 deletions(-) diff --git a/chips/src/is_equal_vec/columns.rs b/chips/src/is_equal_vec/columns.rs index 662304ba8d..8cf26c1a4a 100644 --- a/chips/src/is_equal_vec/columns.rs +++ b/chips/src/is_equal_vec/columns.rs @@ -24,7 +24,7 @@ impl IsEqualVecIOCols { } } -#[derive(Default)] +#[derive(Default, Debug, Clone)] pub struct IsEqualVecAuxCols { pub prods: Vec, pub invs: Vec, diff --git a/chips/src/single_page_index_scan/page_controller/mod.rs b/chips/src/single_page_index_scan/page_controller/mod.rs index 9679b0584d..672d806aad 100644 --- a/chips/src/single_page_index_scan/page_controller/mod.rs +++ b/chips/src/single_page_index_scan/page_controller/mod.rs @@ -46,13 +46,12 @@ where idx_decomp: usize, cmp: Comp, ) -> Self { - let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, 1 << idx_decomp)); + let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); Self { input_chip: PageIndexScanInputChip::new( bus_index, idx_len, data_len, - range_max, idx_limb_bits.clone(), idx_decomp, range_checker.clone(), @@ -62,7 +61,6 @@ where bus_index, idx_len, data_len, - range_max, idx_limb_bits.clone(), idx_decomp, range_checker.clone(), @@ -295,7 +293,6 @@ where bus_index, idx_len, data_len, - self.range_checker.range_max(), idx_limb_bits.clone(), idx_decomp, self.range_checker.clone(), @@ -311,7 +308,6 @@ where bus_index, idx_len, data_len, - self.range_checker.range_max(), idx_limb_bits.clone(), idx_decomp, self.range_checker.clone(), diff --git a/chips/src/single_page_index_scan/page_index_scan_input/mod.rs b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs index 512a7fb624..9ecad9ce4b 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/mod.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs @@ -67,7 +67,6 @@ impl PageIndexScanInputChip { bus_index: usize, idx_len: usize, data_len: usize, - range_max: u32, idx_limb_bits: Vec, decomp: usize, range_checker: Arc, @@ -77,7 +76,6 @@ impl PageIndexScanInputChip { Comp::Lt => PageIndexScanInputAirVariants::Lt(StrictCompAir { is_less_than_tuple_air: IsLessThanTupleAir::new( bus_index, - range_max, idx_limb_bits.clone(), decomp, ), @@ -85,7 +83,6 @@ impl PageIndexScanInputChip { Comp::Lte => PageIndexScanInputAirVariants::Lte(NonStrictCompAir { is_less_than_tuple_air: IsLessThanTupleAir::new( bus_index, - range_max, idx_limb_bits.clone(), decomp, ), @@ -97,7 +94,6 @@ impl PageIndexScanInputChip { Comp::Gte => PageIndexScanInputAirVariants::Gte(NonStrictCompAir { is_less_than_tuple_air: IsLessThanTupleAir::new( bus_index, - range_max, idx_limb_bits.clone(), decomp, ), @@ -106,7 +102,6 @@ impl PageIndexScanInputChip { Comp::Gt => PageIndexScanInputAirVariants::Gt(StrictCompAir { is_less_than_tuple_air: IsLessThanTupleAir::new( bus_index, - range_max, idx_limb_bits.clone(), decomp, ), diff --git a/chips/src/single_page_index_scan/page_index_scan_output/mod.rs b/chips/src/single_page_index_scan/page_index_scan_output/mod.rs index 0986ac6ab8..0d8af2e405 100644 --- a/chips/src/single_page_index_scan/page_index_scan_output/mod.rs +++ b/chips/src/single_page_index_scan/page_index_scan_output/mod.rs @@ -40,7 +40,6 @@ impl PageIndexScanOutputChip { bus_index: usize, idx_len: usize, data_len: usize, - range_max: u32, idx_limb_bits: Vec, decomp: usize, range_checker: Arc, @@ -50,12 +49,7 @@ impl PageIndexScanOutputChip { bus_index, idx_len, data_len, - is_less_than_tuple_air: IsLessThanTupleAir::new( - bus_index, - range_max, - idx_limb_bits, - decomp, - ), + is_less_than_tuple_air: IsLessThanTupleAir::new(bus_index, idx_limb_bits, decomp), }, range_checker, } diff --git a/chips/src/single_page_index_scan/tests.rs b/chips/src/single_page_index_scan/tests.rs index b3e46e694c..85a3883404 100644 --- a/chips/src/single_page_index_scan/tests.rs +++ b/chips/src/single_page_index_scan/tests.rs @@ -430,6 +430,8 @@ fn test_single_page_index_scan_gte() { let page_output = page_controller.gen_output(page.clone(), x.clone(), idx_len, page_width, Comp::Gte); + println!("{:?}", page_output); + index_scan_test( &engine, page, From d830f5f244f34b53ad9e94c090256d73639cd304 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Mon, 17 Jun 2024 16:03:50 -0400 Subject: [PATCH 43/46] chore: address comments --- .../page_controller/mod.rs | 42 +- .../page_index_scan_input/air.rs | 307 +++++--------- .../page_index_scan_input/bridge.rs | 24 +- .../page_index_scan_input/columns.rs | 19 +- .../page_index_scan_input/mod.rs | 86 ++-- .../page_index_scan_input/trace.rs | 287 ++++++------- .../page_index_scan_output/trace.rs | 33 +- chips/src/single_page_index_scan/tests.rs | 391 ++++++------------ 8 files changed, 463 insertions(+), 726 deletions(-) diff --git a/chips/src/single_page_index_scan/page_controller/mod.rs b/chips/src/single_page_index_scan/page_controller/mod.rs index 672d806aad..8f25f74b3f 100644 --- a/chips/src/single_page_index_scan/page_controller/mod.rs +++ b/chips/src/single_page_index_scan/page_controller/mod.rs @@ -8,7 +8,7 @@ use p3_field::{AbstractField, PrimeField, PrimeField64}; use p3_matrix::dense::DenseMatrix; use p3_uni_stark::{StarkGenericConfig, Val}; -use crate::range_gate::RangeCheckerGateChip; +use crate::{common::page::Page, range_gate::RangeCheckerGateChip}; use super::{ page_index_scan_input::{Comp, PageIndexScanInputChip}, @@ -105,20 +105,13 @@ where )); } - pub fn gen_output( - &self, - page: Vec>, - x: Vec, - idx_len: usize, - page_width: usize, - cmp: Comp, - ) -> Vec> { + pub fn gen_output(&self, page: Page, x: Vec, page_width: usize, cmp: Comp) -> Page { let mut output: Vec> = vec![]; - for page_row in &page { - let is_alloc = page_row[0]; - let idx = page_row[1..1 + idx_len].to_vec(); - let data = page_row[1 + idx_len..].to_vec(); + for page_row in &page.rows { + let is_alloc = page_row.is_alloc; + let idx = page_row.idx.clone(); + let data = page_row.data.clone(); match cmp { Comp::Lt => { @@ -260,18 +253,18 @@ where } } - let num_remaining = page.len() - output.len(); + let num_remaining = page.rows.len() - output.len(); output.extend((0..num_remaining).map(|_| vec![0; page_width])); - output + Page::from_2d_vec(&output, page.rows[0].idx.len(), page.rows[0].data.len()) } #[allow(clippy::too_many_arguments)] pub fn load_page( &mut self, - page_input: Vec>, - page_output: Vec>, + page_input: Page, + page_output: Page, x: Vec, idx_len: usize, data_len: usize, @@ -285,7 +278,7 @@ where // idx_decomp can't change between different pages since range_checker depends on it assert!(1 << idx_decomp == self.range_checker.range_max()); - assert!(!page_input.is_empty()); + assert!(!page_input.rows.is_empty()); let bus_index = self.input_chip.air.bus_index; @@ -298,11 +291,9 @@ where self.range_checker.clone(), self.input_chip.cmp.clone(), ); - self.input_chip_trace = Some(self.input_chip.gen_page_trace::(page_input.clone())); - self.input_chip_aux_trace = Some( - self.input_chip - .gen_aux_trace::(page_input.clone(), x.clone()), - ); + self.input_chip_trace = Some(self.input_chip.gen_page_trace::(&page_input)); + self.input_chip_aux_trace = + Some(self.input_chip.gen_aux_trace::(&page_input, x.clone())); self.output_chip = PageIndexScanOutputChip::new( bus_index, @@ -313,9 +304,8 @@ where self.range_checker.clone(), ); - self.output_chip_trace = Some(self.output_chip.gen_page_trace::(page_output.clone())); - self.output_chip_aux_trace = - Some(self.output_chip.gen_aux_trace::(page_output.clone())); + self.output_chip_trace = Some(self.output_chip.gen_page_trace::(&page_output)); + self.output_chip_aux_trace = Some(self.output_chip.gen_aux_trace::(&page_output)); let prover_data = vec![ trace_committer.commit(vec![self.input_chip_trace.clone().unwrap()]), diff --git a/chips/src/single_page_index_scan/page_index_scan_input/air.rs b/chips/src/single_page_index_scan/page_index_scan_input/air.rs index 8b55aafd86..9e75f4e7ab 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/air.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/air.rs @@ -4,10 +4,8 @@ use p3_field::Field; use p3_matrix::Matrix; use crate::{ - is_equal_vec::columns::{IsEqualVecAuxCols, IsEqualVecCols, IsEqualVecIOCols}, - is_less_than_tuple::columns::{ - IsLessThanTupleAuxCols, IsLessThanTupleCols, IsLessThanTupleIOCols, - }, + is_equal_vec::columns::{IsEqualVecCols, IsEqualVecIOCols}, + is_less_than_tuple::columns::{IsLessThanTupleCols, IsLessThanTupleIOCols}, sub_chip::{AirConfig, SubAir}, }; @@ -90,7 +88,7 @@ where .collect::>(); let local = local_vec.as_slice(); - // when comparator is = we can use dummy values for idx_limb_bits and decomp + // get the idx_limb_bits and decomp, which will be used to generate local_cols let (idx_limb_bits, decomp) = match &self.variant_air { PageIndexScanInputAirVariants::Lt(StrictCompAir { is_less_than_tuple_air, @@ -114,6 +112,7 @@ where PageIndexScanInputAirVariants::Eq(EqCompAir { .. }) => (vec![], 0), }; + // get the comparator let cmp = match &self.variant_air { PageIndexScanInputAirVariants::Lt(..) => Comp::Lt, PageIndexScanInputAirVariants::Gt(..) => Comp::Gt, @@ -144,28 +143,67 @@ where builder.assert_bool(local_cols.satisfies_pred); builder.assert_bool(local_cols.send_row); - // generate flattened aux columns for IsLessThanTuple and IsEqualVec - let is_less_than_tuple_aux_flattened = match &local_cols.aux_cols { - PageIndexScanInputAuxCols::Lt(StrictCompAuxCols { - is_less_than_tuple_aux, - .. - }) - | PageIndexScanInputAuxCols::Gt(StrictCompAuxCols { - is_less_than_tuple_aux, - .. - }) - | PageIndexScanInputAuxCols::Lte(NonStrictCompAuxCols { - is_less_than_tuple_aux, - .. - }) - | PageIndexScanInputAuxCols::Gte(NonStrictCompAuxCols { - is_less_than_tuple_aux, - .. - }) => is_less_than_tuple_aux.flatten(), - PageIndexScanInputAuxCols::Eq(EqCompAuxCols { .. }) => vec![], - }; + // get the indicators for strict and equal comparisons + let (strict_comp_ind, equal_comp_ind): (Option, Option) = + match &local_cols.aux_cols { + PageIndexScanInputAuxCols::Lt(..) | PageIndexScanInputAuxCols::Gt(..) => { + (Some(local_cols.satisfies_pred), None) + } + PageIndexScanInputAuxCols::Lte(NonStrictCompAuxCols { + satisfies_strict_comp, + satisfies_eq_comp, + .. + }) + | PageIndexScanInputAuxCols::Gte(NonStrictCompAuxCols { + satisfies_strict_comp, + satisfies_eq_comp, + .. + }) => (Some(*satisfies_strict_comp), Some(*satisfies_eq_comp)), + PageIndexScanInputAuxCols::Eq(..) => (None, Some(local_cols.satisfies_pred)), + }; - let is_equal_vec_aux_flattened = match &local_cols.aux_cols { + // generate aux columns for IsLessThanTuple + let is_less_than_tuple_cols: Option> = + match &local_cols.aux_cols { + PageIndexScanInputAuxCols::Lt(StrictCompAuxCols { + is_less_than_tuple_aux, + .. + }) + | PageIndexScanInputAuxCols::Lte(NonStrictCompAuxCols { + is_less_than_tuple_aux, + .. + }) => Some(IsLessThanTupleCols { + io: IsLessThanTupleIOCols { + // idx < x + x: local_cols.page_cols.idx.clone(), + y: local_cols.x.clone(), + // use the strict_comp_ind + tuple_less_than: strict_comp_ind.unwrap(), + }, + aux: is_less_than_tuple_aux.clone(), + }), + PageIndexScanInputAuxCols::Gt(StrictCompAuxCols { + is_less_than_tuple_aux, + .. + }) + | PageIndexScanInputAuxCols::Gte(NonStrictCompAuxCols { + is_less_than_tuple_aux, + .. + }) => Some(IsLessThanTupleCols { + io: IsLessThanTupleIOCols { + // idx > x + x: local_cols.x.clone(), + y: local_cols.page_cols.idx.clone(), + // use the strict_comp_ind + tuple_less_than: strict_comp_ind.unwrap(), + }, + aux: is_less_than_tuple_aux.clone(), + }), + PageIndexScanInputAuxCols::Eq(EqCompAuxCols { .. }) => None, + }; + + // generate aux columns for IsEqualVec + let is_equal_vec_cols: Option> = match &local_cols.aux_cols { PageIndexScanInputAuxCols::Eq(EqCompAuxCols { is_equal_vec_aux, .. }) @@ -174,31 +212,34 @@ where }) | PageIndexScanInputAuxCols::Gte(NonStrictCompAuxCols { is_equal_vec_aux, .. - }) => is_equal_vec_aux.flatten(), - _ => vec![], + }) => { + let is_equal_vec_cols = IsEqualVecCols { + io: IsEqualVecIOCols { + x: local_cols.page_cols.idx.clone(), + y: local_cols.x.clone(), + // use the equal_comp_ind + prod: equal_comp_ind.unwrap(), + }, + aux: is_equal_vec_aux.clone(), + }; + Some(is_equal_vec_cols) + } + _ => None, }; + // constrain that satisfies pred is correct match &self.variant_air { PageIndexScanInputAirVariants::Lt(StrictCompAir { is_less_than_tuple_air, .. + }) + | PageIndexScanInputAirVariants::Gt(StrictCompAir { + is_less_than_tuple_air, + .. }) => { - // here, we are checking if idx < x - let is_less_than_tuple_cols = IsLessThanTupleCols { - io: IsLessThanTupleIOCols { - x: local_cols.page_cols.idx.clone(), - y: local_cols.x.clone(), - tuple_less_than: local_cols.satisfies_pred, - }, - aux: IsLessThanTupleAuxCols::from_slice( - &is_less_than_tuple_aux_flattened, - idx_limb_bits.clone(), - decomp, - self.idx_len, - ), - }; + let is_less_than_tuple_cols = is_less_than_tuple_cols.unwrap(); - // constrain the indicator that we used to check whether idx < x is correct + // constrain the indicator that we used to check the strict comp is correct SubAir::eval( is_less_than_tuple_air, &mut builder.when_transition(), @@ -209,173 +250,45 @@ where PageIndexScanInputAirVariants::Lte(NonStrictCompAir { is_less_than_tuple_air, is_equal_vec_air, + }) + | PageIndexScanInputAirVariants::Gte(NonStrictCompAir { + is_less_than_tuple_air, + is_equal_vec_air, }) => { - match &local_cols.aux_cols { - PageIndexScanInputAuxCols::Lte(NonStrictCompAuxCols { - satisfies_strict_comp, - satisfies_eq_comp, - .. - }) => { - // here, we are checking if idx < x - let is_less_than_tuple_cols = IsLessThanTupleCols { - io: IsLessThanTupleIOCols { - x: local_cols.page_cols.idx.clone(), - y: local_cols.x.clone(), - tuple_less_than: *satisfies_strict_comp, - }, - aux: IsLessThanTupleAuxCols::from_slice( - &is_less_than_tuple_aux_flattened, - idx_limb_bits, - decomp, - self.idx_len, - ), - }; - - // constrain the indicator that we used to check whether idx < x is correct - SubAir::eval( - is_less_than_tuple_air, - &mut builder.when_transition(), - is_less_than_tuple_cols.io, - is_less_than_tuple_cols.aux, - ); - - // here, we are checking if idx = x - let is_equal_vec_cols = IsEqualVecCols { - io: IsEqualVecIOCols { - x: local_cols.page_cols.idx.clone(), - y: local_cols.x.clone(), - prod: *satisfies_eq_comp, - }, - aux: IsEqualVecAuxCols::from_slice( - &is_equal_vec_aux_flattened, - self.idx_len, - ), - }; + let is_less_than_tuple_cols = is_less_than_tuple_cols.unwrap(); + let is_equal_vec_cols = is_equal_vec_cols.unwrap(); - // constrain the indicator that we used to check whether idx = x is correct - SubAir::eval( - is_equal_vec_air, - builder, - is_equal_vec_cols.io, - is_equal_vec_cols.aux, - ); - - // constrain that satisfies_pred indicates whether idx <= x - builder.assert_eq( - *satisfies_strict_comp + *satisfies_eq_comp, - local_cols.satisfies_pred, - ); - } - _ => panic!("Unexpected aux cols"), - } - } - PageIndexScanInputAirVariants::Eq(EqCompAir { is_equal_vec_air }) => { - // here, we are checking if idx = x - let is_equal_vec_cols = IsEqualVecCols { - io: IsEqualVecIOCols { - x: local_cols.page_cols.idx.clone(), - y: local_cols.x.clone(), - prod: local_cols.satisfies_pred, - }, - aux: IsEqualVecAuxCols::from_slice(&is_equal_vec_aux_flattened, self.idx_len), - }; + // constrain the indicator that we used to check the strict comp is correct + SubAir::eval( + is_less_than_tuple_air, + &mut builder.when_transition(), + is_less_than_tuple_cols.io, + is_less_than_tuple_cols.aux, + ); - // constrain the indicator that we used to check whether idx = x is correct + // constrain the indicator that we used to check for equality is correct SubAir::eval( is_equal_vec_air, builder, is_equal_vec_cols.io, is_equal_vec_cols.aux, ); - } - PageIndexScanInputAirVariants::Gte(NonStrictCompAir { - is_less_than_tuple_air, - is_equal_vec_air, - }) => { - match &local_cols.aux_cols { - PageIndexScanInputAuxCols::Gte(NonStrictCompAuxCols { - satisfies_strict_comp, - satisfies_eq_comp, - .. - }) => { - // here, we are checking if idx > x - let is_less_than_tuple_cols = IsLessThanTupleCols { - io: IsLessThanTupleIOCols { - x: local_cols.x.clone(), - y: local_cols.page_cols.idx.clone(), - tuple_less_than: *satisfies_strict_comp, - }, - aux: IsLessThanTupleAuxCols::from_slice( - &is_less_than_tuple_aux_flattened, - idx_limb_bits, - decomp, - self.idx_len, - ), - }; - - // constrain the indicator that we used to check whether idx > x is correct - SubAir::eval( - is_less_than_tuple_air, - &mut builder.when_transition(), - is_less_than_tuple_cols.io, - is_less_than_tuple_cols.aux, - ); - // here, we are checking if idx = x - let is_equal_vec_cols = IsEqualVecCols { - io: IsEqualVecIOCols { - x: local_cols.page_cols.idx.clone(), - y: local_cols.x.clone(), - prod: *satisfies_eq_comp, - }, - aux: IsEqualVecAuxCols::from_slice( - &is_equal_vec_aux_flattened, - self.idx_len, - ), - }; - - // constrain the indicator that we used to check whether idx = x is correct - SubAir::eval( - is_equal_vec_air, - builder, - is_equal_vec_cols.io, - is_equal_vec_cols.aux, - ); - - // constrain that satisfies_pred indicates whether idx >= x - builder.assert_eq( - *satisfies_strict_comp + *satisfies_eq_comp, - local_cols.satisfies_pred, - ); - } - _ => panic!("Unexpected aux cols"), - } + // constrain that satisfies_pred indicates the nonstrict comparison + builder.assert_eq( + strict_comp_ind.unwrap() + equal_comp_ind.unwrap(), + local_cols.satisfies_pred, + ); } - PageIndexScanInputAirVariants::Gt(StrictCompAir { - is_less_than_tuple_air, - .. - }) => { - // here, we are checking if idx > x - let is_less_than_tuple_cols = IsLessThanTupleCols { - io: IsLessThanTupleIOCols { - x: local_cols.x.clone(), - y: local_cols.page_cols.idx.clone(), - tuple_less_than: local_cols.satisfies_pred, - }, - aux: IsLessThanTupleAuxCols::from_slice( - &is_less_than_tuple_aux_flattened, - idx_limb_bits, - decomp, - self.idx_len, - ), - }; + PageIndexScanInputAirVariants::Eq(EqCompAir { is_equal_vec_air }) => { + let is_equal_vec_cols = is_equal_vec_cols.unwrap(); - // constrain the indicator that we used to check whether idx > x is correct + // constrain the indicator that we used to check whether idx = x is correct SubAir::eval( - is_less_than_tuple_air, - &mut builder.when_transition(), - is_less_than_tuple_cols.io, - is_less_than_tuple_cols.aux, + is_equal_vec_air, + builder, + is_equal_vec_cols.io, + is_equal_vec_cols.aux, ); } } diff --git a/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs b/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs index 8ef5718a25..06d90e8e8f 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs @@ -91,7 +91,10 @@ impl AirBridge for PageIndexScanInputAir { // here, we generate the flattened aux columns for IsLessThanTuple, and get the indicator associated with the strict comparison // when the comparator is =, we can just generate dummy values - let (is_less_than_tuple_aux_flattened, strict_comp_ind) = match cols_numbered.aux_cols { + let (is_less_than_tuple_aux_flattened, strict_comp_ind): ( + Option>, + Option, + ) = match cols_numbered.aux_cols { PageIndexScanInputAuxCols::Lt(StrictCompAuxCols { is_less_than_tuple_aux, .. @@ -100,8 +103,8 @@ impl AirBridge for PageIndexScanInputAir { is_less_than_tuple_aux, .. }) => ( - is_less_than_tuple_aux.flatten(), - cols_numbered.satisfies_pred, + Some(is_less_than_tuple_aux.flatten()), + Some(cols_numbered.satisfies_pred), ), PageIndexScanInputAuxCols::Lte(NonStrictCompAuxCols { satisfies_strict_comp, @@ -112,8 +115,11 @@ impl AirBridge for PageIndexScanInputAir { satisfies_strict_comp, is_less_than_tuple_aux, .. - }) => (is_less_than_tuple_aux.flatten(), satisfies_strict_comp), - PageIndexScanInputAuxCols::Eq(EqCompAuxCols { .. }) => (vec![], 0), + }) => ( + Some(is_less_than_tuple_aux.flatten()), + Some(satisfies_strict_comp), + ), + PageIndexScanInputAuxCols::Eq(EqCompAuxCols { .. }) => (None, None), }; // get interactions from IsLessThanTuple subchip @@ -130,10 +136,10 @@ impl AirBridge for PageIndexScanInputAir { io: IsLessThanTupleIOCols { x: cols_numbered.page_cols.idx.clone(), y: cols_numbered.x.clone(), - tuple_less_than: strict_comp_ind, + tuple_less_than: strict_comp_ind.unwrap(), }, aux: IsLessThanTupleAuxCols::from_slice( - &is_less_than_tuple_aux_flattened, + &is_less_than_tuple_aux_flattened.unwrap(), idx_limb_bits, decomp, self.idx_len, @@ -154,10 +160,10 @@ impl AirBridge for PageIndexScanInputAir { io: IsLessThanTupleIOCols { x: cols_numbered.x.clone(), y: cols_numbered.page_cols.idx.clone(), - tuple_less_than: strict_comp_ind, + tuple_less_than: strict_comp_ind.unwrap(), }, aux: IsLessThanTupleAuxCols::from_slice( - &is_less_than_tuple_aux_flattened, + &is_less_than_tuple_aux_flattened.unwrap(), idx_limb_bits, decomp, self.idx_len, diff --git a/chips/src/single_page_index_scan/page_index_scan_input/columns.rs b/chips/src/single_page_index_scan/page_index_scan_input/columns.rs index 27fa587014..b60f6e45e8 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/columns.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/columns.rs @@ -1,25 +1,10 @@ use crate::{ - is_equal_vec::columns::IsEqualVecAuxCols, is_less_than_tuple::columns::IsLessThanTupleAuxCols, + common::page_cols::PageCols, is_equal_vec::columns::IsEqualVecAuxCols, + is_less_than_tuple::columns::IsLessThanTupleAuxCols, }; use super::Comp; -pub struct PageCols { - pub is_alloc: T, // indicates if row is allocated - pub idx: Vec, - pub data: Vec, -} - -impl PageCols { - pub fn from_slice(cols: &[T], idx_len: usize, data_len: usize) -> PageCols { - PageCols { - is_alloc: cols[0].clone(), - idx: cols[1..idx_len + 1].to_vec(), - data: cols[idx_len + 1..idx_len + data_len + 1].to_vec(), - } - } -} - pub struct StrictCompAuxCols { pub is_less_than_tuple_aux: IsLessThanTupleAuxCols, } diff --git a/chips/src/single_page_index_scan/page_index_scan_input/mod.rs b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs index 9ecad9ce4b..e6fa7592b5 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/mod.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs @@ -50,10 +50,50 @@ pub struct PageIndexScanInputAir { variant_air: PageIndexScanInputAirVariants, } +impl PageIndexScanInputAir { + pub fn new( + bus_index: usize, + idx_len: usize, + data_len: usize, + idx_limb_bits: Vec, + decomp: usize, + cmp: Comp, + ) -> Self { + let is_less_than_tuple_air = + IsLessThanTupleAir::new(bus_index, idx_limb_bits.clone(), decomp); + let is_equal_vec_air = IsEqualVecAir::new(idx_len); + + let variant_air = match cmp { + Comp::Lt => PageIndexScanInputAirVariants::Lt(StrictCompAir { + is_less_than_tuple_air, + }), + Comp::Lte => PageIndexScanInputAirVariants::Lte(NonStrictCompAir { + is_less_than_tuple_air, + is_equal_vec_air, + }), + Comp::Eq => PageIndexScanInputAirVariants::Eq(EqCompAir { is_equal_vec_air }), + Comp::Gte => PageIndexScanInputAirVariants::Gte(NonStrictCompAir { + is_less_than_tuple_air, + is_equal_vec_air, + }), + Comp::Gt => PageIndexScanInputAirVariants::Gt(StrictCompAir { + is_less_than_tuple_air, + }), + }; + + Self { + bus_index, + idx_len, + data_len, + variant_air, + } + } +} + /// Given a fixed predicate of the form index OP x, where OP is one of {<, <=, =, >=, >} /// and x is a private input, the PageIndexScanInputChip implements a chip such that the chip: /// -/// 1. Has public value x +/// 1. Has public value x and OP given by cmp (Lt, Lte, Eq, Gte, or Gt) /// 2. Sends all rows of the page that match the predicate index OP x where x is the public value pub struct PageIndexScanInputChip { pub air: PageIndexScanInputAir, @@ -72,48 +112,14 @@ impl PageIndexScanInputChip { range_checker: Arc, cmp: Comp, ) -> Self { - let variant_air = match cmp { - Comp::Lt => PageIndexScanInputAirVariants::Lt(StrictCompAir { - is_less_than_tuple_air: IsLessThanTupleAir::new( - bus_index, - idx_limb_bits.clone(), - decomp, - ), - }), - Comp::Lte => PageIndexScanInputAirVariants::Lte(NonStrictCompAir { - is_less_than_tuple_air: IsLessThanTupleAir::new( - bus_index, - idx_limb_bits.clone(), - decomp, - ), - is_equal_vec_air: IsEqualVecAir::new(idx_len), - }), - Comp::Eq => PageIndexScanInputAirVariants::Eq(EqCompAir { - is_equal_vec_air: IsEqualVecAir::new(idx_len), - }), - Comp::Gte => PageIndexScanInputAirVariants::Gte(NonStrictCompAir { - is_less_than_tuple_air: IsLessThanTupleAir::new( - bus_index, - idx_limb_bits.clone(), - decomp, - ), - is_equal_vec_air: IsEqualVecAir::new(idx_len), - }), - Comp::Gt => PageIndexScanInputAirVariants::Gt(StrictCompAir { - is_less_than_tuple_air: IsLessThanTupleAir::new( - bus_index, - idx_limb_bits.clone(), - decomp, - ), - }), - }; - - let air = PageIndexScanInputAir { + let air = PageIndexScanInputAir::new( bus_index, idx_len, data_len, - variant_air, - }; + idx_limb_bits, + decomp, + cmp.clone(), + ); Self { air, diff --git a/chips/src/single_page_index_scan/page_index_scan_input/trace.rs b/chips/src/single_page_index_scan/page_index_scan_input/trace.rs index 3817ab7e96..8a7b44c124 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/trace.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/trace.rs @@ -2,7 +2,7 @@ use p3_field::{AbstractField, PrimeField64}; use p3_matrix::dense::RowMajorMatrix; use p3_uni_stark::{StarkGenericConfig, Val}; -use crate::sub_chip::LocalTraceInstructions; +use crate::{common::page::Page, sub_chip::LocalTraceInstructions}; use super::{ EqCompAir, NonStrictCompAir, PageIndexScanInputAirVariants, PageIndexScanInputChip, @@ -11,29 +11,74 @@ use super::{ impl PageIndexScanInputChip { /// Generate the trace for the page table - pub fn gen_page_trace( - &self, - page: Vec>, - ) -> RowMajorMatrix> + pub fn gen_page_trace(&self, page: &Page) -> RowMajorMatrix> where - Val: AbstractField, + Val: AbstractField + PrimeField64, { - RowMajorMatrix::new( - page.into_iter() - .flat_map(|row| { - row.into_iter() - .map(Val::::from_wrapped_u32) - .collect::>>() - }) - .collect(), - self.page_width(), - ) + page.gen_trace() + } + + /// Helper function to handle trace generation with an IsLessThanTupleAir + fn handle_is_less_than_tuple( + &self, + is_less_than_tuple_trace: Vec>, + is_alloc: Val, + row: &mut Vec>, + ) where + Val: AbstractField + PrimeField64, + { + // satisfies_pred, send_row, is_less_than_tuple_aux_cols + row.push(is_less_than_tuple_trace[2 * self.air.idx_len]); + let send_row = is_less_than_tuple_trace[2 * self.air.idx_len] * is_alloc; + row.push(send_row); + + row.extend_from_slice(&is_less_than_tuple_trace[2 * self.air.idx_len + 1..]); + } + + /// Helper function to handle trace generation with an IsEqualVecAir + fn handle_is_equal_vec( + &self, + is_equal_vec_trace: Vec>, + is_alloc: Val, + row: &mut Vec>, + ) where + Val: AbstractField + PrimeField64, + { + // satisfies_pred, send_row, is_equal_vec_aux_cols + row.push(is_equal_vec_trace[3 * self.air.idx_len - 1]); + let send_row = is_equal_vec_trace[3 * self.air.idx_len - 1] * is_alloc; + row.push(send_row); + + row.extend_from_slice(&is_equal_vec_trace[2 * self.air.idx_len..]); + } + + /// Helper function to handle trace generation with an IsLessThanTupleAir and an IsEqualVecAir + fn handle_both_airs( + &self, + is_less_than_tuple_trace: Vec>, + is_equal_vec_trace: Vec>, + is_alloc: Val, + row: &mut Vec>, + ) where + Val: AbstractField + PrimeField64, + { + // satisfies_pred, send_row, satisfies_strict_comp, satisfies_eq_comp, is_less_than_tuple_aux_cols, is_equal_vec_aux_cols + let satisfies_pred = is_less_than_tuple_trace[2 * self.air.idx_len] + + is_equal_vec_trace[3 * self.air.idx_len - 1]; + row.push(satisfies_pred); + row.push(satisfies_pred * is_alloc); + + row.push(is_less_than_tuple_trace[2 * self.air.idx_len]); + row.push(is_equal_vec_trace[3 * self.air.idx_len - 1]); + + row.extend_from_slice(&is_less_than_tuple_trace[2 * self.air.idx_len + 1..]); + row.extend_from_slice(&is_equal_vec_trace[2 * self.air.idx_len..]); } /// Generate the trace for the auxiliary columns pub fn gen_aux_trace( &self, - page: Vec>, + page: &Page, x: Vec, ) -> RowMajorMatrix> where @@ -41,11 +86,11 @@ impl PageIndexScanInputChip { { let mut rows: Vec> = vec![]; - for page_row in &page { + for page_row in &page.rows { let mut row: Vec> = vec![]; - let is_alloc = Val::::from_canonical_u32(page_row[0]); - let idx = page_row[1..1 + self.air.idx_len].to_vec(); + let is_alloc = Val::::from_canonical_u32(page_row.is_alloc); + let idx = page_row.idx.clone(); // first, get the values for x let x_trace: Vec> = x @@ -54,148 +99,84 @@ impl PageIndexScanInputChip { .collect(); row.extend(x_trace); - match &self.air.variant_air { + let is_less_than_tuple_trace: Option>> = match &self.air.variant_air { PageIndexScanInputAirVariants::Lt(StrictCompAir { is_less_than_tuple_air, .. - }) => { - let is_less_than_tuple_trace: Vec> = - LocalTraceInstructions::generate_trace_row( - is_less_than_tuple_air, - (idx.clone(), x.clone(), self.range_checker.clone()), - ) - .flatten(); - - // satisfies_pred, send_row, is_less_than_tuple_aux_cols - row.push(is_less_than_tuple_trace[2 * self.air.idx_len]); - let send_row = is_less_than_tuple_trace[2 * self.air.idx_len] * is_alloc; - row.push(send_row); - - row.extend_from_slice(&is_less_than_tuple_trace[2 * self.air.idx_len + 1..]); - } - PageIndexScanInputAirVariants::Lte(NonStrictCompAir { + }) + | PageIndexScanInputAirVariants::Lte(NonStrictCompAir { is_less_than_tuple_air, - is_equal_vec_air, .. - }) => { - let is_less_than_tuple_trace: Vec> = - LocalTraceInstructions::generate_trace_row( - is_less_than_tuple_air, - (idx.clone(), x.clone(), self.range_checker.clone()), - ) - .flatten(); - - let is_equal_vec_trace: Vec> = - LocalTraceInstructions::generate_trace_row( - is_equal_vec_air, - ( - idx.clone() - .into_iter() - .map(Val::::from_canonical_u32) - .collect(), - x.clone() - .into_iter() - .map(Val::::from_canonical_u32) - .collect(), - ), - ) - .flatten(); - - // satisfies_pred, send_row, satisfies_strict_comp, satisfies_eq_comp, is_less_than_tuple_aux_cols, is_equal_vec_aux_cols - let satisfies_pred = is_less_than_tuple_trace[2 * self.air.idx_len] - + is_equal_vec_trace[3 * self.air.idx_len - 1]; - row.push(satisfies_pred); - row.push(satisfies_pred * is_alloc); - - row.push(is_less_than_tuple_trace[2 * self.air.idx_len]); - row.push(is_equal_vec_trace[3 * self.air.idx_len - 1]); - - row.extend_from_slice(&is_less_than_tuple_trace[2 * self.air.idx_len + 1..]); - row.extend_from_slice(&is_equal_vec_trace[2 * self.air.idx_len..]); - } - PageIndexScanInputAirVariants::Eq(EqCompAir { - is_equal_vec_air, .. - }) => { - let is_equal_vec_trace: Vec> = - LocalTraceInstructions::generate_trace_row( - is_equal_vec_air, - ( - idx.clone() - .into_iter() - .map(Val::::from_canonical_u32) - .collect(), - x.clone() - .into_iter() - .map(Val::::from_canonical_u32) - .collect(), - ), - ) - .flatten(); - - // satisfies_pred, send_row, is_equal_vec_aux_cols - row.push(is_equal_vec_trace[3 * self.air.idx_len - 1]); - let send_row = is_equal_vec_trace[3 * self.air.idx_len - 1] * is_alloc; - row.push(send_row); - - row.extend_from_slice(&is_equal_vec_trace[2 * self.air.idx_len..]); - } - PageIndexScanInputAirVariants::Gte(NonStrictCompAir { + }) => Some( + LocalTraceInstructions::generate_trace_row( + is_less_than_tuple_air, + (idx.clone(), x.clone(), self.range_checker.clone()), + ) + .flatten(), + ), + PageIndexScanInputAirVariants::Gt(StrictCompAir { is_less_than_tuple_air, - is_equal_vec_air, .. - }) => { - let is_less_than_tuple_trace: Vec> = - LocalTraceInstructions::generate_trace_row( - is_less_than_tuple_air, - (x.clone(), idx.clone(), self.range_checker.clone()), - ) - .flatten(); - - let is_equal_vec_trace: Vec> = - LocalTraceInstructions::generate_trace_row( - is_equal_vec_air, - ( - idx.clone() - .into_iter() - .map(Val::::from_canonical_u32) - .collect(), - x.clone() - .into_iter() - .map(Val::::from_canonical_u32) - .collect(), - ), - ) - .flatten(); - - // satisfies_pred, send_row, satisfies_strict_comp, satisfies_eq_comp, is_less_than_tuple_aux_cols, is_equal_vec_aux_cols - let satisfies_pred = is_less_than_tuple_trace[2 * self.air.idx_len] - + is_equal_vec_trace[3 * self.air.idx_len - 1]; - row.push(satisfies_pred); - row.push(satisfies_pred * is_alloc); - - row.push(is_less_than_tuple_trace[2 * self.air.idx_len]); - row.push(is_equal_vec_trace[3 * self.air.idx_len - 1]); - - row.extend_from_slice(&is_less_than_tuple_trace[2 * self.air.idx_len + 1..]); - row.extend_from_slice(&is_equal_vec_trace[2 * self.air.idx_len..]); - } - PageIndexScanInputAirVariants::Gt(StrictCompAir { + }) + | PageIndexScanInputAirVariants::Gte(NonStrictCompAir { is_less_than_tuple_air, .. - }) => { - let is_less_than_tuple_trace: Vec> = - LocalTraceInstructions::generate_trace_row( - is_less_than_tuple_air, - (x.clone(), idx.clone(), self.range_checker.clone()), - ) - .flatten(); - - // satisfies_pred, send_row, is_less_than_tuple_aux_cols - row.push(is_less_than_tuple_trace[2 * self.air.idx_len]); - let send_row = is_less_than_tuple_trace[2 * self.air.idx_len] * is_alloc; - row.push(send_row); - - row.extend_from_slice(&is_less_than_tuple_trace[2 * self.air.idx_len + 1..]); + }) => Some( + LocalTraceInstructions::generate_trace_row( + is_less_than_tuple_air, + (x.clone(), idx.clone(), self.range_checker.clone()), + ) + .flatten(), + ), + _ => None, + }; + + let is_equal_vec_trace: Option>> = match &self.air.variant_air { + PageIndexScanInputAirVariants::Lte(NonStrictCompAir { + is_equal_vec_air, .. + }) + | PageIndexScanInputAirVariants::Eq(EqCompAir { + is_equal_vec_air, .. + }) + | PageIndexScanInputAirVariants::Gte(NonStrictCompAir { + is_equal_vec_air, .. + }) => Some( + LocalTraceInstructions::generate_trace_row( + is_equal_vec_air, + ( + idx.clone() + .into_iter() + .map(Val::::from_canonical_u32) + .collect(), + x.clone() + .into_iter() + .map(Val::::from_canonical_u32) + .collect(), + ), + ) + .flatten(), + ), + _ => None, + }; + + match &self.air.variant_air { + PageIndexScanInputAirVariants::Lt(..) | PageIndexScanInputAirVariants::Gt(..) => { + self.handle_is_less_than_tuple::( + is_less_than_tuple_trace.unwrap(), + is_alloc, + &mut row, + ); + } + PageIndexScanInputAirVariants::Lte(..) | PageIndexScanInputAirVariants::Gte(..) => { + self.handle_both_airs::( + is_less_than_tuple_trace.unwrap(), + is_equal_vec_trace.unwrap(), + is_alloc, + &mut row, + ); + } + PageIndexScanInputAirVariants::Eq(..) => { + self.handle_is_equal_vec::(is_equal_vec_trace.unwrap(), is_alloc, &mut row); } } diff --git a/chips/src/single_page_index_scan/page_index_scan_output/trace.rs b/chips/src/single_page_index_scan/page_index_scan_output/trace.rs index 9ad2d1d6d4..eff503b8c5 100644 --- a/chips/src/single_page_index_scan/page_index_scan_output/trace.rs +++ b/chips/src/single_page_index_scan/page_index_scan_output/trace.rs @@ -2,47 +2,32 @@ use p3_field::{AbstractField, PrimeField64}; use p3_matrix::dense::RowMajorMatrix; use p3_uni_stark::{StarkGenericConfig, Val}; -use crate::sub_chip::LocalTraceInstructions; +use crate::{common::page::Page, sub_chip::LocalTraceInstructions}; use super::PageIndexScanOutputChip; impl PageIndexScanOutputChip { /// Generate the trace for the page table - pub fn gen_page_trace( - &self, - page: Vec>, - ) -> RowMajorMatrix> + pub fn gen_page_trace(&self, page: &Page) -> RowMajorMatrix> where - Val: AbstractField, + Val: AbstractField + PrimeField64, { - RowMajorMatrix::new( - page.into_iter() - .flat_map(|row| { - row.into_iter() - .map(Val::::from_wrapped_u32) - .collect::>>() - }) - .collect(), - self.page_width(), - ) + page.gen_trace() } /// Generate the trace for the auxiliary columns - pub fn gen_aux_trace( - &self, - page: Vec>, - ) -> RowMajorMatrix> + pub fn gen_aux_trace(&self, page: &Page) -> RowMajorMatrix> where Val: AbstractField + PrimeField64, { let mut rows: Vec> = vec![]; - for i in 0..page.len() { + for i in 0..page.rows.len() { let page_row = page[i].clone(); - let next_page: Vec = if i == page.len() - 1 { + let next_page: Vec = if i == page.rows.len() - 1 { vec![0; 1 + self.air.idx_len + self.air.data_len] } else { - page[i + 1].clone() + page.rows[i + 1].to_vec() }; let mut row: Vec> = vec![]; @@ -50,7 +35,7 @@ impl PageIndexScanOutputChip { let is_less_than_tuple_trace = LocalTraceInstructions::generate_trace_row( self.air.is_less_than_tuple_air(), ( - page_row[1..1 + self.air.idx_len].to_vec(), + page_row.to_vec()[1..1 + self.air.idx_len].to_vec(), next_page[1..1 + self.air.idx_len].to_vec(), self.range_checker.clone(), ), diff --git a/chips/src/single_page_index_scan/tests.rs b/chips/src/single_page_index_scan/tests.rs index 85a3883404..9ef15ccf7f 100644 --- a/chips/src/single_page_index_scan/tests.rs +++ b/chips/src/single_page_index_scan/tests.rs @@ -13,13 +13,15 @@ use afs_test_utils::{ use p3_baby_bear::BabyBear; use p3_field::AbstractField; +use crate::common::page::Page; + use super::{page_controller::PageController, page_index_scan_input::Comp}; #[allow(clippy::too_many_arguments)] fn index_scan_test( engine: &BabyBearPoseidon2Engine, - page: Vec>, - page_output: Vec>, + page: Page, + page_output: Page, x: Vec, idx_len: usize, data_len: usize, @@ -29,7 +31,7 @@ fn index_scan_test( trace_builder: &mut TraceCommitmentBuilder, partial_pk: &MultiStarkPartialProvingKey, ) -> Result<(), VerificationError> { - let page_height = page.len(); + let page_height = page.rows.len(); assert!(page_height > 0); let (page_traces, mut prover_data) = page_controller.load_page( @@ -98,29 +100,17 @@ fn index_scan_test( ) } -#[test] -fn test_single_page_index_scan_lt() { - let bus_index: usize = 0; - let idx_len: usize = 2; - let data_len: usize = 3; - let decomp: usize = 8; - let limb_bits: Vec = vec![16, 16]; - let range_max: u32 = 1 << decomp; - - let log_page_height = 1; - let page_height = 1 << log_page_height; - let page_width = 1 + idx_len + data_len; - - let mut page_controller: PageController = PageController::new( - bus_index, - idx_len, - data_len, - range_max, - limb_bits.clone(), - decomp, - Comp::Lt, - ); - +fn generate_pk( + page_controller: &mut PageController, + log_page_height: usize, + page_width: usize, + page_height: usize, + idx_len: usize, + decomp: usize, +) -> ( + BabyBearPoseidon2Engine, + MultiStarkPartialProvingKey, +) { let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); let mut keygen_builder = MultiStarkKeygenBuilder::new(&engine.config); @@ -156,18 +146,53 @@ fn test_single_page_index_scan_lt() { let partial_pk = keygen_builder.generate_partial_pk(); - let prover = MultiTraceStarkProver::new(&engine.config); - let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); + (engine, partial_pk) +} + +#[test] +fn test_single_page_index_scan_lt() { + let bus_index: usize = 0; + let idx_len: usize = 2; + let data_len: usize = 3; + let decomp: usize = 8; + let limb_bits: Vec = vec![16, 16]; + let range_max: u32 = 1 << decomp; + + let log_page_height = 1; + let page_height = 1 << log_page_height; + let page_width = 1 + idx_len + data_len; + + let mut page_controller: PageController = PageController::new( + bus_index, + idx_len, + data_len, + range_max, + limb_bits.clone(), + decomp, + Comp::Lt, + ); let page: Vec> = vec![ vec![1, 443, 376, 22278, 13998, 58327], vec![1, 2883, 7769, 51171, 3989, 12770], ]; + let page = Page::from_2d_vec(&page, idx_len, data_len); let x: Vec = vec![2177, 5880]; - let page_output = - page_controller.gen_output(page.clone(), x.clone(), idx_len, page_width, Comp::Lt); + let page_output = page_controller.gen_output(page.clone(), x.clone(), page_width, Comp::Lt); + + let (engine, partial_pk) = generate_pk( + &mut page_controller, + log_page_height, + page_width, + page_height, + idx_len, + decomp, + ); + + let prover = MultiTraceStarkProver::new(&engine.config); + let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); index_scan_test( &engine, @@ -208,54 +233,28 @@ fn test_single_page_index_scan_lte() { Comp::Lte, ); - let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); + let page: Vec> = vec![ + vec![1, 443, 376, 22278, 13998, 58327], + vec![1, 2177, 5880, 51171, 3989, 12770], + ]; + let page = Page::from_2d_vec(&page, idx_len, data_len); - let mut keygen_builder = MultiStarkKeygenBuilder::new(&engine.config); + let x: Vec = vec![2177, 5880]; - let input_page_ptr = keygen_builder.add_cached_main_matrix(page_width); - let output_page_ptr = keygen_builder.add_cached_main_matrix(page_width); - let input_page_aux_ptr = keygen_builder.add_main_matrix(page_controller.input_chip.aux_width()); - let output_page_aux_ptr = - keygen_builder.add_main_matrix(page_controller.output_chip.aux_width()); - let range_checker_ptr = - keygen_builder.add_main_matrix(page_controller.range_checker.air_width()); + let page_output = page_controller.gen_output(page.clone(), x.clone(), page_width, Comp::Lte); - keygen_builder.add_partitioned_air( - &page_controller.input_chip.air, + let (engine, partial_pk) = generate_pk( + &mut page_controller, + log_page_height, + page_width, page_height, idx_len, - vec![input_page_ptr, input_page_aux_ptr], - ); - - keygen_builder.add_partitioned_air( - &page_controller.output_chip.air, - page_height, - 0, - vec![output_page_ptr, output_page_aux_ptr], - ); - - keygen_builder.add_partitioned_air( - &page_controller.range_checker.air, - 1 << decomp, - 0, - vec![range_checker_ptr], + decomp, ); - let partial_pk = keygen_builder.generate_partial_pk(); - let prover = MultiTraceStarkProver::new(&engine.config); let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); - let page: Vec> = vec![ - vec![1, 443, 376, 22278, 13998, 58327], - vec![1, 2177, 5880, 51171, 3989, 12770], - ]; - - let x: Vec = vec![2177, 5880]; - - let page_output = - page_controller.gen_output(page.clone(), x.clone(), idx_len, page_width, Comp::Lte); - index_scan_test( &engine, page, @@ -295,54 +294,28 @@ fn test_single_page_index_scan_eq() { Comp::Eq, ); - let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); + let page: Vec> = vec![ + vec![1, 443, 376, 22278, 13998, 58327], + vec![1, 2883, 7769, 51171, 3989, 12770], + ]; + let page = Page::from_2d_vec(&page, idx_len, data_len); - let mut keygen_builder = MultiStarkKeygenBuilder::new(&engine.config); + let x: Vec = vec![443, 376]; - let input_page_ptr = keygen_builder.add_cached_main_matrix(page_width); - let output_page_ptr = keygen_builder.add_cached_main_matrix(page_width); - let input_page_aux_ptr = keygen_builder.add_main_matrix(page_controller.input_chip.aux_width()); - let output_page_aux_ptr = - keygen_builder.add_main_matrix(page_controller.output_chip.aux_width()); - let range_checker_ptr = - keygen_builder.add_main_matrix(page_controller.range_checker.air_width()); + let page_output = page_controller.gen_output(page.clone(), x.clone(), page_width, Comp::Eq); - keygen_builder.add_partitioned_air( - &page_controller.input_chip.air, + let (engine, partial_pk) = generate_pk( + &mut page_controller, + log_page_height, + page_width, page_height, idx_len, - vec![input_page_ptr, input_page_aux_ptr], - ); - - keygen_builder.add_partitioned_air( - &page_controller.output_chip.air, - page_height, - 0, - vec![output_page_ptr, output_page_aux_ptr], - ); - - keygen_builder.add_partitioned_air( - &page_controller.range_checker.air, - 1 << decomp, - 0, - vec![range_checker_ptr], + decomp, ); - let partial_pk = keygen_builder.generate_partial_pk(); - let prover = MultiTraceStarkProver::new(&engine.config); let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); - let page: Vec> = vec![ - vec![1, 443, 376, 22278, 13998, 58327], - vec![1, 2883, 7769, 51171, 3989, 12770], - ]; - - let x: Vec = vec![443, 376]; - - let page_output = - page_controller.gen_output(page.clone(), x.clone(), idx_len, page_width, Comp::Eq); - index_scan_test( &engine, page, @@ -382,56 +355,28 @@ fn test_single_page_index_scan_gte() { Comp::Gte, ); - let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); + let page: Vec> = vec![ + vec![1, 2177, 5880, 22278, 13998, 58327], + vec![1, 2883, 7769, 51171, 3989, 12770], + ]; + let page = Page::from_2d_vec(&page, idx_len, data_len); - let mut keygen_builder = MultiStarkKeygenBuilder::new(&engine.config); + let x: Vec = vec![2177, 5880]; - let input_page_ptr = keygen_builder.add_cached_main_matrix(page_width); - let output_page_ptr = keygen_builder.add_cached_main_matrix(page_width); - let input_page_aux_ptr = keygen_builder.add_main_matrix(page_controller.input_chip.aux_width()); - let output_page_aux_ptr = - keygen_builder.add_main_matrix(page_controller.output_chip.aux_width()); - let range_checker_ptr = - keygen_builder.add_main_matrix(page_controller.range_checker.air_width()); + let page_output = page_controller.gen_output(page.clone(), x.clone(), page_width, Comp::Gte); - keygen_builder.add_partitioned_air( - &page_controller.input_chip.air, + let (engine, partial_pk) = generate_pk( + &mut page_controller, + log_page_height, + page_width, page_height, idx_len, - vec![input_page_ptr, input_page_aux_ptr], - ); - - keygen_builder.add_partitioned_air( - &page_controller.output_chip.air, - page_height, - 0, - vec![output_page_ptr, output_page_aux_ptr], - ); - - keygen_builder.add_partitioned_air( - &page_controller.range_checker.air, - 1 << decomp, - 0, - vec![range_checker_ptr], + decomp, ); - let partial_pk = keygen_builder.generate_partial_pk(); - let prover = MultiTraceStarkProver::new(&engine.config); let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); - let page: Vec> = vec![ - vec![1, 2177, 5880, 22278, 13998, 58327], - vec![1, 2883, 7769, 51171, 3989, 12770], - ]; - - let x: Vec = vec![2177, 5880]; - - let page_output = - page_controller.gen_output(page.clone(), x.clone(), idx_len, page_width, Comp::Gte); - - println!("{:?}", page_output); - index_scan_test( &engine, page, @@ -471,54 +416,28 @@ fn test_single_page_index_scan_gt() { Comp::Gt, ); - let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); + let page: Vec> = vec![ + vec![1, 2203, 376, 22278, 13998, 58327], + vec![1, 2883, 7769, 51171, 3989, 12770], + ]; + let page = Page::from_2d_vec(&page, idx_len, data_len); - let mut keygen_builder = MultiStarkKeygenBuilder::new(&engine.config); + let x: Vec = vec![2177, 5880]; - let input_page_ptr = keygen_builder.add_cached_main_matrix(page_width); - let output_page_ptr = keygen_builder.add_cached_main_matrix(page_width); - let input_page_aux_ptr = keygen_builder.add_main_matrix(page_controller.input_chip.aux_width()); - let output_page_aux_ptr = - keygen_builder.add_main_matrix(page_controller.output_chip.aux_width()); - let range_checker_ptr = - keygen_builder.add_main_matrix(page_controller.range_checker.air_width()); + let page_output = page_controller.gen_output(page.clone(), x.clone(), page_width, Comp::Gt); - keygen_builder.add_partitioned_air( - &page_controller.input_chip.air, + let (engine, partial_pk) = generate_pk( + &mut page_controller, + log_page_height, + page_width, page_height, idx_len, - vec![input_page_ptr, input_page_aux_ptr], - ); - - keygen_builder.add_partitioned_air( - &page_controller.output_chip.air, - page_height, - 0, - vec![output_page_ptr, output_page_aux_ptr], - ); - - keygen_builder.add_partitioned_air( - &page_controller.range_checker.air, - 1 << decomp, - 0, - vec![range_checker_ptr], + decomp, ); - let partial_pk = keygen_builder.generate_partial_pk(); - let prover = MultiTraceStarkProver::new(&engine.config); let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); - let page: Vec> = vec![ - vec![1, 2203, 376, 22278, 13998, 58327], - vec![1, 2883, 7769, 51171, 3989, 12770], - ]; - - let x: Vec = vec![2177, 5880]; - - let page_output = - page_controller.gen_output(page.clone(), x.clone(), idx_len, page_width, Comp::Gt); - index_scan_test( &engine, page, @@ -560,48 +479,11 @@ fn test_single_page_index_scan_wrong_order() { cmp, ); - let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); - - let mut keygen_builder = MultiStarkKeygenBuilder::new(&engine.config); - - let input_page_ptr = keygen_builder.add_cached_main_matrix(page_width); - let output_page_ptr = keygen_builder.add_cached_main_matrix(page_width); - let input_page_aux_ptr = keygen_builder.add_main_matrix(page_controller.input_chip.aux_width()); - let output_page_aux_ptr = - keygen_builder.add_main_matrix(page_controller.output_chip.aux_width()); - let range_checker_ptr = - keygen_builder.add_main_matrix(page_controller.range_checker.air_width()); - - keygen_builder.add_partitioned_air( - &page_controller.input_chip.air, - page_height, - idx_len, - vec![input_page_ptr, input_page_aux_ptr], - ); - - keygen_builder.add_partitioned_air( - &page_controller.output_chip.air, - page_height, - 0, - vec![output_page_ptr, output_page_aux_ptr], - ); - - keygen_builder.add_partitioned_air( - &page_controller.range_checker.air, - 1 << decomp, - 0, - vec![range_checker_ptr], - ); - - let partial_pk = keygen_builder.generate_partial_pk(); - - let prover = MultiTraceStarkProver::new(&engine.config); - let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); - let page: Vec> = vec![ vec![1, 443, 376, 22278, 13998, 58327], vec![1, 2883, 7769, 51171, 3989, 12770], ]; + let page = Page::from_2d_vec(&page, idx_len, data_len); let x: Vec = vec![2177, 5880]; @@ -609,6 +491,19 @@ fn test_single_page_index_scan_wrong_order() { vec![0, 0, 0, 0, 0, 0], vec![1, 443, 376, 22278, 13998, 58327], ]; + let page_output = Page::from_2d_vec(&page_output, idx_len, data_len); + + let (engine, partial_pk) = generate_pk( + &mut page_controller, + log_page_height, + page_width, + page_height, + idx_len, + decomp, + ); + + let prover = MultiTraceStarkProver::new(&engine.config); + let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); USE_DEBUG_BUILDER.with(|debug| { *debug.lock().unwrap() = false; @@ -657,48 +552,11 @@ fn test_single_page_index_scan_unsorted() { cmp, ); - let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); - - let mut keygen_builder = MultiStarkKeygenBuilder::new(&engine.config); - - let input_page_ptr = keygen_builder.add_cached_main_matrix(page_width); - let output_page_ptr = keygen_builder.add_cached_main_matrix(page_width); - let input_page_aux_ptr = keygen_builder.add_main_matrix(page_controller.input_chip.aux_width()); - let output_page_aux_ptr = - keygen_builder.add_main_matrix(page_controller.output_chip.aux_width()); - let range_checker_ptr = - keygen_builder.add_main_matrix(page_controller.range_checker.air_width()); - - keygen_builder.add_partitioned_air( - &page_controller.input_chip.air, - page_height, - idx_len, - vec![input_page_ptr, input_page_aux_ptr], - ); - - keygen_builder.add_partitioned_air( - &page_controller.output_chip.air, - page_height, - 0, - vec![output_page_ptr, output_page_aux_ptr], - ); - - keygen_builder.add_partitioned_air( - &page_controller.range_checker.air, - 1 << decomp, - 0, - vec![range_checker_ptr], - ); - - let partial_pk = keygen_builder.generate_partial_pk(); - - let prover = MultiTraceStarkProver::new(&engine.config); - let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); - let page: Vec> = vec![ vec![1, 2883, 7769, 51171, 3989, 12770], vec![1, 443, 376, 22278, 13998, 58327], ]; + let page = Page::from_2d_vec(&page, idx_len, data_len); let x: Vec = vec![2177, 5880]; @@ -706,6 +564,19 @@ fn test_single_page_index_scan_unsorted() { vec![0, 0, 0, 0, 0, 0], vec![1, 443, 376, 22278, 13998, 58327], ]; + let page_output = Page::from_2d_vec(&page_output, idx_len, data_len); + + let (engine, partial_pk) = generate_pk( + &mut page_controller, + log_page_height, + page_width, + page_height, + idx_len, + decomp, + ); + + let prover = MultiTraceStarkProver::new(&engine.config); + let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); USE_DEBUG_BUILDER.with(|debug| { *debug.lock().unwrap() = false; From 6751b3e090255eb46ad54d0dbe11d77af7abd678 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Mon, 17 Jun 2024 16:48:09 -0400 Subject: [PATCH 44/46] feat: use FinalPage --- .../page_controller/mod.rs | 24 +-- .../page_index_scan_input/bridge.rs | 2 +- .../page_index_scan_input/mod.rs | 18 ++- .../page_index_scan_output/air.rs | 93 +---------- .../page_index_scan_output/bridge.rs | 59 ++----- .../page_index_scan_output/columns.rs | 52 +++---- .../page_index_scan_output/mod.rs | 42 ++--- .../page_index_scan_output/trace.rs | 33 +--- chips/src/single_page_index_scan/tests.rs | 147 ++++++++++++++---- 9 files changed, 202 insertions(+), 268 deletions(-) diff --git a/chips/src/single_page_index_scan/page_controller/mod.rs b/chips/src/single_page_index_scan/page_controller/mod.rs index 8f25f74b3f..d5f81e737d 100644 --- a/chips/src/single_page_index_scan/page_controller/mod.rs +++ b/chips/src/single_page_index_scan/page_controller/mod.rs @@ -37,31 +37,33 @@ impl PageController where Val: AbstractField + PrimeField64, { + #[allow(clippy::too_many_arguments)] pub fn new( - bus_index: usize, + page_bus_index: usize, + range_bus_index: usize, idx_len: usize, data_len: usize, range_max: u32, - idx_limb_bits: Vec, + idx_limb_bits: usize, idx_decomp: usize, cmp: Comp, ) -> Self { - let range_checker = Arc::new(RangeCheckerGateChip::new(bus_index, range_max)); + let range_checker = Arc::new(RangeCheckerGateChip::new(range_bus_index, range_max)); Self { input_chip: PageIndexScanInputChip::new( - bus_index, + page_bus_index, idx_len, data_len, - idx_limb_bits.clone(), + idx_limb_bits, idx_decomp, range_checker.clone(), cmp, ), output_chip: PageIndexScanOutputChip::new( - bus_index, + page_bus_index, idx_len, data_len, - idx_limb_bits.clone(), + idx_limb_bits, idx_decomp, range_checker.clone(), ), @@ -268,7 +270,7 @@ where x: Vec, idx_len: usize, data_len: usize, - idx_limb_bits: Vec, + idx_limb_bits: usize, idx_decomp: usize, trace_committer: &mut TraceCommitter, ) -> (Vec>>, Vec>) @@ -280,13 +282,13 @@ where assert!(!page_input.rows.is_empty()); - let bus_index = self.input_chip.air.bus_index; + let bus_index = self.input_chip.air.page_bus_index; self.input_chip = PageIndexScanInputChip::new( bus_index, idx_len, data_len, - idx_limb_bits.clone(), + idx_limb_bits, idx_decomp, self.range_checker.clone(), self.input_chip.cmp.clone(), @@ -299,7 +301,7 @@ where bus_index, idx_len, data_len, - idx_limb_bits.clone(), + idx_limb_bits, idx_decomp, self.range_checker.clone(), ); diff --git a/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs b/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs index 06d90e8e8f..d71ca9fb2e 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/bridge.rs @@ -86,7 +86,7 @@ impl AirBridge for PageIndexScanInputAir { interactions.push(Interaction { fields: virtual_cols, count: VirtualPairCol::single_main(cols_numbered.send_row), - argument_index: self.bus_index, + argument_index: self.page_bus_index, }); // here, we generate the flattened aux columns for IsLessThanTuple, and get the indicator associated with the strict comparison diff --git a/chips/src/single_page_index_scan/page_index_scan_input/mod.rs b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs index e6fa7592b5..a69f3953e3 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/mod.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/mod.rs @@ -43,7 +43,7 @@ pub enum PageIndexScanInputAirVariants { } pub struct PageIndexScanInputAir { - pub bus_index: usize, + pub page_bus_index: usize, pub idx_len: usize, pub data_len: usize, @@ -52,15 +52,16 @@ pub struct PageIndexScanInputAir { impl PageIndexScanInputAir { pub fn new( - bus_index: usize, + page_bus_index: usize, + range_bus_index: usize, idx_len: usize, data_len: usize, - idx_limb_bits: Vec, + idx_limb_bits: usize, decomp: usize, cmp: Comp, ) -> Self { let is_less_than_tuple_air = - IsLessThanTupleAir::new(bus_index, idx_limb_bits.clone(), decomp); + IsLessThanTupleAir::new(range_bus_index, vec![idx_limb_bits; idx_len], decomp); let is_equal_vec_air = IsEqualVecAir::new(idx_len); let variant_air = match cmp { @@ -82,7 +83,7 @@ impl PageIndexScanInputAir { }; Self { - bus_index, + page_bus_index, idx_len, data_len, variant_air, @@ -104,16 +105,17 @@ pub struct PageIndexScanInputChip { impl PageIndexScanInputChip { #[allow(clippy::too_many_arguments)] pub fn new( - bus_index: usize, + page_bus_index: usize, idx_len: usize, data_len: usize, - idx_limb_bits: Vec, + idx_limb_bits: usize, decomp: usize, range_checker: Arc, cmp: Comp, ) -> Self { let air = PageIndexScanInputAir::new( - bus_index, + page_bus_index, + range_checker.bus_index(), idx_len, data_len, idx_limb_bits, diff --git a/chips/src/single_page_index_scan/page_index_scan_output/air.rs b/chips/src/single_page_index_scan/page_index_scan_output/air.rs index ab0390122e..3e0b60fb10 100644 --- a/chips/src/single_page_index_scan/page_index_scan_output/air.rs +++ b/chips/src/single_page_index_scan/page_index_scan_output/air.rs @@ -1,14 +1,8 @@ -use std::borrow::Borrow; - use afs_stark_backend::air_builders::PartitionedAirBuilder; -use p3_air::{Air, AirBuilder, BaseAir}; -use p3_field::{AbstractField, Field}; -use p3_matrix::Matrix; +use p3_air::{Air, BaseAir}; +use p3_field::Field; -use crate::{ - is_less_than_tuple::columns::{IsLessThanTupleCols, IsLessThanTupleIOCols}, - sub_chip::{AirConfig, SubAir}, -}; +use crate::sub_chip::AirConfig; use super::{columns::PageIndexScanOutputCols, PageIndexScanOutputAir}; @@ -18,12 +12,7 @@ impl AirConfig for PageIndexScanOutputAir { impl BaseAir for PageIndexScanOutputAir { fn width(&self) -> usize { - PageIndexScanOutputCols::::get_width( - self.idx_len, - self.data_len, - self.is_less_than_tuple_air().limb_bits().clone(), - self.is_less_than_tuple_air().decomp(), - ) + PageIndexScanOutputCols::::get_width(self.final_page_air.clone()) } } @@ -32,77 +21,7 @@ where AB::M: Clone, { fn eval(&self, builder: &mut AB) { - let page_main = &builder.partitioned_main()[0].clone(); - let aux_main = &builder.partitioned_main()[1].clone(); - - // get the current row and the next row - let (local_page, next_page) = (page_main.row_slice(0), page_main.row_slice(1)); - let local_page: &[AB::Var] = (*local_page).borrow(); - let next_page: &[AB::Var] = (*next_page).borrow(); - - let (local_aux, next_aux) = (aux_main.row_slice(0), aux_main.row_slice(1)); - let local_aux: &[AB::Var] = (*local_aux).borrow(); - let next_aux: &[AB::Var] = (*next_aux).borrow(); - - let local_vec = local_page - .iter() - .chain(local_aux.iter()) - .cloned() - .collect::>(); - let local = local_vec.as_slice(); - let next_vec = next_page - .iter() - .chain(next_aux.iter()) - .cloned() - .collect::>(); - let next = next_vec.as_slice(); - - let local_cols = PageIndexScanOutputCols::::from_slice( - local, - self.idx_len, - self.data_len, - self.is_less_than_tuple_air().limb_bits().clone(), - self.is_less_than_tuple_air().decomp(), - ); - let next_cols = PageIndexScanOutputCols::::from_slice( - next, - self.idx_len, - self.data_len, - self.is_less_than_tuple_air().limb_bits().clone(), - self.is_less_than_tuple_air().decomp(), - ); - - // check that is_alloc is a bool - builder.when_transition().assert_bool(local_cols.is_alloc); - // if local_cols.is_alloc is 1, then next_cols.is_alloc can be 0 or 1 - builder - .when_transition() - .assert_bool(local_cols.is_alloc * next_cols.is_alloc); - // if local_cols.is_alloc is 0, then next_cols.is_alloc must be 0 - builder - .when_transition() - .assert_zero((AB::Expr::one() - local_cols.is_alloc) * next_cols.is_alloc); - - let is_less_than_tuple_cols = IsLessThanTupleCols { - io: IsLessThanTupleIOCols { - x: local_cols.idx, - y: next_cols.idx, - tuple_less_than: local_cols.less_than_next_idx, - }, - aux: local_cols.is_less_than_tuple_aux, - }; - - // constrain the indicator that we used to check whether the current key < next key is correct - SubAir::eval( - self.is_less_than_tuple_air(), - &mut builder.when_transition(), - is_less_than_tuple_cols.io, - is_less_than_tuple_cols.aux, - ); - - // if the next row is allocated, then the current row index must be less than the next row index - builder - .when_transition() - .assert_zero(next_cols.is_alloc * (AB::Expr::one() - local_cols.less_than_next_idx)); + // Making sure the page is in the proper format + Air::eval(&self.final_page_air, builder); } } diff --git a/chips/src/single_page_index_scan/page_index_scan_output/bridge.rs b/chips/src/single_page_index_scan/page_index_scan_output/bridge.rs index 8ede9a08b4..a8515263d3 100644 --- a/chips/src/single_page_index_scan/page_index_scan_output/bridge.rs +++ b/chips/src/single_page_index_scan/page_index_scan_output/bridge.rs @@ -1,7 +1,4 @@ -use crate::{ - is_less_than_tuple::columns::{IsLessThanTupleCols, IsLessThanTupleIOCols}, - sub_chip::SubAirBridge, -}; +use crate::sub_chip::SubAirBridge; use super::columns::PageIndexScanOutputCols; use afs_stark_backend::interaction::{AirBridge, Interaction}; @@ -13,26 +10,16 @@ use super::PageIndexScanOutputAir; impl AirBridge for PageIndexScanOutputAir { // we receive the rows that satisfy the predicate fn receives(&self) -> Vec> { - let num_cols = PageIndexScanOutputCols::::get_width( - self.idx_len, - self.data_len, - self.is_less_than_tuple_air().limb_bits().clone(), - self.is_less_than_tuple_air().decomp(), - ); + let num_cols = PageIndexScanOutputCols::::get_width(self.final_page_air.clone()); let all_cols = (0..num_cols).collect::>(); - let cols_numbered = PageIndexScanOutputCols::::from_slice( - &all_cols, - self.idx_len, - self.data_len, - self.is_less_than_tuple_air().limb_bits().clone(), - self.is_less_than_tuple_air().decomp(), - ); + let cols_numbered = + PageIndexScanOutputCols::::from_slice(&all_cols, self.final_page_air.clone()); let mut cols = vec![]; - cols.push(cols_numbered.is_alloc); - cols.extend(cols_numbered.idx.clone()); - cols.extend(cols_numbered.data); + cols.push(cols_numbered.final_page_cols.page_cols.is_alloc); + cols.extend(cols_numbered.final_page_cols.page_cols.idx.clone()); + cols.extend(cols_numbered.final_page_cols.page_cols.data); let virtual_cols = cols .iter() @@ -41,39 +28,19 @@ impl AirBridge for PageIndexScanOutputAir { vec![Interaction { fields: virtual_cols, - count: VirtualPairCol::single_main(cols_numbered.is_alloc), - argument_index: self.bus_index, + count: VirtualPairCol::single_main(cols_numbered.final_page_cols.page_cols.is_alloc), + argument_index: self.page_bus_index, }] } // we send range checks that are from the IsLessThanTuple subchip fn sends(&self) -> Vec> { - let num_cols = PageIndexScanOutputCols::::get_width( - self.idx_len, - self.data_len, - self.is_less_than_tuple_air().limb_bits().clone(), - self.is_less_than_tuple_air().decomp(), - ); + let num_cols = PageIndexScanOutputCols::::get_width(self.final_page_air.clone()); let all_cols = (0..num_cols).collect::>(); - let cols_numbered = PageIndexScanOutputCols::::from_slice( - &all_cols, - self.idx_len, - self.data_len, - self.is_less_than_tuple_air().limb_bits().clone(), - self.is_less_than_tuple_air().decomp(), - ); + let my_final_page_cols = + PageIndexScanOutputCols::::from_slice(&all_cols, self.final_page_air.clone()); - // range check the decompositions of x within aux columns; here the io doesn't matter - let is_less_than_tuple_cols = IsLessThanTupleCols { - io: IsLessThanTupleIOCols { - x: cols_numbered.idx.clone(), - y: cols_numbered.idx.clone(), - tuple_less_than: cols_numbered.less_than_next_idx, - }, - aux: cols_numbered.is_less_than_tuple_aux, - }; - - SubAirBridge::::sends(&self.is_less_than_tuple_air, is_less_than_tuple_cols) + SubAirBridge::sends(&self.final_page_air, my_final_page_cols.final_page_cols) } } diff --git a/chips/src/single_page_index_scan/page_index_scan_output/columns.rs b/chips/src/single_page_index_scan/page_index_scan_output/columns.rs index 317f5d41b0..7a21f29e20 100644 --- a/chips/src/single_page_index_scan/page_index_scan_output/columns.rs +++ b/chips/src/single_page_index_scan/page_index_scan_output/columns.rs @@ -1,45 +1,33 @@ -use crate::is_less_than_tuple::columns::IsLessThanTupleAuxCols; +use crate::{ + final_page::{columns::FinalPageCols, FinalPageAir}, + is_less_than_tuple::columns::IsLessThanTupleAuxCols, +}; pub struct PageIndexScanOutputCols { - pub is_alloc: T, - pub idx: Vec, - pub data: Vec, - - pub less_than_next_idx: T, - pub is_less_than_tuple_aux: IsLessThanTupleAuxCols, + pub final_page_cols: FinalPageCols, } impl PageIndexScanOutputCols { - pub fn from_slice( - slc: &[T], - idx_len: usize, - data_len: usize, - idx_limb_bits: Vec, - decomp: usize, - ) -> Self { + pub fn from_slice(slc: &[T], final_page_air: FinalPageAir) -> Self { Self { - is_alloc: slc[0].clone(), - idx: slc[1..idx_len + 1].to_vec(), - data: slc[idx_len + 1..idx_len + data_len + 1].to_vec(), - less_than_next_idx: slc[idx_len + data_len + 1].clone(), - is_less_than_tuple_aux: IsLessThanTupleAuxCols::from_slice( - &slc[idx_len + data_len + 2..], - idx_limb_bits, - decomp, - idx_len, + final_page_cols: FinalPageCols::from_slice( + slc, + final_page_air.idx_len, + final_page_air.data_len, + final_page_air.idx_limb_bits, + final_page_air.idx_decomp, ), } } - pub fn get_width( - idx_len: usize, - data_len: usize, - idx_limb_bits: Vec, - decomp: usize, - ) -> usize { - 1 + idx_len - + data_len + pub fn get_width(final_page_air: FinalPageAir) -> usize { + 1 + final_page_air.idx_len + + final_page_air.data_len + 1 - + IsLessThanTupleAuxCols::::get_width(idx_limb_bits, decomp, idx_len) + + IsLessThanTupleAuxCols::::get_width( + vec![final_page_air.idx_limb_bits; final_page_air.idx_len], + final_page_air.idx_decomp, + final_page_air.idx_len, + ) } } diff --git a/chips/src/single_page_index_scan/page_index_scan_output/mod.rs b/chips/src/single_page_index_scan/page_index_scan_output/mod.rs index 0d8af2e405..5b6a321346 100644 --- a/chips/src/single_page_index_scan/page_index_scan_output/mod.rs +++ b/chips/src/single_page_index_scan/page_index_scan_output/mod.rs @@ -2,27 +2,19 @@ use std::sync::Arc; use getset::Getters; -use crate::{ - is_less_than_tuple::{columns::IsLessThanTupleAuxCols, IsLessThanTupleAir}, - range_gate::RangeCheckerGateChip, -}; +use crate::{final_page::FinalPageAir, range_gate::RangeCheckerGateChip}; pub mod air; pub mod bridge; pub mod columns; pub mod trace; -#[derive(Default, Getters)] +#[derive(Getters)] pub struct PageIndexScanOutputAir { - /// The bus index for sends to range chip - pub bus_index: usize, - /// The length of each index in the page table - pub idx_len: usize, - /// The length of each data entry in the page table - pub data_len: usize, + /// The bus index for page row receives + pub page_bus_index: usize, - #[getset(get = "pub")] - is_less_than_tuple_air: IsLessThanTupleAir, + pub final_page_air: FinalPageAir, } /// This chip receives rows from the PageIndexScanInputChip and constrains that: @@ -37,34 +29,34 @@ pub struct PageIndexScanOutputChip { impl PageIndexScanOutputChip { pub fn new( - bus_index: usize, + page_bus_index: usize, idx_len: usize, data_len: usize, - idx_limb_bits: Vec, + idx_limb_bits: usize, decomp: usize, range_checker: Arc, ) -> Self { Self { air: PageIndexScanOutputAir { - bus_index, - idx_len, - data_len, - is_less_than_tuple_air: IsLessThanTupleAir::new(bus_index, idx_limb_bits, decomp), + page_bus_index, + final_page_air: FinalPageAir::new( + range_checker.bus_index(), + idx_len, + data_len, + idx_limb_bits, + decomp, + ), }, range_checker, } } pub fn page_width(&self) -> usize { - 1 + self.air.idx_len + self.air.data_len + 1 + self.air.final_page_air.idx_len + self.air.final_page_air.data_len } pub fn aux_width(&self) -> usize { - 1 + IsLessThanTupleAuxCols::::get_width( - self.air.is_less_than_tuple_air().limb_bits(), - self.air.is_less_than_tuple_air().decomp(), - self.air.idx_len, - ) + self.air.final_page_air.aux_width() } pub fn air_width(&self) -> usize { diff --git a/chips/src/single_page_index_scan/page_index_scan_output/trace.rs b/chips/src/single_page_index_scan/page_index_scan_output/trace.rs index eff503b8c5..7a676cf597 100644 --- a/chips/src/single_page_index_scan/page_index_scan_output/trace.rs +++ b/chips/src/single_page_index_scan/page_index_scan_output/trace.rs @@ -2,7 +2,7 @@ use p3_field::{AbstractField, PrimeField64}; use p3_matrix::dense::RowMajorMatrix; use p3_uni_stark::{StarkGenericConfig, Val}; -use crate::{common::page::Page, sub_chip::LocalTraceInstructions}; +use crate::common::page::Page; use super::PageIndexScanOutputChip; @@ -20,33 +20,8 @@ impl PageIndexScanOutputChip { where Val: AbstractField + PrimeField64, { - let mut rows: Vec> = vec![]; - - for i in 0..page.rows.len() { - let page_row = page[i].clone(); - let next_page: Vec = if i == page.rows.len() - 1 { - vec![0; 1 + self.air.idx_len + self.air.data_len] - } else { - page.rows[i + 1].to_vec() - }; - - let mut row: Vec> = vec![]; - - let is_less_than_tuple_trace = LocalTraceInstructions::generate_trace_row( - self.air.is_less_than_tuple_air(), - ( - page_row.to_vec()[1..1 + self.air.idx_len].to_vec(), - next_page[1..1 + self.air.idx_len].to_vec(), - self.range_checker.clone(), - ), - ) - .flatten(); - - row.extend_from_slice(&is_less_than_tuple_trace[2 * self.air.idx_len..]); - - rows.extend_from_slice(&row); - } - - RowMajorMatrix::new(rows, self.aux_width()) + self.air + .final_page_air + .gen_aux_trace::(page, self.range_checker.clone()) } } diff --git a/chips/src/single_page_index_scan/tests.rs b/chips/src/single_page_index_scan/tests.rs index 9ef15ccf7f..a0651e2e34 100644 --- a/chips/src/single_page_index_scan/tests.rs +++ b/chips/src/single_page_index_scan/tests.rs @@ -25,7 +25,7 @@ fn index_scan_test( x: Vec, idx_len: usize, data_len: usize, - idx_limb_bits: Vec, + idx_limb_bits: usize, idx_decomp: usize, page_controller: &mut PageController, trace_builder: &mut TraceCommitmentBuilder, @@ -151,11 +151,12 @@ fn generate_pk( #[test] fn test_single_page_index_scan_lt() { - let bus_index: usize = 0; + let page_bus_index: usize = 0; + let range_bus_index: usize = 1; let idx_len: usize = 2; let data_len: usize = 3; let decomp: usize = 8; - let limb_bits: Vec = vec![16, 16]; + let limb_bits: usize = 16; let range_max: u32 = 1 << decomp; let log_page_height = 1; @@ -163,11 +164,12 @@ fn test_single_page_index_scan_lt() { let page_width = 1 + idx_len + data_len; let mut page_controller: PageController = PageController::new( - bus_index, + page_bus_index, + range_bus_index, idx_len, data_len, range_max, - limb_bits.clone(), + limb_bits, decomp, Comp::Lt, ); @@ -212,11 +214,12 @@ fn test_single_page_index_scan_lt() { #[test] fn test_single_page_index_scan_lte() { - let bus_index: usize = 0; + let page_bus_index: usize = 0; + let range_bus_index: usize = 1; let idx_len: usize = 2; let data_len: usize = 3; let decomp: usize = 8; - let limb_bits: Vec = vec![16, 16]; + let limb_bits: usize = 16; let range_max: u32 = 1 << decomp; let log_page_height = 1; @@ -224,11 +227,12 @@ fn test_single_page_index_scan_lte() { let page_width = 1 + idx_len + data_len; let mut page_controller: PageController = PageController::new( - bus_index, + page_bus_index, + range_bus_index, idx_len, data_len, range_max, - limb_bits.clone(), + limb_bits, decomp, Comp::Lte, ); @@ -273,11 +277,12 @@ fn test_single_page_index_scan_lte() { #[test] fn test_single_page_index_scan_eq() { - let bus_index: usize = 0; + let page_bus_index: usize = 0; + let range_bus_index: usize = 1; let idx_len: usize = 2; let data_len: usize = 3; let decomp: usize = 8; - let limb_bits: Vec = vec![16, 16]; + let limb_bits: usize = 16; let range_max: u32 = 1 << decomp; let log_page_height = 1; @@ -285,11 +290,12 @@ fn test_single_page_index_scan_eq() { let page_width = 1 + idx_len + data_len; let mut page_controller: PageController = PageController::new( - bus_index, + page_bus_index, + range_bus_index, idx_len, data_len, range_max, - limb_bits.clone(), + limb_bits, decomp, Comp::Eq, ); @@ -334,11 +340,12 @@ fn test_single_page_index_scan_eq() { #[test] fn test_single_page_index_scan_gte() { - let bus_index: usize = 0; + let page_bus_index: usize = 0; + let range_bus_index: usize = 1; let idx_len: usize = 2; let data_len: usize = 3; let decomp: usize = 8; - let limb_bits: Vec = vec![16, 16]; + let limb_bits: usize = 16; let range_max: u32 = 1 << decomp; let log_page_height = 1; @@ -346,11 +353,12 @@ fn test_single_page_index_scan_gte() { let page_width = 1 + idx_len + data_len; let mut page_controller: PageController = PageController::new( - bus_index, + page_bus_index, + range_bus_index, idx_len, data_len, range_max, - limb_bits.clone(), + limb_bits, decomp, Comp::Gte, ); @@ -395,11 +403,12 @@ fn test_single_page_index_scan_gte() { #[test] fn test_single_page_index_scan_gt() { - let bus_index: usize = 0; + let page_bus_index: usize = 0; + let range_bus_index: usize = 1; let idx_len: usize = 2; let data_len: usize = 3; let decomp: usize = 8; - let limb_bits: Vec = vec![16, 16]; + let limb_bits: usize = 16; let range_max: u32 = 1 << decomp; let log_page_height = 1; @@ -407,11 +416,12 @@ fn test_single_page_index_scan_gt() { let page_width = 1 + idx_len + data_len; let mut page_controller: PageController = PageController::new( - bus_index, + page_bus_index, + range_bus_index, idx_len, data_len, range_max, - limb_bits.clone(), + limb_bits, decomp, Comp::Gt, ); @@ -456,11 +466,12 @@ fn test_single_page_index_scan_gt() { #[test] fn test_single_page_index_scan_wrong_order() { - let bus_index: usize = 0; + let page_bus_index: usize = 0; + let range_bus_index: usize = 1; let idx_len: usize = 2; let data_len: usize = 3; let decomp: usize = 8; - let limb_bits: Vec = vec![16, 16]; + let limb_bits: usize = 16; let range_max: u32 = 1 << decomp; let log_page_height = 1; @@ -470,11 +481,12 @@ fn test_single_page_index_scan_wrong_order() { let cmp = Comp::Lt; let mut page_controller: PageController = PageController::new( - bus_index, + page_bus_index, + range_bus_index, idx_len, data_len, range_max, - limb_bits.clone(), + limb_bits, decomp, cmp, ); @@ -529,11 +541,12 @@ fn test_single_page_index_scan_wrong_order() { #[test] fn test_single_page_index_scan_unsorted() { - let bus_index: usize = 0; + let page_bus_index: usize = 0; + let range_bus_index: usize = 1; let idx_len: usize = 2; let data_len: usize = 3; let decomp: usize = 8; - let limb_bits: Vec = vec![16, 16]; + let limb_bits: usize = 16; let range_max: u32 = 1 << decomp; let log_page_height = 1; @@ -543,11 +556,12 @@ fn test_single_page_index_scan_unsorted() { let cmp = Comp::Lt; let mut page_controller: PageController = PageController::new( - bus_index, + page_bus_index, + range_bus_index, idx_len, data_len, range_max, - limb_bits.clone(), + limb_bits, decomp, cmp, ); @@ -599,3 +613,78 @@ fn test_single_page_index_scan_unsorted() { "Expected verification to fail, but it passed" ); } + +#[test] +fn test_single_page_index_scan_wrong_answer() { + let page_bus_index: usize = 0; + let range_bus_index: usize = 1; + let idx_len: usize = 2; + let data_len: usize = 3; + let decomp: usize = 8; + let limb_bits: usize = 16; + let range_max: u32 = 1 << decomp; + + let log_page_height = 1; + let page_height = 1 << log_page_height; + let page_width = 1 + idx_len + data_len; + + let cmp = Comp::Lt; + + let mut page_controller: PageController = PageController::new( + page_bus_index, + range_bus_index, + idx_len, + data_len, + range_max, + limb_bits, + decomp, + cmp, + ); + + let page: Vec> = vec![ + vec![1, 2883, 7769, 51171, 3989, 12770], + vec![1, 443, 376, 22278, 13998, 58327], + ]; + let page = Page::from_2d_vec(&page, idx_len, data_len); + + let x: Vec = vec![2177, 5880]; + + let page_output = vec![ + vec![1, 2883, 7769, 51171, 3989, 12770], + vec![0, 0, 0, 0, 0, 0], + ]; + let page_output = Page::from_2d_vec(&page_output, idx_len, data_len); + + let (engine, partial_pk) = generate_pk( + &mut page_controller, + log_page_height, + page_width, + page_height, + idx_len, + decomp, + ); + + let prover = MultiTraceStarkProver::new(&engine.config); + let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); + + USE_DEBUG_BUILDER.with(|debug| { + *debug.lock().unwrap() = false; + }); + assert_eq!( + index_scan_test( + &engine, + page, + page_output, + x, + idx_len, + data_len, + limb_bits, + decomp, + &mut page_controller, + &mut trace_builder, + &partial_pk, + ), + Err(VerificationError::NonZeroCumulativeSum), + "Expected verification to fail, but it passed" + ); +} From d269ad72cfc287e0d2ffb660bb60b542fc9ebdff Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Mon, 17 Jun 2024 17:57:13 -0400 Subject: [PATCH 45/46] chore: refactor index_scan_test --- .../page_controller/mod.rs | 168 +++++++++++++- chips/src/single_page_index_scan/tests.rs | 211 ++---------------- 2 files changed, 179 insertions(+), 200 deletions(-) diff --git a/chips/src/single_page_index_scan/page_controller/mod.rs b/chips/src/single_page_index_scan/page_controller/mod.rs index d5f81e737d..55a5be9804 100644 --- a/chips/src/single_page_index_scan/page_controller/mod.rs +++ b/chips/src/single_page_index_scan/page_controller/mod.rs @@ -1,12 +1,21 @@ use std::sync::Arc; use afs_stark_backend::{ - config::Com, - prover::trace::{ProverTraceData, TraceCommitter}, + config::{Com, PcsProof, PcsProverData}, + keygen::{ + types::{MultiStarkPartialProvingKey, MultiStarkPartialVerifyingKey}, + MultiStarkKeygenBuilder, + }, + prover::{ + trace::{ProverTraceData, TraceCommitmentBuilder, TraceCommitter}, + types::Proof, + }, + verifier::VerificationError, }; +use afs_test_utils::engine::StarkEngine; use p3_field::{AbstractField, PrimeField, PrimeField64}; use p3_matrix::dense::DenseMatrix; -use p3_uni_stark::{StarkGenericConfig, Val}; +use p3_uni_stark::{Domain, StarkGenericConfig, Val}; use crate::{common::page::Page, range_gate::RangeCheckerGateChip}; @@ -30,6 +39,9 @@ where input_commitment: Option>, output_commitment: Option>, + page_traces: Vec>>, + prover_data: Vec>, + pub range_checker: Arc, } @@ -73,6 +85,8 @@ where output_chip_aux_trace: None, input_commitment: None, output_commitment: None, + page_traces: vec![], + prover_data: vec![], range_checker, } } @@ -262,6 +276,139 @@ where Page::from_2d_vec(&output, page.rows[0].idx.len(), page.rows[0].data.len()) } + pub fn set_up_keygen_builder( + &self, + keygen_builder: &mut MultiStarkKeygenBuilder, + page_width: usize, + page_height: usize, + idx_len: usize, + decomp: usize, + ) { + let input_page_ptr = keygen_builder.add_cached_main_matrix(page_width); + let output_page_ptr = keygen_builder.add_cached_main_matrix(page_width); + let input_page_aux_ptr = keygen_builder.add_main_matrix(self.input_chip.aux_width()); + let output_page_aux_ptr = keygen_builder.add_main_matrix(self.output_chip.aux_width()); + let range_checker_ptr = keygen_builder.add_main_matrix(self.range_checker.air_width()); + + keygen_builder.add_partitioned_air( + &self.input_chip.air, + page_height, + idx_len, + vec![input_page_ptr, input_page_aux_ptr], + ); + + keygen_builder.add_partitioned_air( + &self.output_chip.air, + page_height, + 0, + vec![output_page_ptr, output_page_aux_ptr], + ); + + keygen_builder.add_partitioned_air( + &self.range_checker.air, + 1 << decomp, + 0, + vec![range_checker_ptr], + ); + } + + pub fn prove( + &mut self, + engine: &dyn StarkEngine, + partial_pk: &MultiStarkPartialProvingKey, + trace_builder: &mut TraceCommitmentBuilder, + x: Vec, + idx_decomp: usize, + ) -> Proof + where + Val: PrimeField, + Domain: Send + Sync, + SC::Pcs: Sync, + Domain: Send + Sync, + PcsProverData: Send + Sync, + Com: Send + Sync, + SC::Challenge: Send + Sync, + PcsProof: Send + Sync, + { + let page_traces = self.page_traces.clone(); + + let input_chip_aux_trace = self.input_chip_aux_trace(); + let output_chip_aux_trace = self.output_chip_aux_trace(); + let range_checker_trace = self.range_checker_trace(); + + // Clearing the range_checker counts + self.update_range_checker(idx_decomp); + + trace_builder.clear(); + + trace_builder.load_cached_trace(page_traces[0].clone(), self.prover_data.remove(0)); + trace_builder.load_cached_trace(page_traces[1].clone(), self.prover_data.remove(0)); + trace_builder.load_trace(input_chip_aux_trace); + trace_builder.load_trace(output_chip_aux_trace); + trace_builder.load_trace(range_checker_trace); + + trace_builder.commit_current(); + + let partial_vk = partial_pk.partial_vk(); + + let main_trace_data = trace_builder.view( + &partial_vk, + vec![ + &self.input_chip.air, + &self.output_chip.air, + &self.range_checker.air, + ], + ); + + let pis = vec![ + x.iter() + .map(|x| Val::::from_canonical_u32(*x)) + .collect(), + vec![], + vec![], + ]; + + let prover = engine.prover(); + let mut challenger = engine.new_challenger(); + + prover.prove(&mut challenger, partial_pk, main_trace_data, &pis) + } + + /// This function takes a proof (returned by the prove function) and verifies it + pub fn verify( + &self, + engine: &dyn StarkEngine, + partial_vk: MultiStarkPartialVerifyingKey, + proof: Proof, + x: Vec, + ) -> Result<(), VerificationError> + where + Val: PrimeField, + { + let verifier = engine.verifier(); + + let pis = vec![ + x.iter() + .map(|x| Val::::from_canonical_u32(*x)) + .collect(), + vec![], + vec![], + ]; + + let mut challenger = engine.new_challenger(); + verifier.verify( + &mut challenger, + partial_vk, + vec![ + &self.input_chip.air, + &self.output_chip.air, + &self.range_checker.air, + ], + proof, + &pis, + ) + } + #[allow(clippy::too_many_arguments)] pub fn load_page( &mut self, @@ -273,8 +420,7 @@ where idx_limb_bits: usize, idx_decomp: usize, trace_committer: &mut TraceCommitter, - ) -> (Vec>>, Vec>) - where + ) where Val: PrimeField, { // idx_decomp can't change between different pages since range_checker depends on it @@ -317,12 +463,10 @@ where self.input_commitment = Some(prover_data[0].commit.clone()); self.output_commitment = Some(prover_data[1].commit.clone()); - ( - vec![ - self.input_chip_trace.clone().unwrap(), - self.output_chip_trace.clone().unwrap(), - ], - prover_data, - ) + self.page_traces = vec![ + self.input_chip_trace.clone().unwrap(), + self.output_chip_trace.clone().unwrap(), + ]; + self.prover_data = prover_data; } } diff --git a/chips/src/single_page_index_scan/tests.rs b/chips/src/single_page_index_scan/tests.rs index a0651e2e34..7d464849ce 100644 --- a/chips/src/single_page_index_scan/tests.rs +++ b/chips/src/single_page_index_scan/tests.rs @@ -1,17 +1,12 @@ use afs_stark_backend::{ - keygen::{types::MultiStarkPartialProvingKey, MultiStarkKeygenBuilder}, + keygen::MultiStarkKeygenBuilder, prover::{trace::TraceCommitmentBuilder, MultiTraceStarkProver, USE_DEBUG_BUILDER}, verifier::VerificationError, }; -use afs_test_utils::{ - config::{ - self, - baby_bear_poseidon2::{BabyBearPoseidon2Config, BabyBearPoseidon2Engine}, - }, - engine::StarkEngine, +use afs_test_utils::config::{ + self, + baby_bear_poseidon2::{BabyBearPoseidon2Config, BabyBearPoseidon2Engine}, }; -use p3_baby_bear::BabyBear; -use p3_field::AbstractField; use crate::common::page::Page; @@ -29,12 +24,11 @@ fn index_scan_test( idx_decomp: usize, page_controller: &mut PageController, trace_builder: &mut TraceCommitmentBuilder, - partial_pk: &MultiStarkPartialProvingKey, ) -> Result<(), VerificationError> { let page_height = page.rows.len(); assert!(page_height > 0); - let (page_traces, mut prover_data) = page_controller.load_page( + page_controller.load_page( page.clone(), page_output.clone(), x.clone(), @@ -45,108 +39,24 @@ fn index_scan_test( &mut trace_builder.committer, ); - let input_chip_aux_trace = page_controller.input_chip_aux_trace(); - let output_chip_aux_trace = page_controller.output_chip_aux_trace(); - let range_checker_trace = page_controller.range_checker_trace(); - - // Clearing the range_checker counts - page_controller.update_range_checker(idx_decomp); - - trace_builder.clear(); - - trace_builder.load_cached_trace(page_traces[0].clone(), prover_data.remove(0)); - trace_builder.load_cached_trace(page_traces[1].clone(), prover_data.remove(0)); - trace_builder.load_trace(input_chip_aux_trace); - trace_builder.load_trace(output_chip_aux_trace); - trace_builder.load_trace(range_checker_trace); - - trace_builder.commit_current(); - - let partial_vk = partial_pk.partial_vk(); - - let main_trace_data = trace_builder.view( - &partial_vk, - vec![ - &page_controller.input_chip.air, - &page_controller.output_chip.air, - &page_controller.range_checker.air, - ], - ); - - let pis = vec![ - x.iter().map(|x| BabyBear::from_canonical_u32(*x)).collect(), - vec![], - vec![], - ]; - - let prover = engine.prover(); - let verifier = engine.verifier(); - - let mut challenger = engine.new_challenger(); - let proof = prover.prove(&mut challenger, partial_pk, main_trace_data, &pis); - - let mut challenger = engine.new_challenger(); - - verifier.verify( - &mut challenger, - partial_vk, - vec![ - &page_controller.input_chip.air, - &page_controller.output_chip.air, - &page_controller.range_checker.air, - ], - proof, - &pis, - ) -} - -fn generate_pk( - page_controller: &mut PageController, - log_page_height: usize, - page_width: usize, - page_height: usize, - idx_len: usize, - decomp: usize, -) -> ( - BabyBearPoseidon2Engine, - MultiStarkPartialProvingKey, -) { - let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); - let mut keygen_builder = MultiStarkKeygenBuilder::new(&engine.config); + let page_width = 1 + idx_len + data_len; + let page_height = page.rows.len(); - let input_page_ptr = keygen_builder.add_cached_main_matrix(page_width); - let output_page_ptr = keygen_builder.add_cached_main_matrix(page_width); - let input_page_aux_ptr = keygen_builder.add_main_matrix(page_controller.input_chip.aux_width()); - let output_page_aux_ptr = - keygen_builder.add_main_matrix(page_controller.output_chip.aux_width()); - let range_checker_ptr = - keygen_builder.add_main_matrix(page_controller.range_checker.air_width()); - - keygen_builder.add_partitioned_air( - &page_controller.input_chip.air, + page_controller.set_up_keygen_builder( + &mut keygen_builder, + page_width, page_height, idx_len, - vec![input_page_ptr, input_page_aux_ptr], - ); - - keygen_builder.add_partitioned_air( - &page_controller.output_chip.air, - page_height, - 0, - vec![output_page_ptr, output_page_aux_ptr], - ); - - keygen_builder.add_partitioned_air( - &page_controller.range_checker.air, - 1 << decomp, - 0, - vec![range_checker_ptr], + idx_decomp, ); let partial_pk = keygen_builder.generate_partial_pk(); - (engine, partial_pk) + let proof = page_controller.prove(engine, &partial_pk, trace_builder, x.clone(), idx_decomp); + let partial_vk = partial_pk.partial_vk(); + + page_controller.verify(engine, partial_vk, proof, x.clone()) } #[test] @@ -160,7 +70,6 @@ fn test_single_page_index_scan_lt() { let range_max: u32 = 1 << decomp; let log_page_height = 1; - let page_height = 1 << log_page_height; let page_width = 1 + idx_len + data_len; let mut page_controller: PageController = PageController::new( @@ -184,14 +93,7 @@ fn test_single_page_index_scan_lt() { let page_output = page_controller.gen_output(page.clone(), x.clone(), page_width, Comp::Lt); - let (engine, partial_pk) = generate_pk( - &mut page_controller, - log_page_height, - page_width, - page_height, - idx_len, - decomp, - ); + let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); let prover = MultiTraceStarkProver::new(&engine.config); let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); @@ -207,7 +109,6 @@ fn test_single_page_index_scan_lt() { decomp, &mut page_controller, &mut trace_builder, - &partial_pk, ) .expect("Verification failed"); } @@ -223,7 +124,6 @@ fn test_single_page_index_scan_lte() { let range_max: u32 = 1 << decomp; let log_page_height = 1; - let page_height = 1 << log_page_height; let page_width = 1 + idx_len + data_len; let mut page_controller: PageController = PageController::new( @@ -247,14 +147,7 @@ fn test_single_page_index_scan_lte() { let page_output = page_controller.gen_output(page.clone(), x.clone(), page_width, Comp::Lte); - let (engine, partial_pk) = generate_pk( - &mut page_controller, - log_page_height, - page_width, - page_height, - idx_len, - decomp, - ); + let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); let prover = MultiTraceStarkProver::new(&engine.config); let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); @@ -270,7 +163,6 @@ fn test_single_page_index_scan_lte() { decomp, &mut page_controller, &mut trace_builder, - &partial_pk, ) .expect("Verification failed"); } @@ -286,7 +178,6 @@ fn test_single_page_index_scan_eq() { let range_max: u32 = 1 << decomp; let log_page_height = 1; - let page_height = 1 << log_page_height; let page_width = 1 + idx_len + data_len; let mut page_controller: PageController = PageController::new( @@ -310,14 +201,7 @@ fn test_single_page_index_scan_eq() { let page_output = page_controller.gen_output(page.clone(), x.clone(), page_width, Comp::Eq); - let (engine, partial_pk) = generate_pk( - &mut page_controller, - log_page_height, - page_width, - page_height, - idx_len, - decomp, - ); + let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); let prover = MultiTraceStarkProver::new(&engine.config); let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); @@ -333,7 +217,6 @@ fn test_single_page_index_scan_eq() { decomp, &mut page_controller, &mut trace_builder, - &partial_pk, ) .expect("Verification failed"); } @@ -349,7 +232,6 @@ fn test_single_page_index_scan_gte() { let range_max: u32 = 1 << decomp; let log_page_height = 1; - let page_height = 1 << log_page_height; let page_width = 1 + idx_len + data_len; let mut page_controller: PageController = PageController::new( @@ -373,14 +255,7 @@ fn test_single_page_index_scan_gte() { let page_output = page_controller.gen_output(page.clone(), x.clone(), page_width, Comp::Gte); - let (engine, partial_pk) = generate_pk( - &mut page_controller, - log_page_height, - page_width, - page_height, - idx_len, - decomp, - ); + let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); let prover = MultiTraceStarkProver::new(&engine.config); let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); @@ -396,7 +271,6 @@ fn test_single_page_index_scan_gte() { decomp, &mut page_controller, &mut trace_builder, - &partial_pk, ) .expect("Verification failed"); } @@ -412,7 +286,6 @@ fn test_single_page_index_scan_gt() { let range_max: u32 = 1 << decomp; let log_page_height = 1; - let page_height = 1 << log_page_height; let page_width = 1 + idx_len + data_len; let mut page_controller: PageController = PageController::new( @@ -436,14 +309,7 @@ fn test_single_page_index_scan_gt() { let page_output = page_controller.gen_output(page.clone(), x.clone(), page_width, Comp::Gt); - let (engine, partial_pk) = generate_pk( - &mut page_controller, - log_page_height, - page_width, - page_height, - idx_len, - decomp, - ); + let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); let prover = MultiTraceStarkProver::new(&engine.config); let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); @@ -459,7 +325,6 @@ fn test_single_page_index_scan_gt() { decomp, &mut page_controller, &mut trace_builder, - &partial_pk, ) .expect("Verification failed"); } @@ -475,8 +340,6 @@ fn test_single_page_index_scan_wrong_order() { let range_max: u32 = 1 << decomp; let log_page_height = 1; - let page_height = 1 << log_page_height; - let page_width = 1 + idx_len + data_len; let cmp = Comp::Lt; @@ -505,14 +368,7 @@ fn test_single_page_index_scan_wrong_order() { ]; let page_output = Page::from_2d_vec(&page_output, idx_len, data_len); - let (engine, partial_pk) = generate_pk( - &mut page_controller, - log_page_height, - page_width, - page_height, - idx_len, - decomp, - ); + let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); let prover = MultiTraceStarkProver::new(&engine.config); let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); @@ -532,7 +388,6 @@ fn test_single_page_index_scan_wrong_order() { decomp, &mut page_controller, &mut trace_builder, - &partial_pk, ), Err(VerificationError::OodEvaluationMismatch), "Expected verification to fail, but it passed" @@ -550,8 +405,6 @@ fn test_single_page_index_scan_unsorted() { let range_max: u32 = 1 << decomp; let log_page_height = 1; - let page_height = 1 << log_page_height; - let page_width = 1 + idx_len + data_len; let cmp = Comp::Lt; @@ -580,14 +433,7 @@ fn test_single_page_index_scan_unsorted() { ]; let page_output = Page::from_2d_vec(&page_output, idx_len, data_len); - let (engine, partial_pk) = generate_pk( - &mut page_controller, - log_page_height, - page_width, - page_height, - idx_len, - decomp, - ); + let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); let prover = MultiTraceStarkProver::new(&engine.config); let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); @@ -607,7 +453,6 @@ fn test_single_page_index_scan_unsorted() { decomp, &mut page_controller, &mut trace_builder, - &partial_pk, ), Err(VerificationError::OodEvaluationMismatch), "Expected verification to fail, but it passed" @@ -625,8 +470,6 @@ fn test_single_page_index_scan_wrong_answer() { let range_max: u32 = 1 << decomp; let log_page_height = 1; - let page_height = 1 << log_page_height; - let page_width = 1 + idx_len + data_len; let cmp = Comp::Lt; @@ -655,14 +498,7 @@ fn test_single_page_index_scan_wrong_answer() { ]; let page_output = Page::from_2d_vec(&page_output, idx_len, data_len); - let (engine, partial_pk) = generate_pk( - &mut page_controller, - log_page_height, - page_width, - page_height, - idx_len, - decomp, - ); + let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); let prover = MultiTraceStarkProver::new(&engine.config); let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); @@ -682,7 +518,6 @@ fn test_single_page_index_scan_wrong_answer() { decomp, &mut page_controller, &mut trace_builder, - &partial_pk, ), Err(VerificationError::NonZeroCumulativeSum), "Expected verification to fail, but it passed" From 6d3f039c249b5484d1ea5671147a64a440560f34 Mon Sep 17 00:00:00 2001 From: bfan <76703988+bfan05@users.noreply.github.com> Date: Tue, 18 Jun 2024 12:16:09 -0400 Subject: [PATCH 46/46] chore: address comments --- .../page_index_scan_input/air.rs | 4 +- .../page_index_scan_input/trace.rs | 18 +- .../page_index_scan_output/air.rs | 2 +- .../page_index_scan_output/bridge.rs | 8 +- .../page_index_scan_output/columns.rs | 18 +- chips/src/single_page_index_scan/tests.rs | 330 +++++++----------- 6 files changed, 151 insertions(+), 229 deletions(-) diff --git a/chips/src/single_page_index_scan/page_index_scan_input/air.rs b/chips/src/single_page_index_scan/page_index_scan_input/air.rs index 9e75f4e7ab..5ba8021620 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/air.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/air.rs @@ -242,7 +242,7 @@ where // constrain the indicator that we used to check the strict comp is correct SubAir::eval( is_less_than_tuple_air, - &mut builder.when_transition(), + builder, is_less_than_tuple_cols.io, is_less_than_tuple_cols.aux, ); @@ -261,7 +261,7 @@ where // constrain the indicator that we used to check the strict comp is correct SubAir::eval( is_less_than_tuple_air, - &mut builder.when_transition(), + builder, is_less_than_tuple_cols.io, is_less_than_tuple_cols.aux, ); diff --git a/chips/src/single_page_index_scan/page_index_scan_input/trace.rs b/chips/src/single_page_index_scan/page_index_scan_input/trace.rs index 8a7b44c124..80980b2448 100644 --- a/chips/src/single_page_index_scan/page_index_scan_input/trace.rs +++ b/chips/src/single_page_index_scan/page_index_scan_input/trace.rs @@ -122,11 +122,9 @@ impl PageIndexScanInputChip { is_less_than_tuple_air, .. }) => Some( - LocalTraceInstructions::generate_trace_row( - is_less_than_tuple_air, - (x.clone(), idx.clone(), self.range_checker.clone()), - ) - .flatten(), + is_less_than_tuple_air + .generate_trace_row((x.clone(), idx.clone(), self.range_checker.clone())) + .flatten(), ), _ => None, }; @@ -141,9 +139,8 @@ impl PageIndexScanInputChip { | PageIndexScanInputAirVariants::Gte(NonStrictCompAir { is_equal_vec_air, .. }) => Some( - LocalTraceInstructions::generate_trace_row( - is_equal_vec_air, - ( + is_equal_vec_air + .generate_trace_row(( idx.clone() .into_iter() .map(Val::::from_canonical_u32) @@ -152,9 +149,8 @@ impl PageIndexScanInputChip { .into_iter() .map(Val::::from_canonical_u32) .collect(), - ), - ) - .flatten(), + )) + .flatten(), ), _ => None, }; diff --git a/chips/src/single_page_index_scan/page_index_scan_output/air.rs b/chips/src/single_page_index_scan/page_index_scan_output/air.rs index 3e0b60fb10..5d006027f7 100644 --- a/chips/src/single_page_index_scan/page_index_scan_output/air.rs +++ b/chips/src/single_page_index_scan/page_index_scan_output/air.rs @@ -12,7 +12,7 @@ impl AirConfig for PageIndexScanOutputAir { impl BaseAir for PageIndexScanOutputAir { fn width(&self) -> usize { - PageIndexScanOutputCols::::get_width(self.final_page_air.clone()) + PageIndexScanOutputCols::::get_width(&self.final_page_air) } } diff --git a/chips/src/single_page_index_scan/page_index_scan_output/bridge.rs b/chips/src/single_page_index_scan/page_index_scan_output/bridge.rs index a8515263d3..078d46fdd1 100644 --- a/chips/src/single_page_index_scan/page_index_scan_output/bridge.rs +++ b/chips/src/single_page_index_scan/page_index_scan_output/bridge.rs @@ -10,11 +10,11 @@ use super::PageIndexScanOutputAir; impl AirBridge for PageIndexScanOutputAir { // we receive the rows that satisfy the predicate fn receives(&self) -> Vec> { - let num_cols = PageIndexScanOutputCols::::get_width(self.final_page_air.clone()); + let num_cols = PageIndexScanOutputCols::::get_width(&self.final_page_air); let all_cols = (0..num_cols).collect::>(); let cols_numbered = - PageIndexScanOutputCols::::from_slice(&all_cols, self.final_page_air.clone()); + PageIndexScanOutputCols::::from_slice(&all_cols, &self.final_page_air); let mut cols = vec![]; cols.push(cols_numbered.final_page_cols.page_cols.is_alloc); @@ -35,11 +35,11 @@ impl AirBridge for PageIndexScanOutputAir { // we send range checks that are from the IsLessThanTuple subchip fn sends(&self) -> Vec> { - let num_cols = PageIndexScanOutputCols::::get_width(self.final_page_air.clone()); + let num_cols = PageIndexScanOutputCols::::get_width(&self.final_page_air); let all_cols = (0..num_cols).collect::>(); let my_final_page_cols = - PageIndexScanOutputCols::::from_slice(&all_cols, self.final_page_air.clone()); + PageIndexScanOutputCols::::from_slice(&all_cols, &self.final_page_air); SubAirBridge::sends(&self.final_page_air, my_final_page_cols.final_page_cols) } diff --git a/chips/src/single_page_index_scan/page_index_scan_output/columns.rs b/chips/src/single_page_index_scan/page_index_scan_output/columns.rs index 7a21f29e20..77435ec359 100644 --- a/chips/src/single_page_index_scan/page_index_scan_output/columns.rs +++ b/chips/src/single_page_index_scan/page_index_scan_output/columns.rs @@ -1,14 +1,11 @@ -use crate::{ - final_page::{columns::FinalPageCols, FinalPageAir}, - is_less_than_tuple::columns::IsLessThanTupleAuxCols, -}; +use crate::final_page::{columns::FinalPageCols, FinalPageAir}; pub struct PageIndexScanOutputCols { pub final_page_cols: FinalPageCols, } impl PageIndexScanOutputCols { - pub fn from_slice(slc: &[T], final_page_air: FinalPageAir) -> Self { + pub fn from_slice(slc: &[T], final_page_air: &FinalPageAir) -> Self { Self { final_page_cols: FinalPageCols::from_slice( slc, @@ -20,14 +17,7 @@ impl PageIndexScanOutputCols { } } - pub fn get_width(final_page_air: FinalPageAir) -> usize { - 1 + final_page_air.idx_len - + final_page_air.data_len - + 1 - + IsLessThanTupleAuxCols::::get_width( - vec![final_page_air.idx_limb_bits; final_page_air.idx_len], - final_page_air.idx_decomp, - final_page_air.idx_len, - ) + pub fn get_width(final_page_air: &FinalPageAir) -> usize { + final_page_air.air_width() } } diff --git a/chips/src/single_page_index_scan/tests.rs b/chips/src/single_page_index_scan/tests.rs index 7d464849ce..72862228cd 100644 --- a/chips/src/single_page_index_scan/tests.rs +++ b/chips/src/single_page_index_scan/tests.rs @@ -12,6 +12,17 @@ use crate::common::page::Page; use super::{page_controller::PageController, page_index_scan_input::Comp}; +const PAGE_BUS_INDEX: usize = 0; +const RANGE_BUS_INDEX: usize = 1; +const IDX_LEN: usize = 2; +const DATA_LEN: usize = 3; +const DECOMP: usize = 8; +const LIMB_BITS: usize = 16; +const RANGE_MAX: u32 = 1 << DECOMP; + +const LOG_PAGE_HEIGHT: usize = 1; +const PAGE_WIDTH: usize = 1 + IDX_LEN + DATA_LEN; + #[allow(clippy::too_many_arguments)] fn index_scan_test( engine: &BabyBearPoseidon2Engine, @@ -61,39 +72,30 @@ fn index_scan_test( #[test] fn test_single_page_index_scan_lt() { - let page_bus_index: usize = 0; - let range_bus_index: usize = 1; - let idx_len: usize = 2; - let data_len: usize = 3; - let decomp: usize = 8; - let limb_bits: usize = 16; - let range_max: u32 = 1 << decomp; - - let log_page_height = 1; - let page_width = 1 + idx_len + data_len; + let cmp = Comp::Lt; let mut page_controller: PageController = PageController::new( - page_bus_index, - range_bus_index, - idx_len, - data_len, - range_max, - limb_bits, - decomp, - Comp::Lt, + PAGE_BUS_INDEX, + RANGE_BUS_INDEX, + IDX_LEN, + DATA_LEN, + RANGE_MAX, + LIMB_BITS, + DECOMP, + cmp.clone(), ); let page: Vec> = vec![ vec![1, 443, 376, 22278, 13998, 58327], vec![1, 2883, 7769, 51171, 3989, 12770], ]; - let page = Page::from_2d_vec(&page, idx_len, data_len); + let page = Page::from_2d_vec(&page, IDX_LEN, DATA_LEN); let x: Vec = vec![2177, 5880]; - let page_output = page_controller.gen_output(page.clone(), x.clone(), page_width, Comp::Lt); + let page_output = page_controller.gen_output(page.clone(), x.clone(), PAGE_WIDTH, cmp); - let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); + let engine = config::baby_bear_poseidon2::default_engine(LOG_PAGE_HEIGHT.max(DECOMP)); let prover = MultiTraceStarkProver::new(&engine.config); let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); @@ -103,10 +105,10 @@ fn test_single_page_index_scan_lt() { page, page_output, x, - idx_len, - data_len, - limb_bits, - decomp, + IDX_LEN, + DATA_LEN, + LIMB_BITS, + DECOMP, &mut page_controller, &mut trace_builder, ) @@ -115,39 +117,30 @@ fn test_single_page_index_scan_lt() { #[test] fn test_single_page_index_scan_lte() { - let page_bus_index: usize = 0; - let range_bus_index: usize = 1; - let idx_len: usize = 2; - let data_len: usize = 3; - let decomp: usize = 8; - let limb_bits: usize = 16; - let range_max: u32 = 1 << decomp; - - let log_page_height = 1; - let page_width = 1 + idx_len + data_len; + let cmp = Comp::Lte; let mut page_controller: PageController = PageController::new( - page_bus_index, - range_bus_index, - idx_len, - data_len, - range_max, - limb_bits, - decomp, - Comp::Lte, + PAGE_BUS_INDEX, + RANGE_BUS_INDEX, + IDX_LEN, + DATA_LEN, + RANGE_MAX, + LIMB_BITS, + DECOMP, + cmp.clone(), ); let page: Vec> = vec![ vec![1, 443, 376, 22278, 13998, 58327], vec![1, 2177, 5880, 51171, 3989, 12770], ]; - let page = Page::from_2d_vec(&page, idx_len, data_len); + let page = Page::from_2d_vec(&page, IDX_LEN, DATA_LEN); let x: Vec = vec![2177, 5880]; - let page_output = page_controller.gen_output(page.clone(), x.clone(), page_width, Comp::Lte); + let page_output = page_controller.gen_output(page.clone(), x.clone(), PAGE_WIDTH, cmp); - let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); + let engine = config::baby_bear_poseidon2::default_engine(LOG_PAGE_HEIGHT.max(DECOMP)); let prover = MultiTraceStarkProver::new(&engine.config); let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); @@ -157,10 +150,10 @@ fn test_single_page_index_scan_lte() { page, page_output, x, - idx_len, - data_len, - limb_bits, - decomp, + IDX_LEN, + DATA_LEN, + LIMB_BITS, + DECOMP, &mut page_controller, &mut trace_builder, ) @@ -169,39 +162,30 @@ fn test_single_page_index_scan_lte() { #[test] fn test_single_page_index_scan_eq() { - let page_bus_index: usize = 0; - let range_bus_index: usize = 1; - let idx_len: usize = 2; - let data_len: usize = 3; - let decomp: usize = 8; - let limb_bits: usize = 16; - let range_max: u32 = 1 << decomp; - - let log_page_height = 1; - let page_width = 1 + idx_len + data_len; + let cmp = Comp::Eq; let mut page_controller: PageController = PageController::new( - page_bus_index, - range_bus_index, - idx_len, - data_len, - range_max, - limb_bits, - decomp, - Comp::Eq, + PAGE_BUS_INDEX, + RANGE_BUS_INDEX, + IDX_LEN, + DATA_LEN, + RANGE_MAX, + LIMB_BITS, + DECOMP, + cmp.clone(), ); let page: Vec> = vec![ vec![1, 443, 376, 22278, 13998, 58327], vec![1, 2883, 7769, 51171, 3989, 12770], ]; - let page = Page::from_2d_vec(&page, idx_len, data_len); + let page = Page::from_2d_vec(&page, IDX_LEN, DATA_LEN); let x: Vec = vec![443, 376]; - let page_output = page_controller.gen_output(page.clone(), x.clone(), page_width, Comp::Eq); + let page_output = page_controller.gen_output(page.clone(), x.clone(), PAGE_WIDTH, cmp); - let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); + let engine = config::baby_bear_poseidon2::default_engine(LOG_PAGE_HEIGHT.max(DECOMP)); let prover = MultiTraceStarkProver::new(&engine.config); let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); @@ -211,10 +195,10 @@ fn test_single_page_index_scan_eq() { page, page_output, x, - idx_len, - data_len, - limb_bits, - decomp, + IDX_LEN, + DATA_LEN, + LIMB_BITS, + DECOMP, &mut page_controller, &mut trace_builder, ) @@ -223,39 +207,30 @@ fn test_single_page_index_scan_eq() { #[test] fn test_single_page_index_scan_gte() { - let page_bus_index: usize = 0; - let range_bus_index: usize = 1; - let idx_len: usize = 2; - let data_len: usize = 3; - let decomp: usize = 8; - let limb_bits: usize = 16; - let range_max: u32 = 1 << decomp; - - let log_page_height = 1; - let page_width = 1 + idx_len + data_len; + let cmp = Comp::Gte; let mut page_controller: PageController = PageController::new( - page_bus_index, - range_bus_index, - idx_len, - data_len, - range_max, - limb_bits, - decomp, - Comp::Gte, + PAGE_BUS_INDEX, + RANGE_BUS_INDEX, + IDX_LEN, + DATA_LEN, + RANGE_MAX, + LIMB_BITS, + DECOMP, + cmp.clone(), ); let page: Vec> = vec![ vec![1, 2177, 5880, 22278, 13998, 58327], vec![1, 2883, 7769, 51171, 3989, 12770], ]; - let page = Page::from_2d_vec(&page, idx_len, data_len); + let page = Page::from_2d_vec(&page, IDX_LEN, DATA_LEN); let x: Vec = vec![2177, 5880]; - let page_output = page_controller.gen_output(page.clone(), x.clone(), page_width, Comp::Gte); + let page_output = page_controller.gen_output(page.clone(), x.clone(), PAGE_WIDTH, cmp); - let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); + let engine = config::baby_bear_poseidon2::default_engine(LOG_PAGE_HEIGHT.max(DECOMP)); let prover = MultiTraceStarkProver::new(&engine.config); let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); @@ -265,10 +240,10 @@ fn test_single_page_index_scan_gte() { page, page_output, x, - idx_len, - data_len, - limb_bits, - decomp, + IDX_LEN, + DATA_LEN, + LIMB_BITS, + DECOMP, &mut page_controller, &mut trace_builder, ) @@ -277,39 +252,30 @@ fn test_single_page_index_scan_gte() { #[test] fn test_single_page_index_scan_gt() { - let page_bus_index: usize = 0; - let range_bus_index: usize = 1; - let idx_len: usize = 2; - let data_len: usize = 3; - let decomp: usize = 8; - let limb_bits: usize = 16; - let range_max: u32 = 1 << decomp; - - let log_page_height = 1; - let page_width = 1 + idx_len + data_len; + let cmp = Comp::Gt; let mut page_controller: PageController = PageController::new( - page_bus_index, - range_bus_index, - idx_len, - data_len, - range_max, - limb_bits, - decomp, - Comp::Gt, + PAGE_BUS_INDEX, + RANGE_BUS_INDEX, + IDX_LEN, + DATA_LEN, + RANGE_MAX, + LIMB_BITS, + DECOMP, + cmp.clone(), ); let page: Vec> = vec![ vec![1, 2203, 376, 22278, 13998, 58327], vec![1, 2883, 7769, 51171, 3989, 12770], ]; - let page = Page::from_2d_vec(&page, idx_len, data_len); + let page = Page::from_2d_vec(&page, IDX_LEN, DATA_LEN); let x: Vec = vec![2177, 5880]; - let page_output = page_controller.gen_output(page.clone(), x.clone(), page_width, Comp::Gt); + let page_output = page_controller.gen_output(page.clone(), x.clone(), PAGE_WIDTH, cmp); - let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); + let engine = config::baby_bear_poseidon2::default_engine(LOG_PAGE_HEIGHT.max(DECOMP)); let prover = MultiTraceStarkProver::new(&engine.config); let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); @@ -319,10 +285,10 @@ fn test_single_page_index_scan_gt() { page, page_output, x, - idx_len, - data_len, - limb_bits, - decomp, + IDX_LEN, + DATA_LEN, + LIMB_BITS, + DECOMP, &mut page_controller, &mut trace_builder, ) @@ -331,26 +297,16 @@ fn test_single_page_index_scan_gt() { #[test] fn test_single_page_index_scan_wrong_order() { - let page_bus_index: usize = 0; - let range_bus_index: usize = 1; - let idx_len: usize = 2; - let data_len: usize = 3; - let decomp: usize = 8; - let limb_bits: usize = 16; - let range_max: u32 = 1 << decomp; - - let log_page_height = 1; - let cmp = Comp::Lt; let mut page_controller: PageController = PageController::new( - page_bus_index, - range_bus_index, - idx_len, - data_len, - range_max, - limb_bits, - decomp, + PAGE_BUS_INDEX, + RANGE_BUS_INDEX, + IDX_LEN, + DATA_LEN, + RANGE_MAX, + LIMB_BITS, + DECOMP, cmp, ); @@ -358,7 +314,7 @@ fn test_single_page_index_scan_wrong_order() { vec![1, 443, 376, 22278, 13998, 58327], vec![1, 2883, 7769, 51171, 3989, 12770], ]; - let page = Page::from_2d_vec(&page, idx_len, data_len); + let page = Page::from_2d_vec(&page, IDX_LEN, DATA_LEN); let x: Vec = vec![2177, 5880]; @@ -366,9 +322,9 @@ fn test_single_page_index_scan_wrong_order() { vec![0, 0, 0, 0, 0, 0], vec![1, 443, 376, 22278, 13998, 58327], ]; - let page_output = Page::from_2d_vec(&page_output, idx_len, data_len); + let page_output = Page::from_2d_vec(&page_output, IDX_LEN, DATA_LEN); - let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); + let engine = config::baby_bear_poseidon2::default_engine(LOG_PAGE_HEIGHT.max(DECOMP)); let prover = MultiTraceStarkProver::new(&engine.config); let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); @@ -382,10 +338,10 @@ fn test_single_page_index_scan_wrong_order() { page, page_output, x, - idx_len, - data_len, - limb_bits, - decomp, + IDX_LEN, + DATA_LEN, + LIMB_BITS, + DECOMP, &mut page_controller, &mut trace_builder, ), @@ -396,26 +352,16 @@ fn test_single_page_index_scan_wrong_order() { #[test] fn test_single_page_index_scan_unsorted() { - let page_bus_index: usize = 0; - let range_bus_index: usize = 1; - let idx_len: usize = 2; - let data_len: usize = 3; - let decomp: usize = 8; - let limb_bits: usize = 16; - let range_max: u32 = 1 << decomp; - - let log_page_height = 1; - let cmp = Comp::Lt; let mut page_controller: PageController = PageController::new( - page_bus_index, - range_bus_index, - idx_len, - data_len, - range_max, - limb_bits, - decomp, + PAGE_BUS_INDEX, + RANGE_BUS_INDEX, + IDX_LEN, + DATA_LEN, + RANGE_MAX, + LIMB_BITS, + DECOMP, cmp, ); @@ -423,7 +369,7 @@ fn test_single_page_index_scan_unsorted() { vec![1, 2883, 7769, 51171, 3989, 12770], vec![1, 443, 376, 22278, 13998, 58327], ]; - let page = Page::from_2d_vec(&page, idx_len, data_len); + let page = Page::from_2d_vec(&page, IDX_LEN, DATA_LEN); let x: Vec = vec![2177, 5880]; @@ -431,9 +377,9 @@ fn test_single_page_index_scan_unsorted() { vec![0, 0, 0, 0, 0, 0], vec![1, 443, 376, 22278, 13998, 58327], ]; - let page_output = Page::from_2d_vec(&page_output, idx_len, data_len); + let page_output = Page::from_2d_vec(&page_output, IDX_LEN, DATA_LEN); - let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); + let engine = config::baby_bear_poseidon2::default_engine(LOG_PAGE_HEIGHT.max(DECOMP)); let prover = MultiTraceStarkProver::new(&engine.config); let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); @@ -447,10 +393,10 @@ fn test_single_page_index_scan_unsorted() { page, page_output, x, - idx_len, - data_len, - limb_bits, - decomp, + IDX_LEN, + DATA_LEN, + LIMB_BITS, + DECOMP, &mut page_controller, &mut trace_builder, ), @@ -461,26 +407,16 @@ fn test_single_page_index_scan_unsorted() { #[test] fn test_single_page_index_scan_wrong_answer() { - let page_bus_index: usize = 0; - let range_bus_index: usize = 1; - let idx_len: usize = 2; - let data_len: usize = 3; - let decomp: usize = 8; - let limb_bits: usize = 16; - let range_max: u32 = 1 << decomp; - - let log_page_height = 1; - let cmp = Comp::Lt; let mut page_controller: PageController = PageController::new( - page_bus_index, - range_bus_index, - idx_len, - data_len, - range_max, - limb_bits, - decomp, + PAGE_BUS_INDEX, + RANGE_BUS_INDEX, + IDX_LEN, + DATA_LEN, + RANGE_MAX, + LIMB_BITS, + DECOMP, cmp, ); @@ -488,7 +424,7 @@ fn test_single_page_index_scan_wrong_answer() { vec![1, 2883, 7769, 51171, 3989, 12770], vec![1, 443, 376, 22278, 13998, 58327], ]; - let page = Page::from_2d_vec(&page, idx_len, data_len); + let page = Page::from_2d_vec(&page, IDX_LEN, DATA_LEN); let x: Vec = vec![2177, 5880]; @@ -496,9 +432,9 @@ fn test_single_page_index_scan_wrong_answer() { vec![1, 2883, 7769, 51171, 3989, 12770], vec![0, 0, 0, 0, 0, 0], ]; - let page_output = Page::from_2d_vec(&page_output, idx_len, data_len); + let page_output = Page::from_2d_vec(&page_output, IDX_LEN, DATA_LEN); - let engine = config::baby_bear_poseidon2::default_engine(log_page_height.max(decomp)); + let engine = config::baby_bear_poseidon2::default_engine(LOG_PAGE_HEIGHT.max(DECOMP)); let prover = MultiTraceStarkProver::new(&engine.config); let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); @@ -512,10 +448,10 @@ fn test_single_page_index_scan_wrong_answer() { page, page_output, x, - idx_len, - data_len, - limb_bits, - decomp, + IDX_LEN, + DATA_LEN, + LIMB_BITS, + DECOMP, &mut page_controller, &mut trace_builder, ),