Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48307][SQL] InlineCTE should keep not-inlined relations in the original WithCTE node #46617

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -142,50 +142,17 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
errorClass, missingCol, orderedCandidates, a.origin)
}

private def checkUnreferencedCTERelations(
cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])],
visited: mutable.Map[Long, Boolean],
danglingCTERelations: mutable.ArrayBuffer[CTERelationDef],
cteId: Long): Unit = {
if (visited(cteId)) {
return
}
val (cteDef, _, refMap) = cteMap(cteId)
refMap.foreach { case (id, _) =>
checkUnreferencedCTERelations(cteMap, visited, danglingCTERelations, id)
}
danglingCTERelations.append(cteDef)
visited(cteId) = true
}

def checkAnalysis(plan: LogicalPlan): Unit = {
val inlineCTE = InlineCTE(alwaysInline = true)
val cteMap = mutable.HashMap.empty[Long, (CTERelationDef, Int, mutable.Map[Long, Int])]
inlineCTE.buildCTEMap(plan, cteMap)
val danglingCTERelations = mutable.ArrayBuffer.empty[CTERelationDef]
val visited: mutable.Map[Long, Boolean] = mutable.Map.empty.withDefaultValue(false)
// If a CTE relation is never used, it will disappear after inline. Here we explicitly collect
// these dangling CTE relations, and put them back in the main query, to make sure the entire
// query plan is valid.
cteMap.foreach { case (cteId, (_, refCount, _)) =>
// If a CTE relation ref count is 0, the other CTE relations that reference it should also be
// collected. This code will also guarantee the leaf relations that do not reference
// any others are collected first.
if (refCount == 0) {
checkUnreferencedCTERelations(cteMap, visited, danglingCTERelations, cteId)
}
}
// Inline all CTEs in the plan to help check query plan structures in subqueries.
var inlinedPlan: LogicalPlan = plan
try {
inlinedPlan = inlineCTE(plan)
// We should inline all CTE relations to restore the original plan shape, as the analysis check
// may need to match certain plan shapes. For dangling CTE relations, they will still be kept
// in the original `WithCTE` node, as we need to perform analysis check for them as well.
val inlineCTE = InlineCTE(alwaysInline = true, keepDanglingRelations = true)
val inlinedPlan: LogicalPlan = try {
inlineCTE(plan)
} catch {
case e: AnalysisException =>
throw new ExtendedAnalysisException(e, plan)
}
if (danglingCTERelations.nonEmpty) {
inlinedPlan = WithCTE(inlinedPlan, danglingCTERelations.toSeq)
}
try {
checkAnalysis0(inlinedPlan)
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,19 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{CTE, PLAN_EXPRESSION}
* query level.
*
* @param alwaysInline if true, inline all CTEs in the query plan.
* @param keepDanglingRelations if true, dangling CTE relations will be kept in the original
* `WithCTE` node.
*/
case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] {
case class InlineCTE(
alwaysInline: Boolean = false,
keepDanglingRelations: Boolean = false) extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = {
if (!plan.isInstanceOf[Subquery] && plan.containsPattern(CTE)) {
val cteMap = mutable.SortedMap.empty[Long, (CTERelationDef, Int, mutable.Map[Long, Int])]
val cteMap = mutable.SortedMap.empty[Long, CTEReferenceInfo]
buildCTEMap(plan, cteMap)
cleanCTEMap(cteMap)
val notInlined = mutable.ArrayBuffer.empty[CTERelationDef]
val inlined = inlineCTE(plan, cteMap, notInlined)
// CTEs in SQL Commands have been inlined by `CTESubstitution` already, so it is safe to add
// WithCTE as top node here.
if (notInlined.isEmpty) {
inlined
} else {
WithCTE(inlined, notInlined.toSeq)
}
inlineCTE(plan, cteMap)
} else {
plan
}
Expand All @@ -74,34 +70,33 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] {
*
* @param plan The plan to collect the CTEs from
* @param cteMap A mutable map that accumulates the CTEs and their reference information by CTE
* ids. The value of the map is tuple whose elements are:
* - The CTE definition
* - The number of incoming references to the CTE. This includes references from
* other CTEs and regular places.
* - A mutable inner map that tracks outgoing references (counts) to other CTEs.
* ids.
* @param outerCTEId While collecting the map we use this optional CTE id to identify the
* current outer CTE.
*/
def buildCTEMap(
private def buildCTEMap(
plan: LogicalPlan,
cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])],
cteMap: mutable.Map[Long, CTEReferenceInfo],
outerCTEId: Option[Long] = None): Unit = {
plan match {
case WithCTE(child, cteDefs) =>
cteDefs.foreach { cteDef =>
cteMap(cteDef.id) = (cteDef, 0, mutable.Map.empty.withDefaultValue(0))
cteMap(cteDef.id) = CTEReferenceInfo(
cteDef = cteDef,
refCount = 0,
outgoingRefs = mutable.Map.empty.withDefaultValue(0),
shouldInline = true
)
}
cteDefs.foreach { cteDef =>
buildCTEMap(cteDef, cteMap, Some(cteDef.id))
}
buildCTEMap(child, cteMap, outerCTEId)

case ref: CTERelationRef =>
val (cteDef, refCount, refMap) = cteMap(ref.cteId)
cteMap(ref.cteId) = (cteDef, refCount + 1, refMap)
cteMap(ref.cteId) = cteMap(ref.cteId).withRefCountIncreased(1)
outerCTEId.foreach { cteId =>
val (_, _, outerRefMap) = cteMap(cteId)
outerRefMap(ref.cteId) += 1
cteMap(cteId).increaseOutgoingRefCount(ref.cteId, 1)
}

case _ =>
Expand Down Expand Up @@ -129,46 +124,58 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] {
* @param cteMap A mutable map that accumulates the CTEs and their reference information by CTE
* ids. Needs to be sorted to speed up cleaning.
*/
private def cleanCTEMap(
cteMap: mutable.SortedMap[Long, (CTERelationDef, Int, mutable.Map[Long, Int])]
) = {
private def cleanCTEMap(cteMap: mutable.SortedMap[Long, CTEReferenceInfo]): Unit = {
cteMap.keys.toSeq.reverse.foreach { currentCTEId =>
val (_, currentRefCount, refMap) = cteMap(currentCTEId)
if (currentRefCount == 0) {
refMap.foreach { case (referencedCTEId, uselessRefCount) =>
val (cteDef, refCount, refMap) = cteMap(referencedCTEId)
cteMap(referencedCTEId) = (cteDef, refCount - uselessRefCount, refMap)
val refInfo = cteMap(currentCTEId)
if (refInfo.refCount == 0) {
refInfo.outgoingRefs.foreach { case (referencedCTEId, uselessRefCount) =>
cteMap(referencedCTEId) = cteMap(referencedCTEId).withRefCountDecreased(uselessRefCount)
}
}
}
}

private def inlineCTE(
plan: LogicalPlan,
cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])],
notInlined: mutable.ArrayBuffer[CTERelationDef]): LogicalPlan = {
cteMap: mutable.Map[Long, CTEReferenceInfo]): LogicalPlan = {
plan match {
case WithCTE(child, cteDefs) =>
cteDefs.foreach { cteDef =>
val (cte, refCount, refMap) = cteMap(cteDef.id)
if (refCount > 0) {
val inlined = cte.copy(child = inlineCTE(cte.child, cteMap, notInlined))
cteMap(cteDef.id) = (inlined, refCount, refMap)
if (!shouldInline(inlined, refCount)) {
notInlined.append(inlined)
}
val remainingDefs = cteDefs.filter { cteDef =>
val refInfo = cteMap(cteDef.id)
if (refInfo.refCount > 0) {
val newDef = refInfo.cteDef.copy(child = inlineCTE(refInfo.cteDef.child, cteMap))
val inlineDecision = shouldInline(newDef, refInfo.refCount)
cteMap(cteDef.id) = cteMap(cteDef.id).copy(
cteDef = newDef, shouldInline = inlineDecision
)
// Retain the not-inlined CTE relations in place.
!inlineDecision
} else {
keepDanglingRelations
}
}
inlineCTE(child, cteMap, notInlined)
val inlined = inlineCTE(child, cteMap)
if (remainingDefs.isEmpty) {
inlined
} else {
WithCTE(inlined, remainingDefs)
}

case ref: CTERelationRef =>
val (cteDef, refCount, _) = cteMap(ref.cteId)
if (shouldInline(cteDef, refCount)) {
if (ref.outputSet == cteDef.outputSet) {
cteDef.child
val refInfo = cteMap(ref.cteId)
if (refInfo.shouldInline) {
if (ref.outputSet == refInfo.cteDef.outputSet) {
refInfo.cteDef.child
} else {
val ctePlan = DeduplicateRelations(
Join(cteDef.child, cteDef.child, Inner, None, JoinHint(None, None))).children(1)
Join(
refInfo.cteDef.child,
refInfo.cteDef.child,
Inner,
None,
JoinHint(None, None)
)
).children(1)
val projectList = ref.output.zip(ctePlan.output).map { case (tgtAttr, srcAttr) =>
if (srcAttr.semanticEquals(tgtAttr)) {
tgtAttr
Expand All @@ -184,13 +191,41 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] {

case _ if plan.containsPattern(CTE) =>
plan
.withNewChildren(plan.children.map(child => inlineCTE(child, cteMap, notInlined)))
.withNewChildren(plan.children.map(child => inlineCTE(child, cteMap)))
.transformExpressionsWithPruning(_.containsAllPatterns(PLAN_EXPRESSION, CTE)) {
case e: SubqueryExpression =>
e.withNewPlan(inlineCTE(e.plan, cteMap, notInlined))
e.withNewPlan(inlineCTE(e.plan, cteMap))
}

case _ => plan
}
}
}

/**
* The bookkeeping information for tracking CTE relation references.
*
* @param cteDef The CTE relation definition
* @param refCount The number of incoming references to this CTE relation. This includes references
* from other CTE relations and regular places.
* @param outgoingRefs A mutable map that tracks outgoing reference counts to other CTE relations.
* @param shouldInline If true, this CTE relation should be inlined in the places that reference it.
*/
case class CTEReferenceInfo(
cteDef: CTERelationDef,
refCount: Int,
outgoingRefs: mutable.Map[Long, Int],
shouldInline: Boolean) {

def withRefCountIncreased(count: Int): CTEReferenceInfo = {
copy(refCount = refCount + count)
}

def withRefCountDecreased(count: Int): CTEReferenceInfo = {
copy(refCount = refCount - count)
}

def increaseOutgoingRefCount(cteDefId: Long, count: Int): Unit = {
outgoingRefs(cteDefId) += count
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.TestRelation
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CTERelationDef, CTERelationRef, LogicalPlan, OneRowRelation, WithCTE}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class InlineCTESuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("inline CTE", FixedPoint(100), InlineCTE()) :: Nil
}

test("SPARK-48307: not-inlined CTE relation in command") {
val cteDef = CTERelationDef(OneRowRelation().select(rand(0).as("a")))
val cteRef = CTERelationRef(cteDef.id, cteDef.resolved, cteDef.output, cteDef.isStreaming)
val plan = AppendData.byName(
TestRelation(Seq($"a".double)),
WithCTE(cteRef.except(cteRef, isAll = true), Seq(cteDef))
).analyze
comparePlans(Optimize.execute(plan), plan)
}
}