Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Common Subexpression Elimination for PhysicalExpr trees #13046

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion datafusion/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ arrow-schema = { workspace = true }
base64 = "0.22.1"
half = { workspace = true }
hashbrown = { workspace = true }
indexmap = { workspace = true }
libc = "0.2.140"
log = { workspace = true }
object_store = { workspace = true, optional = true }
Expand Down
4 changes: 4 additions & 0 deletions datafusion/common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ use std::error::Error;
use std::fmt::{self, Display};
use std::str::FromStr;

use crate::alias::AliasGenerator;
use crate::error::_config_err;
use crate::parsers::CompressionTypeVariant;
use crate::utils::get_available_parallelism;
use crate::{DataFusionError, Result};
use std::sync::Arc;

/// A macro that wraps a configuration struct and automatically derives
/// [`Default`] and [`ConfigField`] for it, allowing it to be used
Expand Down Expand Up @@ -736,6 +738,8 @@ pub struct ConfigOptions {
pub explain: ExplainOptions,
/// Optional extensions registered using [`Extensions::insert`]
pub extensions: Extensions,
/// Return alias generator used to generate unique aliases
pub alias_generator: Arc<AliasGenerator>,
}

impl ConfigField for ConfigOptions {
Expand Down
65 changes: 43 additions & 22 deletions datafusion/common/src/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ use crate::tree_node::{
TreeNodeVisitor,
};
use crate::Result;
use indexmap::IndexMap;
use std::collections::HashMap;
use std::hash::{BuildHasher, Hash, Hasher, RandomState};
use std::marker::PhantomData;
Expand Down Expand Up @@ -59,6 +58,12 @@ pub trait Normalizeable {
fn can_normalize(&self) -> bool;
}

impl<T: Normalizeable + ?Sized> Normalizeable for Arc<T> {
fn can_normalize(&self) -> bool {
(**self).can_normalize()
}
}

/// The `NormalizeEq` trait extends `Eq` and `Normalizeable` to provide a method for comparing
/// normalized nodes in optimizations like Common Subexpression Elimination (CSE).
///
Expand All @@ -71,6 +76,12 @@ pub trait NormalizeEq: Eq + Normalizeable {
fn normalize_eq(&self, other: &Self) -> bool;
}

impl<T: NormalizeEq + ?Sized> NormalizeEq for Arc<T> {
fn normalize_eq(&self, other: &Self) -> bool {
(**self).normalize_eq(other)
}
}

/// Identifier that represents a [`TreeNode`] tree.
///
/// This identifier is designed to be efficient and "hash", "accumulate", "equal" and
Expand Down Expand Up @@ -161,11 +172,13 @@ enum NodeEvaluation {
}

/// A map that contains the evaluation stats of [`TreeNode`]s by their identifiers.
type NodeStats<'n, N> = HashMap<Identifier<'n, N>, NodeEvaluation>;
/// It also contains the position of [`TreeNode`]s in [`CommonNodes`] once a node is
/// found to be common and got extracted.
type NodeStats<'n, N> = HashMap<Identifier<'n, N>, (NodeEvaluation, Option<usize>)>;

/// A map that contains the common [`TreeNode`]s and their alias by their identifiers,
/// extracted during the second, rewriting traversal.
type CommonNodes<'n, N> = IndexMap<Identifier<'n, N>, (N, String)>;
/// A list that contains the common [`TreeNode`]s and their alias, extracted during the
/// second, rewriting traversal.
type CommonNodes<'n, N> = Vec<(N, String)>;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for this change of type from IndexMap to Vec in CommonNodes is that physical columns works with indexes rather than names. E.g. when we repace a common subexpression to a column during rewrite, we need both the name and the index of the common subexpression in the intermediate ProjectionExec node. Storing the index in NodeStats and using Vec in CommonNodes better fits this usecase.


type ChildrenList<N> = (Vec<N>, Vec<N>);

Expand Down Expand Up @@ -193,7 +206,7 @@ pub trait CSEController {
fn generate_alias(&self) -> String;

// Replaces a node to the generated alias.
fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node;
fn rewrite(&mut self, node: &Self::Node, alias: &str, index: usize) -> Self::Node;

// A helper method called on each node during top-down traversal during the second,
// rewriting traversal of CSE.
Expand Down Expand Up @@ -394,7 +407,7 @@ where
self.id_array[down_index].1 = Some(node_id);
self.node_stats
.entry(node_id)
.and_modify(|evaluation| {
.and_modify(|(evaluation, _)| {
if *evaluation == NodeEvaluation::SurelyOnce
|| *evaluation == NodeEvaluation::ConditionallyAtLeastOnce
&& !self.conditional
Expand All @@ -404,11 +417,12 @@ where
}
})
.or_insert_with(|| {
if self.conditional {
let evaluation = if self.conditional {
NodeEvaluation::ConditionallyAtLeastOnce
} else {
NodeEvaluation::SurelyOnce
}
};
(evaluation, None)
});
}
self.visit_stack
Expand All @@ -428,7 +442,7 @@ where
C: CSEController<Node = N>,
{
/// statistics of [`TreeNode`]s
node_stats: &'a NodeStats<'n, N>,
node_stats: &'a mut NodeStats<'n, N>,

/// cache to speed up second traversal
id_array: &'a IdArray<'n, N>,
Expand Down Expand Up @@ -458,7 +472,7 @@ where

// Handle nodes with identifiers only
if let Some(node_id) = node_id {
let evaluation = self.node_stats.get(&node_id).unwrap();
let (evaluation, common_index) = self.node_stats.get_mut(&node_id).unwrap();
if *evaluation == NodeEvaluation::Common {
// step index to skip all sub-node (which has smaller series number).
while self.down_index < self.id_array.len()
Expand All @@ -482,13 +496,15 @@ where
//
// This way, we can efficiently handle semantically equivalent expressions without
// incorrectly treating them as identical.
let rewritten = if let Some((_, alias)) = self.common_nodes.get(&node_id)
{
self.controller.rewrite(&node, alias)
let rewritten = if let Some(index) = common_index {
let (_, alias) = self.common_nodes.get(*index).unwrap();
self.controller.rewrite(&node, alias, *index)
} else {
let node_alias = self.controller.generate_alias();
let rewritten = self.controller.rewrite(&node, &node_alias);
self.common_nodes.insert(node_id, (node, node_alias));
let index = self.common_nodes.len();
let alias = self.controller.generate_alias();
let rewritten = self.controller.rewrite(&node, &alias, index);
*common_index = Some(index);
self.common_nodes.push((node, alias));
rewritten
};

Expand Down Expand Up @@ -587,7 +603,7 @@ where
&mut self,
node: N,
id_array: &IdArray<'n, N>,
node_stats: &NodeStats<'n, N>,
node_stats: &mut NodeStats<'n, N>,
common_nodes: &mut CommonNodes<'n, N>,
) -> Result<N> {
if id_array.is_empty() {
Expand All @@ -610,7 +626,7 @@ where
&mut self,
nodes_list: Vec<Vec<N>>,
arrays_list: &[Vec<IdArray<'n, N>>],
node_stats: &NodeStats<'n, N>,
node_stats: &mut NodeStats<'n, N>,
common_nodes: &mut CommonNodes<'n, N>,
) -> Result<Vec<Vec<N>>> {
nodes_list
Expand Down Expand Up @@ -656,13 +672,13 @@ where
// nodes so we have to keep them intact.
nodes_list.clone(),
&id_arrays_list,
&node_stats,
&mut node_stats,
&mut common_nodes,
)?;
assert!(!common_nodes.is_empty());

Ok(FoundCommonNodes::Yes {
common_nodes: common_nodes.into_values().collect(),
common_nodes,
new_nodes_list,
original_nodes_list: nodes_list,
})
Expand Down Expand Up @@ -735,7 +751,12 @@ mod test {
self.alias_generator.next(CSE_PREFIX)
}

fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
fn rewrite(
&mut self,
node: &Self::Node,
alias: &str,
_index: usize,
) -> Self::Node {
TestTreeNode::new_leaf(format!("alias({}, {})", node.data, alias))
}
}
Expand Down
67 changes: 48 additions & 19 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ impl CSEController for ExprCSEController<'_> {
self.alias_generator.next(CSE_PREFIX)
}

fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
fn rewrite(&mut self, node: &Self::Node, alias: &str, _index: usize) -> Self::Node {
// alias the expressions without an `Alias` ancestor node
if self.alias_counter > 0 {
col(alias)
Expand Down Expand Up @@ -1030,10 +1030,14 @@ mod test {
fn subexpr_in_same_order() -> Result<()> {
let table_scan = test_table_scan()?;

let a = col("a");
let lit_1 = lit(1);
let _1_plus_a = lit_1 + a;

let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![
(lit(1) + col("a")).alias("first"),
(lit(1) + col("a")).alias("second"),
_1_plus_a.clone().alias("first"),
_1_plus_a.alias("second"),
])?
.build()?;

Expand All @@ -1050,8 +1054,13 @@ mod test {
fn subexpr_in_different_order() -> Result<()> {
let table_scan = test_table_scan()?;

let a = col("a");
let lit_1 = lit(1);
let _1_plus_a = lit_1.clone() + a.clone();
let a_plus_1 = a + lit_1;

let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![lit(1) + col("a"), col("a") + lit(1)])?
.project(vec![_1_plus_a, a_plus_1])?
.build()?;

let expected = "Projection: __common_expr_1 AS Int32(1) + test.a, __common_expr_1 AS test.a + Int32(1)\
Expand All @@ -1067,6 +1076,8 @@ mod test {
fn cross_plans_subexpr() -> Result<()> {
let table_scan = test_table_scan()?;

let _1_plus_col_a = lit(1) + col("a");

let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![lit(1) + col("a"), col("a")])?
.project(vec![lit(1) + col("a")])?
Expand Down Expand Up @@ -1284,10 +1295,13 @@ mod test {
fn test_short_circuits() -> Result<()> {
let table_scan = test_table_scan()?;

let extracted_short_circuit = col("a").eq(lit(0)).or(col("b").eq(lit(0)));
let extracted_short_circuit_leg_1 = (col("a") + col("b")).eq(lit(0));
let not_extracted_short_circuit_leg_2 = (col("a") - col("b")).eq(lit(0));
let extracted_short_circuit_leg_3 = (col("a") * col("b")).eq(lit(0));
let a = col("a");
let b = col("b");

let extracted_short_circuit = a.clone().eq(lit(0)).or(b.clone().eq(lit(0)));
let extracted_short_circuit_leg_1 = (a.clone() + b.clone()).eq(lit(0));
let not_extracted_short_circuit_leg_2 = (a.clone() - b.clone()).eq(lit(0));
let extracted_short_circuit_leg_3 = (a * b).eq(lit(0));
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![
extracted_short_circuit.clone().alias("c1"),
Expand Down Expand Up @@ -1319,9 +1333,12 @@ mod test {
fn test_volatile() -> Result<()> {
let table_scan = test_table_scan()?;

let extracted_child = col("a") + col("b");
let rand = rand_func().call(vec![]);
let a = col("a");
let b = col("b");
let extracted_child = a + b;
let rand = rand_expr();
let not_extracted_volatile = extracted_child + rand;

let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![
not_extracted_volatile.clone().alias("c1"),
Expand All @@ -1342,13 +1359,19 @@ mod test {
fn test_volatile_short_circuits() -> Result<()> {
let table_scan = test_table_scan()?;

let rand = rand_func().call(vec![]);
let extracted_short_circuit_leg_1 = col("a").eq(lit(0));
let a = col("a");
let b = col("b");
let rand = rand_expr();
let rand_eq_0 = rand.eq(lit(0));

let extracted_short_circuit_leg_1 = a.eq(lit(0));
let not_extracted_volatile_short_circuit_1 =
extracted_short_circuit_leg_1.or(rand.clone().eq(lit(0)));
let not_extracted_short_circuit_leg_2 = col("b").eq(lit(0));
extracted_short_circuit_leg_1.or(rand_eq_0.clone());

let not_extracted_short_circuit_leg_2 = b.eq(lit(0));
let not_extracted_volatile_short_circuit_2 =
rand.eq(lit(0)).or(not_extracted_short_circuit_leg_2);
rand_eq_0.or(not_extracted_short_circuit_leg_2);

let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![
not_extracted_volatile_short_circuit_1.clone().alias("c1"),
Expand All @@ -1371,7 +1394,10 @@ mod test {
fn test_non_top_level_common_expression() -> Result<()> {
let table_scan = test_table_scan()?;

let common_expr = col("a") + col("b");
let a = col("a");
let b = col("b");
let common_expr = a + b;

let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![
common_expr.clone().alias("c1"),
Expand All @@ -1394,8 +1420,11 @@ mod test {
fn test_nested_common_expression() -> Result<()> {
let table_scan = test_table_scan()?;

let nested_common_expr = col("a") + col("b");
let a = col("a");
let b = col("b");
let nested_common_expr = a + b;
let common_expr = nested_common_expr.clone() * nested_common_expr;

let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![
common_expr.clone().alias("c1"),
Expand Down Expand Up @@ -1671,8 +1700,8 @@ mod test {
///
/// Does not use datafusion_functions::rand to avoid introducing a
/// dependency on that crate.
fn rand_func() -> ScalarUDF {
ScalarUDF::new_from_impl(RandomStub::new())
fn rand_expr() -> Expr {
ScalarUDF::new_from_impl(RandomStub::new()).call(vec![])
}

#[derive(Debug)]
Expand Down
Loading