diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/tests/arithmetic.test.cpp b/barretenberg/cpp/src/barretenberg/vm/avm/tests/arithmetic.test.cpp index 8c69de39672..3d2ddb3ccdd 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/tests/arithmetic.test.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/tests/arithmetic.test.cpp @@ -220,19 +220,7 @@ class AvmArithmeticTests : public ::testing::Test { trace_builder = AvmTraceBuilder(public_inputs, {}, 0).set_full_precomputed_tables(false).set_range_check_required(false); trace_builder.set_all_calldata(calldata); - AvmTraceBuilder::ExtCallCtx ext_call_ctx({ .context_id = 0, - .parent_id = 0, - .contract_address = FF(0), - .calldata = calldata, - .nested_returndata = {}, - .last_pc = 0, - .success_offset = 0, - .start_l2_gas_left = 0, - .start_da_gas_left = 0, - .l2_gas_left = 0, - .da_gas_left = 0, - .internal_return_ptr_stack = {} }); - trace_builder.current_ext_call_ctx = ext_call_ctx; + trace_builder.current_ext_call_ctx.calldata = calldata; } // Generate a trace with an EQ opcode operation. @@ -614,6 +602,7 @@ TEST_F(AvmArithmeticTestsFF, fDivisionByZeroError) // We check that the operator error flag is raised. TEST_F(AvmArithmeticTestsFF, fDivisionZeroByZeroError) { + gen_trace_builder({}); // Memory layout: [0,0,0,0,0,0,....] trace_builder.op_fdiv(0, 0, 1, 2); // [0,0,0,0,0,0....] trace_builder.op_set(0, 0, 100, AvmMemoryTag::U32); diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/tests/cast.test.cpp b/barretenberg/cpp/src/barretenberg/vm/avm/tests/cast.test.cpp index 8086db12147..51364e48153 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/tests/cast.test.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/tests/cast.test.cpp @@ -188,6 +188,7 @@ TEST_F(AvmCastTests, truncationFFToU16ModMinus1) trace_builder.set_all_calldata(calldata); AvmTraceBuilder::ExtCallCtx ext_call_ctx({ .context_id = 0, .parent_id = 0, + .is_top_level = true, .contract_address = FF(0), .calldata = calldata, .nested_returndata = {}, @@ -219,6 +220,7 @@ TEST_F(AvmCastTests, truncationFFToU16ModMinus2) trace_builder.set_all_calldata(calldata); AvmTraceBuilder::ExtCallCtx ext_call_ctx({ .context_id = 0, .parent_id = 0, + .is_top_level = true, .contract_address = FF(0), .calldata = calldata, .nested_returndata = {}, diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/tests/execution.test.cpp b/barretenberg/cpp/src/barretenberg/vm/avm/tests/execution.test.cpp index 86b62b8a64e..5ed0171cd44 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/tests/execution.test.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/tests/execution.test.cpp @@ -84,6 +84,7 @@ class AvmExecutionTests : public ::testing::Test { auto [contract_class_id, contract_instance] = gen_test_contract_hint(bytecode); auto execution_hints = ExecutionHints().with_avm_contract_bytecode( { AvmContractBytecode{ bytecode, contract_instance, contract_class_id } }); + execution_hints.contract_instance_hints.emplace(contract_instance.address, contract_instance); vinfo("Calling execution::gen_trace"); return AvmExecutionTests::gen_trace(bytecode, calldata, public_inputs, returndata, execution_hints); @@ -98,6 +99,7 @@ class AvmExecutionTests : public ::testing::Test { auto [contract_class_id, contract_instance] = gen_test_contract_hint(bytecode); execution_hints.with_avm_contract_bytecode( { AvmContractBytecode{ bytecode, contract_instance, contract_class_id } }); + execution_hints.contract_instance_hints.emplace(contract_instance.address, contract_instance); // These are magic values because of how some tests work! Don't change them public_inputs.public_app_logic_call_requests[0].contract_address = contract_instance.address; diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/tests/slice.test.cpp b/barretenberg/cpp/src/barretenberg/vm/avm/tests/slice.test.cpp index ea8830a353c..dbdd0ae929f 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/tests/slice.test.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/tests/slice.test.cpp @@ -30,19 +30,7 @@ class AvmSliceTests : public ::testing::Test { trace_builder = AvmTraceBuilder(public_inputs, {}, 0).set_full_precomputed_tables(false).set_range_check_required(false); trace_builder.set_all_calldata(calldata); - AvmTraceBuilder::ExtCallCtx ext_call_ctx({ .context_id = 0, - .parent_id = 0, - .contract_address = FF(0), - .calldata = calldata, - .nested_returndata = {}, - .last_pc = 0, - .success_offset = 0, - .start_l2_gas_left = 0, - .start_da_gas_left = 0, - .l2_gas_left = 0, - .da_gas_left = 0, - .internal_return_ptr_stack = {} }); - trace_builder.current_ext_call_ctx = ext_call_ctx; + trace_builder.current_ext_call_ctx.calldata = calldata; this->calldata = calldata; } diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/errors.hpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/errors.hpp index eb17641710b..88ef959bec4 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/errors.hpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/errors.hpp @@ -22,7 +22,7 @@ enum class AvmError : uint32_t { SIDE_EFFECT_LIMIT_REACHED, OUT_OF_GAS, STATIC_CALL_ALTERATION, - NO_BYTECODE_FOUND, + FAILED_BYTECODE_RETRIEVAL, }; } // namespace bb::avm_trace diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.cpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.cpp index e5bbbeb2251..a56f7f9008f 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.cpp @@ -427,13 +427,13 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder, AvmError error = AvmError::NO_ERROR; // These hints help us to set up first call ctx - uint32_t clk = trace_builder.get_clk(); - auto context_id = static_cast(clk); + auto context_id = trace_builder.next_context_id; uint32_t l2_gas_allocated_to_enqueued_call = trace_builder.get_l2_gas_left(); uint32_t da_gas_allocated_to_enqueued_call = trace_builder.get_da_gas_left(); trace_builder.current_ext_call_ctx = AvmTraceBuilder::ExtCallCtx{ .context_id = context_id, .parent_id = 0, + .is_top_level = true, .contract_address = enqueued_call_hint.contract_address, .calldata = enqueued_call_hint.calldata, .nested_returndata = {}, @@ -445,13 +445,14 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder, .da_gas_left = da_gas_allocated_to_enqueued_call, .internal_return_ptr_stack = {}, }; + trace_builder.next_context_id++; // Find the bytecode based on contract address of the public call request std::vector bytecode; try { bytecode = trace_builder.get_bytecode(trace_builder.current_ext_call_ctx.contract_address, check_bytecode_membership); } catch ([[maybe_unused]] const std::runtime_error& e) { - info("AVM enqueued call exceptionally halted. Error: No bytecode found for enqueued call"); + info("AVM enqueued call exceptionally halted. Failed bytecode retrieval."); // FIXME: properly handle case when bytecode is not found! // For now, we add a dummy row in main trace to mutate later. // Dummy row in main trace to mutate afterwards. @@ -459,8 +460,7 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder, // we need at least one row in the execution trace to then mutate and say "it halted and consumed all gas!" trace_builder.op_add(0, 0, 0, 0, OpCode::ADD_8); trace_builder.handle_exceptional_halt(); - return AvmError::NO_BYTECODE_FOUND; - ; + return AvmError::FAILED_BYTECODE_RETRIEVAL; } trace_builder.allocate_gas_for_call(l2_gas_allocated_to_enqueued_call, da_gas_allocated_to_enqueued_call); @@ -471,7 +471,8 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder, uint32_t pc = 0; std::stack debug_counter_stack; uint32_t counter = 0; - trace_builder.set_call_ptr(context_id); + // FIXME: this cast means that we can have duplicate call ptrs since clk will end up way bigger than 256 + trace_builder.set_call_ptr(static_cast(context_id)); while (is_ok(error) && (pc = trace_builder.get_pc()) < bytecode.size()) { auto [inst, parse_error] = Deserialization::parse(bytecode, pc); @@ -881,7 +882,8 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder, bytecode = trace_builder.get_bytecode(trace_builder.current_ext_call_ctx.contract_address, /*check_membership=*/true); } catch ([[maybe_unused]] const std::runtime_error& e) { - error = AvmError::NO_BYTECODE_FOUND; + info("AVM CALL failed bytecode retrieval."); + error = AvmError::FAILED_BYTECODE_RETRIEVAL; } debug_counter_stack.push(counter); counter = 0; @@ -901,7 +903,8 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder, bytecode = trace_builder.get_bytecode(trace_builder.current_ext_call_ctx.contract_address, /*check_membership=*/true); } catch ([[maybe_unused]] const std::runtime_error& e) { - error = AvmError::NO_BYTECODE_FOUND; + info("AVM STATICCALL failed bytecode retrieval."); + error = AvmError::FAILED_BYTECODE_RETRIEVAL; } debug_counter_stack.push(counter); counter = 0; @@ -1033,7 +1036,7 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder, } if (!is_ok(error)) { - const bool is_top_level = trace_builder.current_ext_call_ctx.context_id == 0; + const bool is_top_level = trace_builder.current_ext_call_ctx.is_top_level; auto const error_ic = counter - 1; // Need adjustement as counter increment occurs in loop body std::string call_type = is_top_level ? "enqueued" : "nested"; diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/helper.cpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/helper.cpp index 7f525bd136e..23f30d12807 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/helper.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/helper.cpp @@ -135,8 +135,8 @@ std::string to_name(AvmError error) return "SIDE EFFECT LIMIT REACHED"; case AvmError::OUT_OF_GAS: return "OUT OF GAS"; - case AvmError::NO_BYTECODE_FOUND: - return "NO BYTECODE FOUND"; + case AvmError::FAILED_BYTECODE_RETRIEVAL: + return "FAILED BYTECODE RETRIEVAL"; default: throw std::runtime_error("Invalid error type"); break; diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp index f632962d980..367e854a910 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp @@ -157,15 +157,25 @@ void AvmTraceBuilder::rollback_to_non_revertible_checkpoint() merkle_tree_trace_builder.rollback_to_non_revertible_checkpoint(); } -std::vector AvmTraceBuilder::get_bytecode(const FF contract_address, bool check_membership) +std::vector AvmTraceBuilder::get_bytecode_from_hints(const FF contract_class_id) { - auto clk = static_cast(main_trace.size()) + 1; - // Find the bytecode based on contract address of the public call request + // Find the bytecode based on the hinted contract class id + // TODO: still need to make sure that the contract address does correspond to this class id const AvmContractBytecode bytecode_hint = - *std::ranges::find_if(execution_hints.all_contract_bytecode, [contract_address](const auto& contract) { - return contract.contract_instance.address == contract_address; + *std::ranges::find_if(execution_hints.all_contract_bytecode, [contract_class_id](const auto& contract) { + return contract.contract_instance.contract_class_id == contract_class_id; }); + return bytecode_hint.bytecode; +} + +std::vector AvmTraceBuilder::get_bytecode(const FF contract_address, bool check_membership) +{ + auto clk = static_cast(main_trace.size()) + 1; + + ASSERT(execution_hints.contract_instance_hints.contains(contract_address)); + const ContractInstanceHint instance_hint = execution_hints.contract_instance_hints.at(contract_address); + const FF contract_class_id = instance_hint.contract_class_id; bool exists = true; if (check_membership && !is_canonical(contract_address)) { @@ -173,12 +183,12 @@ std::vector AvmTraceBuilder::get_bytecode(const FF contract_address, bo // If we have already seen the contract address, we can skip the membership check and used the cached // membership proof vinfo("Found bytecode for contract address in cache: ", contract_address); - return bytecode_hint.bytecode; + return get_bytecode_from_hints(contract_class_id); } const auto contract_address_nullifier = AvmMerkleTreeTraceBuilder::unconstrained_silo_nullifier( DEPLOYER_CONTRACT_ADDRESS, /*nullifier=*/contract_address); // nullifier read hint for the contract address - NullifierReadTreeHint nullifier_read_hint = bytecode_hint.contract_instance.membership_hint; + NullifierReadTreeHint nullifier_read_hint = instance_hint.membership_hint; // If the hinted preimage matches the contract address nullifier, the membership check will prove its existence, // otherwise the membership check will prove that a low-leaf exists that skips the contract address nullifier. @@ -194,8 +204,20 @@ std::vector AvmTraceBuilder::get_bytecode(const FF contract_address, bo if (exists) { // This was a membership proof! // Assert that the hint's exists flag matches. The flag isn't really necessary... - ASSERT(bytecode_hint.contract_instance.exists); + ASSERT(instance_hint.exists); bytecode_membership_cache.insert(contract_address); + + // The cache contains all the unique contract class ids we have seen so far + // If this bytecode retrievals have reached the number of unique contract class IDs, can't make + // more retrievals unless to already-checked contract class ids! + if (!contract_class_id_cache.contains(contract_class_id) && + contract_class_id_cache.size() >= MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS) { + info("Limit reached for contract calls to unique class id. Limit: ", + MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS); + throw std::runtime_error("Limit reached for contract calls to unique class id."); + } + contract_class_id_cache.insert(instance_hint.contract_class_id); + return get_bytecode_from_hints(contract_class_id); } else { // This was a non-membership proof! // Enforce that the tree access membership checked a low-leaf that skips the contract address nullifier. @@ -206,9 +228,9 @@ std::vector AvmTraceBuilder::get_bytecode(const FF contract_address, bo if (exists) { vinfo("Found bytecode for contract address: ", contract_address); - return bytecode_hint.bytecode; + return get_bytecode_from_hints(contract_class_id); } - vinfo("Bytecode not found for contract address: ", contract_address); + info("Bytecode not found for contract address: ", contract_address); throw std::runtime_error("Bytecode not found"); } @@ -357,7 +379,7 @@ void AvmTraceBuilder::allocate_gas_for_call(uint32_t l2_gas, uint32_t da_gas) void AvmTraceBuilder::handle_exceptional_halt() { - const bool is_top_level = current_ext_call_ctx.context_id == 0; + const bool is_top_level = current_ext_call_ctx.is_top_level; if (is_top_level) { vinfo("Handling exceptional halt in top-level call. Consuming all allocated gas:"); // Consume all remaining gas. @@ -577,6 +599,21 @@ AvmTraceBuilder::AvmTraceBuilder(AvmPublicInputs public_inputs, , bytecode_trace_builder(execution_hints.all_contract_bytecode) , merkle_tree_trace_builder(public_inputs.start_tree_snapshots) { + AvmTraceBuilder::ExtCallCtx ext_call_ctx({ .context_id = 0, + .parent_id = 0, + .is_top_level = true, + .contract_address = FF(0), + .calldata = {}, + .nested_returndata = {}, + .last_pc = 0, + .success_offset = 0, + .start_l2_gas_left = 0, + .start_da_gas_left = 0, + .l2_gas_left = 0, + .da_gas_left = 0, + .internal_return_ptr_stack = {} }); + current_ext_call_ctx = ext_call_ctx; + // Only allocate up to the maximum L2 gas for execution // TODO: constrain this! auto const l2_gas_left_after_private = @@ -2111,8 +2148,7 @@ AvmError AvmTraceBuilder::op_calldata_copy(uint8_t indirect, const uint32_t cd_offset = static_cast(unconstrained_read_from_memory(cd_offset_resolved)); const uint32_t copy_size = static_cast(unconstrained_read_from_memory(copy_size_offset_resolved)); - // If the context_id == 0, then we are at the top level call so we read/write to a trace column - bool is_top_level = current_ext_call_ctx.context_id == 0; + bool is_top_level = current_ext_call_ctx.is_top_level; auto calldata = current_ext_call_ctx.calldata; if (is_ok(error)) { @@ -3788,15 +3824,13 @@ AvmError AvmTraceBuilder::constrain_external_call(OpCode opcode, std::vector calldata; read_slice_from_memory(resolved_args_offset, args_size, calldata); - set_call_ptr(static_cast(clk)); - // Don't try allocating more than the gas that is actually left const auto l2_gas_allocated_to_nested_call = std::min(static_cast(read_gas_l2.val), gas_trace_builder.get_l2_gas_left()); const auto da_gas_allocated_to_nested_call = std::min(static_cast(read_gas_da.val), gas_trace_builder.get_da_gas_left()); current_ext_call_ctx = ExtCallCtx{ - .context_id = static_cast(clk), + .context_id = next_context_id, .parent_id = current_ext_call_ctx.context_id, .is_static_call = opcode == OpCode::STATICCALL, .contract_address = read_addr.val, @@ -3811,8 +3845,12 @@ AvmError AvmTraceBuilder::constrain_external_call(OpCode opcode, .internal_return_ptr_stack = {}, .tree_snapshot = {}, }; + next_context_id++; + + set_call_ptr(static_cast(current_ext_call_ctx.context_id)); allocate_gas_for_call(l2_gas_allocated_to_nested_call, da_gas_allocated_to_nested_call); + set_pc(0); } @@ -3904,7 +3942,7 @@ ReturnDataError AvmTraceBuilder::op_return(uint8_t indirect, uint32_t ret_offset const auto ret_size = static_cast(unconstrained_read_from_memory(resolved_ret_size_offset)); - const bool is_top_level = current_ext_call_ctx.context_id == 0; + const bool is_top_level = current_ext_call_ctx.is_top_level; const auto [l2_gas_cost, da_gas_cost] = gas_trace_builder.unconstrained_compute_gas(OpCode::RETURN, ret_size); bool out_of_gas = @@ -4043,7 +4081,7 @@ ReturnDataError AvmTraceBuilder::op_revert(uint8_t indirect, uint32_t ret_offset const auto ret_size = is_ok(error) ? static_cast(unconstrained_read_from_memory(resolved_ret_size_offset)) : 0; - const bool is_top_level = current_ext_call_ctx.context_id == 0; + const bool is_top_level = current_ext_call_ctx.is_top_level; const auto [l2_gas_cost, da_gas_cost] = gas_trace_builder.unconstrained_compute_gas(OpCode::REVERT_8, ret_size); bool out_of_gas = diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.hpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.hpp index b988892dc50..f228501aa14 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.hpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.hpp @@ -233,6 +233,8 @@ class AvmTraceBuilder { void checkpoint_non_revertible_state(); void rollback_to_non_revertible_checkpoint(); std::vector get_bytecode(const FF contract_address, bool check_membership = false); + // Used to track the unique class ids, could also be used to cache membership checks of class ids + std::unordered_set contract_class_id_cache; std::unordered_set bytecode_membership_cache; void insert_private_state(const std::vector& siloed_nullifiers, const std::vector& unique_note_hashes); void insert_private_revertible_state(const std::vector& siloed_nullifiers, @@ -253,6 +255,24 @@ class AvmTraceBuilder { full_precomputed_tables = required; return *this; } + AvmTraceBuilder& with_default_ctx() + { + AvmTraceBuilder::ExtCallCtx ext_call_ctx({ .context_id = 0, + .parent_id = 0, + .is_top_level = true, + .contract_address = FF(0), + .calldata = {}, + .nested_returndata = {}, + .last_pc = 0, + .success_offset = 0, + .start_l2_gas_left = 0, + .start_da_gas_left = 0, + .l2_gas_left = 0, + .da_gas_left = 0, + .internal_return_ptr_stack = {} }); + current_ext_call_ctx = ext_call_ctx; + return *this; + } struct MemOp { bool is_indirect; @@ -266,6 +286,7 @@ class AvmTraceBuilder { struct ExtCallCtx { uint32_t context_id; // This is the unique id of the ctx, we'll use the clk uint32_t parent_id; + bool is_top_level = false; bool is_static_call = false; FF contract_address{}; std::vector calldata; @@ -281,6 +302,7 @@ class AvmTraceBuilder { std::unordered_set public_data_unique_writes; }; + uint32_t next_context_id = 0; ExtCallCtx current_ext_call_ctx{}; std::stack external_call_ctx_stack; @@ -326,6 +348,7 @@ class AvmTraceBuilder { AvmBytecodeTraceBuilder bytecode_trace_builder; AvmMerkleTreeTraceBuilder merkle_tree_trace_builder; + std::vector get_bytecode_from_hints(const FF contract_class_id); RowWithError create_kernel_lookup_opcode(uint8_t indirect, uint32_t dst_offset, FF value, AvmMemoryTag w_tag); RowWithError create_kernel_output_opcode(uint8_t indirect, uint32_t clk, uint32_t data_offset); diff --git a/barretenberg/cpp/src/barretenberg/vm/aztec_constants.hpp b/barretenberg/cpp/src/barretenberg/vm/aztec_constants.hpp index d180cb2508b..518730ac15a 100644 --- a/barretenberg/cpp/src/barretenberg/vm/aztec_constants.hpp +++ b/barretenberg/cpp/src/barretenberg/vm/aztec_constants.hpp @@ -48,6 +48,7 @@ #define AVM_ACCUMULATED_DATA_LENGTH 320 #define AVM_CIRCUIT_PUBLIC_INPUTS_LENGTH 1011 #define AVM_VERIFICATION_KEY_LENGTH_IN_FIELDS 86 +#define MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS 21 #define AVM_PROOF_LENGTH_IN_FIELDS 4155 #define AVM_PUBLIC_COLUMN_MAX_SIZE 1024 #define AVM_PUBLIC_INPUTS_FLATTENED_SIZE 2915 diff --git a/l1-contracts/src/core/libraries/ConstantsGen.sol b/l1-contracts/src/core/libraries/ConstantsGen.sol index 6e274c79760..fdddf701b5a 100644 --- a/l1-contracts/src/core/libraries/ConstantsGen.sol +++ b/l1-contracts/src/core/libraries/ConstantsGen.sol @@ -225,6 +225,8 @@ library Constants { uint256 internal constant TUBE_PROOF_LENGTH = 459; uint256 internal constant HONK_VERIFICATION_KEY_LENGTH_IN_FIELDS = 128; uint256 internal constant CLIENT_IVC_VERIFICATION_KEY_LENGTH_IN_FIELDS = 143; + uint256 internal constant MAX_PUBLIC_BYTECODE_SIZE_IN_BYTES = 96000; + uint256 internal constant MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS = 21; uint256 internal constant MEM_TAG_FF = 0; uint256 internal constant MEM_TAG_U1 = 1; uint256 internal constant MEM_TAG_U8 = 2; diff --git a/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr b/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr index c18b0bfaa9f..fbf8816d8b2 100644 --- a/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr +++ b/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr @@ -44,6 +44,7 @@ contract AvmTest { scalar::Scalar, }; use dep::aztec::protocol_types::{ + constants::MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS, contract_class_id::ContractClassId, storage::map::derive_storage_slot_in_map, }; use dep::aztec::state_vars::PublicMutable; @@ -596,6 +597,18 @@ contract AvmTest { AvmTest::at(context.this_address()).add_args_return(arg_a, arg_b).call(&mut context) } + #[public] + fn nested_call_to_add_n_times_different_addresses( + addrs: [AztecAddress; MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS + 2], + ) { + for i in 0..MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS + 2 { + let addr = addrs[i]; + if addr != AztecAddress::empty() { + let _ = AvmTest::at(addr).add_args_return(1, 2).call(&mut context); + } + } + } + // Indirectly call_static the external call opcode to initiate a nested call to the add function #[public] fn nested_static_call_to_add(arg_a: Field, arg_b: Field) -> pub Field { diff --git a/noir-projects/noir-protocol-circuits/crates/types/src/constants.nr b/noir-projects/noir-protocol-circuits/crates/types/src/constants.nr index 12a824936fc..88d51b89f60 100644 --- a/noir-projects/noir-protocol-circuits/crates/types/src/constants.nr +++ b/noir-projects/noir-protocol-circuits/crates/types/src/constants.nr @@ -474,6 +474,12 @@ pub global CLIENT_IVC_VERIFICATION_KEY_LENGTH_IN_FIELDS: u32 = 143; // size of a // 21 above refers to the constant AvmFlavor::NUM_PRECOMPUTED_ENTITIES pub global AVM_VERIFICATION_KEY_LENGTH_IN_FIELDS: u32 = 2 + 21 * 4; +// Setting limits for MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS +// This value is determined by the length of the AVM trace and the MAX_PUBLIC_BYTECODE_SIZE_IN_BYTES +// (i.e. 2^21 / MAX_PUBLIC_BYTECODE_SIZE_IN_BYTES ==> 2^21 / 96,000 = 21 +pub global MAX_PUBLIC_BYTECODE_SIZE_IN_BYTES: u32 = MAX_PACKED_PUBLIC_BYTECODE_SIZE_IN_FIELDS * 32; +pub global MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS: u32 = 21; + // `AVM_PROOF_LENGTH_IN_FIELDS` must be updated when AVM circuit changes. // To determine latest value, hover `COMPUTED_AVM_PROOF_LENGTH_IN_FIELDS` // in barretenberg/cpp/src/barretenberg/vm/avm/generated/flavor.hpp diff --git a/yarn-project/bb-prover/src/avm_proving.test.ts b/yarn-project/bb-prover/src/avm_proving.test.ts index ab3aad9dd39..0adfbab0607 100644 --- a/yarn-project/bb-prover/src/avm_proving.test.ts +++ b/yarn-project/bb-prover/src/avm_proving.test.ts @@ -2,13 +2,17 @@ import { MAX_L2_TO_L1_MSGS_PER_TX, MAX_NOTE_HASHES_PER_TX, MAX_NULLIFIERS_PER_TX, + MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS, MAX_PUBLIC_DATA_UPDATE_REQUESTS_PER_TX, MAX_UNENCRYPTED_LOGS_PER_TX, VerificationKeyData, } from '@aztec/circuits.js'; import { Fr } from '@aztec/foundation/fields'; import { createLogger } from '@aztec/foundation/log'; -import { simulateAvmTestContractGenerateCircuitInputs } from '@aztec/simulator/public/fixtures'; +import { + MockedAvmTestContractDataSource, + simulateAvmTestContractGenerateCircuitInputs, +} from '@aztec/simulator/public/fixtures'; import fs from 'node:fs/promises'; import { tmpdir } from 'node:os'; @@ -17,7 +21,7 @@ import path from 'path'; import { type BBSuccess, BB_RESULT, generateAvmProof, verifyAvmProof } from './bb/execute.js'; import { extractAvmVkData } from './verification_key/verification_key_data.js'; -const TIMEOUT = 180_000; +const TIMEOUT = 300_000; describe('AVM WitGen, proof generation and verification', () => { it( @@ -25,7 +29,7 @@ describe('AVM WitGen, proof generation and verification', () => { async () => { await proveAndVerifyAvmTestContract( 'bulk_testing', - /*calldata=*/ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10].map(x => new Fr(x)), + /*args=*/ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10].map(x => new Fr(x)), ); }, TIMEOUT, @@ -35,7 +39,7 @@ describe('AVM WitGen, proof generation and verification', () => { async () => { await proveAndVerifyAvmTestContract( 'n_storage_writes', - /*calldata=*/ [new Fr(MAX_PUBLIC_DATA_UPDATE_REQUESTS_PER_TX + 1)], + /*args=*/ [new Fr(MAX_PUBLIC_DATA_UPDATE_REQUESTS_PER_TX + 1)], /*expectRevert=*/ true, ); }, @@ -46,7 +50,7 @@ describe('AVM WitGen, proof generation and verification', () => { async () => { await proveAndVerifyAvmTestContract( 'n_new_note_hashes', - /*calldata=*/ [new Fr(MAX_NOTE_HASHES_PER_TX + 1)], + /*args=*/ [new Fr(MAX_NOTE_HASHES_PER_TX + 1)], /*expectRevert=*/ true, ); }, @@ -57,7 +61,7 @@ describe('AVM WitGen, proof generation and verification', () => { async () => { await proveAndVerifyAvmTestContract( 'n_new_nullifiers', - /*calldata=*/ [new Fr(MAX_NULLIFIERS_PER_TX + 1)], + /*args=*/ [new Fr(MAX_NULLIFIERS_PER_TX + 1)], /*expectRevert=*/ true, ); }, @@ -68,7 +72,7 @@ describe('AVM WitGen, proof generation and verification', () => { async () => { await proveAndVerifyAvmTestContract( 'n_new_l2_to_l1_msgs', - /*calldata=*/ [new Fr(MAX_L2_TO_L1_MSGS_PER_TX + 1)], + /*args=*/ [new Fr(MAX_L2_TO_L1_MSGS_PER_TX + 1)], /*expectRevert=*/ true, ); }, @@ -79,8 +83,51 @@ describe('AVM WitGen, proof generation and verification', () => { async () => { await proveAndVerifyAvmTestContract( 'n_new_unencrypted_logs', - /*calldata=*/ [new Fr(MAX_UNENCRYPTED_LOGS_PER_TX + 1)], + /*args=*/ [new Fr(MAX_UNENCRYPTED_LOGS_PER_TX + 1)], + /*expectRevert=*/ true, + ); + }, + TIMEOUT, + ); + it( + 'Should prove and verify test that calls the max number of unique contract classes', + async () => { + const contractDataSource = new MockedAvmTestContractDataSource(); + // args is initialized to MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS contract addresses with unique class IDs + const args = Array.from(contractDataSource.contractInstances.values()) + .map(instance => instance.address.toField()) + .slice(0, MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS); + // include the first contract again again at the end to ensure that we can call it even after the limit is reached + args.push(args[0]); + // include another contract address that reuses a class ID to ensure that we can call it even after the limit is reached + args.push(contractDataSource.instanceSameClassAsFirstContract.address.toField()); + await proveAndVerifyAvmTestContract( + 'nested_call_to_add_n_times_different_addresses', + args, + /*expectRevert=*/ false, + /*skipContractDeployments=*/ false, + contractDataSource, + ); + }, + TIMEOUT, + ); + it( + 'Should prove and verify test that attempts too many calls to unique contract class ids', + async () => { + const contractDataSource = new MockedAvmTestContractDataSource(); + // args is initialized to MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS+1 contract addresses with unique class IDs + // should fail because we are trying to call MAX+1 unique class IDs + const args = Array.from(contractDataSource.contractInstances.values()).map(instance => + instance.address.toField(), + ); + // push an empty one (just padding to match function calldata size of MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS+2) + args.push(new Fr(0)); + await proveAndVerifyAvmTestContract( + 'nested_call_to_add_n_times_different_addresses', + args, /*expectRevert=*/ true, + /*skipContractDeployments=*/ false, + contractDataSource, ); }, TIMEOUT, @@ -88,14 +135,14 @@ describe('AVM WitGen, proof generation and verification', () => { it( 'Should prove and verify a top-level exceptional halt', async () => { - await proveAndVerifyAvmTestContract('divide_by_zero', /*calldata=*/ [], /*expectRevert=*/ true); + await proveAndVerifyAvmTestContract('divide_by_zero', /*args=*/ [], /*expectRevert=*/ true); }, TIMEOUT, ); it( 'Should prove and verify a nested exceptional halt that propagates to top-level', async () => { - await proveAndVerifyAvmTestContract('external_call_to_divide_by_zero', /*calldata=*/ [], /*expectRevert=*/ true); + await proveAndVerifyAvmTestContract('external_call_to_divide_by_zero', /*args=*/ [], /*expectRevert=*/ true); }, TIMEOUT, ); @@ -104,7 +151,7 @@ describe('AVM WitGen, proof generation and verification', () => { async () => { await proveAndVerifyAvmTestContract( 'external_call_to_divide_by_zero_recovers', - /*calldata=*/ [], + /*args=*/ [], /*expectRevert=*/ false, ); }, @@ -113,14 +160,14 @@ describe('AVM WitGen, proof generation and verification', () => { it( 'Should prove and verify an exceptional halt due to a nested call to non-existent contract that is propagated to top-level', async () => { - await proveAndVerifyAvmTestContract('nested_call_to_nothing', /*calldata=*/ [], /*expectRevert=*/ true); + await proveAndVerifyAvmTestContract('nested_call_to_nothing', /*args=*/ [], /*expectRevert=*/ true); }, TIMEOUT, ); it( 'Should prove and verify an exceptional halt due to a nested call to non-existent contract that is recovered from in caller', async () => { - await proveAndVerifyAvmTestContract('nested_call_to_nothing_recovers', /*calldata=*/ [], /*expectRevert=*/ false); + await proveAndVerifyAvmTestContract('nested_call_to_nothing_recovers', /*args=*/ [], /*expectRevert=*/ false); }, TIMEOUT, ); @@ -129,7 +176,7 @@ describe('AVM WitGen, proof generation and verification', () => { async () => { await proveAndVerifyAvmTestContract( 'add_args_return', - /*calldata=*/ [new Fr(1), new Fr(2)], + /*args=*/ [new Fr(1), new Fr(2)], /*expectRevert=*/ true, /*skipContractDeployments=*/ true, ); @@ -140,15 +187,16 @@ describe('AVM WitGen, proof generation and verification', () => { async function proveAndVerifyAvmTestContract( functionName: string, - calldata: Fr[] = [], + args: Fr[] = [], expectRevert = false, skipContractDeployments = false, + contractDataSource = new MockedAvmTestContractDataSource(skipContractDeployments), ) { const avmCircuitInputs = await simulateAvmTestContractGenerateCircuitInputs( functionName, - calldata, + args, expectRevert, - skipContractDeployments, + contractDataSource, ); const internalLogger = createLogger('bb-prover:avm-proving-test'); diff --git a/yarn-project/circuits.js/src/constants.gen.ts b/yarn-project/circuits.js/src/constants.gen.ts index ac7e24b713c..4139dd7aab5 100644 --- a/yarn-project/circuits.js/src/constants.gen.ts +++ b/yarn-project/circuits.js/src/constants.gen.ts @@ -206,6 +206,8 @@ export const TUBE_PROOF_LENGTH = 459; export const HONK_VERIFICATION_KEY_LENGTH_IN_FIELDS = 128; export const CLIENT_IVC_VERIFICATION_KEY_LENGTH_IN_FIELDS = 143; export const AVM_VERIFICATION_KEY_LENGTH_IN_FIELDS = 86; +export const MAX_PUBLIC_BYTECODE_SIZE_IN_BYTES = 96000; +export const MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS = 21; export const AVM_PROOF_LENGTH_IN_FIELDS = 4155; export const AVM_PUBLIC_COLUMN_MAX_SIZE = 1024; export const AVM_PUBLIC_INPUTS_FLATTENED_SIZE = 2915; diff --git a/yarn-project/circuits.js/src/scripts/constants.in.ts b/yarn-project/circuits.js/src/scripts/constants.in.ts index 39c2270d6fd..acf5a3a0181 100644 --- a/yarn-project/circuits.js/src/scripts/constants.in.ts +++ b/yarn-project/circuits.js/src/scripts/constants.in.ts @@ -91,6 +91,7 @@ const CPP_CONSTANTS = [ 'FEE_JUICE_ADDRESS', 'ROUTER_ADDRESS', 'FEE_JUICE_BALANCES_SLOT', + 'MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS', ]; const CPP_GENERATORS: string[] = [ diff --git a/yarn-project/circuits.js/src/structs/avm/avm.ts b/yarn-project/circuits.js/src/structs/avm/avm.ts index 4447516818d..d65fea2beb1 100644 --- a/yarn-project/circuits.js/src/structs/avm/avm.ts +++ b/yarn-project/circuits.js/src/structs/avm/avm.ts @@ -683,7 +683,6 @@ export class AvmExecutionHints { public readonly enqueuedCalls: Vector; public readonly contractInstances: Vector; - public readonly contractBytecodeHints: Vector; public readonly publicDataReads: Vector; public readonly publicDataWrites: Vector; @@ -696,7 +695,8 @@ export class AvmExecutionHints { constructor( enqueuedCalls: AvmEnqueuedCallHint[], contractInstances: AvmContractInstanceHint[], - contractBytecodeHints: AvmContractBytecodeHints[], + // string here is the contract class id + public contractBytecodeHints: Map, publicDataReads: AvmPublicDataReadTreeHint[], publicDataWrites: AvmPublicDataWriteTreeHint[], nullifierReads: AvmNullifierReadTreeHint[], @@ -707,7 +707,6 @@ export class AvmExecutionHints { ) { this.enqueuedCalls = new Vector(enqueuedCalls); this.contractInstances = new Vector(contractInstances); - this.contractBytecodeHints = new Vector(contractBytecodeHints); this.publicDataReads = new Vector(publicDataReads); this.publicDataWrites = new Vector(publicDataWrites); this.nullifierReads = new Vector(nullifierReads); @@ -722,7 +721,7 @@ export class AvmExecutionHints { * @returns an empty instance. */ static empty() { - return new AvmExecutionHints([], [], [], [], [], [], [], [], [], []); + return new AvmExecutionHints([], [], new Map(), [], [], [], [], [], [], []); } /** @@ -749,7 +748,7 @@ export class AvmExecutionHints { return ( this.enqueuedCalls.items.length == 0 && this.contractInstances.items.length == 0 && - this.contractBytecodeHints.items.length == 0 && + this.contractBytecodeHints.size == 0 && this.publicDataReads.items.length == 0 && this.publicDataWrites.items.length == 0 && this.nullifierReads.items.length == 0 && @@ -769,7 +768,7 @@ export class AvmExecutionHints { return new AvmExecutionHints( fields.enqueuedCalls.items, fields.contractInstances.items, - fields.contractBytecodeHints.items, + fields.contractBytecodeHints, fields.publicDataReads.items, fields.publicDataWrites.items, fields.nullifierReads.items, @@ -789,7 +788,7 @@ export class AvmExecutionHints { return [ fields.enqueuedCalls, fields.contractInstances, - fields.contractBytecodeHints, + new Vector(Array.from(fields.contractBytecodeHints.values())), fields.publicDataReads, fields.publicDataWrites, fields.nullifierReads, @@ -807,10 +806,20 @@ export class AvmExecutionHints { */ static fromBuffer(buff: Buffer | BufferReader): AvmExecutionHints { const reader = BufferReader.asReader(buff); + + const readMap = (r: BufferReader) => { + const map = new Map(); + const values = r.readVector(AvmContractBytecodeHints); + for (const value of values) { + map.set(value.contractInstanceHint.address.toString(), value); + } + return map; + }; + return new AvmExecutionHints( reader.readVector(AvmEnqueuedCallHint), reader.readVector(AvmContractInstanceHint), - reader.readVector(AvmContractBytecodeHints), + readMap(reader), reader.readVector(AvmPublicDataReadTreeHint), reader.readVector(AvmPublicDataWriteTreeHint), reader.readVector(AvmNullifierReadTreeHint), diff --git a/yarn-project/circuits.js/src/tests/factories.ts b/yarn-project/circuits.js/src/tests/factories.ts index a606dfb8684..ad683470c36 100644 --- a/yarn-project/circuits.js/src/tests/factories.ts +++ b/yarn-project/circuits.js/src/tests/factories.ts @@ -1282,6 +1282,10 @@ export function makeVector(length: number, fn: (i: number) return new Vector(makeArray(length, fn, offset)); } +export function makeMap(size: number, fn: (i: number) => [string, T], offset = 0) { + return new Map(makeArray(size, i => fn(i + offset))); +} + export function makeContractInstanceFromClassId(classId: Fr, seed = 0): ContractInstanceWithAddress { const salt = new Fr(seed); const initializationHash = new Fr(seed + 1); @@ -1417,7 +1421,14 @@ export function makeAvmExecutionHints( return AvmExecutionHints.from({ enqueuedCalls: makeVector(baseLength, makeAvmEnqueuedCallHint, seed + 0x4100), contractInstances: makeVector(baseLength + 5, makeAvmContractInstanceHint, seed + 0x4700), - contractBytecodeHints: makeVector(baseLength + 6, makeAvmBytecodeHints, seed + 0x4800), + contractBytecodeHints: makeMap( + baseLength + 6, + i => { + const h = makeAvmBytecodeHints(i); + return [h.contractInstanceHint.address.toString(), h]; + }, + seed + 0x4900, + ), publicDataReads: makeVector(baseLength + 7, makeAvmStorageReadTreeHints, seed + 0x4900), publicDataWrites: makeVector(baseLength + 8, makeAvmStorageUpdateTreeHints, seed + 0x4a00), nullifierReads: makeVector(baseLength + 9, makeAvmNullifierReadTreeHints, seed + 0x4b00), diff --git a/yarn-project/simulator/src/avm/avm_memory_types.ts b/yarn-project/simulator/src/avm/avm_memory_types.ts index 444af4376a2..a1c86f870b1 100644 --- a/yarn-project/simulator/src/avm/avm_memory_types.ts +++ b/yarn-project/simulator/src/avm/avm_memory_types.ts @@ -316,8 +316,9 @@ export class TaggedMemory implements TaggedMemoryInterface { * Check that the memory at the given offset matches the specified tag. */ public checkTag(tag: TypeTag, offset: number) { - if (this.getTag(offset) !== tag) { - throw TagCheckError.forOffset(offset, TypeTag[this.getTag(offset)], TypeTag[tag]); + const gotTag = this.getTag(offset); + if (gotTag !== tag) { + throw TagCheckError.forOffset(offset, TypeTag[gotTag], TypeTag[tag]); } } @@ -336,13 +337,13 @@ export class TaggedMemory implements TaggedMemoryInterface { public static checkIsValidTag(tagNumber: number) { if ( ![ + TypeTag.FIELD, TypeTag.UINT1, TypeTag.UINT8, TypeTag.UINT16, TypeTag.UINT32, TypeTag.UINT64, TypeTag.UINT128, - TypeTag.FIELD, ].includes(tagNumber) ) { throw new InvalidTagValueError(tagNumber); @@ -382,21 +383,23 @@ export class TaggedMemory implements TaggedMemoryInterface { public static getTag(v: MemoryValue | undefined): TypeTag { let tag = TypeTag.INVALID; + // Not sure why, but using instanceof here doesn't work and leads odd behavior, + // but using constructor.name does the job... if (v === undefined) { tag = TypeTag.FIELD; // uninitialized memory is Field(0) - } else if (v instanceof Field) { + } else if (v.constructor.name == 'Field') { tag = TypeTag.FIELD; - } else if (v instanceof Uint1) { + } else if (v.constructor.name == 'Uint1') { tag = TypeTag.UINT1; - } else if (v instanceof Uint8) { + } else if (v.constructor.name == 'Uint8') { tag = TypeTag.UINT8; - } else if (v instanceof Uint16) { + } else if (v.constructor.name == 'Uint16') { tag = TypeTag.UINT16; - } else if (v instanceof Uint32) { + } else if (v.constructor.name == 'Uint32') { tag = TypeTag.UINT32; - } else if (v instanceof Uint64) { + } else if (v.constructor.name == 'Uint64') { tag = TypeTag.UINT64; - } else if (v instanceof Uint128) { + } else if (v.constructor.name == 'Uint128') { tag = TypeTag.UINT128; } diff --git a/yarn-project/simulator/src/avm/avm_simulator.test.ts b/yarn-project/simulator/src/avm/avm_simulator.test.ts index 8ce5ae1cffb..7ee18457dbf 100644 --- a/yarn-project/simulator/src/avm/avm_simulator.test.ts +++ b/yarn-project/simulator/src/avm/avm_simulator.test.ts @@ -2,7 +2,7 @@ import { MerkleTreeId, type MerkleTreeWriteOperations } from '@aztec/circuit-typ import { DEPLOYER_CONTRACT_ADDRESS, GasFees, - GlobalVariables, + MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS, PublicDataTreeLeafPreimage, PublicKeys, SerializableContractInstance, @@ -30,8 +30,8 @@ import { randomInt } from 'crypto'; import { mock } from 'jest-mock-extended'; import { PublicEnqueuedCallSideEffectTrace } from '../public/enqueued_call_side_effect_trace.js'; -import { MockedAvmTestContractDataSource } from '../public/fixtures/index.js'; -import { WorldStateDB } from '../public/public_db_sources.js'; +import { MockedAvmTestContractDataSource, simulateAvmTestContractCall } from '../public/fixtures/index.js'; +import { type WorldStateDB } from '../public/public_db_sources.js'; import { type PublicSideEffectTraceInterface } from '../public/side_effect_trace_interface.js'; import { type AvmContext } from './avm_context.js'; import { type MemoryValue, TypeTag, type Uint8, type Uint64 } from './avm_memory_types.js'; @@ -41,7 +41,6 @@ import { isAvmBytecode, markBytecodeAsAvm } from './bytecode_utils.js'; import { getAvmTestContractArtifact, getAvmTestContractBytecode, - getAvmTestContractFunctionSelector, initContext, initExecutionEnvironment, initGlobalVariables, @@ -148,54 +147,43 @@ describe('AVM simulator: injected bytecode', () => { }); }); -const TIMESTAMP = new Fr(99833); - describe('AVM simulator: transpiled Noir contracts', () => { it('bulk testing', async () => { - const functionName = 'bulk_testing'; - const functionSelector = getAvmTestContractFunctionSelector(functionName); const args = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10].map(x => new Fr(x)); - const calldata = [functionSelector.toField(), ...args]; - const globals = GlobalVariables.empty(); - globals.timestamp = TIMESTAMP; + await simulateAvmTestContractCall('bulk_testing', args, /*expectRevert=*/ false); + }); - const telemetry = new NoopTelemetryClient(); - const merkleTrees = await (await MerkleTrees.new(openTmpStore(), telemetry)).fork(); + it('call max unique contract classes', async () => { const contractDataSource = new MockedAvmTestContractDataSource(); - const worldStateDB = new WorldStateDB(merkleTrees, contractDataSource); - - const contractInstance = contractDataSource.contractInstance; - const contractAddressNullifier = siloNullifier( - AztecAddress.fromNumber(DEPLOYER_CONTRACT_ADDRESS), - contractInstance.address.toField(), + // args is initialized to MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS contract addresses with unique class IDs + const args = Array.from(contractDataSource.contractInstances.values()) + .map(instance => instance.address.toField()) + .slice(0, MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS); + // include the first contract again again at the end to ensure that we can call it even after the limit is reached + args.push(args[0]); + // include another contract address that reuses a class ID to ensure that we can call it even after the limit is reached + args.push(contractDataSource.instanceSameClassAsFirstContract.address.toField()); + await simulateAvmTestContractCall( + 'nested_call_to_add_n_times_different_addresses', + args, + /*expectRevert=*/ false, + contractDataSource, ); - await merkleTrees.batchInsert(MerkleTreeId.NULLIFIER_TREE, [contractAddressNullifier.toBuffer()], 0); - // other contract address used by the bulk test's GETCONTRACTINSTANCE test - const otherContractAddressNullifier = siloNullifier( - AztecAddress.fromNumber(DEPLOYER_CONTRACT_ADDRESS), - contractDataSource.otherContractInstance.address.toField(), - ); - await merkleTrees.batchInsert(MerkleTreeId.NULLIFIER_TREE, [otherContractAddressNullifier.toBuffer()], 0); - - const trace = mock(); - const nestedTrace = mock(); - mockNoteHashCount(trace, 0); - mockTraceFork(trace, nestedTrace); - const ephemeralTrees = await AvmEphemeralForest.create(worldStateDB.getMerkleInterface()); - const persistableState = initPersistableStateManager({ worldStateDB, trace, merkleTrees: ephemeralTrees }); - const environment = initExecutionEnvironment({ - calldata, - globals, - address: contractInstance.address, - sender: AztecAddress.fromNumber(42), - }); - const context = initContext({ env: environment, persistableState }); - - // First we simulate (though it's not needed in this simple case). - const simulator = new AvmSimulator(context); - const results = await simulator.execute(); + }); - expect(results.reverted).toBe(false); + it('call too many unique contract classes fails', async () => { + const contractDataSource = new MockedAvmTestContractDataSource(); + // args is initialized to MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS+1 contract addresses with unique class IDs + // should fail because we are trying to call MAX+1 unique class IDs + const args = Array.from(contractDataSource.contractInstances.values()).map(instance => instance.address.toField()); + // push an empty one (just padding to match function calldata size of MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS+2) + args.push(new Fr(0)); + await simulateAvmTestContractCall( + 'nested_call_to_add_n_times_different_addresses', + args, + /*expectRevert=*/ true, + contractDataSource, + ); }); it('execution of a non-existent contract immediately reverts and consumes all allocated gas', async () => { diff --git a/yarn-project/simulator/src/avm/avm_simulator.ts b/yarn-project/simulator/src/avm/avm_simulator.ts index 1ff7cf60d61..fff7f7d7fd5 100644 --- a/yarn-project/simulator/src/avm/avm_simulator.ts +++ b/yarn-project/simulator/src/avm/avm_simulator.ts @@ -97,25 +97,22 @@ export class AvmSimulator { * Fetch the bytecode and execute it in the current context. */ public async execute(): Promise { - const bytecode = await this.context.persistableState.getBytecode(this.context.environment.address); - if (!bytecode) { - // revert, consuming all gas - const message = `No bytecode found at: ${this.context.environment.address}. Reverting...`; - const fnName = await this.context.persistableState.getPublicFunctionDebugName(this.context.environment); - const revertReason = new AvmRevertReason( - message, - /*failingFunction=*/ { - contractAddress: this.context.environment.address, - functionName: fnName, - }, - /*noirCallStack=*/ [], + let bytecode: Buffer | undefined; + try { + bytecode = await this.context.persistableState.getBytecode(this.context.environment.address); + } catch (err: any) { + if (!(err instanceof AvmExecutionError || err instanceof SideEffectLimitReachedError)) { + this.log.error(`Unknown error thrown by AVM during bytecode retrieval: ${err}`); + throw err; + } + return await this.handleFailureToRetrieveBytecode( + `Bytecode retrieval for contract '${this.context.environment.address}' failed with ${err}. Reverting...`, ); - this.log.warn(message); - return new AvmContractCallResult( - /*reverted=*/ true, - /*output=*/ [], - /*gasLeft=*/ { l2Gas: 0, daGas: 0 }, // consumes all allocated gas - revertReason, + } + + if (!bytecode) { + return await this.handleFailureToRetrieveBytecode( + `No bytecode found at: ${this.context.environment.address}. Reverting...`, ); } @@ -189,7 +186,16 @@ export class AvmSimulator { return results; } catch (err: any) { this.log.verbose('Exceptional halt (revert by something other than REVERT opcode)'); - if (!(err instanceof AvmExecutionError || err instanceof SideEffectLimitReachedError)) { + // FIXME: weird that we have to do this OutOfGasError check because: + // 1. OutOfGasError is an AvmExecutionError, so that check should cover both + // 2. We should at least be able to do instanceof OutOfGasError instead of checking the constructor name + if ( + !( + err.constructor.name == 'OutOfGasError' || + err instanceof AvmExecutionError || + err instanceof SideEffectLimitReachedError + ) + ) { this.log.error(`Unknown error thrown by AVM: ${err}`); throw err; } @@ -207,6 +213,26 @@ export class AvmSimulator { } } + private async handleFailureToRetrieveBytecode(message: string): Promise { + // revert, consuming all gas + const fnName = await this.context.persistableState.getPublicFunctionDebugName(this.context.environment); + const revertReason = new AvmRevertReason( + message, + /*failingFunction=*/ { + contractAddress: this.context.environment.address, + functionName: fnName, + }, + /*noirCallStack=*/ [], + ); + this.log.warn(message); + return new AvmContractCallResult( + /*reverted=*/ true, + /*output=*/ [], + /*gasLeft=*/ { l2Gas: 0, daGas: 0 }, // consumes all allocated gas + revertReason, + ); + } + private tallyInstruction(pc: number, opcode: string, gasUsed: Gas) { const opcodeTally = this.opcodeTallies.get(opcode) || ({ count: 0, gas: { l2Gas: 0, daGas: 0 } } as OpcodeTally); opcodeTally.count++; diff --git a/yarn-project/simulator/src/avm/journal/journal.ts b/yarn-project/simulator/src/avm/journal/journal.ts index 35d349add3b..f962f2aa3b8 100644 --- a/yarn-project/simulator/src/avm/journal/journal.ts +++ b/yarn-project/simulator/src/avm/journal/journal.ts @@ -64,29 +64,6 @@ export class AvmPersistableStateManager { public readonly txHash: TxHash, ) {} - /** - * Create a new state manager with some preloaded pending siloed nullifiers - */ - public static async newWithPendingSiloedNullifiers( - worldStateDB: WorldStateDB, - trace: PublicSideEffectTraceInterface, - pendingSiloedNullifiers: Fr[], - doMerkleOperations: boolean = false, - txHash: TxHash, - ) { - const parentNullifiers = NullifierManager.newWithPendingSiloedNullifiers(worldStateDB, pendingSiloedNullifiers); - const ephemeralForest = await AvmEphemeralForest.create(worldStateDB.getMerkleInterface()); - return new AvmPersistableStateManager( - worldStateDB, - trace, - /*publicStorage=*/ new PublicStorage(worldStateDB), - /*nullifiers=*/ parentNullifiers.fork(), - doMerkleOperations, - ephemeralForest, - txHash, - ); - } - /** * Create a new state manager */ diff --git a/yarn-project/simulator/src/avm/journal/nullifiers.ts b/yarn-project/simulator/src/avm/journal/nullifiers.ts index 1af35cc9e4c..8e482108990 100644 --- a/yarn-project/simulator/src/avm/journal/nullifiers.ts +++ b/yarn-project/simulator/src/avm/journal/nullifiers.ts @@ -17,17 +17,6 @@ export class NullifierManager { private readonly parent?: NullifierManager, ) {} - /** - * Create a new nullifiers manager with some preloaded pending siloed nullifiers - */ - public static newWithPendingSiloedNullifiers(hostNullifiers: CommitmentsDB, pendingSiloedNullifiers?: Fr[]) { - const cachedSiloedNullifiers = new Set(); - if (pendingSiloedNullifiers !== undefined) { - pendingSiloedNullifiers.forEach(nullifier => cachedSiloedNullifiers.add(nullifier.toBigInt())); - } - return new NullifierManager(hostNullifiers, cachedSiloedNullifiers); - } - /** * Create a new nullifiers manager forked from this one */ diff --git a/yarn-project/simulator/src/avm/journal/public_storage.ts b/yarn-project/simulator/src/avm/journal/public_storage.ts index da2565e8d2f..64abe824ddb 100644 --- a/yarn-project/simulator/src/avm/journal/public_storage.ts +++ b/yarn-project/simulator/src/avm/journal/public_storage.ts @@ -1,4 +1,4 @@ -import { AztecAddress } from '@aztec/circuits.js'; +import { type AztecAddress } from '@aztec/circuits.js'; import { Fr } from '@aztec/foundation/fields'; import type { PublicStateDB } from '../../index.js'; @@ -33,13 +33,6 @@ export class PublicStorage { return new PublicStorage(this.hostPublicStorage, this); } - /** - * Get the pending storage. - */ - public getCache() { - return this.cache; - } - /** * Read a storage value from this' cache or parent's (recursively). * DOES NOT CHECK HOST STORAGE! @@ -108,17 +101,6 @@ export class PublicStorage { public acceptAndMerge(incomingPublicStorage: PublicStorage) { this.cache.acceptAndMerge(incomingPublicStorage.cache); } - - /** - * Commits ALL staged writes to the host's state. - */ - public async commitToDB() { - for (const [contractAddress, cacheAtContract] of this.cache.cachePerContract) { - for (const [slot, value] of cacheAtContract) { - await this.hostPublicStorage.storageWrite(AztecAddress.fromBigInt(contractAddress), new Fr(slot), value); - } - } - } } /** @@ -132,8 +114,7 @@ class PublicStorageCache { * One inner-map per contract storage address, * mapping storage slot to latest staged write value. */ - public cachePerContract: Map> = new Map(); - // FIXME: storage ^ should be private, but its value is used in commitToDB + private cachePerContract: Map> = new Map(); /** * Read a staged value from storage, if it has been previously written to. diff --git a/yarn-project/simulator/src/public/bytecode_errors.ts b/yarn-project/simulator/src/public/bytecode_errors.ts new file mode 100644 index 00000000000..c3d6c8a305e --- /dev/null +++ b/yarn-project/simulator/src/public/bytecode_errors.ts @@ -0,0 +1,6 @@ +export class ContractClassBytecodeError extends Error { + constructor(contractAddress: string) { + super(`Failed to get bytecode for contract at address ${contractAddress}`); + this.name = 'ContractClassBytecodeError'; + } +} diff --git a/yarn-project/simulator/src/public/enqueued_call_side_effect_trace.test.ts b/yarn-project/simulator/src/public/enqueued_call_side_effect_trace.test.ts index a6b34e58e13..9faa6cbd0ec 100644 --- a/yarn-project/simulator/src/public/enqueued_call_side_effect_trace.test.ts +++ b/yarn-project/simulator/src/public/enqueued_call_side_effect_trace.test.ts @@ -13,6 +13,7 @@ import { MAX_L2_TO_L1_MSGS_PER_TX, MAX_NOTE_HASHES_PER_TX, MAX_NULLIFIERS_PER_TX, + MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS, MAX_PUBLIC_DATA_UPDATE_REQUESTS_PER_TX, MAX_UNENCRYPTED_LOGS_PER_TX, NoteHash, @@ -201,7 +202,7 @@ describe('Enqueued-call Side Effect Trace', () => { ); const membershipHint = new AvmNullifierReadTreeHint(lowLeafPreimage, lowLeafIndex, lowLeafSiblingPath); - expect(trace.getAvmCircuitHints().contractBytecodeHints.items).toEqual([ + expect(Array.from(trace.getAvmCircuitHints().contractBytecodeHints.values())).toEqual([ { bytecode, contractInstanceHint: { address, exists, ...instanceWithoutVersion, membershipHint: { ...membershipHint } }, @@ -319,6 +320,37 @@ describe('Enqueued-call Side Effect Trace', () => { ); }); + it('Should enforce maximum number of calls to unique contract class IDs', () => { + const firstAddr = AztecAddress.fromNumber(0); + const firstInstance = SerializableContractInstance.random(); + trace.traceGetBytecode(firstAddr, /*exists=*/ true, bytecode, firstInstance); + + for (let i = 1; i < MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS; i++) { + const addr = AztecAddress.fromNumber(i); + const instance = SerializableContractInstance.random(); + trace.traceGetBytecode(addr, /*exists=*/ true, bytecode, instance); + } + + const addr = AztecAddress.fromNumber(MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS); + const instance = SerializableContractInstance.random(); + expect(() => trace.traceGetBytecode(addr, /*exists=*/ true, bytecode, instance)).toThrow( + SideEffectLimitReachedError, + ); + + // can re-trace same contract address + trace.traceGetBytecode(firstAddr, /*exists=*/ true, bytecode, firstInstance); + + const differentAddr = AztecAddress.fromNumber(MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS + 1); + const instanceWithSameClassId = SerializableContractInstance.random({ + contractClassId: firstInstance.contractClassId, + }); + // can re-trace different contract address if it has a duplicate class ID + trace.traceGetBytecode(differentAddr, /*exists=*/ true, bytecode, instanceWithSameClassId); + + // can trace a call to a non-existent contract + trace.traceGetBytecode(differentAddr, /*exists=*/ false); + }); + it('PreviousValidationRequestArrayLengths and PreviousAccumulatedDataArrayLengths contribute to limits', () => { trace = new PublicEnqueuedCallSideEffectTrace( 0, @@ -405,7 +437,7 @@ describe('Enqueued-call Side Effect Trace', () => { const childHints = nestedTrace.getAvmCircuitHints(); expect(parentHints.enqueuedCalls.items).toEqual(childHints.enqueuedCalls.items); expect(parentHints.contractInstances.items).toEqual(childHints.contractInstances.items); - expect(parentHints.contractBytecodeHints.items).toEqual(childHints.contractBytecodeHints.items); + expect(parentHints.contractBytecodeHints).toEqual(childHints.contractBytecodeHints); expect(parentHints.publicDataReads.items).toEqual(childHints.publicDataReads.items); expect(parentHints.publicDataWrites.items).toEqual(childHints.publicDataWrites.items); expect(parentHints.nullifierReads.items).toEqual(childHints.nullifierReads.items); diff --git a/yarn-project/simulator/src/public/enqueued_call_side_effect_trace.ts b/yarn-project/simulator/src/public/enqueued_call_side_effect_trace.ts index 040ebbaaa66..cfd3f76c11c 100644 --- a/yarn-project/simulator/src/public/enqueued_call_side_effect_trace.ts +++ b/yarn-project/simulator/src/public/enqueued_call_side_effect_trace.ts @@ -24,6 +24,7 @@ import { MAX_L2_TO_L1_MSGS_PER_TX, MAX_NOTE_HASHES_PER_TX, MAX_NULLIFIERS_PER_TX, + MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS, MAX_PUBLIC_DATA_UPDATE_REQUESTS_PER_TX, MAX_TOTAL_PUBLIC_DATA_UPDATE_REQUESTS_PER_TX, MAX_UNENCRYPTED_LOGS_PER_TX, @@ -58,6 +59,7 @@ import { type AvmExecutionEnvironment } from '../avm/avm_execution_environment.j import { type EnqueuedPublicCallExecutionResultWithSideEffects, type PublicFunctionCallResult } from './execution.js'; import { SideEffectLimitReachedError } from './side_effect_errors.js'; import { type PublicSideEffectTraceInterface } from './side_effect_trace_interface.js'; +import { UniqueClassIds } from './unique_class_ids.js'; const emptyPublicDataPath = () => new Array(PUBLIC_DATA_TREE_HEIGHT).fill(Fr.zero()); const emptyNoteHashPath = () => new Array(NOTE_HASH_TREE_HEIGHT).fill(Fr.zero()); @@ -128,6 +130,8 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI * otherwise the public kernel can fail to prove because TX limits are breached. */ private readonly previousSideEffectArrayLengths: SideEffectArrayLengths = SideEffectArrayLengths.empty(), + /** We need to track the set of class IDs used for bytecode retrieval to deduplicate and enforce limits. */ + private gotBytecodeFromClassIds: UniqueClassIds = new UniqueClassIds(), ) { this.log.debug(`Creating trace instance with startSideEffectCounter: ${startSideEffectCounter}`); this.sideEffectCounter = startSideEffectCounter; @@ -145,6 +149,7 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI this.previousSideEffectArrayLengths.l2ToL1Msgs + this.l2ToL1Messages.length, this.previousSideEffectArrayLengths.unencryptedLogs + this.unencryptedLogs.length, ), + this.gotBytecodeFromClassIds.fork(), ); } @@ -152,7 +157,7 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI // sanity check to avoid merging the same forked trace twice assert( !forkedTrace.alreadyMergedIntoParent, - 'Cannot merge a forked trace that has already been merged into its parent!', + 'Bug! Cannot merge a forked trace that has already been merged into its parent!', ); forkedTrace.alreadyMergedIntoParent = true; @@ -171,10 +176,21 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI } private mergeHints(forkedTrace: this) { + this.gotBytecodeFromClassIds.acceptAndMerge(forkedTrace.gotBytecodeFromClassIds); + this.avmCircuitHints.enqueuedCalls.items.push(...forkedTrace.avmCircuitHints.enqueuedCalls.items); this.avmCircuitHints.contractInstances.items.push(...forkedTrace.avmCircuitHints.contractInstances.items); - this.avmCircuitHints.contractBytecodeHints.items.push(...forkedTrace.avmCircuitHints.contractBytecodeHints.items); + + // merge in contract bytecode hints + // UniqueClassIds should prevent duplication + for (const [contractClassId, bytecodeHint] of forkedTrace.avmCircuitHints.contractBytecodeHints) { + assert( + !this.avmCircuitHints.contractBytecodeHints.has(contractClassId), + 'Bug preventing duplication of contract bytecode hints', + ); + this.avmCircuitHints.contractBytecodeHints.set(contractClassId, bytecodeHint); + } this.avmCircuitHints.publicDataReads.items.push(...forkedTrace.avmCircuitHints.publicDataReads.items); this.avmCircuitHints.publicDataWrites.items.push(...forkedTrace.avmCircuitHints.publicDataWrites.items); @@ -405,6 +421,14 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI lowLeafIndex: Fr = Fr.zero(), lowLeafPath: Fr[] = emptyNullifierPath(), ) { + // FIXME: The way we are hinting contract bytecodes is fundamentally broken. + // We are mapping contract class ID to a bytecode hint + // But a bytecode hint is tied to a contract INSTANCE. + // What if you encounter another contract instance with the same class ID? + // We can't hint that instance too since there is already an entry in the hints set that class ID. + // But without that instance hinted, the circuit can't prove that the called contract address + // actually corresponds to any class ID. + const membershipHint = new AvmNullifierReadTreeHint(lowLeafPreimage, lowLeafIndex, lowLeafPath); const instance = new AvmContractInstanceHint( contractAddress, @@ -416,13 +440,57 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI contractInstance.publicKeys, membershipHint, ); - // We need to deduplicate the contract instances based on addresses - this.avmCircuitHints.contractBytecodeHints.items.push( - new AvmContractBytecodeHints(bytecode, instance, contractClass), - ); + + // Always hint the contract instance separately from the bytecode hint. + // Since the bytecode hints are keyed by class ID, we need to hint the instance separately + // since there might be multiple instances hinted for the same class ID. + this.avmCircuitHints.contractInstances.items.push(instance); this.log.debug( - `Bytecode retrieval for contract execution traced: exists=${exists}, instance=${jsonStringify(contractInstance)}`, + `Tracing contract instance for bytecode retrieval: exists=${exists}, instance=${jsonStringify(contractInstance)}`, + ); + + if (!exists) { + // this ensures there are no duplicates + this.log.debug(`Contract address ${contractAddress} does not exist. Not tracing bytecode & class ID.`); + return; + } + // We already hinted this bytecode. No need to + // Don't we still need to hint if the class ID already exists? + // Because the circuit needs to prove that the called contract address corresponds to the class ID. + // To do so, the circuit needs to know the class ID in the + if (this.gotBytecodeFromClassIds.has(contractInstance.contractClassId.toString())) { + // this ensures there are no duplicates + this.log.debug( + `Contract class id ${contractInstance.contractClassId.toString()} already exists in previous hints`, + ); + return; + } + + // If we could actually allow contract calls after the limit was reached, we would hint even if we have + // surpassed the limit of unique class IDs (still trace the failed bytecode retrieval) + // because the circuit needs to know the class ID to know when the limit is hit. + // BUT, the issue with this approach is that the sequencer could lie and say "this call was to a new class ID", + // and the circuit cannot prove that it's not true without deriving the class ID from bytecode, + // proving that it corresponds to the called contract address, and proving that the class ID wasn't already + // present/used. That would require more bytecode hashing which is exactly what this limit exists to avoid. + if (this.gotBytecodeFromClassIds.size() >= MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS) { + this.log.debug( + `Bytecode retrieval failure for contract class ID ${contractInstance.contractClassId.toString()} (limit reached)`, + ); + throw new SideEffectLimitReachedError( + 'contract calls to unique class IDs', + MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS, + ); + } + + this.log.debug(`Tracing bytecode & contract class for bytecode retrieval: class=${jsonStringify(contractClass)}`); + this.avmCircuitHints.contractBytecodeHints.set( + contractInstance.contractClassId.toString(), + new AvmContractBytecodeHints(bytecode, instance, contractClass), ); + // After adding the bytecode hint, mark the classId as retrieved to avoid duplication. + // The above map alone isn't sufficient because we need to check the parent trace's (and its parent) as well. + this.gotBytecodeFromClassIds.add(contractInstance.contractClassId.toString()); } /** diff --git a/yarn-project/simulator/src/public/fixtures/index.ts b/yarn-project/simulator/src/public/fixtures/index.ts index af8b3a5b3cd..e14b6f8c360 100644 --- a/yarn-project/simulator/src/public/fixtures/index.ts +++ b/yarn-project/simulator/src/public/fixtures/index.ts @@ -1,4 +1,4 @@ -import { MerkleTreeId, PublicExecutionRequest, Tx } from '@aztec/circuit-types'; +import { MerkleTreeId, type MerkleTreeWriteOperations, PublicExecutionRequest, Tx } from '@aztec/circuit-types'; import { type AvmCircuitInputs, BlockHeader, @@ -14,6 +14,7 @@ import { GasSettings, GlobalVariables, MAX_L2_GAS_PER_TX_PUBLIC_PORTION, + MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS, PartialPrivateTailPublicInputsForPublic, PrivateKernelTailCircuitPublicInputs, type PublicFunction, @@ -31,63 +32,54 @@ import { AztecAddress } from '@aztec/foundation/aztec-address'; import { Fr, Point } from '@aztec/foundation/fields'; import { openTmpStore } from '@aztec/kv-store/lmdb'; import { AvmTestContractArtifact } from '@aztec/noir-contracts.js/AvmTest'; -import { PublicTxSimulator, WorldStateDB } from '@aztec/simulator'; +import { + AvmEphemeralForest, + AvmSimulator, + PublicEnqueuedCallSideEffectTrace, + PublicTxSimulator, + WorldStateDB, +} from '@aztec/simulator'; import { NoopTelemetryClient } from '@aztec/telemetry-client/noop'; import { MerkleTrees } from '@aztec/world-state'; import { strict as assert } from 'assert'; +import { initContext, initExecutionEnvironment, initPersistableStateManager } from '../../avm/fixtures/index.js'; + +const TIMESTAMP = new Fr(99833); + export async function simulateAvmTestContractGenerateCircuitInputs( functionName: string, - calldata: Fr[] = [], + args: Fr[] = [], expectRevert: boolean = false, - skipContractDeployments: boolean = false, + contractDataSource = new MockedAvmTestContractDataSource(), assertionErrString?: string, ): Promise { - const sender = AztecAddress.random(); - const functionSelector = getAvmTestContractFunctionSelector(functionName); - calldata = [functionSelector.toField(), ...calldata]; - - const globalVariables = GlobalVariables.empty(); - globalVariables.gasFees = GasFees.empty(); - globalVariables.timestamp = new Fr(99833); + const globals = GlobalVariables.empty(); + globals.timestamp = TIMESTAMP; - const telemetry = new NoopTelemetryClient(); - const merkleTrees = await (await MerkleTrees.new(openTmpStore(), telemetry)).fork(); - const contractDataSource = new MockedAvmTestContractDataSource(skipContractDeployments); + const merkleTrees = await (await MerkleTrees.new(openTmpStore(), new NoopTelemetryClient())).fork(); + await contractDataSource.deployContracts(merkleTrees); const worldStateDB = new WorldStateDB(merkleTrees, contractDataSource); - const contractInstance = contractDataSource.contractInstance; - - if (!skipContractDeployments) { - const contractAddressNullifier = siloNullifier( - AztecAddress.fromNumber(DEPLOYER_CONTRACT_ADDRESS), - contractInstance.address.toField(), - ); - await merkleTrees.batchInsert(MerkleTreeId.NULLIFIER_TREE, [contractAddressNullifier.toBuffer()], 0); - // other contract address used by the bulk test's GETCONTRACTINSTANCE test - const otherContractAddressNullifier = siloNullifier( - AztecAddress.fromNumber(DEPLOYER_CONTRACT_ADDRESS), - contractDataSource.otherContractInstance.address.toField(), - ); - await merkleTrees.batchInsert(MerkleTreeId.NULLIFIER_TREE, [otherContractAddressNullifier.toBuffer()], 0); - } - const simulator = new PublicTxSimulator( merkleTrees, worldStateDB, new NoopTelemetryClient(), - globalVariables, + globals, /*doMerkleOperations=*/ true, ); + const sender = AztecAddress.random(); + const functionSelector = getAvmTestContractFunctionSelector(functionName); + args = [functionSelector.toField(), ...args]; const callContext = new CallContext( sender, - contractInstance.address, + contractDataSource.firstContractInstance.address, contractDataSource.fnSelector, /*isStaticCall=*/ false, ); - const executionRequest = new PublicExecutionRequest(callContext, calldata); + const executionRequest = new PublicExecutionRequest(callContext, args); const tx: Tx = createTxForPublicCall(executionRequest); @@ -108,6 +100,46 @@ export async function simulateAvmTestContractGenerateCircuitInputs( return avmCircuitInputs; } +export async function simulateAvmTestContractCall( + functionName: string, + args: Fr[] = [], + expectRevert: boolean = false, + contractDataSource = new MockedAvmTestContractDataSource(), +) { + const globals = GlobalVariables.empty(); + globals.timestamp = TIMESTAMP; + + const merkleTrees = await (await MerkleTrees.new(openTmpStore(), new NoopTelemetryClient())).fork(); + await contractDataSource.deployContracts(merkleTrees); + const worldStateDB = new WorldStateDB(merkleTrees, contractDataSource); + + const trace = new PublicEnqueuedCallSideEffectTrace(); + const ephemeralTrees = await AvmEphemeralForest.create(worldStateDB.getMerkleInterface()); + const persistableState = initPersistableStateManager({ + worldStateDB, + trace, + merkleTrees: ephemeralTrees, + doMerkleOperations: true, + }); + + const sender = AztecAddress.random(); + const functionSelector = getAvmTestContractFunctionSelector(functionName); + args = [functionSelector.toField(), ...args]; + const environment = initExecutionEnvironment({ + calldata: args, + globals, + address: contractDataSource.firstContractInstance.address, + sender, + }); + const context = initContext({ env: environment, persistableState }); + + // First we simulate (though it's not needed in this simple case). + const simulator = new AvmSimulator(context); + const results = await simulator.execute(); + + expect(results.reverted).toBe(expectRevert); +} + /** * Craft a carrier transaction for a public call for simulation by PublicTxSimulator. */ @@ -151,21 +183,45 @@ export function createTxForPublicCall( export class MockedAvmTestContractDataSource implements ContractDataSource { private fnName = 'public_dispatch'; + public fnSelector: FunctionSelector = getAvmTestContractFunctionSelector(this.fnName); private bytecode: Buffer; - public fnSelector: FunctionSelector; private publicFn: PublicFunction; - private contractClass: ContractClassPublic; - public contractInstance: ContractInstanceWithAddress; private bytecodeCommitment: Fr; + + // maps contract class ID to class + private contractClasses: Map = new Map(); + // maps contract instance address to instance + public contractInstances: Map = new Map(); + + public firstContractInstance: ContractInstanceWithAddress = SerializableContractInstance.default().withAddress( + AztecAddress.fromNumber(0), + ); + public instanceSameClassAsFirstContract: ContractInstanceWithAddress = + SerializableContractInstance.default().withAddress(AztecAddress.fromNumber(0)); public otherContractInstance: ContractInstanceWithAddress; - constructor(private noContractsDeployed: boolean = false) { + constructor(private skipContractDeployments: boolean = false) { this.bytecode = getAvmTestContractBytecode(this.fnName); this.fnSelector = getAvmTestContractFunctionSelector(this.fnName); this.publicFn = { bytecode: this.bytecode, selector: this.fnSelector }; - this.contractClass = makeContractClassPublic(0, this.publicFn); - this.contractInstance = makeContractInstanceFromClassId(this.contractClass.id); this.bytecodeCommitment = computePublicBytecodeCommitment(this.bytecode); + + // create enough unique classes to hit the limit (plus two extra) + for (let i = 0; i < MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS + 1; i++) { + const contractClass = makeContractClassPublic(/*seed=*/ i, this.publicFn); + const contractInstance = makeContractInstanceFromClassId(contractClass.id, /*seed=*/ i); + this.contractClasses.set(contractClass.id.toString(), contractClass); + this.contractInstances.set(contractInstance.address.toString(), contractInstance); + if (i === 0) { + this.firstContractInstance = contractInstance; + } + } + // a contract with the same class but different instance/address as the first contract + this.instanceSameClassAsFirstContract = makeContractInstanceFromClassId( + this.firstContractInstance.contractClassId, + /*seed=*/ 1000, + ); + // The values here should match those in `avm_simulator.test.ts` // Used for GETCONTRACTINSTANCE test this.otherContractInstance = new SerializableContractInstance({ @@ -183,6 +239,46 @@ export class MockedAvmTestContractDataSource implements ContractDataSource { }).withAddress(AztecAddress.fromNumber(0x4444)); } + async deployContracts(merkleTrees: MerkleTreeWriteOperations) { + if (!this.skipContractDeployments) { + for (const contractInstance of this.contractInstances.values()) { + const contractAddressNullifier = siloNullifier( + AztecAddress.fromNumber(DEPLOYER_CONTRACT_ADDRESS), + contractInstance.address.toField(), + ); + await merkleTrees.batchInsert(MerkleTreeId.NULLIFIER_TREE, [contractAddressNullifier.toBuffer()], 0); + } + + const instanceSameClassAsFirstContractNullifier = siloNullifier( + AztecAddress.fromNumber(DEPLOYER_CONTRACT_ADDRESS), + this.instanceSameClassAsFirstContract.address.toField(), + ); + await merkleTrees.batchInsert( + MerkleTreeId.NULLIFIER_TREE, + [instanceSameClassAsFirstContractNullifier.toBuffer()], + 0, + ); + + // other contract address used by the bulk test's GETCONTRACTINSTANCE test + const otherContractAddressNullifier = siloNullifier( + AztecAddress.fromNumber(DEPLOYER_CONTRACT_ADDRESS), + this.otherContractInstance.address.toField(), + ); + await merkleTrees.batchInsert(MerkleTreeId.NULLIFIER_TREE, [otherContractAddressNullifier.toBuffer()], 0); + } + } + + public static async create( + merkleTrees: MerkleTreeWriteOperations, + skipContractDeployments: boolean = false, + ): Promise { + const dataSource = new MockedAvmTestContractDataSource(skipContractDeployments); + if (!skipContractDeployments) { + await dataSource.deployContracts(merkleTrees); + } + return dataSource; + } + getPublicFunction(_address: AztecAddress, _selector: FunctionSelector): Promise { return Promise.resolve(this.publicFn); } @@ -191,8 +287,8 @@ export class MockedAvmTestContractDataSource implements ContractDataSource { throw new Error('Method not implemented.'); } - getContractClass(_id: Fr): Promise { - return Promise.resolve(this.contractClass); + getContractClass(id: Fr): Promise { + return Promise.resolve(this.contractClasses.get(id.toString())); } getBytecodeCommitment(_id: Fr): Promise { @@ -204,11 +300,13 @@ export class MockedAvmTestContractDataSource implements ContractDataSource { } getContract(address: AztecAddress): Promise { - if (!this.noContractsDeployed) { - if (address.equals(this.contractInstance.address)) { - return Promise.resolve(this.contractInstance); - } else if (address.equals(this.otherContractInstance.address)) { + if (!this.skipContractDeployments) { + if (address.equals(this.otherContractInstance.address)) { return Promise.resolve(this.otherContractInstance); + } else if (address.equals(this.instanceSameClassAsFirstContract.address)) { + return Promise.resolve(this.instanceSameClassAsFirstContract); + } else { + return Promise.resolve(this.contractInstances.get(address.toString())); } } return Promise.resolve(undefined); diff --git a/yarn-project/simulator/src/public/side_effect_errors.ts b/yarn-project/simulator/src/public/side_effect_errors.ts index 4953c72d2bc..fbdff387806 100644 --- a/yarn-project/simulator/src/public/side_effect_errors.ts +++ b/yarn-project/simulator/src/public/side_effect_errors.ts @@ -1,6 +1,6 @@ export class SideEffectLimitReachedError extends Error { constructor(sideEffectType: string, limit: number) { - super(`Reached the limit on number of '${sideEffectType}' side effects: ${limit}`); + super(`Reached the limit (${limit}) on number of '${sideEffectType}' per tx`); this.name = 'SideEffectLimitReachedError'; } } diff --git a/yarn-project/simulator/src/public/unique_class_ids.ts b/yarn-project/simulator/src/public/unique_class_ids.ts new file mode 100644 index 00000000000..97540cfd077 --- /dev/null +++ b/yarn-project/simulator/src/public/unique_class_ids.ts @@ -0,0 +1,80 @@ +import { MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS } from '@aztec/circuits.js'; + +import { strict as assert } from 'assert'; + +/** + * A class manage a de-duplicated set of class IDs that errors if you try to add a duplicate. + * Useful for bytecode retrieval hints to avoid duplicates in parent trace & grandparent trace.... + */ +export class UniqueClassIds { + private readonly classIds: Set = new Set(); + + constructor(private readonly parent?: UniqueClassIds) {} + + /** + * Create a fork that references this one as its parent + */ + public fork() { + return new UniqueClassIds(/*parent=*/ this); + } + + /** + * Check for a class ID here or in parent's (recursively). + * + * @param classId - the contract class ID (as a string) to check + * @returns boolean: whether the class ID is here + */ + public has(classId: string): boolean { + // First try check this' classIds + let here = this.classIds.has(classId); + // Then try parent's + if (!here && this.parent) { + // Note: this will recurse to grandparent/etc until we reach top or find it + here = this.parent.has(classId); + } + return here; + } + + /** + * Get the total number of classIds + */ + public size(): number { + return this.classIds.size + (this.parent ? this.parent.size() : 0); + } + + /** + * Add a class ID (if not already present) to the set. + * + * @param classId - the contract class ID (as a string) + */ + public add(classId: string) { + assert(!this.has(classId), `Bug! Tried to add duplicate classId ${classId} to set of unique classIds.`); + if (!this.has(classId)) { + this.classIds.add(classId); + assert( + this.size() <= MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS, + `Bug! Surpassed limit (${MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS}) of unique contract class IDs used for bytecode retrievals.`, + ); + } + } + + /** + * Merge in another set of unique class IDs into this one, but fail on duplicates. + * + * @param incoming: other unique class IDs + */ + public acceptAndMerge(incoming: UniqueClassIds) { + for (const classId of incoming.classIds.keys()) { + assert( + !this.has(classId), + `Bug! Cannot merge classId ${classId} into set of unique classIds as it already exists.`, + ); + this.classIds.add(classId); + } + // since set() has an assertion, and size() always checks parent, this should be impossible + assert( + this.size() <= MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS, + `Bug! Merging unique class Ids should never exceed the limit of ${MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS}.`, + ); + } +}