Skip to content

Commit

Permalink
fix: Nested array equality (#4903)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves #4383 (both examples in the issue)

## Summary\*

We were relying on the built in array equality for arrays in SSA
previously, but this did not recurse on the array element in the case of
nested arrays. I've just removed this since it is no longer needed. We
now use the impl Eq for arrays in the stdlib instead since the built in
version provided no speedup anyway.

## Additional Context



## Documentation\*

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [ ] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.

---------

Co-authored-by: Tom French <tom@tomfren.ch>
  • Loading branch information
jfecher and TomAFrench authored Apr 24, 2024
1 parent b380dc4 commit 0cf2e2a
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 133 deletions.
108 changes: 5 additions & 103 deletions compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -566,22 +566,12 @@ impl<'a> FunctionContext<'a> {
mut rhs: ValueId,
location: Location,
) -> Values {
let result_type = self.builder.type_of_value(lhs);
let mut result = match operator {
BinaryOpKind::Equal | BinaryOpKind::NotEqual
if matches!(result_type, Type::Array(..)) =>
{
return self.insert_array_equality(lhs, operator, rhs, location)
}
_ => {
let op = convert_operator(operator);
if operator_requires_swapped_operands(operator) {
std::mem::swap(&mut lhs, &mut rhs);
}
let op = convert_operator(operator);
if operator_requires_swapped_operands(operator) {
std::mem::swap(&mut lhs, &mut rhs);
}

self.builder.set_location(location).insert_binary(lhs, op, rhs)
}
};
let mut result = self.builder.set_location(location).insert_binary(lhs, op, rhs);

// Check for integer overflow
if matches!(
Expand All @@ -600,94 +590,6 @@ impl<'a> FunctionContext<'a> {
result.into()
}

/// The frontend claims to support equality (==) on arrays, so we must support it in SSA here.
/// The actual BinaryOp::Eq in SSA is meant only for primitive numeric types so we encode an
/// entire equality loop on each array element. The generated IR is as follows:
///
/// ...
/// result_alloc = allocate
/// store u1 1 in result_alloc
/// jmp loop_start(0)
/// loop_start(i: Field):
/// v0 = lt i, array_len
/// jmpif v0, then: loop_body, else: loop_end
/// loop_body():
/// v1 = array_get lhs, index i
/// v2 = array_get rhs, index i
/// v3 = eq v1, v2
/// v4 = load result_alloc
/// v5 = and v4, v3
/// store v5 in result_alloc
/// v6 = add i, Field 1
/// jmp loop_start(v6)
/// loop_end():
/// result = load result_alloc
fn insert_array_equality(
&mut self,
lhs: ValueId,
operator: BinaryOpKind,
rhs: ValueId,
location: Location,
) -> Values {
let lhs_type = self.builder.type_of_value(lhs);
let rhs_type = self.builder.type_of_value(rhs);

let (array_length, element_type) = match (lhs_type, rhs_type) {
(
Type::Array(lhs_composite_type, lhs_length),
Type::Array(rhs_composite_type, rhs_length),
) => {
assert!(
lhs_composite_type.len() == 1 && rhs_composite_type.len() == 1,
"== is unimplemented for arrays of structs"
);
assert_eq!(lhs_composite_type[0], rhs_composite_type[0]);
assert_eq!(lhs_length, rhs_length, "Expected two arrays of equal length");
(lhs_length, lhs_composite_type[0].clone())
}
_ => unreachable!("Expected two array values"),
};

let loop_start = self.builder.insert_block();
let loop_body = self.builder.insert_block();
let loop_end = self.builder.insert_block();

// pre-loop
let result_alloc = self.builder.set_location(location).insert_allocate(Type::bool());
let true_value = self.builder.numeric_constant(1u128, Type::bool());
self.builder.insert_store(result_alloc, true_value);
let zero = self.builder.length_constant(0u128);
self.builder.terminate_with_jmp(loop_start, vec![zero]);

// loop_start
self.builder.switch_to_block(loop_start);
let i = self.builder.add_block_parameter(loop_start, Type::length_type());
let array_length = self.builder.length_constant(array_length as u128);
let v0 = self.builder.insert_binary(i, BinaryOp::Lt, array_length);
self.builder.terminate_with_jmpif(v0, loop_body, loop_end);

// loop body
self.builder.switch_to_block(loop_body);
let v1 = self.builder.insert_array_get(lhs, i, element_type.clone());
let v2 = self.builder.insert_array_get(rhs, i, element_type);
let v3 = self.builder.insert_binary(v1, BinaryOp::Eq, v2);
let v4 = self.builder.insert_load(result_alloc, Type::bool());
let v5 = self.builder.insert_binary(v4, BinaryOp::And, v3);
self.builder.insert_store(result_alloc, v5);
let one = self.builder.length_constant(1u128);
let v6 = self.builder.insert_binary(i, BinaryOp::Add, one);
self.builder.terminate_with_jmp(loop_start, vec![v6]);

// loop end
self.builder.switch_to_block(loop_end);
let mut result = self.builder.insert_load(result_alloc, Type::bool());

if operator_requires_not(operator) {
result = self.builder.insert_not(result);
}
result.into()
}

/// Inserts a call instruction at the end of the current block and returns the results
/// of the call.
///
Expand Down
30 changes: 0 additions & 30 deletions compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -890,36 +890,6 @@ impl<'interner> TypeChecker<'interner> {
// <= and friends are technically valid for booleans, just not very useful
(Bool, Bool) => Ok((Bool, false)),

// Special-case == and != for arrays
(Array(x_size, x_type), Array(y_size, y_type))
if matches!(op.kind, BinaryOpKind::Equal | BinaryOpKind::NotEqual) =>
{
self.unify(x_size, y_size, || TypeCheckError::TypeMismatchWithSource {
expected: lhs_type.clone(),
actual: rhs_type.clone(),
source: Source::ArrayLen,
span: op.location.span,
});

let (_, use_impl) = self.comparator_operand_type_rules(x_type, y_type, op, span)?;

// If the size is not constant, we must fall back to a user-provided impl for
// equality on slices.
let size = x_size.follow_bindings();
let use_impl = use_impl || size.evaluate_to_u64().is_none();
Ok((Bool, use_impl))
}

(String(x_size), String(y_size)) => {
self.unify(x_size, y_size, || TypeCheckError::TypeMismatchWithSource {
expected: *x_size.clone(),
actual: *y_size.clone(),
span: op.location.span,
source: Source::StringLen,
});

Ok((Bool, false))
}
(lhs, rhs) => {
self.unify(lhs, rhs, || TypeCheckError::TypeMismatchWithSource {
expected: lhs.clone(),
Expand Down
7 changes: 7 additions & 0 deletions test_programs/execution_success/regression_4383/Nargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "regression_4383"
type = "bin"
authors = [""]
compiler_version = ">=0.26.0"

[dependencies]
3 changes: 3 additions & 0 deletions test_programs/execution_success/regression_4383/src/main.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
fn main() {
assert([[1]] == [[1]]);
}

0 comments on commit 0cf2e2a

Please sign in to comment.