Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add limit to unique contract call #10640

Merged
merged 11 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ enum class AvmError : uint32_t {
SIDE_EFFECT_LIMIT_REACHED,
OUT_OF_GAS,
STATIC_CALL_ALTERATION,
NO_BYTECODE_FOUND,
FAILED_BYTECODE_RETRIEVAL,
};

} // namespace bb::avm_trace
23 changes: 14 additions & 9 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -451,16 +451,15 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder,
bytecode =
trace_builder.get_bytecode(trace_builder.current_ext_call_ctx.contract_address, check_bytecode_membership);
} catch ([[maybe_unused]] const std::runtime_error& e) {
info("AVM enqueued call exceptionally halted. Error: No bytecode found for enqueued call");
info("AVM enqueued call exceptionally halted. Failed bytecode retrieval.");
// FIXME: properly handle case when bytecode is not found!
// For now, we add a dummy row in main trace to mutate later.
// Dummy row in main trace to mutate afterwards.
// This error was encountered before any opcodes were executed, but
// we need at least one row in the execution trace to then mutate and say "it halted and consumed all gas!"
trace_builder.op_add(0, 0, 0, 0, OpCode::ADD_8);
trace_builder.handle_exceptional_halt();
return AvmError::NO_BYTECODE_FOUND;
;
return AvmError::FAILED_BYTECODE_RETRIEVAL;
}

trace_builder.allocate_gas_for_call(l2_gas_allocated_to_enqueued_call, da_gas_allocated_to_enqueued_call);
Expand Down Expand Up @@ -881,7 +880,8 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder,
bytecode = trace_builder.get_bytecode(trace_builder.current_ext_call_ctx.contract_address,
/*check_membership=*/true);
} catch ([[maybe_unused]] const std::runtime_error& e) {
error = AvmError::NO_BYTECODE_FOUND;
info("AVM CALL failed bytecode retrieval.");
error = AvmError::FAILED_BYTECODE_RETRIEVAL;
}
debug_counter_stack.push(counter);
counter = 0;
Expand All @@ -901,7 +901,8 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder,
bytecode = trace_builder.get_bytecode(trace_builder.current_ext_call_ctx.contract_address,
/*check_membership=*/true);
} catch ([[maybe_unused]] const std::runtime_error& e) {
error = AvmError::NO_BYTECODE_FOUND;
info("AVM STATICCALL failed bytecode retrieval.");
error = AvmError::FAILED_BYTECODE_RETRIEVAL;
}
debug_counter_stack.push(counter);
counter = 0;
Expand All @@ -919,7 +920,8 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder,
} else if (is_ok(error)) {
// switch back to caller's bytecode
bytecode = trace_builder.get_bytecode(trace_builder.current_ext_call_ctx.contract_address,
/*check_membership=*/false);
/*check_membership=*/false,
/*jumping_to_parent=*/true);
counter = debug_counter_stack.top();
debug_counter_stack.pop();
}
Expand All @@ -939,7 +941,8 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder,
} else if (is_ok(error)) {
// switch back to caller's bytecode
bytecode = trace_builder.get_bytecode(trace_builder.current_ext_call_ctx.contract_address,
/*check_membership=*/false);
/*check_membership=*/false,
/*jumping_to_parent=*/true);
counter = debug_counter_stack.top();
debug_counter_stack.pop();
}
Expand All @@ -959,7 +962,8 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder,
} else if (is_ok(error)) {
// switch back to caller's bytecode
bytecode = trace_builder.get_bytecode(trace_builder.current_ext_call_ctx.contract_address,
/*check_membership=*/false);
/*check_membership=*/false,
/*jumping_to_parent=*/true);
counter = debug_counter_stack.top();
debug_counter_stack.pop();
}
Expand Down Expand Up @@ -1054,7 +1058,8 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder,
// otherwise, handle exceptional halt and proceed with execution in caller/parent
// We hack it in here the logic to change contract address that we are processing
bytecode = trace_builder.get_bytecode(trace_builder.current_ext_call_ctx.contract_address,
/*check_membership=*/false);
/*check_membership=*/false,
/*jumping_to_parent=*/true);
counter = debug_counter_stack.top();
debug_counter_stack.pop();

Expand Down
4 changes: 2 additions & 2 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ std::string to_name(AvmError error)
return "SIDE EFFECT LIMIT REACHED";
case AvmError::OUT_OF_GAS:
return "OUT OF GAS";
case AvmError::NO_BYTECODE_FOUND:
return "NO BYTECODE FOUND";
case AvmError::FAILED_BYTECODE_RETRIEVAL:
return "FAILED BYTECODE RETRIEVAL";
default:
throw std::runtime_error("Invalid error type");
break;
Expand Down
22 changes: 20 additions & 2 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,27 @@ void AvmTraceBuilder::rollback_to_non_revertible_checkpoint()
merkle_tree_trace_builder.rollback_to_non_revertible_checkpoint();
}

std::vector<uint8_t> AvmTraceBuilder::get_bytecode(const FF contract_address, bool check_membership)
std::vector<uint8_t> AvmTraceBuilder::get_bytecode(const FF contract_address,
bool check_membership,
bool jumping_to_parent)
{
auto clk = static_cast<uint32_t>(main_trace.size()) + 1;

// The cache contains all the unique contract class ids we have seen so far
// If this bytecode retrievals have reached the number of unique contract class IDs, can't make any more retrievals!
//
// If we could actually allow contract calls after the limit was reached, we would let you make more calls
// as long as they were to class IDs that are already here.
// BUT, the issue with this approach is that the sequencer could lie and say "this call was to a new class ID",
// and the circuit cannot prove that it's not true without deriving the class ID from bytecode,
// proving that it corresponds to the called contract address, and proving that the class ID wasn't already
// present/used. That would require more bytecode hashing which is exactly what this limit exists to avoid.
if (!jumping_to_parent && contract_class_id_cache.size() >= MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS) {
info("Limit reached for contract calls to unique class id. Limit: ",
MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS);
throw std::runtime_error("Limit reached for contract calls to unique class id.");
}

// Find the bytecode based on contract address of the public call request
const AvmContractBytecode bytecode_hint =
*std::ranges::find_if(execution_hints.all_contract_bytecode, [contract_address](const auto& contract) {
Expand Down Expand Up @@ -196,6 +213,7 @@ std::vector<uint8_t> AvmTraceBuilder::get_bytecode(const FF contract_address, bo
// Assert that the hint's exists flag matches. The flag isn't really necessary...
ASSERT(bytecode_hint.contract_instance.exists);
bytecode_membership_cache.insert(contract_address);
contract_class_id_cache.insert(bytecode_hint.contract_instance.contract_class_id);
} else {
// This was a non-membership proof!
// Enforce that the tree access membership checked a low-leaf that skips the contract address nullifier.
Expand All @@ -208,7 +226,7 @@ std::vector<uint8_t> AvmTraceBuilder::get_bytecode(const FF contract_address, bo
vinfo("Found bytecode for contract address: ", contract_address);
return bytecode_hint.bytecode;
}
vinfo("Bytecode not found for contract address: ", contract_address);
info("Bytecode not found for contract address: ", contract_address);
throw std::runtime_error("Bytecode not found");
}

Expand Down
6 changes: 5 additions & 1 deletion barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,11 @@ class AvmTraceBuilder {

void checkpoint_non_revertible_state();
void rollback_to_non_revertible_checkpoint();
std::vector<uint8_t> get_bytecode(const FF contract_address, bool check_membership = false);
std::vector<uint8_t> get_bytecode(const FF contract_address,
bool check_membership = false,
bool jumping_to_parent = false);
// Used to track the unique class ids, could also be used to cache membership checks of class ids
std::unordered_set<FF> contract_class_id_cache;
std::unordered_set<FF> bytecode_membership_cache;
void insert_private_state(const std::vector<FF>& siloed_nullifiers, const std::vector<FF>& unique_note_hashes);
void insert_private_revertible_state(const std::vector<FF>& siloed_nullifiers,
Expand Down
1 change: 1 addition & 0 deletions barretenberg/cpp/src/barretenberg/vm/aztec_constants.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
#define AVM_ACCUMULATED_DATA_LENGTH 320
#define AVM_CIRCUIT_PUBLIC_INPUTS_LENGTH 1011
#define AVM_VERIFICATION_KEY_LENGTH_IN_FIELDS 86
#define MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS 21
#define AVM_PROOF_LENGTH_IN_FIELDS 4155
#define AVM_PUBLIC_COLUMN_MAX_SIZE 1024
#define AVM_PUBLIC_INPUTS_FLATTENED_SIZE 2915
Expand Down
2 changes: 2 additions & 0 deletions l1-contracts/src/core/libraries/ConstantsGen.sol
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ library Constants {
uint256 internal constant TUBE_PROOF_LENGTH = 459;
uint256 internal constant HONK_VERIFICATION_KEY_LENGTH_IN_FIELDS = 128;
uint256 internal constant CLIENT_IVC_VERIFICATION_KEY_LENGTH_IN_FIELDS = 143;
uint256 internal constant MAX_PUBLIC_BYTECODE_SIZE_IN_BYTES = 96000;
dbanks12 marked this conversation as resolved.
Show resolved Hide resolved
uint256 internal constant MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS = 21;
uint256 internal constant MEM_TAG_FF = 0;
uint256 internal constant MEM_TAG_U1 = 1;
uint256 internal constant MEM_TAG_U8 = 2;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ contract AvmTest {
scalar::Scalar,
};
use dep::aztec::protocol_types::{
constants::MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS,
contract_class_id::ContractClassId, storage::map::derive_storage_slot_in_map,
};
use dep::aztec::state_vars::PublicMutable;
Expand Down Expand Up @@ -596,6 +597,16 @@ contract AvmTest {
AvmTest::at(context.this_address()).add_args_return(arg_a, arg_b).call(&mut context)
}

#[public]
fn nested_call_to_add_n_times_different_addresses(addrs: [AztecAddress; MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS+1]) {
for i in 0..MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS+1 {
let addr = addrs[i];
if addr != AztecAddress::empty() {
let _ = AvmTest::at(addr).add_args_return(1, 2).call(&mut context);
}
}
}

// Indirectly call_static the external call opcode to initiate a nested call to the add function
#[public]
fn nested_static_call_to_add(arg_a: Field, arg_b: Field) -> pub Field {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,12 @@ pub global CLIENT_IVC_VERIFICATION_KEY_LENGTH_IN_FIELDS: u32 = 143; // size of a
// 21 above refers to the constant AvmFlavor::NUM_PRECOMPUTED_ENTITIES
pub global AVM_VERIFICATION_KEY_LENGTH_IN_FIELDS: u32 = 2 + 21 * 4;

// Setting limits for MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS
// This value is determined by the length of the AVM trace and the MAX_PUBLIC_BYTECODE_SIZE_IN_BYTES
// (i.e. 2^21 / MAX_PUBLIC_BYTECODE_SIZE_IN_BYTES ==> 2^21 / 96,000 = 21
pub global MAX_PUBLIC_BYTECODE_SIZE_IN_BYTES: u32 = MAX_PACKED_PUBLIC_BYTECODE_SIZE_IN_FIELDS * 32;
dbanks12 marked this conversation as resolved.
Show resolved Hide resolved
pub global MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS: u32 = 21;

// `AVM_PROOF_LENGTH_IN_FIELDS` must be updated when AVM circuit changes.
// To determine latest value, hover `COMPUTED_AVM_PROOF_LENGTH_IN_FIELDS`
// in barretenberg/cpp/src/barretenberg/vm/avm/generated/flavor.hpp
Expand Down
40 changes: 39 additions & 1 deletion yarn-project/bb-prover/src/avm_proving.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ import {
} from '@aztec/circuits.js';
import { Fr } from '@aztec/foundation/fields';
import { createLogger } from '@aztec/foundation/log';
import { simulateAvmTestContractGenerateCircuitInputs } from '@aztec/simulator/public/fixtures';
import {
simulateAvmTestCallingTooManyContractClassesGenerateCircuitInputs,
simulateAvmTestContractGenerateCircuitInputs,
} from '@aztec/simulator/public/fixtures';

import fs from 'node:fs/promises';
import { tmpdir } from 'node:os';
Expand Down Expand Up @@ -85,6 +88,13 @@ describe('AVM WitGen, proof generation and verification', () => {
},
TIMEOUT,
);
it(
'Should prove and verify test that attempts too many calls to unique contract class ids',
async () => {
await proveAndVerifyAvmTestCallingTooManyContractClasses();
},
TIMEOUT,
);
it(
'Should prove and verify a top-level exceptional halt',
async () => {
Expand Down Expand Up @@ -175,3 +185,31 @@ async function proveAndVerifyAvmTestContract(
const verificationRes = await verifyAvmProof(bbPath, succeededRes.proofPath!, rawVkPath, logger);
expect(verificationRes.status).toBe(BB_RESULT.SUCCESS);
}

async function proveAndVerifyAvmTestCallingTooManyContractClasses() {
const avmCircuitInputs = await simulateAvmTestCallingTooManyContractClassesGenerateCircuitInputs();

const internalLogger = createLogger('bb-prover:avm-proving-test');
const logger = (msg: string, _data?: any) => internalLogger.verbose(msg);

// The paths for the barretenberg binary and the write path are hardcoded for now.
const bbPath = path.resolve('../../barretenberg/cpp/build/bin/bb');
const bbWorkingDirectory = await fs.mkdtemp(path.join(tmpdir(), 'bb-'));

// Then we prove.
const proofRes = await generateAvmProof(bbPath, bbWorkingDirectory, avmCircuitInputs, internalLogger);
if (proofRes.status === BB_RESULT.FAILURE) {
internalLogger.error(`Proof generation failed: ${proofRes.reason}`);
}
expect(proofRes.status).toEqual(BB_RESULT.SUCCESS);

// Then we test VK extraction and serialization.
const succeededRes = proofRes as BBSuccess;
const vkData = await extractAvmVkData(succeededRes.vkPath!);
VerificationKeyData.fromBuffer(vkData.toBuffer());

// Then we verify.
const rawVkPath = path.join(succeededRes.vkPath!, 'vk');
const verificationRes = await verifyAvmProof(bbPath, succeededRes.proofPath!, rawVkPath, logger);
expect(verificationRes.status).toBe(BB_RESULT.SUCCESS);
}
2 changes: 2 additions & 0 deletions yarn-project/circuits.js/src/constants.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ export const TUBE_PROOF_LENGTH = 459;
export const HONK_VERIFICATION_KEY_LENGTH_IN_FIELDS = 128;
export const CLIENT_IVC_VERIFICATION_KEY_LENGTH_IN_FIELDS = 143;
export const AVM_VERIFICATION_KEY_LENGTH_IN_FIELDS = 86;
export const MAX_PUBLIC_BYTECODE_SIZE_IN_BYTES = 96000;
export const MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS = 21;
export const AVM_PROOF_LENGTH_IN_FIELDS = 4155;
export const AVM_PUBLIC_COLUMN_MAX_SIZE = 1024;
export const AVM_PUBLIC_INPUTS_FLATTENED_SIZE = 2915;
Expand Down
1 change: 1 addition & 0 deletions yarn-project/circuits.js/src/scripts/constants.in.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ const CPP_CONSTANTS = [
'FEE_JUICE_ADDRESS',
'ROUTER_ADDRESS',
'FEE_JUICE_BALANCES_SLOT',
'MAX_PUBLIC_CALLS_TO_UNIQUE_CONTRACT_CLASS_IDS',
];

const CPP_GENERATORS: string[] = [
Expand Down
25 changes: 17 additions & 8 deletions yarn-project/circuits.js/src/structs/avm/avm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,6 @@ export class AvmExecutionHints {
public readonly enqueuedCalls: Vector<AvmEnqueuedCallHint>;

public readonly contractInstances: Vector<AvmContractInstanceHint>;
public readonly contractBytecodeHints: Vector<AvmContractBytecodeHints>;

public readonly publicDataReads: Vector<AvmPublicDataReadTreeHint>;
public readonly publicDataWrites: Vector<AvmPublicDataWriteTreeHint>;
Expand All @@ -696,7 +695,8 @@ export class AvmExecutionHints {
constructor(
enqueuedCalls: AvmEnqueuedCallHint[],
contractInstances: AvmContractInstanceHint[],
contractBytecodeHints: AvmContractBytecodeHints[],
// string here is the contract class id
public contractBytecodeHints: Map<string, AvmContractBytecodeHints>,
publicDataReads: AvmPublicDataReadTreeHint[],
publicDataWrites: AvmPublicDataWriteTreeHint[],
nullifierReads: AvmNullifierReadTreeHint[],
Expand All @@ -707,7 +707,6 @@ export class AvmExecutionHints {
) {
this.enqueuedCalls = new Vector(enqueuedCalls);
this.contractInstances = new Vector(contractInstances);
this.contractBytecodeHints = new Vector(contractBytecodeHints);
this.publicDataReads = new Vector(publicDataReads);
this.publicDataWrites = new Vector(publicDataWrites);
this.nullifierReads = new Vector(nullifierReads);
Expand All @@ -722,7 +721,7 @@ export class AvmExecutionHints {
* @returns an empty instance.
*/
static empty() {
return new AvmExecutionHints([], [], [], [], [], [], [], [], [], []);
return new AvmExecutionHints([], [], new Map(), [], [], [], [], [], [], []);
}

/**
Expand All @@ -749,7 +748,7 @@ export class AvmExecutionHints {
return (
this.enqueuedCalls.items.length == 0 &&
this.contractInstances.items.length == 0 &&
this.contractBytecodeHints.items.length == 0 &&
this.contractBytecodeHints.size == 0 &&
this.publicDataReads.items.length == 0 &&
this.publicDataWrites.items.length == 0 &&
this.nullifierReads.items.length == 0 &&
Expand All @@ -769,7 +768,7 @@ export class AvmExecutionHints {
return new AvmExecutionHints(
fields.enqueuedCalls.items,
fields.contractInstances.items,
fields.contractBytecodeHints.items,
fields.contractBytecodeHints,
fields.publicDataReads.items,
fields.publicDataWrites.items,
fields.nullifierReads.items,
Expand All @@ -789,7 +788,7 @@ export class AvmExecutionHints {
return [
fields.enqueuedCalls,
fields.contractInstances,
fields.contractBytecodeHints,
new Vector(Array.from(fields.contractBytecodeHints.values())),
fields.publicDataReads,
fields.publicDataWrites,
fields.nullifierReads,
Expand All @@ -807,10 +806,20 @@ export class AvmExecutionHints {
*/
static fromBuffer(buff: Buffer | BufferReader): AvmExecutionHints {
const reader = BufferReader.asReader(buff);

const readMap = (r: BufferReader) => {
const map = new Map();
const values = r.readVector(AvmContractBytecodeHints);
for (const value of values) {
map.set(value.contractInstanceHint.address.toString(), value);
}
return map;
};

return new AvmExecutionHints(
reader.readVector(AvmEnqueuedCallHint),
reader.readVector(AvmContractInstanceHint),
reader.readVector(AvmContractBytecodeHints),
readMap(reader),
reader.readVector(AvmPublicDataReadTreeHint),
reader.readVector(AvmPublicDataWriteTreeHint),
reader.readVector(AvmNullifierReadTreeHint),
Expand Down
Loading
Loading