From 1b6e26b06ceb12abf92fdd49b70e6e2d10852d3b Mon Sep 17 00:00:00 2001 From: guipublic <47281315+guipublic@users.noreply.github.com> Date: Mon, 9 Dec 2024 16:24:42 +0100 Subject: [PATCH] fix: map entry point indexes after all ssa passes (#6740) --- compiler/noirc_evaluator/src/acir/mod.rs | 15 ++++--- compiler/noirc_evaluator/src/ssa.rs | 2 +- .../src/ssa/opt/constant_folding.rs | 6 ++- .../src/ssa/ssa_gen/program.rs | 45 ++++++++++++------- 4 files changed, 43 insertions(+), 25 deletions(-) diff --git a/compiler/noirc_evaluator/src/acir/mod.rs b/compiler/noirc_evaluator/src/acir/mod.rs index 76f0dea95bb..63facac5a17 100644 --- a/compiler/noirc_evaluator/src/acir/mod.rs +++ b/compiler/noirc_evaluator/src/acir/mod.rs @@ -821,14 +821,12 @@ impl<'a> Context<'a> { }) .sum(); - let Some(acir_function_id) = - ssa.entry_point_to_generated_index.get(id) - else { + let Some(acir_function_id) = ssa.get_entry_point_index(id) else { unreachable!("Expected an associated final index for call to acir function {id} with args {arguments:?}"); }; let output_vars = self.acir_context.call_acir_function( - AcirFunctionId(*acir_function_id), + AcirFunctionId(acir_function_id), inputs, output_count, self.current_side_effects_enabled_var, @@ -2979,7 +2977,7 @@ mod test { build_basic_foo_with_return(&mut builder, foo_id, false, inline_type); - let ssa = builder.finish(); + let ssa = builder.finish().generate_entry_point_index(); let (acir_functions, _, _, _) = ssa .into_acir(&Brillig::default(), ExpressionWidth::default()) @@ -3087,6 +3085,7 @@ mod test { let ssa = builder.finish(); let (acir_functions, _, _, _) = ssa + .generate_entry_point_index() .into_acir(&Brillig::default(), ExpressionWidth::default()) .expect("Should compile manually written SSA into ACIR"); // The expected result should look very similar to the above test expect that the input witnesses of the `Call` @@ -3184,7 +3183,7 @@ mod test { build_basic_foo_with_return(&mut builder, foo_id, false, inline_type); - let ssa = builder.finish(); + let ssa = builder.finish().generate_entry_point_index(); let (acir_functions, _, _, _) = ssa .into_acir(&Brillig::default(), ExpressionWidth::default()) @@ -3311,6 +3310,7 @@ mod test { let brillig = ssa.to_brillig(false); let (acir_functions, brillig_functions, _, _) = ssa + .generate_entry_point_index() .into_acir(&brillig, ExpressionWidth::default()) .expect("Should compile manually written SSA into ACIR"); @@ -3375,6 +3375,7 @@ mod test { // The Brillig bytecode we insert for the stdlib is hardcoded so we do not need to provide any // Brillig artifacts to the ACIR gen pass. let (acir_functions, brillig_functions, _, _) = ssa + .generate_entry_point_index() .into_acir(&Brillig::default(), ExpressionWidth::default()) .expect("Should compile manually written SSA into ACIR"); @@ -3449,6 +3450,7 @@ mod test { println!("{}", ssa); let (acir_functions, brillig_functions, _, _) = ssa + .generate_entry_point_index() .into_acir(&brillig, ExpressionWidth::default()) .expect("Should compile manually written SSA into ACIR"); @@ -3537,6 +3539,7 @@ mod test { println!("{}", ssa); let (acir_functions, brillig_functions, _, _) = ssa + .generate_entry_point_index() .into_acir(&brillig, ExpressionWidth::default()) .expect("Should compile manually written SSA into ACIR"); diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index 8f31023f790..426659949bf 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -449,7 +449,7 @@ impl SsaBuilder { } fn finish(self) -> Ssa { - self.ssa + self.ssa.generate_entry_point_index() } /// Runs the given SSA pass and prints the SSA afterward if `print_ssa_passes` is true. diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index e039b8f0f9e..56029a8fbd4 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -125,11 +125,13 @@ impl Ssa { } // The ones that remain are never called: let's remove them. - for func_id in brillig_functions.keys() { + for (func_id, func) in &brillig_functions { // We never want to remove the main function (it could be `unconstrained` or it // could have been turned into brillig if `--force-brillig` was given). // We also don't want to remove entry points. - if self.main_id == *func_id || self.entry_point_to_generated_index.contains_key(func_id) + let runtime = func.runtime(); + if self.main_id == *func_id + || (runtime.is_entry_point() && matches!(runtime, RuntimeType::Acir(_))) { continue; } diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/program.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/program.rs index 3dba6dc0a98..de01a4596ad 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/program.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/program.rs @@ -25,7 +25,7 @@ pub(crate) struct Ssa { /// This mapping is necessary to use the correct function pointer for an ACIR call, /// as the final program artifact will be a list of only entry point functions. #[serde(skip)] - pub(crate) entry_point_to_generated_index: BTreeMap, + entry_point_to_generated_index: BTreeMap, // We can skip serializing this field as the error selector types end up as part of the // ABI not the actual SSA IR. #[serde(skip)] @@ -47,25 +47,11 @@ impl Ssa { (f.id(), f) }); - let entry_point_to_generated_index = btree_map( - functions - .iter() - .filter(|(_, func)| { - let runtime = func.runtime(); - match func.runtime() { - RuntimeType::Acir(_) => runtime.is_entry_point() || func.id() == main_id, - RuntimeType::Brillig(_) => false, - } - }) - .enumerate(), - |(i, (id, _))| (*id, i as u32), - ); - Self { functions, main_id, next_id: AtomicCounter::starting_after(max_id), - entry_point_to_generated_index, + entry_point_to_generated_index: BTreeMap::new(), error_selector_to_type: error_types, } } @@ -98,6 +84,33 @@ impl Ssa { self.functions.insert(new_id, function); new_id } + pub(crate) fn generate_entry_point_index(mut self) -> Self { + self.entry_point_to_generated_index = btree_map( + self.functions + .iter() + .filter(|(_, func)| { + let runtime = func.runtime(); + match func.runtime() { + RuntimeType::Acir(_) => { + runtime.is_entry_point() || func.id() == self.main_id + } + RuntimeType::Brillig(_) => false, + } + }) + .enumerate(), + |(i, (id, _))| (*id, i as u32), + ); + self + } + + pub(crate) fn get_entry_point_index(&self, func_id: &FunctionId) -> Option { + // Ensure the map has been initialized + assert!( + !self.entry_point_to_generated_index.is_empty(), + "Trying to read uninitialized entry point index" + ); + self.entry_point_to_generated_index.get(func_id).copied() + } } impl Display for Ssa {