Skip to content

Commit

Permalink
feat(simulator): Fetch return values at circuit execution (#5642)
Browse files Browse the repository at this point in the history
We were deserializing a `Program` twice in the simulator. Once to
execute the circuit and again to fetch the return witness. Every call to
`executeCircuitWithBlackBoxSolver` is followed by a call to
`getReturnWitness` in the simulator. We should just pass the return
witness right away as to not cross the WASM boundary a second time and
as to avoid a second deserialization.

I settled on a new ACVM method rather than using abi decoding as the
return witnesses are stripped from the ABI. We can have a method that
returns both the fully solved witness and the return witness based upon
the circuit. This allows us to both avoid storing duplicate return
witness information and an unnecessary deserialization.
  • Loading branch information
vezenovm authored Apr 10, 2024
1 parent 25cc70d commit 413a4e0
Show file tree
Hide file tree
Showing 9 changed files with 125 additions and 29 deletions.
56 changes: 55 additions & 1 deletion noir/noir-repo/acvm-repo/acvm_js/src/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ use wasm_bindgen::prelude::wasm_bindgen;

use crate::{
foreign_call::{resolve_brillig, ForeignCallHandler},
JsExecutionError, JsWitnessMap, JsWitnessStack,
public_witness::extract_indices,
JsExecutionError, JsSolvedAndReturnWitness, JsWitnessMap, JsWitnessStack,
};

#[wasm_bindgen]
Expand Down Expand Up @@ -58,6 +59,44 @@ pub async fn execute_circuit(
Ok(witness_map.into())
}

/// Executes an ACIR circuit to generate the solved witness from the initial witness.
/// This method also extracts the public return values from the solved witness into its own return witness.
///
/// @param {&WasmBlackBoxFunctionSolver} solver - A black box solver.
/// @param {Uint8Array} circuit - A serialized representation of an ACIR circuit
/// @param {WitnessMap} initial_witness - The initial witness map defining all of the inputs to `circuit`..
/// @param {ForeignCallHandler} foreign_call_handler - A callback to process any foreign calls from the circuit.
/// @returns {SolvedAndReturnWitness} The solved witness calculated by executing the circuit on the provided inputs, as well as the return witness indices as specified by the circuit.
#[wasm_bindgen(js_name = executeCircuitWithReturnWitness, skip_jsdoc)]
pub async fn execute_circuit_with_return_witness(
solver: &WasmBlackBoxFunctionSolver,
program: Vec<u8>,
initial_witness: JsWitnessMap,
foreign_call_handler: ForeignCallHandler,
) -> Result<JsSolvedAndReturnWitness, Error> {
console_error_panic_hook::set_once();

let program: Program = Program::deserialize_program(&program)
.map_err(|_| JsExecutionError::new("Failed to deserialize circuit. This is likely due to differing serialization formats between ACVM_JS and your compiler".to_string(), None))?;

let mut witness_stack = execute_program_with_native_program_and_return(
solver,
&program,
initial_witness,
&foreign_call_handler,
)
.await?;
let solved_witness =
witness_stack.pop().expect("Should have at least one witness on the stack").witness;

let main_circuit = &program.functions[0];
let return_witness =
extract_indices(&solved_witness, main_circuit.return_values.0.iter().copied().collect())
.map_err(|err| JsExecutionError::new(err, None))?;

Ok((solved_witness, return_witness).into())
}

/// Executes an ACIR circuit to generate the solved witness from the initial witness.
///
/// @param {&WasmBlackBoxFunctionSolver} solver - A black box solver.
Expand Down Expand Up @@ -127,6 +166,21 @@ async fn execute_program_with_native_type_return(
let program: Program = Program::deserialize_program(&program)
.map_err(|_| JsExecutionError::new("Failed to deserialize circuit. This is likely due to differing serialization formats between ACVM_JS and your compiler".to_string(), None))?;

execute_program_with_native_program_and_return(
solver,
&program,
initial_witness,
foreign_call_executor,
)
.await
}

async fn execute_program_with_native_program_and_return(
solver: &WasmBlackBoxFunctionSolver,
program: &Program,
initial_witness: JsWitnessMap,
foreign_call_executor: &ForeignCallHandler,
) -> Result<WitnessStack, Error> {
let executor = ProgramExecutor::new(&program.functions, &solver.0, foreign_call_executor);
let witness_stack = executor.execute(initial_witness.into()).await?;

Expand Down
38 changes: 37 additions & 1 deletion noir/noir-repo/acvm-repo/acvm_js/src/js_witness_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,23 @@ use acvm::{
acir::native_types::{Witness, WitnessMap},
FieldElement,
};
use js_sys::{JsString, Map};
use js_sys::{JsString, Map, Object};
use wasm_bindgen::prelude::{wasm_bindgen, JsValue};

#[wasm_bindgen(typescript_custom_section)]
const WITNESS_MAP: &'static str = r#"
// Map from witness index to hex string value of witness.
export type WitnessMap = Map<number, string>;
/**
* An execution result containing two witnesses.
* 1. The full solved witness of the execution.
* 2. The return witness which contains the given public return values within the full witness.
*/
export type SolvedAndReturnWitness = {
solvedWitness: WitnessMap;
returnWitness: WitnessMap;
}
"#;

// WitnessMap
Expand All @@ -21,6 +31,12 @@ extern "C" {
#[wasm_bindgen(constructor, js_class = "Map")]
pub fn new() -> JsWitnessMap;

#[wasm_bindgen(extends = Object, js_name = "SolvedAndReturnWitness", typescript_type = "SolvedAndReturnWitness")]
#[derive(Clone, Debug, PartialEq, Eq)]
pub type JsSolvedAndReturnWitness;

#[wasm_bindgen(constructor, js_class = "Object")]
pub fn new() -> JsSolvedAndReturnWitness;
}

impl Default for JsWitnessMap {
Expand All @@ -29,6 +45,12 @@ impl Default for JsWitnessMap {
}
}

impl Default for JsSolvedAndReturnWitness {
fn default() -> Self {
Self::new()
}
}

impl From<WitnessMap> for JsWitnessMap {
fn from(witness_map: WitnessMap) -> Self {
let js_map = JsWitnessMap::new();
Expand All @@ -54,6 +76,20 @@ impl From<JsWitnessMap> for WitnessMap {
}
}

impl From<(WitnessMap, WitnessMap)> for JsSolvedAndReturnWitness {
fn from(witness_maps: (WitnessMap, WitnessMap)) -> Self {
let js_solved_witness = JsWitnessMap::from(witness_maps.0);
let js_return_witness = JsWitnessMap::from(witness_maps.1);

let entry_map = Map::new();
entry_map.set(&JsValue::from_str("solvedWitness"), &js_solved_witness);
entry_map.set(&JsValue::from_str("returnWitness"), &js_return_witness);

let solved_and_return_witness = Object::from_entries(&entry_map).unwrap();
JsSolvedAndReturnWitness { obj: solved_and_return_witness }
}
}

pub(crate) fn js_value_to_field_element(js_value: JsValue) -> Result<FieldElement, JsString> {
let hex_str = js_value.as_string().ok_or("failed to parse field element from non-string")?;

Expand Down
3 changes: 2 additions & 1 deletion noir/noir-repo/acvm-repo/acvm_js/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ pub use compression::{
};
pub use execute::{
create_black_box_solver, execute_circuit, execute_circuit_with_black_box_solver,
execute_program, execute_program_with_black_box_solver,
execute_circuit_with_return_witness, execute_program, execute_program_with_black_box_solver,
};
pub use js_execution_error::JsExecutionError;
pub use js_witness_map::JsSolvedAndReturnWitness;
pub use js_witness_map::JsWitnessMap;
pub use js_witness_stack::JsWitnessStack;
pub use logging::init_log_level;
Expand Down
9 changes: 6 additions & 3 deletions noir/noir-repo/acvm-repo/acvm_js/src/public_witness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ use wasm_bindgen::prelude::wasm_bindgen;

use crate::JsWitnessMap;

fn extract_indices(witness_map: &WitnessMap, indices: Vec<Witness>) -> Result<WitnessMap, String> {
pub(crate) fn extract_indices(
witness_map: &WitnessMap,
indices: Vec<Witness>,
) -> Result<WitnessMap, String> {
let mut extracted_witness_map = WitnessMap::new();
for witness in indices {
let witness_value = witness_map.get(&witness).ok_or(format!(
Expand Down Expand Up @@ -44,7 +47,7 @@ pub fn get_return_witness(
let witness_map = WitnessMap::from(witness_map);

let return_witness =
extract_indices(&witness_map, circuit.return_values.0.clone().into_iter().collect())?;
extract_indices(&witness_map, circuit.return_values.0.iter().copied().collect())?;

Ok(JsWitnessMap::from(return_witness))
}
Expand All @@ -71,7 +74,7 @@ pub fn get_public_parameters_witness(
let witness_map = WitnessMap::from(solved_witness);

let public_params_witness =
extract_indices(&witness_map, circuit.public_parameters.0.clone().into_iter().collect())?;
extract_indices(&witness_map, circuit.public_parameters.0.iter().copied().collect())?;

Ok(JsWitnessMap::from(public_params_witness))
}
Expand Down
11 changes: 7 additions & 4 deletions yarn-project/simulator/src/acvm/acvm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {
type ForeignCallInput,
type ForeignCallOutput,
type WasmBlackBoxFunctionSolver,
executeCircuitWithBlackBoxSolver,
executeCircuitWithReturnWitness,
} from '@noir-lang/acvm_js';

import { traverseCauseChain } from '../common/errors.js';
Expand All @@ -27,9 +27,12 @@ type ACIRCallback = Record<
*/
export interface ACIRExecutionResult {
/**
* The partial witness of the execution.
* An execution result contains two witnesses.
* 1. The partial witness of the execution.
* 2. The return witness which contains the given public return values within the full witness.
*/
partialWitness: ACVMWitness;
returnWitness: ACVMWitness;
}

/**
Expand Down Expand Up @@ -89,7 +92,7 @@ export async function acvm(
): Promise<ACIRExecutionResult> {
const logger = createDebugLogger('aztec:simulator:acvm');

const partialWitness = await executeCircuitWithBlackBoxSolver(
const solvedAndReturnWitness = await executeCircuitWithReturnWitness(
solver,
acir,
initialWitness,
Expand Down Expand Up @@ -127,7 +130,7 @@ export async function acvm(
throw err;
});

return { partialWitness };
return { partialWitness: solvedAndReturnWitness.solvedWitness, returnWitness: solvedAndReturnWitness.returnWitness };
}

/**
Expand Down
14 changes: 5 additions & 9 deletions yarn-project/simulator/src/acvm/deserialize.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import { Fr } from '@aztec/foundation/fields';

import { getReturnWitness } from '@noir-lang/acvm_js';

import { type ACVMField, type ACVMWitness } from './acvm_types.js';

/**
Expand Down Expand Up @@ -32,13 +30,11 @@ export function frToBoolean(fr: Fr): boolean {
}

/**
* Extracts the return fields of a given partial witness.
* @param acir - The bytecode of the function.
* @param partialWitness - The witness to extract from.
* Transforms a witness map to its field elements.
* @param witness - The witness to extract from.
* @returns The return values.
*/
export function extractReturnWitness(acir: Buffer, partialWitness: ACVMWitness): Fr[] {
const returnWitness = getReturnWitness(acir, partialWitness);
const sortedKeys = [...returnWitness.keys()].sort((a, b) => a - b);
return sortedKeys.map(key => returnWitness.get(key)!).map(fromACVMField);
export function witnessMapToFields(witness: ACVMWitness): Fr[] {
const sortedKeys = [...witness.keys()].sort((a, b) => a - b);
return sortedKeys.map(key => witness.get(key)!).map(fromACVMField);
}
8 changes: 4 additions & 4 deletions yarn-project/simulator/src/client/private_execution.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { type AztecAddress } from '@aztec/foundation/aztec-address';
import { Fr } from '@aztec/foundation/fields';
import { createDebugLogger } from '@aztec/foundation/log';

import { extractReturnWitness } from '../acvm/deserialize.js';
import { witnessMapToFields } from '../acvm/deserialize.js';
import { Oracle, acvm, extractCallStack } from '../acvm/index.js';
import { ExecutionError } from '../common/errors.js';
import { type ClientExecutionContext } from './client_execution_context.js';
Expand All @@ -26,7 +26,7 @@ export async function executePrivateFunction(
const acir = artifact.bytecode;
const initialWitness = context.getInitialWitness(artifact);
const acvmCallback = new Oracle(context);
const { partialWitness } = await acvm(await AcirSimulator.getSolver(), acir, initialWitness, acvmCallback).catch(
const acirExecutionResult = await acvm(await AcirSimulator.getSolver(), acir, initialWitness, acvmCallback).catch(
(err: Error) => {
throw new ExecutionError(
err.message,
Expand All @@ -39,8 +39,8 @@ export async function executePrivateFunction(
);
},
);

const returnWitness = extractReturnWitness(acir, partialWitness);
const partialWitness = acirExecutionResult.partialWitness;
const returnWitness = witnessMapToFields(acirExecutionResult.returnWitness);
const publicInputs = PrivateCircuitPublicInputs.fromFields(returnWitness);

const encryptedLogs = context.getEncryptedLogs();
Expand Down
7 changes: 4 additions & 3 deletions yarn-project/simulator/src/client/unconstrained_execution.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { type AztecAddress } from '@aztec/foundation/aztec-address';
import { type Fr } from '@aztec/foundation/fields';
import { createDebugLogger } from '@aztec/foundation/log';

import { extractReturnWitness } from '../acvm/deserialize.js';
import { witnessMapToFields } from '../acvm/deserialize.js';
import { Oracle, acvm, extractCallStack, toACVMWitness } from '../acvm/index.js';
import { ExecutionError } from '../common/errors.js';
import { AcirSimulator } from './simulator.js';
Expand All @@ -27,7 +27,7 @@ export async function executeUnconstrainedFunction(

const acir = artifact.bytecode;
const initialWitness = toACVMWitness(0, args);
const { partialWitness } = await acvm(
const acirExecutionResult = await acvm(
await AcirSimulator.getSolver(),
acir,
initialWitness,
Expand All @@ -44,6 +44,7 @@ export async function executeUnconstrainedFunction(
);
});

return decodeReturnValues(artifact, extractReturnWitness(acir, partialWitness));
const returnWitness = witnessMapToFields(acirExecutionResult.returnWitness);
return decodeReturnValues(artifact, returnWitness);
}
// docs:end:execute_unconstrained_function
8 changes: 5 additions & 3 deletions yarn-project/simulator/src/public/executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { spawn } from 'child_process';
import fs from 'fs/promises';
import path from 'path';

import { Oracle, acvm, extractCallStack, extractReturnWitness } from '../acvm/index.js';
import { Oracle, acvm, extractCallStack, witnessMapToFields } from '../acvm/index.js';
import { AvmContext } from '../avm/avm_context.js';
import { AvmMachineState } from '../avm/avm_machine_state.js';
import { AvmSimulator } from '../avm/avm_simulator.js';
Expand Down Expand Up @@ -97,11 +97,12 @@ async function executePublicFunctionAcvm(
const initialWitness = context.getInitialWitness();
const acvmCallback = new Oracle(context);

const { partialWitness, reverted, revertReason } = await (async () => {
const { partialWitness, returnWitnessMap, reverted, revertReason } = await (async () => {
try {
const result = await acvm(await AcirSimulator.getSolver(), acir, initialWitness, acvmCallback);
return {
partialWitness: result.partialWitness,
returnWitnessMap: result.returnWitness,
reverted: false,
revertReason: undefined,
};
Expand All @@ -123,6 +124,7 @@ async function executePublicFunctionAcvm(
} else {
return {
partialWitness: undefined,
returnWitnessMap: undefined,
reverted: true,
revertReason: createSimulationError(ee),
};
Expand Down Expand Up @@ -159,7 +161,7 @@ async function executePublicFunctionAcvm(
throw new Error('No partial witness returned from ACVM');
}

const returnWitness = extractReturnWitness(acir, partialWitness);
const returnWitness = witnessMapToFields(returnWitnessMap);
const {
returnValues,
nullifierReadRequests: nullifierReadRequestsPadded,
Expand Down

0 comments on commit 413a4e0

Please sign in to comment.