Skip to content

Commit

Permalink
Add optimizer rule for type coercion (binary operations only) (#3222)
Browse files Browse the repository at this point in the history
* Add binary type coercion to logical plan and do not allow CAST to change an expression name

* fix tests

* update avro tests

* add reference to GitHub issue

* unignore timestamp_add_interval_months

* fix: update tests to use correct column types

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
andygrove and alamb authored Sep 6, 2022
1 parent 9b546e7 commit 191d8b7
Show file tree
Hide file tree
Showing 21 changed files with 427 additions and 213 deletions.
4 changes: 4 additions & 0 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys;
use datafusion_optimizer::pre_cast_lit_in_comparison::PreCastLitInComparisonExpressions;
use datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
use datafusion_optimizer::scalar_subquery_to_join::ScalarSubqueryToJoin;
use datafusion_optimizer::type_coercion::TypeCoercion;
use datafusion_sql::{
parser::DFParser,
planner::{ContextProvider, SqlToRel},
Expand Down Expand Up @@ -1433,6 +1434,9 @@ impl SessionState {
}
rules.push(Arc::new(ReduceOuterJoin::new()));
rules.push(Arc::new(FilterPushDown::new()));
// we do type coercion after filter push down so that we don't push CAST filters to Parquet
// until https://github.com/apache/arrow-datafusion/issues/3289 is resolved
rules.push(Arc::new(TypeCoercion::new()));
rules.push(Arc::new(LimitPushDown::new()));
rules.push(Arc::new(SingleDistinctToGroupBy::new()));

Expand Down
12 changes: 6 additions & 6 deletions datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,13 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
name += "END";
Ok(name)
}
Expr::Cast { expr, data_type } => {
let expr = create_physical_name(expr, false)?;
Ok(format!("CAST({} AS {:?})", expr, data_type))
Expr::Cast { expr, .. } => {
// CAST does not change the expression name
create_physical_name(expr, false)
}
Expr::TryCast { expr, data_type } => {
let expr = create_physical_name(expr, false)?;
Ok(format!("TRY_CAST({} AS {:?})", expr, data_type))
Expr::TryCast { expr, .. } => {
// CAST does not change the expression name
create_physical_name(expr, false)
}
Expr::Not(expr) => {
let expr = create_physical_name(expr, false)?;
Expand Down
16 changes: 8 additions & 8 deletions datafusion/core/tests/dataframe_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -667,14 +667,14 @@ async fn test_fn_substr() -> Result<()> {
async fn test_cast() -> Result<()> {
let expr = cast(col("b"), DataType::Float64);
let expected = vec![
"+-------------------------+",
"| CAST(test.b AS Float64) |",
"+-------------------------+",
"| 1 |",
"| 10 |",
"| 10 |",
"| 100 |",
"+-------------------------+",
"+--------+",
"| test.b |",
"+--------+",
"| 1 |",
"| 10 |",
"| 10 |",
"| 100 |",
"+--------+",
];

assert_fn_batches!(expr, expected);
Expand Down
1 change: 1 addition & 0 deletions datafusion/core/tests/parquet_pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,7 @@ impl ContextWithParquet {
let pretty_input = pretty_format_batches(&input).unwrap().to_string();

let logical_plan = self.ctx.optimize(&logical_plan).expect("optimizing plan");

let physical_plan = self
.ctx
.create_physical_plan(&logical_plan)
Expand Down
29 changes: 24 additions & 5 deletions datafusion/core/tests/sql/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,11 +462,11 @@ async fn csv_query_external_table_sum() {
"SELECT SUM(CAST(c7 AS BIGINT)), SUM(CAST(c8 AS BIGINT)) FROM aggregate_test_100";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-------------------------------------------+-------------------------------------------+",
"| SUM(CAST(aggregate_test_100.c7 AS Int64)) | SUM(CAST(aggregate_test_100.c8 AS Int64)) |",
"+-------------------------------------------+-------------------------------------------+",
"| 13060 | 3017641 |",
"+-------------------------------------------+-------------------------------------------+",
"+----------------------------+----------------------------+",
"| SUM(aggregate_test_100.c7) | SUM(aggregate_test_100.c8) |",
"+----------------------------+----------------------------+",
"| 13060 | 3017641 |",
"+----------------------------+----------------------------+",
];
assert_batches_eq!(expected, &actual);
}
Expand Down Expand Up @@ -555,6 +555,7 @@ async fn csv_query_count_one() {
}

#[tokio::test]
#[ignore] // https://github.com/apache/arrow-datafusion/issues/3353
async fn csv_query_approx_count() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
Expand All @@ -571,6 +572,24 @@ async fn csv_query_approx_count() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn csv_query_approx_count_dupe_expr_aliased() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql =
"SELECT approx_distinct(c9) a, approx_distinct(c9) b FROM aggregate_test_100";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-----+-----+",
"| a | b |",
"+-----+-----+",
"| 100 | 100 |",
"+-----+-----+",
];
assert_batches_eq!(expected, &actual);
Ok(())
}

// This test executes the APPROX_PERCENTILE_CONT aggregation against the test
// data, asserting the estimated quantiles are ±5% their actual values.
//
Expand Down
64 changes: 32 additions & 32 deletions datafusion/core/tests/sql/avro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,18 @@ async fn avro_query() {
let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+----+-----------------------------------------+",
"| id | CAST(alltypes_plain.string_col AS Utf8) |",
"+----+-----------------------------------------+",
"| 4 | 0 |",
"| 5 | 1 |",
"| 6 | 0 |",
"| 7 | 1 |",
"| 2 | 0 |",
"| 3 | 1 |",
"| 0 | 0 |",
"| 1 | 1 |",
"+----+-----------------------------------------+",
"+----+---------------------------+",
"| id | alltypes_plain.string_col |",
"+----+---------------------------+",
"| 4 | 0 |",
"| 5 | 1 |",
"| 6 | 0 |",
"| 7 | 1 |",
"| 2 | 0 |",
"| 3 | 1 |",
"| 0 | 0 |",
"| 1 | 1 |",
"+----+---------------------------+",
];

assert_batches_eq!(expected, &actual);
Expand Down Expand Up @@ -84,26 +84,26 @@ async fn avro_query_multiple_files() {
let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+----+-----------------------------------------+",
"| id | CAST(alltypes_plain.string_col AS Utf8) |",
"+----+-----------------------------------------+",
"| 4 | 0 |",
"| 5 | 1 |",
"| 6 | 0 |",
"| 7 | 1 |",
"| 2 | 0 |",
"| 3 | 1 |",
"| 0 | 0 |",
"| 1 | 1 |",
"| 4 | 0 |",
"| 5 | 1 |",
"| 6 | 0 |",
"| 7 | 1 |",
"| 2 | 0 |",
"| 3 | 1 |",
"| 0 | 0 |",
"| 1 | 1 |",
"+----+-----------------------------------------+",
"+----+---------------------------+",
"| id | alltypes_plain.string_col |",
"+----+---------------------------+",
"| 4 | 0 |",
"| 5 | 1 |",
"| 6 | 0 |",
"| 7 | 1 |",
"| 2 | 0 |",
"| 3 | 1 |",
"| 0 | 0 |",
"| 1 | 1 |",
"| 4 | 0 |",
"| 5 | 1 |",
"| 6 | 0 |",
"| 7 | 1 |",
"| 2 | 0 |",
"| 3 | 1 |",
"| 0 | 0 |",
"| 1 | 1 |",
"+----+---------------------------+",
];

assert_batches_eq!(expected, &actual);
Expand Down
106 changes: 53 additions & 53 deletions datafusion/core/tests/sql/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ async fn decimal_cast() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
"+------------------------------------------+",
"| CAST(Float64(1.23) AS Decimal128(10, 4)) |",
"+------------------------------------------+",
"| 1.2300 |",
"+------------------------------------------+",
"+---------------+",
"| Float64(1.23) |",
"+---------------+",
"| 1.2300 |",
"+---------------+",
];
assert_batches_eq!(expected, &actual);

Expand All @@ -42,11 +42,11 @@ async fn decimal_cast() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
"+---------------------------------------------------------------------+",
"| CAST(CAST(Float64(1.23) AS Decimal128(10, 3)) AS Decimal128(10, 4)) |",
"+---------------------------------------------------------------------+",
"| 1.2300 |",
"+---------------------------------------------------------------------+",
"+---------------+",
"| Float64(1.23) |",
"+---------------+",
"| 1.2300 |",
"+---------------+",
];
assert_batches_eq!(expected, &actual);

Expand All @@ -57,11 +57,11 @@ async fn decimal_cast() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
"+--------------------------------------------+",
"| CAST(Float64(1.2345) AS Decimal128(24, 2)) |",
"+--------------------------------------------+",
"| 1.23 |",
"+--------------------------------------------+",
"+-----------------+",
"| Float64(1.2345) |",
"+-----------------+",
"| 1.23 |",
"+-----------------+",
];
assert_batches_eq!(expected, &actual);

Expand Down Expand Up @@ -550,25 +550,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
"+----------------------------------------------------------------+",
"| decimal_simple.c1 / CAST(Float64(0.00001) AS Decimal128(5, 5)) |",
"+----------------------------------------------------------------+",
"| 1.000000000000 |",
"| 2.000000000000 |",
"| 2.000000000000 |",
"| 3.000000000000 |",
"| 3.000000000000 |",
"| 3.000000000000 |",
"| 4.000000000000 |",
"| 4.000000000000 |",
"| 4.000000000000 |",
"| 4.000000000000 |",
"| 5.000000000000 |",
"| 5.000000000000 |",
"| 5.000000000000 |",
"| 5.000000000000 |",
"| 5.000000000000 |",
"+----------------------------------------------------------------+",
"+--------------------------------------+",
"| decimal_simple.c1 / Float64(0.00001) |",
"+--------------------------------------+",
"| 1.000000000000 |",
"| 2.000000000000 |",
"| 2.000000000000 |",
"| 3.000000000000 |",
"| 3.000000000000 |",
"| 3.000000000000 |",
"| 4.000000000000 |",
"| 4.000000000000 |",
"| 4.000000000000 |",
"| 4.000000000000 |",
"| 5.000000000000 |",
"| 5.000000000000 |",
"| 5.000000000000 |",
"| 5.000000000000 |",
"| 5.000000000000 |",
"+--------------------------------------+",
];
assert_batches_eq!(expected, &actual);

Expand Down Expand Up @@ -609,25 +609,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
"+----------------------------------------------------------------+",
"| decimal_simple.c5 % CAST(Float64(0.00001) AS Decimal128(5, 5)) |",
"+----------------------------------------------------------------+",
"| 0.0000040 |",
"| 0.0000050 |",
"| 0.0000090 |",
"| 0.0000020 |",
"| 0.0000050 |",
"| 0.0000010 |",
"| 0.0000040 |",
"| 0.0000000 |",
"| 0.0000000 |",
"| 0.0000040 |",
"| 0.0000020 |",
"| 0.0000080 |",
"| 0.0000030 |",
"| 0.0000080 |",
"| 0.0000000 |",
"+----------------------------------------------------------------+",
"+--------------------------------------+",
"| decimal_simple.c5 % Float64(0.00001) |",
"+--------------------------------------+",
"| 0.0000040 |",
"| 0.0000050 |",
"| 0.0000090 |",
"| 0.0000020 |",
"| 0.0000050 |",
"| 0.0000010 |",
"| 0.0000040 |",
"| 0.0000000 |",
"| 0.0000000 |",
"| 0.0000040 |",
"| 0.0000020 |",
"| 0.0000080 |",
"| 0.0000030 |",
"| 0.0000080 |",
"| 0.0000000 |",
"+--------------------------------------+",
];
assert_batches_eq!(expected, &actual);

Expand Down
6 changes: 3 additions & 3 deletions datafusion/core/tests/sql/explain_analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ order by
let expected = "\
Sort: #revenue DESC NULLS FIRST\
\n Projection: #customer.c_custkey, #customer.c_name, #SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue, #customer.c_acctbal, #nation.n_name, #customer.c_address, #customer.c_phone, #customer.c_comment\
\n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, #customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, #customer.c_comment]], aggr=[[SUM(#lineitem.l_extendedprice * Int64(1) - #lineitem.l_discount)]]\
\n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, #customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, #customer.c_comment]], aggr=[[SUM(#lineitem.l_extendedprice * CAST(Int64(1) AS Float64) - #lineitem.l_discount)]]\
\n Inner Join: #customer.c_nationkey = #nation.n_nationkey\
\n Inner Join: #orders.o_orderkey = #lineitem.l_orderkey\
\n Inner Join: #customer.c_custkey = #orders.o_custkey\
Expand All @@ -663,7 +663,7 @@ order by
\n Filter: #lineitem.l_returnflag = Utf8(\"R\")\
\n TableScan: lineitem projection=[l_orderkey, l_extendedprice, l_discount, l_returnflag], partial_filters=[#lineitem.l_returnflag = Utf8(\"R\")]\
\n TableScan: nation projection=[n_nationkey, n_name]";
assert_eq!(format!("{:?}", plan.unwrap()), expected);
assert_eq!(expected, format!("{:?}", plan.unwrap()),);

Ok(())
}
Expand Down Expand Up @@ -694,7 +694,7 @@ async fn test_physical_plan_display_indent() {
" RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 9000)",
" AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)]",
" CoalesceBatchesExec: target_batch_size=4096",
" FilterExec: c12@1 < CAST(10 AS Float64)",
" FilterExec: c12@1 < 10",
" RepartitionExec: partitioning=RoundRobinBatch(9000)",
" CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, limit=None, projection=[c1, c12]",
];
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/tests/sql/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,8 @@ async fn query_not() -> Result<()> {
async fn csv_query_sum_cast() {
let ctx = SessionContext::new();
register_aggregate_csv_by_sql(&ctx).await;
// c8 = i32; c9 = i64
let sql = "SELECT c8 + c9 FROM aggregate_test_100";
// c8 = i32; c6 = i64
let sql = "SELECT c8 + c6 FROM aggregate_test_100";
// check that the physical and logical schemas are equal
execute(&ctx, sql).await;
}
Expand Down
Loading

0 comments on commit 191d8b7

Please sign in to comment.