Skip to content

Commit

Permalink
[SPARK-43838][SQL] Fix subquery on single table with having clause ca…
Browse files Browse the repository at this point in the history
…n't be optimized

### What changes were proposed in this pull request?

Eg:
```scala
sql("create view t(c1, c2) as values (0, 1), (0, 2), (1, 2)")

sql("select c1, c2, (select count(*) cnt from t t2 where t1.c1 = t2.c1 " +
"having cnt = 0) from t t1").show()
```
The error will throw:
```
[PLAN_VALIDATION_FAILED_RULE_IN_BATCH] Rule org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery in batch Operator Optimization before Inferring Filters generated an invalid plan: The plan becomes unresolved: 'Project [toprettystring(c1#224, Some(America/Los_Angeles)) AS toprettystring(c1)#238, toprettystring(c2#225, Some(America/Los_Angeles)) AS toprettystring(c2)#239, toprettystring(cnt#246L, Some(America/Los_Angeles)) AS toprettystring(scalarsubquery(c1))#240]
+- 'Project [c1#224, c2#225, CASE WHEN isnull(alwaysTrue#245) THEN 0 WHEN NOT (cnt#222L = 0) THEN null ELSE cnt#222L END AS cnt#246L]
   +- 'Join LeftOuter, (c1#224 = c1#224#244)
      :- Project [col1#226 AS c1#224, col2#227 AS c2#225]
      :  +- LocalRelation [col1#226, col2#227]
      +- Project [cnt#222L, c1#224#244, cnt#222L, c1#224, true AS alwaysTrue#245]
         +- Project [cnt#222L, c1#224 AS c1#224#244, cnt#222L, c1#224]
            +- Aggregate [c1#224], [count(1) AS cnt#222L, c1#224]
               +- Project [col1#228 AS c1#224]
                  +- LocalRelation [col1#228, col2#229]The previous plan: Project [toprettystring(c1#224, Some(America/Los_Angeles)) AS toprettystring(c1)#238, toprettystring(c2#225, Some(America/Los_Angeles)) AS toprettystring(c2)#239, toprettystring(scalar-subquery#223 [c1#224 && (c1#224 = c1#224#244)], Some(America/Los_Angeles)) AS toprettystring(scalarsubquery(c1))#240]
:  +- Project [cnt#222L, c1#224 AS c1#224#244]
:     +- Filter (cnt#222L = 0)
:        +- Aggregate [c1#224], [count(1) AS cnt#222L, c1#224]
:           +- Project [col1#228 AS c1#224]
:              +- LocalRelation [col1#228, col2#229]
+- Project [col1#226 AS c1#224, col2#227 AS c2#225]
   +- LocalRelation [col1#226, col2#227]
```

The reason of error is the unresolved expression in `Join` node which generate by subquery decorrelation. The `duplicateResolved` in `Join` node are false. That's meaning the `Join` left and right have same `Attribute`, in this eg is `c1#224`. The right `c1#224` `Attribute` generated by having Inputs, because there are wrong having Inputs.

This problem only occurs when there contain having clause.

also do some code format fix.

### Why are the changes needed?
Fix subquery bug on single table when use having clause

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Add new test

Closes #41347 from Hisoka-X/SPARK-43838_subquery_having.

Lead-authored-by: Jia Fan <fanjiaeminem@qq.com>
Co-authored-by: Wenchen Fan <cloud0fan@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
Hisoka-X and cloud-fan committed Jul 20, 2023
1 parent 11e2e42 commit e0c79c6
Show file tree
Hide file tree
Showing 45 changed files with 2,168 additions and 2,041 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -979,10 +979,8 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
case ScalarSubquery(query, outerAttrs, _, _, _, _) =>
// Scalar subquery must return one column as output.
if (query.output.size != 1) {
expr.failAnalysis(
errorClass = "INVALID_SUBQUERY_EXPRESSION." +
"SCALAR_SUBQUERY_RETURN_MORE_THAN_ONE_OUTPUT_COLUMN",
messageParameters = Map("number" -> query.output.size.toString))
throw QueryCompilationErrors.subqueryReturnMoreThanOneColumn(query.output.size,
expr.origin)
}

if (outerAttrs.nonEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
}
}

private def existDuplicatedExprId(
existingRelations: mutable.HashSet[RelationWrapper],
plan: RelationWrapper): Boolean = {
existingRelations.filter(_.cls == plan.cls)
.exists(_.outputAttrIds.intersect(plan.outputAttrIds).nonEmpty)
}

/**
* Deduplicate any duplicated relations of a LogicalPlan
* @param existingRelations the known unique relations for a LogicalPlan
Expand All @@ -95,59 +102,161 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
case p: LogicalPlan if p.isStreaming => (plan, false)

case m: MultiInstanceRelation =>
val planWrapper = RelationWrapper(m.getClass, m.output.map(_.exprId.id))
if (existingRelations.contains(planWrapper)) {
val newNode = m.newInstance()
newNode.copyTagsFrom(m)
(newNode, true)
} else {
existingRelations.add(planWrapper)
(m, false)
}
deduplicateAndRenew[LogicalPlan with MultiInstanceRelation](
existingRelations,
m,
_.output.map(_.exprId.id),
node => node.newInstance().asInstanceOf[LogicalPlan with MultiInstanceRelation])

case p: Project =>
deduplicateAndRenew[Project](
existingRelations,
p,
newProject => findAliases(newProject.projectList).map(_.exprId.id).toSeq,
newProject => newProject.copy(newAliases(newProject.projectList)))

case s: SerializeFromObject =>
deduplicateAndRenew[SerializeFromObject](
existingRelations,
s,
_.serializer.map(_.exprId.id),
newSer => newSer.copy(newSer.serializer.map(_.newInstance())))

case f: FlatMapGroupsInPandas =>
deduplicateAndRenew[FlatMapGroupsInPandas](
existingRelations,
f,
_.output.map(_.exprId.id),
newFlatMap => newFlatMap.copy(output = newFlatMap.output.map(_.newInstance())))

case f: FlatMapCoGroupsInPandas =>
deduplicateAndRenew[FlatMapCoGroupsInPandas](
existingRelations,
f,
_.output.map(_.exprId.id),
newFlatMap => newFlatMap.copy(output = newFlatMap.output.map(_.newInstance())))

case m: MapInPandas =>
deduplicateAndRenew[MapInPandas](
existingRelations,
m,
_.output.map(_.exprId.id),
newMap => newMap.copy(output = newMap.output.map(_.newInstance())))

case p: PythonMapInArrow =>
deduplicateAndRenew[PythonMapInArrow](
existingRelations,
p,
_.output.map(_.exprId.id),
newMap => newMap.copy(output = newMap.output.map(_.newInstance())))

case a: AttachDistributedSequence =>
deduplicateAndRenew[AttachDistributedSequence](
existingRelations,
a,
_.producedAttributes.map(_.exprId.id).toSeq,
newAttach => newAttach.copy(sequenceAttr = newAttach.producedAttributes
.map(_.newInstance()).head))

case g: Generate =>
deduplicateAndRenew[Generate](
existingRelations,
g,
_.generatorOutput.map(_.exprId.id), newGenerate =>
newGenerate.copy(generatorOutput = newGenerate.generatorOutput.map(_.newInstance())))

case e: Expand =>
deduplicateAndRenew[Expand](
existingRelations,
e,
_.producedAttributes.map(_.exprId.id).toSeq,
newExpand => newExpand.copy(output = newExpand.output.map(_.newInstance())))

case w: Window =>
deduplicateAndRenew[Window](
existingRelations,
w,
_.windowExpressions.map(_.exprId.id),
newWindow => newWindow.copy(windowExpressions =
newWindow.windowExpressions.map(_.newInstance())))

case s: ScriptTransformation =>
deduplicateAndRenew[ScriptTransformation](
existingRelations,
s,
_.output.map(_.exprId.id),
newScript => newScript.copy(output = newScript.output.map(_.newInstance())))

case plan: LogicalPlan =>
var planChanged = false
val newPlan = if (plan.children.nonEmpty) {
val newChildren = mutable.ArrayBuffer.empty[LogicalPlan]
for (c <- plan.children) {
val (renewed, changed) = renewDuplicatedRelations(existingRelations, c)
newChildren += renewed
if (changed) {
planChanged = true
}
}
deduplicate(existingRelations, plan)
}

val planWithNewSubquery = plan.transformExpressions {
case subquery: SubqueryExpression =>
val (renewed, changed) = renewDuplicatedRelations(existingRelations, subquery.plan)
if (changed) planChanged = true
subquery.withNewPlan(renewed)
private def deduplicate(
existingRelations: mutable.HashSet[RelationWrapper],
plan: LogicalPlan): (LogicalPlan, Boolean) = {
var planChanged = false
val newPlan = if (plan.children.nonEmpty) {
val newChildren = mutable.ArrayBuffer.empty[LogicalPlan]
for (c <- plan.children) {
val (renewed, changed) = renewDuplicatedRelations(existingRelations, c)
newChildren += renewed
if (changed) {
planChanged = true
}
}

val planWithNewSubquery = plan.transformExpressions {
case subquery: SubqueryExpression =>
val (renewed, changed) = renewDuplicatedRelations(existingRelations, subquery.plan)
if (changed) planChanged = true
subquery.withNewPlan(renewed)
}

if (planChanged) {
if (planWithNewSubquery.childrenResolved) {
val planWithNewChildren = planWithNewSubquery.withNewChildren(newChildren.toSeq)
val attrMap = AttributeMap(
plan
.children
.flatMap(_.output).zip(newChildren.flatMap(_.output))
.filter { case (a1, a2) => a1.exprId != a2.exprId }
)
if (attrMap.isEmpty) {
planWithNewChildren
} else {
planWithNewChildren.rewriteAttrs(attrMap)
}
if (planChanged) {
if (planWithNewSubquery.childrenResolved) {
val planWithNewChildren = planWithNewSubquery.withNewChildren(newChildren.toSeq)
val attrMap = AttributeMap(plan.children.flatMap(_.output)
.zip(newChildren.flatMap(_.output)).filter { case (a1, a2) => a1.exprId != a2.exprId })
if (attrMap.isEmpty) {
planWithNewChildren
} else {
planWithNewSubquery.withNewChildren(newChildren.toSeq)
planWithNewChildren.rewriteAttrs(attrMap)
}
} else {
plan
planWithNewSubquery.withNewChildren(newChildren.toSeq)
}
} else {
plan
}
} else {
plan
}
(newPlan, planChanged)
}

private def deduplicateAndRenew[T <: LogicalPlan](
existingRelations: mutable.HashSet[RelationWrapper], plan: T,
getExprIds: T => Seq[Long],
copyNewPlan: T => T): (LogicalPlan, Boolean) = {
var (newPlan, planChanged) = deduplicate(existingRelations, plan)
if (newPlan.resolved) {
val exprIds = getExprIds(newPlan.asInstanceOf[T])
if (exprIds.nonEmpty) {
val planWrapper = RelationWrapper(newPlan.getClass, exprIds)
if (existDuplicatedExprId(existingRelations, planWrapper)) {
newPlan = copyNewPlan(newPlan.asInstanceOf[T])
newPlan.copyTagsFrom(plan)
(newPlan, true)
} else {
existingRelations.add(planWrapper)
(newPlan, planChanged)
}
} else {
(newPlan, planChanged)
}
} else {
(newPlan, planChanged)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,10 @@ case class ScalarSubquery(
mayHaveCountBug: Option[Boolean] = None)
extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable {
override def dataType: DataType = {
assert(plan.schema.fields.nonEmpty, "Scalar subquery should have only one column")
if (!plan.schema.fields.nonEmpty) {
throw QueryCompilationErrors.subqueryReturnMoreThanOneColumn(plan.schema.fields.length,
origin)
}
plan.schema.fields.head.dataType
}
override def nullable: Boolean = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2072,6 +2072,14 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase {
messageParameters = Map("function" -> funcStr))
}

def subqueryReturnMoreThanOneColumn(number: Int, origin: Origin): Throwable = {
new AnalysisException(
errorClass = "INVALID_SUBQUERY_EXPRESSION." +
"SCALAR_SUBQUERY_RETURN_MORE_THAN_ONE_OUTPUT_COLUMN",
origin = origin,
messageParameters = Map("number" -> number.toString))
}

def unsupportedCorrelatedReferenceDataTypeError(
expr: Expression,
dataType: DataType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,25 +437,6 @@ class LeftSemiAntiJoinPushDownSuite extends PlanTest {
}
}

Seq(LeftSemi, LeftAnti).foreach { case jt =>
test(s"Aggregate: $jt join no pushdown - join condition refers left leg and right leg child") {
val aggregation = testRelation
.select($"b".as("id"), $"c")
.groupBy($"id")($"id", sum($"c").as("sum"))

// reference "b" exists in left leg, and the children of the right leg of the join
val originalQuery = aggregation.select(($"id" + 1).as("id_plus_1"), $"sum")
.join(aggregation, joinType = jt, condition = Some($"id" === $"id_plus_1"))
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
.select($"b".as("id"), $"c")
.groupBy($"id")(($"id" + 1).as("id_plus_1"), sum($"c").as("sum"))
.join(aggregation, joinType = jt, condition = Some($"id" === $"id_plus_1"))
.analyze
comparePlans(optimized, correctAnswer)
}
}

Seq(LeftSemi, LeftAnti).foreach { case outerJT =>
Seq(Inner, LeftOuter, RightOuter, Cross).foreach { case innerJT =>
test(s"$outerJT no pushdown - join condition refers none of the leg - join type $innerJT") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1977,17 +1977,16 @@ Union false, false
: +- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- Project [c1#x AS c1#x, c2#x AS c2#x, c1#x AS c1#x, c2#x AS c2#x]
+- Project [c1#x, c2#x, c1#x, c2#x]
+- Join Inner
:- SubqueryAlias spark_catalog.default.t1
: +- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias spark_catalog.default.t4
+- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+- LocalRelation [col1#x, col2#x]
+- Project [c1#x, c2#x, c1#x, c2#x]
+- Join Inner
:- SubqueryAlias spark_catalog.default.t1
: +- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias spark_catalog.default.t4
+- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+- LocalRelation [col1#x, col2#x]


-- !query
Expand Down Expand Up @@ -2030,27 +2029,26 @@ Union false, false
: +- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- Project [c1#x AS c1#x, c2#x AS c2#x, c2#x AS c2#x]
+- Project [c1#x, c2#x, c2#x]
+- LateralJoin lateral-subquery#x [c1#x && c1#x], Inner
: +- SubqueryAlias __auto_generated_subquery_name
: +- Union false, false
: :- Project [c2#x]
: : +- Filter (outer(c1#x) <= c1#x)
: : +- SubqueryAlias spark_catalog.default.t1
: : +- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
: : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: : +- LocalRelation [col1#x, col2#x]
: +- Project [c2#x]
: +- Filter (c1#x < outer(c1#x))
: +- SubqueryAlias spark_catalog.default.t4
: +- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias spark_catalog.default.t2
+- View (`spark_catalog`.`default`.`t2`, [c1#x,c2#x])
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+- LocalRelation [col1#x, col2#x]
+- Project [c1#x, c2#x, c2#x]
+- LateralJoin lateral-subquery#x [c1#x && c1#x], Inner
: +- SubqueryAlias __auto_generated_subquery_name
: +- Union false, false
: :- Project [c2#x]
: : +- Filter (outer(c1#x) <= c1#x)
: : +- SubqueryAlias spark_catalog.default.t1
: : +- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
: : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: : +- LocalRelation [col1#x, col2#x]
: +- Project [c2#x]
: +- Filter (c1#x < outer(c1#x))
: +- SubqueryAlias spark_catalog.default.t4
: +- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias spark_catalog.default.t2
+- View (`spark_catalog`.`default`.`t2`, [c1#x,c2#x])
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+- LocalRelation [col1#x, col2#x]


-- !query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1057,3 +1057,21 @@ Project [c1#xL, c2#xL]
: +- Range (1, 2, step=1, splits=None)
+- SubqueryAlias t1
+- Range (1, 3, step=1, splits=None)


-- !query
SELECT c1, c2, (SELECT count(*) cnt FROM t1 t2 WHERE t1.c1 = t2.c1 HAVING cnt = 0) FROM t1
-- !query analysis
Project [c1#x, c2#x, scalar-subquery#x [c1#x] AS scalarsubquery(c1)#xL]
: +- Filter (cnt#xL = cast(0 as bigint))
: +- Aggregate [count(1) AS cnt#xL]
: +- Filter (outer(c1#x) = c1#x)
: +- SubqueryAlias t2
: +- SubqueryAlias t1
: +- View (`t1`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias t1
+- View (`t1`, [c1#x,c2#x])
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+- LocalRelation [col1#x, col2#x]
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,6 @@ select * from (
where t1.id = t2.id ) c2
from range (1, 3) t1 ) t
where t.c2 is not null;

-- SPARK-43838: Subquery on single table with having clause
SELECT c1, c2, (SELECT count(*) cnt FROM t1 t2 WHERE t1.c1 = t2.c1 HAVING cnt = 0) FROM t1
Original file line number Diff line number Diff line change
Expand Up @@ -598,3 +598,12 @@ where t.c2 is not null
struct<c1:bigint,c2:bigint>
-- !query output
1 1


-- !query
SELECT c1, c2, (SELECT count(*) cnt FROM t1 t2 WHERE t1.c1 = t2.c1 HAVING cnt = 0) FROM t1
-- !query schema
struct<c1:int,c2:int,scalarsubquery(c1):bigint>
-- !query output
0 1 NULL
1 2 NULL
Loading

0 comments on commit e0c79c6

Please sign in to comment.