From 9fb7e4d306041edc5158e2dffd71a19ccc578ac2 Mon Sep 17 00:00:00 2001 From: jfecher Date: Wed, 3 Jul 2024 09:57:58 -0500 Subject: [PATCH] fix: ICE when using a comptime let variable in runtime code (#5391) # Description ## Problem\* Resolves https://github.com/noir-lang/noir/issues/5388 ## Summary\* We were getting an ICE error because all comptime blocks are removed, but comptime let variables were not inlined at call sites. So given the program ```rs fn main() { comptime let x = 2; assert_eq(x, 2); } ``` The monomorphizer would see: ```rs fn main() { assert_eq(x, 2); } ``` ## Additional Context ## Documentation\* Check one: - [x] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist\* - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --- aztec_macros/src/utils/hir_utils.rs | 1 + compiler/noirc_frontend/src/elaborator/mod.rs | 33 ++++++++++++++----- .../noirc_frontend/src/elaborator/patterns.rs | 26 ++++++++++++--- .../src/elaborator/statements.rs | 5 ++- .../noirc_frontend/src/elaborator/types.rs | 13 ++++---- .../src/hir/def_collector/dc_mod.rs | 2 ++ .../src/hir/resolution/resolver.rs | 8 ++--- .../noirc_frontend/src/hir/type_check/mod.rs | 29 ++++++++++++---- compiler/noirc_frontend/src/node_interner.rs | 15 ++++++--- compiler/noirc_frontend/src/tests.rs | 11 +++++++ 10 files changed, 105 insertions(+), 38 deletions(-) diff --git a/aztec_macros/src/utils/hir_utils.rs b/aztec_macros/src/utils/hir_utils.rs index d4b55e1311f..3f47fe5ca25 100644 --- a/aztec_macros/src/utils/hir_utils.rs +++ b/aztec_macros/src/utils/hir_utils.rs @@ -230,6 +230,7 @@ pub fn inject_global( file_id, global.attributes.clone(), false, + false, ); // Add the statement to the scope so its path can be looked up later diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index 4b4cf33aa2a..d7087a5ab07 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -136,8 +136,6 @@ pub struct Elaborator<'context> { /// Each constraint in the `where` clause of the function currently being resolved. trait_bounds: Vec, - current_function: Option, - /// This is a stack of function contexts. Most of the time, for each function we /// expect this to be of length one, containing each type variable and trait constraint /// used in the function. This is also pushed to when a `comptime {}` block is used within @@ -196,7 +194,6 @@ impl<'context> Elaborator<'context> { crate_id, resolving_ids: BTreeSet::new(), trait_bounds: Vec::new(), - current_function: None, function_context: vec![FunctionContext::default()], current_trait_impl: None, comptime_scopes: vec![HashMap::default()], @@ -329,8 +326,6 @@ impl<'context> Elaborator<'context> { FunctionBody::Resolving => return, }; - let old_function = std::mem::replace(&mut self.current_function, Some(id)); - self.scopes.start_function(); let old_item = std::mem::replace(&mut self.current_item, Some(DependencyId::Function(id))); @@ -407,7 +402,6 @@ impl<'context> Elaborator<'context> { self.trait_bounds.clear(); self.interner.update_fn(id, hir_func); - self.current_function = old_function; self.current_item = old_item; } @@ -615,8 +609,6 @@ impl<'context> Elaborator<'context> { func_id: FuncId, is_trait_function: bool, ) { - self.current_function = Some(func_id); - let in_contract = if self.self_type.is_some() { // Without this, impl methods can accidentally be placed in contracts. // See: https://github.com/noir-lang/noir/issues/3254 @@ -738,7 +730,6 @@ impl<'context> Elaborator<'context> { }; self.interner.push_fn_meta(meta, func_id); - self.current_function = None; self.scopes.end_function(); self.current_item = None; } @@ -1468,6 +1459,30 @@ impl<'context> Elaborator<'context> { } } + /// True if we're currently within a `comptime` block, function, or global + fn in_comptime_context(&self) -> bool { + // The first context is the global context, followed by the function-specific context. + // Any context after that is a `comptime {}` block's. + if self.function_context.len() > 2 { + return true; + } + + match self.current_item { + Some(DependencyId::Function(id)) => self.interner.function_modifiers(&id).is_comptime, + Some(DependencyId::Global(id)) => self.interner.get_global_definition(id).comptime, + _ => false, + } + } + + /// True if we're currently within a constrained function. + /// Defaults to `true` if the current function is unknown. + fn in_constrained_function(&self) -> bool { + self.current_item.map_or(true, |id| match id { + DependencyId::Function(id) => !self.interner.function_modifiers(&id).is_unconstrained, + _ => true, + }) + } + /// Filters out comptime items from non-comptime items. /// Returns a pair of (comptime items, non-comptime items) fn filter_comptime_items(mut items: CollectedItems) -> (CollectedItems, CollectedItems) { diff --git a/compiler/noirc_frontend/src/elaborator/patterns.rs b/compiler/noirc_frontend/src/elaborator/patterns.rs index 94f03fb511b..61d30a915fc 100644 --- a/compiler/noirc_frontend/src/elaborator/patterns.rs +++ b/compiler/noirc_frontend/src/elaborator/patterns.rs @@ -5,6 +5,7 @@ use rustc_hash::FxHashSet as HashSet; use crate::{ ast::ERROR_IDENT, hir::{ + comptime::Interpreter, def_collector::dc_crate::CompilationError, resolution::errors::ResolverError, type_check::{Source, TypeCheckError}, @@ -285,19 +286,21 @@ impl<'context> Elaborator<'context> { } let location = Location::new(name.span(), self.file); + let name = name.0.contents; + let comptime = self.in_comptime_context(); let id = - self.interner.push_definition(name.0.contents.clone(), mutable, definition, location); + self.interner.push_definition(name.clone(), mutable, comptime, definition, location); let ident = HirIdent::non_trait_method(id, location); let resolver_meta = ResolverMeta { num_times_used: 0, ident: ident.clone(), warn_if_unused }; let scope = self.scopes.get_mut_scope(); - let old_value = scope.add_key_value(name.0.contents.clone(), resolver_meta); + let old_value = scope.add_key_value(name.clone(), resolver_meta); if !allow_shadowing { if let Some(old_value) = old_value { self.push_err(ResolverError::DuplicateDefinition { - name: name.0.contents, + name, first_span: old_value.ident.location.span, second_span: location.span, }); @@ -329,6 +332,7 @@ impl<'context> Elaborator<'context> { name: Ident, definition: DefinitionKind, ) -> HirIdent { + let comptime = self.in_comptime_context(); let scope = self.scopes.get_mut_scope(); // This check is necessary to maintain the same definition ids in the interner. Currently, each function uses a new resolver that has its own ScopeForest and thus global scope. @@ -350,8 +354,8 @@ impl<'context> Elaborator<'context> { (hir_ident, resolver_meta) } else { let location = Location::new(name.span(), self.file); - let id = - self.interner.push_definition(name.0.contents.clone(), false, definition, location); + let name = name.0.contents.clone(); + let id = self.interner.push_definition(name, false, comptime, definition, location); let ident = HirIdent::non_trait_method(id, location); let resolver_meta = ResolverMeta { num_times_used: 0, ident: ident.clone(), warn_if_unused: true }; @@ -400,12 +404,24 @@ impl<'context> Elaborator<'context> { ) -> (ExprId, Type) { let span = variable.span; let expr = self.resolve_variable(variable); + let definition_id = expr.id; let id = self.interner.push_expr(HirExpression::Ident(expr.clone(), generics.clone())); self.interner.push_expr_location(id, span, self.file); let typ = self.type_check_variable(expr, id, generics); self.interner.push_expr_type(id, typ.clone()); + + // Comptime variables must be replaced with their values + if let Some(definition) = self.interner.try_definition(definition_id) { + if definition.comptime && !self.in_comptime_context() { + let mut interpreter = + Interpreter::new(self.interner, &mut self.comptime_scopes, self.crate_id); + let value = interpreter.evaluate(id); + return self.inline_comptime_value(value, span); + } + } + (id, typ) } diff --git a/compiler/noirc_frontend/src/elaborator/statements.rs b/compiler/noirc_frontend/src/elaborator/statements.rs index e2d44919c5e..8d97bd1a25d 100644 --- a/compiler/noirc_frontend/src/elaborator/statements.rs +++ b/compiler/noirc_frontend/src/elaborator/statements.rs @@ -206,9 +206,8 @@ impl<'context> Elaborator<'context> { } fn elaborate_jump(&mut self, is_break: bool, span: noirc_errors::Span) -> (HirStatement, Type) { - let in_constrained_function = self - .current_function - .map_or(true, |func_id| !self.interner.function_modifiers(&func_id).is_unconstrained); + let in_constrained_function = self.in_constrained_function(); + if in_constrained_function { self.push_err(ResolverError::JumpInConstrainedFn { is_break, span }); } diff --git a/compiler/noirc_frontend/src/elaborator/types.rs b/compiler/noirc_frontend/src/elaborator/types.rs index 924950bd0b6..d27e150d649 100644 --- a/compiler/noirc_frontend/src/elaborator/types.rs +++ b/compiler/noirc_frontend/src/elaborator/types.rs @@ -1169,9 +1169,11 @@ impl<'context> Elaborator<'context> { None } Type::NamedGeneric(_, _, _) => { - let func_meta = self.interner.function_meta( - &self.current_function.expect("unexpected method outside a function"), - ); + let func_id = match self.current_item { + Some(DependencyId::Function(id)) => id, + _ => panic!("unexpected method outside a function"), + }; + let func_meta = self.interner.function_meta(&func_id); for constraint in &func_meta.trait_constraints { if *object_type == constraint.typ { @@ -1242,9 +1244,8 @@ impl<'context> Elaborator<'context> { lints::deprecated_function(elaborator.interner, call.func).map(Into::into) }); - let func_mod = self.current_function.map(|func| self.interner.function_modifiers(&func)); - let is_current_func_constrained = - func_mod.map_or(true, |func_mod| !func_mod.is_unconstrained); + let is_current_func_constrained = self.in_constrained_function(); + let is_unconstrained_call = self.is_unconstrained_call(call.func); let crossing_runtime_boundary = is_current_func_constrained && is_unconstrained_call; if crossing_runtime_boundary { diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs index 935a891170c..e908f5c1545 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs @@ -474,6 +474,7 @@ impl<'a> ModCollector<'a> { self.file_id, vec![], false, + false, ); if let Err((first_def, second_def)) = self.def_collector.def_map.modules @@ -811,6 +812,7 @@ pub(crate) fn collect_global( file_id, global.attributes.clone(), matches!(global.pattern, Pattern::Mutable { .. }), + global.comptime, ); // Add the statement to the scope so its path can be looked up later diff --git a/compiler/noirc_frontend/src/hir/resolution/resolver.rs b/compiler/noirc_frontend/src/hir/resolution/resolver.rs index 2eb33f7603f..c97de6d3e05 100644 --- a/compiler/noirc_frontend/src/hir/resolution/resolver.rs +++ b/compiler/noirc_frontend/src/hir/resolution/resolver.rs @@ -399,8 +399,8 @@ impl<'a> Resolver<'a> { } let location = Location::new(name.span(), self.file); - let id = - self.interner.push_definition(name.0.contents.clone(), mutable, definition, location); + let var_name = name.0.contents.clone(); + let id = self.interner.push_definition(var_name, mutable, false, definition, location); let ident = HirIdent::non_trait_method(id, location); let resolver_meta = ResolverMeta { num_times_used: 0, ident: ident.clone(), warn_if_unused }; @@ -445,8 +445,8 @@ impl<'a> Resolver<'a> { (hir_ident, resolver_meta) } else { let location = Location::new(name.span(), self.file); - let id = - self.interner.push_definition(name.0.contents.clone(), false, definition, location); + let var_name = name.0.contents.clone(); + let id = self.interner.push_definition(var_name, false, false, definition, location); let ident = HirIdent::non_trait_method(id, location); let resolver_meta = ResolverMeta { num_times_used: 0, ident: ident.clone(), warn_if_unused: true }; diff --git a/compiler/noirc_frontend/src/hir/type_check/mod.rs b/compiler/noirc_frontend/src/hir/type_check/mod.rs index b8fd59e015b..3f1678f4dba 100644 --- a/compiler/noirc_frontend/src/hir/type_check/mod.rs +++ b/compiler/noirc_frontend/src/hir/type_check/mod.rs @@ -483,19 +483,34 @@ pub mod test { // let z = x + y; // // Push x variable - let x_id = - interner.push_definition("x".into(), false, DefinitionKind::Local(None), location); + let x_id = interner.push_definition( + "x".into(), + false, + false, + DefinitionKind::Local(None), + location, + ); let x = HirIdent::non_trait_method(x_id, location); // Push y variable - let y_id = - interner.push_definition("y".into(), false, DefinitionKind::Local(None), location); + let y_id = interner.push_definition( + "y".into(), + false, + false, + DefinitionKind::Local(None), + location, + ); let y = HirIdent::non_trait_method(y_id, location); // Push z variable - let z_id = - interner.push_definition("z".into(), false, DefinitionKind::Local(None), location); + let z_id = interner.push_definition( + "z".into(), + false, + false, + DefinitionKind::Local(None), + location, + ); let z = HirIdent::non_trait_method(z_id, location); // Push x and y as expressions @@ -531,7 +546,7 @@ pub mod test { let func_id = interner.push_fn(func); let definition = DefinitionKind::Local(None); - let id = interner.push_definition("test_func".into(), false, definition, location); + let id = interner.push_definition("test_func".into(), false, false, definition, location); let name = HirIdent::non_trait_method(id, location); // Add function meta diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index 618ed6a9ffa..7c30ccf5b8f 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -460,6 +460,7 @@ pub(crate) enum Node { pub struct DefinitionInfo { pub name: String, pub mutable: bool, + pub comptime: bool, pub kind: DefinitionKind, pub location: Location, } @@ -721,6 +722,7 @@ impl NodeInterner { self.type_ref_locations.push((typ, location)); } + #[allow(clippy::too_many_arguments)] fn push_global( &mut self, ident: Ident, @@ -729,12 +731,13 @@ impl NodeInterner { file: FileId, attributes: Vec, mutable: bool, + comptime: bool, ) -> GlobalId { let id = GlobalId(self.globals.len()); let location = Location::new(ident.span(), file); let name = ident.to_string(); let definition_id = - self.push_definition(name, mutable, DefinitionKind::Global(id), location); + self.push_definition(name, mutable, comptime, DefinitionKind::Global(id), location); self.globals.push(GlobalInfo { id, @@ -761,10 +764,11 @@ impl NodeInterner { file: FileId, attributes: Vec, mutable: bool, + comptime: bool, ) -> GlobalId { let statement = self.push_stmt(HirStatement::Error); let span = name.span(); - let id = self.push_global(name, local_id, statement, file, attributes, mutable); + let id = self.push_global(name, local_id, statement, file, attributes, mutable, comptime); self.push_stmt_location(statement, span, file); id } @@ -808,6 +812,7 @@ impl NodeInterner { &mut self, name: String, mutable: bool, + comptime: bool, definition: DefinitionKind, location: Location, ) -> DefinitionId { @@ -816,7 +821,8 @@ impl NodeInterner { self.function_definition_ids.insert(func_id, id); } - self.definitions.push(DefinitionInfo { name, mutable, kind: definition, location }); + let kind = definition; + self.definitions.push(DefinitionInfo { name, mutable, comptime, kind, location }); id } @@ -864,9 +870,10 @@ impl NodeInterner { location: Location, ) -> DefinitionId { let name = modifiers.name.clone(); + let comptime = modifiers.is_comptime; self.function_modifiers.insert(func, modifiers); self.function_modules.insert(func, module); - self.push_definition(name, false, DefinitionKind::Function(func), location) + self.push_definition(name, false, comptime, DefinitionKind::Function(func), location) } pub fn set_function_trait(&mut self, func: FuncId, self_type: Type, trait_id: TraitId) { diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index 9251eb3db6b..dbfa5222ca4 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -1896,3 +1896,14 @@ fn quote_code_fragments() { use InterpreterError::FailingConstraint; assert!(matches!(&errors[0].0, CompilationError::InterpreterError(FailingConstraint { .. }))); } + +// Regression for #5388 +#[test] +fn comptime_let() { + let src = r#"fn main() { + comptime let my_var = 2; + assert_eq(my_var, 2); + }"#; + let errors = get_program_errors(src); + assert_eq!(errors.len(), 0); +}