diff --git a/src/extension/infer.rs b/src/extension/infer.rs
index b9525a18a..856681adf 100644
--- a/src/extension/infer.rs
+++ b/src/extension/infer.rs
@@ -21,6 +21,7 @@ use crate::{
use super::validate::ExtensionError;
use petgraph::graph as pg;
+use petgraph::{Directed, EdgeType, Undirected};
use std::collections::{HashMap, HashSet, VecDeque};
@@ -106,53 +107,66 @@ pub enum InferExtensionError {
EdgeMismatch(#[from] ExtensionError),
}
-/// A graph of metavariables which we've found equality constraints for. Edges
-/// between nodes represent equality constraints.
-struct EqGraph {
- equalities: pg::Graph,
+/// A graph of metavariables connected by constraints.
+/// The edges represent `Equal` constraints in the undirected graph and `Plus`
+/// constraints in the directed case.
+struct GraphContainer {
+ graph: pg::Graph,
node_map: HashMap,
}
-impl EqGraph {
- /// Create a new `EqGraph`
- fn new() -> Self {
- EqGraph {
- equalities: pg::Graph::new_undirected(),
- node_map: HashMap::new(),
- }
- }
-
+impl GraphContainer {
/// Add a metavariable to the graph as a node and return the `NodeIndex`.
/// If it's already there, just return the existing `NodeIndex`
fn add_or_retrieve(&mut self, m: Meta) -> pg::NodeIndex {
self.node_map.get(&m).cloned().unwrap_or_else(|| {
- let ix = self.equalities.add_node(m);
+ let ix = self.graph.add_node(m);
self.node_map.insert(m, ix);
ix
})
}
- /// Create an edge between two nodes on the graph, declaring that they stand
- /// for metavariables which should be equal.
- fn register_eq(&mut self, src: Meta, tgt: Meta) {
+ /// Create an edge between two nodes on the graph
+ fn add_edge(&mut self, src: Meta, tgt: Meta) {
let src_ix = self.add_or_retrieve(src);
let tgt_ix = self.add_or_retrieve(tgt);
- self.equalities.add_edge(src_ix, tgt_ix, ());
+ self.graph.add_edge(src_ix, tgt_ix, ());
}
- /// Return the connected components of the graph in terms of metavariables
- fn ccs(&self) -> Vec> {
- petgraph::algo::tarjan_scc(&self.equalities)
+ /// Return the strongly connected components of the graph in terms of
+ /// metavariables. In the undirected case, return the connected components
+ fn sccs(&self) -> Vec> {
+ petgraph::algo::tarjan_scc(&self.graph)
.into_iter()
.map(|cc| {
cc.into_iter()
- .map(|n| *self.equalities.node_weight(n).unwrap())
+ .map(|n| *self.graph.node_weight(n).unwrap())
.collect()
})
.collect()
}
}
+impl GraphContainer {
+ fn new() -> Self {
+ GraphContainer {
+ graph: pg::Graph::new_undirected(),
+ node_map: HashMap::new(),
+ }
+ }
+}
+
+impl GraphContainer {
+ fn new() -> Self {
+ GraphContainer {
+ graph: pg::Graph::new(),
+ node_map: HashMap::new(),
+ }
+ }
+}
+
+type EqGraph = GraphContainer;
+
/// Our current knowledge about the extensions of the graph
struct UnificationContext {
/// A list of constraints for each metavariable
@@ -412,7 +426,7 @@ impl UnificationContext {
fn merge_equal_metas(&mut self) -> Result<(HashSet, HashSet), InferExtensionError> {
let mut merged: HashSet = HashSet::new();
let mut new_metas: HashSet = HashSet::new();
- for cc in self.eq_graph.ccs().into_iter() {
+ for cc in self.eq_graph.sccs().into_iter() {
// Within a connected component everything is equal
let combined_meta = self.fresh_meta();
for m in cc.iter() {
@@ -476,7 +490,7 @@ impl UnificationContext {
match c {
// Just register the equality in the EqGraph, we'll process it later
Constraint::Equal(other_meta) => {
- self.eq_graph.register_eq(meta, *other_meta);
+ self.eq_graph.add_edge(meta, *other_meta);
}
// N.B. If `meta` is already solved, we can't use that
// information to solve `other_meta`. This is because the Plus
@@ -617,31 +631,98 @@ impl UnificationContext {
self.results()
}
- /// Instantiate all variables in the graph with the empty extension set.
+ /// Gather all the transitive dependencies (induced by constraints) of the
+ /// variables in the context.
+ fn search_variable_deps(&self) -> HashSet {
+ let mut seen = HashSet::new();
+ let mut new_variables: HashSet = self.variables.clone();
+ while !new_variables.is_empty() {
+ new_variables = new_variables
+ .into_iter()
+ .filter(|m| seen.insert(*m))
+ .flat_map(|m| self.get_constraints(&m).unwrap())
+ .map(|c| match c {
+ Constraint::Plus(_, other) => self.resolve(*other),
+ Constraint::Equal(other) => self.resolve(*other),
+ })
+ .collect();
+ }
+ seen
+ }
+
/// Instantiate all variables in the graph with the empty extension set, or
/// the smallest solution possible given their constraints.
/// This is done to solve metas which depend on variables, which allows
/// us to come up with a fully concrete solution to pass into validation.
+ ///
+ /// Nodes which loop into themselves must be considered as a "minimum" set
+ /// of requirements. If we have
+ /// 1 = 2 + X, ...
+ /// 2 = 1 + x, ...
+ /// then 1 and 2 both definitely contain X, even if we don't know what else.
+ /// So instead of instantiating to the empty set, we'll instantiate to `{X}`
pub fn instantiate_variables(&mut self) {
- for m in self.variables.clone().into_iter() {
+ // A directed graph to keep track of `Plus` constraint relationships
+ let mut relations = GraphContainer::::new();
+ let mut solutions: HashMap = HashMap::new();
+
+ let variable_scope = self.search_variable_deps();
+ for m in variable_scope.into_iter() {
+ // If `m` has been merged, [`self.variables`] entry
+ // will have already been updated to the merged
+ // value by [`self.merge_equal_metas`] so we don't
+ // need to worry about resolving it.
if !self.solved.contains_key(&m) {
// Handle the case where the constraints for `m` contain a self
// reference, i.e. "m = Plus(E, m)", in which case the variable
// should be instantiated to E rather than the empty set.
- let solution = self
- .get_constraints(&m)
- .unwrap()
+ let plus_constraints =
+ self.get_constraints(&m)
+ .unwrap()
+ .iter()
+ .cloned()
+ .flat_map(|c| match c {
+ Constraint::Plus(r, other_m) => Some((r, self.resolve(other_m))),
+ _ => None,
+ });
+
+ let (rs, other_ms): (Vec<_>, Vec<_>) = plus_constraints.unzip();
+ let solution = rs.iter().fold(ExtensionSet::new(), |e1, e2| e1.union(e2));
+ let unresolved_metas = other_ms
+ .into_iter()
+ .filter(|other_m| m != *other_m)
+ .collect::>();
+
+ // If `m` doesn't depend on any other metas then we have all the
+ // information we need to come up with a solution for it.
+ relations.add_or_retrieve(m);
+ unresolved_metas
.iter()
- .filter_map(|c| match c {
- // If `m` has been merged, [`self.variables`] entry
- // will have already been updated to the merged
- // value by [`self.merge_equal_metas`] so we don't
- // need to worry about resolving it.
- Constraint::Plus(x, other_m) if m == self.resolve(*other_m) => Some(x),
- _ => None,
- })
- .fold(ExtensionSet::new(), ExtensionSet::union);
- self.add_solution(m, solution);
+ .for_each(|other_m| relations.add_edge(m, *other_m));
+ solutions.insert(m, solution);
+ }
+ }
+ println!("{:?}", relations.node_map);
+ println!("{:?}", relations.graph);
+
+ // Process the strongly-connected components. We need to deal with these
+ // depended-upon before depender. ccs() gives them back in some order
+ // - this might need to be reversed????
+ for cc in relations.sccs() {
+ // Strongly connected components are looping constraint dependencies.
+ // This means that each metavariable in the CC has the same solution.
+ let combined_solution = cc
+ .iter()
+ .flat_map(|m| self.get_constraints(m).unwrap())
+ .filter_map(|c| match c {
+ Constraint::Plus(_, other_m) => solutions.get(&self.resolve(*other_m)),
+ Constraint::Equal(_) => None,
+ })
+ .fold(ExtensionSet::new(), |a, b| a.union(b));
+
+ for m in cc.iter() {
+ self.add_solution(*m, combined_solution.clone());
+ solutions.insert(*m, combined_solution.clone());
}
}
self.variables = HashSet::new();
@@ -1465,17 +1546,14 @@ mod test {
#[test]
fn test_cfg_loops() -> Result<(), Box> {
let just_a = ExtensionSet::singleton(&A);
- let variants = vec![
- (
- ExtensionSet::new(),
- ExtensionSet::new(),
- ExtensionSet::new(),
- ),
- (just_a.clone(), ExtensionSet::new(), ExtensionSet::new()),
- (ExtensionSet::new(), just_a.clone(), ExtensionSet::new()),
- (ExtensionSet::new(), ExtensionSet::new(), just_a.clone()),
- ];
-
+ let mut variants = Vec::new();
+ for entry in [ExtensionSet::new(), just_a.clone()] {
+ for bb1 in [ExtensionSet::new(), just_a.clone()] {
+ for bb2 in [ExtensionSet::new(), just_a.clone()] {
+ variants.push((entry.clone(), bb1.clone(), bb2.clone()));
+ }
+ }
+ }
for (bb0, bb1, bb2) in variants.into_iter() {
let mut hugr = make_looping_cfg(bb0, bb1, bb2)?;
hugr.update_validate(&PRELUDE_REGISTRY)?;
@@ -1581,4 +1659,27 @@ mod test {
fn plus_on_self_10_times() {
[0; 10].iter().for_each(|_| plus_on_self().unwrap())
}
+
+ #[test]
+ // Test that logic for dealing with self-referential constraints doesn't
+ // fall over when a self-referencing group of metas also references a meta
+ // outside the group
+ fn failing_sccs_test() {
+ let hugr = Hugr::default();
+ let mut ctx = UnificationContext::new(&hugr);
+ let m1 = ctx.fresh_meta();
+ let m2 = ctx.fresh_meta();
+ let m3 = ctx.fresh_meta();
+ // Outside of the connected component
+ let m_other = ctx.fresh_meta();
+ // These 3 metavariables form a loop
+ ctx.add_constraint(m1, Constraint::Plus(ExtensionSet::singleton(&A), m3));
+ ctx.add_constraint(m2, Constraint::Plus(ExtensionSet::singleton(&A), m1));
+ ctx.add_constraint(m3, Constraint::Plus(ExtensionSet::singleton(&A), m2));
+ // This other meta is outside the loop, but depended on by one of the loop metas
+ ctx.add_constraint(m2, Constraint::Plus(ExtensionSet::singleton(&A), m_other));
+ ctx.variables.insert(m1);
+ ctx.variables.insert(m_other);
+ ctx.instantiate_variables();
+ }
}