Skip to content

Commit

Permalink
feat(hugr-llvm): Emit ipow (#1839)
Browse files Browse the repository at this point in the history
Co-authored-by: Seyon Sivarajah <seyon.sivarajah@quantinuum.com>
  • Loading branch information
croyzor and ss2165 authored Jan 10, 2025
1 parent 74ce446 commit 89e671a
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 3 deletions.
2 changes: 1 addition & 1 deletion hugr-core/src/std_extensions/arithmetic/int_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ impl MakeOpDef for IntOpDef {
idiv_s => "as idivmod_s but discarding the second output",
imod_checked_s => "as idivmod_checked_s but discarding the first output",
imod_s => "as idivmod_s but discarding the first output",
ipow => "raise first input to the power of second input",
ipow => "raise first input to the power of second input, the exponent is treated as an unsigned integer",
iabs => "convert signed to unsigned by taking absolute value",
iand => "bitwise AND",
ior => "bitwise OR",
Expand Down
65 changes: 65 additions & 0 deletions hugr-llvm/src/extension/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,63 @@ fn emit_icmp<'c, H: HugrView>(
})
}

/// Emit an ipow operation. This isn't directly supported in llvm, so we do a
/// loop over the exponent, performing `imul`s instead.
/// The insertion pointer is expected to be pointing to the end of `launch_bb`.
fn emit_ipow<'c, H: HugrView>(
context: &mut EmitFuncContext<'c, '_, H>,
args: EmitOpArgs<'c, '_, ExtensionOp, H>,
) -> Result<()> {
emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
let done_bb = ctx.new_basic_block("done", None);
let pow_body_bb = ctx.new_basic_block("pow_body", Some(done_bb));
let return_one_bb = ctx.new_basic_block("power_of_zero", Some(pow_body_bb));
let pow_bb = ctx.new_basic_block("pow", Some(return_one_bb));

let acc_p = ctx.builder().build_alloca(lhs.get_type(), "acc_ptr")?;
let exp_p = ctx.builder().build_alloca(rhs.get_type(), "exp_ptr")?;
ctx.builder().build_store(acc_p, lhs)?;
ctx.builder().build_store(exp_p, rhs)?;
ctx.builder().build_unconditional_branch(pow_bb)?;

let zero = rhs.get_type().into_int_type().const_int(0, false);
// Assumes RHS type is the same as output type (which it should be)
let one = rhs.get_type().into_int_type().const_int(1, false);

// Block for just returning one
ctx.builder().position_at_end(return_one_bb);
ctx.builder().build_store(acc_p, one)?;
ctx.builder().build_unconditional_branch(done_bb)?;

ctx.builder().position_at_end(pow_bb);
let acc = ctx.builder().build_load(acc_p, "acc")?;
let exp = ctx.builder().build_load(exp_p, "exp")?;

// Special case if the exponent is 0 or 1
ctx.builder().build_switch(
exp.into_int_value(),
pow_body_bb,
&[(one, done_bb), (zero, return_one_bb)],
)?;

// Block that performs one `imul` and modifies the values in the store
ctx.builder().position_at_end(pow_body_bb);
let new_acc =
ctx.builder()
.build_int_mul(acc.into_int_value(), lhs.into_int_value(), "new_acc")?;
let new_exp = ctx
.builder()
.build_int_sub(exp.into_int_value(), one, "new_exp")?;
ctx.builder().build_store(acc_p, new_acc)?;
ctx.builder().build_store(exp_p, new_exp)?;
ctx.builder().build_unconditional_branch(pow_bb)?;

ctx.builder().position_at_end(done_bb);
let result = ctx.builder().build_load(acc_p, "result")?;
Ok(vec![result.as_basic_value_enum()])
})
}

fn emit_int_op<'c, H: HugrView>(
context: &mut EmitFuncContext<'c, '_, H>,
args: EmitOpArgs<'c, '_, ExtensionOp, H>,
Expand Down Expand Up @@ -223,6 +280,7 @@ fn emit_int_op<'c, H: HugrView>(
.build_and(lhs.into_int_value(), rhs.into_int_value(), "")?
.as_basic_value_enum()])
}),
IntOpDef::ipow => emit_ipow(context, args),
_ => Err(anyhow!("IntOpEmitter: unimplemented op: {}", op.name())),
}
}
Expand Down Expand Up @@ -364,6 +422,7 @@ mod test {
#[rstest]
#[case::iadd("iadd", 3)]
#[case::isub("isub", 6)]
#[case::ipow("ipow", 3)]
fn test_binop_emission(mut llvm_ctx: TestContext, #[case] op: String, #[case] width: u8) {
llvm_ctx.add_extensions(add_int_extensions);
let hugr = test_binary_int_op(op.clone(), width);
Expand Down Expand Up @@ -398,6 +457,9 @@ mod test {
#[case::iand("iand", 6, 15, 6)]
#[case::iand("iand", 15, 6, 6)]
#[case::iand("iand", 15, 15, 15)]
#[case::ipow("ipow", 2, 3, 8)]
#[case::ipow("ipow", 42, 1, 42)]
#[case::ipow("ipow", 42, 0, 1)]
fn test_exec_unsigned_bin_op(
mut exec_ctx: TestContext,
#[case] op: String,
Expand Down Expand Up @@ -428,6 +490,9 @@ mod test {
#[case::imin("imin_s", -1, -2, -2)]
#[case::imin("imin_s", -2, -1, -2)]
#[case::imin("imin_s", -2, -2, -2)]
#[case::ipow("ipow", -2, 1, -2)]
#[case::ipow("ipow", -2, 2, 4)]
#[case::ipow("ipow", -2, 3, -8)]
fn test_exec_signed_bin_op(
mut exec_ctx: TestContext,
#[case] op: String,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
---
source: hugr-llvm/src/extension/int.rs
expression: mod_str
---
; ModuleID = 'test_context'
source_filename = "test_context"

define i8 @_hl.main.1(i8 %0, i8 %1) {
alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
%acc_ptr = alloca i8, align 1
%exp_ptr = alloca i8, align 1
store i8 %0, i8* %acc_ptr, align 1
store i8 %1, i8* %exp_ptr, align 1
br label %pow

pow: ; preds = %pow_body, %entry_block
%acc = load i8, i8* %acc_ptr, align 1
%exp = load i8, i8* %exp_ptr, align 1
switch i8 %exp, label %pow_body [
i8 1, label %done
i8 0, label %power_of_zero
]

power_of_zero: ; preds = %pow
store i8 1, i8* %acc_ptr, align 1
br label %done

pow_body: ; preds = %pow
%new_acc = mul i8 %acc, %0
%new_exp = sub i8 %exp, 1
store i8 %new_acc, i8* %acc_ptr, align 1
store i8 %new_exp, i8* %exp_ptr, align 1
br label %pow

done: ; preds = %pow, %power_of_zero
%result = load i8, i8* %acc_ptr, align 1
ret i8 %result
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
---
source: hugr-llvm/src/extension/int.rs
expression: mod_str
---
; ModuleID = 'test_context'
source_filename = "test_context"

define i8 @_hl.main.1(i8 %0, i8 %1) {
alloca_block:
%"0" = alloca i8, align 1
%"2_0" = alloca i8, align 1
%"2_1" = alloca i8, align 1
%"4_0" = alloca i8, align 1
br label %entry_block

entry_block: ; preds = %alloca_block
store i8 %0, i8* %"2_0", align 1
store i8 %1, i8* %"2_1", align 1
%"2_01" = load i8, i8* %"2_0", align 1
%"2_12" = load i8, i8* %"2_1", align 1
%acc_ptr = alloca i8, align 1
%exp_ptr = alloca i8, align 1
store i8 %"2_01", i8* %acc_ptr, align 1
store i8 %"2_12", i8* %exp_ptr, align 1
br label %pow

pow: ; preds = %pow_body, %entry_block
%acc = load i8, i8* %acc_ptr, align 1
%exp = load i8, i8* %exp_ptr, align 1
switch i8 %exp, label %pow_body [
i8 1, label %done
i8 0, label %power_of_zero
]

power_of_zero: ; preds = %pow
store i8 1, i8* %acc_ptr, align 1
br label %done

pow_body: ; preds = %pow
%new_acc = mul i8 %acc, %"2_01"
%new_exp = sub i8 %exp, 1
store i8 %new_acc, i8* %acc_ptr, align 1
store i8 %new_exp, i8* %exp_ptr, align 1
br label %pow

done: ; preds = %pow, %power_of_zero
%result = load i8, i8* %acc_ptr, align 1
store i8 %result, i8* %"4_0", align 1
%"4_03" = load i8, i8* %"4_0", align 1
store i8 %"4_03", i8* %"0", align 1
%"04" = load i8, i8* %"0", align 1
ret i8 %"04"
}
2 changes: 1 addition & 1 deletion hugr-py/src/hugr/std/_json_defs/arithmetic/int.json
Original file line number Diff line number Diff line change
Expand Up @@ -2459,7 +2459,7 @@
"ipow": {
"extension": "arithmetic.int",
"name": "ipow",
"description": "raise first input to the power of second input",
"description": "raise first input to the power of second input, the exponent is treated as an unsigned integer",
"signature": {
"params": [
{
Expand Down
2 changes: 1 addition & 1 deletion specification/std_extensions/arithmetic/int.json
Original file line number Diff line number Diff line change
Expand Up @@ -2459,7 +2459,7 @@
"ipow": {
"extension": "arithmetic.int",
"name": "ipow",
"description": "raise first input to the power of second input",
"description": "raise first input to the power of second input, the exponent is treated as an unsigned integer",
"signature": {
"params": [
{
Expand Down

0 comments on commit 89e671a

Please sign in to comment.