Skip to content

Commit

Permalink
feat: add limit to unique contract call (#10640)
Browse files Browse the repository at this point in the history
Resolves AztecProtocol/aztec-packages#10369

Note from @dbanks12:
Once the limit has been reached for contract calls to unique class IDs,
you can still call repeat contract addresses or even other contract
addresses that reuse an already checked class ID.

I had to change the call-ptr/space-id to just use a counter instead of
clk because space-id is uint8 and we were getting collisions.

Follow-up work:
- constrain that user-called address can be derived from the hinted
class ID & instance

---------

Co-authored-by: dbanks12 <david@aztecprotocol.com>
  • Loading branch information
2 people authored and AztecBot committed Dec 24, 2024
1 parent 7147ccb commit 9d5e289
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 56 deletions.
15 changes: 2 additions & 13 deletions cpp/src/barretenberg/vm/avm/tests/arithmetic.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,19 +220,7 @@ class AvmArithmeticTests : public ::testing::Test {
trace_builder =
AvmTraceBuilder(public_inputs, {}, 0).set_full_precomputed_tables(false).set_range_check_required(false);
trace_builder.set_all_calldata(calldata);
AvmTraceBuilder::ExtCallCtx ext_call_ctx({ .context_id = 0,
.parent_id = 0,
.contract_address = FF(0),
.calldata = calldata,
.nested_returndata = {},
.last_pc = 0,
.success_offset = 0,
.start_l2_gas_left = 0,
.start_da_gas_left = 0,
.l2_gas_left = 0,
.da_gas_left = 0,
.internal_return_ptr_stack = {} });
trace_builder.current_ext_call_ctx = ext_call_ctx;
trace_builder.current_ext_call_ctx.calldata = calldata;
}

// Generate a trace with an EQ opcode operation.
Expand Down Expand Up @@ -614,6 +602,7 @@ TEST_F(AvmArithmeticTestsFF, fDivisionByZeroError)
// We check that the operator error flag is raised.
TEST_F(AvmArithmeticTestsFF, fDivisionZeroByZeroError)
{
gen_trace_builder({});
// Memory layout: [0,0,0,0,0,0,....]
trace_builder.op_fdiv(0, 0, 1, 2); // [0,0,0,0,0,0....]
trace_builder.op_set(0, 0, 100, AvmMemoryTag::U32);
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/barretenberg/vm/avm/tests/cast.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ TEST_F(AvmCastTests, truncationFFToU16ModMinus1)
trace_builder.set_all_calldata(calldata);
AvmTraceBuilder::ExtCallCtx ext_call_ctx({ .context_id = 0,
.parent_id = 0,
.is_top_level = true,
.contract_address = FF(0),
.calldata = calldata,
.nested_returndata = {},
Expand Down Expand Up @@ -219,6 +220,7 @@ TEST_F(AvmCastTests, truncationFFToU16ModMinus2)
trace_builder.set_all_calldata(calldata);
AvmTraceBuilder::ExtCallCtx ext_call_ctx({ .context_id = 0,
.parent_id = 0,
.is_top_level = true,
.contract_address = FF(0),
.calldata = calldata,
.nested_returndata = {},
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/barretenberg/vm/avm/tests/execution.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class AvmExecutionTests : public ::testing::Test {
auto [contract_class_id, contract_instance] = gen_test_contract_hint(bytecode);
auto execution_hints = ExecutionHints().with_avm_contract_bytecode(
{ AvmContractBytecode{ bytecode, contract_instance, contract_class_id } });
execution_hints.contract_instance_hints.emplace(contract_instance.address, contract_instance);

vinfo("Calling execution::gen_trace");
return AvmExecutionTests::gen_trace(bytecode, calldata, public_inputs, returndata, execution_hints);
Expand All @@ -98,6 +99,7 @@ class AvmExecutionTests : public ::testing::Test {
auto [contract_class_id, contract_instance] = gen_test_contract_hint(bytecode);
execution_hints.with_avm_contract_bytecode(
{ AvmContractBytecode{ bytecode, contract_instance, contract_class_id } });
execution_hints.contract_instance_hints.emplace(contract_instance.address, contract_instance);

// These are magic values because of how some tests work! Don't change them
public_inputs.public_app_logic_call_requests[0].contract_address = contract_instance.address;
Expand Down
14 changes: 1 addition & 13 deletions cpp/src/barretenberg/vm/avm/tests/slice.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,7 @@ class AvmSliceTests : public ::testing::Test {
trace_builder =
AvmTraceBuilder(public_inputs, {}, 0).set_full_precomputed_tables(false).set_range_check_required(false);
trace_builder.set_all_calldata(calldata);
AvmTraceBuilder::ExtCallCtx ext_call_ctx({ .context_id = 0,
.parent_id = 0,
.contract_address = FF(0),
.calldata = calldata,
.nested_returndata = {},
.last_pc = 0,
.success_offset = 0,
.start_l2_gas_left = 0,
.start_da_gas_left = 0,
.l2_gas_left = 0,
.da_gas_left = 0,
.internal_return_ptr_stack = {} });
trace_builder.current_ext_call_ctx = ext_call_ctx;
trace_builder.current_ext_call_ctx.calldata = calldata;
this->calldata = calldata;
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/barretenberg/vm/avm/trace/errors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ enum class AvmError : uint32_t {
SIDE_EFFECT_LIMIT_REACHED,
OUT_OF_GAS,
STATIC_CALL_ALTERATION,
NO_BYTECODE_FOUND,
FAILED_BYTECODE_RETRIEVAL,
};

} // namespace bb::avm_trace
21 changes: 12 additions & 9 deletions cpp/src/barretenberg/vm/avm/trace/execution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,13 +427,13 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder,
AvmError error = AvmError::NO_ERROR;

// These hints help us to set up first call ctx
uint32_t clk = trace_builder.get_clk();
auto context_id = static_cast<uint8_t>(clk);
auto context_id = trace_builder.next_context_id;
uint32_t l2_gas_allocated_to_enqueued_call = trace_builder.get_l2_gas_left();
uint32_t da_gas_allocated_to_enqueued_call = trace_builder.get_da_gas_left();
trace_builder.current_ext_call_ctx = AvmTraceBuilder::ExtCallCtx{
.context_id = context_id,
.parent_id = 0,
.is_top_level = true,
.contract_address = enqueued_call_hint.contract_address,
.calldata = enqueued_call_hint.calldata,
.nested_returndata = {},
Expand All @@ -445,22 +445,22 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder,
.da_gas_left = da_gas_allocated_to_enqueued_call,
.internal_return_ptr_stack = {},
};
trace_builder.next_context_id++;
// Find the bytecode based on contract address of the public call request
std::vector<uint8_t> bytecode;
try {
bytecode =
trace_builder.get_bytecode(trace_builder.current_ext_call_ctx.contract_address, check_bytecode_membership);
} catch ([[maybe_unused]] const std::runtime_error& e) {
info("AVM enqueued call exceptionally halted. Error: No bytecode found for enqueued call");
info("AVM enqueued call exceptionally halted. Failed bytecode retrieval.");
// FIXME: properly handle case when bytecode is not found!
// For now, we add a dummy row in main trace to mutate later.
// Dummy row in main trace to mutate afterwards.
// This error was encountered before any opcodes were executed, but
// we need at least one row in the execution trace to then mutate and say "it halted and consumed all gas!"
trace_builder.op_add(0, 0, 0, 0, OpCode::ADD_8);
trace_builder.handle_exceptional_halt();
return AvmError::NO_BYTECODE_FOUND;
;
return AvmError::FAILED_BYTECODE_RETRIEVAL;
}

trace_builder.allocate_gas_for_call(l2_gas_allocated_to_enqueued_call, da_gas_allocated_to_enqueued_call);
Expand All @@ -471,7 +471,8 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder,
uint32_t pc = 0;
std::stack<uint32_t> debug_counter_stack;
uint32_t counter = 0;
trace_builder.set_call_ptr(context_id);
// FIXME: this cast means that we can have duplicate call ptrs since clk will end up way bigger than 256
trace_builder.set_call_ptr(static_cast<uint8_t>(context_id));
while (is_ok(error) && (pc = trace_builder.get_pc()) < bytecode.size()) {
auto [inst, parse_error] = Deserialization::parse(bytecode, pc);

Expand Down Expand Up @@ -881,7 +882,8 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder,
bytecode = trace_builder.get_bytecode(trace_builder.current_ext_call_ctx.contract_address,
/*check_membership=*/true);
} catch ([[maybe_unused]] const std::runtime_error& e) {
error = AvmError::NO_BYTECODE_FOUND;
info("AVM CALL failed bytecode retrieval.");
error = AvmError::FAILED_BYTECODE_RETRIEVAL;
}
debug_counter_stack.push(counter);
counter = 0;
Expand All @@ -901,7 +903,8 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder,
bytecode = trace_builder.get_bytecode(trace_builder.current_ext_call_ctx.contract_address,
/*check_membership=*/true);
} catch ([[maybe_unused]] const std::runtime_error& e) {
error = AvmError::NO_BYTECODE_FOUND;
info("AVM STATICCALL failed bytecode retrieval.");
error = AvmError::FAILED_BYTECODE_RETRIEVAL;
}
debug_counter_stack.push(counter);
counter = 0;
Expand Down Expand Up @@ -1033,7 +1036,7 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder,
}

if (!is_ok(error)) {
const bool is_top_level = trace_builder.current_ext_call_ctx.context_id == 0;
const bool is_top_level = trace_builder.current_ext_call_ctx.is_top_level;

auto const error_ic = counter - 1; // Need adjustement as counter increment occurs in loop body
std::string call_type = is_top_level ? "enqueued" : "nested";
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/barretenberg/vm/avm/trace/helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ std::string to_name(AvmError error)
return "SIDE EFFECT LIMIT REACHED";
case AvmError::OUT_OF_GAS:
return "OUT OF GAS";
case AvmError::NO_BYTECODE_FOUND:
return "NO BYTECODE FOUND";
case AvmError::FAILED_BYTECODE_RETRIEVAL:
return "FAILED BYTECODE RETRIEVAL";
default:
throw std::runtime_error("Invalid error type");
break;
Expand Down
74 changes: 56 additions & 18 deletions cpp/src/barretenberg/vm/avm/trace/trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,28 +157,38 @@ void AvmTraceBuilder::rollback_to_non_revertible_checkpoint()
merkle_tree_trace_builder.rollback_to_non_revertible_checkpoint();
}

std::vector<uint8_t> AvmTraceBuilder::get_bytecode(const FF contract_address, bool check_membership)
std::vector<uint8_t> AvmTraceBuilder::get_bytecode_from_hints(const FF contract_class_id)
{
auto clk = static_cast<uint32_t>(main_trace.size()) + 1;

// Find the bytecode based on contract address of the public call request
// Find the bytecode based on the hinted contract class id
// TODO: still need to make sure that the contract address does correspond to this class id
const AvmContractBytecode bytecode_hint =
*std::ranges::find_if(execution_hints.all_contract_bytecode, [contract_address](const auto& contract) {
return contract.contract_instance.address == contract_address;
*std::ranges::find_if(execution_hints.all_contract_bytecode, [contract_class_id](const auto& contract) {
return contract.contract_instance.contract_class_id == contract_class_id;
});
return bytecode_hint.bytecode;
}

std::vector<uint8_t> AvmTraceBuilder::get_bytecode(const FF contract_address, bool check_membership)
{
auto clk = static_cast<uint32_t>(main_trace.size()) + 1;

ASSERT(execution_hints.contract_instance_hints.contains(contract_address));
const ContractInstanceHint instance_hint = execution_hints.contract_instance_hints.at(contract_address);
const FF contract_class_id = instance_hint.contract_class_id;

bool exists = true;
if (check_membership && !is_canonical(contract_address)) {
if (bytecode_membership_cache.find(contract_address) != bytecode_membership_cache.end()) {
// If we have already seen the contract address, we can skip the membership check and used the cached
// membership proof
vinfo("Found bytecode for contract address in cache: ", contract_address);
return bytecode_hint.bytecode;
return get_bytecode_from_hints(contract_class_id);
}
const auto contract_address_nullifier = AvmMerkleTreeTraceBuilder::unconstrained_silo_nullifier(
DEPLOYER_CONTRACT_ADDRESS, /*nullifier=*/contract_address);
// nullifier read hint for the contract address
NullifierReadTreeHint nullifier_read_hint = bytecode_hint.contract_instance.membership_hint;
NullifierReadTreeHint nullifier_read_hint = instance_hint.membership_hint;

// If the hinted preimage matches the contract address nullifier, the membership check will prove its existence,
// otherwise the membership check will prove that a low-leaf exists that skips the contract address nullifier.
Expand All @@ -194,8 +204,20 @@ std::vector<uint8_t> AvmTraceBuilder::get_bytecode(const FF contract_address, bo
if (exists) {
// This was a membership proof!
// Assert that the hint's exists flag matches. The flag isn't really necessary...
ASSERT(bytecode_hint.contract_instance.exists);
ASSERT(instance_hint.exists);
bytecode_membership_cache.insert(contract_address);

// The cache contains all the unique contract class ids we have seen so far
// If this bytecode retrievals have reached the number of unique contract class IDs, can't make
// more retrievals unless to already-checked contract class ids!
if (!contract_class_id_cache.contains(contract_class_id) &&
contract_class_id_cache.size() >= MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS) {
info("Limit reached for contract calls to unique class id. Limit: ",
MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS);
throw std::runtime_error("Limit reached for contract calls to unique class id.");
}
contract_class_id_cache.insert(instance_hint.contract_class_id);
return get_bytecode_from_hints(contract_class_id);
} else {
// This was a non-membership proof!
// Enforce that the tree access membership checked a low-leaf that skips the contract address nullifier.
Expand All @@ -206,9 +228,9 @@ std::vector<uint8_t> AvmTraceBuilder::get_bytecode(const FF contract_address, bo

if (exists) {
vinfo("Found bytecode for contract address: ", contract_address);
return bytecode_hint.bytecode;
return get_bytecode_from_hints(contract_class_id);
}
vinfo("Bytecode not found for contract address: ", contract_address);
info("Bytecode not found for contract address: ", contract_address);
throw std::runtime_error("Bytecode not found");
}

Expand Down Expand Up @@ -357,7 +379,7 @@ void AvmTraceBuilder::allocate_gas_for_call(uint32_t l2_gas, uint32_t da_gas)

void AvmTraceBuilder::handle_exceptional_halt()
{
const bool is_top_level = current_ext_call_ctx.context_id == 0;
const bool is_top_level = current_ext_call_ctx.is_top_level;
if (is_top_level) {
vinfo("Handling exceptional halt in top-level call. Consuming all allocated gas:");
// Consume all remaining gas.
Expand Down Expand Up @@ -577,6 +599,21 @@ AvmTraceBuilder::AvmTraceBuilder(AvmPublicInputs public_inputs,
, bytecode_trace_builder(execution_hints.all_contract_bytecode)
, merkle_tree_trace_builder(public_inputs.start_tree_snapshots)
{
AvmTraceBuilder::ExtCallCtx ext_call_ctx({ .context_id = 0,
.parent_id = 0,
.is_top_level = true,
.contract_address = FF(0),
.calldata = {},
.nested_returndata = {},
.last_pc = 0,
.success_offset = 0,
.start_l2_gas_left = 0,
.start_da_gas_left = 0,
.l2_gas_left = 0,
.da_gas_left = 0,
.internal_return_ptr_stack = {} });
current_ext_call_ctx = ext_call_ctx;

// Only allocate up to the maximum L2 gas for execution
// TODO: constrain this!
auto const l2_gas_left_after_private =
Expand Down Expand Up @@ -2111,8 +2148,7 @@ AvmError AvmTraceBuilder::op_calldata_copy(uint8_t indirect,
const uint32_t cd_offset = static_cast<uint32_t>(unconstrained_read_from_memory(cd_offset_resolved));
const uint32_t copy_size = static_cast<uint32_t>(unconstrained_read_from_memory(copy_size_offset_resolved));

// If the context_id == 0, then we are at the top level call so we read/write to a trace column
bool is_top_level = current_ext_call_ctx.context_id == 0;
bool is_top_level = current_ext_call_ctx.is_top_level;

auto calldata = current_ext_call_ctx.calldata;
if (is_ok(error)) {
Expand Down Expand Up @@ -3788,15 +3824,13 @@ AvmError AvmTraceBuilder::constrain_external_call(OpCode opcode,
std::vector<FF> calldata;
read_slice_from_memory(resolved_args_offset, args_size, calldata);

set_call_ptr(static_cast<uint8_t>(clk));

// Don't try allocating more than the gas that is actually left
const auto l2_gas_allocated_to_nested_call =
std::min(static_cast<uint32_t>(read_gas_l2.val), gas_trace_builder.get_l2_gas_left());
const auto da_gas_allocated_to_nested_call =
std::min(static_cast<uint32_t>(read_gas_da.val), gas_trace_builder.get_da_gas_left());
current_ext_call_ctx = ExtCallCtx{
.context_id = static_cast<uint8_t>(clk),
.context_id = next_context_id,
.parent_id = current_ext_call_ctx.context_id,
.is_static_call = opcode == OpCode::STATICCALL,
.contract_address = read_addr.val,
Expand All @@ -3811,8 +3845,12 @@ AvmError AvmTraceBuilder::constrain_external_call(OpCode opcode,
.internal_return_ptr_stack = {},
.tree_snapshot = {},
};
next_context_id++;

set_call_ptr(static_cast<uint8_t>(current_ext_call_ctx.context_id));

allocate_gas_for_call(l2_gas_allocated_to_nested_call, da_gas_allocated_to_nested_call);

set_pc(0);
}

Expand Down Expand Up @@ -3904,7 +3942,7 @@ ReturnDataError AvmTraceBuilder::op_return(uint8_t indirect, uint32_t ret_offset

const auto ret_size = static_cast<uint32_t>(unconstrained_read_from_memory(resolved_ret_size_offset));

const bool is_top_level = current_ext_call_ctx.context_id == 0;
const bool is_top_level = current_ext_call_ctx.is_top_level;

const auto [l2_gas_cost, da_gas_cost] = gas_trace_builder.unconstrained_compute_gas(OpCode::RETURN, ret_size);
bool out_of_gas =
Expand Down Expand Up @@ -4043,7 +4081,7 @@ ReturnDataError AvmTraceBuilder::op_revert(uint8_t indirect, uint32_t ret_offset
const auto ret_size =
is_ok(error) ? static_cast<uint32_t>(unconstrained_read_from_memory(resolved_ret_size_offset)) : 0;

const bool is_top_level = current_ext_call_ctx.context_id == 0;
const bool is_top_level = current_ext_call_ctx.is_top_level;

const auto [l2_gas_cost, da_gas_cost] = gas_trace_builder.unconstrained_compute_gas(OpCode::REVERT_8, ret_size);
bool out_of_gas =
Expand Down
Loading

0 comments on commit 9d5e289

Please sign in to comment.