diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 42514537e28d..bb268e048d9a 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -292,6 +292,23 @@ pub trait TreeNode: Sized { ) } + /// Returns true if `f` returns true for node in the tree. + /// + /// Stops recursion as soon as a matching node is found + fn exists bool>(&self, mut f: F) -> bool { + let mut found = false; + self.apply(&mut |n| { + Ok(if f(n) { + found = true; + TreeNodeRecursion::Stop + } else { + TreeNodeRecursion::Continue + }) + }) + .unwrap(); + found + } + /// Apply the closure `F` to the node's children. fn apply_children Result>( &self, diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index b11636d831b1..ad15a81a2325 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -25,7 +25,7 @@ use std::sync::Arc; use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; -use crate::utils::{expr_to_columns, find_out_reference_exprs}; +use crate::utils::expr_to_columns; use crate::window_frame; use crate::{ aggregate_function, built_in_function, built_in_window_function, udaf, @@ -1232,7 +1232,7 @@ impl Expr { /// Return true when the expression contains out reference(correlated) expressions. pub fn contains_outer(&self) -> bool { - !find_out_reference_exprs(self).is_empty() + self.exists(|expr| matches!(expr, Expr::OuterReferenceColumn { .. })) } /// Recursively find all [`Expr::Placeholder`] expressions, and diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index ca8d718ec090..7ea1324d9052 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1321,6 +1321,21 @@ impl LogicalPlan { | LogicalPlan::Extension(_) => None, } } + + /// If this node's expressions contains any references to an outer subquery + pub fn contains_outer_reference(&self) -> bool { + let mut contains = false; + self.apply_expressions(|expr| { + Ok(if expr.contains_outer() { + contains = true; + TreeNodeRecursion::Stop + } else { + TreeNodeRecursion::Continue + }) + }) + .unwrap(); + contains + } } /// This macro is used to determine continuation during combined transforming diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index 79375e52da1f..002885266e2f 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -140,7 +140,7 @@ fn check_inner_plan( is_aggregate: bool, can_contain_outer_ref: bool, ) -> Result<()> { - if !can_contain_outer_ref && contains_outer_reference(inner_plan) { + if !can_contain_outer_ref && inner_plan.contains_outer_reference() { return plan_err!("Accessing outer reference columns is not allowed in the plan"); } // We want to support as many operators as possible inside the correlated subquery @@ -233,13 +233,6 @@ fn check_inner_plan( } } -fn contains_outer_reference(inner_plan: &LogicalPlan) -> bool { - inner_plan - .expressions() - .iter() - .any(|expr| expr.contains_outer()) -} - fn check_aggregation_in_scalar_subquery( inner_plan: &LogicalPlan, agg: &Aggregate, diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index dbcf02b26ba6..7eda45fb563c 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -91,7 +91,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { _ => Ok(Transformed::no(plan)), } } - _ if plan.expressions().iter().any(|expr| expr.contains_outer()) => { + _ if plan.contains_outer_reference() => { // the unsupported cases, the plan expressions contain out reference columns(like window expressions) self.can_pull_up = false; Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump))