Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(Inference): Work harder in variable instantiation #591

Merged
merged 22 commits into from
Nov 8, 2023
Merged
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
61ce9a0
refactor: Make `EqGraph` generic over directedness
croyzor Oct 5, 2023
456e55b
fix: Improve instantiate_vars code
croyzor Oct 5, 2023
802b011
tests: Add more tests for looping CFGs
croyzor Oct 5, 2023
c043802
refactor: Remove unnecessary macro
croyzor Oct 10, 2023
fad983e
doc: Update comment
croyzor Oct 10, 2023
d49926d
refactor: Rename `new_{un,}directed` to `new`
croyzor Oct 10, 2023
983bb3a
cosmetic: Move comment
croyzor Oct 10, 2023
e5dbc78
refactor: Rewrite `instantiate_variables` in a functional style
croyzor Oct 10, 2023
80516c3
Reduce mutable variables in search_variable_deps
acl-cqc Oct 10, 2023
84bffa8
refactor: Redo `search_variable_deps`
croyzor Oct 10, 2023
4530168
refactor: Redo `search_variable_deps` in functional style
croyzor Oct 11, 2023
a94d2c8
doc: Move comment
croyzor Oct 11, 2023
52fe101
Fix case of dependent cycles?? Need a test, and some comments to resolve
acl-cqc Oct 11, 2023
fada318
cosmetic: Rename `ccs` to `sccs`
croyzor Oct 13, 2023
a1fb7ec
Missed `resolve`
acl-cqc Oct 23, 2023
4f62200
Drop comment - calling self.resolve enough should handle Equals const…
acl-cqc Oct 23, 2023
ea3c102
Merge remote-tracking branch 'origin/main' into fix/inference-variabl…
croyzor Nov 7, 2023
9f032d6
Add failing test of SCC logic
croyzor Nov 7, 2023
6f288cd
Merge branch 'fix/inference-variable-instantiation' into inference-va…
croyzor Nov 8, 2023
c04e302
Update test case
croyzor Nov 8, 2023
8e87d3e
tests: Add failing test of SCC logic
croyzor Nov 7, 2023
9d49c29
Merge branch 'inference-variable/fix-dependent-sccs' into fix/inferen…
croyzor Nov 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 151 additions & 50 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use crate::{
use super::validate::ExtensionError;

use petgraph::graph as pg;
use petgraph::{Directed, EdgeType, Undirected};

use std::collections::{HashMap, HashSet, VecDeque};

Expand Down Expand Up @@ -106,53 +107,66 @@ pub enum InferExtensionError {
EdgeMismatch(#[from] ExtensionError),
}

/// A graph of metavariables which we've found equality constraints for. Edges
/// between nodes represent equality constraints.
struct EqGraph {
equalities: pg::Graph<Meta, (), petgraph::Undirected>,
/// A graph of metavariables connected by constraints.
/// The edges represent `Equal` constraints in the undirected graph and `Plus`
croyzor marked this conversation as resolved.
Show resolved Hide resolved
/// constraints in the directed case.
struct GraphContainer<Dir: EdgeType> {
graph: pg::Graph<Meta, (), Dir>,
node_map: HashMap<Meta, pg::NodeIndex>,
}

impl EqGraph {
/// Create a new `EqGraph`
fn new() -> Self {
EqGraph {
equalities: pg::Graph::new_undirected(),
node_map: HashMap::new(),
}
}

impl<T: EdgeType> GraphContainer<T> {
/// Add a metavariable to the graph as a node and return the `NodeIndex`.
/// If it's already there, just return the existing `NodeIndex`
fn add_or_retrieve(&mut self, m: Meta) -> pg::NodeIndex {
self.node_map.get(&m).cloned().unwrap_or_else(|| {
let ix = self.equalities.add_node(m);
let ix = self.graph.add_node(m);
self.node_map.insert(m, ix);
ix
})
}

/// Create an edge between two nodes on the graph, declaring that they stand
/// for metavariables which should be equal.
fn register_eq(&mut self, src: Meta, tgt: Meta) {
/// Create an edge between two nodes on the graph
fn add_edge(&mut self, src: Meta, tgt: Meta) {
let src_ix = self.add_or_retrieve(src);
let tgt_ix = self.add_or_retrieve(tgt);
self.equalities.add_edge(src_ix, tgt_ix, ());
self.graph.add_edge(src_ix, tgt_ix, ());
}

/// Return the connected components of the graph in terms of metavariables
fn ccs(&self) -> Vec<Vec<Meta>> {
petgraph::algo::tarjan_scc(&self.equalities)
/// Return the strongly connected components of the graph in terms of
/// metavariables. In the undirected case, return the connected components
fn sccs(&self) -> Vec<Vec<Meta>> {
petgraph::algo::tarjan_scc(&self.graph)
croyzor marked this conversation as resolved.
Show resolved Hide resolved
.into_iter()
.map(|cc| {
cc.into_iter()
.map(|n| *self.equalities.node_weight(n).unwrap())
.map(|n| *self.graph.node_weight(n).unwrap())
.collect()
})
.collect()
}
}

impl GraphContainer<Undirected> {
fn new() -> Self {
GraphContainer {
graph: pg::Graph::new_undirected(),
croyzor marked this conversation as resolved.
Show resolved Hide resolved
node_map: HashMap::new(),
}
}
}

impl GraphContainer<Directed> {
fn new() -> Self {
GraphContainer {
graph: pg::Graph::new(),
node_map: HashMap::new(),
}
}
}

type EqGraph = GraphContainer<Undirected>;

/// Our current knowledge about the extensions of the graph
struct UnificationContext {
/// A list of constraints for each metavariable
Expand Down Expand Up @@ -412,7 +426,7 @@ impl UnificationContext {
fn merge_equal_metas(&mut self) -> Result<(HashSet<Meta>, HashSet<Meta>), InferExtensionError> {
let mut merged: HashSet<Meta> = HashSet::new();
let mut new_metas: HashSet<Meta> = HashSet::new();
for cc in self.eq_graph.ccs().into_iter() {
for cc in self.eq_graph.sccs().into_iter() {
// Within a connected component everything is equal
let combined_meta = self.fresh_meta();
for m in cc.iter() {
Expand Down Expand Up @@ -476,7 +490,7 @@ impl UnificationContext {
match c {
// Just register the equality in the EqGraph, we'll process it later
Constraint::Equal(other_meta) => {
self.eq_graph.register_eq(meta, *other_meta);
self.eq_graph.add_edge(meta, *other_meta);
}
// N.B. If `meta` is already solved, we can't use that
// information to solve `other_meta`. This is because the Plus
Expand Down Expand Up @@ -617,31 +631,98 @@ impl UnificationContext {
self.results()
}

/// Instantiate all variables in the graph with the empty extension set.
/// Gather all the transitive dependencies (induced by constraints) of the
/// variables in the context.
fn search_variable_deps(&self) -> HashSet<Meta> {
let mut seen = HashSet::new();
let mut new_variables: HashSet<Meta> = self.variables.clone();
while !new_variables.is_empty() {
new_variables = new_variables
.into_iter()
.filter(|m| seen.insert(*m))
.flat_map(|m| self.get_constraints(&m).unwrap())
.map(|c| match c {
Constraint::Plus(_, other) => self.resolve(*other),
Constraint::Equal(other) => self.resolve(*other),
})
.collect();
}
seen
}

/// Instantiate all variables in the graph with the empty extension set, or
/// the smallest solution possible given their constraints.
/// This is done to solve metas which depend on variables, which allows
/// us to come up with a fully concrete solution to pass into validation.
///
/// Nodes which loop into themselves must be considered as a "minimum" set
/// of requirements. If we have
/// 1 = 2 + X, ...
/// 2 = 1 + x, ...
/// then 1 and 2 both definitely contain X, even if we don't know what else.
/// So instead of instantiating to the empty set, we'll instantiate to `{X}`
pub fn instantiate_variables(&mut self) {
for m in self.variables.clone().into_iter() {
// A directed graph to keep track of `Plus` constraint relationships
let mut relations = GraphContainer::<Directed>::new();
let mut solutions: HashMap<Meta, ExtensionSet> = HashMap::new();

let variable_scope = self.search_variable_deps();
for m in variable_scope.into_iter() {
// If `m` has been merged, [`self.variables`] entry
// will have already been updated to the merged
// value by [`self.merge_equal_metas`] so we don't
// need to worry about resolving it.
if !self.solved.contains_key(&m) {
// Handle the case where the constraints for `m` contain a self
// reference, i.e. "m = Plus(E, m)", in which case the variable
// should be instantiated to E rather than the empty set.
let solution = self
.get_constraints(&m)
.unwrap()
let plus_constraints =
self.get_constraints(&m)
.unwrap()
.iter()
.cloned()
acl-cqc marked this conversation as resolved.
Show resolved Hide resolved
.flat_map(|c| match c {
Constraint::Plus(r, other_m) => Some((r, self.resolve(other_m))),
_ => None,
});

let (rs, other_ms): (Vec<_>, Vec<_>) = plus_constraints.unzip();
let solution = rs.iter().fold(ExtensionSet::new(), |e1, e2| e1.union(e2));
let unresolved_metas = other_ms
.into_iter()
.filter(|other_m| m != *other_m)
.collect::<Vec<_>>();

// If `m` doesn't depend on any other metas then we have all the
// information we need to come up with a solution for it.
relations.add_or_retrieve(m);
unresolved_metas
.iter()
.filter_map(|c| match c {
// If `m` has been merged, [`self.variables`] entry
// will have already been updated to the merged
// value by [`self.merge_equal_metas`] so we don't
// need to worry about resolving it.
Constraint::Plus(x, other_m) if m == self.resolve(*other_m) => Some(x),
_ => None,
})
.fold(ExtensionSet::new(), ExtensionSet::union);
self.add_solution(m, solution);
.for_each(|other_m| relations.add_edge(m, *other_m));
solutions.insert(m, solution);
}
}
println!("{:?}", relations.node_map);
println!("{:?}", relations.graph);

// Process the strongly-connected components. We need to deal with these
// depended-upon before depender. ccs() gives them back in some order
// - this might need to be reversed????
for cc in relations.sccs() {
// Strongly connected components are looping constraint dependencies.
// This means that each metavariable in the CC has the same solution.
let combined_solution = cc
.iter()
.flat_map(|m| self.get_constraints(m).unwrap())
.filter_map(|c| match c {
Constraint::Plus(_, other_m) => solutions.get(&self.resolve(*other_m)),
Constraint::Equal(_) => None,
})
.fold(ExtensionSet::new(), |a, b| a.union(b));

for m in cc.iter() {
self.add_solution(*m, combined_solution.clone());
Copy link
Contributor

@acl-cqc acl-cqc Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to add to solutions here do we? What if we had, say, two SCCs with an edge between them:

A -> B -> C -> A
D -> E -> F -> D
A -> D

I note Petgraph's tarjan_scc returns the components in a defined order (reverse topsort or something), so in theory (maybe you have to reverse cc) it's soluble, but I think you need to consider not just solutions.get(m) for each m in the SCC but the solutions to the constraints upon m (which by that topsort ordering have already been computed).

Or, maybe such a structure can't occur, I dunno....????

Copy link
Contributor

@acl-cqc acl-cqc Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try branch inference-variable/fix-dependent-sccs (the last commit) - some uncertainties detailed in the comments there, and needs a test of a structure such as the ABCDEF above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooops, realize that wasn't passing tests! A missed self.resolve should make it pass now.

solutions.insert(*m, combined_solution.clone());
}
}
self.variables = HashSet::new();
Expand Down Expand Up @@ -1465,17 +1546,14 @@ mod test {
#[test]
fn test_cfg_loops() -> Result<(), Box<dyn Error>> {
let just_a = ExtensionSet::singleton(&A);
let variants = vec![
(
ExtensionSet::new(),
ExtensionSet::new(),
ExtensionSet::new(),
),
(just_a.clone(), ExtensionSet::new(), ExtensionSet::new()),
(ExtensionSet::new(), just_a.clone(), ExtensionSet::new()),
(ExtensionSet::new(), ExtensionSet::new(), just_a.clone()),
];

let mut variants = Vec::new();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hahaha! Neat :-) You might consider flat_map (*3) to avoid mut but like it either way :)

for entry in [ExtensionSet::new(), just_a.clone()] {
for bb1 in [ExtensionSet::new(), just_a.clone()] {
for bb2 in [ExtensionSet::new(), just_a.clone()] {
variants.push((entry.clone(), bb1.clone(), bb2.clone()));
}
}
}
for (bb0, bb1, bb2) in variants.into_iter() {
let mut hugr = make_looping_cfg(bb0, bb1, bb2)?;
hugr.update_validate(&PRELUDE_REGISTRY)?;
Expand Down Expand Up @@ -1581,4 +1659,27 @@ mod test {
fn plus_on_self_10_times() {
[0; 10].iter().for_each(|_| plus_on_self().unwrap())
}

#[test]
// Test that logic for dealing with self-referential constraints doesn't
// fall over when a self-referencing group of metas also references a meta
// outside the group
fn failing_sccs_test() {
let hugr = Hugr::default();
let mut ctx = UnificationContext::new(&hugr);
let m1 = ctx.fresh_meta();
let m2 = ctx.fresh_meta();
let m3 = ctx.fresh_meta();
// Outside of the connected component
let m_other = ctx.fresh_meta();
// These 3 metavariables form a loop
ctx.add_constraint(m1, Constraint::Plus(ExtensionSet::singleton(&A), m3));
ctx.add_constraint(m2, Constraint::Plus(ExtensionSet::singleton(&A), m1));
ctx.add_constraint(m3, Constraint::Plus(ExtensionSet::singleton(&A), m2));
// This other meta is outside the loop, but depended on by one of the loop metas
ctx.add_constraint(m2, Constraint::Plus(ExtensionSet::singleton(&A), m_other));
ctx.variables.insert(m1);
ctx.variables.insert(m_other);
ctx.instantiate_variables();
}
}