From 9f6b4efdbcf8a5ffed31013aae46ca5caab80948 Mon Sep 17 00:00:00 2001 From: zac-williamson Date: Mon, 29 Apr 2024 12:41:29 +0000 Subject: [PATCH 01/24] initial commit. biggroup objects track whether they are points at infinity, and have +/- methods that correctly handle points at infinity --- .../stdlib/primitives/bigfield/bigfield.hpp | 6 + .../primitives/bigfield/bigfield.test.cpp | 44 +++++ .../primitives/bigfield/bigfield_impl.hpp | 52 ++++++ .../stdlib/primitives/biggroup/biggroup.hpp | 74 ++++++--- .../primitives/biggroup/biggroup.test.cpp | 94 ++++++++++- .../biggroup/biggroup_batch_mul.hpp | 41 ++++- .../primitives/biggroup/biggroup_bn254.hpp | 32 ++-- .../primitives/biggroup/biggroup_goblin.hpp | 1 + .../biggroup/biggroup_goblin.test.cpp | 4 +- .../primitives/biggroup/biggroup_impl.hpp | 154 ++++++++++++++++-- .../primitives/biggroup/biggroup_nafs.hpp | 15 +- .../biggroup/biggroup_secp256k1.hpp | 7 +- .../primitives/biggroup/biggroup_tables.hpp | 108 ++++++------ .../stdlib/primitives/curves/secp256r1.hpp | 10 +- .../primitives/databus/databus.test.cpp | 2 + 15 files changed, 517 insertions(+), 127 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.hpp index 2fc3572cec3..7643afe8ad6 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.hpp @@ -241,6 +241,12 @@ template class bigfield { bigfield conditional_negate(const bool_t& predicate) const; bigfield conditional_select(const bigfield& other, const bool_t& predicate) const; + static bigfield conditional_assign(const bool_t& predicate, const bigfield& lhs, const bigfield& rhs) + { + return rhs.conditional_select(lhs, predicate); + } + + bool_t operator==(const bigfield& other) const; void assert_is_in_field() const; void assert_less_than(const uint256_t upper_limit) const; diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.test.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.test.cpp index 3aa7f6090ce..8ec46f817de 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.test.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.test.cpp @@ -841,6 +841,45 @@ template class stdlib_bigfield : public testing::Test { fq_ct ret = fq_ct::div_check_denominator_nonzero({}, a_ct); EXPECT_NE(ret.get_context(), nullptr); } + + static void test_assert_equal_not_equal() + { + auto builder = Builder(); + size_t num_repetitions = 10; + for (size_t i = 0; i < num_repetitions; ++i) { + fq inputs[4]{ fq::random_element(), fq::random_element(), fq::random_element(), fq::random_element() }; + + fq_ct a(witness_ct(&builder, fr(uint256_t(inputs[0]).slice(0, fq_ct::NUM_LIMB_BITS * 2))), + witness_ct(&builder, + fr(uint256_t(inputs[0]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4)))); + fq_ct b(witness_ct(&builder, fr(uint256_t(inputs[1]).slice(0, fq_ct::NUM_LIMB_BITS * 2))), + witness_ct(&builder, + fr(uint256_t(inputs[1]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4)))); + fq_ct c(witness_ct(&builder, fr(uint256_t(inputs[2]).slice(0, fq_ct::NUM_LIMB_BITS * 2))), + witness_ct(&builder, + fr(uint256_t(inputs[2]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4)))); + fq_ct d(witness_ct(&builder, fr(uint256_t(inputs[3]).slice(0, fq_ct::NUM_LIMB_BITS * 2))), + witness_ct(&builder, + fr(uint256_t(inputs[3]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4)))); + + fq_ct two(witness_ct(&builder, fr(2)), + witness_ct(&builder, fr(0)), + witness_ct(&builder, fr(0)), + witness_ct(&builder, fr(0))); + fq_ct t0 = a + a; + fq_ct t1 = a * two; + + t0.assert_equal(t1); + t0.assert_is_not_equal(c); + t0.assert_is_not_equal(d); + stdlib::bool_t is_equal_a = t0 == t1; + stdlib::bool_t is_equal_b = t0 == c; + EXPECT_TRUE(is_equal_a.get_value()); + EXPECT_FALSE(is_equal_b.get_value()); + } + bool result = CircuitChecker::check(builder); + EXPECT_EQ(result, true); + } }; // Define types for which the above tests will be constructed. @@ -930,6 +969,11 @@ TYPED_TEST(stdlib_bigfield, division_context) TestFixture::test_division_context(); } +TYPED_TEST(stdlib_bigfield, assert_equal_not_equal) +{ + TestFixture::test_assert_equal_not_equal(); +} + // // This test was disabled before the refactor to use TYPED_TEST's/ // TEST(stdlib_bigfield, DISABLED_test_div_against_constants) // { diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield_impl.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield_impl.hpp index 3e6fc79a994..f8773225ad7 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield_impl.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield_impl.hpp @@ -1562,6 +1562,57 @@ bigfield bigfield::conditional_select(const bigfield& ot return result; } +/** + * @brief Validate whether two bigfield elements are equal to each other + * @details To evaluate whether `(a == b)`, we use result boolean `r` to evaluate the following logic: + * (n.b all algebra involving bigfield elements is done in the bigfield) + * 1. If `r == 1` , `a - b == 0` + * 2. If `r == 0`, `a - b` posesses an inverse `I` i.e. `(a - b) * I - 1 == 0` + * We efficiently evaluate this logic by evaluating a single expression `(a - b)*X = Y` + * We use conditional assignment logic to define `X, Y` to be the following: + * If `r == 1` then `X = 1, Y = 0` + * If `r == 0` then `X = I, Y = 1` + * This allows us to evaluate `operator==` using only 1 bigfield multiplication operation. + * We can check the product equals 0 or 1 by directly evaluating the binary basis/prime basis limbs of Y. + * i.e. if `r == 1` then `(a - b)*X` should have 0 for all limb values + * if `r == 0` then `(a - b)*X` should have 1 in the least significant binary basis limb and 0 elsewhere + * @tparam Builder + * @tparam T + * @param other + * @return bool_t + */ +template bool_t bigfield::operator==(const bigfield& other) const +{ + Builder* ctx = context ? context : other.get_context(); + auto lhs = get_value() % modulus_u512; + auto rhs = other.get_value() % modulus_u512; + bool is_equal_raw = (lhs == rhs); + bool_t is_equal = witness_t(ctx, is_equal_raw); + + bigfield diff = (*this) - other; + + // TODO: get native values efficiently (i.e. if u512 value fits in a u256, subtract off modulus until u256 fits + // into finite field) + native diff_native = native((diff.get_value() % modulus_u512).lo); + native inverse_native = is_equal_raw ? 0 : diff_native.invert(); + + bigfield inverse = bigfield::from_witness(ctx, inverse_native); + + bigfield multiplicand = bigfield::conditional_assign(is_equal, one(), inverse); + + bigfield product = diff * multiplicand; + + field_t result = field_t::conditional_assign(is_equal, 0, 1); + + product.prime_basis_limb.assert_equal(result); + product.binary_basis_limbs[0].element.assert_equal(result); + product.binary_basis_limbs[1].element.assert_equal(0); + product.binary_basis_limbs[2].element.assert_equal(0); + product.binary_basis_limbs[3].element.assert_equal(0); + + return is_equal; +} + /** * REDUCTION CHECK * @@ -1747,6 +1798,7 @@ template void bigfield::assert_equal( std::cerr << "bigfield: calling assert equal on 2 CONSTANT bigfield elements...is this intended?" << std::endl; return; } else if (other.is_constant()) { + // TODO: wtf? // evaluate a strict equality - make sure *this is reduced first, or an honest prover // might not be able to satisfy these constraints. field_t t0 = (binary_basis_limbs[0].element - other.binary_basis_limbs[0].element); diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.hpp index 51cdb25c790..4cbe262e5d9 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.hpp @@ -21,6 +21,8 @@ namespace bb::stdlib { // ( ͡° ͜ʖ ͡°) template class element { public: + using bool_t = stdlib::bool_t; + struct secp256k1_wnaf { std::vector> wnaf; field_t positive_skew; @@ -38,13 +40,23 @@ template class element { element(const Fq& x, const Fq& y); element(const element& other); - element(element&& other); + element(element&& other) noexcept; static element from_witness(Builder* ctx, const typename NativeGroup::affine_element& input) { - Fq x = Fq::from_witness(ctx, input.x); - Fq y = Fq::from_witness(ctx, input.y); - element out(x, y); + element out; + if (input.is_point_at_infinity()) { + Fq x = Fq::from_witness(ctx, NativeGroup::affine_one.x); + Fq y = Fq::from_witness(ctx, NativeGroup::affine_one.y); + out.x = x; + out.y = y; + } else { + Fq x = Fq::from_witness(ctx, input.x); + Fq y = Fq::from_witness(ctx, input.y); + out.x = x; + out.y = y; + } + out.set_point_at_infinity(witness_t(ctx, input.is_point_at_infinity())); out.validate_on_curve(); return out; } @@ -52,13 +64,17 @@ template class element { void validate_on_curve() const { Fq b(get_context(), uint256_t(NativeGroup::curve_b)); + Fq _b = Fq::conditional_assign(is_point_at_infinity(), Fq::zero(), b); + Fq _x = Fq::conditional_assign(is_point_at_infinity(), Fq::zero(), x); + Fq _y = Fq::conditional_assign(is_point_at_infinity(), Fq::zero(), y); if constexpr (!NativeGroup::has_a) { // we validate y^2 = x^3 + b by setting "fix_remainder_zero = true" when calling mult_madd - Fq::mult_madd({ x.sqr(), y }, { x, -y }, { b }, true); + Fq::mult_madd({ _x.sqr(), _y }, { _x, -_y }, { _b }, true); } else { Fq a(get_context(), uint256_t(NativeGroup::curve_a)); + Fq _a = Fq::conditional_assign(is_point_at_infinity(), Fq::zero(), a); // we validate y^2 = x^3 + ax + b by setting "fix_remainder_zero = true" when calling mult_madd - Fq::mult_madd({ x.sqr(), x, y }, { x, a, -y }, { b }, true); + Fq::mult_madd({ _x.sqr(), _x, _y }, { _x, _a, -_y }, { _b }, true); } } @@ -72,7 +88,7 @@ template class element { } element& operator=(const element& other); - element& operator=(element&& other); + element& operator=(element&& other) noexcept; byte_array to_byte_array() const { @@ -82,6 +98,9 @@ template class element { return result; } + element checked_unconditional_add(const element& other) const; + element checked_unconditional_subtract(const element& other) const; + element operator+(const element& other) const; element operator-(const element& other) const; element operator-() const @@ -100,11 +119,11 @@ template class element { *this = *this - other; return *this; } - std::array add_sub(const element& other) const; + std::array checked_unconditional_add_sub(const element& other) const; element operator*(const Fr& other) const; - element conditional_negate(const bool_t& predicate) const + element conditional_negate(const bool_t& predicate) const { element result(*this); result.y = result.y.conditional_negate(predicate); @@ -176,9 +195,13 @@ template class element { typename NativeGroup::affine_element get_value() const { - uint512_t x_val = x.get_value(); - uint512_t y_val = y.get_value(); - return typename NativeGroup::affine_element(x_val.lo, y_val.lo); + uint512_t x_val = x.get_value() % Fq::modulus_u512; + uint512_t y_val = y.get_value() % Fq::modulus_u512; + auto result = typename NativeGroup::affine_element(x_val.lo, y_val.lo); + if (is_point_at_infinity().get_value()) { + result.self_set_infinity(); + } + return result; } // compute a multi-scalar-multiplication by creating a precomputed lookup table for each point, @@ -229,7 +252,7 @@ template class element { template ::value>> static element secp256k1_ecdsa_mul(const element& pubkey, const Fr& u1, const Fr& u2); - static std::vector> compute_naf(const Fr& scalar, const size_t max_num_bits = 0); + static std::vector compute_naf(const Fr& scalar, const size_t max_num_bits = 0); template static std::vector> compute_wnaf(const Fr& scalar); @@ -265,10 +288,15 @@ template class element { return nullptr; } + bool_t is_point_at_infinity() const { return _is_infinity; } + void set_point_at_infinity(const bool_t& is_infinity) { _is_infinity = is_infinity; } + Fq x; Fq y; private: + bool_t _is_infinity; + template >> static std::array, 5> create_group_element_rom_tables( const std::array& elements, std::array& limb_max); @@ -367,7 +395,7 @@ template class element { lookup_table_base(const lookup_table_base& other) = default; lookup_table_base& operator=(const lookup_table_base& other) = default; - element get(const std::array, length>& bits) const; + element get(const std::array& bits) const; element operator[](const size_t idx) const { return element_table[idx]; } @@ -397,7 +425,7 @@ template class element { lookup_table_plookup(const lookup_table_plookup& other) = default; lookup_table_plookup& operator=(const lookup_table_plookup& other) = default; - element get(const std::array, length>& bits) const; + element get(const std::array& bits) const; element operator[](const size_t idx) const { return element_table[idx]; } @@ -608,7 +636,7 @@ template class element { return chain_add_accumulator(add_accumulator[0]); } - element::chain_add_accumulator get_chain_add_accumulator(std::vector>& naf_entries) const + element::chain_add_accumulator get_chain_add_accumulator(std::vector& naf_entries) const { std::vector round_accumulator; for (size_t j = 0; j < num_sixes; ++j) { @@ -660,7 +688,7 @@ template class element { return (accumulator); } - element get(std::vector>& naf_entries) const + element get(std::vector& naf_entries) const { std::vector round_accumulator; for (size_t j = 0; j < num_sixes; ++j) { @@ -812,21 +840,21 @@ template class element { return chain_add_accumulator(add_accumulator[0]); } - element::chain_add_accumulator get_chain_add_accumulator(std::vector>& naf_entries) const + element::chain_add_accumulator get_chain_add_accumulator(std::vector& naf_entries) const { std::vector round_accumulator; for (size_t j = 0; j < num_quads; ++j) { - round_accumulator.push_back(quad_tables[j].get(std::array, 4>{ + round_accumulator.push_back(quad_tables[j].get(std::array{ naf_entries[4 * j], naf_entries[4 * j + 1], naf_entries[4 * j + 2], naf_entries[4 * j + 3] })); } if (has_triple) { - round_accumulator.push_back(triple_tables[0].get(std::array, 3>{ + round_accumulator.push_back(triple_tables[0].get(std::array{ naf_entries[num_quads * 4], naf_entries[num_quads * 4 + 1], naf_entries[num_quads * 4 + 2] })); } if (has_twin) { round_accumulator.push_back(twin_tables[0].get( - std::array, 2>{ naf_entries[num_quads * 4], naf_entries[num_quads * 4 + 1] })); + std::array{ naf_entries[num_quads * 4], naf_entries[num_quads * 4 + 1] })); } if (has_singleton) { round_accumulator.push_back(singletons[0].conditional_negate(naf_entries[num_points - 1])); @@ -849,7 +877,7 @@ template class element { return (accumulator); } - element get(std::vector>& naf_entries) const + element get(std::vector& naf_entries) const { std::vector round_accumulator; for (size_t j = 0; j < num_quads; ++j) { @@ -858,7 +886,7 @@ template class element { } if (has_triple) { - round_accumulator.push_back(triple_tables[0].get(std::array, 3>{ + round_accumulator.push_back(triple_tables[0].get(std::array{ naf_entries[num_quads * 4], naf_entries[num_quads * 4 + 1], naf_entries[num_quads * 4 + 2] })); } if (has_twin) { diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.test.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.test.cpp index 44201423b28..a8de2df775b 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.test.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.test.cpp @@ -10,12 +10,12 @@ #include "barretenberg/stdlib/primitives/curves/secp256k1.hpp" #include "barretenberg/stdlib/primitives/curves/secp256r1.hpp" +using namespace bb; + namespace { auto& engine = numeric::get_debug_randomness(); } -using namespace bb; - // One can only define a TYPED_TEST with a single template paramter. // Our workaround is to pass parameters of the following type. template struct TestType { @@ -41,6 +41,8 @@ template class stdlib_biggroup : public testing::Test { using element = typename g1::element; using Builder = typename Curve::Builder; + using witness_ct = stdlib::witness_t; + using bool_ct = stdlib::bool_t; static constexpr auto EXPECT_CIRCUIT_CORRECTNESS = [](Builder& builder, bool expected_result = true) { info("num gates = ", builder.get_num_gates()); @@ -82,6 +84,45 @@ template class stdlib_biggroup : public testing::Test { EXPECT_CIRCUIT_CORRECTNESS(builder); } + static void test_add_points_at_infnity() + { + Builder builder; + size_t num_repetitions = 1; + for (size_t i = 0; i < num_repetitions; ++i) { + affine_element input_a(element::random_element()); + affine_element input_b(element::random_element()); + input_b.self_set_infinity(); + element_ct a = element_ct::from_witness(&builder, input_a); + // create copy of a with different witness + element_ct a_alternate = element_ct::from_witness(&builder, input_a); + element_ct a_negated = element_ct::from_witness(&builder, -input_a); + element_ct b = element_ct::from_witness(&builder, input_b); + + element_ct c = a + b; + element_ct d = b + a; + element_ct e = b + b; + element_ct f = a + a; + element_ct g = a + a_alternate; + element_ct h = a + a_negated; + + affine_element c_expected = affine_element(element(input_a) + element(input_b)); + affine_element d_expected = affine_element(element(input_b) + element(input_a)); + affine_element e_expected = affine_element(element(input_b) + element(input_b)); + affine_element f_expected = affine_element(element(input_a) + element(input_a)); + affine_element g_expected = affine_element(element(input_a) + element(input_a)); + affine_element h_expected = affine_element(element(input_a) + element(-input_a)); + + EXPECT_EQ(c.get_value(), c_expected); + EXPECT_EQ(d.get_value(), d_expected); + EXPECT_EQ(e.get_value(), e_expected); + EXPECT_EQ(f.get_value(), f_expected); + EXPECT_EQ(g.get_value(), g_expected); + EXPECT_EQ(h.get_value(), h_expected); + } + + EXPECT_CIRCUIT_CORRECTNESS(builder); + } + static void test_sub() { Builder builder; @@ -110,6 +151,45 @@ template class stdlib_biggroup : public testing::Test { EXPECT_CIRCUIT_CORRECTNESS(builder); } + static void test_sub_points_at_infnity() + { + Builder builder; + size_t num_repetitions = 1; + for (size_t i = 0; i < num_repetitions; ++i) { + affine_element input_a(element::random_element()); + affine_element input_b(element::random_element()); + input_b.self_set_infinity(); + element_ct a = element_ct::from_witness(&builder, input_a); + // create copy of a with different witness + element_ct a_alternate = element_ct::from_witness(&builder, input_a); + element_ct a_negated = element_ct::from_witness(&builder, -input_a); + element_ct b = element_ct::from_witness(&builder, input_b); + + element_ct c = a - b; + element_ct d = b - a; + element_ct e = b - b; + element_ct f = a - a; + element_ct g = a - a_alternate; + element_ct h = a - a_negated; + + affine_element c_expected = affine_element(element(input_a) - element(input_b)); + affine_element d_expected = affine_element(element(input_b) - element(input_a)); + affine_element e_expected = affine_element(element(input_b) - element(input_b)); + affine_element f_expected = affine_element(element(input_a) - element(input_a)); + affine_element g_expected = affine_element(element(input_a) - element(input_a)); + affine_element h_expected = affine_element(element(input_a) - element(-input_a)); + + EXPECT_EQ(c.get_value(), c_expected); + EXPECT_EQ(d.get_value(), d_expected); + EXPECT_EQ(e.get_value(), e_expected); + EXPECT_EQ(f.get_value(), f_expected); + EXPECT_EQ(g.get_value(), g_expected); + EXPECT_EQ(h.get_value(), h_expected); + } + + EXPECT_CIRCUIT_CORRECTNESS(builder); + } + static void test_dbl() { Builder builder; @@ -833,10 +913,20 @@ TYPED_TEST(stdlib_biggroup, add) TestFixture::test_add(); } +TYPED_TEST(stdlib_biggroup, add_points_at_infinity) +{ + + TestFixture::test_add_points_at_infnity(); +} TYPED_TEST(stdlib_biggroup, sub) { TestFixture::test_sub(); } +TYPED_TEST(stdlib_biggroup, sub_points_at_infinity) +{ + + TestFixture::test_sub_points_at_infnity(); +} TYPED_TEST(stdlib_biggroup, dbl) { TestFixture::test_dbl(); diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_batch_mul.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_batch_mul.hpp index a10198286c3..004538a3e5d 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_batch_mul.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_batch_mul.hpp @@ -1,21 +1,50 @@ #pragma once +#include "barretenberg/stdlib/primitives/biggroup/biggroup.hpp" +#include namespace bb::stdlib { /** * only works for Plookup (otherwise falls back on batch_mul)! Multiscalar multiplication that utilizes 4-bit wNAF * lookup tables is more efficient than points-as-linear-combinations lookup tables, if the number of points is 3 or * fewer + * TODO: when we nuke standard and turbo plonk we should remove the fallback batch mul method! */ template template -element element::wnaf_batch_mul(const std::vector& points, - const std::vector& scalars) +element element::wnaf_batch_mul(const std::vector& _points, + const std::vector& _scalars) { constexpr size_t WNAF_SIZE = 4; - ASSERT(points.size() == scalars.size()); + ASSERT(_points.size() == _scalars.size()); if constexpr (!HasPlookup) { - return batch_mul(points, scalars, max_num_bits); + return batch_mul(_points, _scalars, max_num_bits); + } + + // treat inputs for points at infinity. + // if a base point is at infinity, we substitute for element::one, and set the scalar multiplier to 0 + // this (partially) ensures the mul algorithm does not need to account for points at infinity + std::vector points; + std::vector scalars; + element one = element::one(nullptr); + for (size_t i = 0; i < points.size(); ++i) { + bool_t is_point_at_infinity = points[i].is_point_at_infinity(); + if (is_point_at_infinity.get_value() && static_cast(is_point_at_infinity.is_constant())) { + // if point is at infinity and a circuit constant we can just skip. + continue; + } + if (_scalars[i].get_value() == 0 && _scalars[i].is_constant()) { + // if scalar multiplier is 0 and also a constant, we can skip + continue; + } + element point(_points[i]); + point.x = Fq::conditional_assign(is_point_at_infinity, one.x, point.x); + point.y = Fq::conditional_assign(is_point_at_infinity, one.y, point.y); + Fr scalar = Fr::conditional_assign(is_point_at_infinity, 0, _scalars[i]); + points.push_back(point); + scalars.push_back(scalar); + + // TODO: if both point and scalar are constant, don't bother adding constraints } std::vector> point_tables; @@ -49,8 +78,8 @@ element element::wnaf_batch_mul(const std::vector(wnaf_entries[i][num_rounds])); - Fq out_y = accumulator.y.conditional_select(skew.y, bool_t(wnaf_entries[i][num_rounds])); + Fq out_x = accumulator.x.conditional_select(skew.x, bool_t(wnaf_entries[i][num_rounds])); + Fq out_y = accumulator.y.conditional_select(skew.y, bool_t(wnaf_entries[i][num_rounds])); accumulator = element(out_x, out_y); } accumulator -= offset_generators.second; diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_bn254.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_bn254.hpp index 5e03f8a58da..0836b29bc87 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_bn254.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_bn254.hpp @@ -7,6 +7,8 @@ * We use a special case algorithm to split bn254 scalar multipliers into endomorphism scalars * **/ +#include "barretenberg/stdlib/primitives/biggroup/biggroup.hpp" +#include "barretenberg/stdlib/primitives/circuit_builders/circuit_builders.hpp" namespace bb::stdlib { /** @@ -18,6 +20,7 @@ namespace bb::stdlib { * `small_scalars/small_points` : 128-bit scalar multipliers * `generator_scalar` : a 254-bit scalar multiplier over the bn254 generator point * + * TODO: this is plonk only. kill method when we deprecate standard/turbo plonk **/ template template @@ -54,9 +57,9 @@ element element::bn254_endo_batch_mul_with_generator auto& big_table = big_table_pair.first; auto& endo_table = big_table_pair.second; batch_lookup_table small_table(small_points); - std::vector>> big_naf_entries; - std::vector>> endo_naf_entries; - std::vector>> small_naf_entries; + std::vector> big_naf_entries; + std::vector> endo_naf_entries; + std::vector> small_naf_entries; const auto split_into_endomorphism_scalars = [ctx](const Fr& scalar) { bb::fr k = scalar.get_value(); @@ -99,9 +102,9 @@ element element::bn254_endo_batch_mul_with_generator element accumulator = element::chain_add_end(init_point); const auto get_point_to_add = [&](size_t naf_index) { - std::vector> small_nafs; - std::vector> big_nafs; - std::vector> endo_nafs; + std::vector small_nafs; + std::vector big_nafs; + std::vector endo_nafs; for (size_t i = 0; i < small_points.size(); ++i) { small_nafs.emplace_back(small_naf_entries[i][naf_index]); } @@ -178,16 +181,14 @@ element element::bn254_endo_batch_mul_with_generator } { element skew = accumulator - generator_table[128]; - Fq out_x = accumulator.x.conditional_select(skew.x, bool_t(generator_wnaf[generator_wnaf.size() - 1])); - Fq out_y = accumulator.y.conditional_select(skew.y, bool_t(generator_wnaf[generator_wnaf.size() - 1])); + Fq out_x = accumulator.x.conditional_select(skew.x, bool_t(generator_wnaf[generator_wnaf.size() - 1])); + Fq out_y = accumulator.y.conditional_select(skew.y, bool_t(generator_wnaf[generator_wnaf.size() - 1])); accumulator = element(out_x, out_y); } { element skew = accumulator - generator_endo_table[128]; - Fq out_x = - accumulator.x.conditional_select(skew.x, bool_t(generator_endo_wnaf[generator_wnaf.size() - 1])); - Fq out_y = - accumulator.y.conditional_select(skew.y, bool_t(generator_endo_wnaf[generator_wnaf.size() - 1])); + Fq out_x = accumulator.x.conditional_select(skew.x, bool_t(generator_endo_wnaf[generator_wnaf.size() - 1])); + Fq out_y = accumulator.y.conditional_select(skew.y, bool_t(generator_endo_wnaf[generator_wnaf.size() - 1])); accumulator = element(out_x, out_y); } @@ -213,6 +214,7 @@ element element::bn254_endo_batch_mul_with_generator * max_num_small_bits : MINIMUM value must be 128 bits * (we will be splitting `big_scalars` into two 128-bit scalars, we assume all scalars after this transformation are 128 *bits) + * TODO: this does not seem to be used anywhere except turbo plonk. delete once we deprecate turbo? **/ template template @@ -320,7 +322,7 @@ element element::bn254_endo_batch_mul(const std::vec **/ const size_t num_rounds = max_num_small_bits; const size_t num_points = points.size(); - std::vector>> naf_entries; + std::vector> naf_entries; for (size_t i = 0; i < num_points; ++i) { naf_entries.emplace_back(compute_naf(scalars[i], max_num_small_bits)); } @@ -354,7 +356,7 @@ element element::bn254_endo_batch_mul(const std::vec **/ for (size_t i = 1; i < num_rounds / 2; ++i) { // `nafs` tracks the naf value for each point for the current round - std::vector> nafs; + std::vector nafs; for (size_t j = 0; j < points.size(); ++j) { nafs.emplace_back(naf_entries[j][i * 2 - 1]); } @@ -383,7 +385,7 @@ element element::bn254_endo_batch_mul(const std::vec // we need to iterate 1 more time if the number of rounds is even if ((num_rounds & 0x01ULL) == 0x00ULL) { - std::vector> nafs; + std::vector nafs; for (size_t j = 0; j < points.size(); ++j) { nafs.emplace_back(naf_entries[j][num_rounds - 1]); } diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_goblin.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_goblin.hpp index 62404fc055e..15d8a16c372 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_goblin.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_goblin.hpp @@ -1,5 +1,6 @@ #pragma once +#include "barretenberg/stdlib/primitives/biggroup/biggroup.hpp" namespace bb::stdlib { /** diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_goblin.test.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_goblin.test.cpp index 1ac09c4e69d..6e6e38d9358 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_goblin.test.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_goblin.test.cpp @@ -10,12 +10,12 @@ #include "barretenberg/numeric/random/engine.hpp" #include +using namespace bb; + namespace { auto& engine = numeric::get_debug_randomness(); } -using namespace bb; - template class stdlib_biggroup_goblin : public testing::Test { using element_ct = typename Curve::Element; using scalar_ct = typename Curve::ScalarField; diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_impl.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_impl.hpp index 35b1c477d72..d446cfa06a3 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_impl.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_impl.hpp @@ -2,8 +2,7 @@ #include "../bit_array/bit_array.hpp" #include "../circuit_builders/circuit_builders.hpp" - -using namespace bb; +#include "barretenberg/stdlib/primitives/biggroup/biggroup.hpp" namespace bb::stdlib { @@ -11,50 +10,181 @@ template element::element() : x() , y() + , _is_infinity() {} template element::element(const typename G::affine_element& input) : x(nullptr, input.x) , y(nullptr, input.y) + , _is_infinity(nullptr, input.is_point_at_infinity()) {} template element::element(const Fq& x_in, const Fq& y_in) : x(x_in) , y(y_in) + , _is_infinity(x.get_context() ? x.get_context() : y.get_context(), false) {} template element::element(const element& other) : x(other.x) , y(other.y) + , _is_infinity(other.is_point_at_infinity()) {} template -element::element(element&& other) +element::element(element&& other) noexcept : x(other.x) , y(other.y) + , _is_infinity(other.is_point_at_infinity()) {} template element& element::operator=(const element& other) { + if (&other == this) { + return *this; + } x = other.x; y = other.y; + _is_infinity = other.is_point_at_infinity(); return *this; } template -element& element::operator=(element&& other) +element& element::operator=(element&& other) noexcept { + if (&other == this) { + return *this; + } x = other.x; y = other.y; + _is_infinity = other.is_point_at_infinity(); return *this; } template element element::operator+(const element& other) const +{ + // return checked_unconditional_add(other); + if constexpr (IsGoblinBuilder && std::same_as) { + // TODO(https://github.com/AztecProtocol/barretenberg/issues/707) Optimize + // Current gate count: 6398 + std::vector points{ *this, other }; + std::vector scalars{ 1, 1 }; + return goblin_batch_mul(points, scalars); + } + + // if x_coordinates match, lambda triggers a divide by zero error. + // Adding in `x_coordinates_match` ensures that lambda will always be well-formed + const bool_t x_coordinates_match = other.x == x; + const bool_t y_coordinates_match = (y == other.y); + const bool_t infinity_predicate = (x_coordinates_match && !y_coordinates_match); + const bool_t double_predicate = (x_coordinates_match && y_coordinates_match); + const bool_t lhs_infinity = is_point_at_infinity(); + const bool_t rhs_infinity = other.is_point_at_infinity(); + + // Compute the gradient `lambda`. If we add, `lambda = (y2 - y1)/(x2 - x1)`, else `lambda = 3x1*x1/2y1 + const Fq add_lambda_numerator = other.y - y; + const Fq xx = x * x; + const Fq dbl_lambda_numerator = xx + xx + xx; + const Fq lambda_numerator = Fq::conditional_assign(double_predicate, dbl_lambda_numerator, add_lambda_numerator); + + const Fq add_lambda_denominator = other.x - x; + const Fq dbl_lambda_denominator = y + y; + Fq lambda_denominator = Fq::conditional_assign(double_predicate, dbl_lambda_denominator, add_lambda_denominator); + // If either inputs are points at infinity, we set lambda_denominator to be 1. This ensures we never trigger a + // divide by zero error. + // (if either inputs are points at infinity we will not use the result of this computation) + Fq safe_edgecase_denominator = Fq(field_t(1), field_t(0), field_t(0), field_t(0)); + lambda_denominator = Fq::conditional_assign( + lhs_infinity || rhs_infinity || infinity_predicate, safe_edgecase_denominator, lambda_denominator); + const Fq lambda = Fq::div_without_denominator_check({ lambda_numerator }, lambda_denominator); + + const Fq x3 = lambda.sqradd({ -other.x, -x }); + const Fq y3 = lambda.madd(x - x3, { -y }); + + element result(x3, y3); + // if lhs infinity, return rhs + result.x = Fq::conditional_assign(lhs_infinity, other.x, result.x); + result.y = Fq::conditional_assign(lhs_infinity, other.y, result.y); + // if rhs infinity, return lhs + result.x = Fq::conditional_assign(rhs_infinity, x, result.x); + result.y = Fq::conditional_assign(rhs_infinity, y, result.y); + + // is result point at infinity? + // yes = infinity_predicate && !lhs_infinity && !rhs_infinity + // yes = lhs_infinity && rhs_infinity + // n.b. can likely optimize this + bool_t result_is_infinity = infinity_predicate && (!lhs_infinity && !rhs_infinity); + result_is_infinity = result_is_infinity || (lhs_infinity && rhs_infinity); + result.set_point_at_infinity(result_is_infinity); + return result; +} + +template +element element::operator-(const element& other) const +{ + // return checked_unconditional_add(other); + if constexpr (IsGoblinBuilder && std::same_as) { + // TODO(https://github.com/AztecProtocol/barretenberg/issues/707) Optimize + // Current gate count: 6398 + std::vector points{ *this, other }; + std::vector scalars{ 1, -Fr(1) }; + return goblin_batch_mul(points, scalars); + } + + // if x_coordinates match, lambda triggers a divide by zero error. + // Adding in `x_coordinates_match` ensures that lambda will always be well-formed + const bool_t x_coordinates_match = other.x == x; + const bool_t y_coordinates_match = (y == other.y); + const bool_t infinity_predicate = (x_coordinates_match && y_coordinates_match); + const bool_t double_predicate = (x_coordinates_match && !y_coordinates_match); + const bool_t lhs_infinity = is_point_at_infinity(); + const bool_t rhs_infinity = other.is_point_at_infinity(); + + // Compute the gradient `lambda`. If we add, `lambda = (y2 - y1)/(x2 - x1)`, else `lambda = 3x1*x1/2y1 + const Fq add_lambda_numerator = -other.y - y; + const Fq xx = x * x; + const Fq dbl_lambda_numerator = xx + xx + xx; + const Fq lambda_numerator = Fq::conditional_assign(double_predicate, dbl_lambda_numerator, add_lambda_numerator); + + const Fq add_lambda_denominator = other.x - x; + const Fq dbl_lambda_denominator = y + y; + Fq lambda_denominator = Fq::conditional_assign(double_predicate, dbl_lambda_denominator, add_lambda_denominator); + // If either inputs are points at infinity, we set lambda_denominator to be 1. This ensures we never trigger a + // divide by zero error. + // (if either inputs are points at infinity we will not use the result of this computation) + Fq safe_edgecase_denominator = Fq(field_t(1), field_t(0), field_t(0), field_t(0)); + lambda_denominator = Fq::conditional_assign( + lhs_infinity || rhs_infinity || infinity_predicate, safe_edgecase_denominator, lambda_denominator); + const Fq lambda = Fq::div_without_denominator_check({ lambda_numerator }, lambda_denominator); + + const Fq x3 = lambda.sqradd({ -other.x, -x }); + const Fq y3 = lambda.madd(x - x3, { -y }); + + element result(x3, y3); + // if lhs infinity, return rhs + result.x = Fq::conditional_assign(lhs_infinity, other.x, result.x); + result.y = Fq::conditional_assign(lhs_infinity, -other.y, result.y); + // if rhs infinity, return lhs + result.x = Fq::conditional_assign(rhs_infinity, x, result.x); + result.y = Fq::conditional_assign(rhs_infinity, y, result.y); + + // is result point at infinity? + // yes = infinity_predicate && !lhs_infinity && !rhs_infinity + // yes = lhs_infinity && rhs_infinity + // n.b. can likely optimize this + bool_t result_is_infinity = infinity_predicate && (!lhs_infinity && !rhs_infinity); + result_is_infinity = result_is_infinity || (lhs_infinity && rhs_infinity); + result.set_point_at_infinity(result_is_infinity); + return result; +} + +template +element element::checked_unconditional_add(const element& other) const { if constexpr (IsGoblinBuilder && std::same_as) { // TODO(https://github.com/AztecProtocol/barretenberg/issues/707) Optimize @@ -72,7 +202,7 @@ element element::operator+(const element& other) con } template -element element::operator-(const element& other) const +element element::checked_unconditional_subtract(const element& other) const { if constexpr (IsGoblinBuilder && std::same_as) { // TODO(https://github.com/AztecProtocol/barretenberg/issues/707) Optimize @@ -105,7 +235,7 @@ element element::operator-(const element& other) con */ // TODO(https://github.com/AztecProtocol/barretenberg/issues/657): This function is untested template -std::array, 2> element::add_sub(const element& other) const +std::array, 2> element::checked_unconditional_add_sub(const element& other) const { if constexpr (IsGoblinBuilder && std::same_as) { return { *this + other, *this - other }; @@ -140,7 +270,9 @@ template element element Fq neg_lambda = Fq::msub_div({ x }, { (two_x + x) }, (y + y), {}); Fq x_3 = neg_lambda.sqradd({ -(two_x) }); Fq y_3 = neg_lambda.madd(x_3 - x, { -y }); - return element(x_3, y_3); + element result = element(x_3, y_3); + result.set_point_at_infinity(is_point_at_infinity()); + return result; } /** @@ -631,7 +763,7 @@ element element::batch_mul(const std::vector>> naf_entries; + std::vector> naf_entries; for (size_t i = 0; i < num_points; ++i) { naf_entries.emplace_back(compute_naf(scalars[i], max_num_bits)); } @@ -646,7 +778,7 @@ element element::batch_mul(const std::vector> nafs(num_points); + std::vector nafs(num_points); std::vector to_add; const size_t inner_num_rounds = (i != num_iterations - 1) ? num_rounds_per_iteration : num_rounds_per_final_iteration; @@ -709,14 +841,14 @@ element element::operator*(const Fr& scalar) const } else { constexpr uint64_t num_rounds = Fr::modulus.get_msb() + 1; - std::vector> naf_entries = compute_naf(scalar); + std::vector naf_entries = compute_naf(scalar); const auto offset_generators = compute_offset_generators(num_rounds); element accumulator = *this + offset_generators.first; for (size_t i = 1; i < num_rounds; ++i) { - bool_t predicate = naf_entries[i]; + bool_t predicate = naf_entries[i]; bigfield y_test = y.conditional_negate(predicate); element to_add(x, y_test); accumulator = accumulator.montgomery_ladder(to_add); diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_nafs.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_nafs.hpp index 32a8a3876c1..f1dd10cd30e 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_nafs.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_nafs.hpp @@ -1,5 +1,6 @@ #pragma once #include "barretenberg/ecc/curves/secp256k1/secp256k1.hpp" +#include "barretenberg/stdlib/primitives/biggroup/biggroup.hpp" namespace bb::stdlib { @@ -481,17 +482,17 @@ std::vector> element::compute_naf(const Fr& scalar, cons uint256_t scalar_multiplier = scalar_multiplier_512.lo; const size_t num_rounds = (max_num_bits == 0) ? Fr::modulus.get_msb() + 1 : max_num_bits; - std::vector> naf_entries(num_rounds + 1); + std::vector naf_entries(num_rounds + 1); // if boolean is false => do NOT flip y // if boolean is true => DO flip y // first entry is skew. i.e. do we subtract one from the final result or not if (scalar_multiplier.get_bit(0) == false) { // add skew - naf_entries[num_rounds] = bool_t(witness_t(ctx, true)); + naf_entries[num_rounds] = bool_t(witness_t(ctx, true)); scalar_multiplier += uint256_t(1); } else { - naf_entries[num_rounds] = bool_t(witness_t(ctx, false)); + naf_entries[num_rounds] = bool_t(witness_t(ctx, false)); } for (size_t i = 0; i < num_rounds - 1; ++i) { bool next_entry = scalar_multiplier.get_bit(i + 1); @@ -499,7 +500,7 @@ std::vector> element::compute_naf(const Fr& scalar, cons // This is a VERY hacky workaround to ensure that UltraPlonkBuilder will apply a basic // range constraint per bool, and not a full 1-bit range gate if (next_entry == false) { - bool_t bit(ctx, true); + bool_t bit(ctx, true); bit.context = ctx; bit.witness_index = witness_t(ctx, true).witness_index; // flip sign bit.witness_bool = true; @@ -512,7 +513,7 @@ std::vector> element::compute_naf(const Fr& scalar, cons } naf_entries[num_rounds - i - 1] = bit; } else { - bool_t bit(ctx, false); + bool_t bit(ctx, false); bit.witness_index = witness_t(ctx, false).witness_index; // don't flip sign bit.witness_bool = false; if constexpr (HasPlookup) { @@ -525,7 +526,7 @@ std::vector> element::compute_naf(const Fr& scalar, cons naf_entries[num_rounds - i - 1] = bit; } } - naf_entries[0] = bool_t(ctx, false); // most significant entry is always true + naf_entries[0] = bool_t(ctx, false); // most significant entry is always true // validate correctness of NAF if constexpr (!Fr::is_composite) { @@ -542,7 +543,7 @@ std::vector> element::compute_naf(const Fr& scalar, cons Fr accumulator_result = Fr::accumulate(accumulators); scalar.assert_equal(accumulator_result); } else { - const auto reconstruct_half_naf = [](bool_t* nafs, const size_t half_round_length) { + const auto reconstruct_half_naf = [](bool_t* nafs, const size_t half_round_length) { // Q: need constraint to start from zero? field_t negative_accumulator(0); field_t positive_accumulator(0); diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_secp256k1.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_secp256k1.hpp index 6f898f6a217..b9b363ba8ea 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_secp256k1.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_secp256k1.hpp @@ -5,6 +5,7 @@ * TODO: we should try to genericize this, but this method is super fiddly and we need it to be efficient! * **/ +#include "barretenberg/stdlib/primitives/biggroup/biggroup.hpp" namespace bb::stdlib { template @@ -119,14 +120,14 @@ element element::secp256k1_ecdsa_mul(const element& const element& base_point, const field_t& positive_skew, const field_t& negative_skew) { - const bool_t positive_skew_bool(positive_skew); - const bool_t negative_skew_bool(negative_skew); + const bool_t positive_skew_bool(positive_skew); + const bool_t negative_skew_bool(negative_skew); auto to_add = base_point; to_add.y = to_add.y.conditional_negate(negative_skew_bool); element result = accumulator + to_add; // when computing the wNAF we have already validated that positive_skew and negative_skew cannot both be true - bool_t skew_combined = positive_skew_bool ^ negative_skew_bool; + bool_t skew_combined = positive_skew_bool ^ negative_skew_bool; result.x = accumulator.x.conditional_select(result.x, skew_combined); result.y = accumulator.y.conditional_select(result.y, skew_combined); return result; diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_tables.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_tables.hpp index 78cc53e03b7..bdb6a9cd61f 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_tables.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_tables.hpp @@ -1,4 +1,6 @@ #pragma once +#include "barretenberg/stdlib/primitives/biggroup/biggroup.hpp" +#include "barretenberg/stdlib/primitives/memory/twin_rom_table.hpp" #include "barretenberg/stdlib_circuit_builders/plookup_tables/types.hpp" namespace bb::stdlib { @@ -180,27 +182,27 @@ template element::lookup_table_plookup::lookup_table_plookup(const std::array& inputs) { if constexpr (length == 2) { - auto [A0, A1] = inputs[1].add_sub(inputs[0]); + auto [A0, A1] = inputs[1].checked_unconditional_add_sub(inputs[0]); element_table[0] = A0; element_table[1] = A1; } else if constexpr (length == 3) { - auto [R0, R1] = inputs[1].add_sub(inputs[0]); // B ± A + auto [R0, R1] = inputs[1].checked_unconditional_add_sub(inputs[0]); // B ± A - auto [T0, T1] = inputs[2].add_sub(R0); // C ± (B + A) - auto [T2, T3] = inputs[2].add_sub(R1); // C ± (B - A) + auto [T0, T1] = inputs[2].checked_unconditional_add_sub(R0); // C ± (B + A) + auto [T2, T3] = inputs[2].checked_unconditional_add_sub(R1); // C ± (B - A) element_table[0] = T0; element_table[1] = T2; element_table[2] = T3; element_table[3] = T1; } else if constexpr (length == 4) { - auto [T0, T1] = inputs[1].add_sub(inputs[0]); // B ± A - auto [T2, T3] = inputs[3].add_sub(inputs[2]); // D ± C + auto [T0, T1] = inputs[1].checked_unconditional_add_sub(inputs[0]); // B ± A + auto [T2, T3] = inputs[3].checked_unconditional_add_sub(inputs[2]); // D ± C - auto [F0, F3] = T2.add_sub(T0); // (D + C) ± (B + A) - auto [F1, F2] = T2.add_sub(T1); // (D + C) ± (B - A) - auto [F4, F7] = T3.add_sub(T0); // (D - C) ± (B + A) - auto [F5, F6] = T3.add_sub(T1); // (D - C) ± (B - A) + auto [F0, F3] = T2.checked_unconditional_add_sub(T0); // (D + C) ± (B + A) + auto [F1, F2] = T2.checked_unconditional_add_sub(T1); // (D + C) ± (B - A) + auto [F4, F7] = T3.checked_unconditional_add_sub(T0); // (D - C) ± (B + A) + auto [F5, F6] = T3.checked_unconditional_add_sub(T1); // (D - C) ± (B - A) element_table[0] = F0; element_table[1] = F1; @@ -211,20 +213,20 @@ element::lookup_table_plookup::lookup_table_plookup(con element_table[6] = F6; element_table[7] = F7; } else if constexpr (length == 5) { - auto [A0, A1] = inputs[1].add_sub(inputs[0]); // B ± A - auto [T2, T3] = inputs[3].add_sub(inputs[2]); // D ± C + auto [A0, A1] = inputs[1].checked_unconditional_add_sub(inputs[0]); // B ± A + auto [T2, T3] = inputs[3].checked_unconditional_add_sub(inputs[2]); // D ± C - auto [E0, E3] = inputs[4].add_sub(T2); // E ± (D + C) - auto [E1, E2] = inputs[4].add_sub(T3); // E ± (D - C) + auto [E0, E3] = inputs[4].checked_unconditional_add_sub(T2); // E ± (D + C) + auto [E1, E2] = inputs[4].checked_unconditional_add_sub(T3); // E ± (D - C) - auto [F0, F3] = E0.add_sub(A0); - auto [F1, F2] = E0.add_sub(A1); - auto [F4, F7] = E1.add_sub(A0); - auto [F5, F6] = E1.add_sub(A1); - auto [F8, F11] = E2.add_sub(A0); - auto [F9, F10] = E2.add_sub(A1); - auto [F12, F15] = E3.add_sub(A0); - auto [F13, F14] = E3.add_sub(A1); + auto [F0, F3] = E0.checked_unconditional_add_sub(A0); + auto [F1, F2] = E0.checked_unconditional_add_sub(A1); + auto [F4, F7] = E1.checked_unconditional_add_sub(A0); + auto [F5, F6] = E1.checked_unconditional_add_sub(A1); + auto [F8, F11] = E2.checked_unconditional_add_sub(A0); + auto [F9, F10] = E2.checked_unconditional_add_sub(A1); + auto [F12, F15] = E3.checked_unconditional_add_sub(A0); + auto [F13, F14] = E3.checked_unconditional_add_sub(A1); element_table[0] = F0; element_table[1] = F1; @@ -245,33 +247,33 @@ element::lookup_table_plookup::lookup_table_plookup(con } else if constexpr (length == 6) { // 44 adds! Only use this if it saves us adding another table to a multi-scalar-multiplication - auto [A0, A1] = inputs[1].add_sub(inputs[0]); - auto [E0, E1] = inputs[4].add_sub(inputs[3]); - auto [C0, C3] = inputs[2].add_sub(A0); - auto [C1, C2] = inputs[2].add_sub(A1); + auto [A0, A1] = inputs[1].checked_unconditional_add_sub(inputs[0]); + auto [E0, E1] = inputs[4].checked_unconditional_add_sub(inputs[3]); + auto [C0, C3] = inputs[2].checked_unconditional_add_sub(A0); + auto [C1, C2] = inputs[2].checked_unconditional_add_sub(A1); - auto [F0, F3] = inputs[5].add_sub(E0); - auto [F1, F2] = inputs[5].add_sub(E1); + auto [F0, F3] = inputs[5].checked_unconditional_add_sub(E0); + auto [F1, F2] = inputs[5].checked_unconditional_add_sub(E1); - auto [R0, R7] = F0.add_sub(C0); - auto [R1, R6] = F0.add_sub(C1); - auto [R2, R5] = F0.add_sub(C2); - auto [R3, R4] = F0.add_sub(C3); + auto [R0, R7] = F0.checked_unconditional_add_sub(C0); + auto [R1, R6] = F0.checked_unconditional_add_sub(C1); + auto [R2, R5] = F0.checked_unconditional_add_sub(C2); + auto [R3, R4] = F0.checked_unconditional_add_sub(C3); - auto [S0, S7] = F1.add_sub(C0); - auto [S1, S6] = F1.add_sub(C1); - auto [S2, S5] = F1.add_sub(C2); - auto [S3, S4] = F1.add_sub(C3); + auto [S0, S7] = F1.checked_unconditional_add_sub(C0); + auto [S1, S6] = F1.checked_unconditional_add_sub(C1); + auto [S2, S5] = F1.checked_unconditional_add_sub(C2); + auto [S3, S4] = F1.checked_unconditional_add_sub(C3); - auto [U0, U7] = F2.add_sub(C0); - auto [U1, U6] = F2.add_sub(C1); - auto [U2, U5] = F2.add_sub(C2); - auto [U3, U4] = F2.add_sub(C3); + auto [U0, U7] = F2.checked_unconditional_add_sub(C0); + auto [U1, U6] = F2.checked_unconditional_add_sub(C1); + auto [U2, U5] = F2.checked_unconditional_add_sub(C2); + auto [U3, U4] = F2.checked_unconditional_add_sub(C3); - auto [W0, W7] = F3.add_sub(C0); - auto [W1, W6] = F3.add_sub(C1); - auto [W2, W5] = F3.add_sub(C2); - auto [W3, W4] = F3.add_sub(C3); + auto [W0, W7] = F3.checked_unconditional_add_sub(C0); + auto [W1, W6] = F3.checked_unconditional_add_sub(C1); + auto [W2, W5] = F3.checked_unconditional_add_sub(C2); + auto [W3, W4] = F3.checked_unconditional_add_sub(C3); element_table[0] = R0; element_table[1] = R1; @@ -408,7 +410,7 @@ element::lookup_table_plookup::lookup_table_plookup(con template template element element::lookup_table_plookup::get( - const std::array, length>& bits) const + const std::array& bits) const { std::vector> accumulators; for (size_t i = 0; i < length; ++i) { @@ -558,20 +560,20 @@ element::lookup_table_base::lookup_table_base(const std::a template template element element::lookup_table_base::get( - const std::array, length>& bits) const + const std::array& bits) const { static_assert(length <= 4 && length >= 2); if constexpr (length == 2) { - bool_t table_selector = bits[0] ^ bits[1]; - bool_t sign_selector = bits[1]; + bool_t table_selector = bits[0] ^ bits[1]; + bool_t sign_selector = bits[1]; Fq to_add_x = twin0.x.conditional_select(twin1.x, table_selector); Fq to_add_y = twin0.y.conditional_select(twin1.y, table_selector); element to_add(to_add_x, to_add_y.conditional_negate(sign_selector)); return to_add; } else if constexpr (length == 3) { - bool_t t0 = bits[2] ^ bits[0]; - bool_t t1 = bits[2] ^ bits[1]; + bool_t t0 = bits[2] ^ bits[0]; + bool_t t1 = bits[2] ^ bits[1]; field_t x_b0 = field_t::select_from_two_bit_table(x_b0_table, t1, t0); field_t x_b1 = field_t::select_from_two_bit_table(x_b1_table, t1, t0); @@ -604,9 +606,9 @@ element element::lookup_table_base::get( return to_add; } else if constexpr (length == 4) { - bool_t t0 = bits[3] ^ bits[0]; - bool_t t1 = bits[3] ^ bits[1]; - bool_t t2 = bits[3] ^ bits[2]; + bool_t t0 = bits[3] ^ bits[0]; + bool_t t1 = bits[3] ^ bits[1]; + bool_t t2 = bits[3] ^ bits[2]; field_t x_b0 = field_t::select_from_three_bit_table(x_b0_table, t2, t1, t0); field_t x_b1 = field_t::select_from_three_bit_table(x_b1_table, t2, t1, t0); diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/curves/secp256r1.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/curves/secp256r1.hpp index a6593e4f831..5b7a5106f3f 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/curves/secp256r1.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/curves/secp256r1.hpp @@ -11,9 +11,9 @@ namespace bb::stdlib { template struct secp256r1 { static constexpr bb::CurveType type = bb::CurveType::SECP256R1; - typedef ::secp256r1::fq fq; - typedef ::secp256r1::fr fr; - typedef ::secp256r1::g1 g1; + typedef bb::secp256r1::fq fq; + typedef bb::secp256r1::fr fr; + typedef bb::secp256r1::g1 g1; typedef CircuitType Builder; typedef witness_t witness_ct; @@ -23,8 +23,8 @@ template struct secp256r1 { typedef bool_t bool_ct; typedef stdlib::uint32 uint32_ct; - typedef bigfield fq_ct; - typedef bigfield bigfr_ct; + typedef bigfield fq_ct; + typedef bigfield bigfr_ct; typedef element g1_ct; typedef element g1_bigfr_ct; }; diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/databus/databus.test.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/databus/databus.test.cpp index e8daaa52170..5d8f05b50b3 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/databus/databus.test.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/databus/databus.test.cpp @@ -6,6 +6,8 @@ #include "barretenberg/stdlib_circuit_builders/goblin_ultra_circuit_builder.hpp" #include "databus.hpp" +using namespace bb; + using Builder = GoblinUltraCircuitBuilder; using field_ct = stdlib::field_t; using witness_ct = stdlib::witness_t; From e4f45bac1cecee880fe5f7ac9fc34f545189a32f Mon Sep 17 00:00:00 2001 From: zac-williamson Date: Thu, 2 May 2024 17:37:15 +0000 Subject: [PATCH 02/24] initial commit. Updated transcript relations --- .../src/barretenberg/eccvm/eccvm_flavor.hpp | 67 ++++- .../eccvm/eccvm_transcript.test.cpp | 7 + .../src/barretenberg/eccvm/eccvm_verifier.cpp | 7 + .../barretenberg/eccvm/transcript_builder.hpp | 66 ++++- .../ecc_vm/ecc_transcript_relation.cpp | 233 +++++++++++------- .../ecc_vm/ecc_transcript_relation.hpp | 4 +- 6 files changed, 292 insertions(+), 92 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp index c3c26a6ef70..23e0356b162 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp @@ -35,17 +35,17 @@ class ECCVMFlavor { using VerifierCommitmentKey = bb::VerifierCommitmentKey; using RelationSeparator = FF; - static constexpr size_t NUM_WIRES = 74; + static constexpr size_t NUM_WIRES = 81; // The number of multivariate polynomials on which a sumcheck prover sumcheck operates (including shifts). We often // need containers of this size to hold related data, so we choose a name more agnostic than `NUM_POLYNOMIALS`. // Note: this number does not include the individual sorted list polynomials. - static constexpr size_t NUM_ALL_ENTITIES = 105; + static constexpr size_t NUM_ALL_ENTITIES = 112; // The number of polynomials precomputed to describe a circuit and to aid a prover in constructing a satisfying // assignment of witnesses. We again choose a neutral name. static constexpr size_t NUM_PRECOMPUTED_ENTITIES = 3; // The total number of witness entities not including shifts. - static constexpr size_t NUM_WITNESS_ENTITIES = 76; + static constexpr size_t NUM_WITNESS_ENTITIES = 83; using GrandProductRelations = std::tuple>; // define the tuple of Relations that comprise the Sumcheck relation @@ -180,7 +180,14 @@ class ECCVMFlavor { transcript_reset_accumulator, // column 70 precompute_select, // column 71 lookup_read_counts_0, // column 72 - lookup_read_counts_1); // column 73 + lookup_read_counts_1, // column 73 + transcript_base_infinity, // column 74 + transcript_base_x_inverse, // column 75 + transcript_base_y_inverse, // column 76 + transcript_add_x_equal, // column 77 + transcript_add_y_equal, // column 78 + transcript_y_collision_check, // column 79 + transcript_add_lambda); // column 80 }; /** @@ -572,6 +579,13 @@ class ECCVMFlavor { transcript_msm_x[i] = transcript_state[i].msm_output_x; transcript_msm_y[i] = transcript_state[i].msm_output_y; transcript_collision_check[i] = transcript_state[i].collision_check; + transcript_base_infinity[i] = transcript_state[i].base_infinity; + transcript_base_x_inverse[i] = transcript_state[i].base_x_inverse; + transcript_base_y_inverse[i] = transcript_state[i].base_y_inverse; + transcript_add_x_equal[i] = transcript_state[i].transcript_add_x_equal; + transcript_add_y_equal[i] = transcript_state[i].transcript_add_y_equal; + transcript_y_collision_check[i] = transcript_state[i].transcript_y_collision_check; + transcript_add_lambda[i] = transcript_state[i].transcript_add_lambda; } }); @@ -583,6 +597,13 @@ class ECCVMFlavor { transcript_accumulator_empty[i] = 1; } } + // in addition, unless the accumulator is reset, it contains the value from the previous row so this + // must be propagated + for (size_t i = transcript_state.size(); i < num_rows_pow2; ++i) { + transcript_accumulator_x[i] = transcript_accumulator_x[i - 1]; + transcript_accumulator_y[i] = transcript_accumulator_y[i - 1]; + } + run_loop_in_parallel(precompute_table_state.size(), [&](size_t start, size_t end) { for (size_t i = start; i < end; i++) { // first row is always an empty row (to accommodate shifted polynomials which must have 0 as 1st @@ -768,6 +789,13 @@ class ECCVMFlavor { Base::precompute_select = "PRECOMPUTE_SELECT"; Base::lookup_read_counts_0 = "LOOKUP_READ_COUNTS_0"; Base::lookup_read_counts_1 = "LOOKUP_READ_COUNTS_1"; + Base::transcript_base_infinity = "TRANSCRIPT_BASE_INFINITY"; + Base::transcript_base_x_inverse = "TRANSCRIPT_BASE_X_INVERSE"; + Base::transcript_base_y_inverse = "TRANSCRIPT_BASE_Y_INVERSE"; + Base::transcript_add_x_equal = "TRANSCRIPT_ADD_X_EQUAL"; + Base::transcript_add_y_equal = "TRANSCRIPT_ADD_Y_EQUAL"; + Base::transcript_y_collision_check = "TRANSCRIPT_Y_COLLISION_CHECK"; + Base::transcript_add_lambda = "TRANSCRIPT_ADD_LAMBDA"; Base::z_perm = "Z_PERM"; Base::lookup_inverses = "LOOKUP_INVERSES"; // The ones beginning with "__" are only used for debugging @@ -868,6 +896,13 @@ class ECCVMFlavor { Commitment precompute_select_comm; Commitment lookup_read_counts_0_comm; Commitment lookup_read_counts_1_comm; + Commitment transcript_base_infinity_comm; + Commitment transcript_base_x_inverse_comm; + Commitment transcript_base_y_inverse_comm; + Commitment transcript_add_x_equal_comm; + Commitment transcript_add_y_equal_comm; + Commitment transcript_y_collision_check_comm; + Commitment transcript_add_lambda_comm; Commitment z_perm_comm; Commitment lookup_inverses_comm; std::vector> sumcheck_univariates; @@ -1051,6 +1086,20 @@ class ECCVMFlavor { NativeTranscript::proof_data, num_frs_read); lookup_read_counts_1_comm = NativeTranscript::template deserialize_from_buffer( NativeTranscript::proof_data, num_frs_read); + transcript_base_infinity_comm = NativeTranscript::template deserialize_from_buffer( + NativeTranscript::proof_data, num_frs_read); + transcript_base_x_inverse_comm = NativeTranscript::template deserialize_from_buffer( + NativeTranscript::proof_data, num_frs_read); + transcript_base_y_inverse_comm = NativeTranscript::template deserialize_from_buffer( + NativeTranscript::proof_data, num_frs_read); + transcript_add_x_equal_comm = NativeTranscript::template deserialize_from_buffer( + NativeTranscript::proof_data, num_frs_read); + transcript_add_y_equal_comm = NativeTranscript::template deserialize_from_buffer( + NativeTranscript::proof_data, num_frs_read); + transcript_y_collision_check_comm = NativeTranscript::template deserialize_from_buffer( + NativeTranscript::proof_data, num_frs_read); + transcript_add_lambda_comm = NativeTranscript::template deserialize_from_buffer( + NativeTranscript::proof_data, num_frs_read); lookup_inverses_comm = NativeTranscript::template deserialize_from_buffer( NativeTranscript::proof_data, num_frs_read); z_perm_comm = NativeTranscript::template deserialize_from_buffer(NativeTranscript::proof_data, @@ -1195,6 +1244,16 @@ class ECCVMFlavor { NativeTranscript::template serialize_to_buffer(precompute_select_comm, NativeTranscript::proof_data); NativeTranscript::template serialize_to_buffer(lookup_read_counts_0_comm, NativeTranscript::proof_data); NativeTranscript::template serialize_to_buffer(lookup_read_counts_1_comm, NativeTranscript::proof_data); + NativeTranscript::template serialize_to_buffer(transcript_base_infinity_comm, NativeTranscript::proof_data); + NativeTranscript::template serialize_to_buffer(transcript_base_x_inverse_comm, + NativeTranscript::proof_data); + NativeTranscript::template serialize_to_buffer(transcript_base_y_inverse_comm, + NativeTranscript::proof_data); + NativeTranscript::template serialize_to_buffer(transcript_add_x_equal_comm, NativeTranscript::proof_data); + NativeTranscript::template serialize_to_buffer(transcript_add_y_equal_comm, NativeTranscript::proof_data); + NativeTranscript::template serialize_to_buffer(transcript_y_collision_check_comm, + NativeTranscript::proof_data); + NativeTranscript::template serialize_to_buffer(transcript_add_lambda_comm, NativeTranscript::proof_data); NativeTranscript::template serialize_to_buffer(lookup_inverses_comm, NativeTranscript::proof_data); NativeTranscript::template serialize_to_buffer(z_perm_comm, NativeTranscript::proof_data); for (size_t i = 0; i < log_n; ++i) { diff --git a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_transcript.test.cpp b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_transcript.test.cpp index 240f74bb0b9..05a5be24034 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_transcript.test.cpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_transcript.test.cpp @@ -116,6 +116,13 @@ class ECCVMTranscriptTests : public ::testing::Test { manifest_expected.add_entry(round, "PRECOMPUTE_SELECT", frs_per_G); manifest_expected.add_entry(round, "LOOKUP_READ_COUNTS_0", frs_per_G); manifest_expected.add_entry(round, "LOOKUP_READ_COUNTS_1", frs_per_G); + manifest_expected.add_entry(round, "TRANSCRIPT_BASE_INFINITY", frs_per_G); + manifest_expected.add_entry(round, "TRANSCRIPT_BASE_X_INVERSE", frs_per_G); + manifest_expected.add_entry(round, "TRANSCRIPT_BASE_Y_INVERSE", frs_per_G); + manifest_expected.add_entry(round, "TRANSCRIPT_ADD_X_EQUAL", frs_per_G); + manifest_expected.add_entry(round, "TRANSCRIPT_ADD_Y_EQUAL", frs_per_G); + manifest_expected.add_entry(round, "TRANSCRIPT_Y_COLLISION_CHECK", frs_per_G); + manifest_expected.add_entry(round, "TRANSCRIPT_ADD_LAMBDA", frs_per_G); manifest_expected.add_challenge(round, "beta", "gamma"); round++; diff --git a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_verifier.cpp b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_verifier.cpp index aa3cb7f19fa..95f4e5967d3 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_verifier.cpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_verifier.cpp @@ -102,6 +102,13 @@ bool ECCVMVerifier::verify_proof(const HonkProof& proof) commitments.precompute_select = receive_commitment(commitment_labels.precompute_select); commitments.lookup_read_counts_0 = receive_commitment(commitment_labels.lookup_read_counts_0); commitments.lookup_read_counts_1 = receive_commitment(commitment_labels.lookup_read_counts_1); + commitments.transcript_base_infinity = receive_commitment(commitment_labels.transcript_base_infinity); + commitments.transcript_base_x_inverse = receive_commitment(commitment_labels.transcript_base_x_inverse); + commitments.transcript_base_y_inverse = receive_commitment(commitment_labels.transcript_base_y_inverse); + commitments.transcript_add_x_equal = receive_commitment(commitment_labels.transcript_add_x_equal); + commitments.transcript_add_y_equal = receive_commitment(commitment_labels.transcript_add_y_equal); + commitments.transcript_y_collision_check = receive_commitment(commitment_labels.transcript_y_collision_check); + commitments.transcript_add_lambda = receive_commitment(commitment_labels.transcript_add_lambda); // Get challenge for sorted list batching and wire four memory records auto [beta, gamma] = transcript->template get_challenges("beta", "gamma"); diff --git a/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp b/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp index 106d83b5d4b..c723efa43eb 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp @@ -32,6 +32,13 @@ class ECCVMTranscriptBuilder { FF msm_output_x = 0; FF msm_output_y = 0; FF collision_check = 0; + bool base_infinity = 0; + FF base_x_inverse = 0; + FF base_y_inverse = 0; + bool transcript_add_x_equal = false; + bool transcript_add_y_equal = false; + FF transcript_y_collision_check = 0; + FF transcript_add_lambda = 0; }; struct VMState { uint32_t pc = 0; @@ -64,6 +71,11 @@ class ECCVMTranscriptBuilder { std::vector transcript_state(num_transcript_entries); std::vector inverse_trace(num_transcript_entries - 2); + std::vector inverse_trace_x(num_transcript_entries - 2); + std::vector inverse_trace_y(num_transcript_entries - 2); + std::vector transcript_y_collision_check(num_transcript_entries - 2); + std::vector transcript_add_lambda(num_transcript_entries - 2); + VMState state{ .pc = total_number_of_muls, .count = 0, @@ -138,8 +150,53 @@ class ECCVMTranscriptBuilder { row.msm_transition = msm_transition; row.pc = state.pc; row.msm_count = state.count; - row.base_x = (entry.add || entry.mul || entry.eq) ? entry.base_point.x : 0; - row.base_y = (entry.add || entry.mul || entry.eq) ? entry.base_point.y : 0; + auto base_point_infinity = entry.base_point.is_point_at_infinity(); + auto base_point_x = entry.base_point.x; + auto base_point_y = entry.base_point.y; + if ((entry.add || entry.mul || entry.eq) && base_point_infinity) { + base_point_x = 0; + base_point_y = 0; + } + row.base_x = (entry.add || entry.mul || entry.eq) ? base_point_x : 0; + row.base_y = (entry.add || entry.mul || entry.eq) ? base_point_y : 0; + row.base_infinity = (entry.add || entry.mul || entry.eq) ? (base_point_infinity ? 1 : 0) : 0; + if (msm_transition) { + auto lhsx = AffineElement(updated_state.msm_accumulator).x; + auto lhsy = AffineElement(updated_state.msm_accumulator).y; + auto rhsx = (state.accumulator.is_point_at_infinity()) ? 0 : state.accumulator.x; + auto rhsy = (state.accumulator.is_point_at_infinity()) ? 0 : state.accumulator.y; + inverse_trace_x[i] = lhsx - rhsx; + inverse_trace_y[i] = lhsy - rhsy; + } else if (entry.add) { + auto lhsx = base_point_x; + auto lhsy = base_point_y; + auto rhsx = (state.accumulator.is_point_at_infinity()) ? 0 : state.accumulator.x; + auto rhsy = (state.accumulator.is_point_at_infinity()) ? 0 : state.accumulator.y; + inverse_trace_x[i] = lhsx - rhsx; + inverse_trace_y[i] = lhsy - rhsy; + } else { + inverse_trace_x[i] = 0; + inverse_trace_y[i] = 0; + } + + if (entry.add || msm_transition) { + auto lhs = entry.add ? entry.base_point : updated_state.msm_accumulator; + auto rhs = state.accumulator; + row.transcript_add_x_equal = lhs.x == rhs.x; // check infinity? + row.transcript_add_y_equal = lhs.y == rhs.y; + if (lhs.x == rhs.x && !lhs.is_point_at_infinity() && !rhs.is_point_at_infinity()) { + row.transcript_add_lambda = (lhs.x * lhs.x * 3) / (lhs.y * 2); + } else if (!lhs.is_point_at_infinity() && !rhs.is_point_at_infinity()) { + row.transcript_add_lambda = (rhs.y - lhs.y) / (rhs.x - lhs.x); + } else { + row.transcript_add_lambda = 0; + } + } else { + row.transcript_add_x_equal = 0; + row.transcript_add_y_equal = 0; + row.transcript_add_lambda = 0; + } + row.z1 = (entry.mul) ? entry.z1 : 0; row.z2 = (entry.mul) ? entry.z2 : 0; row.z1_zero = z1_zero; @@ -177,15 +234,18 @@ class ECCVMTranscriptBuilder { } FF::batch_invert(&inverse_trace[0], inverse_trace.size()); + FF::batch_invert(&inverse_trace_x[0], inverse_trace.size()); + FF::batch_invert(&inverse_trace_y[0], inverse_trace.size()); for (size_t i = 0; i < inverse_trace.size(); ++i) { transcript_state[i + 1].collision_check = inverse_trace[i]; + transcript_state[i + 1].base_x_inverse = inverse_trace_x[i]; + transcript_state[i + 1].base_y_inverse = inverse_trace_y[i]; } TranscriptState& final_row = transcript_state.back(); final_row.pc = updated_state.pc; final_row.accumulator_x = (updated_state.accumulator.is_point_at_infinity()) ? 0 : updated_state.accumulator.x; final_row.accumulator_y = (updated_state.accumulator.is_point_at_infinity()) ? 0 : updated_state.accumulator.y; final_row.accumulator_empty = updated_state.is_accumulator_empty; - return transcript_state; } }; diff --git a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp index 4e7a4cdbdb6..2372c64d81b 100644 --- a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp +++ b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp @@ -67,10 +67,17 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu auto is_accumulator_empty_shift = View(in.transcript_accumulator_empty_shift); auto q_reset_accumulator = View(in.transcript_reset_accumulator); auto lagrange_second = View(in.lagrange_second); - auto transcript_collision_check = View(in.transcript_collision_check); + auto transcript_Pinfinity = View(in.transcript_base_infinity); + auto transcript_Px_inverse = View(in.transcript_base_x_inverse); + auto transcript_Py_inverse = View(in.transcript_base_y_inverse); + auto transcript_add_x_equal = View(in.transcript_add_x_equal); + auto transcript_add_y_equal = View(in.transcript_add_y_equal); + auto transcript_add_lambda = View(in.transcript_add_lambda); auto is_not_first_row = (-lagrange_first + 1); auto is_not_first_or_last_row = (-lagrange_first + -lagrange_last + 1); + auto is_not_infinity = (-transcript_Pinfinity + 1); + /** * @brief Validate correctness of z1_zero, z2_zero. * If z1_zero = 0 and operation is a MUL, we will write a scalar mul instruction into our multiplication table. @@ -135,56 +142,6 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu std::get<6>(accumulator) += is_not_first_row * (-msm_transition + 1) * (msm_count_delta - q_mul * ((-z1_zero + 1) + (-z2_zero + 1))) * scaling_factor; - /** - * @brief Add multiscalar multiplication result into transcript accumulator. - * If `msm_transition == 1`, we expect msm output to be present on (transcript_msm_x, transcript_msm_y). - * (this is enforced via a lookup protocol). - * If `is_accumulator_empty == 0`, we ADD msm output into transcript_accumulator. - * If `is_accumulator_empty = =1`, we ASSIGN msm output to transcript_accumulator. - * @note the output of an msm cannot be point at infinity (will create unsatisfiable constraints in - * ecc_msm_relation). We assume this does not affect statistical completeness for honest provers. We should validate - * this! - */ - auto add_msm_into_accumulator = msm_transition * (-is_accumulator_empty + 1); - auto x3 = transcript_accumulator_x_shift; - auto y3 = transcript_accumulator_y_shift; - auto x1 = transcript_accumulator_x; - auto y1 = transcript_accumulator_y; - auto x2 = transcript_msm_x; - auto y2 = transcript_msm_y; - auto tmpx = (x3 + x2 + x1) * (x2 - x1) * (x2 - x1) - (y2 - y1) * (y2 - y1); - auto tmpy = (y3 + y1) * (x2 - x1) - (y2 - y1) * (x1 - x3); - std::get<7>(accumulator) += tmpx * add_msm_into_accumulator * scaling_factor; // degree 5 - std::get<8>(accumulator) += tmpy * add_msm_into_accumulator * scaling_factor; // degree 4 - - /** - * @brief If is_accumulator_empty == 1, assign transcript_accumulator output into accumulator - * - * @note The accumulator point for all operations at row `i` is the accumulator point at row `i + 1`! - */ - auto assign_msm_into_accumulator = msm_transition * is_accumulator_empty; - std::get<9>(accumulator) += assign_msm_into_accumulator * (x3 - x2) * scaling_factor; // degree 3 - std::get<10>(accumulator) += assign_msm_into_accumulator * (y3 - y2) * scaling_factor; - - /** - * @brief Constrain `add` opcode. - * - * add will add the input point in (transcript_Px, transcript_Py) into the accumulator. - * Correctly handles case where accumulator is point at infinity. - * TODO: need to add constraints to rule out point doubling case (x2 != x1) - * TODO: need to assert input point is on the curve! - */ - x2 = transcript_Px; - y2 = transcript_Py; - auto add_into_accumulator = q_add * (-is_accumulator_empty + 1); - tmpx = (x3 + x2 + x1) * (x2 - x1) * (x2 - x1) - (y2 - y1) * (y2 - y1); - tmpy = (y3 + y1) * (x2 - x1) - (y2 - y1) * (x1 - x3); - std::get<11>(accumulator) += tmpx * add_into_accumulator * scaling_factor; // degree 5 - std::get<12>(accumulator) += tmpy * add_into_accumulator * scaling_factor; // degree 4 - auto assign_into_accumulator = q_add * is_accumulator_empty; - std::get<13>(accumulator) += (x3 - x2) * assign_into_accumulator * scaling_factor; // degree 3 - std::get<14>(accumulator) += (y3 - y2) * assign_into_accumulator * scaling_factor; - /** * @brief Opcode exclusion tests. We have the following assertions: * 1. If q_mul = 1, (q_add, eq, reset) are zero @@ -195,33 +152,31 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu * @note point 3: both q_add = 1, msm_transition = 1 cannot occur because of point 1 (msm_transition only 1 when * q_mul 1) we can use a slightly more efficient relation than a pure binary OR */ - std::get<15>(accumulator) += q_mul * (q_add + q_eq + q_reset_accumulator) * scaling_factor; - std::get<16>(accumulator) += q_add * (q_mul + q_eq + q_reset_accumulator) * scaling_factor; - std::get<17>(accumulator) += q_reset_accumulator * (-is_accumulator_empty_shift + 1) * scaling_factor; - std::get<18>(accumulator) += (q_add + msm_transition) * is_accumulator_empty_shift * scaling_factor; + std::get<7>(accumulator) += q_mul * (q_add + q_eq + q_reset_accumulator) * scaling_factor; + std::get<8>(accumulator) += q_add * (q_mul + q_eq + q_reset_accumulator) * scaling_factor; + std::get<9>(accumulator) += q_reset_accumulator * (-is_accumulator_empty_shift + 1) * scaling_factor; + // std::get<18>(accumulator) += (q_add + msm_transition) * is_accumulator_empty_shift * scaling_factor; auto accumulator_state_not_modified = -(q_add + msm_transition + q_reset_accumulator) + 1; - std::get<19>(accumulator) += accumulator_state_not_modified * is_not_first_or_last_row * + std::get<10>(accumulator) += accumulator_state_not_modified * is_not_first_or_last_row * (is_accumulator_empty_shift - is_accumulator_empty) * scaling_factor; /** * @brief `eq` opcode. - * If eq = 1, assert transcript_Px/y = transcript_accumulator_x/y. - * If eq = 1, assert is_accumulator_empty = 0 (input point cannot be point at infinity) - */ - std::get<20>(accumulator) += q_eq * (transcript_accumulator_x - transcript_Px) * scaling_factor; - std::get<21>(accumulator) += - q_eq * (-is_accumulator_empty + 1) * (transcript_accumulator_y - transcript_Py) * scaling_factor; - std::get<22>(accumulator) += q_eq * is_accumulator_empty * scaling_factor; + * Let lhs = transcript_P and rhs = transcript_accumulator + * If eq = 1, we must validate the following cases: + * IF lhs and rhs are not at infinity THEN lhs == rhs + * ELSE lhs and rhs are BOTH points at infinity + **/ + auto both_infinity = transcript_Pinfinity * is_accumulator_empty; + auto both_not_infinity = (-transcript_Pinfinity + 1) * (-is_accumulator_empty + 1); + auto infinity_exclusion_check = transcript_Pinfinity + is_accumulator_empty - both_infinity - both_infinity; + auto eq_x_diff = transcript_Px - transcript_accumulator_x; + auto eq_y_diff = transcript_Py - transcript_accumulator_y; + auto eq_x_diff_relation = q_eq * (eq_x_diff * both_not_infinity + infinity_exclusion_check); // degree 4 + auto eq_y_diff_relation = q_eq * (eq_y_diff * both_not_infinity + infinity_exclusion_check); // degree 4 + std::get<11>(accumulator) += eq_x_diff_relation * scaling_factor; - // validate selectors are boolean (put somewhere else? these are low degree) - std::get<23>(accumulator) += q_eq * (q_eq - 1) * scaling_factor; - std::get<24>(accumulator) += q_add * (q_add - 1) * scaling_factor; - std::get<25>(accumulator) += q_mul * (q_mul - 1) * scaling_factor; - std::get<26>(accumulator) += q_reset_accumulator * (q_reset_accumulator - 1) * scaling_factor; - std::get<27>(accumulator) += msm_transition * (msm_transition - 1) * scaling_factor; - std::get<28>(accumulator) += is_accumulator_empty * (is_accumulator_empty - 1) * scaling_factor; - std::get<29>(accumulator) += z1_zero * (z1_zero - 1) * scaling_factor; - std::get<30>(accumulator) += z2_zero * (z2_zero - 1) * scaling_factor; + std::get<12>(accumulator) += eq_y_diff_relation * scaling_factor; /** * @brief Initial condition check on 1st row. @@ -231,28 +186,140 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu * note...actually second row? bleurgh * NOTE: we want pc = 0 at lagrange_last :o */ - std::get<31>(accumulator) += lagrange_second * (-is_accumulator_empty + 1) * scaling_factor; - std::get<32>(accumulator) += lagrange_second * msm_count * scaling_factor; + std::get<13>(accumulator) += lagrange_second * (-is_accumulator_empty + 1) * scaling_factor; + std::get<14>(accumulator) += lagrange_second * msm_count * scaling_factor; /** * @brief On-curve validation checks. * If q_mul = 1 OR q_add = 1 OR q_eq = 1, require (transcript_Px, transcript_Py) is valid ecc point * q_mul/q_add/q_eq mutually exclusive, can represent as sum of 3 */ - const auto validate_on_curve = q_mul; // q_add + q_mul + q_eq; + const auto validate_on_curve = q_mul + q_add + q_mul + q_eq; const auto on_curve_check = transcript_Py * transcript_Py - transcript_Px * transcript_Px * transcript_Px - get_curve_b(); - std::get<33>(accumulator) += validate_on_curve * on_curve_check * scaling_factor; + std::get<15>(accumulator) += validate_on_curve * on_curve_check * is_not_infinity * scaling_factor; // degree 6 /** - * @brief If performing an add, validate x-coordintes of inputs do not collide. - * If adding msm output into accumulator, validate x-coordinates of inputs do not collide + * @brief Validate correctness of ECC Group Operation + * An ECC group operation is performed if q_add = 1 or msm_transition = 1. + * Because input points can be points at infinity, we must support COMPLETE addition and handle points at infinity */ - auto x_coordinate_collision_check = - add_msm_into_accumulator * ((transcript_msm_x - transcript_accumulator_x) * transcript_collision_check - FF(1)); - x_coordinate_collision_check += - add_into_accumulator * ((transcript_Px - transcript_accumulator_x) * transcript_collision_check - FF(1)); - std::get<34>(accumulator) += x_coordinate_collision_check * scaling_factor; + // define the lhs point: either transcript_Px/y or transcript_accumulator_x/y + auto lhs_x = transcript_Px * q_add + transcript_msm_x * msm_transition; + auto lhs_y = transcript_Py * q_add + transcript_msm_y * msm_transition; + // the rhs point will always be the accumulator point at the next row in the trace + auto rhs_x = transcript_accumulator_x; + auto rhs_y = transcript_accumulator_y; + // the group operation will be either an ADD or a DOUBLE depending on whether x/y coordinates of lhs/rhs match. + // If lhs_x == rhs_x, we evaluate a DOUBLE, otherwise an ADD + // (we will only activate this relation if lhs_y != rhs_y, but this is done later) + auto ecc_op_is_dbl = transcript_add_x_equal; + auto ecc_op_is_add = (-transcript_add_x_equal + 1); + // Are the lhs/rhs points at infinity? + // MSM output CANNOT be point at infinity without triggering unsatisfiable constraints in msm_relation + // lhs can only be at infinity if q_add is active + auto lhs_infinity = transcript_Pinfinity * q_add; + auto rhs_infinity = is_accumulator_empty; + // Determine where the group operation output is sourced from + // | lhs_infinity | rhs_infinity | lhs_x == rhs_x && lhs_y != rhs_y | output | + // | ------------ | ------------ | -------------------------------- | --------- | + // | 0 | 0 | 0 | lhs + rhs | + // | 0 | 0 | 1 | infinity | + // | 0 | 1 | n/a | lhs | + // | 1 | 0 | n/a | rhs | + // | 1 | 1 | n/a | infinity | + auto add_result_is_lhs = rhs_infinity * (-lhs_infinity + 1); // degree 3 + auto add_result_is_rhs = lhs_infinity * (-rhs_infinity + 1); // degree 3 + auto add_result_is_out = (-lhs_infinity + 1) * (-rhs_infinity + 1); // degree 3 + auto add_result_infinity_from_inputs = lhs_infinity * rhs_infinity; // degree 2 + auto add_result_infinity_from_operation = transcript_add_x_equal * (-transcript_add_y_equal + 1); // degree 2 + auto add_result_is_infinity = add_result_infinity_from_inputs + add_result_infinity_from_operation; // degree 2?? + + // Determine the gradient `lambda` of the group operation + // If lhs_x == rhs_x, lambda = (3 * lhs_x * lhs_x) / (2 * lhs_y) + // Else, lambda = (rhs_y - lhs_y) / (rhs_x - lhs_x) + auto lhs_xx = lhs_x * lhs_x; + auto lambda_numerator = (rhs_y - lhs_y) * ecc_op_is_add + (lhs_xx + lhs_xx + lhs_xx) * ecc_op_is_dbl; + auto lambda_denominator = (rhs_x - lhs_x) * ecc_op_is_add + (lhs_y + lhs_y) * ecc_op_is_dbl; // degree 3 + auto lambda_term = lambda_denominator * transcript_add_lambda - lambda_numerator; // degree 4 + // We only activate lambda relation if we don't have points at infinity - this is to avoid divide-by-zero problems + // N.B. check this is needed + auto any_add_is_active = q_add + msm_transition; + auto lambda_relation_active = any_add_is_active * add_result_is_out; // degree 4 + auto lambda_relation = lambda_term * lambda_relation_active; // degree 8! + std::get<16>(accumulator) += lambda_relation * scaling_factor; // degree 8 + + // Determine the x/y coordinates of the shifted accumulator + // add_x3/add_y3 = result of group operation computation + auto add_x3 = transcript_add_lambda * transcript_add_lambda - lhs_x - rhs_x; // degree 2 + auto add_y3 = transcript_add_lambda * (lhs_x - add_x3) - lhs_y; // degree 3 + // x3/y3 = result of group operation computation that considers input points at infinity + auto x3 = (add_x3 * add_result_is_out + lhs_x * add_result_is_lhs + rhs_x * add_result_is_rhs); // degree 5 + auto y3 = (add_y3 * add_result_is_out + lhs_y * add_result_is_lhs + rhs_y * add_result_is_rhs); // degree 6 + + auto propagate_transcript_accumulator = (-q_add - msm_transition - q_reset_accumulator + 1); + auto add_point_x_relation = + (x3 - transcript_accumulator_x_shift * (add_result_is_out + add_result_is_lhs + add_result_is_rhs)) * + any_add_is_active; // degree 7 + add_point_x_relation += propagate_transcript_accumulator * (-lagrange_last + 1) * + (transcript_accumulator_x_shift - transcript_accumulator_x); + auto add_point_y_relation = + (y3 - transcript_accumulator_y_shift * (add_result_is_out + add_result_is_lhs + add_result_is_rhs)) * + any_add_is_active; // degree 7 + add_point_y_relation += propagate_transcript_accumulator * (-lagrange_last + 1) * + (transcript_accumulator_y_shift - transcript_accumulator_y); + std::get<17>(accumulator) += add_point_x_relation * scaling_factor; // degree 7 + std::get<18>(accumulator) += add_point_y_relation * scaling_factor; // degree 8 + + /** + * @brief Validate `is_accumulator_empty` is updated correctly + * An add operation can produce a point at infinity + * Resetting the accumulator produces a point at infinity + * If we are not adding, performing an msm or resetting the accumulator, is_accumulator_empty should not update + */ + auto accumulator_infinity_preserve_flag = (-(q_add + msm_transition + q_reset_accumulator) + 1); + auto accumulator_infinity_preserve = + accumulator_infinity_preserve_flag * (is_accumulator_empty - is_accumulator_empty_shift) * (-lagrange_last + 1); + auto accumulator_infinity_q_reset = q_reset_accumulator * (-is_accumulator_empty_shift + 1); + auto accumulator_infinity_from_add = any_add_is_active * (add_result_is_infinity - is_accumulator_empty_shift); + auto accumulator_infinity_relation = + accumulator_infinity_preserve + accumulator_infinity_q_reset + accumulator_infinity_from_add; + std::get<19>(accumulator) += (accumulator_infinity_relation * is_not_first_row) * scaling_factor; // degree 5? + + /** + * @brief Validate `transcript_add_x_equal` is well-formed + * If lhs_x == rhs_x, transcript_add_x_equal = 1 + * If transcript_add_x_equal = 0, a valid inverse must exist for (lhs_x - rhs_x) + */ + auto x_diff = lhs_x - rhs_x; + auto x_product = transcript_Px_inverse * (-transcript_add_x_equal + 1) + transcript_add_x_equal; + auto x_constant = transcript_add_x_equal - 1; + auto transcript_add_x_equal_check_relation = (x_diff * x_product + x_constant) * any_add_is_active; + std::get<20>(accumulator) += transcript_add_x_equal_check_relation * scaling_factor; // degree 6 + + /** + * @brief Validate `transcript_add_y_equal` is well-formed + * If lhs_y == rhs_y, transcript_add_y_equal = 1 + * If transcript_add_y_equal = 0, a valid inverse must exist for (lhs_y - rhs_y) + */ + auto y_diff = lhs_y - rhs_y; + auto y_product = transcript_Py_inverse * (-transcript_add_y_equal + 1) + transcript_add_y_equal; + auto y_constant = transcript_add_y_equal - 1; + auto transcript_add_y_equal_check_relation = (y_diff * y_product + y_constant) * any_add_is_active; + std::get<21>(accumulator) += transcript_add_y_equal_check_relation * scaling_factor; // degree 6 + + // validate selectors are boolean (put somewhere else? these are low degree) + std::get<22>(accumulator) += q_eq * (q_eq - 1) * scaling_factor; + std::get<23>(accumulator) += q_add * (q_add - 1) * scaling_factor; + std::get<24>(accumulator) += q_mul * (q_mul - 1) * scaling_factor; + std::get<25>(accumulator) += q_reset_accumulator * (q_reset_accumulator - 1) * scaling_factor; + std::get<26>(accumulator) += msm_transition * (msm_transition - 1) * scaling_factor; + std::get<27>(accumulator) += is_accumulator_empty * (is_accumulator_empty - 1) * scaling_factor; + std::get<28>(accumulator) += z1_zero * (z1_zero - 1) * scaling_factor; + std::get<29>(accumulator) += z2_zero * (z2_zero - 1) * scaling_factor; + std::get<30>(accumulator) += transcript_add_x_equal * (transcript_add_x_equal - 1) * scaling_factor; + std::get<31>(accumulator) += transcript_add_y_equal * (transcript_add_y_equal - 1) * scaling_factor; + std::get<32>(accumulator) += transcript_Pinfinity * (transcript_Pinfinity - 1) * scaling_factor; } template class ECCVMTranscriptRelationImpl; diff --git a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.hpp b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.hpp index ef511e41331..ee2904f82dd 100644 --- a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.hpp +++ b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.hpp @@ -30,8 +30,8 @@ template class ECCVMTranscriptRelationImpl { public: using FF = FF_; - static constexpr std::array SUBRELATION_PARTIAL_LENGTHS{ - 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + static constexpr std::array SUBRELATION_PARTIAL_LENGTHS{ + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, }; template From 223daa728e154c196987e8a8c781553a225a3b59 Mon Sep 17 00:00:00 2001 From: zac-williamson Date: Fri, 3 May 2024 12:28:44 +0000 Subject: [PATCH 03/24] added tests for eccvm points at infinity --- .../eccvm/eccvm_circuit_builder.test.cpp | 34 +++++++++++++++++++ .../barretenberg/eccvm/transcript_builder.hpp | 8 +++-- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.test.cpp b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.test.cpp index 253c1af019c..395fedf0ed4 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.test.cpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.test.cpp @@ -243,3 +243,37 @@ TEST(ECCVMCircuitBuilderTests, MSM) bool result = ECCVMTraceChecker::check(circuit); EXPECT_EQ(result, true); } + +TEST(ECCVMCircuitBuilderTests, EqAgainstPointAtInfinity) +{ + std::shared_ptr op_queue = std::make_shared(); + + auto generators = G1::derive_generators("test generators", 3); + typename G1::element a = generators[0]; + a.self_set_infinity(); + + op_queue->add_accumulate(a); + op_queue->eq_and_reset(); + + ECCVMCircuitBuilder circuit{ op_queue }; + bool result = ECCVMTraceChecker::check(circuit); + EXPECT_EQ(result, true); +} + +TEST(ECCVMCircuitBuilderTests, AddPointAtInfinity) +{ + std::shared_ptr op_queue = std::make_shared(); + + auto generators = G1::derive_generators("test generators", 3); + typename G1::element a = generators[0]; + typename G1::element b = generators[0]; + b.self_set_infinity(); + + op_queue->add_accumulate(a); + op_queue->add_accumulate(b); + op_queue->eq_and_reset(); + + ECCVMCircuitBuilder circuit{ op_queue }; + bool result = ECCVMTraceChecker::check(circuit); + EXPECT_EQ(result, true); +} \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp b/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp index c723efa43eb..40b04148b5c 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp @@ -140,7 +140,7 @@ class ECCVMTranscriptBuilder { } else { updated_state.accumulator = typename CycleGroup::element(state.accumulator) + entry.base_point; } - updated_state.is_accumulator_empty = false; + updated_state.is_accumulator_empty = updated_state.accumulator.is_point_at_infinity(); } row.accumulator_empty = state.is_accumulator_empty; row.q_add = entry.add; @@ -182,8 +182,10 @@ class ECCVMTranscriptBuilder { if (entry.add || msm_transition) { auto lhs = entry.add ? entry.base_point : updated_state.msm_accumulator; auto rhs = state.accumulator; - row.transcript_add_x_equal = lhs.x == rhs.x; // check infinity? - row.transcript_add_y_equal = lhs.y == rhs.y; + row.transcript_add_x_equal = + lhs.x == rhs.x || (lhs.is_point_at_infinity() && rhs.is_point_at_infinity()); // check infinity? + row.transcript_add_y_equal = + lhs.y == rhs.y || (lhs.is_point_at_infinity() && rhs.is_point_at_infinity()); if (lhs.x == rhs.x && !lhs.is_point_at_infinity() && !rhs.is_point_at_infinity()) { row.transcript_add_lambda = (lhs.x * lhs.x * 3) / (lhs.y * 2); } else if (!lhs.is_point_at_infinity() && !rhs.is_point_at_infinity()) { From 24a3c8771826e40936d904387f7ca7d4a39e932a Mon Sep 17 00:00:00 2001 From: zac-williamson Date: Fri, 10 May 2024 13:53:56 +0000 Subject: [PATCH 04/24] points at infinity now fully supported by eccvm --- .../src/barretenberg/eccvm/eccvm_flavor.hpp | 199 +++++++++------- .../eccvm/eccvm_transcript.test.cpp | 4 + .../src/barretenberg/eccvm/eccvm_verifier.cpp | 4 + .../src/barretenberg/eccvm/msm_builder.hpp | 49 ++-- .../barretenberg/eccvm/transcript_builder.hpp | 212 +++++++++++------- .../relations/ecc_vm/ecc_msm_relation.cpp | 43 +++- .../ecc_vm/ecc_transcript_relation.cpp | 58 ++++- .../ecc_vm/ecc_transcript_relation.hpp | 5 +- 8 files changed, 372 insertions(+), 202 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp index 23e0356b162..64742c084d9 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp @@ -35,17 +35,17 @@ class ECCVMFlavor { using VerifierCommitmentKey = bb::VerifierCommitmentKey; using RelationSeparator = FF; - static constexpr size_t NUM_WIRES = 81; + static constexpr size_t NUM_WIRES = 85; // The number of multivariate polynomials on which a sumcheck prover sumcheck operates (including shifts). We often // need containers of this size to hold related data, so we choose a name more agnostic than `NUM_POLYNOMIALS`. // Note: this number does not include the individual sorted list polynomials. - static constexpr size_t NUM_ALL_ENTITIES = 112; + static constexpr size_t NUM_ALL_ENTITIES = 116; // The number of polynomials precomputed to describe a circuit and to aid a prover in constructing a satisfying // assignment of witnesses. We again choose a neutral name. static constexpr size_t NUM_PRECOMPUTED_ENTITIES = 3; // The total number of witness entities not including shifts. - static constexpr size_t NUM_WITNESS_ENTITIES = 83; + static constexpr size_t NUM_WITNESS_ENTITIES = 87; using GrandProductRelations = std::tuple>; // define the tuple of Relations that comprise the Sumcheck relation @@ -107,87 +107,91 @@ class ECCVMFlavor { template class WireEntities { public: DEFINE_FLAVOR_MEMBERS(DataType, - transcript_add, // column 0 - transcript_mul, // column 1 - transcript_eq, // column 2 - transcript_collision_check, // column 3 - transcript_msm_transition, // column 4 - transcript_pc, // column 5 - transcript_msm_count, // column 6 - transcript_Px, // column 7 - transcript_Py, // column 8 - transcript_z1, // column 9 - transcript_z2, // column 10 - transcript_z1zero, // column 11 - transcript_z2zero, // column 12 - transcript_op, // column 13 - transcript_accumulator_x, // column 14 - transcript_accumulator_y, // column 15 - transcript_msm_x, // column 16 - transcript_msm_y, // column 17 - precompute_pc, // column 18 - precompute_point_transition, // column 19 - precompute_round, // column 20 - precompute_scalar_sum, // column 21 - precompute_s1hi, // column 22 - precompute_s1lo, // column 23 - precompute_s2hi, // column 24 - precompute_s2lo, // column 25 - precompute_s3hi, // column 26 - precompute_s3lo, // column 27 - precompute_s4hi, // column 28 - precompute_s4lo, // column 29 - precompute_skew, // column 30 - precompute_dx, // column 31 - precompute_dy, // column 32 - precompute_tx, // column 33 - precompute_ty, // column 34 - msm_transition, // column 35 - msm_add, // column 36 - msm_double, // column 37 - msm_skew, // column 38 - msm_accumulator_x, // column 39 - msm_accumulator_y, // column 40 - msm_pc, // column 41 - msm_size_of_msm, // column 42 - msm_count, // column 43 - msm_round, // column 44 - msm_add1, // column 45 - msm_add2, // column 46 - msm_add3, // column 47 - msm_add4, // column 48 - msm_x1, // column 49 - msm_y1, // column 50 - msm_x2, // column 51 - msm_y2, // column 52 - msm_x3, // column 53 - msm_y3, // column 54 - msm_x4, // column 55 - msm_y4, // column 56 - msm_collision_x1, // column 57 - msm_collision_x2, // column 58 - msm_collision_x3, // column 59 - msm_collision_x4, // column 60 - msm_lambda1, // column 61 - msm_lambda2, // column 62 - msm_lambda3, // column 63 - msm_lambda4, // column 64 - msm_slice1, // column 65 - msm_slice2, // column 66 - msm_slice3, // column 67 - msm_slice4, // column 68 - transcript_accumulator_empty, // column 69 - transcript_reset_accumulator, // column 70 - precompute_select, // column 71 - lookup_read_counts_0, // column 72 - lookup_read_counts_1, // column 73 - transcript_base_infinity, // column 74 - transcript_base_x_inverse, // column 75 - transcript_base_y_inverse, // column 76 - transcript_add_x_equal, // column 77 - transcript_add_y_equal, // column 78 - transcript_y_collision_check, // column 79 - transcript_add_lambda); // column 80 + transcript_add, // column 0 + transcript_mul, // column 1 + transcript_eq, // column 2 + transcript_collision_check, // column 3 + transcript_msm_transition, // column 4 + transcript_pc, // column 5 + transcript_msm_count, // column 6 + transcript_Px, // column 7 + transcript_Py, // column 8 + transcript_z1, // column 9 + transcript_z2, // column 10 + transcript_z1zero, // column 11 + transcript_z2zero, // column 12 + transcript_op, // column 13 + transcript_accumulator_x, // column 14 + transcript_accumulator_y, // column 15 + transcript_msm_x, // column 16 + transcript_msm_y, // column 17 + precompute_pc, // column 18 + precompute_point_transition, // column 19 + precompute_round, // column 20 + precompute_scalar_sum, // column 21 + precompute_s1hi, // column 22 + precompute_s1lo, // column 23 + precompute_s2hi, // column 24 + precompute_s2lo, // column 25 + precompute_s3hi, // column 26 + precompute_s3lo, // column 27 + precompute_s4hi, // column 28 + precompute_s4lo, // column 29 + precompute_skew, // column 30 + precompute_dx, // column 31 + precompute_dy, // column 32 + precompute_tx, // column 33 + precompute_ty, // column 34 + msm_transition, // column 35 + msm_add, // column 36 + msm_double, // column 37 + msm_skew, // column 38 + msm_accumulator_x, // column 39 + msm_accumulator_y, // column 40 + msm_pc, // column 41 + msm_size_of_msm, // column 42 + msm_count, // column 43 + msm_round, // column 44 + msm_add1, // column 45 + msm_add2, // column 46 + msm_add3, // column 47 + msm_add4, // column 48 + msm_x1, // column 49 + msm_y1, // column 50 + msm_x2, // column 51 + msm_y2, // column 52 + msm_x3, // column 53 + msm_y3, // column 54 + msm_x4, // column 55 + msm_y4, // column 56 + msm_collision_x1, // column 57 + msm_collision_x2, // column 58 + msm_collision_x3, // column 59 + msm_collision_x4, // column 60 + msm_lambda1, // column 61 + msm_lambda2, // column 62 + msm_lambda3, // column 63 + msm_lambda4, // column 64 + msm_slice1, // column 65 + msm_slice2, // column 66 + msm_slice3, // column 67 + msm_slice4, // column 68 + transcript_accumulator_empty, // column 69 + transcript_reset_accumulator, // column 70 + precompute_select, // column 71 + lookup_read_counts_0, // column 72 + lookup_read_counts_1, // column 73 + transcript_base_infinity, // column 74 + transcript_base_x_inverse, // column 75 + transcript_base_y_inverse, // column 76 + transcript_add_x_equal, // column 77 + transcript_add_y_equal, // column 78 + transcript_y_collision_check, // column 79 + transcript_add_lambda, // column 80 + transcript_msm_intermediate_x, // column 81 + transcript_msm_intermediate_y, // column 82 + transcript_msm_infinity, // column 83 + transcript_msm_x_inverse); // column 84 }; /** @@ -586,6 +590,10 @@ class ECCVMFlavor { transcript_add_y_equal[i] = transcript_state[i].transcript_add_y_equal; transcript_y_collision_check[i] = transcript_state[i].transcript_y_collision_check; transcript_add_lambda[i] = transcript_state[i].transcript_add_lambda; + transcript_msm_intermediate_x[i] = transcript_state[i].transcript_msm_intermediate_x; + transcript_msm_intermediate_y[i] = transcript_state[i].transcript_msm_intermediate_y; + transcript_msm_infinity[i] = transcript_state[i].transcript_msm_infinity; + transcript_msm_x_inverse[i] = transcript_state[i].transcript_msm_x_inverse; } }); @@ -796,6 +804,10 @@ class ECCVMFlavor { Base::transcript_add_y_equal = "TRANSCRIPT_ADD_Y_EQUAL"; Base::transcript_y_collision_check = "TRANSCRIPT_Y_COLLISION_CHECK"; Base::transcript_add_lambda = "TRANSCRIPT_ADD_LAMBDA"; + Base::transcript_msm_intermediate_x = "TRANSCRIPT_MSM_INTERMEDIATE_X"; + Base::transcript_msm_intermediate_y = "TRANSCRIPT_MSM_INTERMEDIATE_Y"; + Base::transcript_msm_infinity = "TRANSCRIPT_MSM_INFINITY"; + Base::transcript_msm_x_inverse = "TRANSCRIPT_MSM_X_INVERSE"; Base::z_perm = "Z_PERM"; Base::lookup_inverses = "LOOKUP_INVERSES"; // The ones beginning with "__" are only used for debugging @@ -903,6 +915,10 @@ class ECCVMFlavor { Commitment transcript_add_y_equal_comm; Commitment transcript_y_collision_check_comm; Commitment transcript_add_lambda_comm; + Commitment transcript_msm_intermediate_x_comm; + Commitment transcript_msm_intermediate_y_comm; + Commitment transcript_msm_infinity_comm; + Commitment transcript_msm_x_inverse_comm; Commitment z_perm_comm; Commitment lookup_inverses_comm; std::vector> sumcheck_univariates; @@ -1100,6 +1116,14 @@ class ECCVMFlavor { NativeTranscript::proof_data, num_frs_read); transcript_add_lambda_comm = NativeTranscript::template deserialize_from_buffer( NativeTranscript::proof_data, num_frs_read); + transcript_msm_intermediate_x_comm = NativeTranscript::template deserialize_from_buffer( + NativeTranscript::proof_data, num_frs_read); + transcript_msm_intermediate_y_comm = NativeTranscript::template deserialize_from_buffer( + NativeTranscript::proof_data, num_frs_read); + transcript_msm_infinity_comm = NativeTranscript::template deserialize_from_buffer( + NativeTranscript::proof_data, num_frs_read); + transcript_msm_x_inverse_comm = NativeTranscript::template deserialize_from_buffer( + NativeTranscript::proof_data, num_frs_read); lookup_inverses_comm = NativeTranscript::template deserialize_from_buffer( NativeTranscript::proof_data, num_frs_read); z_perm_comm = NativeTranscript::template deserialize_from_buffer(NativeTranscript::proof_data, @@ -1254,6 +1278,13 @@ class ECCVMFlavor { NativeTranscript::template serialize_to_buffer(transcript_y_collision_check_comm, NativeTranscript::proof_data); NativeTranscript::template serialize_to_buffer(transcript_add_lambda_comm, NativeTranscript::proof_data); + + NativeTranscript::template serialize_to_buffer(transcript_msm_intermediate_x_comm, + NativeTranscript::proof_data); + NativeTranscript::template serialize_to_buffer(transcript_msm_intermediate_y_comm, + NativeTranscript::proof_data); + NativeTranscript::template serialize_to_buffer(transcript_msm_infinity_comm, NativeTranscript::proof_data); + NativeTranscript::template serialize_to_buffer(transcript_msm_x_inverse_comm, NativeTranscript::proof_data); NativeTranscript::template serialize_to_buffer(lookup_inverses_comm, NativeTranscript::proof_data); NativeTranscript::template serialize_to_buffer(z_perm_comm, NativeTranscript::proof_data); for (size_t i = 0; i < log_n; ++i) { diff --git a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_transcript.test.cpp b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_transcript.test.cpp index 05a5be24034..08821bf76b7 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_transcript.test.cpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_transcript.test.cpp @@ -123,6 +123,10 @@ class ECCVMTranscriptTests : public ::testing::Test { manifest_expected.add_entry(round, "TRANSCRIPT_ADD_Y_EQUAL", frs_per_G); manifest_expected.add_entry(round, "TRANSCRIPT_Y_COLLISION_CHECK", frs_per_G); manifest_expected.add_entry(round, "TRANSCRIPT_ADD_LAMBDA", frs_per_G); + manifest_expected.add_entry(round, "TRANSCRIPT_MSM_INTERMEDIATE_X", frs_per_G); + manifest_expected.add_entry(round, "TRANSCRIPT_MSM_INTERMEDIATE_Y", frs_per_G); + manifest_expected.add_entry(round, "TRANSCRIPT_MSM_INFINITY", frs_per_G); + manifest_expected.add_entry(round, "TRANSCRIPT_MSM_X_INVERSE", frs_per_G); manifest_expected.add_challenge(round, "beta", "gamma"); round++; diff --git a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_verifier.cpp b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_verifier.cpp index 95f4e5967d3..8a6e97cb770 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_verifier.cpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_verifier.cpp @@ -109,6 +109,10 @@ bool ECCVMVerifier::verify_proof(const HonkProof& proof) commitments.transcript_add_y_equal = receive_commitment(commitment_labels.transcript_add_y_equal); commitments.transcript_y_collision_check = receive_commitment(commitment_labels.transcript_y_collision_check); commitments.transcript_add_lambda = receive_commitment(commitment_labels.transcript_add_lambda); + commitments.transcript_msm_intermediate_x = receive_commitment(commitment_labels.transcript_msm_intermediate_x); + commitments.transcript_msm_intermediate_y = receive_commitment(commitment_labels.transcript_msm_intermediate_y); + commitments.transcript_msm_infinity = receive_commitment(commitment_labels.transcript_msm_infinity); + commitments.transcript_msm_x_inverse = receive_commitment(commitment_labels.transcript_msm_x_inverse); // Get challenge for sorted list batching and wire four memory records auto [beta, gamma] = transcript->template get_challenges("beta", "gamma"); diff --git a/barretenberg/cpp/src/barretenberg/eccvm/msm_builder.hpp b/barretenberg/cpp/src/barretenberg/eccvm/msm_builder.hpp index 3be74f357aa..2d01476ecd9 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/msm_builder.hpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/msm_builder.hpp @@ -204,14 +204,15 @@ class ECCVMMSMMBuilder { // accumulator_trace tracks the value of the ECCVM accumulator for each row std::span accumulator_trace(&point_trace[num_point_adds_and_doubles * 3], num_accumulators); - // we start the accumulator at the point at infinity - accumulator_trace[0] = (CycleGroup::affine_point_at_infinity); + // we start the accumulator at the offset generator point. This ensures we can support an MSM that produces a + constexpr auto offset_generator = bb::g1::derive_generators("ECCVM_OFFSET_GENERATOR", 1)[0]; + accumulator_trace[0] = offset_generator; // populate point trace data, and the components of the MSM execution trace that do not relate to affine point // operations run_loop_in_parallel(msms.size(), [&](size_t start, size_t end) { for (size_t i = start; i < end; i++) { - Element accumulator = CycleGroup::affine_point_at_infinity; + Element accumulator = offset_generator; const auto& msm = msms[i]; size_t msm_row_index = msm_row_indices[i]; const size_t msm_size = msm.size(); @@ -247,20 +248,9 @@ class ECCVMMSMMBuilder { ? msm[idx + m].precomputed_table[static_cast(add_state.slice)] : AffineElement{ 0, 0 }; - // predicate logic: - // add_predicate should normally equal add_state.add - // However! if j == 0 AND k == 0 AND m == 0 this implies we are examing the 1st point - // addition of a new MSM In this case, we do NOT add the 1st point into the accumulator, - // instead we SET the accumulator to equal the 1st point. add_predicate is used to - // determine whether we add the output of a point addition into the accumulator, - // therefore if j == 0 AND k == 0 AND m == 0, add_predicate = 0 even if add_state.add = - // true - bool add_predicate = (m == 0 ? (j != 0 || k != 0) : add_state.add); - - Element p1 = (m == 0) ? Element(add_state.point) : accumulator; - Element p2 = (m == 0) ? accumulator : Element(add_state.point); - - accumulator = add_predicate ? (accumulator + add_state.point) : Element(p1); + Element p1 = accumulator; + Element p2 = Element(add_state.point); + accumulator = add_state.add ? (accumulator + add_state.point) : Element(p1); p1_trace[trace_index] = p1; p2_trace[trace_index] = p2; p3_trace[trace_index] = accumulator; @@ -385,20 +375,16 @@ class ECCVMMSMMBuilder { for (size_t k = 0; k < rows_per_round; ++k) { auto& row = msm_state[msm_row_index]; const Element& normalized_accumulator = accumulator_trace[accumulator_index]; - const FF& acc_x = normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.x; - const FF& acc_y = normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.y; - row.accumulator_x = acc_x; - row.accumulator_y = acc_y; - + ASSERT(normalized_accumulator.is_point_at_infinity() == 0); + row.accumulator_x = normalized_accumulator.x; + row.accumulator_y = normalized_accumulator.y; for (size_t m = 0; m < ADDITIONS_PER_ROW; ++m) { auto& add_state = row.add_state[m]; - bool add_predicate = (m == 0 ? (j != 0 || k != 0) : add_state.add); - const auto& inverse = inverse_trace[trace_index]; const auto& p1 = p1_trace[trace_index]; const auto& p2 = p2_trace[trace_index]; - add_state.collision_inverse = add_predicate ? inverse : 0; - add_state.lambda = add_predicate ? (p2.y - p1.y) * inverse : 0; + add_state.collision_inverse = add_state.add ? inverse : 0; + add_state.lambda = add_state.add ? (p2.y - p1.y) * inverse : 0; trace_index++; } accumulator_index++; @@ -427,15 +413,10 @@ class ECCVMMSMMBuilder { for (size_t k = 0; k < rows_per_round; ++k) { MSMState& row = msm_state[msm_row_index]; const Element& normalized_accumulator = accumulator_trace[accumulator_index]; - + ASSERT(normalized_accumulator.is_point_at_infinity() == 0); const size_t idx = k * ADDITIONS_PER_ROW; - - const FF& acc_x = - normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.x; - const FF& acc_y = - normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.y; - row.accumulator_x = acc_x; - row.accumulator_y = acc_y; + row.accumulator_x = normalized_accumulator.x; + row.accumulator_y = normalized_accumulator.y; for (size_t m = 0; m < ADDITIONS_PER_ROW; ++m) { auto& add_state = row.add_state[m]; diff --git a/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp b/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp index 40b04148b5c..9cce05ae6bc 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp @@ -39,12 +39,28 @@ class ECCVMTranscriptBuilder { bool transcript_add_y_equal = false; FF transcript_y_collision_check = 0; FF transcript_add_lambda = 0; + FF transcript_msm_intermediate_x = 0; + FF transcript_msm_intermediate_y = 0; + bool transcript_msm_infinity = false; + FF transcript_msm_x_inverse = 0; }; + + static AffineElement offset_generator() + { + static constexpr auto offset_generator_base = CycleGroup::derive_generators("ECCVM_OFFSET_GENERATOR", 1)[0]; + static const AffineElement result = + AffineElement(Element(offset_generator_base) * grumpkin::fq(uint256_t(1) << 124)); + return result; + } + static AffineElement remove_offset_generator(const AffineElement& other) + { + return AffineElement(Element(other) - offset_generator()); + } struct VMState { uint32_t pc = 0; uint32_t count = 0; - AffineElement accumulator = CycleGroup::affine_point_at_infinity; - AffineElement msm_accumulator = CycleGroup::affine_point_at_infinity; + Element accumulator = CycleGroup::affine_point_at_infinity; + Element msm_accumulator = offset_generator(); bool is_accumulator_empty = true; }; struct Opcode { @@ -68,19 +84,27 @@ class ECCVMTranscriptBuilder { const std::vector>& vm_operations, const uint32_t total_number_of_muls) { const size_t num_transcript_entries = vm_operations.size() + 2; - + const size_t num_vm_entries = vm_operations.size(); std::vector transcript_state(num_transcript_entries); - std::vector inverse_trace(num_transcript_entries - 2); - std::vector inverse_trace_x(num_transcript_entries - 2); - std::vector inverse_trace_y(num_transcript_entries - 2); - std::vector transcript_y_collision_check(num_transcript_entries - 2); - std::vector transcript_add_lambda(num_transcript_entries - 2); + // These vectors track quantities that we need to invert. + // We fill these vectors and then perform batch inversions to amortize the cost of FF inverts + std::vector inverse_trace(num_vm_entries); + std::vector inverse_trace_x(num_vm_entries); + std::vector inverse_trace_y(num_vm_entries); + std::vector transcript_y_collision_check(num_vm_entries); + std::vector transcript_add_lambda(num_vm_entries); + std::vector transcript_msm_x_inverse_trace(num_vm_entries); + std::vector add_lambda_denominator(num_vm_entries); + std::vector add_lambda_numerator(num_vm_entries); + std::vector msm_accumulator_trace(num_vm_entries); + std::vector accumulator_trace(num_vm_entries); + std::vector intermediate_accumulator_trace(num_vm_entries); VMState state{ .pc = total_number_of_muls, .count = 0, .accumulator = CycleGroup::affine_point_at_infinity, - .msm_accumulator = CycleGroup::affine_point_at_infinity, + .msm_accumulator = offset_generator(), .is_accumulator_empty = true, }; VMState updated_state; @@ -99,7 +123,7 @@ class ECCVMTranscriptBuilder { if (entry.reset) { updated_state.is_accumulator_empty = true; - updated_state.msm_accumulator = CycleGroup::affine_point_at_infinity; + updated_state.msm_accumulator = offset_generator(); } updated_state.pc = state.pc - num_muls; @@ -124,12 +148,12 @@ class ECCVMTranscriptBuilder { if (entry.mul && next_not_msm) { if (state.is_accumulator_empty) { - updated_state.accumulator = updated_state.msm_accumulator; + updated_state.accumulator = updated_state.msm_accumulator - offset_generator(); } else { const auto R = typename CycleGroup::element(state.accumulator); - updated_state.accumulator = R + updated_state.msm_accumulator; + updated_state.accumulator = R + updated_state.msm_accumulator - offset_generator(); } - updated_state.is_accumulator_empty = false; + updated_state.is_accumulator_empty = updated_state.accumulator.is_point_at_infinity(); } bool add_accumulate = entry.add; @@ -151,102 +175,140 @@ class ECCVMTranscriptBuilder { row.pc = state.pc; row.msm_count = state.count; auto base_point_infinity = entry.base_point.is_point_at_infinity(); - auto base_point_x = entry.base_point.x; - auto base_point_y = entry.base_point.y; - if ((entry.add || entry.mul || entry.eq) && base_point_infinity) { - base_point_x = 0; - base_point_y = 0; - } - row.base_x = (entry.add || entry.mul || entry.eq) ? base_point_x : 0; - row.base_y = (entry.add || entry.mul || entry.eq) ? base_point_y : 0; + row.base_x = ((entry.add || entry.mul || entry.eq) && !base_point_infinity) ? entry.base_point.x : 0; + row.base_y = ((entry.add || entry.mul || entry.eq) && !base_point_infinity) ? entry.base_point.y : 0; row.base_infinity = (entry.add || entry.mul || entry.eq) ? (base_point_infinity ? 1 : 0) : 0; if (msm_transition) { - auto lhsx = AffineElement(updated_state.msm_accumulator).x; - auto lhsy = AffineElement(updated_state.msm_accumulator).y; - auto rhsx = (state.accumulator.is_point_at_infinity()) ? 0 : state.accumulator.x; - auto rhsy = (state.accumulator.is_point_at_infinity()) ? 0 : state.accumulator.y; + Element msm_output = updated_state.msm_accumulator - offset_generator(); + row.transcript_msm_infinity = msm_output.is_point_at_infinity(); + } + + row.z1 = (entry.mul) ? entry.z1 : 0; + row.z2 = (entry.mul) ? entry.z2 : 0; + row.z1_zero = z1_zero; + row.z2_zero = z2_zero; + row.opcode = Opcode{ .add = entry.add, .mul = entry.mul, .eq = entry.eq, .reset = entry.reset }.value(); + accumulator_trace[i] = state.accumulator; + msm_accumulator_trace[i] = msm_transition ? updated_state.msm_accumulator : Element::infinity(); + intermediate_accumulator_trace[i] = + msm_transition ? (updated_state.msm_accumulator - offset_generator()) : Element::infinity(); + if (entry.mul && next_not_msm && !row.accumulator_empty) { + state.msm_accumulator = offset_generator(); + } + + state = updated_state; + + if (entry.mul && next_not_msm) { + state.msm_accumulator = offset_generator(); + } + } + Element::batch_normalize(&accumulator_trace[0], accumulator_trace.size()); + Element::batch_normalize(&msm_accumulator_trace[0], msm_accumulator_trace.size()); + Element::batch_normalize(&intermediate_accumulator_trace[0], intermediate_accumulator_trace.size()); + + for (size_t i = 0; i < accumulator_trace.size(); ++i) { + if (!accumulator_trace[i].is_point_at_infinity()) { + transcript_state[i + 1].accumulator_x = accumulator_trace[i].x; + transcript_state[i + 1].accumulator_y = accumulator_trace[i].y; + } + if (!msm_accumulator_trace[i].is_point_at_infinity()) { + transcript_state[i + 1].msm_output_x = msm_accumulator_trace[i].x; + transcript_state[i + 1].msm_output_y = msm_accumulator_trace[i].y; + } + if (!intermediate_accumulator_trace[i].is_point_at_infinity()) { + transcript_state[i + 1].transcript_msm_intermediate_x = intermediate_accumulator_trace[i].x; + transcript_state[i + 1].transcript_msm_intermediate_y = intermediate_accumulator_trace[i].y; + } + } + for (size_t i = 0; i < accumulator_trace.size(); ++i) { + auto& row = transcript_state[i + 1]; + const bool msm_transition = row.msm_transition; + const bool add = row.q_add; + if (msm_transition) { + Element msm_output = intermediate_accumulator_trace[i]; + row.transcript_msm_infinity = msm_output.is_point_at_infinity(); + if (!row.transcript_msm_infinity) { + transcript_msm_x_inverse_trace[i] = (msm_accumulator_trace[i].x - offset_generator().x); + } else { + transcript_msm_x_inverse_trace[i] = 0; + } + auto lhsx = msm_output.x; + auto lhsy = msm_output.y; + auto rhsx = accumulator_trace[i].is_point_at_infinity() ? 0 : accumulator_trace[i].x; + auto rhsy = accumulator_trace[i].is_point_at_infinity() ? (0) : accumulator_trace[i].y; inverse_trace_x[i] = lhsx - rhsx; inverse_trace_y[i] = lhsy - rhsy; - } else if (entry.add) { - auto lhsx = base_point_x; - auto lhsy = base_point_y; - auto rhsx = (state.accumulator.is_point_at_infinity()) ? 0 : state.accumulator.x; - auto rhsy = (state.accumulator.is_point_at_infinity()) ? 0 : state.accumulator.y; + } else if (add) { + auto lhsx = row.base_x; + auto lhsy = row.base_y; + auto rhsx = accumulator_trace[i].is_point_at_infinity() ? 0 : accumulator_trace[i].x; + auto rhsy = accumulator_trace[i].is_point_at_infinity() ? (0) : accumulator_trace[i].y; inverse_trace_x[i] = lhsx - rhsx; inverse_trace_y[i] = lhsy - rhsy; } else { inverse_trace_x[i] = 0; inverse_trace_y[i] = 0; } - + bool last_row = i == (vm_operations.size() - 1); + // msm transition = current row is doing a lookup to validate output = msm output + // i.e. next row is not part of MSM and current row is part of MSM + // or next row is irrelevent and current row is a straight MUL + bool next_not_msm = last_row ? true : !vm_operations[i + 1].mul; + if (row.q_mul && next_not_msm && !row.accumulator_empty) { + ASSERT((row.msm_output_x != row.accumulator_x) && + "eccvm: attempting msm. Result point x-coordinate matches accumulator x-coordinate."); + inverse_trace[i] = (row.msm_output_x - row.accumulator_x); + } else if (row.q_add && !row.accumulator_empty) { + ASSERT((row.base_x != row.accumulator_x) && + "eccvm: attempting to add points with matching x-coordinates"); + inverse_trace[i] = (row.base_x - row.accumulator_x); + } else { + inverse_trace[i] = (0); + } + const bb::eccvm::VMOperation& entry = vm_operations[i]; if (entry.add || msm_transition) { - auto lhs = entry.add ? entry.base_point : updated_state.msm_accumulator; - auto rhs = state.accumulator; + Element lhs = entry.add ? Element(entry.base_point) : intermediate_accumulator_trace[i]; + Element rhs = accumulator_trace[i]; row.transcript_add_x_equal = lhs.x == rhs.x || (lhs.is_point_at_infinity() && rhs.is_point_at_infinity()); // check infinity? row.transcript_add_y_equal = lhs.y == rhs.y || (lhs.is_point_at_infinity() && rhs.is_point_at_infinity()); if (lhs.x == rhs.x && !lhs.is_point_at_infinity() && !rhs.is_point_at_infinity()) { - row.transcript_add_lambda = (lhs.x * lhs.x * 3) / (lhs.y * 2); + add_lambda_denominator[i] = lhs.y + lhs.y; + add_lambda_numerator[i] = lhs.x * lhs.x * 3; } else if (!lhs.is_point_at_infinity() && !rhs.is_point_at_infinity()) { - row.transcript_add_lambda = (rhs.y - lhs.y) / (rhs.x - lhs.x); + add_lambda_denominator[i] = rhs.x - lhs.x; + add_lambda_numerator[i] = rhs.y - lhs.y; } else { - row.transcript_add_lambda = 0; + add_lambda_numerator[i] = 0; + add_lambda_denominator[i] = 0; } } else { row.transcript_add_x_equal = 0; row.transcript_add_y_equal = 0; - row.transcript_add_lambda = 0; - } - - row.z1 = (entry.mul) ? entry.z1 : 0; - row.z2 = (entry.mul) ? entry.z2 : 0; - row.z1_zero = z1_zero; - row.z2_zero = z2_zero; - row.opcode = Opcode{ .add = entry.add, .mul = entry.mul, .eq = entry.eq, .reset = entry.reset }.value(); - row.accumulator_x = (state.accumulator.is_point_at_infinity()) ? 0 : state.accumulator.x; - row.accumulator_y = (state.accumulator.is_point_at_infinity()) ? 0 : state.accumulator.y; - row.msm_output_x = - msm_transition - ? (updated_state.msm_accumulator.is_point_at_infinity() ? 0 : updated_state.msm_accumulator.x) - : 0; - row.msm_output_y = - msm_transition - ? (updated_state.msm_accumulator.is_point_at_infinity() ? 0 : updated_state.msm_accumulator.y) - : 0; - - if (entry.mul && next_not_msm && !row.accumulator_empty) { - ASSERT((row.msm_output_x != row.accumulator_x) && - "eccvm: attempting msm. Result point x-coordinate matches accumulator x-coordinate."); - state.msm_accumulator = CycleGroup::affine_point_at_infinity; - inverse_trace[i] = (row.msm_output_x - row.accumulator_x); - } else if (entry.add && !row.accumulator_empty) { - ASSERT((row.base_x != row.accumulator_x) && - "eccvm: attempting to add points with matching x-coordinates"); - inverse_trace[i] = (row.base_x - row.accumulator_x); - } else { - inverse_trace[i] = (0); - } - - state = updated_state; - - if (entry.mul && next_not_msm) { - state.msm_accumulator = CycleGroup::affine_point_at_infinity; + add_lambda_numerator[i] = 0; + add_lambda_denominator[i] = 0; } } - FF::batch_invert(&inverse_trace[0], inverse_trace.size()); FF::batch_invert(&inverse_trace_x[0], inverse_trace.size()); FF::batch_invert(&inverse_trace_y[0], inverse_trace.size()); + FF::batch_invert(&transcript_msm_x_inverse_trace[0], inverse_trace.size()); + FF::batch_invert(&add_lambda_denominator[0], inverse_trace.size()); + for (size_t i = 0; i < inverse_trace.size(); ++i) { transcript_state[i + 1].collision_check = inverse_trace[i]; transcript_state[i + 1].base_x_inverse = inverse_trace_x[i]; transcript_state[i + 1].base_y_inverse = inverse_trace_y[i]; + transcript_state[i + 1].transcript_msm_x_inverse = transcript_msm_x_inverse_trace[i]; + transcript_state[i + 1].transcript_add_lambda = add_lambda_numerator[i] * add_lambda_denominator[i]; } TranscriptState& final_row = transcript_state.back(); final_row.pc = updated_state.pc; - final_row.accumulator_x = (updated_state.accumulator.is_point_at_infinity()) ? 0 : updated_state.accumulator.x; - final_row.accumulator_y = (updated_state.accumulator.is_point_at_infinity()) ? 0 : updated_state.accumulator.y; + final_row.accumulator_x = + (updated_state.accumulator.is_point_at_infinity()) ? 0 : AffineElement(updated_state.accumulator).x; + final_row.accumulator_y = + (updated_state.accumulator.is_point_at_infinity()) ? 0 : AffineElement(updated_state.accumulator).y; final_row.accumulator_empty = updated_state.is_accumulator_empty; return transcript_state; } diff --git a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_msm_relation.cpp b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_msm_relation.cpp index b71b5a6e4a0..c2c3e97460d 100644 --- a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_msm_relation.cpp +++ b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_msm_relation.cpp @@ -176,7 +176,6 @@ void ECCVMMSMRelationImpl::accumulate(ContainerOverSubrelations& accumulator auto& selector, auto& relation, auto& collision_relation) { - // (L * (xb - xa) - yb - ya) * s = 0 // L * (1 - s) = 0 // (combine) (L * (xb - xa - 1) - yb - ya) * s + L = 0 relation += selector * (lambda * (xb - xa - 1) - (yb - ya)) + lambda; @@ -189,6 +188,42 @@ void ECCVMMSMRelationImpl::accumulate(ContainerOverSubrelations& accumulator return std::array{ x_out, y_out }; }; + /** + * @brief First Addition relation + * + * The first add operation per row is treated differently. + * Normally we add the point xa/ya with an accumulator xb/yb, + * BUT, if this row STARTS a multiscalar multiplication, + * We need to add the point xa/ya with the "offset generator point" xo/yo + * The offset generator point's purpose is to ensure that no intermediate computations in the MSM will produce + * points at infinity, for an honest Prover. + * (we ensure soundness by validating the x-coordinates of xa/xb are not the same i.e. incomplete addition formula + * edge cases have not been hit) + * Note: this technique is only statistically complete, as there is a chance of an honest Prover creating a + * collision, but this probability is equivalent to solving the discrete logarithm problem + */ + auto first_add = [&](auto& xb, + auto& yb, + auto& xa, + auto& ya, + auto& lambda, + auto& selector, + auto& relation, + auto& collision_relation) { + constexpr auto offset_generator = bb::g1::derive_generators("ECCVM_OFFSET_GENERATOR", 1)[0]; + constexpr uint256_t oxu = offset_generator.x; + constexpr uint256_t oyu = offset_generator.y; + const Accumulator xo(oxu); + const Accumulator yo(oyu); + + auto x = xo * selector + xb * (-selector + 1); + auto y = yo * selector + yb * (-selector + 1); + relation += lambda * (x - xa) - (y - ya); // degree 3 + collision_relation += (xa - x); + auto x_out = lambda * lambda + (-x - xa); + auto y_out = lambda * (xa - x_out) - ya; + return std::array{ x_out, y_out }; + }; // ADD operations (if row represents ADD round, not SKEW or DOUBLE) Accumulator add_relation(0); Accumulator x1_collision_relation(0); @@ -197,8 +232,8 @@ void ECCVMMSMRelationImpl::accumulate(ContainerOverSubrelations& accumulator Accumulator x4_collision_relation(0); // If msm_transition = 1, we have started a new MSM. We need to treat the current value of [Acc] as the point at // infinity! - auto add_into_accumulator = -msm_transition + 1; - auto [x_t1, y_t1] = add(acc_x, acc_y, x1, y1, lambda1, add_into_accumulator, add_relation, x1_collision_relation); + // auto add_into_accumulator = -msm_transition + 1; + auto [x_t1, y_t1] = first_add(acc_x, acc_y, x1, y1, lambda1, msm_transition, add_relation, x1_collision_relation); auto [x_t2, y_t2] = add(x2, y2, x_t1, y_t1, lambda2, add2, add_relation, x2_collision_relation); auto [x_t3, y_t3] = add(x3, y3, x_t2, y_t2, lambda3, add3, add_relation, x3_collision_relation); auto [x_t4, y_t4] = add(x4, y4, x_t3, y_t3, lambda4, add4, add_relation, x4_collision_relation); @@ -285,7 +320,7 @@ void ECCVMMSMRelationImpl::accumulate(ContainerOverSubrelations& accumulator // Check x-coordinates do not collide if row is an ADD row or a SKEW row // if either q_add or q_skew = 1, an inverse should exist for each computed relation // Step 1: construct boolean selectors that describe whether we added a point at the current row - const auto add_first_point = add_into_accumulator * q_add + q_skew * skew1_select; + const auto add_first_point = add1 * q_add + q_skew * skew1_select; const auto add_second_point = add2 * q_add + q_skew * skew2_select; const auto add_third_point = add3 * q_add + q_skew * skew3_select; const auto add_fourth_point = add4 * q_add + q_skew * skew4_select; diff --git a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp index 2372c64d81b..6bf2ff6e6c1 100644 --- a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp +++ b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp @@ -39,6 +39,16 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu using Accumulator = typename std::tuple_element_t<0, ContainerOverSubrelations>; using View = typename Accumulator::View; + static const auto offset_generator = [&]() { + static constexpr auto offset_generator_base = bb::g1::derive_generators("ECCVM_OFFSET_GENERATOR", 1)[0]; + static bb::g1::affine_element result = + bb::g1::affine_element(bb::g1::element(offset_generator_base) * grumpkin::fq(uint256_t(1) << 124)); + static const FF qx = result.x; + static const FF qy = result.y; + static const Accumulator ox(qx); + static const Accumulator oy(qy); + return std::array{ ox, oy }; + }; auto z1 = View(in.transcript_z1); auto z2 = View(in.transcript_z2); auto z1_zero = View(in.transcript_z1zero); @@ -57,8 +67,8 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu auto transcript_accumulator_y_shift = View(in.transcript_accumulator_y_shift); auto transcript_accumulator_x = View(in.transcript_accumulator_x); auto transcript_accumulator_y = View(in.transcript_accumulator_y); - auto transcript_msm_x = View(in.transcript_msm_x); - auto transcript_msm_y = View(in.transcript_msm_y); + auto transcript_msm_x = View(in.transcript_msm_intermediate_x); + auto transcript_msm_y = View(in.transcript_msm_intermediate_y); auto transcript_Px = View(in.transcript_Px); auto transcript_Py = View(in.transcript_Py); auto is_accumulator_empty = View(in.transcript_accumulator_empty); @@ -73,6 +83,7 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu auto transcript_add_x_equal = View(in.transcript_add_x_equal); auto transcript_add_y_equal = View(in.transcript_add_y_equal); auto transcript_add_lambda = View(in.transcript_add_lambda); + auto transcript_msm_infinity = View(in.transcript_msm_infinity); auto is_not_first_row = (-lagrange_first + 1); auto is_not_first_or_last_row = (-lagrange_first + -lagrange_last + 1); @@ -218,7 +229,7 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu // Are the lhs/rhs points at infinity? // MSM output CANNOT be point at infinity without triggering unsatisfiable constraints in msm_relation // lhs can only be at infinity if q_add is active - auto lhs_infinity = transcript_Pinfinity * q_add; + auto lhs_infinity = transcript_Pinfinity * q_add + transcript_msm_infinity * msm_transition; auto rhs_infinity = is_accumulator_empty; // Determine where the group operation output is sourced from // | lhs_infinity | rhs_infinity | lhs_x == rhs_x && lhs_y != rhs_y | output | @@ -320,6 +331,47 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu std::get<30>(accumulator) += transcript_add_x_equal * (transcript_add_x_equal - 1) * scaling_factor; std::get<31>(accumulator) += transcript_add_y_equal * (transcript_add_y_equal - 1) * scaling_factor; std::get<32>(accumulator) += transcript_Pinfinity * (transcript_Pinfinity - 1) * scaling_factor; + + // step 1: subtract offset generator from msm_accumulator + // this might produce a point at infinity + { + const auto offset = offset_generator(); + const auto x1 = offset[0]; + const auto y1 = -offset[1]; + const auto x2 = View(in.transcript_msm_x); + const auto y2 = View(in.transcript_msm_y); + const auto x3 = View(in.transcript_msm_intermediate_x); + const auto y3 = View(in.transcript_msm_intermediate_y); + const auto transcript_msm_infinity = View(in.transcript_msm_infinity); + // cases: + // x2 == x1, y2 == y1 + // x2 != x1 + // (x2 - x1) + const auto x_term = (x3 + x2 + x1) * (x2 - x1) * (x2 - x1) - (y2 - y1) * (y2 - y1); + const auto y_term = (x1 - x3) * (y2 - y1) - (x2 - x1) * (y1 + y3); + // IF msm_infinity = false, transcript_msm_intermediate_x/y is either the result of subtracting offset generator + // from msm_x/y IF msm_infinity = true, transcript_msm_intermediate_x/y is 0 + const auto transcript_offset_generator_subtract_x = + x_term * (-transcript_msm_infinity + 1) + transcript_msm_infinity * x3; + const auto transcript_offset_generator_subtract_y = + y_term * (-transcript_msm_infinity + 1) + transcript_msm_infinity * y3; + std::get<33>(accumulator) += msm_transition * transcript_offset_generator_subtract_x * scaling_factor; + std::get<34>(accumulator) += msm_transition * transcript_offset_generator_subtract_y * scaling_factor; + + // validate transcript_msm_infinity is correct + // if transcript_msm_infinity = 1, (x2 == x1) and (y2 + y1 == 0) + const auto x_diff = x2 - x1; + const auto y_sum = y2 + y1; + std::get<35>(accumulator) += msm_transition * transcript_msm_infinity * x_diff * scaling_factor; + std::get<36>(accumulator) += msm_transition * transcript_msm_infinity * y_sum * scaling_factor; + // if transcript_msm_infinity = 1, then x_diff must have an inverse + const auto transcript_msm_x_inverse = View(in.transcript_msm_x_inverse); + const auto inverse_term = (-transcript_msm_infinity + 1) * (x_diff * transcript_msm_x_inverse - 1); + std::get<37>(accumulator) += msm_transition * inverse_term * scaling_factor; + } + + // Validate correctness of + {} } template class ECCVMTranscriptRelationImpl; diff --git a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.hpp b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.hpp index ee2904f82dd..b214f3c941c 100644 --- a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.hpp +++ b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.hpp @@ -30,8 +30,9 @@ template class ECCVMTranscriptRelationImpl { public: using FF = FF_; - static constexpr std::array SUBRELATION_PARTIAL_LENGTHS{ - 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + static constexpr std::array SUBRELATION_PARTIAL_LENGTHS{ + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, }; template From f22670edbfe5086011c64dec6538549467d70cf2 Mon Sep 17 00:00:00 2001 From: zac-williamson Date: Tue, 14 May 2024 10:03:49 +0000 Subject: [PATCH 05/24] eccvm: added comprehensive tests to handle points at infinity / point doubling edge cases upgraded eccvm transcript to correctly handle an MSM composed entirely of points at infinity and/or zero scalars --- .../eccvm/eccvm_circuit_builder.hpp | 75 +++++---- .../eccvm/eccvm_circuit_builder.test.cpp | 157 ++++++++++++++++++ .../src/barretenberg/eccvm/eccvm_flavor.hpp | 26 ++- .../eccvm/eccvm_transcript.test.cpp | 2 + .../src/barretenberg/eccvm/eccvm_verifier.cpp | 4 + .../barretenberg/eccvm/transcript_builder.hpp | 59 +++---- .../relations/ecc_vm/ecc_set_relation.cpp | 20 ++- .../relations/ecc_vm/ecc_set_relation.hpp | 4 +- .../ecc_vm/ecc_transcript_relation.cpp | 72 +++++--- .../ecc_vm/ecc_transcript_relation.hpp | 6 +- .../op_queue/ecc_op_queue.hpp | 33 +++- 11 files changed, 347 insertions(+), 111 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.hpp b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.hpp index b295133b12a..8f493da5735 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.hpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.hpp @@ -120,14 +120,14 @@ class ECCVMCircuitBuilder { size_t op_idx = 0; for (const auto& op : raw_ops) { if (op.mul) { - if (op.z1 != 0 || op.z2 != 0) { + if ((op.z1 != 0 || op.z2 != 0) && !op.base_point.is_point_at_infinity()) { msm_opqueue_index.push_back(op_idx); msm_mul_index.emplace_back(msm_count, active_mul_count); } - if (op.z1 != 0) { + if (op.z1 != 0 && !op.base_point.is_point_at_infinity()) { active_mul_count++; } - if (op.z2 != 0) { + if (op.z2 != 0 && !op.base_point.is_point_at_infinity()) { active_mul_count++; } } else if (active_mul_count > 0) { @@ -138,7 +138,7 @@ class ECCVMCircuitBuilder { op_idx++; } // if last op is a mul we have not correctly computed the total number of msms - if (raw_ops.back().mul) { + if (raw_ops.back().mul && active_mul_count > 0) { msm_sizes.push_back(active_mul_count); msm_count++; } @@ -148,39 +148,42 @@ class ECCVMCircuitBuilder { msm.resize(msm_sizes[i]); } - run_loop_in_parallel(msm_opqueue_index.size(), [&](size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - const size_t opqueue_index = msm_opqueue_index[i]; - const auto& op = raw_ops[opqueue_index]; - auto [msm_index, mul_index] = msm_mul_index[i]; - if (op.z1 != 0) { - ASSERT(msms_test.size() > msm_index); - ASSERT(msms_test[msm_index].size() > mul_index); - msms_test[msm_index][mul_index] = (ScalarMul{ - .pc = 0, - .scalar = op.z1, - .base_point = op.base_point, - .wnaf_slices = compute_wnaf_slices(op.z1), - .wnaf_skew = (op.z1 & 1) == 0, - .precomputed_table = compute_precomputed_table(op.base_point), - }); - mul_index++; - } - if (op.z2 != 0) { - ASSERT(msms_test.size() > msm_index); - ASSERT(msms_test[msm_index].size() > mul_index); - auto endo_point = AffineElement{ op.base_point.x * FF::cube_root_of_unity(), -op.base_point.y }; - msms_test[msm_index][mul_index] = (ScalarMul{ - .pc = 0, - .scalar = op.z2, - .base_point = endo_point, - .wnaf_slices = compute_wnaf_slices(op.z2), - .wnaf_skew = (op.z2 & 1) == 0, - .precomputed_table = compute_precomputed_table(endo_point), - }); - } + // run_loop_in_parallel(msm_opqueue_index.size(), [&](size_t start, size_t end) { + size_t start = 0; + size_t end = msm_opqueue_index.size(); + for (size_t i = start; i < end; i++) { + const size_t opqueue_index = msm_opqueue_index[i]; + const auto& op = raw_ops[opqueue_index]; + auto [msm_index, mul_index] = msm_mul_index[i]; + if (op.z1 != 0 && !op.base_point.is_point_at_infinity()) { + + ASSERT(msms_test.size() > msm_index); + ASSERT(msms_test[msm_index].size() > mul_index); + msms_test[msm_index][mul_index] = (ScalarMul{ + .pc = 0, + .scalar = op.z1, + .base_point = op.base_point, + .wnaf_slices = compute_wnaf_slices(op.z1), + .wnaf_skew = (op.z1 & 1) == 0, + .precomputed_table = compute_precomputed_table(op.base_point), + }); + mul_index++; } - }); + if (op.z2 != 0 && !op.base_point.is_point_at_infinity()) { + ASSERT(msms_test.size() > msm_index); + ASSERT(msms_test[msm_index].size() > mul_index); + auto endo_point = AffineElement{ op.base_point.x * FF::cube_root_of_unity(), -op.base_point.y }; + msms_test[msm_index][mul_index] = (ScalarMul{ + .pc = 0, + .scalar = op.z2, + .base_point = endo_point, + .wnaf_slices = compute_wnaf_slices(op.z2), + .wnaf_skew = (op.z2 & 1) == 0, + .precomputed_table = compute_precomputed_table(endo_point), + }); + } + } + // }); // update pc. easier to do this serially but in theory could be optimised out // We start pc at `num_muls` and decrement for each mul processed. diff --git a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.test.cpp b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.test.cpp index 395fedf0ed4..7f302f9a1e5 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.test.cpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.test.cpp @@ -18,8 +18,10 @@ TEST(ECCVMCircuitBuilderTests, BaseCase) typename G1::element a = generators[0]; typename G1::element b = generators[1]; typename G1::element c = generators[2]; + typename G1::element point_at_infinity = G1::point_at_infinity; Fr x = Fr::random_element(&engine); Fr y = Fr::random_element(&engine); + Fr zero_scalar = 0; std::shared_ptr op_queue = std::make_shared(); @@ -32,11 +34,33 @@ TEST(ECCVMCircuitBuilderTests, BaseCase) op_queue->eq_and_reset(); op_queue->add_accumulate(c); op_queue->mul_accumulate(a, x); + op_queue->mul_accumulate(point_at_infinity, x); op_queue->mul_accumulate(b, x); op_queue->eq_and_reset(); op_queue->mul_accumulate(a, x); op_queue->mul_accumulate(b, x); + op_queue->mul_accumulate(point_at_infinity, zero_scalar); op_queue->mul_accumulate(c, x); + op_queue->eq_and_reset(); + op_queue->mul_accumulate(point_at_infinity, zero_scalar); + op_queue->mul_accumulate(point_at_infinity, x); + op_queue->mul_accumulate(point_at_infinity, zero_scalar); + op_queue->add_accumulate(a); + op_queue->eq_and_reset(); + op_queue->add_accumulate(a); + op_queue->add_accumulate(point_at_infinity); + op_queue->eq_and_reset(); + op_queue->add_accumulate(point_at_infinity); + op_queue->eq_and_reset(); + op_queue->mul_accumulate(point_at_infinity, x); + op_queue->mul_accumulate(point_at_infinity, -x); + op_queue->eq_and_reset(); + op_queue->add_accumulate(a); + op_queue->mul_accumulate(point_at_infinity, x); + op_queue->mul_accumulate(point_at_infinity, -x); + op_queue->add_accumulate(a); + op_queue->add_accumulate(a); + op_queue->eq_and_reset(); ECCVMCircuitBuilder circuit{ op_queue }; bool result = ECCVMTraceChecker::check(circuit); @@ -72,6 +96,109 @@ TEST(ECCVMCircuitBuilderTests, Mul) EXPECT_EQ(result, true); } +TEST(ECCVMCircuitBuilderTests, MulInfinity) +{ + std::shared_ptr op_queue = std::make_shared(); + + auto generators = G1::derive_generators("test generators", 3); + typename G1::element a = generators[0]; + Fr x = Fr::random_element(&engine); + G1::element b = -a * x; + // G1::affine_element c = G1::affine_point_at_infinity; + op_queue->add_accumulate(b); + op_queue->mul_accumulate(a, x); + // op_queue->eq_and_resetb(c); + ECCVMCircuitBuilder circuit{ op_queue }; + bool result = ECCVMTraceChecker::check(circuit); + EXPECT_EQ(result, true); +} + +// Validate we do not trigger edge cases of addition formulae when we have identical mul inputs +TEST(ECCVMCircuitBuilderTests, MulOverIdenticalInputs) +{ + std::shared_ptr op_queue = std::make_shared(); + + auto generators = G1::derive_generators("test generators", 3); + typename G1::element a = generators[0]; + Fr x = Fr::random_element(&engine); + op_queue->mul_accumulate(a, x); + op_queue->mul_accumulate(a, x); + op_queue->eq_and_reset(); + ECCVMCircuitBuilder circuit{ op_queue }; + bool result = ECCVMTraceChecker::check(circuit); + EXPECT_EQ(result, true); +} + +TEST(ECCVMCircuitBuilderTests, MSMProducesInfinity) +{ + std::shared_ptr op_queue = std::make_shared(); + + auto generators = G1::derive_generators("test generators", 3); + typename G1::element a = generators[0]; + Fr x = Fr::random_element(&engine); + op_queue->add_accumulate(a); + op_queue->mul_accumulate(a, x); + op_queue->mul_accumulate(a, -x); + op_queue->eq_and_reset(); + ECCVMCircuitBuilder circuit{ op_queue }; + bool result = ECCVMTraceChecker::check(circuit); + EXPECT_EQ(result, true); +} + +TEST(ECCVMCircuitBuilderTests, MSMOverPointAtInfinity) +{ + std::shared_ptr op_queue = std::make_shared(); + + auto generators = G1::derive_generators("test generators", 3); + typename G1::element point_at_infinity = G1::point_at_infinity; + typename G1::element b = generators[0]; + Fr x = Fr::random_element(&engine); + Fr zero_scalar = 0; + + // validate including points at infinity in a multiscalar multiplication does not effect result + { + op_queue->mul_accumulate(b, x); + op_queue->mul_accumulate(point_at_infinity, x); + op_queue->eq_and_reset(); + ECCVMCircuitBuilder circuit{ op_queue }; + bool result = ECCVMTraceChecker::check(circuit); + EXPECT_EQ(result, true); + } + // validate multiplying a point at infinity by nonzero scalar produces point at infinity + { + op_queue->mul_accumulate(point_at_infinity, x); + op_queue->eq_and_reset(); + ECCVMCircuitBuilder circuit{ op_queue }; + bool result = ECCVMTraceChecker::check(circuit); + EXPECT_EQ(result, true); + } + // validate multiplying a point by zero produces point at infinity + { + op_queue->mul_accumulate(b, zero_scalar); + op_queue->eq_and_reset(); + ECCVMCircuitBuilder circuit{ op_queue }; + bool result = ECCVMTraceChecker::check(circuit); + EXPECT_EQ(result, true); + } + // validate multiplying a point at infinity by zero produces a point at infinity + { + op_queue->mul_accumulate(point_at_infinity, zero_scalar); + op_queue->eq_and_reset(); + ECCVMCircuitBuilder circuit{ op_queue }; + bool result = ECCVMTraceChecker::check(circuit); + EXPECT_EQ(result, true); + } + // validate an MSM made entirely of points at infinity / zero scalars produces a point at infinity + { + op_queue->mul_accumulate(point_at_infinity, x); + op_queue->mul_accumulate(b, zero_scalar); + op_queue->eq_and_reset(); + ECCVMCircuitBuilder circuit{ op_queue }; + bool result = ECCVMTraceChecker::check(circuit); + EXPECT_EQ(result, true); + } +} + TEST(ECCVMCircuitBuilderTests, ShortMul) { std::shared_ptr op_queue = std::make_shared(); @@ -273,6 +400,36 @@ TEST(ECCVMCircuitBuilderTests, AddPointAtInfinity) op_queue->add_accumulate(b); op_queue->eq_and_reset(); + ECCVMCircuitBuilder circuit{ op_queue }; + bool result = ECCVMTraceChecker::check(circuit); + EXPECT_EQ(result, true); +} + +TEST(ECCVMCircuitBuilderTests, AddProducesPointAtInfinity) +{ + std::shared_ptr op_queue = std::make_shared(); + + auto generators = G1::derive_generators("test generators", 3); + typename G1::element a = generators[0]; + + op_queue->add_accumulate(a); + op_queue->add_accumulate(-a); + op_queue->eq_and_reset(); + ECCVMCircuitBuilder circuit{ op_queue }; + bool result = ECCVMTraceChecker::check(circuit); + EXPECT_EQ(result, true); +} + +TEST(ECCVMCircuitBuilderTests, AddProducesDouble) +{ + std::shared_ptr op_queue = std::make_shared(); + + auto generators = G1::derive_generators("test generators", 3); + typename G1::element a = generators[0]; + + op_queue->add_accumulate(a); + op_queue->add_accumulate(a); + op_queue->eq_and_reset(); ECCVMCircuitBuilder circuit{ op_queue }; bool result = ECCVMTraceChecker::check(circuit); EXPECT_EQ(result, true); diff --git a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp index 64742c084d9..074011b887d 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp @@ -35,17 +35,17 @@ class ECCVMFlavor { using VerifierCommitmentKey = bb::VerifierCommitmentKey; using RelationSeparator = FF; - static constexpr size_t NUM_WIRES = 85; + static constexpr size_t NUM_WIRES = 87; // The number of multivariate polynomials on which a sumcheck prover sumcheck operates (including shifts). We often // need containers of this size to hold related data, so we choose a name more agnostic than `NUM_POLYNOMIALS`. // Note: this number does not include the individual sorted list polynomials. - static constexpr size_t NUM_ALL_ENTITIES = 116; + static constexpr size_t NUM_ALL_ENTITIES = 118; // The number of polynomials precomputed to describe a circuit and to aid a prover in constructing a satisfying // assignment of witnesses. We again choose a neutral name. static constexpr size_t NUM_PRECOMPUTED_ENTITIES = 3; // The total number of witness entities not including shifts. - static constexpr size_t NUM_WITNESS_ENTITIES = 87; + static constexpr size_t NUM_WITNESS_ENTITIES = 89; using GrandProductRelations = std::tuple>; // define the tuple of Relations that comprise the Sumcheck relation @@ -191,7 +191,9 @@ class ECCVMFlavor { transcript_msm_intermediate_x, // column 81 transcript_msm_intermediate_y, // column 82 transcript_msm_infinity, // column 83 - transcript_msm_x_inverse); // column 84 + transcript_msm_x_inverse, + transcript_msm_count_zero_at_transition, + transcript_msm_count_at_transition_inverse); // column 86 }; /** @@ -594,6 +596,8 @@ class ECCVMFlavor { transcript_msm_intermediate_y[i] = transcript_state[i].transcript_msm_intermediate_y; transcript_msm_infinity[i] = transcript_state[i].transcript_msm_infinity; transcript_msm_x_inverse[i] = transcript_state[i].transcript_msm_x_inverse; + transcript_msm_count_zero_at_transition[i] = transcript_state[i].msm_count_zero_at_transition; + transcript_msm_count_at_transition_inverse[i] = transcript_state[i].msm_count_at_transition_inverse; } }); @@ -808,6 +812,8 @@ class ECCVMFlavor { Base::transcript_msm_intermediate_y = "TRANSCRIPT_MSM_INTERMEDIATE_Y"; Base::transcript_msm_infinity = "TRANSCRIPT_MSM_INFINITY"; Base::transcript_msm_x_inverse = "TRANSCRIPT_MSM_X_INVERSE"; + Base::transcript_msm_count_zero_at_transition = "TRANSCRIPT_MSM_COUNT_ZERO_AT_TRANSITION"; + Base::transcript_msm_count_at_transition_inverse = "TRANSCRIPT_MSM_COUNT_AT_TRANSITION_INVERSE"; Base::z_perm = "Z_PERM"; Base::lookup_inverses = "LOOKUP_INVERSES"; // The ones beginning with "__" are only used for debugging @@ -919,6 +925,8 @@ class ECCVMFlavor { Commitment transcript_msm_intermediate_y_comm; Commitment transcript_msm_infinity_comm; Commitment transcript_msm_x_inverse_comm; + Commitment transcript_msm_count_zero_at_transition_comm; + Commitment transcript_msm_count_at_transition_inverse_comm; Commitment z_perm_comm; Commitment lookup_inverses_comm; std::vector> sumcheck_univariates; @@ -1124,6 +1132,12 @@ class ECCVMFlavor { NativeTranscript::proof_data, num_frs_read); transcript_msm_x_inverse_comm = NativeTranscript::template deserialize_from_buffer( NativeTranscript::proof_data, num_frs_read); + transcript_msm_count_zero_at_transition_comm = + NativeTranscript::template deserialize_from_buffer(NativeTranscript::proof_data, + num_frs_read); + transcript_msm_count_at_transition_inverse_comm = + NativeTranscript::template deserialize_from_buffer(NativeTranscript::proof_data, + num_frs_read); lookup_inverses_comm = NativeTranscript::template deserialize_from_buffer( NativeTranscript::proof_data, num_frs_read); z_perm_comm = NativeTranscript::template deserialize_from_buffer(NativeTranscript::proof_data, @@ -1285,6 +1299,10 @@ class ECCVMFlavor { NativeTranscript::proof_data); NativeTranscript::template serialize_to_buffer(transcript_msm_infinity_comm, NativeTranscript::proof_data); NativeTranscript::template serialize_to_buffer(transcript_msm_x_inverse_comm, NativeTranscript::proof_data); + NativeTranscript::template serialize_to_buffer(transcript_msm_count_zero_at_transition_comm, + NativeTranscript::proof_data); + NativeTranscript::template serialize_to_buffer(transcript_msm_count_at_transition_inverse_comm, + NativeTranscript::proof_data); NativeTranscript::template serialize_to_buffer(lookup_inverses_comm, NativeTranscript::proof_data); NativeTranscript::template serialize_to_buffer(z_perm_comm, NativeTranscript::proof_data); for (size_t i = 0; i < log_n; ++i) { diff --git a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_transcript.test.cpp b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_transcript.test.cpp index 08821bf76b7..720567d92c5 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_transcript.test.cpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_transcript.test.cpp @@ -127,6 +127,8 @@ class ECCVMTranscriptTests : public ::testing::Test { manifest_expected.add_entry(round, "TRANSCRIPT_MSM_INTERMEDIATE_Y", frs_per_G); manifest_expected.add_entry(round, "TRANSCRIPT_MSM_INFINITY", frs_per_G); manifest_expected.add_entry(round, "TRANSCRIPT_MSM_X_INVERSE", frs_per_G); + manifest_expected.add_entry(round, "TRANSCRIPT_MSM_COUNT_ZERO_AT_TRANSITION", frs_per_G); + manifest_expected.add_entry(round, "TRANSCRIPT_MSM_COUNT_AT_TRANSITION_INVERSE", frs_per_G); manifest_expected.add_challenge(round, "beta", "gamma"); round++; diff --git a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_verifier.cpp b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_verifier.cpp index 8a6e97cb770..884e0af3cfc 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_verifier.cpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_verifier.cpp @@ -113,6 +113,10 @@ bool ECCVMVerifier::verify_proof(const HonkProof& proof) commitments.transcript_msm_intermediate_y = receive_commitment(commitment_labels.transcript_msm_intermediate_y); commitments.transcript_msm_infinity = receive_commitment(commitment_labels.transcript_msm_infinity); commitments.transcript_msm_x_inverse = receive_commitment(commitment_labels.transcript_msm_x_inverse); + commitments.transcript_msm_count_zero_at_transition = + receive_commitment(commitment_labels.transcript_msm_count_zero_at_transition); + commitments.transcript_msm_count_at_transition_inverse = + receive_commitment(commitment_labels.transcript_msm_count_at_transition_inverse); // Get challenge for sorted list batching and wire four memory records auto [beta, gamma] = transcript->template get_challenges("beta", "gamma"); diff --git a/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp b/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp index 9cce05ae6bc..a95806f671e 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp @@ -43,6 +43,8 @@ class ECCVMTranscriptBuilder { FF transcript_msm_intermediate_y = 0; bool transcript_msm_infinity = false; FF transcript_msm_x_inverse = 0; + bool msm_count_zero_at_transition = false; + FF msm_count_at_transition_inverse = 0; }; static AffineElement offset_generator() @@ -89,7 +91,6 @@ class ECCVMTranscriptBuilder { // These vectors track quantities that we need to invert. // We fill these vectors and then perform batch inversions to amortize the cost of FF inverts - std::vector inverse_trace(num_vm_entries); std::vector inverse_trace_x(num_vm_entries); std::vector inverse_trace_y(num_vm_entries); std::vector transcript_y_collision_check(num_vm_entries); @@ -117,12 +118,19 @@ class ECCVMTranscriptBuilder { const bool is_mul = entry.mul; const bool z1_zero = (entry.mul) ? entry.z1 == 0 : true; const bool z2_zero = (entry.mul) ? entry.z2 == 0 : true; - const uint32_t num_muls = is_mul ? (static_cast(!z1_zero) + static_cast(!z2_zero)) : 0; - + uint32_t num_muls = 0; + auto base_point_infinity = entry.base_point.is_point_at_infinity(); + if (is_mul) { + num_muls = static_cast(!z1_zero) + static_cast(!z2_zero); + if (base_point_infinity) { + num_muls = 0; + } + } updated_state = state; if (entry.reset) { updated_state.is_accumulator_empty = true; + updated_state.accumulator = CycleGroup::point_at_infinity; updated_state.msm_accumulator = offset_generator(); } updated_state.pc = state.pc - num_muls; @@ -133,20 +141,20 @@ class ECCVMTranscriptBuilder { // or next row is irrelevent and current row is a straight MUL bool next_not_msm = last_row ? true : !vm_operations[i + 1].mul; - bool msm_transition = entry.mul && next_not_msm; + bool msm_transition = entry.mul && next_not_msm && (state.count + num_muls > 0); // we reset the count in updated state if we are not accumulating and not doing an msm bool current_msm = entry.mul; bool current_ongoing_msm = entry.mul && !next_not_msm; updated_state.count = current_ongoing_msm ? state.count + num_muls : 0; - if (current_msm) { const auto P = typename CycleGroup::element(entry.base_point); const auto R = typename CycleGroup::element(state.msm_accumulator); updated_state.msm_accumulator = R + P * entry.mul_scalar_full; } - if (entry.mul && next_not_msm) { + // TODO IF FAKE TRANSITION FIGURE OUT WHAT TO DO WITH ACCUMULATORS BLAH BLAH BLAH + if (msm_transition) { if (state.is_accumulator_empty) { updated_state.accumulator = updated_state.msm_accumulator - offset_generator(); } else { @@ -159,7 +167,6 @@ class ECCVMTranscriptBuilder { bool add_accumulate = entry.add; if (add_accumulate) { if (state.is_accumulator_empty) { - updated_state.accumulator = entry.base_point; } else { updated_state.accumulator = typename CycleGroup::element(state.accumulator) + entry.base_point; @@ -174,7 +181,9 @@ class ECCVMTranscriptBuilder { row.msm_transition = msm_transition; row.pc = state.pc; row.msm_count = state.count; - auto base_point_infinity = entry.base_point.is_point_at_infinity(); + row.msm_count_zero_at_transition = ((state.count + num_muls) == 0) && (entry.mul && next_not_msm); + row.msm_count_at_transition_inverse = + ((state.count + num_muls) == 0) ? 0 : FF(state.count + num_muls).invert(); // TODO BATCH row.base_x = ((entry.add || entry.mul || entry.eq) && !base_point_infinity) ? entry.base_point.x : 0; row.base_y = ((entry.add || entry.mul || entry.eq) && !base_point_infinity) ? entry.base_point.y : 0; row.base_infinity = (entry.add || entry.mul || entry.eq) ? (base_point_infinity ? 1 : 0) : 0; @@ -232,8 +241,8 @@ class ECCVMTranscriptBuilder { } else { transcript_msm_x_inverse_trace[i] = 0; } - auto lhsx = msm_output.x; - auto lhsy = msm_output.y; + auto lhsx = msm_output.is_point_at_infinity() ? 0 : msm_output.x; + auto lhsy = msm_output.is_point_at_infinity() ? 0 : msm_output.y; auto rhsx = accumulator_trace[i].is_point_at_infinity() ? 0 : accumulator_trace[i].x; auto rhsy = accumulator_trace[i].is_point_at_infinity() ? (0) : accumulator_trace[i].y; inverse_trace_x[i] = lhsx - rhsx; @@ -249,22 +258,9 @@ class ECCVMTranscriptBuilder { inverse_trace_x[i] = 0; inverse_trace_y[i] = 0; } - bool last_row = i == (vm_operations.size() - 1); // msm transition = current row is doing a lookup to validate output = msm output // i.e. next row is not part of MSM and current row is part of MSM // or next row is irrelevent and current row is a straight MUL - bool next_not_msm = last_row ? true : !vm_operations[i + 1].mul; - if (row.q_mul && next_not_msm && !row.accumulator_empty) { - ASSERT((row.msm_output_x != row.accumulator_x) && - "eccvm: attempting msm. Result point x-coordinate matches accumulator x-coordinate."); - inverse_trace[i] = (row.msm_output_x - row.accumulator_x); - } else if (row.q_add && !row.accumulator_empty) { - ASSERT((row.base_x != row.accumulator_x) && - "eccvm: attempting to add points with matching x-coordinates"); - inverse_trace[i] = (row.base_x - row.accumulator_x); - } else { - inverse_trace[i] = (0); - } const bb::eccvm::VMOperation& entry = vm_operations[i]; if (entry.add || msm_transition) { Element lhs = entry.add ? Element(entry.base_point) : intermediate_accumulator_trace[i]; @@ -273,10 +269,11 @@ class ECCVMTranscriptBuilder { lhs.x == rhs.x || (lhs.is_point_at_infinity() && rhs.is_point_at_infinity()); // check infinity? row.transcript_add_y_equal = lhs.y == rhs.y || (lhs.is_point_at_infinity() && rhs.is_point_at_infinity()); - if (lhs.x == rhs.x && !lhs.is_point_at_infinity() && !rhs.is_point_at_infinity()) { + if ((lhs.x == rhs.x) && (lhs.y == rhs.y) && !lhs.is_point_at_infinity() && + !rhs.is_point_at_infinity()) { add_lambda_denominator[i] = lhs.y + lhs.y; add_lambda_numerator[i] = lhs.x * lhs.x * 3; - } else if (!lhs.is_point_at_infinity() && !rhs.is_point_at_infinity()) { + } else if ((lhs.x != rhs.x) && !lhs.is_point_at_infinity() && !rhs.is_point_at_infinity()) { add_lambda_denominator[i] = rhs.x - lhs.x; add_lambda_numerator[i] = rhs.y - lhs.y; } else { @@ -290,14 +287,12 @@ class ECCVMTranscriptBuilder { add_lambda_denominator[i] = 0; } } - FF::batch_invert(&inverse_trace[0], inverse_trace.size()); - FF::batch_invert(&inverse_trace_x[0], inverse_trace.size()); - FF::batch_invert(&inverse_trace_y[0], inverse_trace.size()); - FF::batch_invert(&transcript_msm_x_inverse_trace[0], inverse_trace.size()); - FF::batch_invert(&add_lambda_denominator[0], inverse_trace.size()); + FF::batch_invert(&inverse_trace_x[0], num_vm_entries); + FF::batch_invert(&inverse_trace_y[0], num_vm_entries); + FF::batch_invert(&transcript_msm_x_inverse_trace[0], num_vm_entries); + FF::batch_invert(&add_lambda_denominator[0], num_vm_entries); - for (size_t i = 0; i < inverse_trace.size(); ++i) { - transcript_state[i + 1].collision_check = inverse_trace[i]; + for (size_t i = 0; i < num_vm_entries; ++i) { transcript_state[i + 1].base_x_inverse = inverse_trace_x[i]; transcript_state[i + 1].base_y_inverse = inverse_trace_y[i]; transcript_state[i + 1].transcript_msm_x_inverse = transcript_msm_x_inverse_trace[i]; diff --git a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_set_relation.cpp b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_set_relation.cpp index 852ceded699..5d9a7c04501 100644 --- a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_set_relation.cpp +++ b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_set_relation.cpp @@ -294,6 +294,7 @@ Accumulator ECCVMSetRelationImpl::compute_permutation_denominator(const AllE auto z2 = View(in.transcript_z2); auto z1_zero = View(in.transcript_z1zero); auto z2_zero = View(in.transcript_z2zero); + auto base_infinity = View(in.transcript_base_infinity); auto transcript_mul = View(in.transcript_mul); auto lookup_first = (-z1_zero + 1); @@ -312,10 +313,12 @@ Accumulator ECCVMSetRelationImpl::compute_permutation_denominator(const AllE // | 1 | 1 | 1 | (X + gamma)(Y + gamma) | transcript_input1 = (transcript_input1 + gamma) * lookup_first + (-lookup_first + 1); transcript_input2 = (transcript_input2 + gamma) * lookup_second + (-lookup_second + 1); - // point_table_init_write = degree 2 + // transcript_product = degree 3 + auto transcript_product = (transcript_input1 * transcript_input2) * (-base_infinity + 1) + base_infinity; - auto point_table_init_write = transcript_mul * transcript_input1 * transcript_input2 + (-transcript_mul + 1); - denominator *= point_table_init_write; // degree-13 + // point_table_init_write = degree 4 + auto point_table_init_write = transcript_mul * transcript_product + (-transcript_mul + 1); + denominator *= point_table_init_write; // degree-14 // auto point_table_init_write_1 = transcript_mul * transcript_input1 + (-transcript_mul + 1); // denominator *= point_table_init_write_1; // degree-11 @@ -339,15 +342,20 @@ Accumulator ECCVMSetRelationImpl::compute_permutation_denominator(const AllE auto z1_zero = View(in.transcript_z1zero); auto z2_zero = View(in.transcript_z2zero); auto transcript_mul = View(in.transcript_mul); + auto base_infinity = View(in.transcript_base_infinity); + // auto transcript_msm_count_zero_at_transition = View(in.transcript_msm_count_zero_at_transition); - auto full_msm_count = transcript_msm_count + transcript_mul * ((-z1_zero + 1) + (-z2_zero + 1)); + // do not add to count if point at infinity! + auto full_msm_count = + transcript_msm_count + transcript_mul * ((-z1_zero + 1) + (-z2_zero + 1)) * (-base_infinity + 1); // auto count_test = transcript_msm_count // msm_result_read = degree 2 auto msm_result_read = transcript_pc_shift + transcript_msm_x * beta + transcript_msm_y * beta_sqr + full_msm_count * beta_cube; - + // N.B. NOT COUNT ZERO NOT NEEDED IS FACTORED INTO MSM TRANSITION + // auto read_active = transcript_msm_transition; msm_result_read = transcript_msm_transition * (msm_result_read + gamma) + (-transcript_msm_transition + 1); - denominator *= msm_result_read; // degree-17 + denominator *= msm_result_read; // degree-20 } return denominator; } diff --git a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_set_relation.hpp b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_set_relation.hpp index 41043a88134..550f2474f7d 100644 --- a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_set_relation.hpp +++ b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_set_relation.hpp @@ -14,8 +14,8 @@ template class ECCVMSetRelationImpl { using FF = FF_; static constexpr std::array SUBRELATION_PARTIAL_LENGTHS{ - 19, // grand product construction sub-relation - 19 // left-shiftable polynomial sub-relation + 21, // grand product construction sub-relation + 21 // left-shiftable polynomial sub-relation }; template static Accumulator convert_to_wnaf(const auto& s0, const auto& s1) diff --git a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp index 6bf2ff6e6c1..4941feaece0 100644 --- a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp +++ b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp @@ -124,17 +124,32 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu * @note pc starts out at its max value and decrements down to 0. This keeps the degree of the pc polynomial smol */ Accumulator pc_delta = pc - pc_shift; - std::get<3>(accumulator) += - is_not_first_row * (pc_delta - q_mul * ((-z1_zero + 1) + (-z2_zero + 1))) * scaling_factor; + auto num_muls_in_row = ((-z1_zero + 1) + (-z2_zero + 1)) * (-transcript_Pinfinity + 1); + std::get<3>(accumulator) += is_not_first_row * (pc_delta - q_mul * num_muls_in_row) * scaling_factor; /** * @brief Validate `msm_transition` is well-formed. * * If the current row is the last mul instruction in a multiscalar multiplication, msm_transition = 1. * i.e. if q_mul == 1 and q_mul_shift == 0, msm_transition = 1, else is 0 + * We also require that `msm_count + [current msm number] > 0` */ auto msm_transition_check = q_mul * (-q_mul_shift + 1); - std::get<4>(accumulator) += (msm_transition - msm_transition_check) * scaling_factor; + // auto num_muls_total = msm_count + num_muls_in_row; + auto msm_count_zero_at_transition = View(in.transcript_msm_count_zero_at_transition); + auto msm_count_at_transition_inverse = View(in.transcript_msm_count_at_transition_inverse); + + auto msm_count_total = msm_count + num_muls_in_row; // degree 3 + + auto msm_count_zero_at_transition_check = msm_count_zero_at_transition * msm_count_total; + msm_count_zero_at_transition_check += + (msm_count_total * msm_count_at_transition_inverse - 1) * (-msm_count_zero_at_transition + 1); + std::get<40>(accumulator) += msm_transition_check * msm_count_zero_at_transition_check * scaling_factor; + + // Validate msm_transition_msm_count is correct + // ensure msm_transition is zero if count is zero + std::get<4>(accumulator) += + (msm_transition - msm_transition_check * (-msm_count_zero_at_transition + 1)) * scaling_factor; /** * @brief Validate `msm_count` resets when we end a multiscalar multiplication. @@ -150,8 +165,9 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu * row). */ auto msm_count_delta = msm_count_shift - msm_count; // degree 4 - std::get<6>(accumulator) += is_not_first_row * (-msm_transition + 1) * - (msm_count_delta - q_mul * ((-z1_zero + 1) + (-z2_zero + 1))) * scaling_factor; + auto num_counts = ((-z1_zero + 1) + (-z2_zero + 1)) * (-transcript_Pinfinity + 1); + std::get<6>(accumulator) += + is_not_first_row * (-msm_transition + 1) * (msm_count_delta - q_mul * (num_counts)) * scaling_factor; /** * @brief Opcode exclusion tests. We have the following assertions: @@ -241,11 +257,11 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu // | 1 | 1 | n/a | infinity | auto add_result_is_lhs = rhs_infinity * (-lhs_infinity + 1); // degree 3 auto add_result_is_rhs = lhs_infinity * (-rhs_infinity + 1); // degree 3 - auto add_result_is_out = (-lhs_infinity + 1) * (-rhs_infinity + 1); // degree 3 auto add_result_infinity_from_inputs = lhs_infinity * rhs_infinity; // degree 2 auto add_result_infinity_from_operation = transcript_add_x_equal * (-transcript_add_y_equal + 1); // degree 2 auto add_result_is_infinity = add_result_infinity_from_inputs + add_result_infinity_from_operation; // degree 2?? + auto lambda_relation_valid = (-lhs_infinity + 1) * (-rhs_infinity + 1) * (-add_result_is_infinity + 1); // degree 4 // Determine the gradient `lambda` of the group operation // If lhs_x == rhs_x, lambda = (3 * lhs_x * lhs_x) / (2 * lhs_y) // Else, lambda = (rhs_y - lhs_y) / (rhs_x - lhs_x) @@ -254,29 +270,30 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu auto lambda_denominator = (rhs_x - lhs_x) * ecc_op_is_add + (lhs_y + lhs_y) * ecc_op_is_dbl; // degree 3 auto lambda_term = lambda_denominator * transcript_add_lambda - lambda_numerator; // degree 4 // We only activate lambda relation if we don't have points at infinity - this is to avoid divide-by-zero problems - // N.B. check this is needed + // N.B. check this is need auto any_add_is_active = q_add + msm_transition; - auto lambda_relation_active = any_add_is_active * add_result_is_out; // degree 4 - auto lambda_relation = lambda_term * lambda_relation_active; // degree 8! - std::get<16>(accumulator) += lambda_relation * scaling_factor; // degree 8 + auto lambda_relation_active = any_add_is_active * lambda_relation_valid; // degree 5 + auto lambda_relation = lambda_term * lambda_relation_active; // degree 9! + // if lambda relation is not active, assert lambda = 0 + lambda_relation += (-lambda_relation_active + 1) * transcript_add_lambda; + std::get<16>(accumulator) += lambda_relation * scaling_factor; // degree 9 // Determine the x/y coordinates of the shifted accumulator // add_x3/add_y3 = result of group operation computation auto add_x3 = transcript_add_lambda * transcript_add_lambda - lhs_x - rhs_x; // degree 2 - auto add_y3 = transcript_add_lambda * (lhs_x - add_x3) - lhs_y; // degree 3 - // x3/y3 = result of group operation computation that considers input points at infinity - auto x3 = (add_x3 * add_result_is_out + lhs_x * add_result_is_lhs + rhs_x * add_result_is_rhs); // degree 5 - auto y3 = (add_y3 * add_result_is_out + lhs_y * add_result_is_lhs + rhs_y * add_result_is_rhs); // degree 6 + add_x3 += (lhs_x + lhs_x + rhs_x) * add_result_is_lhs; + add_x3 += (rhs_x + rhs_x + lhs_x) * add_result_is_rhs; + add_x3 += (lhs_x + rhs_x) * add_result_is_infinity; + auto add_y3 = transcript_add_lambda * (lhs_x - add_x3) - lhs_y; // degree 3 + add_y3 += (lhs_y + lhs_y) * add_result_is_lhs; + add_y3 += (lhs_y + rhs_y) * add_result_is_rhs; + add_y3 += (lhs_y)*add_result_is_infinity; auto propagate_transcript_accumulator = (-q_add - msm_transition - q_reset_accumulator + 1); - auto add_point_x_relation = - (x3 - transcript_accumulator_x_shift * (add_result_is_out + add_result_is_lhs + add_result_is_rhs)) * - any_add_is_active; // degree 7 + auto add_point_x_relation = (add_x3 - transcript_accumulator_x_shift) * any_add_is_active; // degree 7 add_point_x_relation += propagate_transcript_accumulator * (-lagrange_last + 1) * (transcript_accumulator_x_shift - transcript_accumulator_x); - auto add_point_y_relation = - (y3 - transcript_accumulator_y_shift * (add_result_is_out + add_result_is_lhs + add_result_is_rhs)) * - any_add_is_active; // degree 7 + auto add_point_y_relation = (add_y3 - transcript_accumulator_y_shift) * any_add_is_active; // degree 7 add_point_y_relation += propagate_transcript_accumulator * (-lagrange_last + 1) * (transcript_accumulator_y_shift - transcript_accumulator_y); std::get<17>(accumulator) += add_point_x_relation * scaling_factor; // degree 7 @@ -308,6 +325,8 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu auto transcript_add_x_equal_check_relation = (x_diff * x_product + x_constant) * any_add_is_active; std::get<20>(accumulator) += transcript_add_x_equal_check_relation * scaling_factor; // degree 6 + // TODO: IF MUL PRODUCES 0 POINTS DUE TO Z1=0, Z2=0 OR POINTS AT INFINITY, ENSURE THAT MSM_OUTPUT IS ALWAYS POINT AT + // INFINITY /** * @brief Validate `transcript_add_y_equal` is well-formed * If lhs_y == rhs_y, transcript_add_y_equal = 1 @@ -331,7 +350,8 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu std::get<30>(accumulator) += transcript_add_x_equal * (transcript_add_x_equal - 1) * scaling_factor; std::get<31>(accumulator) += transcript_add_y_equal * (transcript_add_y_equal - 1) * scaling_factor; std::get<32>(accumulator) += transcript_Pinfinity * (transcript_Pinfinity - 1) * scaling_factor; - + std::get<33>(accumulator) += transcript_msm_infinity * (transcript_msm_infinity - 1) * scaling_factor; + std::get<39>(accumulator) += msm_count_zero_at_transition * (msm_count_zero_at_transition - 1) * scaling_factor; // step 1: subtract offset generator from msm_accumulator // this might produce a point at infinity { @@ -355,19 +375,19 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu x_term * (-transcript_msm_infinity + 1) + transcript_msm_infinity * x3; const auto transcript_offset_generator_subtract_y = y_term * (-transcript_msm_infinity + 1) + transcript_msm_infinity * y3; - std::get<33>(accumulator) += msm_transition * transcript_offset_generator_subtract_x * scaling_factor; - std::get<34>(accumulator) += msm_transition * transcript_offset_generator_subtract_y * scaling_factor; + std::get<34>(accumulator) += msm_transition * transcript_offset_generator_subtract_x * scaling_factor; + std::get<35>(accumulator) += msm_transition * transcript_offset_generator_subtract_y * scaling_factor; // validate transcript_msm_infinity is correct // if transcript_msm_infinity = 1, (x2 == x1) and (y2 + y1 == 0) const auto x_diff = x2 - x1; const auto y_sum = y2 + y1; - std::get<35>(accumulator) += msm_transition * transcript_msm_infinity * x_diff * scaling_factor; - std::get<36>(accumulator) += msm_transition * transcript_msm_infinity * y_sum * scaling_factor; + std::get<36>(accumulator) += msm_transition * transcript_msm_infinity * x_diff * scaling_factor; + std::get<37>(accumulator) += msm_transition * transcript_msm_infinity * y_sum * scaling_factor; // if transcript_msm_infinity = 1, then x_diff must have an inverse const auto transcript_msm_x_inverse = View(in.transcript_msm_x_inverse); const auto inverse_term = (-transcript_msm_infinity + 1) * (x_diff * transcript_msm_x_inverse - 1); - std::get<37>(accumulator) += msm_transition * inverse_term * scaling_factor; + std::get<38>(accumulator) += msm_transition * inverse_term * scaling_factor; } // Validate correctness of diff --git a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.hpp b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.hpp index b214f3c941c..448db8613f2 100644 --- a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.hpp +++ b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.hpp @@ -30,9 +30,9 @@ template class ECCVMTranscriptRelationImpl { public: using FF = FF_; - static constexpr std::array SUBRELATION_PARTIAL_LENGTHS{ - 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, - 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + static constexpr std::array SUBRELATION_PARTIAL_LENGTHS{ + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, }; template diff --git a/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp b/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp index 4ef2ef12ef8..efc17614cca 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp @@ -372,6 +372,35 @@ class ECCOpQueue { return ultra_op; } + /** + * @brief Write equality op using internal accumulator point + * + * @return current internal accumulator point (prior to reset to 0) + */ + UltraOp eq_and_resetb(Point& expected) + { + accumulator.self_set_infinity(); + + // Construct and store the operation in the ultra op format + auto ultra_op = construct_and_populate_ultra_ops(EQUALITY, expected); + + // Store raw operation + raw_ops.emplace_back(ECCVMOperation{ + .add = false, + .mul = false, + .eq = true, + .reset = true, + .base_point = expected, + .z1 = 0, + .z2 = 0, + .mul_scalar_full = 0, + }); + num_transcript_rows += 1; + update_cached_msms(raw_ops.back()); + + return ultra_op; + } + /** * @brief Write equality op using internal accumulator point * @@ -411,10 +440,10 @@ class ECCOpQueue { void update_cached_msms(const ECCVMOperation& op) { if (op.mul) { - if (op.z1 != 0) { + if (op.z1 != 0 && !op.base_point.is_point_at_infinity()) { cached_active_msm_count++; } - if (op.z2 != 0) { + if (op.z2 != 0 && !op.base_point.is_point_at_infinity()) { cached_active_msm_count++; } } else if (cached_active_msm_count != 0) { From e435687c08662cb0c722a5153cd8df1d2d62a0ac Mon Sep 17 00:00:00 2001 From: zac-williamson Date: Tue, 14 May 2024 12:36:10 +0000 Subject: [PATCH 06/24] tidied eccvm code moved eccvm boolean relations into low degree relation class, added missing boolean checks removed unused transcript columns --- .../relations_bench/relations.bench.cpp | 1 + .../eccvm/eccvm_circuit_builder.test.cpp | 13 + .../src/barretenberg/eccvm/eccvm_flavor.hpp | 197 ++++----- .../eccvm/eccvm_trace_checker.cpp | 1 + .../eccvm/eccvm_transcript.test.cpp | 2 - .../src/barretenberg/eccvm/eccvm_verifier.cpp | 2 - .../barretenberg/eccvm/transcript_builder.hpp | 11 +- .../relations/ecc_vm/ecc_bools_relation.cpp | 90 ++++ .../relations/ecc_vm/ecc_bools_relation.hpp | 33 ++ .../ecc_vm/ecc_transcript_relation.cpp | 392 ++++++++++-------- .../ecc_vm/ecc_transcript_relation.hpp | 7 +- .../op_queue/ecc_op_queue.hpp | 27 ++ 12 files changed, 477 insertions(+), 299 deletions(-) create mode 100644 barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_bools_relation.cpp create mode 100644 barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_bools_relation.hpp diff --git a/barretenberg/cpp/src/barretenberg/benchmark/relations_bench/relations.bench.cpp b/barretenberg/cpp/src/barretenberg/benchmark/relations_bench/relations.bench.cpp index 48959e431e9..837764376b3 100644 --- a/barretenberg/cpp/src/barretenberg/benchmark/relations_bench/relations.bench.cpp +++ b/barretenberg/cpp/src/barretenberg/benchmark/relations_bench/relations.bench.cpp @@ -53,6 +53,7 @@ BENCHMARK(execute_relation>); BENCHMARK(execute_relation>); BENCHMARK(execute_relation>); BENCHMARK(execute_relation>); +BENCHMARK(execute_relation>); } // namespace bb::benchmark::relations diff --git a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.test.cpp b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.test.cpp index 7f302f9a1e5..e7393974f80 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.test.cpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.test.cpp @@ -31,6 +31,8 @@ TEST(ECCVMCircuitBuilderTests, BaseCase) op_queue->mul_accumulate(b, y); op_queue->add_accumulate(a); op_queue->mul_accumulate(b, x); + op_queue->no_op(); + op_queue->add_accumulate(b); op_queue->eq_and_reset(); op_queue->add_accumulate(c); op_queue->mul_accumulate(a, x); @@ -67,6 +69,17 @@ TEST(ECCVMCircuitBuilderTests, BaseCase) EXPECT_EQ(result, true); } +TEST(ECCVMCircuitBuilderTests, NoOp) +{ + std::shared_ptr op_queue = std::make_shared(); + + op_queue->no_op(); + + ECCVMCircuitBuilder circuit{ op_queue }; + bool result = ECCVMTraceChecker::check(circuit); + EXPECT_EQ(result, true); +} + TEST(ECCVMCircuitBuilderTests, Add) { std::shared_ptr op_queue = std::make_shared(); diff --git a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp index 074011b887d..862723a9fc0 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp @@ -8,6 +8,7 @@ #include "barretenberg/flavor/flavor_macros.hpp" #include "barretenberg/flavor/relation_definitions.hpp" #include "barretenberg/polynomials/univariate.hpp" +#include "barretenberg/relations/ecc_vm/ecc_bools_relation.hpp" #include "barretenberg/relations/ecc_vm/ecc_lookup_relation.hpp" #include "barretenberg/relations/ecc_vm/ecc_msm_relation.hpp" #include "barretenberg/relations/ecc_vm/ecc_point_table_relation.hpp" @@ -35,17 +36,17 @@ class ECCVMFlavor { using VerifierCommitmentKey = bb::VerifierCommitmentKey; using RelationSeparator = FF; - static constexpr size_t NUM_WIRES = 87; + static constexpr size_t NUM_WIRES = 85; // The number of multivariate polynomials on which a sumcheck prover sumcheck operates (including shifts). We often // need containers of this size to hold related data, so we choose a name more agnostic than `NUM_POLYNOMIALS`. // Note: this number does not include the individual sorted list polynomials. - static constexpr size_t NUM_ALL_ENTITIES = 118; + static constexpr size_t NUM_ALL_ENTITIES = 116; // The number of polynomials precomputed to describe a circuit and to aid a prover in constructing a satisfying // assignment of witnesses. We again choose a neutral name. static constexpr size_t NUM_PRECOMPUTED_ENTITIES = 3; // The total number of witness entities not including shifts. - static constexpr size_t NUM_WITNESS_ENTITIES = 89; + static constexpr size_t NUM_WITNESS_ENTITIES = 87; using GrandProductRelations = std::tuple>; // define the tuple of Relations that comprise the Sumcheck relation @@ -54,7 +55,8 @@ class ECCVMFlavor { ECCVMWnafRelation, ECCVMMSMRelation, ECCVMSetRelation, - ECCVMLookupRelation>; + ECCVMLookupRelation, + ECCVMBoolsRelation>; using LookupRelation = ECCVMLookupRelation; static constexpr size_t MAX_PARTIAL_RELATION_LENGTH = compute_max_partial_relation_length(); @@ -107,93 +109,91 @@ class ECCVMFlavor { template class WireEntities { public: DEFINE_FLAVOR_MEMBERS(DataType, - transcript_add, // column 0 - transcript_mul, // column 1 - transcript_eq, // column 2 - transcript_collision_check, // column 3 - transcript_msm_transition, // column 4 - transcript_pc, // column 5 - transcript_msm_count, // column 6 - transcript_Px, // column 7 - transcript_Py, // column 8 - transcript_z1, // column 9 - transcript_z2, // column 10 - transcript_z1zero, // column 11 - transcript_z2zero, // column 12 - transcript_op, // column 13 - transcript_accumulator_x, // column 14 - transcript_accumulator_y, // column 15 - transcript_msm_x, // column 16 - transcript_msm_y, // column 17 - precompute_pc, // column 18 - precompute_point_transition, // column 19 - precompute_round, // column 20 - precompute_scalar_sum, // column 21 - precompute_s1hi, // column 22 - precompute_s1lo, // column 23 - precompute_s2hi, // column 24 - precompute_s2lo, // column 25 - precompute_s3hi, // column 26 - precompute_s3lo, // column 27 - precompute_s4hi, // column 28 - precompute_s4lo, // column 29 - precompute_skew, // column 30 - precompute_dx, // column 31 - precompute_dy, // column 32 - precompute_tx, // column 33 - precompute_ty, // column 34 - msm_transition, // column 35 - msm_add, // column 36 - msm_double, // column 37 - msm_skew, // column 38 - msm_accumulator_x, // column 39 - msm_accumulator_y, // column 40 - msm_pc, // column 41 - msm_size_of_msm, // column 42 - msm_count, // column 43 - msm_round, // column 44 - msm_add1, // column 45 - msm_add2, // column 46 - msm_add3, // column 47 - msm_add4, // column 48 - msm_x1, // column 49 - msm_y1, // column 50 - msm_x2, // column 51 - msm_y2, // column 52 - msm_x3, // column 53 - msm_y3, // column 54 - msm_x4, // column 55 - msm_y4, // column 56 - msm_collision_x1, // column 57 - msm_collision_x2, // column 58 - msm_collision_x3, // column 59 - msm_collision_x4, // column 60 - msm_lambda1, // column 61 - msm_lambda2, // column 62 - msm_lambda3, // column 63 - msm_lambda4, // column 64 - msm_slice1, // column 65 - msm_slice2, // column 66 - msm_slice3, // column 67 - msm_slice4, // column 68 - transcript_accumulator_empty, // column 69 - transcript_reset_accumulator, // column 70 - precompute_select, // column 71 - lookup_read_counts_0, // column 72 - lookup_read_counts_1, // column 73 - transcript_base_infinity, // column 74 - transcript_base_x_inverse, // column 75 - transcript_base_y_inverse, // column 76 - transcript_add_x_equal, // column 77 - transcript_add_y_equal, // column 78 - transcript_y_collision_check, // column 79 - transcript_add_lambda, // column 80 - transcript_msm_intermediate_x, // column 81 - transcript_msm_intermediate_y, // column 82 - transcript_msm_infinity, // column 83 - transcript_msm_x_inverse, - transcript_msm_count_zero_at_transition, - transcript_msm_count_at_transition_inverse); // column 86 + transcript_add, // column 0 + transcript_mul, // column 1 + transcript_eq, // column 2 + transcript_msm_transition, // column 3 + transcript_pc, // column 4 + transcript_msm_count, // column 5 + transcript_Px, // column 6 + transcript_Py, // column 7 + transcript_z1, // column 8 + transcript_z2, // column 9 + transcript_z1zero, // column 10 + transcript_z2zero, // column 11 + transcript_op, // column 12 + transcript_accumulator_x, // column 13 + transcript_accumulator_y, // column 14 + transcript_msm_x, // column 15 + transcript_msm_y, // column 16 + precompute_pc, // column 17 + precompute_point_transition, // column 18 + precompute_round, // column 19 + precompute_scalar_sum, // column 20 + precompute_s1hi, // column 21 + precompute_s1lo, // column 22 + precompute_s2hi, // column 23 + precompute_s2lo, // column 24 + precompute_s3hi, // column 25 + precompute_s3lo, // column 26 + precompute_s4hi, // column 27 + precompute_s4lo, // column 28 + precompute_skew, // column 29 + precompute_dx, // column 30 + precompute_dy, // column 31 + precompute_tx, // column 32 + precompute_ty, // column 33 + msm_transition, // column 34 + msm_add, // column 35 + msm_double, // column 36 + msm_skew, // column 37 + msm_accumulator_x, // column 38 + msm_accumulator_y, // column 39 + msm_pc, // column 40 + msm_size_of_msm, // column 41 + msm_count, // column 42 + msm_round, // column 43 + msm_add1, // column 44 + msm_add2, // column 45 + msm_add3, // column 46 + msm_add4, // column 47 + msm_x1, // column 48 + msm_y1, // column 49 + msm_x2, // column 50 + msm_y2, // column 51 + msm_x3, // column 52 + msm_y3, // column 53 + msm_x4, // column 54 + msm_y4, // column 55 + msm_collision_x1, // column 56 + msm_collision_x2, // column 57 + msm_collision_x3, // column 58 + msm_collision_x4, // column 59 + msm_lambda1, // column 60 + msm_lambda2, // column 61 + msm_lambda3, // column 62 + msm_lambda4, // column 63 + msm_slice1, // column 64 + msm_slice2, // column 65 + msm_slice3, // column 66 + msm_slice4, // column 67 + transcript_accumulator_empty, // column 68 + transcript_reset_accumulator, // column 69 + precompute_select, // column 70 + lookup_read_counts_0, // column 71 + lookup_read_counts_1, // column 72 + transcript_base_infinity, // column 73 + transcript_base_x_inverse, // column 74 + transcript_base_y_inverse, // column 75 + transcript_add_x_equal, // column 76 + transcript_add_y_equal, // column 77 + transcript_add_lambda, // column 78 + transcript_msm_intermediate_x, // column 79 + transcript_msm_intermediate_y, // column 80 + transcript_msm_infinity, // column 81 + transcript_msm_x_inverse, // column 82 + transcript_msm_count_zero_at_transition, // column 83 + transcript_msm_count_at_transition_inverse); // column 84 }; /** @@ -456,7 +456,6 @@ class ECCVMFlavor { * lagrange_second: lagrange_second[1] = 1, 0 elsewhere * lagrange_last: lagrange_last[lagrange_last.size() - 1] = 1, 0 elsewhere * transcript_add/mul/eq/reset_accumulator: boolean selectors that toggle add/mul/eq/reset opcodes - * transcript_collision_check: used to ensure any point being added into eccvm accumulator does not trigger * incomplete addition rules * transcript_msm_transition: is current transcript row the final `mul` opcode of a multiscalar @@ -584,13 +583,11 @@ class ECCVMFlavor { transcript_accumulator_y[i] = transcript_state[i].accumulator_y; transcript_msm_x[i] = transcript_state[i].msm_output_x; transcript_msm_y[i] = transcript_state[i].msm_output_y; - transcript_collision_check[i] = transcript_state[i].collision_check; transcript_base_infinity[i] = transcript_state[i].base_infinity; transcript_base_x_inverse[i] = transcript_state[i].base_x_inverse; transcript_base_y_inverse[i] = transcript_state[i].base_y_inverse; transcript_add_x_equal[i] = transcript_state[i].transcript_add_x_equal; transcript_add_y_equal[i] = transcript_state[i].transcript_add_y_equal; - transcript_y_collision_check[i] = transcript_state[i].transcript_y_collision_check; transcript_add_lambda[i] = transcript_state[i].transcript_add_lambda; transcript_msm_intermediate_x[i] = transcript_state[i].transcript_msm_intermediate_x; transcript_msm_intermediate_y[i] = transcript_state[i].transcript_msm_intermediate_y; @@ -730,7 +727,6 @@ class ECCVMFlavor { Base::transcript_add = "TRANSCRIPT_ADD"; Base::transcript_mul = "TRANSCRIPT_MUL"; Base::transcript_eq = "TRANSCRIPT_EQ"; - Base::transcript_collision_check = "TRANSCRIPT_COLLISION_CHECK"; Base::transcript_msm_transition = "TRANSCRIPT_MSM_TRANSITION"; Base::transcript_pc = "TRANSCRIPT_PC"; Base::transcript_msm_count = "TRANSCRIPT_MSM_COUNT"; @@ -806,7 +802,6 @@ class ECCVMFlavor { Base::transcript_base_y_inverse = "TRANSCRIPT_BASE_Y_INVERSE"; Base::transcript_add_x_equal = "TRANSCRIPT_ADD_X_EQUAL"; Base::transcript_add_y_equal = "TRANSCRIPT_ADD_Y_EQUAL"; - Base::transcript_y_collision_check = "TRANSCRIPT_Y_COLLISION_CHECK"; Base::transcript_add_lambda = "TRANSCRIPT_ADD_LAMBDA"; Base::transcript_msm_intermediate_x = "TRANSCRIPT_MSM_INTERMEDIATE_X"; Base::transcript_msm_intermediate_y = "TRANSCRIPT_MSM_INTERMEDIATE_Y"; @@ -843,7 +838,6 @@ class ECCVMFlavor { Commitment transcript_add_comm; Commitment transcript_mul_comm; Commitment transcript_eq_comm; - Commitment transcript_collision_check_comm; Commitment transcript_msm_transition_comm; Commitment transcript_pc_comm; Commitment transcript_msm_count_comm; @@ -919,7 +913,6 @@ class ECCVMFlavor { Commitment transcript_base_y_inverse_comm; Commitment transcript_add_x_equal_comm; Commitment transcript_add_y_equal_comm; - Commitment transcript_y_collision_check_comm; Commitment transcript_add_lambda_comm; Commitment transcript_msm_intermediate_x_comm; Commitment transcript_msm_intermediate_y_comm; @@ -968,8 +961,6 @@ class ECCVMFlavor { NativeTranscript::proof_data, num_frs_read); transcript_eq_comm = NativeTranscript::template deserialize_from_buffer( NativeTranscript::proof_data, num_frs_read); - transcript_collision_check_comm = NativeTranscript::template deserialize_from_buffer( - NativeTranscript::proof_data, num_frs_read); transcript_msm_transition_comm = NativeTranscript::template deserialize_from_buffer( NativeTranscript::proof_data, num_frs_read); transcript_pc_comm = NativeTranscript::template deserialize_from_buffer( @@ -1120,8 +1111,6 @@ class ECCVMFlavor { NativeTranscript::proof_data, num_frs_read); transcript_add_y_equal_comm = NativeTranscript::template deserialize_from_buffer( NativeTranscript::proof_data, num_frs_read); - transcript_y_collision_check_comm = NativeTranscript::template deserialize_from_buffer( - NativeTranscript::proof_data, num_frs_read); transcript_add_lambda_comm = NativeTranscript::template deserialize_from_buffer( NativeTranscript::proof_data, num_frs_read); transcript_msm_intermediate_x_comm = NativeTranscript::template deserialize_from_buffer( @@ -1206,8 +1195,6 @@ class ECCVMFlavor { NativeTranscript::template serialize_to_buffer(transcript_add_comm, NativeTranscript::proof_data); NativeTranscript::template serialize_to_buffer(transcript_mul_comm, NativeTranscript::proof_data); NativeTranscript::template serialize_to_buffer(transcript_eq_comm, NativeTranscript::proof_data); - NativeTranscript::template serialize_to_buffer(transcript_collision_check_comm, - NativeTranscript::proof_data); NativeTranscript::template serialize_to_buffer(transcript_msm_transition_comm, NativeTranscript::proof_data); NativeTranscript::template serialize_to_buffer(transcript_pc_comm, NativeTranscript::proof_data); @@ -1289,8 +1276,6 @@ class ECCVMFlavor { NativeTranscript::proof_data); NativeTranscript::template serialize_to_buffer(transcript_add_x_equal_comm, NativeTranscript::proof_data); NativeTranscript::template serialize_to_buffer(transcript_add_y_equal_comm, NativeTranscript::proof_data); - NativeTranscript::template serialize_to_buffer(transcript_y_collision_check_comm, - NativeTranscript::proof_data); NativeTranscript::template serialize_to_buffer(transcript_add_lambda_comm, NativeTranscript::proof_data); NativeTranscript::template serialize_to_buffer(transcript_msm_intermediate_x_comm, diff --git a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_trace_checker.cpp b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_trace_checker.cpp index 349ec65c675..f46ba36b44b 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_trace_checker.cpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_trace_checker.cpp @@ -65,6 +65,7 @@ bool ECCVMTraceChecker::check(Builder& builder) result = result && evaluate_relation.template operator()>("ECCVMWnafRelation"); result = result && evaluate_relation.template operator()>("ECCVMMSMRelation"); result = result && evaluate_relation.template operator()>("ECCVMSetRelation"); + result = result && evaluate_relation.template operator()>("ECCVMBoolsRelation"); using LookupRelation = ECCVMLookupRelation; typename ECCVMLookupRelation::SumcheckArrayOfValuesOverSubrelations lookup_result; diff --git a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_transcript.test.cpp b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_transcript.test.cpp index 720567d92c5..0b2e13a7850 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_transcript.test.cpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_transcript.test.cpp @@ -45,7 +45,6 @@ class ECCVMTranscriptTests : public ::testing::Test { manifest_expected.add_entry(round, "TRANSCRIPT_ADD", frs_per_G); manifest_expected.add_entry(round, "TRANSCRIPT_MUL", frs_per_G); manifest_expected.add_entry(round, "TRANSCRIPT_EQ", frs_per_G); - manifest_expected.add_entry(round, "TRANSCRIPT_COLLISION_CHECK", frs_per_G); manifest_expected.add_entry(round, "TRANSCRIPT_MSM_TRANSITION", frs_per_G); manifest_expected.add_entry(round, "TRANSCRIPT_PC", frs_per_G); manifest_expected.add_entry(round, "TRANSCRIPT_MSM_COUNT", frs_per_G); @@ -121,7 +120,6 @@ class ECCVMTranscriptTests : public ::testing::Test { manifest_expected.add_entry(round, "TRANSCRIPT_BASE_Y_INVERSE", frs_per_G); manifest_expected.add_entry(round, "TRANSCRIPT_ADD_X_EQUAL", frs_per_G); manifest_expected.add_entry(round, "TRANSCRIPT_ADD_Y_EQUAL", frs_per_G); - manifest_expected.add_entry(round, "TRANSCRIPT_Y_COLLISION_CHECK", frs_per_G); manifest_expected.add_entry(round, "TRANSCRIPT_ADD_LAMBDA", frs_per_G); manifest_expected.add_entry(round, "TRANSCRIPT_MSM_INTERMEDIATE_X", frs_per_G); manifest_expected.add_entry(round, "TRANSCRIPT_MSM_INTERMEDIATE_Y", frs_per_G); diff --git a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_verifier.cpp b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_verifier.cpp index 884e0af3cfc..fbacd9ca6a4 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_verifier.cpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_verifier.cpp @@ -31,7 +31,6 @@ bool ECCVMVerifier::verify_proof(const HonkProof& proof) commitments.transcript_add = receive_commitment(commitment_labels.transcript_add); commitments.transcript_mul = receive_commitment(commitment_labels.transcript_mul); commitments.transcript_eq = receive_commitment(commitment_labels.transcript_eq); - commitments.transcript_collision_check = receive_commitment(commitment_labels.transcript_collision_check); commitments.transcript_msm_transition = receive_commitment(commitment_labels.transcript_msm_transition); commitments.transcript_pc = receive_commitment(commitment_labels.transcript_pc); commitments.transcript_msm_count = receive_commitment(commitment_labels.transcript_msm_count); @@ -107,7 +106,6 @@ bool ECCVMVerifier::verify_proof(const HonkProof& proof) commitments.transcript_base_y_inverse = receive_commitment(commitment_labels.transcript_base_y_inverse); commitments.transcript_add_x_equal = receive_commitment(commitment_labels.transcript_add_x_equal); commitments.transcript_add_y_equal = receive_commitment(commitment_labels.transcript_add_y_equal); - commitments.transcript_y_collision_check = receive_commitment(commitment_labels.transcript_y_collision_check); commitments.transcript_add_lambda = receive_commitment(commitment_labels.transcript_add_lambda); commitments.transcript_msm_intermediate_x = receive_commitment(commitment_labels.transcript_msm_intermediate_x); commitments.transcript_msm_intermediate_y = receive_commitment(commitment_labels.transcript_msm_intermediate_y); diff --git a/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp b/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp index a95806f671e..ed087e5cd62 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp @@ -31,13 +31,11 @@ class ECCVMTranscriptBuilder { FF accumulator_y = 0; FF msm_output_x = 0; FF msm_output_y = 0; - FF collision_check = 0; bool base_infinity = 0; FF base_x_inverse = 0; FF base_y_inverse = 0; bool transcript_add_x_equal = false; bool transcript_add_y_equal = false; - FF transcript_y_collision_check = 0; FF transcript_add_lambda = 0; FF transcript_msm_intermediate_x = 0; FF transcript_msm_intermediate_y = 0; @@ -93,11 +91,11 @@ class ECCVMTranscriptBuilder { // We fill these vectors and then perform batch inversions to amortize the cost of FF inverts std::vector inverse_trace_x(num_vm_entries); std::vector inverse_trace_y(num_vm_entries); - std::vector transcript_y_collision_check(num_vm_entries); std::vector transcript_add_lambda(num_vm_entries); std::vector transcript_msm_x_inverse_trace(num_vm_entries); std::vector add_lambda_denominator(num_vm_entries); std::vector add_lambda_numerator(num_vm_entries); + std::vector msm_count_at_transition_inverse_trace(num_vm_entries); std::vector msm_accumulator_trace(num_vm_entries); std::vector accumulator_trace(num_vm_entries); std::vector intermediate_accumulator_trace(num_vm_entries); @@ -153,7 +151,6 @@ class ECCVMTranscriptBuilder { updated_state.msm_accumulator = R + P * entry.mul_scalar_full; } - // TODO IF FAKE TRANSITION FIGURE OUT WHAT TO DO WITH ACCUMULATORS BLAH BLAH BLAH if (msm_transition) { if (state.is_accumulator_empty) { updated_state.accumulator = updated_state.msm_accumulator - offset_generator(); @@ -182,8 +179,7 @@ class ECCVMTranscriptBuilder { row.pc = state.pc; row.msm_count = state.count; row.msm_count_zero_at_transition = ((state.count + num_muls) == 0) && (entry.mul && next_not_msm); - row.msm_count_at_transition_inverse = - ((state.count + num_muls) == 0) ? 0 : FF(state.count + num_muls).invert(); // TODO BATCH + msm_count_at_transition_inverse_trace[i] = ((state.count + num_muls) == 0) ? 0 : FF(state.count + num_muls); row.base_x = ((entry.add || entry.mul || entry.eq) && !base_point_infinity) ? entry.base_point.x : 0; row.base_y = ((entry.add || entry.mul || entry.eq) && !base_point_infinity) ? entry.base_point.y : 0; row.base_infinity = (entry.add || entry.mul || entry.eq) ? (base_point_infinity ? 1 : 0) : 0; @@ -291,12 +287,13 @@ class ECCVMTranscriptBuilder { FF::batch_invert(&inverse_trace_y[0], num_vm_entries); FF::batch_invert(&transcript_msm_x_inverse_trace[0], num_vm_entries); FF::batch_invert(&add_lambda_denominator[0], num_vm_entries); - + FF::batch_invert(&msm_count_at_transition_inverse_trace[0], num_vm_entries); for (size_t i = 0; i < num_vm_entries; ++i) { transcript_state[i + 1].base_x_inverse = inverse_trace_x[i]; transcript_state[i + 1].base_y_inverse = inverse_trace_y[i]; transcript_state[i + 1].transcript_msm_x_inverse = transcript_msm_x_inverse_trace[i]; transcript_state[i + 1].transcript_add_lambda = add_lambda_numerator[i] * add_lambda_denominator[i]; + transcript_state[i + 1].msm_count_at_transition_inverse = msm_count_at_transition_inverse_trace[i]; } TranscriptState& final_row = transcript_state.back(); final_row.pc = updated_state.pc; diff --git a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_bools_relation.cpp b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_bools_relation.cpp new file mode 100644 index 00000000000..53fcdee6ef8 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_bools_relation.cpp @@ -0,0 +1,90 @@ +#include +#include + +#include "./ecc_bools_relation.hpp" +#include "barretenberg/eccvm/eccvm_flavor.hpp" +#include "barretenberg/flavor/relation_definitions.hpp" + +namespace bb { + +/** + * @brief ECCVMBoolsRelationImpl evaluates the correctness of ECCVM boolean checks + * + * @details There are a lot of columns in ECCVM that are boolean. As these are all low-degree we place them in a + * separate relation class + * @tparam FF + * @tparam ContainerOverSubrelations + * @tparam AllEntities + * @tparam Parameters + */ +template +template +void ECCVMBoolsRelationImpl::accumulate(ContainerOverSubrelations& accumulator, + const AllEntities& in, + const Parameters& /*unused*/, + const FF& scaling_factor) +{ + using Accumulator = typename std::tuple_element_t<0, ContainerOverSubrelations>; + using View = typename Accumulator::View; + + auto z1 = View(in.transcript_z1); + auto z2 = View(in.transcript_z2); + auto z1_zero = View(in.transcript_z1zero); + auto z2_zero = View(in.transcript_z2zero); + auto msm_count_zero_at_transition = View(in.transcript_msm_count_zero_at_transition); + auto q_add = View(in.transcript_add); + auto q_mul = View(in.transcript_mul); + auto q_eq = View(in.transcript_eq); + auto transcript_msm_transition = View(in.transcript_msm_transition); + auto is_accumulator_empty = View(in.transcript_accumulator_empty); + auto q_reset_accumulator = View(in.transcript_reset_accumulator); + auto transcript_Pinfinity = View(in.transcript_base_infinity); + auto transcript_msm_infinity = View(in.transcript_msm_infinity); + auto transcript_add_x_equal = View(in.transcript_add_x_equal); + auto transcript_add_y_equal = View(in.transcript_add_y_equal); + auto precompute_point_transition = View(in.precompute_point_transition); + auto msm_transition = View(in.msm_transition); + auto msm_add = View(in.msm_add); + auto msm_double = View(in.msm_double); + auto msm_skew = View(in.msm_skew); + auto precompute_select = View(in.precompute_select); + + std::get<0>(accumulator) += q_eq * (q_eq - 1) * scaling_factor; + std::get<1>(accumulator) += q_add * (q_add - 1) * scaling_factor; + std::get<2>(accumulator) += q_mul * (q_mul - 1) * scaling_factor; + std::get<3>(accumulator) += q_reset_accumulator * (q_reset_accumulator - 1) * scaling_factor; + std::get<4>(accumulator) += transcript_msm_transition * (transcript_msm_transition - 1) * scaling_factor; + std::get<5>(accumulator) += is_accumulator_empty * (is_accumulator_empty - 1) * scaling_factor; + std::get<6>(accumulator) += z1_zero * (z1_zero - 1) * scaling_factor; + std::get<7>(accumulator) += z2_zero * (z2_zero - 1) * scaling_factor; + std::get<8>(accumulator) += transcript_add_x_equal * (transcript_add_x_equal - 1) * scaling_factor; + std::get<9>(accumulator) += transcript_add_y_equal * (transcript_add_y_equal - 1) * scaling_factor; + std::get<10>(accumulator) += transcript_Pinfinity * (transcript_Pinfinity - 1) * scaling_factor; + std::get<11>(accumulator) += transcript_msm_infinity * (transcript_msm_infinity - 1) * scaling_factor; + std::get<12>(accumulator) += msm_count_zero_at_transition * (msm_count_zero_at_transition - 1) * scaling_factor; + std::get<13>(accumulator) += msm_transition * (msm_transition - 1) * scaling_factor; + std::get<14>(accumulator) += precompute_point_transition * (precompute_point_transition - 1) * scaling_factor; + std::get<15>(accumulator) += msm_add * (msm_add - 1) * scaling_factor; + std::get<16>(accumulator) += msm_double * (msm_double - 1) * scaling_factor; + std::get<17>(accumulator) += msm_skew * (msm_skew - 1) * scaling_factor; + std::get<18>(accumulator) += precompute_select * (precompute_select - 1) * scaling_factor; + + /** + * @brief Validate correctness of z1_zero, z2_zero. + * If z1_zero = 0 and operation is a MUL, we will write a scalar mul instruction into our multiplication table. + * If z1_zero = 1 and operation is a MUL, we will NOT write a scalar mul instruction. + * (same with z2_zero). + * z1_zero / z2_zero is user-defined. + * We constraint z1_zero such that if z1_zero == 1, we require z1 == 0. (same for z2_zero). + * We do *NOT* constrain z1 != 0 if z1_zero = 0. If the user sets z1_zero = 0 and z1 = 0, + * this will add a scalar mul instruction into the multiplication table, where the scalar multiplier is 0. + * This is inefficient but will still produce the correct output. + */ + std::get<19>(accumulator) += (z1 * z1_zero) * scaling_factor; // if z1_zero = 1, z1 must be 0 + std::get<20>(accumulator) += (z2 * z2_zero) * scaling_factor; +} + +template class ECCVMBoolsRelationImpl; +DEFINE_SUMCHECK_RELATION_CLASS(ECCVMBoolsRelationImpl, ECCVMFlavor); + +} // namespace bb diff --git a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_bools_relation.hpp b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_bools_relation.hpp new file mode 100644 index 00000000000..3a91f559cfb --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_bools_relation.hpp @@ -0,0 +1,33 @@ +#pragma once + +#include "barretenberg/ecc/curves/bn254/g1.hpp" +#include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" +#include "barretenberg/relations/relation_types.hpp" + +namespace bb { + +/** + * @brief ECCVMBoolsRelationImpl evaluates the correctness of ECCVM boolean checks + * + * @details There are a lot of columns in ECCVM that are boolean. As these are all low-degree we place them in a + * separate relation class + * @tparam FF + */ +template class ECCVMBoolsRelationImpl { + public: + using FF = FF_; + + static constexpr std::array SUBRELATION_PARTIAL_LENGTHS{ + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + }; + + template + static void accumulate(ContainerOverSubrelations& accumulator, + const AllEntities& in, + const Parameters& /* unused */, + const FF& scaling_factor); +}; + +template using ECCVMBoolsRelation = Relation>; + +} // namespace bb diff --git a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp index 4941feaece0..7269f81e8d1 100644 --- a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp +++ b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp @@ -86,9 +86,9 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu auto transcript_msm_infinity = View(in.transcript_msm_infinity); auto is_not_first_row = (-lagrange_first + 1); + auto is_not_last_row = (-lagrange_last + 1); auto is_not_first_or_last_row = (-lagrange_first + -lagrange_last + 1); auto is_not_infinity = (-transcript_Pinfinity + 1); - /** * @brief Validate correctness of z1_zero, z2_zero. * If z1_zero = 0 and operation is a MUL, we will write a scalar mul instruction into our multiplication table. @@ -144,11 +144,11 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu auto msm_count_zero_at_transition_check = msm_count_zero_at_transition * msm_count_total; msm_count_zero_at_transition_check += (msm_count_total * msm_count_at_transition_inverse - 1) * (-msm_count_zero_at_transition + 1); - std::get<40>(accumulator) += msm_transition_check * msm_count_zero_at_transition_check * scaling_factor; + std::get<4>(accumulator) += msm_transition_check * msm_count_zero_at_transition_check * scaling_factor; // Validate msm_transition_msm_count is correct // ensure msm_transition is zero if count is zero - std::get<4>(accumulator) += + std::get<5>(accumulator) += (msm_transition - msm_transition_check * (-msm_count_zero_at_transition + 1)) * scaling_factor; /** @@ -157,7 +157,7 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu * (if no msm active, msm_count == 0) * If current row ends an MSM, `msm_count_shift = 0` (msm_count value at next row) */ - std::get<5>(accumulator) += (msm_transition * msm_count_shift) * scaling_factor; + std::get<6>(accumulator) += (msm_transition * msm_count_shift) * scaling_factor; /** * @brief Validate `msm_count` updates correctly for mul operations. @@ -166,26 +166,18 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu */ auto msm_count_delta = msm_count_shift - msm_count; // degree 4 auto num_counts = ((-z1_zero + 1) + (-z2_zero + 1)) * (-transcript_Pinfinity + 1); - std::get<6>(accumulator) += + std::get<7>(accumulator) += is_not_first_row * (-msm_transition + 1) * (msm_count_delta - q_mul * (num_counts)) * scaling_factor; /** * @brief Opcode exclusion tests. We have the following assertions: * 1. If q_mul = 1, (q_add, eq, reset) are zero - * 2. If q_reset = 1, is_accumulator_empty at next row = 1 - * 3. If q_add = 1 OR msm_transition = 1, is_accumulator_empty at next row = 0 - * 4. If q_add = 0 AND msm_transition = 0 AND q_reset_accumulator = 0, is_accumulator at next row = current row - * value - * @note point 3: both q_add = 1, msm_transition = 1 cannot occur because of point 1 (msm_transition only 1 when - * q_mul 1) we can use a slightly more efficient relation than a pure binary OR + * 2. If q_add = 1, (q_mul, eq, reset) are zero + * 3. If q_eq = 1, (q_add, q_mul) are zero (is established by previous 2 checks) */ - std::get<7>(accumulator) += q_mul * (q_add + q_eq + q_reset_accumulator) * scaling_factor; - std::get<8>(accumulator) += q_add * (q_mul + q_eq + q_reset_accumulator) * scaling_factor; - std::get<9>(accumulator) += q_reset_accumulator * (-is_accumulator_empty_shift + 1) * scaling_factor; - // std::get<18>(accumulator) += (q_add + msm_transition) * is_accumulator_empty_shift * scaling_factor; - auto accumulator_state_not_modified = -(q_add + msm_transition + q_reset_accumulator) + 1; - std::get<10>(accumulator) += accumulator_state_not_modified * is_not_first_or_last_row * - (is_accumulator_empty_shift - is_accumulator_empty) * scaling_factor; + auto opcode_exclusion_relation = q_mul * (q_add + q_eq + q_reset_accumulator); + opcode_exclusion_relation += q_add * (q_mul + q_eq + q_reset_accumulator); + std::get<8>(accumulator) += opcode_exclusion_relation * scaling_factor; // degree 2 /** * @brief `eq` opcode. @@ -201,9 +193,9 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu auto eq_y_diff = transcript_Py - transcript_accumulator_y; auto eq_x_diff_relation = q_eq * (eq_x_diff * both_not_infinity + infinity_exclusion_check); // degree 4 auto eq_y_diff_relation = q_eq * (eq_y_diff * both_not_infinity + infinity_exclusion_check); // degree 4 - std::get<11>(accumulator) += eq_x_diff_relation * scaling_factor; + std::get<9>(accumulator) += eq_x_diff_relation * scaling_factor; - std::get<12>(accumulator) += eq_y_diff_relation * scaling_factor; + std::get<10>(accumulator) += eq_y_diff_relation * scaling_factor; /** * @brief Initial condition check on 1st row. @@ -213,8 +205,8 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu * note...actually second row? bleurgh * NOTE: we want pc = 0 at lagrange_last :o */ - std::get<13>(accumulator) += lagrange_second * (-is_accumulator_empty + 1) * scaling_factor; - std::get<14>(accumulator) += lagrange_second * msm_count * scaling_factor; + std::get<11>(accumulator) += lagrange_second * (-is_accumulator_empty + 1) * scaling_factor; + std::get<12>(accumulator) += lagrange_second * msm_count * scaling_factor; /** * @brief On-curve validation checks. @@ -224,174 +216,218 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu const auto validate_on_curve = q_mul + q_add + q_mul + q_eq; const auto on_curve_check = transcript_Py * transcript_Py - transcript_Px * transcript_Px * transcript_Px - get_curve_b(); - std::get<15>(accumulator) += validate_on_curve * on_curve_check * is_not_infinity * scaling_factor; // degree 6 + std::get<13>(accumulator) += validate_on_curve * on_curve_check * is_not_infinity * scaling_factor; // degree 6 /** - * @brief Validate correctness of ECC Group Operation - * An ECC group operation is performed if q_add = 1 or msm_transition = 1. - * Because input points can be points at infinity, we must support COMPLETE addition and handle points at infinity + * @brief Validate relations from ECC Group Operations are well formed + * */ - // define the lhs point: either transcript_Px/y or transcript_accumulator_x/y - auto lhs_x = transcript_Px * q_add + transcript_msm_x * msm_transition; - auto lhs_y = transcript_Py * q_add + transcript_msm_y * msm_transition; - // the rhs point will always be the accumulator point at the next row in the trace - auto rhs_x = transcript_accumulator_x; - auto rhs_y = transcript_accumulator_y; - // the group operation will be either an ADD or a DOUBLE depending on whether x/y coordinates of lhs/rhs match. - // If lhs_x == rhs_x, we evaluate a DOUBLE, otherwise an ADD - // (we will only activate this relation if lhs_y != rhs_y, but this is done later) - auto ecc_op_is_dbl = transcript_add_x_equal; - auto ecc_op_is_add = (-transcript_add_x_equal + 1); - // Are the lhs/rhs points at infinity? - // MSM output CANNOT be point at infinity without triggering unsatisfiable constraints in msm_relation - // lhs can only be at infinity if q_add is active - auto lhs_infinity = transcript_Pinfinity * q_add + transcript_msm_infinity * msm_transition; - auto rhs_infinity = is_accumulator_empty; - // Determine where the group operation output is sourced from - // | lhs_infinity | rhs_infinity | lhs_x == rhs_x && lhs_y != rhs_y | output | - // | ------------ | ------------ | -------------------------------- | --------- | - // | 0 | 0 | 0 | lhs + rhs | - // | 0 | 0 | 1 | infinity | - // | 0 | 1 | n/a | lhs | - // | 1 | 0 | n/a | rhs | - // | 1 | 1 | n/a | infinity | - auto add_result_is_lhs = rhs_infinity * (-lhs_infinity + 1); // degree 3 - auto add_result_is_rhs = lhs_infinity * (-rhs_infinity + 1); // degree 3 - auto add_result_infinity_from_inputs = lhs_infinity * rhs_infinity; // degree 2 - auto add_result_infinity_from_operation = transcript_add_x_equal * (-transcript_add_y_equal + 1); // degree 2 - auto add_result_is_infinity = add_result_infinity_from_inputs + add_result_infinity_from_operation; // degree 2?? + { + Accumulator transcript_lambda_relation(0); + auto is_double = transcript_add_x_equal * transcript_add_y_equal; + auto is_add = (-transcript_add_x_equal + 1); + auto add_result_is_infinity = transcript_add_x_equal * (-transcript_add_y_equal + 1); // degree 2 + auto rhs_x = transcript_accumulator_x; + auto rhs_y = transcript_accumulator_y; + auto out_x = transcript_accumulator_x_shift; + auto out_y = transcript_accumulator_y_shift; + auto lambda = transcript_add_lambda; + auto lhs_x = transcript_Px * q_add + transcript_msm_x * msm_transition; + auto lhs_y = transcript_Py * q_add + transcript_msm_y * msm_transition; + auto lhs_infinity = transcript_Pinfinity * q_add + transcript_msm_infinity * msm_transition; + auto rhs_infinity = is_accumulator_empty; + auto result_is_lhs = rhs_infinity * (-lhs_infinity + 1); // degree 2 + auto result_is_rhs = (-rhs_infinity + 1) * lhs_infinity; // degree 2 + auto result_infinity_from_inputs = lhs_infinity * rhs_infinity; // degree 2 + auto result_infinity_from_operation = transcript_add_x_equal * (-transcript_add_y_equal + 1); // degree 2 + // infinity_from_inputs and infinity_from_operation mutually exclusive so we can perform an OR by adding + // (mutually exclusive because if result_infinity_from_inputs then transcript_add_y_equal = 1 (both y are 0) + auto result_is_infinity = result_infinity_from_inputs + result_infinity_from_operation; // degree 2 + auto any_add_is_active = q_add + msm_transition; - auto lambda_relation_valid = (-lhs_infinity + 1) * (-rhs_infinity + 1) * (-add_result_is_infinity + 1); // degree 4 - // Determine the gradient `lambda` of the group operation - // If lhs_x == rhs_x, lambda = (3 * lhs_x * lhs_x) / (2 * lhs_y) - // Else, lambda = (rhs_y - lhs_y) / (rhs_x - lhs_x) - auto lhs_xx = lhs_x * lhs_x; - auto lambda_numerator = (rhs_y - lhs_y) * ecc_op_is_add + (lhs_xx + lhs_xx + lhs_xx) * ecc_op_is_dbl; - auto lambda_denominator = (rhs_x - lhs_x) * ecc_op_is_add + (lhs_y + lhs_y) * ecc_op_is_dbl; // degree 3 - auto lambda_term = lambda_denominator * transcript_add_lambda - lambda_numerator; // degree 4 - // We only activate lambda relation if we don't have points at infinity - this is to avoid divide-by-zero problems - // N.B. check this is need - auto any_add_is_active = q_add + msm_transition; - auto lambda_relation_active = any_add_is_active * lambda_relation_valid; // degree 5 - auto lambda_relation = lambda_term * lambda_relation_active; // degree 9! - // if lambda relation is not active, assert lambda = 0 - lambda_relation += (-lambda_relation_active + 1) * transcript_add_lambda; - std::get<16>(accumulator) += lambda_relation * scaling_factor; // degree 9 + // Valdiate `transcript_add_lambda` is well formed if we are adding msm output into accumulator + { + Accumulator transcript_msm_lambda_relation(0); + auto msm_x = transcript_msm_x; + auto msm_y = transcript_msm_y; + // Group operation is point addition + { + auto lambda_denominator = (rhs_x - msm_x); + auto lambda_numerator = (rhs_y - msm_y); + auto lambda_relation = lambda * lambda_denominator - lambda_numerator; // degree 2 + transcript_msm_lambda_relation += lambda_relation * is_add; // degree 3 + } + // Group operation is point doubling + { + auto lambda_denominator = msm_y + msm_y; + auto lambda_numerator = msm_x * msm_x * 3; + auto lambda_relation = lambda * lambda_denominator - lambda_numerator; // degree 2 + transcript_msm_lambda_relation += lambda_relation * is_double; // degree 4 + } + auto transcript_add_or_dbl_from_msm_output_is_valid = + (-transcript_msm_infinity + 1) * (-is_accumulator_empty + 1); // degree 2 + transcript_msm_lambda_relation *= transcript_add_or_dbl_from_msm_output_is_valid; // degree 6 + // No group operation because of points at infinity + { + auto lambda_relation_invalid = + (transcript_msm_infinity + is_accumulator_empty + add_result_is_infinity); // degree 2 + auto lambda_relation = lambda * lambda_relation_invalid; // degree 4 + transcript_msm_lambda_relation += lambda_relation; // (still degree 6) + } + transcript_lambda_relation = transcript_msm_lambda_relation * msm_transition; // degree 7 + } + // Valdiate `transcript_add_lambda` is well formed if we are adding base point into accumulator + { + Accumulator transcript_add_lambda_relation(0); + auto add_x = transcript_Px; + auto add_y = transcript_Py; + // Group operation is point addition + { + auto lambda_denominator = (rhs_x - add_x); + auto lambda_numerator = (rhs_y - add_y); + auto lambda_relation = lambda * lambda_denominator - lambda_numerator; // degree 2 + transcript_add_lambda_relation += lambda_relation * is_add; // degree 3 + } + // Group operation is point doubling + { + auto lambda_denominator = add_y + add_y; + auto lambda_numerator = add_x * add_x * 3; + auto lambda_relation = lambda * lambda_denominator - lambda_numerator; // degree 2 + transcript_add_lambda_relation += lambda_relation * is_double; // degree 4 + } + auto transcript_add_or_dbl_from_add_output_is_valid = + (-transcript_Pinfinity + 1) * (-is_accumulator_empty + 1); // degree 2 + transcript_add_lambda_relation *= transcript_add_or_dbl_from_add_output_is_valid; // degree 6 + // No group operation because of points at infinity + { + auto lambda_relation_invalid = + (transcript_Pinfinity + is_accumulator_empty + add_result_is_infinity); // degree 2 + auto lambda_relation = lambda * lambda_relation_invalid; // degree 4 + transcript_add_lambda_relation += lambda_relation; // (still degree 6) + } + transcript_lambda_relation += transcript_add_lambda_relation * q_add; + std::get<14>(accumulator) += transcript_lambda_relation * scaling_factor; // degree 7 + } + /** + * @brief Validate transcript_accumulator_x_shift / transcript_accumulator_y_shift are well formed. + * Conditions (one of the following): + * 1. The result of a group operation involving transcript_accumulator and msm_output (q_add = 1) + * 2. The result of a group operation involving transcript_accumulator and transcript_P (msm_transition = + * 1) + * 3. Is equal to transcript_accumulator (no group operation, no reset) + * 4. Is 0 (reset) + */ + { + auto lambda_sqr = lambda * lambda; + // add relation that validates result_infinity_from_operation * result_is_infinity = 0 - // Determine the x/y coordinates of the shifted accumulator - // add_x3/add_y3 = result of group operation computation - auto add_x3 = transcript_add_lambda * transcript_add_lambda - lhs_x - rhs_x; // degree 2 - add_x3 += (lhs_x + lhs_x + rhs_x) * add_result_is_lhs; - add_x3 += (rhs_x + rhs_x + lhs_x) * add_result_is_rhs; - add_x3 += (lhs_x + rhs_x) * add_result_is_infinity; + // N.B. these relations rely on the fact that `lambda = 0` if we are not evaluating add/double formula + // (i.e. one or both outputs are points at infinity, or produce a point at infinity) + // This should be validated by the lambda_relation + auto x3 = lambda_sqr - lhs_x - rhs_x; // degree 2 + x3 += result_is_lhs * (rhs_x + lhs_x + lhs_x); // degree 4 + x3 += result_is_rhs * (lhs_x + rhs_x + rhs_x); // degree 4 + x3 += result_is_infinity * (lhs_x + rhs_x); // degree 4 + auto y3 = lambda * (lhs_x - out_x) - lhs_y; // degree 3 + y3 += result_is_lhs * (lhs_y + lhs_y); // degree 4 + y3 += result_is_rhs * (lhs_y + rhs_y); // degree 4 + y3 += result_is_infinity * lhs_y; // degree 4 - auto add_y3 = transcript_add_lambda * (lhs_x - add_x3) - lhs_y; // degree 3 - add_y3 += (lhs_y + lhs_y) * add_result_is_lhs; - add_y3 += (lhs_y + rhs_y) * add_result_is_rhs; - add_y3 += (lhs_y)*add_result_is_infinity; - auto propagate_transcript_accumulator = (-q_add - msm_transition - q_reset_accumulator + 1); - auto add_point_x_relation = (add_x3 - transcript_accumulator_x_shift) * any_add_is_active; // degree 7 - add_point_x_relation += propagate_transcript_accumulator * (-lagrange_last + 1) * - (transcript_accumulator_x_shift - transcript_accumulator_x); - auto add_point_y_relation = (add_y3 - transcript_accumulator_y_shift) * any_add_is_active; // degree 7 - add_point_y_relation += propagate_transcript_accumulator * (-lagrange_last + 1) * - (transcript_accumulator_y_shift - transcript_accumulator_y); - std::get<17>(accumulator) += add_point_x_relation * scaling_factor; // degree 7 - std::get<18>(accumulator) += add_point_y_relation * scaling_factor; // degree 8 + auto propagate_transcript_accumulator = (-q_add - msm_transition - q_reset_accumulator + 1); + auto add_point_x_relation = (x3 - out_x) * any_add_is_active; // degree 5 + add_point_x_relation += + propagate_transcript_accumulator * is_not_last_row * (out_x - transcript_accumulator_x); + // validate out_x = 0 if q_reset_accumulator = 1 + add_point_x_relation += (out_x * q_reset_accumulator); + auto add_point_y_relation = (y3 - out_y) * any_add_is_active; // degree 5 + add_point_y_relation += + propagate_transcript_accumulator * is_not_last_row * (out_y - transcript_accumulator_y); + // validate out_y = 0 if q_reset_accumulator = 1 + add_point_y_relation += (out_y * q_reset_accumulator); + std::get<15>(accumulator) += add_point_x_relation * scaling_factor; // degree 5 + std::get<16>(accumulator) += add_point_y_relation * scaling_factor; // degree 5 + } - /** - * @brief Validate `is_accumulator_empty` is updated correctly - * An add operation can produce a point at infinity - * Resetting the accumulator produces a point at infinity - * If we are not adding, performing an msm or resetting the accumulator, is_accumulator_empty should not update - */ - auto accumulator_infinity_preserve_flag = (-(q_add + msm_transition + q_reset_accumulator) + 1); - auto accumulator_infinity_preserve = - accumulator_infinity_preserve_flag * (is_accumulator_empty - is_accumulator_empty_shift) * (-lagrange_last + 1); - auto accumulator_infinity_q_reset = q_reset_accumulator * (-is_accumulator_empty_shift + 1); - auto accumulator_infinity_from_add = any_add_is_active * (add_result_is_infinity - is_accumulator_empty_shift); - auto accumulator_infinity_relation = - accumulator_infinity_preserve + accumulator_infinity_q_reset + accumulator_infinity_from_add; - std::get<19>(accumulator) += (accumulator_infinity_relation * is_not_first_row) * scaling_factor; // degree 5? + // step 1: subtract offset generator from msm_accumulator + // this might produce a point at infinity + { + const auto offset = offset_generator(); + const auto x1 = offset[0]; + const auto y1 = -offset[1]; + const auto x2 = View(in.transcript_msm_x); + const auto y2 = View(in.transcript_msm_y); + const auto x3 = View(in.transcript_msm_intermediate_x); + const auto y3 = View(in.transcript_msm_intermediate_y); + const auto transcript_msm_infinity = View(in.transcript_msm_infinity); + // cases: + // x2 == x1, y2 == y1 + // x2 != x1 + // (x2 - x1) + const auto x_term = (x3 + x2 + x1) * (x2 - x1) * (x2 - x1) - (y2 - y1) * (y2 - y1); // degree 3 + const auto y_term = (x1 - x3) * (y2 - y1) - (x2 - x1) * (y1 + y3); // degree 2 + // IF msm_infinity = false, transcript_msm_intermediate_x/y is either the result of subtracting offset + // generator from msm_x/y IF msm_infinity = true, transcript_msm_intermediate_x/y is 0 + const auto transcript_offset_generator_subtract_x = + x_term * (-transcript_msm_infinity + 1) + transcript_msm_infinity * x3; // degree 4 + const auto transcript_offset_generator_subtract_y = + y_term * (-transcript_msm_infinity + 1) + transcript_msm_infinity * y3; // degree 3 + std::get<17>(accumulator) += + msm_transition * transcript_offset_generator_subtract_x * scaling_factor; // degree 5 + std::get<18>(accumulator) += + msm_transition * transcript_offset_generator_subtract_y * scaling_factor; // degree 5 - /** - * @brief Validate `transcript_add_x_equal` is well-formed - * If lhs_x == rhs_x, transcript_add_x_equal = 1 - * If transcript_add_x_equal = 0, a valid inverse must exist for (lhs_x - rhs_x) - */ - auto x_diff = lhs_x - rhs_x; - auto x_product = transcript_Px_inverse * (-transcript_add_x_equal + 1) + transcript_add_x_equal; - auto x_constant = transcript_add_x_equal - 1; - auto transcript_add_x_equal_check_relation = (x_diff * x_product + x_constant) * any_add_is_active; - std::get<20>(accumulator) += transcript_add_x_equal_check_relation * scaling_factor; // degree 6 + // validate transcript_msm_infinity is correct + // if transcript_msm_infinity = 1, (x2 == x1) and (y2 + y1 == 0) + const auto x_diff = x2 - x1; + const auto y_sum = y2 + y1; + std::get<19>(accumulator) += msm_transition * transcript_msm_infinity * x_diff * scaling_factor; // degree 3 + std::get<20>(accumulator) += msm_transition * transcript_msm_infinity * y_sum * scaling_factor; // degree 3 + // if transcript_msm_infinity = 1, then x_diff must have an inverse + const auto transcript_msm_x_inverse = View(in.transcript_msm_x_inverse); + const auto inverse_term = (-transcript_msm_infinity + 1) * (x_diff * transcript_msm_x_inverse - 1); + std::get<21>(accumulator) += msm_transition * inverse_term * scaling_factor; // degree 3 + } - // TODO: IF MUL PRODUCES 0 POINTS DUE TO Z1=0, Z2=0 OR POINTS AT INFINITY, ENSURE THAT MSM_OUTPUT IS ALWAYS POINT AT - // INFINITY - /** - * @brief Validate `transcript_add_y_equal` is well-formed - * If lhs_y == rhs_y, transcript_add_y_equal = 1 - * If transcript_add_y_equal = 0, a valid inverse must exist for (lhs_y - rhs_y) - */ - auto y_diff = lhs_y - rhs_y; - auto y_product = transcript_Py_inverse * (-transcript_add_y_equal + 1) + transcript_add_y_equal; - auto y_constant = transcript_add_y_equal - 1; - auto transcript_add_y_equal_check_relation = (y_diff * y_product + y_constant) * any_add_is_active; - std::get<21>(accumulator) += transcript_add_y_equal_check_relation * scaling_factor; // degree 6 + /** + * @brief Validate `is_accumulator_empty` is updated correctly + * An add operation can produce a point at infinity + * Resetting the accumulator produces a point at infinity + * If we are not adding, performing an msm or resetting the accumulator, is_accumulator_empty should not update + */ + auto accumulator_infinity_preserve_flag = (-(q_add + msm_transition + q_reset_accumulator) + 1); // degree 1 + auto accumulator_infinity_preserve = accumulator_infinity_preserve_flag * + (is_accumulator_empty - is_accumulator_empty_shift) * + is_not_first_or_last_row; // degree 3 + auto accumulator_infinity_q_reset = q_reset_accumulator * (-is_accumulator_empty_shift + 1); // degree 2 + auto accumulator_infinity_from_add = + any_add_is_active * (result_is_infinity - is_accumulator_empty_shift); // degree 3 + auto accumulator_infinity_relation = + accumulator_infinity_preserve + + (accumulator_infinity_q_reset + accumulator_infinity_from_add) * is_not_first_row; // degree 4 + std::get<22>(accumulator) += accumulator_infinity_relation * scaling_factor; // degree 4 - // validate selectors are boolean (put somewhere else? these are low degree) - std::get<22>(accumulator) += q_eq * (q_eq - 1) * scaling_factor; - std::get<23>(accumulator) += q_add * (q_add - 1) * scaling_factor; - std::get<24>(accumulator) += q_mul * (q_mul - 1) * scaling_factor; - std::get<25>(accumulator) += q_reset_accumulator * (q_reset_accumulator - 1) * scaling_factor; - std::get<26>(accumulator) += msm_transition * (msm_transition - 1) * scaling_factor; - std::get<27>(accumulator) += is_accumulator_empty * (is_accumulator_empty - 1) * scaling_factor; - std::get<28>(accumulator) += z1_zero * (z1_zero - 1) * scaling_factor; - std::get<29>(accumulator) += z2_zero * (z2_zero - 1) * scaling_factor; - std::get<30>(accumulator) += transcript_add_x_equal * (transcript_add_x_equal - 1) * scaling_factor; - std::get<31>(accumulator) += transcript_add_y_equal * (transcript_add_y_equal - 1) * scaling_factor; - std::get<32>(accumulator) += transcript_Pinfinity * (transcript_Pinfinity - 1) * scaling_factor; - std::get<33>(accumulator) += transcript_msm_infinity * (transcript_msm_infinity - 1) * scaling_factor; - std::get<39>(accumulator) += msm_count_zero_at_transition * (msm_count_zero_at_transition - 1) * scaling_factor; - // step 1: subtract offset generator from msm_accumulator - // this might produce a point at infinity - { - const auto offset = offset_generator(); - const auto x1 = offset[0]; - const auto y1 = -offset[1]; - const auto x2 = View(in.transcript_msm_x); - const auto y2 = View(in.transcript_msm_y); - const auto x3 = View(in.transcript_msm_intermediate_x); - const auto y3 = View(in.transcript_msm_intermediate_y); - const auto transcript_msm_infinity = View(in.transcript_msm_infinity); - // cases: - // x2 == x1, y2 == y1 - // x2 != x1 - // (x2 - x1) - const auto x_term = (x3 + x2 + x1) * (x2 - x1) * (x2 - x1) - (y2 - y1) * (y2 - y1); - const auto y_term = (x1 - x3) * (y2 - y1) - (x2 - x1) * (y1 + y3); - // IF msm_infinity = false, transcript_msm_intermediate_x/y is either the result of subtracting offset generator - // from msm_x/y IF msm_infinity = true, transcript_msm_intermediate_x/y is 0 - const auto transcript_offset_generator_subtract_x = - x_term * (-transcript_msm_infinity + 1) + transcript_msm_infinity * x3; - const auto transcript_offset_generator_subtract_y = - y_term * (-transcript_msm_infinity + 1) + transcript_msm_infinity * y3; - std::get<34>(accumulator) += msm_transition * transcript_offset_generator_subtract_x * scaling_factor; - std::get<35>(accumulator) += msm_transition * transcript_offset_generator_subtract_y * scaling_factor; + /** + * @brief Validate `transcript_add_x_equal` is well-formed + * If lhs_x == rhs_x, transcript_add_x_equal = 1 + * If transcript_add_x_equal = 0, a valid inverse must exist for (lhs_x - rhs_x) + */ + auto x_diff = lhs_x - rhs_x; // degree 2 + auto x_product = transcript_Px_inverse * (-transcript_add_x_equal + 1) + transcript_add_x_equal; // degree 2 + auto x_constant = transcript_add_x_equal - 1; // degree 1 + auto transcript_add_x_equal_check_relation = (x_diff * x_product + x_constant) * any_add_is_active; // degree 5 + std::get<23>(accumulator) += transcript_add_x_equal_check_relation * scaling_factor; // degree 5 - // validate transcript_msm_infinity is correct - // if transcript_msm_infinity = 1, (x2 == x1) and (y2 + y1 == 0) - const auto x_diff = x2 - x1; - const auto y_sum = y2 + y1; - std::get<36>(accumulator) += msm_transition * transcript_msm_infinity * x_diff * scaling_factor; - std::get<37>(accumulator) += msm_transition * transcript_msm_infinity * y_sum * scaling_factor; - // if transcript_msm_infinity = 1, then x_diff must have an inverse - const auto transcript_msm_x_inverse = View(in.transcript_msm_x_inverse); - const auto inverse_term = (-transcript_msm_infinity + 1) * (x_diff * transcript_msm_x_inverse - 1); - std::get<38>(accumulator) += msm_transition * inverse_term * scaling_factor; + /** + * @brief Validate `transcript_add_y_equal` is well-formed + * If lhs_y == rhs_y, transcript_add_y_equal = 1 + * If transcript_add_y_equal = 0, a valid inverse must exist for (lhs_y - rhs_y) + */ + auto y_diff = lhs_y - rhs_y; + auto y_product = transcript_Py_inverse * (-transcript_add_y_equal + 1) + transcript_add_y_equal; + auto y_constant = transcript_add_y_equal - 1; + auto transcript_add_y_equal_check_relation = (y_diff * y_product + y_constant) * any_add_is_active; + std::get<24>(accumulator) += transcript_add_y_equal_check_relation * scaling_factor; // degree 5 } - - // Validate correctness of - {} } template class ECCVMTranscriptRelationImpl; diff --git a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.hpp b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.hpp index 448db8613f2..c4dc075f0ce 100644 --- a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.hpp +++ b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.hpp @@ -30,9 +30,8 @@ template class ECCVMTranscriptRelationImpl { public: using FF = FF_; - static constexpr std::array SUBRELATION_PARTIAL_LENGTHS{ - 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, - 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + static constexpr std::array SUBRELATION_PARTIAL_LENGTHS{ + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, }; template @@ -41,7 +40,7 @@ template class ECCVMTranscriptRelationImpl { const Parameters& /* unused */, const FF& scaling_factor); - // TODO(@zac-williamson #2609 find more generic way of doing this) + // TODO(@zac-williamson #2809 find more generic way of doing this) static constexpr FF get_curve_b() { if constexpr (FF::modulus == bb::fq::modulus) { diff --git a/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp b/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp index efc17614cca..d9464c5c5a9 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp @@ -401,6 +401,33 @@ class ECCOpQueue { return ultra_op; } + /** + * @brief Write no op (i.e. empty row) + * + */ + UltraOp no_op() + { + + // Construct and store the operation in the ultra op format + auto ultra_op = construct_and_populate_ultra_ops(NULL_OP, accumulator); + + // Store raw operation + raw_ops.emplace_back(ECCVMOperation{ + .add = false, + .mul = false, + .eq = false, + .reset = false, + .base_point = { 0, 0 }, + .z1 = 0, + .z2 = 0, + .mul_scalar_full = 0, + }); + num_transcript_rows += 1; + update_cached_msms(raw_ops.back()); + + return ultra_op; + } + /** * @brief Write equality op using internal accumulator point * From cf5d93d457a451b89aadab4fea451447f8e0c4b8 Mon Sep 17 00:00:00 2001 From: zac-williamson Date: Tue, 14 May 2024 12:41:37 +0000 Subject: [PATCH 07/24] code tidy --- .../relations/ecc_vm/ecc_bools_relation.cpp | 16 ------------ .../relations/ecc_vm/ecc_bools_relation.hpp | 4 +-- .../ecc_vm/ecc_transcript_relation.cpp | 25 +++++++++---------- 3 files changed, 14 insertions(+), 31 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_bools_relation.cpp b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_bools_relation.cpp index 53fcdee6ef8..6d765c8a3c4 100644 --- a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_bools_relation.cpp +++ b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_bools_relation.cpp @@ -27,8 +27,6 @@ void ECCVMBoolsRelationImpl::accumulate(ContainerOverSubrelations& accumulat using Accumulator = typename std::tuple_element_t<0, ContainerOverSubrelations>; using View = typename Accumulator::View; - auto z1 = View(in.transcript_z1); - auto z2 = View(in.transcript_z2); auto z1_zero = View(in.transcript_z1zero); auto z2_zero = View(in.transcript_z2zero); auto msm_count_zero_at_transition = View(in.transcript_msm_count_zero_at_transition); @@ -68,20 +66,6 @@ void ECCVMBoolsRelationImpl::accumulate(ContainerOverSubrelations& accumulat std::get<16>(accumulator) += msm_double * (msm_double - 1) * scaling_factor; std::get<17>(accumulator) += msm_skew * (msm_skew - 1) * scaling_factor; std::get<18>(accumulator) += precompute_select * (precompute_select - 1) * scaling_factor; - - /** - * @brief Validate correctness of z1_zero, z2_zero. - * If z1_zero = 0 and operation is a MUL, we will write a scalar mul instruction into our multiplication table. - * If z1_zero = 1 and operation is a MUL, we will NOT write a scalar mul instruction. - * (same with z2_zero). - * z1_zero / z2_zero is user-defined. - * We constraint z1_zero such that if z1_zero == 1, we require z1 == 0. (same for z2_zero). - * We do *NOT* constrain z1 != 0 if z1_zero = 0. If the user sets z1_zero = 0 and z1 = 0, - * this will add a scalar mul instruction into the multiplication table, where the scalar multiplier is 0. - * This is inefficient but will still produce the correct output. - */ - std::get<19>(accumulator) += (z1 * z1_zero) * scaling_factor; // if z1_zero = 1, z1 must be 0 - std::get<20>(accumulator) += (z2 * z2_zero) * scaling_factor; } template class ECCVMBoolsRelationImpl; diff --git a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_bools_relation.hpp b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_bools_relation.hpp index 3a91f559cfb..88d6eef5dc8 100644 --- a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_bools_relation.hpp +++ b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_bools_relation.hpp @@ -17,8 +17,8 @@ template class ECCVMBoolsRelationImpl { public: using FF = FF_; - static constexpr std::array SUBRELATION_PARTIAL_LENGTHS{ - 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + static constexpr std::array SUBRELATION_PARTIAL_LENGTHS{ + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, }; template diff --git a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp index 7269f81e8d1..47a0a3f8731 100644 --- a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp +++ b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp @@ -100,8 +100,8 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu * this will add a scalar mul instruction into the multiplication table, where the scalar multiplier is 0. * This is inefficient but will still produce the correct output. */ - std::get<0>(accumulator) += (z1 * z1_zero) * scaling_factor; // if z1_zero = 1, z1 must be 0 - std::get<1>(accumulator) += (z2 * z2_zero) * scaling_factor; + std::get<0>(accumulator) += (z1 * z1_zero) * scaling_factor; // if z1_zero = 1, z1 must be 0. degree 2 + std::get<1>(accumulator) += (z2 * z2_zero) * scaling_factor; // degree 2 /** * @brief Validate `op` opcode is well formed. @@ -115,7 +115,7 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu tmp += q_eq; tmp += tmp; tmp += q_reset_accumulator; - std::get<2>(accumulator) += (tmp - op) * scaling_factor; + std::get<2>(accumulator) += (tmp - op) * scaling_factor; // degree 1 /** * @brief Validate `pc` is updated correctly. @@ -125,7 +125,7 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu */ Accumulator pc_delta = pc - pc_shift; auto num_muls_in_row = ((-z1_zero + 1) + (-z2_zero + 1)) * (-transcript_Pinfinity + 1); - std::get<3>(accumulator) += is_not_first_row * (pc_delta - q_mul * num_muls_in_row) * scaling_factor; + std::get<3>(accumulator) += is_not_first_row * (pc_delta - q_mul * num_muls_in_row) * scaling_factor; // degree 4 /** * @brief Validate `msm_transition` is well-formed. @@ -134,7 +134,7 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu * i.e. if q_mul == 1 and q_mul_shift == 0, msm_transition = 1, else is 0 * We also require that `msm_count + [current msm number] > 0` */ - auto msm_transition_check = q_mul * (-q_mul_shift + 1); + auto msm_transition_check = q_mul * (-q_mul_shift + 1); // degree 2 // auto num_muls_total = msm_count + num_muls_in_row; auto msm_count_zero_at_transition = View(in.transcript_msm_count_zero_at_transition); auto msm_count_at_transition_inverse = View(in.transcript_msm_count_at_transition_inverse); @@ -144,12 +144,12 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu auto msm_count_zero_at_transition_check = msm_count_zero_at_transition * msm_count_total; msm_count_zero_at_transition_check += (msm_count_total * msm_count_at_transition_inverse - 1) * (-msm_count_zero_at_transition + 1); - std::get<4>(accumulator) += msm_transition_check * msm_count_zero_at_transition_check * scaling_factor; + std::get<4>(accumulator) += msm_transition_check * msm_count_zero_at_transition_check * scaling_factor; // degree 3 // Validate msm_transition_msm_count is correct // ensure msm_transition is zero if count is zero std::get<5>(accumulator) += - (msm_transition - msm_transition_check * (-msm_count_zero_at_transition + 1)) * scaling_factor; + (msm_transition - msm_transition_check * (-msm_count_zero_at_transition + 1)) * scaling_factor; // degree 3 /** * @brief Validate `msm_count` resets when we end a multiscalar multiplication. @@ -157,7 +157,7 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu * (if no msm active, msm_count == 0) * If current row ends an MSM, `msm_count_shift = 0` (msm_count value at next row) */ - std::get<6>(accumulator) += (msm_transition * msm_count_shift) * scaling_factor; + std::get<6>(accumulator) += (msm_transition * msm_count_shift) * scaling_factor; // degree 2 /** * @brief Validate `msm_count` updates correctly for mul operations. @@ -193,9 +193,8 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu auto eq_y_diff = transcript_Py - transcript_accumulator_y; auto eq_x_diff_relation = q_eq * (eq_x_diff * both_not_infinity + infinity_exclusion_check); // degree 4 auto eq_y_diff_relation = q_eq * (eq_y_diff * both_not_infinity + infinity_exclusion_check); // degree 4 - std::get<9>(accumulator) += eq_x_diff_relation * scaling_factor; - - std::get<10>(accumulator) += eq_y_diff_relation * scaling_factor; + std::get<9>(accumulator) += eq_x_diff_relation * scaling_factor; // degree 4 + std::get<10>(accumulator) += eq_y_diff_relation * scaling_factor; // degree 4 /** * @brief Initial condition check on 1st row. @@ -205,8 +204,8 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu * note...actually second row? bleurgh * NOTE: we want pc = 0 at lagrange_last :o */ - std::get<11>(accumulator) += lagrange_second * (-is_accumulator_empty + 1) * scaling_factor; - std::get<12>(accumulator) += lagrange_second * msm_count * scaling_factor; + std::get<11>(accumulator) += lagrange_second * (-is_accumulator_empty + 1) * scaling_factor; // degree 2 + std::get<12>(accumulator) += lagrange_second * msm_count * scaling_factor; // degree 2 /** * @brief On-curve validation checks. From 212c6a0bbdc9bfad8513a69523ad7516b344f52e Mon Sep 17 00:00:00 2001 From: zac-williamson Date: Tue, 14 May 2024 12:46:29 +0000 Subject: [PATCH 08/24] formatting fix. uncommented parallelism in eccvm circuit builder --- .../eccvm/eccvm_circuit_builder.hpp | 68 +++++++++---------- .../relations/ecc_vm/ecc_msm_relation.cpp | 1 - .../ecc_vm/ecc_transcript_relation.hpp | 2 +- 3 files changed, 34 insertions(+), 37 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.hpp b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.hpp index 8f493da5735..beece79c4a4 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.hpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.hpp @@ -148,42 +148,40 @@ class ECCVMCircuitBuilder { msm.resize(msm_sizes[i]); } - // run_loop_in_parallel(msm_opqueue_index.size(), [&](size_t start, size_t end) { - size_t start = 0; - size_t end = msm_opqueue_index.size(); - for (size_t i = start; i < end; i++) { - const size_t opqueue_index = msm_opqueue_index[i]; - const auto& op = raw_ops[opqueue_index]; - auto [msm_index, mul_index] = msm_mul_index[i]; - if (op.z1 != 0 && !op.base_point.is_point_at_infinity()) { - - ASSERT(msms_test.size() > msm_index); - ASSERT(msms_test[msm_index].size() > mul_index); - msms_test[msm_index][mul_index] = (ScalarMul{ - .pc = 0, - .scalar = op.z1, - .base_point = op.base_point, - .wnaf_slices = compute_wnaf_slices(op.z1), - .wnaf_skew = (op.z1 & 1) == 0, - .precomputed_table = compute_precomputed_table(op.base_point), - }); - mul_index++; - } - if (op.z2 != 0 && !op.base_point.is_point_at_infinity()) { - ASSERT(msms_test.size() > msm_index); - ASSERT(msms_test[msm_index].size() > mul_index); - auto endo_point = AffineElement{ op.base_point.x * FF::cube_root_of_unity(), -op.base_point.y }; - msms_test[msm_index][mul_index] = (ScalarMul{ - .pc = 0, - .scalar = op.z2, - .base_point = endo_point, - .wnaf_slices = compute_wnaf_slices(op.z2), - .wnaf_skew = (op.z2 & 1) == 0, - .precomputed_table = compute_precomputed_table(endo_point), - }); + run_loop_in_parallel(msm_opqueue_index.size(), [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + const size_t opqueue_index = msm_opqueue_index[i]; + const auto& op = raw_ops[opqueue_index]; + auto [msm_index, mul_index] = msm_mul_index[i]; + if (op.z1 != 0 && !op.base_point.is_point_at_infinity()) { + + ASSERT(msms_test.size() > msm_index); + ASSERT(msms_test[msm_index].size() > mul_index); + msms_test[msm_index][mul_index] = (ScalarMul{ + .pc = 0, + .scalar = op.z1, + .base_point = op.base_point, + .wnaf_slices = compute_wnaf_slices(op.z1), + .wnaf_skew = (op.z1 & 1) == 0, + .precomputed_table = compute_precomputed_table(op.base_point), + }); + mul_index++; + } + if (op.z2 != 0 && !op.base_point.is_point_at_infinity()) { + ASSERT(msms_test.size() > msm_index); + ASSERT(msms_test[msm_index].size() > mul_index); + auto endo_point = AffineElement{ op.base_point.x * FF::cube_root_of_unity(), -op.base_point.y }; + msms_test[msm_index][mul_index] = (ScalarMul{ + .pc = 0, + .scalar = op.z2, + .base_point = endo_point, + .wnaf_slices = compute_wnaf_slices(op.z2), + .wnaf_skew = (op.z2 & 1) == 0, + .precomputed_table = compute_precomputed_table(endo_point), + }); + } } - } - // }); + }); // update pc. easier to do this serially but in theory could be optimised out // We start pc at `num_muls` and decrement for each mul processed. diff --git a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_msm_relation.cpp b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_msm_relation.cpp index c2c3e97460d..d79e3ab44c0 100644 --- a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_msm_relation.cpp +++ b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_msm_relation.cpp @@ -232,7 +232,6 @@ void ECCVMMSMRelationImpl::accumulate(ContainerOverSubrelations& accumulator Accumulator x4_collision_relation(0); // If msm_transition = 1, we have started a new MSM. We need to treat the current value of [Acc] as the point at // infinity! - // auto add_into_accumulator = -msm_transition + 1; auto [x_t1, y_t1] = first_add(acc_x, acc_y, x1, y1, lambda1, msm_transition, add_relation, x1_collision_relation); auto [x_t2, y_t2] = add(x2, y2, x_t1, y_t1, lambda2, add2, add_relation, x2_collision_relation); auto [x_t3, y_t3] = add(x3, y3, x_t2, y_t2, lambda3, add3, add_relation, x3_collision_relation); diff --git a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.hpp b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.hpp index c4dc075f0ce..21546026046 100644 --- a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.hpp +++ b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.hpp @@ -40,7 +40,7 @@ template class ECCVMTranscriptRelationImpl { const Parameters& /* unused */, const FF& scaling_factor); - // TODO(@zac-williamson #2809 find more generic way of doing this) + // TODO(@zac-williamson #2609 find more generic way of doing this) static constexpr FF get_curve_b() { if constexpr (FF::modulus == bb::fq::modulus) { From 27c75340ea0a97e7c58067636cd4550563f45a34 Mon Sep 17 00:00:00 2001 From: zac-williamson Date: Tue, 14 May 2024 12:48:06 +0000 Subject: [PATCH 09/24] removed unused code --- .../eccvm/eccvm_circuit_builder.test.cpp | 2 +- .../op_queue/ecc_op_queue.hpp | 29 ------------------- 2 files changed, 1 insertion(+), 30 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.test.cpp b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.test.cpp index e7393974f80..10dc708399e 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.test.cpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.test.cpp @@ -120,7 +120,7 @@ TEST(ECCVMCircuitBuilderTests, MulInfinity) // G1::affine_element c = G1::affine_point_at_infinity; op_queue->add_accumulate(b); op_queue->mul_accumulate(a, x); - // op_queue->eq_and_resetb(c); + op_queue->eq_and_reset(); ECCVMCircuitBuilder circuit{ op_queue }; bool result = ECCVMTraceChecker::check(circuit); EXPECT_EQ(result, true); diff --git a/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp b/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp index d9464c5c5a9..e9ed60c6042 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp @@ -372,35 +372,6 @@ class ECCOpQueue { return ultra_op; } - /** - * @brief Write equality op using internal accumulator point - * - * @return current internal accumulator point (prior to reset to 0) - */ - UltraOp eq_and_resetb(Point& expected) - { - accumulator.self_set_infinity(); - - // Construct and store the operation in the ultra op format - auto ultra_op = construct_and_populate_ultra_ops(EQUALITY, expected); - - // Store raw operation - raw_ops.emplace_back(ECCVMOperation{ - .add = false, - .mul = false, - .eq = true, - .reset = true, - .base_point = expected, - .z1 = 0, - .z2 = 0, - .mul_scalar_full = 0, - }); - num_transcript_rows += 1; - update_cached_msms(raw_ops.back()); - - return ultra_op; - } - /** * @brief Write no op (i.e. empty row) * From b911c5288cdc4469051e84396b069e06e868e8ec Mon Sep 17 00:00:00 2001 From: zac-williamson Date: Tue, 14 May 2024 14:41:09 +0000 Subject: [PATCH 10/24] formatting --- .../eccvm/eccvm_circuit_builder.hpp | 8 +------ .../src/barretenberg/eccvm/eccvm_flavor.hpp | 12 +++++++++++ .../barretenberg/eccvm/transcript_builder.hpp | 10 ++++++++- .../relations/ecc_vm/ecc_msm_relation.cpp | 1 + .../relations/ecc_vm/ecc_set_relation.cpp | 21 ++++++++++--------- 5 files changed, 34 insertions(+), 18 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.hpp b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.hpp index beece79c4a4..5ae37ca9745 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.hpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.hpp @@ -123,12 +123,7 @@ class ECCVMCircuitBuilder { if ((op.z1 != 0 || op.z2 != 0) && !op.base_point.is_point_at_infinity()) { msm_opqueue_index.push_back(op_idx); msm_mul_index.emplace_back(msm_count, active_mul_count); - } - if (op.z1 != 0 && !op.base_point.is_point_at_infinity()) { - active_mul_count++; - } - if (op.z2 != 0 && !op.base_point.is_point_at_infinity()) { - active_mul_count++; + active_mul_count += static_cast(op.z1 != 0) + static_cast(op.z2 != 0); } } else if (active_mul_count > 0) { msm_sizes.push_back(active_mul_count); @@ -154,7 +149,6 @@ class ECCVMCircuitBuilder { const auto& op = raw_ops[opqueue_index]; auto [msm_index, mul_index] = msm_mul_index[i]; if (op.z1 != 0 && !op.base_point.is_point_at_infinity()) { - ASSERT(msms_test.size() > msm_index); ASSERT(msms_test[msm_index].size() > mul_index); msms_test[msm_index][mul_index] = (ScalarMul{ diff --git a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp index 862723a9fc0..4779acdbe38 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp @@ -476,6 +476,18 @@ class ECCVMFlavor { * transcript_msm_x: x-coordinate of MSM output * transcript_msm_y: y-coordinate of MSM output * transcript_accumulator_empty: if 1, transcript_accumulator = point at infinity + * transcript_base_infinity: if 1, transcript_Px, transcript_Py is a point at infinity + * transcript_add_x_equal: if adding a point into the accumulator, is 1 if x-coordinates are equal + * transcript_add_y_equal: if adding a point into the accumulator, is 1 if y-coordinates are equal + * transcript_base_x_inverse: to check transcript_add_x_equal (if x-vals not equal inverse exists) + * transcript_base_y_inverse: to check transcript_add_x_equal (if y-vals not equal inverse exists) + * transcript_add_lambda: if adding a point into the accumulator, contains the lambda gradient + * transcript_msm_intermediate_x: if add MSM result into accumulator, is msm_output - offset_generator + * transcript_msm_intermediate_y: if add MSM result into accumulator, is msm_output - offset_generator + * transcript_msm_infinity: is MSM result the point at infinity? + * transcript_msm_x_inverse: used to validate transcript_msm_infinity correct + * transcript_msm_count_zero_at_transition: does an MSM only contain points at infinity/zero scalars + * transcript_msm_count_at_transition_inverse: used to validate transcript_msm_count_zero_at_transition * precompute_pc: point counter for Straus precomputation columns * precompute_select: if 1, evaluate Straus precomputation algorithm at current row * precompute_point_transition: 1 if current row operating on a different point to previous row diff --git a/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp b/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp index ed087e5cd62..c4aef2c45eb 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp @@ -45,6 +45,15 @@ class ECCVMTranscriptBuilder { FF msm_count_at_transition_inverse = 0; }; + /** + * @brief Computes offset_generator group element + * @details "offset generator" is used when performing multi-scalar-multiplications to ensure an HONEST prover never + * triggers incomplete point addition formulae. + * i.e. we don't need to constrain point doubling or points at infinity when computing an MSM + * The MSM accumulator is initialized to `offset_generator`. When adding the MSM result into the transcript + * accumulator, the contribution of the offset generator to the MSM result is removed (offset_generator * 2^{124}) + * @return AffineElement + */ static AffineElement offset_generator() { static constexpr auto offset_generator_base = CycleGroup::derive_generators("ECCVM_OFFSET_GENERATOR", 1)[0]; @@ -91,7 +100,6 @@ class ECCVMTranscriptBuilder { // We fill these vectors and then perform batch inversions to amortize the cost of FF inverts std::vector inverse_trace_x(num_vm_entries); std::vector inverse_trace_y(num_vm_entries); - std::vector transcript_add_lambda(num_vm_entries); std::vector transcript_msm_x_inverse_trace(num_vm_entries); std::vector add_lambda_denominator(num_vm_entries); std::vector add_lambda_numerator(num_vm_entries); diff --git a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_msm_relation.cpp b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_msm_relation.cpp index d79e3ab44c0..dc33ba4a095 100644 --- a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_msm_relation.cpp +++ b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_msm_relation.cpp @@ -210,6 +210,7 @@ void ECCVMMSMRelationImpl::accumulate(ContainerOverSubrelations& accumulator auto& selector, auto& relation, auto& collision_relation) { + // N.B. this is brittle - should be curve agnostic but we don't propagate the curve parameter into relations! constexpr auto offset_generator = bb::g1::derive_generators("ECCVM_OFFSET_GENERATOR", 1)[0]; constexpr uint256_t oxu = offset_generator.x; constexpr uint256_t oyu = offset_generator.y; diff --git a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_set_relation.cpp b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_set_relation.cpp index 5d9a7c04501..7288de054bf 100644 --- a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_set_relation.cpp +++ b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_set_relation.cpp @@ -305,12 +305,17 @@ Accumulator ECCVMSetRelationImpl::compute_permutation_denominator(const AllE auto transcript_input2 = (transcript_pc - 1) + transcript_Px * endomorphism_base_field_shift * beta - transcript_Py * beta_sqr + z2 * beta_cube; - // | q_mul | z2_zero | z1_zero | lookup | - // | ----- | ------- | ------- | ---------------------- | - // | 0 | - | - | 1 | - // | 1 | 0 | 1 | X + gamma | - // | 1 | 1 | 0 | Y + gamma | - // | 1 | 1 | 1 | (X + gamma)(Y + gamma) | + // | q_mul | z2_zero | z1_zero | base_infinity | lookup | + // | ----- | ------- | ------- | ------------- |----------------------- | + // | 0 | - | - | - | 1 | + // | 1 | 0 | 0 | 0 | 1 | + // | 1 | 0 | 1 | 0 | X + gamma | + // | 1 | 1 | 0 | 0 | Y + gamma | + // | 1 | 1 | 1 | 0 | (X + gamma)(Y + gamma) | + // | 1 | 0 | 0 | 1 | 1 | + // | 1 | 0 | 1 | 1 | 1 | + // | 1 | 1 | 0 | 1 | 1 | + // | 1 | 1 | 1 | 1 | 1 | transcript_input1 = (transcript_input1 + gamma) * lookup_first + (-lookup_first + 1); transcript_input2 = (transcript_input2 + gamma) * lookup_second + (-lookup_second + 1); // transcript_product = degree 3 @@ -343,17 +348,13 @@ Accumulator ECCVMSetRelationImpl::compute_permutation_denominator(const AllE auto z2_zero = View(in.transcript_z2zero); auto transcript_mul = View(in.transcript_mul); auto base_infinity = View(in.transcript_base_infinity); - // auto transcript_msm_count_zero_at_transition = View(in.transcript_msm_count_zero_at_transition); // do not add to count if point at infinity! auto full_msm_count = transcript_msm_count + transcript_mul * ((-z1_zero + 1) + (-z2_zero + 1)) * (-base_infinity + 1); - // auto count_test = transcript_msm_count // msm_result_read = degree 2 auto msm_result_read = transcript_pc_shift + transcript_msm_x * beta + transcript_msm_y * beta_sqr + full_msm_count * beta_cube; - // N.B. NOT COUNT ZERO NOT NEEDED IS FACTORED INTO MSM TRANSITION - // auto read_active = transcript_msm_transition; msm_result_read = transcript_msm_transition * (msm_result_read + gamma) + (-transcript_msm_transition + 1); denominator *= msm_result_read; // degree-20 } From db9e0704da3ed2ac02588371654f87249e6669fb Mon Sep 17 00:00:00 2001 From: codygunton Date: Wed, 15 May 2024 22:06:21 +0000 Subject: [PATCH 11/24] Revert "initial commit. biggroup objects track whether they are points at infinity, and have +/- methods that correctly handle points at infinity" This reverts commit 9f6b4efdbcf8a5ffed31013aae46ca5caab80948. --- .../stdlib/primitives/bigfield/bigfield.hpp | 6 - .../primitives/bigfield/bigfield.test.cpp | 44 ----- .../primitives/bigfield/bigfield_impl.hpp | 52 ------ .../stdlib/primitives/biggroup/biggroup.hpp | 74 +++------ .../primitives/biggroup/biggroup.test.cpp | 94 +---------- .../biggroup/biggroup_batch_mul.hpp | 41 +---- .../primitives/biggroup/biggroup_bn254.hpp | 32 ++-- .../primitives/biggroup/biggroup_goblin.hpp | 1 - .../biggroup/biggroup_goblin.test.cpp | 4 +- .../primitives/biggroup/biggroup_impl.hpp | 154 ++---------------- .../primitives/biggroup/biggroup_nafs.hpp | 15 +- .../biggroup/biggroup_secp256k1.hpp | 7 +- .../primitives/biggroup/biggroup_tables.hpp | 108 ++++++------ .../stdlib/primitives/curves/secp256r1.hpp | 10 +- .../primitives/databus/databus.test.cpp | 2 - 15 files changed, 127 insertions(+), 517 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.hpp index 7643afe8ad6..2fc3572cec3 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.hpp @@ -241,12 +241,6 @@ template class bigfield { bigfield conditional_negate(const bool_t& predicate) const; bigfield conditional_select(const bigfield& other, const bool_t& predicate) const; - static bigfield conditional_assign(const bool_t& predicate, const bigfield& lhs, const bigfield& rhs) - { - return rhs.conditional_select(lhs, predicate); - } - - bool_t operator==(const bigfield& other) const; void assert_is_in_field() const; void assert_less_than(const uint256_t upper_limit) const; diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.test.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.test.cpp index 8ec46f817de..3aa7f6090ce 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.test.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.test.cpp @@ -841,45 +841,6 @@ template class stdlib_bigfield : public testing::Test { fq_ct ret = fq_ct::div_check_denominator_nonzero({}, a_ct); EXPECT_NE(ret.get_context(), nullptr); } - - static void test_assert_equal_not_equal() - { - auto builder = Builder(); - size_t num_repetitions = 10; - for (size_t i = 0; i < num_repetitions; ++i) { - fq inputs[4]{ fq::random_element(), fq::random_element(), fq::random_element(), fq::random_element() }; - - fq_ct a(witness_ct(&builder, fr(uint256_t(inputs[0]).slice(0, fq_ct::NUM_LIMB_BITS * 2))), - witness_ct(&builder, - fr(uint256_t(inputs[0]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4)))); - fq_ct b(witness_ct(&builder, fr(uint256_t(inputs[1]).slice(0, fq_ct::NUM_LIMB_BITS * 2))), - witness_ct(&builder, - fr(uint256_t(inputs[1]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4)))); - fq_ct c(witness_ct(&builder, fr(uint256_t(inputs[2]).slice(0, fq_ct::NUM_LIMB_BITS * 2))), - witness_ct(&builder, - fr(uint256_t(inputs[2]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4)))); - fq_ct d(witness_ct(&builder, fr(uint256_t(inputs[3]).slice(0, fq_ct::NUM_LIMB_BITS * 2))), - witness_ct(&builder, - fr(uint256_t(inputs[3]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4)))); - - fq_ct two(witness_ct(&builder, fr(2)), - witness_ct(&builder, fr(0)), - witness_ct(&builder, fr(0)), - witness_ct(&builder, fr(0))); - fq_ct t0 = a + a; - fq_ct t1 = a * two; - - t0.assert_equal(t1); - t0.assert_is_not_equal(c); - t0.assert_is_not_equal(d); - stdlib::bool_t is_equal_a = t0 == t1; - stdlib::bool_t is_equal_b = t0 == c; - EXPECT_TRUE(is_equal_a.get_value()); - EXPECT_FALSE(is_equal_b.get_value()); - } - bool result = CircuitChecker::check(builder); - EXPECT_EQ(result, true); - } }; // Define types for which the above tests will be constructed. @@ -969,11 +930,6 @@ TYPED_TEST(stdlib_bigfield, division_context) TestFixture::test_division_context(); } -TYPED_TEST(stdlib_bigfield, assert_equal_not_equal) -{ - TestFixture::test_assert_equal_not_equal(); -} - // // This test was disabled before the refactor to use TYPED_TEST's/ // TEST(stdlib_bigfield, DISABLED_test_div_against_constants) // { diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield_impl.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield_impl.hpp index f8773225ad7..3e6fc79a994 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield_impl.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield_impl.hpp @@ -1562,57 +1562,6 @@ bigfield bigfield::conditional_select(const bigfield& ot return result; } -/** - * @brief Validate whether two bigfield elements are equal to each other - * @details To evaluate whether `(a == b)`, we use result boolean `r` to evaluate the following logic: - * (n.b all algebra involving bigfield elements is done in the bigfield) - * 1. If `r == 1` , `a - b == 0` - * 2. If `r == 0`, `a - b` posesses an inverse `I` i.e. `(a - b) * I - 1 == 0` - * We efficiently evaluate this logic by evaluating a single expression `(a - b)*X = Y` - * We use conditional assignment logic to define `X, Y` to be the following: - * If `r == 1` then `X = 1, Y = 0` - * If `r == 0` then `X = I, Y = 1` - * This allows us to evaluate `operator==` using only 1 bigfield multiplication operation. - * We can check the product equals 0 or 1 by directly evaluating the binary basis/prime basis limbs of Y. - * i.e. if `r == 1` then `(a - b)*X` should have 0 for all limb values - * if `r == 0` then `(a - b)*X` should have 1 in the least significant binary basis limb and 0 elsewhere - * @tparam Builder - * @tparam T - * @param other - * @return bool_t - */ -template bool_t bigfield::operator==(const bigfield& other) const -{ - Builder* ctx = context ? context : other.get_context(); - auto lhs = get_value() % modulus_u512; - auto rhs = other.get_value() % modulus_u512; - bool is_equal_raw = (lhs == rhs); - bool_t is_equal = witness_t(ctx, is_equal_raw); - - bigfield diff = (*this) - other; - - // TODO: get native values efficiently (i.e. if u512 value fits in a u256, subtract off modulus until u256 fits - // into finite field) - native diff_native = native((diff.get_value() % modulus_u512).lo); - native inverse_native = is_equal_raw ? 0 : diff_native.invert(); - - bigfield inverse = bigfield::from_witness(ctx, inverse_native); - - bigfield multiplicand = bigfield::conditional_assign(is_equal, one(), inverse); - - bigfield product = diff * multiplicand; - - field_t result = field_t::conditional_assign(is_equal, 0, 1); - - product.prime_basis_limb.assert_equal(result); - product.binary_basis_limbs[0].element.assert_equal(result); - product.binary_basis_limbs[1].element.assert_equal(0); - product.binary_basis_limbs[2].element.assert_equal(0); - product.binary_basis_limbs[3].element.assert_equal(0); - - return is_equal; -} - /** * REDUCTION CHECK * @@ -1798,7 +1747,6 @@ template void bigfield::assert_equal( std::cerr << "bigfield: calling assert equal on 2 CONSTANT bigfield elements...is this intended?" << std::endl; return; } else if (other.is_constant()) { - // TODO: wtf? // evaluate a strict equality - make sure *this is reduced first, or an honest prover // might not be able to satisfy these constraints. field_t t0 = (binary_basis_limbs[0].element - other.binary_basis_limbs[0].element); diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.hpp index 4cbe262e5d9..51cdb25c790 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.hpp @@ -21,8 +21,6 @@ namespace bb::stdlib { // ( ͡° ͜ʖ ͡°) template class element { public: - using bool_t = stdlib::bool_t; - struct secp256k1_wnaf { std::vector> wnaf; field_t positive_skew; @@ -40,23 +38,13 @@ template class element { element(const Fq& x, const Fq& y); element(const element& other); - element(element&& other) noexcept; + element(element&& other); static element from_witness(Builder* ctx, const typename NativeGroup::affine_element& input) { - element out; - if (input.is_point_at_infinity()) { - Fq x = Fq::from_witness(ctx, NativeGroup::affine_one.x); - Fq y = Fq::from_witness(ctx, NativeGroup::affine_one.y); - out.x = x; - out.y = y; - } else { - Fq x = Fq::from_witness(ctx, input.x); - Fq y = Fq::from_witness(ctx, input.y); - out.x = x; - out.y = y; - } - out.set_point_at_infinity(witness_t(ctx, input.is_point_at_infinity())); + Fq x = Fq::from_witness(ctx, input.x); + Fq y = Fq::from_witness(ctx, input.y); + element out(x, y); out.validate_on_curve(); return out; } @@ -64,17 +52,13 @@ template class element { void validate_on_curve() const { Fq b(get_context(), uint256_t(NativeGroup::curve_b)); - Fq _b = Fq::conditional_assign(is_point_at_infinity(), Fq::zero(), b); - Fq _x = Fq::conditional_assign(is_point_at_infinity(), Fq::zero(), x); - Fq _y = Fq::conditional_assign(is_point_at_infinity(), Fq::zero(), y); if constexpr (!NativeGroup::has_a) { // we validate y^2 = x^3 + b by setting "fix_remainder_zero = true" when calling mult_madd - Fq::mult_madd({ _x.sqr(), _y }, { _x, -_y }, { _b }, true); + Fq::mult_madd({ x.sqr(), y }, { x, -y }, { b }, true); } else { Fq a(get_context(), uint256_t(NativeGroup::curve_a)); - Fq _a = Fq::conditional_assign(is_point_at_infinity(), Fq::zero(), a); // we validate y^2 = x^3 + ax + b by setting "fix_remainder_zero = true" when calling mult_madd - Fq::mult_madd({ _x.sqr(), _x, _y }, { _x, _a, -_y }, { _b }, true); + Fq::mult_madd({ x.sqr(), x, y }, { x, a, -y }, { b }, true); } } @@ -88,7 +72,7 @@ template class element { } element& operator=(const element& other); - element& operator=(element&& other) noexcept; + element& operator=(element&& other); byte_array to_byte_array() const { @@ -98,9 +82,6 @@ template class element { return result; } - element checked_unconditional_add(const element& other) const; - element checked_unconditional_subtract(const element& other) const; - element operator+(const element& other) const; element operator-(const element& other) const; element operator-() const @@ -119,11 +100,11 @@ template class element { *this = *this - other; return *this; } - std::array checked_unconditional_add_sub(const element& other) const; + std::array add_sub(const element& other) const; element operator*(const Fr& other) const; - element conditional_negate(const bool_t& predicate) const + element conditional_negate(const bool_t& predicate) const { element result(*this); result.y = result.y.conditional_negate(predicate); @@ -195,13 +176,9 @@ template class element { typename NativeGroup::affine_element get_value() const { - uint512_t x_val = x.get_value() % Fq::modulus_u512; - uint512_t y_val = y.get_value() % Fq::modulus_u512; - auto result = typename NativeGroup::affine_element(x_val.lo, y_val.lo); - if (is_point_at_infinity().get_value()) { - result.self_set_infinity(); - } - return result; + uint512_t x_val = x.get_value(); + uint512_t y_val = y.get_value(); + return typename NativeGroup::affine_element(x_val.lo, y_val.lo); } // compute a multi-scalar-multiplication by creating a precomputed lookup table for each point, @@ -252,7 +229,7 @@ template class element { template ::value>> static element secp256k1_ecdsa_mul(const element& pubkey, const Fr& u1, const Fr& u2); - static std::vector compute_naf(const Fr& scalar, const size_t max_num_bits = 0); + static std::vector> compute_naf(const Fr& scalar, const size_t max_num_bits = 0); template static std::vector> compute_wnaf(const Fr& scalar); @@ -288,15 +265,10 @@ template class element { return nullptr; } - bool_t is_point_at_infinity() const { return _is_infinity; } - void set_point_at_infinity(const bool_t& is_infinity) { _is_infinity = is_infinity; } - Fq x; Fq y; private: - bool_t _is_infinity; - template >> static std::array, 5> create_group_element_rom_tables( const std::array& elements, std::array& limb_max); @@ -395,7 +367,7 @@ template class element { lookup_table_base(const lookup_table_base& other) = default; lookup_table_base& operator=(const lookup_table_base& other) = default; - element get(const std::array& bits) const; + element get(const std::array, length>& bits) const; element operator[](const size_t idx) const { return element_table[idx]; } @@ -425,7 +397,7 @@ template class element { lookup_table_plookup(const lookup_table_plookup& other) = default; lookup_table_plookup& operator=(const lookup_table_plookup& other) = default; - element get(const std::array& bits) const; + element get(const std::array, length>& bits) const; element operator[](const size_t idx) const { return element_table[idx]; } @@ -636,7 +608,7 @@ template class element { return chain_add_accumulator(add_accumulator[0]); } - element::chain_add_accumulator get_chain_add_accumulator(std::vector& naf_entries) const + element::chain_add_accumulator get_chain_add_accumulator(std::vector>& naf_entries) const { std::vector round_accumulator; for (size_t j = 0; j < num_sixes; ++j) { @@ -688,7 +660,7 @@ template class element { return (accumulator); } - element get(std::vector& naf_entries) const + element get(std::vector>& naf_entries) const { std::vector round_accumulator; for (size_t j = 0; j < num_sixes; ++j) { @@ -840,21 +812,21 @@ template class element { return chain_add_accumulator(add_accumulator[0]); } - element::chain_add_accumulator get_chain_add_accumulator(std::vector& naf_entries) const + element::chain_add_accumulator get_chain_add_accumulator(std::vector>& naf_entries) const { std::vector round_accumulator; for (size_t j = 0; j < num_quads; ++j) { - round_accumulator.push_back(quad_tables[j].get(std::array{ + round_accumulator.push_back(quad_tables[j].get(std::array, 4>{ naf_entries[4 * j], naf_entries[4 * j + 1], naf_entries[4 * j + 2], naf_entries[4 * j + 3] })); } if (has_triple) { - round_accumulator.push_back(triple_tables[0].get(std::array{ + round_accumulator.push_back(triple_tables[0].get(std::array, 3>{ naf_entries[num_quads * 4], naf_entries[num_quads * 4 + 1], naf_entries[num_quads * 4 + 2] })); } if (has_twin) { round_accumulator.push_back(twin_tables[0].get( - std::array{ naf_entries[num_quads * 4], naf_entries[num_quads * 4 + 1] })); + std::array, 2>{ naf_entries[num_quads * 4], naf_entries[num_quads * 4 + 1] })); } if (has_singleton) { round_accumulator.push_back(singletons[0].conditional_negate(naf_entries[num_points - 1])); @@ -877,7 +849,7 @@ template class element { return (accumulator); } - element get(std::vector& naf_entries) const + element get(std::vector>& naf_entries) const { std::vector round_accumulator; for (size_t j = 0; j < num_quads; ++j) { @@ -886,7 +858,7 @@ template class element { } if (has_triple) { - round_accumulator.push_back(triple_tables[0].get(std::array{ + round_accumulator.push_back(triple_tables[0].get(std::array, 3>{ naf_entries[num_quads * 4], naf_entries[num_quads * 4 + 1], naf_entries[num_quads * 4 + 2] })); } if (has_twin) { diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.test.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.test.cpp index a8de2df775b..44201423b28 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.test.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.test.cpp @@ -10,12 +10,12 @@ #include "barretenberg/stdlib/primitives/curves/secp256k1.hpp" #include "barretenberg/stdlib/primitives/curves/secp256r1.hpp" -using namespace bb; - namespace { auto& engine = numeric::get_debug_randomness(); } +using namespace bb; + // One can only define a TYPED_TEST with a single template paramter. // Our workaround is to pass parameters of the following type. template struct TestType { @@ -41,8 +41,6 @@ template class stdlib_biggroup : public testing::Test { using element = typename g1::element; using Builder = typename Curve::Builder; - using witness_ct = stdlib::witness_t; - using bool_ct = stdlib::bool_t; static constexpr auto EXPECT_CIRCUIT_CORRECTNESS = [](Builder& builder, bool expected_result = true) { info("num gates = ", builder.get_num_gates()); @@ -84,45 +82,6 @@ template class stdlib_biggroup : public testing::Test { EXPECT_CIRCUIT_CORRECTNESS(builder); } - static void test_add_points_at_infnity() - { - Builder builder; - size_t num_repetitions = 1; - for (size_t i = 0; i < num_repetitions; ++i) { - affine_element input_a(element::random_element()); - affine_element input_b(element::random_element()); - input_b.self_set_infinity(); - element_ct a = element_ct::from_witness(&builder, input_a); - // create copy of a with different witness - element_ct a_alternate = element_ct::from_witness(&builder, input_a); - element_ct a_negated = element_ct::from_witness(&builder, -input_a); - element_ct b = element_ct::from_witness(&builder, input_b); - - element_ct c = a + b; - element_ct d = b + a; - element_ct e = b + b; - element_ct f = a + a; - element_ct g = a + a_alternate; - element_ct h = a + a_negated; - - affine_element c_expected = affine_element(element(input_a) + element(input_b)); - affine_element d_expected = affine_element(element(input_b) + element(input_a)); - affine_element e_expected = affine_element(element(input_b) + element(input_b)); - affine_element f_expected = affine_element(element(input_a) + element(input_a)); - affine_element g_expected = affine_element(element(input_a) + element(input_a)); - affine_element h_expected = affine_element(element(input_a) + element(-input_a)); - - EXPECT_EQ(c.get_value(), c_expected); - EXPECT_EQ(d.get_value(), d_expected); - EXPECT_EQ(e.get_value(), e_expected); - EXPECT_EQ(f.get_value(), f_expected); - EXPECT_EQ(g.get_value(), g_expected); - EXPECT_EQ(h.get_value(), h_expected); - } - - EXPECT_CIRCUIT_CORRECTNESS(builder); - } - static void test_sub() { Builder builder; @@ -151,45 +110,6 @@ template class stdlib_biggroup : public testing::Test { EXPECT_CIRCUIT_CORRECTNESS(builder); } - static void test_sub_points_at_infnity() - { - Builder builder; - size_t num_repetitions = 1; - for (size_t i = 0; i < num_repetitions; ++i) { - affine_element input_a(element::random_element()); - affine_element input_b(element::random_element()); - input_b.self_set_infinity(); - element_ct a = element_ct::from_witness(&builder, input_a); - // create copy of a with different witness - element_ct a_alternate = element_ct::from_witness(&builder, input_a); - element_ct a_negated = element_ct::from_witness(&builder, -input_a); - element_ct b = element_ct::from_witness(&builder, input_b); - - element_ct c = a - b; - element_ct d = b - a; - element_ct e = b - b; - element_ct f = a - a; - element_ct g = a - a_alternate; - element_ct h = a - a_negated; - - affine_element c_expected = affine_element(element(input_a) - element(input_b)); - affine_element d_expected = affine_element(element(input_b) - element(input_a)); - affine_element e_expected = affine_element(element(input_b) - element(input_b)); - affine_element f_expected = affine_element(element(input_a) - element(input_a)); - affine_element g_expected = affine_element(element(input_a) - element(input_a)); - affine_element h_expected = affine_element(element(input_a) - element(-input_a)); - - EXPECT_EQ(c.get_value(), c_expected); - EXPECT_EQ(d.get_value(), d_expected); - EXPECT_EQ(e.get_value(), e_expected); - EXPECT_EQ(f.get_value(), f_expected); - EXPECT_EQ(g.get_value(), g_expected); - EXPECT_EQ(h.get_value(), h_expected); - } - - EXPECT_CIRCUIT_CORRECTNESS(builder); - } - static void test_dbl() { Builder builder; @@ -913,20 +833,10 @@ TYPED_TEST(stdlib_biggroup, add) TestFixture::test_add(); } -TYPED_TEST(stdlib_biggroup, add_points_at_infinity) -{ - - TestFixture::test_add_points_at_infnity(); -} TYPED_TEST(stdlib_biggroup, sub) { TestFixture::test_sub(); } -TYPED_TEST(stdlib_biggroup, sub_points_at_infinity) -{ - - TestFixture::test_sub_points_at_infnity(); -} TYPED_TEST(stdlib_biggroup, dbl) { TestFixture::test_dbl(); diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_batch_mul.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_batch_mul.hpp index 004538a3e5d..a10198286c3 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_batch_mul.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_batch_mul.hpp @@ -1,50 +1,21 @@ #pragma once -#include "barretenberg/stdlib/primitives/biggroup/biggroup.hpp" -#include namespace bb::stdlib { /** * only works for Plookup (otherwise falls back on batch_mul)! Multiscalar multiplication that utilizes 4-bit wNAF * lookup tables is more efficient than points-as-linear-combinations lookup tables, if the number of points is 3 or * fewer - * TODO: when we nuke standard and turbo plonk we should remove the fallback batch mul method! */ template template -element element::wnaf_batch_mul(const std::vector& _points, - const std::vector& _scalars) +element element::wnaf_batch_mul(const std::vector& points, + const std::vector& scalars) { constexpr size_t WNAF_SIZE = 4; - ASSERT(_points.size() == _scalars.size()); + ASSERT(points.size() == scalars.size()); if constexpr (!HasPlookup) { - return batch_mul(_points, _scalars, max_num_bits); - } - - // treat inputs for points at infinity. - // if a base point is at infinity, we substitute for element::one, and set the scalar multiplier to 0 - // this (partially) ensures the mul algorithm does not need to account for points at infinity - std::vector points; - std::vector scalars; - element one = element::one(nullptr); - for (size_t i = 0; i < points.size(); ++i) { - bool_t is_point_at_infinity = points[i].is_point_at_infinity(); - if (is_point_at_infinity.get_value() && static_cast(is_point_at_infinity.is_constant())) { - // if point is at infinity and a circuit constant we can just skip. - continue; - } - if (_scalars[i].get_value() == 0 && _scalars[i].is_constant()) { - // if scalar multiplier is 0 and also a constant, we can skip - continue; - } - element point(_points[i]); - point.x = Fq::conditional_assign(is_point_at_infinity, one.x, point.x); - point.y = Fq::conditional_assign(is_point_at_infinity, one.y, point.y); - Fr scalar = Fr::conditional_assign(is_point_at_infinity, 0, _scalars[i]); - points.push_back(point); - scalars.push_back(scalar); - - // TODO: if both point and scalar are constant, don't bother adding constraints + return batch_mul(points, scalars, max_num_bits); } std::vector> point_tables; @@ -78,8 +49,8 @@ element element::wnaf_batch_mul(const std::vector(wnaf_entries[i][num_rounds])); + Fq out_y = accumulator.y.conditional_select(skew.y, bool_t(wnaf_entries[i][num_rounds])); accumulator = element(out_x, out_y); } accumulator -= offset_generators.second; diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_bn254.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_bn254.hpp index 0836b29bc87..5e03f8a58da 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_bn254.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_bn254.hpp @@ -7,8 +7,6 @@ * We use a special case algorithm to split bn254 scalar multipliers into endomorphism scalars * **/ -#include "barretenberg/stdlib/primitives/biggroup/biggroup.hpp" -#include "barretenberg/stdlib/primitives/circuit_builders/circuit_builders.hpp" namespace bb::stdlib { /** @@ -20,7 +18,6 @@ namespace bb::stdlib { * `small_scalars/small_points` : 128-bit scalar multipliers * `generator_scalar` : a 254-bit scalar multiplier over the bn254 generator point * - * TODO: this is plonk only. kill method when we deprecate standard/turbo plonk **/ template template @@ -57,9 +54,9 @@ element element::bn254_endo_batch_mul_with_generator auto& big_table = big_table_pair.first; auto& endo_table = big_table_pair.second; batch_lookup_table small_table(small_points); - std::vector> big_naf_entries; - std::vector> endo_naf_entries; - std::vector> small_naf_entries; + std::vector>> big_naf_entries; + std::vector>> endo_naf_entries; + std::vector>> small_naf_entries; const auto split_into_endomorphism_scalars = [ctx](const Fr& scalar) { bb::fr k = scalar.get_value(); @@ -102,9 +99,9 @@ element element::bn254_endo_batch_mul_with_generator element accumulator = element::chain_add_end(init_point); const auto get_point_to_add = [&](size_t naf_index) { - std::vector small_nafs; - std::vector big_nafs; - std::vector endo_nafs; + std::vector> small_nafs; + std::vector> big_nafs; + std::vector> endo_nafs; for (size_t i = 0; i < small_points.size(); ++i) { small_nafs.emplace_back(small_naf_entries[i][naf_index]); } @@ -181,14 +178,16 @@ element element::bn254_endo_batch_mul_with_generator } { element skew = accumulator - generator_table[128]; - Fq out_x = accumulator.x.conditional_select(skew.x, bool_t(generator_wnaf[generator_wnaf.size() - 1])); - Fq out_y = accumulator.y.conditional_select(skew.y, bool_t(generator_wnaf[generator_wnaf.size() - 1])); + Fq out_x = accumulator.x.conditional_select(skew.x, bool_t(generator_wnaf[generator_wnaf.size() - 1])); + Fq out_y = accumulator.y.conditional_select(skew.y, bool_t(generator_wnaf[generator_wnaf.size() - 1])); accumulator = element(out_x, out_y); } { element skew = accumulator - generator_endo_table[128]; - Fq out_x = accumulator.x.conditional_select(skew.x, bool_t(generator_endo_wnaf[generator_wnaf.size() - 1])); - Fq out_y = accumulator.y.conditional_select(skew.y, bool_t(generator_endo_wnaf[generator_wnaf.size() - 1])); + Fq out_x = + accumulator.x.conditional_select(skew.x, bool_t(generator_endo_wnaf[generator_wnaf.size() - 1])); + Fq out_y = + accumulator.y.conditional_select(skew.y, bool_t(generator_endo_wnaf[generator_wnaf.size() - 1])); accumulator = element(out_x, out_y); } @@ -214,7 +213,6 @@ element element::bn254_endo_batch_mul_with_generator * max_num_small_bits : MINIMUM value must be 128 bits * (we will be splitting `big_scalars` into two 128-bit scalars, we assume all scalars after this transformation are 128 *bits) - * TODO: this does not seem to be used anywhere except turbo plonk. delete once we deprecate turbo? **/ template template @@ -322,7 +320,7 @@ element element::bn254_endo_batch_mul(const std::vec **/ const size_t num_rounds = max_num_small_bits; const size_t num_points = points.size(); - std::vector> naf_entries; + std::vector>> naf_entries; for (size_t i = 0; i < num_points; ++i) { naf_entries.emplace_back(compute_naf(scalars[i], max_num_small_bits)); } @@ -356,7 +354,7 @@ element element::bn254_endo_batch_mul(const std::vec **/ for (size_t i = 1; i < num_rounds / 2; ++i) { // `nafs` tracks the naf value for each point for the current round - std::vector nafs; + std::vector> nafs; for (size_t j = 0; j < points.size(); ++j) { nafs.emplace_back(naf_entries[j][i * 2 - 1]); } @@ -385,7 +383,7 @@ element element::bn254_endo_batch_mul(const std::vec // we need to iterate 1 more time if the number of rounds is even if ((num_rounds & 0x01ULL) == 0x00ULL) { - std::vector nafs; + std::vector> nafs; for (size_t j = 0; j < points.size(); ++j) { nafs.emplace_back(naf_entries[j][num_rounds - 1]); } diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_goblin.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_goblin.hpp index 15d8a16c372..62404fc055e 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_goblin.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_goblin.hpp @@ -1,6 +1,5 @@ #pragma once -#include "barretenberg/stdlib/primitives/biggroup/biggroup.hpp" namespace bb::stdlib { /** diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_goblin.test.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_goblin.test.cpp index 6e6e38d9358..1ac09c4e69d 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_goblin.test.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_goblin.test.cpp @@ -10,12 +10,12 @@ #include "barretenberg/numeric/random/engine.hpp" #include -using namespace bb; - namespace { auto& engine = numeric::get_debug_randomness(); } +using namespace bb; + template class stdlib_biggroup_goblin : public testing::Test { using element_ct = typename Curve::Element; using scalar_ct = typename Curve::ScalarField; diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_impl.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_impl.hpp index d446cfa06a3..35b1c477d72 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_impl.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_impl.hpp @@ -2,7 +2,8 @@ #include "../bit_array/bit_array.hpp" #include "../circuit_builders/circuit_builders.hpp" -#include "barretenberg/stdlib/primitives/biggroup/biggroup.hpp" + +using namespace bb; namespace bb::stdlib { @@ -10,181 +11,50 @@ template element::element() : x() , y() - , _is_infinity() {} template element::element(const typename G::affine_element& input) : x(nullptr, input.x) , y(nullptr, input.y) - , _is_infinity(nullptr, input.is_point_at_infinity()) {} template element::element(const Fq& x_in, const Fq& y_in) : x(x_in) , y(y_in) - , _is_infinity(x.get_context() ? x.get_context() : y.get_context(), false) {} template element::element(const element& other) : x(other.x) , y(other.y) - , _is_infinity(other.is_point_at_infinity()) {} template -element::element(element&& other) noexcept +element::element(element&& other) : x(other.x) , y(other.y) - , _is_infinity(other.is_point_at_infinity()) {} template element& element::operator=(const element& other) { - if (&other == this) { - return *this; - } x = other.x; y = other.y; - _is_infinity = other.is_point_at_infinity(); return *this; } template -element& element::operator=(element&& other) noexcept +element& element::operator=(element&& other) { - if (&other == this) { - return *this; - } x = other.x; y = other.y; - _is_infinity = other.is_point_at_infinity(); return *this; } template element element::operator+(const element& other) const -{ - // return checked_unconditional_add(other); - if constexpr (IsGoblinBuilder && std::same_as) { - // TODO(https://github.com/AztecProtocol/barretenberg/issues/707) Optimize - // Current gate count: 6398 - std::vector points{ *this, other }; - std::vector scalars{ 1, 1 }; - return goblin_batch_mul(points, scalars); - } - - // if x_coordinates match, lambda triggers a divide by zero error. - // Adding in `x_coordinates_match` ensures that lambda will always be well-formed - const bool_t x_coordinates_match = other.x == x; - const bool_t y_coordinates_match = (y == other.y); - const bool_t infinity_predicate = (x_coordinates_match && !y_coordinates_match); - const bool_t double_predicate = (x_coordinates_match && y_coordinates_match); - const bool_t lhs_infinity = is_point_at_infinity(); - const bool_t rhs_infinity = other.is_point_at_infinity(); - - // Compute the gradient `lambda`. If we add, `lambda = (y2 - y1)/(x2 - x1)`, else `lambda = 3x1*x1/2y1 - const Fq add_lambda_numerator = other.y - y; - const Fq xx = x * x; - const Fq dbl_lambda_numerator = xx + xx + xx; - const Fq lambda_numerator = Fq::conditional_assign(double_predicate, dbl_lambda_numerator, add_lambda_numerator); - - const Fq add_lambda_denominator = other.x - x; - const Fq dbl_lambda_denominator = y + y; - Fq lambda_denominator = Fq::conditional_assign(double_predicate, dbl_lambda_denominator, add_lambda_denominator); - // If either inputs are points at infinity, we set lambda_denominator to be 1. This ensures we never trigger a - // divide by zero error. - // (if either inputs are points at infinity we will not use the result of this computation) - Fq safe_edgecase_denominator = Fq(field_t(1), field_t(0), field_t(0), field_t(0)); - lambda_denominator = Fq::conditional_assign( - lhs_infinity || rhs_infinity || infinity_predicate, safe_edgecase_denominator, lambda_denominator); - const Fq lambda = Fq::div_without_denominator_check({ lambda_numerator }, lambda_denominator); - - const Fq x3 = lambda.sqradd({ -other.x, -x }); - const Fq y3 = lambda.madd(x - x3, { -y }); - - element result(x3, y3); - // if lhs infinity, return rhs - result.x = Fq::conditional_assign(lhs_infinity, other.x, result.x); - result.y = Fq::conditional_assign(lhs_infinity, other.y, result.y); - // if rhs infinity, return lhs - result.x = Fq::conditional_assign(rhs_infinity, x, result.x); - result.y = Fq::conditional_assign(rhs_infinity, y, result.y); - - // is result point at infinity? - // yes = infinity_predicate && !lhs_infinity && !rhs_infinity - // yes = lhs_infinity && rhs_infinity - // n.b. can likely optimize this - bool_t result_is_infinity = infinity_predicate && (!lhs_infinity && !rhs_infinity); - result_is_infinity = result_is_infinity || (lhs_infinity && rhs_infinity); - result.set_point_at_infinity(result_is_infinity); - return result; -} - -template -element element::operator-(const element& other) const -{ - // return checked_unconditional_add(other); - if constexpr (IsGoblinBuilder && std::same_as) { - // TODO(https://github.com/AztecProtocol/barretenberg/issues/707) Optimize - // Current gate count: 6398 - std::vector points{ *this, other }; - std::vector scalars{ 1, -Fr(1) }; - return goblin_batch_mul(points, scalars); - } - - // if x_coordinates match, lambda triggers a divide by zero error. - // Adding in `x_coordinates_match` ensures that lambda will always be well-formed - const bool_t x_coordinates_match = other.x == x; - const bool_t y_coordinates_match = (y == other.y); - const bool_t infinity_predicate = (x_coordinates_match && y_coordinates_match); - const bool_t double_predicate = (x_coordinates_match && !y_coordinates_match); - const bool_t lhs_infinity = is_point_at_infinity(); - const bool_t rhs_infinity = other.is_point_at_infinity(); - - // Compute the gradient `lambda`. If we add, `lambda = (y2 - y1)/(x2 - x1)`, else `lambda = 3x1*x1/2y1 - const Fq add_lambda_numerator = -other.y - y; - const Fq xx = x * x; - const Fq dbl_lambda_numerator = xx + xx + xx; - const Fq lambda_numerator = Fq::conditional_assign(double_predicate, dbl_lambda_numerator, add_lambda_numerator); - - const Fq add_lambda_denominator = other.x - x; - const Fq dbl_lambda_denominator = y + y; - Fq lambda_denominator = Fq::conditional_assign(double_predicate, dbl_lambda_denominator, add_lambda_denominator); - // If either inputs are points at infinity, we set lambda_denominator to be 1. This ensures we never trigger a - // divide by zero error. - // (if either inputs are points at infinity we will not use the result of this computation) - Fq safe_edgecase_denominator = Fq(field_t(1), field_t(0), field_t(0), field_t(0)); - lambda_denominator = Fq::conditional_assign( - lhs_infinity || rhs_infinity || infinity_predicate, safe_edgecase_denominator, lambda_denominator); - const Fq lambda = Fq::div_without_denominator_check({ lambda_numerator }, lambda_denominator); - - const Fq x3 = lambda.sqradd({ -other.x, -x }); - const Fq y3 = lambda.madd(x - x3, { -y }); - - element result(x3, y3); - // if lhs infinity, return rhs - result.x = Fq::conditional_assign(lhs_infinity, other.x, result.x); - result.y = Fq::conditional_assign(lhs_infinity, -other.y, result.y); - // if rhs infinity, return lhs - result.x = Fq::conditional_assign(rhs_infinity, x, result.x); - result.y = Fq::conditional_assign(rhs_infinity, y, result.y); - - // is result point at infinity? - // yes = infinity_predicate && !lhs_infinity && !rhs_infinity - // yes = lhs_infinity && rhs_infinity - // n.b. can likely optimize this - bool_t result_is_infinity = infinity_predicate && (!lhs_infinity && !rhs_infinity); - result_is_infinity = result_is_infinity || (lhs_infinity && rhs_infinity); - result.set_point_at_infinity(result_is_infinity); - return result; -} - -template -element element::checked_unconditional_add(const element& other) const { if constexpr (IsGoblinBuilder && std::same_as) { // TODO(https://github.com/AztecProtocol/barretenberg/issues/707) Optimize @@ -202,7 +72,7 @@ element element::checked_unconditional_add(const ele } template -element element::checked_unconditional_subtract(const element& other) const +element element::operator-(const element& other) const { if constexpr (IsGoblinBuilder && std::same_as) { // TODO(https://github.com/AztecProtocol/barretenberg/issues/707) Optimize @@ -235,7 +105,7 @@ element element::checked_unconditional_subtract(cons */ // TODO(https://github.com/AztecProtocol/barretenberg/issues/657): This function is untested template -std::array, 2> element::checked_unconditional_add_sub(const element& other) const +std::array, 2> element::add_sub(const element& other) const { if constexpr (IsGoblinBuilder && std::same_as) { return { *this + other, *this - other }; @@ -270,9 +140,7 @@ template element element Fq neg_lambda = Fq::msub_div({ x }, { (two_x + x) }, (y + y), {}); Fq x_3 = neg_lambda.sqradd({ -(two_x) }); Fq y_3 = neg_lambda.madd(x_3 - x, { -y }); - element result = element(x_3, y_3); - result.set_point_at_infinity(is_point_at_infinity()); - return result; + return element(x_3, y_3); } /** @@ -763,7 +631,7 @@ element element::batch_mul(const std::vector> naf_entries; + std::vector>> naf_entries; for (size_t i = 0; i < num_points; ++i) { naf_entries.emplace_back(compute_naf(scalars[i], max_num_bits)); } @@ -778,7 +646,7 @@ element element::batch_mul(const std::vector nafs(num_points); + std::vector> nafs(num_points); std::vector to_add; const size_t inner_num_rounds = (i != num_iterations - 1) ? num_rounds_per_iteration : num_rounds_per_final_iteration; @@ -841,14 +709,14 @@ element element::operator*(const Fr& scalar) const } else { constexpr uint64_t num_rounds = Fr::modulus.get_msb() + 1; - std::vector naf_entries = compute_naf(scalar); + std::vector> naf_entries = compute_naf(scalar); const auto offset_generators = compute_offset_generators(num_rounds); element accumulator = *this + offset_generators.first; for (size_t i = 1; i < num_rounds; ++i) { - bool_t predicate = naf_entries[i]; + bool_t predicate = naf_entries[i]; bigfield y_test = y.conditional_negate(predicate); element to_add(x, y_test); accumulator = accumulator.montgomery_ladder(to_add); diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_nafs.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_nafs.hpp index f1dd10cd30e..32a8a3876c1 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_nafs.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_nafs.hpp @@ -1,6 +1,5 @@ #pragma once #include "barretenberg/ecc/curves/secp256k1/secp256k1.hpp" -#include "barretenberg/stdlib/primitives/biggroup/biggroup.hpp" namespace bb::stdlib { @@ -482,17 +481,17 @@ std::vector> element::compute_naf(const Fr& scalar, cons uint256_t scalar_multiplier = scalar_multiplier_512.lo; const size_t num_rounds = (max_num_bits == 0) ? Fr::modulus.get_msb() + 1 : max_num_bits; - std::vector naf_entries(num_rounds + 1); + std::vector> naf_entries(num_rounds + 1); // if boolean is false => do NOT flip y // if boolean is true => DO flip y // first entry is skew. i.e. do we subtract one from the final result or not if (scalar_multiplier.get_bit(0) == false) { // add skew - naf_entries[num_rounds] = bool_t(witness_t(ctx, true)); + naf_entries[num_rounds] = bool_t(witness_t(ctx, true)); scalar_multiplier += uint256_t(1); } else { - naf_entries[num_rounds] = bool_t(witness_t(ctx, false)); + naf_entries[num_rounds] = bool_t(witness_t(ctx, false)); } for (size_t i = 0; i < num_rounds - 1; ++i) { bool next_entry = scalar_multiplier.get_bit(i + 1); @@ -500,7 +499,7 @@ std::vector> element::compute_naf(const Fr& scalar, cons // This is a VERY hacky workaround to ensure that UltraPlonkBuilder will apply a basic // range constraint per bool, and not a full 1-bit range gate if (next_entry == false) { - bool_t bit(ctx, true); + bool_t bit(ctx, true); bit.context = ctx; bit.witness_index = witness_t(ctx, true).witness_index; // flip sign bit.witness_bool = true; @@ -513,7 +512,7 @@ std::vector> element::compute_naf(const Fr& scalar, cons } naf_entries[num_rounds - i - 1] = bit; } else { - bool_t bit(ctx, false); + bool_t bit(ctx, false); bit.witness_index = witness_t(ctx, false).witness_index; // don't flip sign bit.witness_bool = false; if constexpr (HasPlookup) { @@ -526,7 +525,7 @@ std::vector> element::compute_naf(const Fr& scalar, cons naf_entries[num_rounds - i - 1] = bit; } } - naf_entries[0] = bool_t(ctx, false); // most significant entry is always true + naf_entries[0] = bool_t(ctx, false); // most significant entry is always true // validate correctness of NAF if constexpr (!Fr::is_composite) { @@ -543,7 +542,7 @@ std::vector> element::compute_naf(const Fr& scalar, cons Fr accumulator_result = Fr::accumulate(accumulators); scalar.assert_equal(accumulator_result); } else { - const auto reconstruct_half_naf = [](bool_t* nafs, const size_t half_round_length) { + const auto reconstruct_half_naf = [](bool_t* nafs, const size_t half_round_length) { // Q: need constraint to start from zero? field_t negative_accumulator(0); field_t positive_accumulator(0); diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_secp256k1.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_secp256k1.hpp index b9b363ba8ea..6f898f6a217 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_secp256k1.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_secp256k1.hpp @@ -5,7 +5,6 @@ * TODO: we should try to genericize this, but this method is super fiddly and we need it to be efficient! * **/ -#include "barretenberg/stdlib/primitives/biggroup/biggroup.hpp" namespace bb::stdlib { template @@ -120,14 +119,14 @@ element element::secp256k1_ecdsa_mul(const element& const element& base_point, const field_t& positive_skew, const field_t& negative_skew) { - const bool_t positive_skew_bool(positive_skew); - const bool_t negative_skew_bool(negative_skew); + const bool_t positive_skew_bool(positive_skew); + const bool_t negative_skew_bool(negative_skew); auto to_add = base_point; to_add.y = to_add.y.conditional_negate(negative_skew_bool); element result = accumulator + to_add; // when computing the wNAF we have already validated that positive_skew and negative_skew cannot both be true - bool_t skew_combined = positive_skew_bool ^ negative_skew_bool; + bool_t skew_combined = positive_skew_bool ^ negative_skew_bool; result.x = accumulator.x.conditional_select(result.x, skew_combined); result.y = accumulator.y.conditional_select(result.y, skew_combined); return result; diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_tables.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_tables.hpp index bdb6a9cd61f..78cc53e03b7 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_tables.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_tables.hpp @@ -1,6 +1,4 @@ #pragma once -#include "barretenberg/stdlib/primitives/biggroup/biggroup.hpp" -#include "barretenberg/stdlib/primitives/memory/twin_rom_table.hpp" #include "barretenberg/stdlib_circuit_builders/plookup_tables/types.hpp" namespace bb::stdlib { @@ -182,27 +180,27 @@ template element::lookup_table_plookup::lookup_table_plookup(const std::array& inputs) { if constexpr (length == 2) { - auto [A0, A1] = inputs[1].checked_unconditional_add_sub(inputs[0]); + auto [A0, A1] = inputs[1].add_sub(inputs[0]); element_table[0] = A0; element_table[1] = A1; } else if constexpr (length == 3) { - auto [R0, R1] = inputs[1].checked_unconditional_add_sub(inputs[0]); // B ± A + auto [R0, R1] = inputs[1].add_sub(inputs[0]); // B ± A - auto [T0, T1] = inputs[2].checked_unconditional_add_sub(R0); // C ± (B + A) - auto [T2, T3] = inputs[2].checked_unconditional_add_sub(R1); // C ± (B - A) + auto [T0, T1] = inputs[2].add_sub(R0); // C ± (B + A) + auto [T2, T3] = inputs[2].add_sub(R1); // C ± (B - A) element_table[0] = T0; element_table[1] = T2; element_table[2] = T3; element_table[3] = T1; } else if constexpr (length == 4) { - auto [T0, T1] = inputs[1].checked_unconditional_add_sub(inputs[0]); // B ± A - auto [T2, T3] = inputs[3].checked_unconditional_add_sub(inputs[2]); // D ± C + auto [T0, T1] = inputs[1].add_sub(inputs[0]); // B ± A + auto [T2, T3] = inputs[3].add_sub(inputs[2]); // D ± C - auto [F0, F3] = T2.checked_unconditional_add_sub(T0); // (D + C) ± (B + A) - auto [F1, F2] = T2.checked_unconditional_add_sub(T1); // (D + C) ± (B - A) - auto [F4, F7] = T3.checked_unconditional_add_sub(T0); // (D - C) ± (B + A) - auto [F5, F6] = T3.checked_unconditional_add_sub(T1); // (D - C) ± (B - A) + auto [F0, F3] = T2.add_sub(T0); // (D + C) ± (B + A) + auto [F1, F2] = T2.add_sub(T1); // (D + C) ± (B - A) + auto [F4, F7] = T3.add_sub(T0); // (D - C) ± (B + A) + auto [F5, F6] = T3.add_sub(T1); // (D - C) ± (B - A) element_table[0] = F0; element_table[1] = F1; @@ -213,20 +211,20 @@ element::lookup_table_plookup::lookup_table_plookup(con element_table[6] = F6; element_table[7] = F7; } else if constexpr (length == 5) { - auto [A0, A1] = inputs[1].checked_unconditional_add_sub(inputs[0]); // B ± A - auto [T2, T3] = inputs[3].checked_unconditional_add_sub(inputs[2]); // D ± C + auto [A0, A1] = inputs[1].add_sub(inputs[0]); // B ± A + auto [T2, T3] = inputs[3].add_sub(inputs[2]); // D ± C - auto [E0, E3] = inputs[4].checked_unconditional_add_sub(T2); // E ± (D + C) - auto [E1, E2] = inputs[4].checked_unconditional_add_sub(T3); // E ± (D - C) + auto [E0, E3] = inputs[4].add_sub(T2); // E ± (D + C) + auto [E1, E2] = inputs[4].add_sub(T3); // E ± (D - C) - auto [F0, F3] = E0.checked_unconditional_add_sub(A0); - auto [F1, F2] = E0.checked_unconditional_add_sub(A1); - auto [F4, F7] = E1.checked_unconditional_add_sub(A0); - auto [F5, F6] = E1.checked_unconditional_add_sub(A1); - auto [F8, F11] = E2.checked_unconditional_add_sub(A0); - auto [F9, F10] = E2.checked_unconditional_add_sub(A1); - auto [F12, F15] = E3.checked_unconditional_add_sub(A0); - auto [F13, F14] = E3.checked_unconditional_add_sub(A1); + auto [F0, F3] = E0.add_sub(A0); + auto [F1, F2] = E0.add_sub(A1); + auto [F4, F7] = E1.add_sub(A0); + auto [F5, F6] = E1.add_sub(A1); + auto [F8, F11] = E2.add_sub(A0); + auto [F9, F10] = E2.add_sub(A1); + auto [F12, F15] = E3.add_sub(A0); + auto [F13, F14] = E3.add_sub(A1); element_table[0] = F0; element_table[1] = F1; @@ -247,33 +245,33 @@ element::lookup_table_plookup::lookup_table_plookup(con } else if constexpr (length == 6) { // 44 adds! Only use this if it saves us adding another table to a multi-scalar-multiplication - auto [A0, A1] = inputs[1].checked_unconditional_add_sub(inputs[0]); - auto [E0, E1] = inputs[4].checked_unconditional_add_sub(inputs[3]); - auto [C0, C3] = inputs[2].checked_unconditional_add_sub(A0); - auto [C1, C2] = inputs[2].checked_unconditional_add_sub(A1); + auto [A0, A1] = inputs[1].add_sub(inputs[0]); + auto [E0, E1] = inputs[4].add_sub(inputs[3]); + auto [C0, C3] = inputs[2].add_sub(A0); + auto [C1, C2] = inputs[2].add_sub(A1); - auto [F0, F3] = inputs[5].checked_unconditional_add_sub(E0); - auto [F1, F2] = inputs[5].checked_unconditional_add_sub(E1); + auto [F0, F3] = inputs[5].add_sub(E0); + auto [F1, F2] = inputs[5].add_sub(E1); - auto [R0, R7] = F0.checked_unconditional_add_sub(C0); - auto [R1, R6] = F0.checked_unconditional_add_sub(C1); - auto [R2, R5] = F0.checked_unconditional_add_sub(C2); - auto [R3, R4] = F0.checked_unconditional_add_sub(C3); + auto [R0, R7] = F0.add_sub(C0); + auto [R1, R6] = F0.add_sub(C1); + auto [R2, R5] = F0.add_sub(C2); + auto [R3, R4] = F0.add_sub(C3); - auto [S0, S7] = F1.checked_unconditional_add_sub(C0); - auto [S1, S6] = F1.checked_unconditional_add_sub(C1); - auto [S2, S5] = F1.checked_unconditional_add_sub(C2); - auto [S3, S4] = F1.checked_unconditional_add_sub(C3); + auto [S0, S7] = F1.add_sub(C0); + auto [S1, S6] = F1.add_sub(C1); + auto [S2, S5] = F1.add_sub(C2); + auto [S3, S4] = F1.add_sub(C3); - auto [U0, U7] = F2.checked_unconditional_add_sub(C0); - auto [U1, U6] = F2.checked_unconditional_add_sub(C1); - auto [U2, U5] = F2.checked_unconditional_add_sub(C2); - auto [U3, U4] = F2.checked_unconditional_add_sub(C3); + auto [U0, U7] = F2.add_sub(C0); + auto [U1, U6] = F2.add_sub(C1); + auto [U2, U5] = F2.add_sub(C2); + auto [U3, U4] = F2.add_sub(C3); - auto [W0, W7] = F3.checked_unconditional_add_sub(C0); - auto [W1, W6] = F3.checked_unconditional_add_sub(C1); - auto [W2, W5] = F3.checked_unconditional_add_sub(C2); - auto [W3, W4] = F3.checked_unconditional_add_sub(C3); + auto [W0, W7] = F3.add_sub(C0); + auto [W1, W6] = F3.add_sub(C1); + auto [W2, W5] = F3.add_sub(C2); + auto [W3, W4] = F3.add_sub(C3); element_table[0] = R0; element_table[1] = R1; @@ -410,7 +408,7 @@ element::lookup_table_plookup::lookup_table_plookup(con template template element element::lookup_table_plookup::get( - const std::array& bits) const + const std::array, length>& bits) const { std::vector> accumulators; for (size_t i = 0; i < length; ++i) { @@ -560,20 +558,20 @@ element::lookup_table_base::lookup_table_base(const std::a template template element element::lookup_table_base::get( - const std::array& bits) const + const std::array, length>& bits) const { static_assert(length <= 4 && length >= 2); if constexpr (length == 2) { - bool_t table_selector = bits[0] ^ bits[1]; - bool_t sign_selector = bits[1]; + bool_t table_selector = bits[0] ^ bits[1]; + bool_t sign_selector = bits[1]; Fq to_add_x = twin0.x.conditional_select(twin1.x, table_selector); Fq to_add_y = twin0.y.conditional_select(twin1.y, table_selector); element to_add(to_add_x, to_add_y.conditional_negate(sign_selector)); return to_add; } else if constexpr (length == 3) { - bool_t t0 = bits[2] ^ bits[0]; - bool_t t1 = bits[2] ^ bits[1]; + bool_t t0 = bits[2] ^ bits[0]; + bool_t t1 = bits[2] ^ bits[1]; field_t x_b0 = field_t::select_from_two_bit_table(x_b0_table, t1, t0); field_t x_b1 = field_t::select_from_two_bit_table(x_b1_table, t1, t0); @@ -606,9 +604,9 @@ element element::lookup_table_base::get( return to_add; } else if constexpr (length == 4) { - bool_t t0 = bits[3] ^ bits[0]; - bool_t t1 = bits[3] ^ bits[1]; - bool_t t2 = bits[3] ^ bits[2]; + bool_t t0 = bits[3] ^ bits[0]; + bool_t t1 = bits[3] ^ bits[1]; + bool_t t2 = bits[3] ^ bits[2]; field_t x_b0 = field_t::select_from_three_bit_table(x_b0_table, t2, t1, t0); field_t x_b1 = field_t::select_from_three_bit_table(x_b1_table, t2, t1, t0); diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/curves/secp256r1.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/curves/secp256r1.hpp index 5b7a5106f3f..a6593e4f831 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/curves/secp256r1.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/curves/secp256r1.hpp @@ -11,9 +11,9 @@ namespace bb::stdlib { template struct secp256r1 { static constexpr bb::CurveType type = bb::CurveType::SECP256R1; - typedef bb::secp256r1::fq fq; - typedef bb::secp256r1::fr fr; - typedef bb::secp256r1::g1 g1; + typedef ::secp256r1::fq fq; + typedef ::secp256r1::fr fr; + typedef ::secp256r1::g1 g1; typedef CircuitType Builder; typedef witness_t witness_ct; @@ -23,8 +23,8 @@ template struct secp256r1 { typedef bool_t bool_ct; typedef stdlib::uint32 uint32_ct; - typedef bigfield fq_ct; - typedef bigfield bigfr_ct; + typedef bigfield fq_ct; + typedef bigfield bigfr_ct; typedef element g1_ct; typedef element g1_bigfr_ct; }; diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/databus/databus.test.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/databus/databus.test.cpp index 5d8f05b50b3..e8daaa52170 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/databus/databus.test.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/databus/databus.test.cpp @@ -6,8 +6,6 @@ #include "barretenberg/stdlib_circuit_builders/goblin_ultra_circuit_builder.hpp" #include "databus.hpp" -using namespace bb; - using Builder = GoblinUltraCircuitBuilder; using field_ct = stdlib::field_t; using witness_ct = stdlib::witness_t; From 51b95896db0693184b43125c3beb3f7815b127e2 Mon Sep 17 00:00:00 2001 From: codygunton Date: Thu, 16 May 2024 13:13:33 +0000 Subject: [PATCH 12/24] Try to get earthly run --- .../relations_bench/relations.bench.cpp | 128 +++++++++++++----- 1 file changed, 96 insertions(+), 32 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/benchmark/relations_bench/relations.bench.cpp b/barretenberg/cpp/src/barretenberg/benchmark/relations_bench/relations.bench.cpp index 837764376b3..86f246dc79f 100644 --- a/barretenberg/cpp/src/barretenberg/benchmark/relations_bench/relations.bench.cpp +++ b/barretenberg/cpp/src/barretenberg/benchmark/relations_bench/relations.bench.cpp @@ -1,4 +1,5 @@ #include "barretenberg/eccvm/eccvm_flavor.hpp" +#include "barretenberg/protogalaxy/protogalaxy_prover.hpp" #include "barretenberg/stdlib_circuit_builders/goblin_ultra_flavor.hpp" #include "barretenberg/stdlib_circuit_builders/ultra_flavor.hpp" #include "barretenberg/translator_vm/goblin_translator_flavor.hpp" @@ -13,47 +14,110 @@ namespace bb::benchmark::relations { using Fr = bb::fr; using Fq = grumpkin::fr; -template void execute_relation(::benchmark::State& state) +// Generic helper for executing Relation::accumulate for the template specified input type +template +void execute_relation(::benchmark::State& state) { using FF = typename Flavor::FF; - using AllValues = typename Flavor::AllValues; - using SumcheckArrayOfValuesOverSubrelations = typename Relation::SumcheckArrayOfValuesOverSubrelations; auto params = bb::RelationParameters::get_random(); - // Extract an array containing all the polynomial evaluations at a given row i - AllValues new_value{}; - // Define the appropriate SumcheckArrayOfValuesOverSubrelations type for this relation and initialize to zero - SumcheckArrayOfValuesOverSubrelations accumulator; - // Evaluate each constraint in the relation and check that each is satisfied + // Instantiate zero-initialized inputs and accumulator + Input input{}; + Accumulator accumulator; for (auto _ : state) { - Relation::accumulate(accumulator, new_value, params, 1); + Relation::accumulate(accumulator, input, params, 1); } } -BENCHMARK(execute_relation>); -BENCHMARK(execute_relation>); -BENCHMARK(execute_relation>); -BENCHMARK(execute_relation>); -BENCHMARK(execute_relation>); -BENCHMARK(execute_relation>); - -BENCHMARK(execute_relation>); - -BENCHMARK(execute_relation>); -BENCHMARK(execute_relation>); -BENCHMARK(execute_relation>); -BENCHMARK(execute_relation>); -BENCHMARK(execute_relation>); -BENCHMARK(execute_relation>); - -BENCHMARK(execute_relation>); -BENCHMARK(execute_relation>); -BENCHMARK(execute_relation>); -BENCHMARK(execute_relation>); -BENCHMARK(execute_relation>); -BENCHMARK(execute_relation>); -BENCHMARK(execute_relation>); + +// Single execution of relation on values (FF), e.g. Sumcheck verifier / PG perturbator work +template void execute_relation_for_values(::benchmark::State& state) +{ + using Input = typename Flavor::AllValues; + using Accumulator = typename Relation::SumcheckArrayOfValuesOverSubrelations; + + execute_relation(state); +} + +// Single execution of relation on Sumcheck univariates, i.e. Sumcheck/Decider prover work +template void execute_relation_for_univariates(::benchmark::State& state) +{ + using Input = typename Flavor::ExtendedEdges; + using Accumulator = typename Relation::SumcheckTupleOfUnivariatesOverSubrelations; + + execute_relation(state); +} + +// Single execution of relation on PG univariates, i.e. PG combiner work +template void execute_relation_for_pg_univariates(::benchmark::State& state) +{ + using ProverInstances = ProverInstances_; + using ProtoGalaxyProver = ProtoGalaxyProver_; + using Input = ProtoGalaxyProver::ExtendedUnivariates; + using Accumulator = typename Relation::template ProtogalaxyTupleOfUnivariatesOverSubrelations; + + execute_relation(state); +} + +// Ultra relations (PG prover combiner work) +BENCHMARK(execute_relation_for_pg_univariates>); +BENCHMARK(execute_relation_for_pg_univariates>); +BENCHMARK(execute_relation_for_pg_univariates>); +BENCHMARK(execute_relation_for_pg_univariates>); +BENCHMARK(execute_relation_for_pg_univariates>); +BENCHMARK(execute_relation_for_pg_univariates>); + +// Goblin-Ultra only relations (PG prover combiner work) +BENCHMARK(execute_relation_for_pg_univariates>); +BENCHMARK(execute_relation_for_pg_univariates>); +BENCHMARK(execute_relation_for_pg_univariates>); +BENCHMARK(execute_relation_for_pg_univariates>); + +// Ultra relations (Sumcheck prover work) +BENCHMARK(execute_relation_for_univariates>); +BENCHMARK(execute_relation_for_univariates>); +BENCHMARK(execute_relation_for_univariates>); +BENCHMARK(execute_relation_for_univariates>); +BENCHMARK(execute_relation_for_univariates>); +BENCHMARK(execute_relation_for_univariates>); + +// Goblin-Ultra only relations (Sumcheck prover work) +BENCHMARK(execute_relation_for_univariates>); +BENCHMARK(execute_relation_for_univariates>); +BENCHMARK(execute_relation_for_univariates>); +BENCHMARK(execute_relation_for_univariates>); + +// Ultra relations (verifier work) +BENCHMARK(execute_relation_for_values>); +BENCHMARK(execute_relation_for_values>); +BENCHMARK(execute_relation_for_values>); +BENCHMARK(execute_relation_for_values>); +BENCHMARK(execute_relation_for_values>); +BENCHMARK(execute_relation_for_values>); + +// Goblin-Ultra only relations (verifier work) +BENCHMARK(execute_relation_for_values>); +BENCHMARK(execute_relation_for_values>); +BENCHMARK(execute_relation_for_values>); +BENCHMARK(execute_relation_for_values>); + +// Translator VM +BENCHMARK(execute_relation_for_values>); +BENCHMARK(execute_relation_for_values>); +BENCHMARK(execute_relation_for_values>); +BENCHMARK(execute_relation_for_values>); +BENCHMARK(execute_relation_for_values>); +BENCHMARK(execute_relation_for_values>); + +// ECCVM +BENCHMARK(execute_relation_for_values>); +BENCHMARK(execute_relation_for_values>); +BENCHMARK(execute_relation_for_values>); +BENCHMARK(execute_relation_for_values>); +BENCHMARK(execute_relation_for_values>); +BENCHMARK(execute_relation_for_values>); +BENCHMARK(execute_relation_for_values>); } // namespace bb::benchmark::relations From 4bbb7ee15bf763b155a07fe9ab13b779a5ba9e9a Mon Sep 17 00:00:00 2001 From: codygunton Date: Thu, 16 May 2024 16:57:29 +0000 Subject: [PATCH 13/24] Partial builder merge --- .../src/barretenberg/eccvm/eccvm_flavor.hpp | 9 +-- .../src/barretenberg/eccvm/msm_builder.hpp | 69 ++++++++----------- 2 files changed, 32 insertions(+), 46 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp index 5f193dbe927..42afa570abb 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp @@ -492,8 +492,9 @@ class ECCVMFlavor { const std::vector msms = builder.get_msms(); const auto point_table_rows = ECCVMPointTablePrecomputationBuilder::compute_rows(CircuitBuilder::get_flattened_scalar_muls(msms)); - const auto [msm_rows, point_table_read_counts] = ECCVMMSMMBuilder::compute_rows( - msms, builder.get_number_of_muls(), builder.op_queue->get_num_msm_rows()); + std::array, 2> point_table_read_counts; + const auto msm_rows = ECCVMMSMMBuilder::compute_rows( + msms, point_table_read_counts, builder.get_number_of_muls(), builder.op_queue->get_num_msm_rows()); const size_t num_rows = std::max({ point_table_rows.size(), msm_rows.size(), transcript_rows.size() }); const auto log_num_rows = static_cast(numeric::get_msb64(num_rows)); @@ -564,12 +565,12 @@ class ECCVMFlavor { } // in addition, unless the accumulator is reset, it contains the value from the previous row so this // must be propagated - for (size_t i = transcript_state.size(); i < num_rows_pow2; ++i) { + for (size_t i = transcript_rows.size(); i < dyadic_num_rows; ++i) { transcript_accumulator_x[i] = transcript_accumulator_x[i - 1]; transcript_accumulator_y[i] = transcript_accumulator_y[i - 1]; } - run_loop_in_parallel(precompute_table_rows.size(), [&](size_t start, size_t end) { + run_loop_in_parallel(point_table_rows.size(), [&](size_t start, size_t end) { for (size_t i = start; i < end; i++) { // first row is always an empty row (to accommodate shifted polynomials which must have 0 as 1st // coefficient). All other rows in the point_table_rows represent active wnaf gates (i.e. diff --git a/barretenberg/cpp/src/barretenberg/eccvm/msm_builder.hpp b/barretenberg/cpp/src/barretenberg/eccvm/msm_builder.hpp index 2d01476ecd9..676dbe00249 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/msm_builder.hpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/msm_builder.hpp @@ -15,11 +15,12 @@ class ECCVMMSMMBuilder { using AffineElement = typename CycleGroup::affine_element; static constexpr size_t ADDITIONS_PER_ROW = bb::eccvm::ADDITIONS_PER_ROW; - static constexpr size_t NUM_SCALAR_BITS = bb::eccvm::NUM_SCALAR_BITS; - static constexpr size_t WNAF_SLICE_BITS = bb::eccvm::WNAF_SLICE_BITS; + static constexpr size_t NUM_WNAF_DIGITS_PER_SCALAR = bb::eccvm::NUM_WNAF_DIGITS_PER_SCALAR; - struct alignas(64) MSMState { + struct alignas(64) MSMRow { + // counter over all half-length scalar muls used to compute the required MSMs uint32_t pc = 0; + // the number of points that will be scaled and summed uint32_t msm_size = 0; uint32_t msm_count = 0; uint32_t msm_round = 0; @@ -43,21 +44,6 @@ class ECCVMMSMMBuilder { FF accumulator_y = 0; }; - struct alignas(64) MSMRowTranscript { - std::array lambda_numerator; - std::array lambda_denominator; - Element accumulator_in; - Element accumulator_out; - }; - - struct alignas(64) AdditionTrace { - Element p1; - Element p2; - Element p3; - bool predicate; - bool is_double; - }; - /** * @brief Computes the row values for the Straus MSM columns of the ECCVM. * @@ -67,12 +53,12 @@ class ECCVMMSMMBuilder { * @param msms * @param point_table_read_counts * @param total_number_of_muls - * @return std::vector + * @return std::vector */ - static std::vector compute_msm_state(const std::vector>& msms, - std::array, 2>& point_table_read_counts, - const uint32_t total_number_of_muls, - const size_t num_msm_rows) + static std::vector compute_rows(const std::vector>& msms, + std::array, 2>& point_table_read_counts, + const uint32_t total_number_of_muls, + const size_t num_msm_rows) { // N.B. the following comments refer to a "point lookup table" frequently. // To perform a scalar multiplicaiton of a point [P] by a scalar x, we compute multiples of [P] and store in a @@ -129,15 +115,14 @@ class ECCVMMSMMBuilder { msm_row_indices.push_back(1); pc_indices.push_back(total_number_of_muls); for (const auto& msm : msms) { - const size_t rows = ECCOpQueue::get_msm_row_count_for_single_msm(msm.size()); + const size_t rows = ECCOpQueue::num_eccvm_msm_rows(msm.size()); msm_row_indices.push_back(msm_row_indices.back() + rows); pc_indices.push_back(pc_indices.back() - msm.size()); } - static constexpr size_t num_rounds = NUM_SCALAR_BITS / WNAF_SLICE_BITS; - std::vector msm_state(num_msm_rows); + std::vector msm_state(num_msm_rows); // start with empty row (shiftable polynomials must have 0 as first coefficient) - msm_state[0] = (MSMState{}); + msm_state[0] = (MSMRow{}); // compute "read counts" so that we can determine the number of times entries in our log-derivative lookup // tables are called. @@ -145,7 +130,7 @@ class ECCVMMSMMBuilder { // concern. for (size_t i = 0; i < msms.size(); ++i) { - for (size_t j = 0; j < num_rounds; ++j) { + for (size_t j = 0; j < NUM_WNAF_DIGITS_PER_SCALAR; ++j) { uint32_t pc = static_cast(pc_indices[i]); const auto& msm = msms[i]; const size_t msm_size = msm.size(); @@ -159,13 +144,13 @@ class ECCVMMSMMBuilder { for (size_t m = 0; m < ADDITIONS_PER_ROW; ++m) { bool add = points_per_row > m; if (add) { - int slice = add ? msm[idx + m].wnaf_slices[j] : 0; + int slice = add ? msm[idx + m].wnaf_digits[j] : 0; update_read_counts(pc - idx - m, slice); } } } - if (j == num_rounds - 1) { + if (j == NUM_WNAF_DIGITS_PER_SCALAR - 1) { for (size_t k = 0; k < rows_per_round; ++k) { const size_t points_per_row = (k + 1) * ADDITIONS_PER_ROW > msm_size ? msm_size % ADDITIONS_PER_ROW : ADDITIONS_PER_ROW; @@ -220,7 +205,7 @@ class ECCVMMSMMBuilder { (msm_size / ADDITIONS_PER_ROW) + (msm_size % ADDITIONS_PER_ROW != 0 ? 1 : 0); size_t trace_index = (msm_row_indices[i] - 1) * 4; - for (size_t j = 0; j < num_rounds; ++j) { + for (size_t j = 0; j < NUM_WNAF_DIGITS_PER_SCALAR; ++j) { const uint32_t pc = static_cast(pc_indices[i]); for (size_t k = 0; k < rows_per_round; ++k) { @@ -233,7 +218,7 @@ class ECCVMMSMMBuilder { auto& add_state = row.add_state[m]; add_state.add = points_per_row > m; - int slice = add_state.add ? msm[idx + m].wnaf_slices[j] : 0; + int slice = add_state.add ? msm[idx + m].wnaf_digits[j] : 0; // In the MSM columns in the ECCVM circuit, we can add up to 4 points per row. // if `row.add_state[m].add = 1`, this indicates that we want to add the `m`'th point in // the MSM columns into the MSM accumulator `add_state.slice` = A 4-bit WNAF slice of @@ -268,7 +253,7 @@ class ECCVMMSMMBuilder { msm_row_index++; } // doubling - if (j < num_rounds - 1) { + if (j < NUM_WNAF_DIGITS_PER_SCALAR - 1) { auto& row = msm_state[msm_row_index]; row.msm_transition = false; row.msm_round = static_cast(j + 1); @@ -371,7 +356,7 @@ class ECCVMMSMMBuilder { const size_t rows_per_round = (msm_size / ADDITIONS_PER_ROW) + (msm_size % ADDITIONS_PER_ROW != 0 ? 1 : 0); - for (size_t j = 0; j < num_rounds; ++j) { + for (size_t j = 0; j < NUM_WNAF_DIGITS_PER_SCALAR; ++j) { for (size_t k = 0; k < rows_per_round; ++k) { auto& row = msm_state[msm_row_index]; const Element& normalized_accumulator = accumulator_trace[accumulator_index]; @@ -391,8 +376,8 @@ class ECCVMMSMMBuilder { msm_row_index++; } - if (j < num_rounds - 1) { - MSMState& row = msm_state[msm_row_index]; + if (j < NUM_WNAF_DIGITS_PER_SCALAR - 1) { + MSMRow& row = msm_state[msm_row_index]; const Element& normalized_accumulator = accumulator_trace[accumulator_index]; const FF& acc_x = normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.x; const FF& acc_y = normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.y; @@ -411,7 +396,7 @@ class ECCVMMSMMBuilder { msm_row_index++; } else { for (size_t k = 0; k < rows_per_round; ++k) { - MSMState& row = msm_state[msm_row_index]; + MSMRow& row = msm_state[msm_row_index]; const Element& normalized_accumulator = accumulator_trace[accumulator_index]; ASSERT(normalized_accumulator.is_point_at_infinity() == 0); const size_t idx = k * ADDITIONS_PER_ROW; @@ -441,7 +426,7 @@ class ECCVMMSMMBuilder { // we always require 1 extra row at the end of the trace, because the accumulator x/y coordinates for row `i` // are present at row `i+1` Element final_accumulator(accumulator_trace.back()); - MSMState& final_row = msm_state.back(); + MSMRow& final_row = msm_state.back(); final_row.pc = static_cast(pc_indices.back()); final_row.msm_transition = true; final_row.accumulator_x = final_accumulator.is_point_at_infinity() ? 0 : final_accumulator.x; @@ -451,10 +436,10 @@ class ECCVMMSMMBuilder { final_row.q_add = false; final_row.q_double = false; final_row.q_skew = false; - final_row.add_state = { typename MSMState::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 }, - typename MSMState::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 }, - typename MSMState::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 }, - typename MSMState::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 } }; + final_row.add_state = { typename MSMRow::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 }, + typename MSMRow::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 }, + typename MSMRow::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 }, + typename MSMRow::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 } }; return msm_state; } From d6023b0b208ccd267231cd9a5468ca23834559d9 Mon Sep 17 00:00:00 2001 From: codygunton Date: Thu, 16 May 2024 19:59:00 +0000 Subject: [PATCH 14/24] Temp avoid multithreading bug --- .../src/barretenberg/eccvm/msm_builder.hpp | 332 +++++++++--------- 1 file changed, 171 insertions(+), 161 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/eccvm/msm_builder.hpp b/barretenberg/cpp/src/barretenberg/eccvm/msm_builder.hpp index 676dbe00249..d9050f0a7e4 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/msm_builder.hpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/msm_builder.hpp @@ -195,133 +195,138 @@ class ECCVMMSMMBuilder { // populate point trace data, and the components of the MSM execution trace that do not relate to affine point // operations - run_loop_in_parallel(msms.size(), [&](size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - Element accumulator = offset_generator; - const auto& msm = msms[i]; - size_t msm_row_index = msm_row_indices[i]; - const size_t msm_size = msm.size(); - const size_t rows_per_round = - (msm_size / ADDITIONS_PER_ROW) + (msm_size % ADDITIONS_PER_ROW != 0 ? 1 : 0); - size_t trace_index = (msm_row_indices[i] - 1) * 4; - - for (size_t j = 0; j < NUM_WNAF_DIGITS_PER_SCALAR; ++j) { - const uint32_t pc = static_cast(pc_indices[i]); + run_loop_in_parallel( + msms.size(), + [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + Element accumulator = offset_generator; + const auto& msm = msms[i]; + size_t msm_row_index = msm_row_indices[i]; + const size_t msm_size = msm.size(); + const size_t rows_per_round = + (msm_size / ADDITIONS_PER_ROW) + (msm_size % ADDITIONS_PER_ROW != 0 ? 1 : 0); + size_t trace_index = (msm_row_indices[i] - 1) * 4; + + for (size_t j = 0; j < NUM_WNAF_DIGITS_PER_SCALAR; ++j) { + const uint32_t pc = static_cast(pc_indices[i]); - for (size_t k = 0; k < rows_per_round; ++k) { - const size_t points_per_row = - (k + 1) * ADDITIONS_PER_ROW > msm_size ? msm_size % ADDITIONS_PER_ROW : ADDITIONS_PER_ROW; - auto& row = msm_state[msm_row_index]; - const size_t idx = k * ADDITIONS_PER_ROW; - row.msm_transition = (j == 0) && (k == 0); - for (size_t m = 0; m < ADDITIONS_PER_ROW; ++m) { - - auto& add_state = row.add_state[m]; - add_state.add = points_per_row > m; - int slice = add_state.add ? msm[idx + m].wnaf_digits[j] : 0; - // In the MSM columns in the ECCVM circuit, we can add up to 4 points per row. - // if `row.add_state[m].add = 1`, this indicates that we want to add the `m`'th point in - // the MSM columns into the MSM accumulator `add_state.slice` = A 4-bit WNAF slice of - // the scalar multiplier associated with the point we are adding (the specific slice - // chosen depends on the value of msm_round) (WNAF = windowed-non-adjacent-form. Value - // range is `-15, -13, - // ..., 15`) If `add_state.add = 1`, we want `add_state.slice` to be the *compressed* - // form of the WNAF slice value. (compressed = no gaps in the value range. i.e. -15, - // -13, ..., 15 maps to 0, ... , 15) - add_state.slice = add_state.add ? (slice + 15) / 2 : 0; - add_state.point = add_state.add - ? msm[idx + m].precomputed_table[static_cast(add_state.slice)] - : AffineElement{ 0, 0 }; - - Element p1 = accumulator; - Element p2 = Element(add_state.point); - accumulator = add_state.add ? (accumulator + add_state.point) : Element(p1); - p1_trace[trace_index] = p1; - p2_trace[trace_index] = p2; - p3_trace[trace_index] = accumulator; - operation_trace[trace_index] = false; - trace_index++; - } - accumulator_trace[msm_row_index] = accumulator; - row.q_add = true; - row.q_double = false; - row.q_skew = false; - row.msm_round = static_cast(j); - row.msm_size = static_cast(msm_size); - row.msm_count = static_cast(idx); - row.pc = pc; - msm_row_index++; - } - // doubling - if (j < NUM_WNAF_DIGITS_PER_SCALAR - 1) { - auto& row = msm_state[msm_row_index]; - row.msm_transition = false; - row.msm_round = static_cast(j + 1); - row.msm_size = static_cast(msm_size); - row.msm_count = static_cast(0); - row.q_add = false; - row.q_double = true; - row.q_skew = false; - for (size_t m = 0; m < 4; ++m) { - - auto& add_state = row.add_state[m]; - add_state.add = false; - add_state.slice = 0; - add_state.point = { 0, 0 }; - add_state.collision_inverse = 0; - - p1_trace[trace_index] = accumulator; - p2_trace[trace_index] = accumulator; - accumulator = accumulator.dbl(); - p3_trace[trace_index] = accumulator; - operation_trace[trace_index] = true; - trace_index++; - } - accumulator_trace[msm_row_index] = accumulator; - msm_row_index++; - } else { for (size_t k = 0; k < rows_per_round; ++k) { - auto& row = msm_state[msm_row_index]; - const size_t points_per_row = (k + 1) * ADDITIONS_PER_ROW > msm_size ? msm_size % ADDITIONS_PER_ROW : ADDITIONS_PER_ROW; + auto& row = msm_state[msm_row_index]; const size_t idx = k * ADDITIONS_PER_ROW; - row.msm_transition = false; - - Element acc_expected = accumulator; + row.msm_transition = (j == 0) && (k == 0); + for (size_t m = 0; m < ADDITIONS_PER_ROW; ++m) { - for (size_t m = 0; m < 4; ++m) { auto& add_state = row.add_state[m]; add_state.add = points_per_row > m; - add_state.slice = add_state.add ? msm[idx + m].wnaf_skew ? 7 : 0 : 0; - + int slice = add_state.add ? msm[idx + m].wnaf_digits[j] : 0; + // In the MSM columns in the ECCVM circuit, we can add up to 4 points per row. + // if `row.add_state[m].add = 1`, this indicates that we want to add the `m`'th point in + // the MSM columns into the MSM accumulator `add_state.slice` = A 4-bit WNAF slice of + // the scalar multiplier associated with the point we are adding (the specific slice + // chosen depends on the value of msm_round) (WNAF = windowed-non-adjacent-form. Value + // range is `-15, -13, + // ..., 15`) If `add_state.add = 1`, we want `add_state.slice` to be the *compressed* + // form of the WNAF slice value. (compressed = no gaps in the value range. i.e. -15, + // -13, ..., 15 maps to 0, ... , 15) + add_state.slice = add_state.add ? (slice + 15) / 2 : 0; add_state.point = add_state.add ? msm[idx + m].precomputed_table[static_cast(add_state.slice)] : AffineElement{ 0, 0 }; - bool add_predicate = add_state.add ? msm[idx + m].wnaf_skew : false; - auto p1 = accumulator; - accumulator = add_predicate ? accumulator + add_state.point : accumulator; + + Element p1 = accumulator; + Element p2 = Element(add_state.point); + accumulator = add_state.add ? (accumulator + add_state.point) : Element(p1); p1_trace[trace_index] = p1; - p2_trace[trace_index] = add_state.point; + p2_trace[trace_index] = p2; p3_trace[trace_index] = accumulator; operation_trace[trace_index] = false; trace_index++; } - row.q_add = false; + accumulator_trace[msm_row_index] = accumulator; + row.q_add = true; row.q_double = false; - row.q_skew = true; - row.msm_round = static_cast(j + 1); + row.q_skew = false; + row.msm_round = static_cast(j); row.msm_size = static_cast(msm_size); row.msm_count = static_cast(idx); row.pc = pc; + msm_row_index++; + } + // doubling + if (j < NUM_WNAF_DIGITS_PER_SCALAR - 1) { + auto& row = msm_state[msm_row_index]; + row.msm_transition = false; + row.msm_round = static_cast(j + 1); + row.msm_size = static_cast(msm_size); + row.msm_count = static_cast(0); + row.q_add = false; + row.q_double = true; + row.q_skew = false; + for (size_t m = 0; m < 4; ++m) { + + auto& add_state = row.add_state[m]; + add_state.add = false; + add_state.slice = 0; + add_state.point = { 0, 0 }; + add_state.collision_inverse = 0; + + p1_trace[trace_index] = accumulator; + p2_trace[trace_index] = accumulator; + accumulator = accumulator.dbl(); + p3_trace[trace_index] = accumulator; + operation_trace[trace_index] = true; + trace_index++; + } accumulator_trace[msm_row_index] = accumulator; msm_row_index++; + } else { + for (size_t k = 0; k < rows_per_round; ++k) { + auto& row = msm_state[msm_row_index]; + + const size_t points_per_row = (k + 1) * ADDITIONS_PER_ROW > msm_size + ? msm_size % ADDITIONS_PER_ROW + : ADDITIONS_PER_ROW; + const size_t idx = k * ADDITIONS_PER_ROW; + row.msm_transition = false; + + Element acc_expected = accumulator; + + for (size_t m = 0; m < 4; ++m) { + auto& add_state = row.add_state[m]; + add_state.add = points_per_row > m; + add_state.slice = add_state.add ? msm[idx + m].wnaf_skew ? 7 : 0 : 0; + + add_state.point = + add_state.add + ? msm[idx + m].precomputed_table[static_cast(add_state.slice)] + : AffineElement{ 0, 0 }; + bool add_predicate = add_state.add ? msm[idx + m].wnaf_skew : false; + auto p1 = accumulator; + accumulator = add_predicate ? accumulator + add_state.point : accumulator; + p1_trace[trace_index] = p1; + p2_trace[trace_index] = add_state.point; + p3_trace[trace_index] = accumulator; + operation_trace[trace_index] = false; + trace_index++; + } + row.q_add = false; + row.q_double = false; + row.q_skew = true; + row.msm_round = static_cast(j + 1); + row.msm_size = static_cast(msm_size); + row.msm_count = static_cast(idx); + row.pc = pc; + accumulator_trace[msm_row_index] = accumulator; + msm_row_index++; + } } } } - } - }); + }, + 1 << 30); // Normalize the points in the point trace run_loop_in_parallel(point_trace.size(), [&](size_t start, size_t end) { @@ -344,83 +349,88 @@ class ECCVMMSMMBuilder { // complete the computation of the ECCVM execution trace, by adding the affine intermediate point data // i.e. row.accumulator_x, row.accumulator_y, row.add_state[0...3].collision_inverse, // row.add_state[0...3].lambda - run_loop_in_parallel(msms.size(), [&](size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - const auto& msm = msms[i]; - size_t trace_index = ((msm_row_indices[i] - 1) * ADDITIONS_PER_ROW); - size_t msm_row_index = msm_row_indices[i]; - // 1st MSM row will have accumulator equal to the previous MSM output - // (or point at infinity for 1st MSM) - size_t accumulator_index = msm_row_indices[i] - 1; - const size_t msm_size = msm.size(); - const size_t rows_per_round = - (msm_size / ADDITIONS_PER_ROW) + (msm_size % ADDITIONS_PER_ROW != 0 ? 1 : 0); - - for (size_t j = 0; j < NUM_WNAF_DIGITS_PER_SCALAR; ++j) { - for (size_t k = 0; k < rows_per_round; ++k) { - auto& row = msm_state[msm_row_index]; - const Element& normalized_accumulator = accumulator_trace[accumulator_index]; - ASSERT(normalized_accumulator.is_point_at_infinity() == 0); - row.accumulator_x = normalized_accumulator.x; - row.accumulator_y = normalized_accumulator.y; - for (size_t m = 0; m < ADDITIONS_PER_ROW; ++m) { - auto& add_state = row.add_state[m]; - const auto& inverse = inverse_trace[trace_index]; - const auto& p1 = p1_trace[trace_index]; - const auto& p2 = p2_trace[trace_index]; - add_state.collision_inverse = add_state.add ? inverse : 0; - add_state.lambda = add_state.add ? (p2.y - p1.y) * inverse : 0; - trace_index++; - } - accumulator_index++; - msm_row_index++; - } - - if (j < NUM_WNAF_DIGITS_PER_SCALAR - 1) { - MSMRow& row = msm_state[msm_row_index]; - const Element& normalized_accumulator = accumulator_trace[accumulator_index]; - const FF& acc_x = normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.x; - const FF& acc_y = normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.y; - row.accumulator_x = acc_x; - row.accumulator_y = acc_y; - - for (size_t m = 0; m < 4; ++m) { - auto& add_state = row.add_state[m]; - add_state.collision_inverse = 0; - const FF& dx = p1_trace[trace_index].x; - const FF& inverse = inverse_trace[trace_index]; - add_state.lambda = ((dx + dx + dx) * dx) * inverse; - trace_index++; - } - accumulator_index++; - msm_row_index++; - } else { + run_loop_in_parallel( + msms.size(), + [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + const auto& msm = msms[i]; + size_t trace_index = ((msm_row_indices[i] - 1) * ADDITIONS_PER_ROW); + size_t msm_row_index = msm_row_indices[i]; + // 1st MSM row will have accumulator equal to the previous MSM output + // (or point at infinity for 1st MSM) + size_t accumulator_index = msm_row_indices[i] - 1; + const size_t msm_size = msm.size(); + const size_t rows_per_round = + (msm_size / ADDITIONS_PER_ROW) + (msm_size % ADDITIONS_PER_ROW != 0 ? 1 : 0); + + for (size_t j = 0; j < NUM_WNAF_DIGITS_PER_SCALAR; ++j) { for (size_t k = 0; k < rows_per_round; ++k) { - MSMRow& row = msm_state[msm_row_index]; + auto& row = msm_state[msm_row_index]; const Element& normalized_accumulator = accumulator_trace[accumulator_index]; ASSERT(normalized_accumulator.is_point_at_infinity() == 0); - const size_t idx = k * ADDITIONS_PER_ROW; row.accumulator_x = normalized_accumulator.x; row.accumulator_y = normalized_accumulator.y; - for (size_t m = 0; m < ADDITIONS_PER_ROW; ++m) { auto& add_state = row.add_state[m]; - bool add_predicate = add_state.add ? msm[idx + m].wnaf_skew : false; - const auto& inverse = inverse_trace[trace_index]; const auto& p1 = p1_trace[trace_index]; const auto& p2 = p2_trace[trace_index]; - add_state.collision_inverse = add_predicate ? inverse : 0; - add_state.lambda = add_predicate ? (p2.y - p1.y) * inverse : 0; + add_state.collision_inverse = add_state.add ? inverse : 0; + add_state.lambda = add_state.add ? (p2.y - p1.y) * inverse : 0; + trace_index++; + } + accumulator_index++; + msm_row_index++; + } + + if (j < NUM_WNAF_DIGITS_PER_SCALAR - 1) { + MSMRow& row = msm_state[msm_row_index]; + const Element& normalized_accumulator = accumulator_trace[accumulator_index]; + const FF& acc_x = + normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.x; + const FF& acc_y = + normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.y; + row.accumulator_x = acc_x; + row.accumulator_y = acc_y; + + for (size_t m = 0; m < 4; ++m) { + auto& add_state = row.add_state[m]; + add_state.collision_inverse = 0; + const FF& dx = p1_trace[trace_index].x; + const FF& inverse = inverse_trace[trace_index]; + add_state.lambda = ((dx + dx + dx) * dx) * inverse; trace_index++; } accumulator_index++; msm_row_index++; + } else { + for (size_t k = 0; k < rows_per_round; ++k) { + MSMRow& row = msm_state[msm_row_index]; + const Element& normalized_accumulator = accumulator_trace[accumulator_index]; + ASSERT(normalized_accumulator.is_point_at_infinity() == 0); + const size_t idx = k * ADDITIONS_PER_ROW; + row.accumulator_x = normalized_accumulator.x; + row.accumulator_y = normalized_accumulator.y; + + for (size_t m = 0; m < ADDITIONS_PER_ROW; ++m) { + auto& add_state = row.add_state[m]; + bool add_predicate = add_state.add ? msm[idx + m].wnaf_skew : false; + + const auto& inverse = inverse_trace[trace_index]; + const auto& p1 = p1_trace[trace_index]; + const auto& p2 = p2_trace[trace_index]; + add_state.collision_inverse = add_predicate ? inverse : 0; + add_state.lambda = add_predicate ? (p2.y - p1.y) * inverse : 0; + trace_index++; + } + accumulator_index++; + msm_row_index++; + } } } } - } - }); + }, + 1 << 30); // populate the final row in the MSM execution trace. // we always require 1 extra row at the end of the trace, because the accumulator x/y coordinates for row `i` From cc3d26129fdbf2c2b4b14d0f0cde74202c57794a Mon Sep 17 00:00:00 2001 From: codygunton Date: Thu, 16 May 2024 19:59:34 +0000 Subject: [PATCH 15/24] ClientIVC test now fails consistently with: `Relation ECCVMTranscriptRelation, subrelation index 24 failed at row 138` --- .../client_ivc/client_ivc.test.cpp | 90 ++++++++++--------- .../cpp/src/barretenberg/goblin/goblin.hpp | 2 + 2 files changed, 50 insertions(+), 42 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/client_ivc/client_ivc.test.cpp b/barretenberg/cpp/src/barretenberg/client_ivc/client_ivc.test.cpp index 0dd189112b8..c52e3632d96 100644 --- a/barretenberg/cpp/src/barretenberg/client_ivc/client_ivc.test.cpp +++ b/barretenberg/cpp/src/barretenberg/client_ivc/client_ivc.test.cpp @@ -113,53 +113,59 @@ class ClientIVCTests : public ::testing::Test { */ // TODO fix with https://github.com/AztecProtocol/barretenberg/issues/930 // intermittent failures, presumably due to uninitialized memory -TEST_F(ClientIVCTests, DISABLED_Full) +TEST_F(ClientIVCTests, Full) { using VerificationKey = Flavor::VerificationKey; - ClientIVC ivc; - // Initialize IVC with function circuit - Builder function_circuit = create_mock_circuit(ivc); - ivc.initialize(function_circuit); - - auto function_vk = std::make_shared(ivc.prover_fold_output.accumulator->proving_key); - auto foo_verifier_instance = std::make_shared(function_vk); - // Accumulate kernel circuit (first kernel mocked as simple circuit since no folding proofs yet) - Builder kernel_circuit = create_mock_circuit(ivc); - FoldProof kernel_fold_proof = ivc.accumulate(kernel_circuit); - // This will have a different verification key because we added the recursive merge verification to the circuit - auto function_vk_with_merge = std::make_shared(ivc.prover_instance->proving_key); - auto kernel_vk = function_vk_with_merge; - auto intermediary_acc = update_accumulator_and_decide_native( - ivc.prover_fold_output.accumulator, kernel_fold_proof, foo_verifier_instance, kernel_vk); - - VerifierFoldData kernel_fold_output = { kernel_fold_proof, function_vk_with_merge }; - size_t NUM_CIRCUITS = 1; - for (size_t circuit_idx = 0; circuit_idx < NUM_CIRCUITS; ++circuit_idx) { - // Accumulate function circuit + const auto run_test = []() { + ClientIVC ivc; + // Initialize IVC with function circuit Builder function_circuit = create_mock_circuit(ivc); - FoldProof function_fold_proof = ivc.accumulate(function_circuit); + ivc.initialize(function_circuit); - intermediary_acc = update_accumulator_and_decide_native( - ivc.prover_fold_output.accumulator, function_fold_proof, intermediary_acc, function_vk_with_merge); - - VerifierFoldData function_fold_output = { function_fold_proof, function_vk_with_merge }; - // Accumulate kernel circuit - Builder kernel_circuit{ ivc.goblin.op_queue }; - foo_verifier_instance = construct_mock_folding_kernel( - kernel_circuit, kernel_fold_output, function_fold_output, foo_verifier_instance); + auto function_vk = std::make_shared(ivc.prover_fold_output.accumulator->proving_key); + auto foo_verifier_instance = std::make_shared(function_vk); + // Accumulate kernel circuit (first kernel mocked as simple circuit since no folding proofs yet) + Builder kernel_circuit = create_mock_circuit(ivc); FoldProof kernel_fold_proof = ivc.accumulate(kernel_circuit); - kernel_vk = std::make_shared(ivc.prover_instance->proving_key); - - intermediary_acc = update_accumulator_and_decide_native( - ivc.prover_fold_output.accumulator, kernel_fold_proof, intermediary_acc, kernel_vk); - - VerifierFoldData kernel_fold_output = { kernel_fold_proof, kernel_vk }; + // This will have a different verification key because we added the recursive merge verification to the circuit + auto function_vk_with_merge = std::make_shared(ivc.prover_instance->proving_key); + auto kernel_vk = function_vk_with_merge; + auto intermediary_acc = update_accumulator_and_decide_native( + ivc.prover_fold_output.accumulator, kernel_fold_proof, foo_verifier_instance, kernel_vk); + + VerifierFoldData kernel_fold_output = { kernel_fold_proof, function_vk_with_merge }; + size_t NUM_CIRCUITS = 1; + for (size_t circuit_idx = 0; circuit_idx < NUM_CIRCUITS; ++circuit_idx) { + // Accumulate function circuit + Builder function_circuit = create_mock_circuit(ivc); + FoldProof function_fold_proof = ivc.accumulate(function_circuit); + + intermediary_acc = update_accumulator_and_decide_native( + ivc.prover_fold_output.accumulator, function_fold_proof, intermediary_acc, function_vk_with_merge); + + VerifierFoldData function_fold_output = { function_fold_proof, function_vk_with_merge }; + // Accumulate kernel circuit + Builder kernel_circuit{ ivc.goblin.op_queue }; + foo_verifier_instance = construct_mock_folding_kernel( + kernel_circuit, kernel_fold_output, function_fold_output, foo_verifier_instance); + FoldProof kernel_fold_proof = ivc.accumulate(kernel_circuit); + kernel_vk = std::make_shared(ivc.prover_instance->proving_key); + + intermediary_acc = update_accumulator_and_decide_native( + ivc.prover_fold_output.accumulator, kernel_fold_proof, intermediary_acc, kernel_vk); + + VerifierFoldData kernel_fold_output = { kernel_fold_proof, kernel_vk }; + } + + // Constuct four proofs: merge, eccvm, translator, decider + auto proof = ivc.prove(); + auto inst = std::make_shared(kernel_vk); + // Verify all four proofs + EXPECT_TRUE(ivc.verify(proof, { foo_verifier_instance, inst })); + }; + for (size_t idx = 0; idx < 1024; idx++) { + info("run ", idx); + run_test(); } - - // Constuct four proofs: merge, eccvm, translator, decider - auto proof = ivc.prove(); - auto inst = std::make_shared(kernel_vk); - // Verify all four proofs - EXPECT_TRUE(ivc.verify(proof, { foo_verifier_instance, inst })); }; \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/goblin/goblin.hpp b/barretenberg/cpp/src/barretenberg/goblin/goblin.hpp index 3823a510059..f1618caaa74 100644 --- a/barretenberg/cpp/src/barretenberg/goblin/goblin.hpp +++ b/barretenberg/cpp/src/barretenberg/goblin/goblin.hpp @@ -2,6 +2,7 @@ #include "barretenberg/eccvm/eccvm_circuit_builder.hpp" #include "barretenberg/eccvm/eccvm_prover.hpp" +#include "barretenberg/eccvm/eccvm_trace_checker.hpp" #include "barretenberg/eccvm/eccvm_verifier.hpp" #include "barretenberg/goblin/mock_circuits.hpp" #include "barretenberg/plonk_honk_shared/instance_inspector.hpp" @@ -162,6 +163,7 @@ class Goblin { void prove_eccvm() { eccvm_builder = std::make_unique(op_queue); + ASSERT(ECCVMTraceChecker::check(*eccvm_builder)); eccvm_prover = std::make_unique(*eccvm_builder); goblin_proof.eccvm_proof = eccvm_prover->construct_proof(); goblin_proof.translation_evaluations = eccvm_prover->translation_evaluations; From 00167feceb4471f93e6a5f43b37871e48d777834 Mon Sep 17 00:00:00 2001 From: codygunton Date: Thu, 16 May 2024 22:07:53 +0000 Subject: [PATCH 16/24] Trying to debug relation failure in ClientIVCTests.Full --- .../relations/ecc_vm/ecc_transcript_relation.cpp | 4 ++++ .../verifier/protogalaxy_recursive_verifier.hpp | 12 ++++++++++++ .../op_queue/ecc_op_queue.hpp | 11 +++++++++++ 3 files changed, 27 insertions(+) diff --git a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp index 47a0a3f8731..40029847105 100644 --- a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp +++ b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp @@ -424,6 +424,10 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu auto y_diff = lhs_y - rhs_y; auto y_product = transcript_Py_inverse * (-transcript_add_y_equal + 1) + transcript_add_y_equal; auto y_constant = transcript_add_y_equal - 1; + info("y_diff : ", y_diff); + info("transcript_Py_inverse : ", transcript_Py_inverse); + info("prod that is 1 or 0 : ", y_diff * transcript_Py_inverse); + info("transcript_add_y_equal: ", transcript_add_y_equal); auto transcript_add_y_equal_check_relation = (y_diff * y_product + y_constant) * any_add_is_active; std::get<24>(accumulator) += transcript_add_y_equal_check_relation * scaling_factor; // degree 5 } diff --git a/barretenberg/cpp/src/barretenberg/stdlib/honk_recursion/verifier/protogalaxy_recursive_verifier.hpp b/barretenberg/cpp/src/barretenberg/stdlib/honk_recursion/verifier/protogalaxy_recursive_verifier.hpp index dbc78b944e6..f3cac974d19 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/honk_recursion/verifier/protogalaxy_recursive_verifier.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/honk_recursion/verifier/protogalaxy_recursive_verifier.hpp @@ -183,6 +183,18 @@ template class ProtoGalaxyRecursiveVerifier_ { for (auto& instance : instances) { commitments.emplace_back(instance->verification_key->get_all()[vk_idx]); } + if (vk_idx == 26) { + info("folding ", accumulator->verification_key->get_labels()[vk_idx]); + info("scalars: "); + for (auto& scalar : lagranges) { + info(scalar.get_value()); + } + info("points: "); + for (auto& point : commitments) { + info(point.get_value()); + info("on curve?: ", point.get_value().on_curve()); + } + } expected_vk = Commitment::batch_mul(commitments, lagranges); vk_idx++; } diff --git a/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp b/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp index ae8dff828b9..5eb39135267 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp @@ -320,6 +320,9 @@ class ECCOpQueue { */ UltraOp add_accumulate(const Point& to_add) { + if (raw_ops.size() == 137) { + info("about to append bad op! adding ", to_add); + } // Update the accumulator natively accumulator = accumulator + to_add; @@ -350,6 +353,10 @@ class ECCOpQueue { */ UltraOp mul_accumulate(const Point& to_mul, const Fr& scalar) { + if (raw_ops.size() == 137) { + info("about to append bad op! adding ", scalar, " * ", to_mul); + } + // Update the accumulator natively accumulator = accumulator + to_mul * scalar; @@ -407,6 +414,10 @@ class ECCOpQueue { */ UltraOp eq_and_reset() { + if (raw_ops.size() == 137) { + info("about to append bad op! eq and reset with ", accumulator); + } + auto expected = accumulator; accumulator.self_set_infinity(); From 959d5503f01141b72760200d4d72ef7aab96ff82 Mon Sep 17 00:00:00 2001 From: codygunton Date: Mon, 20 May 2024 15:45:14 +0000 Subject: [PATCH 17/24] Hide logs --- .../ecc_vm/ecc_transcript_relation.cpp | 8 +++---- .../protogalaxy_recursive_verifier.hpp | 24 +++++++++---------- .../op_queue/ecc_op_queue.hpp | 18 +++++++------- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp index 40029847105..e50e970ce8f 100644 --- a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp +++ b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp @@ -424,10 +424,10 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu auto y_diff = lhs_y - rhs_y; auto y_product = transcript_Py_inverse * (-transcript_add_y_equal + 1) + transcript_add_y_equal; auto y_constant = transcript_add_y_equal - 1; - info("y_diff : ", y_diff); - info("transcript_Py_inverse : ", transcript_Py_inverse); - info("prod that is 1 or 0 : ", y_diff * transcript_Py_inverse); - info("transcript_add_y_equal: ", transcript_add_y_equal); + // info("y_diff : ", y_diff); + // info("transcript_Py_inverse : ", transcript_Py_inverse); + // info("prod that is 1 or 0 : ", y_diff * transcript_Py_inverse); + // info("transcript_add_y_equal: ", transcript_add_y_equal); auto transcript_add_y_equal_check_relation = (y_diff * y_product + y_constant) * any_add_is_active; std::get<24>(accumulator) += transcript_add_y_equal_check_relation * scaling_factor; // degree 5 } diff --git a/barretenberg/cpp/src/barretenberg/stdlib/honk_recursion/verifier/protogalaxy_recursive_verifier.hpp b/barretenberg/cpp/src/barretenberg/stdlib/honk_recursion/verifier/protogalaxy_recursive_verifier.hpp index f3cac974d19..cdb0f1312fd 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/honk_recursion/verifier/protogalaxy_recursive_verifier.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/honk_recursion/verifier/protogalaxy_recursive_verifier.hpp @@ -183,18 +183,18 @@ template class ProtoGalaxyRecursiveVerifier_ { for (auto& instance : instances) { commitments.emplace_back(instance->verification_key->get_all()[vk_idx]); } - if (vk_idx == 26) { - info("folding ", accumulator->verification_key->get_labels()[vk_idx]); - info("scalars: "); - for (auto& scalar : lagranges) { - info(scalar.get_value()); - } - info("points: "); - for (auto& point : commitments) { - info(point.get_value()); - info("on curve?: ", point.get_value().on_curve()); - } - } + // if (vk_idx == 26) { + // info("folding ", accumulator->verification_key->get_labels()[vk_idx]); + // info("scalars: "); + // for (auto& scalar : lagranges) { + // info(scalar.get_value()); + // } + // info("points: "); + // for (auto& point : commitments) { + // info(point.get_value()); + // info("on curve?: ", point.get_value().on_curve()); + // } + // } expected_vk = Commitment::batch_mul(commitments, lagranges); vk_idx++; } diff --git a/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp b/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp index 5eb39135267..dde601aa90c 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp @@ -320,9 +320,9 @@ class ECCOpQueue { */ UltraOp add_accumulate(const Point& to_add) { - if (raw_ops.size() == 137) { - info("about to append bad op! adding ", to_add); - } + // if (raw_ops.size() == 137) { + // info("about to append bad op! adding ", to_add); + // } // Update the accumulator natively accumulator = accumulator + to_add; @@ -353,9 +353,9 @@ class ECCOpQueue { */ UltraOp mul_accumulate(const Point& to_mul, const Fr& scalar) { - if (raw_ops.size() == 137) { - info("about to append bad op! adding ", scalar, " * ", to_mul); - } + // if (raw_ops.size() == 137) { + // info("about to append bad op! adding ", scalar, " * ", to_mul); + // } // Update the accumulator natively accumulator = accumulator + to_mul * scalar; @@ -414,9 +414,9 @@ class ECCOpQueue { */ UltraOp eq_and_reset() { - if (raw_ops.size() == 137) { - info("about to append bad op! eq and reset with ", accumulator); - } + // if (raw_ops.size() == 137) { + // info("about to append bad op! eq and reset with ", accumulator); + // } auto expected = accumulator; accumulator.self_set_infinity(); From 31bb78fd1bfda7a38d23fbf6195c6002ce741825 Mon Sep 17 00:00:00 2001 From: codygunton Date: Mon, 20 May 2024 21:16:39 +0000 Subject: [PATCH 18/24] Make test repeat --- .../client_ivc/client_ivc.test.cpp | 31 ++++++++++++------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/client_ivc/client_ivc.test.cpp b/barretenberg/cpp/src/barretenberg/client_ivc/client_ivc.test.cpp index 6a1676c4883..b4b9ccd0431 100644 --- a/barretenberg/cpp/src/barretenberg/client_ivc/client_ivc.test.cpp +++ b/barretenberg/cpp/src/barretenberg/client_ivc/client_ivc.test.cpp @@ -121,21 +121,28 @@ TEST_F(ClientIVCTests, BasicFailure) */ TEST_F(ClientIVCTests, BasicLarge) { - ClientIVC ivc; + const auto run_test = []() { + ClientIVC ivc; + + // Construct a set of arbitrary circuits + size_t NUM_CIRCUITS = 5; + std::vector circuits; + for (size_t idx = 0; idx < NUM_CIRCUITS; ++idx) { + circuits.emplace_back(create_mock_circuit(ivc)); + } - // Construct a set of arbitrary circuits - size_t NUM_CIRCUITS = 5; - std::vector circuits; - for (size_t idx = 0; idx < NUM_CIRCUITS; ++idx) { - circuits.emplace_back(create_mock_circuit(ivc)); - } + // Accumulate each circuit + for (auto& circuit : circuits) { + ivc.accumulate(circuit); + } - // Accumulate each circuit - for (auto& circuit : circuits) { - ivc.accumulate(circuit); + EXPECT_TRUE(prove_and_verify(ivc)); + }; + for (size_t idx = 0; idx < 256; idx++) { + numeric::get_debug_randomness(true, idx); + info("run ", idx); + run_test(); } - - EXPECT_TRUE(prove_and_verify(ivc)); }; /** From 7ffad2a4d24c5d6ff37f1509ecd9ad8600f8cced Mon Sep 17 00:00:00 2001 From: codygunton Date: Sat, 25 May 2024 22:13:57 +0000 Subject: [PATCH 19/24] Cleanup --- .../src/barretenberg/eccvm/msm_builder.hpp | 334 +++++++++--------- .../protogalaxy_recursive_verifier.hpp | 12 - .../op_queue/ecc_op_queue.hpp | 11 - 3 files changed, 162 insertions(+), 195 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/eccvm/msm_builder.hpp b/barretenberg/cpp/src/barretenberg/eccvm/msm_builder.hpp index d9050f0a7e4..565ba750cb1 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/msm_builder.hpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/msm_builder.hpp @@ -126,7 +126,7 @@ class ECCVMMSMMBuilder { // compute "read counts" so that we can determine the number of times entries in our log-derivative lookup // tables are called. - // Note: this part is single-threaded. THe amount of compute is low, however, so this is likely not a big + // Note: this part is single-threaded. The amount of compute is low, however, so this is likely not a big // concern. for (size_t i = 0; i < msms.size(); ++i) { @@ -195,138 +195,133 @@ class ECCVMMSMMBuilder { // populate point trace data, and the components of the MSM execution trace that do not relate to affine point // operations - run_loop_in_parallel( - msms.size(), - [&](size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - Element accumulator = offset_generator; - const auto& msm = msms[i]; - size_t msm_row_index = msm_row_indices[i]; - const size_t msm_size = msm.size(); - const size_t rows_per_round = - (msm_size / ADDITIONS_PER_ROW) + (msm_size % ADDITIONS_PER_ROW != 0 ? 1 : 0); - size_t trace_index = (msm_row_indices[i] - 1) * 4; - - for (size_t j = 0; j < NUM_WNAF_DIGITS_PER_SCALAR; ++j) { - const uint32_t pc = static_cast(pc_indices[i]); + run_loop_in_parallel(msms.size(), [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + Element accumulator = offset_generator; + const auto& msm = msms[i]; + size_t msm_row_index = msm_row_indices[i]; + const size_t msm_size = msm.size(); + const size_t rows_per_round = + (msm_size / ADDITIONS_PER_ROW) + (msm_size % ADDITIONS_PER_ROW != 0 ? 1 : 0); + size_t trace_index = (msm_row_indices[i] - 1) * 4; + + for (size_t j = 0; j < NUM_WNAF_DIGITS_PER_SCALAR; ++j) { + const uint32_t pc = static_cast(pc_indices[i]); + for (size_t k = 0; k < rows_per_round; ++k) { + const size_t points_per_row = + (k + 1) * ADDITIONS_PER_ROW > msm_size ? msm_size % ADDITIONS_PER_ROW : ADDITIONS_PER_ROW; + auto& row = msm_state[msm_row_index]; + const size_t idx = k * ADDITIONS_PER_ROW; + row.msm_transition = (j == 0) && (k == 0); + for (size_t m = 0; m < ADDITIONS_PER_ROW; ++m) { + + auto& add_state = row.add_state[m]; + add_state.add = points_per_row > m; + int slice = add_state.add ? msm[idx + m].wnaf_digits[j] : 0; + // In the MSM columns in the ECCVM circuit, we can add up to 4 points per row. + // if `row.add_state[m].add = 1`, this indicates that we want to add the `m`'th point in + // the MSM columns into the MSM accumulator `add_state.slice` = A 4-bit WNAF slice of + // the scalar multiplier associated with the point we are adding (the specific slice + // chosen depends on the value of msm_round) (WNAF = windowed-non-adjacent-form. Value + // range is `-15, -13, + // ..., 15`) If `add_state.add = 1`, we want `add_state.slice` to be the *compressed* + // form of the WNAF slice value. (compressed = no gaps in the value range. i.e. -15, + // -13, ..., 15 maps to 0, ... , 15) + add_state.slice = add_state.add ? (slice + 15) / 2 : 0; + add_state.point = add_state.add + ? msm[idx + m].precomputed_table[static_cast(add_state.slice)] + : AffineElement{ 0, 0 }; + + Element p1 = accumulator; + Element p2 = Element(add_state.point); + accumulator = add_state.add ? (accumulator + add_state.point) : Element(p1); + p1_trace[trace_index] = p1; + p2_trace[trace_index] = p2; + p3_trace[trace_index] = accumulator; + operation_trace[trace_index] = false; + trace_index++; + } + accumulator_trace[msm_row_index] = accumulator; + row.q_add = true; + row.q_double = false; + row.q_skew = false; + row.msm_round = static_cast(j); + row.msm_size = static_cast(msm_size); + row.msm_count = static_cast(idx); + row.pc = pc; + msm_row_index++; + } + // doubling + if (j < NUM_WNAF_DIGITS_PER_SCALAR - 1) { + auto& row = msm_state[msm_row_index]; + row.msm_transition = false; + row.msm_round = static_cast(j + 1); + row.msm_size = static_cast(msm_size); + row.msm_count = static_cast(0); + row.q_add = false; + row.q_double = true; + row.q_skew = false; + for (size_t m = 0; m < 4; ++m) { + + auto& add_state = row.add_state[m]; + add_state.add = false; + add_state.slice = 0; + add_state.point = { 0, 0 }; + add_state.collision_inverse = 0; + + p1_trace[trace_index] = accumulator; + p2_trace[trace_index] = accumulator; + accumulator = accumulator.dbl(); + p3_trace[trace_index] = accumulator; + operation_trace[trace_index] = true; + trace_index++; + } + accumulator_trace[msm_row_index] = accumulator; + msm_row_index++; + } else { for (size_t k = 0; k < rows_per_round; ++k) { + auto& row = msm_state[msm_row_index]; + const size_t points_per_row = (k + 1) * ADDITIONS_PER_ROW > msm_size ? msm_size % ADDITIONS_PER_ROW : ADDITIONS_PER_ROW; - auto& row = msm_state[msm_row_index]; const size_t idx = k * ADDITIONS_PER_ROW; - row.msm_transition = (j == 0) && (k == 0); - for (size_t m = 0; m < ADDITIONS_PER_ROW; ++m) { + row.msm_transition = false; + + Element acc_expected = accumulator; + for (size_t m = 0; m < 4; ++m) { auto& add_state = row.add_state[m]; add_state.add = points_per_row > m; - int slice = add_state.add ? msm[idx + m].wnaf_digits[j] : 0; - // In the MSM columns in the ECCVM circuit, we can add up to 4 points per row. - // if `row.add_state[m].add = 1`, this indicates that we want to add the `m`'th point in - // the MSM columns into the MSM accumulator `add_state.slice` = A 4-bit WNAF slice of - // the scalar multiplier associated with the point we are adding (the specific slice - // chosen depends on the value of msm_round) (WNAF = windowed-non-adjacent-form. Value - // range is `-15, -13, - // ..., 15`) If `add_state.add = 1`, we want `add_state.slice` to be the *compressed* - // form of the WNAF slice value. (compressed = no gaps in the value range. i.e. -15, - // -13, ..., 15 maps to 0, ... , 15) - add_state.slice = add_state.add ? (slice + 15) / 2 : 0; + add_state.slice = add_state.add ? msm[idx + m].wnaf_skew ? 7 : 0 : 0; + add_state.point = add_state.add ? msm[idx + m].precomputed_table[static_cast(add_state.slice)] : AffineElement{ 0, 0 }; - - Element p1 = accumulator; - Element p2 = Element(add_state.point); - accumulator = add_state.add ? (accumulator + add_state.point) : Element(p1); + bool add_predicate = add_state.add ? msm[idx + m].wnaf_skew : false; + auto p1 = accumulator; + accumulator = add_predicate ? accumulator + add_state.point : accumulator; p1_trace[trace_index] = p1; - p2_trace[trace_index] = p2; + p2_trace[trace_index] = add_state.point; p3_trace[trace_index] = accumulator; operation_trace[trace_index] = false; trace_index++; } - accumulator_trace[msm_row_index] = accumulator; - row.q_add = true; + row.q_add = false; row.q_double = false; - row.q_skew = false; - row.msm_round = static_cast(j); + row.q_skew = true; + row.msm_round = static_cast(j + 1); row.msm_size = static_cast(msm_size); row.msm_count = static_cast(idx); row.pc = pc; - msm_row_index++; - } - // doubling - if (j < NUM_WNAF_DIGITS_PER_SCALAR - 1) { - auto& row = msm_state[msm_row_index]; - row.msm_transition = false; - row.msm_round = static_cast(j + 1); - row.msm_size = static_cast(msm_size); - row.msm_count = static_cast(0); - row.q_add = false; - row.q_double = true; - row.q_skew = false; - for (size_t m = 0; m < 4; ++m) { - - auto& add_state = row.add_state[m]; - add_state.add = false; - add_state.slice = 0; - add_state.point = { 0, 0 }; - add_state.collision_inverse = 0; - - p1_trace[trace_index] = accumulator; - p2_trace[trace_index] = accumulator; - accumulator = accumulator.dbl(); - p3_trace[trace_index] = accumulator; - operation_trace[trace_index] = true; - trace_index++; - } accumulator_trace[msm_row_index] = accumulator; msm_row_index++; - } else { - for (size_t k = 0; k < rows_per_round; ++k) { - auto& row = msm_state[msm_row_index]; - - const size_t points_per_row = (k + 1) * ADDITIONS_PER_ROW > msm_size - ? msm_size % ADDITIONS_PER_ROW - : ADDITIONS_PER_ROW; - const size_t idx = k * ADDITIONS_PER_ROW; - row.msm_transition = false; - - Element acc_expected = accumulator; - - for (size_t m = 0; m < 4; ++m) { - auto& add_state = row.add_state[m]; - add_state.add = points_per_row > m; - add_state.slice = add_state.add ? msm[idx + m].wnaf_skew ? 7 : 0 : 0; - - add_state.point = - add_state.add - ? msm[idx + m].precomputed_table[static_cast(add_state.slice)] - : AffineElement{ 0, 0 }; - bool add_predicate = add_state.add ? msm[idx + m].wnaf_skew : false; - auto p1 = accumulator; - accumulator = add_predicate ? accumulator + add_state.point : accumulator; - p1_trace[trace_index] = p1; - p2_trace[trace_index] = add_state.point; - p3_trace[trace_index] = accumulator; - operation_trace[trace_index] = false; - trace_index++; - } - row.q_add = false; - row.q_double = false; - row.q_skew = true; - row.msm_round = static_cast(j + 1); - row.msm_size = static_cast(msm_size); - row.msm_count = static_cast(idx); - row.pc = pc; - accumulator_trace[msm_row_index] = accumulator; - msm_row_index++; - } } } } - }, - 1 << 30); + } + }); // Normalize the points in the point trace run_loop_in_parallel(point_trace.size(), [&](size_t start, size_t end) { @@ -349,88 +344,83 @@ class ECCVMMSMMBuilder { // complete the computation of the ECCVM execution trace, by adding the affine intermediate point data // i.e. row.accumulator_x, row.accumulator_y, row.add_state[0...3].collision_inverse, // row.add_state[0...3].lambda - run_loop_in_parallel( - msms.size(), - [&](size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - const auto& msm = msms[i]; - size_t trace_index = ((msm_row_indices[i] - 1) * ADDITIONS_PER_ROW); - size_t msm_row_index = msm_row_indices[i]; - // 1st MSM row will have accumulator equal to the previous MSM output - // (or point at infinity for 1st MSM) - size_t accumulator_index = msm_row_indices[i] - 1; - const size_t msm_size = msm.size(); - const size_t rows_per_round = - (msm_size / ADDITIONS_PER_ROW) + (msm_size % ADDITIONS_PER_ROW != 0 ? 1 : 0); - - for (size_t j = 0; j < NUM_WNAF_DIGITS_PER_SCALAR; ++j) { + run_loop_in_parallel(msms.size(), [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + const auto& msm = msms[i]; + size_t trace_index = ((msm_row_indices[i] - 1) * ADDITIONS_PER_ROW); + size_t msm_row_index = msm_row_indices[i]; + // 1st MSM row will have accumulator equal to the previous MSM output + // (or point at infinity for 1st MSM) + size_t accumulator_index = msm_row_indices[i] - 1; + const size_t msm_size = msm.size(); + const size_t rows_per_round = + (msm_size / ADDITIONS_PER_ROW) + (msm_size % ADDITIONS_PER_ROW != 0 ? 1 : 0); + + for (size_t j = 0; j < NUM_WNAF_DIGITS_PER_SCALAR; ++j) { + for (size_t k = 0; k < rows_per_round; ++k) { + auto& row = msm_state[msm_row_index]; + const Element& normalized_accumulator = accumulator_trace[accumulator_index]; + ASSERT(normalized_accumulator.is_point_at_infinity() == 0); + row.accumulator_x = normalized_accumulator.x; + row.accumulator_y = normalized_accumulator.y; + for (size_t m = 0; m < ADDITIONS_PER_ROW; ++m) { + auto& add_state = row.add_state[m]; + const auto& inverse = inverse_trace[trace_index]; + const auto& p1 = p1_trace[trace_index]; + const auto& p2 = p2_trace[trace_index]; + add_state.collision_inverse = add_state.add ? inverse : 0; + add_state.lambda = add_state.add ? (p2.y - p1.y) * inverse : 0; + trace_index++; + } + accumulator_index++; + msm_row_index++; + } + + if (j < NUM_WNAF_DIGITS_PER_SCALAR - 1) { + MSMRow& row = msm_state[msm_row_index]; + const Element& normalized_accumulator = accumulator_trace[accumulator_index]; + const FF& acc_x = normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.x; + const FF& acc_y = normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.y; + row.accumulator_x = acc_x; + row.accumulator_y = acc_y; + + for (size_t m = 0; m < 4; ++m) { + auto& add_state = row.add_state[m]; + add_state.collision_inverse = 0; + const FF& dx = p1_trace[trace_index].x; + const FF& inverse = inverse_trace[trace_index]; + add_state.lambda = ((dx + dx + dx) * dx) * inverse; + trace_index++; + } + accumulator_index++; + msm_row_index++; + } else { for (size_t k = 0; k < rows_per_round; ++k) { - auto& row = msm_state[msm_row_index]; + MSMRow& row = msm_state[msm_row_index]; const Element& normalized_accumulator = accumulator_trace[accumulator_index]; ASSERT(normalized_accumulator.is_point_at_infinity() == 0); + const size_t idx = k * ADDITIONS_PER_ROW; row.accumulator_x = normalized_accumulator.x; row.accumulator_y = normalized_accumulator.y; + for (size_t m = 0; m < ADDITIONS_PER_ROW; ++m) { auto& add_state = row.add_state[m]; + bool add_predicate = add_state.add ? msm[idx + m].wnaf_skew : false; + const auto& inverse = inverse_trace[trace_index]; const auto& p1 = p1_trace[trace_index]; const auto& p2 = p2_trace[trace_index]; - add_state.collision_inverse = add_state.add ? inverse : 0; - add_state.lambda = add_state.add ? (p2.y - p1.y) * inverse : 0; - trace_index++; - } - accumulator_index++; - msm_row_index++; - } - - if (j < NUM_WNAF_DIGITS_PER_SCALAR - 1) { - MSMRow& row = msm_state[msm_row_index]; - const Element& normalized_accumulator = accumulator_trace[accumulator_index]; - const FF& acc_x = - normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.x; - const FF& acc_y = - normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.y; - row.accumulator_x = acc_x; - row.accumulator_y = acc_y; - - for (size_t m = 0; m < 4; ++m) { - auto& add_state = row.add_state[m]; - add_state.collision_inverse = 0; - const FF& dx = p1_trace[trace_index].x; - const FF& inverse = inverse_trace[trace_index]; - add_state.lambda = ((dx + dx + dx) * dx) * inverse; + add_state.collision_inverse = add_predicate ? inverse : 0; + add_state.lambda = add_predicate ? (p2.y - p1.y) * inverse : 0; trace_index++; } accumulator_index++; msm_row_index++; - } else { - for (size_t k = 0; k < rows_per_round; ++k) { - MSMRow& row = msm_state[msm_row_index]; - const Element& normalized_accumulator = accumulator_trace[accumulator_index]; - ASSERT(normalized_accumulator.is_point_at_infinity() == 0); - const size_t idx = k * ADDITIONS_PER_ROW; - row.accumulator_x = normalized_accumulator.x; - row.accumulator_y = normalized_accumulator.y; - - for (size_t m = 0; m < ADDITIONS_PER_ROW; ++m) { - auto& add_state = row.add_state[m]; - bool add_predicate = add_state.add ? msm[idx + m].wnaf_skew : false; - - const auto& inverse = inverse_trace[trace_index]; - const auto& p1 = p1_trace[trace_index]; - const auto& p2 = p2_trace[trace_index]; - add_state.collision_inverse = add_predicate ? inverse : 0; - add_state.lambda = add_predicate ? (p2.y - p1.y) * inverse : 0; - trace_index++; - } - accumulator_index++; - msm_row_index++; - } } } } - }, - 1 << 30); + } + }); // populate the final row in the MSM execution trace. // we always require 1 extra row at the end of the trace, because the accumulator x/y coordinates for row `i` diff --git a/barretenberg/cpp/src/barretenberg/stdlib/honk_recursion/verifier/protogalaxy_recursive_verifier.hpp b/barretenberg/cpp/src/barretenberg/stdlib/honk_recursion/verifier/protogalaxy_recursive_verifier.hpp index 9dbe3eccfc9..e16c6677f95 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/honk_recursion/verifier/protogalaxy_recursive_verifier.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/honk_recursion/verifier/protogalaxy_recursive_verifier.hpp @@ -183,18 +183,6 @@ template class ProtoGalaxyRecursiveVerifier_ { for (auto& instance : instances) { commitments.emplace_back(instance->verification_key->get_all()[vk_idx]); } - // if (vk_idx == 26) { - // info("folding ", accumulator->verification_key->get_labels()[vk_idx]); - // info("scalars: "); - // for (auto& scalar : lagranges) { - // info(scalar.get_value()); - // } - // info("points: "); - // for (auto& point : commitments) { - // info(point.get_value()); - // info("on curve?: ", point.get_value().on_curve()); - // } - // } expected_vk = Commitment::batch_mul(commitments, lagranges); vk_idx++; } diff --git a/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp b/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp index b14081200b9..0992924b2e3 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp @@ -321,9 +321,6 @@ class ECCOpQueue { */ UltraOp add_accumulate(const Point& to_add) { - // if (raw_ops.size() == 137) { - // info("about to append bad op! adding ", to_add); - // } // Update the accumulator natively accumulator = accumulator + to_add; @@ -354,10 +351,6 @@ class ECCOpQueue { */ UltraOp mul_accumulate(const Point& to_mul, const Fr& scalar) { - // if (raw_ops.size() == 137) { - // info("about to append bad op! adding ", scalar, " * ", to_mul); - // } - // Update the accumulator natively accumulator = accumulator + to_mul * scalar; @@ -415,10 +408,6 @@ class ECCOpQueue { */ UltraOp eq_and_reset() { - // if (raw_ops.size() == 137) { - // info("about to append bad op! eq and reset with ", accumulator); - // } - auto expected = accumulator; accumulator.self_set_infinity(); From 49c71d4ddc79f18710c4434862f1f84b7492b19b Mon Sep 17 00:00:00 2001 From: codygunton Date: Sun, 26 May 2024 03:50:57 +0000 Subject: [PATCH 20/24] Remove circuit checking in ASSERT --- barretenberg/cpp/src/barretenberg/goblin/goblin.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/barretenberg/cpp/src/barretenberg/goblin/goblin.hpp b/barretenberg/cpp/src/barretenberg/goblin/goblin.hpp index ea9ce1fc174..95f0ca86435 100644 --- a/barretenberg/cpp/src/barretenberg/goblin/goblin.hpp +++ b/barretenberg/cpp/src/barretenberg/goblin/goblin.hpp @@ -164,7 +164,6 @@ class Goblin { void prove_eccvm() { eccvm_builder = std::make_unique(op_queue); - ASSERT(ECCVMTraceChecker::check(*eccvm_builder)); eccvm_prover = std::make_unique(*eccvm_builder); goblin_proof.eccvm_proof = eccvm_prover->construct_proof(); goblin_proof.translation_evaluations = eccvm_prover->translation_evaluations; From c2ef8a516c937630ff8177b8d7a5db1743371ed4 Mon Sep 17 00:00:00 2001 From: codygunton Date: Sun, 26 May 2024 04:39:21 +0000 Subject: [PATCH 21/24] Manually merge msm_builder refactor --- .../src/barretenberg/eccvm/eccvm_flavor.hpp | 5 +- .../src/barretenberg/eccvm/msm_builder.hpp | 565 +++++++++--------- 2 files changed, 291 insertions(+), 279 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp index 88e072433bd..c57a92b9894 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp @@ -493,9 +493,8 @@ class ECCVMFlavor { const std::vector msms = builder.get_msms(); const auto point_table_rows = ECCVMPointTablePrecomputationBuilder::compute_rows(CircuitBuilder::get_flattened_scalar_muls(msms)); - std::array, 2> point_table_read_counts; - const auto msm_rows = ECCVMMSMMBuilder::compute_rows( - msms, point_table_read_counts, builder.get_number_of_muls(), builder.op_queue->get_num_msm_rows()); + const auto [msm_rows, point_table_read_counts] = ECCVMMSMMBuilder::compute_rows( + msms, builder.get_number_of_muls(), builder.op_queue->get_num_msm_rows()); const size_t num_rows = std::max({ point_table_rows.size(), msm_rows.size(), transcript_rows.size() }); const auto log_num_rows = static_cast(numeric::get_msb64(num_rows)); diff --git a/barretenberg/cpp/src/barretenberg/eccvm/msm_builder.hpp b/barretenberg/cpp/src/barretenberg/eccvm/msm_builder.hpp index 565ba750cb1..006e21b7685 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/msm_builder.hpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/msm_builder.hpp @@ -13,6 +13,7 @@ class ECCVMMSMMBuilder { using FF = curve::Grumpkin::ScalarField; using Element = typename CycleGroup::element; using AffineElement = typename CycleGroup::affine_element; + using MSM = bb::eccvm::MSM; static constexpr size_t ADDITIONS_PER_ROW = bb::eccvm::ADDITIONS_PER_ROW; static constexpr size_t NUM_WNAF_DIGITS_PER_SCALAR = bb::eccvm::NUM_WNAF_DIGITS_PER_SCALAR; @@ -50,116 +51,132 @@ class ECCVMMSMMBuilder { * For a detailed description of the Straus algorithm and its relation to the ECCVM, please see * https://hackmd.io/@aztec-network/rJ5xhuCsn * - * @param msms - * @param point_table_read_counts - * @param total_number_of_muls + * @param msms A vector of vectors of ScalarMuls. + * @param point_table_read_counts Table of read counts to be populated. + * @param total_number_of_muls A mul op in the OpQueue adds up to two muls, one for each nonzero z_i (i=1,2). + * @param num_msm_rows * @return std::vector */ - static std::vector compute_rows(const std::vector>& msms, - std::array, 2>& point_table_read_counts, - const uint32_t total_number_of_muls, - const size_t num_msm_rows) + static std::tuple, std::array, 2>> compute_rows( + const std::vector& msms, const uint32_t total_number_of_muls, const size_t num_msm_rows) { - // N.B. the following comments refer to a "point lookup table" frequently. - // To perform a scalar multiplicaiton of a point [P] by a scalar x, we compute multiples of [P] and store in a - // table: specifically: -15[P], -13[P], ..., -3[P], -[P], [P], 3[P], ..., 15[P] when we define our point lookup - // table, we have 2 write columns and 4 read columns when we perform a read on a given row, we need to increment - // the read count on the respective write column by 1 we can define the following struture: 1st write column = - // positive 2nd write column = negative the row number is a function of pc and slice value row = pc_delta * - // rows_per_point_table + some function of the slice value pc_delta = total_number_of_muls - pc - // std::vector point_table_read_counts; - const size_t table_rows = static_cast(total_number_of_muls) * 8; - point_table_read_counts[0].reserve(table_rows); - point_table_read_counts[1].reserve(table_rows); - for (size_t i = 0; i < table_rows; ++i) { + // To perform a scalar multiplication of a point P by a scalar x, we precompute a table of points + // -15P, -13P, ..., -3P, -P, P, 3P, ..., 15P + // When we perform a scalar multiplication, we decompose x into base-16 wNAF digits then look these precomputed + // values up with digit-by-digit. We record read counts in a table with the following structure: + // 1st write column = positive wNAF digits + // 2nd write column = negative wNAF digits + // the row number is a function of pc and wnaf digit: + // point_idx = total_number_of_muls - pc + // row = point_idx * rows_per_point_table + (some function of the slice value) + // + // Illustration: + // Block Structure Table structure: + // | 0 | 1 | | Block_{0} | <-- pc = total_number_of_muls + // | - | - | | Block_{1} | <-- pc = total_number_of_muls-(num muls in msm 0) + // 1 | # | # | -1 | ... | ... + // 3 | # | # | -3 | Block_{total_number_of_muls-1} | <-- pc = num muls in last msm + // 5 | # | # | -5 + // 7 | # | # | -7 + // 9 | # | # | -9 + // 11 | # | # | -11 + // 13 | # | # | -13 + // 15 | # | # | -15 + + const size_t num_rows_in_read_counts_table = + static_cast(total_number_of_muls) * (eccvm::POINT_TABLE_SIZE >> 1); + std::array, 2> point_table_read_counts; + point_table_read_counts[0].reserve(num_rows_in_read_counts_table); + point_table_read_counts[1].reserve(num_rows_in_read_counts_table); + for (size_t i = 0; i < num_rows_in_read_counts_table; ++i) { point_table_read_counts[0].emplace_back(0); point_table_read_counts[1].emplace_back(0); } - const auto update_read_counts = [&](const size_t pc, const int slice) { - // When we compute our wnaf/point tables, we start with the point with the largest pc value. - // i.e. if we are reading a slice for point with a point counter value `pc`, - // its position in the wnaf/point table (relative to other points) will be `total_number_of_muls - pc` - const size_t pc_delta = total_number_of_muls - pc; - const size_t pc_offset = pc_delta * 8; - bool slice_negative = slice < 0; - const int slice_row = (slice + 15) / 2; - - const size_t column_index = slice_negative ? 1 : 0; + const auto update_read_count = [&point_table_read_counts](const size_t point_idx, const int slice) { /** - * When computing `point_table_read_counts`, we need the *table index* that a given point belongs to. - * the slice value is in *compressed* windowed-non-adjacent-form format: - * A non-compressed WNAF slice is in the range: `-15, -13, ..., 15` - * In compressed form, tney become `0, ..., 15` + * The wNAF digits for base 16 lie in the range -15, -13, ..., 13, 15. * The *point table* format is the following: - * (for positive point table) T[0] = P, T[1] = PT, ..., T[7] = 15P + * (for positive point table) T[0] = P, T[1] = 3P, ..., T[7] = 15P * (for negative point table) T[0] = -P, T[1] = -3P, ..., T[15] = -15P * i.e. if the slice value is negative, we can use the compressed WNAF directly as the table index - * if the slice value is positive, we must take `15 - compressedWNAF` to get the table index + * if the slice value is positive, we must take 15 - (compressed wNAF) to get the table index */ - if (slice_negative) { - point_table_read_counts[column_index][pc_offset + static_cast(slice_row)]++; + const size_t row_index_offset = point_idx * 8; + const bool digit_is_negative = slice < 0; + const auto relative_row_idx = static_cast((slice + 15) / 2); + const size_t column_index = digit_is_negative ? 1 : 0; + + if (digit_is_negative) { + point_table_read_counts[column_index][row_index_offset + relative_row_idx]++; } else { - point_table_read_counts[column_index][pc_offset + 15 - static_cast(slice_row)]++; + point_table_read_counts[column_index][row_index_offset + 15 - relative_row_idx]++; } }; // compute which row index each multiscalar multiplication will start at. - // also compute the program counter index that each multiscalar multiplication will start at. - // we use this information to populate the MSM row data across multiple threads - std::vector msm_row_indices; - std::vector pc_indices; - msm_row_indices.reserve(msms.size() + 1); - pc_indices.reserve(msms.size() + 1); - - msm_row_indices.push_back(1); - pc_indices.push_back(total_number_of_muls); + std::vector msm_row_counts; + msm_row_counts.reserve(msms.size() + 1); + msm_row_counts.push_back(1); + // compute the program counter (i.e. the index among all single scalar muls) that each multiscalar + // multiplication will start at. + std::vector pc_values; + pc_values.reserve(msms.size() + 1); + pc_values.push_back(total_number_of_muls); for (const auto& msm : msms) { - const size_t rows = ECCOpQueue::num_eccvm_msm_rows(msm.size()); - msm_row_indices.push_back(msm_row_indices.back() + rows); - pc_indices.push_back(pc_indices.back() - msm.size()); + const size_t num_rows_required = ECCOpQueue::num_eccvm_msm_rows(msm.size()); + msm_row_counts.push_back(msm_row_counts.back() + num_rows_required); + pc_values.push_back(pc_values.back() - msm.size()); } + ASSERT(pc_values.back() == 0); - std::vector msm_state(num_msm_rows); - // start with empty row (shiftable polynomials must have 0 as first coefficient) - msm_state[0] = (MSMRow{}); + // compute the MSM rows + std::vector msm_rows(num_msm_rows); + // start with empty row (shiftable polynomials must have 0 as first coefficient) + msm_rows[0] = (MSMRow{}); // compute "read counts" so that we can determine the number of times entries in our log-derivative lookup // tables are called. // Note: this part is single-threaded. The amount of compute is low, however, so this is likely not a big // concern. - for (size_t i = 0; i < msms.size(); ++i) { - - for (size_t j = 0; j < NUM_WNAF_DIGITS_PER_SCALAR; ++j) { - uint32_t pc = static_cast(pc_indices[i]); - const auto& msm = msms[i]; + for (size_t msm_idx = 0; msm_idx < msms.size(); ++msm_idx) { + for (size_t digit_idx = 0; digit_idx < NUM_WNAF_DIGITS_PER_SCALAR; ++digit_idx) { + auto pc = static_cast(pc_values[msm_idx]); + const auto& msm = msms[msm_idx]; const size_t msm_size = msm.size(); - const size_t rows_per_round = - (msm_size / ADDITIONS_PER_ROW) + (msm_size % ADDITIONS_PER_ROW != 0 ? 1 : 0); - - for (size_t k = 0; k < rows_per_round; ++k) { - const size_t points_per_row = - (k + 1) * ADDITIONS_PER_ROW > msm_size ? msm_size % ADDITIONS_PER_ROW : ADDITIONS_PER_ROW; - const size_t idx = k * ADDITIONS_PER_ROW; - for (size_t m = 0; m < ADDITIONS_PER_ROW; ++m) { - bool add = points_per_row > m; + const size_t num_rows_per_digit = + (msm_size / ADDITIONS_PER_ROW) + ((msm_size % ADDITIONS_PER_ROW != 0) ? 1 : 0); + + for (size_t relative_row_idx = 0; relative_row_idx < num_rows_per_digit; ++relative_row_idx) { + const size_t num_points_in_row = (relative_row_idx + 1) * ADDITIONS_PER_ROW > msm_size + ? (msm_size % ADDITIONS_PER_ROW) + : ADDITIONS_PER_ROW; + const size_t offset = relative_row_idx * ADDITIONS_PER_ROW; + for (size_t relative_point_idx = 0; relative_point_idx < ADDITIONS_PER_ROW; ++relative_point_idx) { + const size_t point_idx = offset + relative_point_idx; + const bool add = num_points_in_row > relative_point_idx; if (add) { - int slice = add ? msm[idx + m].wnaf_digits[j] : 0; - update_read_counts(pc - idx - m, slice); + int slice = msm[point_idx].wnaf_digits[digit_idx]; + // pc starts at total_number_of_muls and decreses non-uniformly to 0 + update_read_count((total_number_of_muls - pc) + point_idx, slice); } } } - if (j == NUM_WNAF_DIGITS_PER_SCALAR - 1) { - for (size_t k = 0; k < rows_per_round; ++k) { - const size_t points_per_row = - (k + 1) * ADDITIONS_PER_ROW > msm_size ? msm_size % ADDITIONS_PER_ROW : ADDITIONS_PER_ROW; - const size_t idx = k * ADDITIONS_PER_ROW; - for (size_t m = 0; m < 4; ++m) { - bool add = points_per_row > m; - + if (digit_idx == NUM_WNAF_DIGITS_PER_SCALAR - 1) { + for (size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) { + const size_t num_points_in_row = (row_idx + 1) * ADDITIONS_PER_ROW > msm_size + ? (msm_size % ADDITIONS_PER_ROW) + : ADDITIONS_PER_ROW; + const size_t offset = row_idx * ADDITIONS_PER_ROW; + for (size_t relative_point_idx = 0; relative_point_idx < ADDITIONS_PER_ROW; + ++relative_point_idx) { + bool add = num_points_in_row > relative_point_idx; + const size_t point_idx = offset + relative_point_idx; if (add) { - update_read_counts(pc - idx - m, msm[idx + m].wnaf_skew ? -1 : -15); + // pc starts at total_number_of_muls and decreses non-uniformly to 0 + int slice = msm[point_idx].wnaf_skew ? -1 : -15; + update_read_count((total_number_of_muls - pc) + point_idx, slice); } } } @@ -169,173 +186,173 @@ class ECCVMMSMMBuilder { // The execution trace data for the MSM columns requires knowledge of intermediate values from *affine* point // addition. The naive solution to compute this data requires 2 field inversions per in-circuit group addition - // evaluation. This is bad! To avoid this, we split the witness computation algorithm into 3 steps. Step 1: - // compute the execution trace group operations in *projective* coordinates Step 2: use batch inversion trick to - // convert all point traces into affine coordinates Step 3: populate the full execution trace, including the - // intermediate values from affine group operations This section sets up the data structures we need to store - // all intermediate ECC operations in projective form + // evaluation. This is bad! To avoid this, we split the witness computation algorithm into 3 steps. + // Step 1: compute the execution trace group operations in *projective* coordinates + // Step 2: use batch inversion trick to convert all points into affine coordinates + // Step 3: populate the full execution trace, including the intermediate values from affine group operations + // This section sets up the data structures we need to store all intermediate ECC operations in projective form const size_t num_point_adds_and_doubles = (num_msm_rows - 2) * 4; const size_t num_accumulators = num_msm_rows - 1; - const size_t num_points_in_trace = (num_point_adds_and_doubles * 3) + num_accumulators; + // In what fallows, either p1 + p2 = p3, or p1.dbl() = p3 // We create 1 vector to store the entire point trace. We split into multiple containers using std::span // (we want 1 vector object to more efficiently batch normalize points) - std::vector point_trace(num_points_in_trace); - // the point traces record group operations. Either p1 + p2 = p3, or p1.dbl() = p3 - std::span p1_trace(&point_trace[0], num_point_adds_and_doubles); - std::span p2_trace(&point_trace[num_point_adds_and_doubles], num_point_adds_and_doubles); - std::span p3_trace(&point_trace[num_point_adds_and_doubles * 2], num_point_adds_and_doubles); + static constexpr size_t NUM_POINTS_IN_ADDITION_RELATION = 3; + const size_t num_points_to_normalize = + (num_point_adds_and_doubles * NUM_POINTS_IN_ADDITION_RELATION) + num_accumulators; + std::vector points_to_normalize(num_points_to_normalize); + std::span p1_trace(&points_to_normalize[0], num_point_adds_and_doubles); + std::span p2_trace(&points_to_normalize[num_point_adds_and_doubles], num_point_adds_and_doubles); + std::span p3_trace(&points_to_normalize[num_point_adds_and_doubles * 2], num_point_adds_and_doubles); // operation_trace records whether an entry in the p1/p2/p3 trace represents a point addition or doubling std::vector operation_trace(num_point_adds_and_doubles); // accumulator_trace tracks the value of the ECCVM accumulator for each row - std::span accumulator_trace(&point_trace[num_point_adds_and_doubles * 3], num_accumulators); + std::span accumulator_trace(&points_to_normalize[num_point_adds_and_doubles * 3], num_accumulators); // we start the accumulator at the offset generator point. This ensures we can support an MSM that produces a constexpr auto offset_generator = bb::g1::derive_generators("ECCVM_OFFSET_GENERATOR", 1)[0]; accumulator_trace[0] = offset_generator; - // populate point trace data, and the components of the MSM execution trace that do not relate to affine point + // TODO(https://github.com/AztecProtocol/barretenberg/issues/973): Reinstate multitreading? + // populate point trace, and the components of the MSM execution trace that do not relate to affine point // operations - run_loop_in_parallel(msms.size(), [&](size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - Element accumulator = offset_generator; - const auto& msm = msms[i]; - size_t msm_row_index = msm_row_indices[i]; - const size_t msm_size = msm.size(); - const size_t rows_per_round = - (msm_size / ADDITIONS_PER_ROW) + (msm_size % ADDITIONS_PER_ROW != 0 ? 1 : 0); - size_t trace_index = (msm_row_indices[i] - 1) * 4; - - for (size_t j = 0; j < NUM_WNAF_DIGITS_PER_SCALAR; ++j) { - const uint32_t pc = static_cast(pc_indices[i]); - - for (size_t k = 0; k < rows_per_round; ++k) { - const size_t points_per_row = - (k + 1) * ADDITIONS_PER_ROW > msm_size ? msm_size % ADDITIONS_PER_ROW : ADDITIONS_PER_ROW; - auto& row = msm_state[msm_row_index]; - const size_t idx = k * ADDITIONS_PER_ROW; - row.msm_transition = (j == 0) && (k == 0); - for (size_t m = 0; m < ADDITIONS_PER_ROW; ++m) { - - auto& add_state = row.add_state[m]; - add_state.add = points_per_row > m; - int slice = add_state.add ? msm[idx + m].wnaf_digits[j] : 0; - // In the MSM columns in the ECCVM circuit, we can add up to 4 points per row. - // if `row.add_state[m].add = 1`, this indicates that we want to add the `m`'th point in - // the MSM columns into the MSM accumulator `add_state.slice` = A 4-bit WNAF slice of - // the scalar multiplier associated with the point we are adding (the specific slice - // chosen depends on the value of msm_round) (WNAF = windowed-non-adjacent-form. Value - // range is `-15, -13, - // ..., 15`) If `add_state.add = 1`, we want `add_state.slice` to be the *compressed* - // form of the WNAF slice value. (compressed = no gaps in the value range. i.e. -15, - // -13, ..., 15 maps to 0, ... , 15) - add_state.slice = add_state.add ? (slice + 15) / 2 : 0; - add_state.point = add_state.add - ? msm[idx + m].precomputed_table[static_cast(add_state.slice)] - : AffineElement{ 0, 0 }; - - Element p1 = accumulator; - Element p2 = Element(add_state.point); - accumulator = add_state.add ? (accumulator + add_state.point) : Element(p1); + for (size_t msm_idx = 0; msm_idx < msms.size(); msm_idx++) { + Element accumulator = offset_generator; + const auto& msm = msms[msm_idx]; + size_t msm_row_index = msm_row_counts[msm_idx]; + const size_t msm_size = msm.size(); + const size_t num_rows_per_digit = + (msm_size / ADDITIONS_PER_ROW) + ((msm_size % ADDITIONS_PER_ROW != 0) ? 1 : 0); + size_t trace_index = (msm_row_counts[msm_idx] - 1) * 4; + + for (size_t digit_idx = 0; digit_idx < NUM_WNAF_DIGITS_PER_SCALAR; ++digit_idx) { + const auto pc = static_cast(pc_values[msm_idx]); + for (size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) { + const size_t num_points_in_row = (row_idx + 1) * ADDITIONS_PER_ROW > msm_size + ? (msm_size % ADDITIONS_PER_ROW) + : ADDITIONS_PER_ROW; + auto& row = msm_rows[msm_row_index]; + const size_t offset = row_idx * ADDITIONS_PER_ROW; + row.msm_transition = (digit_idx == 0) && (row_idx == 0); + for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) { + + auto& add_state = row.add_state[point_idx]; + add_state.add = num_points_in_row > point_idx; + int slice = add_state.add ? msm[offset + point_idx].wnaf_digits[digit_idx] : 0; + // In the MSM columns in the ECCVM circuit, we can add up to 4 points per row. + // if `row.add_state[point_idx].add = 1`, this indicates that we want to add the + // `point_idx`'th point in the MSM columns into the MSM accumulator `add_state.slice` = A + // 4-bit WNAF slice of the scalar multiplier associated with the point we are adding (the + // specific slice chosen depends on the value of msm_round) (WNAF = + // windowed-non-adjacent-form. Value range is `-15, -13, + // ..., 15`) If `add_state.add = 1`, we want `add_state.slice` to be the *compressed* + // form of the WNAF slice value. (compressed = no gaps in the value range. i.e. -15, + // -13, ..., 15 maps to 0, ... , 15) + add_state.slice = add_state.add ? (slice + 15) / 2 : 0; + add_state.point = + add_state.add + ? msm[offset + point_idx].precomputed_table[static_cast(add_state.slice)] + : AffineElement{ 0, 0 }; + + Element p1(accumulator); + Element p2(add_state.point); + accumulator = add_state.add ? (accumulator + add_state.point) : Element(p1); + p1_trace[trace_index] = p1; + p2_trace[trace_index] = p2; + p3_trace[trace_index] = accumulator; + operation_trace[trace_index] = false; + trace_index++; + } + accumulator_trace[msm_row_index] = accumulator; + row.q_add = true; + row.q_double = false; + row.q_skew = false; + row.msm_round = static_cast(digit_idx); + row.msm_size = static_cast(msm_size); + row.msm_count = static_cast(offset); + row.pc = pc; + msm_row_index++; + } + // doubling + if (digit_idx < NUM_WNAF_DIGITS_PER_SCALAR - 1) { + auto& row = msm_rows[msm_row_index]; + row.msm_transition = false; + row.msm_round = static_cast(digit_idx + 1); + row.msm_size = static_cast(msm_size); + row.msm_count = static_cast(0); + row.q_add = false; + row.q_double = true; + row.q_skew = false; + for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) { + auto& add_state = row.add_state[point_idx]; + add_state.add = false; + add_state.slice = 0; + add_state.point = { 0, 0 }; + add_state.collision_inverse = 0; + + p1_trace[trace_index] = accumulator; + p2_trace[trace_index] = accumulator; + accumulator = accumulator.dbl(); + p3_trace[trace_index] = accumulator; + operation_trace[trace_index] = true; + trace_index++; + } + accumulator_trace[msm_row_index] = accumulator; + msm_row_index++; + } else { + for (size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) { + auto& row = msm_rows[msm_row_index]; + + const size_t num_points_in_row = (row_idx + 1) * ADDITIONS_PER_ROW > msm_size + ? msm_size % ADDITIONS_PER_ROW + : ADDITIONS_PER_ROW; + const size_t offset = row_idx * ADDITIONS_PER_ROW; + row.msm_transition = false; + Element acc_expected = accumulator; + for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) { + auto& add_state = row.add_state[point_idx]; + add_state.add = num_points_in_row > point_idx; + add_state.slice = add_state.add ? msm[offset + point_idx].wnaf_skew ? 7 : 0 : 0; + + add_state.point = + add_state.add + ? msm[offset + point_idx].precomputed_table[static_cast(add_state.slice)] + : AffineElement{ 0, 0 }; + bool add_predicate = add_state.add ? msm[offset + point_idx].wnaf_skew : false; + auto p1 = accumulator; + accumulator = add_predicate ? accumulator + add_state.point : accumulator; p1_trace[trace_index] = p1; - p2_trace[trace_index] = p2; + p2_trace[trace_index] = add_state.point; p3_trace[trace_index] = accumulator; operation_trace[trace_index] = false; trace_index++; } - accumulator_trace[msm_row_index] = accumulator; - row.q_add = true; + row.q_add = false; row.q_double = false; - row.q_skew = false; - row.msm_round = static_cast(j); + row.q_skew = true; + row.msm_round = static_cast(digit_idx + 1); row.msm_size = static_cast(msm_size); - row.msm_count = static_cast(idx); + row.msm_count = static_cast(offset); row.pc = pc; - msm_row_index++; - } - // doubling - if (j < NUM_WNAF_DIGITS_PER_SCALAR - 1) { - auto& row = msm_state[msm_row_index]; - row.msm_transition = false; - row.msm_round = static_cast(j + 1); - row.msm_size = static_cast(msm_size); - row.msm_count = static_cast(0); - row.q_add = false; - row.q_double = true; - row.q_skew = false; - for (size_t m = 0; m < 4; ++m) { - - auto& add_state = row.add_state[m]; - add_state.add = false; - add_state.slice = 0; - add_state.point = { 0, 0 }; - add_state.collision_inverse = 0; - - p1_trace[trace_index] = accumulator; - p2_trace[trace_index] = accumulator; - accumulator = accumulator.dbl(); - p3_trace[trace_index] = accumulator; - operation_trace[trace_index] = true; - trace_index++; - } accumulator_trace[msm_row_index] = accumulator; msm_row_index++; - } else { - for (size_t k = 0; k < rows_per_round; ++k) { - auto& row = msm_state[msm_row_index]; - - const size_t points_per_row = (k + 1) * ADDITIONS_PER_ROW > msm_size - ? msm_size % ADDITIONS_PER_ROW - : ADDITIONS_PER_ROW; - const size_t idx = k * ADDITIONS_PER_ROW; - row.msm_transition = false; - - Element acc_expected = accumulator; - - for (size_t m = 0; m < 4; ++m) { - auto& add_state = row.add_state[m]; - add_state.add = points_per_row > m; - add_state.slice = add_state.add ? msm[idx + m].wnaf_skew ? 7 : 0 : 0; - - add_state.point = - add_state.add ? msm[idx + m].precomputed_table[static_cast(add_state.slice)] - : AffineElement{ 0, 0 }; - bool add_predicate = add_state.add ? msm[idx + m].wnaf_skew : false; - auto p1 = accumulator; - accumulator = add_predicate ? accumulator + add_state.point : accumulator; - p1_trace[trace_index] = p1; - p2_trace[trace_index] = add_state.point; - p3_trace[trace_index] = accumulator; - operation_trace[trace_index] = false; - trace_index++; - } - row.q_add = false; - row.q_double = false; - row.q_skew = true; - row.msm_round = static_cast(j + 1); - row.msm_size = static_cast(msm_size); - row.msm_count = static_cast(idx); - row.pc = pc; - accumulator_trace[msm_row_index] = accumulator; - msm_row_index++; - } } } } - }); + } // Normalize the points in the point trace - run_loop_in_parallel(point_trace.size(), [&](size_t start, size_t end) { - Element::batch_normalize(&point_trace[start], end - start); + run_loop_in_parallel(points_to_normalize.size(), [&](size_t start, size_t end) { + Element::batch_normalize(&points_to_normalize[start], end - start); }); // inverse_trace is used to compute the value of the `collision_inverse` column in the ECCVM. std::vector inverse_trace(num_point_adds_and_doubles); run_loop_in_parallel(num_point_adds_and_doubles, [&](size_t start, size_t end) { - for (size_t i = start; i < end; ++i) { - if (operation_trace[i]) { - inverse_trace[i] = (p1_trace[i].y + p1_trace[i].y); + for (size_t operation_idx = start; operation_idx < end; ++operation_idx) { + if (operation_trace[operation_idx]) { + inverse_trace[operation_idx] = (p1_trace[operation_idx].y + p1_trace[operation_idx].y); } else { - inverse_trace[i] = (p2_trace[i].x - p1_trace[i].x); + inverse_trace[operation_idx] = (p2_trace[operation_idx].x - p1_trace[operation_idx].x); } } FF::batch_invert(&inverse_trace[start], end - start); @@ -344,90 +361,86 @@ class ECCVMMSMMBuilder { // complete the computation of the ECCVM execution trace, by adding the affine intermediate point data // i.e. row.accumulator_x, row.accumulator_y, row.add_state[0...3].collision_inverse, // row.add_state[0...3].lambda - run_loop_in_parallel(msms.size(), [&](size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - const auto& msm = msms[i]; - size_t trace_index = ((msm_row_indices[i] - 1) * ADDITIONS_PER_ROW); - size_t msm_row_index = msm_row_indices[i]; - // 1st MSM row will have accumulator equal to the previous MSM output - // (or point at infinity for 1st MSM) - size_t accumulator_index = msm_row_indices[i] - 1; - const size_t msm_size = msm.size(); - const size_t rows_per_round = - (msm_size / ADDITIONS_PER_ROW) + (msm_size % ADDITIONS_PER_ROW != 0 ? 1 : 0); + for (size_t msm_idx = 0; msm_idx < msms.size(); msm_idx++) { + const auto& msm = msms[msm_idx]; + size_t trace_index = ((msm_row_counts[msm_idx] - 1) * ADDITIONS_PER_ROW); + size_t msm_row_index = msm_row_counts[msm_idx]; + // 1st MSM row will have accumulator equal to the previous MSM output + // (or point at infinity for 1st MSM) + size_t accumulator_index = msm_row_counts[msm_idx] - 1; + const size_t msm_size = msm.size(); + const size_t num_rows_per_digit = + (msm_size / ADDITIONS_PER_ROW) + ((msm_size % ADDITIONS_PER_ROW != 0) ? 1 : 0); + + for (size_t digit_idx = 0; digit_idx < NUM_WNAF_DIGITS_PER_SCALAR; ++digit_idx) { + for (size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) { + auto& row = msm_rows[msm_row_index]; + const Element& normalized_accumulator = accumulator_trace[accumulator_index]; + ASSERT(normalized_accumulator.is_point_at_infinity() == 0); + row.accumulator_x = normalized_accumulator.x; + row.accumulator_y = normalized_accumulator.y; + for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) { + auto& add_state = row.add_state[point_idx]; + const auto& inverse = inverse_trace[trace_index]; + const auto& p1 = p1_trace[trace_index]; + const auto& p2 = p2_trace[trace_index]; + add_state.collision_inverse = add_state.add ? inverse : 0; + add_state.lambda = add_state.add ? (p2.y - p1.y) * inverse : 0; + trace_index++; + } + accumulator_index++; + msm_row_index++; + } - for (size_t j = 0; j < NUM_WNAF_DIGITS_PER_SCALAR; ++j) { - for (size_t k = 0; k < rows_per_round; ++k) { - auto& row = msm_state[msm_row_index]; + if (digit_idx < NUM_WNAF_DIGITS_PER_SCALAR - 1) { + MSMRow& row = msm_rows[msm_row_index]; + const Element& normalized_accumulator = accumulator_trace[accumulator_index]; + const FF& acc_x = normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.x; + const FF& acc_y = normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.y; + row.accumulator_x = acc_x; + row.accumulator_y = acc_y; + for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) { + auto& add_state = row.add_state[point_idx]; + add_state.collision_inverse = 0; + const FF& dx = p1_trace[trace_index].x; + const FF& inverse = inverse_trace[trace_index]; + add_state.lambda = ((dx + dx + dx) * dx) * inverse; + trace_index++; + } + accumulator_index++; + msm_row_index++; + } else { + for (size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) { + MSMRow& row = msm_rows[msm_row_index]; const Element& normalized_accumulator = accumulator_trace[accumulator_index]; ASSERT(normalized_accumulator.is_point_at_infinity() == 0); + const size_t offset = row_idx * ADDITIONS_PER_ROW; row.accumulator_x = normalized_accumulator.x; row.accumulator_y = normalized_accumulator.y; - for (size_t m = 0; m < ADDITIONS_PER_ROW; ++m) { - auto& add_state = row.add_state[m]; + for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) { + auto& add_state = row.add_state[point_idx]; + bool add_predicate = add_state.add ? msm[offset + point_idx].wnaf_skew : false; + const auto& inverse = inverse_trace[trace_index]; const auto& p1 = p1_trace[trace_index]; const auto& p2 = p2_trace[trace_index]; - add_state.collision_inverse = add_state.add ? inverse : 0; - add_state.lambda = add_state.add ? (p2.y - p1.y) * inverse : 0; + add_state.collision_inverse = add_predicate ? inverse : 0; + add_state.lambda = add_predicate ? (p2.y - p1.y) * inverse : 0; trace_index++; } accumulator_index++; msm_row_index++; } - - if (j < NUM_WNAF_DIGITS_PER_SCALAR - 1) { - MSMRow& row = msm_state[msm_row_index]; - const Element& normalized_accumulator = accumulator_trace[accumulator_index]; - const FF& acc_x = normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.x; - const FF& acc_y = normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.y; - row.accumulator_x = acc_x; - row.accumulator_y = acc_y; - - for (size_t m = 0; m < 4; ++m) { - auto& add_state = row.add_state[m]; - add_state.collision_inverse = 0; - const FF& dx = p1_trace[trace_index].x; - const FF& inverse = inverse_trace[trace_index]; - add_state.lambda = ((dx + dx + dx) * dx) * inverse; - trace_index++; - } - accumulator_index++; - msm_row_index++; - } else { - for (size_t k = 0; k < rows_per_round; ++k) { - MSMRow& row = msm_state[msm_row_index]; - const Element& normalized_accumulator = accumulator_trace[accumulator_index]; - ASSERT(normalized_accumulator.is_point_at_infinity() == 0); - const size_t idx = k * ADDITIONS_PER_ROW; - row.accumulator_x = normalized_accumulator.x; - row.accumulator_y = normalized_accumulator.y; - - for (size_t m = 0; m < ADDITIONS_PER_ROW; ++m) { - auto& add_state = row.add_state[m]; - bool add_predicate = add_state.add ? msm[idx + m].wnaf_skew : false; - - const auto& inverse = inverse_trace[trace_index]; - const auto& p1 = p1_trace[trace_index]; - const auto& p2 = p2_trace[trace_index]; - add_state.collision_inverse = add_predicate ? inverse : 0; - add_state.lambda = add_predicate ? (p2.y - p1.y) * inverse : 0; - trace_index++; - } - accumulator_index++; - msm_row_index++; - } - } } } - }); + } // populate the final row in the MSM execution trace. // we always require 1 extra row at the end of the trace, because the accumulator x/y coordinates for row `i` // are present at row `i+1` Element final_accumulator(accumulator_trace.back()); - MSMRow& final_row = msm_state.back(); - final_row.pc = static_cast(pc_indices.back()); + MSMRow& final_row = msm_rows.back(); + final_row.pc = static_cast(pc_values.back()); final_row.msm_transition = true; final_row.accumulator_x = final_accumulator.is_point_at_infinity() ? 0 : final_accumulator.x; final_row.accumulator_y = final_accumulator.is_point_at_infinity() ? 0 : final_accumulator.y; @@ -441,7 +454,7 @@ class ECCVMMSMMBuilder { typename MSMRow::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 }, typename MSMRow::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 } }; - return msm_state; + return { msm_rows, point_table_read_counts }; } }; } // namespace bb From a726a114a2967df27b5724f7fc94880dfc0781e6 Mon Sep 17 00:00:00 2001 From: codygunton Date: Sun, 26 May 2024 14:21:21 +0000 Subject: [PATCH 22/24] Reinstate failing tests --- .../client_ivc/client_ivc.test.cpp | 35 ++++++++----------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/client_ivc/client_ivc.test.cpp b/barretenberg/cpp/src/barretenberg/client_ivc/client_ivc.test.cpp index 43852f24d21..8730a8793b0 100644 --- a/barretenberg/cpp/src/barretenberg/client_ivc/client_ivc.test.cpp +++ b/barretenberg/cpp/src/barretenberg/client_ivc/client_ivc.test.cpp @@ -119,37 +119,30 @@ TEST_F(ClientIVCTests, BasicFailure) * @brief Prove and verify accumulation of an arbitrary set of circuits * */ -TEST_F(ClientIVCTests, DISABLED_BasicLarge) +TEST_F(ClientIVCTests, BasicLarge) { - const auto run_test = []() { - ClientIVC ivc; - - // Construct a set of arbitrary circuits - size_t NUM_CIRCUITS = 5; - std::vector circuits; - for (size_t idx = 0; idx < NUM_CIRCUITS; ++idx) { - circuits.emplace_back(create_mock_circuit(ivc)); - } + ClientIVC ivc; - // Accumulate each circuit - for (auto& circuit : circuits) { - ivc.accumulate(circuit); - } + // Construct a set of arbitrary circuits + size_t NUM_CIRCUITS = 5; + std::vector circuits; + for (size_t idx = 0; idx < NUM_CIRCUITS; ++idx) { + circuits.emplace_back(create_mock_circuit(ivc)); + } - EXPECT_TRUE(prove_and_verify(ivc)); - }; - for (size_t idx = 0; idx < 256; idx++) { - numeric::get_debug_randomness(true, idx); - info("run ", idx); - run_test(); + // Accumulate each circuit + for (auto& circuit : circuits) { + ivc.accumulate(circuit); } + + EXPECT_TRUE(prove_and_verify(ivc)); }; /** * @brief Using a structured trace allows for the accumulation of circuits of varying size * */ -TEST_F(ClientIVCTests, DISABLED_BasicStructured) +TEST_F(ClientIVCTests, BasicStructured) { ClientIVC ivc; ivc.structured_flag = true; From 0d76d0857212f58018217bcaad12356e2a3244fd Mon Sep 17 00:00:00 2001 From: codygunton Date: Tue, 21 May 2024 17:08:36 +0000 Subject: [PATCH 23/24] Fix from Zac --- .../barretenberg/eccvm/transcript_builder.hpp | 28 +++++++++++++------ .../op_queue/ecc_op_queue.hpp | 1 - 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp b/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp index 58db3bd50e3..7bb6f8ede60 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp @@ -269,17 +269,29 @@ class ECCVMTranscriptBuilder { if (entry.add || msm_transition) { Element lhs = entry.add ? Element(entry.base_point) : intermediate_accumulator_trace[i]; Element rhs = accumulator_trace[i]; + FF lhs_y = lhs.y; + FF lhs_x = lhs.x; + FF rhs_y = rhs.y; + FF rhs_x = rhs.x; + if (rhs.is_point_at_infinity()) { + rhs_y = 0; + rhs_x = 0; + } + if (lhs.is_point_at_infinity()) { + lhs_y = 0; + lhs_x = 0; + } row.transcript_add_x_equal = - lhs.x == rhs.x || (lhs.is_point_at_infinity() && rhs.is_point_at_infinity()); // check infinity? + lhs_x == rhs_x || (lhs.is_point_at_infinity() && rhs.is_point_at_infinity()); // check infinity? row.transcript_add_y_equal = - lhs.y == rhs.y || (lhs.is_point_at_infinity() && rhs.is_point_at_infinity()); - if ((lhs.x == rhs.x) && (lhs.y == rhs.y) && !lhs.is_point_at_infinity() && + lhs_y == rhs_y || (lhs.is_point_at_infinity() && rhs.is_point_at_infinity()); + if ((lhs_x == rhs_x) && (lhs_y == rhs_y) && !lhs.is_point_at_infinity() && !rhs.is_point_at_infinity()) { - add_lambda_denominator[i] = lhs.y + lhs.y; - add_lambda_numerator[i] = lhs.x * lhs.x * 3; - } else if ((lhs.x != rhs.x) && !lhs.is_point_at_infinity() && !rhs.is_point_at_infinity()) { - add_lambda_denominator[i] = rhs.x - lhs.x; - add_lambda_numerator[i] = rhs.y - lhs.y; + add_lambda_denominator[i] = lhs_y + lhs_y; + add_lambda_numerator[i] = lhs_x * lhs_x * 3; + } else if ((lhs_x != rhs_x) && !lhs.is_point_at_infinity() && !rhs.is_point_at_infinity()) { + add_lambda_denominator[i] = rhs_x - lhs_x; + add_lambda_numerator[i] = rhs_y - lhs_y; } else { add_lambda_numerator[i] = 0; add_lambda_denominator[i] = 0; diff --git a/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp b/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp index 0992924b2e3..91e64b764a0 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp @@ -380,7 +380,6 @@ class ECCOpQueue { */ UltraOp no_op() { - // Construct and store the operation in the ultra op format auto ultra_op = construct_and_populate_ultra_ops(NULL_OP, accumulator); From 6eb3665da17c7f2f0276f688d9a9285188fcd4a5 Mon Sep 17 00:00:00 2001 From: codygunton Date: Sun, 26 May 2024 14:38:43 +0000 Subject: [PATCH 24/24] Cleanup --- .../barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp index e50e970ce8f..47a0a3f8731 100644 --- a/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp +++ b/barretenberg/cpp/src/barretenberg/relations/ecc_vm/ecc_transcript_relation.cpp @@ -424,10 +424,6 @@ void ECCVMTranscriptRelationImpl::accumulate(ContainerOverSubrelations& accu auto y_diff = lhs_y - rhs_y; auto y_product = transcript_Py_inverse * (-transcript_add_y_equal + 1) + transcript_add_y_equal; auto y_constant = transcript_add_y_equal - 1; - // info("y_diff : ", y_diff); - // info("transcript_Py_inverse : ", transcript_Py_inverse); - // info("prod that is 1 or 0 : ", y_diff * transcript_Py_inverse); - // info("transcript_add_y_equal: ", transcript_add_y_equal); auto transcript_add_y_equal_check_relation = (y_diff * y_product + y_constant) * any_add_is_active; std::get<24>(accumulator) += transcript_add_y_equal_check_relation * scaling_factor; // degree 5 }