Skip to content

Commit

Permalink
feat: replace eval_global_as_array_length with type/interpreter eva…
Browse files Browse the repository at this point in the history
…luation (#6469)

Co-authored-by: Tom French <15848336+TomAFrench@users.noreply.github.com>
Co-authored-by: jfecher <jake@aztecprotocol.com>
Co-authored-by: Ary Borenszweig <asterite@gmail.com>
  • Loading branch information
4 people authored Dec 16, 2024
1 parent b4325b4 commit ddb4673
Show file tree
Hide file tree
Showing 18 changed files with 248 additions and 252 deletions.
2 changes: 2 additions & 0 deletions compiler/noirc_frontend/src/ast/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ impl StatementKind {
r#type,
expression,
comptime: false,
is_global_let: false,
attributes,
})
}
Expand Down Expand Up @@ -562,6 +563,7 @@ pub struct LetStatement {

// True if this should only be run during compile-time
pub comptime: bool,
pub is_global_let: bool,
}

#[derive(Debug, PartialEq, Eq, Clone)]
Expand Down
23 changes: 13 additions & 10 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1652,20 +1652,13 @@ impl<'context> Elaborator<'context> {
self.push_err(ResolverError::MutableGlobal { span });
}

let comptime = let_stmt.comptime;

let (let_statement, _typ) = if comptime {
self.elaborate_in_comptime_context(|this| this.elaborate_let(let_stmt, Some(global_id)))
} else {
self.elaborate_let(let_stmt, Some(global_id))
};
let (let_statement, _typ) = self
.elaborate_in_comptime_context(|this| this.elaborate_let(let_stmt, Some(global_id)));

let statement_id = self.interner.get_global(global_id).let_statement;
self.interner.replace_statement(statement_id, let_statement);

if comptime {
self.elaborate_comptime_global(global_id);
}
self.elaborate_comptime_global(global_id);

if let Some(name) = name {
self.interner.register_global(global_id, name, global.visibility, self.module_id());
Expand Down Expand Up @@ -1700,6 +1693,16 @@ impl<'context> Elaborator<'context> {
}
}

/// If the given global is unresolved, elaborate it and return true
fn elaborate_global_if_unresolved(&mut self, global_id: &GlobalId) -> bool {
if let Some(global) = self.unresolved_globals.remove(global_id) {
self.elaborate_global(global);
true
} else {
false
}
}

fn define_function_metas(
&mut self,
functions: &mut [UnresolvedFunctions],
Expand Down
4 changes: 1 addition & 3 deletions compiler/noirc_frontend/src/elaborator/patterns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -666,9 +666,7 @@ impl<'context> Elaborator<'context> {
self.interner.add_function_reference(func_id, hir_ident.location);
}
DefinitionKind::Global(global_id) => {
if let Some(global) = self.unresolved_globals.remove(&global_id) {
self.elaborate_global(global);
}
self.elaborate_global_if_unresolved(&global_id);
if let Some(current_item) = self.current_item {
self.interner.add_global_dependency(current_item, global_id);
}
Expand Down
5 changes: 4 additions & 1 deletion compiler/noirc_frontend/src/elaborator/statements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ impl<'context> Elaborator<'context> {
) -> (HirStatement, Type) {
let expr_span = let_stmt.expression.span;
let (expression, expr_type) = self.elaborate_expression(let_stmt.expression);

let type_contains_unspecified = let_stmt.r#type.contains_unspecified();
let annotated_type = self.resolve_inferred_type(let_stmt.r#type);

Expand Down Expand Up @@ -123,7 +124,9 @@ impl<'context> Elaborator<'context> {

let attributes = let_stmt.attributes;
let comptime = let_stmt.comptime;
let let_ = HirLetStatement { pattern, r#type, expression, attributes, comptime };
let is_global_let = let_stmt.is_global_let;
let let_ =
HirLetStatement::new(pattern, r#type, expression, attributes, comptime, is_global_let);
(HirStatement::Let(let_), Type::Unit)
}

Expand Down
158 changes: 36 additions & 122 deletions compiler/noirc_frontend/src/elaborator/types.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::{borrow::Cow, rc::Rc};

use acvm::acir::AcirField;
use iter_extended::vecmap;
use noirc_errors::{Location, Span};
use rustc_hash::FxHashMap as HashMap;
Expand All @@ -12,7 +11,6 @@ use crate::{
UnresolvedTypeData, UnresolvedTypeExpression, WILDCARD_TYPE,
},
hir::{
comptime::{Interpreter, Value},
def_collector::dc_crate::CompilationError,
resolution::{errors::ResolverError, import::PathResolutionError},
type_check::{
Expand All @@ -30,8 +28,8 @@ use crate::{
traits::{NamedType, ResolvedTraitBound, Trait, TraitConstraint},
},
node_interner::{
DefinitionKind, DependencyId, ExprId, GlobalId, ImplSearchErrorKind, NodeInterner, TraitId,
TraitImplKind, TraitMethodId,
DependencyId, ExprId, ImplSearchErrorKind, NodeInterner, TraitId, TraitImplKind,
TraitMethodId,
},
token::SecondaryAttribute,
Generics, Kind, ResolvedGeneric, Type, TypeBinding, TypeBindings, UnificationError,
Expand Down Expand Up @@ -415,17 +413,40 @@ impl<'context> Elaborator<'context> {
.map(|let_statement| Kind::numeric(let_statement.r#type))
.unwrap_or(Kind::u32());

// TODO(https://github.com/noir-lang/noir/issues/6238):
// support non-u32 generics here
if !kind.unifies(&Kind::u32()) {
let error = TypeCheckError::EvaluatedGlobalIsntU32 {
expected_kind: Kind::u32().to_string(),
expr_kind: kind.to_string(),
expr_span: path.span(),
};
self.push_err(error);
}
Some(Type::Constant(self.eval_global_as_array_length(id, path).into(), kind))
let Some(stmt) = self.interner.get_global_let_statement(id) else {
if self.elaborate_global_if_unresolved(&id) {
return self.lookup_generic_or_global_type(path);
} else {
let path = path.clone();
self.push_err(ResolverError::NoSuchNumericTypeVariable { path });
return None;
}
};

let rhs = stmt.expression;
let span = self.interner.expr_span(&rhs);

let Some(global_value) = &self.interner.get_global(id).value else {
self.push_err(ResolverError::UnevaluatedGlobalType { span });
return None;
};

let Some(global_value) = global_value.to_field_element() else {
let global_value = global_value.clone();
if global_value.is_integral() {
self.push_err(ResolverError::NegativeGlobalType { span, global_value });
} else {
self.push_err(ResolverError::NonIntegralGlobalType { span, global_value });
}
return None;
};

let Ok(global_value) = kind.ensure_value_fits(global_value, span) else {
self.push_err(ResolverError::GlobalLargerThanKind { span, global_value, kind });
return None;
};

Some(Type::Constant(global_value, kind))
}
_ => None,
}
Expand Down Expand Up @@ -633,31 +654,6 @@ impl<'context> Elaborator<'context> {
.or_else(|| self.resolve_trait_method_by_named_generic(path))
}

fn eval_global_as_array_length(&mut self, global_id: GlobalId, path: &Path) -> u32 {
let Some(stmt) = self.interner.get_global_let_statement(global_id) else {
if let Some(global) = self.unresolved_globals.remove(&global_id) {
self.elaborate_global(global);
return self.eval_global_as_array_length(global_id, path);
} else {
let path = path.clone();
self.push_err(ResolverError::NoSuchNumericTypeVariable { path });
return 0;
}
};

let length = stmt.expression;
let span = self.interner.expr_span(&length);
let result = try_eval_array_length_id(self.interner, length, span);

match result.map(|length| length.try_into()) {
Ok(Ok(length_value)) => return length_value,
Ok(Err(_cast_err)) => self.push_err(ResolverError::IntegerTooLarge { span }),
Err(Some(error)) => self.push_err(error),
Err(None) => (),
}
0
}

pub(super) fn unify(
&mut self,
actual: &Type,
Expand Down Expand Up @@ -1834,88 +1830,6 @@ fn bind_generic(param: &ResolvedGeneric, arg: &Type, bindings: &mut TypeBindings
}
}

pub fn try_eval_array_length_id(
interner: &NodeInterner,
rhs: ExprId,
span: Span,
) -> Result<u128, Option<ResolverError>> {
// Arbitrary amount of recursive calls to try before giving up
let fuel = 100;
try_eval_array_length_id_with_fuel(interner, rhs, span, fuel)
}

fn try_eval_array_length_id_with_fuel(
interner: &NodeInterner,
rhs: ExprId,
span: Span,
fuel: u32,
) -> Result<u128, Option<ResolverError>> {
if fuel == 0 {
// If we reach here, it is likely from evaluating cyclic globals. We expect an error to
// be issued for them after name resolution so issue no error now.
return Err(None);
}

match interner.expression(&rhs) {
HirExpression::Literal(HirLiteral::Integer(int, false)) => {
int.try_into_u128().ok_or(Some(ResolverError::IntegerTooLarge { span }))
}
HirExpression::Ident(ident, _) => {
if let Some(definition) = interner.try_definition(ident.id) {
match definition.kind {
DefinitionKind::Global(global_id) => {
let let_statement = interner.get_global_let_statement(global_id);
if let Some(let_statement) = let_statement {
let expression = let_statement.expression;
try_eval_array_length_id_with_fuel(interner, expression, span, fuel - 1)
} else {
Err(Some(ResolverError::InvalidArrayLengthExpr { span }))
}
}
_ => Err(Some(ResolverError::InvalidArrayLengthExpr { span })),
}
} else {
Err(Some(ResolverError::InvalidArrayLengthExpr { span }))
}
}
HirExpression::Infix(infix) => {
let lhs = try_eval_array_length_id_with_fuel(interner, infix.lhs, span, fuel - 1)?;
let rhs = try_eval_array_length_id_with_fuel(interner, infix.rhs, span, fuel - 1)?;

match infix.operator.kind {
BinaryOpKind::Add => Ok(lhs + rhs),
BinaryOpKind::Subtract => Ok(lhs - rhs),
BinaryOpKind::Multiply => Ok(lhs * rhs),
BinaryOpKind::Divide => Ok(lhs / rhs),
BinaryOpKind::Equal => Ok((lhs == rhs) as u128),
BinaryOpKind::NotEqual => Ok((lhs != rhs) as u128),
BinaryOpKind::Less => Ok((lhs < rhs) as u128),
BinaryOpKind::LessEqual => Ok((lhs <= rhs) as u128),
BinaryOpKind::Greater => Ok((lhs > rhs) as u128),
BinaryOpKind::GreaterEqual => Ok((lhs >= rhs) as u128),
BinaryOpKind::And => Ok(lhs & rhs),
BinaryOpKind::Or => Ok(lhs | rhs),
BinaryOpKind::Xor => Ok(lhs ^ rhs),
BinaryOpKind::ShiftRight => Ok(lhs >> rhs),
BinaryOpKind::ShiftLeft => Ok(lhs << rhs),
BinaryOpKind::Modulo => Ok(lhs % rhs),
}
}
HirExpression::Cast(cast) => {
let lhs = try_eval_array_length_id_with_fuel(interner, cast.lhs, span, fuel - 1)?;
let lhs_value = Value::Field(lhs.into());
let evaluated_value =
Interpreter::evaluate_cast_one_step(&cast, rhs, lhs_value, interner)
.map_err(|error| Some(ResolverError::ArrayLengthInterpreter { error }))?;

evaluated_value
.to_u128()
.ok_or_else(|| Some(ResolverError::InvalidArrayLengthExpr { span }))
}
_other => Err(Some(ResolverError::InvalidArrayLengthExpr { span })),
}
}

/// Gives an error if a user tries to create a mutable reference
/// to an immutable variable.
fn verify_mutable_reference(interner: &NodeInterner, rhs: ExprId) -> Result<(), ResolverError> {
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_frontend/src/hir/comptime/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> {
},
)?;

if let_.comptime || crate_of_global != self.crate_id {
if let_.runs_comptime() || crate_of_global != self.crate_id {
self.evaluate_let(let_.clone())?;
}

Expand Down
39 changes: 25 additions & 14 deletions compiler/noirc_frontend/src/hir/comptime/value.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{borrow::Cow, rc::Rc, vec};

use acvm::{AcirField, FieldElement};
use acvm::FieldElement;
use im::Vector;
use iter_extended::{try_vecmap, vecmap};
use noirc_errors::{Location, Span};
Expand Down Expand Up @@ -409,6 +409,9 @@ impl Value {
Value::Pointer(element, true) => {
return element.unwrap_or_clone().into_hir_expression(interner, location);
}
Value::Closure(hir_lambda, _args, _typ, _opt_func_id, _module_id) => {
HirExpression::Lambda(hir_lambda)
}
Value::TypedExpr(TypedExpr::StmtId(..))
| Value::Expr(..)
| Value::Pointer(..)
Expand All @@ -420,7 +423,6 @@ impl Value {
| Value::Zeroed(_)
| Value::Type(_)
| Value::UnresolvedType(_)
| Value::Closure(..)
| Value::ModuleDefinition(_) => {
let typ = self.get_type().into_owned();
let value = self.display(interner).to_string();
Expand Down Expand Up @@ -502,19 +504,28 @@ impl Value {
Ok(vec![token])
}

/// Converts any unsigned `Value` into a `u128`.
/// Returns `None` for negative integers.
pub(crate) fn to_u128(&self) -> Option<u128> {
/// Returns false for non-integral `Value`s.
pub(crate) fn is_integral(&self) -> bool {
use Value::*;
matches!(
self,
Field(_) | I8(_) | I16(_) | I32(_) | I64(_) | U8(_) | U16(_) | U32(_) | U64(_)
)
}

/// Converts any non-negative `Value` into a `FieldElement`.
/// Returns `None` for negative integers and non-integral `Value`s.
pub(crate) fn to_field_element(&self) -> Option<FieldElement> {
match self {
Self::Field(value) => Some(value.to_u128()),
Self::I8(value) => (*value >= 0).then_some(*value as u128),
Self::I16(value) => (*value >= 0).then_some(*value as u128),
Self::I32(value) => (*value >= 0).then_some(*value as u128),
Self::I64(value) => (*value >= 0).then_some(*value as u128),
Self::U8(value) => Some(*value as u128),
Self::U16(value) => Some(*value as u128),
Self::U32(value) => Some(*value as u128),
Self::U64(value) => Some(*value as u128),
Self::Field(value) => Some(*value),
Self::I8(value) => (*value >= 0).then_some((*value as u128).into()),
Self::I16(value) => (*value >= 0).then_some((*value as u128).into()),
Self::I32(value) => (*value >= 0).then_some((*value as u128).into()),
Self::I64(value) => (*value >= 0).then_some((*value as u128).into()),
Self::U8(value) => Some((*value as u128).into()),
Self::U16(value) => Some((*value as u128).into()),
Self::U32(value) => Some((*value as u128).into()),
Self::U64(value) => Some((*value as u128).into()),
_ => None,
}
}
Expand Down
Loading

0 comments on commit ddb4673

Please sign in to comment.