Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
Schaeff committed Sep 11, 2024
1 parent 80b1adc commit 8f48ac2
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 20 deletions.
1 change: 1 addition & 0 deletions monty-31/src/poseidon2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use p3_symmetric::Permutation;
use crate::{monty_reduce, FieldParameters, MontyField31, MontyParameters};

/// Everything needed to compute multiplication by a WIDTH x WIDTH diffusion matrix whose monty form is 1 + Diag(vec).
///
/// vec is assumed to be of the form [-2, ...] with all entries after the first being small powers of 2.
pub trait DiffusionMatrixParameters<FP: FieldParameters, const WIDTH: usize>: Clone + Sync {
// Most of the time, ArrayLike will be [u8; WIDTH - 1].
Expand Down
17 changes: 12 additions & 5 deletions uni-stark/src/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ impl<'a, SC: StarkGenericConfig> State<'a, SC> {
}

pub(crate) fn run_stage(mut self, stage: Stage<SC>) -> Self {

#[cfg(debug_assertions)]
let trace = stage.trace.clone();

Expand Down Expand Up @@ -265,17 +264,25 @@ pub struct Stage<SC: StarkGenericConfig> {
}

pub struct CallbackResult<T> {
// the trace for this stage
pub(crate) trace: RowMajorMatrix<T>,
// the values of the public inputs of this stage
pub(crate) public_values: Vec<T>,
// todo: return shared challenges
// the values of the challenges drawn at the previous stage
pub(crate) challenges: Vec<T>,
}

impl<T> CallbackResult<T> {
pub fn new(trace: RowMajorMatrix<T>, public_values: Vec<T>) -> Self {
Self { trace, public_values }
pub fn new(trace: RowMajorMatrix<T>, public_values: Vec<T>, challenges: Vec<T>) -> Self {
Self {
trace,
public_values,
challenges,
}
}
}

pub trait NextStageTraceCallback<SC: StarkGenericConfig> {
fn get_next_stage(&self, trace_stage: u32, challenges: &[Val<SC>]) -> CallbackResult<Val<SC>>;
/// Computes the stage number `trace_stage` based on `challenges` drawn at the end of stage `trace_stage - 1`
fn compute_stage(&self, stage: u32, challenges: &[Val<SC>]) -> CallbackResult<Val<SC>>;
}
39 changes: 24 additions & 15 deletions uni-stark/src/prover.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use alloc::borrow::ToOwned;
use alloc::vec::Vec;
use core::iter;

Expand All @@ -22,7 +23,7 @@ use crate::{
struct Panic;

impl<SC: StarkGenericConfig> NextStageTraceCallback<SC> for Panic {
fn get_next_stage(&self, _: u32, _: &[Val<SC>]) -> CallbackResult<Val<SC>> {
fn compute_stage(&self, _: u32, _: &[Val<SC>]) -> CallbackResult<Val<SC>> {
unreachable!()
}
}
Expand Down Expand Up @@ -69,7 +70,7 @@ pub fn prove_with_key<
challenger: &mut SC::Challenger,
stage_0_trace: RowMajorMatrix<Val<SC>>,
next_stage_trace_callback: Option<&T>,
stage_0_public_values: &Vec<Val<SC>>,
stage_0_public_values: &[Val<SC>],
) -> Proof<SC>
where
SC: StarkGenericConfig,
Expand All @@ -95,24 +96,29 @@ where
let mut stage = Stage {
trace: stage_0_trace,
challenge_count: air.challenge_count(0),
public_values: stage_0_public_values.clone(),
public_values: stage_0_public_values.to_owned(),
};

assert!(stage_count >= 1);
// for all stages except the last one, generate the next stage based on the witgen callback
for stage_id in 0..stage_count - 1 {
// generate all stages starting from the second one based on the witgen callback
for stage_id in 1..stage_count {
state = state.run_stage(stage);
let last_processed_stage = state.processed_stages.last().unwrap();
// get the challenges drawn at the end of the previous stage
let local_challenges = &state.processed_stages.last().unwrap().challenge_values;
let CallbackResult {
trace,
public_values,
challenges,
} = next_stage_trace_callback
.as_ref()
.expect("witgen callback expected in the presence of challenges")
.get_next_stage(stage_id as u32, &last_processed_stage.challenge_values);
.compute_stage(stage_id as u32, local_challenges);
// replace the challenges of the last stage with the ones received
state.processed_stages.last_mut().unwrap().challenge_values = challenges;
// go to the next stage
stage = Stage {
trace,
challenge_count: air.challenge_count(stage_id as u32 + 1),
challenge_count: air.challenge_count(stage_id as u32),
public_values,
};
}
Expand All @@ -135,8 +141,16 @@ where
&air.preprocessed_trace()
.unwrap_or(RowMajorMatrix::new(Default::default(), 0)),
state.processed_stages.iter().map(|s| &s.trace).collect(),
&state.processed_stages.iter().map(|s| &s.public_values).collect(),
state.processed_stages.iter().map(|s| &s.challenge_values).collect(),
&state
.processed_stages
.iter()
.map(|s| &s.public_values)
.collect(),
state
.processed_stages
.iter()
.map(|s| &s.challenge_values)
.collect(),
);

finish(proving_key, air, state)
Expand Down Expand Up @@ -271,11 +285,6 @@ where
sels.inv_zeroifier.push(Val::<SC>::default());
}

let challenges: Vec<Vec<_>> = challenges
.into_iter()
.map(|s| s.into_iter().map(|v| v.into()).collect())
.collect();

(0..quotient_size)
.into_par_iter()
.step_by(PackedVal::<SC>::WIDTH)
Expand Down

0 comments on commit 8f48ac2

Please sign in to comment.