Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: fix membership check #510

Merged
merged 2 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions crates/proof-of-sql/src/base/database/join_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::base::{
map::{IndexMap, IndexSet},
scalar::Scalar,
};
use alloc::{vec, vec::Vec};
use alloc::vec::Vec;
use bumpalo::Bump;
use core::cmp::Ordering;
use itertools::Itertools;
Expand Down Expand Up @@ -66,27 +66,29 @@ pub(crate) fn ordered_set_union<'a, S: Scalar>(
pub(crate) fn get_multiplicities<'a, S: Scalar>(
data: &[Column<'a, S>],
unique: &[Column<'a, S>],
) -> Vec<u64> {
alloc: &'a Bump,
) -> &'a [i128] {
// If unique is empty, the multiplicities vector is empty
if unique.is_empty() {
return Vec::new();
return alloc.alloc_slice_fill_copy(0, 0_i128);
}
let num_unique_rows = unique[0].len();
// If data is empty, all multiplicities are 0
if data.is_empty() {
return vec![0; num_unique_rows];
return alloc.alloc_slice_fill_copy(num_unique_rows, 0_i128);
}
let num_rows = data[0].len();
(0..num_unique_rows)
let multiplicities = (0..num_unique_rows)
.map(|unique_index| {
(0..num_rows)
.filter(|&data_index| {
compare_single_row_of_tables(data, unique, data_index, unique_index)
== Ok(Ordering::Equal)
})
.count() as u64
.count() as i128
})
.collect::<Vec<_>>()
.collect::<Vec<_>>();
alloc.alloc_slice_copy(multiplicities.as_slice())
}

/// Compute the CROSS JOIN / cartesian product of two tables.
Expand Down Expand Up @@ -845,36 +847,38 @@ mod tests {
/// Get Multiplicities
#[test]
fn we_can_get_multiplicities_empty_scenarios() {
let alloc = Bump::new();
let empty_data: Vec<Column<TestScalar>> = vec![];
let empty_unique: Vec<Column<TestScalar>> = vec![];

// 1) Both 'data' and 'unique' empty
let result = get_multiplicities(&empty_data, &empty_unique);
let result = get_multiplicities(&empty_data, &empty_unique, &alloc);
assert!(
result.is_empty(),
"When both are empty, result should be empty"
);

// 2) 'unique' empty, 'data' non-empty
let nonempty_data = vec![Column::<TestScalar>::Boolean(&[true, false])];
let result = get_multiplicities(&nonempty_data, &empty_unique);
let result = get_multiplicities(&nonempty_data, &empty_unique, &alloc);
assert!(
result.is_empty(),
"When 'unique' is empty, result must be empty"
);

// 3) 'unique' non-empty, 'data' empty => all zeros
let nonempty_unique = vec![Column::<TestScalar>::Boolean(&[true, true, false])];
let result = get_multiplicities(&empty_data, &nonempty_unique);
let result = get_multiplicities(&empty_data, &nonempty_unique, &alloc);
assert_eq!(
result,
vec![0_u64; 3],
&[0_i128, 0, 0],
"If data is empty, multiplicities should be zeros"
);
}

#[test]
fn we_can_get_multiplicities() {
let alloc = Bump::new();
let data = vec![
Column::<TestScalar>::Boolean(&[true, false, true, true, true]),
Column::<TestScalar>::Int(&[1, 2, 1, 1, 2]),
Expand All @@ -886,7 +890,7 @@ mod tests {
Column::<TestScalar>::BigInt(&[2_i64, 4, 1, 1]),
];

let result = get_multiplicities(&data, &unique);
assert_eq!(result, vec![1, 0, 3, 1], "Expected multiplicities");
let result = get_multiplicities(&data, &unique, &alloc);
assert_eq!(result, &[1, 0, 3, 1], "Expected multiplicities");
}
}
1 change: 1 addition & 0 deletions crates/proof-of-sql/src/base/scalar/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub trait Scalar:
+ for<'a> core::convert::From<&'a i64> // Required for `Column` to implement `MultilinearExtension`
+ for<'a> core::convert::From<&'a i128> // Required for `Column` to implement `MultilinearExtension`
+ for<'a> core::convert::From<&'a u8> // Required for `Column` to implement `MultilinearExtension`
+ for<'a> core::convert::From<&'a u64> // Required for `Column` to implement `MultilinearExtension`
+ core::convert::TryInto <bool>
+ core::convert::TryInto<u8>
+ core::convert::TryInto <i8>
Expand Down
51 changes: 40 additions & 11 deletions crates/proof-of-sql/src/sql/proof_gadgets/membership_check.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
use crate::{
base::{database::Column, proof::ProofError, scalar::Scalar, slice_ops},
base::{
database::{join_util::get_multiplicities, Column},
proof::ProofError,
scalar::Scalar,
slice_ops,
},
sql::{
proof::{
FinalRoundBuilder, FirstRoundBuilder, SumcheckSubpolynomialType, VerificationBuilder,
Expand All @@ -12,36 +17,59 @@ use bumpalo::Bump;
use num_traits::{One, Zero};

/// Perform first round evaluation of the membership check.
///
/// # Panics
/// Panics if the number of source and candidate columns are not equal
/// or if the number of columns is zero.
pub(crate) fn first_round_evaluate_membership_check<'a, S: Scalar>(
builder: &mut FirstRoundBuilder<'a, S>,
multiplicities: &'a [i128],
) {
builder.produce_intermediate_mle(multiplicities);
alloc: &'a Bump,
columns: &[Column<'a, S>],
candidate_subset: &[Column<'a, S>],
) -> &'a [i128] {
assert_eq!(
columns.len(),
candidate_subset.len(),
"The number of source and candidate columns should be equal"
);
assert!(
!columns.is_empty(),
"The number of source columns should be greater than 0"
);
let multiplicities = get_multiplicities::<S>(candidate_subset, columns, alloc);
builder.produce_intermediate_mle(multiplicities as &[_]);
builder.request_post_result_challenges(2);
multiplicities
}

/// Perform final round evaluation of the membership check.
///
/// # Panics
/// Panics if the number of source and candidate columns are not equal.
/// Panics if the number of source and candidate columns are not equal
/// or if the number of columns is zero.
#[allow(dead_code)]
#[allow(clippy::too_many_arguments)]
pub(crate) fn final_round_evaluate_membership_check<'a, S: Scalar>(
builder: &mut FinalRoundBuilder<'a, S>,
alloc: &'a Bump,
alpha: S,
beta: S,
columns: &[Column<'a, S>],
candidate_subset: &[Column<'a, S>],
multiplicities: &'a [i128],
input_ones: &'a [bool],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: We still should have input_ones be an argument to mirror verifier_evaluate.

candidate_ones: &'a [bool],
) {
columns: &[Column<'a, S>],
candidate_subset: &[Column<'a, S>],
) -> &'a [i128] {
assert_eq!(
columns.len(),
candidate_subset.len(),
"The number of source and candidate columns should be equal"
);
assert!(
!columns.is_empty(),
"The number of source columns should be greater than 0"
);
let multiplicities = get_multiplicities::<S>(candidate_subset, columns, alloc);

// Fold the columns
let c_fold = alloc.alloc_slice_fill_copy(input_ones.len(), Zero::zero());
fold_columns(c_fold, alpha, beta, columns);
Expand Down Expand Up @@ -96,6 +124,7 @@ pub(crate) fn final_round_evaluate_membership_check<'a, S: Scalar>(
(-S::one(), vec![Box::new(candidate_ones as &[_])]),
],
);
multiplicities
}

#[allow(dead_code)]
Expand All @@ -107,7 +136,7 @@ pub(crate) fn verify_membership_check<S: Scalar>(
candidate_one_eval: S,
column_evals: &[S],
candidate_evals: &[S],
) -> Result<(), ProofError> {
) -> Result<S, ProofError> {
// Check that the source and candidate columns have the same amount of columns
if column_evals.len() != candidate_evals.len() {
return Err(ProofError::VerificationError {
Expand Down Expand Up @@ -141,5 +170,5 @@ pub(crate) fn verify_membership_check<S: Scalar>(
2,
)?;

Ok(())
Ok(multiplicity_eval)
}
Loading