Skip to content

Commit

Permalink
todo proof aggregator
Browse files Browse the repository at this point in the history
  • Loading branch information
erhant committed Oct 26, 2024
1 parent a9cd28a commit ee4caa8
Show file tree
Hide file tree
Showing 11 changed files with 198 additions and 43 deletions.
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[workspace]
members = ["lib", "program", "script", "embedder"]
members = ["lib", "program", "script", "embedder", "aggregator"]
resolver = "2"

[workspace.dependencies]
alloy-sol-types = "0.7.7"
sha2 = "0.10.8"
11 changes: 11 additions & 0 deletions aggregator/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[package]
name = "aggregation-program"
version = "1.1.0"
edition = "2021"
publish = false

[dependencies]
sha2.workspace = true
sp1-zkvm = { git = "https://github.com/succinctlabs/sp1.git", tag = "v1.0.1", features = [
"verify",
] }
94 changes: 94 additions & 0 deletions aggregator/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
//! The MIT License (MIT)
//!
//! Copyright (c) 2023 Succinct Labs
//!
//! Permission is hereby granted, free of charge, to any person obtaining a copy
//! of this software and associated documentation files (the "Software"), to deal
//! in the Software without restriction, including without limitation the rights
//! to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
//! copies of the Software, and to permit persons to whom the Software is
//! furnished to do so, subject to the following conditions:
//!
//! The above copyright notice and this permission notice shall be included in
//! all copies or substantial portions of the Software.
//!
//! THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
//! IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
//! FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
//! AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
//! LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
//! OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
//! THE SOFTWARE.
//!
//! A simple program that aggregates the proofs of multiple programs proven with the zkVM.
//!
//! https://github.com/succinctlabs/sp1/tree/main/examples/aggregation
//!
//! cargo prove build --elf-name riscv32im-succinct-aggregator-elf

#![no_main]
sp1_zkvm::entrypoint!(main);

use sha2::Digest;
use sha2::Sha256;

pub fn words_to_bytes_le(words: &[u32; 8]) -> [u8; 32] {
let mut bytes = [0u8; 32];
for i in 0..8 {
let word_bytes = words[i].to_le_bytes();
bytes[i * 4..(i + 1) * 4].copy_from_slice(&word_bytes);
}
bytes
}

/// Encode a list of vkeys and committed values into a single byte array. In the future this could
/// be a merkle tree or some other commitment scheme.
///
/// ( vkeys.len() || vkeys || committed_values[0].len as u32 || committed_values[0] || ... )
pub fn commit_proof_pairs(vkeys: &[[u32; 8]], committed_values: &[Vec<u8>]) -> Vec<u8> {
assert_eq!(vkeys.len(), committed_values.len());
let mut res = Vec::with_capacity(
4 + vkeys.len() * 32
+ committed_values.len() * 4
+ committed_values
.iter()
.map(|vals| vals.len())
.sum::<usize>(),
);

// Note we use big endian because abi.encodePacked in solidity does also
res.extend_from_slice(&(vkeys.len() as u32).to_be_bytes());
for vkey in vkeys.iter() {
res.extend_from_slice(&words_to_bytes_le(vkey));
}
for vals in committed_values.iter() {
res.extend_from_slice(&(vals.len() as u32).to_be_bytes());
res.extend_from_slice(vals);
}

res
}

pub fn main() {
// Read the verification keys.
let vkeys = sp1_zkvm::io::read::<Vec<[u32; 8]>>();

// Read the public values.
let public_values = sp1_zkvm::io::read::<Vec<Vec<u8>>>();

// Verify the proofs.
assert_eq!(vkeys.len(), public_values.len());
for i in 0..vkeys.len() {
let vkey = &vkeys[i];
let public_values = &public_values[i];
let public_values_digest = Sha256::digest(public_values);
sp1_zkvm::lib::verify::verify_sp1_proof(vkey, &public_values_digest.into());
}

// TODO: Do something interesting with the proofs here.
//
// For example, commit to the verified proofs in a merkle tree. For now, we'll just commit to
// all the (vkey, input) pairs.
let commitment = commit_proof_pairs(&vkeys, &public_values);
sp1_zkvm::io::commit_slice(&commitment);
}
Binary file added elf/riscv32im-succinct-aggregator-elf
Binary file not shown.
Binary file modified elf/riscv32im-succinct-zkvm-elf
Binary file not shown.
42 changes: 42 additions & 0 deletions lib/src/hora.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/// Kept as backup
///
/// This did not work due to OOM, maybe solve the issue with aggregation?
#![depracated]

use hora::core::{ann_index::ANNIndex, metrics::Metric};
use hora::index::{hnsw_idx::HNSWIndex, hnsw_params::HNSWParams};

alloy_sol_types::sol! {
/// The public values encoded as a struct that can be easily deserialized inside Solidity.
struct PublicValuesStruct {
uint32 k;
uint32[] dest;
}
}

/// Create a HNSW index from `samples` with `DotProduct` metric, and query it with a given `query` vector.
///
/// HNSW and dot-product is chosen as they turn out to be the least demanding for the zkVM.
///
/// Returns the indices of the top `top_k` samples in the index.
pub fn index_and_query(samples: Vec<Vec<f32>>, query: Vec<f32>, top_k: u32) -> Vec<u32> {
// let raw_samples = include_bytes!("../../data/foods-smol.json");
// let x = Vec::<Vec<f32>>::fr(raw_samples.iter());
// ensure each sample has the same dimension as the query
let len = query.len();
for sample in &samples {
assert_eq!(sample.len(), len);
}

// create & add samples to index
let mut index = HNSWIndex::<f32, u32>::new(len, &HNSWParams::<f32>::default());
for (i, sample) in samples.iter().enumerate() {
index.add(sample, i as u32).unwrap();
}

// construct HNSW
index.build(Metric::DotProduct).unwrap();

// make a query
index.search(&query, top_k as usize)
}
53 changes: 26 additions & 27 deletions lib/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,36 @@
use hora::core::{ann_index::ANNIndex, metrics::Metric};
use hora::index::{hnsw_idx::HNSWIndex, hnsw_params::HNSWParams};

alloy_sol_types::sol! {
/// The public values encoded as a struct that can be easily deserialized inside Solidity.
struct PublicValuesStruct {
uint32 k;
uint32[] dest;
uint32 idx; // index of the most similar sample
}
}

/// Create a HNSW index from `samples` with `DotProduct` metric, and query it with a given `query` vector.
///
/// HNSW and dot-product is chosen as they turn out to be the least demanding for the zkVM.
///
/// Returns the indices of the top `top_k` samples in the index.
pub fn index_and_query(samples: Vec<Vec<f32>>, query: Vec<f32>, top_k: u32) -> Vec<u32> {
// let raw_samples = include_bytes!("../../data/foods-smol.json");
// let x = Vec::<Vec<f32>>::fr(raw_samples.iter());
// ensure each sample has the same dimension as the query
let len = query.len();
for sample in &samples {
assert_eq!(sample.len(), len);
}
pub fn compute_best_sample(samples: &[Vec<f32>], query: &[f32]) -> (usize, f32) {
samples
.iter()
.map(|sample| sample.iter().zip(query).map(|(a, b)| a * b).sum::<f32>())
.enumerate()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.unwrap()
}

// create & add samples to index
let mut index = HNSWIndex::<f32, u32>::new(len, &HNSWParams::<f32>::default());
for (i, sample) in samples.iter().enumerate() {
index.add(sample, i as u32).unwrap();
pub fn iterative_similarity_search(
samples: Vec<Vec<f32>>,
query: Vec<f32>,
batch_size: usize,
) -> usize {
let mut current_samples = samples;
while current_samples.len() > batch_size {
let mut best_samples = Vec::new();
for chunk in current_samples.chunks(batch_size) {
best_samples.push(compute_best_sample(chunk, &query));
}
current_samples = best_samples
.iter()
.map(|&(idx, _)| current_samples[idx].clone())
.collect::<Vec<_>>();
}

// construct HNSW
index.build(Metric::DotProduct).unwrap();

// make a query
index.search(&query, top_k as usize)
let (best_idx, _) = compute_best_sample(&current_samples, &query);
best_idx
}
4 changes: 3 additions & 1 deletion program/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,7 @@ edition = "2021"

[dependencies]
alloy-sol-types = { workspace = true }
sp1-zkvm = { git = "https://github.com/succinctlabs/sp1.git", tag = "v1.0.1" }
sp1-zkvm = { git = "https://github.com/succinctlabs/sp1.git", tag = "v1.0.1", features = [
"verify",
] }
zkvdb-lib = { path = "../lib" }
16 changes: 6 additions & 10 deletions program/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
//! A simple program that takes a number `n` as input, and writes the `n-1`th and `n`th fibonacci
//! number as an output.
//! Given a set of samples, a query, and a number k, this program finds the k samples that are most similar to the query.

// These two lines are necessary for the program to properly compile.
//
// Under the hood, we wrap your main function with some extra code so that it behaves properly
// inside the zkVM.
#![no_main]
sp1_zkvm::entrypoint!(main);

use alloy_sol_types::SolValue;
use zkvdb_lib::{index_and_query, PublicValuesStruct};
use zkvdb_lib::{compute_best_sample, PublicValuesStruct};

pub fn main() {
let samples = sp1_zkvm::io::read::<Vec<Vec<f32>>>();
let query = sp1_zkvm::io::read::<Vec<f32>>();
let k = sp1_zkvm::io::read::<u32>();

let dest = index_and_query(samples, query, k);
let (idx, _) = compute_best_sample(&samples, &query);

let bytes = PublicValuesStruct::abi_encode(&PublicValuesStruct { k, dest });
let bytes = PublicValuesStruct::abi_encode(&PublicValuesStruct { idx: idx as u32 });
sp1_zkvm::io::commit_slice(&bytes);

// TODO: return hash of the returned vector here as well
}
7 changes: 3 additions & 4 deletions script/src/bin/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,10 @@ fn main() {

// Read the output.
let decoded = PublicValuesStruct::abi_decode(output.as_slice(), true).unwrap();
let PublicValuesStruct { k, dest } = decoded;
println!("k: {}", k);
println!("dest: {:?}", dest);
let PublicValuesStruct { idx } = decoded;
println!("Closest idx: {}", idx);

let expected_dest = zkvdb_lib::index_and_query(samples, query, k);
let expected_dest = zkvdb_lib::similarity_search(samples, query, k);
assert_eq!(dest, expected_dest);
println!("Values are correct!");

Expand Down

0 comments on commit ee4caa8

Please sign in to comment.