Skip to content

Commit

Permalink
refactor!: make ProofStream non-generic
Browse files Browse the repository at this point in the history
Decrease verbosity by removing `ProofStream`'s generic `AlgebraicHasher`
type. The only type in use was `Tip5`.

Changing the hash function is a major re-write of Triton VM, requiring
much more work than the generic might suggest.
  • Loading branch information
jan-ferdinand committed Mar 20, 2024
1 parent 05bd271 commit bde928d
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 68 deletions.
2 changes: 1 addition & 1 deletion triton-vm/benches/proof_size.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ fn log_2_fri_domain_length(stark: Stark, proof: &Proof) -> u32 {
/// sizes for that type are accumulated.
fn break_down_proof_size(proof: &Proof) -> HashMap<String, usize> {
let mut proof_size_breakdown = HashMap::new();
let proof_stream: ProofStream<Tip5> = proof.try_into().unwrap();
let proof_stream = ProofStream::try_from(proof).unwrap();
for proof_item in &proof_stream.items {
let item_name = proof_item.to_string();
let item_len = proof_item.encode().len();
Expand Down
2 changes: 1 addition & 1 deletion triton-vm/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ pub enum ProofStreamError {
TooManyLog2PaddedHeights,

#[error(transparent)]
DecodingError(#[from] <ProofStream<Tip5> as BFieldCodec>::Error),
DecodingError(#[from] <ProofStream as BFieldCodec>::Error),
}

#[non_exhaustive]
Expand Down
45 changes: 21 additions & 24 deletions triton-vm/src/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub struct Fri<H: AlgebraicHasher> {
}

struct FriProver<'stream, H: AlgebraicHasher> {
proof_stream: &'stream mut ProofStream<H>,
proof_stream: &'stream mut ProofStream,
rounds: Vec<ProverRound<H>>,
first_round_domain: ArithmeticDomain,
num_rounds: usize,
Expand Down Expand Up @@ -190,14 +190,15 @@ impl<H: AlgebraicHasher> ProverRound<H> {
}

struct FriVerifier<'stream, H: AlgebraicHasher> {
proof_stream: &'stream mut ProofStream<H>,
proof_stream: &'stream mut ProofStream,
rounds: Vec<VerifierRound>,
first_round_domain: ArithmeticDomain,
last_round_codeword: Vec<XFieldElement>,
last_round_max_degree: usize,
num_rounds: usize,
num_collinearity_checks: usize,
first_round_collinearity_check_indices: Vec<usize>,
_hasher: PhantomData<H>,
}

struct VerifierRound {
Expand Down Expand Up @@ -564,7 +565,7 @@ impl<H: AlgebraicHasher> Fri<H> {
pub fn prove(
&self,
codeword: &[XFieldElement],
proof_stream: &mut ProofStream<H>,
proof_stream: &mut ProofStream,
) -> ProverResult<Vec<usize>> {
let mut prover = self.prover(proof_stream);

Expand All @@ -575,7 +576,7 @@ impl<H: AlgebraicHasher> Fri<H> {
Ok(indices)
}

fn prover<'stream>(&'stream self, proof_stream: &'stream mut ProofStream<H>) -> FriProver<H> {
fn prover<'stream>(&'stream self, proof_stream: &'stream mut ProofStream) -> FriProver<H> {
FriProver {
proof_stream,
rounds: vec![],
Expand All @@ -590,7 +591,7 @@ impl<H: AlgebraicHasher> Fri<H> {
/// Returns the indices and revealed elements of the codeword at the top level of the FRI proof.
pub fn verify(
&self,
proof_stream: &mut ProofStream<H>,
proof_stream: &mut ProofStream,
maybe_profiler: &mut Option<TritonProfiler>,
) -> VerifierResult<Vec<(usize, XFieldElement)>> {
prof_start!(maybe_profiler, "init");
Expand All @@ -609,10 +610,7 @@ impl<H: AlgebraicHasher> Fri<H> {
Ok(verifier.first_round_partially_revealed_codeword())
}

fn verifier<'stream>(
&'stream self,
proof_stream: &'stream mut ProofStream<H>,
) -> FriVerifier<H> {
fn verifier<'stream>(&'stream self, proof_stream: &'stream mut ProofStream) -> FriVerifier<H> {
FriVerifier {
proof_stream,
rounds: vec![],
Expand All @@ -622,6 +620,7 @@ impl<H: AlgebraicHasher> Fri<H> {
num_rounds: self.num_rounds(),
num_collinearity_checks: self.num_collinearity_checks,
first_round_collinearity_check_indices: vec![],
_hasher: PhantomData,
}
}

Expand Down Expand Up @@ -848,7 +847,7 @@ mod tests {
fri.prove(&codeword, &mut prover_proof_stream).unwrap();

let proof = (&prover_proof_stream).into();
let verifier_proof_stream = ProofStream::<Tip5>::try_from(&proof).unwrap();
let verifier_proof_stream = ProofStream::try_from(&proof).unwrap();

let prover_items = prover_proof_stream.items.iter();
let verifier_items = verifier_proof_stream.items.iter();
Expand Down Expand Up @@ -884,19 +883,17 @@ mod tests {
}

#[must_use]
fn prepare_proof_stream_for_verification<H: AlgebraicHasher>(
mut proof_stream: ProofStream<H>,
) -> ProofStream<H> {
fn prepare_proof_stream_for_verification(mut proof_stream: ProofStream) -> ProofStream {
proof_stream.items_index = 0;
proof_stream.sponge = H::init();
proof_stream.sponge = Tip5::init();
proof_stream
}

#[must_use]
fn modify_last_round_codeword_in_proof_stream_using_seed<H: AlgebraicHasher>(
mut proof_stream: ProofStream<H>,
fn modify_last_round_codeword_in_proof_stream_using_seed(
mut proof_stream: ProofStream,
seed: u64,
) -> ProofStream<H> {
) -> ProofStream {
let mut proof_items = proof_stream.items.iter_mut();
let last_round_codeword = proof_items.find_map(fri_codeword_filter()).unwrap();

Expand Down Expand Up @@ -937,10 +934,10 @@ mod tests {
}

#[must_use]
fn change_size_of_some_fri_response_in_proof_stream_using_seed<H: AlgebraicHasher>(
mut proof_stream: ProofStream<H>,
fn change_size_of_some_fri_response_in_proof_stream_using_seed(
mut proof_stream: ProofStream,
seed: u64,
) -> ProofStream<H> {
) -> ProofStream {
let proof_items = proof_stream.items.iter_mut();
let fri_responses = proof_items.filter_map(fri_response_filter());

Expand Down Expand Up @@ -989,10 +986,10 @@ mod tests {
}

#[must_use]
fn modify_some_auth_structure_in_proof_stream_using_seed<H: AlgebraicHasher>(
mut proof_stream: ProofStream<H>,
fn modify_some_auth_structure_in_proof_stream_using_seed(
mut proof_stream: ProofStream,
seed: u64,
) -> ProofStream<H> {
) -> ProofStream {
let proof_items = proof_stream.items.iter_mut();
let auth_structures = proof_items.filter_map(non_trivial_auth_structure_filter());

Expand Down Expand Up @@ -1021,7 +1018,7 @@ mod tests {
#[proptest]
fn verifying_arbitrary_proof_does_not_panic(
#[strategy(arbitrary_fri())] fri: Fri<Tip5>,
#[strategy(arb())] mut proof_stream: ProofStream<Tip5>,
#[strategy(arb())] mut proof_stream: ProofStream,
) {
let _ = fri.verify(&mut proof_stream, &mut None);
}
Expand Down
6 changes: 3 additions & 3 deletions triton-vm/src/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ impl Proof {
/// This is an upper bound on the length of the computation this proof is for.
/// It is one of the main contributing factors to the length of the FRI domain.
pub fn padded_height(&self) -> Result<usize, ProofStreamError> {
let proof_stream = ProofStream::<Tip5>::try_from(self)?;
let proof_stream = ProofStream::try_from(self)?;
let proof_items = proof_stream.items.into_iter();
let log_2_padded_heights = proof_items
.filter_map(|item| item.try_into_log2_padded_height().ok())
Expand Down Expand Up @@ -116,7 +116,7 @@ mod tests {

#[proptest(cases = 10)]
fn proof_with_no_log_2_padded_height_gives_err(#[strategy(arb())] root: Digest) {
let mut proof_stream = ProofStream::<Tip5>::new();
let mut proof_stream = ProofStream::new();
proof_stream.enqueue(ProofItem::MerkleRoot(root));
let proof: Proof = proof_stream.into();
let maybe_padded_height = proof.padded_height();
Expand All @@ -125,7 +125,7 @@ mod tests {

#[proptest(cases = 10)]
fn proof_with_multiple_log_2_padded_height_gives_err(#[strategy(arb())] root: Digest) {
let mut proof_stream = ProofStream::<Tip5>::new();
let mut proof_stream = ProofStream::new();
proof_stream.enqueue(ProofItem::Log2PaddedHeight(8));
proof_stream.enqueue(ProofItem::MerkleRoot(root));
proof_stream.enqueue(ProofItem::Log2PaddedHeight(7));
Expand Down
9 changes: 4 additions & 5 deletions triton-vm/src/proof_item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ pub(crate) mod tests {
use proptest::prelude::*;
use strum::IntoEnumIterator;
use test_strategy::proptest;
use twenty_first::prelude::Tip5;

use crate::proof::Proof;
use crate::proof_stream::ProofStream;
Expand All @@ -141,11 +140,11 @@ pub(crate) mod tests {
#[proptest]
fn serialize_fri_response_in_proof_stream(leaved_merkle_tree: LeavedMerkleTreeTestData) {
let fri_response = leaved_merkle_tree.into_fri_response();
let mut proof_stream = ProofStream::<Tip5>::new();
let mut proof_stream = ProofStream::new();
proof_stream.enqueue(ProofItem::FriResponse(fri_response.clone()));
let proof: Proof = proof_stream.into();

let_assert!(Ok(mut proof_stream) = ProofStream::<Tip5>::try_from(&proof));
let_assert!(Ok(mut proof_stream) = ProofStream::try_from(&proof));
let_assert!(Ok(proof_item) = proof_stream.dequeue());
let_assert!(Ok(fri_response_) = proof_item.try_into_fri_response());
prop_assert_eq!(fri_response, fri_response_);
Expand All @@ -166,11 +165,11 @@ pub(crate) mod tests {
leaved_merkle_tree: LeavedMerkleTreeTestData,
) {
let auth_structure = leaved_merkle_tree.auth_structure;
let mut proof_stream = ProofStream::<Tip5>::new();
let mut proof_stream = ProofStream::new();
proof_stream.enqueue(ProofItem::AuthenticationStructure(auth_structure.clone()));
let proof: Proof = proof_stream.into();

let_assert!(Ok(mut proof_stream) = ProofStream::<Tip5>::try_from(&proof));
let_assert!(Ok(mut proof_stream) = ProofStream::try_from(&proof));
let_assert!(Ok(proof_item) = proof_stream.dequeue());
let_assert!(Ok(auth_structure_) = proof_item.try_into_authentication_structure());
prop_assert_eq!(auth_structure, auth_structure_);
Expand Down
45 changes: 15 additions & 30 deletions triton-vm/src/proof_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,22 @@ use crate::proof::Proof;
use crate::proof_item::ProofItem;

#[derive(Debug, Default, Clone, Eq, PartialEq, BFieldCodec, Arbitrary)]
pub struct ProofStream<H>
where
H: AlgebraicHasher,
{
pub struct ProofStream {
pub items: Vec<ProofItem>,

#[bfield_codec(ignore)]
pub items_index: usize,

#[bfield_codec(ignore)]
pub sponge: H,
pub sponge: Tip5,
}

impl<H> ProofStream<H>
where
H: AlgebraicHasher,
{
impl ProofStream {
pub fn new() -> Self {
ProofStream {
items: vec![],
items_index: 0,
sponge: H::init(),
sponge: Tip5::init(),
}
}

Expand Down Expand Up @@ -99,10 +93,7 @@ where
}
}

impl<H> TryFrom<&Proof> for ProofStream<H>
where
H: AlgebraicHasher,
{
impl TryFrom<&Proof> for ProofStream {
type Error = ProofStreamError;

fn try_from(proof: &Proof) -> Result<Self, ProofStreamError> {
Expand All @@ -111,20 +102,14 @@ where
}
}

impl<H> From<&ProofStream<H>> for Proof
where
H: AlgebraicHasher,
{
fn from(proof_stream: &ProofStream<H>) -> Self {
impl From<&ProofStream> for Proof {
fn from(proof_stream: &ProofStream) -> Self {
Proof(proof_stream.encode())
}
}

impl<H> From<ProofStream<H>> for Proof
where
H: AlgebraicHasher,
{
fn from(proof_stream: ProofStream<H>) -> Self {
impl From<ProofStream> for Proof {
fn from(proof_stream: ProofStream) -> Self {
(&proof_stream).into()
}
}
Expand Down Expand Up @@ -165,7 +150,7 @@ mod tests {
let fri_response = leaved_merkle_tree.into_fri_response();

let mut sponge_states = VecDeque::new();
let mut proof_stream = ProofStream::<Tip5>::new();
let mut proof_stream = ProofStream::new();

sponge_states.push_back(proof_stream.sponge.state);
proof_stream.enqueue(ProofItem::AuthenticationStructure(auth_structure.clone()));
Expand All @@ -188,7 +173,7 @@ mod tests {
sponge_states.push_back(proof_stream.sponge.state);

let proof = proof_stream.into();
let mut proof_stream: ProofStream<Tip5> = ProofStream::try_from(&proof).unwrap();
let mut proof_stream = ProofStream::try_from(&proof).unwrap();

assert!(sponge_states.pop_front() == Some(proof_stream.sponge.state));
let_assert!(Ok(proof_item) = proof_stream.dequeue());
Expand Down Expand Up @@ -252,7 +237,7 @@ mod tests {
revealed_leaves,
};

let mut proof_stream = ProofStream::<Tip5>::new();
let mut proof_stream = ProofStream::new();
proof_stream.enqueue(ProofItem::FriResponse(fri_response));

// TODO: Also check that deserializing from Proof works here.
Expand Down Expand Up @@ -280,13 +265,13 @@ mod tests {

#[test]
fn dequeuing_from_empty_stream_fails() {
let mut proof_stream = ProofStream::<Tip5>::new();
let mut proof_stream = ProofStream::new();
let_assert!(Err(ProofStreamError::EmptyQueue) = proof_stream.dequeue());
}

#[test]
fn dequeuing_more_items_than_have_been_enqueued_fails() {
let mut proof_stream = ProofStream::<Tip5>::new();
let mut proof_stream = ProofStream::new();
proof_stream.enqueue(ProofItem::FriCodeword(vec![]));
proof_stream.enqueue(ProofItem::Log2PaddedHeight(7));

Expand All @@ -297,6 +282,6 @@ mod tests {

#[test]
fn encoded_length_of_prove_stream_is_not_known_at_compile_time() {
assert!(ProofStream::<Tip5>::static_length().is_none());
assert!(ProofStream::static_length().is_none());
}
}
9 changes: 5 additions & 4 deletions triton-vm/src/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ use crate::table::NUM_EXT_COLUMNS;
#[deprecated(since = "0.37.0", note = "Use `Tip5` directly instead.")]
pub type StarkHasher = Tip5;

pub type StarkProofStream = ProofStream<Tip5>;
#[deprecated(since = "0.39.0", note = "Use `ProofStream` directly instead.")]
pub type StarkProofStream = ProofStream;

/// The Merkle tree maker in use. Keeping this as a type alias should make it easier to switch
/// between different Merkle tree makers.
Expand Down Expand Up @@ -123,7 +124,7 @@ impl Stark {
maybe_profiler: &mut Option<TritonProfiler>,
) -> Result<Proof, ProvingError> {
prof_start!(maybe_profiler, "Fiat-Shamir: claim", "hash");
let mut proof_stream = StarkProofStream::new();
let mut proof_stream = ProofStream::new();
proof_stream.alter_fiat_shamir_state_with(claim);
prof_stop!(maybe_profiler, "Fiat-Shamir: claim");

Expand Down Expand Up @@ -499,7 +500,7 @@ impl Stark {
}

fn sample_linear_combination_weights(
proof_stream: &mut ProofStream<Tip5>,
proof_stream: &mut ProofStream,
) -> (
Array1<XFieldElement>,
Array1<XFieldElement>,
Expand Down Expand Up @@ -690,7 +691,7 @@ impl Stark {
maybe_profiler: &mut Option<TritonProfiler>,
) -> Result<(), VerificationError> {
prof_start!(maybe_profiler, "deserialize");
let mut proof_stream = StarkProofStream::try_from(proof)?;
let mut proof_stream = ProofStream::try_from(proof)?;
prof_stop!(maybe_profiler, "deserialize");

prof_start!(maybe_profiler, "Fiat-Shamir: Claim", "hash");
Expand Down

0 comments on commit bde928d

Please sign in to comment.