diff --git a/plonkish_backend/src/backend.rs b/plonkish_backend/src/backend.rs index db879ee..6478e55 100644 --- a/plonkish_backend/src/backend.rs +++ b/plonkish_backend/src/backend.rs @@ -4,19 +4,24 @@ use crate::{ arithmetic::Field, expression::Expression, transcript::{TranscriptRead, TranscriptWrite}, - Deserialize, DeserializeOwned, Itertools, Serialize, + Itertools, }, Error, }; use rand::RngCore; use std::{collections::BTreeSet, fmt::Debug, iter}; +use self::lookup::lasso::DecomposableTable; + pub mod hyperplonk; +pub mod lookup; pub trait PlonkishBackend: Clone + Debug { type Pcs: PolynomialCommitmentScheme; - type ProverParam: Clone + Debug + Serialize + DeserializeOwned; - type VerifierParam: Clone + Debug + Serialize + DeserializeOwned; + // FIXME : Add Serialize + DeserializeOwned later, currently removed as a shortcut + // to skip implementing those traits on Lasso related type + type ProverParam: Clone + Debug; + type VerifierParam: Clone + Debug; fn setup( circuit_info: &PlonkishCircuitInfo, @@ -43,7 +48,7 @@ pub trait PlonkishBackend: Clone + Debug { ) -> Result<(), Error>; } -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug)] pub struct PlonkishCircuitInfo { /// 2^k is the size of the circuit pub k: usize, @@ -64,6 +69,8 @@ pub struct PlonkishCircuitInfo { /// which contains vector of tuples representing the input and table /// respectively. pub lookups: Vec, Expression)>>, + /// Represents Lasso lookup argument, which contains index expression, output expression and table info + pub lasso_lookup: Option<(Expression, Expression, Box>)>, /// Each item inside outer vector repesents an closed permutation cycle, /// which contains vetor of tuples representing the polynomial index and /// row respectively. diff --git a/plonkish_backend/src/backend/hyperplonk.rs b/plonkish_backend/src/backend/hyperplonk.rs index e94d29f..1aec679 100644 --- a/plonkish_backend/src/backend/hyperplonk.rs +++ b/plonkish_backend/src/backend/hyperplonk.rs @@ -2,29 +2,30 @@ use crate::{ backend::{ hyperplonk::{ preprocessor::{batch_size, compose, permutation_polys}, - prover::{ - instance_polys, lookup_compressed_polys, lookup_h_polys, lookup_m_polys, - permutation_z_polys, prove_zero_check, - }, - verifier::verify_zero_check, + prover::{instance_polys, permutation_z_polys, prove_zero_check}, + verifier::{verify_zero_check, zero_check_opening_points_len}, }, PlonkishBackend, PlonkishCircuit, PlonkishCircuitInfo, WitnessEncoding, }, pcs::PolynomialCommitmentScheme, poly::multilinear::MultilinearPolynomial, util::{ - arithmetic::{powers, BooleanHypercube, PrimeField}, + arithmetic::{BooleanHypercube, PrimeField}, end_timer, expression::Expression, start_timer, transcript::{TranscriptRead, TranscriptWrite}, - Deserialize, DeserializeOwned, Itertools, Serialize, + DeserializeOwned, Itertools, Serialize, }, Error, }; use rand::RngCore; use std::{fmt::Debug, hash::Hash, iter, marker::PhantomData}; +use self::{prover::prove_lasso_lookup, verifier::verify_lasso_lookup}; + +use super::lookup::lasso::DecomposableTable; + pub(crate) mod preprocessor; pub(crate) mod prover; pub(crate) mod verifier; @@ -35,7 +36,7 @@ pub mod util; #[derive(Clone, Debug)] pub struct HyperPlonk(PhantomData); -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug)] pub struct HyperPlonkProverParam where F: PrimeField, @@ -45,7 +46,13 @@ where pub(crate) num_instances: Vec, pub(crate) num_witness_polys: Vec, pub(crate) num_challenges: Vec, - pub(crate) lookups: Vec, Expression)>>, + /// (index expression, output expression, table info) + pub(crate) lasso_lookup: Option<(Expression, Expression, Box>)>, + /// offset of polynomials related to Lasso lookup in batch opening + /// Lasso polynomials are tracked separately since Lasso invokes separate sumcheck + pub(crate) lookup_polys_offset: usize, + /// offset of points at which polynomials related to Lasso lookup opened in batch opening + pub(crate) lookup_points_offset: usize, pub(crate) num_permutation_z_polys: usize, pub(crate) num_vars: usize, pub(crate) expression: Expression, @@ -55,7 +62,7 @@ where pub(crate) permutation_comms: Vec, } -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug)] pub struct HyperPlonkVerifierParam where F: PrimeField, @@ -65,7 +72,9 @@ where pub(crate) num_instances: Vec, pub(crate) num_witness_polys: Vec, pub(crate) num_challenges: Vec, - pub(crate) num_lookups: usize, + pub(crate) lasso_table: Option>>, + pub(crate) lookup_polys_offset: usize, + pub(crate) lookup_points_offset: usize, pub(crate) num_permutation_z_polys: usize, pub(crate) num_vars: usize, pub(crate) expression: Expression, @@ -124,12 +133,26 @@ where // Compose `VirtualPolynomialInfo` let (num_permutation_z_polys, expression) = compose(circuit_info); + let lookup_polys_offset = circuit_info.num_instances.len() + + preprocess_polys.len() + + circuit_info.num_witness_polys.iter().sum::() + + permutation_polys.len() + + num_permutation_z_polys; + let lookup_points_offset = + zero_check_opening_points_len(&expression, circuit_info.num_instances.len()); + let lasso_table = circuit_info + .lasso_lookup + .is_some() + .then(|| circuit_info.lasso_lookup.as_ref().unwrap().2.clone()); + let vp = HyperPlonkVerifierParam { pcs: pcs_vp, num_instances: circuit_info.num_instances.clone(), num_witness_polys: circuit_info.num_witness_polys.clone(), num_challenges: circuit_info.num_challenges.clone(), - num_lookups: circuit_info.lookups.len(), + lasso_table, + lookup_polys_offset, + lookup_points_offset, num_permutation_z_polys, num_vars, expression: expression.clone(), @@ -145,7 +168,9 @@ where num_instances: circuit_info.num_instances.clone(), num_witness_polys: circuit_info.num_witness_polys.clone(), num_challenges: circuit_info.num_challenges.clone(), - lookups: circuit_info.lookups.clone(), + lasso_lookup: circuit_info.lasso_lookup.clone(), + lookup_polys_offset, + lookup_points_offset, num_permutation_z_polys, num_vars, expression, @@ -208,31 +233,20 @@ where .chain(witness_polys.iter()) .collect_vec(); - // Round n - - let beta = transcript.squeeze_challenge(); - - let timer = start_timer(|| format!("lookup_compressed_polys-{}", pp.lookups.len())); - let lookup_compressed_polys = { - let max_lookup_width = pp.lookups.iter().map(Vec::len).max().unwrap_or_default(); - let betas = powers(beta).take(max_lookup_width).collect_vec(); - lookup_compressed_polys(&pp.lookups, &polys, &challenges, &betas) + let mut lookup_opening_points = vec![]; + let mut lookup_opening_evals = vec![]; + let (lookup_polys, lookup_comms, lasso_challenges) = prove_lasso_lookup( + pp, + &polys, + &mut lookup_opening_points, + &mut lookup_opening_evals, + transcript, + )?; + let [beta, gamma] = if pp.lasso_lookup.is_some() { + lasso_challenges.try_into().unwrap() + } else { + transcript.squeeze_challenges(2).try_into().unwrap() }; - end_timer(timer); - - let timer = start_timer(|| format!("lookup_m_polys-{}", pp.lookups.len())); - let lookup_m_polys = lookup_m_polys(&lookup_compressed_polys)?; - end_timer(timer); - - let lookup_m_comms = Pcs::batch_commit_and_write(&pp.pcs, &lookup_m_polys, transcript)?; - - // Round n+1 - - let gamma = transcript.squeeze_challenge(); - - let timer = start_timer(|| format!("lookup_h_polys-{}", pp.lookups.len())); - let lookup_h_polys = lookup_h_polys(&lookup_compressed_polys, &lookup_m_polys, &gamma); - end_timer(timer); let timer = start_timer(|| format!("permutation_z_polys-{}", pp.permutation_polys.len())); let permutation_z_polys = permutation_z_polys( @@ -244,12 +258,8 @@ where ); end_timer(timer); - let lookup_h_permutation_z_polys = iter::empty() - .chain(lookup_h_polys.iter()) - .chain(permutation_z_polys.iter()) - .collect_vec(); - let lookup_h_permutation_z_comms = - Pcs::batch_commit_and_write(&pp.pcs, lookup_h_permutation_z_polys.clone(), transcript)?; + let permutation_z_comms = + Pcs::batch_commit_and_write(&pp.pcs, permutation_z_polys.iter(), transcript)?; // Round n+2 @@ -259,8 +269,7 @@ where let polys = iter::empty() .chain(polys) .chain(pp.permutation_polys.iter().map(|(_, poly)| poly)) - .chain(lookup_m_polys.iter()) - .chain(lookup_h_permutation_z_polys) + .chain(permutation_z_polys.iter()) .collect_vec(); challenges.extend([beta, gamma, alpha]); let (points, evals) = prove_zero_check( @@ -273,20 +282,27 @@ where )?; // PCS open - + let polys = iter::empty().chain(polys).chain(lookup_polys.iter()); let dummy_comm = Pcs::Commitment::default(); let comms = iter::empty() .chain(iter::repeat(&dummy_comm).take(pp.num_instances.len())) .chain(&pp.preprocess_comms) .chain(&witness_comms) .chain(&pp.permutation_comms) - .chain(&lookup_m_comms) - .chain(&lookup_h_permutation_z_comms) + .chain(&permutation_z_comms) + .chain(lookup_comms.iter()) + .collect_vec(); + let points = iter::empty() + .chain(points) + .chain(lookup_opening_points) + .collect_vec(); + let evals = iter::empty() + .chain(evals) + .chain(lookup_opening_evals) .collect_vec(); let timer = start_timer(|| format!("pcs_batch_open-{}", evals.len())); Pcs::batch_open(&pp.pcs, polys, comms, &points, &evals, transcript)?; end_timer(timer); - Ok(()) } @@ -305,7 +321,8 @@ where // Round 0..n - let mut witness_comms = Vec::with_capacity(vp.num_witness_polys.iter().sum()); + let num_witness_polys = vp.num_witness_polys.iter().sum(); + let mut witness_comms = Vec::with_capacity(num_witness_polys); let mut challenges = Vec::with_capacity(vp.num_challenges.iter().sum::() + 4); for (num_polys, num_challenges) in vp.num_witness_polys.iter().zip_eq(vp.num_challenges.iter()) @@ -314,21 +331,22 @@ where challenges.extend(transcript.squeeze_challenges(*num_challenges)); } - // Round n - - let beta = transcript.squeeze_challenge(); - - let lookup_m_comms = Pcs::read_commitments(&vp.pcs, vp.num_lookups, transcript)?; - - // Round n+1 - - let gamma = transcript.squeeze_challenge(); - - let lookup_h_permutation_z_comms = Pcs::read_commitments( - &vp.pcs, - vp.num_lookups + vp.num_permutation_z_polys, + let mut lookup_opening_points = vec![]; + let mut lookup_opening_evals = vec![]; + let (lookup_comms, lasso_challenges) = verify_lasso_lookup::( + vp, + &mut lookup_opening_points, + &mut lookup_opening_evals, transcript, )?; + let [beta, gamma] = if vp.lasso_table.is_some() { + lasso_challenges.try_into().unwrap() + } else { + transcript.squeeze_challenges(2).try_into().unwrap() + }; + + let permutation_z_comms = + Pcs::read_commitments(&vp.pcs, vp.num_permutation_z_polys, transcript)?; // Round n+2 @@ -346,15 +364,22 @@ where )?; // PCS verify - let dummy_comm = Pcs::Commitment::default(); let comms = iter::empty() .chain(iter::repeat(&dummy_comm).take(vp.num_instances.len())) .chain(&vp.preprocess_comms) .chain(&witness_comms) .chain(vp.permutation_comms.iter().map(|(_, comm)| comm)) - .chain(&lookup_m_comms) - .chain(&lookup_h_permutation_z_comms) + .chain(&permutation_z_comms) + .chain(lookup_comms.iter()) + .collect_vec(); + let points = iter::empty() + .chain(points) + .chain(lookup_opening_points) + .collect_vec(); + let evals = iter::empty() + .chain(evals) + .chain(lookup_opening_evals) .collect_vec(); Pcs::batch_verify(&vp.pcs, comms, &points, &evals, transcript)?; diff --git a/plonkish_backend/src/backend/hyperplonk/preprocessor.rs b/plonkish_backend/src/backend/hyperplonk/preprocessor.rs index c23c9b0..299d1d2 100644 --- a/plonkish_backend/src/backend/hyperplonk/preprocessor.rs +++ b/plonkish_backend/src/backend/hyperplonk/preprocessor.rs @@ -8,16 +8,17 @@ use crate::{ Itertools, }, }; -use std::{array, borrow::Cow, iter, mem}; +use std::{array, iter, mem}; pub(super) fn batch_size(circuit_info: &PlonkishCircuitInfo) -> usize { - let num_lookups = circuit_info.lookups.len(); let num_permutation_polys = circuit_info.permutation_polys().len(); chain![ [circuit_info.preprocess_polys.len() + circuit_info.permutation_polys().len()], circuit_info.num_witness_polys.clone(), - [num_lookups], - [num_lookups + div_ceil(num_permutation_polys, max_degree(circuit_info, None) - 1)], + [div_ceil( + num_permutation_polys, + max_degree(circuit_info) - 1 + )], ] .sum() } @@ -29,85 +30,32 @@ pub(super) fn compose( let [beta, gamma, alpha] = &array::from_fn(|idx| Expression::::Challenge(challenge_offset + idx)); - let (lookup_constraints, lookup_zero_checks) = lookup_constraints(circuit_info, beta, gamma); - - let max_degree = max_degree(circuit_info, Some(&lookup_constraints)); - let (num_permutation_z_polys, permutation_constraints) = permutation_constraints( - circuit_info, - max_degree, - beta, - gamma, - 2 * circuit_info.lookups.len(), - ); + let max_degree = max_degree(circuit_info); + let (num_permutation_z_polys, permutation_constraints) = + permutation_constraints(circuit_info, max_degree, beta, gamma, 0); let expression = { let constraints = iter::empty() .chain(circuit_info.constraints.iter()) - .chain(lookup_constraints.iter()) .chain(permutation_constraints.iter()) .collect_vec(); let eq = Expression::eq_xy(0); let zero_check_on_every_row = Expression::distribute_powers(constraints, alpha) * eq; - Expression::distribute_powers( - iter::empty() - .chain(lookup_zero_checks.iter()) - .chain(Some(&zero_check_on_every_row)), - alpha, - ) + Expression::distribute_powers(iter::empty().chain(Some(&zero_check_on_every_row)), alpha) }; (num_permutation_z_polys, expression) } -pub(super) fn max_degree( - circuit_info: &PlonkishCircuitInfo, - lookup_constraints: Option<&[Expression]>, -) -> usize { - let lookup_constraints = lookup_constraints.map(Cow::Borrowed).unwrap_or_else(|| { - let dummy_challenge = Expression::zero(); - Cow::Owned(self::lookup_constraints(circuit_info, &dummy_challenge, &dummy_challenge).0) - }); +pub(super) fn max_degree(circuit_info: &PlonkishCircuitInfo) -> usize { iter::empty() .chain(circuit_info.constraints.iter().map(Expression::degree)) - .chain(lookup_constraints.iter().map(Expression::degree)) .chain(circuit_info.max_degree) .chain(Some(2)) .max() .unwrap() } -pub(super) fn lookup_constraints( - circuit_info: &PlonkishCircuitInfo, - beta: &Expression, - gamma: &Expression, -) -> (Vec>, Vec>) { - let m_offset = circuit_info.num_poly() + circuit_info.permutation_polys().len(); - let h_offset = m_offset + circuit_info.lookups.len(); - let constraints = circuit_info - .lookups - .iter() - .zip(m_offset..) - .zip(h_offset..) - .flat_map(|((lookup, m), h)| { - let [m, h] = &[m, h] - .map(|poly| Query::new(poly, Rotation::cur())) - .map(Expression::::Polynomial); - let (inputs, tables) = lookup - .iter() - .map(|(input, table)| (input, table)) - .unzip::<_, _, Vec<_>, Vec<_>>(); - let input = &Expression::distribute_powers(inputs, beta); - let table = &Expression::distribute_powers(tables, beta); - [h * (input + gamma) * (table + gamma) - (table + gamma) + m * (input + gamma)] - }) - .collect_vec(); - let sum_check = (h_offset..) - .take(circuit_info.lookups.len()) - .map(|h| Query::new(h, Rotation::cur()).into()) - .collect_vec(); - (constraints, sum_check) -} - pub(crate) fn permutation_constraints( circuit_info: &PlonkishCircuitInfo, max_degree: usize, diff --git a/plonkish_backend/src/backend/hyperplonk/prover.rs b/plonkish_backend/src/backend/hyperplonk/prover.rs index 19ef148..e7f36bc 100644 --- a/plonkish_backend/src/backend/hyperplonk/prover.rs +++ b/plonkish_backend/src/backend/hyperplonk/prover.rs @@ -4,30 +4,29 @@ use crate::{ verifier::{pcs_query, point_offset, points}, HyperPlonk, }, + lookup::lasso::prover::LassoProver, WitnessEncoding, }, - pcs::Evaluation, + pcs::{Evaluation, PolynomialCommitmentScheme}, piop::sum_check::{ classic::{ClassicSumCheck, EvaluationsProver}, SumCheck, VirtualPolynomial, }, - poly::{multilinear::MultilinearPolynomial, Polynomial}, + poly::multilinear::MultilinearPolynomial, util::{ - arithmetic::{div_ceil, steps_by, sum, BatchInvert, BooleanHypercube, PrimeField}, + arithmetic::{div_ceil, steps_by, BatchInvert, BooleanHypercube, PrimeField}, end_timer, - expression::{CommonPolynomial, Expression, Rotation}, - parallel::{num_threads, par_map_collect, parallelize, parallelize_iter}, + expression::{Expression, Rotation}, + parallel::{par_map_collect, parallelize}, start_timer, - transcript::FieldTranscriptWrite, + transcript::{FieldTranscriptWrite, TranscriptWrite}, Itertools, }, Error, }; -use std::{ - collections::{HashMap, HashSet}, - hash::Hash, - iter, -}; +use std::iter; + +use super::HyperPlonkProverParam; pub(crate) fn instance_polys<'a, F: PrimeField>( num_vars: usize, @@ -47,208 +46,6 @@ pub(crate) fn instance_polys<'a, F: PrimeField>( .collect() } -pub(crate) fn lookup_compressed_polys( - lookups: &[Vec<(Expression, Expression)>], - polys: &[&MultilinearPolynomial], - challenges: &[F], - betas: &[F], -) -> Vec<[MultilinearPolynomial; 2]> { - if lookups.is_empty() { - return Default::default(); - } - - let num_vars = polys[0].num_vars(); - let expression = lookups - .iter() - .flat_map(|lookup| lookup.iter().map(|(input, table)| (input + table))) - .sum::>(); - let lagranges = { - let bh = BooleanHypercube::new(num_vars).iter().collect_vec(); - expression - .used_langrange() - .into_iter() - .map(|i| (i, bh[i.rem_euclid(1 << num_vars) as usize])) - .collect::>() - }; - lookups - .iter() - .map(|lookup| lookup_compressed_poly(lookup, &lagranges, polys, challenges, betas)) - .collect() -} - -pub(super) fn lookup_compressed_poly( - lookup: &[(Expression, Expression)], - lagranges: &HashSet<(i32, usize)>, - polys: &[&MultilinearPolynomial], - challenges: &[F], - betas: &[F], -) -> [MultilinearPolynomial; 2] { - let num_vars = polys[0].num_vars(); - let bh = BooleanHypercube::new(num_vars); - let compress = |expressions: &[&Expression]| { - betas - .iter() - .copied() - .zip(expressions.iter().map(|expression| { - let mut compressed = vec![F::ZERO; 1 << num_vars]; - parallelize(&mut compressed, |(compressed, start)| { - for (b, compressed) in (start..).zip(compressed) { - *compressed = expression.evaluate( - &|constant| constant, - &|common_poly| match common_poly { - CommonPolynomial::Identity => F::from(b as u64), - CommonPolynomial::Lagrange(i) => { - if lagranges.contains(&(i, b)) { - F::ONE - } else { - F::ZERO - } - } - CommonPolynomial::EqXY(_) => unreachable!(), - }, - &|query| polys[query.poly()][bh.rotate(b, query.rotation())], - &|challenge| challenges[challenge], - &|value| -value, - &|lhs, rhs| lhs + &rhs, - &|lhs, rhs| lhs * &rhs, - &|value, scalar| value * &scalar, - ); - } - }); - MultilinearPolynomial::new(compressed) - })) - .sum::>() - }; - - let (inputs, tables) = lookup - .iter() - .map(|(input, table)| (input, table)) - .unzip::<_, _, Vec<_>, Vec<_>>(); - - let timer = start_timer(|| "compressed_input_poly"); - let compressed_input_poly = compress(&inputs); - end_timer(timer); - - let timer = start_timer(|| "compressed_table_poly"); - let compressed_table_poly = compress(&tables); - end_timer(timer); - - [compressed_input_poly, compressed_table_poly] -} - -pub(crate) fn lookup_m_polys( - compressed_polys: &[[MultilinearPolynomial; 2]], -) -> Result>, Error> { - compressed_polys.iter().map(lookup_m_poly).try_collect() -} - -pub(super) fn lookup_m_poly( - compressed_polys: &[MultilinearPolynomial; 2], -) -> Result, Error> { - let [input, table] = compressed_polys; - - let counts = { - let indice_map = table.iter().zip(0..).collect::>(); - - let chunk_size = div_ceil(input.evals().len(), num_threads()); - let num_chunks = div_ceil(input.evals().len(), chunk_size); - let mut counts = vec![HashMap::new(); num_chunks]; - let mut valids = vec![true; num_chunks]; - parallelize_iter( - counts - .iter_mut() - .zip(valids.iter_mut()) - .zip((0..).step_by(chunk_size)), - |((count, valid), start)| { - for input in input[start..].iter().take(chunk_size) { - if let Some(idx) = indice_map.get(input) { - count - .entry(*idx) - .and_modify(|count| *count += 1) - .or_insert(1); - } else { - *valid = false; - break; - } - } - }, - ); - if valids.iter().any(|valid| !valid) { - return Err(Error::InvalidSnark("Invalid lookup input".to_string())); - } - counts - }; - - let mut m = vec![0; 1 << input.num_vars()]; - for (idx, count) in counts.into_iter().flatten() { - m[idx] += count; - } - let m = par_map_collect(m, |count| match count { - 0 => F::ZERO, - 1 => F::ONE, - count => F::from(count), - }); - Ok(MultilinearPolynomial::new(m)) -} - -pub(super) fn lookup_h_polys( - compressed_polys: &[[MultilinearPolynomial; 2]], - m_polys: &[MultilinearPolynomial], - gamma: &F, -) -> Vec> { - compressed_polys - .iter() - .zip(m_polys.iter()) - .map(|(compressed_polys, m_poly)| lookup_h_poly(compressed_polys, m_poly, gamma)) - .collect() -} - -pub(super) fn lookup_h_poly( - compressed_polys: &[MultilinearPolynomial; 2], - m_poly: &MultilinearPolynomial, - gamma: &F, -) -> MultilinearPolynomial { - let [input, table] = compressed_polys; - let mut h_input = vec![F::ZERO; 1 << input.num_vars()]; - let mut h_table = vec![F::ZERO; 1 << input.num_vars()]; - - parallelize(&mut h_input, |(h_input, start)| { - for (h_input, input) in h_input.iter_mut().zip(input[start..].iter()) { - *h_input = *gamma + input; - } - }); - parallelize(&mut h_table, |(h_table, start)| { - for (h_table, table) in h_table.iter_mut().zip(table[start..].iter()) { - *h_table = *gamma + table; - } - }); - - let chunk_size = div_ceil(2 * h_input.len(), num_threads()); - parallelize_iter( - iter::empty() - .chain(h_input.chunks_mut(chunk_size)) - .chain(h_table.chunks_mut(chunk_size)), - |h| { - h.iter_mut().batch_invert(); - }, - ); - - parallelize(&mut h_input, |(h_input, start)| { - for (h_input, (h_table, m)) in h_input - .iter_mut() - .zip(h_table[start..].iter().zip(m_poly[start..].iter())) - { - *h_input -= *h_table * m; - } - }); - - if cfg!(feature = "sanity-check") { - assert_eq!(sum::(&h_input), F::ZERO); - } - - MultilinearPolynomial::new(h_input) -} - pub(crate) fn permutation_z_polys( num_chunks: usize, permutation_polys: &[(usize, MultilinearPolynomial)], @@ -407,3 +204,89 @@ pub(crate) fn prove_sum_check( Ok((points(&pcs_query, &x), evals)) } + +pub(super) fn prove_lasso_lookup< + F: PrimeField, + Pcs: PolynomialCommitmentScheme>, +>( + pp: &HyperPlonkProverParam, + polys: &[&MultilinearPolynomial], + lookup_opening_points: &mut Vec>, + lookup_opening_evals: &mut Vec>, + transcript: &mut impl TranscriptWrite, +) -> Result<(Vec>, Vec, Vec), Error> { + if pp.lasso_lookup.is_none() { + return Ok((vec![], vec![], vec![])); + } + let lasso_lookup = pp.lasso_lookup.as_ref().unwrap(); + let (lookup, table) = ((&lasso_lookup.0, &lasso_lookup.1), &lasso_lookup.2); + let (lookup_index_poly, lookup_output_poly) = LassoProver::::lookup_poly(&lookup, &polys); + + let num_vars = lookup_output_poly.num_vars(); + + // get subtable_polys + let subtable_polys = table.subtable_polys(); + let subtable_polys = subtable_polys.iter().collect_vec(); + let subtable_polys = subtable_polys.as_slice(); + + let (lookup_polys, lookup_comms) = LassoProver::::commit( + &pp.pcs, + pp.lookup_polys_offset, + &table, + subtable_polys, + lookup_output_poly, + &lookup_index_poly, + transcript, + )?; + + // Round n + // squeeze `r` + let r = transcript.squeeze_challenges(num_vars); + + let (lookup_output_poly, dims, read_ts_polys, final_cts_polys, e_polys) = ( + &lookup_polys[0][0], + &lookup_polys[1], + &lookup_polys[2], + &lookup_polys[3], + &lookup_polys[4], + ); + // Lasso Sumcheck + LassoProver::::prove_sum_check( + pp.lookup_points_offset, + lookup_opening_points, + lookup_opening_evals, + &table, + lookup_output_poly, + &e_polys.iter().collect_vec(), + &r, + num_vars, + transcript, + )?; + + // squeeze memory checking challenges -> we will reuse beta, gamma for memory checking of Lasso + // Round n+1 + let [beta, gamma] = transcript.squeeze_challenges(2).try_into().unwrap(); + + // memory_checking + LassoProver::::memory_checking( + pp.lookup_points_offset, + lookup_opening_points, + lookup_opening_evals, + table, + subtable_polys, + dims, + read_ts_polys, + final_cts_polys, + e_polys, + &beta, + &gamma, + transcript, + )?; + + let lookup_polys = lookup_polys + .into_iter() + .flat_map(|lookup_polys| lookup_polys.into_iter().map(|poly| poly.poly).collect_vec()) + .collect_vec(); + let lookup_comms = lookup_comms.concat(); + Ok((lookup_polys, lookup_comms, vec![beta, gamma])) +} diff --git a/plonkish_backend/src/backend/hyperplonk/util.rs b/plonkish_backend/src/backend/hyperplonk/util.rs index 30965e5..eea9bee 100644 --- a/plonkish_backend/src/backend/hyperplonk/util.rs +++ b/plonkish_backend/src/backend/hyperplonk/util.rs @@ -2,17 +2,14 @@ use crate::{ backend::{ hyperplonk::{ preprocessor::{compose, permutation_polys}, - prover::{ - instance_polys, lookup_compressed_polys, lookup_h_polys, lookup_m_polys, - permutation_z_polys, - }, + prover::{instance_polys, permutation_z_polys}, }, mock::MockCircuit, PlonkishCircuit, PlonkishCircuitInfo, }, poly::{multilinear::MultilinearPolynomial, Polynomial}, util::{ - arithmetic::{powers, BooleanHypercube, PrimeField}, + arithmetic::{BooleanHypercube, PrimeField}, expression::{Expression, Query, Rotation}, test::{rand_array, rand_idx, rand_vec}, Itertools, @@ -42,7 +39,8 @@ pub fn vanilla_plonk_circuit_info( num_witness_polys: vec![3], num_challenges: vec![0], constraints: vec![q_l * w_l + q_r * w_r + q_m * w_l * w_r + q_o * w_o + q_c + pi], - lookups: Vec::new(), + lookups: vec![], + lasso_lookup: None, permutations, max_degree: Some(4), } @@ -80,6 +78,7 @@ pub fn vanilla_plonk_with_lookup_circuit_info( (q_lookup * w_r, t_r.clone()), (q_lookup * w_o, t_o.clone()), ]], + lasso_lookup: None, permutations, max_degree: Some(4), } @@ -338,17 +337,6 @@ pub fn rand_vanilla_plonk_with_lookup_assignment( let challenges: [_; 3] = rand_array(&mut witness_rng); let [beta, gamma, _] = challenges; - let (lookup_compressed_polys, lookup_m_polys) = { - let PlonkishCircuitInfo { lookups, .. } = - vanilla_plonk_with_lookup_circuit_info(0, 0, Default::default(), Vec::new()); - let betas = powers(beta).take(3).collect_vec(); - let lookup_compressed_polys = - lookup_compressed_polys(&lookups, &polys.iter().collect_vec(), &[], &betas); - let lookup_m_polys = lookup_m_polys(&lookup_compressed_polys).unwrap(); - (lookup_compressed_polys, lookup_m_polys) - }; - let lookup_h_polys = lookup_h_polys(&lookup_compressed_polys, &lookup_m_polys, &gamma); - let permutation_polys = permutation_polys(num_vars, &[10, 11, 12], &permutations); let permutation_z_polys = permutation_z_polys( 1, @@ -365,8 +353,6 @@ pub fn rand_vanilla_plonk_with_lookup_assignment( iter::empty() .chain(polys) .chain(permutation_polys) - .chain(lookup_m_polys) - .chain(lookup_h_polys) .chain(permutation_z_polys) .collect_vec(), challenges.to_vec(), diff --git a/plonkish_backend/src/backend/hyperplonk/verifier.rs b/plonkish_backend/src/backend/hyperplonk/verifier.rs index dcc602c..95d57f9 100644 --- a/plonkish_backend/src/backend/hyperplonk/verifier.rs +++ b/plonkish_backend/src/backend/hyperplonk/verifier.rs @@ -1,20 +1,23 @@ use crate::{ - pcs::Evaluation, + backend::lookup::lasso::verifier::LassoVerifier, + pcs::{Evaluation, PolynomialCommitmentScheme}, piop::sum_check::{ classic::{ClassicSumCheck, EvaluationsProver}, evaluate, lagrange_eval, SumCheck, }, - poly::multilinear::{rotation_eval, rotation_eval_points}, + poly::multilinear::{rotation_eval, rotation_eval_points, MultilinearPolynomial}, util::{ arithmetic::{inner_product, BooleanHypercube, PrimeField}, expression::{Expression, Query, Rotation}, - transcript::FieldTranscriptRead, + transcript::{FieldTranscriptRead, TranscriptRead}, Itertools, }, Error, }; use std::collections::{BTreeSet, HashMap}; +use super::HyperPlonkVerifierParam; + #[allow(clippy::type_complexity)] pub(super) fn verify_zero_check( num_vars: usize, @@ -180,3 +183,67 @@ pub(crate) fn point_offset(pcs_query: &BTreeSet) -> HashMap( + expression: &Expression, + num_instance_poly: usize, +) -> usize { + let pcs_query = pcs_query(expression, num_instance_poly); + pcs_query + .iter() + .map(Query::rotation) + .collect::>() + .into_iter() + .map(|rotation| 1 << rotation.distance()) + .sum() +} + +pub(super) fn verify_lasso_lookup< + F: PrimeField, + Pcs: PolynomialCommitmentScheme>, +>( + vp: &HyperPlonkVerifierParam, + lookup_opening_points: &mut Vec>, + lookup_opening_evals: &mut Vec>, + transcript: &mut impl TranscriptRead, +) -> Result<(Vec, Vec), Error> { + if vp.lasso_table.is_none() { + return Ok((vec![], vec![])); + } + let lookup_table = vp.lasso_table.as_ref().unwrap(); + + let lookup_comms = + LassoVerifier::::read_commitments(&vp.pcs, lookup_table, transcript)?; + + // Round n + let r = transcript.squeeze_challenges(vp.num_vars); + + LassoVerifier::::verify_sum_check( + lookup_table, + vp.num_vars, + vp.lookup_polys_offset, + vp.lookup_points_offset, + lookup_opening_points, + lookup_opening_evals, + &r, + transcript, + )?; + + // Round n+1 + let [beta, gamma] = transcript.squeeze_challenges(2).try_into().unwrap(); + + // memory checking + LassoVerifier::::memory_checking( + vp.num_vars, + vp.lookup_polys_offset, + vp.lookup_points_offset, + lookup_opening_points, + lookup_opening_evals, + lookup_table, + &beta, + &gamma, + transcript, + )?; + + Ok((lookup_comms, vec![beta, gamma])) +} diff --git a/plonkish_backend/src/backend/lookup/lasso.rs b/plonkish_backend/src/backend/lookup/lasso.rs new file mode 100644 index 0000000..518e943 --- /dev/null +++ b/plonkish_backend/src/backend/lookup/lasso.rs @@ -0,0 +1,70 @@ +use std::{fmt::Debug, marker::PhantomData}; + +use halo2_curves::ff::{Field, PrimeField}; + +use crate::{ + pcs::PolynomialCommitmentScheme, + poly::multilinear::{MultilinearPolynomial, MultilinearPolynomialTerms}, + util::expression::Expression, +}; + +pub mod memory_checking; +pub mod prover; +pub mod test; +pub mod verifier; + +pub trait Subtable { + fn evaluate(point: &[F]) -> F; +} + +/// This is a trait that contains information about decomposable table to which +/// backend prover and verifier can ask +pub trait DecomposableTable: Debug + Sync + DecomposableTableClone { + fn num_memories(&self) -> usize; + + /// Returns multilinear extension polynomials of each subtable + fn subtable_polys(&self) -> Vec>; + fn subtable_polys_terms(&self) -> Vec>; + + fn combine_lookup_expressions(&self, expressions: Vec>) -> Expression; + + /// The `g` function that computes T[r] = g(T_1[r_1], ..., T_k[r_1], T_{k+1}[r_2], ..., T_{\alpha}[r_c]) + fn combine_lookups(&self, operands: &[F]) -> F; + + /// Returns the size of bits for each chunk. + /// Each chunk can have different bits. + fn chunk_bits(&self) -> Vec; + + /// Returns the indices of each subtable lookups + /// The length of `index_bits` is same as actual bit length of table index + fn subtable_indices(&self, index_bits: Vec) -> Vec>; + + fn memory_to_subtable_index(&self, memory_index: usize) -> usize; + + fn memory_to_chunk_index(&self, memory_index: usize) -> usize; +} + +pub trait DecomposableTableClone { + fn clone_box(&self) -> Box>; +} + +impl DecomposableTableClone for T +where + T: DecomposableTable + Clone + 'static, +{ + fn clone_box(&self) -> Box> { + Box::new(self.clone()) + } +} + +impl Clone for Box> { + fn clone(&self) -> Self { + self.clone_box() + } +} + +#[derive(Clone, Debug)] +pub struct GeneralizedLasso>( + PhantomData, + PhantomData, +); diff --git a/plonkish_backend/src/backend/lookup/lasso/memory_checking/mod.rs b/plonkish_backend/src/backend/lookup/lasso/memory_checking/mod.rs new file mode 100644 index 0000000..c73fe0e --- /dev/null +++ b/plonkish_backend/src/backend/lookup/lasso/memory_checking/mod.rs @@ -0,0 +1,31 @@ +pub mod prover; +pub mod verifier; + +use halo2_curves::ff::PrimeField; +pub use prover::MemoryCheckingProver; + +use crate::poly::multilinear::MultilinearPolynomial; + +#[derive(Clone, Debug)] +struct MemoryGKR { + init: MultilinearPolynomial, + read: MultilinearPolynomial, + write: MultilinearPolynomial, + final_read: MultilinearPolynomial, +} + +impl MemoryGKR { + pub fn new( + init: MultilinearPolynomial, + read: MultilinearPolynomial, + write: MultilinearPolynomial, + final_read: MultilinearPolynomial, + ) -> Self { + Self { + init, + read, + write, + final_read, + } + } +} diff --git a/plonkish_backend/src/backend/lookup/lasso/memory_checking/prover.rs b/plonkish_backend/src/backend/lookup/lasso/memory_checking/prover.rs new file mode 100644 index 0000000..d9b6327 --- /dev/null +++ b/plonkish_backend/src/backend/lookup/lasso/memory_checking/prover.rs @@ -0,0 +1,212 @@ +use std::iter; + +use halo2_curves::ff::PrimeField; +use itertools::{chain, Itertools}; +use rayon::prelude::{IntoParallelIterator, ParallelIterator}; + +use crate::{ + backend::lookup::lasso::prover::Chunk, pcs::Evaluation, piop::gkr::prove_grand_product, + poly::multilinear::MultilinearPolynomial, util::transcript::FieldTranscriptWrite, Error, +}; + +use super::MemoryGKR; + +pub struct MemoryCheckingProver<'a, F: PrimeField> { + /// offset of MemoryCheckingProver instance opening points + points_offset: usize, + /// chunks with the same bits size + chunks: Vec>, + /// GKR initial polynomials for each memory + memories: Vec>, +} + +impl<'a, F: PrimeField> MemoryCheckingProver<'a, F> { + // T_1[dim_1(x)], ..., T_k[dim_1(x)], + // ... + // T_{\alpha-k+1}[dim_c(x)], ..., T_{\alpha}[dim_c(x)] + pub fn new(points_offset: usize, chunks: Vec>, tau: &F, gamma: &F) -> Self { + let num_reads = chunks[0].num_reads(); + let memory_size = 1 << chunks[0].chunk_bits(); + + let hash = |a: &F, v: &F, t: &F| -> F { *a + *v * gamma + *t * gamma.square() - tau }; + + let memories_gkr: Vec> = (0..chunks.len()) + .into_par_iter() + .flat_map(|i| { + let chunk = &chunks[i]; + let chunk_polys = chunk.chunk_polys().collect_vec(); + let (dim, read_ts_poly, final_cts_poly) = + (chunk_polys[0], chunk_polys[1], chunk_polys[2]); + chunk + .memories() + .map(|memory| { + let memory_polys = memory.polys().collect_vec(); + let (subtable_poly, e_poly) = (memory_polys[0], memory_polys[1]); + let mut init = vec![]; + let mut read = vec![]; + let mut write = vec![]; + let mut final_read = vec![]; + (0..memory_size).for_each(|i| { + init.push(hash(&F::from(i as u64), &subtable_poly[i], &F::ZERO)); + final_read.push(hash( + &F::from(i as u64), + &subtable_poly[i], + &final_cts_poly[i], + )); + }); + (0..num_reads).for_each(|i| { + read.push(hash(&dim[i], &e_poly[i], &read_ts_poly[i])); + write.push(hash(&dim[i], &e_poly[i], &(read_ts_poly[i] + F::ONE))); + }); + MemoryGKR::new( + MultilinearPolynomial::new(init), + MultilinearPolynomial::new(read), + MultilinearPolynomial::new(write), + MultilinearPolynomial::new(final_read), + ) + }) + .collect_vec() + }) + .collect(); + + Self { + points_offset, + chunks, + memories: memories_gkr, + } + } + + fn inits(&self) -> impl Iterator> { + self.memories.iter().map(|memory| &memory.init) + } + + fn reads(&self) -> impl Iterator> { + self.memories.iter().map(|memory| &memory.read) + } + + fn writes(&self) -> impl Iterator> { + self.memories.iter().map(|memory| &memory.write) + } + + fn final_reads(&self) -> impl Iterator> { + self.memories.iter().map(|memory| &memory.final_read) + } + + fn iter( + &self, + ) -> impl Iterator< + Item = ( + &MultilinearPolynomial, + &MultilinearPolynomial, + &MultilinearPolynomial, + &MultilinearPolynomial, + ), + > { + self.memories.iter().map(|memory| { + ( + &memory.init, + &memory.read, + &memory.write, + &memory.final_read, + ) + }) + } + + pub fn claimed_v_0s(&self) -> impl IntoIterator>> { + let (claimed_read_0s, claimed_write_0s, claimed_init_0s, claimed_final_read_0s) = self + .iter() + .map(|(init, read, write, final_read)| { + let claimed_init_0 = init.iter().product(); + let claimed_read_0 = read.iter().product(); + let claimed_write_0 = write.iter().product(); + let claimed_final_read_0 = final_read.iter().product(); + + // sanity check + debug_assert_eq!( + claimed_init_0 * claimed_write_0, + claimed_read_0 * claimed_final_read_0, + "Multiset hashes don't match", + ); + ( + Some(claimed_read_0), + Some(claimed_write_0), + Some(claimed_init_0), + Some(claimed_final_read_0), + ) + }) + .multiunzip::<(Vec<_>, Vec<_>, Vec<_>, Vec<_>)>(); + chain!([ + claimed_read_0s, + claimed_write_0s, + claimed_init_0s, + claimed_final_read_0s + ]) + } + + pub fn prove( + &mut self, + points_offset: usize, + lookup_opening_points: &mut Vec>, + lookup_opening_evals: &mut Vec>, + transcript: &mut impl FieldTranscriptWrite, + ) -> Result<(), Error> { + let (_, x) = prove_grand_product( + iter::repeat(None).take(self.memories.len() * 2), + chain!(self.reads(), self.writes()), + transcript, + )?; + + let (_, y) = prove_grand_product( + iter::repeat(None).take(self.memories.len() * 2), + chain!(self.inits(), self.final_reads()), + transcript, + )?; + + assert_eq!( + points_offset + lookup_opening_points.len(), + self.points_offset + ); + let x_offset = points_offset + lookup_opening_points.len(); + let y_offset = x_offset + 1; + let (dim_xs, read_ts_poly_xs, final_cts_poly_ys, e_poly_xs) = self + .chunks + .iter() + .map(|chunk| { + let chunk_poly_evals = chunk.chunk_poly_evals(&x, &y); + let e_poly_xs = chunk.e_poly_evals(&x); + transcript.write_field_elements(&chunk_poly_evals).unwrap(); + transcript.write_field_elements(&e_poly_xs).unwrap(); + + ( + Evaluation::new(chunk.dim.offset, x_offset, chunk_poly_evals[0]), + Evaluation::new(chunk.read_ts_poly.offset, x_offset, chunk_poly_evals[1]), + Evaluation::new(chunk.final_cts_poly.offset, y_offset, chunk_poly_evals[2]), + chunk + .memories() + .enumerate() + .map(|(i, memory)| { + Evaluation::new(memory.e_poly.offset, x_offset, e_poly_xs[i]) + }) + .collect_vec(), + ) + }) + .multiunzip::<( + Vec>, + Vec>, + Vec>, + Vec>>, + )>(); + + lookup_opening_points.extend_from_slice(&[x, y]); + let opening_evals = chain!( + dim_xs, + read_ts_poly_xs, + final_cts_poly_ys, + e_poly_xs.concat() + ) + .collect_vec(); + lookup_opening_evals.extend_from_slice(&opening_evals); + + Ok(()) + } +} diff --git a/plonkish_backend/src/backend/lookup/lasso/memory_checking/verifier.rs b/plonkish_backend/src/backend/lookup/lasso/memory_checking/verifier.rs new file mode 100644 index 0000000..0c9bd35 --- /dev/null +++ b/plonkish_backend/src/backend/lookup/lasso/memory_checking/verifier.rs @@ -0,0 +1,237 @@ +use std::{iter, marker::PhantomData}; + +use halo2_curves::ff::PrimeField; +use itertools::{chain, Itertools}; + +use crate::{ + pcs::Evaluation, + piop::gkr::verify_grand_product, + poly::multilinear::MultilinearPolynomialTerms, + util::{arithmetic::inner_product, transcript::FieldTranscriptRead}, + Error, +}; + +#[derive(Clone, Debug)] +pub(in crate::backend::lookup::lasso) struct Chunk { + chunk_index: usize, + chunk_bits: usize, + pub(crate) memory: Vec>, +} + +impl Chunk { + pub fn chunk_polys_index(&self, offset: usize, num_chunks: usize) -> Vec { + let dim_poly_index = offset + 1 + self.chunk_index; + let read_ts_poly_index = offset + 1 + num_chunks + self.chunk_index; + let final_cts_poly_index = offset + 1 + 2 * num_chunks + self.chunk_index; + vec![dim_poly_index, read_ts_poly_index, final_cts_poly_index] + } + + pub fn new(chunk_index: usize, chunk_bits: usize, memory: Memory) -> Self { + Self { + chunk_index, + chunk_bits, + memory: vec![memory], + } + } + + pub fn num_memories(&self) -> usize { + self.memory.len() + } + + pub fn chunk_bits(&self) -> usize { + self.chunk_bits + } + + pub fn add_memory(&mut self, memory: Memory) { + self.memory.push(memory); + } + + pub fn memory_indices(&self) -> Vec { + self.memory + .iter() + .map(|memory| memory.memory_index) + .collect_vec() + } + + /// check the following relations: + /// - $read(x) == hash(dim(x), E(x), read_ts(x))$ + /// - $write(x) == hash(dim(x), E(x), read_ts(x) + 1)$ + /// - $init(y) == hash(y, T(y), 0)$ + /// - $final_read(y) == hash(y, T(y), final_cts(x))$ + pub fn verify_memories( + &self, + read_xs: &[F], + write_xs: &[F], + init_ys: &[F], + final_read_ys: &[F], + y: &[F], + hash: impl Fn(&F, &F, &F) -> F, + transcript: &mut impl FieldTranscriptRead, + ) -> Result<(F, F, F, Vec), Error> { + let [dim_x, read_ts_poly_x, final_cts_poly_y] = + transcript.read_field_elements(3)?.try_into().unwrap(); + let e_poly_xs = transcript.read_field_elements(self.num_memories())?; + let id_poly_y = inner_product( + iter::successors(Some(F::ONE), |power_of_two| Some(power_of_two.double())) + .take(y.len()) + .collect_vec() + .iter(), + y, + ); + self.memory.iter().enumerate().for_each(|(i, memory)| { + assert_eq!(read_xs[i], hash(&dim_x, &e_poly_xs[i], &read_ts_poly_x)); + assert_eq!( + write_xs[i], + hash(&dim_x, &e_poly_xs[i], &(read_ts_poly_x + F::ONE)) + ); + let subtable_poly_y = memory.subtable_poly.evaluate(y); + assert_eq!(init_ys[i], hash(&id_poly_y, &subtable_poly_y, &F::ZERO)); + assert_eq!( + final_read_ys[i], + hash(&id_poly_y, &subtable_poly_y, &final_cts_poly_y) + ); + }); + Ok((dim_x, read_ts_poly_x, final_cts_poly_y, e_poly_xs)) + } +} + +#[derive(Clone, Debug)] +pub(in crate::backend::lookup::lasso) struct Memory { + memory_index: usize, + subtable_poly: MultilinearPolynomialTerms, +} + +impl Memory { + pub fn new(memory_index: usize, subtable_poly: MultilinearPolynomialTerms) -> Self { + Self { + memory_index, + subtable_poly, + } + } +} + +#[derive(Clone, Debug)] +pub(in crate::backend::lookup::lasso) struct MemoryCheckingVerifier { + /// chunks with the same bits size + chunks: Vec>, + _marker: PhantomData, +} + +impl<'a, F: PrimeField> MemoryCheckingVerifier { + pub fn new(chunks: Vec>) -> Self { + Self { + chunks, + _marker: PhantomData, + } + } + + pub fn verify( + &self, + num_chunks: usize, + num_reads: usize, + polys_offset: usize, + points_offset: usize, + gamma: &F, + tau: &F, + lookup_opening_points: &mut Vec>, + lookup_opening_evals: &mut Vec>, + transcript: &mut impl FieldTranscriptRead, + ) -> Result<(), Error> { + let num_memories: usize = self.chunks.iter().map(|chunk| chunk.num_memories()).sum(); + let memory_bits = self.chunks[0].chunk_bits(); + let (read_write_xs, x) = verify_grand_product( + num_reads, + iter::repeat(None).take(2 * num_memories), + transcript, + )?; + let (read_xs, write_xs) = read_write_xs.split_at(num_memories); + + let (init_final_read_ys, y) = verify_grand_product( + memory_bits, + iter::repeat(None).take(2 * num_memories), + transcript, + )?; + let (init_ys, final_read_ys) = init_final_read_ys.split_at(num_memories); + + let hash = |a: &F, v: &F, t: &F| -> F { *a + *v * gamma + *t * gamma.square() - tau }; + let mut offset = 0; + let (dim_xs, read_ts_poly_xs, final_cts_poly_ys, e_poly_xs) = self + .chunks + .iter() + .map(|chunk| { + let num_memories = chunk.num_memories(); + let result = chunk.verify_memories( + &read_xs[offset..offset + num_memories], + &write_xs[offset..offset + num_memories], + &init_ys[offset..offset + num_memories], + &final_read_ys[offset..offset + num_memories], + &y, + hash, + transcript, + ); + offset += num_memories; + result + }) + .collect::)>, Error>>()? + .into_iter() + .multiunzip::<(Vec<_>, Vec<_>, Vec<_>, Vec>)>(); + + self.opening_evals( + num_chunks, + polys_offset, + points_offset, + &lookup_opening_points, + lookup_opening_evals, + &dim_xs, + &read_ts_poly_xs, + &final_cts_poly_ys, + &e_poly_xs.concat(), + ); + lookup_opening_points.extend_from_slice(&[x, y]); + + Ok(()) + } + + fn opening_evals( + &self, + num_chunks: usize, + polys_offset: usize, + points_offset: usize, + lookup_opening_points: &Vec>, + lookup_opening_evals: &mut Vec>, + dim_xs: &[F], + read_ts_poly_xs: &[F], + final_cts_poly_ys: &[F], + e_poly_xs: &[F], + ) { + let x_offset = points_offset + lookup_opening_points.len(); + let y_offset = x_offset + 1; + let (dim_xs, read_ts_poly_xs, final_cts_poly_ys) = self + .chunks + .iter() + .enumerate() + .map(|(i, chunk)| { + let chunk_polys_index = chunk.chunk_polys_index(polys_offset, num_chunks); + ( + Evaluation::new(chunk_polys_index[0], x_offset, dim_xs[i]), + Evaluation::new(chunk_polys_index[1], x_offset, read_ts_poly_xs[i]), + Evaluation::new(chunk_polys_index[2], y_offset, final_cts_poly_ys[i]), + ) + }) + .multiunzip::<(Vec>, Vec>, Vec>)>(); + + let e_poly_offset = polys_offset + 1 + 3 * num_chunks; + let e_poly_xs = self + .chunks + .iter() + .flat_map(|chunk| chunk.memory_indices()) + .zip(e_poly_xs) + .map(|(memory_index, &e_poly_x)| { + Evaluation::new(e_poly_offset + memory_index, x_offset, e_poly_x) + }) + .collect_vec(); + lookup_opening_evals.extend_from_slice( + &chain!(dim_xs, read_ts_poly_xs, final_cts_poly_ys, e_poly_xs).collect_vec(), + ); + } +} diff --git a/plonkish_backend/src/backend/lookup/lasso/prover/mod.rs b/plonkish_backend/src/backend/lookup/lasso/prover/mod.rs new file mode 100644 index 0000000..54c7482 --- /dev/null +++ b/plonkish_backend/src/backend/lookup/lasso/prover/mod.rs @@ -0,0 +1,475 @@ +use std::{ + collections::{HashMap, HashSet}, + marker::PhantomData, +}; + +use halo2_curves::ff::{Field, PrimeField}; +use itertools::{chain, Itertools}; + +use crate::{ + pcs::{CommitmentChunk, Evaluation, PolynomialCommitmentScheme}, + poly::multilinear::MultilinearPolynomial, + util::{ + arithmetic::BooleanHypercube, + expression::{CommonPolynomial, Expression}, + impl_index, + parallel::parallelize, + transcript::{FieldTranscriptWrite, TranscriptWrite}, + }, + Error, +}; + +use super::{memory_checking::MemoryCheckingProver, DecomposableTable}; + +mod surge; + +pub use surge::Surge; + +#[derive(Clone, Debug)] +pub struct Poly { + /// polynomial offset in batch opening + pub(crate) offset: usize, + pub(crate) poly: MultilinearPolynomial, +} + +impl_index!(Poly, poly); + +impl Poly { + pub fn num_vars(&self) -> usize { + self.poly.num_vars() + } + + pub fn evaluate(&self, x: &[F]) -> F { + self.poly.evaluate(x) + } +} + +#[derive(Clone, Debug)] +pub struct Chunk<'a, F: PrimeField> { + pub(super) chunk_index: usize, + pub(super) dim: &'a Poly, + pub(super) read_ts_poly: &'a Poly, + pub(super) final_cts_poly: &'a Poly, + pub(super) memories: Vec>, +} + +impl<'a, F: PrimeField> Chunk<'a, F> { + fn new( + chunk_index: usize, + dim: &'a Poly, + read_ts_poly: &'a Poly, + final_cts_poly: &'a Poly, + memory: Memory<'a, F>, + ) -> Self { + // sanity check + assert_eq!(dim.num_vars(), read_ts_poly.num_vars()); + + Self { + chunk_index, + dim, + read_ts_poly, + final_cts_poly, + memories: vec![memory], + } + } + + pub fn chunk_index(&self) -> usize { + self.chunk_index + } + + pub fn chunk_bits(&self) -> usize { + self.final_cts_poly.num_vars() + } + + pub fn num_reads(&self) -> usize { + 1 << self.dim.num_vars() + } + + pub fn chunk_polys(&self) -> impl Iterator> { + chain!([ + &self.dim.poly, + &self.read_ts_poly.poly, + &self.final_cts_poly.poly + ]) + } + + pub fn chunk_poly_evals(&self, x: &[F], y: &[F]) -> Vec { + vec![ + self.dim.evaluate(x), + self.read_ts_poly.evaluate(x), + self.final_cts_poly.evaluate(y), + ] + } + + pub fn e_poly_evals(&self, x: &[F]) -> Vec { + self.memories + .iter() + .map(|memory| memory.e_poly.evaluate(x)) + .collect_vec() + } + + pub(super) fn memories(&self) -> impl Iterator> { + self.memories.iter() + } + + pub(super) fn add_memory(&mut self, memory: Memory<'a, F>) { + // sanity check + let chunk_bits = self.chunk_bits(); + let num_reads = self.num_reads(); + assert_eq!(chunk_bits, memory.subtable_poly.num_vars()); + assert_eq!(num_reads, memory.e_poly.num_vars()); + + self.memories.push(memory); + } +} + +#[derive(Clone, Debug)] +pub(super) struct Memory<'a, F: PrimeField> { + subtable_poly: &'a MultilinearPolynomial, + pub(crate) e_poly: &'a Poly, +} + +impl<'a, F: PrimeField> Memory<'a, F> { + fn new(subtable_poly: &'a MultilinearPolynomial, e_poly: &'a Poly) -> Self { + Self { + subtable_poly, + e_poly, + } + } + + pub fn polys(&'a self) -> impl Iterator> { + chain!([&self.subtable_poly, &self.e_poly.poly]) + } +} + +pub struct LassoProver< + F: Field + PrimeField, + Pcs: PolynomialCommitmentScheme>, +>(PhantomData, PhantomData); + +impl< + F: Field + PrimeField, + Pcs: PolynomialCommitmentScheme>, + > LassoProver +{ + pub fn lookup_poly( + lookup: &(&Expression, &Expression), + polys: &[&MultilinearPolynomial], + ) -> (MultilinearPolynomial, MultilinearPolynomial) { + let num_vars = polys[0].num_vars(); + let expression = lookup.0 + lookup.1; + let lagranges = { + let bh = BooleanHypercube::new(num_vars).iter().collect_vec(); + expression + .used_langrange() + .into_iter() + .map(|i| (i, bh[i.rem_euclid(1 << num_vars) as usize])) + .collect::>() + }; + let bh = BooleanHypercube::new(num_vars); + + let evaluate = |expression: &Expression| { + let mut evals = vec![F::ZERO; 1 << num_vars]; + parallelize(&mut evals, |(evals, start)| { + for (b, eval) in (start..).zip(evals) { + *eval = expression.evaluate( + &|constant| constant, + &|common_poly| match common_poly { + CommonPolynomial::Identity => F::from(b as u64), + CommonPolynomial::Lagrange(i) => { + if lagranges.contains(&(i, b)) { + F::ONE + } else { + F::ZERO + } + } + CommonPolynomial::EqXY(_) => unreachable!(), + }, + &|query| polys[query.poly()][bh.rotate(b, query.rotation())], + &|_| unreachable!(), + &|value| -value, + &|lhs, rhs| lhs + &rhs, + &|lhs, rhs| lhs * &rhs, + &|value, scalar| value * &scalar, + ); + } + }); + MultilinearPolynomial::new(evals) + }; + + let (input, index) = lookup; + (evaluate(input), evaluate(index)) + } +} + +impl< + F: Field + PrimeField, + Pcs: PolynomialCommitmentScheme>, + > LassoProver +{ + fn e_polys( + table: &Box>, + subtable_polys: &[&MultilinearPolynomial], + indices: &Vec<&[usize]>, + ) -> Vec> { + let num_chunks = table.chunk_bits().len(); + let num_memories = table.num_memories(); + assert_eq!(indices.len(), num_chunks); + let num_reads = indices[0].len(); + (0..num_memories) + .map(|i| { + let mut e_poly = Vec::with_capacity(num_reads); + let subtable_poly = subtable_polys[table.memory_to_subtable_index(i)]; + let index = indices[table.memory_to_chunk_index(i)]; + (0..num_reads).for_each(|j| { + e_poly.push(subtable_poly[index[j]]); + }); + MultilinearPolynomial::new(e_poly) + }) + .collect_vec() + } + + fn chunks<'a>( + table: &Box>, + subtable_polys: &'a [&MultilinearPolynomial], + dims: &'a [Poly], + read_ts_polys: &'a [Poly], + final_cts_polys: &'a [Poly], + e_polys: &'a [Poly], + ) -> Vec> { + // key: chunk index, value: chunk + let mut chunk_map: HashMap> = HashMap::new(); + + let num_memories = table.num_memories(); + let memories = (0..num_memories).map(|memory_index| { + let subtable_poly = subtable_polys[table.memory_to_subtable_index(memory_index)]; + Memory::new(subtable_poly, &e_polys[memory_index]) + }); + memories.enumerate().for_each(|(memory_index, memory)| { + let chunk_index = table.memory_to_chunk_index(memory_index); + if chunk_map.get(&chunk_index).is_some() { + chunk_map.entry(chunk_index).and_modify(|chunk| { + chunk.add_memory(memory); + }); + } else { + let dim = &dims[chunk_index]; + let read_ts_poly = &read_ts_polys[chunk_index]; + let final_cts_poly = &final_cts_polys[chunk_index]; + chunk_map.insert( + chunk_index, + Chunk::new(chunk_index, dim, read_ts_poly, final_cts_poly, memory), + ); + } + }); + + // sanity check + { + let num_chunks = table.chunk_bits().len(); + assert_eq!(chunk_map.len(), num_chunks); + } + + let mut chunks = chunk_map.into_iter().collect_vec(); + chunks.sort_by_key(|(chunk_index, _)| *chunk_index); + chunks.into_iter().map(|(_, chunk)| chunk).collect_vec() + } + + pub fn prove_sum_check( + points_offset: usize, + lookup_opening_points: &mut Vec>, + lookup_opening_evals: &mut Vec>, + table: &Box>, + lookup_output_poly: &Poly, + e_polys: &[&Poly], + r: &[F], + num_vars: usize, + transcript: &mut impl TranscriptWrite, F>, + ) -> Result<(), Error> { + Surge::::prove_sum_check( + table, + lookup_output_poly, + &e_polys, + r, + num_vars, + points_offset, + lookup_opening_points, + lookup_opening_evals, + transcript, + ) + } + + fn prepare_memory_checking<'a>( + points_offset: usize, + table: &Box>, + subtable_polys: &'a [&MultilinearPolynomial], + dims: &'a [Poly], + read_ts_polys: &'a [Poly], + final_cts_polys: &'a [Poly], + e_polys: &'a [Poly], + gamma: &F, + tau: &F, + ) -> Vec> { + let chunks = Self::chunks( + table, + subtable_polys, + dims, + read_ts_polys, + final_cts_polys, + e_polys, + ); + let chunk_bits = table.chunk_bits(); + // key: chunk bits, value: chunks + let mut chunk_map: HashMap>> = HashMap::new(); + + chunks.iter().enumerate().for_each(|(chunk_index, chunk)| { + let chunk_bits = chunk_bits[chunk_index]; + if let Some(_) = chunk_map.get(&chunk_bits) { + chunk_map.entry(chunk_bits).and_modify(|chunks| { + chunks.push(chunk.clone()); + }); + } else { + chunk_map.insert(chunk_bits, vec![chunk.clone()]); + } + }); + + chunk_map + .into_iter() + .enumerate() + .map(|(index, (_, chunks))| { + let points_offset = points_offset + 2 + 2 * index; + MemoryCheckingProver::new(points_offset, chunks, tau, gamma) + }) + .collect_vec() + } + + pub fn memory_checking<'a>( + points_offset: usize, + lookup_opening_points: &mut Vec>, + lookup_opening_evals: &mut Vec>, + table: &Box>, + subtable_polys: &'a [&MultilinearPolynomial], + dims: &'a [Poly], + read_ts_polys: &'a [Poly], + final_cts_polys: &'a [Poly], + e_polys: &'a [Poly], + gamma: &F, + tau: &F, + transcript: &mut impl FieldTranscriptWrite, + ) -> Result<(), Error> { + let mut memory_checking = LassoProver::::prepare_memory_checking( + points_offset, + &table, + &subtable_polys, + &dims, + &read_ts_polys, + &final_cts_polys, + &e_polys, + &gamma, + &tau, + ); + + memory_checking + .iter_mut() + .map(|memory_checking| { + memory_checking.prove( + points_offset, + lookup_opening_points, + lookup_opening_evals, + transcript, + ) + }) + .collect::, Error>>()?; + Ok(()) + } + + pub fn commit( + pp: &Pcs::ProverParam, + lookup_polys_offset: usize, + table: &Box>, + subtable_polys: &[&MultilinearPolynomial], + lookup_output_poly: MultilinearPolynomial, + lookup_index_poly: &MultilinearPolynomial, + transcript: &mut impl TranscriptWrite, + ) -> Result<(Vec>>, Vec>), Error> { + let num_chunks = table.chunk_bits().len(); + + // commit to lookup_output_poly + let lookup_output_comm = Pcs::commit_and_write(&pp, &lookup_output_poly, transcript)?; + + // get surge and dims + let mut surge = Surge::::new(); + + // commit to dims + let dims = surge.commit(&table, lookup_index_poly); + let dim_comms = Pcs::batch_commit_and_write(pp, &dims, transcript)?; + + // get e_polys & read_ts_polys & final_cts_polys + let e_polys = { + let indices = surge.indices(); + LassoProver::::e_polys(&table, subtable_polys, &indices) + }; + let (read_ts_polys, final_cts_polys) = surge.counter_polys(&table); + + // commit to read_ts_polys & final_cts_polys & e_polys + let read_ts_comms = Pcs::batch_commit_and_write(&pp, &read_ts_polys, transcript)?; + let final_cts_comms = Pcs::batch_commit_and_write(&pp, &final_cts_polys, transcript)?; + let e_comms = Pcs::batch_commit_and_write(&pp, e_polys.as_slice(), transcript)?; + + let lookup_output_poly = Poly { + offset: lookup_polys_offset, + poly: lookup_output_poly, + }; + + let dims = dims + .into_iter() + .enumerate() + .map(|(chunk_index, dim)| Poly { + offset: lookup_polys_offset + 1 + chunk_index, + poly: dim, + }) + .collect_vec(); + + let read_ts_polys = read_ts_polys + .into_iter() + .enumerate() + .map(|(chunk_index, read_ts_poly)| Poly { + offset: lookup_polys_offset + 1 + num_chunks + chunk_index, + poly: read_ts_poly, + }) + .collect_vec(); + + let final_cts_polys = final_cts_polys + .into_iter() + .enumerate() + .map(|(chunk_index, final_cts_poly)| Poly { + offset: lookup_polys_offset + 1 + 2 * num_chunks + chunk_index, + poly: final_cts_poly, + }) + .collect_vec(); + + let e_polys = e_polys + .into_iter() + .enumerate() + .map(|(memory_index, e_poly)| Poly { + offset: lookup_polys_offset + 1 + 3 * num_chunks + memory_index, + poly: e_poly, + }) + .collect_vec(); + + Ok(( + vec![ + vec![lookup_output_poly], + dims, + read_ts_polys, + final_cts_polys, + e_polys, + ], + vec![ + vec![lookup_output_comm], + dim_comms, + read_ts_comms, + final_cts_comms, + e_comms, + ], + )) + } +} diff --git a/plonkish_backend/src/backend/lookup/lasso/prover/surge.rs b/plonkish_backend/src/backend/lookup/lasso/prover/surge.rs new file mode 100644 index 0000000..d23b1b5 --- /dev/null +++ b/plonkish_backend/src/backend/lookup/lasso/prover/surge.rs @@ -0,0 +1,225 @@ +use std::{collections::BTreeSet, iter::repeat, marker::PhantomData}; + +use halo2_curves::ff::{Field, PrimeField}; +use itertools::Itertools; +use rayon::prelude::{IntoParallelIterator, ParallelIterator}; + +use crate::{ + backend::lookup::lasso::DecomposableTable, + pcs::{CommitmentChunk, Evaluation, PolynomialCommitmentScheme}, + piop::sum_check::{ + classic::{ClassicSumCheck, EvaluationsProver}, + SumCheck as _, VirtualPolynomial, + }, + poly::multilinear::MultilinearPolynomial, + util::{ + arithmetic::{fe_to_bits_le, usize_from_bits_le}, + expression::{Expression, Query, Rotation}, + transcript::TranscriptWrite, + }, + Error, +}; + +use super::Poly; + +type SumCheck = ClassicSumCheck>; + +pub struct Surge< + F: Field + PrimeField, + Pcs: PolynomialCommitmentScheme>, +> { + lookup_indices: Vec>, + _marker: PhantomData, + _marker2: PhantomData, +} + +impl< + F: Field + PrimeField, + Pcs: PolynomialCommitmentScheme>, + > Surge +{ + pub fn new() -> Self { + Self { + lookup_indices: vec![vec![]], + _marker: PhantomData, + _marker2: PhantomData, + } + } + + pub fn indices(&'_ self) -> Vec<&[usize]> { + self.lookup_indices + .iter() + .map(|lookup_indices| lookup_indices.as_slice()) + .collect_vec() + } + + /// computes dim_1, ..., dim_c where c == DecomposableTable::C + pub fn commit( + &mut self, + table: &Box>, + index_poly: &MultilinearPolynomial, + ) -> Vec> { + let num_rows: usize = 1 << index_poly.num_vars(); + let num_chunks = table.chunk_bits().len(); + // get indices of non-zero columns of all rows where each index is chunked + let indices = (0..num_rows) + .map(|i| { + let mut index_bits = fe_to_bits_le(index_poly[i]); + index_bits.truncate(table.chunk_bits().iter().sum()); + assert_eq!( + usize_from_bits_le(&fe_to_bits_le(index_poly[i])), + usize_from_bits_le(&index_bits) + ); + + let mut chunked_index = repeat(0).take(num_chunks).collect_vec(); + let chunked_index_bits = table.subtable_indices(index_bits); + chunked_index + .iter_mut() + .zip(chunked_index_bits) + .map(|(chunked_index, index_bits)| { + *chunked_index = usize_from_bits_le(&index_bits); + }) + .collect_vec(); + chunked_index + }) + .collect_vec(); + let mut dims = Vec::with_capacity(num_chunks); + self.lookup_indices.resize(num_chunks, vec![]); + self.lookup_indices + .iter_mut() + .enumerate() + .for_each(|(i, lookup_indices)| { + let indices = indices + .iter() + .map(|indices| { + lookup_indices.push(indices[i]); + indices[i] + }) + .collect_vec(); + dims.push(MultilinearPolynomial::from_usize(indices)); + }); + + dims + } + + pub fn counter_polys( + &self, + table: &Box>, + ) -> (Vec>, Vec>) { + let num_chunks = table.chunk_bits().len(); + let mut read_ts_polys = Vec::with_capacity(num_chunks); + let mut final_cts_polys = Vec::with_capacity(num_chunks); + let chunk_bits = table.chunk_bits(); + self.lookup_indices + .iter() + .enumerate() + .for_each(|(i, lookup_indices)| { + let num_reads = lookup_indices.len(); + let memory_size = 1 << chunk_bits[i]; + let mut final_timestamps = vec![0usize; memory_size]; + let mut read_timestamps = vec![0usize; num_reads]; + (0..num_reads).for_each(|i| { + let memory_address = lookup_indices[i]; + let ts = final_timestamps[memory_address]; + read_timestamps[i] = ts; + let write_timestamp = ts + 1; + final_timestamps[memory_address] = write_timestamp; + }); + read_ts_polys.push(MultilinearPolynomial::from_usize(read_timestamps)); + final_cts_polys.push(MultilinearPolynomial::from_usize(final_timestamps)); + }); + + (read_ts_polys, final_cts_polys) + } + + pub fn prove_sum_check( + table: &Box>, + lookup_output_poly: &Poly, + e_polys: &[&Poly], + r: &[F], + num_vars: usize, + points_offset: usize, + lookup_opening_points: &mut Vec>, + lookup_opening_evals: &mut Vec>, + transcript: &mut impl TranscriptWrite, F>, + ) -> Result<(), Error> { + let claimed_sum = Self::sum_check_claim(&r, &table, &e_polys); + assert_eq!(claimed_sum, lookup_output_poly.evaluate(r)); + + transcript.write_field_element(&claimed_sum)?; + + let expression = Self::sum_check_expression(&table); + + // proceed sumcheck + let (x, evals) = SumCheck::prove( + &(), + num_vars, + VirtualPolynomial::new( + &expression, + e_polys.iter().map(|e_poly| &e_poly.poly), + &[], + &[r.to_vec()], + ), + claimed_sum, + transcript, + )?; + + lookup_opening_points.extend_from_slice(&[r.to_vec(), x]); + let evals = expression + .used_query() + .into_iter() + .map(|query| { + transcript + .write_field_element(&evals[query.poly()]) + .unwrap(); + Evaluation::new( + e_polys[query.poly()].offset, + points_offset + 1, + evals[query.poly()], + ) + }) + .chain([Evaluation::new( + lookup_output_poly.offset, + points_offset, + claimed_sum, + )]) + .collect_vec(); + lookup_opening_evals.extend_from_slice(&evals); + + Ok(()) + } + + pub fn sum_check_claim( + r: &[F], + table: &Box>, + e_polys: &[&Poly], + ) -> F { + let num_memories = table.num_memories(); + assert_eq!(e_polys.len(), num_memories); + let num_vars = e_polys[0].num_vars(); + let bh_size = 1 << num_vars; + let eq = MultilinearPolynomial::eq_xy(r); + // \sum_{k \in \{0, 1\}^{\log m}} (\tilde{eq}(r, k) * g(E_1(k), ..., E_{\alpha}(k))) + let claim = (0..bh_size) + .into_par_iter() + .map(|k| { + let operands = e_polys.iter().map(|e_poly| e_poly[k]).collect_vec(); + eq[k] * table.combine_lookups(&operands) + }) + .sum(); + + claim + } + + // (\tilde{eq}(r, k) * g(E_1(k), ..., E_{\alpha}(k))) + pub fn sum_check_expression(table: &Box>) -> Expression { + let num_memories = table.num_memories(); + let exprs = table.combine_lookup_expressions( + (0..num_memories) + .map(|idx| Expression::Polynomial(Query::new(idx, Rotation::cur()))) + .collect_vec(), + ); + let eq_xy = Expression::::eq_xy(0); + eq_xy * exprs + } +} diff --git a/plonkish_backend/src/backend/lookup/lasso/test/and.rs b/plonkish_backend/src/backend/lookup/lasso/test/and.rs new file mode 100644 index 0000000..325d9b8 --- /dev/null +++ b/plonkish_backend/src/backend/lookup/lasso/test/and.rs @@ -0,0 +1,266 @@ +use std::{iter, marker::PhantomData}; + +use halo2_curves::ff::PrimeField; +use itertools::{izip, Itertools}; + +use crate::{ + backend::lookup::lasso::DecomposableTable, + poly::multilinear::{MultilinearPolynomial, MultilinearPolynomialTerms, PolyExpr::*}, + util::{ + arithmetic::{inner_product, split_bits, split_by_chunk_bits}, + expression::Expression, + }, +}; + +#[derive(Clone, Debug)] +pub struct AndTable(PhantomData); + +impl AndTable { + pub fn new() -> Self { + Self(PhantomData) + } +} + +/// T[X || Y] = T_1[X_1 || Y_1] + T_2[X_2 || Y_2] * 2^8 + ... + T_8[X_8 || Y_8] * 2^56 +impl DecomposableTable for AndTable { + fn num_memories(&self) -> usize { + 8 + } + + fn subtable_polys(&self) -> Vec> { + let memory_size = 1 << 16; + let mut evals = vec![]; + (0..memory_size).for_each(|i| { + let (lhs, rhs) = split_bits(i, 8); + let result = F::from((lhs & rhs) as u64); + evals.push(result) + }); + vec![MultilinearPolynomial::new(evals)] + } + + fn subtable_polys_terms(&self) -> Vec> { + let init = Prod(vec![Var(0), Var(8)]); + let mut terms = vec![init]; + (1..8).for_each(|i| { + let coeff = Const(F::from(1 << i)); + let x = Var(i); + let y = Var(i + 8); + let term = Prod(vec![coeff, x, y]); + terms.push(term); + }); + vec![MultilinearPolynomialTerms::new(16, Sum(terms))] + } + + fn chunk_bits(&self) -> Vec { + vec![16; 8] + } + + fn subtable_indices(&self, index_bits: Vec) -> Vec> { + assert!(index_bits.len() % 2 == 0); + let chunk_bits = self + .chunk_bits() + .iter() + .map(|chunk_bits| chunk_bits / 2) + .collect_vec(); + let (lhs, rhs) = index_bits.split_at(index_bits.len() / 2); + izip!( + split_by_chunk_bits(lhs, &chunk_bits), + split_by_chunk_bits(rhs, &chunk_bits) + ) + .map(|(chunked_lhs_bits, chunked_rhs_bits)| { + iter::empty() + .chain(chunked_lhs_bits) + .chain(chunked_rhs_bits) + .collect_vec() + }) + .collect_vec() + } + + fn combine_lookup_expressions(&self, expressions: Vec>) -> Expression { + Expression::DistributePowers(expressions, Box::new(Expression::Constant(F::from(1 << 8)))) + } + + fn combine_lookups(&self, operands: &[F]) -> F { + let weight = F::from(1 << 8); + inner_product( + operands, + iter::successors(Some(F::ONE), |power_of_weight| { + Some(*power_of_weight * weight) + }) + .take(operands.len()) + .collect_vec() + .iter(), + ) + } + + fn memory_to_chunk_index(&self, memory_index: usize) -> usize { + memory_index + } + + fn memory_to_subtable_index(&self, _memory_index: usize) -> usize { + 0 + } +} + +#[cfg(test)] +mod test { + use std::array; + + use super::AndTable; + use crate::{ + backend::{ + hyperplonk::{prover::instance_polys, util::Permutation, HyperPlonk}, + lookup::lasso::DecomposableTable, + mock::MockCircuit, + test::run_plonkish_backend, + PlonkishCircuit, PlonkishCircuitInfo, + }, + pcs::{ + multilinear::{ + Gemini, MultilinearBrakedown, MultilinearHyrax, MultilinearIpa, MultilinearKzg, + Zeromorph, + }, + univariate::UnivariateKzg, + }, + poly::Polynomial, + util::{ + arithmetic::{fe_to_bits_le, usize_from_bits_le}, + code::BrakedownSpec6, + expression::{Expression, Query, Rotation}, + hash::Keccak256, + test::{rand_idx, rand_vec, seeded_std_rng}, + transcript::Keccak256Transcript, + }, + }; + use halo2_curves::{ + bn256::{self, Bn256}, + ff::PrimeField, + grumpkin, + }; + use num_integer::Integer; + use rand::RngCore; + + fn rand_lasso_lookup_circuit( + num_vars: usize, + table: Box>, + mut preprocess_rng: impl RngCore, + mut witness_rng: impl RngCore, + ) -> (PlonkishCircuitInfo, impl PlonkishCircuit) { + let size = 1 << num_vars; + let mut polys = [(); 5].map(|_| vec![F::ZERO; size]); + + let instances = rand_vec(num_vars, &mut witness_rng); + polys[0] = instance_polys(num_vars, [&instances])[0].evals().to_vec(); + + let mut permutation = Permutation::default(); + for poly in [2, 3, 4] { + permutation.copy((poly, 1), (poly, 1)); + } + for idx in 0..size - 1 { + let use_copy = preprocess_rng.next_u32().is_even() && idx > 1; + let [w_l, w_r, w_o] = if use_copy { + let [l_copy_idx, r_copy_idx] = [(); 2].map(|_| { + ( + rand_idx(2..5, &mut preprocess_rng), + rand_idx(1..idx, &mut preprocess_rng), + ) + }); + permutation.copy(l_copy_idx, (2, idx)); + permutation.copy(r_copy_idx, (3, idx)); + let w_l = polys[l_copy_idx.0][l_copy_idx.1]; + let w_r = polys[r_copy_idx.0][r_copy_idx.1]; + let w_o = usize_from_bits_le(&fe_to_bits_le(w_l)) + & usize_from_bits_le(&fe_to_bits_le(w_r)); + let w_o = F::from(w_o as u64); + [w_l, w_r, w_o] + } else { + let [w_l, w_r] = [(); 2].map(|_| witness_rng.next_u64()); + let w_o = w_l & w_r; + [F::from(w_l), F::from(w_r), F::from(w_o)] + }; + + let q_and = F::ONE; + let values = vec![ + (1, q_and), + (2, w_l), + (3, w_r), + (4, w_o), + ]; + for (poly, value) in values { + polys[poly][idx] = value; + } + } + let [_, q_and, w_l, w_r, w_o] = polys; + let circuit_info = lasso_lookup_circuit_info( + num_vars, + instances.len(), + [q_and], + table, + permutation.into_cycles(), + ); + ( + circuit_info, + MockCircuit::new(vec![instances], vec![w_l, w_r, w_o]), + ) + } + + fn lasso_lookup_circuit_info( + num_vars: usize, + num_instances: usize, + preprocess_polys: [Vec; 1], + table: Box>, + permutations: Vec>, + ) -> PlonkishCircuitInfo { + let [_, q_and, w_l, w_r, w_o] = + &array::from_fn(|poly| Query::new(poly, Rotation::cur())) + .map(Expression::::Polynomial); + let lasso_lookup_indices = Expression::DistributePowers( + vec![w_l.clone(), w_r.clone()], + Box::new(Expression::Constant(F::from_u128(1 << 64))), + ); + let lasso_lookup_output = w_o.clone(); + let chunk_bits = table.chunk_bits(); + let num_vars = chunk_bits.iter().chain([&num_vars]).max().unwrap(); + PlonkishCircuitInfo { + k: *num_vars, + num_instances: vec![num_instances], + preprocess_polys: preprocess_polys.to_vec(), + num_witness_polys: vec![3], + num_challenges: vec![0], + constraints: vec![], + lookups: vec![vec![]], + lasso_lookup: Some((q_and * lasso_lookup_indices, q_and * lasso_lookup_output, table)), + permutations, + max_degree: Some(4), + } + } + + macro_rules! test { + ($name:ident, $f:ty, $pcs:ty, $num_vars_range:expr) => { + paste::paste! { + #[test] + fn [<$name _hyperplonk_vanilla_plonk_with_lasso_lookup>]() { + run_plonkish_backend::<_, HyperPlonk<$pcs>, Keccak256Transcript<_>, _>($num_vars_range, |num_vars| { + let table = Box::new(AndTable::<$f>::new()); + rand_lasso_lookup_circuit(num_vars, table, seeded_std_rng(), seeded_std_rng()) + }); + } + } + }; + ($name:ident, $f:ty, $pcs:ty) => { + test!($name, $f, $pcs, 16..17); + }; + } + + test!(brakedown, bn256::Fr, MultilinearBrakedown); + test!( + hyrax, + grumpkin::Fr, + MultilinearHyrax, + 5..16 + ); + test!(ipa, grumpkin::Fr, MultilinearIpa); + test!(kzg, bn256::Fr, MultilinearKzg); + test!(gemini_kzg, bn256::Fr, Gemini>); + test!(zeromorph_kzg, bn256::Fr, Zeromorph>); +} diff --git a/plonkish_backend/src/backend/lookup/lasso/test/mod.rs b/plonkish_backend/src/backend/lookup/lasso/test/mod.rs new file mode 100644 index 0000000..c5ab6df --- /dev/null +++ b/plonkish_backend/src/backend/lookup/lasso/test/mod.rs @@ -0,0 +1,2 @@ +pub mod and; +pub mod range; diff --git a/plonkish_backend/src/backend/lookup/lasso/test/range.rs b/plonkish_backend/src/backend/lookup/lasso/test/range.rs new file mode 100644 index 0000000..9cda40c --- /dev/null +++ b/plonkish_backend/src/backend/lookup/lasso/test/range.rs @@ -0,0 +1,288 @@ +use std::{iter, marker::PhantomData}; + +use halo2_curves::ff::PrimeField; +use itertools::Itertools; + +use crate::{ + backend::lookup::lasso::DecomposableTable, + poly::multilinear::{MultilinearPolynomial, MultilinearPolynomialTerms, PolyExpr::*}, + util::{ + arithmetic::{div_ceil, inner_product}, + expression::Expression, + }, +}; + +#[derive(Clone, Debug)] +pub struct RangeTable(PhantomData); + +impl RangeTable { + pub fn new() -> Self { + Self(PhantomData) + } +} + +impl DecomposableTable + for RangeTable +{ + fn chunk_bits(&self) -> Vec { + let remainder_bits = if NUM_BITS % LIMB_BITS != 0 { + vec![NUM_BITS % LIMB_BITS] + } else { + vec![] + }; + iter::repeat(LIMB_BITS) + .take(NUM_BITS / LIMB_BITS) + .chain(remainder_bits) + .collect_vec() + } + + fn combine_lookup_expressions(&self, expressions: Vec>) -> Expression { + Expression::DistributePowers( + expressions, + Box::new(Expression::Constant(F::from(1 << LIMB_BITS))), + ) + } + + fn combine_lookups(&self, operands: &[F]) -> F { + let weight = F::from(1 << LIMB_BITS); + inner_product( + operands, + iter::successors(Some(F::ONE), |power_of_weight| { + Some(*power_of_weight * weight) + }) + .take(operands.len()) + .collect_vec() + .iter(), + ) + } + + fn num_memories(&self) -> usize { + div_ceil(NUM_BITS, LIMB_BITS) + } + + fn subtable_indices(&self, index_bits: Vec) -> Vec> { + index_bits.chunks(LIMB_BITS).map(Vec::from).collect_vec() + } + + fn subtable_polys(&self) -> Vec> { + let mut evals = vec![]; + (0..1 << LIMB_BITS).for_each(|i| evals.push(F::from(i))); + let limb_subtable_poly = MultilinearPolynomial::new(evals); + if NUM_BITS % LIMB_BITS != 0 { + let remainder = NUM_BITS % LIMB_BITS; + let mut evals = vec![]; + (0..1 << remainder).for_each(|i| { + evals.push(F::from(i)); + }); + let rem_subtable_poly = MultilinearPolynomial::new(evals); + vec![limb_subtable_poly, rem_subtable_poly] + } else { + vec![limb_subtable_poly] + } + } + + fn subtable_polys_terms(&self) -> Vec> { + let limb_init = Var(0); + let mut limb_terms = vec![limb_init]; + (1..LIMB_BITS).for_each(|i| { + let coeff = Pow(Box::new(Const(F::from(2))), i as u32); + let x = Var(i); + let term = Prod(vec![coeff, x]); + limb_terms.push(term); + }); + let limb_subtable_poly = MultilinearPolynomialTerms::new(LIMB_BITS, Sum(limb_terms)); + if NUM_BITS % LIMB_BITS == 0 { + vec![limb_subtable_poly] + } else { + let remainder = NUM_BITS % LIMB_BITS; + let rem_init = Var(0); + let mut rem_terms = vec![rem_init]; + (1..remainder).for_each(|i| { + let coeff = Pow(Box::new(Const(F::from(2))), i as u32); + let x = Var(i); + let term = Prod(vec![coeff, x]); + rem_terms.push(term); + }); + vec![ + limb_subtable_poly, + MultilinearPolynomialTerms::new(remainder, Sum(rem_terms)), + ] + } + } + + fn memory_to_chunk_index(&self, memory_index: usize) -> usize { + memory_index + } + + fn memory_to_subtable_index(&self, memory_index: usize) -> usize { + if NUM_BITS % LIMB_BITS != 0 && memory_index == NUM_BITS / LIMB_BITS { + 1 + } else { + 0 + } + } +} + +#[cfg(test)] +mod test { + use std::array; + + use super::RangeTable; + use crate::{ + backend::{ + hyperplonk::{prover::instance_polys, util::Permutation, HyperPlonk}, + lookup::lasso::DecomposableTable, + mock::MockCircuit, + test::run_plonkish_backend, + PlonkishCircuit, PlonkishCircuitInfo, + }, + pcs::{ + multilinear::{ + Gemini, MultilinearBrakedown, MultilinearHyrax, MultilinearIpa, MultilinearKzg, + Zeromorph, + }, + univariate::UnivariateKzg, + }, + poly::Polynomial, + util::{ + code::BrakedownSpec6, + expression::{Expression, Query, Rotation}, + hash::Keccak256, + test::{rand_idx, rand_vec, seeded_std_rng}, + transcript::Keccak256Transcript, + }, + }; + use halo2_curves::{ + bn256::{self, Bn256}, + ff::PrimeField, + grumpkin, + }; + use num_integer::Integer; + use rand::RngCore; + + fn rand_vanilla_plonk_with_lasso_lookup_circuit( + num_vars: usize, + table: Box>, + mut preprocess_rng: impl RngCore, + mut witness_rng: impl RngCore, + ) -> (PlonkishCircuitInfo, impl PlonkishCircuit) { + let size = 1 << num_vars; + let mut polys = [(); 10].map(|_| vec![F::ZERO; size]); + + let instances = rand_vec(num_vars, &mut witness_rng); + polys[0] = instance_polys(num_vars, [&instances])[0].evals().to_vec(); + + let mut permutation = Permutation::default(); + for poly in [7, 8, 9] { + permutation.copy((poly, 1), (poly, 1)); + } + for idx in 0..size - 1 { + let w_l = if preprocess_rng.next_u32().is_even() && idx > 1 { + let l_copy_idx = (7, rand_idx(1..idx, &mut preprocess_rng)); + permutation.copy(l_copy_idx, (7, idx)); + polys[l_copy_idx.0][l_copy_idx.1] + } else { + let value = witness_rng.next_u64() as usize; + F::from_u128(value.pow(2) as u128); + F::from(value as u64).square() + }; + let w_r = F::from(witness_rng.next_u64()); + let q_c = F::random(&mut preprocess_rng); + let q_range = F::ONE; + let values = if preprocess_rng.next_u32().is_even() { + vec![ + (1, F::ONE), + (2, F::ONE), + (4, -F::ONE), + (5, q_c), + (6, q_range), + (7, w_l), + (8, w_r), + (9, w_l + w_r + q_c + polys[0][idx]), + ] + } else { + vec![ + (3, F::ONE), + (4, -F::ONE), + (5, q_c), + (6, q_range), + (7, w_l), + (8, w_r), + (9, w_l * w_r + q_c + polys[0][idx]), + ] + }; + for (poly, value) in values { + polys[poly][idx] = value; + } + } + let [_, q_l, q_r, q_m, q_o, q_c, q_range, w_l, w_r, w_o] = polys; + let circuit_info = vanilla_plonk_with_lasso_lookup_circuit_info( + num_vars, + instances.len(), + [q_l, q_r, q_m, q_o, q_c, q_range], + table, + permutation.into_cycles(), + ); + ( + circuit_info, + MockCircuit::new(vec![instances], vec![w_l, w_r, w_o]), + ) + } + + fn vanilla_plonk_with_lasso_lookup_circuit_info( + num_vars: usize, + num_instances: usize, + preprocess_polys: [Vec; 6], + table: Box>, + permutations: Vec>, + ) -> PlonkishCircuitInfo { + let [pi, q_l, q_r, q_m, q_o, q_c, q_range, w_l, w_r, w_o] = + &array::from_fn(|poly| Query::new(poly, Rotation::cur())) + .map(Expression::::Polynomial); + let lasso_lookup_indices = w_l.clone(); + let lasso_lookup_output = w_l.clone(); + let chunk_bits = table.chunk_bits(); + let num_vars = chunk_bits.iter().chain([&num_vars]).max().unwrap(); + PlonkishCircuitInfo { + k: *num_vars, + num_instances: vec![num_instances], + preprocess_polys: preprocess_polys.to_vec(), + num_witness_polys: vec![3], + num_challenges: vec![0], + constraints: vec![q_l * w_l + q_r * w_r + q_m * w_l * w_r + q_o * w_o + q_c + pi], + lookups: vec![vec![]], + lasso_lookup: Some((q_range * lasso_lookup_output, q_range * lasso_lookup_indices, table)), + permutations, + max_degree: Some(4), + } + } + + macro_rules! test { + ($name:ident, $f:ty, $pcs:ty, $num_vars_range:expr) => { + paste::paste! { + #[test] + fn [<$name _hyperplonk_vanilla_plonk_with_lasso_lookup>]() { + run_plonkish_backend::<_, HyperPlonk<$pcs>, Keccak256Transcript<_>, _>($num_vars_range, |num_vars| { + let table = Box::new(RangeTable::<$f, 128, 16>::new()); + rand_vanilla_plonk_with_lasso_lookup_circuit(num_vars, table, seeded_std_rng(), seeded_std_rng()) + }); + } + } + }; + ($name:ident, $f:ty, $pcs:ty) => { + test!($name, $f, $pcs, 16..17); + }; + } + + test!(brakedown, bn256::Fr, MultilinearBrakedown); + test!( + hyrax, + grumpkin::Fr, + MultilinearHyrax, + 5..16 + ); + test!(ipa, grumpkin::Fr, MultilinearIpa); + test!(kzg, bn256::Fr, MultilinearKzg); + test!(gemini_kzg, bn256::Fr, Gemini>); + test!(zeromorph_kzg, bn256::Fr, Zeromorph>); +} diff --git a/plonkish_backend/src/backend/lookup/lasso/verifier/mod.rs b/plonkish_backend/src/backend/lookup/lasso/verifier/mod.rs new file mode 100644 index 0000000..c8b2824 --- /dev/null +++ b/plonkish_backend/src/backend/lookup/lasso/verifier/mod.rs @@ -0,0 +1,192 @@ +use std::{collections::HashMap, iter, marker::PhantomData}; + +use halo2_curves::ff::{Field, PrimeField}; +use itertools::Itertools; + +use crate::{ + pcs::{Evaluation, PolynomialCommitmentScheme}, + piop::sum_check::{ + classic::{ClassicSumCheck, EvaluationsProver}, + evaluate, SumCheck, + }, + poly::multilinear::MultilinearPolynomial, + util::transcript::{FieldTranscriptRead, TranscriptRead}, + Error, +}; + +use super::{ + memory_checking::verifier::{Chunk, Memory, MemoryCheckingVerifier}, + prover::Surge, + DecomposableTable, +}; + +pub struct LassoVerifier< + F: Field + PrimeField, + Pcs: PolynomialCommitmentScheme>, +>(PhantomData, PhantomData); + +impl< + F: Field + PrimeField, + Pcs: PolynomialCommitmentScheme>, + > LassoVerifier +{ + pub fn read_commitments( + vp: &Pcs::VerifierParam, + table: &Box>, + transcript: &mut impl TranscriptRead, + ) -> Result, Error> { + // read output_comm, dim_comms + let num_chunks = table.chunk_bits().len(); + let num_memories = table.num_memories(); + let output_comm = Pcs::read_commitment(vp, transcript)?; + let dim_comms = Pcs::read_commitments(vp, num_chunks, transcript)?; + + // read read_ts_comms & final_cts_comms & e_comms + let read_ts_comms = Pcs::read_commitments(vp, num_chunks, transcript)?; + let final_cts_comms = Pcs::read_commitments(vp, num_chunks, transcript)?; + let e_comms = Pcs::read_commitments(vp, num_memories, transcript)?; + Ok(iter::empty() + .chain(vec![output_comm]) + .chain(dim_comms) + .chain(read_ts_comms) + .chain(final_cts_comms) + .chain(e_comms) + .collect_vec()) + } + + pub fn verify_sum_check( + table: &Box>, + num_vars: usize, + polys_offset: usize, + points_offset: usize, + lookup_opening_points: &mut Vec>, + lookup_opening_evals: &mut Vec>, + r: &[F], + transcript: &mut impl FieldTranscriptRead, + ) -> Result<(), Error> { + let expression = Surge::::sum_check_expression(&table); + let claim = transcript.read_field_element()?; + let (x_eval, x) = ClassicSumCheck::>::verify( + &(), + num_vars, + expression.degree(), + claim, + transcript, + )?; + lookup_opening_points.extend_from_slice(&[r.to_vec(), x.clone()]); + + let pcs_query = expression.used_query(); + let evals = pcs_query + .into_iter() + .map(|query| { + let value = transcript.read_field_element().unwrap(); + (query, value) + }) + .collect(); + if evaluate(&expression, num_vars, &evals, &[], &[r], &x) != x_eval { + return Err(Error::InvalidSnark( + "Unmatched between Lasso sum_check output and query evaluation".to_string(), + )); + } + let e_polys_offset = polys_offset + 1 + table.chunk_bits().len() * 3; + let evals = evals + .into_iter() + .sorted_by(|a, b| Ord::cmp(&a.0, &b.0)) + .map(|(query, value)| { + Evaluation::new(e_polys_offset + query.poly(), points_offset + 1, value) + }) + .chain([Evaluation::new(polys_offset, points_offset, claim)]) + .collect_vec(); + lookup_opening_evals.extend_from_slice(&evals); + Ok(()) + } + + fn chunks(table: &Box>) -> Vec> { + let num_memories = table.num_memories(); + let chunk_bits = table.chunk_bits(); + let subtable_polys = table.subtable_polys_terms(); + // key: chunk index, value: chunk + let mut chunk_map: HashMap> = HashMap::new(); + (0..num_memories).for_each(|memory_index| { + let chunk_index = table.memory_to_chunk_index(memory_index); + let chunk_bits = chunk_bits[chunk_index]; + let subtable_poly = &subtable_polys[table.memory_to_subtable_index(memory_index)]; + let memory = Memory::new(memory_index, subtable_poly.clone()); + if chunk_map.get(&chunk_index).is_some() { + chunk_map.entry(chunk_index).and_modify(|chunk| { + chunk.add_memory(memory); + }); + } else { + chunk_map.insert(chunk_index, Chunk::new(chunk_index, chunk_bits, memory)); + } + }); + + // sanity check + { + let num_chunks = table.chunk_bits().len(); + assert_eq!(chunk_map.len(), num_chunks); + } + + let mut chunks = chunk_map.into_iter().collect_vec(); + chunks.sort_by_key(|(chunk_index, _)| *chunk_index); + chunks.into_iter().map(|(_, chunk)| chunk).collect_vec() + } + + fn prepare_memory_checking( + table: &Box>, + ) -> Vec> { + let chunks = Self::chunks(table); + let chunk_bits = table.chunk_bits(); + // key: chunk_bits, value: chunks + let mut chunk_map = HashMap::>>::new(); + chunks + .into_iter() + .enumerate() + .for_each(|(chunk_index, chunk)| { + let chunk_bits = chunk_bits[chunk_index]; + if chunk_map.get(&chunk_bits).is_some() { + chunk_map.entry(chunk_bits).and_modify(|chunks| { + chunks.push(chunk); + }); + } else { + chunk_map.insert(chunk_bits, vec![chunk]); + } + }); + chunk_map + .into_iter() + .enumerate() + .map(|(_, (_, chunks))| MemoryCheckingVerifier::new(chunks)) + .collect_vec() + } + + pub fn memory_checking( + num_reads: usize, + polys_offset: usize, + points_offset: usize, + lookup_opening_points: &mut Vec>, + lookup_opening_evals: &mut Vec>, + table: &Box>, + gamma: &F, + tau: &F, + transcript: &mut impl FieldTranscriptRead, + ) -> Result<(), Error> { + let memory_checking = Self::prepare_memory_checking(table); + memory_checking + .iter() + .map(|memory_checking| { + memory_checking.verify( + table.chunk_bits().len(), + num_reads, + polys_offset, + points_offset, + &gamma, + &tau, + lookup_opening_points, + lookup_opening_evals, + transcript, + ) + }) + .collect::, Error>>()?; + Ok(()) + } +} diff --git a/plonkish_backend/src/backend/lookup/mod.rs b/plonkish_backend/src/backend/lookup/mod.rs new file mode 100644 index 0000000..8c4791e --- /dev/null +++ b/plonkish_backend/src/backend/lookup/mod.rs @@ -0,0 +1,49 @@ +use std::fmt::Debug; + +use halo2_curves::ff::Field; + +use crate::{ + pcs::{CommitmentChunk, PolynomialCommitmentScheme}, + poly::multilinear::MultilinearPolynomial, + util::{expression::Expression, transcript::TranscriptWrite}, + Error, +}; + +pub mod lasso; + +pub struct MVLookupStrategyOutput< + F: Field, + Pcs: PolynomialCommitmentScheme>, +> { + polys: Vec>>, + comms: Vec>, +} + +impl>> + MVLookupStrategyOutput +{ + pub fn polys(&self) -> Vec> { + self.polys.concat() + } + + pub fn comms(&self) -> Vec { + self.comms.concat() + } +} + +pub trait MVLookupStrategy: Clone + Debug { + type Pcs: PolynomialCommitmentScheme>; + + fn preprocess( + lookups: &[Vec<(Expression, Expression)>], + polys: &[&MultilinearPolynomial], + challenges: &mut Vec, + ) -> Result; 2]>, Error>; + + fn commit( + pp: &>::ProverParam, + lookup_polys: &[[MultilinearPolynomial; 2]], + challenges: &mut Vec, + transcript: &mut impl TranscriptWrite, F>, + ) -> Result, Error>; +} diff --git a/plonkish_backend/src/frontend/halo2.rs b/plonkish_backend/src/frontend/halo2.rs index 9bae2ad..ad3325a 100644 --- a/plonkish_backend/src/frontend/halo2.rs +++ b/plonkish_backend/src/frontend/halo2.rs @@ -21,6 +21,7 @@ use std::{ #[cfg(any(test, feature = "benchmark"))] pub mod circuit; + #[cfg(test)] mod test; @@ -155,6 +156,7 @@ impl> PlonkishCircuit for Halo2Circuit { num_challenges: num_by_phase(&cs.challenge_phase()), constraints, lookups, + lasso_lookup: None, permutations, max_degree: Some(cs.degree::()), }) diff --git a/plonkish_backend/src/piop/gkr.rs b/plonkish_backend/src/piop/gkr.rs index b26907c..1e7a661 100644 --- a/plonkish_backend/src/piop/gkr.rs +++ b/plonkish_backend/src/piop/gkr.rs @@ -1,3 +1,26 @@ mod fractional_sum_check; +mod grand_product; + +use std::collections::HashMap; pub use fractional_sum_check::{prove_fractional_sum_check, verify_fractional_sum_check}; +pub use grand_product::{prove_grand_product, verify_grand_product}; +use halo2_curves::ff::PrimeField; +use itertools::izip; + +use crate::{ + util::expression::{Query, Rotation}, + Error, +}; + +fn eval_by_query(evals: &[F]) -> HashMap { + izip!( + (0..).map(|idx| Query::new(idx, Rotation::cur())), + evals.iter().cloned() + ) + .collect() +} + +fn err_unmatched_sum_check_output() -> Error { + Error::InvalidSumcheck("Unmatched between sum_check output and query evaluation".to_string()) +} diff --git a/plonkish_backend/src/piop/gkr/fractional_sum_check.rs b/plonkish_backend/src/piop/gkr/fractional_sum_check.rs index 5e16213..1e87224 100644 --- a/plonkish_backend/src/piop/gkr/fractional_sum_check.rs +++ b/plonkish_backend/src/piop/gkr/fractional_sum_check.rs @@ -4,9 +4,12 @@ //! [PH23]: https://eprint.iacr.org/2023/1284.pdf use crate::{ - piop::sum_check::{ - classic::{ClassicSumCheck, EvaluationsProver}, - evaluate, SumCheck as _, VirtualPolynomial, + piop::{ + gkr::{err_unmatched_sum_check_output, eval_by_query}, + sum_check::{ + classic::{ClassicSumCheck, EvaluationsProver}, + evaluate, SumCheck as _, VirtualPolynomial, + }, }, poly::{multilinear::MultilinearPolynomial, Polynomial}, util::{ @@ -20,7 +23,7 @@ use crate::{ }, Error, }; -use std::{array, collections::HashMap, iter}; +use std::{array, iter}; type SumCheck = ClassicSumCheck>; @@ -295,18 +298,6 @@ fn layer_down_claim(evals: &[F], mu: F) -> (Vec, Vec) { .unzip() } -fn eval_by_query(evals: &[F]) -> HashMap { - izip!( - (0..).map(|idx| Query::new(idx, Rotation::cur())), - evals.iter().cloned() - ) - .collect() -} - -fn err_unmatched_sum_check_output() -> Error { - Error::InvalidSumcheck("Unmatched between sum_check output and query evaluation".to_string()) -} - #[cfg(test)] mod test { use crate::{ diff --git a/plonkish_backend/src/piop/gkr/grand_product.rs b/plonkish_backend/src/piop/gkr/grand_product.rs new file mode 100644 index 0000000..04893c2 --- /dev/null +++ b/plonkish_backend/src/piop/gkr/grand_product.rs @@ -0,0 +1,312 @@ +use std::{array, iter}; + +use halo2_curves::ff::PrimeField; +use itertools::{chain, izip, Itertools}; + +use crate::{ + piop::{ + gkr::{err_unmatched_sum_check_output, eval_by_query}, + sum_check::{ + classic::{ClassicSumCheck, EvaluationsProver}, + evaluate, SumCheck as _, VirtualPolynomial, + }, + }, + poly::{multilinear::MultilinearPolynomial, Polynomial}, + util::{ + arithmetic::{div_ceil, inner_product, powers}, + expression::{Expression, Query, Rotation}, + parallel::{num_threads, parallelize_iter}, + transcript::{FieldTranscriptRead, FieldTranscriptWrite}, + }, + Error, +}; + +type SumCheck = ClassicSumCheck>; + +struct Layer { + v_l: MultilinearPolynomial, + v_r: MultilinearPolynomial, +} + +impl From<[Vec; 2]> for Layer { + fn from(values: [Vec; 2]) -> Self { + let [v_l, v_r] = values.map(MultilinearPolynomial::new); + Self { v_l, v_r } + } +} + +impl Layer { + fn bottom(v: &&MultilinearPolynomial) -> Self { + let mid = v.evals().len() >> 1; + [&v[..mid], &v[mid..]].map(ToOwned::to_owned).into() + } + + fn num_vars(&self) -> usize { + self.v_l.num_vars() + } + + fn polys(&self) -> [&MultilinearPolynomial; 2] { + [&self.v_l, &self.v_r] + } + + fn poly_chunks(&self, chunk_size: usize) -> impl Iterator { + let [v_l, v_r] = self.polys().map(|poly| poly.evals().chunks(chunk_size)); + izip!(v_l, v_r) + } + + fn up(&self) -> Self { + assert!(self.num_vars() != 0); + + let len = 1 << self.num_vars(); + let chunk_size = div_ceil(len, num_threads()).next_power_of_two(); + + let mut outputs: [_; 2] = array::from_fn(|_| vec![F::ZERO; len >> 1]); + let (v_up_l, v_up_r) = outputs.split_at_mut(1); + + parallelize_iter( + izip!( + chain![v_up_l, v_up_r].flat_map(|v_up| v_up.chunks_mut(chunk_size)), + self.poly_chunks(chunk_size), + ), + |(v_up, (v_l, v_r))| { + izip!(v_up, v_l, v_r).for_each(|(v_up, v_l, v_r)| { + *v_up = *v_l * *v_r; + }) + }, + ); + + outputs.into() + } +} + +pub fn prove_grand_product<'a, F: PrimeField>( + claimed_v_0s: impl IntoIterator>, + vs: impl IntoIterator>, + transcript: &mut impl FieldTranscriptWrite, +) -> Result<(Vec, Vec), Error> { + let claimed_v_0s = claimed_v_0s.into_iter().collect_vec(); + let vs = vs.into_iter().collect_vec(); + let num_batching = claimed_v_0s.len(); + + assert!(num_batching != 0); + assert_eq!(num_batching, vs.len()); + for poly in &vs { + assert_eq!(poly.num_vars(), vs[0].num_vars()); + } + + let bottom_layers = vs.iter().map(Layer::bottom).collect_vec(); + let layers = iter::successors(bottom_layers.into(), |layers| { + (layers[0].num_vars() > 0).then(|| layers.iter().map(Layer::up).collect()) + }) + .collect_vec(); + + let claimed_v_0s = { + let v_0s = chain![layers.last().unwrap()] + .map(|layer| { + let [v_l, v_r] = layer.polys().map(|poly| poly[0]); + v_l * v_r + }) + .collect_vec(); + + let mut hash_to_transcript = |claimed: Vec<_>, computed: Vec<_>| { + izip!(claimed, computed) + .map(|(claimed, computed)| match claimed { + Some(claimed) => { + if cfg!(feature = "sanity-check") { + assert_eq!(claimed, computed) + } + transcript.common_field_element(&computed).map(|_| computed) + } + None => transcript.write_field_element(&computed).map(|_| computed), + }) + .try_collect::<_, Vec<_>, _>() + }; + + hash_to_transcript(claimed_v_0s, v_0s)? + }; + + let expression = sum_check_expression(num_batching); + + let (v_xs, x) = + layers + .iter() + .rev() + .fold(Ok((claimed_v_0s, Vec::new())), |result, layers| { + let (claimed_v_ys, y) = result?; + + let num_vars = layers[0].num_vars(); + let polys = layers.iter().flat_map(|layer| layer.polys()); + + let (mut x, evals) = if num_vars == 0 { + (vec![], polys.map(|poly| poly[0]).collect_vec()) + } else { + let gamma = transcript.squeeze_challenge(); + + let (x, evals) = { + let claim = sum_check_claim(&claimed_v_ys, gamma); + SumCheck::prove( + &(), + num_vars, + VirtualPolynomial::new(&expression, polys, &[gamma], &[y]), + claim, + transcript, + )? + }; + + (x, evals) + }; + + transcript.write_field_elements(&evals)?; + + let mu = transcript.squeeze_challenge(); + + let v_xs = layer_down_claim(&evals, mu); + x.push(mu); + + Ok((v_xs, x)) + })?; + + if cfg!(feature = "sanity-check") { + izip!(vs, &v_xs).for_each(|(poly, eval)| assert_eq!(poly.evaluate(&x), *eval)); + } + + Ok((v_xs, x)) +} + +pub fn verify_grand_product( + num_vars: usize, + claimed_v_0s: impl IntoIterator>, + transcript: &mut impl FieldTranscriptRead, +) -> Result<(Vec, Vec), Error> { + let claimed_v_0s = claimed_v_0s.into_iter().collect_vec(); + let num_batching = claimed_v_0s.len(); + + assert!(num_batching != 0); + let claimed_v_0s = { + claimed_v_0s + .into_iter() + .map(|claimed| match claimed { + Some(claimed) => transcript.common_field_element(&claimed).map(|_| claimed), + None => transcript.read_field_element(), + }) + .try_collect::<_, Vec<_>, _>()? + }; + + let expression = sum_check_expression(num_batching); + + let (v_xs, x) = (0..num_vars).fold(Ok((claimed_v_0s, Vec::new())), |result, num_vars| { + let (claimed_v_ys, y) = result?; + + let (mut x, evals) = if num_vars == 0 { + let evals = transcript.read_field_elements(2 * num_batching)?; + + for (claimed_v, (&v_l, &v_r)) in izip!(claimed_v_ys, evals.iter().tuples()) { + if claimed_v != v_l * v_r { + return Err(err_unmatched_sum_check_output()); + } + } + + (Vec::new(), evals) + } else { + let gamma = transcript.squeeze_challenge(); + + let (x_eval, x) = { + let claim = sum_check_claim(&claimed_v_ys, gamma); + SumCheck::verify(&(), num_vars, expression.degree(), claim, transcript)? + }; + + let evals = transcript.read_field_elements(2 * num_batching)?; + + let eval_by_query = eval_by_query(&evals); + if x_eval != evaluate(&expression, num_vars, &eval_by_query, &[gamma], &[&y], &x) { + return Err(err_unmatched_sum_check_output()); + } + + (x, evals) + }; + + let mu = transcript.squeeze_challenge(); + + let v_xs = layer_down_claim(&evals, mu); + x.push(mu); + + Ok((v_xs, x)) + })?; + + Ok((v_xs, x)) +} + +fn sum_check_expression(num_batching: usize) -> Expression { + let exprs = &(0..2 * num_batching) + .map(|idx| Expression::::Polynomial(Query::new(idx, Rotation::cur()))) + .tuples() + .map(|(ref v_l, ref v_r)| v_l * v_r) + .collect_vec(); + let eq_xy = &Expression::eq_xy(0); + let gamma = &Expression::Challenge(0); + Expression::distribute_powers(exprs, gamma) * eq_xy +} + +fn sum_check_claim(claimed_v_ys: &[F], gamma: F) -> F { + inner_product( + claimed_v_ys, + &powers(gamma).take(claimed_v_ys.len()).collect_vec(), + ) +} + +fn layer_down_claim(evals: &[F], mu: F) -> Vec { + evals + .iter() + .tuples() + .map(|(&v_l, &v_r)| v_l + mu * (v_r - v_l)) + .collect_vec() +} + +#[cfg(test)] +mod tests { + use std::iter; + + use itertools::{chain, Itertools}; + + use crate::{ + piop::gkr::{prove_grand_product, verify_grand_product}, + poly::multilinear::MultilinearPolynomial, + util::{ + izip_eq, + test::{rand_vec, seeded_std_rng}, + transcript::{InMemoryTranscript, Keccak256Transcript}, + }, + }; + use halo2_curves::bn256::Fr; + + #[test] + fn grand_product_test() { + let num_batching = 4; + for num_vars in 1..16 { + let mut rng = seeded_std_rng(); + + let vs = iter::repeat_with(|| rand_vec(1 << num_vars, &mut rng)) + .map(MultilinearPolynomial::new) + .take(num_batching) + .collect_vec(); + let v_0s = vec![None; num_batching]; + + let proof = { + let mut transcript = Keccak256Transcript::new(()); + prove_grand_product::(v_0s.to_vec(), vs.iter(), &mut transcript).unwrap(); + transcript.into_proof() + }; + + let result = { + let mut transcript = Keccak256Transcript::from_proof((), proof.as_slice()); + verify_grand_product::(num_vars, v_0s.to_vec(), &mut transcript) + }; + assert_eq!(result.as_ref().map(|_| ()), Ok(())); + + let (v_xs, x) = result.unwrap(); + for (poly, eval) in izip_eq!(chain![vs], chain![v_xs]) { + assert_eq!(poly.evaluate(&x), eval); + } + } + } +} diff --git a/plonkish_backend/src/poly/multilinear.rs b/plonkish_backend/src/poly/multilinear.rs index df3dc2a..a7e3052 100644 --- a/plonkish_backend/src/poly/multilinear.rs +++ b/plonkish_backend/src/poly/multilinear.rs @@ -1,15 +1,17 @@ use crate::{ poly::Polynomial, util::{ - arithmetic::{div_ceil, usize_from_bits_le, BooleanHypercube, Field}, + arithmetic::{div_ceil, fe_from_le_bytes, usize_from_bits_le, BooleanHypercube, Field}, expression::Rotation, impl_index, parallel::{num_threads, parallelize, parallelize_iter}, BitIndex, Deserialize, Itertools, Serialize, }, }; +use halo2_curves::ff::PrimeField; use num_integer::Integer; use rand::RngCore; +use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; use std::{ borrow::Cow, iter::{self, Sum}, @@ -17,6 +19,56 @@ use std::{ ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}, }; +/// Multilinear polynomials are represented as expressions +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub struct MultilinearPolynomialTerms { + num_vars: usize, + expression: PolyExpr, +} + +impl MultilinearPolynomialTerms { + pub fn new(num_vars: usize, expression: PolyExpr) -> Self { + Self { + num_vars, + expression, + } + } +} + +impl MultilinearPolynomialTerms { + pub fn evaluate(&self, x: &[F]) -> F { + assert_eq!(x.len(), self.num_vars); + self.expression.evaluate(x) + } +} + +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum PolyExpr { + Const(F), + Var(usize), + Sum(Vec>), + Prod(Vec>), + Pow(Box>, u32), +} + +impl PolyExpr { + fn evaluate(&self, x: &[F]) -> F { + match self { + PolyExpr::Const(c) => c.clone(), + PolyExpr::Var(i) => x[*i], + PolyExpr::Sum(v) => v + .par_iter() + .map(|t| t.evaluate(x)) + .reduce(|| F::ZERO, |acc, f| acc + f), + PolyExpr::Prod(v) => v + .par_iter() + .map(|t| t.evaluate(x)) + .reduce(|| F::ONE, |acc, f| acc * f), + PolyExpr::Pow(inner, e) => inner.evaluate(x).pow([*e as u64]), + } + } +} + #[derive(Clone, Debug, Serialize, Deserialize)] pub struct MultilinearPolynomial { evals: Vec, @@ -62,6 +114,16 @@ impl MultilinearPolynomial { } } +impl MultilinearPolynomial { + pub fn from_usize(evals: Vec) -> Self { + let evals = evals + .iter() + .map(|eval| fe_from_le_bytes(eval.to_le_bytes())) + .collect_vec(); + Self::new(evals) + } +} + impl Polynomial for MultilinearPolynomial { type Point = Vec; diff --git a/plonkish_backend/src/util/arithmetic.rs b/plonkish_backend/src/util/arithmetic.rs index 1a44679..100ea8e 100644 --- a/plonkish_backend/src/util/arithmetic.rs +++ b/plonkish_backend/src/util/arithmetic.rs @@ -186,6 +186,23 @@ pub fn usize_from_bits_le(bits: &[bool]) -> usize { .fold(0, |int, bit| (int << 1) + (*bit as usize)) } +pub fn fe_to_bits_le(fe: F) -> Vec { + let repr = fe.to_repr(); + let bytes = repr.as_ref(); + bytes + .iter() + .flat_map(|byte| { + let value = u8::from_le(*byte); + let mut bits = vec![]; + for i in 0..8 { + let mask = 1 << i; + bits.push(value & mask > 0); + } + bits + }) + .collect_vec() +} + pub fn div_rem(dividend: usize, divisor: usize) -> (usize, usize) { Integer::div_rem(÷nd, &divisor) } @@ -194,6 +211,27 @@ pub fn div_ceil(dividend: usize, divisor: usize) -> usize { Integer::div_ceil(÷nd, &divisor) } +pub fn split_bits(item: usize, num_bits: usize) -> (usize, usize) { + let max_value = (1 << num_bits) - 1; // Calculate the maximum value that can be represented with num_bits + + let low_chunk = item & max_value; // Extract the lower bits + let high_chunk = (item >> num_bits) & max_value; // Shift the item to the right and extract the next set of bits + + (high_chunk, low_chunk) +} + +pub fn split_by_chunk_bits(bits: &[bool], chunk_bits: &[usize]) -> Vec> { + let mut offset = 0; + let mut chunked_bits = vec![]; + chunk_bits.iter().for_each(|chunk_bits| { + let mut chunked = vec![true; *chunk_bits]; + chunked.copy_from_slice(&bits[offset..offset + chunk_bits]); + chunked_bits.push(chunked); + offset = offset + chunk_bits; + }); + chunked_bits +} + #[cfg(test)] mod test { use crate::util::arithmetic; diff --git a/rust-toolchain b/rust-toolchain index 77c582d..274ca3d 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -1.67.0 \ No newline at end of file +nightly-2023-09-22 \ No newline at end of file