From ea4a7f558dfeb40666e8fe725f133f0d0f992ef9 Mon Sep 17 00:00:00 2001 From: Huaijin Date: Sun, 24 Nov 2024 22:25:54 +0800 Subject: [PATCH 1/7] fix: join with sort push down --- .../src/physical_optimizer/sort_pushdown.rs | 10 +++----- datafusion/sqllogictest/test_files/joins.slt | 23 +++++++++++++++++++ 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index d48c7118cb8e..8cd3040ced02 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -562,10 +562,6 @@ fn handle_custom_pushdown( // If all columns are from the maintained child, update the parent requirements if all_from_maintained_child { - let sub_offset = len_of_child_schemas - .iter() - .take(maintained_child_idx) - .sum::(); // Transform the parent-required expression for the child schema by adjusting columns let updated_parent_req = parent_required .iter() @@ -574,10 +570,10 @@ fn handle_custom_pushdown( let updated_columns = Arc::clone(&req.expr) .transform_up(|expr| { if let Some(col) = expr.as_any().downcast_ref::() { - let new_index = col.index() - sub_offset; + let index = child_schema.index_of(col.name())?; Ok(Transformed::yes(Arc::new(Column::new( - child_schema.field(new_index).name(), - new_index, + child_schema.field(index).name(), + index, )))) } else { Ok(Transformed::no(expr)) diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index e636e93007a4..80a73393ee70 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -4313,3 +4313,26 @@ physical_plan 04)------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(binary_col@0, binary_col@0)] 05)--------MemoryExec: partitions=1, partition_sizes=[1] 06)--------MemoryExec: partitions=1, partition_sizes=[1] + +statement ok +CREATE TABLE test(a INT, b INT, c INT) + +query TT +explain select * from test where a not in (select a from test where b > 3) order by c desc; +---- +logical_plan +01)Sort: test.c DESC NULLS FIRST +02)--LeftAnti Join: test.a = __correlated_sq_1.a +03)----TableScan: test projection=[a, b, c] +04)----SubqueryAlias: __correlated_sq_1 +05)------Projection: test.a +06)--------Filter: test.b > Int32(3) +07)----------TableScan: test projection=[a, b] +physical_plan +01)SortExec: expr=[c@2 DESC], preserve_partitioning=[false] +02)--CoalesceBatchesExec: target_batch_size=3 +03)----HashJoinExec: mode=CollectLeft, join_type=LeftAnti, on=[(a@0, a@0)] +04)------MemoryExec: partitions=1, partition_sizes=[0] +05)------CoalesceBatchesExec: target_batch_size=3 +06)--------FilterExec: b@1 > 3, projection=[a@0] +07)----------MemoryExec: partitions=1, partition_sizes=[0] From 5e5fa7d2736c5298b930190d2e6c4ed10079e342 Mon Sep 17 00:00:00 2001 From: Huaijin Date: Sun, 24 Nov 2024 22:31:01 +0800 Subject: [PATCH 2/7] chore: insert some value --- datafusion/sqllogictest/test_files/joins.slt | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 80a73393ee70..b70fc83e68a8 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -4317,6 +4317,9 @@ physical_plan statement ok CREATE TABLE test(a INT, b INT, c INT) +statement ok +insert into test values (1,2,3), (4,5,6) + query TT explain select * from test where a not in (select a from test where b > 3) order by c desc; ---- @@ -4329,10 +4332,10 @@ logical_plan 06)--------Filter: test.b > Int32(3) 07)----------TableScan: test projection=[a, b] physical_plan -01)SortExec: expr=[c@2 DESC], preserve_partitioning=[false] -02)--CoalesceBatchesExec: target_batch_size=3 -03)----HashJoinExec: mode=CollectLeft, join_type=LeftAnti, on=[(a@0, a@0)] -04)------MemoryExec: partitions=1, partition_sizes=[0] -05)------CoalesceBatchesExec: target_batch_size=3 -06)--------FilterExec: b@1 > 3, projection=[a@0] -07)----------MemoryExec: partitions=1, partition_sizes=[0] +01)CoalesceBatchesExec: target_batch_size=3 +02)--HashJoinExec: mode=CollectLeft, join_type=RightAnti, on=[(a@0, a@0)] +03)----CoalesceBatchesExec: target_batch_size=3 +04)------FilterExec: b@1 > 3, projection=[a@0] +05)--------MemoryExec: partitions=1, partition_sizes=[1] +06)----SortExec: expr=[c@2 DESC], preserve_partitioning=[false] +07)------MemoryExec: partitions=1, partition_sizes=[1] From 7f8117888f15bbe6abb51894c4e242f11f855660 Mon Sep 17 00:00:00 2001 From: Huaijin Date: Wed, 4 Dec 2024 09:13:16 +0800 Subject: [PATCH 3/7] apply suggestion --- .../src/physical_optimizer/sort_pushdown.rs | 116 ++++++++++++++++++ datafusion/sqllogictest/test_files/joins.slt | 88 ++++++------- 2 files changed, 160 insertions(+), 44 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index 8cd3040ced02..d48a29a26535 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -38,6 +38,8 @@ use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::PhysicalSortRequirement; use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_plan::joins::utils::ColumnIndex; +use datafusion_physical_plan::joins::HashJoinExec; /// This is a "data class" we use within the [`EnforceSorting`] rule to push /// down [`SortExec`] in the plan. In some cases, we can reduce the total @@ -294,6 +296,8 @@ fn pushdown_requirement_to_children( .then(|| LexRequirement::new(parent_required.to_vec())); Ok(Some(vec![req])) } + } else if let Some(hash_join) = plan.as_any().downcast_ref::() { + handle_hash_join(hash_join, parent_required) } else { handle_custom_pushdown(plan, parent_required, maintains_input_order) } @@ -602,6 +606,118 @@ fn handle_custom_pushdown( } } +// For hash join we only maintain the input order for the right child +// for join type: Inner, Right, RightSemi, RightAnti +fn handle_hash_join( + plan: &HashJoinExec, + parent_required: &LexRequirement, +) -> Result>>> { + // If there's no requirement from the parent or the plan has no children + // or the join type is not Inner, Right, RightSemi, RightAnti, return early + if parent_required.is_empty() + || plan.children().is_empty() + || !matches!( + plan.join_type(), + JoinType::Inner | JoinType::Right | JoinType::RightSemi | JoinType::RightAnti + ) + { + return Ok(None); + } + + // Collect all unique column indices used in the parent-required sorting expression + let all_indices: HashSet = parent_required + .iter() + .flat_map(|order| { + collect_columns(&order.expr) + .iter() + .map(|col| col.index()) + .collect::>() + }) + .collect(); + + let column_indices = build_join_column_index(plan); + let projected_indices: Vec<_> = if let Some(projection) = &plan.projection { + projection.iter().map(|&i| &column_indices[i]).collect() + } else { + column_indices.iter().collect() + }; + let len_of_left_fields = projected_indices + .iter() + .filter(|ci| ci.side == JoinSide::Left) + .count(); + + let all_from_right_child = all_indices.iter().all(|i| *i >= len_of_left_fields); + + // If all columns are from the right child, update the parent requirements + if all_from_right_child { + // Transform the parent-required expression for the child schema by adjusting columns + let updated_parent_req = parent_required + .iter() + .map(|req| { + let child_schema = plan.children()[1].schema(); + let updated_columns = Arc::clone(&req.expr) + .transform_up(|expr| { + if let Some(col) = expr.as_any().downcast_ref::() { + let index = projected_indices[col.index()].index; + Ok(Transformed::yes(Arc::new(Column::new( + child_schema.field(index).name(), + index, + )))) + } else { + Ok(Transformed::no(expr)) + } + })? + .data; + Ok(PhysicalSortRequirement::new(updated_columns, req.options)) + }) + .collect::>>()?; + + // Populating with the updated requirements for children that maintain order + Ok(Some(vec![ + None, + Some(LexRequirement::new(updated_parent_req)), + ])) + } else { + Ok(None) + } +} + +// this function is used to build the column index for the hash join +// push down sort requirements to the right child +fn build_join_column_index(plan: &HashJoinExec) -> Vec { + let left = plan.left().schema(); + let right = plan.right().schema(); + + let left_fields = || { + left.fields() + .iter() + .enumerate() + .map(|(index, _)| ColumnIndex { + index, + side: JoinSide::Left, + }) + }; + + let right_fields = || { + right + .fields() + .iter() + .enumerate() + .map(|(index, _)| ColumnIndex { + index, + side: JoinSide::Right, + }) + }; + + match plan.join_type() { + JoinType::Inner | JoinType::Right => { + left_fields().chain(right_fields()).collect() + } + JoinType::RightSemi | JoinType::RightAnti => right_fields().collect(), + _ => unreachable!("unexpected join type: {}", plan.join_type()), + } +} + /// Define the Requirements Compatibility #[derive(Debug)] enum RequirementsCompatibility { diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index b70fc83e68a8..452f1228e31d 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -2864,13 +2864,13 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 WHERE t1_id I ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] -05)--------CoalesceBatchesExec: target_batch_size=2 -06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 -07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -08)--------------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] +04)------CoalesceBatchesExec: target_batch_size=2 +05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +07)------------MemoryExec: partitions=1, partition_sizes=[1] +08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 09)--------CoalesceBatchesExec: target_batch_size=2 10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -2905,13 +2905,13 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 LEFT SEMI JOI ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] -05)--------CoalesceBatchesExec: target_batch_size=2 -06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 -07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -08)--------------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] +04)------CoalesceBatchesExec: target_batch_size=2 +05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +07)------------MemoryExec: partitions=1, partition_sizes=[1] +08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 09)--------CoalesceBatchesExec: target_batch_size=2 10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -2967,10 +2967,10 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 WHERE t1_id I ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] -05)--------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] +04)------MemoryExec: partitions=1, partition_sizes=[1] +05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 07)----------MemoryExec: partitions=1, partition_sizes=[1] @@ -3003,10 +3003,10 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 LEFT SEMI JOI ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] -05)--------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] +04)------MemoryExec: partitions=1, partition_sizes=[1] +05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 07)----------MemoryExec: partitions=1, partition_sizes=[1] @@ -3061,13 +3061,13 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t1 t1 WHER ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0 -05)--------CoalesceBatchesExec: target_batch_size=2 -06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 -07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -08)--------------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0 +04)------CoalesceBatchesExec: target_batch_size=2 +05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +07)------------MemoryExec: partitions=1, partition_sizes=[1] +08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 09)--------CoalesceBatchesExec: target_batch_size=2 10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -3083,13 +3083,13 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t2 t2 RIGH ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1 -05)--------CoalesceBatchesExec: target_batch_size=2 -06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 -07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -08)--------------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1 +04)------CoalesceBatchesExec: target_batch_size=2 +05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +07)------------MemoryExec: partitions=1, partition_sizes=[1] +08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 09)--------CoalesceBatchesExec: target_batch_size=2 10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -3143,10 +3143,10 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t1 t1 WHER ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0 -05)--------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0 +04)------MemoryExec: partitions=1, partition_sizes=[1] +05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 07)----------MemoryExec: partitions=1, partition_sizes=[1] @@ -3160,10 +3160,10 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t2 t2 RIGH ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1 -05)--------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1 +04)------MemoryExec: partitions=1, partition_sizes=[1] +05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 07)----------MemoryExec: partitions=1, partition_sizes=[1] From 13b98b5a1657505e962023fadde4cc905bbab945 Mon Sep 17 00:00:00 2001 From: Huaijin Date: Wed, 4 Dec 2024 09:16:05 +0800 Subject: [PATCH 4/7] recover handle_costom_pushdown change --- .../core/src/physical_optimizer/sort_pushdown.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index d48a29a26535..44d61c755800 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -566,6 +566,10 @@ fn handle_custom_pushdown( // If all columns are from the maintained child, update the parent requirements if all_from_maintained_child { + let sub_offset = len_of_child_schemas + .iter() + .take(maintained_child_idx) + .sum::(); // Transform the parent-required expression for the child schema by adjusting columns let updated_parent_req = parent_required .iter() @@ -574,10 +578,10 @@ fn handle_custom_pushdown( let updated_columns = Arc::clone(&req.expr) .transform_up(|expr| { if let Some(col) = expr.as_any().downcast_ref::() { - let index = child_schema.index_of(col.name())?; + let new_index = col.index() - sub_offset; Ok(Transformed::yes(Arc::new(Column::new( - child_schema.field(index).name(), - index, + child_schema.field(new_index).name(), + new_index, )))) } else { Ok(Transformed::no(expr)) From b4e214d6e7ca37b46f981c5aa4ab5d6798274c2a Mon Sep 17 00:00:00 2001 From: Huaijin Date: Wed, 4 Dec 2024 16:02:17 +0800 Subject: [PATCH 5/7] apply suggestion --- .../src/physical_optimizer/sort_pushdown.rs | 43 ++++++------------- 1 file changed, 14 insertions(+), 29 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index 44d61c755800..6c761f674b3b 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -28,6 +28,7 @@ use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::tree_node::PlanContext; use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; +use arrow_schema::SchemaRef; use datafusion_common::tree_node::{ ConcreteTreeNode, Transformed, TreeNode, TreeNodeRecursion, @@ -618,13 +619,7 @@ fn handle_hash_join( ) -> Result>>> { // If there's no requirement from the parent or the plan has no children // or the join type is not Inner, Right, RightSemi, RightAnti, return early - if parent_required.is_empty() - || plan.children().is_empty() - || !matches!( - plan.join_type(), - JoinType::Inner | JoinType::Right | JoinType::RightSemi | JoinType::RightAnti - ) - { + if parent_required.is_empty() || !plan.maintains_input_order()[1] { return Ok(None); } @@ -633,7 +628,7 @@ fn handle_hash_join( .iter() .flat_map(|order| { collect_columns(&order.expr) - .iter() + .into_iter() .map(|col| col.index()) .collect::>() }) @@ -689,35 +684,25 @@ fn handle_hash_join( // this function is used to build the column index for the hash join // push down sort requirements to the right child fn build_join_column_index(plan: &HashJoinExec) -> Vec { - let left = plan.left().schema(); - let right = plan.right().schema(); - - let left_fields = || { - left.fields() - .iter() - .enumerate() - .map(|(index, _)| ColumnIndex { - index, - side: JoinSide::Left, - }) - }; - - let right_fields = || { - right + let map_fields = |schema: SchemaRef, side: JoinSide| { + schema .fields() .iter() .enumerate() - .map(|(index, _)| ColumnIndex { - index, - side: JoinSide::Right, - }) + .map(|(index, _)| ColumnIndex { index, side }) + .collect::>() }; match plan.join_type() { JoinType::Inner | JoinType::Right => { - left_fields().chain(right_fields()).collect() + map_fields(plan.left().schema(), JoinSide::Left) + .into_iter() + .chain(map_fields(plan.right().schema(), JoinSide::Right)) + .collect::>() + } + JoinType::RightSemi | JoinType::RightAnti => { + map_fields(plan.right().schema(), JoinSide::Right) } - JoinType::RightSemi | JoinType::RightAnti => right_fields().collect(), _ => unreachable!("unexpected join type: {}", plan.join_type()), } } From 64fac8f764eafe262eacf4fcfe38b0ad204710d8 Mon Sep 17 00:00:00 2001 From: Huaijin Date: Sat, 7 Dec 2024 15:56:04 +0800 Subject: [PATCH 6/7] add more test --- datafusion/sqllogictest/test_files/joins.slt | 45 ++++++++++++++++++-- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 452f1228e31d..5da33113d727 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -4314,18 +4314,20 @@ physical_plan 05)--------MemoryExec: partitions=1, partition_sizes=[1] 06)--------MemoryExec: partitions=1, partition_sizes=[1] +# Test hash join sort push down +# Issue: https://github.com/apache/datafusion/issues/13559 statement ok CREATE TABLE test(a INT, b INT, c INT) statement ok -insert into test values (1,2,3), (4,5,6) +insert into test values (1,2,3), (4,5,6), (null, 7, 8), (8, null, 9), (9, 10, null) query TT -explain select * from test where a not in (select a from test where b > 3) order by c desc; +explain select * from test where a in (select a from test where b > 3) order by c desc nulls first; ---- logical_plan 01)Sort: test.c DESC NULLS FIRST -02)--LeftAnti Join: test.a = __correlated_sq_1.a +02)--LeftSemi Join: test.a = __correlated_sq_1.a 03)----TableScan: test projection=[a, b, c] 04)----SubqueryAlias: __correlated_sq_1 05)------Projection: test.a @@ -4333,9 +4335,44 @@ logical_plan 07)----------TableScan: test projection=[a, b] physical_plan 01)CoalesceBatchesExec: target_batch_size=3 -02)--HashJoinExec: mode=CollectLeft, join_type=RightAnti, on=[(a@0, a@0)] +02)--HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(a@0, a@0)] 03)----CoalesceBatchesExec: target_batch_size=3 04)------FilterExec: b@1 > 3, projection=[a@0] 05)--------MemoryExec: partitions=1, partition_sizes=[1] 06)----SortExec: expr=[c@2 DESC], preserve_partitioning=[false] 07)------MemoryExec: partitions=1, partition_sizes=[1] + +query TT +explain select * from test where a in (select a from test where b > 3) order by c desc nulls last; +---- +logical_plan +01)Sort: test.c DESC NULLS LAST +02)--LeftSemi Join: test.a = __correlated_sq_1.a +03)----TableScan: test projection=[a, b, c] +04)----SubqueryAlias: __correlated_sq_1 +05)------Projection: test.a +06)--------Filter: test.b > Int32(3) +07)----------TableScan: test projection=[a, b] +physical_plan +01)CoalesceBatchesExec: target_batch_size=3 +02)--HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(a@0, a@0)] +03)----CoalesceBatchesExec: target_batch_size=3 +04)------FilterExec: b@1 > 3, projection=[a@0] +05)--------MemoryExec: partitions=1, partition_sizes=[1] +06)----SortExec: expr=[c@2 DESC NULLS LAST], preserve_partitioning=[false] +07)------MemoryExec: partitions=1, partition_sizes=[1] + +query III +select * from test where a in (select a from test where b > 3) order by c desc nulls first; +---- +9 10 NULL +4 5 6 + +query III +select * from test where a in (select a from test where b > 3) order by c desc nulls last; +---- +4 5 6 +9 10 NULL + +statement ok +DROP TABLE test From 2f05a33e3c6f3cb71ed44d0ba3a0155bc23b9f9a Mon Sep 17 00:00:00 2001 From: Huaijin Date: Sat, 7 Dec 2024 16:20:04 +0800 Subject: [PATCH 7/7] add partition --- datafusion/sqllogictest/test_files/joins.slt | 48 ++++++++++++++------ 1 file changed, 34 insertions(+), 14 deletions(-) diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 5da33113d727..62f625119897 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -4322,6 +4322,9 @@ CREATE TABLE test(a INT, b INT, c INT) statement ok insert into test values (1,2,3), (4,5,6), (null, 7, 8), (8, null, 9), (9, 10, null) +statement ok +set datafusion.execution.target_partitions = 2; + query TT explain select * from test where a in (select a from test where b > 3) order by c desc nulls first; ---- @@ -4334,13 +4337,20 @@ logical_plan 06)--------Filter: test.b > Int32(3) 07)----------TableScan: test projection=[a, b] physical_plan -01)CoalesceBatchesExec: target_batch_size=3 -02)--HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(a@0, a@0)] -03)----CoalesceBatchesExec: target_batch_size=3 -04)------FilterExec: b@1 > 3, projection=[a@0] -05)--------MemoryExec: partitions=1, partition_sizes=[1] -06)----SortExec: expr=[c@2 DESC], preserve_partitioning=[false] -07)------MemoryExec: partitions=1, partition_sizes=[1] +01)SortPreservingMergeExec: [c@2 DESC] +02)--CoalesceBatchesExec: target_batch_size=3 +03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(a@0, a@0)] +04)------CoalesceBatchesExec: target_batch_size=3 +05)--------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 +06)----------CoalesceBatchesExec: target_batch_size=3 +07)------------FilterExec: b@1 > 3, projection=[a@0] +08)--------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +09)----------------MemoryExec: partitions=1, partition_sizes=[1] +10)------SortExec: expr=[c@2 DESC], preserve_partitioning=[true] +11)--------CoalesceBatchesExec: target_batch_size=3 +12)----------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 +13)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +14)--------------MemoryExec: partitions=1, partition_sizes=[1] query TT explain select * from test where a in (select a from test where b > 3) order by c desc nulls last; @@ -4354,13 +4364,20 @@ logical_plan 06)--------Filter: test.b > Int32(3) 07)----------TableScan: test projection=[a, b] physical_plan -01)CoalesceBatchesExec: target_batch_size=3 -02)--HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(a@0, a@0)] -03)----CoalesceBatchesExec: target_batch_size=3 -04)------FilterExec: b@1 > 3, projection=[a@0] -05)--------MemoryExec: partitions=1, partition_sizes=[1] -06)----SortExec: expr=[c@2 DESC NULLS LAST], preserve_partitioning=[false] -07)------MemoryExec: partitions=1, partition_sizes=[1] +01)SortPreservingMergeExec: [c@2 DESC NULLS LAST] +02)--CoalesceBatchesExec: target_batch_size=3 +03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(a@0, a@0)] +04)------CoalesceBatchesExec: target_batch_size=3 +05)--------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 +06)----------CoalesceBatchesExec: target_batch_size=3 +07)------------FilterExec: b@1 > 3, projection=[a@0] +08)--------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +09)----------------MemoryExec: partitions=1, partition_sizes=[1] +10)------SortExec: expr=[c@2 DESC NULLS LAST], preserve_partitioning=[true] +11)--------CoalesceBatchesExec: target_batch_size=3 +12)----------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 +13)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +14)--------------MemoryExec: partitions=1, partition_sizes=[1] query III select * from test where a in (select a from test where b > 3) order by c desc nulls first; @@ -4376,3 +4393,6 @@ select * from test where a in (select a from test where b > 3) order by c desc n statement ok DROP TABLE test + +statement ok +set datafusion.execution.target_partitions = 1;