Skip to content

Commit

Permalink
fix bug in poly-api regarding copy_coeffs() in rust (#715)
Browse files Browse the repository at this point in the history
  • Loading branch information
yshekel authored Dec 31, 2024
1 parent 46a9c72 commit eec29d0
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,10 @@ namespace icicle {

const int64_t deg_a = degree(a);
const int64_t deg_b = degree(b);
ICICLE_ASSERT(deg_b >= 0) << "Polynomial division: divide by zero polynomial";
if (deg_b < 0) {
ICICLE_LOG_ERROR << "Polynomial division: divide by zero polynomial. Skipping computation.";
return;
}

// init: Q=0, R=a
Q->allocate(deg_a - deg_b + 1, State::Coefficients, true /*=memset zeros*/);
Expand Down Expand Up @@ -346,7 +349,10 @@ namespace icicle {
numerator->transform_to_coefficients();
auto numerator_coeffs = get_context_storage_mutable(numerator);
const auto N = numerator->get_nof_elements();
ICICLE_ASSERT(vanishing_poly_degree <= N) << "divide_by_vanishing_polynomial(): degree is too large";
if (vanishing_poly_degree > N) {
ICICLE_LOG_ERROR << "divide_by_vanishing_polynomial(): degree is too large. Skipping computation.";
return;
}

out->allocate(N, State::Coefficients, true /*=set zeros*/);
add_monomial_inplace(out, C::zero() - C::one(), 0); //-1
Expand Down Expand Up @@ -384,8 +390,10 @@ namespace icicle {
void divide_by_vanishing_case_2N(PolyContext out, PolyContext numerator, uint64_t vanishing_poly_degree)
{
// in that special case the numertaor has 2N elements and output will be N elements
ICICLE_ASSERT(numerator->get_nof_elements() == 2 * vanishing_poly_degree)
<< "invalid input size. Expecting numerator to be of size 2N";
if (numerator->get_nof_elements() != 2 * vanishing_poly_degree) {
ICICLE_LOG_ERROR << "divide_by_vanishing_case_2N(): invalid input size. Skipping computation.";
return;
}

// In the case where deg(P)=2N, I can transform numerator to Reversed-evals -> The second half is
// a reversed-coset of size N with coset-gen the 2N-th root of unity.
Expand Down Expand Up @@ -423,8 +431,10 @@ namespace icicle {
void divide_by_vanishing_case_N(PolyContext out, PolyContext numerator, uint64_t vanishing_poly_degree)
{
// in that special case the numertaor has N elements and output will be N elements
ICICLE_ASSERT(numerator->get_nof_elements() == vanishing_poly_degree)
<< "invalid input size. Expecting numerator to be of size N";
if (numerator->get_nof_elements() != vanishing_poly_degree) {
ICICLE_LOG_ERROR << "divide_by_vanishing_case_N(): invalid input size. Skipping computation.";
return;
}

const int N = vanishing_poly_degree;
numerator->transform_to_coefficients(N);
Expand Down Expand Up @@ -572,7 +582,8 @@ namespace icicle {
ntt(d_evals, poly_size, NTTDir::kForward, ntt_config, d_evals);
} break;
default:
ICICLE_ASSERT(false) << "Invalid state to compute evaluations";
ICICLE_LOG_ERROR << "Panic: Invalid state to compute evaluations";
ICICLE_ASSERT(false);
break;
}
}
Expand All @@ -596,8 +607,9 @@ namespace icicle {
const bool is_valid_end_idx = end_idx < nof_coeffs && end_idx >= start_idx;
const bool is_valid_indices = is_valid_start_idx && is_valid_end_idx;
if (!is_valid_indices) {
// return -1 instead? I could but 'get_coeff()' cannot with its current declaration
ICICLE_ASSERT(false) << "copy_coeffs() invalid indices";
ICICLE_LOG_ERROR << "copy_coeffs() invalid indices (start=" << start_idx << ", end=" << end_idx
<< ", nof_coeffs=" << nof_coeffs << "). Skipping copy...";
return 0;
}

op->transform_to_coefficients();
Expand Down
7 changes: 6 additions & 1 deletion wrappers/rust/icicle-core/src/polynomials/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ macro_rules! impl_univariate_polynomial_api {
fn copy_coeffs<S: HostOrDeviceSlice<Self::Field> + ?Sized>(&self, start_idx: u64, coeffs: &mut S) {
let coeffs_len = coeffs.len() as u64;
let nof_coeffs = self.get_nof_coeffs();
let end_idx = cmp::min(nof_coeffs, start_idx + coeffs_len - 1);
let end_idx = cmp::min(nof_coeffs - 1, start_idx + coeffs_len - 1);

unsafe {
copy_coeffs(self.handle, coeffs.as_mut_ptr(), start_idx, end_idx);
Expand Down Expand Up @@ -626,6 +626,11 @@ macro_rules! impl_polynomial_tests {
f.copy_coeffs(0, HostSlice::from_mut_slice(&mut host_mem));
assert_eq!(host_mem, coeffs);

// read into larger buffer
let mut host_mem_large = vec![$field::zero(); coeffs.len() + 10];
f.copy_coeffs(0, HostSlice::from_mut_slice(&mut host_mem_large));
assert_eq!(host_mem_large[..coeffs.len()], coeffs);

// read coeffs to device memory
let mut device_mem = DeviceVec::<$field>::device_malloc(coeffs.len()).unwrap();
f.copy_coeffs(0, &mut device_mem[..]);
Expand Down

0 comments on commit eec29d0

Please sign in to comment.