diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index 8c314294d..8f5341b08 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -20,7 +20,7 @@ use crate::{ }; use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE_REGISTRY}; -use crate::types::{FunctionType, PolyFuncType, Type, TypeRow}; +use crate::types::{FunctionType, PolyFuncType, Type, TypeArg, TypeRow}; use itertools::Itertools; @@ -595,7 +595,9 @@ pub trait Dataflow: Container { fn call( &mut self, function: &FuncID, + type_args: &[TypeArg], input_wires: impl IntoIterator, + exts: &ExtensionRegistry, // TODO remove? ) -> Result, BuildError> { let hugr = self.hugr(); let def_op = hugr.get_optype(function.node()); @@ -609,9 +611,16 @@ pub trait Dataflow: Container { }) } }; - let signature = type_scheme.body; + // TODO either some way of returning a SignatureError without a node (as its not + // constructed yet) - or, preferably, a way to "instantiate" the PolyFuncType without + // validating against an ExtensionRegistry. + let signature = type_scheme.instantiate(type_args, exts).map_err(|e| { + BuildError::InvalidHUGR(ValidationError::SignatureError { + node: function.node(), + cause: e, + }) + })?; let const_in_port = signature.output.len(); - // TODO ALAN this is totally broken for polymorphic functions. Need to TypeApply here. let op_id = self.add_dataflow_op(ops::Call { signature }, input_wires)?; let src_port = self.hugr_mut().num_outputs(function.node()) - 1; diff --git a/src/builder/module.rs b/src/builder/module.rs index c3bb1c3d5..92b9e396a 100644 --- a/src/builder/module.rs +++ b/src/builder/module.rs @@ -169,7 +169,7 @@ mod test { test::{n_identity, NAT}, Dataflow, DataflowSubContainer, }, - extension::EMPTY_REG, + extension::{EMPTY_REG, PRELUDE_REGISTRY}, type_row, types::FunctionType, }; @@ -186,7 +186,7 @@ mod test { )?; let mut f_build = module_builder.define_declaration(&f_id)?; - let call = f_build.call(&f_id, f_build.input_wires())?; + let call = f_build.call(&f_id, &[], f_build.input_wires(), &PRELUDE_REGISTRY)?; f_build.finish_with_outputs(call.outputs())?; module_builder.finish_prelude_hugr() @@ -234,7 +234,8 @@ mod test { let [wire] = local_build.input_wires_arr(); let f_id = local_build.finish_with_outputs([wire])?; - let call = f_build.call(f_id.handle(), f_build.input_wires())?; + let call = + f_build.call(f_id.handle(), &[], f_build.input_wires(), &PRELUDE_REGISTRY)?; f_build.finish_with_outputs(call.outputs())?; module_builder.finish_prelude_hugr()