Skip to content

Commit

Permalink
zal: store engine at the prover level
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Feb 23, 2024
1 parent f6e717f commit 9509f4e
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 63 deletions.
83 changes: 31 additions & 52 deletions halo2_backend/src/plonk/prover.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use group::Curve;
use halo2_middleware::ff::{Field, FromUniformBytes, WithSmallOrderMulGroup};
use halo2_middleware::zal::{
impls::{PlonkEngine, PlonkEngineConfig},
impls::{PlonkEngine, PlonkEngineConfig, H2cEngine},
traits::MsmAccel,
};
use rand_core::RngCore;
Expand Down Expand Up @@ -49,7 +49,8 @@ pub struct ProverV2Single<
E: EncodedChallenge<Scheme::Curve>,
R: RngCore,
T: TranscriptWrite<Scheme::Curve, E>,
>(ProverV2<'a, 'params, Scheme, P, E, R, T>);
M: MsmAccel<Scheme::Curve>,
>(ProverV2<'a, 'params, Scheme, P, E, R, T, M>);

impl<
'a,
Expand All @@ -59,11 +60,12 @@ impl<
E: EncodedChallenge<Scheme::Curve>,
R: RngCore,
T: TranscriptWrite<Scheme::Curve, E>,
> ProverV2Single<'a, 'params, Scheme, P, E, R, T>
M: MsmAccel<Scheme::Curve>,
> ProverV2Single<'a, 'params, Scheme, P, E, R, T, M>
{
/// Create a new prover object
pub fn new_with_engine<M: MsmAccel<Scheme::Curve>>(
engine: &PlonkEngine<Scheme::Curve, M>,
pub fn new_with_engine(
engine: PlonkEngine<Scheme::Curve, M>,
params: &'params Scheme::ParamsProver,
pk: &'a ProvingKey<Scheme::Curve>,
// TODO: If this was a vector the usage would be simpler
Expand Down Expand Up @@ -93,46 +95,32 @@ impl<
instance: &[&[Scheme::Scalar]],
rng: R,
transcript: &'a mut T,
) -> Result<Self, Error>
) -> Result<ProverV2Single<'a, 'params, Scheme, P, E, R, T, H2cEngine>, Error>
where
Scheme::Scalar: WithSmallOrderMulGroup<3> + FromUniformBytes<64>,
{
let engine = PlonkEngineConfig::build_default();
Self::new_with_engine(&engine, params, pk, instance, rng, transcript)
ProverV2Single::new_with_engine(engine, params, pk, instance, rng, transcript)
}

/// Commit the `witness` at `phase` and return the challenges after `phase`.
pub fn commit_phase<M: MsmAccel<Scheme::Curve>>(
pub fn commit_phase(
&mut self,
engine: &PlonkEngine<Scheme::Curve, M>,
phase: u8,
witness: Vec<Option<Vec<Scheme::Scalar>>>,
) -> Result<HashMap<usize, Scheme::Scalar>, Error>
where
Scheme::Scalar: WithSmallOrderMulGroup<3> + FromUniformBytes<64>,
{
self.0.commit_phase(engine, phase, vec![witness])
self.0.commit_phase(phase, vec![witness])
}

/// Finalizes the proof creation.
pub fn create_proof(self) -> Result<(), Error>
where
Scheme::Scalar: WithSmallOrderMulGroup<3> + FromUniformBytes<64>,
{
let engine = PlonkEngineConfig::build_default();
self.create_proof_with_engine(&engine)
}

/// Finalizes the proof creation.
/// TODO: change to "ZalEngine" which will contain MsmAccel and FftAccel trait accelerators
pub fn create_proof_with_engine<M: MsmAccel<Scheme::Curve>>(
self,
engine: &PlonkEngine<Scheme::Curve, M>,
) -> Result<(), Error>
where
Scheme::Scalar: WithSmallOrderMulGroup<3> + FromUniformBytes<64>,
{
self.0.create_proof_with_engine(engine)
self.0.create_proof()
}
}

Expand All @@ -147,7 +135,9 @@ pub struct ProverV2<
E: EncodedChallenge<Scheme::Curve>,
R: RngCore,
T: TranscriptWrite<Scheme::Curve, E>,
M: MsmAccel<Scheme::Curve>,
> {
engine: PlonkEngine<Scheme::Curve, M>,
// Circuit and setup fields
params: &'params Scheme::ParamsProver,
pk: &'a ProvingKey<Scheme::Curve>,
Expand All @@ -171,11 +161,12 @@ impl<
E: EncodedChallenge<Scheme::Curve>,
R: RngCore,
T: TranscriptWrite<Scheme::Curve, E>,
> ProverV2<'a, 'params, Scheme, P, E, R, T>
M: MsmAccel<Scheme::Curve>,
> ProverV2<'a, 'params, Scheme, P, E, R, T, M>
{
/// Create a new prover object
pub fn new_with_engine<M: MsmAccel<Scheme::Curve>>(
engine: &PlonkEngine<Scheme::Curve, M>,
pub fn new_with_engine(
engine: PlonkEngine<Scheme::Curve, M>,
params: &'params Scheme::ParamsProver,
pk: &'a ProvingKey<Scheme::Curve>,
// TODO: If this was a vector the usage would be simpler.
Expand Down Expand Up @@ -274,6 +265,7 @@ impl<
let challenges = HashMap::<usize, Scheme::Scalar>::with_capacity(meta.num_challenges);

Ok(ProverV2 {
engine,
params,
pk,
phases,
Expand All @@ -289,9 +281,8 @@ impl<

/// Commit the `witness` at `phase` and return the challenges after `phase`.
#[allow(clippy::type_complexity)]
pub fn commit_phase<M: MsmAccel<Scheme::Curve>>(
pub fn commit_phase(
&mut self,
engine: &PlonkEngine<Scheme::Curve, M>,
phase: u8,
witness: Vec<Vec<Option<Vec<Scheme::Scalar>>>>,
) -> Result<HashMap<usize, Scheme::Scalar>, Error>
Expand Down Expand Up @@ -409,7 +400,7 @@ impl<
let advice_commitments_projective: Vec<_> = advice_values
.iter()
.zip(blinds.iter())
.map(|(poly, blind)| params.commit_lagrange(&engine.msm_backend, poly, *blind))
.map(|(poly, blind)| params.commit_lagrange(&self.engine.msm_backend, poly, *blind))
.collect();
let mut advice_commitments =
vec![Scheme::Curve::identity(); advice_commitments_projective.len()];
Expand Down Expand Up @@ -455,10 +446,7 @@ impl<
}

/// Finalizes the proof creation.
pub fn create_proof_with_engine<M: MsmAccel<Scheme::Curve>>(
mut self,
engine: &PlonkEngine<Scheme::Curve, M>,
) -> Result<(), Error>
pub fn create_proof(mut self) -> Result<(), Error>
where
Scheme::Scalar: WithSmallOrderMulGroup<3> + FromUniformBytes<64>,
{
Expand Down Expand Up @@ -490,7 +478,7 @@ impl<
.iter()
.map(|lookup| {
lookup_commit_permuted(
engine,
&self.engine,
lookup,
pk,
params,
Expand Down Expand Up @@ -527,7 +515,7 @@ impl<
.zip(advice.iter())
.map(|(instance, advice)| {
permutation_commit(
engine,
&self.engine,
&meta.permutation,
params,
pk,
Expand All @@ -551,7 +539,7 @@ impl<
.into_iter()
.map(|lookup| {
lookup.commit_product(
engine,
&self.engine,
pk,
params,
beta,
Expand All @@ -573,7 +561,7 @@ impl<
.iter()
.map(|shuffle| {
shuffle_commit_product(
engine,
&self.engine,
shuffle,
pk,
params,
Expand All @@ -594,7 +582,7 @@ impl<

// Commit to the vanishing argument's random polynomial for blinding h(x_3)
let vanishing = vanishing::Argument::commit(
&engine.msm_backend,
&self.engine.msm_backend,
params,
domain,
&mut rng,
Expand Down Expand Up @@ -646,7 +634,7 @@ impl<

// Construct the vanishing argument's h(X) commitments
let vanishing =
vanishing.construct(engine, params, domain, h_poly, &mut rng, self.transcript)?;
vanishing.construct(&self.engine, params, domain, h_poly, &mut rng, self.transcript)?;

let x: ChallengeX<_> = self.transcript.squeeze_challenge_scalar();
let xn = x.pow([params.n()]);
Expand Down Expand Up @@ -786,7 +774,7 @@ impl<

let prover = P::new(params);
prover
.create_proof_with_engine(&engine.msm_backend, rng, self.transcript, instances)
.create_proof_with_engine(&self.engine.msm_backend, rng, self.transcript, instances)
.map_err(|_| Error::ConstraintSystemFailure)?;

Ok(())
Expand All @@ -801,20 +789,11 @@ impl<
instances: &[&[&[Scheme::Scalar]]],
rng: R,
transcript: &'a mut T,
) -> Result<Self, Error>
where
Scheme::Scalar: WithSmallOrderMulGroup<3> + FromUniformBytes<64>,
{
let engine = PlonkEngineConfig::build_default();
Self::new_with_engine(&engine, params, pk, instances, rng, transcript)
}

/// Finalizes the proof creation.
pub fn create_proof(self) -> Result<(), Error>
) -> Result<ProverV2<'a, 'params, Scheme, P, E, R, T, H2cEngine>, Error>
where
Scheme::Scalar: WithSmallOrderMulGroup<3> + FromUniformBytes<64>,
{
let engine = PlonkEngineConfig::build_default();
self.create_proof_with_engine(&engine)
ProverV2::new_with_engine(engine, params, pk, instances, rng, transcript)
}
}
1 change: 1 addition & 0 deletions halo2_middleware/src/zal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ pub mod impls {

// Backend-agnostic engine objects
// ---------------------------------------------------
#[derive(Debug)]
pub struct PlonkEngine<C: CurveAffine, MsmEngine: MsmAccel<C>> {
pub msm_backend: MsmEngine,
_marker: PhantomData<C>, // compiler complains about unused C otherwise
Expand Down
10 changes: 5 additions & 5 deletions halo2_proofs/src/plonk/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub fn create_proof_with_engine<
ConcreteCircuit: Circuit<Scheme::Scalar>,
M: MsmAccel<Scheme::Curve>,
>(
engine: &PlonkEngine<Scheme::Curve, M>,
engine: PlonkEngine<Scheme::Curve, M>,
params: &'params Scheme::ParamsProver,
pk: &ProvingKey<Scheme::Curve>,
circuits: &[ConcreteCircuit],
Expand All @@ -46,7 +46,7 @@ where
.enumerate()
.map(|(i, circuit)| WitnessCalculator::new(params.k(), circuit, &config, &cs, instances[i]))
.collect();
let mut prover = ProverV2::<Scheme, P, _, _, _>::new_with_engine(
let mut prover = ProverV2::<Scheme, P, _, _, _, _>::new_with_engine(
engine, params, pk, instances, rng, transcript,
)?;
let mut challenges = HashMap::new();
Expand All @@ -56,9 +56,9 @@ where
for witness_calc in witness_calcs.iter_mut() {
witnesses.push(witness_calc.calc(phase.0, &challenges)?);
}
challenges = prover.commit_phase(engine, phase.0, witnesses).unwrap();
challenges = prover.commit_phase(phase.0, witnesses).unwrap();
}
prover.create_proof_with_engine(engine)
prover.create_proof()
}

/// This creates a proof for the provided `circuit` when given the public
Expand Down Expand Up @@ -86,7 +86,7 @@ where
{
let engine = PlonkEngineConfig::build_default();
create_proof_with_engine::<Scheme, P, _, _, _, _, _>(
&engine, params, pk, circuits, instances, rng, transcript,
engine, params, pk, circuits, instances, rng, transcript,
)
}

Expand Down
8 changes: 4 additions & 4 deletions halo2_proofs/tests/frontend_backend_split.rs
Original file line number Diff line number Diff line change
Expand Up @@ -584,8 +584,8 @@ fn test_mycircuit_full_split() {
let mut witness_calc = WitnessCalculator::new(k, &circuit, &config, &cs, instances_slice);
let mut transcript = Blake2bWrite::<_, G1Affine, Challenge255<_>>::init(vec![]);
let mut prover =
ProverV2Single::<KZGCommitmentScheme<Bn256>, ProverSHPLONK<'_, Bn256>, _, _, _>::new_with_engine(
&engine,
ProverV2Single::<KZGCommitmentScheme<Bn256>, ProverSHPLONK<'_, Bn256>, _, _, _, _>::new_with_engine(
engine,
&params,
&pk,
instances_slice,
Expand All @@ -597,9 +597,9 @@ fn test_mycircuit_full_split() {
for phase in 0..cs.phases().count() {
println!("phase {phase}");
let witness = witness_calc.calc(phase as u8, &challenges).unwrap();
challenges = prover.commit_phase(&engine, phase as u8, witness).unwrap();
challenges = prover.commit_phase(phase as u8, witness).unwrap();
}
prover.create_proof_with_engine(&engine).unwrap();
prover.create_proof().unwrap();
let proof = transcript.finalize();
println!("Prove: {:?}", start.elapsed());

Expand Down
4 changes: 2 additions & 2 deletions halo2_proofs/tests/plonk_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ fn plonk_api() {
T: TranscriptWriterBuffer<Vec<u8>, Scheme::Curve, E>,
M: MsmAccel<Scheme::Curve>,
>(
engine: &PlonkEngine<Scheme::Curve, M>,
engine: PlonkEngine<Scheme::Curve, M>,
rng: R,
params: &'params Scheme::ParamsProver,
pk: &ProvingKey<Scheme::Curve>,
Expand Down Expand Up @@ -521,7 +521,7 @@ fn plonk_api() {
Scheme::Scalar: Ord + WithSmallOrderMulGroup<3> + FromUniformBytes<64>,
{
let engine = PlonkEngineConfig::build_default();
create_proof_with_engine::<Scheme, P, _, _, T, _>(&engine, rng, params, pk)
create_proof_with_engine::<Scheme, P, _, _, T, _>(engine, rng, params, pk)
}

fn verify_proof<
Expand Down

0 comments on commit 9509f4e

Please sign in to comment.