Skip to content

Commit

Permalink
Merge pull request #9 from TrevorHansen/om
Browse files Browse the repository at this point in the history
An extra bottom-up recursive extractor
  • Loading branch information
oflatt authored Dec 14, 2023
2 parents c9239c2 + 6eaf8f8 commit 8b33f88
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 20 deletions.
3 changes: 3 additions & 0 deletions plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def process(js, extractors):
print(f"cumulative dag cost for {e1}: {sum(d[e1]['dag'] for d in by_name.values()):.0f}")
print(f"cumulative dag cost for {e2}: {sum(d[e2]['dag'] for d in by_name.values()):.0f}")

print(f"Cumulative time for {e1}: {e1_cumulative/1000:.0f}ms")
print(f"Cumulative time for {e2}: {e2_cumulative/1000:.0f}ms")

print(f"{e1} / {e2}")

print("geo mean")
Expand Down
30 changes: 16 additions & 14 deletions src/extract/faster_bottom_up.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use rustc_hash::{FxHashMap, FxHashSet};

use super::*;

/// A faster bottom up extractor inspired by the faster-greedy-dag extractor.
Expand All @@ -16,7 +18,7 @@ pub struct BottomUpExtractor;

impl Extractor for BottomUpExtractor {
fn extract(&self, egraph: &EGraph, _roots: &[ClassId]) -> ExtractionResult {
let mut parents = IndexMap::<ClassId, Vec<NodeId>>::default();
let mut parents = IndexMap::<ClassId, Vec<NodeId>>::with_capacity(egraph.classes().len());
let n2c = |nid: &NodeId| egraph.nid_to_cid(nid);
let mut analysis_pending = UniqueQueue::default();

Expand All @@ -39,20 +41,20 @@ impl Extractor for BottomUpExtractor {
}

let mut result = ExtractionResult::default();
let mut costs = IndexMap::<ClassId, Cost>::default();
let mut costs = FxHashMap::<ClassId, Cost>::with_capacity_and_hasher(
egraph.classes().len(),
Default::default(),
);

while let Some(node_id) = analysis_pending.pop() {
let class_id = n2c(&node_id);
let node = &egraph[&node_id];
if node.children.iter().all(|c| costs.contains_key(n2c(c))) {
let prev_cost = costs.get(class_id).unwrap_or(&INFINITY);

let cost = result.node_sum_cost(egraph, node, &costs);
if cost < *prev_cost {
result.choose(class_id.clone(), node_id.clone());
costs.insert(class_id.clone(), cost);
analysis_pending.extend(parents[class_id].iter().cloned());
}
let prev_cost = costs.get(class_id).unwrap_or(&INFINITY);
let cost = result.node_sum_cost(egraph, node, &costs);
if cost < *prev_cost {
result.choose(class_id.clone(), node_id.clone());
costs.insert(class_id.clone(), cost);
analysis_pending.extend(parents[class_id].iter().cloned());
}
}

Expand All @@ -64,15 +66,15 @@ impl Extractor for BottomUpExtractor {
Notably, insert/pop operations have O(1) expected amortized runtime complexity.
Thanks Trevor for the implementation!
Thanks @Bastacyclop for the implementation!
*/
#[derive(Clone)]
#[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))]
pub(crate) struct UniqueQueue<T>
where
T: Eq + std::hash::Hash + Clone,
{
set: std::collections::HashSet<T>, // hashbrown::
set: FxHashSet<T>, // hashbrown::
queue: std::collections::VecDeque<T>,
}

Expand All @@ -82,7 +84,7 @@ where
{
fn default() -> Self {
UniqueQueue {
set: std::collections::HashSet::default(),
set: Default::default(),
queue: std::collections::VecDeque::new(),
}
}
Expand Down
43 changes: 37 additions & 6 deletions src/extract/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use indexmap::IndexMap;
use rustc_hash::FxHashMap;
use std::collections::HashMap;

pub use crate::*;
Expand All @@ -22,6 +24,37 @@ pub trait Extractor: Sync {
}
}

pub trait MapGet<K, V> {
fn get(&self, key: &K) -> Option<&V>;
}

impl<K, V> MapGet<K, V> for HashMap<K, V>
where
K: Eq + std::hash::Hash,
{
fn get(&self, key: &K) -> Option<&V> {
HashMap::get(self, key)
}
}

impl<K, V> MapGet<K, V> for FxHashMap<K, V>
where
K: Eq + std::hash::Hash,
{
fn get(&self, key: &K) -> Option<&V> {
FxHashMap::get(self, key)
}
}

impl<K, V> MapGet<K, V> for IndexMap<K, V>
where
K: Eq + std::hash::Hash,
{
fn get(&self, key: &K) -> Option<&V> {
IndexMap::get(self, key)
}
}

#[derive(Default, Clone)]
pub struct ExtractionResult {
pub choices: IndexMap<ClassId, NodeId>,
Expand Down Expand Up @@ -118,12 +151,10 @@ impl ExtractionResult {
costs.values().sum()
}

pub fn node_sum_cost(
&self,
egraph: &EGraph,
node: &Node,
costs: &IndexMap<ClassId, Cost>,
) -> Cost {
pub fn node_sum_cost<M>(&self, egraph: &EGraph, node: &Node, costs: &M) -> Cost
where
M: MapGet<ClassId, Cost>,
{
node.cost
+ node
.children
Expand Down
4 changes: 4 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ fn main() {

let extractors: IndexMap<&str, Box<dyn Extractor>> = [
("bottom-up", extract::bottom_up::BottomUpExtractor.boxed()),
(
"faster-bottom-up",
extract::faster_bottom_up::BottomUpExtractor.boxed(),
),
(
"greedy-dag",
extract::greedy_dag::GreedyDagExtractor.boxed(),
Expand Down

0 comments on commit 8b33f88

Please sign in to comment.