diff --git a/crates/nargo_cli/tests/test_data/bit_shifts_runtime/Nargo.toml b/crates/nargo_cli/tests/test_data/bit_shifts_runtime/Nargo.toml new file mode 100644 index 00000000000..661f4f937d5 --- /dev/null +++ b/crates/nargo_cli/tests/test_data/bit_shifts_runtime/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "bit_shifts_runtime" +authors = [""] +compiler_version = "0.1" + +[dependencies] diff --git a/crates/nargo_cli/tests/test_data/bit_shifts_runtime/Prover.toml b/crates/nargo_cli/tests/test_data/bit_shifts_runtime/Prover.toml new file mode 100644 index 00000000000..98d8630792e --- /dev/null +++ b/crates/nargo_cli/tests/test_data/bit_shifts_runtime/Prover.toml @@ -0,0 +1,2 @@ +x = 64 +y = 1 \ No newline at end of file diff --git a/crates/nargo_cli/tests/test_data/bit_shifts_runtime/src/main.nr b/crates/nargo_cli/tests/test_data/bit_shifts_runtime/src/main.nr new file mode 100644 index 00000000000..271a1ecb880 --- /dev/null +++ b/crates/nargo_cli/tests/test_data/bit_shifts_runtime/src/main.nr @@ -0,0 +1,9 @@ +fn main(x: u64, y: u64) { + // runtime shifts on comptime values + assert(64 << y == 128); + assert(64 >> y == 32); + + // runtime shifts on runtime values + assert(x << y == 128); + assert(x >> y == 32); +} \ No newline at end of file diff --git a/crates/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs b/crates/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs index c7779533a8a..a9bbe189e57 100644 --- a/crates/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs +++ b/crates/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs @@ -336,10 +336,10 @@ impl<'block> BrilligBlock<'block> { dfg.instruction_results(instruction_id)[0], dfg, ); - + let heap_vec = self.brillig_context.extract_heap_vector(target_slice); self.brillig_context.radix_instruction( source, - self.function_context.extract_heap_vector(target_slice), + heap_vec, radix, limb_count, matches!(endianness, Endian::Big), @@ -355,10 +355,10 @@ impl<'block> BrilligBlock<'block> { ); let radix = self.brillig_context.make_constant(2_usize.into()); - + let heap_vec = self.brillig_context.extract_heap_vector(target_slice); self.brillig_context.radix_instruction( source, - self.function_context.extract_heap_vector(target_slice), + heap_vec, radix, limb_count, matches!(endianness, Endian::Big), @@ -589,7 +589,7 @@ impl<'block> BrilligBlock<'block> { dfg.instruction_results(instruction_id)[0], dfg, ); - let target_vector = self.function_context.extract_heap_vector(target_variable); + let target_vector = self.brillig_context.extract_heap_vector(target_variable); let item_value = self.convert_ssa_register_value(arguments[1], dfg); slice_push_back_operation( self.brillig_context, @@ -604,7 +604,7 @@ impl<'block> BrilligBlock<'block> { dfg.instruction_results(instruction_id)[0], dfg, ); - let target_vector = self.function_context.extract_heap_vector(target_variable); + let target_vector = self.brillig_context.extract_heap_vector(target_variable); let item_value = self.convert_ssa_register_value(arguments[1], dfg); slice_push_front_operation( self.brillig_context, @@ -618,7 +618,7 @@ impl<'block> BrilligBlock<'block> { let target_variable = self.function_context.create_variable(self.brillig_context, results[0], dfg); - let target_vector = self.function_context.extract_heap_vector(target_variable); + let target_vector = self.brillig_context.extract_heap_vector(target_variable); let pop_item = self.function_context.create_register_variable( self.brillig_context, @@ -643,7 +643,7 @@ impl<'block> BrilligBlock<'block> { ); let target_variable = self.function_context.create_variable(self.brillig_context, results[1], dfg); - let target_vector = self.function_context.extract_heap_vector(target_variable); + let target_vector = self.brillig_context.extract_heap_vector(target_variable); slice_pop_front_operation( self.brillig_context, @@ -659,7 +659,7 @@ impl<'block> BrilligBlock<'block> { let target_variable = self.function_context.create_variable(self.brillig_context, results[0], dfg); - let target_vector = self.function_context.extract_heap_vector(target_variable); + let target_vector = self.brillig_context.extract_heap_vector(target_variable); slice_insert_operation( self.brillig_context, target_vector, @@ -674,7 +674,7 @@ impl<'block> BrilligBlock<'block> { let target_variable = self.function_context.create_variable(self.brillig_context, results[0], dfg); - let target_vector = self.function_context.extract_heap_vector(target_variable); + let target_vector = self.brillig_context.extract_heap_vector(target_variable); let removed_item_register = self.function_context.create_register_variable( self.brillig_context, @@ -877,7 +877,7 @@ impl<'block> BrilligBlock<'block> { Type::Slice(_) => { let variable = self.function_context.create_variable(self.brillig_context, result, dfg); - let vector = self.function_context.extract_heap_vector(variable); + let vector = self.brillig_context.extract_heap_vector(variable); // Set the pointer to the current stack frame // The stack pointer will then be update by the caller of this method @@ -981,8 +981,6 @@ pub(crate) fn convert_ssa_binary_op_to_brillig_binary_op( BinaryOp::And => BinaryIntOp::And, BinaryOp::Or => BinaryIntOp::Or, BinaryOp::Xor => BinaryIntOp::Xor, - BinaryOp::Shl => BinaryIntOp::Shl, - BinaryOp::Shr => BinaryIntOp::Shr, }; BrilligBinaryOp::Integer { op: operation, bit_size } diff --git a/crates/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs b/crates/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs index 1a751d28b23..210d6da7be6 100644 --- a/crates/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs +++ b/crates/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs @@ -115,13 +115,6 @@ impl FunctionContext { } } - pub(crate) fn extract_heap_vector(&self, variable: RegisterOrMemory) -> HeapVector { - match variable { - RegisterOrMemory::HeapVector(vector) => vector, - _ => unreachable!("ICE: Expected vector, got {variable:?}"), - } - } - /// Collects the registers that a given variable is stored in. pub(crate) fn extract_registers(&self, variable: RegisterOrMemory) -> Vec { match variable { diff --git a/crates/noirc_evaluator/src/brillig/brillig_ir.rs b/crates/noirc_evaluator/src/brillig/brillig_ir.rs index ac0103dd9ed..4471d507579 100644 --- a/crates/noirc_evaluator/src/brillig/brillig_ir.rs +++ b/crates/noirc_evaluator/src/brillig/brillig_ir.rs @@ -951,6 +951,18 @@ impl BrilligContext { self.deallocate_register(end_value_register); self.deallocate_register(index_at_end_of_array); } + + pub(crate) fn extract_heap_vector(&mut self, variable: RegisterOrMemory) -> HeapVector { + match variable { + RegisterOrMemory::HeapVector(vector) => vector, + RegisterOrMemory::HeapArray(array) => { + let size = self.allocate_register(); + self.const_instruction(size, array.size.into()); + HeapVector { pointer: array.pointer, size } + } + _ => unreachable!("ICE: Expected vector, got {variable:?}"), + } + } } /// Type to encapsulate the binary operation types in Brillig diff --git a/crates/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs b/crates/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs index 75716e49177..2bb753de760 100644 --- a/crates/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs +++ b/crates/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs @@ -73,8 +73,9 @@ impl DebugToString for BinaryIntOp { BinaryIntOp::And => "&&".into(), BinaryIntOp::Or => "||".into(), BinaryIntOp::Xor => "^".into(), - BinaryIntOp::Shl => "<<".into(), - BinaryIntOp::Shr => ">>".into(), + BinaryIntOp::Shl | BinaryIntOp::Shr => { + unreachable!("bit shift should have been replaced") + } } } } diff --git a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs index 4a7d2e46775..f00f15d8f05 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs @@ -796,13 +796,6 @@ impl Context { bit_count, self.current_side_effects_enabled_var, ), - BinaryOp::Shl => self.acir_context.shift_left_var(lhs, rhs, binary_type), - BinaryOp::Shr => self.acir_context.shift_right_var( - lhs, - rhs, - binary_type, - self.current_side_effects_enabled_var, - ), BinaryOp::Xor => self.acir_context.xor_var(lhs, rhs, binary_type), BinaryOp::And => self.acir_context.and_var(lhs, rhs, binary_type), BinaryOp::Or => self.acir_context.or_var(lhs, rhs, binary_type), diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs index b7a3ea02ae9..bfdc5f5069a 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs @@ -748,16 +748,6 @@ impl Binary { return SimplifyResult::SimplifiedTo(zero); } } - BinaryOp::Shl => { - if rhs_is_zero { - return SimplifyResult::SimplifiedTo(self.lhs); - } - } - BinaryOp::Shr => { - if rhs_is_zero { - return SimplifyResult::SimplifiedTo(self.lhs); - } - } } SimplifyResult::None } @@ -813,8 +803,6 @@ impl BinaryOp { BinaryOp::And => None, BinaryOp::Or => None, BinaryOp::Xor => None, - BinaryOp::Shl => None, - BinaryOp::Shr => None, } } @@ -828,8 +816,6 @@ impl BinaryOp { BinaryOp::And => |x, y| Some(x & y), BinaryOp::Or => |x, y| Some(x | y), BinaryOp::Xor => |x, y| Some(x ^ y), - BinaryOp::Shl => |x, y| x.checked_shl(y.try_into().ok()?), - BinaryOp::Shr => |x, y| Some(x >> y), BinaryOp::Eq => |x, y| Some((x == y) as u128), BinaryOp::Lt => |x, y| Some((x < y) as u128), } @@ -870,10 +856,6 @@ pub(crate) enum BinaryOp { Or, /// Bitwise xor (^) Xor, - /// Shift lhs left by rhs bits (<<) - Shl, - /// Shift lhs right by rhs bits (>>) - Shr, } impl std::fmt::Display for BinaryOp { @@ -889,8 +871,6 @@ impl std::fmt::Display for BinaryOp { BinaryOp::And => write!(f, "and"), BinaryOp::Or => write!(f, "or"), BinaryOp::Xor => write!(f, "xor"), - BinaryOp::Shl => write!(f, "shl"), - BinaryOp::Shr => write!(f, "shr"), } } } diff --git a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/context.rs b/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/context.rs index c485200a53e..a526d93f85b 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/context.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/context.rs @@ -7,12 +7,12 @@ use iter_extended::vecmap; use noirc_errors::Location; use noirc_frontend::monomorphization::ast::{self, LocalId, Parameters}; use noirc_frontend::monomorphization::ast::{FuncId, Program}; -use noirc_frontend::Signedness; +use noirc_frontend::{BinaryOpKind, Signedness}; use crate::ssa_refactor::ir::dfg::DataFlowGraph; use crate::ssa_refactor::ir::function::FunctionId as IrFunctionId; use crate::ssa_refactor::ir::function::{Function, RuntimeType}; -use crate::ssa_refactor::ir::instruction::BinaryOp; +use crate::ssa_refactor::ir::instruction::{BinaryOp, Endian, Intrinsic}; use crate::ssa_refactor::ir::map::AtomicCounter; use crate::ssa_refactor::ir::types::{NumericType, Type}; use crate::ssa_refactor::ir::value::ValueId; @@ -236,6 +236,46 @@ impl<'a> FunctionContext<'a> { Values::empty() } + /// Insert ssa instructions which computes lhs << rhs by doing lhs*2^rhs + fn insert_shift_left(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId { + let base = self.builder.field_constant(FieldElement::from(2_u128)); + let pow = self.pow(base, rhs); + self.builder.insert_binary(lhs, BinaryOp::Mul, pow) + } + + /// Insert ssa instructions which computes lhs << rhs by doing lhs/2^rhs + fn insert_shift_right(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId { + let base = self.builder.field_constant(FieldElement::from(2_u128)); + let pow = self.pow(base, rhs); + self.builder.insert_binary(lhs, BinaryOp::Div, pow) + } + + /// Computes lhs^rhs via square&multiply, using the bits decomposition of rhs + fn pow(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId { + let typ = self.builder.current_function.dfg.type_of_value(rhs); + if let Type::Numeric(NumericType::Unsigned { bit_size }) = typ { + let to_bits = self.builder.import_intrinsic_id(Intrinsic::ToBits(Endian::Little)); + let length = self.builder.field_constant(FieldElement::from(bit_size as i128)); + let result_types = vec![Type::Array(Rc::new(vec![Type::bool()]), bit_size as usize)]; + let rhs_bits = self.builder.insert_call(to_bits, vec![rhs, length], result_types)[0]; + let one = self.builder.field_constant(FieldElement::one()); + let mut r = one; + for i in 1..bit_size + 1 { + let r1 = self.builder.insert_binary(r, BinaryOp::Mul, r); + let a = self.builder.insert_binary(r1, BinaryOp::Mul, lhs); + let idx = self.builder.field_constant(FieldElement::from((bit_size - i) as i128)); + let b = self.builder.insert_array_get(rhs_bits, idx, Type::field()); + let r2 = self.builder.insert_binary(a, BinaryOp::Mul, b); + let c = self.builder.insert_binary(one, BinaryOp::Sub, b); + let r3 = self.builder.insert_binary(c, BinaryOp::Mul, r1); + r = self.builder.insert_binary(r2, BinaryOp::Add, r3); + } + r + } else { + unreachable!("Value must be unsigned in power operation"); + } + } + /// Insert a binary instruction at the end of the current block. /// Converts the form of the binary instruction as necessary /// (e.g. swapping arguments, inserting a not) to represent it in the IR. @@ -247,17 +287,22 @@ impl<'a> FunctionContext<'a> { mut rhs: ValueId, location: Location, ) -> Values { - let op = convert_operator(operator); - - if op == BinaryOp::Eq && matches!(self.builder.type_of_value(lhs), Type::Array(..)) { - return self.insert_array_equality(lhs, operator, rhs, location); - } - - if operator_requires_swapped_operands(operator) { - std::mem::swap(&mut lhs, &mut rhs); - } - - let mut result = self.builder.set_location(location).insert_binary(lhs, op, rhs); + let mut result = match operator { + BinaryOpKind::ShiftLeft => self.insert_shift_left(lhs, rhs), + BinaryOpKind::ShiftRight => self.insert_shift_right(lhs, rhs), + BinaryOpKind::Equal | BinaryOpKind::NotEqual + if matches!(self.builder.type_of_value(lhs), Type::Array(..)) => + { + return self.insert_array_equality(lhs, operator, rhs, location) + } + _ => { + let op = convert_operator(operator); + if operator_requires_swapped_operands(operator) { + std::mem::swap(&mut lhs, &mut rhs); + } + self.builder.set_location(location).insert_binary(lhs, op, rhs) + } + }; if let Some(max_bit_size) = operator_result_max_bit_size_to_truncate( operator, @@ -704,7 +749,6 @@ fn operator_result_max_bit_size_to_truncate( /// checking operator_requires_not and operator_requires_swapped_operands /// to represent the full operation correctly. fn convert_operator(op: noirc_frontend::BinaryOpKind) -> BinaryOp { - use noirc_frontend::BinaryOpKind; match op { BinaryOpKind::Add => BinaryOp::Add, BinaryOpKind::Subtract => BinaryOp::Sub, @@ -720,8 +764,9 @@ fn convert_operator(op: noirc_frontend::BinaryOpKind) -> BinaryOp { BinaryOpKind::And => BinaryOp::And, BinaryOpKind::Or => BinaryOp::Or, BinaryOpKind::Xor => BinaryOp::Xor, - BinaryOpKind::ShiftRight => BinaryOp::Shr, - BinaryOpKind::ShiftLeft => BinaryOp::Shl, + BinaryOpKind::ShiftRight | BinaryOpKind::ShiftLeft => unreachable!( + "ICE - bit shift operators do not exist in SSA and should have been replaced" + ), } } diff --git a/crates/noirc_frontend/src/ast/expression.rs b/crates/noirc_frontend/src/ast/expression.rs index b1829e8c1ee..b1170ff0ed0 100644 --- a/crates/noirc_frontend/src/ast/expression.rs +++ b/crates/noirc_frontend/src/ast/expression.rs @@ -268,6 +268,10 @@ impl BinaryOpKind { BinaryOpKind::Modulo => Token::Percent, } } + + pub fn is_bit_shift(&self) -> bool { + matches!(self, BinaryOpKind::ShiftRight | BinaryOpKind::ShiftLeft) + } } #[derive(PartialEq, PartialOrd, Eq, Ord, Hash, Debug, Copy, Clone)] diff --git a/crates/noirc_frontend/src/hir/type_check/expr.rs b/crates/noirc_frontend/src/hir/type_check/expr.rs index 12c11bf20e1..24ac5f3443e 100644 --- a/crates/noirc_frontend/src/hir/type_check/expr.rs +++ b/crates/noirc_frontend/src/hir/type_check/expr.rs @@ -12,7 +12,7 @@ use crate::{ }, node_interner::{DefinitionKind, ExprId, FuncId}, token::Attribute::Deprecated, - CompTime, Shared, TypeBinding, TypeVariableKind, UnaryOp, + CompTime, Shared, Signedness, TypeBinding, TypeVariableKind, UnaryOp, }; use super::{errors::TypeCheckError, TypeChecker}; @@ -954,7 +954,7 @@ impl<'interner> TypeChecker<'interner> { if op.is_bitwise() && (other.is_bindable() || other.is_field()) { let other = other.follow_bindings(); - + let kind = op.kind; // This will be an error if these types later resolve to a Field, or stay // polymorphic as the bit size will be unknown. Delay this error until the function // finishes resolving so we can still allow cases like `let x: u8 = 1 << 2;`. @@ -963,6 +963,12 @@ impl<'interner> TypeChecker<'interner> { Err(TypeCheckError::InvalidBitwiseOperationOnField { span }) } else if other.is_bindable() { Err(TypeCheckError::AmbiguousBitWidth { span }) + } else if kind.is_bit_shift() && other.is_signed() { + Err(TypeCheckError::TypeCannotBeUsed { + typ: other, + place: "bit shift", + span, + }) } else { Ok(()) } @@ -1001,8 +1007,14 @@ impl<'interner> TypeChecker<'interner> { span, }); } - let comptime = comptime_x.and(comptime_y, op.location.span); - Ok(Integer(comptime, *sign_x, *bit_width_x)) + if op.is_bit_shift() + && (*sign_x == Signedness::Signed || *sign_y == Signedness::Signed) + { + Err(TypeCheckError::InvalidInfixOp { kind: "Signed integer", span }) + } else { + let comptime = comptime_x.and(comptime_y, op.location.span); + Ok(Integer(comptime, *sign_x, *bit_width_x)) + } } (Integer(..), FieldElement(..)) | (FieldElement(..), Integer(..)) => { Err(TypeCheckError::IntegerAndFieldBinaryOperation { span }) diff --git a/crates/noirc_frontend/src/hir_def/expr.rs b/crates/noirc_frontend/src/hir_def/expr.rs index 5db9751591a..db7db0a803d 100644 --- a/crates/noirc_frontend/src/hir_def/expr.rs +++ b/crates/noirc_frontend/src/hir_def/expr.rs @@ -72,6 +72,10 @@ impl HirBinaryOp { use BinaryOpKind::*; matches!(self.kind, And | Or | Xor | ShiftRight | ShiftLeft) } + + pub fn is_bit_shift(&self) -> bool { + self.kind.is_bit_shift() + } } #[derive(Debug, Clone)] diff --git a/crates/noirc_frontend/src/hir_def/types.rs b/crates/noirc_frontend/src/hir_def/types.rs index df4c2f6c229..ff0a4e53fae 100644 --- a/crates/noirc_frontend/src/hir_def/types.rs +++ b/crates/noirc_frontend/src/hir_def/types.rs @@ -659,6 +659,10 @@ impl Type { matches!(self.follow_bindings(), Type::FieldElement(_)) } + pub fn is_signed(&self) -> bool { + matches!(self.follow_bindings(), Type::Integer(_, Signedness::Signed, _)) + } + fn contains_numeric_typevar(&self, target_id: TypeVariableId) -> bool { // True if the given type is a NamedGeneric with the target_id let named_generic_id_matches_target = |typ: &Type| {