Skip to content

Commit

Permalink
refactor: Make EqGraph generic over directedness
Browse files Browse the repository at this point in the history
  • Loading branch information
croyzor committed Oct 5, 2023
1 parent 09494f1 commit 0d02f8b
Showing 1 changed file with 61 additions and 33 deletions.
94 changes: 61 additions & 33 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::{
use super::validate::ExtensionError;

use petgraph::graph as pg;
use petgraph::{Directed, EdgeType, Undirected};

use std::collections::{HashMap, HashSet};

Expand Down Expand Up @@ -109,48 +110,71 @@ pub enum InferExtensionError {

/// A graph of metavariables which we've found equality constraints for. Edges
/// between nodes represent equality constraints.
struct EqGraph {
equalities: pg::Graph<Meta, (), petgraph::Undirected>,
struct GraphContainer<Dir: EdgeType> {
graph: pg::Graph<Meta, (), Dir>,
node_map: HashMap<Meta, pg::NodeIndex>,
}

impl EqGraph {
/// Create a new `EqGraph`
fn new() -> Self {
EqGraph {
equalities: pg::Graph::new_undirected(),
node_map: HashMap::new(),
macro_rules! impl_graph_container {
($dir:ty) => {
impl GraphContainer<$dir> {
/// Add a metavariable to the graph as a node and return the `NodeIndex`.
/// If it's already there, just return the existing `NodeIndex`
fn add_or_retrieve(&mut self, m: Meta) -> pg::NodeIndex {
self.node_map.get(&m).cloned().unwrap_or_else(|| {
let ix = self.graph.add_node(m);
self.node_map.insert(m, ix);
ix
})
}

/// Create an edge between two nodes on the graph
fn add_edge(&mut self, src: Meta, tgt: Meta) {
let src_ix = self.add_or_retrieve(src);
let tgt_ix = self.add_or_retrieve(tgt);
self.graph.add_edge(src_ix, tgt_ix, ());
}

/// Return the connected components of the graph in terms of metavariables
fn ccs(&self) -> Vec<Vec<Meta>> {
petgraph::algo::tarjan_scc(&self.graph)
.into_iter()
.map(|cc| {
cc.into_iter()
.map(|n| *self.graph.node_weight(n).unwrap())
.collect()
})
.collect()
}
}
}
}

impl_graph_container!(Directed);
impl_graph_container!(Undirected);

/// Add a metavariable to the graph as a node and return the `NodeIndex`.
/// If it's already there, just return the existing `NodeIndex`
fn add_or_retrieve(&mut self, m: Meta) -> pg::NodeIndex {
self.node_map.get(&m).cloned().unwrap_or_else(|| {
let ix = self.equalities.add_node(m);
self.node_map.insert(m, ix);
ix
})
impl GraphContainer<Undirected> {
fn new_undirected() -> Self {
GraphContainer {
graph: pg::Graph::new_undirected(),
node_map: HashMap::new(),
}
}
}

/// Create an edge between two nodes on the graph, declaring that they stand
/// for metavariables which should be equal.
fn register_eq(&mut self, src: Meta, tgt: Meta) {
let src_ix = self.add_or_retrieve(src);
let tgt_ix = self.add_or_retrieve(tgt);
self.equalities.add_edge(src_ix, tgt_ix, ());
impl GraphContainer<Directed> {
fn new_directed() -> Self {
GraphContainer {
graph: pg::Graph::new(),
node_map: HashMap::new(),
}
}
}

/// Return the connected components of the graph in terms of metavariables
fn ccs(&self) -> Vec<Vec<Meta>> {
petgraph::algo::tarjan_scc(&self.equalities)
.into_iter()
.map(|cc| {
cc.into_iter()
.map(|n| *self.equalities.node_weight(n).unwrap())
.collect()
})
.collect()
type EqGraph = GraphContainer<Undirected>;
impl EqGraph {
fn new() -> Self {
EqGraph::new_undirected()
}
}

Expand Down Expand Up @@ -507,7 +531,7 @@ impl UnificationContext {
match c {
// Just register the equality in the EqGraph, we'll process it later
Constraint::Equal(other_meta) => {
self.eq_graph.register_eq(meta, *other_meta);
self.eq_graph.add_edge(meta, *other_meta);
}
// N.B. If `meta` is already solved, we can't use that
// information to solve `other_meta`. This is because the Plus
Expand Down Expand Up @@ -672,6 +696,8 @@ impl UnificationContext {
/// This is done to solve metas which depend on variables, which allows
/// us to come up with a fully concrete solution to pass into validation.
pub fn instantiate_variables(&mut self) {
let minimum_sets: HashMap<Meta, ExtensionSet> = HashMap::new();

for m in self.variables.clone().into_iter() {
if !self.solved.contains_key(&m) {
// Handle the case where the constraints for `m` contain a self
Expand All @@ -691,6 +717,7 @@ impl UnificationContext {
},
));
self.add_solution(m, solution);
//minimum_sets.insert(m, solution);
}
}
self.variables = HashSet::new();
Expand Down Expand Up @@ -1521,6 +1548,7 @@ mod test {
(just_a.clone(), ExtensionSet::new(), ExtensionSet::new()),
(ExtensionSet::new(), just_a.clone(), ExtensionSet::new()),
(ExtensionSet::new(), ExtensionSet::new(), just_a.clone()),
(ExtensionSet::new(), just_a.clone(), just_a.clone()),
];

for (bb0, bb1, bb2) in variants.into_iter() {
Expand Down

0 comments on commit 0d02f8b

Please sign in to comment.