Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-29375][SQL] Exchange reuse across all subquery levels #26044

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,9 @@ object QueryExecution {
* are correct, insert whole stage code gen, and try to reduce the work done by reusing exchanges
* and subqueries.
*/
private[execution] def preparations(sparkSession: SparkSession): Seq[Rule[SparkPlan]] =
private[execution] def preparations(
sparkSession: SparkSession,
subQuery: Boolean = false): Seq[Rule[SparkPlan]] =
Seq(
// `AdaptiveSparkPlanExec` is a leaf node. If inserted, all the following rules will be no-op
// as the original plan is hidden behind `AdaptiveSparkPlanExec`.
Expand All @@ -249,10 +251,9 @@ object QueryExecution {
EnsureRequirements(sparkSession.sessionState.conf),
ApplyColumnarRulesAndInsertTransitions(sparkSession.sessionState.conf,
sparkSession.sessionState.columnarRules),
CollapseCodegenStages(sparkSession.sessionState.conf),
ReuseExchange(sparkSession.sessionState.conf),
CollapseCodegenStages(sparkSession.sessionState.conf)) ++
(if (subQuery) Nil else Seq(ReuseExchange(sparkSession.sessionState.conf))) :+
ReuseSubquery(sparkSession.sessionState.conf)
)

/**
* Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal
Expand Down Expand Up @@ -283,7 +284,7 @@ object QueryExecution {
* Prepare the [[SparkPlan]] for execution.
*/
def prepareExecutedPlan(spark: SparkSession, plan: SparkPlan): SparkPlan = {
prepareForExecution(preparations(spark), plan)
prepareForExecution(preparations(spark, true), plan)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,13 @@ case class HashAggregateExec(
// This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash
// map and/or the sort-based aggregation once it has processed a given number of input rows.
private val testFallbackStartsAt: Option[(Int, Int)] = {
sqlContext.getConf("spark.sql.TungstenAggregate.testFallbackStartsAt", null) match {
case null | "" => None
case fallbackStartsAt =>
val splits = fallbackStartsAt.split(",").map(_.trim)
Some((splits.head.toInt, splits.last.toInt))
Option(sqlContext).flatMap {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change was required because this UT error:

org.scalatest.exceptions.TestFailedException: udf/postgreSQL/udf-aggregates_part3.sql - Scala UDF Expected "struct<[col:bigint]>", but got "struct<[]>" Schema did not match for query #1 select udf((select udf(count(*))         from (values (1)) t0(inner_c))) as col from (values (2),(3)) t1(outer_c): QueryOutput(select udf((select udf(count(*))         from (values (1)) t0(inner_c))) as col from (values (2),(3)) t1(outer_c),struct<>,java.lang.NullPointerException null)

where the stacktrace of the executor is:

02:43:13.445 ERROR org.apache.spark.executor.Executor: Exception in task 0.0 in stage 8.0 (TID 10)
org.apache.spark.sql.catalyst.errors.package$TreeNodeException: makeCopy, tree:
HashAggregate(keys=[], functions=[partial_count(1)], output=[count#397L])
+- Project
   +- LocalTableScan <empty>, [col1#385]

	at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:56)
	at org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy(TreeNode.scala:435)
	at org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy(TreeNode.scala:424)
	at org.apache.spark.sql.execution.SparkPlan.makeCopy(SparkPlan.scala:102)
	at org.apache.spark.sql.execution.SparkPlan.makeCopy(SparkPlan.scala:63)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.mapExpressions(QueryPlan.scala:132)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.doCanonicalize(QueryPlan.scala:261)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:245)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:244)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$doCanonicalize$1(QueryPlan.scala:259)
	at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:238)
	at scala.collection.immutable.List.foreach(List.scala:392)
	at scala.collection.TraversableLike.map(TraversableLike.scala:238)
	at scala.collection.TraversableLike.map$(TraversableLike.scala:231)
	at scala.collection.immutable.List.map(List.scala:298)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.doCanonicalize(QueryPlan.scala:259)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:245)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:244)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$doCanonicalize$1(QueryPlan.scala:259)
	at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:238)
	at scala.collection.immutable.List.foreach(List.scala:392)
	at scala.collection.TraversableLike.map(TraversableLike.scala:238)
	at scala.collection.TraversableLike.map$(TraversableLike.scala:231)
	at scala.collection.immutable.List.map(List.scala:298)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.doCanonicalize(QueryPlan.scala:259)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:245)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:244)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$doCanonicalize$1(QueryPlan.scala:259)
	at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:238)
	at scala.collection.immutable.List.foreach(List.scala:392)
	at scala.collection.TraversableLike.map(TraversableLike.scala:238)
	at scala.collection.TraversableLike.map$(TraversableLike.scala:231)
	at scala.collection.immutable.List.map(List.scala:298)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.doCanonicalize(QueryPlan.scala:259)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:245)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:244)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$doCanonicalize$1(QueryPlan.scala:259)
	at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:238)
	at scala.collection.immutable.List.foreach(List.scala:392)
	at scala.collection.TraversableLike.map(TraversableLike.scala:238)
	at scala.collection.TraversableLike.map$(TraversableLike.scala:231)
	at scala.collection.immutable.List.map(List.scala:298)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.doCanonicalize(QueryPlan.scala:259)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:245)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:244)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$doCanonicalize$1(QueryPlan.scala:259)
	at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:238)
	at scala.collection.immutable.List.foreach(List.scala:392)
	at scala.collection.TraversableLike.map(TraversableLike.scala:238)
	at scala.collection.TraversableLike.map$(TraversableLike.scala:231)
	at scala.collection.immutable.List.map(List.scala:298)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.doCanonicalize(QueryPlan.scala:259)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:245)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:244)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$doCanonicalize$1(QueryPlan.scala:259)
	at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:238)
	at scala.collection.immutable.List.foreach(List.scala:392)
	at scala.collection.TraversableLike.map(TraversableLike.scala:238)
	at scala.collection.TraversableLike.map$(TraversableLike.scala:231)
	at scala.collection.immutable.List.map(List.scala:298)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.doCanonicalize(QueryPlan.scala:259)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:245)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:244)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$doCanonicalize$1(QueryPlan.scala:259)
	at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:238)
	at scala.collection.immutable.List.foreach(List.scala:392)
	at scala.collection.TraversableLike.map(TraversableLike.scala:238)
	at scala.collection.TraversableLike.map$(TraversableLike.scala:231)
	at scala.collection.immutable.List.map(List.scala:298)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.doCanonicalize(QueryPlan.scala:259)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:245)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:244)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$doCanonicalize$1(QueryPlan.scala:259)
	at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:238)
	at scala.collection.immutable.List.foreach(List.scala:392)
	at scala.collection.TraversableLike.map(TraversableLike.scala:238)
	at scala.collection.TraversableLike.map$(TraversableLike.scala:231)
	at scala.collection.immutable.List.map(List.scala:298)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.doCanonicalize(QueryPlan.scala:259)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:245)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:244)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$doCanonicalize$1(QueryPlan.scala:259)
	at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:238)
	at scala.collection.immutable.List.foreach(List.scala:392)
	at scala.collection.TraversableLike.map(TraversableLike.scala:238)
	at scala.collection.TraversableLike.map$(TraversableLike.scala:231)
	at scala.collection.immutable.List.map(List.scala:298)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.doCanonicalize(QueryPlan.scala:259)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:245)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:244)
	at org.apache.spark.sql.execution.SubqueryExec.doCanonicalize(basicPhysicalOperators.scala:772)
	at org.apache.spark.sql.execution.SubqueryExec.doCanonicalize(basicPhysicalOperators.scala:742)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:245)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:244)
	at org.apache.spark.sql.execution.ScalarSubquery.canonicalized$lzycompute(subquery.scala:109)
	at org.apache.spark.sql.execution.ScalarSubquery.canonicalized(subquery.scala:108)
	at org.apache.spark.sql.execution.ScalarSubquery.canonicalized(subquery.scala:62)
	at org.apache.spark.sql.catalyst.expressions.Expression.$anonfun$canonicalized$1(Expression.scala:229)
	at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:238)
	at scala.collection.immutable.List.foreach(List.scala:392)
	at scala.collection.TraversableLike.map(TraversableLike.scala:238)
	at scala.collection.TraversableLike.map$(TraversableLike.scala:231)
	at scala.collection.immutable.List.map(List.scala:298)
	at org.apache.spark.sql.catalyst.expressions.Expression.canonicalized$lzycompute(Expression.scala:229)
	at org.apache.spark.sql.catalyst.expressions.Expression.canonicalized(Expression.scala:228)
	at org.apache.spark.sql.catalyst.expressions.Expression.semanticHash(Expression.scala:248)
	at org.apache.spark.sql.catalyst.expressions.EquivalentExpressions$Expr.hashCode(EquivalentExpressions.scala:41)
	at scala.runtime.Statics.anyHash(Statics.java:122)
	at scala.collection.mutable.HashTable$HashUtils.elemHashCode(HashTable.scala:416)
	at scala.collection.mutable.HashTable$HashUtils.elemHashCode$(HashTable.scala:416)
	at scala.collection.mutable.HashMap.elemHashCode(HashMap.scala:44)
	at scala.collection.mutable.HashTable.findEntry(HashTable.scala:136)
	at scala.collection.mutable.HashTable.findEntry$(HashTable.scala:135)
	at scala.collection.mutable.HashMap.findEntry(HashMap.scala:44)
	at scala.collection.mutable.HashMap.get(HashMap.scala:74)
	at org.apache.spark.sql.catalyst.expressions.EquivalentExpressions.addExpr(EquivalentExpressions.scala:55)
	at org.apache.spark.sql.catalyst.expressions.EquivalentExpressions.addExprTree(EquivalentExpressions.scala:99)
	at org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext.$anonfun$subexpressionElimination$1(CodeGenerator.scala:1118)
	at org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext.$anonfun$subexpressionElimination$1$adapted(CodeGenerator.scala:1118)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext.subexpressionElimination(CodeGenerator.scala:1118)
	at org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext.generateExpressions(CodeGenerator.scala:1170)
	at org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection$.create(GenerateMutableProjection.scala:64)
	at org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection$.generate(GenerateMutableProjection.scala:49)
	at org.apache.spark.sql.catalyst.expressions.MutableProjection$.createCodeGeneratedObject(Projection.scala:84)
	at org.apache.spark.sql.catalyst.expressions.MutableProjection$.createCodeGeneratedObject(Projection.scala:80)
	at org.apache.spark.sql.catalyst.expressions.CodeGeneratorWithInterpretedFallback.createObject(CodeGeneratorWithInterpretedFallback.scala:47)
	at org.apache.spark.sql.catalyst.expressions.MutableProjection$.create(Projection.scala:95)
	at org.apache.spark.sql.catalyst.expressions.MutableProjection$.create(Projection.scala:103)
	at org.apache.spark.sql.execution.SparkPlan.newMutableProjection(SparkPlan.scala:471)
	at org.apache.spark.sql.execution.python.EvalPythonExec.$anonfun$doExecute$2(EvalPythonExec.scala:116)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:837)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:837)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:349)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:313)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:349)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:313)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:349)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:313)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:127)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:425)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1377)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:428)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)
Caused by: java.lang.reflect.InvocationTargetException
	at sun.reflect.GeneratedConstructorAccessor41.newInstance(Unknown Source)
	at sun.reflect.DelegatingConstructorAccessorImpl.newInstance(DelegatingConstructorAccessorImpl.java:45)
	at java.lang.reflect.Constructor.newInstance(Constructor.java:423)
	at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$makeCopy$7(TreeNode.scala:468)
	at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:72)
	at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$makeCopy$1(TreeNode.scala:467)
	at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52)
	... 151 more
Caused by: java.lang.NullPointerException
	at org.apache.spark.sql.execution.aggregate.HashAggregateExec.<init>(HashAggregateExec.scala:96)
	... 158 more

Because this PR adds lazy val canonicalized to ScalarSubquery, EvalPythonExec invoked canonicalization of HashAggregateExec on an executor where SparkSession is not available.
Honestly I'm not sure how many other SparkPlan nodes exist that can't be canonocalized on an executor for similar reasons.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InSubqueryExec already has lazy val canonicalized defined so maybe this issue could come up without this PR in some DPP usecases.

_.getConf("spark.sql.TungstenAggregate.testFallbackStartsAt", null) match {
case null | "" => None
case fallbackStartsAt =>
val splits = fallbackStartsAt.split(",").map(_.trim)
Some((splits.head.toInt, splits.last.toInt))
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.sql.execution.exchange

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -107,35 +106,39 @@ case class ReuseExchange(conf: SQLConf) extends Rule[SparkPlan] {
if (!conf.exchangeReuseEnabled) {
return plan
}
// Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls.
val exchanges = mutable.HashMap[StructType, ArrayBuffer[Exchange]]()

// Replace a Exchange duplicate with a ReusedExchange
def reuse: PartialFunction[Exchange, SparkPlan] = {
case exchange: Exchange =>
val sameSchema = exchanges.getOrElseUpdate(exchange.schema, ArrayBuffer[Exchange]())
val samePlan = sameSchema.find { e =>
exchange.sameResult(e)
}
if (samePlan.isDefined) {
// Keep the output of this exchange, the following plans require that to resolve
// attributes.
ReusedExchangeExec(exchange.output, samePlan.get)
} else {
sameSchema += exchange
exchange
// To avoid costly canonicalization of an exchange:
// - we use its schema first to check if it can be replaced to a reused exchange at all
// - we insert an exchange into the map of canonicalized plans only when at least 2 exchange
// have the same schema
val exchanges = mutable.Map[StructType, (Exchange, mutable.Map[SparkPlan, Exchange])]()

def reuse(plan: SparkPlan): SparkPlan = {
plan.transformUp {
case exchange: Exchange =>
val (firstSameSchemaExchange, sameResultExchanges) =
exchanges.getOrElseUpdate(exchange.schema, (exchange, mutable.Map()))
if (firstSameSchemaExchange.ne(exchange)) {
if (sameResultExchanges.isEmpty) {
sameResultExchanges +=
firstSameSchemaExchange.canonicalized -> firstSameSchemaExchange
}
val sameResultExchange =
sameResultExchanges.getOrElseUpdate(exchange.canonicalized, exchange)
if (sameResultExchange.ne(exchange)) {
ReusedExchangeExec(exchange.output, sameResultExchange)
} else {
exchange
}
} else {
exchange
}
case other => other.transformExpressions {
case sub: ExecSubqueryExpression =>
sub.withNewPlan(reuse(sub.plan).asInstanceOf[BaseSubqueryExec])
}
}
}

plan transformUp {
case exchange: Exchange => reuse(exchange)
} transformAllExpressions {
// Lookup inside subqueries for duplicate exchanges
case in: InSubqueryExec =>
val newIn = in.plan.transformUp {
case exchange: Exchange => reuse(exchange)
}
in.copy(plan = newIn.asInstanceOf[BaseSubqueryExec])
}
reuse(plan)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ case class ShuffleExchangeExec(

override def nodeName: String = "Exchange"

private val serializer: Serializer =
private lazy val serializer: Serializer =
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar reasons for this change as above. This time the stacktrace is:

04:31:28.110 ERROR org.apache.spark.executor.Executor: Exception in task 0.0 in stage 8.0 (TID 10)
org.apache.spark.sql.catalyst.errors.package$TreeNodeException: makeCopy, tree:
Exchange SinglePartition, true, [id=#180]
+- *(1) HashAggregate(keys=[], functions=[partial_count(1)], output=[count#397L])
   +- *(1) Project
      +- *(1) LocalTableScan <empty>, [col1#385]

        at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:56)
        at org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy(TreeNode.scala:435)
        at org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy(TreeNode.scala:424)
        at org.apache.spark.sql.execution.SparkPlan.makeCopy(SparkPlan.scala:102)
        at org.apache.spark.sql.execution.SparkPlan.makeCopy(SparkPlan.scala:63)
        at org.apache.spark.sql.catalyst.trees.TreeNode.withNewChildren(TreeNode.scala:263)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.doCanonicalize(QueryPlan.scala:277)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:245)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:244)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$doCanonicalize$1(QueryPlan.scala:259)
        at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:238)
        at scala.collection.immutable.List.foreach(List.scala:392)
        at scala.collection.TraversableLike.map(TraversableLike.scala:238)
        at scala.collection.TraversableLike.map$(TraversableLike.scala:231)
        at scala.collection.immutable.List.map(List.scala:298)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.doCanonicalize(QueryPlan.scala:259)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:245)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:244)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$doCanonicalize$1(QueryPlan.scala:259)
        at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:238)
        at scala.collection.immutable.List.foreach(List.scala:392)
        at scala.collection.TraversableLike.map(TraversableLike.scala:238)
        at scala.collection.TraversableLike.map$(TraversableLike.scala:231)
        at scala.collection.immutable.List.map(List.scala:298)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.doCanonicalize(QueryPlan.scala:259)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:245)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:244)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$doCanonicalize$1(QueryPlan.scala:259)
        at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:238)
        at scala.collection.immutable.List.foreach(List.scala:392)
        at scala.collection.TraversableLike.map(TraversableLike.scala:238)
        at scala.collection.TraversableLike.map$(TraversableLike.scala:231)
        at scala.collection.immutable.List.map(List.scala:298)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.doCanonicalize(QueryPlan.scala:259)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:245)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:244)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$doCanonicalize$1(QueryPlan.scala:259)
        at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:238)
        at scala.collection.immutable.List.foreach(List.scala:392)
        at scala.collection.TraversableLike.map(TraversableLike.scala:238)
        at scala.collection.TraversableLike.map$(TraversableLike.scala:231)
        at scala.collection.immutable.List.map(List.scala:298)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.doCanonicalize(QueryPlan.scala:259)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:245)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:244)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$doCanonicalize$1(QueryPlan.scala:259)
        at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:238)
        at scala.collection.immutable.List.foreach(List.scala:392)
        at scala.collection.TraversableLike.map(TraversableLike.scala:238)
        at scala.collection.TraversableLike.map$(TraversableLike.scala:231)
        at scala.collection.immutable.List.map(List.scala:298)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.doCanonicalize(QueryPlan.scala:259)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:245)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:244)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$doCanonicalize$1(QueryPlan.scala:259)
        at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:238)
        at scala.collection.immutable.List.foreach(List.scala:392)
        at scala.collection.TraversableLike.map(TraversableLike.scala:238)
        at scala.collection.TraversableLike.map$(TraversableLike.scala:231)
        at scala.collection.immutable.List.map(List.scala:298)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.doCanonicalize(QueryPlan.scala:259)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:245)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:244)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$doCanonicalize$1(QueryPlan.scala:259)
        at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:238)
        at scala.collection.immutable.List.foreach(List.scala:392)
        at scala.collection.TraversableLike.map(TraversableLike.scala:238)
        at scala.collection.TraversableLike.map$(TraversableLike.scala:231)
        at scala.collection.immutable.List.map(List.scala:298)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.doCanonicalize(QueryPlan.scala:259)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:245)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:244)
        at org.apache.spark.sql.execution.SubqueryExec.doCanonicalize(basicPhysicalOperators.scala:772)
        at org.apache.spark.sql.execution.SubqueryExec.doCanonicalize(basicPhysicalOperators.scala:742)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:245)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:244)
        at org.apache.spark.sql.execution.ScalarSubquery.canonicalized$lzycompute(subquery.scala:109)
        at org.apache.spark.sql.execution.ScalarSubquery.canonicalized(subquery.scala:108)
        at org.apache.spark.sql.execution.ScalarSubquery.canonicalized(subquery.scala:62)
        at org.apache.spark.sql.catalyst.expressions.Expression.$anonfun$canonicalized$1(Expression.scala:229)
        at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:238)
        at scala.collection.immutable.List.foreach(List.scala:392)
        at scala.collection.TraversableLike.map(TraversableLike.scala:238)
        at scala.collection.TraversableLike.map$(TraversableLike.scala:231)
        at scala.collection.immutable.List.map(List.scala:298)
        at org.apache.spark.sql.catalyst.expressions.Expression.canonicalized$lzycompute(Expression.scala:229)
        at org.apache.spark.sql.catalyst.expressions.Expression.canonicalized(Expression.scala:228)
        at org.apache.spark.sql.catalyst.expressions.Expression.semanticHash(Expression.scala:248)
        at org.apache.spark.sql.catalyst.expressions.EquivalentExpressions$Expr.hashCode(EquivalentExpressions.scala:41)
        at scala.runtime.Statics.anyHash(Statics.java:122)
        at scala.collection.mutable.HashTable$HashUtils.elemHashCode(HashTable.scala:416)
        at scala.collection.mutable.HashTable$HashUtils.elemHashCode$(HashTable.scala:416)
        at scala.collection.mutable.HashMap.elemHashCode(HashMap.scala:44)
        at scala.collection.mutable.HashTable.findEntry(HashTable.scala:136)
        at scala.collection.mutable.HashTable.findEntry$(HashTable.scala:135)
        at scala.collection.mutable.HashMap.findEntry(HashMap.scala:44)
        at scala.collection.mutable.HashMap.get(HashMap.scala:74)
        at org.apache.spark.sql.catalyst.expressions.EquivalentExpressions.addExpr(EquivalentExpressions.scala:55)
        at org.apache.spark.sql.catalyst.expressions.EquivalentExpressions.addExprTree(EquivalentExpressions.scala:99)
        at org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext.$anonfun$subexpressionElimination$1(CodeGenerator.scala:1118)
        at org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext.$anonfun$subexpressionElimination$1$adapted(CodeGenerator.scala:1118)
        at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
        at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
        at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
        at org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext.subexpressionElimination(CodeGenerator.scala:1118)
        at org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext.generateExpressions(CodeGenerator.scala:1170)
        at org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection$.create(GenerateMutableProjection.scala:64)
        at org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection$.generate(GenerateMutableProjection.scala:49)
        at org.apache.spark.sql.catalyst.expressions.MutableProjection$.createCodeGeneratedObject(Projection.scala:84)
        at org.apache.spark.sql.catalyst.expressions.MutableProjection$.createCodeGeneratedObject(Projection.scala:80)
        at org.apache.spark.sql.catalyst.expressions.CodeGeneratorWithInterpretedFallback.createObject(CodeGeneratorWithInterpretedFallback.scala:47)
        at org.apache.spark.sql.catalyst.expressions.MutableProjection$.create(Projection.scala:95)
        at org.apache.spark.sql.catalyst.expressions.MutableProjection$.create(Projection.scala:103)
        at org.apache.spark.sql.execution.SparkPlan.newMutableProjection(SparkPlan.scala:471)
        at org.apache.spark.sql.execution.python.EvalPythonExec.$anonfun$doExecute$2(EvalPythonExec.scala:116)
        at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:837)
        at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:837)
        at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
        at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:349)
        at org.apache.spark.rdd.RDD.iterator(RDD.scala:313)
        at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
        at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:349)
        at org.apache.spark.rdd.RDD.iterator(RDD.scala:313)
        at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
        at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:349)
        at org.apache.spark.rdd.RDD.iterator(RDD.scala:313)
        at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
        at org.apache.spark.scheduler.Task.run(Task.scala:127)
        at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:425)
        at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1377)
        at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:428)
        at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
        at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
        at java.lang.Thread.run(Thread.java:748)
Caused by: java.lang.reflect.InvocationTargetException
        at sun.reflect.NativeConstructorAccessorImpl.newInstance0(Native Method)
        at sun.reflect.NativeConstructorAccessorImpl.newInstance(NativeConstructorAccessorImpl.java:62)
        at sun.reflect.DelegatingConstructorAccessorImpl.newInstance(DelegatingConstructorAccessorImpl.java:45)
        at java.lang.reflect.Constructor.newInstance(Constructor.java:423)
        at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$makeCopy$7(TreeNode.scala:468)
        at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:72)
        at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$makeCopy$1(TreeNode.scala:467)
        at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52)
        ... 133 more
Caused by: java.lang.NullPointerException
        at org.apache.spark.sql.execution.SparkPlan.sparkContext(SparkPlan.scala:72)
        at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.metrics$lzycompute(ShuffleExchangeExec.scala:57)
        at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.metrics(ShuffleExchangeExec.scala:58)
        at org.apache.spark.sql.execution.SparkPlan.longMetric(SparkPlan.scala:149)
        at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.<init>(ShuffleExchangeExec.scala:63)
        ... 141 more

And if serializer is not lazy then it makes no sense for metrics to be lazy.

new UnsafeRowSerializer(child.output.size, longMetric("dataSize"))

@transient lazy val inputRDD: RDD[InternalRow] = child.execute()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ case class ScalarSubquery(
require(updated, s"$this has not finished")
Literal.create(result, dataType).doGenCode(ctx, ev)
}

override lazy val canonicalized: ScalarSubquery = {
copy(plan = plan.canonicalized.asInstanceOf[BaseSubqueryExec], exprId = ExprId(0))
}
}

/**
Expand Down
57 changes: 57 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Sort}
import org.apache.spark.sql.execution.{ColumnarToRowExec, ExecSubqueryExpression, FileSourceScanExec, InputAdapter, ReusedSubqueryExec, ScalarSubquery, SubqueryExec, WholeStageCodegenExec}
import org.apache.spark.sql.execution.datasources.FileScanRDD
import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession

Expand Down Expand Up @@ -1389,6 +1390,62 @@ class SubquerySuite extends QueryTest with SharedSparkSession {
}
}

test("Exchange reuse across all subquery levels") {
Seq(true, false).foreach { reuse =>
withSQLConf(SQLConf.EXCHANGE_REUSE_ENABLED.key -> reuse.toString) {
val df = sql(
"""
|SELECT
| (SELECT max(a.key) FROM testData AS a JOIN testData AS b ON b.key = a.key),
| a.key
|FROM testData AS a
|JOIN testData AS b ON b.key = a.key
""".stripMargin)

val plan = df.queryExecution.executedPlan

val exchangeIds = plan.collectInPlanAndSubqueries { case e: Exchange => e.id }
val reusedExchangeIds = plan.collectInPlanAndSubqueries {
case re: ReusedExchangeExec => re.child.id
}

if (reuse) {
assert(exchangeIds.size == 2, "Exchange reusing not working correctly")
assert(reusedExchangeIds.size == 3, "Exchange reusing not working correctly")
assert(reusedExchangeIds.forall(exchangeIds.contains(_)),
"ReusedExchangeExec should reuse an existing exchange")
} else {
assert(exchangeIds.size == 5, "expect 5 Exchange when not reusing")
assert(reusedExchangeIds.size == 0, "expect 0 ReusedExchangeExec when not reusing")
}

val df2 = sql(
"""
SELECT
(SELECT min(a.key) FROM testData AS a JOIN testData AS b ON b.key = a.key),
(SELECT max(a.key) FROM testData AS a JOIN testData2 AS b ON b.a = a.key)
""".stripMargin)

val plan2 = df2.queryExecution.executedPlan

val exchangeIds2 = plan2.collectInPlanAndSubqueries { case e: Exchange => e.id }
val reusedExchangeIds2 = plan2.collectInPlanAndSubqueries {
case re: ReusedExchangeExec => re.child.id
}

if (reuse) {
assert(exchangeIds2.size == 4, "Exchange reusing not working correctly")
assert(reusedExchangeIds2.size == 2, "Exchange reusing not working correctly")
assert(reusedExchangeIds2.forall(exchangeIds2.contains(_)),
"ReusedExchangeExec should reuse an existing exchange")
} else {
assert(exchangeIds2.size == 6, "expect 6 Exchange when not reusing")
assert(reusedExchangeIds2.size == 0, "expect 0 ReusedExchangeExec when not reusing")
}
}
}
}

test("Scalar subquery name should start with scalar-subquery#") {
val df = sql("SELECT a FROM l WHERE a = (SELECT max(c) FROM r WHERE c = 1)".stripMargin)
var subqueryExecs: ArrayBuffer[SubqueryExec] = ArrayBuffer.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ class PlannerSuite extends SharedSparkSession {
Inner,
None,
shuffle,
shuffle)
shuffle.copy())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a plan contains the exact same instance of Exchange multiple times then it makes no sense to replace one of the instances to a ReusedExchange.


val outputPlan = ReuseExchange(spark.sessionState.conf).apply(inputPlan)
if (outputPlan.collect { case e: ReusedExchangeExec => true }.size != 1) {
Expand Down