diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 6c40104e52a5f..3cbebca14f7dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -107,12 +107,17 @@ package object debug { */ def codegenStringSeq(plan: SparkPlan): Seq[(String, String, ByteCodeStats)] = { val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegenExec]() - plan transform { - case s: WholeStageCodegenExec => - codegenSubtrees += s - s - case s => s + + def findSubtrees(plan: SparkPlan): Unit = { + plan foreach { + case s: WholeStageCodegenExec => + codegenSubtrees += s + case s => + s.subqueries.foreach(findSubtrees) + } } + + findSubtrees(plan) codegenSubtrees.toSeq.sortBy(_.codegenStageId).map { subtree => val (_, source) = subtree.doCodeGen() val codeStats = try { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala index 2c3b37a1498ec..d58bf2c6260b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala @@ -63,11 +63,17 @@ abstract class BenchmarkQueryTest extends QueryTest with SharedSparkSession { protected def checkGeneratedCode(plan: SparkPlan, checkMethodCodeSize: Boolean = true): Unit = { val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegenExec]() - plan foreach { - case s: WholeStageCodegenExec => - codegenSubtrees += s - case _ => + + def findSubtrees(plan: SparkPlan): Unit = { + plan foreach { + case s: WholeStageCodegenExec => + codegenSubtrees += s + case s => + s.subqueries.foreach(findSubtrees) + } } + + findSubtrees(plan) codegenSubtrees.toSeq.foreach { subtree => val code = subtree.doCodeGen()._2 val (_, ByteCodeStats(maxMethodCodeSize, _, _)) = try { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala index 8b7459fddb59a..bf100c0205efa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala @@ -228,6 +228,22 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite } } + test("SPARK-33853: explain codegen - check presence of subquery") { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") { + withTempView("df") { + val df1 = spark.range(1, 100) + df1.createTempView("df") + + val sqlText = "EXPLAIN CODEGEN SELECT (SELECT min(id) FROM df)" + val expectedText = "Found 3 WholeStageCodegen subtrees." + + withNormalizedExplain(sqlText) { normalizedOutput => + assert(normalizedOutput.contains(expectedText)) + } + } + } + } + test("explain formatted - check presence of subquery in case of DPP") { withTable("df1", "df2") { withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true",