Skip to content

Commit

Permalink
feat: introduce memory friendly proving path
Browse files Browse the repository at this point in the history
Storing the entire low-degree-extended trace takes up a lot of memory.
Just-in-time low-degree-extension is an alternative to this cached
approach that computes (and in some cases recomputes) low-degree-
extended columns or subtables when they are needed and drops them
afterwards.

The decision which codepath to follow is made automatically by default:
if enough RAM is available, the fast path is chosen.
This behavior can be overwritten using the environment variable
`TVM_LDE_TRACE`. Accepted values are “cache” and “no_cache”.
  • Loading branch information
jan-ferdinand committed May 7, 2024
2 parents 56f7d1f + 4721e75 commit 70b740e
Show file tree
Hide file tree
Showing 7 changed files with 1,027 additions and 242 deletions.
4 changes: 4 additions & 0 deletions triton-vm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ workspace = true
name = "bezout_coeffs"
harness = false

[[bench]]
name = "cached_vs_jit_trace"
harness = false

[[bench]]
name = "mem_io"
harness = false
Expand Down
34 changes: 34 additions & 0 deletions triton-vm/benches/cached_vs_jit_trace.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use criterion::criterion_group;
use criterion::criterion_main;
use criterion::Criterion;
use twenty_first::prelude::*;

use triton_vm::config::CacheDecision;
use triton_vm::prelude::*;

criterion_main!(benches);
criterion_group!(
name = benches;
config = Criterion::default().sample_size(10);
targets = prove_fib<1>, prove_fib<100>, prove_fib<1_000>,
);

fn prove_fib<const N: u64>(c: &mut Criterion) {
let stark = Stark::default();
let program = triton_vm::example_programs::FIBONACCI_SEQUENCE.clone();
let public_input = PublicInput::from(bfe_array![N]);
let non_determinism = NonDeterminism::default();
let (aet, output) = program
.trace_execution(public_input, non_determinism)
.unwrap();
let claim = Claim::about_program(&program)
.with_input(bfe_vec![N])
.with_output(output);

let mut group = c.benchmark_group(format!("prove_fib_{N}"));
triton_vm::config::overwrite_lde_trace_caching_to(CacheDecision::Cache);
group.bench_function("cache", |b| b.iter(|| stark.prove(&claim, &aet, &mut None)));
triton_vm::config::overwrite_lde_trace_caching_to(CacheDecision::NoCache);
group.bench_function("jit", |b| b.iter(|| stark.prove(&claim, &aet, &mut None)));
group.finish();
}
56 changes: 44 additions & 12 deletions triton-vm/src/arithmetic_domain.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use itertools::Itertools;
use std::ops::Mul;
use std::ops::MulAssign;

use num_traits::One;
use rayon::prelude::*;
use twenty_first::math::traits::FiniteField;
use twenty_first::math::traits::PrimitiveRootOfUnity;
use twenty_first::prelude::*;
Expand Down Expand Up @@ -58,16 +58,28 @@ impl ArithmeticDomain {
+ Mul<BFieldElement, Output = FF>
+ From<BFieldElement>,
{
// The limitation arises in `Polynomial::fast_coset_evaluate` in dependency `twenty-first`.
let batch_evaluation_is_possible = self.length >= polynomial.coefficients.len();
if batch_evaluation_is_possible {
let (offset, generator, length) = (self.offset, self.generator, self.length);
polynomial.fast_coset_evaluate::<BFieldElement>(offset, generator, length)
} else {
let domain_values = self.domain_values().into_iter();
let domain_values = domain_values.map(FF::from).collect_vec();
polynomial.batch_evaluate(&domain_values)
let (offset, generator, length) = (self.offset, self.generator, self.length);
let evaluate_from =
|chunk| Polynomial::from(chunk).fast_coset_evaluate(offset, generator, length);

// avoid `enumerate` to directly get index of the right type
let mut indexed_chunks = (0..).zip(polynomial.coefficients.chunks(length));

// only allocate a bunch of zeros if there are no chunks
let mut values = indexed_chunks.next().map_or_else(
|| vec![FF::zero(); length],
|(_, first_chunk)| evaluate_from(first_chunk),
);
for (chunk_index, chunk) in indexed_chunks {
let coefficient_index = chunk_index * u64::try_from(length).unwrap();
let scaled_offset = offset.mod_pow(coefficient_index);
values
.par_iter_mut()
.zip(evaluate_from(chunk))
.for_each(|(value, evaluation)| *value += evaluation * scaled_offset);
}

values
}

pub fn interpolate<FF>(&self, values: &[FF]) -> Polynomial<FF>
Expand All @@ -88,8 +100,9 @@ impl ArithmeticDomain {
target_domain.evaluate(&self.interpolate(codeword))
}

pub fn domain_value(&self, index: u32) -> BFieldElement {
self.generator.mod_pow_u32(index) * self.offset
/// Compute the `n`th element of the domain.
pub fn domain_value(&self, n: u32) -> BFieldElement {
self.generator.mod_pow_u32(n) * self.offset
}

pub fn domain_values(&self) -> Vec<BFieldElement> {
Expand Down Expand Up @@ -124,6 +137,7 @@ impl ArithmeticDomain {
mod tests {
use assert2::let_assert;
use itertools::Itertools;
use proptest::collection::vec;
use proptest::prelude::*;
use proptest_arbitrary_interop::arb;
use test_strategy::proptest;
Expand Down Expand Up @@ -284,4 +298,22 @@ mod tests {
assert!(ArithmeticDomainError::TooSmallForHalving(i) == err);
}
}

#[proptest]
fn can_evaluate_polynomial_larger_than_domain(
#[strategy(1_usize..10)] _log_domain_length: usize,
#[strategy(1_usize..5)] _expansion_factor: usize,
#[strategy(Just(1 << #_log_domain_length))] domain_length: usize,
#[strategy(vec(arb(),#domain_length*#_expansion_factor))] coefficients: Vec<BFieldElement>,
#[strategy(arb())] offset: BFieldElement,
) {
let domain = ArithmeticDomain::of_length(domain_length)
.unwrap()
.with_offset(offset);
let polynomial = Polynomial::new(coefficients);

let values0 = domain.evaluate(&polynomial);
let values1 = polynomial.batch_evaluate(&domain.domain_values());
assert_eq!(values0, values1);
}
}
106 changes: 106 additions & 0 deletions triton-vm/src/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
use std::cell::RefCell;

use arbitrary::Arbitrary;

thread_local! {
pub(crate) static CONFIG: RefCell<Config> = RefCell::new(Config::default());
}

#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)]
pub enum CacheDecision {
#[default]
Cache,
NoCache,
}

#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)]
struct Config {
/// Whether to cache the [low-degree extended trace][lde] when [proving].
/// `None` means the decision is made automatically, based on free memory.
/// Can be accessed via [`Config::cache_lde_trace`].
///
/// [lde]: crate::table::master_table::MasterTable::low_degree_extend_all_columns
/// [proving]: crate::stark::Stark::prove
pub cache_lde_trace_overwrite: Option<CacheDecision>,
}

impl Config {
pub fn new() -> Self {
let maybe_overwrite = std::env::var("TVM_LDE_TRACE").map(|s| s.to_ascii_lowercase());
let cache_lde_trace_overwrite = match maybe_overwrite {
Ok(t) if &t == "cache" => Some(CacheDecision::Cache),
Ok(f) if &f == "no_cache" => Some(CacheDecision::NoCache),
_ => None,
};

Self {
cache_lde_trace_overwrite,
}
}
}

impl Default for Config {
fn default() -> Self {
Self::new()
}
}

/// Overwrite the automatic decision whether to cache the [low-degree extended trace][lde] when
/// [proving]. Takes precedence over the environment variable `TVM_LDE_TRACE`.
///
/// Caching the low-degree extended trace improves proving speed but requires more memory. It is
/// generally recommended to cache the trace. Triton VM will make an automatic decision based on
/// free memory. Use this function if you know your requirements better.
///
/// [lde]: crate::table::master_table::MasterTable::low_degree_extend_all_columns
/// [proving]: crate::stark::Stark::prove
pub fn overwrite_lde_trace_caching_to(decision: CacheDecision) {
CONFIG.with_borrow_mut(|config| config.cache_lde_trace_overwrite = Some(decision));
}

/// Should the [low-degree extended trace][lde] be cached? `None` means the
/// decision is made automatically, based on free memory.
///
/// [lde]: crate::table::master_table::MasterTable::low_degree_extend_all_columns
pub(crate) fn cache_lde_trace() -> Option<CacheDecision> {
CONFIG.with_borrow(|config| config.cache_lde_trace_overwrite)
}

#[cfg(test)]
mod tests {
use assert2::assert;
use twenty_first::prelude::*;

use crate::example_programs::FIBONACCI_SEQUENCE;
use crate::prelude::*;
use crate::profiler::TritonProfiler;
use crate::shared_tests::prove_with_low_security_level;

use super::*;

#[test]
fn triton_vm_can_generate_valid_proof_with_just_in_time_lde() {
overwrite_lde_trace_caching_to(CacheDecision::NoCache);
prove_and_verify_a_triton_vm_program();
}

#[test]
fn triton_vm_can_generate_valid_proof_with_cached_lde_trace() {
overwrite_lde_trace_caching_to(CacheDecision::Cache);
prove_and_verify_a_triton_vm_program();
}

fn prove_and_verify_a_triton_vm_program() {
let stdin = PublicInput::from(bfe_array![100]);
let secret_in = NonDeterminism::default();

let mut profiler = Some(TritonProfiler::new("Prove Fib 100"));
let (stark, claim, proof) =
prove_with_low_security_level(&FIBONACCI_SEQUENCE, stdin, secret_in, &mut profiler);
assert!(let Ok(()) = stark.verify(&claim, &proof, &mut None));

let mut profiler = profiler.unwrap();
let report = profiler.report();
println!("{report}");
}
}
11 changes: 11 additions & 0 deletions triton-vm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@
//! For a full overview over all available instructions and their effects, see the
//! [specification](https://triton-vm.org/spec/instructions.html).
//!
//! # Time / Memory Trade-Offs
//!
//! Parts of the [proof generation](Stark::prove) process can trade time for memory. The
//! [config] module provides ways to control these trade-offs. Additionally, and with lower
//! precedence, they can be controlled via the following environment variables:
//!
//! - `TVM_LDE_TRACE`: Set to `cache` to cache the low-degree extended trace. Set to `no_cache`
//! to not cache it. If unset (or set to anything else), Triton VM will make an automatic decision
//! based on free memory.
//!
//! # Examples
//!
//! Convenience function [`prove_program()`] as well as the [`prove()`] and [`verify()`] methods
Expand Down Expand Up @@ -148,6 +158,7 @@ use crate::prelude::*;

pub mod aet;
pub mod arithmetic_domain;
pub mod config;
pub mod error;
pub mod example_programs;
pub mod fri;
Expand Down
Loading

0 comments on commit 70b740e

Please sign in to comment.