Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functionality inside Trace to obtain relation columns used in lookups #2334

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion msm/src/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use folding::expressions::FoldingColumnTrait;
use kimchi::circuits::expr::{CacheId, FormattedOutput};

/// Describe a generic indexed variable X_{i}.
#[derive(PartialEq, Eq, Clone, Copy, Debug, Hash)]
#[derive(PartialEq, Eq, Clone, Copy, Debug, Hash, Ord, PartialOrd)]
pub enum Column {
/// Columns related to the relation encoded in the circuit
Relation(usize),
Expand Down
27 changes: 24 additions & 3 deletions o1vm/src/keccak/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -463,10 +463,12 @@ where
self.lookup_syscall_hash(step);
}

/// When in Absorb mode, reads Lookups containing the 136 bytes of the block of the preimage
/// When in Absorb mode:
/// Reads Lookups containing the 136 bytes of the block of the preimage
/// - if is_absorb, adds 136 lookups
/// - otherwise, adds 0 lookups
// TODO: optimize this by using a single lookup reusing PadSuffix
/// Uses the following columns:
/// - HashIndex, BlockIndex, SpongeBytes(0..136)
fn lookup_syscall_preimage(&mut self, step: Steps) {
for i in 0..RATE_IN_BYTES {
self.read_syscall(
Expand All @@ -481,9 +483,12 @@ where
}
}

/// When in Squeeze mode, writes a Lookup containing the 31byte output of the hash (excludes the MSB)
/// When in Squeeze mode:
/// Writes a Lookup containing the 31byte output of the hash (excludes the MSB)
/// - if is_squeeze, adds 1 lookup
/// - otherwise, adds 0 lookups
/// Uses the following columns:
/// - HashIndex, SpongeBytes(1..32)
fn lookup_syscall_hash(&mut self, step: Steps) {
let bytes31 = (1..32).fold(Self::zero(), |acc, i| {
acc * Self::two_pow(8) + self.sponge_byte(i)
Expand All @@ -496,6 +501,10 @@ where
/// - if is_root, only adds 1 lookup
/// - if is_squeeze, only adds 1 lookup
/// - otherwise, adds 2 lookups
/// When not in Root step, uses the following columns:
/// - HashIndex, StepIndex, Input(0..100)
/// When not in Squeeze step, uses the following columns:
/// - HashIndex, StepIndex, Output(0..100)
fn lookup_steps(&mut self, step: Steps) {
// (if not a root) Output of previous step is input of current step
self.add_lookup(
Expand All @@ -512,6 +521,10 @@ where
/// Adds the 601 lookups required for the sponge
/// - 600 lookups if is_sponge()
/// - 1 extra lookup if is_pad()
/// When in Pad step, uses the following columns:
/// - PadLength, TwoToPad, PadSuffix(0..5)
/// When in Sponge step, uses the following columns:
/// - SpongeBytes(0..200), SpongeShifts(0..400)
fn lookups_sponge(&mut self, step: Steps) {
// PADDING LOOKUPS
// Power of two corresponds to 2^pad_length
Expand Down Expand Up @@ -546,6 +559,8 @@ where
}

/// Adds the 120 lookups required for Theta in the round
/// When in Round step, uses the following columns:
/// - ThetaRemainderC(i), ThetaDenseRotC(i), ThetaExpandRotC(i), ThetaDenseC(i), ThetaShiftsC(i)
fn lookups_round_theta(&mut self, step: Steps) {
for q in 0..QUARTERS {
for x in 0..DIM {
Expand All @@ -572,6 +587,8 @@ where
}

/// Adds the 700 lookups required for PiRho in the round
/// When in Round step, uses the following columns:
/// - PiRhoRemainderE(i), PiRhoQuotientE(i), PiRhoDenseRotE(i), PiRhoExpandRotE(i), PiRhoDenseE(i), PiRhoShiftsE(i)
fn lookups_round_pirho(&mut self, step: Steps) {
for q in 0..QUARTERS {
for x in 0..DIM {
Expand Down Expand Up @@ -601,6 +618,8 @@ where
}

/// Adds the 800 lookups required for Chi in the round
/// When in Round step, uses the following columns:
/// - ChiShiftsB(i), ChiShiftsSum(i)
fn lookups_round_chi(&mut self, step: Steps) {
let shifts_b = self.vec_shifts_b();
let shifts_sum = self.vec_shifts_sum();
Expand All @@ -612,6 +631,8 @@ where
}

/// Adds the 1 lookup required for Iota in the round
/// When in Round step, uses the following columns:
/// - RoundNumber, RoundConstants(0..4)
fn lookups_round_iota(&mut self, step: Steps) {
// Check round constants correspond with the current round
let round_constants = self.round_constants();
Expand Down
165 changes: 165 additions & 0 deletions o1vm/src/keccak/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,171 @@ fn test_regression_number_of_lookups_and_constraints_and_degree() {
}
}

#[test]
fn test_delayed_columns() {
use crate::keccak::{KeccakColumn::*, DIM, QUARTERS};
use kimchi::circuits::polynomials::keccak::constants::{SHIFTS, SHIFTS_LEN, STATE_LEN};
use kimchi_msm::columns::ColumnIndexer;
use std::collections::HashSet;

let mut rng = o1_utils::tests::make_test_rng(None);
let domain_size = 1 << 6;

// Full trace with all possible steps of Keccak
let keccak_trace = create_trace_all_steps(domain_size, &mut rng);

let absorb_cols = [
vec![HashIndex, BlockIndex],
(0..RATE_IN_BYTES).map(SpongeBytes).collect(),
]
.concat();

let squeeze_cols = [vec![HashIndex], (1..32).map(SpongeBytes).collect()].concat();

let not_root_cols = [
vec![HashIndex, StepIndex],
(0..STATE_LEN).map(Input).collect(),
]
.concat();

let not_squeeze_cols = [
vec![HashIndex, StepIndex],
(0..STATE_LEN).map(Output).collect(),
]
.concat();

let pad_cols = [vec![PadLength, TwoToPad], (0..5).map(PadSuffix).collect()].concat();

let sponge_cols = [
(0..200).map(SpongeBytes).collect::<Vec<_>>(),
(0..SHIFTS_LEN).map(SpongeShifts).collect(),
]
.concat();

// Theta lookups
let mut round_cols = (0..QUARTERS * DIM).map(ThetaRemainderC).collect::<Vec<_>>();
round_cols.extend((0..QUARTERS * DIM).map(ThetaDenseRotC));
round_cols.extend((0..QUARTERS * DIM).map(ThetaExpandRotC));
round_cols.extend((0..QUARTERS * DIM).map(ThetaDenseC));
round_cols.extend((0..QUARTERS * DIM * SHIFTS).map(ThetaShiftsC));
// PiRho lookups
round_cols.extend((0..QUARTERS * DIM * DIM).map(PiRhoRemainderE));
round_cols.extend((0..QUARTERS * DIM * DIM).map(PiRhoQuotientE));
round_cols.extend((0..QUARTERS * DIM * DIM).map(PiRhoDenseRotE));
round_cols.extend((0..QUARTERS * DIM * DIM).map(PiRhoExpandRotE));
round_cols.extend((0..QUARTERS * DIM * DIM).map(PiRhoDenseE));
round_cols.extend((0..QUARTERS * DIM * DIM * SHIFTS).map(PiRhoShiftsE));
// Chi lookups
round_cols.extend((0..SHIFTS_LEN).map(ChiShiftsB));
round_cols.extend((0..SHIFTS_LEN).map(ChiShiftsSum));
// Iota lookups
round_cols.push(RoundNumber);
round_cols.extend((0..QUARTERS).map(RoundConstants));

let first: HashSet<KeccakColumn> = HashSet::from_iter(
[
absorb_cols.clone(),
not_squeeze_cols.clone(),
sponge_cols.clone(),
]
.concat(),
);

let middle: HashSet<KeccakColumn> = HashSet::from_iter(
[
absorb_cols.clone(),
not_root_cols.clone(),
not_squeeze_cols.clone(),
sponge_cols.clone(),
]
.concat(),
);

let last: HashSet<KeccakColumn> = HashSet::from_iter(
[
absorb_cols.clone(),
not_root_cols.clone(),
not_squeeze_cols.clone(),
pad_cols.clone(),
sponge_cols.clone(),
]
.concat(),
);

let only: HashSet<KeccakColumn> = HashSet::from_iter(
[
absorb_cols.clone(),
not_squeeze_cols.clone(),
pad_cols.clone(),
sponge_cols.clone(),
]
.concat(),
);

let squeeze: HashSet<KeccakColumn> = HashSet::from_iter(
[
squeeze_cols.clone(),
not_root_cols.clone(),
sponge_cols.clone(),
]
.concat(),
);

let round: HashSet<KeccakColumn> =
HashSet::from_iter([round_cols, not_squeeze_cols, not_root_cols].concat());

// Check delayed columns of First
for col in first.iter() {
assert!(keccak_trace[Sponge(Absorb(First))].delayed_columns[&col.to_column()],);
}
assert_eq!(
first.len(),
keccak_trace[Sponge(Absorb(First))].delayed_columns.len()
);

// Check delayed columns of Middle
for col in middle.iter() {
assert!(keccak_trace[Sponge(Absorb(Middle))].delayed_columns[&col.to_column()],);
}
assert_eq!(
middle.len(),
keccak_trace[Sponge(Absorb(Middle))].delayed_columns.len()
);

// Check delayed columns of Last
for col in last.iter() {
assert!(keccak_trace[Sponge(Absorb(Last))].delayed_columns[&col.to_column()],);
}
assert_eq!(
last.len(),
keccak_trace[Sponge(Absorb(Last))].delayed_columns.len()
);

// Check delayed columns of Only
for col in only.iter() {
assert!(keccak_trace[Sponge(Absorb(Only))].delayed_columns[&col.to_column()],);
}
assert_eq!(
only.len(),
keccak_trace[Sponge(Absorb(Only))].delayed_columns.len()
);

// Check delayed columns of Squeeze
for col in squeeze.iter() {
assert!(keccak_trace[Sponge(Squeeze)].delayed_columns[&col.to_column()],);
}
assert_eq!(
squeeze.len(),
keccak_trace[Sponge(Squeeze)].delayed_columns.len()
);

// Check delayed columns of Round
for col in round.iter() {
assert!(keccak_trace[Round(0)].delayed_columns[&col.to_column()],);
}
assert_eq!(round.len(), keccak_trace[Round(0)].delayed_columns.len());
}

#[test]
fn test_keccak_witness_satisfies_lookups() {
let mut rng = o1_utils::tests::make_test_rng(None);
Expand Down
7 changes: 5 additions & 2 deletions o1vm/src/keccak/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,17 @@ impl Tracer<N_ZKVM_KECCAK_REL_COLS, KeccakConfig, KeccakEnv<ScalarField<KeccakCo
) -> Self {
// Make sure we are using the same round number to refer to round steps
let step = standardize(selector);
Self {
let mut trace = Self {
domain_size,
witness: Witness {
cols: Box::new(std::array::from_fn(|_| Vec::with_capacity(domain_size))),
},
constraints: KeccakEnv::constraints_of(step),
lookups: KeccakEnv::lookups_of(step),
}
delayed_columns: BTreeMap::new(),
};
trace.set_delayed_columns();
trace
}

fn push_row(
Expand Down
3 changes: 3 additions & 0 deletions o1vm/src/mips/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,8 @@ mod unit {
}

mod folding {
use std::collections::BTreeMap;

use super::{
unit::{dummy_env, write_instruction},
Fp,
Expand Down Expand Up @@ -618,6 +620,7 @@ mod folding {
witness: witness_one.clone(),
constraints: constraints.clone(),
lookups: vec![],
delayed_columns: BTreeMap::new(),
};

impl FoldingConfig for MIPSFoldingConfig {
Expand Down
5 changes: 4 additions & 1 deletion o1vm/src/mips/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,17 @@ impl
) -> Self {
interpret_instruction(env, instr);

let trace = Self {
let mut trace = Self {
domain_size,
witness: Witness {
cols: Box::new(std::array::from_fn(|_| Vec::with_capacity(domain_size))),
},
constraints: env.constraints.clone(),
lookups: env.lookups.clone(),
delayed_columns: BTreeMap::new(),
};
trace.set_delayed_columns();

env.scratch_state_idx = 0; // Reset the scratch state index for the next instruction
env.constraints = vec![]; // Clear the constraints for the next instruction
env.lookups = vec![]; // Clear the lookups for the next instruction
Expand Down
2 changes: 1 addition & 1 deletion o1vm/src/ramlookup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub struct RAMLookup<T, ID: LookupTableID> {
pub(crate) mode: LookupMode,
/// The number of times that this lookup value should be added to / subtracted from the lookup accumulator.
pub(crate) magnitude: T,
/// The columns containing the content of this lookup
/// The variables containing the content of this lookup
pub(crate) value: Vec<T>,
}

Expand Down
39 changes: 37 additions & 2 deletions o1vm/src/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use ark_ff::{One, Zero};
use ark_poly::{Evaluations, Radix2EvaluationDomain as D};
use folding::{expressions::FoldingCompatibleExpr, Alphas, FoldingConfig};
use itertools::Itertools;
use kimchi::circuits::expr::ChallengeTerm;
use kimchi::circuits::expr::{ChallengeTerm, ExprInner::*, Operations::*};
use kimchi_msm::{columns::Column, witness::Witness};
use mina_poseidon::sponge::FqSponge;
use poly_commitment::{commitment::absorb_commitment, PolyComm, SRS as _};
Expand All @@ -33,6 +33,41 @@ pub struct Trace<const N: usize, C: FoldingConfig> {
pub witness: Witness<N, Vec<ScalarField<C>>>,
pub constraints: Vec<E<ScalarField<C>>>,
pub lookups: Vec<Lookup<E<ScalarField<C>>>>,
/// Generic columns involved in the lookups for this step used in delayed argument
// NOTE: could be precomputed as they are known per step ahead of time
pub delayed_columns: BTreeMap<Column, bool>,
}

impl<const N: usize, C: FoldingConfig> Trace<N, C> {
pub(crate) fn set_delayed_columns(&mut self) {
let mut delayed_columns: BTreeMap<Column, bool> = BTreeMap::new();
self.lookups.iter().for_each(|lookup| {
lookup.value.iter().for_each(|op| {
let columns = Self::columns_in_expr(op);
columns.iter().for_each(|column| {
delayed_columns.insert(*column, true);
});
});
});
self.delayed_columns = delayed_columns;
}

fn columns_in_expr(op: &E<ScalarField<C>>) -> Vec<Column> {
match op {
Atom(x) => match x {
Cell(x) => vec![x.col],
_ => vec![],
},
Pow(x, _) => Self::columns_in_expr(x),
Add(x, y) | Mul(x, y) | Sub(x, y) => {
let mut columns = Self::columns_in_expr(x);
columns.extend(Self::columns_in_expr(y));
columns
}
Double(x) | Square(x) => Self::columns_in_expr(x),
Cache(_, _) | IfFeature(_, _, _) => vec![],
}
}
}

/// Struct representing a circuit execution trace which is decomposable in
Expand Down Expand Up @@ -277,7 +312,7 @@ pub trait DecomposableTracer<Env> {
/// Generic implementation of the [Tracer] trait for the [DecomposedTrace] struct.
/// It requires the [DecomposedTrace] to implement the [DecomposableTracer] trait,
/// and the [Trace] struct to implement the [Tracer] trait with Selector set to (),
/// and the `C::Selector` to implement the [Indexer] trait.
/// and `usize` to implement the [From] trait with `C::Selector`.
impl<const N: usize, const N_REL: usize, C: FoldingConfig, Env> Tracer<N_REL, C, Env>
for DecomposedTrace<N, C>
where
Expand Down