Skip to content

Commit

Permalink
fix: Run macros within comptime contexts (#5576)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves #5575

## Summary\*

I've given the interpreter a reference to the elaborator. When a macro
is called we setup that elaborator to elaborate in the middle of the
current function by saving any previous state and setting any state we
can. This is sufficient to resolve local variables, global variables,
generics, and some metadata like constrained-ness.

## Additional Context

To implement this I had to tell the elaborator not to call macros that
are within a comptime scope (wait for the interpreter to do it). This
lead to type errors in e.g. `assert(unquote!(...))` since the elaborator
would now see a `Quoted` type instead of the type it unquotes to. I've
put a band-aid over this by returning an unbound type variable in this
case. This usually works ok but in certain situations users may need a
type annotation when using macros within other comptime code. I think
this is a reasonable limitation.

## Documentation\*

Check one:
- [ ] No documentation needed.
- [ ] Documentation included in this PR.
- [x] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
jfecher authored Jul 22, 2024
1 parent df0854e commit df44919
Show file tree
Hide file tree
Showing 12 changed files with 332 additions and 163 deletions.
71 changes: 71 additions & 0 deletions compiler/noirc_frontend/src/elaborator/comptime.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use std::mem::replace;

use crate::{
hir_def::expr::HirIdent,
macros_api::Expression,
node_interner::{DependencyId, ExprId, FuncId},
};

use super::{Elaborator, FunctionContext, ResolverMeta};

impl<'context> Elaborator<'context> {
/// Elaborate an expression from the middle of a comptime scope.
/// When this happens we require additional information to know
/// what variables should be in scope.
pub fn elaborate_expression_from_comptime(
&mut self,
expr: Expression,
function: Option<FuncId>,
) -> ExprId {
self.function_context.push(FunctionContext::default());
let old_scope = self.scopes.end_function();
self.scopes.start_function();
let function_id = function.map(DependencyId::Function);
let old_item = replace(&mut self.current_item, function_id);

// Note: recover_generics isn't good enough here because any existing generics
// should not be in scope of this new function
let old_generics = std::mem::take(&mut self.generics);

let old_crate_and_module = function.map(|function| {
let meta = self.interner.function_meta(&function);
let old_crate = replace(&mut self.crate_id, meta.source_crate);
let old_module = replace(&mut self.local_module, meta.source_module);
self.introduce_generics_into_scope(meta.all_generics.clone());
(old_crate, old_module)
});

self.populate_scope_from_comptime_scopes();
let expr = self.elaborate_expression(expr).0;

if let Some((old_crate, old_module)) = old_crate_and_module {
self.crate_id = old_crate;
self.local_module = old_module;
}

self.generics = old_generics;
self.current_item = old_item;
self.scopes.end_function();
self.scopes.0.push(old_scope);
self.check_and_pop_function_context();
expr
}

fn populate_scope_from_comptime_scopes(&mut self) {
// Take the comptime scope to be our runtime scope.
// Iterate from global scope to the most local scope so that the
// later definitions will naturally shadow the former.
for scope in &self.comptime_scopes {
for definition_id in scope.keys() {
let definition = self.interner.definition(*definition_id);
let name = definition.name.clone();
let location = definition.location;

let scope = self.scopes.get_mut_scope();
let ident = HirIdent::non_trait_method(*definition_id, location);
let meta = ResolverMeta { ident, num_times_used: 0, warn_if_unused: false };
scope.add_key_value(name.clone(), meta);
}
}
}
}
28 changes: 18 additions & 10 deletions compiler/noirc_frontend/src/elaborator/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,15 +300,21 @@ impl<'context> Elaborator<'context> {
}

let location = Location::new(span, self.file);
let hir_call = HirCallExpression { func, arguments, location };
let typ = self.type_check_call(&hir_call, func_type, args, span);
let is_macro_call = call.is_macro_call;
let hir_call = HirCallExpression { func, arguments, location, is_macro_call };
let mut typ = self.type_check_call(&hir_call, func_type, args, span);

if call.is_macro_call {
self.call_macro(func, comptime_args, location, typ)
.unwrap_or_else(|| (HirExpression::Error, Type::Error))
} else {
(HirExpression::Call(hir_call), typ)
if is_macro_call {
if self.in_comptime_context() {
typ = self.interner.next_type_variable();
} else {
return self
.call_macro(func, comptime_args, location, typ)
.unwrap_or_else(|| (HirExpression::Error, Type::Error));
}
}

(HirExpression::Call(hir_call), typ)
}

fn elaborate_method_call(
Expand Down Expand Up @@ -368,6 +374,7 @@ impl<'context> Elaborator<'context> {
let location = Location::new(span, self.file);
let method = method_call.method_name;
let turbofish_generics = generics.clone();
let is_macro_call = method_call.is_macro_call;
let method_call =
HirMethodCallExpression { method, object, arguments, location, generics };

Expand All @@ -377,6 +384,7 @@ impl<'context> Elaborator<'context> {
let ((function_id, function_name), function_call) = method_call.into_function_call(
&method_ref,
object_type,
is_macro_call,
location,
self.interner,
);
Expand Down Expand Up @@ -721,7 +729,7 @@ impl<'context> Elaborator<'context> {
(id, typ)
}

pub(super) fn inline_comptime_value(
pub fn inline_comptime_value(
&mut self,
value: Result<comptime::Value, InterpreterError>,
span: Span,
Expand Down Expand Up @@ -801,14 +809,14 @@ impl<'context> Elaborator<'context> {
for argument in arguments {
match interpreter.evaluate(argument) {
Ok(arg) => {
let location = interpreter.interner.expr_location(&argument);
let location = interpreter.elaborator.interner.expr_location(&argument);
comptime_args.push((arg, location));
}
Err(error) => errors.push((error.into(), file)),
}
}

let bindings = interpreter.interner.get_instantiation_bindings(func).clone();
let bindings = interpreter.elaborator.interner.get_instantiation_bindings(func).clone();
let result = interpreter.call_function(function, comptime_args, bindings, location);

if !errors.is_empty() {
Expand Down
57 changes: 38 additions & 19 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::{
use crate::{
ast::{FunctionKind, UnresolvedTraitConstraint},
hir::{
comptime::{self, Interpreter, InterpreterError, Value},
comptime::{Interpreter, InterpreterError, Value},
def_collector::{
dc_crate::{
filter_literal_globals, CompilationError, ImplMap, UnresolvedGlobal,
Expand Down Expand Up @@ -60,6 +60,7 @@ use crate::{
macros_api::ItemVisibility,
};

mod comptime;
mod expressions;
mod lints;
mod patterns;
Expand Down Expand Up @@ -97,9 +98,9 @@ pub struct LambdaContext {
pub struct Elaborator<'context> {
scopes: ScopeForest,

errors: Vec<(CompilationError, FileId)>,
pub(crate) errors: Vec<(CompilationError, FileId)>,

interner: &'context mut NodeInterner,
pub(crate) interner: &'context mut NodeInterner,

def_maps: &'context mut BTreeMap<CrateId, CrateDefMap>,

Expand Down Expand Up @@ -167,7 +168,7 @@ pub struct Elaborator<'context> {
/// Each value currently in scope in the comptime interpreter.
/// Each element of the Vec represents a scope with every scope together making
/// up all currently visible definitions. The first scope is always the global scope.
comptime_scopes: Vec<HashMap<DefinitionId, comptime::Value>>,
pub(crate) comptime_scopes: Vec<HashMap<DefinitionId, Value>>,

/// The scope of --debug-comptime, or None if unset
debug_comptime_in_file: Option<FileId>,
Expand Down Expand Up @@ -228,6 +229,15 @@ impl<'context> Elaborator<'context> {
items: CollectedItems,
debug_comptime_in_file: Option<FileId>,
) -> Vec<(CompilationError, FileId)> {
Self::elaborate_and_return_self(context, crate_id, items, debug_comptime_in_file).errors
}

pub fn elaborate_and_return_self(
context: &'context mut Context,
crate_id: CrateId,
items: CollectedItems,
debug_comptime_in_file: Option<FileId>,
) -> Self {
let mut this = Self::new(context, crate_id, debug_comptime_in_file);

// Filter out comptime items to execute their functions first if needed.
Expand All @@ -238,7 +248,7 @@ impl<'context> Elaborator<'context> {
let (comptime_items, runtime_items) = Self::filter_comptime_items(items);
this.elaborate_items(comptime_items);
this.elaborate_items(runtime_items);
this.errors
this
}

fn elaborate_items(&mut self, mut items: CollectedItems) {
Expand Down Expand Up @@ -339,6 +349,21 @@ impl<'context> Elaborator<'context> {
self.trait_id = None;
}

fn introduce_generics_into_scope(&mut self, all_generics: Vec<ResolvedGeneric>) {
// Introduce all numeric generics into scope
for generic in &all_generics {
if let Kind::Numeric(typ) = &generic.kind {
let definition = DefinitionKind::GenericType(generic.type_var.clone());
let ident = Ident::new(generic.name.to_string(), generic.span);
let hir_ident =
self.add_variable_decl_inner(ident, false, false, false, definition);
self.interner.push_definition_type(hir_ident.id, *typ.clone());
}
}

self.generics = all_generics;
}

fn elaborate_function(&mut self, id: FuncId) {
let func_meta = self.interner.func_meta.get_mut(&id);
let func_meta =
Expand All @@ -360,16 +385,7 @@ impl<'context> Elaborator<'context> {
self.trait_bounds = func_meta.trait_constraints.clone();
self.function_context.push(FunctionContext::default());

// Introduce all numeric generics into scope
for generic in &func_meta.all_generics {
if let Kind::Numeric(typ) = &generic.kind {
let definition = DefinitionKind::GenericType(generic.type_var.clone());
let ident = Ident::new(generic.name.to_string(), generic.span);
let hir_ident =
self.add_variable_decl_inner(ident, false, false, false, definition);
self.interner.push_definition_type(hir_ident.id, *typ.clone());
}
}
self.introduce_generics_into_scope(func_meta.all_generics.clone());

// The DefinitionIds for each parameter were already created in define_function_meta
// so we need to reintroduce the same IDs into scope here.
Expand All @@ -378,8 +394,6 @@ impl<'context> Elaborator<'context> {
self.add_existing_variable_to_scope(name, parameter.clone(), true);
}

self.generics = func_meta.all_generics.clone();

self.declare_numeric_generics(&func_meta.parameters, func_meta.return_type());
self.add_trait_constraints_to_scope(&func_meta);

Expand Down Expand Up @@ -758,6 +772,7 @@ impl<'context> Elaborator<'context> {
is_trait_function,
has_inline_attribute,
source_crate: self.crate_id,
source_module: self.local_module,
function_body: FunctionBody::Unresolved(func.kind, body, func.def.span),
};

Expand Down Expand Up @@ -1626,8 +1641,12 @@ impl<'context> Elaborator<'context> {
}
}

fn setup_interpreter(&mut self) -> Interpreter {
Interpreter::new(self.interner, &mut self.comptime_scopes, self.crate_id)
pub fn setup_interpreter<'local>(&'local mut self) -> Interpreter<'local, 'context> {
let current_function = match self.current_item {
Some(DependencyId::Function(function)) => Some(function),
_ => None,
};
Interpreter::new(self, self.crate_id, current_function)
}

fn debug_comptime<T: Display, F: FnMut(&mut NodeInterner) -> T>(
Expand Down
4 changes: 1 addition & 3 deletions compiler/noirc_frontend/src/elaborator/patterns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use rustc_hash::FxHashSet as HashSet;
use crate::{
ast::{UnresolvedType, ERROR_IDENT},
hir::{
comptime::Interpreter,
def_collector::dc_crate::CompilationError,
resolution::errors::ResolverError,
type_check::{Source, TypeCheckError},
Expand Down Expand Up @@ -460,8 +459,7 @@ impl<'context> Elaborator<'context> {
// Comptime variables must be replaced with their values
if let Some(definition) = self.interner.try_definition(definition_id) {
if definition.comptime && !self.in_comptime_context() {
let mut interpreter =
Interpreter::new(self.interner, &mut self.comptime_scopes, self.crate_id);
let mut interpreter = self.setup_interpreter();
let value = interpreter.evaluate(id);
return self.inline_comptime_value(value, span);
}
Expand Down
Loading

0 comments on commit df44919

Please sign in to comment.