Skip to content

Commit

Permalink
Fill in all prove/verify subroutines (except r1cs)
Browse files Browse the repository at this point in the history
  • Loading branch information
moodlezoup committed Dec 4, 2023
1 parent 3e43f88 commit 4d32731
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 106 deletions.
10 changes: 5 additions & 5 deletions jolt-core/src/benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ use crate::jolt::vm::instruction_lookups::InstructionLookupsProof;
use crate::jolt::vm::rv32i_vm::{RV32IJoltVM, RV32I};
use crate::jolt::vm::Jolt;
use crate::lasso::surge::Surge;
use crate::utils::math::Math;
use crate::utils::{math::Math, random::RandomTape};
use crate::{jolt::instruction::xor::XORInstruction, utils::gen_random_point};
use ark_curve25519::{EdwardsProjective, Fr};
use ark_std::{test_rng};
use ark_std::test_rng;
use merlin::Transcript;
use rand_chacha::rand_core::RngCore;

Expand Down Expand Up @@ -152,12 +152,12 @@ fn rv32i_lookup_benchmarks() -> Vec<(tracing::Span, Box<dyn FnOnce()>)> {
println!("Running {:?}", ops.len());

let work = Box::new(|| {
let r: Vec<Fr> = gen_random_point::<Fr>(ops.len().log_2());
let mut prover_transcript = Transcript::new(b"example");
let mut random_tape = RandomTape::new(b"test_tape");
let proof: InstructionLookupsProof<Fr, EdwardsProjective> =
RV32IJoltVM::prove_instruction_lookups(ops, r.clone(), &mut prover_transcript);
RV32IJoltVM::prove_instruction_lookups(ops, &mut prover_transcript, &mut random_tape);
let mut verifier_transcript = Transcript::new(b"example");
assert!(RV32IJoltVM::verify_instruction_lookups(proof, r, &mut verifier_transcript).is_ok());
assert!(RV32IJoltVM::verify_instruction_lookups(proof, &mut verifier_transcript).is_ok());
});
vec![(tracing::info_span!("RV32IM"), work)]
}
Expand Down
31 changes: 24 additions & 7 deletions jolt-core/src/jolt/trace/rv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ use crate::jolt::instruction::sltu::SLTUInstruction;
use crate::jolt::instruction::sra::SRAInstruction;
use crate::jolt::instruction::srl::SRLInstruction;
use crate::jolt::instruction::xor::XORInstruction;
use crate::jolt::instruction::{add::ADDInstruction, sub::SUBInstruction};
use crate::jolt::instruction::JoltInstruction;
use crate::jolt::instruction::{add::ADDInstruction, sub::SUBInstruction};
use crate::jolt::vm::{pc::ELFRow, rv32i_vm::RV32I};
use common::{constants::REGISTER_COUNT, RV32InstructionFormat, RV32IM};

Expand Down Expand Up @@ -1070,19 +1070,32 @@ mod tests {
.into_iter()
.map(|common| RVTraceRow::from_common(common))
.collect();
let _: Vec<ELFRow> = converted_trace
.iter()
.map(|row| row.to_pc_trace())
.collect();

let mut num_errors = 0;
for row in &converted_trace {
if let Err(e) = row.validate() {
// if row.opcode != RV32IM::SLLI {
println!("Validation error: {} \n{:#?}\n\n", e, row);
// }
num_errors += 1;
}
}
println!("Total errors: {num_errors}");
}

#[test]
fn load_bytecode() {
use common::path::JoltPaths;
use common::serializable::Serializable;

let bytecode_location = JoltPaths::bytecode_path("fibonacci");
let instructions = Vec::<common::ELFInstruction>::deserialize_from_file(&bytecode_location)
.expect("deserialization failed");
let _: Vec<ELFRow> = instructions.iter().map(|x| ELFRow::from(x)).collect();
}

#[test]
fn fib_e2e() {
use crate::jolt::vm::rv32i_vm::RV32I;
Expand Down Expand Up @@ -1115,13 +1128,18 @@ mod tests {
.into_iter()
.flat_map(|row| row.to_jolt_instructions())
.collect();
let r: Vec<Fr> = gen_random_point::<Fr>(lookup_ops.len().log_2());
let mut prover_transcript = Transcript::new(b"example");
let mut random_tape = RandomTape::new(b"test_tape");

let proof: InstructionLookupsProof<Fr, EdwardsProjective> =
RV32IJoltVM::prove_instruction_lookups(lookup_ops, r.clone(), &mut prover_transcript);
RV32IJoltVM::prove_instruction_lookups(
lookup_ops,
&mut prover_transcript,
&mut random_tape,
);
let mut verifier_transcript = Transcript::new(b"example");
assert!(
RV32IJoltVM::verify_instruction_lookups(proof, r, &mut verifier_transcript).is_ok()
RV32IJoltVM::verify_instruction_lookups(proof, &mut verifier_transcript).is_ok()
);

// Prove memory
Expand All @@ -1141,7 +1159,6 @@ mod tests {
let batched_polys = rw_memory.batch();
let commitments = ReadWriteMemory::commit(&batched_polys);

let mut random_tape = RandomTape::new(b"test_tape");
let proof = rw_memory.prove_memory_checking(
&rw_memory,
&batched_polys,
Expand Down
26 changes: 18 additions & 8 deletions jolt-core/src/jolt/vm/instruction_lookups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -778,10 +778,11 @@ where

pub fn prove_lookups(
&self,
r: Vec<F>,
transcript: &mut Transcript,
random_tape: &mut RandomTape<G>,
) -> InstructionLookupsProof<F, G> {
<Transcript as ProofTranscript<G>>::append_protocol_name(transcript, Self::protocol_name());

let polynomials = self.polynomialize();
let batched_polys = polynomials.batch();
let commitment = InstructionPolynomials::commit(&batched_polys);
Expand All @@ -790,7 +791,13 @@ where
.E_commitment
.append_to_transcript(b"comm_poly_row_col_ops_val", transcript);

let eq = EqPolynomial::new(r.to_vec());
let r_eq = <Transcript as ProofTranscript<G>>::challenge_vector(
transcript,
b"Jolt instruction lookups",
self.ops.len().log_2(),
);

let eq = EqPolynomial::new(r_eq.to_vec());
let sumcheck_claim = Self::compute_sumcheck_claim(&self.ops, &polynomials.E_polys, &eq);

<Transcript as ProofTranscript<G>>::append_scalar(
Expand All @@ -799,7 +806,7 @@ where
&sumcheck_claim,
);

let mut eq_poly = DensePolynomial::new(EqPolynomial::new(r).evals());
let mut eq_poly = DensePolynomial::new(EqPolynomial::new(r_eq).evals());
let num_rounds = self.ops.len().log_2();

// TODO: compartmentalize all primary sumcheck logic
Expand All @@ -814,16 +821,14 @@ where
transcript,
);

let mut random_tape = RandomTape::new(b"proof");

// Create a single opening proof for the flag_evals and memory_evals
let sumcheck_openings = PrimarySumcheckOpenings::prove_openings(
&batched_polys,
&commitment,
&r_primary_sumcheck,
(E_evals, flag_evals),
transcript,
&mut random_tape,
random_tape,
);

let primary_sumcheck = PrimarySumcheck {
Expand All @@ -838,7 +843,7 @@ where
&batched_polys,
&commitment,
transcript,
&mut random_tape,
random_tape,
);

InstructionLookupsProof {
Expand All @@ -850,7 +855,6 @@ where

pub fn verify(
proof: InstructionLookupsProof<F, G>,
r_eq: &[G::ScalarField],
transcript: &mut Transcript,
) -> Result<(), ProofVerifyError> {
<Transcript as ProofTranscript<G>>::append_protocol_name(transcript, Self::protocol_name());
Expand All @@ -860,6 +864,12 @@ where
.E_commitment
.append_to_transcript(b"comm_poly_row_col_ops_val", transcript);

let r_eq = <Transcript as ProofTranscript<G>>::challenge_vector(
transcript,
b"Jolt instruction lookups",
proof.primary_sumcheck.num_rounds,
);

<Transcript as ProofTranscript<G>>::append_scalar(
transcript,
b"claim_eval_scalar_product",
Expand Down
142 changes: 83 additions & 59 deletions jolt-core/src/jolt/vm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ use merlin::Transcript;
use std::any::TypeId;
use strum::{EnumCount, IntoEnumIterator};

use crate::lasso::{
memory_checking::MemoryCheckingProver,
surge::Surge,
use crate::{
lasso::{
memory_checking::{MemoryCheckingProof, MemoryCheckingProver, MemoryCheckingVerifier},
surge::{Surge, SurgeProof},
},
utils::math::Math,
};

use crate::jolt::{
Expand All @@ -17,7 +20,10 @@ use crate::poly::structured_poly::BatchablePolynomials;
use crate::utils::{errors::ProofVerifyError, random::RandomTape};

use self::instruction_lookups::{InstructionLookups, InstructionLookupsProof};
use self::read_write_memory::{MemoryCommitment, MemoryOp, ReadWriteMemory};
use self::pc::{ELFRow, PCInitFinalOpenings, PCPolys, PCReadWriteOpenings, ProgramCommitment};
use self::read_write_memory::{
MemoryCommitment, MemoryInitFinalOpenings, MemoryOp, MemoryReadWriteOpenings, ReadWriteMemory,
};

pub trait Jolt<F: PrimeField, G: CurveGroup<ScalarField = F>, const C: usize, const M: usize> {
type InstructionSet: JoltInstruction + Opcode + IntoEnumIterator + EnumCount;
Expand All @@ -35,77 +41,75 @@ pub trait Jolt<F: PrimeField, G: CurveGroup<ScalarField = F>, const C: usize, co

fn prove_instruction_lookups(
ops: Vec<Self::InstructionSet>,
r: Vec<F>,
transcript: &mut Transcript,
random_tape: &mut RandomTape<G>,
) -> InstructionLookupsProof<F, G> {
let instruction_lookups =
InstructionLookups::<F, G, Self::InstructionSet, Self::Subtables, C, M>::new(ops);
instruction_lookups.prove_lookups(r, transcript)
instruction_lookups.prove_lookups(transcript, random_tape)
}

fn verify_instruction_lookups(
proof: InstructionLookupsProof<F, G>,
r: Vec<F>,
transcript: &mut Transcript,
) -> Result<(), ProofVerifyError> {
InstructionLookups::<F, G, Self::InstructionSet, Self::Subtables, C, M>::verify(
proof, &r, transcript,
proof, transcript,
)
}

fn prove_program_code(
program_code: &[u64],
access_sequence: &[usize],
code_size: usize,
contiguous_reads_per_access: usize,
r_mem_check: &(F, F),
mut program: Vec<ELFRow>,
mut trace: Vec<ELFRow>,
transcript: &mut Transcript,
random_tape: &mut RandomTape<G>,
) -> (
MemoryCheckingProof<G, PCPolys<F, G>, PCReadWriteOpenings<F, G>, PCInitFinalOpenings<F, G>>,
ProgramCommitment<G>,
) {
// let (gamma, tau) = r_mem_check;
// let hash_func = |a: &F, v: &F, t: &F| -> F { *t * gamma.square() + *v * *gamma + *a - tau };

// let m: usize = (access_sequence.len() * contiguous_reads_per_access).next_power_of_two();
// // TODO(moodlezoup): resize access_sequence?

// let mut read_addrs: Vec<usize> = Vec::with_capacity(m);
// let mut final_cts: Vec<usize> = vec![0; code_size];
// let mut read_cts: Vec<usize> = Vec::with_capacity(m);
// let mut read_values: Vec<u64> = Vec::with_capacity(m);

// for (j, code_address) in access_sequence.iter().enumerate() {
// debug_assert!(code_address + contiguous_reads_per_access <= code_size);
// debug_assert!(code_address % contiguous_reads_per_access == 0);

// for offset in 0..contiguous_reads_per_access {
// let addr = code_address + offset;
// let counter = final_cts[addr];
// read_addrs.push(addr);
// read_values.push(program_code[addr]);
// read_cts.push(counter);
// final_cts[addr] = counter + 1;
// }
// }

// let E_poly: DensePolynomial<F> = DensePolynomial::from_u64(&read_values); // v_ops
// let dim: DensePolynomial<F> = DensePolynomial::from_usize(access_sequence); // a_ops
// let read_cts: DensePolynomial<F> = DensePolynomial::from_usize(&read_cts); // t_read
// let final_cts: DensePolynomial<F> = DensePolynomial::from_usize(&final_cts); // t_final
// let init_values: DensePolynomial<F> = DensePolynomial::from_u64(program_code); // v_mem

// let polys = PCPolys::new(dim, E_poly, init_values, read_cts, final_cts, 0);
// let (gens, commitments) = polys.commit::<G>();

todo!("decide how to represent nested proofs, gens, commitments");
// MemoryCheckingProof::<G, PCFingerprintProof<G>>::prove(
// &polys,
// r_fingerprints,
// &gens,
// &mut transcript,
// &mut random_tape,
// )
let polys: PCPolys<F, G> = PCPolys::new_program(program, trace);
let batched_polys = polys.batch();
let commitments = PCPolys::commit(&batched_polys);

(
polys.prove_memory_checking(
&polys,
&batched_polys,
&commitments,
transcript,
random_tape,
),
commitments,
)
}

fn verify_program_code(
proof: MemoryCheckingProof<
G,
PCPolys<F, G>,
PCReadWriteOpenings<F, G>,
PCInitFinalOpenings<F, G>,
>,
commitment: ProgramCommitment<G>,
transcript: &mut Transcript,
) -> Result<(), ProofVerifyError> {
PCPolys::verify_memory_checking(proof, &commitment, transcript)
}

fn prove_memory(memory_trace: Vec<MemoryOp>, memory_size: usize, transcript: &mut Transcript) {
fn prove_memory(
memory_trace: Vec<MemoryOp>,
memory_size: usize,
transcript: &mut Transcript,
random_tape: &mut RandomTape<G>,
) -> (
MemoryCheckingProof<
G,
ReadWriteMemory<F, G>,
MemoryReadWriteOpenings<F, G>,
MemoryInitFinalOpenings<F, G>,
>,
SurgeProof<F, G>,
) {
const MAX_TRACE_SIZE: usize = 1 << 22;
// TODO: Support longer traces
assert!(memory_trace.len() <= MAX_TRACE_SIZE);
Expand All @@ -116,13 +120,12 @@ pub trait Jolt<F: PrimeField, G: CurveGroup<ScalarField = F>, const C: usize, co
let batched_polys = memory.batch();
let commitments: MemoryCommitment<G> = ReadWriteMemory::commit(&batched_polys);

let mut random_tape = RandomTape::new(b"proof");
memory.prove_memory_checking(
let memory_checking_proof = memory.prove_memory_checking(
&memory,
&batched_polys,
&commitments,
transcript,
&mut random_tape,
random_tape,
);

let timestamp_validity_lookups: Vec<SLTUInstruction> = read_timestamps
Expand All @@ -134,6 +137,27 @@ pub trait Jolt<F: PrimeField, G: CurveGroup<ScalarField = F>, const C: usize, co
let timestamp_validity_proof =
<Surge<F, G, SLTUInstruction, 2, MAX_TRACE_SIZE>>::new(timestamp_validity_lookups)
.prove(transcript);

(memory_checking_proof, timestamp_validity_proof)
}

fn verify_memory(
memory_checking_proof: MemoryCheckingProof<
G,
ReadWriteMemory<F, G>,
MemoryReadWriteOpenings<F, G>,
MemoryInitFinalOpenings<F, G>,
>,
commitment: MemoryCommitment<G>,
transcript: &mut Transcript,
timestamp_validity_proof: SurgeProof<F, G>,
) -> Result<(), ProofVerifyError> {
const MAX_TRACE_SIZE: usize = 1 << 22;
ReadWriteMemory::verify_memory_checking(memory_checking_proof, &commitment, transcript)?;
<Surge<F, G, SLTUInstruction, 2, MAX_TRACE_SIZE>>::verify(
timestamp_validity_proof,
transcript,
)
}

fn prove_r1cs() {
Expand Down
Loading

0 comments on commit 4d32731

Please sign in to comment.