Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug in poly-api regarding copy_coeffs() in rust #715

Merged
merged 1 commit into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,10 @@ namespace icicle {

const int64_t deg_a = degree(a);
const int64_t deg_b = degree(b);
ICICLE_ASSERT(deg_b >= 0) << "Polynomial division: divide by zero polynomial";
if (deg_b < 0) {
ICICLE_LOG_ERROR << "Polynomial division: divide by zero polynomial. Skipping computation.";
return;
}

// init: Q=0, R=a
Q->allocate(deg_a - deg_b + 1, State::Coefficients, true /*=memset zeros*/);
Expand Down Expand Up @@ -346,7 +349,10 @@ namespace icicle {
numerator->transform_to_coefficients();
auto numerator_coeffs = get_context_storage_mutable(numerator);
const auto N = numerator->get_nof_elements();
ICICLE_ASSERT(vanishing_poly_degree <= N) << "divide_by_vanishing_polynomial(): degree is too large";
if (vanishing_poly_degree > N) {
ICICLE_LOG_ERROR << "divide_by_vanishing_polynomial(): degree is too large. Skipping computation.";
return;
}

out->allocate(N, State::Coefficients, true /*=set zeros*/);
add_monomial_inplace(out, C::zero() - C::one(), 0); //-1
Expand Down Expand Up @@ -384,8 +390,10 @@ namespace icicle {
void divide_by_vanishing_case_2N(PolyContext out, PolyContext numerator, uint64_t vanishing_poly_degree)
{
// in that special case the numertaor has 2N elements and output will be N elements
ICICLE_ASSERT(numerator->get_nof_elements() == 2 * vanishing_poly_degree)
<< "invalid input size. Expecting numerator to be of size 2N";
if (numerator->get_nof_elements() != 2 * vanishing_poly_degree) {
ICICLE_LOG_ERROR << "divide_by_vanishing_case_2N(): invalid input size. Skipping computation.";
return;
}

// In the case where deg(P)=2N, I can transform numerator to Reversed-evals -> The second half is
// a reversed-coset of size N with coset-gen the 2N-th root of unity.
Expand Down Expand Up @@ -423,8 +431,10 @@ namespace icicle {
void divide_by_vanishing_case_N(PolyContext out, PolyContext numerator, uint64_t vanishing_poly_degree)
{
// in that special case the numertaor has N elements and output will be N elements
ICICLE_ASSERT(numerator->get_nof_elements() == vanishing_poly_degree)
<< "invalid input size. Expecting numerator to be of size N";
if (numerator->get_nof_elements() != vanishing_poly_degree) {
ICICLE_LOG_ERROR << "divide_by_vanishing_case_N(): invalid input size. Skipping computation.";
return;
}

const int N = vanishing_poly_degree;
numerator->transform_to_coefficients(N);
Expand Down Expand Up @@ -572,7 +582,8 @@ namespace icicle {
ntt(d_evals, poly_size, NTTDir::kForward, ntt_config, d_evals);
} break;
default:
ICICLE_ASSERT(false) << "Invalid state to compute evaluations";
ICICLE_LOG_ERROR << "Panic: Invalid state to compute evaluations";
ICICLE_ASSERT(false);
break;
}
}
Expand All @@ -596,8 +607,9 @@ namespace icicle {
const bool is_valid_end_idx = end_idx < nof_coeffs && end_idx >= start_idx;
const bool is_valid_indices = is_valid_start_idx && is_valid_end_idx;
if (!is_valid_indices) {
// return -1 instead? I could but 'get_coeff()' cannot with its current declaration
ICICLE_ASSERT(false) << "copy_coeffs() invalid indices";
ICICLE_LOG_ERROR << "copy_coeffs() invalid indices (start=" << start_idx << ", end=" << end_idx
<< ", nof_coeffs=" << nof_coeffs << "). Skipping copy...";
return 0;
}

op->transform_to_coefficients();
Expand Down
7 changes: 6 additions & 1 deletion wrappers/rust/icicle-core/src/polynomials/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ macro_rules! impl_univariate_polynomial_api {
fn copy_coeffs<S: HostOrDeviceSlice<Self::Field> + ?Sized>(&self, start_idx: u64, coeffs: &mut S) {
let coeffs_len = coeffs.len() as u64;
let nof_coeffs = self.get_nof_coeffs();
let end_idx = cmp::min(nof_coeffs, start_idx + coeffs_len - 1);
let end_idx = cmp::min(nof_coeffs - 1, start_idx + coeffs_len - 1);

unsafe {
copy_coeffs(self.handle, coeffs.as_mut_ptr(), start_idx, end_idx);
Expand Down Expand Up @@ -626,6 +626,11 @@ macro_rules! impl_polynomial_tests {
f.copy_coeffs(0, HostSlice::from_mut_slice(&mut host_mem));
assert_eq!(host_mem, coeffs);

// read into larger buffer
let mut host_mem_large = vec![$field::zero(); coeffs.len() + 10];
f.copy_coeffs(0, HostSlice::from_mut_slice(&mut host_mem_large));
assert_eq!(host_mem_large[..coeffs.len()], coeffs);

// read coeffs to device memory
let mut device_mem = DeviceVec::<$field>::device_malloc(coeffs.len()).unwrap();
f.copy_coeffs(0, &mut device_mem[..]);
Expand Down
Loading