Skip to content

Commit

Permalink
fix(Inference): Work harder in variable instantiation (#591)
Browse files Browse the repository at this point in the history
Extend the `instantiate_variables` method in extension inference to
handle the case where we end up with variables which form loops of
`Plus` constraints with other metavariables. This is required to get the
rest of the test variants in `test_cfg_loops` to work

Resolves #598

---------

Co-authored-by: Alan Lawrence <alan.lawrence@cambridgequantum.com>
  • Loading branch information
croyzor and acl-cqc authored Nov 8, 2023
1 parent a6dee37 commit 2269161
Showing 1 changed file with 151 additions and 50 deletions.
201 changes: 151 additions & 50 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use crate::{
use super::validate::ExtensionError;

use petgraph::graph as pg;
use petgraph::{Directed, EdgeType, Undirected};

use std::collections::{HashMap, HashSet, VecDeque};

Expand Down Expand Up @@ -106,53 +107,66 @@ pub enum InferExtensionError {
EdgeMismatch(#[from] ExtensionError),
}

/// A graph of metavariables which we've found equality constraints for. Edges
/// between nodes represent equality constraints.
struct EqGraph {
equalities: pg::Graph<Meta, (), petgraph::Undirected>,
/// A graph of metavariables connected by constraints.
/// The edges represent `Equal` constraints in the undirected graph and `Plus`
/// constraints in the directed case.
struct GraphContainer<Dir: EdgeType> {
graph: pg::Graph<Meta, (), Dir>,
node_map: HashMap<Meta, pg::NodeIndex>,
}

impl EqGraph {
/// Create a new `EqGraph`
fn new() -> Self {
EqGraph {
equalities: pg::Graph::new_undirected(),
node_map: HashMap::new(),
}
}

impl<T: EdgeType> GraphContainer<T> {
/// Add a metavariable to the graph as a node and return the `NodeIndex`.
/// If it's already there, just return the existing `NodeIndex`
fn add_or_retrieve(&mut self, m: Meta) -> pg::NodeIndex {
self.node_map.get(&m).cloned().unwrap_or_else(|| {
let ix = self.equalities.add_node(m);
let ix = self.graph.add_node(m);
self.node_map.insert(m, ix);
ix
})
}

/// Create an edge between two nodes on the graph, declaring that they stand
/// for metavariables which should be equal.
fn register_eq(&mut self, src: Meta, tgt: Meta) {
/// Create an edge between two nodes on the graph
fn add_edge(&mut self, src: Meta, tgt: Meta) {
let src_ix = self.add_or_retrieve(src);
let tgt_ix = self.add_or_retrieve(tgt);
self.equalities.add_edge(src_ix, tgt_ix, ());
self.graph.add_edge(src_ix, tgt_ix, ());
}

/// Return the connected components of the graph in terms of metavariables
fn ccs(&self) -> Vec<Vec<Meta>> {
petgraph::algo::tarjan_scc(&self.equalities)
/// Return the strongly connected components of the graph in terms of
/// metavariables. In the undirected case, return the connected components
fn sccs(&self) -> Vec<Vec<Meta>> {
petgraph::algo::tarjan_scc(&self.graph)
.into_iter()
.map(|cc| {
cc.into_iter()
.map(|n| *self.equalities.node_weight(n).unwrap())
.map(|n| *self.graph.node_weight(n).unwrap())
.collect()
})
.collect()
}
}

impl GraphContainer<Undirected> {
fn new() -> Self {
GraphContainer {
graph: pg::Graph::new_undirected(),
node_map: HashMap::new(),
}
}
}

impl GraphContainer<Directed> {
fn new() -> Self {
GraphContainer {
graph: pg::Graph::new(),
node_map: HashMap::new(),
}
}
}

type EqGraph = GraphContainer<Undirected>;

/// Our current knowledge about the extensions of the graph
struct UnificationContext {
/// A list of constraints for each metavariable
Expand Down Expand Up @@ -412,7 +426,7 @@ impl UnificationContext {
fn merge_equal_metas(&mut self) -> Result<(HashSet<Meta>, HashSet<Meta>), InferExtensionError> {
let mut merged: HashSet<Meta> = HashSet::new();
let mut new_metas: HashSet<Meta> = HashSet::new();
for cc in self.eq_graph.ccs().into_iter() {
for cc in self.eq_graph.sccs().into_iter() {
// Within a connected component everything is equal
let combined_meta = self.fresh_meta();
for m in cc.iter() {
Expand Down Expand Up @@ -476,7 +490,7 @@ impl UnificationContext {
match c {
// Just register the equality in the EqGraph, we'll process it later
Constraint::Equal(other_meta) => {
self.eq_graph.register_eq(meta, *other_meta);
self.eq_graph.add_edge(meta, *other_meta);
}
// N.B. If `meta` is already solved, we can't use that
// information to solve `other_meta`. This is because the Plus
Expand Down Expand Up @@ -617,31 +631,98 @@ impl UnificationContext {
self.results()
}

/// Instantiate all variables in the graph with the empty extension set.
/// Gather all the transitive dependencies (induced by constraints) of the
/// variables in the context.
fn search_variable_deps(&self) -> HashSet<Meta> {
let mut seen = HashSet::new();
let mut new_variables: HashSet<Meta> = self.variables.clone();
while !new_variables.is_empty() {
new_variables = new_variables
.into_iter()
.filter(|m| seen.insert(*m))
.flat_map(|m| self.get_constraints(&m).unwrap())
.map(|c| match c {
Constraint::Plus(_, other) => self.resolve(*other),
Constraint::Equal(other) => self.resolve(*other),
})
.collect();
}
seen
}

/// Instantiate all variables in the graph with the empty extension set, or
/// the smallest solution possible given their constraints.
/// This is done to solve metas which depend on variables, which allows
/// us to come up with a fully concrete solution to pass into validation.
///
/// Nodes which loop into themselves must be considered as a "minimum" set
/// of requirements. If we have
/// 1 = 2 + X, ...
/// 2 = 1 + x, ...
/// then 1 and 2 both definitely contain X, even if we don't know what else.
/// So instead of instantiating to the empty set, we'll instantiate to `{X}`
pub fn instantiate_variables(&mut self) {
for m in self.variables.clone().into_iter() {
// A directed graph to keep track of `Plus` constraint relationships
let mut relations = GraphContainer::<Directed>::new();
let mut solutions: HashMap<Meta, ExtensionSet> = HashMap::new();

let variable_scope = self.search_variable_deps();
for m in variable_scope.into_iter() {
// If `m` has been merged, [`self.variables`] entry
// will have already been updated to the merged
// value by [`self.merge_equal_metas`] so we don't
// need to worry about resolving it.
if !self.solved.contains_key(&m) {
// Handle the case where the constraints for `m` contain a self
// reference, i.e. "m = Plus(E, m)", in which case the variable
// should be instantiated to E rather than the empty set.
let solution = self
.get_constraints(&m)
.unwrap()
let plus_constraints =
self.get_constraints(&m)
.unwrap()
.iter()
.cloned()
.flat_map(|c| match c {
Constraint::Plus(r, other_m) => Some((r, self.resolve(other_m))),
_ => None,
});

let (rs, other_ms): (Vec<_>, Vec<_>) = plus_constraints.unzip();
let solution = rs.iter().fold(ExtensionSet::new(), |e1, e2| e1.union(e2));
let unresolved_metas = other_ms
.into_iter()
.filter(|other_m| m != *other_m)
.collect::<Vec<_>>();

// If `m` doesn't depend on any other metas then we have all the
// information we need to come up with a solution for it.
relations.add_or_retrieve(m);
unresolved_metas
.iter()
.filter_map(|c| match c {
// If `m` has been merged, [`self.variables`] entry
// will have already been updated to the merged
// value by [`self.merge_equal_metas`] so we don't
// need to worry about resolving it.
Constraint::Plus(x, other_m) if m == self.resolve(*other_m) => Some(x),
_ => None,
})
.fold(ExtensionSet::new(), ExtensionSet::union);
self.add_solution(m, solution);
.for_each(|other_m| relations.add_edge(m, *other_m));
solutions.insert(m, solution);
}
}
println!("{:?}", relations.node_map);
println!("{:?}", relations.graph);

// Process the strongly-connected components. We need to deal with these
// depended-upon before depender. ccs() gives them back in some order
// - this might need to be reversed????
for cc in relations.sccs() {
// Strongly connected components are looping constraint dependencies.
// This means that each metavariable in the CC has the same solution.
let combined_solution = cc
.iter()
.flat_map(|m| self.get_constraints(m).unwrap())
.filter_map(|c| match c {
Constraint::Plus(_, other_m) => solutions.get(&self.resolve(*other_m)),
Constraint::Equal(_) => None,
})
.fold(ExtensionSet::new(), |a, b| a.union(b));

for m in cc.iter() {
self.add_solution(*m, combined_solution.clone());
solutions.insert(*m, combined_solution.clone());
}
}
self.variables = HashSet::new();
Expand Down Expand Up @@ -1465,17 +1546,14 @@ mod test {
#[test]
fn test_cfg_loops() -> Result<(), Box<dyn Error>> {
let just_a = ExtensionSet::singleton(&A);
let variants = vec![
(
ExtensionSet::new(),
ExtensionSet::new(),
ExtensionSet::new(),
),
(just_a.clone(), ExtensionSet::new(), ExtensionSet::new()),
(ExtensionSet::new(), just_a.clone(), ExtensionSet::new()),
(ExtensionSet::new(), ExtensionSet::new(), just_a.clone()),
];

let mut variants = Vec::new();
for entry in [ExtensionSet::new(), just_a.clone()] {
for bb1 in [ExtensionSet::new(), just_a.clone()] {
for bb2 in [ExtensionSet::new(), just_a.clone()] {
variants.push((entry.clone(), bb1.clone(), bb2.clone()));
}
}
}
for (bb0, bb1, bb2) in variants.into_iter() {
let mut hugr = make_looping_cfg(bb0, bb1, bb2)?;
hugr.update_validate(&PRELUDE_REGISTRY)?;
Expand Down Expand Up @@ -1581,4 +1659,27 @@ mod test {
fn plus_on_self_10_times() {
[0; 10].iter().for_each(|_| plus_on_self().unwrap())
}

#[test]
// Test that logic for dealing with self-referential constraints doesn't
// fall over when a self-referencing group of metas also references a meta
// outside the group
fn failing_sccs_test() {
let hugr = Hugr::default();
let mut ctx = UnificationContext::new(&hugr);
let m1 = ctx.fresh_meta();
let m2 = ctx.fresh_meta();
let m3 = ctx.fresh_meta();
// Outside of the connected component
let m_other = ctx.fresh_meta();
// These 3 metavariables form a loop
ctx.add_constraint(m1, Constraint::Plus(ExtensionSet::singleton(&A), m3));
ctx.add_constraint(m2, Constraint::Plus(ExtensionSet::singleton(&A), m1));
ctx.add_constraint(m3, Constraint::Plus(ExtensionSet::singleton(&A), m2));
// This other meta is outside the loop, but depended on by one of the loop metas
ctx.add_constraint(m2, Constraint::Plus(ExtensionSet::singleton(&A), m_other));
ctx.variables.insert(m1);
ctx.variables.insert(m_other);
ctx.instantiate_variables();
}
}

0 comments on commit 2269161

Please sign in to comment.