Skip to content

Commit

Permalink
chore(avm): bugfixing witness generation for add, sub, mul for FF (#9938
Browse files Browse the repository at this point in the history
)
  • Loading branch information
jeanmon authored Nov 13, 2024
1 parent b36c137 commit d8db656
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 38 deletions.
1 change: 1 addition & 0 deletions barretenberg/cpp/pil/avm/alu.pil
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ namespace alu(256);

// This holds the product over the integers
// (u1 multiplication only cares about a_lo and b_lo)
// TODO(9937): The following is not well constrained as this expression overflows the field.
pol PRODUCT = a_lo * b_lo + (1 - u1_tag) * (LIMB_BITS_POW * partial_prod_lo + MAX_BITS_POW * (partial_prod_hi + a_hi * b_hi));

// =============== ADDITION/SUBTRACTION Operation Constraints =================================================
Expand Down
39 changes: 24 additions & 15 deletions barretenberg/cpp/src/barretenberg/vm/avm/tests/arithmetic.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,74 +403,83 @@ std::vector<std::array<FF, 3>> positive_op_div_test_values = { {
// Test on basic addition over finite field type.
TEST_F(AvmArithmeticTestsFF, addition)
{
std::vector<FF> const calldata = { 37, 4, 11 };
const FF a = FF::modulus - 19;
const FF b = FF::modulus - 5;
const FF c = FF::modulus - 24; // c = a + b
std::vector<FF> const calldata = { a, b, 4 };
gen_trace_builder(calldata);
trace_builder.op_set(0, 0, 0, AvmMemoryTag::U32);
trace_builder.op_set(0, 3, 1, AvmMemoryTag::U32);
trace_builder.op_calldata_copy(0, 0, 1, 0);

// Memory layout: [37,4,11,0,0,0,....]
trace_builder.op_add(0, 0, 1, 4); // [37,4,11,0,41,0,....]
// Memory layout: [a,b,4,0,0,....]
trace_builder.op_add(0, 0, 1, 4); // [a,b,4,0,c,0,....]
trace_builder.op_set(0, 5, 100, AvmMemoryTag::U32);
trace_builder.op_return(0, 0, 100);
auto trace = trace_builder.finalize();

auto alu_row = common_validate_add(trace, FF(37), FF(4), FF(41), FF(0), FF(1), FF(4), AvmMemoryTag::FF);
auto alu_row = common_validate_add(trace, a, b, c, FF(0), FF(1), FF(4), AvmMemoryTag::FF);

EXPECT_EQ(alu_row.alu_ff_tag, FF(1));
EXPECT_EQ(alu_row.alu_cf, FF(0));

std::vector<FF> const returndata = { 37, 4, 11, 0, 41 };
std::vector<FF> const returndata = { a, b, 4, 0, c };

validate_trace(std::move(trace), public_inputs, calldata, returndata);
}

// Test on basic subtraction over finite field type.
TEST_F(AvmArithmeticTestsFF, subtraction)
{
std::vector<FF> const calldata = { 8, 4, 17 };
const FF a = 8;
const FF b = FF::modulus - 5;
const FF c = 13; // c = a - b
std::vector<FF> const calldata = { b, 4, a };
gen_trace_builder(calldata);
trace_builder.op_set(0, 0, 0, AvmMemoryTag::U32);
trace_builder.op_set(0, 3, 1, AvmMemoryTag::U32);
trace_builder.op_calldata_copy(0, 0, 1, 0);

// Memory layout: [8,4,17,0,0,0,....]
trace_builder.op_sub(0, 2, 0, 1); // [8,9,17,0,0,0....]
// Memory layout: [b,4,a,0,0,0,....]
trace_builder.op_sub(0, 2, 0, 1); // [b,c,a,0,0,0....]
trace_builder.op_set(0, 3, 100, AvmMemoryTag::U32);
trace_builder.op_return(0, 0, 100);
auto trace = trace_builder.finalize();

auto alu_row = common_validate_sub(trace, FF(17), FF(8), FF(9), FF(2), FF(0), FF(1), AvmMemoryTag::FF);
auto alu_row = common_validate_sub(trace, a, b, c, FF(2), FF(0), FF(1), AvmMemoryTag::FF);

EXPECT_EQ(alu_row.alu_ff_tag, FF(1));
EXPECT_EQ(alu_row.alu_cf, FF(0));

std::vector<FF> const returndata = { 8, 9, 17 };
std::vector<FF> const returndata = { b, c, a };
validate_trace(std::move(trace), public_inputs, calldata, returndata);
}

// Test on basic multiplication over finite field type.
TEST_F(AvmArithmeticTestsFF, multiplication)
{
std::vector<FF> const calldata = { 5, 0, 20 };
const FF a = FF::modulus - 1;
const FF b = 278;
const FF c = FF::modulus - 278;
std::vector<FF> const calldata = { b, 0, a };
gen_trace_builder(calldata);
trace_builder.op_set(0, 0, 0, AvmMemoryTag::U32);
trace_builder.op_set(0, 3, 1, AvmMemoryTag::U32);
trace_builder.op_calldata_copy(0, 0, 1, 0);

// Memory layout: [5,0,20,0,0,0,....]
trace_builder.op_mul(0, 2, 0, 1); // [5,100,20,0,0,0....]
// Memory layout: [b,0,a,0,0,0,....]
trace_builder.op_mul(0, 2, 0, 1); // [b,c,a,0,0,0....]
trace_builder.op_set(0, 3, 100, AvmMemoryTag::U32);
trace_builder.op_return(0, 0, 100);
auto trace = trace_builder.finalize();

auto alu_row_index = common_validate_mul(trace, FF(20), FF(5), FF(100), FF(2), FF(0), FF(1), AvmMemoryTag::FF);
auto alu_row_index = common_validate_mul(trace, a, b, c, FF(2), FF(0), FF(1), AvmMemoryTag::FF);
auto alu_row = trace.at(alu_row_index);

EXPECT_EQ(alu_row.alu_ff_tag, FF(1));
EXPECT_EQ(alu_row.alu_cf, FF(0));

std::vector<FF> const returndata = { 5, 100, 20 };
std::vector<FF> const returndata = { b, c, a };
validate_trace(std::move(trace), public_inputs, calldata, returndata);
}

Expand Down
66 changes: 43 additions & 23 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/alu_trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,14 @@ void AvmAluTraceBuilder::reset()
FF AvmAluTraceBuilder::op_add(FF const& a, FF const& b, AvmMemoryTag in_tag, uint32_t const clk)
{
bool carry = false;
uint256_t c_u256 = uint256_t(a) + uint256_t(b);
FF c = cast_to_mem_tag(c_u256, in_tag);
FF c;

if (in_tag == AvmMemoryTag::FF) {
c = a + b;
} else {
uint256_t c_u256 = uint256_t(a) + uint256_t(b);
c = cast_to_mem_tag(c_u256, in_tag);

if (in_tag != AvmMemoryTag::FF) {
// a_u128 + b_u128 >= 2^128 <==> c_u128 < a_u128
if (uint128_t(c) < uint128_t(a)) {
carry = true;
Expand Down Expand Up @@ -150,10 +154,14 @@ FF AvmAluTraceBuilder::op_add(FF const& a, FF const& b, AvmMemoryTag in_tag, uin
FF AvmAluTraceBuilder::op_sub(FF const& a, FF const& b, AvmMemoryTag in_tag, uint32_t const clk)
{
bool carry = false;
uint256_t c_u256 = uint256_t(a) - uint256_t(b);
FF c = cast_to_mem_tag(c_u256, in_tag);
FF c;

if (in_tag == AvmMemoryTag::FF) {
c = a - b;
} else {
uint256_t c_u256 = uint256_t(a) - uint256_t(b);
c = cast_to_mem_tag(c_u256, in_tag);

if (in_tag != AvmMemoryTag::FF) {
// Underflow when a_u128 < b_u128
if (uint128_t(a) < uint128_t(b)) {
carry = true;
Expand Down Expand Up @@ -189,29 +197,41 @@ FF AvmAluTraceBuilder::op_sub(FF const& a, FF const& b, AvmMemoryTag in_tag, uin
*/
FF AvmAluTraceBuilder::op_mul(FF const& a, FF const& b, AvmMemoryTag in_tag, uint32_t const clk)
{
uint256_t a_u256{ a };
uint256_t b_u256{ b };
uint256_t c_u256 = a_u256 * b_u256; // Multiplication over the integers (not mod. 2^128)
FF c = 0;
uint256_t alu_a_lo = 0;
uint256_t alu_a_hi = 0;
uint256_t alu_b_lo = 0;
uint256_t alu_b_hi = 0;
uint256_t c_hi = 0;
uint256_t partial_prod_lo = 0;
uint256_t partial_prod_hi = 0;

FF c = cast_to_mem_tag(c_u256, in_tag);
if (in_tag == AvmMemoryTag::FF) {
c = a * b;
} else {

uint8_t bits = mem_tag_bits(in_tag);
// limbs are size 1 for u1
uint8_t limb_bits = bits == 1 ? 1 : bits / 2;
uint8_t num_bits = bits;
uint256_t a_u256{ a };
uint256_t b_u256{ b };
uint256_t c_u256 = a_u256 * b_u256; // Multiplication over the integers (not mod. 2^128)

// Decompose a
auto [alu_a_lo, alu_a_hi] = decompose(a_u256, limb_bits);
// Decompose b
auto [alu_b_lo, alu_b_hi] = decompose(b_u256, limb_bits);
c = cast_to_mem_tag(c_u256, in_tag);

uint256_t partial_prod = alu_a_lo * alu_b_hi + alu_a_hi * alu_b_lo;
// Decompose the partial product
auto [partial_prod_lo, partial_prod_hi] = decompose(partial_prod, limb_bits);
uint8_t bits = mem_tag_bits(in_tag);
// limbs are size 1 for u1
uint8_t limb_bits = bits == 1 ? 1 : bits / 2;
uint8_t num_bits = bits;

auto c_hi = c_u256 >> num_bits;
// Decompose a
std::tie(alu_a_lo, alu_a_hi) = decompose(a_u256, limb_bits);
// Decompose b
std::tie(alu_b_lo, alu_b_hi) = decompose(b_u256, limb_bits);

uint256_t partial_prod = alu_a_lo * alu_b_hi + alu_a_hi * alu_b_lo;
// Decompose the partial product
std::tie(partial_prod_lo, partial_prod_hi) = decompose(partial_prod, limb_bits);

c_hi = c_u256 >> num_bits;

if (in_tag != AvmMemoryTag::FF) {
cmp_builder.range_check_builder.assert_range(uint128_t(c), mem_tag_bits(in_tag), EventEmitter::ALU, clk);
}

Expand Down

0 comments on commit d8db656

Please sign in to comment.