Skip to content
This repository has been archived by the owner on Oct 1, 2024. It is now read-only.

Commit

Permalink
implement multi-stage publics
Browse files Browse the repository at this point in the history
  • Loading branch information
Schaeff committed Sep 10, 2024
1 parent b89e07c commit cc341b8
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 95 deletions.
11 changes: 10 additions & 1 deletion air/src/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,16 @@ pub trait AirBuilder: Sized {
pub trait AirBuilderWithPublicValues: AirBuilder {
type PublicVar: Into<Self::Expr> + Copy;

fn public_values(&self) -> &[Self::PublicVar];
fn stage_public_values(&self, stage: usize) -> &[Self::PublicVar] {
match stage {
0 => self.public_values(),
_ => unimplemented!(),
}
}

fn public_values(&self) -> &[Self::PublicVar] {
self.stage_public_values(0)
}
}

pub trait PairBuilder: AirBuilder {
Expand Down
12 changes: 4 additions & 8 deletions dft/benches/fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ where
Dft: TwoAdicSubgroupDft<F>,
Standard: Distribution<F>,
{
let mut group = c.benchmark_group(&format!(
let mut group = c.benchmark_group(format!(
"fft::<{}, {}, {}>",
type_name::<F>(),
type_name::<Dft>(),
Expand Down Expand Up @@ -75,11 +75,7 @@ where
Dft: TwoAdicSubgroupDft<Complex<Mersenne31>>,
Standard: Distribution<Mersenne31>,
{
let mut group = c.benchmark_group(&format!(
"m31_fft::<{}, {}>",
type_name::<Dft>(),
BATCH_SIZE
));
let mut group = c.benchmark_group(format!("m31_fft::<{}, {}>", type_name::<Dft>(), BATCH_SIZE));
group.sample_size(10);

let mut rng = thread_rng();
Expand All @@ -102,7 +98,7 @@ where
Dft: TwoAdicSubgroupDft<F>,
Standard: Distribution<F>,
{
let mut group = c.benchmark_group(&format!(
let mut group = c.benchmark_group(format!(
"ifft::<{}, {}, {}>",
type_name::<F>(),
type_name::<Dft>(),
Expand Down Expand Up @@ -131,7 +127,7 @@ where
Dft: TwoAdicSubgroupDft<F>,
Standard: Distribution<F>,
{
let mut group = c.benchmark_group(&format!(
let mut group = c.benchmark_group(format!(
"coset_lde::<{}, {}, {}>",
type_name::<F>(),
type_name::<Dft>(),
Expand Down
14 changes: 6 additions & 8 deletions uni-stark/src/check_constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@ use p3_matrix::stack::VerticalPair;
use p3_matrix::Matrix;
use tracing::instrument;

// use crate::VerificationError;

#[instrument(name = "check constraints", skip_all)]
pub(crate) fn check_constraints<F, A>(
air: &A,
preprocessed: &RowMajorMatrix<F>,
stages: Vec<&RowMajorMatrix<F>>,
public_values: Vec<&Vec<F>>,
public_values: &Vec<&Vec<F>>,
challenges: Vec<&Vec<F>>,
) where
F: Field,
Expand Down Expand Up @@ -57,7 +55,7 @@ pub(crate) fn check_constraints<F, A>(
challenges: challenges.clone(),
preprocessed,
stages,
public_values: public_values.clone(),
public_values,
is_first_row: F::from_bool(i == 0),
is_last_row: F::from_bool(i == height - 1),
is_transition: F::from_bool(i != height - 1),
Expand All @@ -75,7 +73,7 @@ pub struct DebugConstraintBuilder<'a, F: Field> {
preprocessed: VerticalPair<RowMajorMatrixView<'a, F>, RowMajorMatrixView<'a, F>>,
challenges: Vec<&'a Vec<F>>,
stages: Vec<VerticalPair<RowMajorMatrixView<'a, F>, RowMajorMatrixView<'a, F>>>,
public_values: Vec<&'a Vec<F>>,
public_values: &'a [&'a Vec<F>],
is_first_row: F,
is_last_row: F,
is_transition: F,
Expand Down Expand Up @@ -133,8 +131,8 @@ where
impl<'a, F: Field> AirBuilderWithPublicValues for DebugConstraintBuilder<'a, F> {
type PublicVar = Self::F;

fn public_values(&self) -> &[Self::F] {
self.public_values[0]
fn stage_public_values(&self, stage: usize) -> &[Self::F] {
self.public_values[stage]
}
}

Expand All @@ -150,6 +148,6 @@ impl<'a, F: Field> MultistageAirBuilder for DebugConstraintBuilder<'a, F> {
}

fn challenges(&self, stage: usize) -> &[Self::Expr] {
&self.challenges[stage]
self.challenges[stage]
}
}
10 changes: 5 additions & 5 deletions uni-stark/src/folder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub struct ProverConstraintFolder<'a, SC: StarkGenericConfig> {
pub challenges: Vec<Vec<PackedVal<SC>>>,
pub stages: Vec<RowMajorMatrix<PackedVal<SC>>>,
pub preprocessed: RowMajorMatrix<PackedVal<SC>>,
pub public_values: Vec<&'a Vec<Val<SC>>>,
pub public_values: &'a Vec<Vec<Val<SC>>>,
pub is_first_row: PackedVal<SC>,
pub is_last_row: PackedVal<SC>,
pub is_transition: PackedVal<SC>,
Expand Down Expand Up @@ -71,8 +71,8 @@ impl<'a, SC: StarkGenericConfig> AirBuilder for ProverConstraintFolder<'a, SC> {
impl<'a, SC: StarkGenericConfig> AirBuilderWithPublicValues for ProverConstraintFolder<'a, SC> {
type PublicVar = Self::F;

fn public_values(&self) -> &[Self::F] {
self.public_values[0]
fn stage_public_values(&self, stage: usize) -> &[Self::F] {
&self.public_values[stage]
}
}

Expand Down Expand Up @@ -128,8 +128,8 @@ impl<'a, SC: StarkGenericConfig> AirBuilder for VerifierConstraintFolder<'a, SC>
impl<'a, SC: StarkGenericConfig> AirBuilderWithPublicValues for VerifierConstraintFolder<'a, SC> {
type PublicVar = Self::F;

fn public_values(&self) -> &[Self::F] {
self.public_values[0]
fn stage_public_values(&self, stage: usize) -> &[Self::F] {
self.public_values[stage]
}
}

Expand Down
42 changes: 21 additions & 21 deletions uni-stark/src/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ pub struct Proof<SC: StarkGenericConfig> {
pub(crate) opened_values: OpenedValues<SC::Challenge>,
pub(crate) opening_proof: PcsProof<SC>,
pub(crate) degree_bits: usize,
pub(crate) challenge_counts: Vec<usize>,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct Commitments<Com> {
pub(crate) stages: Vec<Com>, // we need to fix this
pub(crate) stages: Vec<Com>,
pub(crate) quotient_chunks: Com,
}

Expand All @@ -63,21 +64,19 @@ pub struct ProcessedStage<SC: StarkGenericConfig> {
pub(crate) commitment: Com<SC>,
pub(crate) prover_data: PcsProverData<SC>,
pub(crate) challenge_values: Vec<Val<SC>>,
pub(crate) public_values: Vec<Val<SC>>,
}

/// Updating with each new trace in every stage
pub struct State<'a, SC: StarkGenericConfig> {
// todo: let each stage also set public values
pub(crate) public_values: &'a Vec<Val<SC>>,
pub(crate) processed_stages: Vec<ProcessedStage<SC>>,
pub(crate) challenger: &'a mut SC::Challenger,
pcs: &'a <SC>::Pcs,
trace_domain: Domain<SC>,
log_degree: usize,
}

pub struct QuotientInputs<'a, SC: StarkGenericConfig> {
pub public_values: Vec<&'a Vec<Val<SC>>>,
pub struct QuotientInputs<SC: StarkGenericConfig> {
pub trace_domain: Domain<SC>,
pub quotient_domain: Domain<SC>,
pub alpha: SC::Challenge,
Expand All @@ -89,10 +88,8 @@ impl<'a, SC: StarkGenericConfig> State<'a, SC> {
trace_domain: Domain<SC>,
challenger: &'a mut <SC as StarkGenericConfig>::Challenger,
log_degree: usize,
public_values: &'a Vec<Val<SC>>,
) -> Self {
Self {
public_values,
processed_stages: Default::default(),
challenger,
pcs,
Expand All @@ -105,21 +102,15 @@ impl<'a, SC: StarkGenericConfig> State<'a, SC> {
self.log_degree
}

/// Observe a commitment.
pub(crate) fn observe_commit(&mut self, trace_commit: Com<SC>) {
self.challenger.observe(trace_commit);
}

/// Get inputs to quotient calculation
pub(crate) fn quotient_inputs(&mut self, log_quotient_degree: usize) -> QuotientInputs<'a, SC> {
pub(crate) fn quotient_inputs(&mut self, log_quotient_degree: usize) -> QuotientInputs<SC> {
let alpha: SC::Challenge = self.challenger.sample_ext_element();

let quotient_domain = self
.trace_domain
.create_disjoint_domain(1 << (self.log_degree + log_quotient_degree));

QuotientInputs {
public_values: vec![&self.public_values],
trace_domain: self.trace_domain,
quotient_domain,
alpha,
Expand Down Expand Up @@ -155,7 +146,7 @@ impl<'a, SC: StarkGenericConfig> State<'a, SC> {
.commit(izip!(qc_domains, quotient_chunks).collect_vec())
});

self.observe_commit(quotient_commit.clone());
self.challenger.observe(quotient_commit.clone());

(quotient_commit, quotient_data)
}
Expand Down Expand Up @@ -233,16 +224,21 @@ impl<'a, SC: StarkGenericConfig> State<'a, SC> {
}

pub(crate) fn run_stage(mut self, stage: Stage<SC>) -> Self {
// commit to the trace for this stage
let (commitment, prover_data) = info_span!("commit to trace data")
.in_scope(|| self.pcs.commit(vec![(self.trace_domain, stage.trace)]));

self.observe_commit(commitment.clone());
self.challenger.observe(commitment.clone());

let challenge_values = (0..stage.challenge_count)
.map(|_| self.challenger.sample())
.collect();

// observe the public inputs for this stage
self.challenger.observe_slice(&stage.public_values);

self.processed_stages.push(ProcessedStage {
public_values: stage.public_values,
prover_data,
commitment,
challenge_values,
Expand All @@ -256,12 +252,16 @@ pub struct Stage<SC: StarkGenericConfig> {
pub(crate) trace: RowMajorMatrix<Val<SC>>,
// the number of challenges to be drawn at the end of this stage
pub(crate) challenge_count: usize,
// the public values for this stage
pub(crate) public_values: Vec<Val<SC>>,
}

pub struct CallbackResult<T> {
pub(crate) trace: RowMajorMatrix<T>,
pub(crate) public_values: Vec<T>,
// todo: return shared challenges
}

pub trait NextStageTraceCallback<SC: StarkGenericConfig> {
fn get_next_stage_trace(
&self,
trace_stage: u32,
challenges: &[Val<SC>],
) -> RowMajorMatrix<Val<SC>>;
fn get_next_stage(&self, trace_stage: u32, challenges: &[Val<SC>]) -> CallbackResult<Val<SC>>;
}
Loading

0 comments on commit cc341b8

Please sign in to comment.