Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Prevent comparisons from being used on Fields #1860

Merged
merged 2 commits into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use dep::std;

fn main() -> pub Field {
let f = if 3 * 7 > 200 { foo } else { bar };
let f = if 3 * 7 > 200 as u32 { foo } else { bar };
assert(f()[1] == 2);

// Lambdas:
Expand Down
6 changes: 5 additions & 1 deletion crates/noirc_frontend/src/ast/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ impl BinaryOpKind {
/// Comparator operators return a 0 or 1
/// When seen in the middle of an infix operator,
/// they transform the infix expression into a predicate expression
pub fn is_comparator(&self) -> bool {
pub fn is_comparator(self) -> bool {
matches!(
self,
BinaryOpKind::Equal
Expand All @@ -210,6 +210,10 @@ impl BinaryOpKind {
)
}

pub fn is_valid_for_field_type(self) -> bool {
matches!(self, BinaryOpKind::Equal | BinaryOpKind::NotEqual)
}

pub fn as_string(self) -> &'static str {
match self {
BinaryOpKind::Add => "+",
Expand Down
77 changes: 36 additions & 41 deletions crates/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -526,35 +526,41 @@ impl<'interner> TypeChecker<'interner> {
lhs_type: &Type,
rhs_type: &Type,
op: &HirBinaryOp,
span: Span,
) -> Result<Type, String> {
use crate::BinaryOpKind::{Equal, NotEqual};
use Type::*;
let make_error = move |msg| TypeCheckError::Unstructured { msg, span };

match (lhs_type, rhs_type) {
// Avoid reporting errors multiple times
(Error, _) | (_,Error) => Ok(Bool(CompTime::Yes(None))),

// Matches on PolymorphicInteger and TypeVariable must be first to follow any type
// bindings.
(PolymorphicInteger(comptime, int), other)
| (other, PolymorphicInteger(comptime, int)) => {
(var @ PolymorphicInteger(_, int), other)
| (other, var @ PolymorphicInteger(_, int))
| (var @ TypeVariable(int), other)
| (other, var @ TypeVariable(int)) => {
if let TypeBinding::Bound(binding) = &*int.borrow() {
return self.comparator_operand_type_rules(other, binding, op);
return self.comparator_operand_type_rules(other, binding, op, span);
}
if other.try_bind_to_polymorphic_int(int, comptime, true, op.location.span).is_ok() || other == &Type::Error {
Ok(Bool(comptime.clone()))
} else {
Err(format!("Types in a binary operation should match, but found {lhs_type} and {rhs_type}"))
}
}
(TypeVariable(var), other)
| (other, TypeVariable(var)) => {
if let TypeBinding::Bound(binding) = &*var.borrow() {
return self.comparator_operand_type_rules(binding, other, op);

if !op.kind.is_valid_for_field_type() && (other.is_bindable() || other.is_field()) {
let other = other.follow_bindings();

self.push_delayed_type_check(Box::new(move || {
if other.is_field() || other.is_bindable() {
Err(make_error("Comparisons are invalid on Field types. Try casting the operands to a sized integer type first".into()))
} else {
Ok(())
}
}));
}

let comptime = CompTime::No(None);
if other.try_bind_to_polymorphic_int(var, &comptime, true, op.location.span).is_ok() || other == &Type::Error {
Ok(Bool(comptime))
let comptime = var.try_get_comptime();
if other.try_bind_to_polymorphic_int(int, &comptime, true, op.location.span).is_ok() || other == &Type::Error {
Ok(Bool(comptime.into_owned()))
} else {
Err(format!("Types in a binary operation should match, but found {lhs_type} and {rhs_type}"))
}
Expand All @@ -576,14 +582,11 @@ impl<'interner> TypeChecker<'interner> {
Err(format!("Integer cannot be used with type {typ}"))
}
(FieldElement(comptime_x), FieldElement(comptime_y)) => {
match op.kind {
Equal | NotEqual => {
let comptime = comptime_x.and(comptime_y, op.location.span);
Ok(Bool(comptime))
},
_ => {
Err("Fields cannot be compared, try casting to an integer first".into())
}
if op.kind.is_valid_for_field_type() {
let comptime = comptime_x.and(comptime_y, op.location.span);
Ok(Bool(comptime))
} else {
Err("Fields cannot be compared, try casting to an integer first".into())
}
}

Expand Down Expand Up @@ -741,7 +744,9 @@ impl<'interner> TypeChecker<'interner> {
let make_error = move |msg| TypeCheckError::Unstructured { msg, span };

if op.kind.is_comparator() {
return self.comparator_operand_type_rules(lhs_type, rhs_type, op).map_err(make_error);
return self
.comparator_operand_type_rules(lhs_type, rhs_type, op, span)
.map_err(make_error);
}

use Type::*;
Expand All @@ -751,8 +756,10 @@ impl<'interner> TypeChecker<'interner> {

// Matches on PolymorphicInteger and TypeVariable must be first so that we follow any type
// bindings.
(PolymorphicInteger(comptime, int), other)
| (other, PolymorphicInteger(comptime, int)) => {
(var @ PolymorphicInteger(_, int), other)
| (other, var @ PolymorphicInteger(_, int))
| (var @ TypeVariable(int), other)
| (other, var @ TypeVariable(int)) => {
if let TypeBinding::Bound(binding) = &*int.borrow() {
return self.infix_operand_type_rules(binding, op, other, span);
}
Expand All @@ -774,20 +781,8 @@ impl<'interner> TypeChecker<'interner> {
}));
}

if other.try_bind_to_polymorphic_int(int, comptime, true, op.location.span).is_ok() || other == &Type::Error {
Ok(other.clone())
} else {
Err(make_error(format!("Types in a binary operation should match, but found {lhs_type} and {rhs_type}")))
}
}
(TypeVariable(var), other)
| (other, TypeVariable(var)) => {
if let TypeBinding::Bound(binding) = &*var.borrow() {
return self.infix_operand_type_rules(binding, op, other, span);
}

let comptime = CompTime::No(None);
if other.try_bind_to_polymorphic_int(var, &comptime, true, op.location.span).is_ok() || other == &Type::Error {
let comptime = var.try_get_comptime();
if other.try_bind_to_polymorphic_int(int, &comptime, true, op.location.span).is_ok() || other == &Type::Error {
Ok(other.clone())
} else {
Err(make_error(format!("Types in a binary operation should match, but found {lhs_type} and {rhs_type}")))
Expand Down
11 changes: 11 additions & 0 deletions crates/noirc_frontend/src/hir_def/types.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::{
borrow::Cow,
cell::RefCell,
collections::{BTreeSet, HashMap},
rc::Rc,
Expand Down Expand Up @@ -594,6 +595,16 @@ impl Type {
Type::Vec(element) => element.contains_numeric_typevar(target_id),
}
}

pub(crate) fn try_get_comptime(&self) -> Cow<CompTime> {
match self {
Type::FieldElement(comptime)
| Type::Integer(comptime, _, _)
| Type::Bool(comptime)
| Type::PolymorphicInteger(comptime, _) => Cow::Borrowed(comptime),
_ => Cow::Owned(CompTime::No(None)),
}
}
}

impl std::fmt::Display for Type {
Expand Down
10 changes: 5 additions & 5 deletions noir_stdlib/src/sha256.nr
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@ fn digest<N>(msg: [u8; N]) -> [u8; 32] {
let mut h: [u32; 8] = [1779033703,3144134277,1013904242,2773480762,1359893119,2600822924,528734635,1541459225]; // Intermediate hash, starting with the canonical initial value
let mut c: [u32; 8] = [0; 8]; // Compression of current message block as sequence of u32
let mut out_h: [u8; 32] = [0; 32]; // Digest as sequence of bytes
let mut i = 0; // Message byte pointer
let mut i: u64 = 0; // Message byte pointer

for k in 0 .. msg.len() {
// Populate msg_block
msg_block[i] = msg[k];
msg_block[i as Field] = msg[k];
i = i + 1;
if i == 64 { // Enough to hash block
c = sha_c(msg_u8_to_u32(msg_block), h);
Expand All @@ -122,7 +122,7 @@ fn digest<N>(msg: [u8; N]) -> [u8; 32] {

// Pad the rest such that we have a [u32; 2] block at the end representing the length
// of the message, and a block of 1 0 ... 0 following the message (i.e. [1 << 7, 0, ..., 0]).
msg_block[i] = 1 << 7;
msg_block[i as Field] = 1 << 7;
i = i + 1;

// If i >= 57, there aren't enough bits in the current message block to accomplish this, so
Expand All @@ -131,7 +131,7 @@ fn digest<N>(msg: [u8; N]) -> [u8; 32] {
if i < 64 {
for _i in 57..64 {
if i <= 63 {
msg_block[i] = 0;
msg_block[i as Field] = 0;
i += 1;
}
}
Expand All @@ -147,7 +147,7 @@ fn digest<N>(msg: [u8; N]) -> [u8; 32] {

for _i in 0..64 {// In any case, fill blocks up with zeros until the last 64 (i.e. until i = 56).
if i < 56 {
msg_block[i] = 0;
msg_block[i as Field] = 0;
i = i + 1;
} else if i < 64 {
let mut len = 8 * msg.len() as u64;
Expand Down
10 changes: 5 additions & 5 deletions noir_stdlib/src/sha512.nr
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@ fn digest<N>(msg: [u8; N]) -> [u8; 64]
let mut h: [u64; 8] = [7640891576956012808, 13503953896175478587, 4354685564936845355, 11912009170470909681, 5840696475078001361, 11170449401992604703, 2270897969802886507, 6620516959819538809]; // Intermediate hash, starting with the canonical initial value
let mut c: [u64; 8] = [0; 8]; // Compression of current message block as sequence of u64
let mut out_h: [u8; 64] = [0; 64]; // Digest as sequence of bytes
let mut i = 0; // Message byte pointer
let mut i: u64 = 0; // Message byte pointer

for k in 0 .. msg.len() {
// Populate msg_block
msg_block[i] = msg[k];
msg_block[i as Field] = msg[k];
i = i + 1;
if i == 128 { // Enough to hash block
c = sha_c(msg_u8_to_u64(msg_block), h);
Expand All @@ -122,7 +122,7 @@ fn digest<N>(msg: [u8; N]) -> [u8; 64]

// Pad the rest such that we have a [u64; 2] block at the end representing the length
// of the message, and a block of 1 0 ... 0 following the message (i.e. [1 << 7, 0, ..., 0]).
msg_block[i] = 1 << 7;
msg_block[i as Field] = 1 << 7;
i += 1;

// If i >= 113, there aren't enough bits in the current message block to accomplish this, so
Expand All @@ -131,7 +131,7 @@ fn digest<N>(msg: [u8; N]) -> [u8; 64]
if i < 128 {
for _i in 113..128 {
if i <= 127 {
msg_block[i] = 0;
msg_block[i as Field] = 0;
i += 1;
}
}
Expand All @@ -146,7 +146,7 @@ fn digest<N>(msg: [u8; N]) -> [u8; 64]

for _i in 0..128 {// In any case, fill blocks up with zeros until the last 128 (i.e. until i = 112).
if i < 112 {
msg_block[i] = 0;
msg_block[i as Field] = 0;
i += 1;
} else if i < 128 {
let mut len = 8 * msg.len() as u64; // u128 unsupported
Expand Down