From eee2ee5483323193639aea8dffb7572c10f5f4d7 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 5 Nov 2024 20:01:56 +0100 Subject: [PATCH] Implement `PhysicalExpr` CSE --- Cargo.lock | 1 - datafusion/common/Cargo.toml | 1 - datafusion/common/src/config.rs | 4 + datafusion/common/src/cse.rs | 65 +- .../optimizer/src/common_subexpr_eliminate.rs | 67 +- .../physical-expr-common/src/physical_expr.rs | 40 +- .../physical-expr-common/src/sort_expr.rs | 4 + .../physical-expr/src/expressions/binary.rs | 12 +- .../physical-expr/src/expressions/case.rs | 9 +- .../physical-expr/src/expressions/cast.rs | 12 +- .../physical-expr/src/expressions/column.rs | 10 +- .../physical-expr/src/expressions/in_list.rs | 8 + .../src/expressions/is_not_null.rs | 9 +- .../physical-expr/src/expressions/is_null.rs | 9 +- .../physical-expr/src/expressions/like.rs | 12 +- .../physical-expr/src/expressions/literal.rs | 9 +- .../physical-expr/src/expressions/negative.rs | 9 +- .../physical-expr/src/expressions/no_op.rs | 7 +- .../physical-expr/src/expressions/not.rs | 9 +- .../physical-expr/src/expressions/try_cast.rs | 11 +- .../src/expressions/unknown_column.rs | 7 + .../physical-expr/src/scalar_function.rs | 17 +- .../src/eliminate_common_physical_subexprs.rs | 589 ++++++++++++++++++ datafusion/physical-optimizer/src/lib.rs | 1 + .../tests/cases/roundtrip_physical_plan.rs | 10 +- 25 files changed, 864 insertions(+), 68 deletions(-) create mode 100644 datafusion/physical-optimizer/src/eliminate_common_physical_subexprs.rs diff --git a/Cargo.lock b/Cargo.lock index cb77384cb3711..f7ae551f74e61 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1886,7 +1886,6 @@ dependencies = [ "chrono", "half", "hashbrown 0.14.5", - "indexmap 2.7.1", "libc", "log", "object_store", diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 215a06e81c3dc..359e426f6ded9 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -56,7 +56,6 @@ arrow-schema = { workspace = true } base64 = "0.22.1" half = { workspace = true } hashbrown = { workspace = true } -indexmap = { workspace = true } libc = "0.2.140" log = { workspace = true } object_store = { workspace = true, optional = true } diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index c9900204b97f2..ac22337924881 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -23,10 +23,12 @@ use std::error::Error; use std::fmt::{self, Display}; use std::str::FromStr; +use crate::alias::AliasGenerator; use crate::error::_config_err; use crate::parsers::CompressionTypeVariant; use crate::utils::get_available_parallelism; use crate::{DataFusionError, Result}; +use std::sync::Arc; /// A macro that wraps a configuration struct and automatically derives /// [`Default`] and [`ConfigField`] for it, allowing it to be used @@ -736,6 +738,8 @@ pub struct ConfigOptions { pub explain: ExplainOptions, /// Optional extensions registered using [`Extensions::insert`] pub extensions: Extensions, + /// Return alias generator used to generate unique aliases + pub alias_generator: Arc, } impl ConfigField for ConfigOptions { diff --git a/datafusion/common/src/cse.rs b/datafusion/common/src/cse.rs index 674d3386171f8..a0c9ff2dae245 100644 --- a/datafusion/common/src/cse.rs +++ b/datafusion/common/src/cse.rs @@ -25,7 +25,6 @@ use crate::tree_node::{ TreeNodeVisitor, }; use crate::Result; -use indexmap::IndexMap; use std::collections::HashMap; use std::hash::{BuildHasher, Hash, Hasher, RandomState}; use std::marker::PhantomData; @@ -59,6 +58,12 @@ pub trait Normalizeable { fn can_normalize(&self) -> bool; } +impl Normalizeable for Arc { + fn can_normalize(&self) -> bool { + (**self).can_normalize() + } +} + /// The `NormalizeEq` trait extends `Eq` and `Normalizeable` to provide a method for comparing /// normalized nodes in optimizations like Common Subexpression Elimination (CSE). /// @@ -71,6 +76,12 @@ pub trait NormalizeEq: Eq + Normalizeable { fn normalize_eq(&self, other: &Self) -> bool; } +impl NormalizeEq for Arc { + fn normalize_eq(&self, other: &Self) -> bool { + (**self).normalize_eq(other) + } +} + /// Identifier that represents a [`TreeNode`] tree. /// /// This identifier is designed to be efficient and "hash", "accumulate", "equal" and @@ -161,11 +172,13 @@ enum NodeEvaluation { } /// A map that contains the evaluation stats of [`TreeNode`]s by their identifiers. -type NodeStats<'n, N> = HashMap, NodeEvaluation>; +/// It also contains the position of [`TreeNode`]s in [`CommonNodes`] once a node is +/// found to be common and got extracted. +type NodeStats<'n, N> = HashMap, (NodeEvaluation, Option)>; -/// A map that contains the common [`TreeNode`]s and their alias by their identifiers, -/// extracted during the second, rewriting traversal. -type CommonNodes<'n, N> = IndexMap, (N, String)>; +/// A list that contains the common [`TreeNode`]s and their alias, extracted during the +/// second, rewriting traversal. +type CommonNodes<'n, N> = Vec<(N, String)>; type ChildrenList = (Vec, Vec); @@ -193,7 +206,7 @@ pub trait CSEController { fn generate_alias(&self) -> String; // Replaces a node to the generated alias. - fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node; + fn rewrite(&mut self, node: &Self::Node, alias: &str, index: usize) -> Self::Node; // A helper method called on each node during top-down traversal during the second, // rewriting traversal of CSE. @@ -394,7 +407,7 @@ where self.id_array[down_index].1 = Some(node_id); self.node_stats .entry(node_id) - .and_modify(|evaluation| { + .and_modify(|(evaluation, _)| { if *evaluation == NodeEvaluation::SurelyOnce || *evaluation == NodeEvaluation::ConditionallyAtLeastOnce && !self.conditional @@ -404,11 +417,12 @@ where } }) .or_insert_with(|| { - if self.conditional { + let evaluation = if self.conditional { NodeEvaluation::ConditionallyAtLeastOnce } else { NodeEvaluation::SurelyOnce - } + }; + (evaluation, None) }); } self.visit_stack @@ -428,7 +442,7 @@ where C: CSEController, { /// statistics of [`TreeNode`]s - node_stats: &'a NodeStats<'n, N>, + node_stats: &'a mut NodeStats<'n, N>, /// cache to speed up second traversal id_array: &'a IdArray<'n, N>, @@ -458,7 +472,7 @@ where // Handle nodes with identifiers only if let Some(node_id) = node_id { - let evaluation = self.node_stats.get(&node_id).unwrap(); + let (evaluation, common_index) = self.node_stats.get_mut(&node_id).unwrap(); if *evaluation == NodeEvaluation::Common { // step index to skip all sub-node (which has smaller series number). while self.down_index < self.id_array.len() @@ -482,13 +496,15 @@ where // // This way, we can efficiently handle semantically equivalent expressions without // incorrectly treating them as identical. - let rewritten = if let Some((_, alias)) = self.common_nodes.get(&node_id) - { - self.controller.rewrite(&node, alias) + let rewritten = if let Some(index) = common_index { + let (_, alias) = self.common_nodes.get(*index).unwrap(); + self.controller.rewrite(&node, alias, *index) } else { - let node_alias = self.controller.generate_alias(); - let rewritten = self.controller.rewrite(&node, &node_alias); - self.common_nodes.insert(node_id, (node, node_alias)); + let index = self.common_nodes.len(); + let alias = self.controller.generate_alias(); + let rewritten = self.controller.rewrite(&node, &alias, index); + *common_index = Some(index); + self.common_nodes.push((node, alias)); rewritten }; @@ -587,7 +603,7 @@ where &mut self, node: N, id_array: &IdArray<'n, N>, - node_stats: &NodeStats<'n, N>, + node_stats: &mut NodeStats<'n, N>, common_nodes: &mut CommonNodes<'n, N>, ) -> Result { if id_array.is_empty() { @@ -610,7 +626,7 @@ where &mut self, nodes_list: Vec>, arrays_list: &[Vec>], - node_stats: &NodeStats<'n, N>, + node_stats: &mut NodeStats<'n, N>, common_nodes: &mut CommonNodes<'n, N>, ) -> Result>> { nodes_list @@ -656,13 +672,13 @@ where // nodes so we have to keep them intact. nodes_list.clone(), &id_arrays_list, - &node_stats, + &mut node_stats, &mut common_nodes, )?; assert!(!common_nodes.is_empty()); Ok(FoundCommonNodes::Yes { - common_nodes: common_nodes.into_values().collect(), + common_nodes, new_nodes_list, original_nodes_list: nodes_list, }) @@ -735,7 +751,12 @@ mod test { self.alias_generator.next(CSE_PREFIX) } - fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node { + fn rewrite( + &mut self, + node: &Self::Node, + alias: &str, + _index: usize, + ) -> Self::Node { TestTreeNode::new_leaf(format!("alias({}, {})", node.data, alias)) } } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 4b9a83fd3e4c0..c69a14b99c495 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -699,7 +699,7 @@ impl CSEController for ExprCSEController<'_> { self.alias_generator.next(CSE_PREFIX) } - fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node { + fn rewrite(&mut self, node: &Self::Node, alias: &str, _index: usize) -> Self::Node { // alias the expressions without an `Alias` ancestor node if self.alias_counter > 0 { col(alias) @@ -1030,10 +1030,14 @@ mod test { fn subexpr_in_same_order() -> Result<()> { let table_scan = test_table_scan()?; + let a = col("a"); + let lit_1 = lit(1); + let _1_plus_a = lit_1 + a; + let plan = LogicalPlanBuilder::from(table_scan) .project(vec![ - (lit(1) + col("a")).alias("first"), - (lit(1) + col("a")).alias("second"), + _1_plus_a.clone().alias("first"), + _1_plus_a.alias("second"), ])? .build()?; @@ -1050,8 +1054,13 @@ mod test { fn subexpr_in_different_order() -> Result<()> { let table_scan = test_table_scan()?; + let a = col("a"); + let lit_1 = lit(1); + let _1_plus_a = lit_1.clone() + a.clone(); + let a_plus_1 = a + lit_1; + let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![lit(1) + col("a"), col("a") + lit(1)])? + .project(vec![_1_plus_a, a_plus_1])? .build()?; let expected = "Projection: __common_expr_1 AS Int32(1) + test.a, __common_expr_1 AS test.a + Int32(1)\ @@ -1067,6 +1076,8 @@ mod test { fn cross_plans_subexpr() -> Result<()> { let table_scan = test_table_scan()?; + let _1_plus_col_a = lit(1) + col("a"); + let plan = LogicalPlanBuilder::from(table_scan) .project(vec![lit(1) + col("a"), col("a")])? .project(vec![lit(1) + col("a")])? @@ -1284,10 +1295,13 @@ mod test { fn test_short_circuits() -> Result<()> { let table_scan = test_table_scan()?; - let extracted_short_circuit = col("a").eq(lit(0)).or(col("b").eq(lit(0))); - let extracted_short_circuit_leg_1 = (col("a") + col("b")).eq(lit(0)); - let not_extracted_short_circuit_leg_2 = (col("a") - col("b")).eq(lit(0)); - let extracted_short_circuit_leg_3 = (col("a") * col("b")).eq(lit(0)); + let a = col("a"); + let b = col("b"); + + let extracted_short_circuit = a.clone().eq(lit(0)).or(b.clone().eq(lit(0))); + let extracted_short_circuit_leg_1 = (a.clone() + b.clone()).eq(lit(0)); + let not_extracted_short_circuit_leg_2 = (a.clone() - b.clone()).eq(lit(0)); + let extracted_short_circuit_leg_3 = (a * b).eq(lit(0)); let plan = LogicalPlanBuilder::from(table_scan) .project(vec![ extracted_short_circuit.clone().alias("c1"), @@ -1319,9 +1333,12 @@ mod test { fn test_volatile() -> Result<()> { let table_scan = test_table_scan()?; - let extracted_child = col("a") + col("b"); - let rand = rand_func().call(vec![]); + let a = col("a"); + let b = col("b"); + let extracted_child = a + b; + let rand = rand_expr(); let not_extracted_volatile = extracted_child + rand; + let plan = LogicalPlanBuilder::from(table_scan) .project(vec![ not_extracted_volatile.clone().alias("c1"), @@ -1342,13 +1359,19 @@ mod test { fn test_volatile_short_circuits() -> Result<()> { let table_scan = test_table_scan()?; - let rand = rand_func().call(vec![]); - let extracted_short_circuit_leg_1 = col("a").eq(lit(0)); + let a = col("a"); + let b = col("b"); + let rand = rand_expr(); + let rand_eq_0 = rand.eq(lit(0)); + + let extracted_short_circuit_leg_1 = a.eq(lit(0)); let not_extracted_volatile_short_circuit_1 = - extracted_short_circuit_leg_1.or(rand.clone().eq(lit(0))); - let not_extracted_short_circuit_leg_2 = col("b").eq(lit(0)); + extracted_short_circuit_leg_1.or(rand_eq_0.clone()); + + let not_extracted_short_circuit_leg_2 = b.eq(lit(0)); let not_extracted_volatile_short_circuit_2 = - rand.eq(lit(0)).or(not_extracted_short_circuit_leg_2); + rand_eq_0.or(not_extracted_short_circuit_leg_2); + let plan = LogicalPlanBuilder::from(table_scan) .project(vec![ not_extracted_volatile_short_circuit_1.clone().alias("c1"), @@ -1371,7 +1394,10 @@ mod test { fn test_non_top_level_common_expression() -> Result<()> { let table_scan = test_table_scan()?; - let common_expr = col("a") + col("b"); + let a = col("a"); + let b = col("b"); + let common_expr = a + b; + let plan = LogicalPlanBuilder::from(table_scan) .project(vec![ common_expr.clone().alias("c1"), @@ -1394,8 +1420,11 @@ mod test { fn test_nested_common_expression() -> Result<()> { let table_scan = test_table_scan()?; - let nested_common_expr = col("a") + col("b"); + let a = col("a"); + let b = col("b"); + let nested_common_expr = a + b; let common_expr = nested_common_expr.clone() * nested_common_expr; + let plan = LogicalPlanBuilder::from(table_scan) .project(vec![ common_expr.clone().alias("c1"), @@ -1671,8 +1700,8 @@ mod test { /// /// Does not use datafusion_functions::rand to avoid introducing a /// dependency on that crate. - fn rand_func() -> ScalarUDF { - ScalarUDF::new_from_impl(RandomStub::new()) + fn rand_expr() -> Expr { + ScalarUDF::new_from_impl(RandomStub::new()).call(vec![]) } #[derive(Debug)] diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index b1b889136b35f..f7eae53c5346a 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -26,6 +26,7 @@ use arrow::array::BooleanArray; use arrow::compute::filter_record_batch; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; +use datafusion_common::cse::{HashNode, NormalizeEq, Normalizeable}; use datafusion_common::{internal_err, not_impl_err, Result}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; @@ -55,7 +56,9 @@ pub type PhysicalExprRef = Arc; /// [`Expr`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html /// [`create_physical_expr`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/fn.create_physical_expr.html /// [`Column`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/expressions/struct.Column.html -pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash { +pub trait PhysicalExpr: + Send + Sync + Display + Debug + DynEq + DynHash + DynHashNode +{ /// Returns the physical expression as [`Any`] so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; @@ -152,6 +155,10 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash { fn get_properties(&self, _children: &[ExprProperties]) -> Result { Ok(ExprProperties::new_unknown()) } + + fn is_volatile(&self) -> bool { + false + } } /// [`PhysicalExpr`] can't be constrained by [`Eq`] directly because it must remain object @@ -162,7 +169,7 @@ pub trait DynEq { impl DynEq for T { fn dyn_eq(&self, other: &dyn Any) -> bool { - other.downcast_ref::() == Some(self) + other.downcast_ref::().is_some_and(|o| o == self) } } @@ -194,6 +201,35 @@ impl Hash for dyn PhysicalExpr { } } +pub trait DynHashNode { + fn dyn_hash_node(&self, state: &mut dyn Hasher); +} + +impl DynHashNode for T { + fn dyn_hash_node(&self, mut state: &mut dyn Hasher) { + self.type_id().hash(&mut state); + self.hash_node(&mut state) + } +} + +impl HashNode for dyn PhysicalExpr { + fn hash_node(&self, state: &mut H) { + self.dyn_hash_node(state); + } +} + +impl Normalizeable for dyn PhysicalExpr { + fn can_normalize(&self) -> bool { + false + } +} + +impl NormalizeEq for dyn PhysicalExpr { + fn normalize_eq(&self, other: &Self) -> bool { + self == other + } +} + /// Returns a copy of this expr if we change any child according to the pointer comparison. /// The size of `children` must be equal to the size of `PhysicalExpr::children()`. pub fn with_new_children_if_necessary( diff --git a/datafusion/physical-expr-common/src/sort_expr.rs b/datafusion/physical-expr-common/src/sort_expr.rs index b150d3dc9bd38..2533dc55898f7 100644 --- a/datafusion/physical-expr-common/src/sort_expr.rs +++ b/datafusion/physical-expr-common/src/sort_expr.rs @@ -42,6 +42,7 @@ use itertools::Itertools; /// # use std::sync::Arc; /// # use arrow::array::RecordBatch; /// # use datafusion_common::Result; +/// # use datafusion_common::cse::HashNode; /// # use arrow::compute::SortOptions; /// # use arrow::datatypes::{DataType, Schema}; /// # use datafusion_expr_common::columnar_value::ColumnarValue; @@ -62,6 +63,9 @@ use itertools::Itertools; /// # impl Display for MyPhysicalExpr { /// # fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "a") } /// # } +/// # impl HashNode for MyPhysicalExpr { +/// # fn hash_node(&self, _state: &mut H) {} +/// # } /// # fn col(name: &str) -> Arc { Arc::new(MyPhysicalExpr) } /// // Sort by a ASC /// let options = SortOptions::default(); diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 1713842f410ef..0ed938b52cd0b 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -17,7 +17,7 @@ mod kernels; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; @@ -32,6 +32,7 @@ use arrow::compute::{cast, ilike, like, nilike, nlike}; use arrow::datatypes::*; use arrow_schema::ArrowError; use datafusion_common::cast::as_boolean_array; +use datafusion_common::cse::HashNode; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::binary::BinaryTypeCoercer; use datafusion_expr::interval_arithmetic::{apply_operator, Interval}; @@ -66,7 +67,7 @@ impl PartialEq for BinaryExpr { } } impl Hash for BinaryExpr { - fn hash(&self, state: &mut H) { + fn hash(&self, state: &mut H) { self.left.hash(state); self.op.hash(state); self.right.hash(state); @@ -74,6 +75,13 @@ impl Hash for BinaryExpr { } } +impl HashNode for BinaryExpr { + fn hash_node(&self, state: &mut H) { + self.op.hash(state); + self.fail_on_overflow.hash(state); + } +} + impl BinaryExpr { /// Create new binary expression pub fn new( diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 78606f05ae817..0388cab9693a4 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -16,7 +16,7 @@ // under the License. use std::borrow::Cow; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; use crate::expressions::try_cast; @@ -33,6 +33,7 @@ use datafusion_common::{ use datafusion_expr::ColumnarValue; use super::{Column, Literal}; +use datafusion_common::cse::HashNode; use datafusion_physical_expr_common::datum::compare_with_eq; use itertools::Itertools; @@ -98,6 +99,12 @@ pub struct CaseExpr { eval_method: EvalMethod, } +impl HashNode for CaseExpr { + fn hash_node(&self, state: &mut H) { + self.eval_method.hash(state); + } +} + impl std::fmt::Display for CaseExpr { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "CASE ")?; diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 8a093e0ae92ea..cf3b75d2a6f30 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -17,7 +17,7 @@ use std::any::Any; use std::fmt; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::physical_expr::PhysicalExpr; @@ -25,6 +25,7 @@ use crate::physical_expr::PhysicalExpr; use arrow::compute::{can_cast_types, CastOptions}; use arrow::datatypes::{DataType, DataType::*, Schema}; use arrow::record_batch::RecordBatch; +use datafusion_common::cse::HashNode; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; use datafusion_common::{not_impl_err, Result}; use datafusion_expr_common::columnar_value::ColumnarValue; @@ -62,13 +63,20 @@ impl PartialEq for CastExpr { } impl Hash for CastExpr { - fn hash(&self, state: &mut H) { + fn hash(&self, state: &mut H) { self.expr.hash(state); self.cast_type.hash(state); self.cast_options.hash(state); } } +impl HashNode for CastExpr { + fn hash_node(&self, state: &mut H) { + self.cast_type.hash(state); + self.cast_options.hash(state); + } +} + impl CastExpr { /// Create a new CastExpr pub fn new( diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index 0649cbd65d34d..5e6703204c941 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -18,7 +18,7 @@ //! Physical column reference: [`Column`] use std::any::Any; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::physical_expr::PhysicalExpr; @@ -27,6 +27,7 @@ use arrow::{ record_batch::RecordBatch, }; use arrow_schema::SchemaRef; +use datafusion_common::cse::HashNode; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{internal_err, plan_err, Result}; use datafusion_expr::ColumnarValue; @@ -71,6 +72,13 @@ pub struct Column { index: usize, } +impl HashNode for Column { + fn hash_node(&self, state: &mut H) { + self.name.hash(state); + self.index.hash(state); + } +} + impl Column { /// Create a new column expression which references the /// column with the given index in the schema. diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index dfe9a905dfeaa..b1da2817f920e 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -44,6 +44,7 @@ use datafusion_expr::ColumnarValue; use datafusion_physical_expr_common::datum::compare_with_eq; use ahash::RandomState; +use datafusion_common::cse::HashNode; use datafusion_common::HashMap; use hashbrown::hash_map::RawEntryMut; @@ -419,6 +420,13 @@ impl Hash for InListExpr { } } +impl HashNode for InListExpr { + fn hash_node(&self, state: &mut H) { + self.negated.hash(state); + // Add `self.static_filter` when hash is available + } +} + /// Creates a unary expression InList pub fn in_list( expr: Arc, diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index 47dc53d125550..d50c3225e5be6 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -17,7 +17,7 @@ //! IS NOT NULL expression -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; use crate::PhysicalExpr; @@ -25,6 +25,7 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; +use datafusion_common::cse::HashNode; use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::ColumnarValue; @@ -44,11 +45,15 @@ impl PartialEq for IsNotNullExpr { } impl Hash for IsNotNullExpr { - fn hash(&self, state: &mut H) { + fn hash(&self, state: &mut H) { self.arg.hash(state); } } +impl HashNode for IsNotNullExpr { + fn hash_node(&self, _state: &mut H) {} +} + impl IsNotNullExpr { /// Create new not expression pub fn new(arg: Arc) -> Self { diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index 5e883dff997aa..4c313d5a19f0e 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -17,7 +17,7 @@ //! IS NULL expression -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; use crate::PhysicalExpr; @@ -25,6 +25,7 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; +use datafusion_common::cse::HashNode; use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::ColumnarValue; @@ -44,11 +45,15 @@ impl PartialEq for IsNullExpr { } impl Hash for IsNullExpr { - fn hash(&self, state: &mut H) { + fn hash(&self, state: &mut H) { self.arg.hash(state); } } +impl HashNode for IsNullExpr { + fn hash_node(&self, _state: &mut H) {} +} + impl IsNullExpr { /// Create new not expression pub fn new(arg: Arc) -> Self { diff --git a/datafusion/physical-expr/src/expressions/like.rs b/datafusion/physical-expr/src/expressions/like.rs index d61cd63c35b1e..6053ffa7f10b8 100644 --- a/datafusion/physical-expr/src/expressions/like.rs +++ b/datafusion/physical-expr/src/expressions/like.rs @@ -15,12 +15,13 @@ // specific language governing permissions and limitations // under the License. -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; use crate::PhysicalExpr; use arrow::record_batch::RecordBatch; use arrow_schema::{DataType, Schema}; +use datafusion_common::cse::HashNode; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; use datafusion_physical_expr_common::datum::apply_cmp; @@ -45,7 +46,7 @@ impl PartialEq for LikeExpr { } impl Hash for LikeExpr { - fn hash(&self, state: &mut H) { + fn hash(&self, state: &mut H) { self.negated.hash(state); self.case_insensitive.hash(state); self.expr.hash(state); @@ -53,6 +54,13 @@ impl Hash for LikeExpr { } } +impl HashNode for LikeExpr { + fn hash_node(&self, state: &mut H) { + self.negated.hash(state); + self.case_insensitive.hash(state); + } +} + impl LikeExpr { pub fn new( negated: bool, diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index 232f9769b056a..96c51c7c3e735 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -18,7 +18,7 @@ //! Literal expressions for physical operations use std::any::Any; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::physical_expr::PhysicalExpr; @@ -27,6 +27,7 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; +use datafusion_common::cse::HashNode; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Expr; use datafusion_expr_common::columnar_value::ColumnarValue; @@ -39,6 +40,12 @@ pub struct Literal { value: ScalarValue, } +impl HashNode for Literal { + fn hash_node(&self, state: &mut H) { + self.value.hash(state); + } +} + impl Literal { /// Create a literal value expression pub fn new(value: ScalarValue) -> Self { diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index 03f2111aca330..c4123c102fe9d 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -18,7 +18,7 @@ //! Negation (-) expression use std::any::Any; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::PhysicalExpr; @@ -28,6 +28,7 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; +use datafusion_common::cse::HashNode; use datafusion_common::{plan_err, Result}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; @@ -51,11 +52,15 @@ impl PartialEq for NegativeExpr { } impl Hash for NegativeExpr { - fn hash(&self, state: &mut H) { + fn hash(&self, state: &mut H) { self.arg.hash(state); } } +impl HashNode for NegativeExpr { + fn hash_node(&self, _state: &mut H) {} +} + impl NegativeExpr { /// Create new not expression pub fn new(arg: Arc) -> Self { diff --git a/datafusion/physical-expr/src/expressions/no_op.rs b/datafusion/physical-expr/src/expressions/no_op.rs index c17b52f5cdfff..e672bd4fb595b 100644 --- a/datafusion/physical-expr/src/expressions/no_op.rs +++ b/datafusion/physical-expr/src/expressions/no_op.rs @@ -18,7 +18,7 @@ //! NoOp placeholder for physical operations use std::any::Any; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use arrow::{ @@ -27,6 +27,7 @@ use arrow::{ }; use crate::PhysicalExpr; +use datafusion_common::cse::HashNode; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; @@ -43,6 +44,10 @@ impl NoOp { } } +impl HashNode for NoOp { + fn hash_node(&self, _state: &mut H) {} +} + impl std::fmt::Display for NoOp { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "NoOp") diff --git a/datafusion/physical-expr/src/expressions/not.rs b/datafusion/physical-expr/src/expressions/not.rs index 440c4e9557bdf..bdf210911994f 100644 --- a/datafusion/physical-expr/src/expressions/not.rs +++ b/datafusion/physical-expr/src/expressions/not.rs @@ -19,12 +19,13 @@ use std::any::Any; use std::fmt; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; +use datafusion_common::cse::HashNode; use datafusion_common::{cast::as_boolean_array, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::ColumnarValue; @@ -44,11 +45,15 @@ impl PartialEq for NotExpr { } impl Hash for NotExpr { - fn hash(&self, state: &mut H) { + fn hash(&self, state: &mut H) { self.arg.hash(state); } } +impl HashNode for NotExpr { + fn hash_node(&self, _state: &mut H) {} +} + impl NotExpr { /// Create new not expression pub fn new(arg: Arc) -> Self { diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index 06f4e929992e5..e7b819eff76dc 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -17,7 +17,7 @@ use std::any::Any; use std::fmt; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::PhysicalExpr; @@ -26,6 +26,7 @@ use arrow::compute::{cast_with_options, CastOptions}; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use compute::can_cast_types; +use datafusion_common::cse::HashNode; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; use datafusion_common::{not_impl_err, Result, ScalarValue}; use datafusion_expr::ColumnarValue; @@ -47,12 +48,18 @@ impl PartialEq for TryCastExpr { } impl Hash for TryCastExpr { - fn hash(&self, state: &mut H) { + fn hash(&self, state: &mut H) { self.expr.hash(state); self.cast_type.hash(state); } } +impl HashNode for TryCastExpr { + fn hash_node(&self, state: &mut H) { + self.cast_type.hash(state); + } +} + impl TryCastExpr { /// Create a new CastExpr pub fn new(expr: Arc, cast_type: DataType) -> Self { diff --git a/datafusion/physical-expr/src/expressions/unknown_column.rs b/datafusion/physical-expr/src/expressions/unknown_column.rs index a63caf7e13056..eacdf2e1f70b5 100644 --- a/datafusion/physical-expr/src/expressions/unknown_column.rs +++ b/datafusion/physical-expr/src/expressions/unknown_column.rs @@ -27,6 +27,7 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; +use datafusion_common::cse::HashNode; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; @@ -94,6 +95,12 @@ impl Hash for UnKnownColumn { } } +impl HashNode for UnKnownColumn { + fn hash_node(&self, state: &mut H) { + self.name.hash(state); + } +} + impl PartialEq for UnKnownColumn { fn eq(&self, _other: &Self) -> bool { // UnknownColumn is not a valid expression, so it should not be equal to any other expression. diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index bd38fb22ccbc3..7be79003d9741 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -31,7 +31,7 @@ use std::any::Any; use std::fmt::{self, Debug, Formatter}; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::expressions::Literal; @@ -39,6 +39,7 @@ use crate::PhysicalExpr; use arrow::array::{Array, RecordBatch}; use arrow::datatypes::{DataType, Schema}; +use datafusion_common::cse::HashNode; use datafusion_common::{internal_err, DFSchema, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; @@ -46,6 +47,7 @@ use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; use datafusion_expr::{ expr_vec_fmt, ColumnarValue, Expr, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF, }; +use datafusion_expr_common::signature::Volatility; /// Physical expression of a scalar function #[derive(Eq, PartialEq, Hash)] @@ -57,6 +59,15 @@ pub struct ScalarFunctionExpr { nullable: bool, } +impl HashNode for ScalarFunctionExpr { + fn hash_node(&self, state: &mut H) { + self.name.hash(state); + self.return_type.hash(state); + self.nullable.hash(state); + self.fun.hash(state); + } +} + impl Debug for ScalarFunctionExpr { fn fmt(&self, f: &mut Formatter) -> fmt::Result { f.debug_struct("ScalarFunctionExpr") @@ -260,6 +271,10 @@ impl PhysicalExpr for ScalarFunctionExpr { preserves_lex_ordering, }) } + + fn is_volatile(&self) -> bool { + self.fun.signature().volatility == Volatility::Volatile + } } /// Create a physical expression for the UDF. diff --git a/datafusion/physical-optimizer/src/eliminate_common_physical_subexprs.rs b/datafusion/physical-optimizer/src/eliminate_common_physical_subexprs.rs new file mode 100644 index 0000000000000..a01e780d54c5c --- /dev/null +++ b/datafusion/physical-optimizer/src/eliminate_common_physical_subexprs.rs @@ -0,0 +1,589 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`EliminateCommonPhysicalSubexprs`] to avoid redundant computation of common physical +//! sub-expressions. + +use datafusion_common::alias::AliasGenerator; +use datafusion_common::config::ConfigOptions; +use datafusion_common::cse::{CSEController, FoundCommonNodes, CSE}; +use datafusion_common::Result; +use datafusion_physical_plan::ExecutionPlan; +use std::sync::Arc; + +use crate::PhysicalOptimizerRule; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_expr_common::operator::Operator; +use datafusion_physical_expr::expressions::{BinaryExpr, CaseExpr, Column}; +use datafusion_physical_expr::{PhysicalExpr, ScalarFunctionExpr}; +use datafusion_physical_plan::projection::ProjectionExec; + +const CSE_PREFIX: &str = "__common_physical_expr"; + +// Optimizer rule to avoid redundant computation of common physical subexpressions +#[derive(Default, Debug)] +pub struct EliminateCommonPhysicalSubexprs {} + +impl EliminateCommonPhysicalSubexprs { + pub fn new() -> Self { + Self {} + } +} + +impl PhysicalOptimizerRule for EliminateCommonPhysicalSubexprs { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + plan.transform_down(|plan| { + let plan_any = plan.as_any(); + if let Some(p) = plan_any.downcast_ref::() { + match CSE::new(PhysicalExprCSEController::new( + config.alias_generator.as_ref(), + p.input().schema().fields().len(), + )) + .extract_common_nodes(vec![p + .expr() + .iter() + .map(|(e, _)| e) + .cloned() + .collect()])? + { + FoundCommonNodes::Yes { + common_nodes: common_exprs, + new_nodes_list: mut new_exprs_list, + original_nodes_list: _, + } => { + let common_exprs = p + .input() + .schema() + .fields() + .iter() + .enumerate() + .map(|(i, field)| { + ( + Arc::new(Column::new(field.name(), i)) + as Arc, + field.name().to_string(), + ) + }) + .chain(common_exprs) + .collect(); + let common = Arc::new(ProjectionExec::try_new( + common_exprs, + Arc::clone(p.input()), + )?); + + let new_exprs = new_exprs_list + .pop() + .unwrap() + .into_iter() + .zip(p.expr().iter().map(|(_, alias)| alias.to_string())) + .collect(); + let new_project = + Arc::new(ProjectionExec::try_new(new_exprs, common)?) + as Arc; + + Ok(Transformed::yes(new_project)) + } + FoundCommonNodes::No { .. } => Ok(Transformed::no(plan)), + } + } else { + Ok(Transformed::no(plan)) + } + }) + .data() + } + + fn name(&self) -> &str { + "eliminate_common_physical_subexpressions" + } + + /// This rule will change the nullable properties of the schema, disable the schema check. + fn schema_check(&self) -> bool { + false + } +} + +pub struct PhysicalExprCSEController<'a> { + alias_generator: &'a AliasGenerator, + base_index: usize, +} + +impl<'a> PhysicalExprCSEController<'a> { + fn new(alias_generator: &'a AliasGenerator, base_index: usize) -> Self { + Self { + alias_generator, + base_index, + } + } +} + +impl CSEController for PhysicalExprCSEController<'_> { + type Node = Arc; + + fn conditional_children( + node: &Self::Node, + ) -> Option<(Vec<&Self::Node>, Vec<&Self::Node>)> { + if let Some(s) = node.as_any().downcast_ref::() { + // In case of `ScalarFunction`s all children can be conditionally executed. + if s.fun().short_circuits() { + Some((vec![], s.args().iter().collect())) + } else { + None + } + } else if let Some(b) = node.as_any().downcast_ref::() { + // In case of `And` and `Or` the first child is surely executed, but we + // account subexpressions as conditional in the second. + if *b.op() == Operator::And || *b.op() == Operator::Or { + Some((vec![b.left()], vec![b.right()])) + } else { + None + } + } else { + node.as_any().downcast_ref::().map(|c| { + ( + // In case of `Case` the optional base expression and the first when + // expressions are surely executed, but we account subexpressions as + // conditional in the others. + c.expr() + .into_iter() + .chain(c.when_then_expr().iter().take(1).map(|(when, _)| when)) + .collect(), + c.when_then_expr() + .iter() + .take(1) + .map(|(_, then)| then) + .chain( + c.when_then_expr() + .iter() + .skip(1) + .flat_map(|(when, then)| [when, then]), + ) + .chain(c.else_expr()) + .collect(), + ) + }) + } + } + + fn is_valid(node: &Self::Node) -> bool { + !node.is_volatile() + } + + fn is_ignored(&self, node: &Self::Node) -> bool { + node.children().is_empty() + } + + fn generate_alias(&self) -> String { + self.alias_generator.next(CSE_PREFIX) + } + + fn rewrite(&mut self, _node: &Self::Node, alias: &str, index: usize) -> Self::Node { + Arc::new(Column::new(alias, self.base_index + index)) + } + + fn rewrite_f_down(&mut self, _node: &Self::Node) {} + + fn rewrite_f_up(&mut self, _node: &Self::Node) {} +} + +#[cfg(test)] +mod tests { + use crate::eliminate_common_physical_subexprs::EliminateCommonPhysicalSubexprs; + use crate::optimizer::PhysicalOptimizerRule; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::config::ConfigOptions; + use datafusion_common::Result; + use datafusion_expr::{ScalarUDF, ScalarUDFImpl}; + use datafusion_expr_common::columnar_value::ColumnarValue; + use datafusion_expr_common::operator::Operator; + use datafusion_expr_common::signature::{Signature, Volatility}; + use datafusion_physical_expr::expressions::{binary, col, lit}; + use datafusion_physical_expr::{PhysicalExpr, ScalarFunctionExpr}; + use datafusion_physical_plan::memory::MemorySourceConfig; + use datafusion_physical_plan::projection::ProjectionExec; + use datafusion_physical_plan::source::DataSourceExec; + use datafusion_physical_plan::{get_plan_string, ExecutionPlan}; + use std::any::Any; + use std::sync::Arc; + + fn mock_data() -> Arc { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + Arc::new(DataSourceExec::new(Arc::new( + MemorySourceConfig::try_new(&[vec![]], Arc::clone(&schema), None).unwrap(), + ))) + } + + #[test] + fn subexpr_in_same_order() -> Result<()> { + let table_scan = mock_data(); + + let a = col("a", &table_scan.schema())?; + let lit_1 = lit(1); + let _1_plus_a = binary(lit_1, Operator::Plus, a, &table_scan.schema())?; + + let exprs = vec![ + (Arc::clone(&_1_plus_a), "first".to_string()), + (_1_plus_a, "second".to_string()), + ]; + let plan = Arc::new(ProjectionExec::try_new(exprs, mock_data())?); + + let config = ConfigOptions::new(); + let optimizer = EliminateCommonPhysicalSubexprs::new(); + let optimized = optimizer.optimize(plan, &config)?; + + let actual = get_plan_string(&optimized); + let expected = [ + "ProjectionExec: expr=[__common_physical_expr_1@2 as first, __common_physical_expr_1@2 as second]", + " ProjectionExec: expr=[a@0 as a, b@1 as b, 1 + a@0 as __common_physical_expr_1]", + " DataSourceExec: partitions=1, partition_sizes=[0]"]; + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn subexpr_in_different_order() -> Result<()> { + let table_scan = mock_data(); + + let a = col("a", &table_scan.schema())?; + let lit_1 = lit(1); + let _1_plus_a = binary( + Arc::clone(&lit_1), + Operator::Plus, + Arc::clone(&a), + &table_scan.schema(), + )?; + let a_plus_1 = binary(a, Operator::Plus, lit_1, &table_scan.schema())?; + + let exprs = vec![ + (_1_plus_a, "first".to_string()), + (a_plus_1, "second".to_string()), + ]; + let plan = Arc::new(ProjectionExec::try_new(exprs, mock_data())?); + + let config = ConfigOptions::new(); + let optimizer = EliminateCommonPhysicalSubexprs::new(); + let optimized = optimizer.optimize(plan, &config)?; + + let actual = get_plan_string(&optimized); + let expected = [ + "ProjectionExec: expr=[1 + a@0 as first, a@0 + 1 as second]", + " DataSourceExec: partitions=1, partition_sizes=[0]", + ]; + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn test_short_circuits() -> Result<()> { + let table_scan = mock_data(); + + let a = col("a", &table_scan.schema())?; + let b = col("b", &table_scan.schema())?; + + let extracted_short_circuit = binary( + binary(Arc::clone(&a), Operator::Eq, lit(0), &table_scan.schema())?, + Operator::Or, + binary(Arc::clone(&b), Operator::Eq, lit(0), &table_scan.schema())?, + &table_scan.schema(), + )?; + let extracted_short_circuit_leg_1 = binary( + binary( + Arc::clone(&a), + Operator::Plus, + Arc::clone(&b), + &table_scan.schema(), + )?, + Operator::Eq, + lit(0), + &table_scan.schema(), + )?; + let not_extracted_short_circuit_leg_2 = binary( + binary( + Arc::clone(&a), + Operator::Minus, + Arc::clone(&b), + &table_scan.schema(), + )?, + Operator::Eq, + lit(0), + &table_scan.schema(), + )?; + let extracted_short_circuit_leg_3 = binary( + binary(a, Operator::Multiply, b, &table_scan.schema())?, + Operator::Eq, + lit(0), + &table_scan.schema(), + )?; + + let exprs = vec![ + (Arc::clone(&extracted_short_circuit), "c1".to_string()), + (extracted_short_circuit, "c2".to_string()), + ( + binary( + Arc::clone(&extracted_short_circuit_leg_1), + Operator::Or, + Arc::clone(¬_extracted_short_circuit_leg_2), + &table_scan.schema(), + )?, + "c3".to_string(), + ), + ( + binary( + extracted_short_circuit_leg_1, + Operator::And, + Arc::clone(¬_extracted_short_circuit_leg_2), + &table_scan.schema(), + )?, + "c4".to_string(), + ), + ( + binary( + Arc::clone(&extracted_short_circuit_leg_3), + Operator::Or, + extracted_short_circuit_leg_3, + &table_scan.schema(), + )?, + "c5".to_string(), + ), + ]; + let plan = Arc::new(ProjectionExec::try_new(exprs, mock_data())?); + + let config = ConfigOptions::new(); + let optimizer = EliminateCommonPhysicalSubexprs::new(); + let optimized = optimizer.optimize(plan, &config)?; + + let actual = get_plan_string(&optimized); + let expected = [ + "ProjectionExec: expr=[__common_physical_expr_1@2 as c1, __common_physical_expr_1@2 as c2, __common_physical_expr_2@3 OR a@0 - b@1 = 0 as c3, __common_physical_expr_2@3 AND a@0 - b@1 = 0 as c4, __common_physical_expr_3@4 OR __common_physical_expr_3@4 as c5]", + " ProjectionExec: expr=[a@0 as a, b@1 as b, a@0 = 0 OR b@1 = 0 as __common_physical_expr_1, a@0 + b@1 = 0 as __common_physical_expr_2, a@0 * b@1 = 0 as __common_physical_expr_3]", + " DataSourceExec: partitions=1, partition_sizes=[0]"]; + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn test_volatile() -> Result<()> { + let table_scan = mock_data(); + + let a = col("a", &table_scan.schema())?; + let b = col("b", &table_scan.schema())?; + let extracted_child = binary(a, Operator::Plus, b, &table_scan.schema())?; + let rand = rand_expr(); + let not_extracted_volatile = + binary(extracted_child, Operator::Plus, rand, &table_scan.schema())?; + + let exprs = vec![ + (Arc::clone(¬_extracted_volatile), "c1".to_string()), + (not_extracted_volatile, "c2".to_string()), + ]; + let plan = Arc::new(ProjectionExec::try_new(exprs, mock_data())?); + + let config = ConfigOptions::new(); + let optimizer = EliminateCommonPhysicalSubexprs::new(); + let optimized = optimizer.optimize(plan, &config)?; + + let actual = get_plan_string(&optimized); + let expected = [ + "ProjectionExec: expr=[__common_physical_expr_1@2 + random() as c1, __common_physical_expr_1@2 + random() as c2]", + " ProjectionExec: expr=[a@0 as a, b@1 as b, a@0 + b@1 as __common_physical_expr_1]", + " DataSourceExec: partitions=1, partition_sizes=[0]"]; + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn test_volatile_short_circuits() -> Result<()> { + let table_scan = mock_data(); + + let a = col("a", &table_scan.schema())?; + let b = col("b", &table_scan.schema())?; + let rand = rand_expr(); + let rand_eq_0 = binary(rand, Operator::Eq, lit(0), &table_scan.schema())?; + + let extracted_short_circuit_leg_1 = + binary(a, Operator::Eq, lit(0), &table_scan.schema())?; + let not_extracted_volatile_short_circuit_1 = binary( + extracted_short_circuit_leg_1, + Operator::Or, + Arc::clone(&rand_eq_0), + &table_scan.schema(), + )?; + + let not_extracted_short_circuit_leg_2 = + binary(b, Operator::Eq, lit(0), &table_scan.schema())?; + let not_extracted_volatile_short_circuit_2 = binary( + rand_eq_0, + Operator::Or, + not_extracted_short_circuit_leg_2, + &table_scan.schema(), + )?; + + let exprs = vec![ + ( + Arc::clone(¬_extracted_volatile_short_circuit_1), + "c1".to_string(), + ), + (not_extracted_volatile_short_circuit_1, "c2".to_string()), + ( + Arc::clone(¬_extracted_volatile_short_circuit_2), + "c3".to_string(), + ), + (not_extracted_volatile_short_circuit_2, "c4".to_string()), + ]; + let plan = Arc::new(ProjectionExec::try_new(exprs, mock_data())?); + + let config = ConfigOptions::new(); + let optimizer = EliminateCommonPhysicalSubexprs::new(); + let optimized = optimizer.optimize(plan, &config)?; + + let actual = get_plan_string(&optimized); + let expected = [ + "ProjectionExec: expr=[__common_physical_expr_1@2 OR random() = 0 as c1, __common_physical_expr_1@2 OR random() = 0 as c2, random() = 0 OR b@1 = 0 as c3, random() = 0 OR b@1 = 0 as c4]", + " ProjectionExec: expr=[a@0 as a, b@1 as b, a@0 = 0 as __common_physical_expr_1]", + " DataSourceExec: partitions=1, partition_sizes=[0]"]; + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn test_non_top_level_common_expression() -> Result<()> { + let table_scan = mock_data(); + + let a = col("a", &table_scan.schema())?; + let b = col("b", &table_scan.schema())?; + let common_expr = binary(a, Operator::Plus, b, &table_scan.schema())?; + + let exprs = vec![ + (Arc::clone(&common_expr), "c1".to_string()), + (common_expr, "c2".to_string()), + ]; + let plan = Arc::new(ProjectionExec::try_new(exprs, mock_data())?); + + let c1 = col("c1", &plan.schema())?; + let c2 = col("c2", &plan.schema())?; + + let exprs = vec![(c1, "c1".to_string()), (c2, "c2".to_string())]; + let plan = Arc::new(ProjectionExec::try_new(exprs, plan)?); + + let config = ConfigOptions::new(); + let optimizer = EliminateCommonPhysicalSubexprs::new(); + let optimized = optimizer.optimize(plan, &config)?; + + let actual = get_plan_string(&optimized); + let expected = [ + "ProjectionExec: expr=[c1@0 as c1, c2@1 as c2]", + " ProjectionExec: expr=[__common_physical_expr_1@2 as c1, __common_physical_expr_1@2 as c2]", + " ProjectionExec: expr=[a@0 as a, b@1 as b, a@0 + b@1 as __common_physical_expr_1]", + " DataSourceExec: partitions=1, partition_sizes=[0]"]; + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn test_nested_common_expression() -> Result<()> { + let table_scan = mock_data(); + + let a = col("a", &table_scan.schema())?; + let b = col("b", &table_scan.schema())?; + let nested_common_expr = binary(a, Operator::Plus, b, &table_scan.schema())?; + let common_expr = binary( + Arc::clone(&nested_common_expr), + Operator::Multiply, + nested_common_expr, + &table_scan.schema(), + )?; + + let exprs = vec![ + (Arc::clone(&common_expr), "c1".to_string()), + (common_expr, "c2".to_string()), + ]; + let plan = Arc::new(ProjectionExec::try_new(exprs, mock_data())?); + + let config = ConfigOptions::new(); + let optimizer = EliminateCommonPhysicalSubexprs::new(); + let optimized = optimizer.optimize(plan, &config)?; + + let actual = get_plan_string(&optimized); + let expected = [ + "ProjectionExec: expr=[__common_physical_expr_1@2 as c1, __common_physical_expr_1@2 as c2]", + " ProjectionExec: expr=[a@0 as a, b@1 as b, __common_physical_expr_2@2 * __common_physical_expr_2@2 as __common_physical_expr_1]", + " ProjectionExec: expr=[a@0 as a, b@1 as b, a@0 + b@1 as __common_physical_expr_2]", + " DataSourceExec: partitions=1, partition_sizes=[0]"]; + assert_eq!(actual, expected); + + Ok(()) + } + + fn rand_expr() -> Arc { + let r = RandomStub::new(); + let n = r.name().to_string(); + let t = r.return_type(&[]).unwrap(); + Arc::new(ScalarFunctionExpr::new( + &n, + Arc::new(ScalarUDF::new_from_impl(r)), + vec![], + t, + )) + } + + #[derive(Debug)] + struct RandomStub { + signature: Signature, + } + + impl RandomStub { + fn new() -> Self { + Self { + signature: Signature::exact(vec![], Volatility::Volatile), + } + } + } + impl ScalarUDFImpl for RandomStub { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "random" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!() + } + } +} diff --git a/datafusion/physical-optimizer/src/lib.rs b/datafusion/physical-optimizer/src/lib.rs index c2beab0320491..58d1677c44ce6 100644 --- a/datafusion/physical-optimizer/src/lib.rs +++ b/datafusion/physical-optimizer/src/lib.rs @@ -21,6 +21,7 @@ pub mod aggregate_statistics; pub mod coalesce_batches; pub mod combine_partial_final_agg; +pub mod eliminate_common_physical_subexprs; pub mod enforce_distribution; pub mod enforce_sorting; pub mod join_selection; diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 7418184fcac15..d966389241563 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -17,6 +17,7 @@ use std::any::Any; use std::fmt::Display; +use std::hash::{Hash, Hasher}; use std::ops::Deref; use std::sync::Arc; use std::vec; @@ -90,6 +91,7 @@ use datafusion::physical_plan::{ use datafusion::prelude::SessionContext; use datafusion::scalar::ScalarValue; use datafusion_common::config::TableParquetOptions; +use datafusion_common::cse::HashNode; use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; @@ -825,12 +827,16 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { } } - impl std::hash::Hash for CustomPredicateExpr { - fn hash(&self, state: &mut H) { + impl Hash for CustomPredicateExpr { + fn hash(&self, state: &mut H) { self.inner.hash(state); } } + impl HashNode for CustomPredicateExpr { + fn hash_node(&self, _state: &mut H) {} + } + impl Display for CustomPredicateExpr { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "CustomPredicateExpr")