From 378c18eb42d75852b97f849d05c9e3f650601339 Mon Sep 17 00:00:00 2001 From: Tom French <15848336+TomAFrench@users.noreply.github.com> Date: Mon, 15 Jan 2024 17:02:35 +0000 Subject: [PATCH] feat: Avoid unnecessary range checks by inspecting instructions for casts (#4039) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … # Description ## Problem\* Resolves ## Summary\* This PR removes unnecessary overflow checks in more scenarios than previously. We do this by inspecting whether the inputs are `Instruction::Cast`s which allows us to determine an more restrictive upper bound on the value and reason about how this affects the result. ## Additional Context ## Documentation\* Check one: - [x] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[Exceptional Case]** Documentation to be submitted in a separate PR. # PR Checklist\* - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --- .../src/ssa/function_builder/mod.rs | 10 ++- compiler/noirc_evaluator/src/ssa/ir/dfg.rs | 19 +++++ .../src/ssa/ssa_gen/context.rs | 84 ++++++++++++++----- 3 files changed, 89 insertions(+), 24 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs b/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs index 2871f149b41..35782ea85ae 100644 --- a/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs @@ -293,8 +293,9 @@ impl FunctionBuilder { if let Some(rhs_constant) = self.current_function.dfg.get_numeric_constant(rhs) { // Happy case is that we know precisely by how many bits the the integer will // increase: lhs_bit_size + rhs - let (rhs_bit_size_pow_2, overflows) = - 2_u128.overflowing_pow(rhs_constant.to_u128() as u32); + let bit_shift_size = rhs_constant.to_u128() as u32; + + let (rhs_bit_size_pow_2, overflows) = 2_u128.overflowing_pow(bit_shift_size); if overflows { assert!(bit_size < 128, "ICE - shift left with big integers are not supported"); if bit_size < 128 { @@ -303,7 +304,10 @@ impl FunctionBuilder { } } let pow = self.numeric_constant(FieldElement::from(rhs_bit_size_pow_2), typ); - (bit_size + (rhs_constant.to_u128() as u32), pow) + + let max_lhs_bits = self.current_function.dfg.get_value_max_num_bits(lhs); + + (max_lhs_bits + bit_shift_size, pow) } else { // we use a predicate to nullify the result in case of overflow let bit_size_var = diff --git a/compiler/noirc_evaluator/src/ssa/ir/dfg.rs b/compiler/noirc_evaluator/src/ssa/ir/dfg.rs index 9942a48a38a..870b5e602f1 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/dfg.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/dfg.rs @@ -337,6 +337,25 @@ impl DataFlowGraph { self.values[value].get_type().clone() } + /// Returns the maximum possible number of bits that `value` can potentially be. + /// + /// Should `value` be a numeric constant then this function will return the exact number of bits required, + /// otherwise it will return the minimum number of bits based on type information. + pub(crate) fn get_value_max_num_bits(&self, value: ValueId) -> u32 { + match self[value] { + Value::Instruction { instruction, .. } => { + if let Instruction::Cast(original_value, _) = self[instruction] { + self.type_of_value(original_value).bit_size() + } else { + self.type_of_value(value).bit_size() + } + } + + Value::NumericConstant { constant, .. } => constant.num_bits(), + _ => self.type_of_value(value).bit_size(), + } + } + /// True if the type of this value is Type::Reference. /// Using this method over type_of_value avoids cloning the value's type. pub(crate) fn value_is_reference(&self, value: ValueId) -> bool { diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs index b34b667c31a..f1a2154d3a8 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs @@ -335,29 +335,71 @@ impl<'a> FunctionContext<'a> { } } Type::Numeric(NumericType::Unsigned { bit_size }) => { - let op_name = match operator { - BinaryOpKind::Add => "add", - BinaryOpKind::Subtract => "subtract", - BinaryOpKind::Multiply => "multiply", - BinaryOpKind::ShiftLeft => "left shift", - _ => unreachable!("operator {} should not overflow", operator), - }; + let dfg = &self.builder.current_function.dfg; - if operator == BinaryOpKind::Multiply && bit_size == 1 { - result - } else if operator == BinaryOpKind::ShiftLeft - || operator == BinaryOpKind::ShiftRight - { - self.check_shift_overflow(result, rhs, bit_size, location, false) - } else { - let message = format!("attempt to {} with overflow", op_name); - self.builder.set_location(location).insert_range_check( - result, - bit_size, - Some(message), - ); - result + let max_lhs_bits = self.builder.current_function.dfg.get_value_max_num_bits(lhs); + let max_rhs_bits = self.builder.current_function.dfg.get_value_max_num_bits(rhs); + + match operator { + BinaryOpKind::Add => { + if std::cmp::max(max_lhs_bits, max_rhs_bits) < bit_size { + // `lhs` and `rhs` have both been casted up from smaller types and so cannot overflow. + return result; + } + + let message = "attempt to add with overflow".to_string(); + self.builder.set_location(location).insert_range_check( + result, + bit_size, + Some(message), + ); + } + BinaryOpKind::Subtract => { + if dfg.is_constant(lhs) && max_lhs_bits > max_rhs_bits { + // `lhs` is a fixed constant and `rhs` is restricted such that `lhs - rhs > 0` + // Note strict inequality as `rhs > lhs` while `max_lhs_bits == max_rhs_bits` is possible. + return result; + } + + let message = "attempt to subtract with overflow".to_string(); + self.builder.set_location(location).insert_range_check( + result, + bit_size, + Some(message), + ); + } + BinaryOpKind::Multiply => { + if bit_size == 1 || max_lhs_bits + max_rhs_bits <= bit_size { + // Either performing boolean multiplication (which cannot overflow), + // or `lhs` and `rhs` have both been casted up from smaller types and so cannot overflow. + return result; + } + + let message = "attempt to multiply with overflow".to_string(); + self.builder.set_location(location).insert_range_check( + result, + bit_size, + Some(message), + ); + } + BinaryOpKind::ShiftLeft => { + if let Some(rhs_const) = dfg.get_numeric_constant(rhs) { + let bit_shift_size = rhs_const.to_u128() as u32; + + if max_lhs_bits + bit_shift_size <= bit_size { + // `lhs` has been casted up from a smaller type such that shifting it by a constant + // `rhs` is known not to exceed the maximum bit size. + return result; + } + } + + self.check_shift_overflow(result, rhs, bit_size, location, false); + } + + _ => unreachable!("operator {} should not overflow", operator), } + + result } _ => result, }