Skip to content

Commit

Permalink
feat: Emission for CallIndirect nodes (#73)
Browse files Browse the repository at this point in the history
Closes #10
  • Loading branch information
mark-koch authored Aug 16, 2024
1 parent 21b6105 commit 042ad5d
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 4 deletions.
27 changes: 24 additions & 3 deletions src/emit/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@ use anyhow::{anyhow, Result};
use hugr::{
hugr::views::SiblingGraph,
ops::{
constant::Sum, Call, Case, Conditional, Const, Input, LoadConstant, LoadFunction,
MakeTuple, OpTag, OpTrait, OpType, Output, Tag, UnpackTuple, Value, CFG,
constant::Sum, Call, CallIndirect, Case, Conditional, Const, Input, LoadConstant,
LoadFunction, MakeTuple, OpTag, OpTrait, OpType, Output, Tag, UnpackTuple, Value, CFG,
},
types::{SumType, Type, TypeEnum},
HugrView, NodeIndex,
};
use inkwell::{builder::Builder, values::BasicValueEnum};
use inkwell::{
builder::Builder,
values::{BasicValueEnum, CallableValue},
};
use itertools::Itertools;
use petgraph::visit::Walker;

Expand Down Expand Up @@ -337,6 +340,23 @@ fn emit_call<'c, H: HugrView>(
args.outputs.finish(builder, call_results)
}

fn emit_call_indirect<'c, H: HugrView>(
context: &mut EmitFuncContext<'c, H>,
args: EmitOpArgs<'c, CallIndirect, H>,
) -> Result<()> {
let func_ptr = match args.inputs[0] {
BasicValueEnum::PointerValue(v) => Ok(v),
_ => Err(anyhow!("emit_call_indirect: Not a pointer")),
}?;
let func =
CallableValue::try_from(func_ptr).expect("emit_call_indirect: Not a function pointer");
let inputs = args.inputs.into_iter().skip(1).map_into().collect_vec();
let builder = context.builder();
let call = builder.build_call(func, inputs.as_slice(), "")?;
let call_results = deaggregate_call_result(builder, call, args.outputs.len())?;
args.outputs.finish(builder, call_results)
}

fn emit_load_function<'c, H: HugrView>(
context: &mut EmitFuncContext<'c, H>,
args: EmitOpArgs<'c, LoadFunction, H>,
Expand Down Expand Up @@ -385,6 +405,7 @@ fn emit_optype<'c, H: HugrView>(
}
OpType::LoadConstant(ref lc) => emit_load_constant(context, args.into_ot(lc)),
OpType::Call(ref cl) => emit_call(context, args.into_ot(cl)),
OpType::CallIndirect(ref cl) => emit_call_indirect(context, args.into_ot(cl)),
OpType::LoadFunction(ref lf) => emit_load_function(context, args.into_ot(lf)),
OpType::Conditional(ref co) => emit_conditional(context, args.into_ot(co)),
OpType::CFG(ref cfg) => emit_cfg(context, args.into_ot(cfg)),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
---
source: src/emit/test.rs
expression: module.to_string()
---
; ModuleID = 'test_context'
source_filename = "test_context"

define void @_hl.main_void.1() {
alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
call void @_hl.main_void.1()
ret void
}

define { i32, {}, {} } @_hl.main_unary.6({ i32, {}, {} } %0) {
alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
%1 = call { i32, {}, {} } @_hl.main_unary.6({ i32, {}, {} } %0)
ret { i32, {}, {} } %1
}

define { { i32, {}, {} }, { i32, {}, {} } } @_hl.main_binary.11({ i32, {}, {} } %0, { i32, {}, {} } %1) {
alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
%2 = call { { i32, {}, {} }, { i32, {}, {} } } @_hl.main_binary.11({ i32, {}, {} } %0, { i32, {}, {} } %1)
%3 = extractvalue { { i32, {}, {} }, { i32, {}, {} } } %2, 0
%4 = extractvalue { { i32, {}, {} }, { i32, {}, {} } } %2, 1
%mrv = insertvalue { { i32, {}, {} }, { i32, {}, {} } } undef, { i32, {}, {} } %3, 0
%mrv8 = insertvalue { { i32, {}, {} }, { i32, {}, {} } } %mrv, { i32, {}, {} } %4, 1
ret { { i32, {}, {} }, { i32, {}, {} } } %mrv8
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
---
source: src/emit/test.rs
expression: module.to_string()
---
; ModuleID = 'test_context'
source_filename = "test_context"

define void @_hl.main_void.1() {
alloca_block:
%"4_0" = alloca void ()*, align 8
br label %entry_block

entry_block: ; preds = %alloca_block
store void ()* @_hl.main_void.1, void ()** %"4_0", align 8
%"4_01" = load void ()*, void ()** %"4_0", align 8
call void %"4_01"()
ret void
}

define { i32, {}, {} } @_hl.main_unary.6({ i32, {}, {} } %0) {
alloca_block:
%"0" = alloca { i32, {}, {} }, align 8
%"7_0" = alloca { i32, {}, {} }, align 8
%"9_0" = alloca { i32, {}, {} } ({ i32, {}, {} })*, align 8
%"10_0" = alloca { i32, {}, {} }, align 8
br label %entry_block

entry_block: ; preds = %alloca_block
store { i32, {}, {} } %0, { i32, {}, {} }* %"7_0", align 4
store { i32, {}, {} } ({ i32, {}, {} })* @_hl.main_unary.6, { i32, {}, {} } ({ i32, {}, {} })** %"9_0", align 8
%"9_01" = load { i32, {}, {} } ({ i32, {}, {} })*, { i32, {}, {} } ({ i32, {}, {} })** %"9_0", align 8
%"7_02" = load { i32, {}, {} }, { i32, {}, {} }* %"7_0", align 4
%1 = call { i32, {}, {} } %"9_01"({ i32, {}, {} } %"7_02")
store { i32, {}, {} } %1, { i32, {}, {} }* %"10_0", align 4
%"10_03" = load { i32, {}, {} }, { i32, {}, {} }* %"10_0", align 4
store { i32, {}, {} } %"10_03", { i32, {}, {} }* %"0", align 4
%"04" = load { i32, {}, {} }, { i32, {}, {} }* %"0", align 4
ret { i32, {}, {} } %"04"
}

define { { i32, {}, {} }, { i32, {}, {} } } @_hl.main_binary.11({ i32, {}, {} } %0, { i32, {}, {} } %1) {
alloca_block:
%"0" = alloca { i32, {}, {} }, align 8
%"1" = alloca { i32, {}, {} }, align 8
%"12_0" = alloca { i32, {}, {} }, align 8
%"12_1" = alloca { i32, {}, {} }, align 8
%"14_0" = alloca { { i32, {}, {} }, { i32, {}, {} } } ({ i32, {}, {} }, { i32, {}, {} })*, align 8
%"15_0" = alloca { i32, {}, {} }, align 8
%"15_1" = alloca { i32, {}, {} }, align 8
br label %entry_block

entry_block: ; preds = %alloca_block
store { i32, {}, {} } %0, { i32, {}, {} }* %"12_0", align 4
store { i32, {}, {} } %1, { i32, {}, {} }* %"12_1", align 4
store { { i32, {}, {} }, { i32, {}, {} } } ({ i32, {}, {} }, { i32, {}, {} })* @_hl.main_binary.11, { { i32, {}, {} }, { i32, {}, {} } } ({ i32, {}, {} }, { i32, {}, {} })** %"14_0", align 8
%"14_01" = load { { i32, {}, {} }, { i32, {}, {} } } ({ i32, {}, {} }, { i32, {}, {} })*, { { i32, {}, {} }, { i32, {}, {} } } ({ i32, {}, {} }, { i32, {}, {} })** %"14_0", align 8
%"12_02" = load { i32, {}, {} }, { i32, {}, {} }* %"12_0", align 4
%"12_13" = load { i32, {}, {} }, { i32, {}, {} }* %"12_1", align 4
%2 = call { { i32, {}, {} }, { i32, {}, {} } } %"14_01"({ i32, {}, {} } %"12_02", { i32, {}, {} } %"12_13")
%3 = extractvalue { { i32, {}, {} }, { i32, {}, {} } } %2, 0
%4 = extractvalue { { i32, {}, {} }, { i32, {}, {} } } %2, 1
store { i32, {}, {} } %3, { i32, {}, {} }* %"15_0", align 4
store { i32, {}, {} } %4, { i32, {}, {} }* %"15_1", align 4
%"15_04" = load { i32, {}, {} }, { i32, {}, {} }* %"15_0", align 4
%"15_15" = load { i32, {}, {} }, { i32, {}, {} }* %"15_1", align 4
store { i32, {}, {} } %"15_04", { i32, {}, {} }* %"0", align 4
store { i32, {}, {} } %"15_15", { i32, {}, {} }* %"1", align 4
%"06" = load { i32, {}, {} }, { i32, {}, {} }* %"0", align 4
%"17" = load { i32, {}, {} }, { i32, {}, {} }* %"1", align 4
%mrv = insertvalue { { i32, {}, {} }, { i32, {}, {} } } undef, { i32, {}, {} } %"06", 0
%mrv8 = insertvalue { { i32, {}, {} }, { i32, {}, {} } } %mrv, { i32, {}, {} } %"17", 1
ret { { i32, {}, {} }, { i32, {}, {} } } %mrv8
}
26 changes: 25 additions & 1 deletion src/emit/test.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::iter;

use crate::custom::int::add_int_extensions;
use crate::types::HugrFuncType;
use hugr::builder::DataflowSubContainer;
Expand All @@ -8,7 +10,7 @@ use hugr::extension::prelude::BOOL_T;
use hugr::extension::{ExtensionRegistry, ExtensionSet, EMPTY_REG};
use hugr::ops::constant::CustomConst;
use hugr::ops::handle::FuncID;
use hugr::ops::{Tag, UnpackTuple, Value};
use hugr::ops::{CallIndirect, Tag, UnpackTuple, Value};
use hugr::std_extensions::arithmetic::int_ops::{self, INT_OPS_REGISTRY};
use hugr::std_extensions::arithmetic::int_types::ConstInt;
use hugr::types::{Signature, Type, TypeRow};
Expand Down Expand Up @@ -253,6 +255,28 @@ fn emit_hugr_call(llvm_ctx: TestContext) {
check_emission!(hugr, llvm_ctx);
}

#[rstest]
fn emit_hugr_call_indirect(llvm_ctx: TestContext) {
fn build_recursive(mod_b: &mut ModuleBuilder<Hugr>, name: &str, io: TypeRow) {
let signature = HugrFuncType::new_endo(io);
let f_id = mod_b.declare(name, signature.clone().into()).unwrap();
let mut func_b = mod_b.define_declaration(&f_id).unwrap();
let func = func_b.load_func(&f_id, &[], &EMPTY_REG).unwrap();
let inputs = iter::once(func).chain(func_b.input_wires());
let call_indirect = func_b
.add_dataflow_op(CallIndirect { signature }, inputs)
.unwrap();
func_b.finish_with_outputs(call_indirect.outputs()).unwrap();
}

let mut mod_b = ModuleBuilder::new();
build_recursive(&mut mod_b, "main_void", type_row![]);
build_recursive(&mut mod_b, "main_unary", type_row![BOOL_T]);
build_recursive(&mut mod_b, "main_binary", type_row![BOOL_T, BOOL_T]);
let hugr = mod_b.finish_hugr(&EMPTY_REG).unwrap();
check_emission!(hugr, llvm_ctx);
}

#[rstest]
fn emit_hugr_custom_op(mut llvm_ctx: TestContext) {
llvm_ctx.add_extensions(add_int_extensions);
Expand Down

0 comments on commit 042ad5d

Please sign in to comment.