Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix!: force_order failing on Const nodes, add arg to rank. #1300

Merged
merged 3 commits into from
Jul 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 53 additions & 23 deletions hugr-passes/src/force_order.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
//! Provides [force_order], a tool for fixing the order of nodes in a Hugr.
use std::{cmp::Reverse, collections::BinaryHeap};
use std::{cmp::Reverse, collections::BinaryHeap, iter};

use hugr_core::{
hugr::{
hugrmut::HugrMut,
views::{DescendantsGraph, HierarchyView, SiblingGraph},
HugrError,
},
ops::{OpTag, OpTrait},
ops::{NamedOp, OpTag, OpTrait},
types::EdgeKind,
Direction, HugrView as _, Node,
HugrView as _, Node,
};
use itertools::Itertools as _;
use petgraph::{
Expand All @@ -36,45 +36,58 @@ use petgraph::{
/// there is no path from `n2` to `n1` (otherwise this would invalidate `hugr`).
/// Nodes of equal rank will be ordered arbitrarily, although that arbitrary
/// order is deterministic.
pub fn force_order(
hugr: &mut impl HugrMut,
pub fn force_order<H: HugrMut>(
hugr: &mut H,
root: Node,
rank: impl Fn(Node) -> i64,
rank: impl Fn(&H, Node) -> i64,
) -> Result<(), HugrError> {
force_order_by_key(hugr, root, rank)
}

/// As [force_order], but allows a generic [Ord] choice for the result of the
/// `rank` function.
pub fn force_order_by_key<K: Ord>(
hugr: &mut impl HugrMut,
pub fn force_order_by_key<H: HugrMut, K: Ord>(
hugr: &mut H,
root: Node,
rank: impl Fn(Node) -> K,
rank: impl Fn(&H, Node) -> K,
) -> Result<(), HugrError> {
let dataflow_parents = DescendantsGraph::<Node>::try_new(hugr, root)?
.nodes()
.filter(|n| hugr.get_optype(*n).tag() <= OpTag::DataflowParent)
.collect_vec();
for dp in dataflow_parents {
// we filter out the input and output nodes from the topological sort
let [i, o] = hugr.get_io(dp).unwrap();
let rank = |n| rank(hugr, n);
let sg = SiblingGraph::<Node>::try_new(hugr, dp)?;
let petgraph = NodeFiltered::from_fn(sg.as_petgraph(), |x| x != dp);
let petgraph = NodeFiltered::from_fn(sg.as_petgraph(), |x| x != dp && x != i && x != o);
let ordered_nodes = ForceOrder::new(&petgraph, &rank)
.iter(&petgraph)
.filter(|&x| hugr.get_optype(x).tag() <= OpTag::DataflowChild)
.filter(|&x| {
let expected_edge = Some(EdgeKind::StateOrder);
let optype = hugr.get_optype(x);
if optype.other_input() == expected_edge || optype.other_output() == expected_edge {
assert_eq!(
optype.other_input(),
optype.other_output(),
"Optype does not have both input and output order edge: {}",
optype.name()
);
true
} else {
false
}
})
.collect_vec();

for (&n1, &n2) in ordered_nodes.iter().tuple_windows() {
// we iterate over the topologically sorted nodes, prepending the input
// node and suffixing the output node.
for (&n1, &n2) in iter::once(&i)
.chain(ordered_nodes.iter())
.chain(iter::once(&o))
.tuple_windows()
{
let (n1_ot, n2_ot) = (hugr.get_optype(n1), hugr.get_optype(n2));
assert_eq!(
Some(EdgeKind::StateOrder),
n1_ot.other_port_kind(Direction::Outgoing),
"Node {n1} does not support state order edges"
);
assert_eq!(
Some(EdgeKind::StateOrder),
n2_ot.other_port_kind(Direction::Incoming),
"Node {n2} does not support state order edges"
);
if !hugr.output_neighbours(n1).contains(&n2) {
hugr.connect(
n1,
Expand Down Expand Up @@ -192,10 +205,13 @@ mod test {

use super::*;
use hugr_core::builder::{endo_ft, BuildHandle, Dataflow, DataflowHugr};
use hugr_core::extension::EMPTY_REG;
use hugr_core::ops::handle::{DataflowOpID, NodeHandle};

use hugr_core::ops::Value;
use hugr_core::std_extensions::arithmetic::int_ops::{self, IntOpDef};
use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES;
use hugr_core::types::{FunctionType, Type};
use hugr_core::{builder::DFGBuilder, hugr::Hugr};
use hugr_core::{HugrView, Wire};

Expand Down Expand Up @@ -257,7 +273,7 @@ mod test {
type RankMap = HashMap<Node, i64>;

fn force_order_test_impl(hugr: &mut Hugr, rank_map: RankMap) -> Vec<Node> {
force_order(hugr, hugr.root(), |n| *rank_map.get(&n).unwrap_or(&0)).unwrap();
force_order(hugr, hugr.root(), |_, n| *rank_map.get(&n).unwrap_or(&0)).unwrap();

let topo_sorted = Topo::new(&hugr.as_petgraph())
.iter(&hugr.as_petgraph())
Expand Down Expand Up @@ -303,4 +319,18 @@ mod test {
let topo_sort = force_order_test_impl(&mut hugr, rank_map);
assert_eq!(vec![v0, v1, v2, v3], topo_sort);
}

#[test]
fn test_force_order_const() {
let mut hugr = {
let mut builder =
DFGBuilder::new(FunctionType::new(Type::EMPTY_TYPEROW, Type::UNIT)).unwrap();
let unit = builder.add_load_value(Value::unary_unit_sum());
builder
.finish_hugr_with_outputs([unit], &EMPTY_REG)
.unwrap()
};
let root = hugr.root();
force_order(&mut hugr, root, |_, _| 0).unwrap();
}
}
Loading