Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow numeric generics to be referenced and add map #997

Merged
merged 4 commits into from
Mar 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ fn test_array_functions() {

let descending = myarray.sort_via(|a, b| a > b);
constrain descending == [3, 2, 1];

constrain evens.map(|n| n / 2) == myarray;
}

fn foo() -> [u32; 2] {
Expand Down
5 changes: 4 additions & 1 deletion crates/noirc_frontend/src/hir/resolution/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ pub use noirc_errors::Span;
use noirc_errors::{CustomDiagnostic as Diagnostic, FileDiagnostic};
use thiserror::Error;

use crate::{Ident, Shared, StructType, Type};
use crate::{parser::ParserError, Ident, Shared, StructType, Type};

#[derive(Error, Debug, Clone, PartialEq, Eq)]
pub enum ResolverError {
Expand Down Expand Up @@ -57,6 +57,8 @@ pub enum ResolverError {
actual: usize,
expected: usize,
},
#[error("{0}")]
ParserError(ParserError),
}

impl ResolverError {
Expand Down Expand Up @@ -252,6 +254,7 @@ impl From<ResolverError> for Diagnostic {
span,
)
}
ResolverError::ParserError(error) => error.into(),
}
}
}
244 changes: 122 additions & 122 deletions crates/noirc_frontend/src/hir/resolution/resolver.rs

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions crates/noirc_frontend/src/hir/scope/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,17 @@ impl<K: std::hash::Hash + Eq + Clone, V> ScopeForest<K, V> {
fn extend_current_scope_tree(&mut self) {
self.current_scope_tree().push_scope()
}

fn remove_scope_tree_extension(&mut self) -> Scope<K, V> {
self.current_scope_tree().pop_scope()
}

/// Starting a function requires a new scope tree, as you do not want the functions scope to
/// have access to the scope of the caller
pub fn start_function(&mut self) {
self.0.push(ScopeTree::default())
}

/// Ending a function requires that we removes it's whole tree of scope
/// This is by design the current scope, which is the last element in the vector
pub fn end_function(&mut self) -> ScopeTree<K, V> {
Expand Down
8 changes: 6 additions & 2 deletions crates/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use noirc_errors::Span;

use crate::{
hir_def::{
expr::{self, HirBinaryOp, HirExpression, HirLiteral},
expr::{self, HirArrayLiteral, HirBinaryOp, HirExpression, HirLiteral},
types::Type,
},
node_interner::{ExprId, FuncId, NodeInterner},
Expand Down Expand Up @@ -38,7 +38,7 @@ pub(crate) fn type_check_expression(
}
HirExpression::Literal(literal) => {
match literal {
HirLiteral::Array(arr) => {
HirLiteral::Array(HirArrayLiteral::Standard(arr)) => {
let elem_types =
vecmap(&arr, |arg| type_check_expression(interner, arg, errors));

Expand Down Expand Up @@ -68,6 +68,10 @@ pub(crate) fn type_check_expression(

arr_type
}
HirLiteral::Array(HirArrayLiteral::Repeated { repeated_element, length }) => {
let elem_type = type_check_expression(interner, &repeated_element, errors);
Type::Array(Box::new(length), Box::new(elem_type))
}
HirLiteral::Bool(_) => Type::Bool(CompTime::new(interner)),
HirLiteral::Integer(_) => {
let id = interner.next_type_variable_id();
Expand Down
8 changes: 7 additions & 1 deletion crates/noirc_frontend/src/hir_def/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,18 @@ impl HirBinaryOp {

#[derive(Debug, Clone)]
pub enum HirLiteral {
Array(Vec<ExprId>),
Array(HirArrayLiteral),
Bool(bool),
Integer(FieldElement),
Str(String),
}

#[derive(Debug, Clone)]
pub enum HirArrayLiteral {
Standard(Vec<ExprId>),
Repeated { repeated_element: ExprId, length: Type },
}

#[derive(Debug, Clone)]
pub struct HirPrefixExpression {
pub operator: UnaryOp,
Expand Down
59 changes: 59 additions & 0 deletions crates/noirc_frontend/src/hir_def/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,14 @@ impl StructType {
self.fields.keys().cloned().collect()
}

/// True if the given index is the same index as a generic type of this struct
/// which is expected to be a numeric generic.
/// This is needed because we infer type kinds in Noir and don't have extensive kind checking.
pub fn generic_is_numeric(&self, index_of_generic: usize) -> bool {
let target_id = self.generics[index_of_generic].0;
self.fields.iter().any(|(_, field)| field.contains_numeric_typevar(target_id))
}

/// Instantiate this struct type, returning a Vec of the new generic args (in
/// the same order as self.generics)
pub fn instantiate(&self, interner: &mut NodeInterner) -> Vec<Type> {
Expand Down Expand Up @@ -532,6 +540,57 @@ impl Type {
pub fn is_field(&self) -> bool {
matches!(self.follow_bindings(), Type::FieldElement(_))
}

fn contains_numeric_typevar(&self, target_id: TypeVariableId) -> bool {
// True if the given type is a NamedGeneric with the target_id
let named_generic_id_matches_target = |typ: &Type| {
if let Type::NamedGeneric(type_variable, _) = typ {
match &*type_variable.borrow() {
TypeBinding::Bound(_) => {
unreachable!("Named generics should not be bound until monomorphization")
}
TypeBinding::Unbound(id) => target_id == *id,
}
} else {
false
}
};

match self {
Type::FieldElement(_)
| Type::Integer(_, _, _)
| Type::Bool(_)
| Type::String(_)
| Type::Unit
| Type::Error
| Type::TypeVariable(_)
| Type::PolymorphicInteger(_, _)
| Type::Constant(_)
| Type::NamedGeneric(_, _)
| Type::Forall(_, _) => false,

Type::Array(length, elem) => {
elem.contains_numeric_typevar(target_id) || named_generic_id_matches_target(length)
}

Type::Tuple(fields) => {
fields.iter().any(|field| field.contains_numeric_typevar(target_id))
}
Type::Function(parameters, return_type) => {
parameters.iter().any(|parameter| parameter.contains_numeric_typevar(target_id))
|| return_type.contains_numeric_typevar(target_id)
}
Type::Struct(struct_type, generics) => {
generics.iter().enumerate().any(|(i, generic)| {
if named_generic_id_matches_target(generic) {
struct_type.borrow().generic_is_numeric(i)
} else {
generic.contains_numeric_typevar(target_id)
}
})
}
}
}
}

impl std::fmt::Display for Type {
Expand Down
34 changes: 30 additions & 4 deletions crates/noirc_frontend/src/monomorphization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,24 @@ impl<'interner> Monomorphizer<'interner> {
let typ = Self::convert_type(&self.interner.id_type(expr));
Literal(Integer(value, typ))
}
HirExpression::Literal(HirLiteral::Array(array)) => {
HirExpression::Literal(HirLiteral::Array(HirArrayLiteral::Standard(array))) => {
let element_type = Self::convert_type(&self.interner.id_type(array[0]));
let contents = vecmap(array, |id| self.expr_infer(id));
Literal(Array(ast::ArrayLiteral { contents, element_type }))
}
HirExpression::Literal(HirLiteral::Array(HirArrayLiteral::Repeated {
repeated_element,
length,
})) => {
let element_type = Self::convert_type(&self.interner.id_type(repeated_element));
let contents = self.expr_infer(repeated_element);
let length = length
.evaluate_to_u64()
.expect("Length of array is unknown when evaluating numeric generic");

let contents = vec![contents; length as usize];
Literal(Array(ast::ArrayLiteral { contents, element_type }))
}
HirExpression::Block(block) => self.block(block.0),

HirExpression::Prefix(prefix) => ast::Expression::Unary(ast::Unary {
Expand Down Expand Up @@ -479,23 +492,36 @@ impl<'interner> Monomorphizer<'interner> {

fn ident(&mut self, ident: HirIdent, expr_id: node_interner::ExprId) -> ast::Expression {
let definition = self.interner.definition(ident.id);
match definition.kind {
match &definition.kind {
DefinitionKind::Function(func_id) => {
let mutable = definition.mutable;
let location = Some(ident.location);
let name = definition.name.clone();
let typ = self.interner.id_type(expr_id);

let definition = self.lookup_function(func_id, expr_id, &typ);
let definition = self.lookup_function(*func_id, expr_id, &typ);
let typ = Self::convert_type(&typ);
let ident = ast::Ident { location, mutable, definition, name, typ };
ast::Expression::Ident(ident)
}
DefinitionKind::Global(expr_id) => self.expr_infer(expr_id),
DefinitionKind::Global(expr_id) => self.expr_infer(*expr_id),
DefinitionKind::Local(_) => {
let ident = self.local_ident(&ident).unwrap();
ast::Expression::Ident(ident)
}
DefinitionKind::GenericType(type_variable) => {
let value = match &*type_variable.borrow() {
TypeBinding::Unbound(_) => {
unreachable!("Unbound type variable used in expression")
}
TypeBinding::Bound(binding) => binding.evaluate_to_u64().unwrap_or_else(|| {
panic!("Non-numeric type variable used in expression expecting a value")
}),
};

let value = FieldElement::from(value as u128);
ast::Expression::Literal(ast::Literal::Integer(value, ast::Type::Field))
}
}
}

Expand Down
16 changes: 10 additions & 6 deletions crates/noirc_frontend/src/node_interner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::hir_def::{
function::{FuncMeta, HirFunction},
stmt::HirStatement,
};
use crate::{Shared, TypeBinding, TypeBindings, TypeVariableId};
use crate::{Shared, TypeBinding, TypeBindings, TypeVariable, TypeVariableId};

/// The node interner is the central storage location of all nodes in Noir's Hir (the
/// various node types can be found in hir_def). The interner is also used to collect
Expand Down Expand Up @@ -199,14 +199,18 @@ impl DefinitionInfo {
}
}

#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum DefinitionKind {
Function(FuncId),
Global(ExprId),

/// Locals may be defined in let statements or parameters,
/// in which case they will not have an associated ExprId
Local(Option<ExprId>),

/// Generic types in functions (T, U in `fn foo<T, U>(...)` are declared as variables
/// in scope in case they resolve to numeric generics later.
GenericType(TypeVariable),
}

impl DefinitionKind {
Expand All @@ -221,6 +225,7 @@ impl DefinitionKind {
DefinitionKind::Function(_) => None,
DefinitionKind::Global(id) => Some(id),
DefinitionKind::Local(id) => id,
DefinitionKind::GenericType(_) => None,
}
}
}
Expand Down Expand Up @@ -393,13 +398,12 @@ impl NodeInterner {
mutable: bool,
definition: DefinitionKind,
) -> DefinitionId {
let id = self.definitions.len();
self.definitions.push(DefinitionInfo { name, mutable, kind: definition });

let id = DefinitionId(id);
let id = DefinitionId(self.definitions.len());
if let DefinitionKind::Function(func_id) = definition {
self.function_definition_ids.insert(func_id, id);
}

self.definitions.push(DefinitionInfo { name, mutable, kind: definition });
id
}

Expand Down
13 changes: 13 additions & 0 deletions noir_stdlib/src/array.nr
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@ impl<T, N> [T; N] {
a
}

// Apply a function to each element of an array, returning a new array
// containing the mapped elements.
fn map<U>(self, f: fn(T) -> U) -> [U; N] {
let first_elem = f(self[0]);
let mut ret = [first_elem; N];

for i in 1 .. self.len() {
ret[i] = f(self[i]);
}

ret
}

// Apply a function to each element of the array and an accumulator value,
// returning the final accumulated value. This function is also sometimes
// called `foldl`, `fold_left`, `reduce`, or `inject`.
Expand Down