diff --git a/plot.py b/plot.py index 2fb2d73..027a5a7 100755 --- a/plot.py +++ b/plot.py @@ -65,6 +65,9 @@ def process(js, extractors): print(f"cumulative dag cost for {e1}: {sum(d[e1]['dag'] for d in by_name.values()):.0f}") print(f"cumulative dag cost for {e2}: {sum(d[e2]['dag'] for d in by_name.values()):.0f}") + print(f"Cumulative time for {e1}: {e1_cumulative/1000:.0f}ms") + print(f"Cumulative time for {e2}: {e2_cumulative/1000:.0f}ms") + print(f"{e1} / {e2}") print("geo mean") diff --git a/src/extract/faster_bottom_up.rs b/src/extract/faster_bottom_up.rs index d4c1d19..b10ff65 100644 --- a/src/extract/faster_bottom_up.rs +++ b/src/extract/faster_bottom_up.rs @@ -1,3 +1,5 @@ +use rustc_hash::{FxHashMap, FxHashSet}; + use super::*; /// A faster bottom up extractor inspired by the faster-greedy-dag extractor. @@ -16,7 +18,7 @@ pub struct BottomUpExtractor; impl Extractor for BottomUpExtractor { fn extract(&self, egraph: &EGraph, _roots: &[ClassId]) -> ExtractionResult { - let mut parents = IndexMap::>::default(); + let mut parents = IndexMap::>::with_capacity(egraph.classes().len()); let n2c = |nid: &NodeId| egraph.nid_to_cid(nid); let mut analysis_pending = UniqueQueue::default(); @@ -39,20 +41,20 @@ impl Extractor for BottomUpExtractor { } let mut result = ExtractionResult::default(); - let mut costs = IndexMap::::default(); + let mut costs = FxHashMap::::with_capacity_and_hasher( + egraph.classes().len(), + Default::default(), + ); while let Some(node_id) = analysis_pending.pop() { let class_id = n2c(&node_id); let node = &egraph[&node_id]; - if node.children.iter().all(|c| costs.contains_key(n2c(c))) { - let prev_cost = costs.get(class_id).unwrap_or(&INFINITY); - - let cost = result.node_sum_cost(egraph, node, &costs); - if cost < *prev_cost { - result.choose(class_id.clone(), node_id.clone()); - costs.insert(class_id.clone(), cost); - analysis_pending.extend(parents[class_id].iter().cloned()); - } + let prev_cost = costs.get(class_id).unwrap_or(&INFINITY); + let cost = result.node_sum_cost(egraph, node, &costs); + if cost < *prev_cost { + result.choose(class_id.clone(), node_id.clone()); + costs.insert(class_id.clone(), cost); + analysis_pending.extend(parents[class_id].iter().cloned()); } } @@ -64,7 +66,7 @@ impl Extractor for BottomUpExtractor { Notably, insert/pop operations have O(1) expected amortized runtime complexity. -Thanks Trevor for the implementation! +Thanks @Bastacyclop for the implementation! */ #[derive(Clone)] #[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))] @@ -72,7 +74,7 @@ pub(crate) struct UniqueQueue where T: Eq + std::hash::Hash + Clone, { - set: std::collections::HashSet, // hashbrown:: + set: FxHashSet, // hashbrown:: queue: std::collections::VecDeque, } @@ -82,7 +84,7 @@ where { fn default() -> Self { UniqueQueue { - set: std::collections::HashSet::default(), + set: Default::default(), queue: std::collections::VecDeque::new(), } } diff --git a/src/extract/mod.rs b/src/extract/mod.rs index 0b070b1..54a3753 100644 --- a/src/extract/mod.rs +++ b/src/extract/mod.rs @@ -1,3 +1,5 @@ +use indexmap::IndexMap; +use rustc_hash::FxHashMap; use std::collections::HashMap; pub use crate::*; @@ -22,6 +24,37 @@ pub trait Extractor: Sync { } } +pub trait MapGet { + fn get(&self, key: &K) -> Option<&V>; +} + +impl MapGet for HashMap +where + K: Eq + std::hash::Hash, +{ + fn get(&self, key: &K) -> Option<&V> { + HashMap::get(self, key) + } +} + +impl MapGet for FxHashMap +where + K: Eq + std::hash::Hash, +{ + fn get(&self, key: &K) -> Option<&V> { + FxHashMap::get(self, key) + } +} + +impl MapGet for IndexMap +where + K: Eq + std::hash::Hash, +{ + fn get(&self, key: &K) -> Option<&V> { + IndexMap::get(self, key) + } +} + #[derive(Default, Clone)] pub struct ExtractionResult { pub choices: IndexMap, @@ -118,12 +151,10 @@ impl ExtractionResult { costs.values().sum() } - pub fn node_sum_cost( - &self, - egraph: &EGraph, - node: &Node, - costs: &IndexMap, - ) -> Cost { + pub fn node_sum_cost(&self, egraph: &EGraph, node: &Node, costs: &M) -> Cost + where + M: MapGet, + { node.cost + node .children diff --git a/src/main.rs b/src/main.rs index b3a45b3..19463cb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -20,6 +20,10 @@ fn main() { let extractors: IndexMap<&str, Box> = [ ("bottom-up", extract::bottom_up::BottomUpExtractor.boxed()), + ( + "faster-bottom-up", + extract::faster_bottom_up::BottomUpExtractor.boxed(), + ), ( "greedy-dag", extract::greedy_dag::GreedyDagExtractor.boxed(),