Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Builder and HugrMut add_op_xxx default to open extensions #622

Merged
merged 13 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/algorithm/nest_cfgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ pub(crate) mod test {
])
);
transform_cfg_to_nested(&mut IdentityCfgMap::new(rc));
h.validate(&PRELUDE_REGISTRY).unwrap();
h.update_validate(&PRELUDE_REGISTRY).unwrap();
assert_eq!(1, depth(&h, entry));
assert_eq!(1, depth(&h, exit));
for n in [split, left, right, merge, head, tail] {
Expand Down Expand Up @@ -753,7 +753,7 @@ pub(crate) mod test {
let root = h.root();
let m = SiblingMut::<CfgID>::try_new(&mut h, root).unwrap();
transform_cfg_to_nested(&mut IdentityCfgMap::new(m));
h.validate(&PRELUDE_REGISTRY).unwrap();
h.update_validate(&PRELUDE_REGISTRY).unwrap();
assert_eq!(1, depth(&h, entry));
assert_eq!(3, depth(&h, head));
for n in [split, left, right, merge] {
Expand Down
12 changes: 6 additions & 6 deletions src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,18 @@ pub(crate) mod test {
let mut hugr = Hugr::new(NodeType::pure(ops::DFG {
signature: signature.clone(),
}));
hugr.add_node_with_parent(
hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::Input {
ops::Input {
types: signature.input,
}),
},
)
.unwrap();
hugr.add_node_with_parent(
hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::Output {
ops::Output {
types: signature.output,
}),
},
)
.unwrap();
hugr
Expand Down
5 changes: 2 additions & 3 deletions src/builder/conditional.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,9 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> ConditionalBuilder<B> {
let case_node =
// add case before any existing subsequent cases
if let Some(&sibling_node) = self.case_nodes[case + 1..].iter().flatten().next() {
// TODO: Allow this to be non-pure
self.hugr_mut().add_node_before(sibling_node, NodeType::open_extensions(case_op))?
self.hugr_mut().add_op_before(sibling_node, case_op)?
} else {
self.add_child_node(NodeType::open_extensions(case_op))?
self.add_child_op(case_op)?
};

self.case_nodes[case] = Some(case_node);
Expand Down
163 changes: 77 additions & 86 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,6 @@ impl UnificationContext {
where
T: HugrView,
{
if hugr.root_type().signature().is_none() {
let m_input = self.make_or_get_meta(hugr.root(), Direction::Incoming);
self.variables.insert(m_input);
}

for node in hugr.nodes() {
let m_input = self.make_or_get_meta(node, Direction::Incoming);
let m_output = self.make_or_get_meta(node, Direction::Outgoing);
Expand Down Expand Up @@ -312,6 +307,14 @@ impl UnificationContext {
match node_type.signature() {
// Input extensions are open
None => {
if node == hugr.root()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check on this line is equivalent to the one removed above (lines 271-275); the || on next line extends it and is a first stab at what nodes we'd expect to be unconnected. (An alternative might be to skip creating m_input and m_output earlier in the loop...that might be better, happy to give that a go). Note #457 .

This still won't address random unconnected bits of graph, which is to say, malformed graphs, tho. (Only really an issue for tests of validation - some of these have to jump through hoops to get the error they expect rather than "CantInfer")

|| matches!(
node_type.tag(),
OpTag::Alias | OpTag::FuncDefn | OpTag::Function
)
{
self.variables.insert(m_input);
}
self.gen_union_constraint(
m_input,
m_output,
Expand All @@ -338,16 +341,16 @@ impl UnificationContext {
| Some(EdgeKind::ControlFlow)
)
}) {
let m_tgt = *self
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a driveby lifting of a loop invariant, nothing more

.extensions
.get(&(tgt_node, Direction::Incoming))
.unwrap();
for (src_node, _) in hugr.linked_ports(tgt_node, port) {
let m_src = self
.extensions
.get(&(src_node, Direction::Outgoing))
.unwrap();
let m_tgt = self
.extensions
.get(&(tgt_node, Direction::Incoming))
.unwrap();
self.add_constraint(*m_src, Constraint::Equal(*m_tgt));
self.add_constraint(*m_src, Constraint::Equal(m_tgt));
}
}
}
Expand Down Expand Up @@ -727,11 +730,11 @@ mod test {
let root_node = NodeType::open_extensions(op);
let mut hugr = Hugr::new(root_node);

let input = NodeType::open_extensions(ops::Input::new(type_row![NAT, NAT]));
let output = NodeType::open_extensions(ops::Output::new(type_row![NAT]));
let input = ops::Input::new(type_row![NAT, NAT]);
let output = ops::Output::new(type_row![NAT]);

let input = hugr.add_node_with_parent(hugr.root(), input)?;
let output = hugr.add_node_with_parent(hugr.root(), output)?;
let input = hugr.add_op_with_parent(hugr.root(), input)?;
let output = hugr.add_op_with_parent(hugr.root(), output)?;

assert_matches!(hugr.get_io(hugr.root()), Some(_));

Expand All @@ -747,29 +750,29 @@ mod test {
let mult_c_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT])
.with_extension_delta(&ExtensionSet::singleton(&C));

let add_a = hugr.add_node_with_parent(
let add_a = hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::DFG {
ops::DFG {
signature: add_a_sig,
}),
},
)?;
let add_b = hugr.add_node_with_parent(
let add_b = hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::DFG {
ops::DFG {
signature: add_b_sig,
}),
},
)?;
let add_ab = hugr.add_node_with_parent(
let add_ab = hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::DFG {
ops::DFG {
signature: add_ab_sig,
}),
},
)?;
let mult_c = hugr.add_node_with_parent(
let mult_c = hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::DFG {
ops::DFG {
signature: mult_c_sig,
}),
},
)?;

hugr.connect(input, 0, add_a, 0)?;
Expand Down Expand Up @@ -903,29 +906,26 @@ mod test {
let [input, output] = hugr.get_io(hugr.root()).unwrap();
let add_r_sig = FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&rs);

let add_r = hugr.add_node_with_parent(
let add_r = hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::DFG {
ops::DFG {
signature: add_r_sig,
}),
},
)?;

// Dangling thingy
let src_sig = FunctionType::new(type_row![], type_row![NAT])
.with_extension_delta(&ExtensionSet::new());

let src = hugr.add_node_with_parent(
hugr.root(),
NodeType::open_extensions(ops::DFG { signature: src_sig }),
)?;
let src = hugr.add_op_with_parent(hugr.root(), ops::DFG { signature: src_sig })?;

let mult_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT]);
// Mult has open extension requirements, which we should solve to be "R"
let mult = hugr.add_node_with_parent(
let mult = hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::DFG {
ops::DFG {
signature: mult_sig,
}),
},
)?;

hugr.connect(input, 0, add_r, 0)?;
Expand Down Expand Up @@ -985,18 +985,18 @@ mod test {
) -> Result<[Node; 3], Box<dyn Error>> {
let op: OpType = op.into();

let node = hugr.add_node_with_parent(parent, NodeType::open_extensions(op))?;
let input = hugr.add_node_with_parent(
let node = hugr.add_op_with_parent(parent, op)?;
let input = hugr.add_op_with_parent(
node,
NodeType::open_extensions(ops::Input {
ops::Input {
types: op_sig.input,
}),
},
)?;
let output = hugr.add_node_with_parent(
let output = hugr.add_op_with_parent(
node,
NodeType::open_extensions(ops::Output {
ops::Output {
types: op_sig.output,
}),
},
)?;
Ok([node, input, output])
}
Expand All @@ -1017,20 +1017,20 @@ mod test {
Into::<OpType>::into(op).signature(),
)?;

let lift1 = hugr.add_node_with_parent(
let lift1 = hugr.add_op_with_parent(
case,
NodeType::open_extensions(ops::LeafOp::Lift {
ops::LeafOp::Lift {
type_row: type_row![NAT],
new_extension: first_ext,
}),
},
)?;

let lift2 = hugr.add_node_with_parent(
let lift2 = hugr.add_op_with_parent(
case,
NodeType::open_extensions(ops::LeafOp::Lift {
ops::LeafOp::Lift {
type_row: type_row![NAT],
new_extension: second_ext,
}),
},
)?;

hugr.connect(case_in, 0, lift1, 0)?;
Expand Down Expand Up @@ -1095,17 +1095,17 @@ mod test {
}));

let root = hugr.root();
let input = hugr.add_node_with_parent(
let input = hugr.add_op_with_parent(
root,
NodeType::open_extensions(ops::Input {
ops::Input {
types: type_row![NAT],
}),
},
)?;
let output = hugr.add_node_with_parent(
let output = hugr.add_op_with_parent(
root,
NodeType::open_extensions(ops::Output {
ops::Output {
types: type_row![NAT],
}),
},
)?;

// Make identical dataflow nodes which add extension requirement "A" or "B"
Expand All @@ -1126,12 +1126,12 @@ mod test {
.unwrap();

let lift = hugr
.add_node_with_parent(
.add_op_with_parent(
node,
NodeType::open_extensions(ops::LeafOp::Lift {
ops::LeafOp::Lift {
type_row: type_row![NAT],
new_extension: ext,
}),
},
)
.unwrap();

Expand Down Expand Up @@ -1178,7 +1178,7 @@ mod test {

let [bb, bb_in, bb_out] = create_with_io(hugr, bb_parent, dfb, dfb_sig)?;

let dfg = hugr.add_node_with_parent(bb, NodeType::open_extensions(op))?;
let dfg = hugr.add_op_with_parent(bb, op)?;

hugr.connect(bb_in, 0, dfg, 0)?;
hugr.connect(dfg, 0, bb_out, 0)?;
Expand Down Expand Up @@ -1210,23 +1210,20 @@ mod test {
extension_delta: entry_extensions,
};

let exit = hugr.add_node_with_parent(
let exit = hugr.add_op_with_parent(
root,
NodeType::open_extensions(ops::BasicBlock::Exit {
ops::BasicBlock::Exit {
cfg_outputs: exit_types.into(),
}),
},
)?;

let entry = hugr.add_node_before(exit, NodeType::open_extensions(dfb))?;
let entry_in = hugr.add_node_with_parent(
let entry = hugr.add_op_before(exit, dfb)?;
let entry_in = hugr.add_op_with_parent(entry, ops::Input { types: inputs })?;
let entry_out = hugr.add_op_with_parent(
entry,
NodeType::open_extensions(ops::Input { types: inputs }),
)?;
let entry_out = hugr.add_node_with_parent(
entry,
NodeType::open_extensions(ops::Output {
ops::Output {
types: vec![entry_tuple_sum].into(),
}),
},
)?;

Ok(([entry, entry_in, entry_out], exit))
Expand Down Expand Up @@ -1277,12 +1274,12 @@ mod test {
type_row![NAT],
)?;

let mkpred = hugr.add_node_with_parent(
let mkpred = hugr.add_op_with_parent(
entry,
NodeType::open_extensions(make_opaque(
make_opaque(
A,
FunctionType::new(vec![NAT], twoway(NAT)).with_extension_delta(&a),
)),
),
)?;

// Internal wiring for DFGs
Expand Down Expand Up @@ -1373,12 +1370,9 @@ mod test {
type_row![NAT],
)?;

let entry_mid = hugr.add_node_with_parent(
let entry_mid = hugr.add_op_with_parent(
entry,
NodeType::open_extensions(make_opaque(
UNKNOWN_EXTENSION,
FunctionType::new(vec![NAT], twoway(NAT)),
)),
make_opaque(UNKNOWN_EXTENSION, FunctionType::new(vec![NAT], twoway(NAT))),
)?;

hugr.connect(entry_in, 0, entry_mid, 0)?;
Expand Down Expand Up @@ -1462,12 +1456,12 @@ mod test {
type_row![NAT],
)?;

let entry_dfg = hugr.add_node_with_parent(
let entry_dfg = hugr.add_op_with_parent(
entry,
NodeType::open_extensions(make_opaque(
make_opaque(
UNKNOWN_EXTENSION,
FunctionType::new(vec![NAT], oneway(NAT)).with_extension_delta(&entry_ext),
)),
),
)?;

hugr.connect(entry_in, 0, entry_dfg, 0)?;
Expand Down Expand Up @@ -1543,12 +1537,9 @@ mod test {
type_row![NAT],
)?;

let entry_mid = hugr.add_node_with_parent(
let entry_mid = hugr.add_op_with_parent(
entry,
NodeType::open_extensions(make_opaque(
UNKNOWN_EXTENSION,
FunctionType::new(vec![NAT], oneway(NAT)),
)),
make_opaque(UNKNOWN_EXTENSION, FunctionType::new(vec![NAT], oneway(NAT))),
)?;

hugr.connect(entry_in, 0, entry_mid, 0)?;
Expand Down
Loading