diff --git a/src/extension/infer.rs b/src/extension/infer.rs index b9525a18a..856681adf 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -21,6 +21,7 @@ use crate::{ use super::validate::ExtensionError; use petgraph::graph as pg; +use petgraph::{Directed, EdgeType, Undirected}; use std::collections::{HashMap, HashSet, VecDeque}; @@ -106,53 +107,66 @@ pub enum InferExtensionError { EdgeMismatch(#[from] ExtensionError), } -/// A graph of metavariables which we've found equality constraints for. Edges -/// between nodes represent equality constraints. -struct EqGraph { - equalities: pg::Graph, +/// A graph of metavariables connected by constraints. +/// The edges represent `Equal` constraints in the undirected graph and `Plus` +/// constraints in the directed case. +struct GraphContainer { + graph: pg::Graph, node_map: HashMap, } -impl EqGraph { - /// Create a new `EqGraph` - fn new() -> Self { - EqGraph { - equalities: pg::Graph::new_undirected(), - node_map: HashMap::new(), - } - } - +impl GraphContainer { /// Add a metavariable to the graph as a node and return the `NodeIndex`. /// If it's already there, just return the existing `NodeIndex` fn add_or_retrieve(&mut self, m: Meta) -> pg::NodeIndex { self.node_map.get(&m).cloned().unwrap_or_else(|| { - let ix = self.equalities.add_node(m); + let ix = self.graph.add_node(m); self.node_map.insert(m, ix); ix }) } - /// Create an edge between two nodes on the graph, declaring that they stand - /// for metavariables which should be equal. - fn register_eq(&mut self, src: Meta, tgt: Meta) { + /// Create an edge between two nodes on the graph + fn add_edge(&mut self, src: Meta, tgt: Meta) { let src_ix = self.add_or_retrieve(src); let tgt_ix = self.add_or_retrieve(tgt); - self.equalities.add_edge(src_ix, tgt_ix, ()); + self.graph.add_edge(src_ix, tgt_ix, ()); } - /// Return the connected components of the graph in terms of metavariables - fn ccs(&self) -> Vec> { - petgraph::algo::tarjan_scc(&self.equalities) + /// Return the strongly connected components of the graph in terms of + /// metavariables. In the undirected case, return the connected components + fn sccs(&self) -> Vec> { + petgraph::algo::tarjan_scc(&self.graph) .into_iter() .map(|cc| { cc.into_iter() - .map(|n| *self.equalities.node_weight(n).unwrap()) + .map(|n| *self.graph.node_weight(n).unwrap()) .collect() }) .collect() } } +impl GraphContainer { + fn new() -> Self { + GraphContainer { + graph: pg::Graph::new_undirected(), + node_map: HashMap::new(), + } + } +} + +impl GraphContainer { + fn new() -> Self { + GraphContainer { + graph: pg::Graph::new(), + node_map: HashMap::new(), + } + } +} + +type EqGraph = GraphContainer; + /// Our current knowledge about the extensions of the graph struct UnificationContext { /// A list of constraints for each metavariable @@ -412,7 +426,7 @@ impl UnificationContext { fn merge_equal_metas(&mut self) -> Result<(HashSet, HashSet), InferExtensionError> { let mut merged: HashSet = HashSet::new(); let mut new_metas: HashSet = HashSet::new(); - for cc in self.eq_graph.ccs().into_iter() { + for cc in self.eq_graph.sccs().into_iter() { // Within a connected component everything is equal let combined_meta = self.fresh_meta(); for m in cc.iter() { @@ -476,7 +490,7 @@ impl UnificationContext { match c { // Just register the equality in the EqGraph, we'll process it later Constraint::Equal(other_meta) => { - self.eq_graph.register_eq(meta, *other_meta); + self.eq_graph.add_edge(meta, *other_meta); } // N.B. If `meta` is already solved, we can't use that // information to solve `other_meta`. This is because the Plus @@ -617,31 +631,98 @@ impl UnificationContext { self.results() } - /// Instantiate all variables in the graph with the empty extension set. + /// Gather all the transitive dependencies (induced by constraints) of the + /// variables in the context. + fn search_variable_deps(&self) -> HashSet { + let mut seen = HashSet::new(); + let mut new_variables: HashSet = self.variables.clone(); + while !new_variables.is_empty() { + new_variables = new_variables + .into_iter() + .filter(|m| seen.insert(*m)) + .flat_map(|m| self.get_constraints(&m).unwrap()) + .map(|c| match c { + Constraint::Plus(_, other) => self.resolve(*other), + Constraint::Equal(other) => self.resolve(*other), + }) + .collect(); + } + seen + } + /// Instantiate all variables in the graph with the empty extension set, or /// the smallest solution possible given their constraints. /// This is done to solve metas which depend on variables, which allows /// us to come up with a fully concrete solution to pass into validation. + /// + /// Nodes which loop into themselves must be considered as a "minimum" set + /// of requirements. If we have + /// 1 = 2 + X, ... + /// 2 = 1 + x, ... + /// then 1 and 2 both definitely contain X, even if we don't know what else. + /// So instead of instantiating to the empty set, we'll instantiate to `{X}` pub fn instantiate_variables(&mut self) { - for m in self.variables.clone().into_iter() { + // A directed graph to keep track of `Plus` constraint relationships + let mut relations = GraphContainer::::new(); + let mut solutions: HashMap = HashMap::new(); + + let variable_scope = self.search_variable_deps(); + for m in variable_scope.into_iter() { + // If `m` has been merged, [`self.variables`] entry + // will have already been updated to the merged + // value by [`self.merge_equal_metas`] so we don't + // need to worry about resolving it. if !self.solved.contains_key(&m) { // Handle the case where the constraints for `m` contain a self // reference, i.e. "m = Plus(E, m)", in which case the variable // should be instantiated to E rather than the empty set. - let solution = self - .get_constraints(&m) - .unwrap() + let plus_constraints = + self.get_constraints(&m) + .unwrap() + .iter() + .cloned() + .flat_map(|c| match c { + Constraint::Plus(r, other_m) => Some((r, self.resolve(other_m))), + _ => None, + }); + + let (rs, other_ms): (Vec<_>, Vec<_>) = plus_constraints.unzip(); + let solution = rs.iter().fold(ExtensionSet::new(), |e1, e2| e1.union(e2)); + let unresolved_metas = other_ms + .into_iter() + .filter(|other_m| m != *other_m) + .collect::>(); + + // If `m` doesn't depend on any other metas then we have all the + // information we need to come up with a solution for it. + relations.add_or_retrieve(m); + unresolved_metas .iter() - .filter_map(|c| match c { - // If `m` has been merged, [`self.variables`] entry - // will have already been updated to the merged - // value by [`self.merge_equal_metas`] so we don't - // need to worry about resolving it. - Constraint::Plus(x, other_m) if m == self.resolve(*other_m) => Some(x), - _ => None, - }) - .fold(ExtensionSet::new(), ExtensionSet::union); - self.add_solution(m, solution); + .for_each(|other_m| relations.add_edge(m, *other_m)); + solutions.insert(m, solution); + } + } + println!("{:?}", relations.node_map); + println!("{:?}", relations.graph); + + // Process the strongly-connected components. We need to deal with these + // depended-upon before depender. ccs() gives them back in some order + // - this might need to be reversed???? + for cc in relations.sccs() { + // Strongly connected components are looping constraint dependencies. + // This means that each metavariable in the CC has the same solution. + let combined_solution = cc + .iter() + .flat_map(|m| self.get_constraints(m).unwrap()) + .filter_map(|c| match c { + Constraint::Plus(_, other_m) => solutions.get(&self.resolve(*other_m)), + Constraint::Equal(_) => None, + }) + .fold(ExtensionSet::new(), |a, b| a.union(b)); + + for m in cc.iter() { + self.add_solution(*m, combined_solution.clone()); + solutions.insert(*m, combined_solution.clone()); } } self.variables = HashSet::new(); @@ -1465,17 +1546,14 @@ mod test { #[test] fn test_cfg_loops() -> Result<(), Box> { let just_a = ExtensionSet::singleton(&A); - let variants = vec![ - ( - ExtensionSet::new(), - ExtensionSet::new(), - ExtensionSet::new(), - ), - (just_a.clone(), ExtensionSet::new(), ExtensionSet::new()), - (ExtensionSet::new(), just_a.clone(), ExtensionSet::new()), - (ExtensionSet::new(), ExtensionSet::new(), just_a.clone()), - ]; - + let mut variants = Vec::new(); + for entry in [ExtensionSet::new(), just_a.clone()] { + for bb1 in [ExtensionSet::new(), just_a.clone()] { + for bb2 in [ExtensionSet::new(), just_a.clone()] { + variants.push((entry.clone(), bb1.clone(), bb2.clone())); + } + } + } for (bb0, bb1, bb2) in variants.into_iter() { let mut hugr = make_looping_cfg(bb0, bb1, bb2)?; hugr.update_validate(&PRELUDE_REGISTRY)?; @@ -1581,4 +1659,27 @@ mod test { fn plus_on_self_10_times() { [0; 10].iter().for_each(|_| plus_on_self().unwrap()) } + + #[test] + // Test that logic for dealing with self-referential constraints doesn't + // fall over when a self-referencing group of metas also references a meta + // outside the group + fn failing_sccs_test() { + let hugr = Hugr::default(); + let mut ctx = UnificationContext::new(&hugr); + let m1 = ctx.fresh_meta(); + let m2 = ctx.fresh_meta(); + let m3 = ctx.fresh_meta(); + // Outside of the connected component + let m_other = ctx.fresh_meta(); + // These 3 metavariables form a loop + ctx.add_constraint(m1, Constraint::Plus(ExtensionSet::singleton(&A), m3)); + ctx.add_constraint(m2, Constraint::Plus(ExtensionSet::singleton(&A), m1)); + ctx.add_constraint(m3, Constraint::Plus(ExtensionSet::singleton(&A), m2)); + // This other meta is outside the loop, but depended on by one of the loop metas + ctx.add_constraint(m2, Constraint::Plus(ExtensionSet::singleton(&A), m_other)); + ctx.variables.insert(m1); + ctx.variables.insert(m_other); + ctx.instantiate_variables(); + } }