Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement numeric generics #620

Merged
merged 19 commits into from
Jan 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions crates/nargo/tests/test_data/global_consts/src/main.nr
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod baz;
global M: Field = 32;
global L: Field = 10; // Unused globals currently allowed
global N: Field = 5;
global T_LEN = 2; // Type inference is allowed on globals
//global N: Field = 5; // Uncomment to see duplicate globals error

struct Dummy {
Expand Down Expand Up @@ -40,9 +41,7 @@ fn main(a: [Field; M + N - N], b: [Field; 30 + N / 2], c : pub [Field; foo::MAGI

arrays_neq(a, b);

//let mut L: Field = 2; // Uncomment to show expected comptime error for array annotations
let L: comptime Field = 2;
let t: [Field; L] = [N, M];
let t: [Field; T_LEN] = [N, M];
constrain t[1] == 32;

constrain 15 == mysubmodule::my_helper();
Expand Down
7 changes: 7 additions & 0 deletions crates/nargo/tests/test_data/numeric_generics/Nargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

[package]
authors = [""]
compiler_version = "0.1"

[dependencies]

Empty file.
28 changes: 28 additions & 0 deletions crates/nargo/tests/test_data/numeric_generics/src/main.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
fn main() {
let a = id([1, 2]);
let b = id([1, 2, 3]);

let itWorks1 = MyStruct { data: a };
constrain itWorks1.data[1] == 2;
let itWorks2 = MyStruct { data: b };
constrain itWorks2.data[1] == 2;

let c = [1, 2];
let itAlsoWorks = MyStruct { data: c };
constrain itAlsoWorks.data[1] == 2;

constrain foo(itWorks2).data[0] == itWorks2.data[0] + 1;
}

fn id<I>(x: [Field; I]) -> [Field; I] {
x
}

struct MyStruct<S> {
data: [Field; S]
}

fn foo(mut s: MyStruct<2+1>) -> MyStruct<10/2-2> {
s.data[0] = s.data[0] + 1;
s
}
100 changes: 96 additions & 4 deletions crates/noirc_frontend/src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use noirc_errors::Span;
pub use statement::*;
pub use structure::*;

use crate::{token::IntType, Comptime};
use crate::{parser::ParserError, token::IntType, BinaryTypeOperator, Comptime};
use iter_extended::vecmap;

/// The parser parses types as 'UnresolvedType's which
Expand All @@ -21,10 +21,11 @@ use iter_extended::vecmap;
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum UnresolvedType {
FieldElement(Comptime),
Array(Option<Expression>, Box<UnresolvedType>), // [4]Witness = Array(4, Witness)
Integer(Comptime, Signedness, u32), // u32 = Integer(unsigned, 32)
Array(Option<UnresolvedTypeExpression>, Box<UnresolvedType>), // [4]Witness = Array(4, Witness)
Integer(Comptime, Signedness, u32), // u32 = Integer(unsigned, 32)
Bool(Comptime),
String(Option<Expression>),
Expression(UnresolvedTypeExpression),
String(Option<UnresolvedTypeExpression>),
Unit,

/// A Named UnresolvedType can be a struct type or a type variable
Expand All @@ -39,6 +40,21 @@ pub enum UnresolvedType {
Error,
}

/// The precursor to TypeExpression, this is the type that the parser allows
/// to be used in the length position of an array type. Only constants, variables,
/// and numeric binary operators are allowed here.
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum UnresolvedTypeExpression {
Variable(Path),
Constant(u64, Span),
BinaryOperation(
Box<UnresolvedTypeExpression>,
BinaryTypeOperator,
Box<UnresolvedTypeExpression>,
Span,
),
}

impl Recoverable for UnresolvedType {
fn error(_: Span) -> Self {
UnresolvedType::Error
Expand Down Expand Up @@ -70,6 +86,7 @@ impl std::fmt::Display for UnresolvedType {
let elements = vecmap(elements, ToString::to_string);
write!(f, "({})", elements.join(", "))
}
Expression(expression) => expression.fmt(f),
Bool(is_const) => write!(f, "{is_const}bool"),
String(len) => match len {
None => write!(f, "str[]"),
Expand All @@ -86,6 +103,18 @@ impl std::fmt::Display for UnresolvedType {
}
}

impl std::fmt::Display for UnresolvedTypeExpression {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
UnresolvedTypeExpression::Variable(name) => name.fmt(f),
UnresolvedTypeExpression::Constant(x, _) => x.fmt(f),
UnresolvedTypeExpression::BinaryOperation(lhs, op, rhs, _) => {
write!(f, "({lhs} {op} {rhs})")
}
}
}
}

impl UnresolvedType {
pub fn from_int_token(token: (Comptime, IntType)) -> UnresolvedType {
use {IntType::*, UnresolvedType::Integer};
Expand All @@ -101,3 +130,66 @@ pub enum Signedness {
Unsigned,
Signed,
}

impl UnresolvedTypeExpression {
pub fn from_expr(
expr: Expression,
span: Span,
) -> Result<UnresolvedTypeExpression, ParserError> {
Self::from_expr_helper(expr).map_err(|err| {
ParserError::with_reason(
format!("Expression is invalid in an array-length type: '{err}'. Only unsigned integer constants, globals, generics, +, -, *, /, and % may be used in this context."),
span,
)
})
}

pub fn span(&self) -> Span {
match self {
UnresolvedTypeExpression::Variable(path) => path.span(),
UnresolvedTypeExpression::Constant(_, span) => *span,
UnresolvedTypeExpression::BinaryOperation(_, _, _, span) => *span,
}
}

fn from_expr_helper(expr: Expression) -> Result<UnresolvedTypeExpression, Expression> {
match expr.kind {
ExpressionKind::Literal(Literal::Integer(int)) => match int.try_to_u64() {
Some(int) => Ok(UnresolvedTypeExpression::Constant(int, expr.span)),
None => Err(expr),
},
ExpressionKind::Variable(path) => Ok(UnresolvedTypeExpression::Variable(path)),
ExpressionKind::Prefix(prefix) if prefix.operator == UnaryOp::Minus => {
let lhs = Box::new(UnresolvedTypeExpression::Constant(0, expr.span));
let rhs = Box::new(UnresolvedTypeExpression::from_expr_helper(prefix.rhs)?);
let op = BinaryTypeOperator::Subtraction;
Ok(UnresolvedTypeExpression::BinaryOperation(lhs, op, rhs, expr.span))
}
ExpressionKind::Infix(infix) if Self::operator_allowed(infix.operator.contents) => {
let lhs = Box::new(UnresolvedTypeExpression::from_expr_helper(infix.lhs)?);
let rhs = Box::new(UnresolvedTypeExpression::from_expr_helper(infix.rhs)?);
let op = match infix.operator.contents {
BinaryOpKind::Add => BinaryTypeOperator::Addition,
BinaryOpKind::Subtract => BinaryTypeOperator::Subtraction,
BinaryOpKind::Multiply => BinaryTypeOperator::Multiplication,
BinaryOpKind::Divide => BinaryTypeOperator::Division,
BinaryOpKind::Modulo => BinaryTypeOperator::Modulo,
_ => unreachable!(), // impossible via operator_allowed check
};
Ok(UnresolvedTypeExpression::BinaryOperation(lhs, op, rhs, expr.span))
}
_ => Err(expr),
}
}

fn operator_allowed(op: BinaryOpKind) -> bool {
matches!(
op,
BinaryOpKind::Add
| BinaryOpKind::Subtract
| BinaryOpKind::Multiply
| BinaryOpKind::Divide
| BinaryOpKind::Modulo
)
}
}
7 changes: 7 additions & 0 deletions crates/noirc_frontend/src/hir/resolution/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ pub enum ResolverError {
InvalidArrayLengthExpr { span: Span },
#[error("Integer too large to be evaluated in an array length context")]
IntegerTooLarge { span: Span },
#[error("No global or generic type parameter found with the given name")]
NoSuchNumericTypeVariable { path: crate::Path },
#[error("Closures cannot capture mutable variables")]
CapturedMutableVariable { span: Span },
}
Expand Down Expand Up @@ -189,6 +191,11 @@ impl ResolverError {
"Array-lengths may be a maximum size of usize::MAX, including intermediate calculations".into(),
span,
),
ResolverError::NoSuchNumericTypeVariable { path } => Diagnostic::simple_error(
format!("Cannot find a global or generic type parameter named `{path}`"),
"Only globals or generic type parameters are allowed to be used as an array type's length".to_string(),
path.span(),
),
ResolverError::CapturedMutableVariable { span } => Diagnostic::simple_error(
"Closures cannot capture mutable variables".into(),
"Mutable variable".into(),
Expand Down
82 changes: 71 additions & 11 deletions crates/noirc_frontend/src/hir/resolution/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use crate::{
};
use crate::{
ArrayLiteral, Generics, LValue, NoirStruct, Path, Pattern, Shared, StructType, Type,
TypeBinding, TypeVariable, UnresolvedType, ERROR_IDENT,
TypeBinding, TypeVariable, UnresolvedType, UnresolvedTypeExpression, ERROR_IDENT,
};
use fm::FileId;
use iter_extended::vecmap;
Expand Down Expand Up @@ -302,6 +302,7 @@ impl<'a> Resolver<'a> {
let elem = Box::new(self.resolve_type_inner(*elem, new_variables));
Type::Array(Box::new(resolved_size), elem)
}
UnresolvedType::Expression(expr) => self.convert_expression_type(expr),
UnresolvedType::Integer(comptime, sign, bits) => Type::Integer(comptime, sign, bits),
UnresolvedType::Bool(comptime) => Type::Bool(comptime),
UnresolvedType::String(size) => {
Expand All @@ -313,7 +314,7 @@ impl<'a> Resolver<'a> {
UnresolvedType::Error => Type::Error,
UnresolvedType::Named(path, args) => {
// Check if the path is a type variable first. We currently disallow generics on type
// variables since this is what rust does.
// variables since we do not support higher-kinded types.
if args.is_empty() && path.segments.len() == 1 {
let name = &path.last_segment().0.contents;
if let Some((name, (var, _))) = self.generics.get_key_value(name) {
Expand Down Expand Up @@ -342,10 +343,10 @@ impl<'a> Resolver<'a> {

fn resolve_array_size(
&mut self,
size: Option<Expression>,
length: Option<UnresolvedTypeExpression>,
new_variables: &mut Generics,
) -> Type {
match &size {
match length {
None => {
let id = self.interner.next_type_variable_id();
let typevar = Shared::new(TypeBinding::Unbound(id));
Expand All @@ -356,9 +357,47 @@ impl<'a> Resolver<'a> {
// require users to explicitly be generic over array lengths.
Type::NamedGeneric(typevar, Rc::new("".into()))
}
Some(expr) => {
let len = self.eval_array_length(expr);
Type::ArrayLength(len)
Some(length) => self.convert_expression_type(length),
}
}

fn convert_expression_type(&mut self, length: UnresolvedTypeExpression) -> Type {
match length {
UnresolvedTypeExpression::Variable(path) => {
if path.segments.len() == 1 {
let name = &path.last_segment().0.contents;
if let Some((name, (var, _))) = self.generics.get_key_value(name) {
return Type::NamedGeneric(var.clone(), name.clone());
}
}

// If we cannot find a local generic of the same name, try to look up a global
if let Ok(ModuleDefId::GlobalId(id)) =
self.path_resolver.resolve(self.def_maps, path.clone())
{
Type::Constant(self.eval_global_as_array_length(id))
} else {
self.push_err(ResolverError::NoSuchNumericTypeVariable { path });
Type::Constant(0)
}
}
UnresolvedTypeExpression::Constant(int, _) => Type::Constant(int),
UnresolvedTypeExpression::BinaryOperation(lhs, op, rhs, _) => {
let (lhs_span, rhs_span) = (lhs.span(), rhs.span());
let lhs = self.convert_expression_type(*lhs);
let rhs = self.convert_expression_type(*rhs);

match (lhs, rhs) {
(Type::Constant(lhs), Type::Constant(rhs)) => {
Type::Constant(op.function()(lhs, rhs))
}
(lhs, _) => {
let span =
if !matches!(lhs, Type::Constant(_)) { lhs_span } else { rhs_span };
self.push_err(ResolverError::InvalidArrayLengthExpr { span });
Type::Constant(0)
}
}
}
}
}
Expand Down Expand Up @@ -914,11 +953,32 @@ impl<'a> Resolver<'a> {
}

fn eval_array_length(&mut self, length: &Expression) -> u64 {
match self.try_eval_array_length(length).map(|length| length.try_into()) {
Ok(Ok(length_value)) => return length_value,
Ok(Err(_cast_err)) => {
self.push_err(ResolverError::IntegerTooLarge { span: length.span })
let result = self.try_eval_array_length(length);
self.unwrap_array_length_eval_result(result, length.span)
}

fn eval_global_as_array_length(&mut self, global: StmtId) -> u64 {
let stmt = match self.interner.statement(&global) {
HirStatement::Let(let_expr) => let_expr,
other => {
unreachable!("Expected global while evaluating array length, found {:?}", other)
}
};

let length = stmt.expression;
let span = self.interner.expr_span(&length);
let result = self.try_eval_array_length_id(length, span);
self.unwrap_array_length_eval_result(result, span)
}

fn unwrap_array_length_eval_result(
&mut self,
result: Result<u128, Option<ResolverError>>,
span: Span,
) -> u64 {
match result.map(|length| length.try_into()) {
Ok(Ok(length_value)) => return length_value,
Ok(Err(_cast_err)) => self.push_err(ResolverError::IntegerTooLarge { span }),
Err(Some(error)) => self.push_err(error),
Err(None) => (),
}
Expand Down
10 changes: 4 additions & 6 deletions crates/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,13 @@ pub(crate) fn type_check_expression(
HirExpression::Literal(literal) => {
match literal {
HirLiteral::Array(arr) => {
// Type check the contents of the array
vezenovm marked this conversation as resolved.
Show resolved Hide resolved
let elem_types =
vecmap(&arr, |arg| type_check_expression(interner, arg, errors));

let first_elem_type = elem_types.get(0).cloned().unwrap_or(Type::Error);

// Specify the type of the Array
// Note: This assumes that the array is homogeneous, which will be checked next
let arr_type = Type::Array(
Box::new(Type::ArrayLength(arr.len() as u64)),
Box::new(Type::Constant(arr.len() as u64)),
Box::new(first_elem_type.clone()),
);

Expand Down Expand Up @@ -72,7 +69,8 @@ pub(crate) fn type_check_expression(
)
}
HirLiteral::Str(string) => {
Type::String(Box::new(Type::ArrayLength(string.len() as u64)))
let len = Type::Constant(string.len() as u64);
Type::String(Box::new(len))
}
}
}
Expand Down Expand Up @@ -783,7 +781,7 @@ pub fn comparator_operand_type_rules(

x_size.unify(y_size, op.location.span, errors, || {
TypeCheckError::Unstructured {
msg: format!("Can only compare arrays of the same length. Here LHS is of length {x_size}, and RHS is {y_size} "),
msg: format!("Can only compare arrays of the same length. Here LHS is of length {x_size}, and RHS is {y_size}"),
span: op.location.span,
}
});
Expand Down
Loading