From 2f6408746182a3bfd43c6eac2531df37272bcebd Mon Sep 17 00:00:00 2001 From: doug-q <141026920+doug-q@users.noreply.github.com> Date: Tue, 18 Jun 2024 07:17:51 +0100 Subject: [PATCH] feat: add `get_extern_func` (#28) --- src/emit.rs | 58 ++++++++++++++++++++++++++++++++++++------------ src/emit/func.rs | 13 +++++++++++ src/emit/test.rs | 13 +++++++++++ src/test.rs | 17 +++++++++++++- 4 files changed, 86 insertions(+), 15 deletions(-) diff --git a/src/emit.rs b/src/emit.rs index 3b3cf15..07834ab 100644 --- a/src/emit.rs +++ b/src/emit.rs @@ -2,11 +2,12 @@ use anyhow::{anyhow, Result}; use delegate::delegate; use hugr::{ ops::{FuncDecl, FuncDefn, NamedOp as _, OpType}, + types::PolyFuncType, HugrView, Node, NodeIndex, }; use inkwell::{ context::Context, - module::Module, + module::{Linkage, Module}, types::{BasicTypeEnum, FunctionType}, values::{BasicValueEnum, FunctionValue}, }; @@ -205,39 +206,68 @@ impl<'c, H: HugrView> EmitModuleContext<'c, H> { fn get_func_impl( &self, name: impl AsRef, - node: Node, - func_ty: &hugr::types::PolyFuncType, + func_ty: FunctionType<'c>, + linkage: Option, ) -> Result> { - let sig = (func_ty.params().is_empty()) - .then_some(func_ty.body()) - .ok_or(anyhow!("function has type params"))?; - let llvm_func_ty = self.llvm_func_type(sig)?; - let name = self.name_func(name, node); let func = self .module() - .get_function(&name) - .unwrap_or_else(|| self.module.add_function(&name, llvm_func_ty, None)); - if func.get_type() != llvm_func_ty { + .get_function(name.as_ref()) + .unwrap_or_else(|| self.module.add_function(name.as_ref(), func_ty, linkage)); + if func.get_type() != func_ty { Err(anyhow!( - "Function '{name}' has wrong type: hugr: {func_ty} expected: {llvm_func_ty} actual: {}", + "Function '{}' has wrong type: expected: {func_ty} actual: {}", + name.as_ref(), func.get_type() ))? } Ok(func) } + fn get_hugr_func_impl( + &self, + name: impl AsRef, + node: Node, + func_ty: &PolyFuncType, + ) -> Result> { + let func_ty = (func_ty.params().is_empty()) + .then_some(func_ty.body()) + .ok_or(anyhow!("function has type params"))?; + let llvm_func_ty = self.llvm_func_type(func_ty)?; + let name = self.name_func(name, node); + self.get_func_impl(name, llvm_func_ty, None) + } + /// Adds or gets the [FunctionValue] in the [Module] corresponding to the given [FuncDefn]. /// /// The name of the result is mangled by [EmitModuleContext::name_func]. pub fn get_func_defn(&self, node: FatNode<'c, FuncDefn, H>) -> Result> { - self.get_func_impl(&node.name, node.node(), &node.signature) + self.get_hugr_func_impl(&node.name, node.node(), &node.signature) } /// Adds or gets the [FunctionValue] in the [Module] corresponding to the given [FuncDecl]. /// /// The name of the result is mangled by [EmitModuleContext::name_func]. pub fn get_func_decl(&self, node: FatNode<'c, FuncDecl, H>) -> Result> { - self.get_func_impl(&node.name, node.node(), &node.signature) + self.get_hugr_func_impl(&node.name, node.node(), &node.signature) + } + + /// Adds or get the [FunctionValue] in the [Module] with the given symbol + /// and function type. + /// + /// The name undergoes no mangling. The [FunctionValue] will have + /// [Linkage::External]. + /// + /// If this function is called multiple times with the same arguments it + /// will return the same [FunctionValue]. + /// + /// If a function with the given name exists but the type does not match + /// then an Error is returned. + pub fn get_extern_func( + &self, + symbol: impl AsRef, + typ: FunctionType<'c>, + ) -> Result> { + self.get_func_impl(symbol, typ, Some(Linkage::External)) } /// Consumes the `EmitModuleContext` and returns the internal [Module]. diff --git a/src/emit/func.rs b/src/emit/func.rs index f0bcfab..ec5c624 100644 --- a/src/emit/func.rs +++ b/src/emit/func.rs @@ -59,6 +59,7 @@ pub struct EmitFuncContext<'c, H: HugrView> { impl<'c, H: HugrView> EmitFuncContext<'c, H> { delegate! { to self.emit_context { + /// Returns the inkwell [Context]. fn iw_context(&self) -> &'c Context; /// Returns the internal [CodegenExtsMap] . pub fn extensions(&self) -> Rc>; @@ -78,6 +79,18 @@ impl<'c, H: HugrView> EmitFuncContext<'c, H> { /// /// The name of the result may have been mangled. pub fn get_func_decl(&self, node: FatNode<'c, FuncDecl, H>) -> Result>; + /// Adds or get the [FunctionValue] in the [inkwell::module::Module] with the given symbol + /// and function type. + /// + /// The name undergoes no mangling. The [FunctionValue] will have + /// [inkwell::module::Linkage::External]. + /// + /// If this function is called multiple times with the same arguments it + /// will return the same [FunctionValue]. + /// + /// If a function with the given name exists but the type does not match + /// then an Error is returned. + pub fn get_extern_func(&self, symbol: impl AsRef, typ: FunctionType<'c>,) -> Result>; } } diff --git a/src/emit/test.rs b/src/emit/test.rs index ce82689..c5d8aef 100644 --- a/src/emit/test.rs +++ b/src/emit/test.rs @@ -267,3 +267,16 @@ fn emit_hugr_custom_op(#[with(-1, add_int_extensions)] llvm_ctx: TestContext) { }); check_emission!(hugr, llvm_ctx); } + +#[rstest] +fn get_external_func(llvm_ctx: TestContext) { + llvm_ctx.with_emit_module_context(|emc| { + let func_type1 = emc.iw_context().i32_type().fn_type(&[], false); + let func_type2 = emc.iw_context().f64_type().fn_type(&[], false); + let foo1 = emc.get_extern_func("foo", func_type1).unwrap(); + assert_eq!(foo1.get_name().to_str().unwrap(), "foo"); + let foo2 = emc.get_extern_func("foo", func_type1).unwrap(); + assert_eq!(foo1, foo2); + assert!(emc.get_extern_func("foo", func_type2).is_err()); + }); +} diff --git a/src/test.rs b/src/test.rs index 5ea1731..231ece2 100644 --- a/src/test.rs +++ b/src/test.rs @@ -11,7 +11,7 @@ use rstest::fixture; use crate::{ custom::CodegenExtsMap, - emit::EmitHugr, + emit::{EmitHugr, EmitModuleContext, Namer}, types::{TypeConverter, TypingSession}, }; @@ -120,6 +120,21 @@ impl TestContext { (r, ectx.finish()) }) } + + pub fn with_emit_module_context<'c, T>( + &'c self, + f: impl FnOnce(EmitModuleContext<'c, THugrView>) -> T, + ) -> T { + self.with_context(|ctx| { + let m = ctx.create_module("test_module"); + f(EmitModuleContext::new( + m, + Namer::default().into(), + self.extensions(), + TypeConverter::new(ctx), + )) + }) + } } #[fixture]