From 664a10885cb9f5068a4f7ccad91091e27710b0a0 Mon Sep 17 00:00:00 2001 From: erabinov Date: Wed, 30 Oct 2024 11:55:10 -0700 Subject: [PATCH 1/2] types working, e2e working --- crates/prover/src/lib.rs | 116 +++++++++++------- crates/prover/src/shapes.rs | 8 ++ .../recursion/circuit/src/machine/compress.rs | 2 +- crates/recursion/core/src/shape.rs | 114 ++--------------- 4 files changed, 86 insertions(+), 154 deletions(-) diff --git a/crates/prover/src/lib.rs b/crates/prover/src/lib.rs index be4d816264..fa393ec16b 100644 --- a/crates/prover/src/lib.rs +++ b/crates/prover/src/lib.rs @@ -20,7 +20,7 @@ pub mod verify; use std::{ borrow::Borrow, - collections::BTreeMap, + collections::{BTreeMap, HashMap}, env, num::NonZeroUsize, path::Path, @@ -34,6 +34,7 @@ use std::{ use lru::LruCache; +use shapes::SP1ProofShape; use tracing::instrument; use p3_baby_bear::BabyBear; @@ -135,8 +136,7 @@ pub struct SP1Prover { pub recursion_cache_misses: AtomicUsize, - pub compress_programs: - Mutex>>>, + pub compress_programs: HashMap>>, pub compress_cache_misses: AtomicUsize, @@ -188,14 +188,6 @@ impl SP1Prover { ) .expect("PROVER_CORE_CACHE_SIZE must be a non-zero usize"); - let compress_cache_size = NonZeroUsize::new( - env::var("PROVER_COMPRESS_CACHE_SIZE") - .unwrap_or_else(|_| CORE_CACHE_SIZE.to_string()) - .parse() - .unwrap_or(COMPRESS_CACHE_SIZE), - ) - .expect("PROVER_COMPRESS_CACHE_SIZE must be a non-zero usize"); - let core_shape_config = env::var("FIX_CORE_SHAPES") .map(|v| v.eq_ignore_ascii_case("true")) .unwrap_or(true) @@ -220,6 +212,41 @@ impl SP1Prover { let (root, merkle_tree) = MerkleTree::commit(allowed_vk_map.keys().copied().collect()); + let mut compress_programs = HashMap::new(); + if let Some(config) = &recursion_shape_config { + SP1ProofShape::generate_compress_shapes(config, 2).for_each(|shape| { + let compress_shape = SP1CompressWithVkeyShape { + compress_shape: SP1CompressShape { proof_shapes: shape }, + merkle_tree_height: merkle_tree.height, + }; + let input = SP1CompressWithVKeyWitnessValues::dummy( + compress_prover.machine(), + &compress_shape, + ); + let mut builder = Builder::::default(); + // read the input. + let input = input.read(&mut builder); + // Verify the proof. + SP1CompressWithVKeyVerifier::verify( + &mut builder, + compress_prover.machine(), + input, + vk_verification, + PublicValuesOutputDigest::Reduce, + ); + let operations = builder.into_operations(); + + // Compile the program. + let compiler_span = tracing::debug_span!("compile compress program").entered(); + let mut compiler = AsmCompiler::::default(); + let mut program = compiler.compile(operations); + config.fix_shape(&mut program); + let program = Arc::new(program); + compiler_span.exit(); + compress_programs.insert(compress_shape, program); + }); + } + Self { core_prover, compress_prover, @@ -227,7 +254,7 @@ impl SP1Prover { wrap_prover, recursion_programs: Mutex::new(LruCache::new(core_cache_size)), recursion_cache_misses: AtomicUsize::new(0), - compress_programs: Mutex::new(LruCache::new(compress_cache_size)), + compress_programs, compress_cache_misses: AtomicUsize::new(0), vk_root: root, vk_merkle_tree: merkle_tree, @@ -355,40 +382,37 @@ impl SP1Prover { &self, input: &SP1CompressWithVKeyWitnessValues, ) -> Arc> { - let mut cache = self.compress_programs.lock().unwrap_or_else(|e| e.into_inner()); - cache - .get_or_insert(input.shape(), || { - let misses = self.compress_cache_misses.fetch_add(1, Ordering::Relaxed); - tracing::debug!("compress cache miss, misses: {}", misses); - // Get the operations. - let builder_span = tracing::debug_span!("build compress program").entered(); - let mut builder = Builder::::default(); - - // read the input. - let input = input.read(&mut builder); - // Verify the proof. - SP1CompressWithVKeyVerifier::verify( - &mut builder, - self.compress_prover.machine(), - input, - self.vk_verification, - PublicValuesOutputDigest::Reduce, - ); - let operations = builder.into_operations(); - builder_span.exit(); - - // Compile the program. - let compiler_span = tracing::debug_span!("compile compress program").entered(); - let mut compiler = AsmCompiler::::default(); - let mut program = compiler.compile(operations); - if let Some(recursion_shape_config) = &self.recursion_shape_config { - recursion_shape_config.fix_shape(&mut program); - } - let program = Arc::new(program); - compiler_span.exit(); - program - }) - .clone() + self.compress_programs.get(&input.shape()).map(Clone::clone).unwrap_or_else(|| { + let misses = self.compress_cache_misses.fetch_add(1, Ordering::Relaxed); + tracing::debug!("compress cache miss, misses: {}", misses); + // Get the operations. + let builder_span = tracing::debug_span!("build compress program").entered(); + let mut builder = Builder::::default(); + + // read the input. + let input = input.read(&mut builder); + // Verify the proof. + SP1CompressWithVKeyVerifier::verify( + &mut builder, + self.compress_prover.machine(), + input, + self.vk_verification, + PublicValuesOutputDigest::Reduce, + ); + let operations = builder.into_operations(); + builder_span.exit(); + + // Compile the program. + let compiler_span = tracing::debug_span!("compile compress program").entered(); + let mut compiler = AsmCompiler::::default(); + let mut program = compiler.compile(operations); + if let Some(recursion_shape_config) = &self.recursion_shape_config { + recursion_shape_config.fix_shape(&mut program); + } + let program = Arc::new(program); + compiler_span.exit(); + program + }) } pub fn shrink_program( diff --git a/crates/prover/src/shapes.rs b/crates/prover/src/shapes.rs index b7adddc0e5..c7b55dcaac 100644 --- a/crates/prover/src/shapes.rs +++ b/crates/prover/src/shapes.rs @@ -231,6 +231,14 @@ impl SP1ProofShape { ) } + pub fn generate_compress_shapes( + recursion_shape_config: &RecursionShapeConfig>, + reduce_batch_size: usize, + ) -> impl Iterator> + '_ { + (1..=reduce_batch_size) + .flat_map(|batch_size| recursion_shape_config.get_all_shape_combinations(batch_size)) + } + pub fn dummy_vk_map<'a>( core_shape_config: &'a CoreShapeConfig, recursion_shape_config: &'a RecursionShapeConfig>, diff --git a/crates/recursion/circuit/src/machine/compress.rs b/crates/recursion/circuit/src/machine/compress.rs index fe99eb43c2..536844e875 100644 --- a/crates/recursion/circuit/src/machine/compress.rs +++ b/crates/recursion/circuit/src/machine/compress.rs @@ -73,7 +73,7 @@ pub struct SP1CompressWitnessValues { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct SP1CompressShape { - proof_shapes: Vec, + pub proof_shapes: Vec, } impl SP1CompressVerifier diff --git a/crates/recursion/core/src/shape.rs b/crates/recursion/core/src/shape.rs index 64756d3fed..b2a57f08a2 100644 --- a/crates/recursion/core/src/shape.rs +++ b/crates/recursion/core/src/shape.rs @@ -27,7 +27,7 @@ pub struct RecursionShape { } pub struct RecursionShapeConfig { - allowed_shapes: Vec>, + pub allowed_shapes: Vec>, _marker: PhantomData<(F, A)>, } @@ -101,122 +101,22 @@ impl, const DEGREE: usize> Default // Specify allowed shapes. let allowed_shapes = [ [ - (base_alu.clone(), 20), - (mem_var.clone(), 18), - (ext_alu.clone(), 18), - (exp_reverse_bits_len.clone(), 17), - (mem_const.clone(), 17), - (poseidon2_wide.clone(), 16), - (select.clone(), 18), - (public_values.clone(), PUB_VALUES_LOG_HEIGHT), - ], - [ - (base_alu.clone(), 20), - (mem_var.clone(), 18), - (ext_alu.clone(), 18), - (exp_reverse_bits_len.clone(), 17), - (mem_const.clone(), 16), - (poseidon2_wide.clone(), 16), - (select.clone(), 18), - (public_values.clone(), PUB_VALUES_LOG_HEIGHT), - ], - [ - (ext_alu.clone(), 20), - (base_alu.clone(), 19), - (mem_var.clone(), 19), - (poseidon2_wide.clone(), 17), - (mem_const.clone(), 16), - (exp_reverse_bits_len.clone(), 16), - (select.clone(), 18), - (public_values.clone(), PUB_VALUES_LOG_HEIGHT), - ], - [ - (base_alu.clone(), 19), - (mem_var.clone(), 18), - (ext_alu.clone(), 18), - (exp_reverse_bits_len.clone(), 17), - (mem_const.clone(), 16), - (poseidon2_wide.clone(), 16), - (select.clone(), 18), - (public_values.clone(), PUB_VALUES_LOG_HEIGHT), - ], - [ - (base_alu.clone(), 19), - (mem_var.clone(), 18), - (ext_alu.clone(), 18), - (exp_reverse_bits_len.clone(), 16), - (mem_const.clone(), 16), - (poseidon2_wide.clone(), 16), - (select.clone(), 18), - (public_values.clone(), PUB_VALUES_LOG_HEIGHT), - ], - [ - (base_alu.clone(), 20), + (ext_alu.clone(), 21), + (base_alu.clone(), 16), (mem_var.clone(), 19), - (ext_alu.clone(), 19), - (exp_reverse_bits_len.clone(), 17), - (mem_const.clone(), 17), (poseidon2_wide.clone(), 17), - (select.clone(), 19), - (public_values.clone(), PUB_VALUES_LOG_HEIGHT), - ], - [ - (base_alu.clone(), 21), - (mem_var.clone(), 19), - (ext_alu.clone(), 19), - (exp_reverse_bits_len.clone(), 18), (mem_const.clone(), 18), - (poseidon2_wide.clone(), 17), - (select.clone(), 19), - (public_values.clone(), PUB_VALUES_LOG_HEIGHT), - ], - [ - (base_alu.clone(), 21), - (mem_var.clone(), 19), - (ext_alu.clone(), 19), (exp_reverse_bits_len.clone(), 18), - (mem_const.clone(), 17), - (poseidon2_wide.clone(), 17), (select.clone(), 19), (public_values.clone(), PUB_VALUES_LOG_HEIGHT), ], [ - (ext_alu.clone(), 21), - (base_alu.clone(), 20), - (mem_var.clone(), 20), - (poseidon2_wide.clone(), 18), - (mem_const.clone(), 17), - (exp_reverse_bits_len.clone(), 17), - (select.clone(), 19), - (public_values.clone(), PUB_VALUES_LOG_HEIGHT), - ], - [ - (base_alu.clone(), 20), - (mem_var.clone(), 19), - (ext_alu.clone(), 19), - (exp_reverse_bits_len.clone(), 18), - (mem_const.clone(), 17), - (poseidon2_wide.clone(), 17), - (select.clone(), 19), - (public_values.clone(), PUB_VALUES_LOG_HEIGHT), - ], - [ - (base_alu.clone(), 20), - (mem_var.clone(), 19), - (ext_alu.clone(), 19), - (exp_reverse_bits_len.clone(), 17), - (mem_const.clone(), 17), - (poseidon2_wide.clone(), 17), - (select.clone(), 19), - (public_values.clone(), PUB_VALUES_LOG_HEIGHT), - ], - [ - (base_alu.clone(), 21), - (mem_var.clone(), 20), (ext_alu.clone(), 20), - (exp_reverse_bits_len.clone(), 18), + (base_alu.clone(), 16), + (mem_var.clone(), 19), + (poseidon2_wide.clone(), 16), (mem_const.clone(), 18), - (poseidon2_wide.clone(), 18), + (exp_reverse_bits_len.clone(), 18), (select.clone(), 19), (public_values.clone(), PUB_VALUES_LOG_HEIGHT), ], From 5f2590a74374c9bc9e710b41e9647161b43cd24c Mon Sep 17 00:00:00 2001 From: erabinov Date: Wed, 30 Oct 2024 12:35:40 -0700 Subject: [PATCH 2/2] small refactor --- crates/prover/src/lib.rs | 79 ++++++++++++++++++---------------------- 1 file changed, 36 insertions(+), 43 deletions(-) diff --git a/crates/prover/src/lib.rs b/crates/prover/src/lib.rs index fa393ec16b..6fe7678c33 100644 --- a/crates/prover/src/lib.rs +++ b/crates/prover/src/lib.rs @@ -103,7 +103,6 @@ const SHRINK_DEGREE: usize = 3; const WRAP_DEGREE: usize = 9; const CORE_CACHE_SIZE: usize = 5; -const COMPRESS_CACHE_SIZE: usize = 3; pub const REDUCE_BATCH_SIZE: usize = 2; // TODO: FIX @@ -223,26 +222,13 @@ impl SP1Prover { compress_prover.machine(), &compress_shape, ); - let mut builder = Builder::::default(); - // read the input. - let input = input.read(&mut builder); - // Verify the proof. - SP1CompressWithVKeyVerifier::verify( - &mut builder, - compress_prover.machine(), - input, + let program = compress_program_from_input::( + recursion_shape_config.as_ref(), + &compress_prover, vk_verification, - PublicValuesOutputDigest::Reduce, + &input, ); - let operations = builder.into_operations(); - - // Compile the program. - let compiler_span = tracing::debug_span!("compile compress program").entered(); - let mut compiler = AsmCompiler::::default(); - let mut program = compiler.compile(operations); - config.fix_shape(&mut program); let program = Arc::new(program); - compiler_span.exit(); compress_programs.insert(compress_shape, program); }); } @@ -386,32 +372,12 @@ impl SP1Prover { let misses = self.compress_cache_misses.fetch_add(1, Ordering::Relaxed); tracing::debug!("compress cache miss, misses: {}", misses); // Get the operations. - let builder_span = tracing::debug_span!("build compress program").entered(); - let mut builder = Builder::::default(); - - // read the input. - let input = input.read(&mut builder); - // Verify the proof. - SP1CompressWithVKeyVerifier::verify( - &mut builder, - self.compress_prover.machine(), - input, + Arc::new(compress_program_from_input::( + self.recursion_shape_config.as_ref(), + &self.compress_prover, self.vk_verification, - PublicValuesOutputDigest::Reduce, - ); - let operations = builder.into_operations(); - builder_span.exit(); - - // Compile the program. - let compiler_span = tracing::debug_span!("compile compress program").entered(); - let mut compiler = AsmCompiler::::default(); - let mut program = compiler.compile(operations); - if let Some(recursion_shape_config) = &self.recursion_shape_config { - recursion_shape_config.fix_shape(&mut program); - } - let program = Arc::new(program); - compiler_span.exit(); - program + input, + )) }) } @@ -1241,6 +1207,33 @@ impl SP1Prover { } } +pub fn compress_program_from_input( + config: Option<&RecursionShapeConfig>>, + compress_prover: &C::CompressProver, + vk_verification: bool, + input: &SP1CompressWithVKeyWitnessValues, +) -> RecursionProgram { + let mut builder = Builder::::default(); + // read the input. + let input = input.read(&mut builder); + // Verify the proof. + SP1CompressWithVKeyVerifier::verify( + &mut builder, + compress_prover.machine(), + input, + vk_verification, + PublicValuesOutputDigest::Reduce, + ); + let operations = builder.into_operations(); + + // Compile the program. + let mut compiler = AsmCompiler::::default(); + let mut program = compiler.compile(operations); + if let Some(config) = config { + config.fix_shape(&mut program); + } + program +} #[cfg(any(test, feature = "export-tests"))] pub mod tests {