From d8db656980c09ad219c375e831443bd523100d4b Mon Sep 17 00:00:00 2001 From: Jean M <132435771+jeanmon@users.noreply.github.com> Date: Wed, 13 Nov 2024 20:15:38 +0100 Subject: [PATCH] chore(avm): bugfixing witness generation for add, sub, mul for FF (#9938) --- barretenberg/cpp/pil/avm/alu.pil | 1 + .../vm/avm/tests/arithmetic.test.cpp | 39 ++++++----- .../barretenberg/vm/avm/trace/alu_trace.cpp | 66 ++++++++++++------- 3 files changed, 68 insertions(+), 38 deletions(-) diff --git a/barretenberg/cpp/pil/avm/alu.pil b/barretenberg/cpp/pil/avm/alu.pil index 218a21c4b03..7014885fbd5 100644 --- a/barretenberg/cpp/pil/avm/alu.pil +++ b/barretenberg/cpp/pil/avm/alu.pil @@ -147,6 +147,7 @@ namespace alu(256); // This holds the product over the integers // (u1 multiplication only cares about a_lo and b_lo) + // TODO(9937): The following is not well constrained as this expression overflows the field. pol PRODUCT = a_lo * b_lo + (1 - u1_tag) * (LIMB_BITS_POW * partial_prod_lo + MAX_BITS_POW * (partial_prod_hi + a_hi * b_hi)); // =============== ADDITION/SUBTRACTION Operation Constraints ================================================= diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/tests/arithmetic.test.cpp b/barretenberg/cpp/src/barretenberg/vm/avm/tests/arithmetic.test.cpp index ebf0264e43b..46a00c1a004 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/tests/arithmetic.test.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/tests/arithmetic.test.cpp @@ -403,24 +403,27 @@ std::vector> positive_op_div_test_values = { { // Test on basic addition over finite field type. TEST_F(AvmArithmeticTestsFF, addition) { - std::vector const calldata = { 37, 4, 11 }; + const FF a = FF::modulus - 19; + const FF b = FF::modulus - 5; + const FF c = FF::modulus - 24; // c = a + b + std::vector const calldata = { a, b, 4 }; gen_trace_builder(calldata); trace_builder.op_set(0, 0, 0, AvmMemoryTag::U32); trace_builder.op_set(0, 3, 1, AvmMemoryTag::U32); trace_builder.op_calldata_copy(0, 0, 1, 0); - // Memory layout: [37,4,11,0,0,0,....] - trace_builder.op_add(0, 0, 1, 4); // [37,4,11,0,41,0,....] + // Memory layout: [a,b,4,0,0,....] + trace_builder.op_add(0, 0, 1, 4); // [a,b,4,0,c,0,....] trace_builder.op_set(0, 5, 100, AvmMemoryTag::U32); trace_builder.op_return(0, 0, 100); auto trace = trace_builder.finalize(); - auto alu_row = common_validate_add(trace, FF(37), FF(4), FF(41), FF(0), FF(1), FF(4), AvmMemoryTag::FF); + auto alu_row = common_validate_add(trace, a, b, c, FF(0), FF(1), FF(4), AvmMemoryTag::FF); EXPECT_EQ(alu_row.alu_ff_tag, FF(1)); EXPECT_EQ(alu_row.alu_cf, FF(0)); - std::vector const returndata = { 37, 4, 11, 0, 41 }; + std::vector const returndata = { a, b, 4, 0, c }; validate_trace(std::move(trace), public_inputs, calldata, returndata); } @@ -428,49 +431,55 @@ TEST_F(AvmArithmeticTestsFF, addition) // Test on basic subtraction over finite field type. TEST_F(AvmArithmeticTestsFF, subtraction) { - std::vector const calldata = { 8, 4, 17 }; + const FF a = 8; + const FF b = FF::modulus - 5; + const FF c = 13; // c = a - b + std::vector const calldata = { b, 4, a }; gen_trace_builder(calldata); trace_builder.op_set(0, 0, 0, AvmMemoryTag::U32); trace_builder.op_set(0, 3, 1, AvmMemoryTag::U32); trace_builder.op_calldata_copy(0, 0, 1, 0); - // Memory layout: [8,4,17,0,0,0,....] - trace_builder.op_sub(0, 2, 0, 1); // [8,9,17,0,0,0....] + // Memory layout: [b,4,a,0,0,0,....] + trace_builder.op_sub(0, 2, 0, 1); // [b,c,a,0,0,0....] trace_builder.op_set(0, 3, 100, AvmMemoryTag::U32); trace_builder.op_return(0, 0, 100); auto trace = trace_builder.finalize(); - auto alu_row = common_validate_sub(trace, FF(17), FF(8), FF(9), FF(2), FF(0), FF(1), AvmMemoryTag::FF); + auto alu_row = common_validate_sub(trace, a, b, c, FF(2), FF(0), FF(1), AvmMemoryTag::FF); EXPECT_EQ(alu_row.alu_ff_tag, FF(1)); EXPECT_EQ(alu_row.alu_cf, FF(0)); - std::vector const returndata = { 8, 9, 17 }; + std::vector const returndata = { b, c, a }; validate_trace(std::move(trace), public_inputs, calldata, returndata); } // Test on basic multiplication over finite field type. TEST_F(AvmArithmeticTestsFF, multiplication) { - std::vector const calldata = { 5, 0, 20 }; + const FF a = FF::modulus - 1; + const FF b = 278; + const FF c = FF::modulus - 278; + std::vector const calldata = { b, 0, a }; gen_trace_builder(calldata); trace_builder.op_set(0, 0, 0, AvmMemoryTag::U32); trace_builder.op_set(0, 3, 1, AvmMemoryTag::U32); trace_builder.op_calldata_copy(0, 0, 1, 0); - // Memory layout: [5,0,20,0,0,0,....] - trace_builder.op_mul(0, 2, 0, 1); // [5,100,20,0,0,0....] + // Memory layout: [b,0,a,0,0,0,....] + trace_builder.op_mul(0, 2, 0, 1); // [b,c,a,0,0,0....] trace_builder.op_set(0, 3, 100, AvmMemoryTag::U32); trace_builder.op_return(0, 0, 100); auto trace = trace_builder.finalize(); - auto alu_row_index = common_validate_mul(trace, FF(20), FF(5), FF(100), FF(2), FF(0), FF(1), AvmMemoryTag::FF); + auto alu_row_index = common_validate_mul(trace, a, b, c, FF(2), FF(0), FF(1), AvmMemoryTag::FF); auto alu_row = trace.at(alu_row_index); EXPECT_EQ(alu_row.alu_ff_tag, FF(1)); EXPECT_EQ(alu_row.alu_cf, FF(0)); - std::vector const returndata = { 5, 100, 20 }; + std::vector const returndata = { b, c, a }; validate_trace(std::move(trace), public_inputs, calldata, returndata); } diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/alu_trace.cpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/alu_trace.cpp index f41843a91c3..3b4f198b127 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/alu_trace.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/alu_trace.cpp @@ -106,10 +106,14 @@ void AvmAluTraceBuilder::reset() FF AvmAluTraceBuilder::op_add(FF const& a, FF const& b, AvmMemoryTag in_tag, uint32_t const clk) { bool carry = false; - uint256_t c_u256 = uint256_t(a) + uint256_t(b); - FF c = cast_to_mem_tag(c_u256, in_tag); + FF c; + + if (in_tag == AvmMemoryTag::FF) { + c = a + b; + } else { + uint256_t c_u256 = uint256_t(a) + uint256_t(b); + c = cast_to_mem_tag(c_u256, in_tag); - if (in_tag != AvmMemoryTag::FF) { // a_u128 + b_u128 >= 2^128 <==> c_u128 < a_u128 if (uint128_t(c) < uint128_t(a)) { carry = true; @@ -150,10 +154,14 @@ FF AvmAluTraceBuilder::op_add(FF const& a, FF const& b, AvmMemoryTag in_tag, uin FF AvmAluTraceBuilder::op_sub(FF const& a, FF const& b, AvmMemoryTag in_tag, uint32_t const clk) { bool carry = false; - uint256_t c_u256 = uint256_t(a) - uint256_t(b); - FF c = cast_to_mem_tag(c_u256, in_tag); + FF c; + + if (in_tag == AvmMemoryTag::FF) { + c = a - b; + } else { + uint256_t c_u256 = uint256_t(a) - uint256_t(b); + c = cast_to_mem_tag(c_u256, in_tag); - if (in_tag != AvmMemoryTag::FF) { // Underflow when a_u128 < b_u128 if (uint128_t(a) < uint128_t(b)) { carry = true; @@ -189,29 +197,41 @@ FF AvmAluTraceBuilder::op_sub(FF const& a, FF const& b, AvmMemoryTag in_tag, uin */ FF AvmAluTraceBuilder::op_mul(FF const& a, FF const& b, AvmMemoryTag in_tag, uint32_t const clk) { - uint256_t a_u256{ a }; - uint256_t b_u256{ b }; - uint256_t c_u256 = a_u256 * b_u256; // Multiplication over the integers (not mod. 2^128) + FF c = 0; + uint256_t alu_a_lo = 0; + uint256_t alu_a_hi = 0; + uint256_t alu_b_lo = 0; + uint256_t alu_b_hi = 0; + uint256_t c_hi = 0; + uint256_t partial_prod_lo = 0; + uint256_t partial_prod_hi = 0; - FF c = cast_to_mem_tag(c_u256, in_tag); + if (in_tag == AvmMemoryTag::FF) { + c = a * b; + } else { - uint8_t bits = mem_tag_bits(in_tag); - // limbs are size 1 for u1 - uint8_t limb_bits = bits == 1 ? 1 : bits / 2; - uint8_t num_bits = bits; + uint256_t a_u256{ a }; + uint256_t b_u256{ b }; + uint256_t c_u256 = a_u256 * b_u256; // Multiplication over the integers (not mod. 2^128) - // Decompose a - auto [alu_a_lo, alu_a_hi] = decompose(a_u256, limb_bits); - // Decompose b - auto [alu_b_lo, alu_b_hi] = decompose(b_u256, limb_bits); + c = cast_to_mem_tag(c_u256, in_tag); - uint256_t partial_prod = alu_a_lo * alu_b_hi + alu_a_hi * alu_b_lo; - // Decompose the partial product - auto [partial_prod_lo, partial_prod_hi] = decompose(partial_prod, limb_bits); + uint8_t bits = mem_tag_bits(in_tag); + // limbs are size 1 for u1 + uint8_t limb_bits = bits == 1 ? 1 : bits / 2; + uint8_t num_bits = bits; - auto c_hi = c_u256 >> num_bits; + // Decompose a + std::tie(alu_a_lo, alu_a_hi) = decompose(a_u256, limb_bits); + // Decompose b + std::tie(alu_b_lo, alu_b_hi) = decompose(b_u256, limb_bits); + + uint256_t partial_prod = alu_a_lo * alu_b_hi + alu_a_hi * alu_b_lo; + // Decompose the partial product + std::tie(partial_prod_lo, partial_prod_hi) = decompose(partial_prod, limb_bits); + + c_hi = c_u256 >> num_bits; - if (in_tag != AvmMemoryTag::FF) { cmp_builder.range_check_builder.assert_range(uint128_t(c), mem_tag_bits(in_tag), EventEmitter::ALU, clk); }