From e92ec869af32fffce28fe52597ad65902c697ea4 Mon Sep 17 00:00:00 2001 From: ion-elgreco <15728914+ion-elgreco@users.noreply.github.com> Date: Sat, 3 Aug 2024 11:53:33 +0200 Subject: [PATCH] feat: improve merge performance by using predicate non-partition columns min/max for prefiltering --- crates/core/src/operations/merge/mod.rs | 399 +++++++++++++++++++----- crates/core/tests/command_merge.rs | 31 +- 2 files changed, 340 insertions(+), 90 deletions(-) diff --git a/crates/core/src/operations/merge/mod.rs b/crates/core/src/operations/merge/mod.rs index efc54c1869..73a97436f1 100644 --- a/crates/core/src/operations/merge/mod.rs +++ b/crates/core/src/operations/merge/mod.rs @@ -49,11 +49,14 @@ use datafusion::{ use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{Column, DFSchema, ScalarValue, TableReference}; use datafusion_expr::expr::Placeholder; -use datafusion_expr::{col, conditional_expressions::CaseBuilder, lit, when, Expr, JoinType}; use datafusion_expr::{ - BinaryExpr, Distinct, Extension, LogicalPlan, LogicalPlanBuilder, Operator, Projection, + col, conditional_expressions::CaseBuilder, lit, max, min, when, Between, Expr, JoinType, +}; +use datafusion_expr::{ + Aggregate, BinaryExpr, Extension, LogicalPlan, LogicalPlanBuilder, Operator, UserDefinedLogicalNode, UNNAMED_TABLE, }; +use either::{Left, Right}; use futures::future::BoxFuture; use itertools::Itertools; use parquet::file::properties::WriterProperties; @@ -666,13 +669,22 @@ impl ExtensionPlanner for MergeMetricExtensionPlanner { } } -/// Takes the predicate provided and does two things: +struct PredicatePlaceholder { + expr: Expr, + alias: String, + is_aggregate: bool, +} + +/// Takes the predicate provided and does three things: /// -/// 1. for any relations between a source column and a target column, if the target column is a -/// partition column, then replace source with a placeholder matching the name of the partition +/// 1. for any relations between a source column and a partition target column, +/// replace source with a placeholder matching the name of the partition /// columns /// -/// 2. for any other relation with a source column, remove them. +/// 2. for any is equal relations between a source column and a non-partition target column, +/// replace source with is between expression with min(source_column) and max(source_column) placeholders +/// +/// 3. for any other relation with a source column, remove them. /// /// For example, for the predicate: /// @@ -680,21 +692,17 @@ impl ExtensionPlanner for MergeMetricExtensionPlanner { /// /// where `date` is a partition column, would result in the expr: /// -/// `$date = target.date and frob > 42` +/// `$date_0 = target.date and target.id between $id_1_min and $id_1_max and frob > 42` /// /// This leaves us with a predicate that we can push into delta scan after expanding it out to /// a conjunction between the distinct partitions in the source input. /// -/// TODO: A further improvement here might be for non-partition columns to be replaced with min/max -/// checks, so the above example could become: -/// -/// `$date = target.date and target.id between 12345 and 99999 and frob > 42` fn generalize_filter( predicate: Expr, partition_columns: &Vec, source_name: &TableReference, target_name: &TableReference, - placeholders: &mut HashMap, + placeholders: &mut Vec, ) -> Option { #[derive(Debug)] enum ReferenceTableCheck { @@ -738,29 +746,94 @@ fn generalize_filter( res } + fn construct_placeholder( + binary: BinaryExpr, + source_left: bool, + is_partition_column: bool, + column_name: String, + placeholders: &mut Vec, + ) -> Option { + if is_partition_column { + let placeholder_name = format!("{column_name}_{}", placeholders.len()); + let placeholder = Expr::Placeholder(Placeholder { + id: placeholder_name.clone(), + data_type: None, + }); + + let (left, right, source_expr): (Box, Box, Expr) = if source_left { + (placeholder.into(), binary.clone().right, *binary.left) + } else { + (binary.clone().left, placeholder.into(), *binary.right) + }; + + let replaced = Expr::BinaryExpr(BinaryExpr { + left, + op: binary.op, + right, + }); + + placeholders.push(PredicatePlaceholder { + expr: source_expr, + alias: placeholder_name, + is_aggregate: false, + }); + + Some(replaced) + } else { + match binary.op { + Operator::Eq => { + let name_min = format!("{column_name}_{}_min", placeholders.len()); + let placeholder_min = Expr::Placeholder(Placeholder { + id: name_min.clone(), + data_type: None, + }); + let name_max = format!("{column_name}_{}_max", placeholders.len()); + let placeholder_max = Expr::Placeholder(Placeholder { + id: name_max.clone(), + data_type: None, + }); + let (source_expr, target_expr) = if source_left { + (*binary.left, *binary.right) + } else { + (*binary.right, *binary.left) + }; + let replaced = Expr::Between(Between { + expr: target_expr.into(), + negated: false, + low: placeholder_min.into(), + high: placeholder_max.into(), + }); + + placeholders.push(PredicatePlaceholder { + expr: min(source_expr.clone()), + alias: name_min, + is_aggregate: true, + }); + placeholders.push(PredicatePlaceholder { + expr: max(source_expr), + alias: name_max, + is_aggregate: true, + }); + Some(replaced) + } + _ => None, + } + } + } + match predicate { Expr::BinaryExpr(binary) => { if references_table(&binary.right, source_name).has_reference() { if let ReferenceTableCheck::HasReference(left_target) = references_table(&binary.left, target_name) { - if partition_columns.contains(&left_target) { - let placeholder_name = format!("{left_target}_{}", placeholders.len()); - - let placeholder = Expr::Placeholder(datafusion_expr::expr::Placeholder { - id: placeholder_name.clone(), - data_type: None, - }); - let replaced = Expr::BinaryExpr(BinaryExpr { - left: binary.left, - op: binary.op, - right: placeholder.into(), - }); - - placeholders.insert(placeholder_name, *binary.right); - - return Some(replaced); - } + return construct_placeholder( + binary, + false, + partition_columns.contains(&left_target), + left_target, + placeholders, + ); } return None; } @@ -768,23 +841,13 @@ fn generalize_filter( if let ReferenceTableCheck::HasReference(right_target) = references_table(&binary.right, target_name) { - if partition_columns.contains(&right_target) { - let placeholder_name = format!("{right_target}_{}", placeholders.len()); - - let placeholder = Expr::Placeholder(datafusion_expr::expr::Placeholder { - id: placeholder_name.clone(), - data_type: None, - }); - let replaced = Expr::BinaryExpr(BinaryExpr { - right: binary.right, - op: binary.op, - left: placeholder.into(), - }); - - placeholders.insert(placeholder_name, *binary.left); - - return Some(replaced); - } + return construct_placeholder( + binary, + true, + partition_columns.contains(&right_target), + right_target, + placeholders, + ); } return None; } @@ -830,12 +893,16 @@ fn generalize_filter( ReferenceTableCheck::HasReference(col) => { let placeholder_name = format!("{col}_{}", placeholders.len()); - let placeholder = Expr::Placeholder(datafusion_expr::expr::Placeholder { + let placeholder = Expr::Placeholder(Placeholder { id: placeholder_name.clone(), data_type: None, }); - placeholders.insert(placeholder_name, other); + placeholders.push(PredicatePlaceholder { + expr: other, + alias: placeholder_name, + is_aggregate: true, + }); Some(placeholder) } @@ -869,7 +936,7 @@ async fn try_construct_early_filter( let table_metadata = table_snapshot.metadata(); let partition_columns = &table_metadata.partition_columns; - let mut placeholders = HashMap::default(); + let mut placeholders = Vec::default(); match generalize_filter( join_predicate, @@ -881,21 +948,24 @@ async fn try_construct_early_filter( None => Ok(None), Some(filter) => { if placeholders.is_empty() { - // if we haven't recognised any partition-based predicates in the join predicate, return our reduced filter + // if we haven't recognised any source predicates in the join predicate, return our filter with static only predicates Ok(Some(filter)) } else { - // if we have some recognised partitions, then discover the distinct set of partitions in the source data and - // make a new filter, which expands out the placeholders for each distinct partition (and then OR these together) - let distinct_partitions = LogicalPlan::Distinct(Distinct::All( - LogicalPlan::Projection(Projection::try_new( - placeholders - .into_iter() - .map(|(alias, expr)| expr.alias(alias)) - .collect_vec(), - source.clone().into(), - )?) - .into(), - )); + // if we have some filters, which depend on the source df, then collect the placeholders values from the source data + // We aggregate the distinct values for partitions with the group_columns and stats(min, max) for dynamic filter as agg_columns + // Can be translated into `SELECT partition1 as part1_0, min(id) as id_1_min, max(id) as id_1_max FROM source GROUP BY partition1` + let (agg_columns, group_columns) = placeholders.into_iter().partition_map(|p| { + if p.is_aggregate { + Left(p.expr.alias(p.alias)) + } else { + Right(p.expr.alias(p.alias)) + } + }); + let distinct_partitions = LogicalPlan::Aggregate(Aggregate::try_new( + source.clone().into(), + group_columns, + agg_columns, + )?); let execution_plan = session_state .create_physical_plan(&distinct_partitions) .await?; @@ -1584,7 +1654,6 @@ mod tests { use itertools::Itertools; use regex::Regex; use serde_json::json; - use std::collections::HashMap; use std::ops::Neg; use std::sync::Arc; @@ -2064,7 +2133,10 @@ mod tests { let commit_info = table.history(None).await.unwrap(); let last_commit = &commit_info[0]; let parameters = last_commit.operation_parameters.clone().unwrap(); - assert_eq!(parameters["predicate"], "modified = '2021-02-02'"); + assert_eq!( + parameters["predicate"], + "id BETWEEN 'B' AND 'C' AND modified = '2021-02-02'" + ); assert_eq!( parameters["mergePredicate"], "target.id = source.id AND target.modified = '2021-02-02'" @@ -2205,7 +2277,7 @@ mod tests { extra_info["operationMetrics"], serde_json::to_value(&metrics).unwrap() ); - assert!(!parameters.contains_key("predicate")); + assert_eq!(parameters["predicate"], "id BETWEEN 'B' AND 'X'"); assert_eq!(parameters["mergePredicate"], json!("target.id = source.id")); assert_eq!( parameters["matchedPredicates"], @@ -2487,7 +2559,10 @@ mod tests { let last_commit = &commit_info[0]; let parameters = last_commit.operation_parameters.clone().unwrap(); - assert_eq!(parameters["predicate"], json!("modified = '2021-02-02'")); + assert_eq!( + parameters["predicate"], + json!("id BETWEEN 'B' AND 'X' AND modified = '2021-02-02'") + ); let expected = vec![ "+----+-------+------------+", @@ -2591,7 +2666,7 @@ mod tests { let parsed_filter = col(Column::new(source.clone().into(), "id")) .eq(col(Column::new(target.clone().into(), "id"))); - let mut placeholders = HashMap::default(); + let mut placeholders = Vec::default(); let generalized = generalize_filter( parsed_filter, @@ -2623,7 +2698,7 @@ mod tests { let parsed_filter = (source_id.clone().eq(target_id.clone())) .or(source_id.clone().is_null().and(target_id.clone().is_null())); - let mut placeholders = HashMap::default(); + let mut placeholders = Vec::default(); let generalized = generalize_filter( parsed_filter, @@ -2646,9 +2721,9 @@ mod tests { }) .and(target_id.clone().is_null())); - assert!(placeholders.len() == 2); + assert_eq!(placeholders.len(), 2); - let captured_expressions = placeholders.values().collect_vec(); + let captured_expressions = placeholders.into_iter().map(|p| p.expr).collect_vec(); assert!(captured_expressions.contains(&&source_id)); assert!(captured_expressions.contains(&&source_id.is_null())); @@ -2667,7 +2742,7 @@ mod tests { .neg() .eq(col(Column::new(target.clone().into(), "id"))); - let mut placeholders = HashMap::default(); + let mut placeholders = Vec::default(); let generalized = generalize_filter( parsed_filter, @@ -2687,12 +2762,13 @@ mod tests { assert_eq!(generalized, expected_filter); assert_eq!(placeholders.len(), 1); - - let placeholder_expr = &placeholders["id_0"]; + let placeholder_expr = placeholders.get(0).unwrap(); let expected_placeholder = col(Column::new(source.clone().into(), "id")).neg(); - assert_eq!(placeholder_expr, &expected_placeholder); + assert_eq!(placeholder_expr.expr, expected_placeholder); + assert_eq!(placeholder_expr.alias, "id_0"); + assert_eq!(placeholder_expr.is_aggregate, false); } #[tokio::test] @@ -2705,7 +2781,7 @@ mod tests { .eq(col(Column::new(target.clone().into(), "id"))) .and(col(Column::new(target.clone().into(), "id")).eq(lit("C"))); - let mut placeholders = HashMap::default(); + let mut placeholders = Vec::default(); let generalized = generalize_filter( parsed_filter, @@ -2728,15 +2804,14 @@ mod tests { } #[tokio::test] - async fn test_generalize_filter_keeps_only_static_target_references() { + async fn test_generalize_filter_with_dynamic_target_range_references() { let source = TableReference::parse_str("source"); let target = TableReference::parse_str("target"); let parsed_filter = col(Column::new(source.clone().into(), "id")) - .eq(col(Column::new(target.clone().into(), "id"))) - .and(col(Column::new(target.clone().into(), "id")).eq(lit("C"))); + .eq(col(Column::new(target.clone().into(), "id"))); - let mut placeholders = HashMap::default(); + let mut placeholders = Vec::default(); let generalized = generalize_filter( parsed_filter, @@ -2746,8 +2821,16 @@ mod tests { &mut placeholders, ) .unwrap(); - - let expected_filter = col(Column::new(target.clone().into(), "id")).eq(lit("C")); + let expected_filter_l = Expr::Placeholder(Placeholder { + id: "id_0_min".to_owned(), + data_type: None, + }); + let expected_filter_h = Expr::Placeholder(Placeholder { + id: "id_0_max".to_owned(), + data_type: None, + }); + let expected_filter = col(Column::new(target.clone().into(), "id")) + .between(expected_filter_l, expected_filter_h); assert_eq!(generalized, expected_filter); } @@ -2761,7 +2844,7 @@ mod tests { .eq(col(Column::new(target.clone().into(), "id"))) .and(col(Column::new(source.clone().into(), "id")).eq(lit("C"))); - let mut placeholders = HashMap::default(); + let mut placeholders = Vec::default(); let generalized = generalize_filter( parsed_filter, @@ -2879,6 +2962,158 @@ mod tests { assert_eq!(split_pred, expected_pred_parts); } + #[tokio::test] + async fn test_try_construct_early_filter_with_range() { + let schema = get_arrow_schema(&None); + let table = setup_table(Some(vec!["modified"])).await; + + assert_eq!(table.version(), 0); + assert_eq!(table.get_files_count(), 0); + + let ctx = SessionContext::new(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["B", "C"])), + Arc::new(arrow::array::Int32Array::from(vec![10, 20])), + Arc::new(arrow::array::StringArray::from(vec![ + "2023-07-04", + "2023-07-04", + ])), + ], + ) + .unwrap(); + let source = ctx.read_batch(batch).unwrap(); + + let source_name = TableReference::parse_str("source"); + let target_name = TableReference::parse_str("target"); + + let source = LogicalPlanBuilder::scan( + source_name.clone(), + provider_as_source(source.into_view()), + None, + ) + .unwrap() + .build() + .unwrap(); + + let join_predicate = col(Column { + relation: Some(source_name.clone()), + name: "id".to_owned(), + }) + .eq(col(Column { + relation: Some(target_name.clone()), + name: "id".to_owned(), + })); + + let pred = try_construct_early_filter( + join_predicate, + table.snapshot().unwrap(), + &ctx.state(), + &source, + &source_name, + &target_name, + ) + .await + .unwrap(); + + assert!(pred.is_some()); + + let filter = col(Column { + relation: Some(target_name.clone()), + name: "id".to_owned(), + }) + .between( + Expr::Literal(ScalarValue::Utf8(Some("B".to_string()))), + Expr::Literal(ScalarValue::Utf8(Some("C".to_string()))), + ); + assert_eq!(pred.unwrap(), filter); + } + + #[tokio::test] + async fn test_try_construct_early_filter_with_partition_and_range() { + let schema = get_arrow_schema(&None); + let table = setup_table(Some(vec!["modified"])).await; + + assert_eq!(table.version(), 0); + assert_eq!(table.get_files_count(), 0); + + let ctx = SessionContext::new(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["B", "C"])), + Arc::new(arrow::array::Int32Array::from(vec![10, 20])), + Arc::new(arrow::array::StringArray::from(vec![ + "2023-07-04", + "2023-07-04", + ])), + ], + ) + .unwrap(); + let source = ctx.read_batch(batch).unwrap(); + + let source_name = TableReference::parse_str("source"); + let target_name = TableReference::parse_str("target"); + + let source = LogicalPlanBuilder::scan( + source_name.clone(), + provider_as_source(source.into_view()), + None, + ) + .unwrap() + .build() + .unwrap(); + + let join_predicate = col(Column { + relation: Some(source_name.clone()), + name: "id".to_owned(), + }) + .eq(col(Column { + relation: Some(target_name.clone()), + name: "id".to_owned(), + })) + .and( + col(Column { + relation: Some(source_name.clone()), + name: "modified".to_owned(), + }) + .eq(col(Column { + relation: Some(target_name.clone()), + name: "modified".to_owned(), + })), + ); + + let pred = try_construct_early_filter( + join_predicate, + table.snapshot().unwrap(), + &ctx.state(), + &source, + &source_name, + &target_name, + ) + .await + .unwrap(); + + assert!(pred.is_some()); + + let filter = col(Column { + relation: Some(target_name.clone()), + name: "id".to_owned(), + }) + .between( + Expr::Literal(ScalarValue::Utf8(Some("B".to_string()))), + Expr::Literal(ScalarValue::Utf8(Some("C".to_string()))), + ) + .and( + Expr::Literal(ScalarValue::Utf8(Some("2023-07-04".to_string()))).eq(col(Column { + relation: Some(target_name.clone()), + name: "modified".to_owned(), + })), + ); + assert_eq!(pred.unwrap(), filter); + } + #[tokio::test] async fn test_merge_pushdowns() { //See https://github.com/delta-io/delta-rs/issues/2158 diff --git a/crates/core/tests/command_merge.rs b/crates/core/tests/command_merge.rs index 10855aa0a8..76b511254b 100644 --- a/crates/core/tests/command_merge.rs +++ b/crates/core/tests/command_merge.rs @@ -138,17 +138,17 @@ async fn merge( #[tokio::test] async fn test_merge_concurrent_conflict() { - // No partition key or filter predicate -> Commit conflict + // Overlapping id ranges -> Commit conflict let tmp_dir = tempfile::tempdir().unwrap(); let table_uri = tmp_dir.path().to_str().to_owned().unwrap(); - let table_ref1 = create_table(table_uri, Some(vec!["event_date"])).await; + let table_ref1 = create_table(&table_uri.to_string(), Some(vec!["event_date"])).await; let table_ref2 = open_table(table_uri).await.unwrap(); - let (df1, df2) = create_test_data(); + let (df1, _df2) = create_test_data(); let expr = col("target.id").eq(col("source.id")); - let (_table_ref1, _metrics) = merge(table_ref1, df1, expr.clone()).await.unwrap(); - let result = merge(table_ref2, df2, expr).await; + let (_table_ref1, _metrics) = merge(table_ref1, df1.clone(), expr.clone()).await.unwrap(); + let result = merge(table_ref2, df1, expr).await; assert!(matches!( result.as_ref().unwrap_err(), @@ -159,6 +159,23 @@ async fn test_merge_concurrent_conflict() { } } +#[tokio::test] +async fn test_merge_different_range() { + // No overlapping id ranges -> No conflict + let tmp_dir = tempfile::tempdir().unwrap(); + let table_uri = tmp_dir.path().to_str().to_owned().unwrap(); + + let table_ref1 = create_table(table_uri, Some(vec!["event_date"])).await; + let table_ref2 = open_table(table_uri).await.unwrap(); + let (df1, df2) = create_test_data(); + + let expr = col("target.id").eq(col("source.id")); + let (_table_ref1, _metrics) = merge(table_ref1, df1, expr.clone()).await.unwrap(); + let result = merge(table_ref2, df2, expr).await; + + assert!(result.is_ok()); +} + #[tokio::test] async fn test_merge_concurrent_different_partition() { // partition key in predicate -> Successful merge @@ -175,9 +192,7 @@ async fn test_merge_concurrent_different_partition() { let (_table_ref1, _metrics) = merge(table_ref1, df1, expr.clone()).await.unwrap(); let result = merge(table_ref2, df2, expr).await; - // TODO: Currently it throws a Version mismatch error, but the merge commit was successfully - // This bug needs to be fixed, see pull request #2280 - assert!(result.as_ref().is_ok()); + assert!(result.is_ok()); } #[tokio::test]