Skip to content

Commit

Permalink
InlineCTE should keep not-inlined relations in the original WithCTE node
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed May 16, 2024
1 parent 0ba8ddc commit 609f972
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,50 +142,17 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
errorClass, missingCol, orderedCandidates, a.origin)
}

private def checkUnreferencedCTERelations(
cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])],
visited: mutable.Map[Long, Boolean],
danglingCTERelations: mutable.ArrayBuffer[CTERelationDef],
cteId: Long): Unit = {
if (visited(cteId)) {
return
}
val (cteDef, _, refMap) = cteMap(cteId)
refMap.foreach { case (id, _) =>
checkUnreferencedCTERelations(cteMap, visited, danglingCTERelations, id)
}
danglingCTERelations.append(cteDef)
visited(cteId) = true
}

def checkAnalysis(plan: LogicalPlan): Unit = {
val inlineCTE = InlineCTE(alwaysInline = true)
val cteMap = mutable.HashMap.empty[Long, (CTERelationDef, Int, mutable.Map[Long, Int])]
inlineCTE.buildCTEMap(plan, cteMap)
val danglingCTERelations = mutable.ArrayBuffer.empty[CTERelationDef]
val visited: mutable.Map[Long, Boolean] = mutable.Map.empty.withDefaultValue(false)
// If a CTE relation is never used, it will disappear after inline. Here we explicitly collect
// these dangling CTE relations, and put them back in the main query, to make sure the entire
// query plan is valid.
cteMap.foreach { case (cteId, (_, refCount, _)) =>
// If a CTE relation ref count is 0, the other CTE relations that reference it should also be
// collected. This code will also guarantee the leaf relations that do not reference
// any others are collected first.
if (refCount == 0) {
checkUnreferencedCTERelations(cteMap, visited, danglingCTERelations, cteId)
}
}
// Inline all CTEs in the plan to help check query plan structures in subqueries.
var inlinedPlan: LogicalPlan = plan
try {
inlinedPlan = inlineCTE(plan)
// We should inline all CTE relations to restore the original plan shape, as the analysis check
// may need to match certain plan shapes. For dangling CTE relations, they will still be kept
// in the original `WithCTE` node, as we need to perform analysis check for them as well.
val inlineCTE = InlineCTE(alwaysInline = true, keepDanglingRelations = true)
val inlinedPlan: LogicalPlan = try {
inlineCTE(plan)
} catch {
case e: AnalysisException =>
throw new ExtendedAnalysisException(e, plan)
}
if (danglingCTERelations.nonEmpty) {
inlinedPlan = WithCTE(inlinedPlan, danglingCTERelations.toSeq)
}
try {
checkAnalysis0(inlinedPlan)
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,19 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{CTE, PLAN_EXPRESSION}
* query level.
*
* @param alwaysInline if true, inline all CTEs in the query plan.
* @param keepDanglingRelations if true, dangling CTE relations will be kept in the original
* `WithCTE` node.
*/
case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] {
case class InlineCTE(
alwaysInline: Boolean = false,
keepDanglingRelations: Boolean = false) extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = {
if (!plan.isInstanceOf[Subquery] && plan.containsPattern(CTE)) {
val cteMap = mutable.SortedMap.empty[Long, (CTERelationDef, Int, mutable.Map[Long, Int])]
val cteMap = mutable.SortedMap.empty[Long, CTEReferenceInfo]
buildCTEMap(plan, cteMap)
cleanCTEMap(cteMap)
val notInlined = mutable.ArrayBuffer.empty[CTERelationDef]
val inlined = inlineCTE(plan, cteMap, notInlined)
// CTEs in SQL Commands have been inlined by `CTESubstitution` already, so it is safe to add
// WithCTE as top node here.
if (notInlined.isEmpty) {
inlined
} else {
WithCTE(inlined, notInlined.toSeq)
}
inlineCTE(plan, cteMap)
} else {
plan
}
Expand All @@ -74,34 +70,33 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] {
*
* @param plan The plan to collect the CTEs from
* @param cteMap A mutable map that accumulates the CTEs and their reference information by CTE
* ids. The value of the map is tuple whose elements are:
* - The CTE definition
* - The number of incoming references to the CTE. This includes references from
* other CTEs and regular places.
* - A mutable inner map that tracks outgoing references (counts) to other CTEs.
* ids.
* @param outerCTEId While collecting the map we use this optional CTE id to identify the
* current outer CTE.
*/
def buildCTEMap(
private def buildCTEMap(
plan: LogicalPlan,
cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])],
cteMap: mutable.Map[Long, CTEReferenceInfo],
outerCTEId: Option[Long] = None): Unit = {
plan match {
case WithCTE(child, cteDefs) =>
cteDefs.foreach { cteDef =>
cteMap(cteDef.id) = (cteDef, 0, mutable.Map.empty.withDefaultValue(0))
cteMap(cteDef.id) = CTEReferenceInfo(
cteDef = cteDef,
refCount = 0,
outgoingRefs = mutable.Map.empty.withDefaultValue(0),
shouldInline = true
)
}
cteDefs.foreach { cteDef =>
buildCTEMap(cteDef, cteMap, Some(cteDef.id))
}
buildCTEMap(child, cteMap, outerCTEId)

case ref: CTERelationRef =>
val (cteDef, refCount, refMap) = cteMap(ref.cteId)
cteMap(ref.cteId) = (cteDef, refCount + 1, refMap)
cteMap(ref.cteId) = cteMap(ref.cteId).withRefCountIncreased(1)
outerCTEId.foreach { cteId =>
val (_, _, outerRefMap) = cteMap(cteId)
outerRefMap(ref.cteId) += 1
cteMap(cteId).recordOutgoingReference(ref.cteId)
}

case _ =>
Expand Down Expand Up @@ -129,46 +124,58 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] {
* @param cteMap A mutable map that accumulates the CTEs and their reference information by CTE
* ids. Needs to be sorted to speed up cleaning.
*/
private def cleanCTEMap(
cteMap: mutable.SortedMap[Long, (CTERelationDef, Int, mutable.Map[Long, Int])]
) = {
private def cleanCTEMap(cteMap: mutable.SortedMap[Long, CTEReferenceInfo]): Unit = {
cteMap.keys.toSeq.reverse.foreach { currentCTEId =>
val (_, currentRefCount, refMap) = cteMap(currentCTEId)
if (currentRefCount == 0) {
refMap.foreach { case (referencedCTEId, uselessRefCount) =>
val (cteDef, refCount, refMap) = cteMap(referencedCTEId)
cteMap(referencedCTEId) = (cteDef, refCount - uselessRefCount, refMap)
val refInfo = cteMap(currentCTEId)
if (refInfo.refCount == 0) {
refInfo.outgoingRefs.foreach { case (referencedCTEId, uselessRefCount) =>
cteMap(referencedCTEId) = cteMap(referencedCTEId).withRefCountDecreased(uselessRefCount)
}
}
}
}

private def inlineCTE(
plan: LogicalPlan,
cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])],
notInlined: mutable.ArrayBuffer[CTERelationDef]): LogicalPlan = {
cteMap: mutable.Map[Long, CTEReferenceInfo]): LogicalPlan = {
plan match {
case WithCTE(child, cteDefs) =>
cteDefs.foreach { cteDef =>
val (cte, refCount, refMap) = cteMap(cteDef.id)
if (refCount > 0) {
val inlined = cte.copy(child = inlineCTE(cte.child, cteMap, notInlined))
cteMap(cteDef.id) = (inlined, refCount, refMap)
if (!shouldInline(inlined, refCount)) {
notInlined.append(inlined)
}
val remainingDefs = cteDefs.filter { cteDef =>
val refInfo = cteMap(cteDef.id)
if (refInfo.refCount > 0) {
val newDef = refInfo.cteDef.copy(child = inlineCTE(refInfo.cteDef.child, cteMap))
val inlineDecision = shouldInline(newDef, refInfo.refCount)
cteMap(cteDef.id) = cteMap(cteDef.id).copy(
cteDef = newDef, shouldInline = inlineDecision
)
// Retain the not-inlined CTE relations in place.
!inlineDecision
} else {
keepDanglingRelations
}
}
inlineCTE(child, cteMap, notInlined)
val inlined = inlineCTE(child, cteMap)
if (remainingDefs.isEmpty) {
inlined
} else {
WithCTE(inlined, remainingDefs)
}

case ref: CTERelationRef =>
val (cteDef, refCount, _) = cteMap(ref.cteId)
if (shouldInline(cteDef, refCount)) {
if (ref.outputSet == cteDef.outputSet) {
cteDef.child
val refInfo = cteMap(ref.cteId)
if (refInfo.shouldInline) {
if (ref.outputSet == refInfo.cteDef.outputSet) {
refInfo.cteDef.child
} else {
val ctePlan = DeduplicateRelations(
Join(cteDef.child, cteDef.child, Inner, None, JoinHint(None, None))).children(1)
Join(
refInfo.cteDef.child,
refInfo.cteDef.child,
Inner,
None,
JoinHint(None, None)
)
).children(1)
val projectList = ref.output.zip(ctePlan.output).map { case (tgtAttr, srcAttr) =>
if (srcAttr.semanticEquals(tgtAttr)) {
tgtAttr
Expand All @@ -184,13 +191,41 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] {

case _ if plan.containsPattern(CTE) =>
plan
.withNewChildren(plan.children.map(child => inlineCTE(child, cteMap, notInlined)))
.withNewChildren(plan.children.map(child => inlineCTE(child, cteMap)))
.transformExpressionsWithPruning(_.containsAllPatterns(PLAN_EXPRESSION, CTE)) {
case e: SubqueryExpression =>
e.withNewPlan(inlineCTE(e.plan, cteMap, notInlined))
e.withNewPlan(inlineCTE(e.plan, cteMap))
}

case _ => plan
}
}
}

/**
* The bookkeeping information for tracking CTE relation references.
*
* @param cteDef The CTE relation definition
* @param refCount The number of incoming references to this CTE relation. This includes references
* from other CTE relations and regular places.
* @param outgoingRefs A mutable map that tracks outgoing reference counts to other CTE relations.
* @param shouldInline If true, this CTE relation should be inlined in the places that reference it.
*/
case class CTEReferenceInfo(
cteDef: CTERelationDef,
refCount: Int,
outgoingRefs: mutable.Map[Long, Int],
shouldInline: Boolean) {

def withRefCountIncreased(count: Int): CTEReferenceInfo = {
copy(refCount = refCount + count)
}

def withRefCountDecreased(count: Int): CTEReferenceInfo = {
copy(refCount = refCount - count)
}

def recordOutgoingReference(id: Long): Unit = {
outgoingRefs(id) += 1
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.TestRelation
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CTERelationDef, CTERelationRef, LogicalPlan, OneRowRelation, WithCTE}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class InlineCTESuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("inline CTE", FixedPoint(100), InlineCTE()) :: Nil
}

test("not-inlined CTE relation in command") {
val cteDef = CTERelationDef(OneRowRelation().select(rand(0).as("a")))
val cteRef = CTERelationRef(cteDef.id, cteDef.resolved, cteDef.output, cteDef.isStreaming)
val plan = AppendData.byName(
TestRelation(Seq($"a".double)),
WithCTE(cteRef.except(cteRef, isAll = true), Seq(cteDef))
).analyze
comparePlans(Optimize.execute(plan), plan)
}
}

0 comments on commit 609f972

Please sign in to comment.