Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Logical Plans more readable by removing extra aliases #10832

Merged
merged 14 commits into from
Jun 11, 2024
73 changes: 54 additions & 19 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ impl CommonSubexprEliminate {
fn rewrite_exprs_list(
&self,
exprs_list: &[&[Expr]],
arrays_list: &[&[Vec<(usize, String)>]],
arrays_list: &[&[IdArray]],
expr_stats: &ExprStats,
common_exprs: &mut CommonExprs,
) -> Result<Vec<Vec<Expr>>> {
Expand Down Expand Up @@ -159,7 +159,7 @@ impl CommonSubexprEliminate {
fn rewrite_expr(
&self,
exprs_list: &[&[Expr]],
arrays_list: &[&[Vec<(usize, String)>]],
arrays_list: &[&[IdArray]],
input: &LogicalPlan,
expr_stats: &ExprStats,
config: &dyn OptimizerConfig,
Expand Down Expand Up @@ -480,7 +480,7 @@ fn to_arrays(
input_schema: DFSchemaRef,
expr_stats: &mut ExprStats,
expr_mask: ExprMask,
) -> Result<Vec<Vec<(usize, String)>>> {
) -> Result<Vec<IdArray>> {
expr.iter()
.map(|e| {
let mut id_array = vec![];
Expand Down Expand Up @@ -739,7 +739,7 @@ fn expr_identifier(expr: &Expr, sub_expr_identifier: Identifier) -> Identifier {
fn expr_to_identifier(
expr: &Expr,
expr_stats: &mut ExprStats,
id_array: &mut Vec<(usize, Identifier)>,
id_array: &mut IdArray,
input_schema: DFSchemaRef,
expr_mask: ExprMask,
) -> Result<()> {
Expand Down Expand Up @@ -769,15 +769,28 @@ struct CommonSubexprRewriter<'a> {
common_exprs: &'a mut CommonExprs,
// preorder index, starts from 0.
down_index: usize,
// how many aliases have we seen so far
alias_counter: usize,
}

impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
type Node = Expr;

fn f_up(&mut self, expr: Expr) -> Result<Transformed<Self::Node>> {
if matches!(expr, Expr::Alias(_)) {
self.alias_counter -= 1
}
Ok(Transformed::no(expr))
}

fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
// The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate
// the `id_array`, which records the expr's identifier used to rewrite expr. So if we
// skip an expr in `ExprIdentifierVisitor`, we should skip it here, too.
if matches!(expr, Expr::Alias(_)) {
self.alias_counter += 1;
}

if expr.short_circuits() || expr.is_volatile()? {
return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump));
}
Expand All @@ -801,15 +814,16 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> {

let expr_name = expr.display_name()?;
self.common_exprs.insert(expr_id.clone(), expr);
// Alias this `Column` expr to it original "expr name",
// `projection_push_down` optimizer use "expr name" to eliminate useless
// projections.
// TODO: do we really need to alias here?
Ok(Transformed::new(
col(expr_id).alias(expr_name),
true,
TreeNodeRecursion::Jump,
))

// alias the expressions without an `Alias` ancestor node
let rewritten = if self.alias_counter > 0 {
col(expr_id)
} else {
self.alias_counter += 1;
col(expr_id).alias(expr_name)
};

Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump))
} else {
Ok(Transformed::no(expr))
}
Expand All @@ -829,6 +843,7 @@ fn replace_common_expr(
id_array,
common_exprs,
down_index: 0,
alias_counter: 0,
})
.data()
}
Expand Down Expand Up @@ -962,6 +977,26 @@ mod test {
Ok(())
}

#[test]
MohamedAbdeen21 marked this conversation as resolved.
Show resolved Hide resolved
fn nested_aliases() -> Result<()> {
let table_scan = test_table_scan()?;

let plan = LogicalPlanBuilder::from(table_scan.clone())
.project(vec![
(col("a") + col("b") - col("c")).alias("alias1") * (col("a") + col("b")),
col("a") + col("b"),
])?
.build()?;

let expected = "Projection: {test.a + test.b|{test.b}|{test.a}} - test.c AS alias1 * {test.a + test.b|{test.b}|{test.a}} AS test.a + test.b, {test.a + test.b|{test.b}|{test.a}} AS test.a + test.b\
\n Projection: test.a + test.b AS {test.a + test.b|{test.b}|{test.a}}, test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);

Ok(())
}

#[test]
fn aggregate() -> Result<()> {
let table_scan = test_table_scan()?;
Expand Down Expand Up @@ -1006,7 +1041,7 @@ mod test {
)?
.build()?;

let expected = "Projection: {AVG(test.a)|{test.a}} AS AVG(test.a) AS col1, {AVG(test.a)|{test.a}} AS AVG(test.a) AS col2, col3, {AVG(test.c)} AS AVG(test.c), {my_agg(test.a)|{test.a}} AS my_agg(test.a) AS col4, {my_agg(test.a)|{test.a}} AS my_agg(test.a) AS col5, col6, {my_agg(test.c)} AS my_agg(test.c)\
let expected = "Projection: {AVG(test.a)|{test.a}} AS col1, {AVG(test.a)|{test.a}} AS col2, col3, {AVG(test.c)} AS AVG(test.c), {my_agg(test.a)|{test.a}} AS col4, {my_agg(test.a)|{test.a}} AS col5, col6, {my_agg(test.c)} AS my_agg(test.c)\
\n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS {AVG(test.a)|{test.a}}, my_agg(test.a) AS {my_agg(test.a)|{test.a}}, AVG(test.b) AS col3, AVG(test.c) AS {AVG(test.c)}, my_agg(test.b) AS col6, my_agg(test.c) AS {my_agg(test.c)}]]\
\n TableScan: test";

Expand Down Expand Up @@ -1042,7 +1077,7 @@ mod test {
)?
.build()?;

let expected = "Aggregate: groupBy=[[]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS col1, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS col2]]\n Projection: UInt32(1) + test.a AS {UInt32(1) + test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\n TableScan: test";
let expected = "Aggregate: groupBy=[[]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS col1, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS col2]]\n Projection: UInt32(1) + test.a AS {UInt32(1) + test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);

Expand All @@ -1057,7 +1092,7 @@ mod test {
)?
.build()?;

let expected = "Aggregate: groupBy=[[{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS col1, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS col2]]\
let expected = "Aggregate: groupBy=[[{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS col1, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS col2]]\
\n Projection: UInt32(1) + test.a AS {UInt32(1) + test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\
\n TableScan: test";

Expand All @@ -1078,8 +1113,8 @@ mod test {
)?
.build()?;

let expected = "Projection: UInt32(1) + test.a, UInt32(1) + {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}} AS AVG(UInt32(1) + test.a) AS col1, UInt32(1) - {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}} AS AVG(UInt32(1) + test.a) AS col2, {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}} AS AVG(UInt32(1) + test.a), UInt32(1) + {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}} AS my_agg(UInt32(1) + test.a) AS col3, UInt32(1) - {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}} AS my_agg(UInt32(1) + test.a) AS col4, {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}} AS my_agg(UInt32(1) + test.a)\
\n Aggregate: groupBy=[[{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}}, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}}]]\
let expected = "Projection: UInt32(1) + test.a, UInt32(1) + {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}} AS col1, UInt32(1) - {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}} AS col2, {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)} AS AVG(UInt32(1) + test.a), UInt32(1) + {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}} AS col3, UInt32(1) - {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}} AS col4, {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)} AS my_agg(UInt32(1) + test.a)\
\n Aggregate: groupBy=[[{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}, AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)}, my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS {my_agg({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)}]]\
\n Projection: UInt32(1) + test.a AS {UInt32(1) + test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\
\n TableScan: test";

Expand Down Expand Up @@ -1126,7 +1161,7 @@ mod test {
])?
.build()?;

let expected = "Projection: {Int32(1) + test.a|{test.a}|{Int32(1)}} AS Int32(1) + test.a AS first, {Int32(1) + test.a|{test.a}|{Int32(1)}} AS Int32(1) + test.a AS second\
let expected = "Projection: {Int32(1) + test.a|{test.a}|{Int32(1)}} AS first, {Int32(1) + test.a|{test.a}|{Int32(1)}} AS second\
\n Projection: Int32(1) + test.a AS {Int32(1) + test.a|{test.a}|{Int32(1)}}, test.a, test.b, test.c\
\n TableScan: test";

Expand Down
18 changes: 9 additions & 9 deletions datafusion/sqllogictest/test_files/group_by.slt
Original file line number Diff line number Diff line change
Expand Up @@ -4538,7 +4538,7 @@ CREATE EXTERNAL TABLE timestamp_table (
c2 INT,
)
STORED AS CSV
LOCATION 'test_files/scratch/group_by/timestamp_table'
LOCATION 'test_files/scratch/group_by/timestamp_table'
OPTIONS ('format.has_header' 'true');

# Group By using date_trunc
Expand Down Expand Up @@ -4611,7 +4611,7 @@ DROP TABLE timestamp_table;

# Table with an int column and Dict<Int8> column:
statement ok
CREATE TABLE int8_dict AS VALUES
CREATE TABLE int8_dict AS VALUES
(1, arrow_cast('A', 'Dictionary(Int8, Utf8)')),
(2, arrow_cast('B', 'Dictionary(Int8, Utf8)')),
(2, arrow_cast('A', 'Dictionary(Int8, Utf8)')),
Expand Down Expand Up @@ -4649,7 +4649,7 @@ DROP TABLE int8_dict;

# Table with an int column and Dict<Int16> column:
statement ok
CREATE TABLE int16_dict AS VALUES
CREATE TABLE int16_dict AS VALUES
(1, arrow_cast('A', 'Dictionary(Int16, Utf8)')),
(2, arrow_cast('B', 'Dictionary(Int16, Utf8)')),
(2, arrow_cast('A', 'Dictionary(Int16, Utf8)')),
Expand Down Expand Up @@ -4687,7 +4687,7 @@ DROP TABLE int16_dict;

# Table with an int column and Dict<Int32> column:
statement ok
CREATE TABLE int32_dict AS VALUES
CREATE TABLE int32_dict AS VALUES
(1, arrow_cast('A', 'Dictionary(Int32, Utf8)')),
(2, arrow_cast('B', 'Dictionary(Int32, Utf8)')),
(2, arrow_cast('A', 'Dictionary(Int32, Utf8)')),
Expand Down Expand Up @@ -4725,7 +4725,7 @@ DROP TABLE int32_dict;

# Table with an int column and Dict<Int64> column:
statement ok
CREATE TABLE int64_dict AS VALUES
CREATE TABLE int64_dict AS VALUES
(1, arrow_cast('A', 'Dictionary(Int64, Utf8)')),
(2, arrow_cast('B', 'Dictionary(Int64, Utf8)')),
(2, arrow_cast('A', 'Dictionary(Int64, Utf8)')),
Expand Down Expand Up @@ -4763,7 +4763,7 @@ DROP TABLE int64_dict;

# Table with an int column and Dict<UInt8> column:
statement ok
CREATE TABLE uint8_dict AS VALUES
CREATE TABLE uint8_dict AS VALUES
(1, arrow_cast('A', 'Dictionary(UInt8, Utf8)')),
(2, arrow_cast('B', 'Dictionary(UInt8, Utf8)')),
(2, arrow_cast('A', 'Dictionary(UInt8, Utf8)')),
Expand Down Expand Up @@ -4801,7 +4801,7 @@ DROP TABLE uint8_dict;

# Table with an int column and Dict<UInt16> column:
statement ok
CREATE TABLE uint16_dict AS VALUES
CREATE TABLE uint16_dict AS VALUES
(1, arrow_cast('A', 'Dictionary(UInt16, Utf8)')),
(2, arrow_cast('B', 'Dictionary(UInt16, Utf8)')),
(2, arrow_cast('A', 'Dictionary(UInt16, Utf8)')),
Expand Down Expand Up @@ -4839,7 +4839,7 @@ DROP TABLE uint16_dict;

# Table with an int column and Dict<UInt32> column:
statement ok
CREATE TABLE uint32_dict AS VALUES
CREATE TABLE uint32_dict AS VALUES
(1, arrow_cast('A', 'Dictionary(UInt32, Utf8)')),
(2, arrow_cast('B', 'Dictionary(UInt32, Utf8)')),
(2, arrow_cast('A', 'Dictionary(UInt32, Utf8)')),
Expand Down Expand Up @@ -4877,7 +4877,7 @@ DROP TABLE uint32_dict;

# Table with an int column and Dict<UInt64> column:
statement ok
CREATE TABLE uint64_dict AS VALUES
CREATE TABLE uint64_dict AS VALUES
(1, arrow_cast('A', 'Dictionary(UInt64, Utf8)')),
(2, arrow_cast('B', 'Dictionary(UInt64, Utf8)')),
(2, arrow_cast('A', 'Dictionary(UInt64, Utf8)')),
Expand Down
4 changes: 2 additions & 2 deletions datafusion/sqllogictest/test_files/tpch/q1.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ explain select
logical_plan
01)Sort: lineitem.l_returnflag ASC NULLS LAST, lineitem.l_linestatus ASC NULLS LAST
02)--Projection: lineitem.l_returnflag, lineitem.l_linestatus, sum(lineitem.l_quantity) AS sum_qty, sum(lineitem.l_extendedprice) AS sum_base_price, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS sum_disc_price, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax) AS sum_charge, AVG(lineitem.l_quantity) AS avg_qty, AVG(lineitem.l_extendedprice) AS avg_price, AVG(lineitem.l_discount) AS avg_disc, COUNT(*) AS count_order
03)----Aggregate: groupBy=[[lineitem.l_returnflag, lineitem.l_linestatus]], aggr=[[sum(lineitem.l_quantity), sum(lineitem.l_extendedprice), sum({lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)|{Decimal128(Some(1),20,0) - lineitem.l_discount|{lineitem.l_discount}|{Decimal128(Some(1),20,0)}}|{lineitem.l_extendedprice}} AS lineitem.l_extendedprice * Decimal128(Some(1),20,0) - lineitem.l_discount) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), sum({lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)|{Decimal128(Some(1),20,0) - lineitem.l_discount|{lineitem.l_discount}|{Decimal128(Some(1),20,0)}}|{lineitem.l_extendedprice}} AS lineitem.l_extendedprice * Decimal128(Some(1),20,0) - lineitem.l_discount * (Decimal128(Some(1),20,0) + lineitem.l_tax)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax), AVG(lineitem.l_quantity), AVG(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(Int64(1)) AS COUNT(*)]]
03)----Aggregate: groupBy=[[lineitem.l_returnflag, lineitem.l_linestatus]], aggr=[[sum(lineitem.l_quantity), sum(lineitem.l_extendedprice), sum({lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)|{Decimal128(Some(1),20,0) - lineitem.l_discount|{lineitem.l_discount}|{Decimal128(Some(1),20,0)}}|{lineitem.l_extendedprice}}) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), sum({lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)|{Decimal128(Some(1),20,0) - lineitem.l_discount|{lineitem.l_discount}|{Decimal128(Some(1),20,0)}}|{lineitem.l_extendedprice}} * (Decimal128(Some(1),20,0) + lineitem.l_tax)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax), AVG(lineitem.l_quantity), AVG(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(Int64(1)) AS COUNT(*)]]
04)------Projection: lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS {lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)|{Decimal128(Some(1),20,0) - lineitem.l_discount|{lineitem.l_discount}|{Decimal128(Some(1),20,0)}}|{lineitem.l_extendedprice}}, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, lineitem.l_tax, lineitem.l_returnflag, lineitem.l_linestatus
05)--------Filter: lineitem.l_shipdate <= Date32("1998-09-02")
06)----------TableScan: lineitem projection=[l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate], partial_filters=[lineitem.l_shipdate <= Date32("1998-09-02")]
Expand Down Expand Up @@ -80,7 +80,7 @@ group by
l_linestatus
order by
l_returnflag,
l_linestatus;
l_linestatus;
----
A F 3774200 5320753880.69 5054096266.6828 5256751331.449234 25.537587 36002.123829 0.050144 147790
N F 95257 133737795.84 127132372.6512 132286291.229445 25.300664 35521.326916 0.049394 3765
Expand Down