Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify AggregationChildTarget structs and reduce nb of args in method #188

Merged
merged 3 commits into from
Apr 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 44 additions & 98 deletions evm_arithmetization/src/fixed_recursive_verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ where
C: GenericConfig<D, F = F>,
{
pub circuit: CircuitData<F, C, D>,
lhs: SegmentAggregationChildTarget<D>,
rhs: SegmentAggregationChildTarget<D>,
lhs: AggregationChildTarget<D>,
rhs: AggregationChildTarget<D>,
public_values: PublicValuesTarget,
cyclic_vk: VerifierCircuitTarget,
}
Expand Down Expand Up @@ -219,8 +219,8 @@ where
let circuit = buffer.read_circuit_data(gate_serializer, generator_serializer)?;
let cyclic_vk = buffer.read_target_verifier_circuit()?;
let public_values = PublicValuesTarget::from_buffer(buffer)?;
let lhs = SegmentAggregationChildTarget::from_buffer(buffer)?;
let rhs = SegmentAggregationChildTarget::from_buffer(buffer)?;
let lhs = AggregationChildTarget::from_buffer(buffer)?;
let rhs = AggregationChildTarget::from_buffer(buffer)?;
Ok(Self {
circuit,
lhs,
Expand All @@ -232,28 +232,28 @@ where
}

#[derive(Eq, PartialEq, Debug)]
struct SegmentAggregationChildTarget<const D: usize> {
struct AggregationChildTarget<const D: usize> {
is_agg: BoolTarget,
agg_proof: ProofWithPublicInputsTarget<D>,
segment_proof: ProofWithPublicInputsTarget<D>,
proof: ProofWithPublicInputsTarget<D>,
}

impl<const D: usize> SegmentAggregationChildTarget<D> {
impl<const D: usize> AggregationChildTarget<D> {
fn to_buffer(&self, buffer: &mut Vec<u8>) -> IoResult<()> {
buffer.write_target_bool(self.is_agg)?;
buffer.write_target_proof_with_public_inputs(&self.agg_proof)?;
buffer.write_target_proof_with_public_inputs(&self.segment_proof)?;
buffer.write_target_proof_with_public_inputs(&self.proof)?;
Ok(())
}

fn from_buffer(buffer: &mut Buffer) -> IoResult<Self> {
let is_agg = buffer.read_target_bool()?;
let agg_proof = buffer.read_target_proof_with_public_inputs()?;
let segment_proof = buffer.read_target_proof_with_public_inputs()?;
let proof = buffer.read_target_proof_with_public_inputs()?;
Ok(Self {
is_agg,
agg_proof,
segment_proof,
proof,
})
}

Expand All @@ -267,7 +267,7 @@ impl<const D: usize> SegmentAggregationChildTarget<D> {
let agg_pv =
PublicValuesTarget::from_public_inputs(&self.agg_proof.public_inputs, len_mem_cap);
let segment_pv =
PublicValuesTarget::from_public_inputs(&self.segment_proof.public_inputs, len_mem_cap);
PublicValuesTarget::from_public_inputs(&self.proof.public_inputs, len_mem_cap);
PublicValuesTarget::select(builder, self.is_agg, agg_pv, segment_pv)
}
}
Expand All @@ -282,8 +282,8 @@ where
C: GenericConfig<D, F = F>,
{
pub circuit: CircuitData<F, C, D>,
lhs: TxnAggregationChildTarget<D>,
rhs: TxnAggregationChildTarget<D>,
lhs: AggregationChildTarget<D>,
rhs: AggregationChildTarget<D>,
public_values: PublicValuesTarget,
cyclic_vk: VerifierCircuitTarget,
}
Expand Down Expand Up @@ -315,8 +315,8 @@ where
let circuit = buffer.read_circuit_data(gate_serializer, generator_serializer)?;
let cyclic_vk = buffer.read_target_verifier_circuit()?;
let public_values = PublicValuesTarget::from_buffer(buffer)?;
let lhs = TxnAggregationChildTarget::from_buffer(buffer)?;
let rhs = TxnAggregationChildTarget::from_buffer(buffer)?;
let lhs = AggregationChildTarget::from_buffer(buffer)?;
let rhs = AggregationChildTarget::from_buffer(buffer)?;
Ok(Self {
circuit,
lhs,
Expand All @@ -327,49 +327,6 @@ where
}
}

#[derive(Eq, PartialEq, Debug)]
struct TxnAggregationChildTarget<const D: usize> {
is_agg: BoolTarget,
txn_agg_proof: ProofWithPublicInputsTarget<D>,
segment_agg_proof: ProofWithPublicInputsTarget<D>,
}

impl<const D: usize> TxnAggregationChildTarget<D> {
fn to_buffer(&self, buffer: &mut Vec<u8>) -> IoResult<()> {
buffer.write_target_bool(self.is_agg)?;
buffer.write_target_proof_with_public_inputs(&self.txn_agg_proof)?;
buffer.write_target_proof_with_public_inputs(&self.segment_agg_proof)?;
Ok(())
}

fn from_buffer(buffer: &mut Buffer) -> IoResult<Self> {
let is_agg = buffer.read_target_bool()?;
let txn_agg_proof = buffer.read_target_proof_with_public_inputs()?;
let segment_agg_proof = buffer.read_target_proof_with_public_inputs()?;
Ok(Self {
is_agg,
txn_agg_proof,
segment_agg_proof,
})
}

// `len_mem_cap` is the length of the Merkle
// caps for `MemBefore` and `MemAfter`.
fn public_values<F: RichField + Extendable<D>>(
&self,
builder: &mut CircuitBuilder<F, D>,
len_mem_cap: usize,
) -> PublicValuesTarget {
let txn_agg_pv =
PublicValuesTarget::from_public_inputs(&self.txn_agg_proof.public_inputs, len_mem_cap);
let segment_agg_pv = PublicValuesTarget::from_public_inputs(
&self.segment_agg_proof.public_inputs,
len_mem_cap,
);
PublicValuesTarget::select(builder, self.is_agg, txn_agg_pv, segment_agg_pv)
}
}

/// Data for the block circuit, which is used to generate a final block proof,
/// and compress it with an optional parent proof if present.
#[derive(Eq, PartialEq, Debug)]
Expand Down Expand Up @@ -1220,51 +1177,45 @@ where
fn add_segment_agg_child(
builder: &mut CircuitBuilder<F, D>,
root: &RootCircuitData<F, C, D>,
) -> SegmentAggregationChildTarget<D> {
) -> AggregationChildTarget<D> {
let common = &root.circuit.common;
let root_vk = builder.constant_verifier_data(&root.circuit.verifier_only);
let is_agg = builder.add_virtual_bool_target_safe();
let agg_proof = builder.add_virtual_proof_with_pis(common);
let segment_proof = builder.add_virtual_proof_with_pis(common);
let proof = builder.add_virtual_proof_with_pis(common);
builder
.conditionally_verify_cyclic_proof::<C>(
is_agg,
&agg_proof,
&segment_proof,
&root_vk,
common,
)
.conditionally_verify_cyclic_proof::<C>(is_agg, &agg_proof, &proof, &root_vk, common)
.expect("Failed to build cyclic recursion circuit");
SegmentAggregationChildTarget {
AggregationChildTarget {
is_agg,
agg_proof,
segment_proof,
proof,
}
}

fn add_txn_agg_child(
builder: &mut CircuitBuilder<F, D>,
segment_agg: &SegmentAggregationCircuitData<F, C, D>,
) -> TxnAggregationChildTarget<D> {
) -> AggregationChildTarget<D> {
let common = &segment_agg.circuit.common;
let inner_segment_agg_vk =
builder.constant_verifier_data(&segment_agg.circuit.verifier_only);
let is_agg = builder.add_virtual_bool_target_safe();
let txn_agg_proof = builder.add_virtual_proof_with_pis(common);
let segment_agg_proof = builder.add_virtual_proof_with_pis(common);
let agg_proof = builder.add_virtual_proof_with_pis(common);
let proof = builder.add_virtual_proof_with_pis(common);
builder
.conditionally_verify_cyclic_proof::<C>(
is_agg,
&txn_agg_proof,
&segment_agg_proof,
&agg_proof,
&proof,
&inner_segment_agg_vk,
common,
)
.expect("Failed to build cyclic recursion circuit");
TxnAggregationChildTarget {
AggregationChildTarget {
is_agg,
txn_agg_proof,
segment_agg_proof,
agg_proof,
proof,
}
}

Expand Down Expand Up @@ -1658,22 +1609,18 @@ where
let mut agg_inputs = PartialWitness::new();

Self::set_dummy_if_necessary(
self.segment_aggregation.lhs.is_agg,
&self.segment_aggregation.lhs,
lhs_is_agg,
&self.segment_aggregation.circuit,
&mut agg_inputs,
&self.segment_aggregation.lhs.segment_proof,
&self.segment_aggregation.lhs.agg_proof,
lhs_proof,
);

Self::set_dummy_if_necessary(
self.segment_aggregation.rhs.is_agg,
&self.segment_aggregation.rhs,
rhs_is_agg,
&self.segment_aggregation.circuit,
&mut agg_inputs,
&self.segment_aggregation.rhs.segment_proof,
&self.segment_aggregation.rhs.agg_proof,
rhs_proof,
);

Expand Down Expand Up @@ -1741,8 +1688,8 @@ where
/// one will generate a proof of
/// validity for both the transaction range covered by the previous proof
/// and the current transaction.
/// - `agg_segment_proof`: the final aggregation proof containing all
/// segments within the current transaction.
/// - `agg_proof`: the final aggregation proof containing all segments
/// within the current transaction.
/// - `public_values`: the public values associated to the aggregation
/// proof.
///
Expand All @@ -1763,22 +1710,18 @@ where
let mut txn_inputs = PartialWitness::new();

Self::set_dummy_if_necessary(
self.txn_aggregation.lhs.is_agg,
&self.txn_aggregation.lhs,
lhs_is_agg,
&self.txn_aggregation.circuit,
&mut txn_inputs,
&self.txn_aggregation.lhs.segment_agg_proof,
&self.txn_aggregation.lhs.txn_agg_proof,
lhs_proof,
);

Self::set_dummy_if_necessary(
self.txn_aggregation.rhs.is_agg,
&self.txn_aggregation.rhs,
rhs_is_agg,
&self.txn_aggregation.circuit,
&mut txn_inputs,
&self.txn_aggregation.rhs.segment_agg_proof,
&self.txn_aggregation.rhs.txn_agg_proof,
rhs_proof,
);

Expand Down Expand Up @@ -1865,21 +1808,24 @@ where
/// If the lhs is not an aggregation, we set the cyclic vk to a dummy value,
/// so that it corresponds to the aggregation cyclic vk.
fn set_dummy_if_necessary(
is_agg_target: BoolTarget,
agg_child: &AggregationChildTarget<D>,
is_agg: bool,
circuit: &CircuitData<F, C, D>,
agg_inputs: &mut PartialWitness<F>,
segment_proof_target: &ProofWithPublicInputsTarget<D>,
agg_proof_target: &ProofWithPublicInputsTarget<D>,
proof: &ProofWithPublicInputs<F, C, D>,
) {
agg_inputs.set_bool_target(is_agg_target, is_agg);
agg_inputs.set_bool_target(agg_child.is_agg, is_agg);
if is_agg {
agg_inputs.set_proof_with_pis_target(agg_proof_target, proof);
agg_inputs.set_proof_with_pis_target(&agg_child.agg_proof, proof);
} else {
Self::set_dummy_proof_with_cyclic_vk_pis(circuit, agg_inputs, agg_proof_target, proof)
Self::set_dummy_proof_with_cyclic_vk_pis(
circuit,
agg_inputs,
&agg_child.agg_proof,
proof,
)
}
agg_inputs.set_proof_with_pis_target(segment_proof_target, proof);
agg_inputs.set_proof_with_pis_target(&agg_child.proof, proof);
}

/// Create a final block proof, once all transactions of a given block have
Expand Down