From 7da7f2bb6c26a7c55a5869d21c3a5f546880a001 Mon Sep 17 00:00:00 2001 From: Zachary James Williamson Date: Mon, 6 Jan 2025 22:08:25 +0000 Subject: [PATCH] feat: improve witness generation for cycle_group::batch_mul (#9563) Problem: `cycle_group` has a heavy witness generation cost. Existing code performs multiple modular inversions for every cycle_group group operation in `batch_mul` This was leading to 40% of the Prover time for `cycle_group` operations being raw witness generation. Batch inversion techniques are now employed to remove this cost. --- .../examples/join_split/join_split.test.cpp | 2 +- .../stdlib/primitives/group/cycle_group.cpp | 379 ++++++++++++++---- .../stdlib/primitives/group/cycle_group.hpp | 27 +- 3 files changed, 314 insertions(+), 94 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/examples/join_split/join_split.test.cpp b/barretenberg/cpp/src/barretenberg/examples/join_split/join_split.test.cpp index 1c516b7ae9d..b02d82664d2 100644 --- a/barretenberg/cpp/src/barretenberg/examples/join_split/join_split.test.cpp +++ b/barretenberg/cpp/src/barretenberg/examples/join_split/join_split.test.cpp @@ -703,7 +703,7 @@ TEST_F(join_split_tests, test_0_input_notes_and_detect_circuit_change) // The below part detects any changes in the join-split circuit constexpr size_t DYADIC_CIRCUIT_SIZE = 1 << 16; - constexpr uint256_t CIRCUIT_HASH("0x9ffbbd2c3ebd45cba861d3da6f75e2f73c448cc5747c9e34b44d6bc8a90b4a9c"); + constexpr uint256_t CIRCUIT_HASH("0x48687216f00a81d2a0f64f0a10cce056fce2ad13c47f8329229eb3712d3f7566"); const uint256_t circuit_hash = circuit.hash_circuit(); // circuit is finalized now diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/cycle_group.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/cycle_group.cpp index ab196d6d1c2..080e5a20087 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/cycle_group.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/cycle_group.cpp @@ -202,10 +202,11 @@ template cycle_group cycle_group::get_stand * @brief Evaluates a doubling. Does not use Ultra double gate * * @tparam Builder + * @param unused param is due to interface-compatibility with the UltraArithmetic version of `dbl` * @return cycle_group */ template -cycle_group cycle_group::dbl() const +cycle_group cycle_group::dbl([[maybe_unused]] const std::optional /*unused*/) const requires IsNotUltraArithmetic { auto modified_y = field_t::conditional_assign(is_point_at_infinity(), 1, y); @@ -219,35 +220,54 @@ cycle_group cycle_group::dbl() const * @brief Evaluates a doubling. Uses Ultra double gate * * @tparam Builder + * @param hint : value of output point witness, if known ahead of time (used to avoid modular inversions during witgen) * @return cycle_group */ template -cycle_group cycle_group::dbl() const +cycle_group cycle_group::dbl(const std::optional hint) const requires IsUltraArithmetic { // ensure we use a value of y that is not zero. (only happens if point at infinity) // this costs 0 gates if `is_infinity` is a circuit constant auto modified_y = field_t::conditional_assign(is_point_at_infinity(), 1, y).normalize(); - auto x1 = x.get_value(); - auto y1 = modified_y.get_value(); - - // N.B. the formula to derive the witness value for x3 mirrors the formula in elliptic_relation.hpp - // Specifically, we derive x^4 via the Short Weierstrass curve formula `y^2 = x^3 + b` - // i.e. x^4 = x * (y^2 - b) - // We must follow this pattern exactly to support the edge-case where the input is the point at infinity. - auto y_pow_2 = y1.sqr(); - auto x_pow_4 = x1 * (y_pow_2 - Group::curve_b); - auto lambda_squared = (x_pow_4 * 9) / (y_pow_2 * 4); - auto lambda = (x1 * x1 * 3) / (y1 + y1); - auto x3 = lambda_squared - x1 - x1; - auto y3 = lambda * (x1 - x3) - y1; - if (is_constant()) { - auto result = cycle_group(x3, y3, is_point_at_infinity().get_value()); - // We need to manually propagate the origin tag - result.set_origin_tag(get_origin_tag()); - return result; + + cycle_group result; + if (hint.has_value()) { + auto x3 = hint.value().x; + auto y3 = hint.value().y; + if (is_constant()) { + result = cycle_group(x3, y3, is_point_at_infinity()); + // We need to manually propagate the origin tag + result.set_origin_tag(get_origin_tag()); + + return result; + } + + result = cycle_group(witness_t(context, x3), witness_t(context, y3), is_point_at_infinity()); + } else { + auto x1 = x.get_value(); + auto y1 = modified_y.get_value(); + + // N.B. the formula to derive the witness value for x3 mirrors the formula in elliptic_relation.hpp + // Specifically, we derive x^4 via the Short Weierstrass curve formula `y^2 = x^3 + b` + // i.e. x^4 = x * (y^2 - b) + // We must follow this pattern exactly to support the edge-case where the input is the point at infinity. + auto y_pow_2 = y1.sqr(); + auto x_pow_4 = x1 * (y_pow_2 - Group::curve_b); + auto lambda_squared = (x_pow_4 * 9) / (y_pow_2 * 4); + auto lambda = (x1 * x1 * 3) / (y1 + y1); + auto x3 = lambda_squared - x1 - x1; + auto y3 = lambda * (x1 - x3) - y1; + if (is_constant()) { + auto result = cycle_group(x3, y3, is_point_at_infinity().get_value()); + // We need to manually propagate the origin tag + result.set_origin_tag(get_origin_tag()); + return result; + } + + result = cycle_group(witness_t(context, x3), witness_t(context, y3), is_point_at_infinity()); } - cycle_group result(witness_t(context, x3), witness_t(context, y3), is_point_at_infinity()); + context->create_ecc_dbl_gate(bb::ecc_dbl_gate_{ .x1 = x.get_witness_index(), .y1 = modified_y.normalize().get_witness_index(), @@ -272,7 +292,8 @@ cycle_group cycle_group::dbl() const * @return cycle_group */ template -cycle_group cycle_group::unconditional_add(const cycle_group& other) const +cycle_group cycle_group::unconditional_add( + const cycle_group& other, [[maybe_unused]] const std::optional /*unused*/) const requires IsNotUltraArithmetic { auto x_diff = other.x - x; @@ -294,10 +315,12 @@ cycle_group cycle_group::unconditional_add(const cycle_group& * * @tparam Builder * @param other + * @param hint : value of output point witness, if known ahead of time (used to avoid modular inversions during witgen) * @return cycle_group */ template -cycle_group cycle_group::unconditional_add(const cycle_group& other) const +cycle_group cycle_group::unconditional_add(const cycle_group& other, + const std::optional hint) const requires IsUltraArithmetic { auto context = get_context(other); @@ -308,28 +331,36 @@ cycle_group cycle_group::unconditional_add(const cycle_group& auto lhs = cycle_group::from_constant_witness(context, get_value()); // We need to manually propagate the origin tag lhs.set_origin_tag(get_origin_tag()); - return lhs.unconditional_add(other); + return lhs.unconditional_add(other, hint); } if (!lhs_constant && rhs_constant) { auto rhs = cycle_group::from_constant_witness(context, other.get_value()); // We need to manually propagate the origin tag rhs.set_origin_tag(other.get_origin_tag()); - return unconditional_add(rhs); + return unconditional_add(rhs, hint); } - - const auto p1 = get_value(); - const auto p2 = other.get_value(); - AffineElement p3(Element(p1) + Element(p2)); - if (lhs_constant && rhs_constant) { - auto result = cycle_group(p3); - // We need to manually propagate the origin tag - result.set_origin_tag(OriginTag(get_origin_tag(), other.get_origin_tag())); - return result; + cycle_group result; + if (hint.has_value()) { + auto x3 = hint.value().x; + auto y3 = hint.value().y; + if (lhs_constant && rhs_constant) { + return cycle_group(x3, y3, false); + } + result = cycle_group(witness_t(context, x3), witness_t(context, y3), false); + } else { + const auto p1 = get_value(); + const auto p2 = other.get_value(); + AffineElement p3(Element(p1) + Element(p2)); + if (lhs_constant && rhs_constant) { + auto result = cycle_group(p3); + // We need to manually propagate the origin tag + result.set_origin_tag(OriginTag(get_origin_tag(), other.get_origin_tag())); + return result; + } + field_t r_x(witness_t(context, p3.x)); + field_t r_y(witness_t(context, p3.y)); + result = cycle_group(r_x, r_y, false); } - field_t r_x(witness_t(context, p3.x)); - field_t r_y(witness_t(context, p3.y)); - cycle_group result(r_x, r_y, false); - bb::ecc_add_gate_ add_gate{ .x1 = x.get_witness_index(), .y1 = y.get_witness_index(), @@ -353,13 +384,15 @@ cycle_group cycle_group::unconditional_add(const cycle_group& * * @tparam Builder * @param other + * @param hint : value of output point witness, if known ahead of time (used to avoid modular inversions during witgen) * @return cycle_group */ template -cycle_group cycle_group::unconditional_subtract(const cycle_group& other) const +cycle_group cycle_group::unconditional_subtract(const cycle_group& other, + const std::optional hint) const { if constexpr (!IS_ULTRA) { - return unconditional_add(-other); + return unconditional_add(-other, hint); } else { auto context = get_context(other); @@ -370,7 +403,7 @@ cycle_group cycle_group::unconditional_subtract(const cycle_gr auto lhs = cycle_group::from_constant_witness(context, get_value()); // We need to manually propagate the origin tag lhs.set_origin_tag(get_origin_tag()); - return lhs.unconditional_subtract(other); + return lhs.unconditional_subtract(other, hint); } if (!lhs_constant && rhs_constant) { auto rhs = cycle_group::from_constant_witness(context, other.get_value()); @@ -378,19 +411,28 @@ cycle_group cycle_group::unconditional_subtract(const cycle_gr rhs.set_origin_tag(other.get_origin_tag()); return unconditional_subtract(rhs); } - auto p1 = get_value(); - auto p2 = other.get_value(); - AffineElement p3(Element(p1) - Element(p2)); - if (lhs_constant && rhs_constant) { - auto result = cycle_group(p3); - // We need to manually propagate the origin tag - result.set_origin_tag(OriginTag(get_origin_tag(), other.get_origin_tag())); - return result; + cycle_group result; + if (hint.has_value()) { + auto x3 = hint.value().x; + auto y3 = hint.value().y; + if (lhs_constant && rhs_constant) { + return cycle_group(x3, y3, false); + } + result = cycle_group(witness_t(context, x3), witness_t(context, y3), is_point_at_infinity()); + } else { + auto p1 = get_value(); + auto p2 = other.get_value(); + AffineElement p3(Element(p1) - Element(p2)); + if (lhs_constant && rhs_constant) { + auto result = cycle_group(p3); + // We need to manually propagate the origin tag + result.set_origin_tag(OriginTag(get_origin_tag(), other.get_origin_tag())); + return result; + } + field_t r_x(witness_t(context, p3.x)); + field_t r_y(witness_t(context, p3.y)); + result = cycle_group(r_x, r_y, false); } - field_t r_x(witness_t(context, p3.x)); - field_t r_y(witness_t(context, p3.y)); - cycle_group result(r_x, r_y, false); - bb::ecc_add_gate_ add_gate{ .x1 = x.get_witness_index(), .y1 = y.get_witness_index(), @@ -418,14 +460,16 @@ cycle_group cycle_group::unconditional_subtract(const cycle_gr * * @tparam Builder * @param other + * @param hint : value of output point witness, if known ahead of time (used to avoid modular inversions during witgen) * @return cycle_group */ template -cycle_group cycle_group::checked_unconditional_add(const cycle_group& other) const +cycle_group cycle_group::checked_unconditional_add(const cycle_group& other, + const std::optional hint) const { field_t x_delta = x - other.x; x_delta.assert_is_not_zero("cycle_group::checked_unconditional_add, x-coordinate collision"); - return unconditional_add(other); + return unconditional_add(other, hint); } /** @@ -438,14 +482,16 @@ cycle_group cycle_group::checked_unconditional_add(const cycle * * @tparam Builder * @param other + * @param hint : value of output point witness, if known ahead of time (used to avoid modular inversions during witgen) * @return cycle_group */ template -cycle_group cycle_group::checked_unconditional_subtract(const cycle_group& other) const +cycle_group cycle_group::checked_unconditional_subtract(const cycle_group& other, + const std::optional hint) const { field_t x_delta = x - other.x; x_delta.assert_is_not_zero("cycle_group::checked_unconditional_subtract, x-coordinate collision"); - return unconditional_subtract(other); + return unconditional_subtract(other, hint); } /** @@ -901,7 +947,13 @@ cycle_group::straus_scalar_slice::straus_scalar_slice(Builder* context, // convert an input cycle_scalar object into a vector of slices, each containing `table_bits` bits. // this also performs an implicit range check on the input slices const auto slice_scalar = [&](const field_t& scalar, const size_t num_bits) { - std::vector result; + // we record the scalar slices both as field_t circuit elements and u64 values + // (u64 values are used to index arrays and we don't want to repeatedly cast a stdlib value to a numeric + // primitive as this gets expensive when repeated enough times) + std::pair, std::vector> result; + result.first.reserve(static_cast(1ULL) << table_bits); + result.second.reserve(static_cast(1ULL) << table_bits); + if (num_bits == 0) { return result; } @@ -911,12 +963,22 @@ cycle_group::straus_scalar_slice::straus_scalar_slice(Builder* context, uint256_t raw_value = scalar.get_value(); for (size_t i = 0; i < num_slices; ++i) { uint64_t slice_v = static_cast(raw_value.data[0]) & table_mask; - result.push_back(field_t(slice_v)); + result.first.push_back(field_t(slice_v)); + result.second.push_back(slice_v); raw_value = raw_value >> table_bits; } return result; } + uint256_t raw_value = scalar.get_value(); + const uint64_t table_mask = (1ULL << table_bits) - 1ULL; + const size_t num_slices = (num_bits + table_bits - 1) / table_bits; + for (size_t i = 0; i < num_slices; ++i) { + uint64_t slice_v = static_cast(raw_value.data[0]) & table_mask; + result.second.push_back(slice_v); + raw_value = raw_value >> table_bits; + } + if constexpr (IS_ULTRA) { const auto slice_indices = context->decompose_into_default_range(scalar.normalize().get_witness_index(), @@ -924,26 +986,22 @@ cycle_group::straus_scalar_slice::straus_scalar_slice(Builder* context, table_bits, "straus_scalar_slice decompose_into_default_range"); for (auto& idx : slice_indices) { - result.emplace_back(field_t::from_witness_index(context, idx)); + result.first.emplace_back(field_t::from_witness_index(context, idx)); } } else { - uint256_t raw_value = scalar.get_value(); - const uint64_t table_mask = (1ULL << table_bits) - 1ULL; - const size_t num_slices = (num_bits + table_bits - 1) / table_bits; for (size_t i = 0; i < num_slices; ++i) { - uint64_t slice_v = static_cast(raw_value.data[0]) & table_mask; + uint64_t slice_v = result.second[i]; field_t slice(witness_t(context, slice_v)); context->create_range_constraint( slice.get_witness_index(), table_bits, "straus_scalar_slice create_range_constraint"); - result.emplace_back(slice); - raw_value = raw_value >> table_bits; + result.first.push_back(slice); } std::vector linear_elements; FF scaling_factor = 1; for (size_t i = 0; i < num_slices; ++i) { - linear_elements.emplace_back(result[i] * scaling_factor); + linear_elements.emplace_back(result.first[i] * scaling_factor); scaling_factor += scaling_factor; } field_t::accumulate(linear_elements).assert_equal(scalar); @@ -956,8 +1014,10 @@ cycle_group::straus_scalar_slice::straus_scalar_slice(Builder* context, auto hi_slices = slice_scalar(scalar.hi, hi_bits); auto lo_slices = slice_scalar(scalar.lo, lo_bits); - std::copy(lo_slices.begin(), lo_slices.end(), std::back_inserter(slices)); - std::copy(hi_slices.begin(), hi_slices.end(), std::back_inserter(slices)); + std::copy(lo_slices.first.begin(), lo_slices.first.end(), std::back_inserter(slices)); + std::copy(hi_slices.first.begin(), hi_slices.first.end(), std::back_inserter(slices)); + std::copy(lo_slices.second.begin(), lo_slices.second.end(), std::back_inserter(slices_native)); + std::copy(hi_slices.second.begin(), hi_slices.second.end(), std::back_inserter(slices_native)); const auto tag = scalar.get_origin_tag(); for (auto& element : slices) { // All slices need to have the same origin tag @@ -983,6 +1043,35 @@ std::optional> cycle_group::straus_scalar_slice::read( return slices[index]; } +/** + * @brief Compute the output points generated when computing the Straus lookup table + * @details When performing an MSM, we first compute all the witness values as Element types (with a Z-coordinate), + * and then we batch-convert the points into affine representation `AffineElement` + * This avoids the need to compute a modular inversion for every group operation, + * which dramatically cuts witness generation times + * + * @tparam Builder + * @param base_point + * @param offset_generator + * @param table_bits + * @return std::vector::Element> + */ +template +std::vector::Element> cycle_group< + Builder>::straus_lookup_table::compute_straus_lookup_table_hints(const Element& base_point, + const Element& offset_generator, + size_t table_bits) +{ + const size_t table_size = 1UL << table_bits; + Element base = base_point.is_point_at_infinity() ? Group::one : base_point; + std::vector hints; + hints.emplace_back(offset_generator); + for (size_t i = 1; i < table_size; ++i) { + hints.emplace_back(hints[i - 1] + base); + } + return hints; +} + /** * @brief Construct a new cycle group::straus lookup table::straus lookup table object * @@ -1001,7 +1090,8 @@ template cycle_group::straus_lookup_table::straus_lookup_table(Builder* context, const cycle_group& base_point, const cycle_group& offset_generator, - size_t table_bits) + size_t table_bits, + std::optional> hints) : _table_bits(table_bits) , _context(context) , tag(OriginTag(base_point.get_origin_tag(), offset_generator.get_origin_tag())) @@ -1022,11 +1112,41 @@ cycle_group::straus_lookup_table::straus_lookup_table(Builder* context, field_t modded_x = field_t::conditional_assign(base_point.is_point_at_infinity(), fallback_point.x, base_point.x); field_t modded_y = field_t::conditional_assign(base_point.is_point_at_infinity(), fallback_point.y, base_point.y); cycle_group modded_base_point(modded_x, modded_y, false); - for (size_t i = 1; i < table_size; ++i) { - auto add_output = point_table[i - 1].checked_unconditional_add(modded_base_point); - field_t x = field_t::conditional_assign(base_point.is_point_at_infinity(), offset_generator.x, add_output.x); - field_t y = field_t::conditional_assign(base_point.is_point_at_infinity(), offset_generator.y, add_output.y); - point_table[i] = cycle_group(x, y, false); + + // if the input point is constant, it is cheaper to fix the point as a witness and then derive the table, than it is + // to derive the table and fix its witnesses to be constant! (due to group additions = 1 gate, and fixing x/y coords + // to be constant = 2 gates) + if (modded_base_point.is_constant() && !base_point.is_point_at_infinity().get_value()) { + modded_base_point = cycle_group::from_constant_witness(_context, modded_base_point.get_value()); + point_table[0] = cycle_group::from_constant_witness(_context, offset_generator.get_value()); + for (size_t i = 1; i < table_size; ++i) { + std::optional hint = + hints.has_value() ? std::optional(hints.value()[i - 1]) : std::nullopt; + point_table[i] = point_table[i - 1].unconditional_add(modded_base_point, hint); + } + } else { + std::vector> x_coordinate_checks; + // ensure all of the ecc add gates are lined up so that we can pay 1 gate per add and not 2 + for (size_t i = 1; i < table_size; ++i) { + std::optional hint = + hints.has_value() ? std::optional(hints.value()[i - 1]) : std::nullopt; + x_coordinate_checks.emplace_back(point_table[i - 1].x, modded_base_point.x); + point_table[i] = point_table[i - 1].unconditional_add(modded_base_point, hint); + } + + // batch the x-coordinate checks together + // because `assert_is_not_zero` witness generation needs a modular inversion (expensive) + field_t coordinate_check_product = 1; + for (auto& [x1, x2] : x_coordinate_checks) { + auto x_diff = x2 - x1; + coordinate_check_product *= x_diff; + } + coordinate_check_product.assert_is_not_zero("straus_lookup_table x-coordinate collision"); + + for (size_t i = 1; i < table_size; ++i) { + point_table[i] = + cycle_group::conditional_assign(base_point.is_point_at_infinity(), offset_generator, point_table[i]); + } } if constexpr (IS_ULTRA) { rom_id = context->create_ROM_array(table_size); @@ -1137,16 +1257,78 @@ typename cycle_group::batch_mul_internal_output cycle_group::_ const size_t num_points = scalars.size(); std::vector scalar_slices; + + /** + * Compute the witness values of the batch_mul algorithm natively, as Element types with a Z-coordinate. + * We then batch-convert to AffineElement types, and feed these points as "hints" into the cycle_group methods. + * This avoids the need to compute modular inversions for every group operation, which dramatically reduces witness + * generation times + */ + std::vector operation_transcript; + std::vector> native_straus_tables; + Element offset_generator_accumulator = offset_generators[0]; + { + for (size_t i = 0; i < num_points; ++i) { + std::vector native_straus_table; + native_straus_table.emplace_back(offset_generators[i + 1]); + size_t table_size = 1ULL << TABLE_BITS; + for (size_t j = 1; j < table_size; ++j) { + native_straus_table.emplace_back(native_straus_table[j - 1] + base_points[i].get_value()); + } + native_straus_tables.emplace_back(native_straus_table); + } + for (size_t i = 0; i < num_points; ++i) { + scalar_slices.emplace_back(straus_scalar_slice(context, scalars[i], TABLE_BITS)); + + auto table_transcript = straus_lookup_table::compute_straus_lookup_table_hints( + base_points[i].get_value(), offset_generators[i + 1], TABLE_BITS); + std::copy(table_transcript.begin() + 1, table_transcript.end(), std::back_inserter(operation_transcript)); + } + Element accumulator = offset_generators[0]; + + for (size_t i = 0; i < num_rounds; ++i) { + if (i != 0) { + for (size_t j = 0; j < TABLE_BITS; ++j) { + // offset_generator_accuulator is a regular Element, so dbl() won't add constraints + accumulator = accumulator.dbl(); + operation_transcript.emplace_back(accumulator); + offset_generator_accumulator = offset_generator_accumulator.dbl(); + } + } + for (size_t j = 0; j < num_points; ++j) { + + const Element point = + native_straus_tables[j][static_cast(scalar_slices[j].slices_native[num_rounds - i - 1])]; + + accumulator += point; + + operation_transcript.emplace_back(accumulator); + offset_generator_accumulator = offset_generator_accumulator + Element(offset_generators[j + 1]); + } + } + } + + // Normalize the computed witness points and convert into AffineElement type + Element::batch_normalize(&operation_transcript[0], operation_transcript.size()); + + std::vector operation_hints; + operation_hints.reserve(operation_transcript.size()); + for (auto& element : operation_transcript) { + operation_hints.emplace_back(AffineElement(element.x, element.y)); + } + std::vector point_tables; + const size_t hints_per_table = (1ULL << TABLE_BITS) - 1; OriginTag tag{}; for (size_t i = 0; i < num_points; ++i) { + std::span table_hints(&operation_hints[i * hints_per_table], hints_per_table); // Merge tags tag = OriginTag(tag, scalars[i].get_origin_tag(), base_points[i].get_origin_tag()); scalar_slices.emplace_back(straus_scalar_slice(context, scalars[i], TABLE_BITS)); point_tables.emplace_back(straus_lookup_table(context, base_points[i], offset_generators[i + 1], TABLE_BITS)); } - Element offset_generator_accumulator = offset_generators[0]; + AffineElement* hint_ptr = &operation_hints[num_points * hints_per_table]; cycle_group accumulator = offset_generators[0]; // populate the set of points we are going to add into our accumulator, *before* we do any ECC operations @@ -1165,36 +1347,42 @@ typename cycle_group::batch_mul_internal_output cycle_group::_ } } } + std::vector> x_coordinate_checks; size_t point_counter = 0; for (size_t i = 0; i < num_rounds; ++i) { if (i != 0) { for (size_t j = 0; j < TABLE_BITS; ++j) { - // offset_generator_accuulator is a regular Element, so dbl() won't add constraints - accumulator = accumulator.dbl(); - offset_generator_accumulator = offset_generator_accumulator.dbl(); + accumulator = accumulator.dbl(*hint_ptr); + hint_ptr++; } } for (size_t j = 0; j < num_points; ++j) { const std::optional scalar_slice = scalar_slices[j].read(num_rounds - i - 1); - // if we are doing a batch mul over scalars of different bit-lengths, we may not have a bit slice for a - // given round and a given scalar + // if we are doing a batch mul over scalars of different bit-lengths, we may not have a bit slice + // for a given round and a given scalar + ASSERT(scalar_slice.value().get_value() == scalar_slices[j].slices_native[num_rounds - i - 1]); if (scalar_slice.has_value()) { const auto& point = points_to_add[point_counter++]; if (!unconditional_add) { x_coordinate_checks.push_back({ accumulator.x, point.x }); } - accumulator = accumulator.unconditional_add(point); - offset_generator_accumulator = offset_generator_accumulator + Element(offset_generators[j + 1]); + accumulator = accumulator.unconditional_add(point, *hint_ptr); + hint_ptr++; } } } + // validate that none of the x-coordinate differences are zero + // we batch the x-coordinate checks together + // because `assert_is_not_zero` witness generation needs a modular inversion (expensive) + field_t coordinate_check_product = 1; for (auto& [x1, x2] : x_coordinate_checks) { auto x_diff = x2 - x1; - x_diff.assert_is_not_zero("_variable_base_batch_mul_internal x-coordinate collision"); + coordinate_check_product *= x_diff; } + coordinate_check_product.assert_is_not_zero("_variable_base_batch_mul_internal x-coordinate collision"); // Set the final accumulator's tag to the union of all points' and scalars' tags accumulator.set_origin_tag(tag); @@ -1268,12 +1456,33 @@ typename cycle_group::batch_mul_internal_output cycle_group::_ ASSERT(offset_1.has_value()); offset_generator_accumulator += offset_1.value(); } + /** + * Compute the witness values of the batch_mul algorithm natively, as Element types with a Z-coordinate. + * We then batch-convert to AffineElement types, and feed these points as "hints" into the cycle_group methods. + * This avoids the need to compute modular inversions for every group operation, which dramatically reduces witness + * generation times + */ + std::vector operation_transcript; + { + Element accumulator = lookup_points[0].get_value(); + for (size_t i = 1; i < lookup_points.size(); ++i) { + accumulator = accumulator + (lookup_points[i].get_value()); + operation_transcript.emplace_back(accumulator); + } + } + Element::batch_normalize(&operation_transcript[0], operation_transcript.size()); + std::vector operation_hints; + operation_hints.reserve(operation_transcript.size()); + for (auto& element : operation_transcript) { + operation_hints.emplace_back(AffineElement(element.x, element.y)); + } + cycle_group accumulator = lookup_points[0]; // Perform all point additions sequentially. The Ultra ecc_addition relation costs 1 gate iff additions are chained // and output point of previous addition = input point of current addition. // If this condition is not met, the addition relation costs 2 gates. So it's good to do these sequentially! for (size_t i = 1; i < lookup_points.size(); ++i) { - accumulator = accumulator.unconditional_add(lookup_points[i]); + accumulator = accumulator.unconditional_add(lookup_points[i], operation_hints[i - 1]); } /** * offset_generator_accumulator represents the sum of all the offset generator terms present in `accumulator`. diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/cycle_group.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/cycle_group.hpp index cee097dfa4a..ab008de22a3 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/cycle_group.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/cycle_group.hpp @@ -136,6 +136,7 @@ template class cycle_group { std::optional read(size_t index); size_t _table_bits; std::vector slices; + std::vector slices_native; }; /** @@ -165,11 +166,16 @@ template class cycle_group { */ struct straus_lookup_table { public: + static std::vector compute_straus_lookup_table_hints(const Element& base_point, + const Element& offset_generator, + size_t table_bits); + straus_lookup_table() = default; straus_lookup_table(Builder* context, const cycle_group& base_point, const cycle_group& offset_generator, - size_t table_bits); + size_t table_bits, + std::optional> hints = std::nullopt); cycle_group read(const field_t& index); size_t _table_bits; Builder* _context; @@ -204,17 +210,22 @@ template class cycle_group { void set_point_at_infinity(const bool_t& is_infinity) { _is_infinity = is_infinity; } cycle_group get_standard_form() const; void validate_is_on_curve() const; - cycle_group dbl() const + cycle_group dbl(const std::optional hint = std::nullopt) const requires IsUltraArithmetic; - cycle_group dbl() const + cycle_group dbl(const std::optional hint = std::nullopt) const requires IsNotUltraArithmetic; - cycle_group unconditional_add(const cycle_group& other) const + cycle_group unconditional_add(const cycle_group& other, + const std::optional hint = std::nullopt) const requires IsUltraArithmetic; - cycle_group unconditional_add(const cycle_group& other) const + cycle_group unconditional_add(const cycle_group& other, + const std::optional hint = std::nullopt) const requires IsNotUltraArithmetic; - cycle_group unconditional_subtract(const cycle_group& other) const; - cycle_group checked_unconditional_add(const cycle_group& other) const; - cycle_group checked_unconditional_subtract(const cycle_group& other) const; + cycle_group unconditional_subtract(const cycle_group& other, + const std::optional hint = std::nullopt) const; + cycle_group checked_unconditional_add(const cycle_group& other, + const std::optional hint = std::nullopt) const; + cycle_group checked_unconditional_subtract(const cycle_group& other, + const std::optional hint = std::nullopt) const; cycle_group operator+(const cycle_group& other) const; cycle_group operator-(const cycle_group& other) const; cycle_group operator-() const;