Skip to content
This repository has been archived by the owner on Oct 1, 2024. It is now read-only.

Commit

Permalink
keep a single list of public inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
Schaeff committed Sep 16, 2024
1 parent cc1bc7d commit c53ede3
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 100 deletions.
8 changes: 4 additions & 4 deletions dft/benches/fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ where
Dft: TwoAdicSubgroupDft<F>,
Standard: Distribution<F>,
{
let mut group = c.benchmark_group(&format!(
let mut group = c.benchmark_group(format!(
"fft::<{}, {}, {}>",
type_name::<F>(),
type_name::<Dft>(),
Expand Down Expand Up @@ -75,7 +75,7 @@ where
Dft: TwoAdicSubgroupDft<Complex<Mersenne31>>,
Standard: Distribution<Mersenne31>,
{
let mut group = c.benchmark_group(&format!(
let mut group = c.benchmark_group(format!(
"m31_fft::<{}, {}>",
type_name::<Dft>(),
BATCH_SIZE
Expand All @@ -102,7 +102,7 @@ where
Dft: TwoAdicSubgroupDft<F>,
Standard: Distribution<F>,
{
let mut group = c.benchmark_group(&format!(
let mut group = c.benchmark_group(format!(
"ifft::<{}, {}, {}>",
type_name::<F>(),
type_name::<Dft>(),
Expand Down Expand Up @@ -131,7 +131,7 @@ where
Dft: TwoAdicSubgroupDft<F>,
Standard: Distribution<F>,
{
let mut group = c.benchmark_group(&format!(
let mut group = c.benchmark_group(format!(
"coset_lde::<{}, {}, {}>",
type_name::<F>(),
type_name::<Dft>(),
Expand Down
1 change: 1 addition & 0 deletions monty-31/src/poseidon2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use p3_symmetric::Permutation;
use crate::{monty_reduce, FieldParameters, MontyField31, MontyParameters};

/// Everything needed to compute multiplication by a WIDTH x WIDTH diffusion matrix whose monty form is 1 + Diag(vec).
///
/// vec is assumed to be of the form [-2, ...] with all entries after the first being small powers of 2.
pub trait DiffusionMatrixParameters<FP: FieldParameters, const WIDTH: usize>: Clone + Sync {
// Most of the time, ArrayLike will be [u8; WIDTH - 1].
Expand Down
10 changes: 3 additions & 7 deletions uni-stark/src/check_constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub(crate) fn check_constraints<F, A>(
air: &A,
preprocessed: &RowMajorMatrix<F>,
stages: Vec<&RowMajorMatrix<F>>,
public_values: &Vec<&Vec<F>>,
public_values: &Vec<F>,
challenges: Vec<&Vec<F>>,
) where
F: Field,
Expand Down Expand Up @@ -75,7 +75,7 @@ pub struct DebugConstraintBuilder<'a, F: Field> {
preprocessed: VerticalPair<RowMajorMatrixView<'a, F>, RowMajorMatrixView<'a, F>>,
challenges: Vec<&'a Vec<F>>,
stages: Vec<VerticalPair<RowMajorMatrixView<'a, F>, RowMajorMatrixView<'a, F>>>,
public_values: &'a [&'a Vec<F>],
public_values: &'a [F],
is_first_row: F,
is_last_row: F,
is_transition: F,
Expand Down Expand Up @@ -134,7 +134,7 @@ impl<'a, F: Field> AirBuilderWithPublicValues for DebugConstraintBuilder<'a, F>
type PublicVar = Self::F;

fn public_values(&self) -> &[Self::PublicVar] {
self.stage_public_values(0)
self.public_values
}
}

Expand All @@ -147,10 +147,6 @@ impl<'a, F: Field> PairBuilder for DebugConstraintBuilder<'a, F> {
impl<'a, F: Field> MultistageAirBuilder for DebugConstraintBuilder<'a, F> {
type Challenge = Self::Expr;

fn stage_public_values(&self, stage: usize) -> &[Self::F] {
self.public_values[stage]
}

fn stage_trace(&self, stage: usize) -> Self::M {
self.stages[stage]
}
Expand Down
14 changes: 4 additions & 10 deletions uni-stark/src/folder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pub struct ProverConstraintFolder<'a, SC: StarkGenericConfig> {
pub challenges: Vec<Vec<Val<SC>>>,
pub stages: Vec<RowMajorMatrix<PackedVal<SC>>>,
pub preprocessed: RowMajorMatrix<PackedVal<SC>>,
pub public_values: &'a Vec<Vec<Val<SC>>>,
pub public_values: &'a Vec<Val<SC>>,
pub is_first_row: PackedVal<SC>,
pub is_last_row: PackedVal<SC>,
pub is_transition: PackedVal<SC>,
Expand All @@ -28,7 +28,7 @@ pub struct VerifierConstraintFolder<'a, SC: StarkGenericConfig> {
pub challenges: Vec<Vec<Val<SC>>>,
pub stages: Vec<ViewPair<'a, SC::Challenge>>,
pub preprocessed: ViewPair<'a, SC::Challenge>,
pub public_values: Vec<&'a Vec<Val<SC>>>,
pub public_values: &'a Vec<Val<SC>>,
pub is_first_row: SC::Challenge,
pub is_last_row: SC::Challenge,
pub is_transition: SC::Challenge,
Expand Down Expand Up @@ -73,7 +73,7 @@ impl<'a, SC: StarkGenericConfig> AirBuilderWithPublicValues for ProverConstraint
type PublicVar = Val<SC>;

fn public_values(&self) -> &[Self::PublicVar] {
self.stage_public_values(0)
self.public_values
}
}

Expand All @@ -87,9 +87,6 @@ impl<'a, SC: StarkGenericConfig> MultistageAirBuilder for ProverConstraintFolder
fn stage_challenges(&self, stage: usize) -> &[Self::Challenge] {
&self.challenges[stage]
}
fn stage_public_values(&self, stage: usize) -> &[Self::PublicVar] {
&self.public_values[stage]
}
}

impl<'a, SC: StarkGenericConfig> PairBuilder for ProverConstraintFolder<'a, SC> {
Expand Down Expand Up @@ -135,7 +132,7 @@ impl<'a, SC: StarkGenericConfig> AirBuilderWithPublicValues for VerifierConstrai
type PublicVar = Val<SC>;

fn public_values(&self) -> &[Self::PublicVar] {
self.stage_public_values(0)
self.public_values
}
}

Expand All @@ -149,9 +146,6 @@ impl<'a, SC: StarkGenericConfig> MultistageAirBuilder for VerifierConstraintFold
fn stage_challenges(&self, stage: usize) -> &[Self::Challenge] {
&self.challenges[stage]
}
fn stage_public_values(&self, stage: usize) -> &[Self::PublicVar] {
self.public_values[stage]
}
}

impl<'a, SC: StarkGenericConfig> PairBuilder for VerifierConstraintFolder<'a, SC> {
Expand Down
1 change: 0 additions & 1 deletion uni-stark/src/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ pub struct ProcessedStage<SC: StarkGenericConfig> {
pub(crate) commitment: Com<SC>,
pub(crate) prover_data: PcsProverData<SC>,
pub(crate) challenge_values: Vec<Val<SC>>,
pub(crate) public_values: Vec<Val<SC>>,
#[cfg(debug_assertions)]
pub(crate) trace: RowMajorMatrix<Val<SC>>,
}
82 changes: 48 additions & 34 deletions uni-stark/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,29 @@ pub fn prove<
air: &A,
challenger: &mut SC::Challenger,
main_trace: RowMajorMatrix<Val<SC>>,
#[allow(clippy::ptr_arg)]
// we do not use `&[Val<SC>]` in order to keep the same API
public_values: &Vec<Val<SC>>,
) -> Proof<SC>
where
SC: StarkGenericConfig,
A: MultiStageAir<SymbolicAirBuilder<Val<SC>>>
+ for<'a> MultiStageAir<ProverConstraintFolder<'a, SC>>,
{
let public_values = public_values
.iter()
.enumerate()
.map(|(index, value)| (index, *value))
.collect();

prove_with_key(
config,
None,
air,
challenger,
main_trace,
&UnusedCallback,
public_values,
&public_values,
)
}

Expand All @@ -72,9 +80,7 @@ pub fn prove_with_key<
challenger: &mut SC::Challenger,
stage_0_trace: RowMajorMatrix<Val<SC>>,
next_stage_trace_callback: &C,
#[allow(clippy::ptr_arg)]
// we do not use `&[Val<SC>]` in order to keep the same API
stage_0_public_values: &Vec<Val<SC>>,
stage_0_public_values: &Vec<(usize, Val<SC>)>,
) -> Proof<SC>
where
SC: StarkGenericConfig,
Expand All @@ -85,10 +91,6 @@ where
let degree = stage_0_trace.height();
let log_degree = log2_strict_usize(degree);

let log_quotient_degree =
get_log_quotient_degree::<Val<SC>, A>(air, &[stage_0_public_values.len()]);
let quotient_degree = 1 << log_quotient_degree;

let stage_count = <A as MultiStageAir<SymbolicAirBuilder<_>>>::stage_count(air);

let pcs = config.pcs();
Expand All @@ -103,6 +105,7 @@ where
};

let mut state: ProverState<SC> = ProverState::new(pcs, trace_domain, challenger);
state.add_public_values(stage_0_public_values.iter().cloned());
let mut stage = Stage {
trace: stage_0_trace,
challenge_count: <A as MultiStageAir<SymbolicAirBuilder<_>>>::stage_challenge_count(air, 0),
Expand Down Expand Up @@ -146,33 +149,29 @@ where
// sanity check that we processed as many stages as expected
assert_eq!(state.processed_stages.len(), stage_count);

let public_values: Vec<_> = state
.public_values
.into_iter()
.map(|v| v.unwrap())
.collect();

// with the witness complete, check the constraints
#[cfg(debug_assertions)]
crate::check_constraints::check_constraints(
air,
&air.preprocessed_trace()
.unwrap_or(RowMajorMatrix::new(Default::default(), 0)),
state.processed_stages.iter().map(|s| &s.trace).collect(),
&state
.processed_stages
.iter()
.map(|s| &s.public_values)
.collect(),
&public_values,
state
.processed_stages
.iter()
.map(|s| &s.challenge_values)
.collect(),
);

let log_quotient_degree = get_log_quotient_degree::<Val<SC>, A>(
air,
&state
.processed_stages
.iter()
.map(|s| s.public_values.len())
.collect::<Vec<_>>(),
);
let log_quotient_degree = get_log_quotient_degree::<Val<SC>, A>(air, public_values.len());
let quotient_degree = 1 << log_quotient_degree;

let challenger = &mut state.challenger;

Expand All @@ -197,12 +196,6 @@ where
.map(|stage| stage.challenge_values.clone())
.collect();

let public_values = state
.processed_stages
.iter()
.map(|stage| stage.public_values.clone())
.collect();

let quotient_values = quotient_values(
air,
&public_values,
Expand Down Expand Up @@ -312,7 +305,7 @@ where
#[instrument(name = "compute quotient polynomial", skip_all)]
fn quotient_values<'a, SC, A, Mat>(
air: &A,
public_values: &'a Vec<Vec<Val<SC>>>,
public_values: &'a Vec<Val<SC>>,
trace_domain: Domain<SC>,
quotient_domain: Domain<SC>,
preprocessed_on_quotient_domain: Option<Mat>,
Expand Down Expand Up @@ -416,6 +409,7 @@ pub struct ProverState<'a, SC: StarkGenericConfig> {
pub(crate) challenger: &'a mut SC::Challenger,
pub(crate) pcs: &'a <SC>::Pcs,
pub(crate) trace_domain: Domain<SC>,
pub(crate) public_values: Vec<Option<Val<SC>>>,
}

impl<'a, SC: StarkGenericConfig> ProverState<'a, SC> {
Expand All @@ -429,6 +423,24 @@ impl<'a, SC: StarkGenericConfig> ProverState<'a, SC> {
challenger,
pcs,
trace_domain,
public_values: Default::default(),
}
}

pub(crate) fn add_public_values(
&mut self,
public_values: impl IntoIterator<Item = (usize, Val<SC>)>,
) {
for (index, value) in public_values {
if self.public_values.len() <= index + 1 {
self.public_values.resize(index + 1, None);
}
match self.public_values[index] {
Some(_) => panic!("public value at index {index} is already set"),
None => {
self.public_values[index] = Some(value);
}
}
}
}

Expand All @@ -446,11 +458,9 @@ impl<'a, SC: StarkGenericConfig> ProverState<'a, SC> {
.map(|_| self.challenger.sample())
.collect();

// observe the public inputs for this stage
self.challenger.observe_slice(&stage.public_values);
self.add_public_values(stage.public_values);

self.processed_stages.push(ProcessedStage {
public_values: stage.public_values,
prover_data,
commitment,
challenge_values,
Expand All @@ -467,20 +477,24 @@ pub struct Stage<SC: StarkGenericConfig> {
/// the number of challenges to be drawn at the end of this stage
pub(crate) challenge_count: usize,
/// the public values for this stage
pub(crate) public_values: Vec<Val<SC>>,
pub(crate) public_values: Vec<(usize, Val<SC>)>,
}

pub struct CallbackResult<T> {
/// the trace for this stage
pub(crate) trace: RowMajorMatrix<T>,
/// the values of the public inputs of this stage
pub(crate) public_values: Vec<T>,
pub(crate) public_values: Vec<(usize, T)>,
/// the values of the challenges drawn at the previous stage
pub(crate) challenges: Vec<T>,
}

impl<T> CallbackResult<T> {
pub fn new(trace: RowMajorMatrix<T>, public_values: Vec<T>, challenges: Vec<T>) -> Self {
pub fn new(
trace: RowMajorMatrix<T>,
public_values: Vec<(usize, T)>,
challenges: Vec<T>,
) -> Self {
Self {
trace,
public_values,
Expand Down
Loading

0 comments on commit c53ede3

Please sign in to comment.