diff --git a/src/extension/infer.rs b/src/extension/infer.rs index 9024c66c96..5b5d260550 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -857,21 +857,28 @@ mod test { let [w] = mult.outputs_arr(); builder.set_outputs([w])?; - let hugr = builder.base; - // TODO: when we put new extensions onto the graph after inference, we - // can call `finish_hugr` and just look at the graph - let (solution, extra) = infer_extensions(&hugr)?; - assert!(extra.is_empty()); + let mut hugr = builder.base; + let closure = hugr.infer_extensions()?; + assert!(closure.is_empty()); assert_eq!( - *solution.get(&(src.node(), Direction::Outgoing)).unwrap(), + hugr.get_nodetype(src.node()) + .signature() + .unwrap() + .output_extensions(), rs ); assert_eq!( - *solution.get(&(mult.node(), Direction::Incoming)).unwrap(), + hugr.get_nodetype(mult.node()) + .signature() + .unwrap() + .input_extensions, rs ); assert_eq!( - *solution.get(&(mult.node(), Direction::Outgoing)).unwrap(), + hugr.get_nodetype(mult.node()) + .signature() + .unwrap() + .output_extensions(), rs ); Ok(()) diff --git a/src/hugr.rs b/src/hugr.rs index 4e1c52edf5..439888784d 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -193,7 +193,7 @@ impl Hugr { rw.apply(self) } - /// Infer extension requirements + /// Infer extension requirements and add new information to `op_types` field pub fn infer_extensions( &mut self, ) -> Result, InferExtensionError> { @@ -202,9 +202,22 @@ impl Hugr { Ok(extension_closure) } - /// TODO: Write this - fn instantiate_extensions(&mut self, _solution: ExtensionSolution) { - //todo!() + /// Add extension requirement information to the hugr in place. + fn instantiate_extensions(&mut self, solution: ExtensionSolution) { + // We only care about inferred _input_ extensions, because `NodeType` + // uses those to infer the output extensions + for ((node, _), input_extensions) in solution + .iter() + .filter(|((_, dir), _)| *dir == Direction::Incoming) + { + let nodetype = self.op_types.try_get_mut(node.index).unwrap(); + match nodetype.signature() { + None => nodetype.input_extensions = Some(input_extensions.clone()), + Some(existing_ext_reqs) => { + debug_assert_eq!(existing_ext_reqs.input_extensions, *input_extensions) + } + } + } } } @@ -366,7 +379,14 @@ impl From for PyErr { #[cfg(test)] mod test { - use super::Hugr; + use super::{Hugr, HugrView, NodeType}; + use crate::extension::ExtensionSet; + use crate::hugr::hugrmut::HugrInternalsMut; + use crate::ops; + use crate::type_row; + use crate::types::{FunctionType, Type}; + + use std::error::Error; #[test] fn impls_send_and_sync() { @@ -385,4 +405,55 @@ mod test { let hugr = simple_dfg_hugr(); assert_matches!(hugr.get_io(hugr.root()), Some(_)); } + + #[test] + fn extension_instantiation() -> Result<(), Box> { + const BIT: Type = crate::extension::prelude::USIZE_T; + let r = ExtensionSet::singleton(&"R".into()); + + let root = NodeType::pure(ops::DFG { + signature: FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(&r), + }); + let mut hugr = Hugr::new(root); + let input = hugr.add_node_with_parent( + hugr.root(), + NodeType::pure(ops::Input { + types: type_row![BIT], + }), + )?; + let output = hugr.add_node_with_parent( + hugr.root(), + NodeType::open_extensions(ops::Output { + types: type_row![BIT], + }), + )?; + let lift = hugr.add_node_with_parent( + hugr.root(), + NodeType::open_extensions(ops::LeafOp::Lift { + type_row: type_row![BIT], + new_extension: "R".into(), + }), + )?; + hugr.connect(input, 0, lift, 0)?; + hugr.connect(lift, 0, output, 0)?; + hugr.infer_extensions()?; + + assert_eq!( + hugr.op_types + .get(lift.index) + .signature() + .unwrap() + .input_extensions, + ExtensionSet::new() + ); + assert_eq!( + hugr.op_types + .get(output.index) + .signature() + .unwrap() + .input_extensions, + r + ); + Ok(()) + } }