Skip to content

Commit

Permalink
[SPARK-31078][SQL] Respect aliases in output ordering
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Currently, in the following scenario, an unnecessary `Sort` node is introduced:
```scala
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
  val df = (0 until 20).toDF("i").as("df")
  df.repartition(8, df("i")).write.format("parquet")
    .bucketBy(8, "i").sortBy("i").saveAsTable("t")
  val t1 = spark.table("t")
  val t2 = t1.selectExpr("i as ii")
  t1.join(t2, t1("i") === t2("ii")).explain
}
```
```
== Physical Plan ==
*(3) SortMergeJoin [i#8], [ii#10], Inner
:- *(1) Project [i#8]
:  +- *(1) Filter isnotnull(i#8)
:     +- *(1) ColumnarToRow
:        +- FileScan parquet default.t[i#8] Batched: true, DataFilters: [isnotnull(i#8)], Format: Parquet, Location: InMemoryFileIndex[file:/..., PartitionFilters: [], PushedFilters: [IsNotNull(i)], ReadSchema: struct<i:int>, SelectedBucketsCount: 8 out of 8
+- *(2) Sort [ii#10 ASC NULLS FIRST], false, 0    <==== UNNECESSARY
   +- *(2) Project [i#8 AS ii#10]
      +- *(2) Filter isnotnull(i#8)
         +- *(2) ColumnarToRow
            +- FileScan parquet default.t[i#8] Batched: true, DataFilters: [isnotnull(i#8)], Format: Parquet, Location: InMemoryFileIndex[file:/..., PartitionFilters: [], PushedFilters: [IsNotNull(i)], ReadSchema: struct<i:int>, SelectedBucketsCount: 8 out of 8
```
Notice that `Sort [ii#10 ASC NULLS FIRST], false, 0` is introduced even though the underlying data is already sorted. This is because `outputOrdering` doesn't handle aliases correctly. This PR proposes to fix this issue.

### Why are the changes needed?

To better handle aliases in `outputOrdering`.

### Does this PR introduce any user-facing change?

Yes, now with the fix, the `explain` prints out the following:
```
== Physical Plan ==
*(3) SortMergeJoin [i#8], [ii#10], Inner
:- *(1) Project [i#8]
:  +- *(1) Filter isnotnull(i#8)
:     +- *(1) ColumnarToRow
:        +- FileScan parquet default.t[i#8] Batched: true, DataFilters: [isnotnull(i#8)], Format: Parquet, Location: InMemoryFileIndex[file:/..., PartitionFilters: [], PushedFilters: [IsNotNull(i)], ReadSchema: struct<i:int>, SelectedBucketsCount: 8 out of 8
+- *(2) Project [i#8 AS ii#10]
   +- *(2) Filter isnotnull(i#8)
      +- *(2) ColumnarToRow
         +- FileScan parquet default.t[i#8] Batched: true, DataFilters: [isnotnull(i#8)], Format: Parquet, Location: InMemoryFileIndex[file:/..., PartitionFilters: [], PushedFilters: [IsNotNull(i)], ReadSchema: struct<i:int>, SelectedBucketsCount: 8 out of 8
```

### How was this patch tested?

Tests added.

Closes apache#27842 from imback82/alias_aware_sort_order.

Authored-by: Terry Kim <yuminkim@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit 294f605)
  • Loading branch information
imback82 authored and rshkv committed May 6, 2021
1 parent 7a0f23c commit f099edb
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,37 @@
*/
package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}

/**
* A trait that handles aliases in the `outputExpressions` to produce `outputPartitioning`
* that satisfies output distribution requirements.
* A trait that provides functionality to handle aliases in the `outputExpressions`.
*/
trait AliasAwareOutputPartitioning extends UnaryExecNode {
trait AliasAwareOutputExpression extends UnaryExecNode {
protected def outputExpressions: Seq[NamedExpression]

protected def hasAlias: Boolean = outputExpressions.collectFirst { case _: Alias => }.isDefined

protected def replaceAliases(exprs: Seq[Expression]): Seq[Expression] = {
exprs.map {
case a: AttributeReference => replaceAlias(a).getOrElse(a)
case other => other
}
}

protected def replaceAlias(attr: AttributeReference): Option[Attribute] = {
outputExpressions.collectFirst {
case a @ Alias(child: AttributeReference, _) if child.semanticEquals(attr) =>
a.toAttribute
}
}
}

/**
* A trait that handles aliases in the `outputExpressions` to produce `outputPartitioning` that
* satisfies distribution requirements.
*/
trait AliasAwareOutputPartitioning extends AliasAwareOutputExpression {
final override def outputPartitioning: Partitioning = {
if (hasAlias) {
child.outputPartitioning match {
Expand All @@ -36,20 +57,25 @@ trait AliasAwareOutputPartitioning extends UnaryExecNode {
child.outputPartitioning
}
}
}

private def hasAlias: Boolean = outputExpressions.collectFirst { case _: Alias => }.isDefined

private def replaceAliases(exprs: Seq[Expression]): Seq[Expression] = {
exprs.map {
case a: AttributeReference => replaceAlias(a).getOrElse(a)
case other => other
}
}
/**
* A trait that handles aliases in the `orderingExpressions` to produce `outputOrdering` that
* satisfies ordering requirements.
*/
trait AliasAwareOutputOrdering extends AliasAwareOutputExpression {
protected def orderingExpressions: Seq[SortOrder]

private def replaceAlias(attr: AttributeReference): Option[Attribute] = {
outputExpressions.collectFirst {
case a @ Alias(child: AttributeReference, _) if child.semanticEquals(attr) =>
a.toAttribute
final override def outputOrdering: Seq[SortOrder] = {
if (hasAlias) {
orderingExpressions.map { s =>
s.child match {
case a: AttributeReference => s.copy(child = replaceAlias(a).getOrElse(a))
case _ => s
}
}
} else {
orderingExpressions
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ case class HashAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends BaseAggregateExec with BlockingOperatorWithCodegen with AliasAwareOutputPartitioning {
extends BaseAggregateExec
with BlockingOperatorWithCodegen
with AliasAwareOutputPartitioning {

private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, SparkPlan}
import org.apache.spark.sql.execution.{AliasAwareOutputOrdering, AliasAwareOutputPartitioning, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics

/**
Expand All @@ -38,7 +38,9 @@ case class SortAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends BaseAggregateExec with AliasAwareOutputPartitioning {
extends BaseAggregateExec
with AliasAwareOutputPartitioning
with AliasAwareOutputOrdering {

private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
Expand Down Expand Up @@ -68,7 +70,7 @@ case class SortAggregateExec(

override protected def outputExpressions: Seq[NamedExpression] = resultExpressions

override def outputOrdering: Seq[SortOrder] = {
override protected def orderingExpressions: Seq[SortOrder] = {
groupingExpressions.map(SortOrder(_, Ascending))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}

/** Physical plan for Project. */
case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
extends UnaryExecNode with CodegenSupport with AliasAwareOutputPartitioning {
extends UnaryExecNode
with CodegenSupport
with AliasAwareOutputPartitioning
with AliasAwareOutputOrdering {

override def output: Seq[Attribute] = projectList.map(_.toAttribute)

Expand Down Expand Up @@ -80,10 +83,10 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
}
}

override def outputOrdering: Seq[SortOrder] = child.outputOrdering

override protected def outputExpressions: Seq[NamedExpression] = projectList

override protected def orderingExpressions: Seq[SortOrder] = child.outputOrdering

override def verboseStringWithOperatorId(): String = {
s"""
|$formattedNodeName
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,25 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
}
}
}

test("aliases in the sort aggregate expressions should not introduce extra sort") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") {
val t1 = spark.range(10).selectExpr("floor(id/4) as k1")
val t2 = spark.range(20).selectExpr("floor(id/4) as k2")

val agg1 = t1.groupBy("k1").agg(collect_list("k1")).withColumnRenamed("k1", "k3")
val agg2 = t2.groupBy("k2").agg(collect_list("k2"))

val planned = agg1.join(agg2, $"k3" === $"k2").queryExecution.executedPlan
assert(planned.collect { case s: SortAggregateExec => s }.nonEmpty)

// We expect two SortExec nodes on each side of join.
val sorts = planned.collect { case s: SortExec => s }
assert(sorts.size == 4)
}
}
}
}

// Used for unit-testing EnsureRequirements
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,18 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
}
}

test("sort should not be introduced when aliases are used") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
withTable("t") {
df1.repartition(1).write.format("parquet").bucketBy(8, "i").sortBy("i").saveAsTable("t")
val t1 = spark.table("t")
val t2 = t1.selectExpr("i as ii")
val plan = t1.join(t2, t1("i") === t2("ii")).queryExecution.executedPlan
assert(plan.collect { case sort: SortExec => sort }.isEmpty)
}
}
}

test("bucket join should work with SubqueryAlias plan") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
withTable("t") {
Expand Down

0 comments on commit f099edb

Please sign in to comment.