From 2c9fe44e7d049c6ccc56d0c63e625f36b06262a5 Mon Sep 17 00:00:00 2001 From: fcarreiro Date: Thu, 24 Oct 2024 15:51:37 +0000 Subject: [PATCH] feat(avm)!: revert/rethrow oracle --- avm-transpiler/src/transpile.rs | 78 +++++++++++++------ .../aztec/src/context/public_context.nr | 15 ++++ .../contracts/avm_test_contract/src/main.nr | 6 ++ noir/noir-repo/acvm-repo/acvm/tests/solver.rs | 14 +++- .../acvm-repo/brillig/src/opcodes.rs | 2 +- .../noir-repo/acvm-repo/brillig_vm/src/lib.rs | 19 ++++- .../noirc_evaluator/src/brillig/brillig_ir.rs | 12 ++- .../brillig_ir/codegen_control_flow.rs | 22 ++++-- .../src/brillig/brillig_ir/debug_show.rs | 2 +- .../src/brillig/brillig_ir/instructions.rs | 4 +- .../simulator/src/avm/avm_simulator.test.ts | 10 +++ .../src/avm/opcodes/external_calls.test.ts | 9 ++- .../src/avm/opcodes/external_calls.ts | 14 ++-- 13 files changed, 154 insertions(+), 53 deletions(-) diff --git a/avm-transpiler/src/transpile.rs b/avm-transpiler/src/transpile.rs index b5251154a725..0cfa1889a4a6 100644 --- a/avm-transpiler/src/transpile.rs +++ b/avm-transpiler/src/transpile.rs @@ -316,29 +316,11 @@ pub fn brillig_to_avm( }); } BrilligOpcode::Trap { revert_data } => { - let bits_needed = - *[bits_needed_for(&revert_data.pointer), bits_needed_for(&revert_data.size)] - .iter() - .max() - .unwrap(); - let avm_opcode = match bits_needed { - 8 => AvmOpcode::REVERT_8, - 16 => AvmOpcode::REVERT_16, - _ => panic!("REVERT only support 8 or 16 bit encodings, got: {}", bits_needed), - }; - avm_instrs.push(AvmInstruction { - opcode: avm_opcode, - indirect: Some( - AddressingModeBuilder::default() - .indirect_operand(&revert_data.pointer) - .build(), - ), - operands: vec![ - make_operand(bits_needed, &revert_data.pointer.to_usize()), - make_operand(bits_needed, &revert_data.size), - ], - ..Default::default() - }); + generate_revert_instruction( + &mut avm_instrs, + &revert_data.pointer, + &revert_data.size, + ); } BrilligOpcode::Cast { destination, source, bit_size } => { handle_cast(&mut avm_instrs, source, destination, *bit_size); @@ -418,6 +400,7 @@ fn handle_foreign_call( } "avmOpcodeCalldataCopy" => handle_calldata_copy(avm_instrs, destinations, inputs), "avmOpcodeReturn" => handle_return(avm_instrs, destinations, inputs), + "avmOpcodeRevert" => handle_revert(avm_instrs, destinations, inputs), "avmOpcodeStorageRead" => handle_storage_read(avm_instrs, destinations, inputs), "avmOpcodeStorageWrite" => handle_storage_write(avm_instrs, destinations, inputs), "debugLog" => handle_debug_log(avm_instrs, destinations, inputs), @@ -929,6 +912,35 @@ fn generate_cast_instruction( } } +/// Generates an AVM REVERT instruction. +fn generate_revert_instruction( + avm_instrs: &mut Vec, + revert_data_pointer: &MemoryAddress, + revert_data_size_offset: &MemoryAddress, +) { + let bits_needed = + *[revert_data_pointer, revert_data_size_offset].map(bits_needed_for).iter().max().unwrap(); + let avm_opcode = match bits_needed { + 8 => AvmOpcode::REVERT_8, + 16 => AvmOpcode::REVERT_16, + _ => panic!("REVERT only support 8 or 16 bit encodings, got: {}", bits_needed), + }; + avm_instrs.push(AvmInstruction { + opcode: avm_opcode, + indirect: Some( + AddressingModeBuilder::default() + .indirect_operand(revert_data_pointer) + .direct_operand(revert_data_size_offset) + .build(), + ), + operands: vec![ + make_operand(bits_needed, &revert_data_pointer.to_usize()), + make_operand(bits_needed, &revert_data_size_offset.to_usize()), + ], + ..Default::default() + }); +} + /// Generates an AVM MOV instruction. fn generate_mov_instruction( indirect: Option, @@ -1214,7 +1226,6 @@ fn handle_return( assert!(inputs.len() == 1); assert!(destinations.len() == 0); - // First arg is the size, which is ignored because it's redundant. let (return_data_offset, return_data_size) = match inputs[0] { ValueOrArray::HeapArray(HeapArray { pointer, size }) => (pointer, size as u32), _ => panic!("Return instruction's args input should be a HeapArray"), @@ -1233,6 +1244,25 @@ fn handle_return( }); } +// #[oracle(avmOpcodeRevert)] +// unconstrained fn revert_opcode(revertdata: [Field]) {} +fn handle_revert( + avm_instrs: &mut Vec, + destinations: &Vec, + inputs: &Vec, +) { + assert!(inputs.len() == 2); + assert!(destinations.len() == 0); + + // First arg is the size, which is ignored because it's redundant. + let (revert_data_offset, revert_data_size_offset) = match inputs[1] { + ValueOrArray::HeapVector(HeapVector { pointer, size }) => (pointer, size), + _ => panic!("Revert instruction's args input should be a HeapVector"), + }; + + generate_revert_instruction(avm_instrs, &revert_data_offset, &revert_data_size_offset); +} + /// Emit a storage write opcode /// The current implementation writes an array of values into storage ( contiguous slots in memory ) fn handle_storage_write( diff --git a/noir-projects/aztec-nr/aztec/src/context/public_context.nr b/noir-projects/aztec-nr/aztec/src/context/public_context.nr index 8c09a49c8c0e..d68ae89b1a1f 100644 --- a/noir-projects/aztec-nr/aztec/src/context/public_context.nr +++ b/noir-projects/aztec-nr/aztec/src/context/public_context.nr @@ -293,6 +293,14 @@ unconstrained fn avm_return(returndata: [Field; N]) { return_opcode(returndata) } +// This opcode reverts using the exact data given. In general it should only be used +// to do rethrows, where the revert data is the same as the original revert data. +// For normal reverts, use Noir's `assert` which, on top of reverting, will also add +// an error selector to the revert data. +unconstrained fn avm_revert(revertdata: [Field]) { + revert_opcode(revertdata) +} + unconstrained fn storage_read(storage_slot: Field) -> Field { storage_read_opcode(storage_slot) } @@ -378,6 +386,13 @@ unconstrained fn calldata_copy_opcode(cdoffset: u32, copy_size: u32) #[oracle(avmOpcodeReturn)] unconstrained fn return_opcode(returndata: [Field; N]) {} +// This opcode reverts using the exact data given. In general it should only be used +// to do rethrows, where the revert data is the same as the original revert data. +// For normal reverts, use Noir's `assert` which, on top of reverting, will also add +// an error selector to the revert data. +#[oracle(avmOpcodeRevert)] +unconstrained fn revert_opcode(revertdata: [Field]) {} + #[oracle(avmOpcodeCall)] unconstrained fn call_opcode( gas: [Field; 2], // gas allocation: [l2_gas, da_gas] diff --git a/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr b/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr index ff31737f6abc..f3974edd6e81 100644 --- a/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr +++ b/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr @@ -252,6 +252,12 @@ contract AvmTest { [4, 5, 6] // Should not get here. } + #[public] + fn revert_oracle() -> [Field; 3] { + dep::aztec::context::public_context::avm_revert([1, 2, 3]); + [4, 5, 6] // Should not get here. + } + /************************************************************************ * Hashing functions ************************************************************************/ diff --git a/noir/noir-repo/acvm-repo/acvm/tests/solver.rs b/noir/noir-repo/acvm-repo/acvm/tests/solver.rs index efa8de289e59..2fcf452845bb 100644 --- a/noir/noir-repo/acvm-repo/acvm/tests/solver.rs +++ b/noir/noir-repo/acvm-repo/acvm/tests/solver.rs @@ -1,7 +1,7 @@ use std::collections::{BTreeMap, HashSet}; use std::sync::Arc; -use acir::brillig::{BitSize, IntegerBitSize}; +use acir::brillig::{BitSize, HeapVector, IntegerBitSize}; use acir::{ acir_field::GenericFieldElement, brillig::{BinaryFieldOp, HeapArray, MemoryAddress, Opcode as BrilligOpcode, ValueOrArray}, @@ -667,7 +667,12 @@ fn unsatisfied_opcode_resolved_brillig() { let jmp_if_opcode = BrilligOpcode::JumpIf { condition: MemoryAddress::direct(2), location: location_of_stop }; - let trap_opcode = BrilligOpcode::Trap { revert_data: HeapArray::default() }; + let trap_opcode = BrilligOpcode::Trap { + revert_data: HeapVector { + pointer: MemoryAddress::direct(0), + size: MemoryAddress::direct(3), + }, + }; let stop_opcode = BrilligOpcode::Stop { return_data_offset: 0, return_data_size: 0 }; let brillig_bytecode = BrilligBytecode { @@ -682,6 +687,11 @@ fn unsatisfied_opcode_resolved_brillig() { bit_size: BitSize::Integer(IntegerBitSize::U32), value: FieldElement::from(0u64), }, + BrilligOpcode::Const { + destination: MemoryAddress::direct(3), + bit_size: BitSize::Integer(IntegerBitSize::U32), + value: FieldElement::from(0u64), + }, calldata_copy_opcode, equal_opcode, jmp_if_opcode, diff --git a/noir/noir-repo/acvm-repo/brillig/src/opcodes.rs b/noir/noir-repo/acvm-repo/brillig/src/opcodes.rs index 69ca9ed379a0..0d87c5b9410e 100644 --- a/noir/noir-repo/acvm-repo/brillig/src/opcodes.rs +++ b/noir/noir-repo/acvm-repo/brillig/src/opcodes.rs @@ -305,7 +305,7 @@ pub enum BrilligOpcode { BlackBox(BlackBoxOp), /// Used to denote execution failure, returning data after the offset Trap { - revert_data: HeapArray, + revert_data: HeapVector, }, /// Stop execution, returning data after the offset Stop { diff --git a/noir/noir-repo/acvm-repo/brillig_vm/src/lib.rs b/noir/noir-repo/acvm-repo/brillig_vm/src/lib.rs index 1e5ad84eb8fb..07bde85724d1 100644 --- a/noir/noir-repo/acvm-repo/brillig_vm/src/lib.rs +++ b/noir/noir-repo/acvm-repo/brillig_vm/src/lib.rs @@ -314,10 +314,11 @@ impl<'a, F: AcirField, B: BlackBoxFunctionSolver> VM<'a, F, B> { self.increment_program_counter() } Opcode::Trap { revert_data } => { - if revert_data.size > 0 { + let revert_data_size = self.memory.read(revert_data.size).to_usize(); + if revert_data_size > 0 { self.trap( self.memory.read_ref(revert_data.pointer).unwrap_direct(), - revert_data.size, + revert_data_size, ) } else { self.trap(0, 0) @@ -904,8 +905,18 @@ mod tests { size_address: MemoryAddress::direct(0), offset_address: MemoryAddress::direct(1), }, - Opcode::Jump { location: 5 }, - Opcode::Trap { revert_data: HeapArray::default() }, + Opcode::Jump { location: 6 }, + Opcode::Const { + destination: MemoryAddress::direct(0), + bit_size: BitSize::Integer(IntegerBitSize::U32), + value: FieldElement::from(0u64), + }, + Opcode::Trap { + revert_data: HeapVector { + pointer: MemoryAddress::direct(0), + size: MemoryAddress::direct(0), + }, + }, Opcode::BinaryFieldOp { op: BinaryFieldOp::Equals, lhs: MemoryAddress::direct(0), diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir.rs index 4964ff27f608..4ef5556299ba 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir.rs @@ -289,7 +289,17 @@ pub(crate) mod tests { // uses unresolved jumps which requires a block to be constructed in SSA and // we don't need this for Brillig IR tests context.push_opcode(BrilligOpcode::JumpIf { condition: r_equality, location: 8 }); - context.push_opcode(BrilligOpcode::Trap { revert_data: HeapArray::default() }); + context.push_opcode(BrilligOpcode::Const { + destination: MemoryAddress::direct(0), + bit_size: BitSize::Integer(IntegerBitSize::U32), + value: FieldElement::from(0u64), + }); + context.push_opcode(BrilligOpcode::Trap { + revert_data: HeapVector { + pointer: MemoryAddress::direct(0), + size: MemoryAddress::direct(0), + }, + }); context.stop_instruction(); diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_control_flow.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_control_flow.rs index c305d8c78f31..6935ebb0f530 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_control_flow.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_control_flow.rs @@ -1,5 +1,5 @@ use acvm::{ - acir::brillig::{HeapArray, MemoryAddress}, + acir::brillig::{HeapVector, MemoryAddress}, AcirField, }; @@ -157,12 +157,12 @@ impl BrilligContext< assert!(condition.bit_size == 1); self.codegen_if_not(condition.address, |ctx| { - let revert_data = HeapArray { - pointer: ctx.allocate_register(), - // + 1 due to the revert data id being the first item returned - size: Self::flattened_tuple_size(&revert_data_types) + 1, - }; - ctx.codegen_allocate_immediate_mem(revert_data.pointer, revert_data.size); + // + 1 due to the revert data id being the first item returned + let revert_data_size = Self::flattened_tuple_size(&revert_data_types) + 1; + let revert_data_size_var = ctx.make_usize_constant_instruction(revert_data_size.into()); + let revert_data = + HeapVector { pointer: ctx.allocate_register(), size: revert_data_size_var.address }; + ctx.codegen_allocate_immediate_mem(revert_data.pointer, revert_data_size); let current_revert_data_pointer = ctx.allocate_register(); ctx.mov_instruction(current_revert_data_pointer, revert_data.pointer); @@ -208,6 +208,7 @@ impl BrilligContext< ); } ctx.trap_instruction(revert_data); + ctx.deallocate_single_addr(revert_data_size_var); ctx.deallocate_register(revert_data.pointer); ctx.deallocate_register(current_revert_data_pointer); }); @@ -223,7 +224,12 @@ impl BrilligContext< assert!(condition.bit_size == 1); self.codegen_if_not(condition.address, |ctx| { - ctx.trap_instruction(HeapArray::default()); + let revert_data_size_var = ctx.make_usize_constant_instruction(F::zero()); + ctx.trap_instruction(HeapVector { + pointer: MemoryAddress::direct(0), + size: revert_data_size_var.address, + }); + ctx.deallocate_single_addr(revert_data_size_var); if let Some(assert_message) = assert_message { ctx.obj.add_assert_message_to_last_opcode(assert_message); } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs index 2a46a04cc91a..4e82a0d3af57 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs @@ -117,7 +117,7 @@ impl DebugShow { } /// Emits a `trap` instruction. - pub(crate) fn trap_instruction(&self, revert_data: HeapArray) { + pub(crate) fn trap_instruction(&self, revert_data: HeapVector) { debug_println!(self.enable_debug_trace, " TRAP {}", revert_data); } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/instructions.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/instructions.rs index 1ac672687f34..5f0aedb9c5e4 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/instructions.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/instructions.rs @@ -1,7 +1,7 @@ use acvm::{ acir::{ brillig::{ - BinaryFieldOp, BinaryIntOp, BitSize, BlackBoxOp, HeapArray, HeapValueType, + BinaryFieldOp, BinaryIntOp, BitSize, BlackBoxOp, HeapValueType, HeapVector, MemoryAddress, Opcode as BrilligOpcode, ValueOrArray, }, AcirField, @@ -425,7 +425,7 @@ impl BrilligContext< self.deallocate_single_addr(offset_var); } - pub(super) fn trap_instruction(&mut self, revert_data: HeapArray) { + pub(super) fn trap_instruction(&mut self, revert_data: HeapVector) { self.debug_show.trap_instruction(revert_data); self.push_opcode(BrilligOpcode::Trap { revert_data }); diff --git a/yarn-project/simulator/src/avm/avm_simulator.test.ts b/yarn-project/simulator/src/avm/avm_simulator.test.ts index a4c39654a5a4..77714c3bd7c6 100644 --- a/yarn-project/simulator/src/avm/avm_simulator.test.ts +++ b/yarn-project/simulator/src/avm/avm_simulator.test.ts @@ -190,6 +190,16 @@ describe('AVM simulator: transpiled Noir contracts', () => { expect(results.output).toEqual([new Fr(1), new Fr(2), new Fr(3)]); }); + it('Should handle revert oracle', async () => { + const context = initContext(); + + const bytecode = getAvmTestContractBytecode('revert_oracle'); + const results = await new AvmSimulator(context).executeBytecode(bytecode); + + expect(results.reverted).toBe(true); + expect(results.output).toEqual([new Fr(1), new Fr(2), new Fr(3)]); + }); + it('ec_add should not revert', async () => { // This test performs the same doubling as in elliptic_curve_add_and_double // But the optimizer is not able to optimize out the addition diff --git a/yarn-project/simulator/src/avm/opcodes/external_calls.test.ts b/yarn-project/simulator/src/avm/opcodes/external_calls.test.ts index 0f74ab3a6eb8..839509d75e2b 100644 --- a/yarn-project/simulator/src/avm/opcodes/external_calls.test.ts +++ b/yarn-project/simulator/src/avm/opcodes/external_calls.test.ts @@ -290,9 +290,9 @@ describe('External Calls', () => { Opcode.REVERT_16, // opcode 0x01, // indirect ...Buffer.from('1234', 'hex'), // returnOffset - ...Buffer.from('a234', 'hex'), // retSize + ...Buffer.from('a234', 'hex'), // retSizeOffset ]); - const inst = new Revert(/*indirect=*/ 0x01, /*returnOffset=*/ 0x1234, /*retSize=*/ 0xa234).as( + const inst = new Revert(/*indirect=*/ 0x01, /*returnOffset=*/ 0x1234, /*retSizeOffset=*/ 0xa234).as( Opcode.REVERT_16, Revert.wireFormat16, ); @@ -305,9 +305,10 @@ describe('External Calls', () => { const returnData = [...'assert message'].flatMap(c => new Field(c.charCodeAt(0))); returnData.unshift(new Field(0n)); // Prepend an error selector - context.machineState.memory.setSlice(0, returnData); + context.machineState.memory.set(0, new Uint32(returnData.length)); + context.machineState.memory.setSlice(10, returnData); - const instruction = new Revert(/*indirect=*/ 0, /*returnOffset=*/ 0, returnData.length); + const instruction = new Revert(/*indirect=*/ 0, /*returnOffset=*/ 10, /*retSizeOffset=*/ 0); await instruction.execute(context); expect(context.machineState.getHalted()).toBe(true); diff --git a/yarn-project/simulator/src/avm/opcodes/external_calls.ts b/yarn-project/simulator/src/avm/opcodes/external_calls.ts index 2507885f2e0f..df8950e71532 100644 --- a/yarn-project/simulator/src/avm/opcodes/external_calls.ts +++ b/yarn-project/simulator/src/avm/opcodes/external_calls.ts @@ -204,22 +204,24 @@ export class Revert extends Instruction { OperandType.UINT16, ]; - constructor(private indirect: number, private returnOffset: number, private retSize: number) { + constructor(private indirect: number, private returnOffset: number, private retSizeOffset: number) { super(); } public async execute(context: AvmContext): Promise { const memory = context.machineState.memory.track(this.type); - context.machineState.consumeGas(this.gasCost(this.retSize)); - const operands = [this.returnOffset]; + const operands = [this.returnOffset, this.retSizeOffset]; const addressing = Addressing.fromWire(this.indirect, operands.length); - const [returnOffset] = addressing.resolve(operands, memory); + const [returnOffset, retSizeOffset] = addressing.resolve(operands, memory); - const output = memory.getSlice(returnOffset, this.retSize).map(word => word.toFr()); + memory.checkTag(TypeTag.UINT32, retSizeOffset); + const retSize = memory.get(retSizeOffset).toNumber(); + context.machineState.consumeGas(this.gasCost(retSize)); + const output = memory.getSlice(returnOffset, retSize).map(word => word.toFr()); context.machineState.revert(output); - memory.assert({ reads: this.retSize, addressing }); + memory.assert({ reads: retSize + 1, addressing }); } }