Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for bitshifts by distances known at runtime #2072

Merged
merged 12 commits into from
Aug 2, 2023
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[package]
name = "bit_shifts_runtime"
authors = [""]
compiler_version = "0.1"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
x = 64
y = 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
fn main(x: u64, y: u64) {
// runtime shifts on comptime values
assert(64 << y == 128);
assert(64 >> y == 32);

// runtime shifts on runtime values
assert(x << y == 128);
assert(x >> y == 32);
}
24 changes: 11 additions & 13 deletions crates/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,10 +336,10 @@ impl<'block> BrilligBlock<'block> {
dfg.instruction_results(instruction_id)[0],
dfg,
);

let heap_vec = self.brillig_context.extract_heap_vector(target_slice);
self.brillig_context.radix_instruction(
source,
self.function_context.extract_heap_vector(target_slice),
heap_vec,
radix,
limb_count,
matches!(endianness, Endian::Big),
Expand All @@ -355,10 +355,10 @@ impl<'block> BrilligBlock<'block> {
);

let radix = self.brillig_context.make_constant(2_usize.into());

let heap_vec = self.brillig_context.extract_heap_vector(target_slice);
self.brillig_context.radix_instruction(
source,
self.function_context.extract_heap_vector(target_slice),
heap_vec,
radix,
limb_count,
matches!(endianness, Endian::Big),
Expand Down Expand Up @@ -589,7 +589,7 @@ impl<'block> BrilligBlock<'block> {
dfg.instruction_results(instruction_id)[0],
dfg,
);
let target_vector = self.function_context.extract_heap_vector(target_variable);
let target_vector = self.brillig_context.extract_heap_vector(target_variable);
let item_value = self.convert_ssa_register_value(arguments[1], dfg);
slice_push_back_operation(
self.brillig_context,
Expand All @@ -604,7 +604,7 @@ impl<'block> BrilligBlock<'block> {
dfg.instruction_results(instruction_id)[0],
dfg,
);
let target_vector = self.function_context.extract_heap_vector(target_variable);
let target_vector = self.brillig_context.extract_heap_vector(target_variable);
let item_value = self.convert_ssa_register_value(arguments[1], dfg);
slice_push_front_operation(
self.brillig_context,
Expand All @@ -618,7 +618,7 @@ impl<'block> BrilligBlock<'block> {

let target_variable =
self.function_context.create_variable(self.brillig_context, results[0], dfg);
let target_vector = self.function_context.extract_heap_vector(target_variable);
let target_vector = self.brillig_context.extract_heap_vector(target_variable);

let pop_item = self.function_context.create_register_variable(
self.brillig_context,
Expand All @@ -643,7 +643,7 @@ impl<'block> BrilligBlock<'block> {
);
let target_variable =
self.function_context.create_variable(self.brillig_context, results[1], dfg);
let target_vector = self.function_context.extract_heap_vector(target_variable);
let target_vector = self.brillig_context.extract_heap_vector(target_variable);

slice_pop_front_operation(
self.brillig_context,
Expand All @@ -659,7 +659,7 @@ impl<'block> BrilligBlock<'block> {
let target_variable =
self.function_context.create_variable(self.brillig_context, results[0], dfg);

let target_vector = self.function_context.extract_heap_vector(target_variable);
let target_vector = self.brillig_context.extract_heap_vector(target_variable);
slice_insert_operation(
self.brillig_context,
target_vector,
Expand All @@ -674,7 +674,7 @@ impl<'block> BrilligBlock<'block> {

let target_variable =
self.function_context.create_variable(self.brillig_context, results[0], dfg);
let target_vector = self.function_context.extract_heap_vector(target_variable);
let target_vector = self.brillig_context.extract_heap_vector(target_variable);

let removed_item_register = self.function_context.create_register_variable(
self.brillig_context,
Expand Down Expand Up @@ -877,7 +877,7 @@ impl<'block> BrilligBlock<'block> {
Type::Slice(_) => {
let variable =
self.function_context.create_variable(self.brillig_context, result, dfg);
let vector = self.function_context.extract_heap_vector(variable);
let vector = self.brillig_context.extract_heap_vector(variable);

// Set the pointer to the current stack frame
// The stack pointer will then be update by the caller of this method
Expand Down Expand Up @@ -981,8 +981,6 @@ pub(crate) fn convert_ssa_binary_op_to_brillig_binary_op(
BinaryOp::And => BinaryIntOp::And,
BinaryOp::Or => BinaryIntOp::Or,
BinaryOp::Xor => BinaryIntOp::Xor,
BinaryOp::Shl => BinaryIntOp::Shl,
BinaryOp::Shr => BinaryIntOp::Shr,
};

BrilligBinaryOp::Integer { op: operation, bit_size }
Expand Down
7 changes: 0 additions & 7 deletions crates/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,6 @@ impl FunctionContext {
}
}

pub(crate) fn extract_heap_vector(&self, variable: RegisterOrMemory) -> HeapVector {
match variable {
RegisterOrMemory::HeapVector(vector) => vector,
_ => unreachable!("ICE: Expected vector, got {variable:?}"),
}
}

/// Collects the registers that a given variable is stored in.
pub(crate) fn extract_registers(&self, variable: RegisterOrMemory) -> Vec<RegisterIndex> {
match variable {
Expand Down
12 changes: 12 additions & 0 deletions crates/noirc_evaluator/src/brillig/brillig_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,18 @@ impl BrilligContext {
self.deallocate_register(end_value_register);
self.deallocate_register(index_at_end_of_array);
}

pub(crate) fn extract_heap_vector(&mut self, variable: RegisterOrMemory) -> HeapVector {
match variable {
RegisterOrMemory::HeapVector(vector) => vector,
RegisterOrMemory::HeapArray(array) => {
let size = self.allocate_register();
self.const_instruction(size, array.size.into());
HeapVector { pointer: array.pointer, size }
}
_ => unreachable!("ICE: Expected vector, got {variable:?}"),
}
}
}

/// Type to encapsulate the binary operation types in Brillig
Expand Down
5 changes: 3 additions & 2 deletions crates/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ impl DebugToString for BinaryIntOp {
BinaryIntOp::And => "&&".into(),
BinaryIntOp::Or => "||".into(),
BinaryIntOp::Xor => "^".into(),
BinaryIntOp::Shl => "<<".into(),
BinaryIntOp::Shr => ">>".into(),
BinaryIntOp::Shl | BinaryIntOp::Shr => {
unreachable!("bit shift should have been replaced")
}
}
}
}
Expand Down
7 changes: 0 additions & 7 deletions crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -796,13 +796,6 @@ impl Context {
bit_count,
self.current_side_effects_enabled_var,
),
BinaryOp::Shl => self.acir_context.shift_left_var(lhs, rhs, binary_type),
BinaryOp::Shr => self.acir_context.shift_right_var(
lhs,
rhs,
binary_type,
self.current_side_effects_enabled_var,
),
BinaryOp::Xor => self.acir_context.xor_var(lhs, rhs, binary_type),
BinaryOp::And => self.acir_context.and_var(lhs, rhs, binary_type),
BinaryOp::Or => self.acir_context.or_var(lhs, rhs, binary_type),
Expand Down
20 changes: 0 additions & 20 deletions crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -748,16 +748,6 @@ impl Binary {
return SimplifyResult::SimplifiedTo(zero);
}
}
BinaryOp::Shl => {
if rhs_is_zero {
return SimplifyResult::SimplifiedTo(self.lhs);
}
}
BinaryOp::Shr => {
if rhs_is_zero {
return SimplifyResult::SimplifiedTo(self.lhs);
}
}
}
SimplifyResult::None
}
Expand Down Expand Up @@ -813,8 +803,6 @@ impl BinaryOp {
BinaryOp::And => None,
BinaryOp::Or => None,
BinaryOp::Xor => None,
BinaryOp::Shl => None,
BinaryOp::Shr => None,
}
}

Expand All @@ -828,8 +816,6 @@ impl BinaryOp {
BinaryOp::And => |x, y| Some(x & y),
BinaryOp::Or => |x, y| Some(x | y),
BinaryOp::Xor => |x, y| Some(x ^ y),
BinaryOp::Shl => |x, y| x.checked_shl(y.try_into().ok()?),
BinaryOp::Shr => |x, y| Some(x >> y),
BinaryOp::Eq => |x, y| Some((x == y) as u128),
BinaryOp::Lt => |x, y| Some((x < y) as u128),
}
Expand Down Expand Up @@ -870,10 +856,6 @@ pub(crate) enum BinaryOp {
Or,
/// Bitwise xor (^)
Xor,
/// Shift lhs left by rhs bits (<<)
Shl,
/// Shift lhs right by rhs bits (>>)
Shr,
}

impl std::fmt::Display for BinaryOp {
Expand All @@ -889,8 +871,6 @@ impl std::fmt::Display for BinaryOp {
BinaryOp::And => write!(f, "and"),
BinaryOp::Or => write!(f, "or"),
BinaryOp::Xor => write!(f, "xor"),
BinaryOp::Shl => write!(f, "shl"),
BinaryOp::Shr => write!(f, "shr"),
}
}
}
Expand Down
77 changes: 61 additions & 16 deletions crates/noirc_evaluator/src/ssa_refactor/ssa_gen/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ use iter_extended::vecmap;
use noirc_errors::Location;
use noirc_frontend::monomorphization::ast::{self, LocalId, Parameters};
use noirc_frontend::monomorphization::ast::{FuncId, Program};
use noirc_frontend::Signedness;
use noirc_frontend::{BinaryOpKind, Signedness};

use crate::ssa_refactor::ir::dfg::DataFlowGraph;
use crate::ssa_refactor::ir::function::FunctionId as IrFunctionId;
use crate::ssa_refactor::ir::function::{Function, RuntimeType};
use crate::ssa_refactor::ir::instruction::BinaryOp;
use crate::ssa_refactor::ir::instruction::{BinaryOp, Endian, Intrinsic};
use crate::ssa_refactor::ir::map::AtomicCounter;
use crate::ssa_refactor::ir::types::{NumericType, Type};
use crate::ssa_refactor::ir::value::ValueId;
Expand Down Expand Up @@ -224,6 +224,46 @@ impl<'a> FunctionContext<'a> {
Values::empty()
}

/// Insert ssa instructions which computes lhs << rhs by doing lhs*2^rhs
fn insert_shift_left(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
let base = self.builder.field_constant(FieldElement::from(2_u128));
let pow = self.pow(base, rhs);
self.builder.insert_binary(lhs, BinaryOp::Mul, pow)
}

/// Insert ssa instructions which computes lhs << rhs by doing lhs/2^rhs
fn insert_shift_right(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
let base = self.builder.field_constant(FieldElement::from(2_u128));
let pow = self.pow(base, rhs);
self.builder.insert_binary(lhs, BinaryOp::Div, pow)
}

/// Computes lhs^rhs via square&multiply, using the bits decomposition of rhs
fn pow(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
let typ = self.builder.current_function.dfg.type_of_value(rhs);
if let Type::Numeric(NumericType::Unsigned { bit_size }) = typ {
let to_bits = self.builder.import_intrinsic_id(Intrinsic::ToBits(Endian::Little));
let length = self.builder.field_constant(FieldElement::from(bit_size as i128));
let result_types = vec![Type::Array(Rc::new(vec![Type::bool()]), bit_size as usize)];
let rhs_bits = self.builder.insert_call(to_bits, vec![rhs, length], result_types)[0];
let one = self.builder.field_constant(FieldElement::one());
let mut r = one;
for i in 1..bit_size + 1 {
let r1 = self.builder.insert_binary(r, BinaryOp::Mul, r);
let a = self.builder.insert_binary(r1, BinaryOp::Mul, lhs);
let idx = self.builder.field_constant(FieldElement::from((bit_size - i) as i128));
let b = self.builder.insert_array_get(rhs_bits, idx, Type::field());
let r2 = self.builder.insert_binary(a, BinaryOp::Mul, b);
let c = self.builder.insert_binary(one, BinaryOp::Sub, b);
let r3 = self.builder.insert_binary(c, BinaryOp::Mul, r1);
r = self.builder.insert_binary(r2, BinaryOp::Add, r3);
}
r
} else {
unreachable!("Value must be unsigned in power operation");
jfecher marked this conversation as resolved.
Show resolved Hide resolved
}
}

/// Insert a binary instruction at the end of the current block.
/// Converts the form of the binary instruction as necessary
/// (e.g. swapping arguments, inserting a not) to represent it in the IR.
Expand All @@ -235,17 +275,22 @@ impl<'a> FunctionContext<'a> {
mut rhs: ValueId,
location: Location,
) -> Values {
let op = convert_operator(operator);

if op == BinaryOp::Eq && matches!(self.builder.type_of_value(lhs), Type::Array(..)) {
return self.insert_array_equality(lhs, operator, rhs, location);
}

if operator_requires_swapped_operands(operator) {
std::mem::swap(&mut lhs, &mut rhs);
}

let mut result = self.builder.set_location(location).insert_binary(lhs, op, rhs);
let mut result = match operator {
BinaryOpKind::ShiftLeft => self.insert_shift_left(lhs, rhs),
BinaryOpKind::ShiftRight => self.insert_shift_right(lhs, rhs),
BinaryOpKind::Equal | BinaryOpKind::NotEqual
if matches!(self.builder.type_of_value(lhs), Type::Array(..)) =>
{
return self.insert_array_equality(lhs, operator, rhs, location)
}
_ => {
let op = convert_operator(operator);
if operator_requires_swapped_operands(operator) {
std::mem::swap(&mut lhs, &mut rhs);
}
self.builder.set_location(location).insert_binary(lhs, op, rhs)
}
};

if let Some(max_bit_size) = operator_result_max_bit_size_to_truncate(
operator,
Expand Down Expand Up @@ -692,7 +737,6 @@ fn operator_result_max_bit_size_to_truncate(
/// checking operator_requires_not and operator_requires_swapped_operands
/// to represent the full operation correctly.
fn convert_operator(op: noirc_frontend::BinaryOpKind) -> BinaryOp {
use noirc_frontend::BinaryOpKind;
match op {
BinaryOpKind::Add => BinaryOp::Add,
BinaryOpKind::Subtract => BinaryOp::Sub,
Expand All @@ -708,8 +752,9 @@ fn convert_operator(op: noirc_frontend::BinaryOpKind) -> BinaryOp {
BinaryOpKind::And => BinaryOp::And,
BinaryOpKind::Or => BinaryOp::Or,
BinaryOpKind::Xor => BinaryOp::Xor,
BinaryOpKind::ShiftRight => BinaryOp::Shr,
BinaryOpKind::ShiftLeft => BinaryOp::Shl,
BinaryOpKind::ShiftRight | BinaryOpKind::ShiftLeft => unreachable!(
"ICE - bit shift operators do not exist in SSA and should have been replaced"
),
}
}

Expand Down
4 changes: 4 additions & 0 deletions crates/noirc_frontend/src/ast/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,10 @@ impl BinaryOpKind {
BinaryOpKind::Modulo => Token::Percent,
}
}

pub fn is_bit_shift(&self) -> bool {
matches!(self, BinaryOpKind::ShiftRight | BinaryOpKind::ShiftLeft)
}
}

#[derive(PartialEq, PartialOrd, Eq, Ord, Hash, Debug, Copy, Clone)]
Expand Down
Loading