diff --git a/src/builder.rs b/src/builder.rs index 4cf84ca91..2a809e6a4 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -90,6 +90,8 @@ impl From for PyErr { pub(crate) mod test { use rstest::fixture; + use crate::hugr::{views::HugrView, HugrMut, NodeType}; + use crate::ops; use crate::types::{FunctionType, Signature, Type}; use crate::{type_row, Hugr}; @@ -130,4 +132,28 @@ pub(crate) mod test { let [i1] = dfg_builder.input_wires_arr(); dfg_builder.finish_prelude_hugr_with_outputs([i1]).unwrap() } + + /// A helper method which creates a DFG rooted hugr with closed resources, + /// for tests which want to avoid having open extension variables after + /// inference. + pub(crate) fn closed_dfg_root_hugr(signature: FunctionType) -> Hugr { + let mut hugr = Hugr::new(NodeType::pure(ops::DFG { + signature: signature.clone(), + })); + hugr.add_node_with_parent( + hugr.root(), + NodeType::open_extensions(ops::Input { + types: signature.input, + }), + ) + .unwrap(); + hugr.add_node_with_parent( + hugr.root(), + NodeType::open_extensions(ops::Output { + types: signature.output, + }), + ) + .unwrap(); + hugr + } } diff --git a/src/extension/infer.rs b/src/extension/infer.rs index 2cfdb2b9b..79627edb2 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -668,6 +668,7 @@ mod test { use std::error::Error; use super::*; + use crate::builder::test::closed_dfg_root_hugr; use crate::builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr}; use crate::extension::{ExtensionSet, EMPTY_REG}; use crate::hugr::HugrMut; @@ -854,43 +855,44 @@ mod test { // Infer the extensions on a child node with no inputs fn dangling_src() -> Result<(), Box> { let rs = ExtensionSet::singleton(&"R".into()); - let root_signature = - FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&rs); - let mut builder = DFGBuilder::new(root_signature)?; - let [input_wire] = builder.input_wires_arr(); + let mut hugr = closed_dfg_root_hugr( + FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&rs), + ); + + let [input, output] = hugr.get_io(hugr.root()).unwrap(); let add_r_sig = FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&rs); - let add_r = builder.add_dataflow_node( + let add_r = hugr.add_node_with_parent( + hugr.root(), NodeType::open_extensions(ops::DFG { signature: add_r_sig, }), - [input_wire], )?; - let [wl] = add_r.outputs_arr(); // Dangling thingy let src_sig = FunctionType::new(type_row![], type_row![NAT]) .with_extension_delta(&ExtensionSet::new()); - let src = builder.add_dataflow_node( + + let src = hugr.add_node_with_parent( + hugr.root(), NodeType::open_extensions(ops::DFG { signature: src_sig }), - [], )?; - let [wr] = src.outputs_arr(); - let mult_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT]) - .with_extension_delta(&ExtensionSet::new()); + let mult_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT]); // Mult has open extension requirements, which we should solve to be "R" - let mult = builder.add_dataflow_node( + let mult = hugr.add_node_with_parent( + hugr.root(), NodeType::open_extensions(ops::DFG { signature: mult_sig, }), - [wl, wr], )?; - let [w] = mult.outputs_arr(); - builder.set_outputs([w])?; - let mut hugr = builder.base; + hugr.connect(input, 0, add_r, 0)?; + hugr.connect(add_r, 0, mult, 0)?; + hugr.connect(src, 0, mult, 1)?; + hugr.connect(mult, 0, output, 0)?; + let closure = hugr.infer_extensions()?; assert!(closure.is_empty()); assert_eq!( @@ -944,37 +946,20 @@ mod test { let abc = ExtensionSet::from_iter(["A".into(), "B".into(), "C".into()]); // Parent graph is closed - let mut hugr = Hugr::new(NodeType::pure(ops::DFG { - signature: FunctionType::new(type_row![], just_bool.clone()).with_extension_delta(&abc), - })); + let mut hugr = closed_dfg_root_hugr( + FunctionType::new(type_row![], just_bool.clone()).with_extension_delta(&abc), + ); - let _input = hugr.add_node_with_parent( - hugr.root(), - NodeType::open_extensions(ops::Input { types: type_row![] }), - )?; - let output = hugr.add_node_with_parent( - hugr.root(), - NodeType::open_extensions(ops::Output { - types: just_bool.clone(), - }), - )?; + let [_, output] = hugr.get_io(hugr.root()).unwrap(); - let child = hugr.add_node_with_parent( - hugr.root(), - NodeType::open_extensions(ops::DFG { + let root = hugr.root(); + let [child, _, ochild] = create_with_io( + &mut hugr, + root, + ops::DFG { signature: FunctionType::new(type_row![], just_bool.clone()) .with_extension_delta(&abc), - }), - )?; - let _ichild = hugr.add_node_with_parent( - child, - NodeType::open_extensions(ops::Input { types: type_row![] }), - )?; - let ochild = hugr.add_node_with_parent( - child, - NodeType::open_extensions(ops::Output { - types: just_bool.clone(), - }), + }, )?; let const_node = hugr.add_node_with_parent(child, NodeType::open_extensions(const_true))?; diff --git a/src/hugr.rs b/src/hugr.rs index 374077529..df99e8197 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -466,6 +466,7 @@ impl From for PyErr { #[cfg(test)] mod test { use super::{Hugr, HugrView, NodeType}; + use crate::builder::test::closed_dfg_root_hugr; use crate::extension::ExtensionSet; use crate::hugr::HugrMut; use crate::ops; @@ -497,22 +498,10 @@ mod test { const BIT: Type = crate::extension::prelude::USIZE_T; let r = ExtensionSet::singleton(&"R".into()); - let root = NodeType::pure(ops::DFG { - signature: FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(&r), - }); - let mut hugr = Hugr::new(root); - let input = hugr.add_node_with_parent( - hugr.root(), - NodeType::pure(ops::Input { - types: type_row![BIT], - }), - )?; - let output = hugr.add_node_with_parent( - hugr.root(), - NodeType::open_extensions(ops::Output { - types: type_row![BIT], - }), - )?; + let mut hugr = closed_dfg_root_hugr( + FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(&r), + ); + let [input, output] = hugr.get_io(hugr.root()).unwrap(); let lift = hugr.add_node_with_parent( hugr.root(), NodeType::open_extensions(ops::LeafOp::Lift { diff --git a/src/hugr/serialize.rs b/src/hugr/serialize.rs index 5584e42a2..953253e8e 100644 --- a/src/hugr/serialize.rs +++ b/src/hugr/serialize.rs @@ -271,8 +271,8 @@ pub mod test { use crate::hugr::hugrmut::sealed::HugrMutInternals; use crate::{ builder::{ - Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, - ModuleBuilder, + test::closed_dfg_root_hugr, Container, DFGBuilder, Dataflow, DataflowHugr, + DataflowSubContainer, HugrBuilder, ModuleBuilder, }, extension::prelude::BOOL_T, hugr::NodeType, @@ -455,10 +455,9 @@ pub mod test { #[test] fn hierarchy_order() { - let dfg = DFGBuilder::new(FunctionType::new(vec![QB], vec![QB])).unwrap(); - let [old_in, out] = dfg.io(); - let w = dfg.input_wires(); - let mut hugr = dfg.finish_prelude_hugr_with_outputs(w).unwrap(); + let mut hugr = closed_dfg_root_hugr(FunctionType::new(vec![QB], vec![QB])); + let [old_in, out] = hugr.get_io(hugr.root()).unwrap(); + hugr.connect(old_in, 0, out, 0).unwrap(); // Now add a new input let new_in = hugr.add_op(Input::new([QB].to_vec())); @@ -466,7 +465,7 @@ pub mod test { hugr.connect(new_in, 0, out, 0).unwrap(); hugr.move_before_sibling(new_in, old_in).unwrap(); hugr.remove_node(old_in).unwrap(); - hugr.validate(&PRELUDE_REGISTRY).unwrap(); + hugr.infer_and_validate(&PRELUDE_REGISTRY).unwrap(); let ser = serde_json::to_vec(&hugr).unwrap(); let new_hugr: Hugr = serde_json::from_slice(&ser).unwrap(); diff --git a/src/hugr/validate.rs b/src/hugr/validate.rs index 061bad2a4..b20ef8e08 100644 --- a/src/hugr/validate.rs +++ b/src/hugr/validate.rs @@ -695,6 +695,7 @@ mod test { use cool_asserts::assert_matches; use super::*; + use crate::builder::test::closed_dfg_root_hugr; use crate::builder::{BuildError, Container, Dataflow, DataflowSubContainer, ModuleBuilder}; use crate::extension::prelude::{BOOL_T, PRELUDE, USIZE_T}; use crate::extension::{Extension, ExtensionSet, TypeDefBound, EMPTY_REG, PRELUDE_REGISTRY}; @@ -1030,11 +1031,12 @@ mod test { #[test] fn test_ext_edge() -> Result<(), HugrError> { - let mut h = Hugr::new(NodeType::pure(ops::DFG { - signature: FunctionType::new(type_row![BOOL_T, BOOL_T], type_row![BOOL_T]), - })); - let input = h.add_op_with_parent(h.root(), ops::Input::new(type_row![BOOL_T, BOOL_T]))?; - let output = h.add_op_with_parent(h.root(), ops::Output::new(type_row![BOOL_T]))?; + let mut h = closed_dfg_root_hugr(FunctionType::new( + type_row![BOOL_T, BOOL_T], + type_row![BOOL_T], + )); + let [input, output] = h.get_io(h.root()).unwrap(); + // Nested DFG BOOL_T -> BOOL_T let sub_dfg = h.add_op_with_parent( h.root(), @@ -1056,35 +1058,32 @@ mod test { h.connect(sub_dfg, 0, output, 0)?; assert_matches!( - h.validate(&EMPTY_REG), + h.infer_and_validate(&EMPTY_REG), Err(ValidationError::UnconnectedPort { .. }) ); h.connect(input, 1, sub_op, 1)?; assert_matches!( - h.validate(&EMPTY_REG), + h.infer_and_validate(&EMPTY_REG), Err(ValidationError::InterGraphEdgeError( InterGraphEdgeError::MissingOrderEdge { .. } )) ); //Order edge. This will need metadata indicating its purpose. h.add_other_edge(input, sub_dfg)?; - h.validate(&EMPTY_REG).unwrap(); + h.infer_and_validate(&EMPTY_REG).unwrap(); Ok(()) } #[test] fn test_local_const() -> Result<(), HugrError> { - let mut h = Hugr::new(NodeType::pure(ops::DFG { - signature: FunctionType::new(type_row![BOOL_T], type_row![BOOL_T]), - })); - let input = h.add_op_with_parent(h.root(), ops::Input::new(type_row![BOOL_T]))?; - let output = h.add_op_with_parent(h.root(), ops::Output::new(type_row![BOOL_T]))?; + let mut h = closed_dfg_root_hugr(FunctionType::new(type_row![BOOL_T], type_row![BOOL_T])); + let [input, output] = h.get_io(h.root()).unwrap(); let and = h.add_op_with_parent(h.root(), and_op())?; h.connect(input, 0, and, 0)?; h.connect(and, 0, output, 0)?; assert_eq!( - h.validate(&EMPTY_REG), + h.infer_and_validate(&EMPTY_REG), Err(ValidationError::UnconnectedPort { node: and, port: Port::new_incoming(1), @@ -1102,7 +1101,7 @@ mod test { h.connect(cst, 0, lcst, 0)?; h.connect(lcst, 0, and, 1)?; // There is no edge from Input to LoadConstant, but that's OK: - h.validate(&EMPTY_REG).unwrap(); + h.infer_and_validate(&EMPTY_REG).unwrap(); Ok(()) } @@ -1278,11 +1277,11 @@ mod test { #[test] fn dfg_with_cycles() -> Result<(), HugrError> { - let mut h = Hugr::new(NodeType::pure(ops::DFG { - signature: FunctionType::new(type_row![BOOL_T, BOOL_T], type_row![BOOL_T]), - })); - let input = h.add_op_with_parent(h.root(), ops::Input::new(type_row![BOOL_T, BOOL_T]))?; - let output = h.add_op_with_parent(h.root(), ops::Output::new(type_row![BOOL_T]))?; + let mut h = closed_dfg_root_hugr(FunctionType::new( + type_row![BOOL_T, BOOL_T], + type_row![BOOL_T], + )); + let [input, output] = h.get_io(h.root()).unwrap(); let and = h.add_op_with_parent(h.root(), and_op())?; let not1 = h.add_op_with_parent(h.root(), not_op())?; let not2 = h.add_op_with_parent(h.root(), not_op())?;