diff --git a/crates/nargo_cli/tests/test_data/higher_order_functions/src/main.nr b/crates/nargo_cli/tests/test_data/higher_order_functions/src/main.nr index 572e6603cc5..e5f4a2f6598 100644 --- a/crates/nargo_cli/tests/test_data/higher_order_functions/src/main.nr +++ b/crates/nargo_cli/tests/test_data/higher_order_functions/src/main.nr @@ -1,7 +1,7 @@ use dep::std; fn main() -> pub Field { - let f = if 3 * 7 > 200 { foo } else { bar }; + let f = if 3 * 7 > 200 as u32 { foo } else { bar }; assert(f()[1] == 2); // Lambdas: diff --git a/crates/noirc_frontend/src/ast/expression.rs b/crates/noirc_frontend/src/ast/expression.rs index 9e9ff2f592e..4d19d47d484 100644 --- a/crates/noirc_frontend/src/ast/expression.rs +++ b/crates/noirc_frontend/src/ast/expression.rs @@ -198,7 +198,7 @@ impl BinaryOpKind { /// Comparator operators return a 0 or 1 /// When seen in the middle of an infix operator, /// they transform the infix expression into a predicate expression - pub fn is_comparator(&self) -> bool { + pub fn is_comparator(self) -> bool { matches!( self, BinaryOpKind::Equal @@ -210,6 +210,10 @@ impl BinaryOpKind { ) } + pub fn is_valid_for_field_type(self) -> bool { + matches!(self, BinaryOpKind::Equal | BinaryOpKind::NotEqual) + } + pub fn as_string(self) -> &'static str { match self { BinaryOpKind::Add => "+", diff --git a/crates/noirc_frontend/src/hir/type_check/expr.rs b/crates/noirc_frontend/src/hir/type_check/expr.rs index 7160d1f153d..031adaf75ec 100644 --- a/crates/noirc_frontend/src/hir/type_check/expr.rs +++ b/crates/noirc_frontend/src/hir/type_check/expr.rs @@ -526,35 +526,41 @@ impl<'interner> TypeChecker<'interner> { lhs_type: &Type, rhs_type: &Type, op: &HirBinaryOp, + span: Span, ) -> Result { use crate::BinaryOpKind::{Equal, NotEqual}; use Type::*; + let make_error = move |msg| TypeCheckError::Unstructured { msg, span }; + match (lhs_type, rhs_type) { // Avoid reporting errors multiple times (Error, _) | (_,Error) => Ok(Bool(CompTime::Yes(None))), // Matches on PolymorphicInteger and TypeVariable must be first to follow any type // bindings. - (PolymorphicInteger(comptime, int), other) - | (other, PolymorphicInteger(comptime, int)) => { + (var @ PolymorphicInteger(_, int), other) + | (other, var @ PolymorphicInteger(_, int)) + | (var @ TypeVariable(int), other) + | (other, var @ TypeVariable(int)) => { if let TypeBinding::Bound(binding) = &*int.borrow() { - return self.comparator_operand_type_rules(other, binding, op); + return self.comparator_operand_type_rules(other, binding, op, span); } - if other.try_bind_to_polymorphic_int(int, comptime, true, op.location.span).is_ok() || other == &Type::Error { - Ok(Bool(comptime.clone())) - } else { - Err(format!("Types in a binary operation should match, but found {lhs_type} and {rhs_type}")) - } - } - (TypeVariable(var), other) - | (other, TypeVariable(var)) => { - if let TypeBinding::Bound(binding) = &*var.borrow() { - return self.comparator_operand_type_rules(binding, other, op); + + if !op.kind.is_valid_for_field_type() && (other.is_bindable() || other.is_field()) { + let other = other.follow_bindings(); + + self.push_delayed_type_check(Box::new(move || { + if other.is_field() || other.is_bindable() { + Err(make_error("Comparisons are invalid on Field types. Try casting the operands to a sized integer type first".into())) + } else { + Ok(()) + } + })); } - let comptime = CompTime::No(None); - if other.try_bind_to_polymorphic_int(var, &comptime, true, op.location.span).is_ok() || other == &Type::Error { - Ok(Bool(comptime)) + let comptime = var.try_get_comptime(); + if other.try_bind_to_polymorphic_int(int, &comptime, true, op.location.span).is_ok() || other == &Type::Error { + Ok(Bool(comptime.into_owned())) } else { Err(format!("Types in a binary operation should match, but found {lhs_type} and {rhs_type}")) } @@ -576,14 +582,11 @@ impl<'interner> TypeChecker<'interner> { Err(format!("Integer cannot be used with type {typ}")) } (FieldElement(comptime_x), FieldElement(comptime_y)) => { - match op.kind { - Equal | NotEqual => { - let comptime = comptime_x.and(comptime_y, op.location.span); - Ok(Bool(comptime)) - }, - _ => { - Err("Fields cannot be compared, try casting to an integer first".into()) - } + if op.kind.is_valid_for_field_type() { + let comptime = comptime_x.and(comptime_y, op.location.span); + Ok(Bool(comptime)) + } else { + Err("Fields cannot be compared, try casting to an integer first".into()) } } @@ -741,7 +744,9 @@ impl<'interner> TypeChecker<'interner> { let make_error = move |msg| TypeCheckError::Unstructured { msg, span }; if op.kind.is_comparator() { - return self.comparator_operand_type_rules(lhs_type, rhs_type, op).map_err(make_error); + return self + .comparator_operand_type_rules(lhs_type, rhs_type, op, span) + .map_err(make_error); } use Type::*; @@ -751,8 +756,10 @@ impl<'interner> TypeChecker<'interner> { // Matches on PolymorphicInteger and TypeVariable must be first so that we follow any type // bindings. - (PolymorphicInteger(comptime, int), other) - | (other, PolymorphicInteger(comptime, int)) => { + (var @ PolymorphicInteger(_, int), other) + | (other, var @ PolymorphicInteger(_, int)) + | (var @ TypeVariable(int), other) + | (other, var @ TypeVariable(int)) => { if let TypeBinding::Bound(binding) = &*int.borrow() { return self.infix_operand_type_rules(binding, op, other, span); } @@ -774,20 +781,8 @@ impl<'interner> TypeChecker<'interner> { })); } - if other.try_bind_to_polymorphic_int(int, comptime, true, op.location.span).is_ok() || other == &Type::Error { - Ok(other.clone()) - } else { - Err(make_error(format!("Types in a binary operation should match, but found {lhs_type} and {rhs_type}"))) - } - } - (TypeVariable(var), other) - | (other, TypeVariable(var)) => { - if let TypeBinding::Bound(binding) = &*var.borrow() { - return self.infix_operand_type_rules(binding, op, other, span); - } - - let comptime = CompTime::No(None); - if other.try_bind_to_polymorphic_int(var, &comptime, true, op.location.span).is_ok() || other == &Type::Error { + let comptime = var.try_get_comptime(); + if other.try_bind_to_polymorphic_int(int, &comptime, true, op.location.span).is_ok() || other == &Type::Error { Ok(other.clone()) } else { Err(make_error(format!("Types in a binary operation should match, but found {lhs_type} and {rhs_type}"))) diff --git a/crates/noirc_frontend/src/hir_def/types.rs b/crates/noirc_frontend/src/hir_def/types.rs index 133d8a79055..7769b0b1153 100644 --- a/crates/noirc_frontend/src/hir_def/types.rs +++ b/crates/noirc_frontend/src/hir_def/types.rs @@ -1,4 +1,5 @@ use std::{ + borrow::Cow, cell::RefCell, collections::{BTreeSet, HashMap}, rc::Rc, @@ -594,6 +595,16 @@ impl Type { Type::Vec(element) => element.contains_numeric_typevar(target_id), } } + + pub(crate) fn try_get_comptime(&self) -> Cow { + match self { + Type::FieldElement(comptime) + | Type::Integer(comptime, _, _) + | Type::Bool(comptime) + | Type::PolymorphicInteger(comptime, _) => Cow::Borrowed(comptime), + _ => Cow::Owned(CompTime::No(None)), + } + } } impl std::fmt::Display for Type { diff --git a/noir_stdlib/src/sha256.nr b/noir_stdlib/src/sha256.nr index cf72fec7266..63107442bf8 100644 --- a/noir_stdlib/src/sha256.nr +++ b/noir_stdlib/src/sha256.nr @@ -104,11 +104,11 @@ fn digest(msg: [u8; N]) -> [u8; 32] { let mut h: [u32; 8] = [1779033703,3144134277,1013904242,2773480762,1359893119,2600822924,528734635,1541459225]; // Intermediate hash, starting with the canonical initial value let mut c: [u32; 8] = [0; 8]; // Compression of current message block as sequence of u32 let mut out_h: [u8; 32] = [0; 32]; // Digest as sequence of bytes - let mut i = 0; // Message byte pointer + let mut i: u64 = 0; // Message byte pointer for k in 0 .. msg.len() { // Populate msg_block - msg_block[i] = msg[k]; + msg_block[i as Field] = msg[k]; i = i + 1; if i == 64 { // Enough to hash block c = sha_c(msg_u8_to_u32(msg_block), h); @@ -122,7 +122,7 @@ fn digest(msg: [u8; N]) -> [u8; 32] { // Pad the rest such that we have a [u32; 2] block at the end representing the length // of the message, and a block of 1 0 ... 0 following the message (i.e. [1 << 7, 0, ..., 0]). - msg_block[i] = 1 << 7; + msg_block[i as Field] = 1 << 7; i = i + 1; // If i >= 57, there aren't enough bits in the current message block to accomplish this, so @@ -131,7 +131,7 @@ fn digest(msg: [u8; N]) -> [u8; 32] { if i < 64 { for _i in 57..64 { if i <= 63 { - msg_block[i] = 0; + msg_block[i as Field] = 0; i += 1; } } @@ -147,7 +147,7 @@ fn digest(msg: [u8; N]) -> [u8; 32] { for _i in 0..64 {// In any case, fill blocks up with zeros until the last 64 (i.e. until i = 56). if i < 56 { - msg_block[i] = 0; + msg_block[i as Field] = 0; i = i + 1; } else if i < 64 { let mut len = 8 * msg.len() as u64; diff --git a/noir_stdlib/src/sha512.nr b/noir_stdlib/src/sha512.nr index 66a5cc5a169..1617af8e598 100644 --- a/noir_stdlib/src/sha512.nr +++ b/noir_stdlib/src/sha512.nr @@ -104,11 +104,11 @@ fn digest(msg: [u8; N]) -> [u8; 64] let mut h: [u64; 8] = [7640891576956012808, 13503953896175478587, 4354685564936845355, 11912009170470909681, 5840696475078001361, 11170449401992604703, 2270897969802886507, 6620516959819538809]; // Intermediate hash, starting with the canonical initial value let mut c: [u64; 8] = [0; 8]; // Compression of current message block as sequence of u64 let mut out_h: [u8; 64] = [0; 64]; // Digest as sequence of bytes - let mut i = 0; // Message byte pointer + let mut i: u64 = 0; // Message byte pointer for k in 0 .. msg.len() { // Populate msg_block - msg_block[i] = msg[k]; + msg_block[i as Field] = msg[k]; i = i + 1; if i == 128 { // Enough to hash block c = sha_c(msg_u8_to_u64(msg_block), h); @@ -122,7 +122,7 @@ fn digest(msg: [u8; N]) -> [u8; 64] // Pad the rest such that we have a [u64; 2] block at the end representing the length // of the message, and a block of 1 0 ... 0 following the message (i.e. [1 << 7, 0, ..., 0]). - msg_block[i] = 1 << 7; + msg_block[i as Field] = 1 << 7; i += 1; // If i >= 113, there aren't enough bits in the current message block to accomplish this, so @@ -131,7 +131,7 @@ fn digest(msg: [u8; N]) -> [u8; 64] if i < 128 { for _i in 113..128 { if i <= 127 { - msg_block[i] = 0; + msg_block[i as Field] = 0; i += 1; } } @@ -146,7 +146,7 @@ fn digest(msg: [u8; N]) -> [u8; 64] for _i in 0..128 {// In any case, fill blocks up with zeros until the last 128 (i.e. until i = 112). if i < 112 { - msg_block[i] = 0; + msg_block[i as Field] = 0; i += 1; } else if i < 128 { let mut len = 8 * msg.len() as u64; // u128 unsupported