From 609f9728e5182cde8a79092e4be0e044a52a8d63 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 16 May 2024 19:41:17 +0800 Subject: [PATCH 1/3] InlineCTE should keep not-inlined relations in the original WithCTE node --- .../sql/catalyst/analysis/CheckAnalysis.scala | 45 +----- .../sql/catalyst/optimizer/InlineCTE.scala | 133 +++++++++++------- .../catalyst/optimizer/InlineCTESuite.scala | 42 ++++++ 3 files changed, 132 insertions(+), 88 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTESuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index e18f4d1b36e1a..cf827a6461bd1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -142,50 +142,17 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB errorClass, missingCol, orderedCandidates, a.origin) } - private def checkUnreferencedCTERelations( - cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])], - visited: mutable.Map[Long, Boolean], - danglingCTERelations: mutable.ArrayBuffer[CTERelationDef], - cteId: Long): Unit = { - if (visited(cteId)) { - return - } - val (cteDef, _, refMap) = cteMap(cteId) - refMap.foreach { case (id, _) => - checkUnreferencedCTERelations(cteMap, visited, danglingCTERelations, id) - } - danglingCTERelations.append(cteDef) - visited(cteId) = true - } - def checkAnalysis(plan: LogicalPlan): Unit = { - val inlineCTE = InlineCTE(alwaysInline = true) - val cteMap = mutable.HashMap.empty[Long, (CTERelationDef, Int, mutable.Map[Long, Int])] - inlineCTE.buildCTEMap(plan, cteMap) - val danglingCTERelations = mutable.ArrayBuffer.empty[CTERelationDef] - val visited: mutable.Map[Long, Boolean] = mutable.Map.empty.withDefaultValue(false) - // If a CTE relation is never used, it will disappear after inline. Here we explicitly collect - // these dangling CTE relations, and put them back in the main query, to make sure the entire - // query plan is valid. - cteMap.foreach { case (cteId, (_, refCount, _)) => - // If a CTE relation ref count is 0, the other CTE relations that reference it should also be - // collected. This code will also guarantee the leaf relations that do not reference - // any others are collected first. - if (refCount == 0) { - checkUnreferencedCTERelations(cteMap, visited, danglingCTERelations, cteId) - } - } - // Inline all CTEs in the plan to help check query plan structures in subqueries. - var inlinedPlan: LogicalPlan = plan - try { - inlinedPlan = inlineCTE(plan) + // We should inline all CTE relations to restore the original plan shape, as the analysis check + // may need to match certain plan shapes. For dangling CTE relations, they will still be kept + // in the original `WithCTE` node, as we need to perform analysis check for them as well. + val inlineCTE = InlineCTE(alwaysInline = true, keepDanglingRelations = true) + val inlinedPlan: LogicalPlan = try { + inlineCTE(plan) } catch { case e: AnalysisException => throw new ExtendedAnalysisException(e, plan) } - if (danglingCTERelations.nonEmpty) { - inlinedPlan = WithCTE(inlinedPlan, danglingCTERelations.toSeq) - } try { checkAnalysis0(inlinedPlan) } catch { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala index 8d7ff4cbf163d..1ac1d4c02ec7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala @@ -37,23 +37,19 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{CTE, PLAN_EXPRESSION} * query level. * * @param alwaysInline if true, inline all CTEs in the query plan. + * @param keepDanglingRelations if true, dangling CTE relations will be kept in the original + * `WithCTE` node. */ -case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] { +case class InlineCTE( + alwaysInline: Boolean = false, + keepDanglingRelations: Boolean = false) extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { if (!plan.isInstanceOf[Subquery] && plan.containsPattern(CTE)) { - val cteMap = mutable.SortedMap.empty[Long, (CTERelationDef, Int, mutable.Map[Long, Int])] + val cteMap = mutable.SortedMap.empty[Long, CTEReferenceInfo] buildCTEMap(plan, cteMap) cleanCTEMap(cteMap) - val notInlined = mutable.ArrayBuffer.empty[CTERelationDef] - val inlined = inlineCTE(plan, cteMap, notInlined) - // CTEs in SQL Commands have been inlined by `CTESubstitution` already, so it is safe to add - // WithCTE as top node here. - if (notInlined.isEmpty) { - inlined - } else { - WithCTE(inlined, notInlined.toSeq) - } + inlineCTE(plan, cteMap) } else { plan } @@ -74,22 +70,23 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] { * * @param plan The plan to collect the CTEs from * @param cteMap A mutable map that accumulates the CTEs and their reference information by CTE - * ids. The value of the map is tuple whose elements are: - * - The CTE definition - * - The number of incoming references to the CTE. This includes references from - * other CTEs and regular places. - * - A mutable inner map that tracks outgoing references (counts) to other CTEs. + * ids. * @param outerCTEId While collecting the map we use this optional CTE id to identify the * current outer CTE. */ - def buildCTEMap( + private def buildCTEMap( plan: LogicalPlan, - cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])], + cteMap: mutable.Map[Long, CTEReferenceInfo], outerCTEId: Option[Long] = None): Unit = { plan match { case WithCTE(child, cteDefs) => cteDefs.foreach { cteDef => - cteMap(cteDef.id) = (cteDef, 0, mutable.Map.empty.withDefaultValue(0)) + cteMap(cteDef.id) = CTEReferenceInfo( + cteDef = cteDef, + refCount = 0, + outgoingRefs = mutable.Map.empty.withDefaultValue(0), + shouldInline = true + ) } cteDefs.foreach { cteDef => buildCTEMap(cteDef, cteMap, Some(cteDef.id)) @@ -97,11 +94,9 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] { buildCTEMap(child, cteMap, outerCTEId) case ref: CTERelationRef => - val (cteDef, refCount, refMap) = cteMap(ref.cteId) - cteMap(ref.cteId) = (cteDef, refCount + 1, refMap) + cteMap(ref.cteId) = cteMap(ref.cteId).withRefCountIncreased(1) outerCTEId.foreach { cteId => - val (_, _, outerRefMap) = cteMap(cteId) - outerRefMap(ref.cteId) += 1 + cteMap(cteId).recordOutgoingReference(ref.cteId) } case _ => @@ -129,15 +124,12 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] { * @param cteMap A mutable map that accumulates the CTEs and their reference information by CTE * ids. Needs to be sorted to speed up cleaning. */ - private def cleanCTEMap( - cteMap: mutable.SortedMap[Long, (CTERelationDef, Int, mutable.Map[Long, Int])] - ) = { + private def cleanCTEMap(cteMap: mutable.SortedMap[Long, CTEReferenceInfo]): Unit = { cteMap.keys.toSeq.reverse.foreach { currentCTEId => - val (_, currentRefCount, refMap) = cteMap(currentCTEId) - if (currentRefCount == 0) { - refMap.foreach { case (referencedCTEId, uselessRefCount) => - val (cteDef, refCount, refMap) = cteMap(referencedCTEId) - cteMap(referencedCTEId) = (cteDef, refCount - uselessRefCount, refMap) + val refInfo = cteMap(currentCTEId) + if (refInfo.refCount == 0) { + refInfo.outgoingRefs.foreach { case (referencedCTEId, uselessRefCount) => + cteMap(referencedCTEId) = cteMap(referencedCTEId).withRefCountDecreased(uselessRefCount) } } } @@ -145,30 +137,45 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] { private def inlineCTE( plan: LogicalPlan, - cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])], - notInlined: mutable.ArrayBuffer[CTERelationDef]): LogicalPlan = { + cteMap: mutable.Map[Long, CTEReferenceInfo]): LogicalPlan = { plan match { case WithCTE(child, cteDefs) => - cteDefs.foreach { cteDef => - val (cte, refCount, refMap) = cteMap(cteDef.id) - if (refCount > 0) { - val inlined = cte.copy(child = inlineCTE(cte.child, cteMap, notInlined)) - cteMap(cteDef.id) = (inlined, refCount, refMap) - if (!shouldInline(inlined, refCount)) { - notInlined.append(inlined) - } + val remainingDefs = cteDefs.filter { cteDef => + val refInfo = cteMap(cteDef.id) + if (refInfo.refCount > 0) { + val newDef = refInfo.cteDef.copy(child = inlineCTE(refInfo.cteDef.child, cteMap)) + val inlineDecision = shouldInline(newDef, refInfo.refCount) + cteMap(cteDef.id) = cteMap(cteDef.id).copy( + cteDef = newDef, shouldInline = inlineDecision + ) + // Retain the not-inlined CTE relations in place. + !inlineDecision + } else { + keepDanglingRelations } } - inlineCTE(child, cteMap, notInlined) + val inlined = inlineCTE(child, cteMap) + if (remainingDefs.isEmpty) { + inlined + } else { + WithCTE(inlined, remainingDefs) + } case ref: CTERelationRef => - val (cteDef, refCount, _) = cteMap(ref.cteId) - if (shouldInline(cteDef, refCount)) { - if (ref.outputSet == cteDef.outputSet) { - cteDef.child + val refInfo = cteMap(ref.cteId) + if (refInfo.shouldInline) { + if (ref.outputSet == refInfo.cteDef.outputSet) { + refInfo.cteDef.child } else { val ctePlan = DeduplicateRelations( - Join(cteDef.child, cteDef.child, Inner, None, JoinHint(None, None))).children(1) + Join( + refInfo.cteDef.child, + refInfo.cteDef.child, + Inner, + None, + JoinHint(None, None) + ) + ).children(1) val projectList = ref.output.zip(ctePlan.output).map { case (tgtAttr, srcAttr) => if (srcAttr.semanticEquals(tgtAttr)) { tgtAttr @@ -184,13 +191,41 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] { case _ if plan.containsPattern(CTE) => plan - .withNewChildren(plan.children.map(child => inlineCTE(child, cteMap, notInlined))) + .withNewChildren(plan.children.map(child => inlineCTE(child, cteMap))) .transformExpressionsWithPruning(_.containsAllPatterns(PLAN_EXPRESSION, CTE)) { case e: SubqueryExpression => - e.withNewPlan(inlineCTE(e.plan, cteMap, notInlined)) + e.withNewPlan(inlineCTE(e.plan, cteMap)) } case _ => plan } } } + +/** + * The bookkeeping information for tracking CTE relation references. + * + * @param cteDef The CTE relation definition + * @param refCount The number of incoming references to this CTE relation. This includes references + * from other CTE relations and regular places. + * @param outgoingRefs A mutable map that tracks outgoing reference counts to other CTE relations. + * @param shouldInline If true, this CTE relation should be inlined in the places that reference it. + */ +case class CTEReferenceInfo( + cteDef: CTERelationDef, + refCount: Int, + outgoingRefs: mutable.Map[Long, Int], + shouldInline: Boolean) { + + def withRefCountIncreased(count: Int): CTEReferenceInfo = { + copy(refCount = refCount + count) + } + + def withRefCountDecreased(count: Int): CTEReferenceInfo = { + copy(refCount = refCount - count) + } + + def recordOutgoingReference(id: Long): Unit = { + outgoingRefs(id) += 1 + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTESuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTESuite.scala new file mode 100644 index 0000000000000..b92ae55899f0d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTESuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.TestRelation +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CTERelationDef, CTERelationRef, LogicalPlan, OneRowRelation, WithCTE} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class InlineCTESuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("inline CTE", FixedPoint(100), InlineCTE()) :: Nil + } + + test("not-inlined CTE relation in command") { + val cteDef = CTERelationDef(OneRowRelation().select(rand(0).as("a"))) + val cteRef = CTERelationRef(cteDef.id, cteDef.resolved, cteDef.output, cteDef.isStreaming) + val plan = AppendData.byName( + TestRelation(Seq($"a".double)), + WithCTE(cteRef.except(cteRef, isAll = true), Seq(cteDef)) + ).analyze + comparePlans(Optimize.execute(plan), plan) + } +} From b15ec40389b13483a0274c2dae225ed6aee91cbf Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 17 May 2024 13:01:42 +0800 Subject: [PATCH 2/3] address comments --- .../org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala index 1ac1d4c02ec7c..50828b945bb40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala @@ -96,7 +96,7 @@ case class InlineCTE( case ref: CTERelationRef => cteMap(ref.cteId) = cteMap(ref.cteId).withRefCountIncreased(1) outerCTEId.foreach { cteId => - cteMap(cteId).recordOutgoingReference(ref.cteId) + cteMap(cteId).increaseOutgoingRefCount(ref.cteId, 1) } case _ => @@ -225,7 +225,7 @@ case class CTEReferenceInfo( copy(refCount = refCount - count) } - def recordOutgoingReference(id: Long): Unit = { - outgoingRefs(id) += 1 + def increaseOutgoingRefCount(cteDefId: Long, count: Int): Unit = { + outgoingRefs(cteDefId) += count } } From a636deb34f38e7a065a636894381a41fa064b15f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 21 May 2024 07:21:25 +0800 Subject: [PATCH 3/3] Update sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTESuite.scala Co-authored-by: Liang-Chi Hsieh --- .../apache/spark/sql/catalyst/optimizer/InlineCTESuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTESuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTESuite.scala index b92ae55899f0d..9d775a5335c67 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTESuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTESuite.scala @@ -30,7 +30,7 @@ class InlineCTESuite extends PlanTest { val batches = Batch("inline CTE", FixedPoint(100), InlineCTE()) :: Nil } - test("not-inlined CTE relation in command") { + test("SPARK-48307: not-inlined CTE relation in command") { val cteDef = CTERelationDef(OneRowRelation().select(rand(0).as("a"))) val cteRef = CTERelationRef(cteDef.id, cteDef.resolved, cteDef.output, cteDef.isStreaming) val plan = AppendData.byName(