Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix hash join with sort push down #13560

Merged
merged 12 commits into from
Dec 9, 2024
101 changes: 101 additions & 0 deletions datafusion/core/src/physical_optimizer/sort_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use crate::physical_plan::repartition::RepartitionExec;
use crate::physical_plan::sorts::sort::SortExec;
use crate::physical_plan::tree_node::PlanContext;
use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties};
use arrow_schema::SchemaRef;

use datafusion_common::tree_node::{
ConcreteTreeNode, Transformed, TreeNode, TreeNodeRecursion,
Expand All @@ -38,6 +39,8 @@ use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::utils::collect_columns;
use datafusion_physical_expr::PhysicalSortRequirement;
use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement};
use datafusion_physical_plan::joins::utils::ColumnIndex;
use datafusion_physical_plan::joins::HashJoinExec;

/// This is a "data class" we use within the [`EnforceSorting`] rule to push
/// down [`SortExec`] in the plan. In some cases, we can reduce the total
Expand Down Expand Up @@ -294,6 +297,8 @@ fn pushdown_requirement_to_children(
.then(|| LexRequirement::new(parent_required.to_vec()));
Ok(Some(vec![req]))
}
} else if let Some(hash_join) = plan.as_any().downcast_ref::<HashJoinExec>() {
handle_hash_join(hash_join, parent_required)
} else {
handle_custom_pushdown(plan, parent_required, maintains_input_order)
}
Expand Down Expand Up @@ -606,6 +611,102 @@ fn handle_custom_pushdown(
}
}

// For hash join we only maintain the input order for the right child
// for join type: Inner, Right, RightSemi, RightAnti
fn handle_hash_join(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be great eventually to get this kind of operator specific logic into the operators (e.g. some method in HashJoinExec). Definitely not in this PR, but having assumptions about the operator separate from its implementation gives us a larger chance of introducing inconsistencies I think

plan: &HashJoinExec,
parent_required: &LexRequirement,
) -> Result<Option<Vec<Option<LexRequirement>>>> {
// If there's no requirement from the parent or the plan has no children
// or the join type is not Inner, Right, RightSemi, RightAnti, return early
if parent_required.is_empty() || !plan.maintains_input_order()[1] {
return Ok(None);
}

// Collect all unique column indices used in the parent-required sorting expression
let all_indices: HashSet<usize> = parent_required
.iter()
.flat_map(|order| {
collect_columns(&order.expr)
.into_iter()
.map(|col| col.index())
.collect::<HashSet<_>>()
})
.collect();

let column_indices = build_join_column_index(plan);
let projected_indices: Vec<_> = if let Some(projection) = &plan.projection {
projection.iter().map(|&i| &column_indices[i]).collect()
} else {
column_indices.iter().collect()
};
let len_of_left_fields = projected_indices
.iter()
.filter(|ci| ci.side == JoinSide::Left)
.count();

let all_from_right_child = all_indices.iter().all(|i| *i >= len_of_left_fields);

// If all columns are from the right child, update the parent requirements
if all_from_right_child {
// Transform the parent-required expression for the child schema by adjusting columns
let updated_parent_req = parent_required
.iter()
.map(|req| {
let child_schema = plan.children()[1].schema();
let updated_columns = Arc::clone(&req.expr)
.transform_up(|expr| {
if let Some(col) = expr.as_any().downcast_ref::<Column>() {
let index = projected_indices[col.index()].index;
Ok(Transformed::yes(Arc::new(Column::new(
child_schema.field(index).name(),
index,
))))
} else {
Ok(Transformed::no(expr))
}
})?
.data;
Ok(PhysicalSortRequirement::new(updated_columns, req.options))
})
.collect::<Result<Vec<_>>>()?;

// Populating with the updated requirements for children that maintain order
Ok(Some(vec![
None,
Some(LexRequirement::new(updated_parent_req)),
]))
} else {
Ok(None)
}
}

// this function is used to build the column index for the hash join
// push down sort requirements to the right child
fn build_join_column_index(plan: &HashJoinExec) -> Vec<ColumnIndex> {
let map_fields = |schema: SchemaRef, side: JoinSide| {
schema
.fields()
.iter()
.enumerate()
.map(|(index, _)| ColumnIndex { index, side })
.collect::<Vec<_>>()
};

match plan.join_type() {
JoinType::Inner | JoinType::Right => {
map_fields(plan.left().schema(), JoinSide::Left)
.into_iter()
.chain(map_fields(plan.right().schema(), JoinSide::Right))
.collect::<Vec<_>>()
}
JoinType::RightSemi | JoinType::RightAnti => {
map_fields(plan.right().schema(), JoinSide::Right)
}
_ => unreachable!("unexpected join type: {}", plan.join_type()),
}
}

/// Define the Requirements Compatibility
#[derive(Debug)]
enum RequirementsCompatibility {
Expand Down
171 changes: 127 additions & 44 deletions datafusion/sqllogictest/test_files/joins.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2864,13 +2864,13 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 WHERE t1_id I
----
physical_plan
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
03)----CoalesceBatchesExec: target_batch_size=2
04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
05)--------CoalesceBatchesExec: target_batch_size=2
06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
08)--------------MemoryExec: partitions=1, partition_sizes=[1]
02)--CoalesceBatchesExec: target_batch_size=2
03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
04)------CoalesceBatchesExec: target_batch_size=2
05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
07)------------MemoryExec: partitions=1, partition_sizes=[1]
08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
09)--------CoalesceBatchesExec: target_batch_size=2
10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2
11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
Expand Down Expand Up @@ -2905,13 +2905,13 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 LEFT SEMI JOI
----
physical_plan
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
03)----CoalesceBatchesExec: target_batch_size=2
04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
05)--------CoalesceBatchesExec: target_batch_size=2
06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
08)--------------MemoryExec: partitions=1, partition_sizes=[1]
02)--CoalesceBatchesExec: target_batch_size=2
03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
04)------CoalesceBatchesExec: target_batch_size=2
05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
07)------------MemoryExec: partitions=1, partition_sizes=[1]
08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
09)--------CoalesceBatchesExec: target_batch_size=2
10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2
11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
Expand Down Expand Up @@ -2967,10 +2967,10 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 WHERE t1_id I
----
physical_plan
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
03)----CoalesceBatchesExec: target_batch_size=2
04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
05)--------MemoryExec: partitions=1, partition_sizes=[1]
02)--CoalesceBatchesExec: target_batch_size=2
03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
04)------MemoryExec: partitions=1, partition_sizes=[1]
05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
07)----------MemoryExec: partitions=1, partition_sizes=[1]

Expand Down Expand Up @@ -3003,10 +3003,10 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 LEFT SEMI JOI
----
physical_plan
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
03)----CoalesceBatchesExec: target_batch_size=2
04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
05)--------MemoryExec: partitions=1, partition_sizes=[1]
02)--CoalesceBatchesExec: target_batch_size=2
03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)]
04)------MemoryExec: partitions=1, partition_sizes=[1]
05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
07)----------MemoryExec: partitions=1, partition_sizes=[1]

Expand Down Expand Up @@ -3061,13 +3061,13 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t1 t1 WHER
----
physical_plan
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
03)----CoalesceBatchesExec: target_batch_size=2
04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0
05)--------CoalesceBatchesExec: target_batch_size=2
06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
08)--------------MemoryExec: partitions=1, partition_sizes=[1]
02)--CoalesceBatchesExec: target_batch_size=2
03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0
04)------CoalesceBatchesExec: target_batch_size=2
05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
07)------------MemoryExec: partitions=1, partition_sizes=[1]
08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
09)--------CoalesceBatchesExec: target_batch_size=2
10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2
11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
Expand All @@ -3083,13 +3083,13 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t2 t2 RIGH
----
physical_plan
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
03)----CoalesceBatchesExec: target_batch_size=2
04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1
05)--------CoalesceBatchesExec: target_batch_size=2
06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
08)--------------MemoryExec: partitions=1, partition_sizes=[1]
02)--CoalesceBatchesExec: target_batch_size=2
03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1
04)------CoalesceBatchesExec: target_batch_size=2
05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
07)------------MemoryExec: partitions=1, partition_sizes=[1]
08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
09)--------CoalesceBatchesExec: target_batch_size=2
10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2
11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
Expand Down Expand Up @@ -3143,10 +3143,10 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t1 t1 WHER
----
physical_plan
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
03)----CoalesceBatchesExec: target_batch_size=2
04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0
05)--------MemoryExec: partitions=1, partition_sizes=[1]
02)--CoalesceBatchesExec: target_batch_size=2
03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0
04)------MemoryExec: partitions=1, partition_sizes=[1]
05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
07)----------MemoryExec: partitions=1, partition_sizes=[1]

Expand All @@ -3160,10 +3160,10 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t2 t2 RIGH
----
physical_plan
01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]
02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
03)----CoalesceBatchesExec: target_batch_size=2
04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1
05)--------MemoryExec: partitions=1, partition_sizes=[1]
02)--CoalesceBatchesExec: target_batch_size=2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These plans imply that the sort has been pushed into the second (probe) input which makes sense I think : https://docs.rs/datafusion/latest/datafusion/physical_plan/joins/struct.HashJoinExec.html

03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1
04)------MemoryExec: partitions=1, partition_sizes=[1]
05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true]
06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
07)----------MemoryExec: partitions=1, partition_sizes=[1]

Expand Down Expand Up @@ -4313,3 +4313,86 @@ physical_plan
04)------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(binary_col@0, binary_col@0)]
05)--------MemoryExec: partitions=1, partition_sizes=[1]
06)--------MemoryExec: partitions=1, partition_sizes=[1]

# Test hash join sort push down
# Issue: https://github.com/apache/datafusion/issues/13559
statement ok
CREATE TABLE test(a INT, b INT, c INT)

statement ok
insert into test values (1,2,3), (4,5,6), (null, 7, 8), (8, null, 9), (9, 10, null)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @alamb, are these tests appropriate?


statement ok
set datafusion.execution.target_partitions = 2;

query TT
explain select * from test where a in (select a from test where b > 3) order by c desc nulls first;
----
logical_plan
01)Sort: test.c DESC NULLS FIRST
02)--LeftSemi Join: test.a = __correlated_sq_1.a
03)----TableScan: test projection=[a, b, c]
04)----SubqueryAlias: __correlated_sq_1
05)------Projection: test.a
06)--------Filter: test.b > Int32(3)
07)----------TableScan: test projection=[a, b]
physical_plan
01)SortPreservingMergeExec: [c@2 DESC]
02)--CoalesceBatchesExec: target_batch_size=3
03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(a@0, a@0)]
04)------CoalesceBatchesExec: target_batch_size=3
05)--------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2
06)----------CoalesceBatchesExec: target_batch_size=3
07)------------FilterExec: b@1 > 3, projection=[a@0]
08)--------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
09)----------------MemoryExec: partitions=1, partition_sizes=[1]
10)------SortExec: expr=[c@2 DESC], preserve_partitioning=[true]
11)--------CoalesceBatchesExec: target_batch_size=3
12)----------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2
13)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
14)--------------MemoryExec: partitions=1, partition_sizes=[1]

query TT
explain select * from test where a in (select a from test where b > 3) order by c desc nulls last;
----
logical_plan
01)Sort: test.c DESC NULLS LAST
02)--LeftSemi Join: test.a = __correlated_sq_1.a
03)----TableScan: test projection=[a, b, c]
04)----SubqueryAlias: __correlated_sq_1
05)------Projection: test.a
06)--------Filter: test.b > Int32(3)
07)----------TableScan: test projection=[a, b]
physical_plan
01)SortPreservingMergeExec: [c@2 DESC NULLS LAST]
02)--CoalesceBatchesExec: target_batch_size=3
03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(a@0, a@0)]
04)------CoalesceBatchesExec: target_batch_size=3
05)--------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2
06)----------CoalesceBatchesExec: target_batch_size=3
07)------------FilterExec: b@1 > 3, projection=[a@0]
08)--------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
09)----------------MemoryExec: partitions=1, partition_sizes=[1]
10)------SortExec: expr=[c@2 DESC NULLS LAST], preserve_partitioning=[true]
11)--------CoalesceBatchesExec: target_batch_size=3
12)----------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2
13)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
14)--------------MemoryExec: partitions=1, partition_sizes=[1]

query III
select * from test where a in (select a from test where b > 3) order by c desc nulls first;
----
9 10 NULL
4 5 6

query III
select * from test where a in (select a from test where b > 3) order by c desc nulls last;
----
4 5 6
9 10 NULL

statement ok
DROP TABLE test

statement ok
set datafusion.execution.target_partitions = 1;
Loading