Skip to content
This repository has been archived by the owner on Jul 5, 2024. It is now read-only.

Commit

Permalink
twist root circuit api to support multi-proof
Browse files Browse the repository at this point in the history
  • Loading branch information
hero78119 committed Dec 11, 2023
1 parent 6ab26d6 commit 0edf828
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 62 deletions.
16 changes: 11 additions & 5 deletions integration-tests/src/integration_test_circuits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use zkevm_circuits::{
pi_circuit::TestPiCircuit,
root_circuit::{
compile, Config, EvmTranscript, NativeLoader, PoseidonTranscript, RootCircuit, Shplonk,
UserChallenge,
SnarkWitness, UserChallenge,
},
state_circuit::TestStateCircuit,
super_circuit::SuperCircuit,
Expand Down Expand Up @@ -319,8 +319,11 @@ impl<C: SubCircuit<Fr> + Circuit<Fr>> IntegrationTest<C> {
let circuit = RootCircuit::<Bn256, Shplonk<_>>::new(
&params,
&protocol,
Value::unknown(),
Value::unknown(),
vec![SnarkWitness::new(
&protocol,
Value::unknown(),
Value::unknown(),
)],
None,
)
.unwrap();
Expand Down Expand Up @@ -458,8 +461,11 @@ impl<C: SubCircuit<Fr> + Circuit<Fr>> IntegrationTest<C> {
let root_circuit = RootCircuit::<Bn256, Shplonk<_>>::new(
&params,
&protocol,
Value::known(&instance),
Value::known(&proof),
vec![SnarkWitness::new(
&protocol,
Value::known(&instance),
Value::known(&proof),
)],
Some(user_challenge),
)
.unwrap();
Expand Down
91 changes: 54 additions & 37 deletions zkevm-circuits/src/root_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,10 @@ pub struct UserChallenge {
#[derive(Clone)]
pub struct RootCircuit<'a, M: MultiMillerLoop, As> {
svk: KzgSvk<M>,
snark: SnarkWitness<'a, M::G1Affine>,
protocol: &'a PlonkProtocol<M::G1Affine>,
snark_witnesses: Vec<SnarkWitness<'a, M::G1Affine>>,
instance: Vec<M::Scalar>,
user_challenges: Option<&'a UserChallenge>,
_marker: PhantomData<As>,
}

Expand All @@ -139,70 +141,79 @@ where
{
/// Create a `RootCircuit` with accumulator computed given a `SuperCircuit`
/// proof and its instance. Returns `None` if given proof is invalid.
/// TODO support multiple snark proof aggregation
pub fn new(
params: &ParamsKZG<M>,
super_circuit_protocol: &'a PlonkProtocol<M::G1Affine>,
super_circuit_instances: Value<&'a Vec<Vec<M::Scalar>>>,
super_circuit_proof: Value<&'a [u8]>,
protocol: &'a PlonkProtocol<M::G1Affine>,
snark_witnesses: Vec<SnarkWitness<'a, M::G1Affine>>,
user_challenges: Option<&'a UserChallenge>,
) -> Result<Self, snark_verifier::Error> {
let num_instances = super_circuit_protocol.num_instance.iter().sum::<usize>();
let num_raw_instances = protocol.num_instance.iter().sum::<usize>();

// compute real instance value
let (flatten_super_circuit_instances, accumulator_limbs) = {
let (flatten_first_chunk_instances, accumulator_limbs) = {
let (mut instance, mut accumulator_limbs) = (
vec![M::Scalar::ZERO; num_instances],
vec![M::Scalar::ZERO; num_raw_instances],
Ok(vec![M::Scalar::ZERO; 4 * LIMBS]),
);
super_circuit_instances
.as_ref()
.zip(super_circuit_proof.as_ref())
.map(|(super_circuit_instances, super_circuit_proof)| {
let snark = Snark::new(
super_circuit_protocol,
super_circuit_instances,
super_circuit_proof,
);
accumulator_limbs = aggregate::<M, As>(params, [snark])
.map(|accumulator_limbs| accumulator_limbs.to_vec());
instance = super_circuit_instances
.iter()
.flatten()
.cloned()
.collect_vec()
// compute aggregate_limbs
snark_witnesses
.iter()
.fold(Value::known(vec![]), |acc_snark, snark_witness| {
snark_witness
.instances
.as_ref()
.zip(snark_witness.proof.as_ref())
.map(|(super_circuit_instances, super_circuit_proof)| {
Snark::new(protocol, super_circuit_instances, super_circuit_proof)
})
.zip(acc_snark)
.map(|(snark, mut acc_snark)| {
acc_snark.push(snark);
acc_snark
})
})
.map(|snarks| {
if !snarks.is_empty() {
accumulator_limbs = aggregate::<M, As>(params, snarks)
.map(|accumulator_limbs| accumulator_limbs.to_vec());
}
});

// retrieve first instance
if let Some(snark_witness) = snark_witnesses.first() {
snark_witness
.instances
.map(|instances| instance = instances.iter().flatten().cloned().collect_vec());
}

(instance, accumulator_limbs?)
};

debug_assert_eq!(flatten_super_circuit_instances.len(), num_instances);
debug_assert_eq!(flatten_first_chunk_instances.len(), num_raw_instances);
let mut flatten_instance =
exposed_instances(&SuperCircuitInstance::new(flatten_super_circuit_instances));
exposed_instances(&SuperCircuitInstance::new(flatten_first_chunk_instances));
flatten_instance.extend(accumulator_limbs);

Ok(Self {
svk: KzgSvk::<M>::new(params.get_g()[0]),
snark: SnarkWitness::new(
super_circuit_protocol,
super_circuit_instances,
user_challenges,
super_circuit_proof,
),
protocol,
snark_witnesses,
instance: flatten_instance,
user_challenges,
_marker: PhantomData,
})
}

/// Returns accumulator indices in instance columns, which will be in
/// the last `4 * LIMBS` rows of instance column in `MainGate`.
pub fn accumulator_indices(&self) -> Vec<(usize, usize)> {
let offset = self.snark.protocol().num_instance.iter().sum::<usize>();
let offset = self.protocol.num_instance.iter().sum::<usize>();
(offset..).map(|idx| (0, idx)).take(4 * LIMBS).collect()
}

/// Returns number of instance
pub fn num_instance(&self) -> Vec<usize> {
vec![self.snark.protocol().num_instance.iter().sum::<usize>() + 4 * LIMBS]
vec![self.instance.len()]
}

/// Returns instance
Expand Down Expand Up @@ -234,7 +245,13 @@ where
fn without_witnesses(&self) -> Self {
Self {
svk: self.svk,
snark: self.snark.without_witnesses(),
protocol: self.protocol,
snark_witnesses: self
.snark_witnesses
.iter()
.map(|snark_witness| snark_witness.without_witnesses())
.collect_vec(),
user_challenges: self.user_challenges,
instance: vec![M::Scalar::ZERO; self.instance.len()],
_marker: PhantomData,
}
Expand All @@ -257,13 +274,13 @@ where
config.named_column_in_region(&mut region);
let ctx = RegionCtx::new(region, 0);
let (loaded_instances, accumulator_limbs, loader, proofs) =
config.aggregate::<M, As>(ctx, &key.clone(), [self.snark])?;
config.aggregate::<M, As>(ctx, &key.clone(), &self.snark_witnesses)?;

// aggregate user challenge for rwtable permutation challenge
let (alpha, gamma) = {
let mut challenges = config.aggregate_user_challenges::<M, As>(
loader.clone(),
self.snark.user_challenges(),
self.user_challenges,
proofs,
)?;
(challenges.remove(0), challenges.remove(0))
Expand Down
23 changes: 7 additions & 16 deletions zkevm-circuits/src/root_circuit/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,32 +101,31 @@ impl<'a, C: CurveAffine> From<Snark<'a, C>> for SnarkWitness<'a, C> {
protocol: snark.protocol,
instances: Value::known(snark.instances),
proof: Value::known(snark.proof),
user_challenges: None,
}
}
}

/// SnarkWitness
#[derive(Clone, Copy)]
pub struct SnarkWitness<'a, C: CurveAffine> {
protocol: &'a PlonkProtocol<C>,
instances: Value<&'a Vec<Vec<C::Scalar>>>,
user_challenges: Option<&'a UserChallenge>,
proof: Value<&'a [u8]>,
/// protocol
pub protocol: &'a PlonkProtocol<C>,
/// instance
pub instances: Value<&'a Vec<Vec<C::Scalar>>>,
/// proof
pub proof: Value<&'a [u8]>,
}

impl<'a, C: CurveAffine> SnarkWitness<'a, C> {
/// Construct `SnarkWitness` with each field.
pub fn new(
protocol: &'a PlonkProtocol<C>,
instances: Value<&'a Vec<Vec<C::Scalar>>>,
user_challenges: Option<&'a UserChallenge>,
proof: Value<&'a [u8]>,
) -> Self {
Self {
protocol,
instances,
user_challenges,
proof,
}
}
Expand All @@ -136,7 +135,6 @@ impl<'a, C: CurveAffine> SnarkWitness<'a, C> {
SnarkWitness {
protocol: self.protocol,
instances: Value::unknown(),
user_challenges: self.user_challenges,
proof: Value::unknown(),
}
}
Expand All @@ -151,11 +149,6 @@ impl<'a, C: CurveAffine> SnarkWitness<'a, C> {
self.proof
}

/// Returns user_challenges as option.
pub fn user_challenges(&self) -> Option<&'a UserChallenge> {
self.user_challenges
}

fn loaded_instances<'b>(
&self,
loader: &Rc<Halo2Loader<'b, C>>,
Expand Down Expand Up @@ -298,7 +291,7 @@ impl AggregationConfig {
&self,
ctx: RegionCtx<'c, M::Scalar>,
svk: &KzgSvk<M>,
snarks: impl IntoIterator<Item = SnarkWitness<'a, M::G1Affine>>,
snarks: &[SnarkWitness<'a, M::G1Affine>],
) -> Result<
(
Vec<Vec<Scalar<'c, M::G1Affine, EccChip<M::G1Affine>>>>,
Expand Down Expand Up @@ -329,8 +322,6 @@ impl AggregationConfig {
type PoseidonTranscript<'a, C, S> =
transcript::halo2::PoseidonTranscript<C, Rc<Halo2Loader<'a, C>>, S, T, RATE, R_F, R_P>;

let snarks = snarks.into_iter().collect_vec();

let mut plonk_svp = vec![];
// Verify the cheap part and get accumulator (left-hand and right-hand side of
// pairing) of individual proof.
Expand Down
2 changes: 1 addition & 1 deletion zkevm-circuits/src/root_circuit/dev.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ where
config.named_column_in_region(&mut region);
let ctx = RegionCtx::new(region, 0);
let (instances, accumulator_limbs, _, _) =
config.aggregate::<M, As>(ctx, &self.svk, self.snarks.clone())?;
config.aggregate::<M, As>(ctx, &self.svk, &self.snarks)?;
let instances = instances
.iter()
.flat_map(|instances| {
Expand Down
11 changes: 8 additions & 3 deletions zkevm-circuits/src/root_circuit/test.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::{
root_circuit::{compile, Config, Gwc, PoseidonTranscript, RootCircuit, UserChallenge},
root_circuit::{
compile, Config, Gwc, PoseidonTranscript, RootCircuit, SnarkWitness, UserChallenge,
},
super_circuit::{test::block_1tx, SuperCircuit},
};
use bus_mapping::circuit_input_builder::FixedCParams;
Expand Down Expand Up @@ -75,8 +77,11 @@ fn test_root_circuit() {
let root_circuit = RootCircuit::<Bn256, Gwc<_>>::new(
&params,
&protocol,
Value::known(&instance),
Value::known(&proof),
vec![SnarkWitness::new(
&protocol,
Value::known(&instance),
Value::known(&proof),
)],
Some(&user_challenge),
)
.unwrap();
Expand Down

0 comments on commit 0edf828

Please sign in to comment.