Skip to content

Commit

Permalink
fix: Prevent comparisons from being used on Fields (#1860)
Browse files Browse the repository at this point in the history
* Prevent comparisons from being used on fields

* Fix higher order functions test
  • Loading branch information
jfecher authored Jul 5, 2023
1 parent e55b5a8 commit c8858fd
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 53 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use dep::std;

fn main() -> pub Field {
let f = if 3 * 7 > 200 { foo } else { bar };
let f = if 3 * 7 > 200 as u32 { foo } else { bar };
assert(f()[1] == 2);

// Lambdas:
Expand Down
6 changes: 5 additions & 1 deletion crates/noirc_frontend/src/ast/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ impl BinaryOpKind {
/// Comparator operators return a 0 or 1
/// When seen in the middle of an infix operator,
/// they transform the infix expression into a predicate expression
pub fn is_comparator(&self) -> bool {
pub fn is_comparator(self) -> bool {
matches!(
self,
BinaryOpKind::Equal
Expand All @@ -210,6 +210,10 @@ impl BinaryOpKind {
)
}

pub fn is_valid_for_field_type(self) -> bool {
matches!(self, BinaryOpKind::Equal | BinaryOpKind::NotEqual)
}

pub fn as_string(self) -> &'static str {
match self {
BinaryOpKind::Add => "+",
Expand Down
77 changes: 36 additions & 41 deletions crates/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -526,35 +526,41 @@ impl<'interner> TypeChecker<'interner> {
lhs_type: &Type,
rhs_type: &Type,
op: &HirBinaryOp,
span: Span,
) -> Result<Type, String> {
use crate::BinaryOpKind::{Equal, NotEqual};
use Type::*;
let make_error = move |msg| TypeCheckError::Unstructured { msg, span };

match (lhs_type, rhs_type) {
// Avoid reporting errors multiple times
(Error, _) | (_,Error) => Ok(Bool(CompTime::Yes(None))),

// Matches on PolymorphicInteger and TypeVariable must be first to follow any type
// bindings.
(PolymorphicInteger(comptime, int), other)
| (other, PolymorphicInteger(comptime, int)) => {
(var @ PolymorphicInteger(_, int), other)
| (other, var @ PolymorphicInteger(_, int))
| (var @ TypeVariable(int), other)
| (other, var @ TypeVariable(int)) => {
if let TypeBinding::Bound(binding) = &*int.borrow() {
return self.comparator_operand_type_rules(other, binding, op);
return self.comparator_operand_type_rules(other, binding, op, span);
}
if other.try_bind_to_polymorphic_int(int, comptime, true, op.location.span).is_ok() || other == &Type::Error {
Ok(Bool(comptime.clone()))
} else {
Err(format!("Types in a binary operation should match, but found {lhs_type} and {rhs_type}"))
}
}
(TypeVariable(var), other)
| (other, TypeVariable(var)) => {
if let TypeBinding::Bound(binding) = &*var.borrow() {
return self.comparator_operand_type_rules(binding, other, op);

if !op.kind.is_valid_for_field_type() && (other.is_bindable() || other.is_field()) {
let other = other.follow_bindings();

self.push_delayed_type_check(Box::new(move || {
if other.is_field() || other.is_bindable() {
Err(make_error("Comparisons are invalid on Field types. Try casting the operands to a sized integer type first".into()))
} else {
Ok(())
}
}));
}

let comptime = CompTime::No(None);
if other.try_bind_to_polymorphic_int(var, &comptime, true, op.location.span).is_ok() || other == &Type::Error {
Ok(Bool(comptime))
let comptime = var.try_get_comptime();
if other.try_bind_to_polymorphic_int(int, &comptime, true, op.location.span).is_ok() || other == &Type::Error {
Ok(Bool(comptime.into_owned()))
} else {
Err(format!("Types in a binary operation should match, but found {lhs_type} and {rhs_type}"))
}
Expand All @@ -576,14 +582,11 @@ impl<'interner> TypeChecker<'interner> {
Err(format!("Integer cannot be used with type {typ}"))
}
(FieldElement(comptime_x), FieldElement(comptime_y)) => {
match op.kind {
Equal | NotEqual => {
let comptime = comptime_x.and(comptime_y, op.location.span);
Ok(Bool(comptime))
},
_ => {
Err("Fields cannot be compared, try casting to an integer first".into())
}
if op.kind.is_valid_for_field_type() {
let comptime = comptime_x.and(comptime_y, op.location.span);
Ok(Bool(comptime))
} else {
Err("Fields cannot be compared, try casting to an integer first".into())
}
}

Expand Down Expand Up @@ -741,7 +744,9 @@ impl<'interner> TypeChecker<'interner> {
let make_error = move |msg| TypeCheckError::Unstructured { msg, span };

if op.kind.is_comparator() {
return self.comparator_operand_type_rules(lhs_type, rhs_type, op).map_err(make_error);
return self
.comparator_operand_type_rules(lhs_type, rhs_type, op, span)
.map_err(make_error);
}

use Type::*;
Expand All @@ -751,8 +756,10 @@ impl<'interner> TypeChecker<'interner> {

// Matches on PolymorphicInteger and TypeVariable must be first so that we follow any type
// bindings.
(PolymorphicInteger(comptime, int), other)
| (other, PolymorphicInteger(comptime, int)) => {
(var @ PolymorphicInteger(_, int), other)
| (other, var @ PolymorphicInteger(_, int))
| (var @ TypeVariable(int), other)
| (other, var @ TypeVariable(int)) => {
if let TypeBinding::Bound(binding) = &*int.borrow() {
return self.infix_operand_type_rules(binding, op, other, span);
}
Expand All @@ -774,20 +781,8 @@ impl<'interner> TypeChecker<'interner> {
}));
}

if other.try_bind_to_polymorphic_int(int, comptime, true, op.location.span).is_ok() || other == &Type::Error {
Ok(other.clone())
} else {
Err(make_error(format!("Types in a binary operation should match, but found {lhs_type} and {rhs_type}")))
}
}
(TypeVariable(var), other)
| (other, TypeVariable(var)) => {
if let TypeBinding::Bound(binding) = &*var.borrow() {
return self.infix_operand_type_rules(binding, op, other, span);
}

let comptime = CompTime::No(None);
if other.try_bind_to_polymorphic_int(var, &comptime, true, op.location.span).is_ok() || other == &Type::Error {
let comptime = var.try_get_comptime();
if other.try_bind_to_polymorphic_int(int, &comptime, true, op.location.span).is_ok() || other == &Type::Error {
Ok(other.clone())
} else {
Err(make_error(format!("Types in a binary operation should match, but found {lhs_type} and {rhs_type}")))
Expand Down
11 changes: 11 additions & 0 deletions crates/noirc_frontend/src/hir_def/types.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::{
borrow::Cow,
cell::RefCell,
collections::{BTreeSet, HashMap},
rc::Rc,
Expand Down Expand Up @@ -594,6 +595,16 @@ impl Type {
Type::Vec(element) => element.contains_numeric_typevar(target_id),
}
}

pub(crate) fn try_get_comptime(&self) -> Cow<CompTime> {
match self {
Type::FieldElement(comptime)
| Type::Integer(comptime, _, _)
| Type::Bool(comptime)
| Type::PolymorphicInteger(comptime, _) => Cow::Borrowed(comptime),
_ => Cow::Owned(CompTime::No(None)),
}
}
}

impl std::fmt::Display for Type {
Expand Down
10 changes: 5 additions & 5 deletions noir_stdlib/src/sha256.nr
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@ fn digest<N>(msg: [u8; N]) -> [u8; 32] {
let mut h: [u32; 8] = [1779033703,3144134277,1013904242,2773480762,1359893119,2600822924,528734635,1541459225]; // Intermediate hash, starting with the canonical initial value
let mut c: [u32; 8] = [0; 8]; // Compression of current message block as sequence of u32
let mut out_h: [u8; 32] = [0; 32]; // Digest as sequence of bytes
let mut i = 0; // Message byte pointer
let mut i: u64 = 0; // Message byte pointer

for k in 0 .. msg.len() {
// Populate msg_block
msg_block[i] = msg[k];
msg_block[i as Field] = msg[k];
i = i + 1;
if i == 64 { // Enough to hash block
c = sha_c(msg_u8_to_u32(msg_block), h);
Expand All @@ -122,7 +122,7 @@ fn digest<N>(msg: [u8; N]) -> [u8; 32] {

// Pad the rest such that we have a [u32; 2] block at the end representing the length
// of the message, and a block of 1 0 ... 0 following the message (i.e. [1 << 7, 0, ..., 0]).
msg_block[i] = 1 << 7;
msg_block[i as Field] = 1 << 7;
i = i + 1;

// If i >= 57, there aren't enough bits in the current message block to accomplish this, so
Expand All @@ -131,7 +131,7 @@ fn digest<N>(msg: [u8; N]) -> [u8; 32] {
if i < 64 {
for _i in 57..64 {
if i <= 63 {
msg_block[i] = 0;
msg_block[i as Field] = 0;
i += 1;
}
}
Expand All @@ -147,7 +147,7 @@ fn digest<N>(msg: [u8; N]) -> [u8; 32] {

for _i in 0..64 {// In any case, fill blocks up with zeros until the last 64 (i.e. until i = 56).
if i < 56 {
msg_block[i] = 0;
msg_block[i as Field] = 0;
i = i + 1;
} else if i < 64 {
let mut len = 8 * msg.len() as u64;
Expand Down
10 changes: 5 additions & 5 deletions noir_stdlib/src/sha512.nr
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@ fn digest<N>(msg: [u8; N]) -> [u8; 64]
let mut h: [u64; 8] = [7640891576956012808, 13503953896175478587, 4354685564936845355, 11912009170470909681, 5840696475078001361, 11170449401992604703, 2270897969802886507, 6620516959819538809]; // Intermediate hash, starting with the canonical initial value
let mut c: [u64; 8] = [0; 8]; // Compression of current message block as sequence of u64
let mut out_h: [u8; 64] = [0; 64]; // Digest as sequence of bytes
let mut i = 0; // Message byte pointer
let mut i: u64 = 0; // Message byte pointer

for k in 0 .. msg.len() {
// Populate msg_block
msg_block[i] = msg[k];
msg_block[i as Field] = msg[k];
i = i + 1;
if i == 128 { // Enough to hash block
c = sha_c(msg_u8_to_u64(msg_block), h);
Expand All @@ -122,7 +122,7 @@ fn digest<N>(msg: [u8; N]) -> [u8; 64]

// Pad the rest such that we have a [u64; 2] block at the end representing the length
// of the message, and a block of 1 0 ... 0 following the message (i.e. [1 << 7, 0, ..., 0]).
msg_block[i] = 1 << 7;
msg_block[i as Field] = 1 << 7;
i += 1;

// If i >= 113, there aren't enough bits in the current message block to accomplish this, so
Expand All @@ -131,7 +131,7 @@ fn digest<N>(msg: [u8; N]) -> [u8; 64]
if i < 128 {
for _i in 113..128 {
if i <= 127 {
msg_block[i] = 0;
msg_block[i as Field] = 0;
i += 1;
}
}
Expand All @@ -146,7 +146,7 @@ fn digest<N>(msg: [u8; N]) -> [u8; 64]

for _i in 0..128 {// In any case, fill blocks up with zeros until the last 128 (i.e. until i = 112).
if i < 112 {
msg_block[i] = 0;
msg_block[i as Field] = 0;
i += 1;
} else if i < 128 {
let mut len = 8 * msg.len() as u64; // u128 unsupported
Expand Down

0 comments on commit c8858fd

Please sign in to comment.